1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 this: Option<This>, 18 path: Path, 19 brace_token: token::Brace, 20 fields: Punctuated<InitializerField, Token![,]>, 21 rest: Option<(Token![..], Expr)>, 22 error: Option<(Token![?], Type)>, 23 } 24 25 struct This { 26 _and_token: Token![&], 27 ident: Ident, 28 _in_token: Token![in], 29 } 30 31 enum InitializerField { 32 Value { 33 ident: Ident, 34 value: Option<(Token![:], Expr)>, 35 }, 36 Init { 37 ident: Ident, 38 _left_arrow_token: Token![<-], 39 value: Expr, 40 }, 41 Code { 42 _underscore_token: Token![_], 43 _colon_token: Token![:], 44 block: Block, 45 }, 46 } 47 48 impl InitializerField { 49 fn ident(&self) -> Option<&Ident> { 50 match self { 51 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 52 Self::Code { .. } => None, 53 } 54 } 55 } 56 57 pub(crate) fn expand( 58 Initializer { 59 this, 60 path, 61 brace_token, 62 fields, 63 rest, 64 error, 65 }: Initializer, 66 default_error: Option<&'static str>, 67 pinned: bool, 68 dcx: &mut DiagCtxt, 69 ) -> Result<TokenStream, ErrorGuaranteed> { 70 let error = error.map_or_else( 71 || { 72 if let Some(default_error) = default_error { 73 syn::parse_str(default_error).unwrap() 74 } else { 75 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 76 parse_quote!(::core::convert::Infallible) 77 } 78 }, 79 |(_, err)| err, 80 ); 81 let slot = format_ident!("slot"); 82 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 83 ( 84 format_ident!("HasPinData"), 85 format_ident!("PinData"), 86 format_ident!("__pin_data"), 87 format_ident!("pin_init_from_closure"), 88 ) 89 } else { 90 ( 91 format_ident!("HasInitData"), 92 format_ident!("InitData"), 93 format_ident!("__init_data"), 94 format_ident!("init_from_closure"), 95 ) 96 }; 97 let init_kind = get_init_kind(rest, dcx); 98 let zeroable_check = match init_kind { 99 InitKind::Normal => quote!(), 100 InitKind::Zeroing => quote! { 101 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 102 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 103 // This allows us to also remove the check that all fields are present (since we 104 // already set the memory to zero and that is a valid bit pattern). 105 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 106 where T: ::pin_init::Zeroable 107 {} 108 // Ensure that the struct is indeed `Zeroable`. 109 assert_zeroable(#slot); 110 // SAFETY: The type implements `Zeroable` by the check above. 111 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 112 }, 113 }; 114 let this = match this { 115 None => quote!(), 116 Some(This { ident, .. }) => quote! { 117 // Create the `this` so it can be referenced by the user inside of the 118 // expressions creating the individual fields. 119 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 120 }, 121 }; 122 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 123 let data = Ident::new("__data", Span::mixed_site()); 124 let init_fields = init_fields(&fields, pinned, &data, &slot); 125 let field_check = make_field_check(&fields, init_kind, &path); 126 Ok(quote! {{ 127 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 128 // type and shadow it later when we insert the arbitrary user code. That way there will be 129 // no possibility of returning without `unsafe`. 130 struct __InitOk; 131 132 // Get the data about fields from the supplied type. 133 // SAFETY: TODO 134 let #data = unsafe { 135 use ::pin_init::__internal::#has_data_trait; 136 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 137 // generics (which need to be present with that syntax). 138 #path::#get_data() 139 }; 140 // Ensure that `#data` really is of type `#data` and help with type inference: 141 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 142 #data, 143 move |slot| { 144 { 145 // Shadow the structure so it cannot be used to return early. 146 struct __InitOk; 147 #zeroable_check 148 #this 149 #init_fields 150 #field_check 151 } 152 Ok(__InitOk) 153 } 154 ); 155 let init = move |slot| -> ::core::result::Result<(), #error> { 156 init(slot).map(|__InitOk| ()) 157 }; 158 // SAFETY: TODO 159 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 160 init 161 }}) 162 } 163 164 enum InitKind { 165 Normal, 166 Zeroing, 167 } 168 169 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 170 let Some((dotdot, expr)) = rest else { 171 return InitKind::Normal; 172 }; 173 match &expr { 174 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 175 Expr::Path(ExprPath { 176 attrs, 177 qself: None, 178 path: 179 Path { 180 leading_colon: None, 181 segments, 182 }, 183 }) if attrs.is_empty() 184 && segments.len() == 2 185 && segments[0].ident == "Zeroable" 186 && segments[0].arguments.is_none() 187 && segments[1].ident == "init_zeroed" 188 && segments[1].arguments.is_none() => 189 { 190 return InitKind::Zeroing; 191 } 192 _ => {} 193 }, 194 _ => {} 195 } 196 dcx.error( 197 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 198 "expected nothing or `..Zeroable::init_zeroed()`.", 199 ); 200 InitKind::Normal 201 } 202 203 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 204 fn init_fields( 205 fields: &Punctuated<InitializerField, Token![,]>, 206 pinned: bool, 207 data: &Ident, 208 slot: &Ident, 209 ) -> TokenStream { 210 let mut guards = vec![]; 211 let mut res = TokenStream::new(); 212 for field in fields { 213 let init = match field { 214 InitializerField::Value { ident, value } => { 215 let mut value_ident = ident.clone(); 216 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 217 // Setting the span of `value_ident` to `value`'s span improves error messages 218 // when the type of `value` is wrong. 219 value_ident.set_span(value.span()); 220 quote!(let #value_ident = #value;) 221 }); 222 // Again span for better diagnostics 223 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 224 let accessor = if pinned { 225 let project_ident = format_ident!("__project_{ident}"); 226 quote! { 227 // SAFETY: TODO 228 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 229 } 230 } else { 231 quote! { 232 // SAFETY: TODO 233 unsafe { &mut (*#slot).#ident } 234 } 235 }; 236 quote! { 237 { 238 #value_prep 239 // SAFETY: TODO 240 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 241 } 242 #[allow(unused_variables)] 243 let #ident = #accessor; 244 } 245 } 246 InitializerField::Init { ident, value, .. } => { 247 // Again span for better diagnostics 248 let init = format_ident!("init", span = value.span()); 249 if pinned { 250 let project_ident = format_ident!("__project_{ident}"); 251 quote! { 252 { 253 let #init = #value; 254 // SAFETY: 255 // - `slot` is valid, because we are inside of an initializer closure, we 256 // return when an error/panic occurs. 257 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 258 // for `#ident`. 259 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 260 } 261 // SAFETY: TODO 262 #[allow(unused_variables)] 263 let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) }; 264 } 265 } else { 266 quote! { 267 { 268 let #init = #value; 269 // SAFETY: `slot` is valid, because we are inside of an initializer 270 // closure, we return when an error/panic occurs. 271 unsafe { 272 ::pin_init::Init::__init( 273 #init, 274 ::core::ptr::addr_of_mut!((*#slot).#ident), 275 )? 276 }; 277 } 278 // SAFETY: TODO 279 #[allow(unused_variables)] 280 let #ident = unsafe { &mut (*#slot).#ident }; 281 } 282 } 283 } 284 InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value), 285 }; 286 res.extend(init); 287 if let Some(ident) = field.ident() { 288 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 289 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 290 guards.push(guard.clone()); 291 res.extend(quote! { 292 // Create the drop guard: 293 // 294 // We rely on macro hygiene to make it impossible for users to access this local 295 // variable. 296 // SAFETY: We forget the guard later when initialization has succeeded. 297 let #guard = unsafe { 298 ::pin_init::__internal::DropGuard::new( 299 ::core::ptr::addr_of_mut!((*slot).#ident) 300 ) 301 }; 302 }); 303 } 304 } 305 quote! { 306 #res 307 // If execution reaches this point, all fields have been initialized. Therefore we can now 308 // dismiss the guards by forgetting them. 309 #(::core::mem::forget(#guards);)* 310 } 311 } 312 313 /// Generate the check for ensuring that every field has been initialized. 314 fn make_field_check( 315 fields: &Punctuated<InitializerField, Token![,]>, 316 init_kind: InitKind, 317 path: &Path, 318 ) -> TokenStream { 319 let fields = fields.iter().filter_map(|f| f.ident()); 320 match init_kind { 321 InitKind::Normal => quote! { 322 // We use unreachable code to ensure that all fields have been mentioned exactly once, 323 // this struct initializer will still be type-checked and complain with a very natural 324 // error message if a field is forgotten/mentioned more than once. 325 #[allow(unreachable_code, clippy::diverging_sub_expression)] 326 // SAFETY: this code is never executed. 327 let _ = || unsafe { 328 ::core::ptr::write(slot, #path { 329 #( 330 #fields: ::core::panic!(), 331 )* 332 }) 333 }; 334 }, 335 InitKind::Zeroing => quote! { 336 // We use unreachable code to ensure that all fields have been mentioned at most once. 337 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 338 // be zeroed. This struct initializer will still be type-checked and complain with a 339 // very natural error message if a field is mentioned more than once, or doesn't exist. 340 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 341 // SAFETY: this code is never executed. 342 let _ = || unsafe { 343 let mut zeroed = ::core::mem::zeroed(); 344 // We have to use type inference here to make zeroed have the correct type. This 345 // does not get executed, so it has no effect. 346 ::core::ptr::write(slot, zeroed); 347 zeroed = ::core::mem::zeroed(); 348 ::core::ptr::write(slot, #path { 349 #( 350 #fields: ::core::panic!(), 351 )* 352 ..zeroed 353 }) 354 }; 355 }, 356 } 357 } 358 359 impl Parse for Initializer { 360 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 361 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 362 let path = input.parse()?; 363 let content; 364 let brace_token = braced!(content in input); 365 let mut fields = Punctuated::new(); 366 loop { 367 let lh = content.lookahead1(); 368 if lh.peek(End) || lh.peek(Token![..]) { 369 break; 370 } else if lh.peek(Ident) || lh.peek(Token![_]) { 371 fields.push_value(content.parse()?); 372 let lh = content.lookahead1(); 373 if lh.peek(End) { 374 break; 375 } else if lh.peek(Token![,]) { 376 fields.push_punct(content.parse()?); 377 } else { 378 return Err(lh.error()); 379 } 380 } else { 381 return Err(lh.error()); 382 } 383 } 384 let rest = content 385 .peek(Token![..]) 386 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 387 .transpose()?; 388 let error = input 389 .peek(Token![?]) 390 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 391 .transpose()?; 392 Ok(Self { 393 this, 394 path, 395 brace_token, 396 fields, 397 rest, 398 error, 399 }) 400 } 401 } 402 403 impl Parse for This { 404 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 405 Ok(Self { 406 _and_token: input.parse()?, 407 ident: input.parse()?, 408 _in_token: input.parse()?, 409 }) 410 } 411 } 412 413 impl Parse for InitializerField { 414 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 415 let lh = input.lookahead1(); 416 if lh.peek(Token![_]) { 417 Ok(Self::Code { 418 _underscore_token: input.parse()?, 419 _colon_token: input.parse()?, 420 block: input.parse()?, 421 }) 422 } else if lh.peek(Ident) { 423 let ident = input.parse()?; 424 let lh = input.lookahead1(); 425 if lh.peek(Token![<-]) { 426 Ok(Self::Init { 427 ident, 428 _left_arrow_token: input.parse()?, 429 value: input.parse()?, 430 }) 431 } else if lh.peek(Token![:]) { 432 Ok(Self::Value { 433 ident, 434 value: Some((input.parse()?, input.parse()?)), 435 }) 436 } else if lh.peek(Token![,]) || lh.peek(End) { 437 Ok(Self::Value { ident, value: None }) 438 } else { 439 Err(lh.error()) 440 } 441 } else { 442 Err(lh.error()) 443 } 444 } 445 } 446