1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 DisableInitializedFieldAccess, 66 } 67 68 struct DefaultErrorAttribute { 69 ty: Box<Type>, 70 } 71 72 pub(crate) fn expand( 73 Initializer { 74 attrs, 75 this, 76 path, 77 brace_token, 78 fields, 79 rest, 80 error, 81 }: Initializer, 82 default_error: Option<&'static str>, 83 pinned: bool, 84 dcx: &mut DiagCtxt, 85 ) -> Result<TokenStream, ErrorGuaranteed> { 86 let error = error.map_or_else( 87 || { 88 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields( 149 &fields, 150 pinned, 151 !attrs 152 .iter() 153 .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)), 154 &data, 155 &slot, 156 ); 157 let field_check = make_field_check(&fields, init_kind, &path); 158 Ok(quote! {{ 159 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 160 // type and shadow it later when we insert the arbitrary user code. That way there will be 161 // no possibility of returning without `unsafe`. 162 struct __InitOk; 163 164 // Get the data about fields from the supplied type. 165 // SAFETY: TODO 166 let #data = unsafe { 167 use ::pin_init::__internal::#has_data_trait; 168 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 169 // generics (which need to be present with that syntax). 170 #path::#get_data() 171 }; 172 // Ensure that `#data` really is of type `#data` and help with type inference: 173 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 174 #data, 175 move |slot| { 176 { 177 // Shadow the structure so it cannot be used to return early. 178 struct __InitOk; 179 #zeroable_check 180 #this 181 #init_fields 182 #field_check 183 } 184 Ok(__InitOk) 185 } 186 ); 187 let init = move |slot| -> ::core::result::Result<(), #error> { 188 init(slot).map(|__InitOk| ()) 189 }; 190 // SAFETY: TODO 191 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 192 init 193 }}) 194 } 195 196 enum InitKind { 197 Normal, 198 Zeroing, 199 } 200 201 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 202 let Some((dotdot, expr)) = rest else { 203 return InitKind::Normal; 204 }; 205 match &expr { 206 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 207 Expr::Path(ExprPath { 208 attrs, 209 qself: None, 210 path: 211 Path { 212 leading_colon: None, 213 segments, 214 }, 215 }) if attrs.is_empty() 216 && segments.len() == 2 217 && segments[0].ident == "Zeroable" 218 && segments[0].arguments.is_none() 219 && segments[1].ident == "init_zeroed" 220 && segments[1].arguments.is_none() => 221 { 222 return InitKind::Zeroing; 223 } 224 _ => {} 225 }, 226 _ => {} 227 } 228 dcx.error( 229 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 230 "expected nothing or `..Zeroable::init_zeroed()`.", 231 ); 232 InitKind::Normal 233 } 234 235 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 236 fn init_fields( 237 fields: &Punctuated<InitializerField, Token![,]>, 238 pinned: bool, 239 generate_initialized_accessors: bool, 240 data: &Ident, 241 slot: &Ident, 242 ) -> TokenStream { 243 let mut guards = vec![]; 244 let mut guard_attrs = vec![]; 245 let mut res = TokenStream::new(); 246 for InitializerField { attrs, kind } in fields { 247 let cfgs = { 248 let mut cfgs = attrs.clone(); 249 cfgs.retain(|attr| attr.path().is_ident("cfg")); 250 cfgs 251 }; 252 let init = match kind { 253 InitializerKind::Value { ident, value } => { 254 let mut value_ident = ident.clone(); 255 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 256 // Setting the span of `value_ident` to `value`'s span improves error messages 257 // when the type of `value` is wrong. 258 value_ident.set_span(value.span()); 259 quote!(let #value_ident = #value;) 260 }); 261 // Again span for better diagnostics 262 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 263 let accessor = if pinned { 264 let project_ident = format_ident!("__project_{ident}"); 265 quote! { 266 // SAFETY: TODO 267 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 268 } 269 } else { 270 quote! { 271 // SAFETY: TODO 272 unsafe { &mut (*#slot).#ident } 273 } 274 }; 275 let accessor = generate_initialized_accessors.then(|| { 276 quote! { 277 #(#cfgs)* 278 #[allow(unused_variables)] 279 let #ident = #accessor; 280 } 281 }); 282 quote! { 283 #(#attrs)* 284 { 285 #value_prep 286 // SAFETY: TODO 287 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 288 } 289 #accessor 290 } 291 } 292 InitializerKind::Init { ident, value, .. } => { 293 // Again span for better diagnostics 294 let init = format_ident!("init", span = value.span()); 295 let (value_init, accessor) = if pinned { 296 let project_ident = format_ident!("__project_{ident}"); 297 ( 298 quote! { 299 // SAFETY: 300 // - `slot` is valid, because we are inside of an initializer closure, we 301 // return when an error/panic occurs. 302 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 303 // for `#ident`. 304 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 305 }, 306 quote! { 307 // SAFETY: TODO 308 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 309 }, 310 ) 311 } else { 312 ( 313 quote! { 314 // SAFETY: `slot` is valid, because we are inside of an initializer 315 // closure, we return when an error/panic occurs. 316 unsafe { 317 ::pin_init::Init::__init( 318 #init, 319 ::core::ptr::addr_of_mut!((*#slot).#ident), 320 )? 321 }; 322 }, 323 quote! { 324 // SAFETY: TODO 325 unsafe { &mut (*#slot).#ident } 326 }, 327 ) 328 }; 329 let accessor = generate_initialized_accessors.then(|| { 330 quote! { 331 #(#cfgs)* 332 #[allow(unused_variables)] 333 let #ident = #accessor; 334 } 335 }); 336 quote! { 337 #(#attrs)* 338 { 339 let #init = #value; 340 #value_init 341 } 342 #accessor 343 } 344 } 345 InitializerKind::Code { block: value, .. } => quote! { 346 #(#attrs)* 347 #[allow(unused_braces)] 348 #value 349 }, 350 }; 351 res.extend(init); 352 if let Some(ident) = kind.ident() { 353 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 354 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 355 res.extend(quote! { 356 #(#cfgs)* 357 // Create the drop guard: 358 // 359 // We rely on macro hygiene to make it impossible for users to access this local 360 // variable. 361 // SAFETY: We forget the guard later when initialization has succeeded. 362 let #guard = unsafe { 363 ::pin_init::__internal::DropGuard::new( 364 ::core::ptr::addr_of_mut!((*slot).#ident) 365 ) 366 }; 367 }); 368 guards.push(guard); 369 guard_attrs.push(cfgs); 370 } 371 } 372 quote! { 373 #res 374 // If execution reaches this point, all fields have been initialized. Therefore we can now 375 // dismiss the guards by forgetting them. 376 #( 377 #(#guard_attrs)* 378 ::core::mem::forget(#guards); 379 )* 380 } 381 } 382 383 /// Generate the check for ensuring that every field has been initialized. 384 fn make_field_check( 385 fields: &Punctuated<InitializerField, Token![,]>, 386 init_kind: InitKind, 387 path: &Path, 388 ) -> TokenStream { 389 let field_attrs = fields 390 .iter() 391 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); 392 let field_name = fields.iter().filter_map(|f| f.kind.ident()); 393 match init_kind { 394 InitKind::Normal => quote! { 395 // We use unreachable code to ensure that all fields have been mentioned exactly once, 396 // this struct initializer will still be type-checked and complain with a very natural 397 // error message if a field is forgotten/mentioned more than once. 398 #[allow(unreachable_code, clippy::diverging_sub_expression)] 399 // SAFETY: this code is never executed. 400 let _ = || unsafe { 401 ::core::ptr::write(slot, #path { 402 #( 403 #(#field_attrs)* 404 #field_name: ::core::panic!(), 405 )* 406 }) 407 }; 408 }, 409 InitKind::Zeroing => quote! { 410 // We use unreachable code to ensure that all fields have been mentioned at most once. 411 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 412 // be zeroed. This struct initializer will still be type-checked and complain with a 413 // very natural error message if a field is mentioned more than once, or doesn't exist. 414 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 415 // SAFETY: this code is never executed. 416 let _ = || unsafe { 417 let mut zeroed = ::core::mem::zeroed(); 418 // We have to use type inference here to make zeroed have the correct type. This 419 // does not get executed, so it has no effect. 420 ::core::ptr::write(slot, zeroed); 421 zeroed = ::core::mem::zeroed(); 422 ::core::ptr::write(slot, #path { 423 #( 424 #(#field_attrs)* 425 #field_name: ::core::panic!(), 426 )* 427 ..zeroed 428 }) 429 }; 430 }, 431 } 432 } 433 434 impl Parse for Initializer { 435 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 436 let attrs = input.call(Attribute::parse_outer)?; 437 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 438 let path = input.parse()?; 439 let content; 440 let brace_token = braced!(content in input); 441 let mut fields = Punctuated::new(); 442 loop { 443 let lh = content.lookahead1(); 444 if lh.peek(End) || lh.peek(Token![..]) { 445 break; 446 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 447 fields.push_value(content.parse()?); 448 let lh = content.lookahead1(); 449 if lh.peek(End) { 450 break; 451 } else if lh.peek(Token![,]) { 452 fields.push_punct(content.parse()?); 453 } else { 454 return Err(lh.error()); 455 } 456 } else { 457 return Err(lh.error()); 458 } 459 } 460 let rest = content 461 .peek(Token![..]) 462 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 463 .transpose()?; 464 let error = input 465 .peek(Token![?]) 466 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 467 .transpose()?; 468 let attrs = attrs 469 .into_iter() 470 .map(|a| { 471 if a.path().is_ident("default_error") { 472 a.parse_args::<DefaultErrorAttribute>() 473 .map(InitializerAttribute::DefaultError) 474 } else if a.path().is_ident("disable_initialized_field_access") { 475 a.meta 476 .require_path_only() 477 .map(|_| InitializerAttribute::DisableInitializedFieldAccess) 478 } else { 479 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 480 } 481 }) 482 .collect::<Result<Vec<_>, _>>()?; 483 Ok(Self { 484 attrs, 485 this, 486 path, 487 brace_token, 488 fields, 489 rest, 490 error, 491 }) 492 } 493 } 494 495 impl Parse for DefaultErrorAttribute { 496 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 497 Ok(Self { ty: input.parse()? }) 498 } 499 } 500 501 impl Parse for This { 502 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 503 Ok(Self { 504 _and_token: input.parse()?, 505 ident: input.parse()?, 506 _in_token: input.parse()?, 507 }) 508 } 509 } 510 511 impl Parse for InitializerField { 512 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 513 let attrs = input.call(Attribute::parse_outer)?; 514 Ok(Self { 515 attrs, 516 kind: input.parse()?, 517 }) 518 } 519 } 520 521 impl Parse for InitializerKind { 522 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 523 let lh = input.lookahead1(); 524 if lh.peek(Token![_]) { 525 Ok(Self::Code { 526 _underscore_token: input.parse()?, 527 _colon_token: input.parse()?, 528 block: input.parse()?, 529 }) 530 } else if lh.peek(Ident) { 531 let ident = input.parse()?; 532 let lh = input.lookahead1(); 533 if lh.peek(Token![<-]) { 534 Ok(Self::Init { 535 ident, 536 _left_arrow_token: input.parse()?, 537 value: input.parse()?, 538 }) 539 } else if lh.peek(Token![:]) { 540 Ok(Self::Value { 541 ident, 542 value: Some((input.parse()?, input.parse()?)), 543 }) 544 } else if lh.peek(Token![,]) || lh.peek(End) { 545 Ok(Self::Value { ident, value: None }) 546 } else { 547 Err(lh.error()) 548 } 549 } else { 550 Err(lh.error()) 551 } 552 } 553 } 554