1 // SPDX-License-Identifier: GPL-2.0 2 3 //! Procedural macro to run KUnit tests using a user-space like syntax. 4 //! 5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> 6 7 use std::ffi::CString; 8 9 use proc_macro2::TokenStream; 10 use quote::{ 11 format_ident, 12 quote, 13 ToTokens, // 14 }; 15 use syn::{ 16 parse_quote, 17 Error, 18 Ident, 19 Item, 20 ItemMod, 21 LitCStr, 22 Result, // 23 }; 24 25 pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> { 26 if test_suite.to_string().len() > 255 { 27 return Err(Error::new_spanned( 28 test_suite, 29 "test suite names cannot exceed the maximum length of 255 bytes", 30 )); 31 } 32 33 // We cannot handle modules that defer to another file (e.g. `mod foo;`). 34 let Some((module_brace, module_items)) = module.content.take() else { 35 Err(Error::new_spanned( 36 module, 37 "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", 38 ))? 39 }; 40 41 // Make the entire module gated behind `CONFIG_KUNIT`. 42 module 43 .attrs 44 .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")])); 45 46 let mut processed_items = Vec::new(); 47 let mut test_cases = Vec::new(); 48 49 // Generate the test KUnit test suite and a test case for each `#[test]`. 50 // 51 // The code generated for the following test module: 52 // 53 // ``` 54 // #[kunit_tests(kunit_test_suit_name)] 55 // mod tests { 56 // #[test] 57 // fn foo() { 58 // assert_eq!(1, 1); 59 // } 60 // 61 // #[test] 62 // fn bar() { 63 // assert_eq!(2, 2); 64 // } 65 // } 66 // ``` 67 // 68 // Looks like: 69 // 70 // ``` 71 // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); } 72 // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } 73 // 74 // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ 75 // ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo), 76 // ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar), 77 // ::pin_init::zeroed(), 78 // ]; 79 // 80 // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); 81 // ``` 82 // 83 // Non-function items (e.g. imports) are preserved. 84 for item in module_items { 85 let Item::Fn(mut f) = item else { 86 processed_items.push(item); 87 continue; 88 }; 89 90 if f.attrs 91 .extract_if(.., |attr| attr.path().is_ident("test")) 92 .count() 93 == 0 94 { 95 processed_items.push(Item::Fn(f)); 96 continue; 97 } 98 99 let test = f.sig.ident.clone(); 100 101 // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too. 102 let cfg_attrs: Vec<_> = f 103 .attrs 104 .iter() 105 .filter(|attr| attr.path().is_ident("cfg")) 106 .cloned() 107 .collect(); 108 109 // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call 110 // KUnit instead. 111 let test_str = test.to_string(); 112 let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL"); 113 processed_items.push(parse_quote! { 114 #[allow(unused)] 115 macro_rules! assert { 116 ($cond:expr $(,)?) => {{ 117 kernel::kunit_assert!(#test_str, #path, 0, $cond); 118 }} 119 } 120 }); 121 processed_items.push(parse_quote! { 122 #[allow(unused)] 123 macro_rules! assert_eq { 124 ($left:expr, $right:expr $(,)?) => {{ 125 kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); 126 }} 127 } 128 }); 129 130 // Add back the test item. 131 processed_items.push(Item::Fn(f)); 132 133 let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}"); 134 let test_cstr = LitCStr::new( 135 &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"), 136 test.span(), 137 ); 138 processed_items.push(parse_quote! { 139 unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) { 140 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; 141 142 // Append any `cfg` attributes the user might have written on their tests so we 143 // don't attempt to call them when they are `cfg`'d out. An extra `use` is used 144 // here to reduce the length of the assert message. 145 #(#cfg_attrs)* 146 { 147 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; 148 use ::kernel::kunit::is_test_result_ok; 149 assert!(is_test_result_ok(#test())); 150 } 151 } 152 }); 153 154 test_cases.push(quote!( 155 ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) 156 )); 157 } 158 159 let num_tests_plus_1 = test_cases.len() + 1; 160 processed_items.push(parse_quote! { 161 static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [ 162 #(#test_cases,)* 163 ::pin_init::zeroed(), 164 ]; 165 }); 166 processed_items.push(parse_quote! { 167 ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); 168 }); 169 170 module.content = Some((module_brace, processed_items)); 171 Ok(module.to_token_stream()) 172 } 173