xref: /linux/rust/macros/kunit.rs (revision 7a9b709e7cc5ce1ffb84ce07bf6d157e1de758df)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 //! Procedural macro to run KUnit tests using a user-space like syntax.
4 //!
5 //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
6 
7 use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
8 use std::fmt::Write;
9 
10 pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
11     let attr = attr.to_string();
12 
13     if attr.is_empty() {
14         panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
15     }
16 
17     if attr.len() > 255 {
18         panic!(
19             "The test suite name `{}` exceeds the maximum length of 255 bytes",
20             attr
21         )
22     }
23 
24     let mut tokens: Vec<_> = ts.into_iter().collect();
25 
26     // Scan for the `mod` keyword.
27     tokens
28         .iter()
29         .find_map(|token| match token {
30             TokenTree::Ident(ident) => match ident.to_string().as_str() {
31                 "mod" => Some(true),
32                 _ => None,
33             },
34             _ => None,
35         })
36         .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
37 
38     // Retrieve the main body. The main body should be the last token tree.
39     let body = match tokens.pop() {
40         Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
41         _ => panic!("Cannot locate main body of module"),
42     };
43 
44     // Get the functions set as tests. Search for `[test]` -> `fn`.
45     let mut body_it = body.stream().into_iter();
46     let mut tests = Vec::new();
47     while let Some(token) = body_it.next() {
48         match token {
49             TokenTree::Group(ident) if ident.to_string() == "[test]" => match body_it.next() {
50                 Some(TokenTree::Ident(ident)) if ident.to_string() == "fn" => {
51                     let test_name = match body_it.next() {
52                         Some(TokenTree::Ident(ident)) => ident.to_string(),
53                         _ => continue,
54                     };
55                     tests.push(test_name);
56                 }
57                 _ => continue,
58             },
59             _ => (),
60         }
61     }
62 
63     // Add `#[cfg(CONFIG_KUNIT)]` before the module declaration.
64     let config_kunit = "#[cfg(CONFIG_KUNIT)]".to_owned().parse().unwrap();
65     tokens.insert(
66         0,
67         TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
68     );
69 
70     // Generate the test KUnit test suite and a test case for each `#[test]`.
71     // The code generated for the following test module:
72     //
73     // ```
74     // #[kunit_tests(kunit_test_suit_name)]
75     // mod tests {
76     //     #[test]
77     //     fn foo() {
78     //         assert_eq!(1, 1);
79     //     }
80     //
81     //     #[test]
82     //     fn bar() {
83     //         assert_eq!(2, 2);
84     //     }
85     // }
86     // ```
87     //
88     // Looks like:
89     //
90     // ```
91     // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut kernel::bindings::kunit) { foo(); }
92     // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut kernel::bindings::kunit) { bar(); }
93     //
94     // static mut TEST_CASES: [kernel::bindings::kunit_case; 3] = [
95     //     kernel::kunit::kunit_case(kernel::c_str!("foo"), kunit_rust_wrapper_foo),
96     //     kernel::kunit::kunit_case(kernel::c_str!("bar"), kunit_rust_wrapper_bar),
97     //     kernel::kunit::kunit_case_null(),
98     // ];
99     //
100     // kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
101     // ```
102     let mut kunit_macros = "".to_owned();
103     let mut test_cases = "".to_owned();
104     for test in &tests {
105         let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{}", test);
106         let kunit_wrapper = format!(
107             "unsafe extern \"C\" fn {}(_test: *mut kernel::bindings::kunit) {{ {}(); }}",
108             kunit_wrapper_fn_name, test
109         );
110         writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
111         writeln!(
112             test_cases,
113             "    kernel::kunit::kunit_case(kernel::c_str!(\"{}\"), {}),",
114             test, kunit_wrapper_fn_name
115         )
116         .unwrap();
117     }
118 
119     writeln!(kunit_macros).unwrap();
120     writeln!(
121         kunit_macros,
122         "static mut TEST_CASES: [kernel::bindings::kunit_case; {}] = [\n{test_cases}    kernel::kunit::kunit_case_null(),\n];",
123         tests.len() + 1
124     )
125     .unwrap();
126 
127     writeln!(
128         kunit_macros,
129         "kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
130     )
131     .unwrap();
132 
133     // Remove the `#[test]` macros.
134     // We do this at a token level, in order to preserve span information.
135     let mut new_body = vec![];
136     let mut body_it = body.stream().into_iter();
137 
138     while let Some(token) = body_it.next() {
139         match token {
140             TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
141                 Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
142                 Some(next) => {
143                     new_body.extend([token, next]);
144                 }
145                 _ => {
146                     new_body.push(token);
147                 }
148             },
149             _ => {
150                 new_body.push(token);
151             }
152         }
153     }
154 
155     let mut new_body = TokenStream::from_iter(new_body);
156     new_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
157 
158     tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body)));
159 
160     tokens.into_iter().collect()
161 }
162