1 // SPDX-License-Identifier: GPL-2.0 2 3 //! Procedural macro to run KUnit tests using a user-space like syntax. 4 //! 5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> 6 7 use std::collections::HashMap; 8 use std::fmt::Write; 9 10 use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; 11 12 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { 13 let attr = attr.to_string(); 14 15 if attr.is_empty() { 16 panic!("Missing test name in `#[kunit_tests(test_name)]` macro") 17 } 18 19 if attr.len() > 255 { 20 panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") 21 } 22 23 let mut tokens: Vec<_> = ts.into_iter().collect(); 24 25 // Scan for the `mod` keyword. 26 tokens 27 .iter() 28 .find_map(|token| match token { 29 TokenTree::Ident(ident) => match ident.to_string().as_str() { 30 "mod" => Some(true), 31 _ => None, 32 }, 33 _ => None, 34 }) 35 .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); 36 37 // Retrieve the main body. The main body should be the last token tree. 38 let body = match tokens.pop() { 39 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, 40 _ => panic!("Cannot locate main body of module"), 41 }; 42 43 // Get the functions set as tests. Search for `[test]` -> `fn`. 44 let mut body_it = body.stream().into_iter(); 45 let mut tests = Vec::new(); 46 let mut attributes: HashMap<String, TokenStream> = HashMap::new(); 47 while let Some(token) = body_it.next() { 48 match token { 49 TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { 50 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { 51 if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { 52 // Collect attributes because we need to find which are tests. We also 53 // need to copy `cfg` attributes so tests can be conditionally enabled. 54 attributes 55 .entry(name.to_string()) 56 .or_default() 57 .extend([token, TokenTree::Group(g)]); 58 } 59 continue; 60 } 61 _ => (), 62 }, 63 TokenTree::Ident(i) if i == "fn" && attributes.contains_key("test") => { 64 if let Some(TokenTree::Ident(test_name)) = body_it.next() { 65 tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) 66 } 67 } 68 69 _ => (), 70 } 71 attributes.clear(); 72 } 73 74 // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. 75 let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); 76 tokens.insert( 77 0, 78 TokenTree::Group(Group::new(Delimiter::None, config_kunit)), 79 ); 80 81 // Generate the test KUnit test suite and a test case for each `#[test]`. 82 // The code generated for the following test module: 83 // 84 // ``` 85 // #[kunit_tests(kunit_test_suit_name)] 86 // mod tests { 87 // #[test] 88 // fn foo() { 89 // assert_eq!(1, 1); 90 // } 91 // 92 // #[test] 93 // fn bar() { 94 // assert_eq!(2, 2); 95 // } 96 // } 97 // ``` 98 // 99 // Looks like: 100 // 101 // ``` 102 // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); } 103 // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } 104 // 105 // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ 106 // ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo), 107 // ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar), 108 // ::kernel::kunit::kunit_case_null(), 109 // ]; 110 // 111 // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); 112 // ``` 113 let mut kunit_macros = "".to_owned(); 114 let mut test_cases = "".to_owned(); 115 let mut assert_macros = "".to_owned(); 116 let path = crate::helpers::file(); 117 let num_tests = tests.len(); 118 for (test, cfg_attr) in tests { 119 let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); 120 // Append any `cfg` attributes the user might have written on their tests so we don't 121 // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce 122 // the length of the assert message. 123 let kunit_wrapper = format!( 124 r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) 125 {{ 126 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; 127 {cfg_attr} {{ 128 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; 129 use ::kernel::kunit::is_test_result_ok; 130 assert!(is_test_result_ok({test}())); 131 }} 132 }}"#, 133 ); 134 writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); 135 writeln!( 136 test_cases, 137 " ::kernel::kunit::kunit_case(c\"{test}\", {kunit_wrapper_fn_name})," 138 ) 139 .unwrap(); 140 writeln!( 141 assert_macros, 142 r#" 143 /// Overrides the usual [`assert!`] macro with one that calls KUnit instead. 144 #[allow(unused)] 145 macro_rules! assert {{ 146 ($cond:expr $(,)?) => {{{{ 147 kernel::kunit_assert!("{test}", c"{path}", 0, $cond); 148 }}}} 149 }} 150 151 /// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. 152 #[allow(unused)] 153 macro_rules! assert_eq {{ 154 ($left:expr, $right:expr $(,)?) => {{{{ 155 kernel::kunit_assert_eq!("{test}", c"{path}", 0, $left, $right); 156 }}}} 157 }} 158 "# 159 ) 160 .unwrap(); 161 } 162 163 writeln!(kunit_macros).unwrap(); 164 writeln!( 165 kunit_macros, 166 "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", 167 num_tests + 1 168 ) 169 .unwrap(); 170 171 writeln!( 172 kunit_macros, 173 "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" 174 ) 175 .unwrap(); 176 177 // Remove the `#[test]` macros. 178 // We do this at a token level, in order to preserve span information. 179 let mut new_body = vec![]; 180 let mut body_it = body.stream().into_iter(); 181 182 while let Some(token) = body_it.next() { 183 match token { 184 TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { 185 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), 186 Some(next) => { 187 new_body.extend([token, next]); 188 } 189 _ => { 190 new_body.push(token); 191 } 192 }, 193 _ => { 194 new_body.push(token); 195 } 196 } 197 } 198 199 let mut final_body = TokenStream::new(); 200 final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); 201 final_body.extend(new_body); 202 final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); 203 204 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); 205 206 tokens.into_iter().collect() 207 } 208