xref: /linux/rust/pin-init/internal/src/zeroable.rs (revision ec7714e4947909190ffb3041a03311a975350fe0)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #[cfg(not(kernel))]
4 use proc_macro2 as proc_macro;
5 
6 use crate::helpers::{parse_generics, Generics};
7 use proc_macro::{TokenStream, TokenTree};
8 
parse_zeroable_derive_input( input: TokenStream, ) -> ( Vec<TokenTree>, Vec<TokenTree>, Vec<TokenTree>, Option<TokenTree>, )9 pub(crate) fn parse_zeroable_derive_input(
10     input: TokenStream,
11 ) -> (
12     Vec<TokenTree>,
13     Vec<TokenTree>,
14     Vec<TokenTree>,
15     Option<TokenTree>,
16 ) {
17     let (
18         Generics {
19             impl_generics,
20             decl_generics: _,
21             ty_generics,
22         },
23         mut rest,
24     ) = parse_generics(input);
25     // This should be the body of the struct `{...}`.
26     let last = rest.pop();
27     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
28     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
29     // Are we inside of a generic where we want to add `Zeroable`?
30     let mut in_generic = !impl_generics.is_empty();
31     // Have we already inserted `Zeroable`?
32     let mut inserted = false;
33     // Level of `<>` nestings.
34     let mut nested = 0;
35     for tt in impl_generics {
36         match &tt {
37             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
38             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
39                 if in_generic && !inserted {
40                     new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
41                 }
42                 in_generic = true;
43                 inserted = false;
44                 new_impl_generics.push(tt);
45             }
46             // If we find `'`, then we are entering a lifetime.
47             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
48                 in_generic = false;
49                 new_impl_generics.push(tt);
50             }
51             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
52                 new_impl_generics.push(tt);
53                 if in_generic {
54                     new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
55                     inserted = true;
56                 }
57             }
58             TokenTree::Punct(p) if p.as_char() == '<' => {
59                 nested += 1;
60                 new_impl_generics.push(tt);
61             }
62             TokenTree::Punct(p) if p.as_char() == '>' => {
63                 assert!(nested > 0);
64                 nested -= 1;
65                 new_impl_generics.push(tt);
66             }
67             _ => new_impl_generics.push(tt),
68         }
69     }
70     assert_eq!(nested, 0);
71     if in_generic && !inserted {
72         new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
73     }
74     (rest, new_impl_generics, ty_generics, last)
75 }
76 
derive(input: TokenStream) -> TokenStream77 pub(crate) fn derive(input: TokenStream) -> TokenStream {
78     let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
79     quote! {
80         ::pin_init::__derive_zeroable!(
81             parse_input:
82                 @sig(#(#rest)*),
83                 @impl_generics(#(#new_impl_generics)*),
84                 @ty_generics(#(#ty_generics)*),
85                 @body(#last),
86         );
87     }
88 }
89 
maybe_derive(input: TokenStream) -> TokenStream90 pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream {
91     let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
92     quote! {
93         ::pin_init::__maybe_derive_zeroable!(
94             parse_input:
95                 @sig(#(#rest)*),
96                 @impl_generics(#(#new_impl_generics)*),
97                 @ty_generics(#(#ty_generics)*),
98                 @body(#last),
99         );
100     }
101 }
102