xref: /linux/rust/macros/zeroable.rs (revision 621cde16e49b3ecf7d59a8106a20aaebfb4a59a9)
1071cedc8SBenno Lossin // SPDX-License-Identifier: GPL-2.0
2071cedc8SBenno Lossin 
3071cedc8SBenno Lossin use crate::helpers::{parse_generics, Generics};
4071cedc8SBenno Lossin use proc_macro::{TokenStream, TokenTree};
5071cedc8SBenno Lossin 
derive(input: TokenStream) -> TokenStream6071cedc8SBenno Lossin pub(crate) fn derive(input: TokenStream) -> TokenStream {
7071cedc8SBenno Lossin     let (
8071cedc8SBenno Lossin         Generics {
9071cedc8SBenno Lossin             impl_generics,
10*9762dca5SBenno Lossin             decl_generics: _,
11071cedc8SBenno Lossin             ty_generics,
12071cedc8SBenno Lossin         },
13071cedc8SBenno Lossin         mut rest,
14071cedc8SBenno Lossin     ) = parse_generics(input);
15071cedc8SBenno Lossin     // This should be the body of the struct `{...}`.
16071cedc8SBenno Lossin     let last = rest.pop();
17071cedc8SBenno Lossin     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
18071cedc8SBenno Lossin     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
19071cedc8SBenno Lossin     // Are we inside of a generic where we want to add `Zeroable`?
20071cedc8SBenno Lossin     let mut in_generic = !impl_generics.is_empty();
21071cedc8SBenno Lossin     // Have we already inserted `Zeroable`?
22071cedc8SBenno Lossin     let mut inserted = false;
23071cedc8SBenno Lossin     // Level of `<>` nestings.
24071cedc8SBenno Lossin     let mut nested = 0;
25071cedc8SBenno Lossin     for tt in impl_generics {
26071cedc8SBenno Lossin         match &tt {
27071cedc8SBenno Lossin             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
28071cedc8SBenno Lossin             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
29071cedc8SBenno Lossin                 if in_generic && !inserted {
30071cedc8SBenno Lossin                     new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
31071cedc8SBenno Lossin                 }
32071cedc8SBenno Lossin                 in_generic = true;
33071cedc8SBenno Lossin                 inserted = false;
34071cedc8SBenno Lossin                 new_impl_generics.push(tt);
35071cedc8SBenno Lossin             }
36071cedc8SBenno Lossin             // If we find `'`, then we are entering a lifetime.
37071cedc8SBenno Lossin             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
38071cedc8SBenno Lossin                 in_generic = false;
39071cedc8SBenno Lossin                 new_impl_generics.push(tt);
40071cedc8SBenno Lossin             }
41071cedc8SBenno Lossin             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
42071cedc8SBenno Lossin                 new_impl_generics.push(tt);
43071cedc8SBenno Lossin                 if in_generic {
44071cedc8SBenno Lossin                     new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
45071cedc8SBenno Lossin                     inserted = true;
46071cedc8SBenno Lossin                 }
47071cedc8SBenno Lossin             }
48071cedc8SBenno Lossin             TokenTree::Punct(p) if p.as_char() == '<' => {
49071cedc8SBenno Lossin                 nested += 1;
50071cedc8SBenno Lossin                 new_impl_generics.push(tt);
51071cedc8SBenno Lossin             }
52071cedc8SBenno Lossin             TokenTree::Punct(p) if p.as_char() == '>' => {
53071cedc8SBenno Lossin                 assert!(nested > 0);
54071cedc8SBenno Lossin                 nested -= 1;
55071cedc8SBenno Lossin                 new_impl_generics.push(tt);
56071cedc8SBenno Lossin             }
57071cedc8SBenno Lossin             _ => new_impl_generics.push(tt),
58071cedc8SBenno Lossin         }
59071cedc8SBenno Lossin     }
60071cedc8SBenno Lossin     assert_eq!(nested, 0);
61071cedc8SBenno Lossin     if in_generic && !inserted {
62071cedc8SBenno Lossin         new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
63071cedc8SBenno Lossin     }
64071cedc8SBenno Lossin     quote! {
65071cedc8SBenno Lossin         ::kernel::__derive_zeroable!(
66071cedc8SBenno Lossin             parse_input:
67071cedc8SBenno Lossin                 @sig(#(#rest)*),
68071cedc8SBenno Lossin                 @impl_generics(#(#new_impl_generics)*),
69071cedc8SBenno Lossin                 @ty_generics(#(#ty_generics)*),
70071cedc8SBenno Lossin                 @body(#last),
71071cedc8SBenno Lossin         );
72071cedc8SBenno Lossin     }
73071cedc8SBenno Lossin }
74