1 // SPDX-License-Identifier: GPL-2.0 2 3 //! Procedural macro to run KUnit tests using a user-space like syntax. 4 //! 5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> 6 7 use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; 8 use std::fmt::Write; 9 10 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { 11 let attr = attr.to_string(); 12 13 if attr.is_empty() { 14 panic!("Missing test name in `#[kunit_tests(test_name)]` macro") 15 } 16 17 if attr.len() > 255 { 18 panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") 19 } 20 21 let mut tokens: Vec<_> = ts.into_iter().collect(); 22 23 // Scan for the `mod` keyword. 24 tokens 25 .iter() 26 .find_map(|token| match token { 27 TokenTree::Ident(ident) => match ident.to_string().as_str() { 28 "mod" => Some(true), 29 _ => None, 30 }, 31 _ => None, 32 }) 33 .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); 34 35 // Retrieve the main body. The main body should be the last token tree. 36 let body = match tokens.pop() { 37 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, 38 _ => panic!("Cannot locate main body of module"), 39 }; 40 41 // Get the functions set as tests. Search for `[test]` -> `fn`. 42 let mut body_it = body.stream().into_iter(); 43 let mut tests = Vec::new(); 44 while let Some(token) = body_it.next() { 45 match token { 46 TokenTree::Group(ident) if ident.to_string() == "[test]" => match body_it.next() { 47 Some(TokenTree::Ident(ident)) if ident.to_string() == "fn" => { 48 let test_name = match body_it.next() { 49 Some(TokenTree::Ident(ident)) => ident.to_string(), 50 _ => continue, 51 }; 52 tests.push(test_name); 53 } 54 _ => continue, 55 }, 56 _ => (), 57 } 58 } 59 60 // Add `#[cfg(CONFIG_KUNIT)]` before the module declaration. 61 let config_kunit = "#[cfg(CONFIG_KUNIT)]".to_owned().parse().unwrap(); 62 tokens.insert( 63 0, 64 TokenTree::Group(Group::new(Delimiter::None, config_kunit)), 65 ); 66 67 // Generate the test KUnit test suite and a test case for each `#[test]`. 68 // The code generated for the following test module: 69 // 70 // ``` 71 // #[kunit_tests(kunit_test_suit_name)] 72 // mod tests { 73 // #[test] 74 // fn foo() { 75 // assert_eq!(1, 1); 76 // } 77 // 78 // #[test] 79 // fn bar() { 80 // assert_eq!(2, 2); 81 // } 82 // } 83 // ``` 84 // 85 // Looks like: 86 // 87 // ``` 88 // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut kernel::bindings::kunit) { foo(); } 89 // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut kernel::bindings::kunit) { bar(); } 90 // 91 // static mut TEST_CASES: [kernel::bindings::kunit_case; 3] = [ 92 // kernel::kunit::kunit_case(kernel::c_str!("foo"), kunit_rust_wrapper_foo), 93 // kernel::kunit::kunit_case(kernel::c_str!("bar"), kunit_rust_wrapper_bar), 94 // kernel::kunit::kunit_case_null(), 95 // ]; 96 // 97 // kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); 98 // ``` 99 let mut kunit_macros = "".to_owned(); 100 let mut test_cases = "".to_owned(); 101 for test in &tests { 102 let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); 103 let kunit_wrapper = format!( 104 "unsafe extern \"C\" fn {kunit_wrapper_fn_name}(_test: *mut kernel::bindings::kunit) {{ {test}(); }}" 105 ); 106 writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); 107 writeln!( 108 test_cases, 109 " kernel::kunit::kunit_case(kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," 110 ) 111 .unwrap(); 112 } 113 114 writeln!(kunit_macros).unwrap(); 115 writeln!( 116 kunit_macros, 117 "static mut TEST_CASES: [kernel::bindings::kunit_case; {}] = [\n{test_cases} kernel::kunit::kunit_case_null(),\n];", 118 tests.len() + 1 119 ) 120 .unwrap(); 121 122 writeln!( 123 kunit_macros, 124 "kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" 125 ) 126 .unwrap(); 127 128 // Remove the `#[test]` macros. 129 // We do this at a token level, in order to preserve span information. 130 let mut new_body = vec![]; 131 let mut body_it = body.stream().into_iter(); 132 133 while let Some(token) = body_it.next() { 134 match token { 135 TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { 136 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), 137 Some(next) => { 138 new_body.extend([token, next]); 139 } 140 _ => { 141 new_body.push(token); 142 } 143 }, 144 _ => { 145 new_body.push(token); 146 } 147 } 148 } 149 150 let mut new_body = TokenStream::from_iter(new_body); 151 new_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); 152 153 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); 154 155 tokens.into_iter().collect() 156 } 157