xref: /linux/rust/pin-init/internal/src/helpers.rs (revision 84837cf6fa541150a3012ea233225a7ecfa8771a)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 /// Parsed generics.
4 ///
5 /// See the field documentation for an explanation what each of the fields represents.
6 ///
7 /// # Examples
8 ///
9 /// ```rust,ignore
10 /// # let input = todo!();
11 /// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input);
12 /// quote! {
13 ///     struct Foo<$($decl_generics)*> {
14 ///         // ...
15 ///     }
16 ///
17 ///     impl<$impl_generics> Foo<$ty_generics> {
18 ///         fn foo() {
19 ///             // ...
20 ///         }
21 ///     }
22 /// }
23 /// ```
24 pub(crate) struct Generics {
25     /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`).
26     ///
27     /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`).
28     pub(crate) decl_generics: Vec<TokenTree>,
29     /// The generics with bounds (e.g. `T: Clone, const N: usize`).
30     ///
31     /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`.
32     pub(crate) impl_generics: Vec<TokenTree>,
33     /// The generics without bounds and without default values (e.g. `T, N`).
34     ///
35     /// Use this when you use the type that is declared with these generics e.g.
36     /// `Foo<$ty_generics>`.
37     pub(crate) ty_generics: Vec<TokenTree>,
38 }
39 
40 /// Parses the given `TokenStream` into `Generics` and the rest.
41 ///
42 /// The generics are not present in the rest, but a where clause might remain.
43 pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) {
44     // The generics with bounds and default values.
45     let mut decl_generics = vec![];
46     // `impl_generics`, the declared generics with their bounds.
47     let mut impl_generics = vec![];
48     // Only the names of the generics, without any bounds.
49     let mut ty_generics = vec![];
50     // Tokens not related to the generics e.g. the `where` token and definition.
51     let mut rest = vec![];
52     // The current level of `<`.
53     let mut nesting = 0;
54     let mut toks = input.into_iter();
55     // If we are at the beginning of a generic parameter.
56     let mut at_start = true;
57     let mut skip_until_comma = false;
58     while let Some(tt) = toks.next() {
59         if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') {
60             // Found the end of the generics.
61             break;
62         } else if nesting >= 1 {
63             decl_generics.push(tt.clone());
64         }
65         match tt.clone() {
66             TokenTree::Punct(p) if p.as_char() == '<' => {
67                 if nesting >= 1 && !skip_until_comma {
68                     // This is inside of the generics and part of some bound.
69                     impl_generics.push(tt);
70                 }
71                 nesting += 1;
72             }
73             TokenTree::Punct(p) if p.as_char() == '>' => {
74                 // This is a parsing error, so we just end it here.
75                 if nesting == 0 {
76                     break;
77                 } else {
78                     nesting -= 1;
79                     if nesting >= 1 && !skip_until_comma {
80                         // We are still inside of the generics and part of some bound.
81                         impl_generics.push(tt);
82                     }
83                 }
84             }
85             TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => {
86                 if nesting == 1 {
87                     impl_generics.push(tt.clone());
88                     impl_generics.push(tt);
89                     skip_until_comma = false;
90                 }
91             }
92             _ if !skip_until_comma => {
93                 match nesting {
94                     // If we haven't entered the generics yet, we still want to keep these tokens.
95                     0 => rest.push(tt),
96                     1 => {
97                         // Here depending on the token, it might be a generic variable name.
98                         match tt.clone() {
99                             TokenTree::Ident(i) if at_start && i.to_string() == "const" => {
100                                 let Some(name) = toks.next() else {
101                                     // Parsing error.
102                                     break;
103                                 };
104                                 impl_generics.push(tt);
105                                 impl_generics.push(name.clone());
106                                 ty_generics.push(name.clone());
107                                 decl_generics.push(name);
108                                 at_start = false;
109                             }
110                             TokenTree::Ident(_) if at_start => {
111                                 impl_generics.push(tt.clone());
112                                 ty_generics.push(tt);
113                                 at_start = false;
114                             }
115                             TokenTree::Punct(p) if p.as_char() == ',' => {
116                                 impl_generics.push(tt.clone());
117                                 ty_generics.push(tt);
118                                 at_start = true;
119                             }
120                             // Lifetimes begin with `'`.
121                             TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
122                                 impl_generics.push(tt.clone());
123                                 ty_generics.push(tt);
124                             }
125                             // Generics can have default values, we skip these.
126                             TokenTree::Punct(p) if p.as_char() == '=' => {
127                                 skip_until_comma = true;
128                             }
129                             _ => impl_generics.push(tt),
130                         }
131                     }
132                     _ => impl_generics.push(tt),
133                 }
134             }
135             _ => {}
136         }
137     }
138     rest.extend(toks);
139     (
140         Generics {
141             impl_generics,
142             decl_generics,
143             ty_generics,
144         },
145         rest,
146     )
147 }
148