1 // SPDX-License-Identifier: GPL-2.0 2 3 //! Procedural macro to run KUnit tests using a user-space like syntax. 4 //! 5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> 6 7 use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; 8 use std::fmt::Write; 9 10 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { 11 let attr = attr.to_string(); 12 13 if attr.is_empty() { 14 panic!("Missing test name in `#[kunit_tests(test_name)]` macro") 15 } 16 17 if attr.len() > 255 { 18 panic!( 19 "The test suite name `{}` exceeds the maximum length of 255 bytes", 20 attr 21 ) 22 } 23 24 let mut tokens: Vec<_> = ts.into_iter().collect(); 25 26 // Scan for the `mod` keyword. 27 tokens 28 .iter() 29 .find_map(|token| match token { 30 TokenTree::Ident(ident) => match ident.to_string().as_str() { 31 "mod" => Some(true), 32 _ => None, 33 }, 34 _ => None, 35 }) 36 .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); 37 38 // Retrieve the main body. The main body should be the last token tree. 39 let body = match tokens.pop() { 40 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, 41 _ => panic!("Cannot locate main body of module"), 42 }; 43 44 // Get the functions set as tests. Search for `[test]` -> `fn`. 45 let mut body_it = body.stream().into_iter(); 46 let mut tests = Vec::new(); 47 while let Some(token) = body_it.next() { 48 match token { 49 TokenTree::Group(ident) if ident.to_string() == "[test]" => match body_it.next() { 50 Some(TokenTree::Ident(ident)) if ident.to_string() == "fn" => { 51 let test_name = match body_it.next() { 52 Some(TokenTree::Ident(ident)) => ident.to_string(), 53 _ => continue, 54 }; 55 tests.push(test_name); 56 } 57 _ => continue, 58 }, 59 _ => (), 60 } 61 } 62 63 // Add `#[cfg(CONFIG_KUNIT)]` before the module declaration. 64 let config_kunit = "#[cfg(CONFIG_KUNIT)]".to_owned().parse().unwrap(); 65 tokens.insert( 66 0, 67 TokenTree::Group(Group::new(Delimiter::None, config_kunit)), 68 ); 69 70 // Generate the test KUnit test suite and a test case for each `#[test]`. 71 // The code generated for the following test module: 72 // 73 // ``` 74 // #[kunit_tests(kunit_test_suit_name)] 75 // mod tests { 76 // #[test] 77 // fn foo() { 78 // assert_eq!(1, 1); 79 // } 80 // 81 // #[test] 82 // fn bar() { 83 // assert_eq!(2, 2); 84 // } 85 // } 86 // ``` 87 // 88 // Looks like: 89 // 90 // ``` 91 // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut kernel::bindings::kunit) { foo(); } 92 // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut kernel::bindings::kunit) { bar(); } 93 // 94 // static mut TEST_CASES: [kernel::bindings::kunit_case; 3] = [ 95 // kernel::kunit::kunit_case(kernel::c_str!("foo"), kunit_rust_wrapper_foo), 96 // kernel::kunit::kunit_case(kernel::c_str!("bar"), kunit_rust_wrapper_bar), 97 // kernel::kunit::kunit_case_null(), 98 // ]; 99 // 100 // kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); 101 // ``` 102 let mut kunit_macros = "".to_owned(); 103 let mut test_cases = "".to_owned(); 104 for test in &tests { 105 let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{}", test); 106 let kunit_wrapper = format!( 107 "unsafe extern \"C\" fn {}(_test: *mut kernel::bindings::kunit) {{ {}(); }}", 108 kunit_wrapper_fn_name, test 109 ); 110 writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); 111 writeln!( 112 test_cases, 113 " kernel::kunit::kunit_case(kernel::c_str!(\"{}\"), {}),", 114 test, kunit_wrapper_fn_name 115 ) 116 .unwrap(); 117 } 118 119 writeln!(kunit_macros).unwrap(); 120 writeln!( 121 kunit_macros, 122 "static mut TEST_CASES: [kernel::bindings::kunit_case; {}] = [\n{test_cases} kernel::kunit::kunit_case_null(),\n];", 123 tests.len() + 1 124 ) 125 .unwrap(); 126 127 writeln!( 128 kunit_macros, 129 "kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" 130 ) 131 .unwrap(); 132 133 // Remove the `#[test]` macros. 134 // We do this at a token level, in order to preserve span information. 135 let mut new_body = vec![]; 136 let mut body_it = body.stream().into_iter(); 137 138 while let Some(token) = body_it.next() { 139 match token { 140 TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { 141 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), 142 Some(next) => { 143 new_body.extend([token, next]); 144 } 145 _ => { 146 new_body.push(token); 147 } 148 }, 149 _ => { 150 new_body.push(token); 151 } 152 } 153 } 154 155 let mut new_body = TokenStream::from_iter(new_body); 156 new_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); 157 158 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); 159 160 tokens.into_iter().collect() 161 } 162