xref: /linux/rust/macros/kunit.rs (revision f637bafe1ff15fa356c1e0576c32f077b9e6e46a)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Procedural macro to run KUnit tests using a user-space like syntax.
4 //!
5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6 
7 use std::collections::HashMap;
8 use std::fmt::Write;
9 
10 use proc_macro2::{Delimiter, Group, TokenStream, TokenTree};
11 
12 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
13     let attr = attr.to_string();
14 
15     if attr.is_empty() {
16         panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
17     }
18 
19     if attr.len() > 255 {
20         panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
21     }
22 
23     let mut tokens: Vec<_> = ts.into_iter().collect();
24 
25     // Scan for the `mod` keyword.
26     tokens
27         .iter()
28         .find_map(|token| match token {
29             TokenTree::Ident(ident) => match ident.to_string().as_str() {
30                 "mod" => Some(true),
31                 _ => None,
32             },
33             _ => None,
34         })
35         .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
36 
37     // Retrieve the main body. The main body should be the last token tree.
38     let body = match tokens.pop() {
39         Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
40         _ => panic!("Cannot locate main body of module"),
41     };
42 
43     // Get the functions set as tests. Search for `[test]` -> `fn`.
44     let mut body_it = body.stream().into_iter();
45     let mut tests = Vec::new();
46     let mut attributes: HashMap<String, TokenStream> = HashMap::new();
47     while let Some(token) = body_it.next() {
48         match token {
49             TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
50                 Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
51                     if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
52                         // Collect attributes because we need to find which are tests. We also
53                         // need to copy `cfg` attributes so tests can be conditionally enabled.
54                         attributes
55                             .entry(name.to_string())
56                             .or_default()
57                             .extend([token, TokenTree::Group(g)]);
58                     }
59                     continue;
60                 }
61                 _ => (),
62             },
63             TokenTree::Ident(i) if i == "fn" && attributes.contains_key("test") => {
64                 if let Some(TokenTree::Ident(test_name)) = body_it.next() {
65                     tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
66                 }
67             }
68 
69             _ => (),
70         }
71         attributes.clear();
72     }
73 
74     // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
75     let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
76     tokens.insert(
77         0,
78         TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
79     );
80 
81     // Generate the test KUnit test suite and a test case for each `#[test]`.
82     // The code generated for the following test module:
83     //
84     // ```
85     // #[kunit_tests(kunit_test_suit_name)]
86     // mod tests {
87     //     #[test]
88     //     fn foo() {
89     //         assert_eq!(1, 1);
90     //     }
91     //
92     //     #[test]
93     //     fn bar() {
94     //         assert_eq!(2, 2);
95     //     }
96     // }
97     // ```
98     //
99     // Looks like:
100     //
101     // ```
102     // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
103     // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
104     //
105     // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
106     //     ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo),
107     //     ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar),
108     //     ::kernel::kunit::kunit_case_null(),
109     // ];
110     //
111     // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
112     // ```
113     let mut kunit_macros = "".to_owned();
114     let mut test_cases = "".to_owned();
115     let mut assert_macros = "".to_owned();
116     let path = crate::helpers::file();
117     let num_tests = tests.len();
118     for (test, cfg_attr) in tests {
119         let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
120         // Append any `cfg` attributes the user might have written on their tests so we don't
121         // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
122         // the length of the assert message.
123         let kunit_wrapper = format!(
124             r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
125             {{
126                 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
127                 {cfg_attr} {{
128                     (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
129                     use ::kernel::kunit::is_test_result_ok;
130                     assert!(is_test_result_ok({test}()));
131                 }}
132             }}"#,
133         );
134         writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
135         writeln!(
136             test_cases,
137             "    ::kernel::kunit::kunit_case(c\"{test}\", {kunit_wrapper_fn_name}),"
138         )
139         .unwrap();
140         writeln!(
141             assert_macros,
142             r#"
143 /// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
144 #[allow(unused)]
145 macro_rules! assert {{
146     ($cond:expr $(,)?) => {{{{
147         kernel::kunit_assert!("{test}", c"{path}", 0, $cond);
148     }}}}
149 }}
150 
151 /// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
152 #[allow(unused)]
153 macro_rules! assert_eq {{
154     ($left:expr, $right:expr $(,)?) => {{{{
155         kernel::kunit_assert_eq!("{test}", c"{path}", 0, $left, $right);
156     }}}}
157 }}
158         "#
159         )
160         .unwrap();
161     }
162 
163     writeln!(kunit_macros).unwrap();
164     writeln!(
165         kunit_macros,
166         "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];",
167         num_tests + 1
168     )
169     .unwrap();
170 
171     writeln!(
172         kunit_macros,
173         "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
174     )
175     .unwrap();
176 
177     // Remove the `#[test]` macros.
178     // We do this at a token level, in order to preserve span information.
179     let mut new_body = vec![];
180     let mut body_it = body.stream().into_iter();
181 
182     while let Some(token) = body_it.next() {
183         match token {
184             TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
185                 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
186                 Some(next) => {
187                     new_body.extend([token, next]);
188                 }
189                 _ => {
190                     new_body.push(token);
191                 }
192             },
193             _ => {
194                 new_body.push(token);
195             }
196         }
197     }
198 
199     let mut final_body = TokenStream::new();
200     final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
201     final_body.extend(new_body);
202     final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
203 
204     tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
205 
206     tokens.into_iter().collect()
207 }
208