1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 } 66 67 struct DefaultErrorAttribute { 68 ty: Box<Type>, 69 } 70 71 pub(crate) fn expand( 72 Initializer { 73 attrs, 74 this, 75 path, 76 brace_token, 77 fields, 78 rest, 79 error, 80 }: Initializer, 81 default_error: Option<&'static str>, 82 pinned: bool, 83 dcx: &mut DiagCtxt, 84 ) -> Result<TokenStream, ErrorGuaranteed> { 85 let error = error.map_or_else( 86 || { 87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 88 #[expect(irrefutable_let_patterns)] 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("PinData"), 110 format_ident!("__pin_data"), 111 format_ident!("pin_init_from_closure"), 112 ) 113 } else { 114 ( 115 format_ident!("HasInitData"), 116 format_ident!("InitData"), 117 format_ident!("__init_data"), 118 format_ident!("init_from_closure"), 119 ) 120 }; 121 let init_kind = get_init_kind(rest, dcx); 122 let zeroable_check = match init_kind { 123 InitKind::Normal => quote!(), 124 InitKind::Zeroing => quote! { 125 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 126 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 127 // This allows us to also remove the check that all fields are present (since we 128 // already set the memory to zero and that is a valid bit pattern). 129 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 130 where T: ::pin_init::Zeroable 131 {} 132 // Ensure that the struct is indeed `Zeroable`. 133 assert_zeroable(#slot); 134 // SAFETY: The type implements `Zeroable` by the check above. 135 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 136 }, 137 }; 138 let this = match this { 139 None => quote!(), 140 Some(This { ident, .. }) => quote! { 141 // Create the `this` so it can be referenced by the user inside of the 142 // expressions creating the individual fields. 143 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 144 }, 145 }; 146 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 147 let data = Ident::new("__data", Span::mixed_site()); 148 let init_fields = init_fields(&fields, pinned, &data, &slot); 149 let field_check = make_field_check(&fields, init_kind, &path); 150 Ok(quote! {{ 151 // Get the data about fields from the supplied type. 152 // SAFETY: TODO 153 let #data = unsafe { 154 use ::pin_init::__internal::#has_data_trait; 155 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 156 // generics (which need to be present with that syntax). 157 #path::#get_data() 158 }; 159 // Ensure that `#data` really is of type `#data` and help with type inference: 160 let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>( 161 #data, 162 move |slot| { 163 #zeroable_check 164 #this 165 #init_fields 166 #field_check 167 // SAFETY: we are the `init!` macro that is allowed to call this. 168 Ok(unsafe { ::pin_init::__internal::InitOk::new() }) 169 } 170 ); 171 let init = move |slot| -> ::core::result::Result<(), #error> { 172 init(slot).map(|__InitOk| ()) 173 }; 174 // SAFETY: TODO 175 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 176 // FIXME: this let binding is required to avoid a compiler error (cycle when computing the 177 // opaque type returned by this function) before Rust 1.81. Remove after MSRV bump. 178 #[allow( 179 clippy::let_and_return, 180 reason = "some clippy versions warn about the let binding" 181 )] 182 init 183 }}) 184 } 185 186 enum InitKind { 187 Normal, 188 Zeroing, 189 } 190 191 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 192 let Some((dotdot, expr)) = rest else { 193 return InitKind::Normal; 194 }; 195 match &expr { 196 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 197 Expr::Path(ExprPath { 198 attrs, 199 qself: None, 200 path: 201 Path { 202 leading_colon: None, 203 segments, 204 }, 205 }) if attrs.is_empty() 206 && segments.len() == 2 207 && segments[0].ident == "Zeroable" 208 && segments[0].arguments.is_none() 209 && segments[1].ident == "init_zeroed" 210 && segments[1].arguments.is_none() => 211 { 212 return InitKind::Zeroing; 213 } 214 _ => {} 215 }, 216 _ => {} 217 } 218 dcx.error( 219 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 220 "expected nothing or `..Zeroable::init_zeroed()`.", 221 ); 222 InitKind::Normal 223 } 224 225 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 226 fn init_fields( 227 fields: &Punctuated<InitializerField, Token![,]>, 228 pinned: bool, 229 data: &Ident, 230 slot: &Ident, 231 ) -> TokenStream { 232 let mut guards = vec![]; 233 let mut guard_attrs = vec![]; 234 let mut res = TokenStream::new(); 235 for InitializerField { attrs, kind } in fields { 236 let cfgs = { 237 let mut cfgs = attrs.clone(); 238 cfgs.retain(|attr| attr.path().is_ident("cfg")); 239 cfgs 240 }; 241 let init = match kind { 242 InitializerKind::Value { ident, value } => { 243 let mut value_ident = ident.clone(); 244 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 245 // Setting the span of `value_ident` to `value`'s span improves error messages 246 // when the type of `value` is wrong. 247 value_ident.set_span(value.span()); 248 quote!(let #value_ident = #value;) 249 }); 250 // Again span for better diagnostics 251 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 252 quote! { 253 #(#attrs)* 254 { 255 #value_prep 256 // SAFETY: TODO 257 unsafe { #write(&raw mut (*#slot).#ident, #value_ident) }; 258 } 259 } 260 } 261 InitializerKind::Init { ident, value, .. } => { 262 // Again span for better diagnostics 263 let init = format_ident!("init", span = value.span()); 264 let value_init = if pinned { 265 quote! { 266 // SAFETY: 267 // - `slot` is valid, because we are inside of an initializer closure, we 268 // return when an error/panic occurs. 269 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 270 // for `#ident`. 271 unsafe { #data.#ident(&raw mut (*#slot).#ident, #init)? }; 272 } 273 } else { 274 quote! { 275 // SAFETY: `slot` is valid, because we are inside of an initializer 276 // closure, we return when an error/panic occurs. 277 unsafe { 278 ::pin_init::Init::__init( 279 #init, 280 &raw mut (*#slot).#ident, 281 )? 282 }; 283 } 284 }; 285 quote! { 286 #(#attrs)* 287 { 288 let #init = #value; 289 #value_init 290 } 291 } 292 } 293 InitializerKind::Code { block: value, .. } => quote! { 294 #(#attrs)* 295 #[allow(unused_braces)] 296 #value 297 }, 298 }; 299 res.extend(init); 300 if let Some(ident) = kind.ident() { 301 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 302 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 303 304 // NOTE: The reference is derived from the guard so that it only lives as long as the 305 // guard does and cannot escape the scope. If it's created via `&mut (*#slot).#ident` 306 // like the unaligned field guard, it will become effectively `'static`. 307 let accessor = if pinned { 308 let project_ident = format_ident!("__project_{ident}"); 309 quote! { 310 // SAFETY: the initialization is pinned. 311 unsafe { #data.#project_ident(#guard.let_binding()) } 312 } 313 } else { 314 quote! { 315 #guard.let_binding() 316 } 317 }; 318 319 res.extend(quote! { 320 #(#cfgs)* 321 // Create the drop guard. 322 // 323 // SAFETY: 324 // - `&raw mut (*slot).#ident` is valid. 325 // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned. 326 // - `(*slot).#ident` has been initialized above. 327 // - We only need the ownership to the pointee back when initialization has 328 // succeeded, where we `forget` the guard. 329 let mut #guard = unsafe { 330 ::pin_init::__internal::DropGuard::new( 331 &raw mut (*slot).#ident 332 ) 333 }; 334 335 #(#cfgs)* 336 #[allow(unused_variables)] 337 let #ident = #accessor; 338 }); 339 guards.push(guard); 340 guard_attrs.push(cfgs); 341 } 342 } 343 quote! { 344 #res 345 // If execution reaches this point, all fields have been initialized. Therefore we can now 346 // dismiss the guards by forgetting them. 347 #( 348 #(#guard_attrs)* 349 ::core::mem::forget(#guards); 350 )* 351 } 352 } 353 354 /// Generate the check for ensuring that every field has been initialized and aligned. 355 fn make_field_check( 356 fields: &Punctuated<InitializerField, Token![,]>, 357 init_kind: InitKind, 358 path: &Path, 359 ) -> TokenStream { 360 let field_attrs: Vec<_> = fields 361 .iter() 362 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)) 363 .collect(); 364 let field_name: Vec<_> = fields.iter().filter_map(|f| f.kind.ident()).collect(); 365 let zeroing_trailer = match init_kind { 366 InitKind::Normal => None, 367 InitKind::Zeroing => Some(quote! { 368 ..::core::mem::zeroed() 369 }), 370 }; 371 quote! { 372 #[allow(unreachable_code, clippy::diverging_sub_expression)] 373 // We use unreachable code to perform field checks. They're still checked by the compiler. 374 // SAFETY: this code is never executed. 375 let _ = || unsafe { 376 // Create references to ensure that the initialized field is properly aligned. 377 // Unaligned fields will cause the compiler to emit E0793. We do not support 378 // unaligned fields since `Init::__init` requires an aligned pointer; the call to 379 // `ptr::write` for value-initialization case has the same requirement. 380 #( 381 #(#field_attrs)* 382 let _ = &(*slot).#field_name; 383 )* 384 385 // If the zeroing trailer is not present, this checks that all fields have been 386 // mentioned exactly once. If the zeroing trailer is present, all missing fields will be 387 // zeroed, so this checks that all fields have been mentioned at most once. The use of 388 // struct initializer will still generate very natural error messages for any misuse. 389 ::core::ptr::write(slot, #path { 390 #( 391 #(#field_attrs)* 392 #field_name: ::core::panic!(), 393 )* 394 #zeroing_trailer 395 }) 396 }; 397 } 398 } 399 400 impl Parse for Initializer { 401 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 402 let attrs = input.call(Attribute::parse_outer)?; 403 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 404 let path = input.parse()?; 405 let content; 406 let brace_token = braced!(content in input); 407 let mut fields = Punctuated::new(); 408 loop { 409 let lh = content.lookahead1(); 410 if lh.peek(End) || lh.peek(Token![..]) { 411 break; 412 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 413 fields.push_value(content.parse()?); 414 let lh = content.lookahead1(); 415 if lh.peek(End) { 416 break; 417 } else if lh.peek(Token![,]) { 418 fields.push_punct(content.parse()?); 419 } else { 420 return Err(lh.error()); 421 } 422 } else { 423 return Err(lh.error()); 424 } 425 } 426 let rest = content 427 .peek(Token![..]) 428 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 429 .transpose()?; 430 let error = input 431 .peek(Token![?]) 432 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 433 .transpose()?; 434 let attrs = attrs 435 .into_iter() 436 .map(|a| { 437 if a.path().is_ident("default_error") { 438 a.parse_args::<DefaultErrorAttribute>() 439 .map(InitializerAttribute::DefaultError) 440 } else { 441 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 442 } 443 }) 444 .collect::<Result<Vec<_>, _>>()?; 445 Ok(Self { 446 attrs, 447 this, 448 path, 449 brace_token, 450 fields, 451 rest, 452 error, 453 }) 454 } 455 } 456 457 impl Parse for DefaultErrorAttribute { 458 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 459 Ok(Self { ty: input.parse()? }) 460 } 461 } 462 463 impl Parse for This { 464 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 465 Ok(Self { 466 _and_token: input.parse()?, 467 ident: input.parse()?, 468 _in_token: input.parse()?, 469 }) 470 } 471 } 472 473 impl Parse for InitializerField { 474 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 475 let attrs = input.call(Attribute::parse_outer)?; 476 Ok(Self { 477 attrs, 478 kind: input.parse()?, 479 }) 480 } 481 } 482 483 impl Parse for InitializerKind { 484 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 485 let lh = input.lookahead1(); 486 if lh.peek(Token![_]) { 487 Ok(Self::Code { 488 _underscore_token: input.parse()?, 489 _colon_token: input.parse()?, 490 block: input.parse()?, 491 }) 492 } else if lh.peek(Ident) { 493 let ident = input.parse()?; 494 let lh = input.lookahead1(); 495 if lh.peek(Token![<-]) { 496 Ok(Self::Init { 497 ident, 498 _left_arrow_token: input.parse()?, 499 value: input.parse()?, 500 }) 501 } else if lh.peek(Token![:]) { 502 Ok(Self::Value { 503 ident, 504 value: Some((input.parse()?, input.parse()?)), 505 }) 506 } else if lh.peek(Token![,]) || lh.peek(End) { 507 Ok(Self::Value { ident, value: None }) 508 } else { 509 Err(lh.error()) 510 } 511 } else { 512 Err(lh.error()) 513 } 514 } 515 } 516