1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 } 66 67 struct DefaultErrorAttribute { 68 ty: Box<Type>, 69 } 70 71 pub(crate) fn expand( 72 Initializer { 73 attrs, 74 this, 75 path, 76 brace_token, 77 fields, 78 rest, 79 error, 80 }: Initializer, 81 default_error: Option<&'static str>, 82 pinned: bool, 83 dcx: &mut DiagCtxt, 84 ) -> Result<TokenStream, ErrorGuaranteed> { 85 let error = error.map_or_else( 86 || { 87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 88 #[expect(irrefutable_let_patterns)] 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields(&fields, pinned, &data, &slot); 149 let field_check = make_field_check(&fields, init_kind, &path); 150 Ok(quote! {{ 151 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 152 // type and shadow it later when we insert the arbitrary user code. That way there will be 153 // no possibility of returning without `unsafe`. 154 struct __InitOk; 155 156 // Get the data about fields from the supplied type. 157 // SAFETY: TODO 158 let #data = unsafe { 159 use ::pin_init::__internal::#has_data_trait; 160 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 161 // generics (which need to be present with that syntax). 162 #path::#get_data() 163 }; 164 // Ensure that `#data` really is of type `#data` and help with type inference: 165 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 166 #data, 167 move |slot| { 168 { 169 // Shadow the structure so it cannot be used to return early. 170 struct __InitOk; 171 #zeroable_check 172 #this 173 #init_fields 174 #field_check 175 } 176 Ok(__InitOk) 177 } 178 ); 179 let init = move |slot| -> ::core::result::Result<(), #error> { 180 init(slot).map(|__InitOk| ()) 181 }; 182 // SAFETY: TODO 183 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 184 init 185 }}) 186 } 187 188 enum InitKind { 189 Normal, 190 Zeroing, 191 } 192 193 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 194 let Some((dotdot, expr)) = rest else { 195 return InitKind::Normal; 196 }; 197 match &expr { 198 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 199 Expr::Path(ExprPath { 200 attrs, 201 qself: None, 202 path: 203 Path { 204 leading_colon: None, 205 segments, 206 }, 207 }) if attrs.is_empty() 208 && segments.len() == 2 209 && segments[0].ident == "Zeroable" 210 && segments[0].arguments.is_none() 211 && segments[1].ident == "init_zeroed" 212 && segments[1].arguments.is_none() => 213 { 214 return InitKind::Zeroing; 215 } 216 _ => {} 217 }, 218 _ => {} 219 } 220 dcx.error( 221 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 222 "expected nothing or `..Zeroable::init_zeroed()`.", 223 ); 224 InitKind::Normal 225 } 226 227 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 228 fn init_fields( 229 fields: &Punctuated<InitializerField, Token![,]>, 230 pinned: bool, 231 data: &Ident, 232 slot: &Ident, 233 ) -> TokenStream { 234 let mut guards = vec![]; 235 let mut guard_attrs = vec![]; 236 let mut res = TokenStream::new(); 237 for InitializerField { attrs, kind } in fields { 238 let cfgs = { 239 let mut cfgs = attrs.clone(); 240 cfgs.retain(|attr| attr.path().is_ident("cfg")); 241 cfgs 242 }; 243 let init = match kind { 244 InitializerKind::Value { ident, value } => { 245 let mut value_ident = ident.clone(); 246 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 247 // Setting the span of `value_ident` to `value`'s span improves error messages 248 // when the type of `value` is wrong. 249 value_ident.set_span(value.span()); 250 quote!(let #value_ident = #value;) 251 }); 252 // Again span for better diagnostics 253 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 254 // NOTE: the field accessor ensures that the initialized field is properly aligned. 255 // Unaligned fields will cause the compiler to emit E0793. We do not support 256 // unaligned fields since `Init::__init` requires an aligned pointer; the call to 257 // `ptr::write` below has the same requirement. 258 let accessor = if pinned { 259 let project_ident = format_ident!("__project_{ident}"); 260 quote! { 261 // SAFETY: TODO 262 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 263 } 264 } else { 265 quote! { 266 // SAFETY: TODO 267 unsafe { &mut (*#slot).#ident } 268 } 269 }; 270 quote! { 271 #(#attrs)* 272 { 273 #value_prep 274 // SAFETY: TODO 275 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 276 } 277 #(#cfgs)* 278 #[allow(unused_variables)] 279 let #ident = #accessor; 280 } 281 } 282 InitializerKind::Init { ident, value, .. } => { 283 // Again span for better diagnostics 284 let init = format_ident!("init", span = value.span()); 285 // NOTE: the field accessor ensures that the initialized field is properly aligned. 286 // Unaligned fields will cause the compiler to emit E0793. We do not support 287 // unaligned fields since `Init::__init` requires an aligned pointer; the call to 288 // `ptr::write` below has the same requirement. 289 let (value_init, accessor) = if pinned { 290 let project_ident = format_ident!("__project_{ident}"); 291 ( 292 quote! { 293 // SAFETY: 294 // - `slot` is valid, because we are inside of an initializer closure, we 295 // return when an error/panic occurs. 296 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 297 // for `#ident`. 298 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 299 }, 300 quote! { 301 // SAFETY: TODO 302 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 303 }, 304 ) 305 } else { 306 ( 307 quote! { 308 // SAFETY: `slot` is valid, because we are inside of an initializer 309 // closure, we return when an error/panic occurs. 310 unsafe { 311 ::pin_init::Init::__init( 312 #init, 313 ::core::ptr::addr_of_mut!((*#slot).#ident), 314 )? 315 }; 316 }, 317 quote! { 318 // SAFETY: TODO 319 unsafe { &mut (*#slot).#ident } 320 }, 321 ) 322 }; 323 quote! { 324 #(#attrs)* 325 { 326 let #init = #value; 327 #value_init 328 } 329 #(#cfgs)* 330 #[allow(unused_variables)] 331 let #ident = #accessor; 332 } 333 } 334 InitializerKind::Code { block: value, .. } => quote! { 335 #(#attrs)* 336 #[allow(unused_braces)] 337 #value 338 }, 339 }; 340 res.extend(init); 341 if let Some(ident) = kind.ident() { 342 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 343 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 344 res.extend(quote! { 345 #(#cfgs)* 346 // Create the drop guard: 347 // 348 // We rely on macro hygiene to make it impossible for users to access this local 349 // variable. 350 // SAFETY: We forget the guard later when initialization has succeeded. 351 let #guard = unsafe { 352 ::pin_init::__internal::DropGuard::new( 353 ::core::ptr::addr_of_mut!((*slot).#ident) 354 ) 355 }; 356 }); 357 guards.push(guard); 358 guard_attrs.push(cfgs); 359 } 360 } 361 quote! { 362 #res 363 // If execution reaches this point, all fields have been initialized. Therefore we can now 364 // dismiss the guards by forgetting them. 365 #( 366 #(#guard_attrs)* 367 ::core::mem::forget(#guards); 368 )* 369 } 370 } 371 372 /// Generate the check for ensuring that every field has been initialized. 373 fn make_field_check( 374 fields: &Punctuated<InitializerField, Token![,]>, 375 init_kind: InitKind, 376 path: &Path, 377 ) -> TokenStream { 378 let field_attrs = fields 379 .iter() 380 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); 381 let field_name = fields.iter().filter_map(|f| f.kind.ident()); 382 match init_kind { 383 InitKind::Normal => quote! { 384 // We use unreachable code to ensure that all fields have been mentioned exactly once, 385 // this struct initializer will still be type-checked and complain with a very natural 386 // error message if a field is forgotten/mentioned more than once. 387 #[allow(unreachable_code, clippy::diverging_sub_expression)] 388 // SAFETY: this code is never executed. 389 let _ = || unsafe { 390 ::core::ptr::write(slot, #path { 391 #( 392 #(#field_attrs)* 393 #field_name: ::core::panic!(), 394 )* 395 }) 396 }; 397 }, 398 InitKind::Zeroing => quote! { 399 // We use unreachable code to ensure that all fields have been mentioned at most once. 400 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 401 // be zeroed. This struct initializer will still be type-checked and complain with a 402 // very natural error message if a field is mentioned more than once, or doesn't exist. 403 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 404 // SAFETY: this code is never executed. 405 let _ = || unsafe { 406 ::core::ptr::write(slot, #path { 407 #( 408 #(#field_attrs)* 409 #field_name: ::core::panic!(), 410 )* 411 ..::core::mem::zeroed() 412 }) 413 }; 414 }, 415 } 416 } 417 418 impl Parse for Initializer { 419 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 420 let attrs = input.call(Attribute::parse_outer)?; 421 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 422 let path = input.parse()?; 423 let content; 424 let brace_token = braced!(content in input); 425 let mut fields = Punctuated::new(); 426 loop { 427 let lh = content.lookahead1(); 428 if lh.peek(End) || lh.peek(Token![..]) { 429 break; 430 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 431 fields.push_value(content.parse()?); 432 let lh = content.lookahead1(); 433 if lh.peek(End) { 434 break; 435 } else if lh.peek(Token![,]) { 436 fields.push_punct(content.parse()?); 437 } else { 438 return Err(lh.error()); 439 } 440 } else { 441 return Err(lh.error()); 442 } 443 } 444 let rest = content 445 .peek(Token![..]) 446 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 447 .transpose()?; 448 let error = input 449 .peek(Token![?]) 450 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 451 .transpose()?; 452 let attrs = attrs 453 .into_iter() 454 .map(|a| { 455 if a.path().is_ident("default_error") { 456 a.parse_args::<DefaultErrorAttribute>() 457 .map(InitializerAttribute::DefaultError) 458 } else { 459 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 460 } 461 }) 462 .collect::<Result<Vec<_>, _>>()?; 463 Ok(Self { 464 attrs, 465 this, 466 path, 467 brace_token, 468 fields, 469 rest, 470 error, 471 }) 472 } 473 } 474 475 impl Parse for DefaultErrorAttribute { 476 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 477 Ok(Self { ty: input.parse()? }) 478 } 479 } 480 481 impl Parse for This { 482 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 483 Ok(Self { 484 _and_token: input.parse()?, 485 ident: input.parse()?, 486 _in_token: input.parse()?, 487 }) 488 } 489 } 490 491 impl Parse for InitializerField { 492 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 493 let attrs = input.call(Attribute::parse_outer)?; 494 Ok(Self { 495 attrs, 496 kind: input.parse()?, 497 }) 498 } 499 } 500 501 impl Parse for InitializerKind { 502 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 503 let lh = input.lookahead1(); 504 if lh.peek(Token![_]) { 505 Ok(Self::Code { 506 _underscore_token: input.parse()?, 507 _colon_token: input.parse()?, 508 block: input.parse()?, 509 }) 510 } else if lh.peek(Ident) { 511 let ident = input.parse()?; 512 let lh = input.lookahead1(); 513 if lh.peek(Token![<-]) { 514 Ok(Self::Init { 515 ident, 516 _left_arrow_token: input.parse()?, 517 value: input.parse()?, 518 }) 519 } else if lh.peek(Token![:]) { 520 Ok(Self::Value { 521 ident, 522 value: Some((input.parse()?, input.parse()?)), 523 }) 524 } else if lh.peek(Token![,]) || lh.peek(End) { 525 Ok(Self::Value { ident, value: None }) 526 } else { 527 Err(lh.error()) 528 } 529 } else { 530 Err(lh.error()) 531 } 532 } 533 } 534