1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 } 66 67 struct DefaultErrorAttribute { 68 ty: Box<Type>, 69 } 70 71 pub(crate) fn expand( 72 Initializer { 73 attrs, 74 this, 75 path, 76 brace_token, 77 fields, 78 rest, 79 error, 80 }: Initializer, 81 default_error: Option<&'static str>, 82 pinned: bool, 83 dcx: &mut DiagCtxt, 84 ) -> Result<TokenStream, ErrorGuaranteed> { 85 let error = error.map_or_else( 86 || { 87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 88 #[expect(irrefutable_let_patterns)] 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields(&fields, pinned, &data, &slot); 149 let field_check = make_field_check(&fields, init_kind, &path); 150 Ok(quote! {{ 151 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 152 // type and shadow it later when we insert the arbitrary user code. That way there will be 153 // no possibility of returning without `unsafe`. 154 struct __InitOk; 155 156 // Get the data about fields from the supplied type. 157 // SAFETY: TODO 158 let #data = unsafe { 159 use ::pin_init::__internal::#has_data_trait; 160 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 161 // generics (which need to be present with that syntax). 162 #path::#get_data() 163 }; 164 // Ensure that `#data` really is of type `#data` and help with type inference: 165 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 166 #data, 167 move |slot| { 168 { 169 // Shadow the structure so it cannot be used to return early. 170 struct __InitOk; 171 #zeroable_check 172 #this 173 #init_fields 174 #field_check 175 } 176 Ok(__InitOk) 177 } 178 ); 179 let init = move |slot| -> ::core::result::Result<(), #error> { 180 init(slot).map(|__InitOk| ()) 181 }; 182 // SAFETY: TODO 183 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 184 init 185 }}) 186 } 187 188 enum InitKind { 189 Normal, 190 Zeroing, 191 } 192 193 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 194 let Some((dotdot, expr)) = rest else { 195 return InitKind::Normal; 196 }; 197 match &expr { 198 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 199 Expr::Path(ExprPath { 200 attrs, 201 qself: None, 202 path: 203 Path { 204 leading_colon: None, 205 segments, 206 }, 207 }) if attrs.is_empty() 208 && segments.len() == 2 209 && segments[0].ident == "Zeroable" 210 && segments[0].arguments.is_none() 211 && segments[1].ident == "init_zeroed" 212 && segments[1].arguments.is_none() => 213 { 214 return InitKind::Zeroing; 215 } 216 _ => {} 217 }, 218 _ => {} 219 } 220 dcx.error( 221 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 222 "expected nothing or `..Zeroable::init_zeroed()`.", 223 ); 224 InitKind::Normal 225 } 226 227 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 228 fn init_fields( 229 fields: &Punctuated<InitializerField, Token![,]>, 230 pinned: bool, 231 data: &Ident, 232 slot: &Ident, 233 ) -> TokenStream { 234 let mut guards = vec![]; 235 let mut guard_attrs = vec![]; 236 let mut res = TokenStream::new(); 237 for InitializerField { attrs, kind } in fields { 238 let cfgs = { 239 let mut cfgs = attrs.clone(); 240 cfgs.retain(|attr| attr.path().is_ident("cfg")); 241 cfgs 242 }; 243 let init = match kind { 244 InitializerKind::Value { ident, value } => { 245 let mut value_ident = ident.clone(); 246 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 247 // Setting the span of `value_ident` to `value`'s span improves error messages 248 // when the type of `value` is wrong. 249 value_ident.set_span(value.span()); 250 quote!(let #value_ident = #value;) 251 }); 252 // Again span for better diagnostics 253 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 254 let accessor = if pinned { 255 let project_ident = format_ident!("__project_{ident}"); 256 quote! { 257 // SAFETY: TODO 258 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 259 } 260 } else { 261 quote! { 262 // SAFETY: TODO 263 unsafe { &mut (*#slot).#ident } 264 } 265 }; 266 quote! { 267 #(#attrs)* 268 { 269 #value_prep 270 // SAFETY: TODO 271 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 272 } 273 #(#cfgs)* 274 #[allow(unused_variables)] 275 let #ident = #accessor; 276 } 277 } 278 InitializerKind::Init { ident, value, .. } => { 279 // Again span for better diagnostics 280 let init = format_ident!("init", span = value.span()); 281 let (value_init, accessor) = if pinned { 282 let project_ident = format_ident!("__project_{ident}"); 283 ( 284 quote! { 285 // SAFETY: 286 // - `slot` is valid, because we are inside of an initializer closure, we 287 // return when an error/panic occurs. 288 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 289 // for `#ident`. 290 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 291 }, 292 quote! { 293 // SAFETY: TODO 294 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 295 }, 296 ) 297 } else { 298 ( 299 quote! { 300 // SAFETY: `slot` is valid, because we are inside of an initializer 301 // closure, we return when an error/panic occurs. 302 unsafe { 303 ::pin_init::Init::__init( 304 #init, 305 ::core::ptr::addr_of_mut!((*#slot).#ident), 306 )? 307 }; 308 }, 309 quote! { 310 // SAFETY: TODO 311 unsafe { &mut (*#slot).#ident } 312 }, 313 ) 314 }; 315 quote! { 316 #(#attrs)* 317 { 318 let #init = #value; 319 #value_init 320 } 321 #(#cfgs)* 322 #[allow(unused_variables)] 323 let #ident = #accessor; 324 } 325 } 326 InitializerKind::Code { block: value, .. } => quote! { 327 #(#attrs)* 328 #[allow(unused_braces)] 329 #value 330 }, 331 }; 332 res.extend(init); 333 if let Some(ident) = kind.ident() { 334 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 335 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 336 res.extend(quote! { 337 #(#cfgs)* 338 // Create the drop guard: 339 // 340 // We rely on macro hygiene to make it impossible for users to access this local 341 // variable. 342 // SAFETY: We forget the guard later when initialization has succeeded. 343 let #guard = unsafe { 344 ::pin_init::__internal::DropGuard::new( 345 ::core::ptr::addr_of_mut!((*slot).#ident) 346 ) 347 }; 348 }); 349 guards.push(guard); 350 guard_attrs.push(cfgs); 351 } 352 } 353 quote! { 354 #res 355 // If execution reaches this point, all fields have been initialized. Therefore we can now 356 // dismiss the guards by forgetting them. 357 #( 358 #(#guard_attrs)* 359 ::core::mem::forget(#guards); 360 )* 361 } 362 } 363 364 /// Generate the check for ensuring that every field has been initialized. 365 fn make_field_check( 366 fields: &Punctuated<InitializerField, Token![,]>, 367 init_kind: InitKind, 368 path: &Path, 369 ) -> TokenStream { 370 let field_attrs = fields 371 .iter() 372 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); 373 let field_name = fields.iter().filter_map(|f| f.kind.ident()); 374 match init_kind { 375 InitKind::Normal => quote! { 376 // We use unreachable code to ensure that all fields have been mentioned exactly once, 377 // this struct initializer will still be type-checked and complain with a very natural 378 // error message if a field is forgotten/mentioned more than once. 379 #[allow(unreachable_code, clippy::diverging_sub_expression)] 380 // SAFETY: this code is never executed. 381 let _ = || unsafe { 382 ::core::ptr::write(slot, #path { 383 #( 384 #(#field_attrs)* 385 #field_name: ::core::panic!(), 386 )* 387 }) 388 }; 389 }, 390 InitKind::Zeroing => quote! { 391 // We use unreachable code to ensure that all fields have been mentioned at most once. 392 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 393 // be zeroed. This struct initializer will still be type-checked and complain with a 394 // very natural error message if a field is mentioned more than once, or doesn't exist. 395 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 396 // SAFETY: this code is never executed. 397 let _ = || unsafe { 398 ::core::ptr::write(slot, #path { 399 #( 400 #(#field_attrs)* 401 #field_name: ::core::panic!(), 402 )* 403 ..::core::mem::zeroed() 404 }) 405 }; 406 }, 407 } 408 } 409 410 impl Parse for Initializer { 411 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 412 let attrs = input.call(Attribute::parse_outer)?; 413 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 414 let path = input.parse()?; 415 let content; 416 let brace_token = braced!(content in input); 417 let mut fields = Punctuated::new(); 418 loop { 419 let lh = content.lookahead1(); 420 if lh.peek(End) || lh.peek(Token![..]) { 421 break; 422 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 423 fields.push_value(content.parse()?); 424 let lh = content.lookahead1(); 425 if lh.peek(End) { 426 break; 427 } else if lh.peek(Token![,]) { 428 fields.push_punct(content.parse()?); 429 } else { 430 return Err(lh.error()); 431 } 432 } else { 433 return Err(lh.error()); 434 } 435 } 436 let rest = content 437 .peek(Token![..]) 438 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 439 .transpose()?; 440 let error = input 441 .peek(Token![?]) 442 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 443 .transpose()?; 444 let attrs = attrs 445 .into_iter() 446 .map(|a| { 447 if a.path().is_ident("default_error") { 448 a.parse_args::<DefaultErrorAttribute>() 449 .map(InitializerAttribute::DefaultError) 450 } else { 451 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 452 } 453 }) 454 .collect::<Result<Vec<_>, _>>()?; 455 Ok(Self { 456 attrs, 457 this, 458 path, 459 brace_token, 460 fields, 461 rest, 462 error, 463 }) 464 } 465 } 466 467 impl Parse for DefaultErrorAttribute { 468 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 469 Ok(Self { ty: input.parse()? }) 470 } 471 } 472 473 impl Parse for This { 474 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 475 Ok(Self { 476 _and_token: input.parse()?, 477 ident: input.parse()?, 478 _in_token: input.parse()?, 479 }) 480 } 481 } 482 483 impl Parse for InitializerField { 484 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 485 let attrs = input.call(Attribute::parse_outer)?; 486 Ok(Self { 487 attrs, 488 kind: input.parse()?, 489 }) 490 } 491 } 492 493 impl Parse for InitializerKind { 494 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 495 let lh = input.lookahead1(); 496 if lh.peek(Token![_]) { 497 Ok(Self::Code { 498 _underscore_token: input.parse()?, 499 _colon_token: input.parse()?, 500 block: input.parse()?, 501 }) 502 } else if lh.peek(Ident) { 503 let ident = input.parse()?; 504 let lh = input.lookahead1(); 505 if lh.peek(Token![<-]) { 506 Ok(Self::Init { 507 ident, 508 _left_arrow_token: input.parse()?, 509 value: input.parse()?, 510 }) 511 } else if lh.peek(Token![:]) { 512 Ok(Self::Value { 513 ident, 514 value: Some((input.parse()?, input.parse()?)), 515 }) 516 } else if lh.peek(Token![,]) || lh.peek(End) { 517 Ok(Self::Value { ident, value: None }) 518 } else { 519 Err(lh.error()) 520 } 521 } else { 522 Err(lh.error()) 523 } 524 } 525 } 526