1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 struct InitializerField { 33 attrs: Vec<Attribute>, 34 kind: InitializerKind, 35 } 36 37 enum InitializerKind { 38 Value { 39 ident: Ident, 40 value: Option<(Token![:], Expr)>, 41 }, 42 Init { 43 ident: Ident, 44 _left_arrow_token: Token![<-], 45 value: Expr, 46 }, 47 Code { 48 _underscore_token: Token![_], 49 _colon_token: Token![:], 50 block: Block, 51 }, 52 } 53 54 impl InitializerKind { 55 fn ident(&self) -> Option<&Ident> { 56 match self { 57 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 58 Self::Code { .. } => None, 59 } 60 } 61 } 62 63 enum InitializerAttribute { 64 DefaultError(DefaultErrorAttribute), 65 } 66 67 struct DefaultErrorAttribute { 68 ty: Box<Type>, 69 } 70 71 pub(crate) fn expand( 72 Initializer { 73 attrs, 74 this, 75 path, 76 brace_token, 77 fields, 78 rest, 79 error, 80 }: Initializer, 81 default_error: Option<&'static str>, 82 pinned: bool, 83 dcx: &mut DiagCtxt, 84 ) -> Result<TokenStream, ErrorGuaranteed> { 85 let error = error.map_or_else( 86 || { 87 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 88 #[expect(irrefutable_let_patterns)] 89 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 90 Some(ty.clone()) 91 } else { 92 acc 93 } 94 }) { 95 default_error 96 } else if let Some(default_error) = default_error { 97 syn::parse_str(default_error).unwrap() 98 } else { 99 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 100 parse_quote!(::core::convert::Infallible) 101 } 102 }, 103 |(_, err)| Box::new(err), 104 ); 105 let slot = format_ident!("slot"); 106 let (has_data_trait, get_data, init_from_closure) = if pinned { 107 ( 108 format_ident!("HasPinData"), 109 format_ident!("__pin_data"), 110 format_ident!("pin_init_from_closure"), 111 ) 112 } else { 113 ( 114 format_ident!("HasInitData"), 115 format_ident!("__init_data"), 116 format_ident!("init_from_closure"), 117 ) 118 }; 119 let init_kind = get_init_kind(rest, dcx); 120 let zeroable_check = match init_kind { 121 InitKind::Normal => quote!(), 122 InitKind::Zeroing => quote! { 123 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 124 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 125 // This allows us to also remove the check that all fields are present (since we 126 // already set the memory to zero and that is a valid bit pattern). 127 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 128 where T: ::pin_init::Zeroable 129 {} 130 // Ensure that the struct is indeed `Zeroable`. 131 assert_zeroable(#slot); 132 // SAFETY: The type implements `Zeroable` by the check above. 133 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 134 }, 135 }; 136 let this = match this { 137 None => quote!(), 138 Some(This { ident, .. }) => quote! { 139 // Create the `this` so it can be referenced by the user inside of the 140 // expressions creating the individual fields. 141 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 142 }, 143 }; 144 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 145 let data = Ident::new("__data", Span::mixed_site()); 146 let init_fields = init_fields(&fields, pinned, &data, &slot); 147 let field_check = make_field_check(&fields, init_kind, &path); 148 Ok(quote! {{ 149 // Get the data about fields from the supplied type. 150 // SAFETY: TODO 151 let #data = unsafe { 152 use ::pin_init::__internal::#has_data_trait; 153 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 154 // generics (which need to be present with that syntax). 155 #path::#get_data() 156 }; 157 // Ensure that `#data` really is of type `#data` and help with type inference: 158 let init = #data.__make_closure::<_, #error>( 159 move |slot| { 160 #zeroable_check 161 #this 162 #init_fields 163 #field_check 164 // SAFETY: we are the `init!` macro that is allowed to call this. 165 Ok(unsafe { ::pin_init::__internal::InitOk::new() }) 166 } 167 ); 168 let init = move |slot| -> ::core::result::Result<(), #error> { 169 init(slot).map(|__InitOk| ()) 170 }; 171 // SAFETY: TODO 172 unsafe { ::pin_init::#init_from_closure::<_, #error>(init) } 173 }}) 174 } 175 176 enum InitKind { 177 Normal, 178 Zeroing, 179 } 180 181 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 182 let Some((dotdot, expr)) = rest else { 183 return InitKind::Normal; 184 }; 185 match &expr { 186 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 187 Expr::Path(ExprPath { 188 attrs, 189 qself: None, 190 path: 191 Path { 192 leading_colon: None, 193 segments, 194 }, 195 }) if attrs.is_empty() 196 && segments.len() == 2 197 && segments[0].ident == "Zeroable" 198 && segments[0].arguments.is_none() 199 && segments[1].ident == "init_zeroed" 200 && segments[1].arguments.is_none() => 201 { 202 return InitKind::Zeroing; 203 } 204 _ => {} 205 }, 206 _ => {} 207 } 208 dcx.error( 209 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 210 "expected nothing or `..Zeroable::init_zeroed()`.", 211 ); 212 InitKind::Normal 213 } 214 215 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 216 fn init_fields( 217 fields: &Punctuated<InitializerField, Token![,]>, 218 pinned: bool, 219 data: &Ident, 220 slot: &Ident, 221 ) -> TokenStream { 222 let mut guards = vec![]; 223 let mut guard_attrs = vec![]; 224 let mut res = TokenStream::new(); 225 for InitializerField { attrs, kind } in fields { 226 let cfgs = { 227 let mut cfgs = attrs.clone(); 228 cfgs.retain(|attr| attr.path().is_ident("cfg")); 229 cfgs 230 }; 231 232 let ident = match kind { 233 InitializerKind::Value { ident, .. } => ident, 234 InitializerKind::Init { ident, .. } => ident, 235 InitializerKind::Code { block, .. } => { 236 res.extend(quote! { 237 #(#attrs)* 238 #[allow(unused_braces)] 239 #block 240 }); 241 continue; 242 } 243 }; 244 245 let slot = if pinned { 246 quote! { 247 // SAFETY: 248 // - `slot` is valid and properly aligned. 249 // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned. 250 // - `make_field_check` prevents `#ident` from being used twice, therefore 251 // `(*slot).#ident` is exclusively accessed and has not been initialized. 252 (unsafe { #data.#ident(#slot) }) 253 } 254 } else { 255 quote! { 256 // For `init!()` macro, everything is unpinned. 257 // SAFETY: 258 // - `&raw mut (*slot).#ident` is valid. 259 // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned. 260 // - `make_field_check` prevents `#ident` from being used twice, therefore 261 // `(*slot).#ident` is exclusively accessed and has not been initialized. 262 (unsafe { 263 ::pin_init::__internal::Slot::<::pin_init::__internal::Unpinned, _>::new( 264 &raw mut (*#slot).#ident 265 ) 266 }) 267 } 268 }; 269 270 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 271 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 272 273 let init = match kind { 274 InitializerKind::Value { ident, value } => { 275 let value = value 276 .as_ref() 277 .map(|(_, value)| quote!(#value)) 278 .unwrap_or_else(|| quote!(#ident)); 279 280 quote! { 281 #(#attrs)* 282 let mut #guard = #slot.write(#value); 283 284 } 285 } 286 InitializerKind::Init { value, .. } => { 287 quote! { 288 #(#attrs)* 289 let mut #guard = #slot.init(#value)?; 290 } 291 } 292 InitializerKind::Code { .. } => unreachable!(), 293 }; 294 295 res.extend(quote! { 296 #init 297 298 #(#cfgs)* 299 #[allow(unused_variables)] 300 let #ident = #guard.let_binding(); 301 }); 302 303 guards.push(guard); 304 guard_attrs.push(cfgs); 305 } 306 quote! { 307 #res 308 // If execution reaches this point, all fields have been initialized. Therefore we can now 309 // dismiss the guards by forgetting them. 310 #( 311 #(#guard_attrs)* 312 ::core::mem::forget(#guards); 313 )* 314 } 315 } 316 317 /// Generate the check for ensuring that every field has been initialized and aligned. 318 fn make_field_check( 319 fields: &Punctuated<InitializerField, Token![,]>, 320 init_kind: InitKind, 321 path: &Path, 322 ) -> TokenStream { 323 let field_attrs: Vec<_> = fields 324 .iter() 325 .filter_map(|f| f.kind.ident().map(|_| &f.attrs)) 326 .collect(); 327 let field_name: Vec<_> = fields.iter().filter_map(|f| f.kind.ident()).collect(); 328 let zeroing_trailer = match init_kind { 329 InitKind::Normal => None, 330 InitKind::Zeroing => Some(quote! { 331 ..::core::mem::zeroed() 332 }), 333 }; 334 quote! { 335 #[allow(unreachable_code, clippy::diverging_sub_expression)] 336 // We use unreachable code to perform field checks. They're still checked by the compiler. 337 // SAFETY: this code is never executed. 338 let _ = || unsafe { 339 // Create references to ensure that the initialized field is properly aligned. 340 // Unaligned fields will cause the compiler to emit E0793. We do not support 341 // unaligned fields since `Init::__init` requires an aligned pointer; the call to 342 // `ptr::write` for value-initialization case has the same requirement. 343 #( 344 #(#field_attrs)* 345 let _ = &(*slot).#field_name; 346 )* 347 348 // If the zeroing trailer is not present, this checks that all fields have been 349 // mentioned exactly once. If the zeroing trailer is present, all missing fields will be 350 // zeroed, so this checks that all fields have been mentioned at most once. The use of 351 // struct initializer will still generate very natural error messages for any misuse. 352 ::core::ptr::write(slot, #path { 353 #( 354 #(#field_attrs)* 355 #field_name: ::core::panic!(), 356 )* 357 #zeroing_trailer 358 }) 359 }; 360 } 361 } 362 363 impl Parse for Initializer { 364 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 365 let attrs = input.call(Attribute::parse_outer)?; 366 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 367 let path = input.parse()?; 368 let content; 369 let brace_token = braced!(content in input); 370 let mut fields = Punctuated::new(); 371 loop { 372 let lh = content.lookahead1(); 373 if lh.peek(End) || lh.peek(Token![..]) { 374 break; 375 } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { 376 fields.push_value(content.parse()?); 377 let lh = content.lookahead1(); 378 if lh.peek(End) { 379 break; 380 } else if lh.peek(Token![,]) { 381 fields.push_punct(content.parse()?); 382 } else { 383 return Err(lh.error()); 384 } 385 } else { 386 return Err(lh.error()); 387 } 388 } 389 let rest = content 390 .peek(Token![..]) 391 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 392 .transpose()?; 393 let error = input 394 .peek(Token![?]) 395 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 396 .transpose()?; 397 let attrs = attrs 398 .into_iter() 399 .map(|a| { 400 if a.path().is_ident("default_error") { 401 a.parse_args::<DefaultErrorAttribute>() 402 .map(InitializerAttribute::DefaultError) 403 } else { 404 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 405 } 406 }) 407 .collect::<Result<Vec<_>, _>>()?; 408 Ok(Self { 409 attrs, 410 this, 411 path, 412 brace_token, 413 fields, 414 rest, 415 error, 416 }) 417 } 418 } 419 420 impl Parse for DefaultErrorAttribute { 421 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 422 Ok(Self { ty: input.parse()? }) 423 } 424 } 425 426 impl Parse for This { 427 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 428 Ok(Self { 429 _and_token: input.parse()?, 430 ident: input.parse()?, 431 _in_token: input.parse()?, 432 }) 433 } 434 } 435 436 impl Parse for InitializerField { 437 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 438 let attrs = input.call(Attribute::parse_outer)?; 439 Ok(Self { 440 attrs, 441 kind: input.parse()?, 442 }) 443 } 444 } 445 446 impl Parse for InitializerKind { 447 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 448 let lh = input.lookahead1(); 449 if lh.peek(Token![_]) { 450 Ok(Self::Code { 451 _underscore_token: input.parse()?, 452 _colon_token: input.parse()?, 453 block: input.parse()?, 454 }) 455 } else if lh.peek(Ident) { 456 let ident = input.parse()?; 457 let lh = input.lookahead1(); 458 if lh.peek(Token![<-]) { 459 Ok(Self::Init { 460 ident, 461 _left_arrow_token: input.parse()?, 462 value: input.parse()?, 463 }) 464 } else if lh.peek(Token![:]) { 465 Ok(Self::Value { 466 ident, 467 value: Some((input.parse()?, input.parse()?)), 468 }) 469 } else if lh.peek(Token![,]) || lh.peek(End) { 470 Ok(Self::Value { ident, value: None }) 471 } else { 472 Err(lh.error()) 473 } 474 } else { 475 Err(lh.error()) 476 } 477 } 478 } 479