1 // SPDX-License-Identifier: GPL-2.0 2 3 //! Procedural macro to run KUnit tests using a user-space like syntax. 4 //! 5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> 6 7 use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; 8 use std::collections::HashMap; 9 use std::fmt::Write; 10 11 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { 12 let attr = attr.to_string(); 13 14 if attr.is_empty() { 15 panic!("Missing test name in `#[kunit_tests(test_name)]` macro") 16 } 17 18 if attr.len() > 255 { 19 panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") 20 } 21 22 let mut tokens: Vec<_> = ts.into_iter().collect(); 23 24 // Scan for the `mod` keyword. 25 tokens 26 .iter() 27 .find_map(|token| match token { 28 TokenTree::Ident(ident) => match ident.to_string().as_str() { 29 "mod" => Some(true), 30 _ => None, 31 }, 32 _ => None, 33 }) 34 .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); 35 36 // Retrieve the main body. The main body should be the last token tree. 37 let body = match tokens.pop() { 38 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, 39 _ => panic!("Cannot locate main body of module"), 40 }; 41 42 // Get the functions set as tests. Search for `[test]` -> `fn`. 43 let mut body_it = body.stream().into_iter(); 44 let mut tests = Vec::new(); 45 let mut attributes: HashMap<String, TokenStream> = HashMap::new(); 46 while let Some(token) = body_it.next() { 47 match token { 48 TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { 49 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { 50 if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { 51 // Collect attributes because we need to find which are tests. We also 52 // need to copy `cfg` attributes so tests can be conditionally enabled. 53 attributes 54 .entry(name.to_string()) 55 .or_default() 56 .extend([token, TokenTree::Group(g)]); 57 } 58 continue; 59 } 60 _ => (), 61 }, 62 TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => { 63 if let Some(TokenTree::Ident(test_name)) = body_it.next() { 64 tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) 65 } 66 } 67 68 _ => (), 69 } 70 attributes.clear(); 71 } 72 73 // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. 74 let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); 75 tokens.insert( 76 0, 77 TokenTree::Group(Group::new(Delimiter::None, config_kunit)), 78 ); 79 80 // Generate the test KUnit test suite and a test case for each `#[test]`. 81 // The code generated for the following test module: 82 // 83 // ``` 84 // #[kunit_tests(kunit_test_suit_name)] 85 // mod tests { 86 // #[test] 87 // fn foo() { 88 // assert_eq!(1, 1); 89 // } 90 // 91 // #[test] 92 // fn bar() { 93 // assert_eq!(2, 2); 94 // } 95 // } 96 // ``` 97 // 98 // Looks like: 99 // 100 // ``` 101 // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); } 102 // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } 103 // 104 // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ 105 // ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo), 106 // ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar), 107 // ::kernel::kunit::kunit_case_null(), 108 // ]; 109 // 110 // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); 111 // ``` 112 let mut kunit_macros = "".to_owned(); 113 let mut test_cases = "".to_owned(); 114 let mut assert_macros = "".to_owned(); 115 let path = crate::helpers::file(); 116 let num_tests = tests.len(); 117 for (test, cfg_attr) in tests { 118 let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); 119 // Append any `cfg` attributes the user might have written on their tests so we don't 120 // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce 121 // the length of the assert message. 122 let kunit_wrapper = format!( 123 r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) 124 {{ 125 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; 126 {cfg_attr} {{ 127 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; 128 use ::kernel::kunit::is_test_result_ok; 129 assert!(is_test_result_ok({test}())); 130 }} 131 }}"#, 132 ); 133 writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); 134 writeln!( 135 test_cases, 136 " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," 137 ) 138 .unwrap(); 139 writeln!( 140 assert_macros, 141 r#" 142 /// Overrides the usual [`assert!`] macro with one that calls KUnit instead. 143 #[allow(unused)] 144 macro_rules! assert {{ 145 ($cond:expr $(,)?) => {{{{ 146 kernel::kunit_assert!("{test}", "{path}", 0, $cond); 147 }}}} 148 }} 149 150 /// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. 151 #[allow(unused)] 152 macro_rules! assert_eq {{ 153 ($left:expr, $right:expr $(,)?) => {{{{ 154 kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); 155 }}}} 156 }} 157 "# 158 ) 159 .unwrap(); 160 } 161 162 writeln!(kunit_macros).unwrap(); 163 writeln!( 164 kunit_macros, 165 "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", 166 num_tests + 1 167 ) 168 .unwrap(); 169 170 writeln!( 171 kunit_macros, 172 "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" 173 ) 174 .unwrap(); 175 176 // Remove the `#[test]` macros. 177 // We do this at a token level, in order to preserve span information. 178 let mut new_body = vec![]; 179 let mut body_it = body.stream().into_iter(); 180 181 while let Some(token) = body_it.next() { 182 match token { 183 TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { 184 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), 185 Some(next) => { 186 new_body.extend([token, next]); 187 } 188 _ => { 189 new_body.push(token); 190 } 191 }, 192 _ => { 193 new_body.push(token); 194 } 195 } 196 } 197 198 let mut final_body = TokenStream::new(); 199 final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); 200 final_body.extend(new_body); 201 final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); 202 203 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); 204 205 tokens.into_iter().collect() 206 } 207