xref: /linux/rust/pin-init/internal/src/zeroable.rs (revision 514e4ed2c9da9112fcd378b1bd9b9d120edf7ca9)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use crate::helpers::{parse_generics, Generics};
4 use proc_macro2::{TokenStream, TokenTree};
5 use quote::quote;
6 
7 pub(crate) fn parse_zeroable_derive_input(
8     input: TokenStream,
9 ) -> (
10     Vec<TokenTree>,
11     Vec<TokenTree>,
12     Vec<TokenTree>,
13     Option<TokenTree>,
14 ) {
15     let (
16         Generics {
17             impl_generics,
18             decl_generics: _,
19             ty_generics,
20         },
21         mut rest,
22     ) = parse_generics(input);
23     // This should be the body of the struct `{...}`.
24     let last = rest.pop();
25     // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
26     let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
27     // Are we inside of a generic where we want to add `Zeroable`?
28     let mut in_generic = !impl_generics.is_empty();
29     // Have we already inserted `Zeroable`?
30     let mut inserted = false;
31     // Level of `<>` nestings.
32     let mut nested = 0;
33     for tt in impl_generics {
34         match &tt {
35             // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
36             TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
37                 if in_generic && !inserted {
38                     new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
39                 }
40                 in_generic = true;
41                 inserted = false;
42                 new_impl_generics.push(tt);
43             }
44             // If we find `'`, then we are entering a lifetime.
45             TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
46                 in_generic = false;
47                 new_impl_generics.push(tt);
48             }
49             TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
50                 new_impl_generics.push(tt);
51                 if in_generic {
52                     new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
53                     inserted = true;
54                 }
55             }
56             TokenTree::Punct(p) if p.as_char() == '<' => {
57                 nested += 1;
58                 new_impl_generics.push(tt);
59             }
60             TokenTree::Punct(p) if p.as_char() == '>' => {
61                 assert!(nested > 0);
62                 nested -= 1;
63                 new_impl_generics.push(tt);
64             }
65             _ => new_impl_generics.push(tt),
66         }
67     }
68     assert_eq!(nested, 0);
69     if in_generic && !inserted {
70         new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
71     }
72     (rest, new_impl_generics, ty_generics, last)
73 }
74 
75 pub(crate) fn derive(input: TokenStream) -> TokenStream {
76     let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
77     quote! {
78         ::pin_init::__derive_zeroable!(
79             parse_input:
80                 @sig(#(#rest)*),
81                 @impl_generics(#(#new_impl_generics)*),
82                 @ty_generics(#(#ty_generics)*),
83                 @body(#last),
84         );
85     }
86 }
87 
88 pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream {
89     let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
90     quote! {
91         ::pin_init::__maybe_derive_zeroable!(
92             parse_input:
93                 @sig(#(#rest)*),
94                 @impl_generics(#(#new_impl_generics)*),
95                 @ty_generics(#(#ty_generics)*),
96                 @body(#last),
97         );
98     }
99 }
100