1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 } 66 67 struct DefaultErrorAttribute { 68 ty: Box<Type>, 69 } 70 71 pub(crate) fn expand( 72 Initializer { 73 attrs, 74 this, 75 path, 76 brace_token, 77 fields, 78 rest, 79 error, 80 }: Initializer, 81 default_error: Option<&'static str>, 82 pinned: bool, 83 dcx: &mut DiagCtxt, 84 ) -> Result<TokenStream, ErrorGuaranteed> { 85 let error = error.map_or_else( 86 || { 87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 88 #[expect(irrefutable_let_patterns)] 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields(&fields, pinned, &data, &slot); 149 let field_check = make_field_check(&fields, init_kind, &path); 150 Ok(quote! {{ 151 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 152 // type and shadow it later when we insert the arbitrary user code. That way there will be 153 // no possibility of returning without `unsafe`. 154 struct __InitOk; 155 156 // Get the data about fields from the supplied type. 157 // SAFETY: TODO 158 let #data = unsafe { 159 use ::pin_init::__internal::#has_data_trait; 160 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 161 // generics (which need to be present with that syntax). 162 #path::#get_data() 163 }; 164 // Ensure that `#data` really is of type `#data` and help with type inference: 165 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 166 #data, 167 move |slot| { 168 { 169 // Shadow the structure so it cannot be used to return early. 170 struct __InitOk; 171 #zeroable_check 172 #this 173 #init_fields 174 #field_check 175 } 176 Ok(__InitOk) 177 } 178 ); 179 let init = move |slot| -> ::core::result::Result<(), #error> { 180 init(slot).map(|__InitOk| ()) 181 }; 182 // SAFETY: TODO 183 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 184 init 185 }}) 186 } 187 188 enum InitKind { 189 Normal, 190 Zeroing, 191 } 192 193 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 194 let Some((dotdot, expr)) = rest else { 195 return InitKind::Normal; 196 }; 197 match &expr { 198 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 199 Expr::Path(ExprPath { 200 attrs, 201 qself: None, 202 path: 203 Path { 204 leading_colon: None, 205 segments, 206 }, 207 }) if attrs.is_empty() 208 && segments.len() == 2 209 && segments[0].ident == "Zeroable" 210 && segments[0].arguments.is_none() 211 && segments[1].ident == "init_zeroed" 212 && segments[1].arguments.is_none() => 213 { 214 return InitKind::Zeroing; 215 } 216 _ => {} 217 }, 218 _ => {} 219 } 220 dcx.error( 221 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 222 "expected nothing or `..Zeroable::init_zeroed()`.", 223 ); 224 InitKind::Normal 225 } 226 227 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 228 fn init_fields( 229 fields: &Punctuated<InitializerField, Token![,]>, 230 pinned: bool, 231 data: &Ident, 232 slot: &Ident, 233 ) -> TokenStream { 234 let mut guards = vec![]; 235 let mut guard_attrs = vec![]; 236 let mut res = TokenStream::new(); 237 for InitializerField { attrs, kind } in fields { 238 let cfgs = { 239 let mut cfgs = attrs.clone(); 240 cfgs.retain(|attr| attr.path().is_ident("cfg")); 241 cfgs 242 }; 243 let init = match kind { 244 InitializerKind::Value { ident, value } => { 245 let mut value_ident = ident.clone(); 246 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 247 // Setting the span of `value_ident` to `value`'s span improves error messages 248 // when the type of `value` is wrong. 249 value_ident.set_span(value.span()); 250 quote!(let #value_ident = #value;) 251 }); 252 // Again span for better diagnostics 253 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 254 let accessor = if pinned { 255 let project_ident = format_ident!("__project_{ident}"); 256 quote! { 257 // SAFETY: TODO 258 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 259 } 260 } else { 261 quote! { 262 // SAFETY: TODO 263 unsafe { &mut (*#slot).#ident } 264 } 265 }; 266 quote! { 267 #(#attrs)* 268 { 269 #value_prep 270 // SAFETY: TODO 271 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 272 } 273 #(#cfgs)* 274 #[allow(unused_variables)] 275 let #ident = #accessor; 276 } 277 } 278 InitializerKind::Init { ident, value, .. } => { 279 // Again span for better diagnostics 280 let init = format_ident!("init", span = value.span()); 281 if pinned { 282 let project_ident = format_ident!("__project_{ident}"); 283 quote! { 284 #(#attrs)* 285 { 286 let #init = #value; 287 // SAFETY: 288 // - `slot` is valid, because we are inside of an initializer closure, we 289 // return when an error/panic occurs. 290 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 291 // for `#ident`. 292 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 293 } 294 #(#cfgs)* 295 // SAFETY: TODO 296 #[allow(unused_variables)] 297 let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) }; 298 } 299 } else { 300 quote! { 301 #(#attrs)* 302 { 303 let #init = #value; 304 // SAFETY: `slot` is valid, because we are inside of an initializer 305 // closure, we return when an error/panic occurs. 306 unsafe { 307 ::pin_init::Init::__init( 308 #init, 309 ::core::ptr::addr_of_mut!((*#slot).#ident), 310 )? 311 }; 312 } 313 #(#cfgs)* 314 // SAFETY: TODO 315 #[allow(unused_variables)] 316 let #ident = unsafe { &mut (*#slot).#ident }; 317 } 318 } 319 } 320 InitializerKind::Code { block: value, .. } => quote! { 321 #(#attrs)* 322 #[allow(unused_braces)] 323 #value 324 }, 325 }; 326 res.extend(init); 327 if let Some(ident) = kind.ident() { 328 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 329 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 330 res.extend(quote! { 331 #(#cfgs)* 332 // Create the drop guard: 333 // 334 // We rely on macro hygiene to make it impossible for users to access this local 335 // variable. 336 // SAFETY: We forget the guard later when initialization has succeeded. 337 let #guard = unsafe { 338 ::pin_init::__internal::DropGuard::new( 339 ::core::ptr::addr_of_mut!((*slot).#ident) 340 ) 341 }; 342 }); 343 guards.push(guard); 344 guard_attrs.push(cfgs); 345 } 346 } 347 quote! { 348 #res 349 // If execution reaches this point, all fields have been initialized. Therefore we can now 350 // dismiss the guards by forgetting them. 351 #( 352 #(#guard_attrs)* 353 ::core::mem::forget(#guards); 354 )* 355 } 356 } 357 358 /// Generate the check for ensuring that every field has been initialized. 359 fn make_field_check( 360 fields: &Punctuated<InitializerField, Token![,]>, 361 init_kind: InitKind, 362 path: &Path, 363 ) -> TokenStream { 364 let field_attrs = fields 365 .iter() 366 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); 367 let field_name = fields.iter().filter_map(|f| f.kind.ident()); 368 match init_kind { 369 InitKind::Normal => quote! { 370 // We use unreachable code to ensure that all fields have been mentioned exactly once, 371 // this struct initializer will still be type-checked and complain with a very natural 372 // error message if a field is forgotten/mentioned more than once. 373 #[allow(unreachable_code, clippy::diverging_sub_expression)] 374 // SAFETY: this code is never executed. 375 let _ = || unsafe { 376 ::core::ptr::write(slot, #path { 377 #( 378 #(#field_attrs)* 379 #field_name: ::core::panic!(), 380 )* 381 }) 382 }; 383 }, 384 InitKind::Zeroing => quote! { 385 // We use unreachable code to ensure that all fields have been mentioned at most once. 386 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 387 // be zeroed. This struct initializer will still be type-checked and complain with a 388 // very natural error message if a field is mentioned more than once, or doesn't exist. 389 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 390 // SAFETY: this code is never executed. 391 let _ = || unsafe { 392 let mut zeroed = ::core::mem::zeroed(); 393 // We have to use type inference here to make zeroed have the correct type. This 394 // does not get executed, so it has no effect. 395 ::core::ptr::write(slot, zeroed); 396 zeroed = ::core::mem::zeroed(); 397 ::core::ptr::write(slot, #path { 398 #( 399 #(#field_attrs)* 400 #field_name: ::core::panic!(), 401 )* 402 ..zeroed 403 }) 404 }; 405 }, 406 } 407 } 408 409 impl Parse for Initializer { 410 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 411 let attrs = input.call(Attribute::parse_outer)?; 412 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 413 let path = input.parse()?; 414 let content; 415 let brace_token = braced!(content in input); 416 let mut fields = Punctuated::new(); 417 loop { 418 let lh = content.lookahead1(); 419 if lh.peek(End) || lh.peek(Token![..]) { 420 break; 421 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 422 fields.push_value(content.parse()?); 423 let lh = content.lookahead1(); 424 if lh.peek(End) { 425 break; 426 } else if lh.peek(Token![,]) { 427 fields.push_punct(content.parse()?); 428 } else { 429 return Err(lh.error()); 430 } 431 } else { 432 return Err(lh.error()); 433 } 434 } 435 let rest = content 436 .peek(Token![..]) 437 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 438 .transpose()?; 439 let error = input 440 .peek(Token![?]) 441 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 442 .transpose()?; 443 let attrs = attrs 444 .into_iter() 445 .map(|a| { 446 if a.path().is_ident("default_error") { 447 a.parse_args::<DefaultErrorAttribute>() 448 .map(InitializerAttribute::DefaultError) 449 } else { 450 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 451 } 452 }) 453 .collect::<Result<Vec<_>, _>>()?; 454 Ok(Self { 455 attrs, 456 this, 457 path, 458 brace_token, 459 fields, 460 rest, 461 error, 462 }) 463 } 464 } 465 466 impl Parse for DefaultErrorAttribute { 467 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 468 Ok(Self { ty: input.parse()? }) 469 } 470 } 471 472 impl Parse for This { 473 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 474 Ok(Self { 475 _and_token: input.parse()?, 476 ident: input.parse()?, 477 _in_token: input.parse()?, 478 }) 479 } 480 } 481 482 impl Parse for InitializerField { 483 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 484 let attrs = input.call(Attribute::parse_outer)?; 485 Ok(Self { 486 attrs, 487 kind: input.parse()?, 488 }) 489 } 490 } 491 492 impl Parse for InitializerKind { 493 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 494 let lh = input.lookahead1(); 495 if lh.peek(Token![_]) { 496 Ok(Self::Code { 497 _underscore_token: input.parse()?, 498 _colon_token: input.parse()?, 499 block: input.parse()?, 500 }) 501 } else if lh.peek(Ident) { 502 let ident = input.parse()?; 503 let lh = input.lookahead1(); 504 if lh.peek(Token![<-]) { 505 Ok(Self::Init { 506 ident, 507 _left_arrow_token: input.parse()?, 508 value: input.parse()?, 509 }) 510 } else if lh.peek(Token![:]) { 511 Ok(Self::Value { 512 ident, 513 value: Some((input.parse()?, input.parse()?)), 514 }) 515 } else if lh.peek(Token![,]) || lh.peek(End) { 516 Ok(Self::Value { ident, value: None }) 517 } else { 518 Err(lh.error()) 519 } 520 } else { 521 Err(lh.error()) 522 } 523 } 524 } 525