xref: /linux/rust/macros/kunit.rs (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Procedural macro to run KUnit tests using a user-space like syntax.
4 //!
5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6 
7 use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
8 use std::collections::HashMap;
9 use std::fmt::Write;
10 
11 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
12     let attr = attr.to_string();
13 
14     if attr.is_empty() {
15         panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
16     }
17 
18     if attr.len() > 255 {
19         panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
20     }
21 
22     let mut tokens: Vec<_> = ts.into_iter().collect();
23 
24     // Scan for the `mod` keyword.
25     tokens
26         .iter()
27         .find_map(|token| match token {
28             TokenTree::Ident(ident) => match ident.to_string().as_str() {
29                 "mod" => Some(true),
30                 _ => None,
31             },
32             _ => None,
33         })
34         .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
35 
36     // Retrieve the main body. The main body should be the last token tree.
37     let body = match tokens.pop() {
38         Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
39         _ => panic!("Cannot locate main body of module"),
40     };
41 
42     // Get the functions set as tests. Search for `[test]` -> `fn`.
43     let mut body_it = body.stream().into_iter();
44     let mut tests = Vec::new();
45     let mut attributes: HashMap<String, TokenStream> = HashMap::new();
46     while let Some(token) = body_it.next() {
47         match token {
48             TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
49                 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
50                     if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
51                         // Collect attributes because we need to find which are tests. We also
52                         // need to copy `cfg` attributes so tests can be conditionally enabled.
53                         attributes
54                             .entry(name.to_string())
55                             .or_default()
56                             .extend([token, TokenTree::Group(g)]);
57                     }
58                     continue;
59                 }
60                 _ => (),
61             },
62             TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
63                 if let Some(TokenTree::Ident(test_name)) = body_it.next() {
64                     tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
65                 }
66             }
67 
68             _ => (),
69         }
70         attributes.clear();
71     }
72 
73     // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
74     let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
75     tokens.insert(
76         0,
77         TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
78     );
79 
80     // Generate the test KUnit test suite and a test case for each `#[test]`.
81     // The code generated for the following test module:
82     //
83     // ```
84     // #[kunit_tests(kunit_test_suit_name)]
85     // mod tests {
86     //     #[test]
87     //     fn foo() {
88     //         assert_eq!(1, 1);
89     //     }
90     //
91     //     #[test]
92     //     fn bar() {
93     //         assert_eq!(2, 2);
94     //     }
95     // }
96     // ```
97     //
98     // Looks like:
99     //
100     // ```
101     // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
102     // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
103     //
104     // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
105     //     ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo),
106     //     ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar),
107     //     ::kernel::kunit::kunit_case_null(),
108     // ];
109     //
110     // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
111     // ```
112     let mut kunit_macros = "".to_owned();
113     let mut test_cases = "".to_owned();
114     let mut assert_macros = "".to_owned();
115     let path = crate::helpers::file();
116     let num_tests = tests.len();
117     for (test, cfg_attr) in tests {
118         let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
119         // Append any `cfg` attributes the user might have written on their tests so we don't
120         // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
121         // the length of the assert message.
122         let kunit_wrapper = format!(
123             r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
124             {{
125                 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
126                 {cfg_attr} {{
127                     (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
128                     use ::kernel::kunit::is_test_result_ok;
129                     assert!(is_test_result_ok({test}()));
130                 }}
131             }}"#,
132         );
133         writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
134         writeln!(
135             test_cases,
136             "    ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
137         )
138         .unwrap();
139         writeln!(
140             assert_macros,
141             r#"
142 /// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
143 #[allow(unused)]
144 macro_rules! assert {{
145     ($cond:expr $(,)?) => {{{{
146         kernel::kunit_assert!("{test}", "{path}", 0, $cond);
147     }}}}
148 }}
149 
150 /// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
151 #[allow(unused)]
152 macro_rules! assert_eq {{
153     ($left:expr, $right:expr $(,)?) => {{{{
154         kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
155     }}}}
156 }}
157         "#
158         )
159         .unwrap();
160     }
161 
162     writeln!(kunit_macros).unwrap();
163     writeln!(
164         kunit_macros,
165         "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];",
166         num_tests + 1
167     )
168     .unwrap();
169 
170     writeln!(
171         kunit_macros,
172         "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
173     )
174     .unwrap();
175 
176     // Remove the `#[test]` macros.
177     // We do this at a token level, in order to preserve span information.
178     let mut new_body = vec![];
179     let mut body_it = body.stream().into_iter();
180 
181     while let Some(token) = body_it.next() {
182         match token {
183             TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
184                 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
185                 Some(next) => {
186                     new_body.extend([token, next]);
187                 }
188                 _ => {
189                     new_body.push(token);
190                 }
191             },
192             _ => {
193                 new_body.push(token);
194             }
195         }
196     }
197 
198     let mut final_body = TokenStream::new();
199     final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
200     final_body.extend(new_body);
201     final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
202 
203     tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
204 
205     tokens.into_iter().collect()
206 }
207