1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{ 4 Span, 5 TokenStream, // 6 }; 7 use quote::{ 8 format_ident, 9 quote, // 10 }; 11 use syn::{ 12 parse::{ 13 Parse, 14 ParseStream, // 15 }, 16 visit::Visit, 17 visit_mut::VisitMut, 18 Lifetime, 19 Result, 20 Token, 21 Type, // 22 }; 23 24 pub(crate) enum HigherRankedType { 25 Explicit { 26 _for_token: Token![for], 27 _lt_token: Token![<], 28 lifetime: Lifetime, 29 _gt_token: Token![>], 30 ty: Type, 31 }, 32 Implicit { 33 ty: Type, 34 }, 35 } 36 37 impl Parse for HigherRankedType { 38 fn parse(input: ParseStream<'_>) -> Result<Self> { 39 if input.peek(Token![for]) { 40 Ok(Self::Explicit { 41 _for_token: input.parse()?, 42 _lt_token: input.parse()?, 43 lifetime: input.parse()?, 44 _gt_token: input.parse()?, 45 ty: input.parse()?, 46 }) 47 } else { 48 Ok(Self::Implicit { ty: input.parse()? }) 49 } 50 } 51 } 52 53 trait TypeExt { 54 fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type; 55 fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type; 56 fn has_lifetime(&self, lt: &Lifetime) -> bool; 57 } 58 59 impl TypeExt for Type { 60 fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type { 61 struct ElidedLifetimeExpander<'a>(&'a Lifetime); 62 63 impl VisitMut for ElidedLifetimeExpander<'_> { 64 fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) { 65 // Expand explicit `'_` 66 if lifetime.ident == "_" { 67 *lifetime = self.0.clone(); 68 } 69 } 70 71 fn visit_type_reference_mut(&mut self, reference: &mut syn::TypeReference) { 72 syn::visit_mut::visit_type_reference_mut(self, reference); 73 74 if reference.lifetime.is_none() { 75 reference.lifetime = Some(self.0.clone()); 76 } 77 } 78 } 79 80 let mut ret = self.clone(); 81 ElidedLifetimeExpander(explicit_lt).visit_type_mut(&mut ret); 82 ret 83 } 84 85 fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type { 86 struct LifetimeReplacer<'a>(&'a Lifetime, &'a Lifetime); 87 88 impl VisitMut for LifetimeReplacer<'_> { 89 fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) { 90 if lifetime.ident == self.0.ident { 91 *lifetime = self.1.clone(); 92 } 93 } 94 } 95 96 let mut ret = self.clone(); 97 LifetimeReplacer(src, dst).visit_type_mut(&mut ret); 98 ret 99 } 100 101 fn has_lifetime(&self, lt: &Lifetime) -> bool { 102 struct HasLifetime<'a>(&'a Lifetime, bool); 103 104 impl Visit<'_> for HasLifetime<'_> { 105 fn visit_lifetime(&mut self, lifetime: &Lifetime) { 106 if lifetime.ident == self.0.ident { 107 self.1 = true; 108 } 109 } 110 111 // Macro invocations are opaque; conservatively assume they may 112 // reference the lifetime. 113 fn visit_macro(&mut self, _: &syn::Macro) { 114 self.1 = true; 115 } 116 } 117 118 let mut visitor = HasLifetime(lt, false); 119 visitor.visit_type(self); 120 visitor.1 121 } 122 } 123 124 struct Prover<'a>(&'a Lifetime, Vec<&'a Type>); 125 126 impl<'a> Prover<'a> { 127 /// Prove that `ty` is covariant over `'lt`. 128 /// 129 /// This also needs to prove that it'll be wellformed for any instance of `'lt`. 130 /// It can be assumed that `ty` will be wellformed if `'lt` is substituted to `'static`. 131 fn prove(&mut self, ty: &'a Type) { 132 match ty { 133 Type::Paren(ty) => self.prove(&ty.elem), 134 Type::Group(ty) => self.prove(&ty.elem), 135 136 // No lifetime involved 137 Type::Never(_) => {} 138 139 // `[T; N]` and `[T]` is covariant over `T`. 140 Type::Array(ty) => self.prove(&ty.elem), 141 Type::Slice(ty) => self.prove(&ty.elem), 142 143 Type::Tuple(ty) => { 144 for elem in &ty.elems { 145 self.prove(elem); 146 } 147 } 148 149 // `*const T` is covariant over `T` 150 Type::Ptr(ty) if ty.const_token.is_some() => self.prove(&ty.elem), 151 152 // `&T` is covariant over `T` and lifetime. 153 // 154 // Note that if we encounter `&'other_lt T`, then we still need to make sure the type 155 // is wellformed if `T` involves `&'lt`, so we defer to the compiler. 156 // 157 // This is to block cases like `ForLt!(for<'a> &'static &'a u32)`, as the presence of 158 // the type implies `'a: 'static` but this is unsound. 159 Type::Reference(ty) 160 if ty.mutability.is_none() && ty.lifetime.as_ref() == Some(self.0) => 161 { 162 self.prove(&ty.elem) 163 } 164 165 // `&[mut] T` is covariant over lifetime. 166 // In case we have `&[mut] NoLifetime`, we don't need to do additional checks. 167 Type::Reference(ty) if !ty.elem.has_lifetime(self.0) => (), 168 169 // No mention of lifetime at all, no need to perform compiler check. 170 ty if !ty.has_lifetime(self.0) => (), 171 172 // Otherwise, we need to emit checks so that compiler can determine if the types are 173 // actually covariant. 174 ty => self.1.push(ty), 175 } 176 } 177 } 178 179 pub(crate) fn for_lt(input: HigherRankedType) -> TokenStream { 180 let (ty, lifetime) = match input { 181 HigherRankedType::Explicit { lifetime, ty, .. } => (ty, lifetime), 182 HigherRankedType::Implicit { ty } => { 183 // If there's no explicit `for<'a>` binder, inject a synthetic `'__elided` lifetime 184 // and expand elided sites. 185 let lifetime = Lifetime { 186 apostrophe: Span::mixed_site(), 187 ident: format_ident!("__elided", span = Span::mixed_site()), 188 }; 189 (ty.expand_elided_lifetime(&lifetime), lifetime) 190 } 191 }; 192 193 let mut prover = Prover(&lifetime, Vec::new()); 194 prover.prove(&ty); 195 196 let mut proof = Vec::new(); 197 198 // Emit proofs for every type that requires additional compiler help in proving covariance. 199 for (idx, required_proof) in prover.1.into_iter().enumerate() { 200 // Insert a proof that the type is well-formed. 201 // 202 // This is intended to workaround a Rust compiler soundness bug related to HRTB. 203 // https://github.com/rust-lang/rust/issues/152489 204 // 205 // This needs to be a struct instead of fn to avoid the implied WF bounds. 206 let wf_proof_name = format_ident!("ProveWf{idx}"); 207 proof.push(quote!( 208 struct #wf_proof_name<#lifetime>( 209 ::core::marker::PhantomData<&#lifetime ()>, #required_proof 210 ); 211 )); 212 213 // Insert a proof that the type is covariant. 214 let cov_proof_name = format_ident!("prove_covariant_{idx}"); 215 proof.push(quote!( 216 fn #cov_proof_name<'__short, '__long: '__short>( 217 long: #wf_proof_name<'__long> 218 ) -> #wf_proof_name<'__short> { 219 long 220 } 221 )); 222 } 223 224 // Make sure that the type is wellformed when substituting lifetime with `'static`. 225 // 226 // Currently the Rust compiler doesn't check this, see the above `ProveWf` documentation. 227 // 228 // We prefer to use this way of proving WF-ness as it can work when generics are involved. 229 let ty_static = ty.replace_lifetime( 230 &lifetime, 231 &Lifetime { 232 apostrophe: Span::mixed_site(), 233 ident: format_ident!("static"), 234 }, 235 ); 236 237 quote!( 238 ::kernel::types::for_lt::UnsafeForLtImpl::< 239 dyn for<#lifetime> ::kernel::types::for_lt::WithLt<#lifetime, Of = #ty>, 240 #ty_static, 241 { 242 #(#proof)* 243 244 0 245 } 246 > 247 ) 248 } 249