xref: /linux/rust/macros/zeroable.rs (revision 001821b0e79716c4e17c71d8e053a23599a7a508)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use crate::helpers::{parse_generics, Generics};
4 use proc_macro::{TokenStream, TokenTree};
5 
6 pub(crate) fn derive(input: TokenStream) -> TokenStream {
7     let (
8         Generics {
9             impl_generics,
10             decl_generics: _,
11             ty_generics,
12         },
13         mut rest,
14     ) = parse_generics(input);
15     // This should be the body of the struct `{...}`.
16     let last = rest.pop();
17     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
18     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
19     // Are we inside of a generic where we want to add `Zeroable`?
20     let mut in_generic = !impl_generics.is_empty();
21     // Have we already inserted `Zeroable`?
22     let mut inserted = false;
23     // Level of `<>` nestings.
24     let mut nested = 0;
25     for tt in impl_generics {
26         match &tt {
27             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
28             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
29                 if in_generic && !inserted {
30                     new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
31                 }
32                 in_generic = true;
33                 inserted = false;
34                 new_impl_generics.push(tt);
35             }
36             // If we find `'`, then we are entering a lifetime.
37             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
38                 in_generic = false;
39                 new_impl_generics.push(tt);
40             }
41             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
42                 new_impl_generics.push(tt);
43                 if in_generic {
44                     new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
45                     inserted = true;
46                 }
47             }
48             TokenTree::Punct(p) if p.as_char() == '<' => {
49                 nested += 1;
50                 new_impl_generics.push(tt);
51             }
52             TokenTree::Punct(p) if p.as_char() == '>' => {
53                 assert!(nested > 0);
54                 nested -= 1;
55                 new_impl_generics.push(tt);
56             }
57             _ => new_impl_generics.push(tt),
58         }
59     }
60     assert_eq!(nested, 0);
61     if in_generic && !inserted {
62         new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
63     }
64     quote! {
65         ::kernel::__derive_zeroable!(
66             parse_input:
67                 @sig(#(#rest)*),
68                 @impl_generics(#(#new_impl_generics)*),
69                 @ty_generics(#(#ty_generics)*),
70                 @body(#last),
71         );
72     }
73 }
74