xref: /linux/rust/macros/kunit.rs (revision 23b0f90ba871f096474e1c27c3d14f455189d2d9)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Procedural macro to run KUnit tests using a user-space like syntax.
4 //!
5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6 
7 use std::ffi::CString;
8 
9 use proc_macro2::TokenStream;
10 use quote::{
11     format_ident,
12     quote,
13     ToTokens, //
14 };
15 use syn::{
16     parse_quote,
17     Error,
18     Ident,
19     Item,
20     ItemMod,
21     LitCStr,
22     Result, //
23 };
24 
25 pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
26     if test_suite.to_string().len() > 255 {
27         return Err(Error::new_spanned(
28             test_suite,
29             "test suite names cannot exceed the maximum length of 255 bytes",
30         ));
31     }
32 
33     // We cannot handle modules that defer to another file (e.g. `mod foo;`).
34     let Some((module_brace, module_items)) = module.content.take() else {
35         Err(Error::new_spanned(
36             module,
37             "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
38         ))?
39     };
40 
41     // Make the entire module gated behind `CONFIG_KUNIT`.
42     module
43         .attrs
44         .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
45 
46     let mut processed_items = Vec::new();
47     let mut test_cases = Vec::new();
48 
49     // Generate the test KUnit test suite and a test case for each `#[test]`.
50     //
51     // The code generated for the following test module:
52     //
53     // ```
54     // #[kunit_tests(kunit_test_suit_name)]
55     // mod tests {
56     //     #[test]
57     //     fn foo() {
58     //         assert_eq!(1, 1);
59     //     }
60     //
61     //     #[test]
62     //     fn bar() {
63     //         assert_eq!(2, 2);
64     //     }
65     // }
66     // ```
67     //
68     // Looks like:
69     //
70     // ```
71     // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
72     // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
73     //
74     // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
75     //     ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo),
76     //     ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar),
77     //     ::pin_init::zeroed(),
78     // ];
79     //
80     // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
81     // ```
82     //
83     // Non-function items (e.g. imports) are preserved.
84     for item in module_items {
85         let Item::Fn(mut f) = item else {
86             processed_items.push(item);
87             continue;
88         };
89 
90         // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85.
91         let before_len = f.attrs.len();
92         f.attrs.retain(|attr| !attr.path().is_ident("test"));
93         if f.attrs.len() == before_len {
94             processed_items.push(Item::Fn(f));
95             continue;
96         }
97 
98         let test = f.sig.ident.clone();
99 
100         // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.
101         let cfg_attrs: Vec<_> = f
102             .attrs
103             .iter()
104             .filter(|attr| attr.path().is_ident("cfg"))
105             .cloned()
106             .collect();
107 
108         // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call
109         // KUnit instead.
110         let test_str = test.to_string();
111         let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
112         processed_items.push(parse_quote! {
113             #[allow(unused)]
114             macro_rules! assert {
115                 ($cond:expr $(,)?) => {{
116                     kernel::kunit_assert!(#test_str, #path, 0, $cond);
117                 }}
118             }
119         });
120         processed_items.push(parse_quote! {
121             #[allow(unused)]
122             macro_rules! assert_eq {
123                 ($left:expr, $right:expr $(,)?) => {{
124                     kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
125                 }}
126             }
127         });
128 
129         // Add back the test item.
130         processed_items.push(Item::Fn(f));
131 
132         let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
133         let test_cstr = LitCStr::new(
134             &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
135             test.span(),
136         );
137         processed_items.push(parse_quote! {
138             unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
139                 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
140 
141                 // Append any `cfg` attributes the user might have written on their tests so we
142                 // don't attempt to call them when they are `cfg`'d out. An extra `use` is used
143                 // here to reduce the length of the assert message.
144                 #(#cfg_attrs)*
145                 {
146                     (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
147                     use ::kernel::kunit::is_test_result_ok;
148                     assert!(is_test_result_ok(#test()));
149                 }
150             }
151         });
152 
153         test_cases.push(quote!(
154             ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
155         ));
156     }
157 
158     let num_tests_plus_1 = test_cases.len() + 1;
159     processed_items.push(parse_quote! {
160         static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
161             #(#test_cases,)*
162             ::pin_init::zeroed(),
163         ];
164     });
165     processed_items.push(parse_quote! {
166         ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
167     });
168 
169     module.content = Some((module_brace, processed_items));
170     Ok(module.to_token_stream())
171 }
172