xref: /linux/rust/pin-init/internal/src/init.rs (revision e92b2872d0b198a77c0a438c5cdb1c5510762c1b)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65 }
66 
67 struct DefaultErrorAttribute {
68     ty: Box<Type>,
69 }
70 
71 pub(crate) fn expand(
72     Initializer {
73         attrs,
74         this,
75         path,
76         brace_token,
77         fields,
78         rest,
79         error,
80     }: Initializer,
81     default_error: Option<&'static str>,
82     pinned: bool,
83     dcx: &mut DiagCtxt,
84 ) -> Result<TokenStream, ErrorGuaranteed> {
85     let error = error.map_or_else(
86         || {
87             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                 #[expect(irrefutable_let_patterns)]
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("PinData"),
110             format_ident!("__pin_data"),
111             format_ident!("pin_init_from_closure"),
112         )
113     } else {
114         (
115             format_ident!("HasInitData"),
116             format_ident!("InitData"),
117             format_ident!("__init_data"),
118             format_ident!("init_from_closure"),
119         )
120     };
121     let init_kind = get_init_kind(rest, dcx);
122     let zeroable_check = match init_kind {
123         InitKind::Normal => quote!(),
124         InitKind::Zeroing => quote! {
125             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127             // This allows us to also remove the check that all fields are present (since we
128             // already set the memory to zero and that is a valid bit pattern).
129             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130             where T: ::pin_init::Zeroable
131             {}
132             // Ensure that the struct is indeed `Zeroable`.
133             assert_zeroable(#slot);
134             // SAFETY: The type implements `Zeroable` by the check above.
135             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136         },
137     };
138     let this = match this {
139         None => quote!(),
140         Some(This { ident, .. }) => quote! {
141             // Create the `this` so it can be referenced by the user inside of the
142             // expressions creating the individual fields.
143             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144         },
145     };
146     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147     let data = Ident::new("__data", Span::mixed_site());
148     let init_fields = init_fields(&fields, pinned, &data, &slot);
149     let field_check = make_field_check(&fields, init_kind, &path);
150     Ok(quote! {{
151         // Get the data about fields from the supplied type.
152         // SAFETY: TODO
153         let #data = unsafe {
154             use ::pin_init::__internal::#has_data_trait;
155             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
156             // generics (which need to be present with that syntax).
157             #path::#get_data()
158         };
159         // Ensure that `#data` really is of type `#data` and help with type inference:
160         let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
161             #data,
162             move |slot| {
163                 #zeroable_check
164                 #this
165                 #init_fields
166                 #field_check
167                 // SAFETY: we are the `init!` macro that is allowed to call this.
168                 Ok(unsafe { ::pin_init::__internal::InitOk::new() })
169             }
170         );
171         let init = move |slot| -> ::core::result::Result<(), #error> {
172             init(slot).map(|__InitOk| ())
173         };
174         // SAFETY: TODO
175         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
176         // FIXME: this let binding is required to avoid a compiler error (cycle when computing the
177         // opaque type returned by this function) before Rust 1.81. Remove after MSRV bump.
178         #[allow(
179             clippy::let_and_return,
180             reason = "some clippy versions warn about the let binding"
181         )]
182         init
183     }})
184 }
185 
186 enum InitKind {
187     Normal,
188     Zeroing,
189 }
190 
191 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
192     let Some((dotdot, expr)) = rest else {
193         return InitKind::Normal;
194     };
195     match &expr {
196         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
197             Expr::Path(ExprPath {
198                 attrs,
199                 qself: None,
200                 path:
201                     Path {
202                         leading_colon: None,
203                         segments,
204                     },
205             }) if attrs.is_empty()
206                 && segments.len() == 2
207                 && segments[0].ident == "Zeroable"
208                 && segments[0].arguments.is_none()
209                 && segments[1].ident == "init_zeroed"
210                 && segments[1].arguments.is_none() =>
211             {
212                 return InitKind::Zeroing;
213             }
214             _ => {}
215         },
216         _ => {}
217     }
218     dcx.error(
219         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
220         "expected nothing or `..Zeroable::init_zeroed()`.",
221     );
222     InitKind::Normal
223 }
224 
225 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
226 fn init_fields(
227     fields: &Punctuated<InitializerField, Token![,]>,
228     pinned: bool,
229     data: &Ident,
230     slot: &Ident,
231 ) -> TokenStream {
232     let mut guards = vec![];
233     let mut guard_attrs = vec![];
234     let mut res = TokenStream::new();
235     for InitializerField { attrs, kind } in fields {
236         let cfgs = {
237             let mut cfgs = attrs.clone();
238             cfgs.retain(|attr| attr.path().is_ident("cfg"));
239             cfgs
240         };
241         let init = match kind {
242             InitializerKind::Value { ident, value } => {
243                 let mut value_ident = ident.clone();
244                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
245                     // Setting the span of `value_ident` to `value`'s span improves error messages
246                     // when the type of `value` is wrong.
247                     value_ident.set_span(value.span());
248                     quote!(let #value_ident = #value;)
249                 });
250                 // Again span for better diagnostics
251                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
252                 quote! {
253                     #(#attrs)*
254                     {
255                         #value_prep
256                         // SAFETY: TODO
257                         unsafe { #write(&raw mut (*#slot).#ident, #value_ident) };
258                     }
259                 }
260             }
261             InitializerKind::Init { ident, value, .. } => {
262                 // Again span for better diagnostics
263                 let init = format_ident!("init", span = value.span());
264                 let value_init = if pinned {
265                     quote! {
266                         // SAFETY:
267                         // - `slot` is valid, because we are inside of an initializer closure, we
268                         //   return when an error/panic occurs.
269                         // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
270                         //   for `#ident`.
271                         unsafe { #data.#ident(&raw mut (*#slot).#ident, #init)? };
272                     }
273                 } else {
274                     quote! {
275                         // SAFETY: `slot` is valid, because we are inside of an initializer
276                         // closure, we return when an error/panic occurs.
277                         unsafe {
278                             ::pin_init::Init::__init(
279                                 #init,
280                                 &raw mut (*#slot).#ident,
281                             )?
282                         };
283                     }
284                 };
285                 quote! {
286                     #(#attrs)*
287                     {
288                         let #init = #value;
289                         #value_init
290                     }
291                 }
292             }
293             InitializerKind::Code { block: value, .. } => quote! {
294                 #(#attrs)*
295                 #[allow(unused_braces)]
296                 #value
297             },
298         };
299         res.extend(init);
300         if let Some(ident) = kind.ident() {
301             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
302             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
303 
304             // NOTE: The reference is derived from the guard so that it only lives as long as the
305             // guard does and cannot escape the scope. If it's created via `&mut (*#slot).#ident`
306             // like the unaligned field guard, it will become effectively `'static`.
307             let accessor = if pinned {
308                 let project_ident = format_ident!("__project_{ident}");
309                 quote! {
310                     // SAFETY: the initialization is pinned.
311                     unsafe { #data.#project_ident(#guard.let_binding()) }
312                 }
313             } else {
314                 quote! {
315                     #guard.let_binding()
316                 }
317             };
318 
319             res.extend(quote! {
320                 #(#cfgs)*
321                 // Create the drop guard.
322                 //
323                 // SAFETY:
324                 // - `&raw mut (*slot).#ident` is valid.
325                 // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned.
326                 // - `(*slot).#ident` has been initialized above.
327                 // - We only need the ownership to the pointee back when initialization has
328                 //   succeeded, where we `forget` the guard.
329                 let mut #guard = unsafe {
330                     ::pin_init::__internal::DropGuard::new(
331                         &raw mut (*slot).#ident
332                     )
333                 };
334 
335                 #(#cfgs)*
336                 #[allow(unused_variables)]
337                 let #ident = #accessor;
338             });
339             guards.push(guard);
340             guard_attrs.push(cfgs);
341         }
342     }
343     quote! {
344         #res
345         // If execution reaches this point, all fields have been initialized. Therefore we can now
346         // dismiss the guards by forgetting them.
347         #(
348             #(#guard_attrs)*
349             ::core::mem::forget(#guards);
350         )*
351     }
352 }
353 
354 /// Generate the check for ensuring that every field has been initialized and aligned.
355 fn make_field_check(
356     fields: &Punctuated<InitializerField, Token![,]>,
357     init_kind: InitKind,
358     path: &Path,
359 ) -> TokenStream {
360     let field_attrs: Vec<_> = fields
361         .iter()
362         .filter_map(|f| f.kind.ident().map(|_| &f.attrs))
363         .collect();
364     let field_name: Vec<_> = fields.iter().filter_map(|f| f.kind.ident()).collect();
365     let zeroing_trailer = match init_kind {
366         InitKind::Normal => None,
367         InitKind::Zeroing => Some(quote! {
368             ..::core::mem::zeroed()
369         }),
370     };
371     quote! {
372         #[allow(unreachable_code, clippy::diverging_sub_expression)]
373         // We use unreachable code to perform field checks. They're still checked by the compiler.
374         // SAFETY: this code is never executed.
375         let _ = || unsafe {
376             // Create references to ensure that the initialized field is properly aligned.
377             // Unaligned fields will cause the compiler to emit E0793. We do not support
378             // unaligned fields since `Init::__init` requires an aligned pointer; the call to
379             // `ptr::write` for value-initialization case has the same requirement.
380             #(
381                 #(#field_attrs)*
382                 let _ = &(*slot).#field_name;
383             )*
384 
385             // If the zeroing trailer is not present, this checks that all fields have been
386             // mentioned exactly once. If the zeroing trailer is present, all missing fields will be
387             // zeroed, so this checks that all fields have been mentioned at most once. The use of
388             // struct initializer will still generate very natural error messages for any misuse.
389             ::core::ptr::write(slot, #path {
390                 #(
391                     #(#field_attrs)*
392                     #field_name: ::core::panic!(),
393                 )*
394                 #zeroing_trailer
395             })
396         };
397     }
398 }
399 
400 impl Parse for Initializer {
401     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
402         let attrs = input.call(Attribute::parse_outer)?;
403         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
404         let path = input.parse()?;
405         let content;
406         let brace_token = braced!(content in input);
407         let mut fields = Punctuated::new();
408         loop {
409             let lh = content.lookahead1();
410             if lh.peek(End) || lh.peek(Token![..]) {
411                 break;
412             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
413                 fields.push_value(content.parse()?);
414                 let lh = content.lookahead1();
415                 if lh.peek(End) {
416                     break;
417                 } else if lh.peek(Token![,]) {
418                     fields.push_punct(content.parse()?);
419                 } else {
420                     return Err(lh.error());
421                 }
422             } else {
423                 return Err(lh.error());
424             }
425         }
426         let rest = content
427             .peek(Token![..])
428             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
429             .transpose()?;
430         let error = input
431             .peek(Token![?])
432             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
433             .transpose()?;
434         let attrs = attrs
435             .into_iter()
436             .map(|a| {
437                 if a.path().is_ident("default_error") {
438                     a.parse_args::<DefaultErrorAttribute>()
439                         .map(InitializerAttribute::DefaultError)
440                 } else {
441                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
442                 }
443             })
444             .collect::<Result<Vec<_>, _>>()?;
445         Ok(Self {
446             attrs,
447             this,
448             path,
449             brace_token,
450             fields,
451             rest,
452             error,
453         })
454     }
455 }
456 
457 impl Parse for DefaultErrorAttribute {
458     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
459         Ok(Self { ty: input.parse()? })
460     }
461 }
462 
463 impl Parse for This {
464     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
465         Ok(Self {
466             _and_token: input.parse()?,
467             ident: input.parse()?,
468             _in_token: input.parse()?,
469         })
470     }
471 }
472 
473 impl Parse for InitializerField {
474     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
475         let attrs = input.call(Attribute::parse_outer)?;
476         Ok(Self {
477             attrs,
478             kind: input.parse()?,
479         })
480     }
481 }
482 
483 impl Parse for InitializerKind {
484     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
485         let lh = input.lookahead1();
486         if lh.peek(Token![_]) {
487             Ok(Self::Code {
488                 _underscore_token: input.parse()?,
489                 _colon_token: input.parse()?,
490                 block: input.parse()?,
491             })
492         } else if lh.peek(Ident) {
493             let ident = input.parse()?;
494             let lh = input.lookahead1();
495             if lh.peek(Token![<-]) {
496                 Ok(Self::Init {
497                     ident,
498                     _left_arrow_token: input.parse()?,
499                     value: input.parse()?,
500                 })
501             } else if lh.peek(Token![:]) {
502                 Ok(Self::Value {
503                     ident,
504                     value: Some((input.parse()?, input.parse()?)),
505                 })
506             } else if lh.peek(Token![,]) || lh.peek(End) {
507                 Ok(Self::Value { ident, value: None })
508             } else {
509                 Err(lh.error())
510             }
511         } else {
512             Err(lh.error())
513         }
514     }
515 }
516