1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use quote::{format_ident, quote, quote_spanned}; 5 use syn::{ 6 braced, 7 parse::{End, Parse}, 8 parse_quote, 9 punctuated::Punctuated, 10 spanned::Spanned, 11 token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, 12 }; 13 14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; 15 16 pub(crate) struct Initializer { 17 attrs: Vec<InitializerAttribute>, 18 this: Option<This>, 19 path: Path, 20 brace_token: token::Brace, 21 fields: Punctuated<InitializerField, Token![,]>, 22 rest: Option<(Token![..], Expr)>, 23 error: Option<(Token![?], Type)>, 24 } 25 26 struct This { 27 _and_token: Token![&], 28 ident: Ident, 29 _in_token: Token![in], 30 } 31 32 enum InitializerField { 33 Value { 34 ident: Ident, 35 value: Option<(Token![:], Expr)>, 36 }, 37 Init { 38 ident: Ident, 39 _left_arrow_token: Token![<-], 40 value: Expr, 41 }, 42 Code { 43 _underscore_token: Token![_], 44 _colon_token: Token![:], 45 block: Block, 46 }, 47 } 48 49 impl InitializerField { 50 fn ident(&self) -> Option<&Ident> { 51 match self { 52 Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), 53 Self::Code { .. } => None, 54 } 55 } 56 } 57 58 enum InitializerAttribute { 59 DefaultError(DefaultErrorAttribute), 60 } 61 62 struct DefaultErrorAttribute { 63 ty: Box<Type>, 64 } 65 66 pub(crate) fn expand( 67 Initializer { 68 attrs, 69 this, 70 path, 71 brace_token, 72 fields, 73 rest, 74 error, 75 }: Initializer, 76 default_error: Option<&'static str>, 77 pinned: bool, 78 dcx: &mut DiagCtxt, 79 ) -> Result<TokenStream, ErrorGuaranteed> { 80 let error = error.map_or_else( 81 || { 82 if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { 83 #[expect(irrefutable_let_patterns)] 84 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { 85 Some(ty.clone()) 86 } else { 87 acc 88 } 89 }) { 90 default_error 91 } else if let Some(default_error) = default_error { 92 syn::parse_str(default_error).unwrap() 93 } else { 94 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); 95 parse_quote!(::core::convert::Infallible) 96 } 97 }, 98 |(_, err)| Box::new(err), 99 ); 100 let slot = format_ident!("slot"); 101 let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { 102 ( 103 format_ident!("HasPinData"), 104 format_ident!("PinData"), 105 format_ident!("__pin_data"), 106 format_ident!("pin_init_from_closure"), 107 ) 108 } else { 109 ( 110 format_ident!("HasInitData"), 111 format_ident!("InitData"), 112 format_ident!("__init_data"), 113 format_ident!("init_from_closure"), 114 ) 115 }; 116 let init_kind = get_init_kind(rest, dcx); 117 let zeroable_check = match init_kind { 118 InitKind::Normal => quote!(), 119 InitKind::Zeroing => quote! { 120 // The user specified `..Zeroable::zeroed()` at the end of the list of fields. 121 // Therefore we check if the struct implements `Zeroable` and then zero the memory. 122 // This allows us to also remove the check that all fields are present (since we 123 // already set the memory to zero and that is a valid bit pattern). 124 fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) 125 where T: ::pin_init::Zeroable 126 {} 127 // Ensure that the struct is indeed `Zeroable`. 128 assert_zeroable(#slot); 129 // SAFETY: The type implements `Zeroable` by the check above. 130 unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; 131 }, 132 }; 133 let this = match this { 134 None => quote!(), 135 Some(This { ident, .. }) => quote! { 136 // Create the `this` so it can be referenced by the user inside of the 137 // expressions creating the individual fields. 138 let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; 139 }, 140 }; 141 // `mixed_site` ensures that the data is not accessible to the user-controlled code. 142 let data = Ident::new("__data", Span::mixed_site()); 143 let init_fields = init_fields(&fields, pinned, &data, &slot); 144 let field_check = make_field_check(&fields, init_kind, &path); 145 Ok(quote! {{ 146 // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return 147 // type and shadow it later when we insert the arbitrary user code. That way there will be 148 // no possibility of returning without `unsafe`. 149 struct __InitOk; 150 151 // Get the data about fields from the supplied type. 152 // SAFETY: TODO 153 let #data = unsafe { 154 use ::pin_init::__internal::#has_data_trait; 155 // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit 156 // generics (which need to be present with that syntax). 157 #path::#get_data() 158 }; 159 // Ensure that `#data` really is of type `#data` and help with type inference: 160 let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( 161 #data, 162 move |slot| { 163 { 164 // Shadow the structure so it cannot be used to return early. 165 struct __InitOk; 166 #zeroable_check 167 #this 168 #init_fields 169 #field_check 170 } 171 Ok(__InitOk) 172 } 173 ); 174 let init = move |slot| -> ::core::result::Result<(), #error> { 175 init(slot).map(|__InitOk| ()) 176 }; 177 // SAFETY: TODO 178 let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; 179 init 180 }}) 181 } 182 183 enum InitKind { 184 Normal, 185 Zeroing, 186 } 187 188 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { 189 let Some((dotdot, expr)) = rest else { 190 return InitKind::Normal; 191 }; 192 match &expr { 193 Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { 194 Expr::Path(ExprPath { 195 attrs, 196 qself: None, 197 path: 198 Path { 199 leading_colon: None, 200 segments, 201 }, 202 }) if attrs.is_empty() 203 && segments.len() == 2 204 && segments[0].ident == "Zeroable" 205 && segments[0].arguments.is_none() 206 && segments[1].ident == "init_zeroed" 207 && segments[1].arguments.is_none() => 208 { 209 return InitKind::Zeroing; 210 } 211 _ => {} 212 }, 213 _ => {} 214 } 215 dcx.error( 216 dotdot.span().join(expr.span()).unwrap_or(expr.span()), 217 "expected nothing or `..Zeroable::init_zeroed()`.", 218 ); 219 InitKind::Normal 220 } 221 222 /// Generate the code that initializes the fields of the struct using the initializers in `field`. 223 fn init_fields( 224 fields: &Punctuated<InitializerField, Token![,]>, 225 pinned: bool, 226 data: &Ident, 227 slot: &Ident, 228 ) -> TokenStream { 229 let mut guards = vec![]; 230 let mut res = TokenStream::new(); 231 for field in fields { 232 let init = match field { 233 InitializerField::Value { ident, value } => { 234 let mut value_ident = ident.clone(); 235 let value_prep = value.as_ref().map(|value| &value.1).map(|value| { 236 // Setting the span of `value_ident` to `value`'s span improves error messages 237 // when the type of `value` is wrong. 238 value_ident.set_span(value.span()); 239 quote!(let #value_ident = #value;) 240 }); 241 // Again span for better diagnostics 242 let write = quote_spanned!(ident.span()=> ::core::ptr::write); 243 let accessor = if pinned { 244 let project_ident = format_ident!("__project_{ident}"); 245 quote! { 246 // SAFETY: TODO 247 unsafe { #data.#project_ident(&mut (*#slot).#ident) } 248 } 249 } else { 250 quote! { 251 // SAFETY: TODO 252 unsafe { &mut (*#slot).#ident } 253 } 254 }; 255 quote! { 256 { 257 #value_prep 258 // SAFETY: TODO 259 unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; 260 } 261 #[allow(unused_variables)] 262 let #ident = #accessor; 263 } 264 } 265 InitializerField::Init { ident, value, .. } => { 266 // Again span for better diagnostics 267 let init = format_ident!("init", span = value.span()); 268 if pinned { 269 let project_ident = format_ident!("__project_{ident}"); 270 quote! { 271 { 272 let #init = #value; 273 // SAFETY: 274 // - `slot` is valid, because we are inside of an initializer closure, we 275 // return when an error/panic occurs. 276 // - We also use `#data` to require the correct trait (`Init` or `PinInit`) 277 // for `#ident`. 278 unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; 279 } 280 // SAFETY: TODO 281 #[allow(unused_variables)] 282 let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) }; 283 } 284 } else { 285 quote! { 286 { 287 let #init = #value; 288 // SAFETY: `slot` is valid, because we are inside of an initializer 289 // closure, we return when an error/panic occurs. 290 unsafe { 291 ::pin_init::Init::__init( 292 #init, 293 ::core::ptr::addr_of_mut!((*#slot).#ident), 294 )? 295 }; 296 } 297 // SAFETY: TODO 298 #[allow(unused_variables)] 299 let #ident = unsafe { &mut (*#slot).#ident }; 300 } 301 } 302 } 303 InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value), 304 }; 305 res.extend(init); 306 if let Some(ident) = field.ident() { 307 // `mixed_site` ensures that the guard is not accessible to the user-controlled code. 308 let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); 309 guards.push(guard.clone()); 310 res.extend(quote! { 311 // Create the drop guard: 312 // 313 // We rely on macro hygiene to make it impossible for users to access this local 314 // variable. 315 // SAFETY: We forget the guard later when initialization has succeeded. 316 let #guard = unsafe { 317 ::pin_init::__internal::DropGuard::new( 318 ::core::ptr::addr_of_mut!((*slot).#ident) 319 ) 320 }; 321 }); 322 } 323 } 324 quote! { 325 #res 326 // If execution reaches this point, all fields have been initialized. Therefore we can now 327 // dismiss the guards by forgetting them. 328 #(::core::mem::forget(#guards);)* 329 } 330 } 331 332 /// Generate the check for ensuring that every field has been initialized. 333 fn make_field_check( 334 fields: &Punctuated<InitializerField, Token![,]>, 335 init_kind: InitKind, 336 path: &Path, 337 ) -> TokenStream { 338 let fields = fields.iter().filter_map(|f| f.ident()); 339 match init_kind { 340 InitKind::Normal => quote! { 341 // We use unreachable code to ensure that all fields have been mentioned exactly once, 342 // this struct initializer will still be type-checked and complain with a very natural 343 // error message if a field is forgotten/mentioned more than once. 344 #[allow(unreachable_code, clippy::diverging_sub_expression)] 345 // SAFETY: this code is never executed. 346 let _ = || unsafe { 347 ::core::ptr::write(slot, #path { 348 #( 349 #fields: ::core::panic!(), 350 )* 351 }) 352 }; 353 }, 354 InitKind::Zeroing => quote! { 355 // We use unreachable code to ensure that all fields have been mentioned at most once. 356 // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will 357 // be zeroed. This struct initializer will still be type-checked and complain with a 358 // very natural error message if a field is mentioned more than once, or doesn't exist. 359 #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] 360 // SAFETY: this code is never executed. 361 let _ = || unsafe { 362 let mut zeroed = ::core::mem::zeroed(); 363 // We have to use type inference here to make zeroed have the correct type. This 364 // does not get executed, so it has no effect. 365 ::core::ptr::write(slot, zeroed); 366 zeroed = ::core::mem::zeroed(); 367 ::core::ptr::write(slot, #path { 368 #( 369 #fields: ::core::panic!(), 370 )* 371 ..zeroed 372 }) 373 }; 374 }, 375 } 376 } 377 378 impl Parse for Initializer { 379 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 380 let attrs = input.call(Attribute::parse_outer)?; 381 let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; 382 let path = input.parse()?; 383 let content; 384 let brace_token = braced!(content in input); 385 let mut fields = Punctuated::new(); 386 loop { 387 let lh = content.lookahead1(); 388 if lh.peek(End) || lh.peek(Token![..]) { 389 break; 390 } else if lh.peek(Ident) || lh.peek(Token![_]) { 391 fields.push_value(content.parse()?); 392 let lh = content.lookahead1(); 393 if lh.peek(End) { 394 break; 395 } else if lh.peek(Token![,]) { 396 fields.push_punct(content.parse()?); 397 } else { 398 return Err(lh.error()); 399 } 400 } else { 401 return Err(lh.error()); 402 } 403 } 404 let rest = content 405 .peek(Token![..]) 406 .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) 407 .transpose()?; 408 let error = input 409 .peek(Token![?]) 410 .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) 411 .transpose()?; 412 let attrs = attrs 413 .into_iter() 414 .map(|a| { 415 if a.path().is_ident("default_error") { 416 a.parse_args::<DefaultErrorAttribute>() 417 .map(InitializerAttribute::DefaultError) 418 } else { 419 Err(syn::Error::new_spanned(a, "unknown initializer attribute")) 420 } 421 }) 422 .collect::<Result<Vec<_>, _>>()?; 423 Ok(Self { 424 attrs, 425 this, 426 path, 427 brace_token, 428 fields, 429 rest, 430 error, 431 }) 432 } 433 } 434 435 impl Parse for DefaultErrorAttribute { 436 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 437 Ok(Self { ty: input.parse()? }) 438 } 439 } 440 441 impl Parse for This { 442 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 443 Ok(Self { 444 _and_token: input.parse()?, 445 ident: input.parse()?, 446 _in_token: input.parse()?, 447 }) 448 } 449 } 450 451 impl Parse for InitializerField { 452 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { 453 let lh = input.lookahead1(); 454 if lh.peek(Token![_]) { 455 Ok(Self::Code { 456 _underscore_token: input.parse()?, 457 _colon_token: input.parse()?, 458 block: input.parse()?, 459 }) 460 } else if lh.peek(Ident) { 461 let ident = input.parse()?; 462 let lh = input.lookahead1(); 463 if lh.peek(Token![<-]) { 464 Ok(Self::Init { 465 ident, 466 _left_arrow_token: input.parse()?, 467 value: input.parse()?, 468 }) 469 } else if lh.peek(Token![:]) { 470 Ok(Self::Value { 471 ident, 472 value: Some((input.parse()?, input.parse()?)), 473 }) 474 } else if lh.peek(Token![,]) || lh.peek(End) { 475 Ok(Self::Value { ident, value: None }) 476 } else { 477 Err(lh.error()) 478 } 479 } else { 480 Err(lh.error()) 481 } 482 } 483 } 484