xref: /linux/rust/pin-init/internal/src/init.rs (revision fdbaa9d2b78e0da9e1aeb303bbdc3adfe6d8e749)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65 }
66 
67 struct DefaultErrorAttribute {
68     ty: Box<Type>,
69 }
70 
71 pub(crate) fn expand(
72     Initializer {
73         attrs,
74         this,
75         path,
76         brace_token,
77         fields,
78         rest,
79         error,
80     }: Initializer,
81     default_error: Option<&'static str>,
82     pinned: bool,
83     dcx: &mut DiagCtxt,
84 ) -> Result<TokenStream, ErrorGuaranteed> {
85     let error = error.map_or_else(
86         || {
87             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                 #[expect(irrefutable_let_patterns)]
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("PinData"),
110             format_ident!("__pin_data"),
111             format_ident!("pin_init_from_closure"),
112         )
113     } else {
114         (
115             format_ident!("HasInitData"),
116             format_ident!("InitData"),
117             format_ident!("__init_data"),
118             format_ident!("init_from_closure"),
119         )
120     };
121     let init_kind = get_init_kind(rest, dcx);
122     let zeroable_check = match init_kind {
123         InitKind::Normal => quote!(),
124         InitKind::Zeroing => quote! {
125             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127             // This allows us to also remove the check that all fields are present (since we
128             // already set the memory to zero and that is a valid bit pattern).
129             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130             where T: ::pin_init::Zeroable
131             {}
132             // Ensure that the struct is indeed `Zeroable`.
133             assert_zeroable(#slot);
134             // SAFETY: The type implements `Zeroable` by the check above.
135             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136         },
137     };
138     let this = match this {
139         None => quote!(),
140         Some(This { ident, .. }) => quote! {
141             // Create the `this` so it can be referenced by the user inside of the
142             // expressions creating the individual fields.
143             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144         },
145     };
146     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147     let data = Ident::new("__data", Span::mixed_site());
148     let init_fields = init_fields(&fields, pinned, &data, &slot);
149     let field_check = make_field_check(&fields, init_kind, &path);
150     Ok(quote! {{
151         // Get the data about fields from the supplied type.
152         // SAFETY: TODO
153         let #data = unsafe {
154             use ::pin_init::__internal::#has_data_trait;
155             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
156             // generics (which need to be present with that syntax).
157             #path::#get_data()
158         };
159         // Ensure that `#data` really is of type `#data` and help with type inference:
160         let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
161             #data,
162             move |slot| {
163                 #zeroable_check
164                 #this
165                 #init_fields
166                 #field_check
167                 // SAFETY: we are the `init!` macro that is allowed to call this.
168                 Ok(unsafe { ::pin_init::__internal::InitOk::new() })
169             }
170         );
171         let init = move |slot| -> ::core::result::Result<(), #error> {
172             init(slot).map(|__InitOk| ())
173         };
174         // SAFETY: TODO
175         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
176         init
177     }})
178 }
179 
180 enum InitKind {
181     Normal,
182     Zeroing,
183 }
184 
185 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
186     let Some((dotdot, expr)) = rest else {
187         return InitKind::Normal;
188     };
189     match &expr {
190         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
191             Expr::Path(ExprPath {
192                 attrs,
193                 qself: None,
194                 path:
195                     Path {
196                         leading_colon: None,
197                         segments,
198                     },
199             }) if attrs.is_empty()
200                 && segments.len() == 2
201                 && segments[0].ident == "Zeroable"
202                 && segments[0].arguments.is_none()
203                 && segments[1].ident == "init_zeroed"
204                 && segments[1].arguments.is_none() =>
205             {
206                 return InitKind::Zeroing;
207             }
208             _ => {}
209         },
210         _ => {}
211     }
212     dcx.error(
213         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
214         "expected nothing or `..Zeroable::init_zeroed()`.",
215     );
216     InitKind::Normal
217 }
218 
219 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
220 fn init_fields(
221     fields: &Punctuated<InitializerField, Token![,]>,
222     pinned: bool,
223     data: &Ident,
224     slot: &Ident,
225 ) -> TokenStream {
226     let mut guards = vec![];
227     let mut guard_attrs = vec![];
228     let mut res = TokenStream::new();
229     for InitializerField { attrs, kind } in fields {
230         let cfgs = {
231             let mut cfgs = attrs.clone();
232             cfgs.retain(|attr| attr.path().is_ident("cfg"));
233             cfgs
234         };
235         let init = match kind {
236             InitializerKind::Value { ident, value } => {
237                 let mut value_ident = ident.clone();
238                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
239                     // Setting the span of `value_ident` to `value`'s span improves error messages
240                     // when the type of `value` is wrong.
241                     value_ident.set_span(value.span());
242                     quote!(let #value_ident = #value;)
243                 });
244                 // Again span for better diagnostics
245                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
246                 // NOTE: the field accessor ensures that the initialized field is properly aligned.
247                 // Unaligned fields will cause the compiler to emit E0793. We do not support
248                 // unaligned fields since `Init::__init` requires an aligned pointer; the call to
249                 // `ptr::write` below has the same requirement.
250                 let accessor = if pinned {
251                     let project_ident = format_ident!("__project_{ident}");
252                     quote! {
253                         // SAFETY: TODO
254                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
255                     }
256                 } else {
257                     quote! {
258                         // SAFETY: TODO
259                         unsafe { &mut (*#slot).#ident }
260                     }
261                 };
262                 quote! {
263                     #(#attrs)*
264                     {
265                         #value_prep
266                         // SAFETY: TODO
267                         unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
268                     }
269                     #(#cfgs)*
270                     #[allow(unused_variables)]
271                     let #ident = #accessor;
272                 }
273             }
274             InitializerKind::Init { ident, value, .. } => {
275                 // Again span for better diagnostics
276                 let init = format_ident!("init", span = value.span());
277                 // NOTE: the field accessor ensures that the initialized field is properly aligned.
278                 // Unaligned fields will cause the compiler to emit E0793. We do not support
279                 // unaligned fields since `Init::__init` requires an aligned pointer; the call to
280                 // `ptr::write` below has the same requirement.
281                 let (value_init, accessor) = if pinned {
282                     let project_ident = format_ident!("__project_{ident}");
283                     (
284                         quote! {
285                             // SAFETY:
286                             // - `slot` is valid, because we are inside of an initializer closure, we
287                             //   return when an error/panic occurs.
288                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
289                             //   for `#ident`.
290                             unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
291                         },
292                         quote! {
293                             // SAFETY: TODO
294                             unsafe { #data.#project_ident(&mut (*#slot).#ident) }
295                         },
296                     )
297                 } else {
298                     (
299                         quote! {
300                             // SAFETY: `slot` is valid, because we are inside of an initializer
301                             // closure, we return when an error/panic occurs.
302                             unsafe {
303                                 ::pin_init::Init::__init(
304                                     #init,
305                                     ::core::ptr::addr_of_mut!((*#slot).#ident),
306                                 )?
307                             };
308                         },
309                         quote! {
310                             // SAFETY: TODO
311                             unsafe { &mut (*#slot).#ident }
312                         },
313                     )
314                 };
315                 quote! {
316                     #(#attrs)*
317                     {
318                         let #init = #value;
319                         #value_init
320                     }
321                     #(#cfgs)*
322                     #[allow(unused_variables)]
323                     let #ident = #accessor;
324                 }
325             }
326             InitializerKind::Code { block: value, .. } => quote! {
327                 #(#attrs)*
328                 #[allow(unused_braces)]
329                 #value
330             },
331         };
332         res.extend(init);
333         if let Some(ident) = kind.ident() {
334             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
335             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
336             res.extend(quote! {
337                 #(#cfgs)*
338                 // Create the drop guard:
339                 //
340                 // We rely on macro hygiene to make it impossible for users to access this local
341                 // variable.
342                 // SAFETY: We forget the guard later when initialization has succeeded.
343                 let #guard = unsafe {
344                     ::pin_init::__internal::DropGuard::new(
345                         ::core::ptr::addr_of_mut!((*slot).#ident)
346                     )
347                 };
348             });
349             guards.push(guard);
350             guard_attrs.push(cfgs);
351         }
352     }
353     quote! {
354         #res
355         // If execution reaches this point, all fields have been initialized. Therefore we can now
356         // dismiss the guards by forgetting them.
357         #(
358             #(#guard_attrs)*
359             ::core::mem::forget(#guards);
360         )*
361     }
362 }
363 
364 /// Generate the check for ensuring that every field has been initialized.
365 fn make_field_check(
366     fields: &Punctuated<InitializerField, Token![,]>,
367     init_kind: InitKind,
368     path: &Path,
369 ) -> TokenStream {
370     let field_attrs = fields
371         .iter()
372         .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
373     let field_name = fields.iter().filter_map(|f| f.kind.ident());
374     match init_kind {
375         InitKind::Normal => quote! {
376             // We use unreachable code to ensure that all fields have been mentioned exactly once,
377             // this struct initializer will still be type-checked and complain with a very natural
378             // error message if a field is forgotten/mentioned more than once.
379             #[allow(unreachable_code, clippy::diverging_sub_expression)]
380             // SAFETY: this code is never executed.
381             let _ = || unsafe {
382                 ::core::ptr::write(slot, #path {
383                     #(
384                         #(#field_attrs)*
385                         #field_name: ::core::panic!(),
386                     )*
387                 })
388             };
389         },
390         InitKind::Zeroing => quote! {
391             // We use unreachable code to ensure that all fields have been mentioned at most once.
392             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
393             // be zeroed. This struct initializer will still be type-checked and complain with a
394             // very natural error message if a field is mentioned more than once, or doesn't exist.
395             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
396             // SAFETY: this code is never executed.
397             let _ = || unsafe {
398                 ::core::ptr::write(slot, #path {
399                     #(
400                         #(#field_attrs)*
401                         #field_name: ::core::panic!(),
402                     )*
403                     ..::core::mem::zeroed()
404                 })
405             };
406         },
407     }
408 }
409 
410 impl Parse for Initializer {
411     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
412         let attrs = input.call(Attribute::parse_outer)?;
413         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
414         let path = input.parse()?;
415         let content;
416         let brace_token = braced!(content in input);
417         let mut fields = Punctuated::new();
418         loop {
419             let lh = content.lookahead1();
420             if lh.peek(End) || lh.peek(Token![..]) {
421                 break;
422             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
423                 fields.push_value(content.parse()?);
424                 let lh = content.lookahead1();
425                 if lh.peek(End) {
426                     break;
427                 } else if lh.peek(Token![,]) {
428                     fields.push_punct(content.parse()?);
429                 } else {
430                     return Err(lh.error());
431                 }
432             } else {
433                 return Err(lh.error());
434             }
435         }
436         let rest = content
437             .peek(Token![..])
438             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
439             .transpose()?;
440         let error = input
441             .peek(Token![?])
442             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
443             .transpose()?;
444         let attrs = attrs
445             .into_iter()
446             .map(|a| {
447                 if a.path().is_ident("default_error") {
448                     a.parse_args::<DefaultErrorAttribute>()
449                         .map(InitializerAttribute::DefaultError)
450                 } else {
451                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
452                 }
453             })
454             .collect::<Result<Vec<_>, _>>()?;
455         Ok(Self {
456             attrs,
457             this,
458             path,
459             brace_token,
460             fields,
461             rest,
462             error,
463         })
464     }
465 }
466 
467 impl Parse for DefaultErrorAttribute {
468     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
469         Ok(Self { ty: input.parse()? })
470     }
471 }
472 
473 impl Parse for This {
474     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
475         Ok(Self {
476             _and_token: input.parse()?,
477             ident: input.parse()?,
478             _in_token: input.parse()?,
479         })
480     }
481 }
482 
483 impl Parse for InitializerField {
484     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
485         let attrs = input.call(Attribute::parse_outer)?;
486         Ok(Self {
487             attrs,
488             kind: input.parse()?,
489         })
490     }
491 }
492 
493 impl Parse for InitializerKind {
494     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
495         let lh = input.lookahead1();
496         if lh.peek(Token![_]) {
497             Ok(Self::Code {
498                 _underscore_token: input.parse()?,
499                 _colon_token: input.parse()?,
500                 block: input.parse()?,
501             })
502         } else if lh.peek(Ident) {
503             let ident = input.parse()?;
504             let lh = input.lookahead1();
505             if lh.peek(Token![<-]) {
506                 Ok(Self::Init {
507                     ident,
508                     _left_arrow_token: input.parse()?,
509                     value: input.parse()?,
510                 })
511             } else if lh.peek(Token![:]) {
512                 Ok(Self::Value {
513                     ident,
514                     value: Some((input.parse()?, input.parse()?)),
515                 })
516             } else if lh.peek(Token![,]) || lh.peek(End) {
517                 Ok(Self::Value { ident, value: None })
518             } else {
519                 Err(lh.error())
520             }
521         } else {
522             Err(lh.error())
523         }
524     }
525 }
526