1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 DisableInitializedFieldAccess, 66 } 67 68 struct DefaultErrorAttribute { 69 ty: Box<Type>, 70 } 71 72 pub(crate) fn expand( 73 Initializer { 74 attrs, 75 this, 76 path, 77 brace_token, 78 fields, 79 rest, 80 error, 81 }: Initializer, 82 default_error: Option<&'static str>, 83 pinned: bool, 84 dcx: &mut DiagCtxt, 85 ) -> Result<TokenStream, ErrorGuaranteed> { 86 let error = error.map_or_else( 87 || { 88 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields( 149 &fields, 150 pinned, 151 !attrs 152 .iter() 153 .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)), 154 &data, 155 &slot, 156 ); 157 let field_check = make_field_check(&fields, init_kind, &path); 158 Ok(quote! {{ 159 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 160 // type and shadow it later when we insert the arbitrary user code. That way there will be 161 // no possibility of returning without `unsafe`. 162 struct __InitOk; 163 164 // Get the data about fields from the supplied type. 165 // SAFETY: TODO 166 let #data = unsafe { 167 use ::pin_init::__internal::#has_data_trait; 168 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 169 // generics (which need to be present with that syntax). 170 #path::#get_data() 171 }; 172 // Ensure that `#data` really is of type `#data` and help with type inference: 173 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 174 #data, 175 move |slot| { 176 { 177 // Shadow the structure so it cannot be used to return early. 178 struct __InitOk; 179 #zeroable_check 180 #this 181 #init_fields 182 #field_check 183 } 184 Ok(__InitOk) 185 } 186 ); 187 let init = move |slot| -> ::core::result::Result<(), #error> { 188 init(slot).map(|__InitOk| ()) 189 }; 190 // SAFETY: TODO 191 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 192 init 193 }}) 194 } 195 196 enum InitKind { 197 Normal, 198 Zeroing, 199 } 200 201 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 202 let Some((dotdot, expr)) = rest else { 203 return InitKind::Normal; 204 }; 205 match &expr { 206 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 207 Expr::Path(ExprPath { 208 attrs, 209 qself: None, 210 path: 211 Path { 212 leading_colon: None, 213 segments, 214 }, 215 }) if attrs.is_empty() 216 && segments.len() == 2 217 && segments[0].ident == "Zeroable" 218 && segments[0].arguments.is_none() 219 && segments[1].ident == "init_zeroed" 220 && segments[1].arguments.is_none() => 221 { 222 return InitKind::Zeroing; 223 } 224 _ => {} 225 }, 226 _ => {} 227 } 228 dcx.error( 229 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 230 "expected nothing or `..Zeroable::init_zeroed()`.", 231 ); 232 InitKind::Normal 233 } 234 235 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 236 fn init_fields( 237 fields: &Punctuated<InitializerField, Token![,]>, 238 pinned: bool, 239 generate_initialized_accessors: bool, 240 data: &Ident, 241 slot: &Ident, 242 ) -> TokenStream { 243 let mut guards = vec![]; 244 let mut guard_attrs = vec![]; 245 let mut res = TokenStream::new(); 246 for InitializerField { attrs, kind } in fields { 247 let cfgs = { 248 let mut cfgs = attrs.clone(); 249 cfgs.retain(|attr| attr.path().is_ident("cfg")); 250 cfgs 251 }; 252 let init = match kind { 253 InitializerKind::Value { ident, value } => { 254 let mut value_ident = ident.clone(); 255 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 256 // Setting the span of `value_ident` to `value`'s span improves error messages 257 // when the type of `value` is wrong. 258 value_ident.set_span(value.span()); 259 quote!(let #value_ident = #value;) 260 }); 261 // Again span for better diagnostics 262 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 263 let accessor = if pinned { 264 let project_ident = format_ident!("__project_{ident}"); 265 quote! { 266 // SAFETY: TODO 267 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 268 } 269 } else { 270 quote! { 271 // SAFETY: TODO 272 unsafe { &mut (*#slot).#ident } 273 } 274 }; 275 let accessor = generate_initialized_accessors.then(|| { 276 quote! { 277 #(#cfgs)* 278 #[allow(unused_variables)] 279 let #ident = #accessor; 280 } 281 }); 282 quote! { 283 #(#attrs)* 284 { 285 #value_prep 286 // SAFETY: TODO 287 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 288 } 289 #accessor 290 } 291 } 292 InitializerKind::Init { ident, value, .. } => { 293 // Again span for better diagnostics 294 let init = format_ident!("init", span = value.span()); 295 let (value_init, accessor) = if pinned { 296 let project_ident = format_ident!("__project_{ident}"); 297 ( 298 quote! { 299 // SAFETY: 300 // - `slot` is valid, because we are inside of an initializer closure, we 301 // return when an error/panic occurs. 302 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 303 // for `#ident`. 304 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 305 }, 306 quote! { 307 // SAFETY: TODO 308 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 309 }, 310 ) 311 } else { 312 ( 313 quote! { 314 // SAFETY: `slot` is valid, because we are inside of an initializer 315 // closure, we return when an error/panic occurs. 316 unsafe { 317 ::pin_init::Init::__init( 318 #init, 319 ::core::ptr::addr_of_mut!((*#slot).#ident), 320 )? 321 }; 322 }, 323 quote! { 324 // SAFETY: TODO 325 unsafe { &mut (*#slot).#ident } 326 }, 327 ) 328 }; 329 let accessor = generate_initialized_accessors.then(|| { 330 quote! { 331 #(#cfgs)* 332 #[allow(unused_variables)] 333 let #ident = #accessor; 334 } 335 }); 336 quote! { 337 #(#attrs)* 338 { 339 let #init = #value; 340 #value_init 341 } 342 #accessor 343 } 344 } 345 InitializerKind::Code { block: value, .. } => quote! { 346 #(#attrs)* 347 #[allow(unused_braces)] 348 #value 349 }, 350 }; 351 res.extend(init); 352 if let Some(ident) = kind.ident() { 353 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 354 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 355 res.extend(quote! { 356 #(#cfgs)* 357 // Create the drop guard: 358 // 359 // We rely on macro hygiene to make it impossible for users to access this local 360 // variable. 361 // SAFETY: We forget the guard later when initialization has succeeded. 362 let #guard = unsafe { 363 ::pin_init::__internal::DropGuard::new( 364 ::core::ptr::addr_of_mut!((*slot).#ident) 365 ) 366 }; 367 }); 368 guards.push(guard); 369 guard_attrs.push(cfgs); 370 } 371 } 372 quote! { 373 #res 374 // If execution reaches this point, all fields have been initialized. Therefore we can now 375 // dismiss the guards by forgetting them. 376 #( 377 #(#guard_attrs)* 378 ::core::mem::forget(#guards); 379 )* 380 } 381 } 382 383 /// Generate the check for ensuring that every field has been initialized. 384 fn make_field_check( 385 fields: &Punctuated<InitializerField, Token![,]>, 386 init_kind: InitKind, 387 path: &Path, 388 ) -> TokenStream { 389 let field_attrs = fields 390 .iter() 391 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); 392 let field_name = fields.iter().filter_map(|f| f.kind.ident()); 393 match init_kind { 394 InitKind::Normal => quote! { 395 // We use unreachable code to ensure that all fields have been mentioned exactly once, 396 // this struct initializer will still be type-checked and complain with a very natural 397 // error message if a field is forgotten/mentioned more than once. 398 #[allow(unreachable_code, clippy::diverging_sub_expression)] 399 // SAFETY: this code is never executed. 400 let _ = || unsafe { 401 ::core::ptr::write(slot, #path { 402 #( 403 #(#field_attrs)* 404 #field_name: ::core::panic!(), 405 )* 406 }) 407 }; 408 }, 409 InitKind::Zeroing => quote! { 410 // We use unreachable code to ensure that all fields have been mentioned at most once. 411 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 412 // be zeroed. This struct initializer will still be type-checked and complain with a 413 // very natural error message if a field is mentioned more than once, or doesn't exist. 414 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 415 // SAFETY: this code is never executed. 416 let _ = || unsafe { 417 ::core::ptr::write(slot, #path { 418 #( 419 #(#field_attrs)* 420 #field_name: ::core::panic!(), 421 )* 422 ..::core::mem::zeroed() 423 }) 424 }; 425 }, 426 } 427 } 428 429 impl Parse for Initializer { 430 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 431 let attrs = input.call(Attribute::parse_outer)?; 432 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 433 let path = input.parse()?; 434 let content; 435 let brace_token = braced!(content in input); 436 let mut fields = Punctuated::new(); 437 loop { 438 let lh = content.lookahead1(); 439 if lh.peek(End) || lh.peek(Token![..]) { 440 break; 441 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 442 fields.push_value(content.parse()?); 443 let lh = content.lookahead1(); 444 if lh.peek(End) { 445 break; 446 } else if lh.peek(Token![,]) { 447 fields.push_punct(content.parse()?); 448 } else { 449 return Err(lh.error()); 450 } 451 } else { 452 return Err(lh.error()); 453 } 454 } 455 let rest = content 456 .peek(Token![..]) 457 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 458 .transpose()?; 459 let error = input 460 .peek(Token![?]) 461 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 462 .transpose()?; 463 let attrs = attrs 464 .into_iter() 465 .map(|a| { 466 if a.path().is_ident("default_error") { 467 a.parse_args::<DefaultErrorAttribute>() 468 .map(InitializerAttribute::DefaultError) 469 } else if a.path().is_ident("disable_initialized_field_access") { 470 a.meta 471 .require_path_only() 472 .map(|_| InitializerAttribute::DisableInitializedFieldAccess) 473 } else { 474 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 475 } 476 }) 477 .collect::<Result<Vec<_>, _>>()?; 478 Ok(Self { 479 attrs, 480 this, 481 path, 482 brace_token, 483 fields, 484 rest, 485 error, 486 }) 487 } 488 } 489 490 impl Parse for DefaultErrorAttribute { 491 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 492 Ok(Self { ty: input.parse()? }) 493 } 494 } 495 496 impl Parse for This { 497 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 498 Ok(Self { 499 _and_token: input.parse()?, 500 ident: input.parse()?, 501 _in_token: input.parse()?, 502 }) 503 } 504 } 505 506 impl Parse for InitializerField { 507 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 508 let attrs = input.call(Attribute::parse_outer)?; 509 Ok(Self { 510 attrs, 511 kind: input.parse()?, 512 }) 513 } 514 } 515 516 impl Parse for InitializerKind { 517 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 518 let lh = input.lookahead1(); 519 if lh.peek(Token![_]) { 520 Ok(Self::Code { 521 _underscore_token: input.parse()?, 522 _colon_token: input.parse()?, 523 block: input.parse()?, 524 }) 525 } else if lh.peek(Ident) { 526 let ident = input.parse()?; 527 let lh = input.lookahead1(); 528 if lh.peek(Token![<-]) { 529 Ok(Self::Init { 530 ident, 531 _left_arrow_token: input.parse()?, 532 value: input.parse()?, 533 }) 534 } else if lh.peek(Token![:]) { 535 Ok(Self::Value { 536 ident, 537 value: Some((input.parse()?, input.parse()?)), 538 }) 539 } else if lh.peek(Token![,]) || lh.peek(End) { 540 Ok(Self::Value { ident, value: None }) 541 } else { 542 Err(lh.error()) 543 } 544 } else { 545 Err(lh.error()) 546 } 547 } 548 } 549