xref: /linux/rust/macros/for_lt.rs (revision e189bdb687a56bcf389798f1d3a2f261fff2ef54)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{
4     Span,
5     TokenStream, //
6 };
7 use quote::{
8     format_ident,
9     quote, //
10 };
11 use syn::{
12     parse::{
13         Parse,
14         ParseStream, //
15     },
16     visit::Visit,
17     visit_mut::VisitMut,
18     Lifetime,
19     Result,
20     Token,
21     Type, //
22 };
23 
24 pub(crate) enum HigherRankedType {
25     Explicit {
26         _for_token: Token![for],
27         _lt_token: Token![<],
28         lifetime: Lifetime,
29         _gt_token: Token![>],
30         ty: Type,
31     },
32     Implicit {
33         ty: Type,
34     },
35 }
36 
37 impl Parse for HigherRankedType {
38     fn parse(input: ParseStream<'_>) -> Result<Self> {
39         if input.peek(Token![for]) {
40             Ok(Self::Explicit {
41                 _for_token: input.parse()?,
42                 _lt_token: input.parse()?,
43                 lifetime: input.parse()?,
44                 _gt_token: input.parse()?,
45                 ty: input.parse()?,
46             })
47         } else {
48             Ok(Self::Implicit { ty: input.parse()? })
49         }
50     }
51 }
52 
53 trait TypeExt {
54     fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type;
55     fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type;
56     fn has_lifetime(&self, lt: &Lifetime) -> bool;
57 }
58 
59 impl TypeExt for Type {
60     fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type {
61         struct ElidedLifetimeExpander<'a>(&'a Lifetime);
62 
63         impl VisitMut for ElidedLifetimeExpander<'_> {
64             fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
65                 // Expand explicit `'_`
66                 if lifetime.ident == "_" {
67                     *lifetime = self.0.clone();
68                 }
69             }
70 
71             fn visit_type_reference_mut(&mut self, reference: &mut syn::TypeReference) {
72                 syn::visit_mut::visit_type_reference_mut(self, reference);
73 
74                 if reference.lifetime.is_none() {
75                     reference.lifetime = Some(self.0.clone());
76                 }
77             }
78         }
79 
80         let mut ret = self.clone();
81         ElidedLifetimeExpander(explicit_lt).visit_type_mut(&mut ret);
82         ret
83     }
84 
85     fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type {
86         struct LifetimeReplacer<'a>(&'a Lifetime, &'a Lifetime);
87 
88         impl VisitMut for LifetimeReplacer<'_> {
89             fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) {
90                 if lifetime.ident == self.0.ident {
91                     *lifetime = self.1.clone();
92                 }
93             }
94         }
95 
96         let mut ret = self.clone();
97         LifetimeReplacer(src, dst).visit_type_mut(&mut ret);
98         ret
99     }
100 
101     fn has_lifetime(&self, lt: &Lifetime) -> bool {
102         struct HasLifetime<'a>(&'a Lifetime, bool);
103 
104         impl Visit<'_> for HasLifetime<'_> {
105             fn visit_lifetime(&mut self, lifetime: &Lifetime) {
106                 if lifetime.ident == self.0.ident {
107                     self.1 = true;
108                 }
109             }
110 
111             // Macro invocations are opaque; conservatively assume they may
112             // reference the lifetime.
113             fn visit_macro(&mut self, _: &syn::Macro) {
114                 self.1 = true;
115             }
116         }
117 
118         let mut visitor = HasLifetime(lt, false);
119         visitor.visit_type(self);
120         visitor.1
121     }
122 }
123 
124 struct Prover<'a>(&'a Lifetime, Vec<&'a Type>);
125 
126 impl<'a> Prover<'a> {
127     /// Prove that `ty` is covariant over `'lt`.
128     ///
129     /// This also needs to prove that it'll be wellformed for any instance of `'lt`.
130     /// It can be assumed that `ty` will be wellformed if `'lt` is substituted to `'static`.
131     fn prove(&mut self, ty: &'a Type) {
132         match ty {
133             Type::Paren(ty) => self.prove(&ty.elem),
134             Type::Group(ty) => self.prove(&ty.elem),
135 
136             // No lifetime involved
137             Type::Never(_) => {}
138 
139             // `[T; N]` and `[T]` is covariant over `T`.
140             Type::Array(ty) => self.prove(&ty.elem),
141             Type::Slice(ty) => self.prove(&ty.elem),
142 
143             Type::Tuple(ty) => {
144                 for elem in &ty.elems {
145                     self.prove(elem);
146                 }
147             }
148 
149             // `*const T` is covariant over `T`
150             Type::Ptr(ty) if ty.const_token.is_some() => self.prove(&ty.elem),
151 
152             // `&T` is covariant over `T` and lifetime.
153             //
154             // Note that if we encounter `&'other_lt T`, then we still need to make sure the type
155             // is wellformed if `T` involves `&'lt`, so we defer to the compiler.
156             //
157             // This is to block cases like `ForLt!(for<'a> &'static &'a u32)`, as the presence of
158             // the type implies `'a: 'static` but this is unsound.
159             Type::Reference(ty)
160                 if ty.mutability.is_none() && ty.lifetime.as_ref() == Some(self.0) =>
161             {
162                 self.prove(&ty.elem)
163             }
164 
165             // `&[mut] T` is covariant over lifetime.
166             // In case we have `&[mut] NoLifetime`, we don't need to do additional checks.
167             Type::Reference(ty) if !ty.elem.has_lifetime(self.0) => (),
168 
169             // No mention of lifetime at all, no need to perform compiler check.
170             ty if !ty.has_lifetime(self.0) => (),
171 
172             // Otherwise, we need to emit checks so that compiler can determine if the types are
173             // actually covariant.
174             ty => self.1.push(ty),
175         }
176     }
177 }
178 
179 pub(crate) fn for_lt(input: HigherRankedType) -> TokenStream {
180     let (ty, lifetime) = match input {
181         HigherRankedType::Explicit { lifetime, ty, .. } => (ty, lifetime),
182         HigherRankedType::Implicit { ty } => {
183             // If there's no explicit `for<'a>` binder, inject a synthetic `'__elided` lifetime
184             // and expand elided sites.
185             let lifetime = Lifetime {
186                 apostrophe: Span::mixed_site(),
187                 ident: format_ident!("__elided", span = Span::mixed_site()),
188             };
189             (ty.expand_elided_lifetime(&lifetime), lifetime)
190         }
191     };
192 
193     let mut prover = Prover(&lifetime, Vec::new());
194     prover.prove(&ty);
195 
196     let mut proof = Vec::new();
197 
198     // Emit proofs for every type that requires additional compiler help in proving covariance.
199     for (idx, required_proof) in prover.1.into_iter().enumerate() {
200         // Insert a proof that the type is well-formed.
201         //
202         // This is intended to workaround a Rust compiler soundness bug related to HRTB.
203         // https://github.com/rust-lang/rust/issues/152489
204         //
205         // This needs to be a struct instead of fn to avoid the implied WF bounds.
206         let wf_proof_name = format_ident!("ProveWf{idx}");
207         proof.push(quote!(
208             struct #wf_proof_name<#lifetime>(
209                 ::core::marker::PhantomData<&#lifetime ()>, #required_proof
210             );
211         ));
212 
213         // Insert a proof that the type is covariant.
214         let cov_proof_name = format_ident!("prove_covariant_{idx}");
215         proof.push(quote!(
216             fn #cov_proof_name<'__short, '__long: '__short>(
217                 long: #wf_proof_name<'__long>
218             ) -> #wf_proof_name<'__short> {
219                 long
220             }
221         ));
222     }
223 
224     // Make sure that the type is wellformed when substituting lifetime with `'static`.
225     //
226     // Currently the Rust compiler doesn't check this, see the above `ProveWf` documentation.
227     //
228     // We prefer to use this way of proving WF-ness as it can work when generics are involved.
229     let ty_static = ty.replace_lifetime(
230         &lifetime,
231         &Lifetime {
232             apostrophe: Span::mixed_site(),
233             ident: format_ident!("static"),
234         },
235     );
236 
237     quote!(
238         ::kernel::types::for_lt::UnsafeForLtImpl::<
239             dyn for<#lifetime> ::kernel::types::for_lt::WithLt<#lifetime, Of = #ty>,
240             #ty_static,
241             {
242                 #(#proof)*
243 
244                 0
245             }
246         >
247     )
248 }
249