xref: /linux/rust/macros/kunit.rs (revision e2683c8868d03382da7e1ce8453b543a043066d1)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Procedural macro to run KUnit tests using a user-space like syntax.
4 //!
5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6 
7 use std::ffi::CString;
8 
9 use proc_macro2::TokenStream;
10 use quote::{
11     format_ident,
12     quote,
13     ToTokens, //
14 };
15 use syn::{
16     parse_quote,
17     Error,
18     Ident,
19     Item,
20     ItemMod,
21     LitCStr,
22     Result, //
23 };
24 
25 pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
26     if test_suite.to_string().len() > 255 {
27         return Err(Error::new_spanned(
28             test_suite,
29             "test suite names cannot exceed the maximum length of 255 bytes",
30         ));
31     }
32 
33     // We cannot handle modules that defer to another file (e.g. `mod foo;`).
34     let Some((module_brace, module_items)) = module.content.take() else {
35         Err(Error::new_spanned(
36             module,
37             "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
38         ))?
39     };
40 
41     // Make the entire module gated behind `CONFIG_KUNIT`.
42     module
43         .attrs
44         .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
45 
46     let mut processed_items = Vec::new();
47     let mut test_cases = Vec::new();
48 
49     // Generate the test KUnit test suite and a test case for each `#[test]`.
50     //
51     // The code generated for the following test module:
52     //
53     // ```
54     // #[kunit_tests(kunit_test_suit_name)]
55     // mod tests {
56     //     #[test]
57     //     fn foo() {
58     //         assert_eq!(1, 1);
59     //     }
60     //
61     //     #[test]
62     //     fn bar() {
63     //         assert_eq!(2, 2);
64     //     }
65     // }
66     // ```
67     //
68     // Looks like:
69     //
70     // ```
71     // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
72     // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
73     //
74     // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
75     //     ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo),
76     //     ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar),
77     //     ::pin_init::zeroed(),
78     // ];
79     //
80     // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
81     // ```
82     //
83     // Non-function items (e.g. imports) are preserved.
84     for item in module_items {
85         let Item::Fn(mut f) = item else {
86             processed_items.push(item);
87             continue;
88         };
89 
90         if f.attrs
91             .extract_if(.., |attr| attr.path().is_ident("test"))
92             .count()
93             == 0
94         {
95             processed_items.push(Item::Fn(f));
96             continue;
97         }
98 
99         let test = f.sig.ident.clone();
100 
101         // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.
102         let cfg_attrs: Vec<_> = f
103             .attrs
104             .iter()
105             .filter(|attr| attr.path().is_ident("cfg"))
106             .cloned()
107             .collect();
108 
109         // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call
110         // KUnit instead.
111         let test_str = test.to_string();
112         let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
113         processed_items.push(parse_quote! {
114             #[allow(unused)]
115             macro_rules! assert {
116                 ($cond:expr $(,)?) => {{
117                     kernel::kunit_assert!(#test_str, #path, 0, $cond);
118                 }}
119             }
120         });
121         processed_items.push(parse_quote! {
122             #[allow(unused)]
123             macro_rules! assert_eq {
124                 ($left:expr, $right:expr $(,)?) => {{
125                     kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
126                 }}
127             }
128         });
129 
130         // Add back the test item.
131         processed_items.push(Item::Fn(f));
132 
133         let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
134         let test_cstr = LitCStr::new(
135             &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
136             test.span(),
137         );
138         processed_items.push(parse_quote! {
139             unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
140                 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
141 
142                 // Append any `cfg` attributes the user might have written on their tests so we
143                 // don't attempt to call them when they are `cfg`'d out. An extra `use` is used
144                 // here to reduce the length of the assert message.
145                 #(#cfg_attrs)*
146                 {
147                     (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
148                     use ::kernel::kunit::is_test_result_ok;
149                     assert!(is_test_result_ok(#test()));
150                 }
151             }
152         });
153 
154         test_cases.push(quote!(
155             ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
156         ));
157     }
158 
159     let num_tests_plus_1 = test_cases.len() + 1;
160     processed_items.push(parse_quote! {
161         static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
162             #(#test_cases,)*
163             ::pin_init::zeroed(),
164         ];
165     });
166     processed_items.push(parse_quote! {
167         ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
168     });
169 
170     module.content = Some((module_brace, processed_items));
171     Ok(module.to_token_stream())
172 }
173