xref: /linux/rust/pin-init/internal/src/init.rs (revision 580cc37b1de4fcd9997c48d7080e744533f09f36)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65 }
66 
67 struct DefaultErrorAttribute {
68     ty: Box<Type>,
69 }
70 
71 pub(crate) fn expand(
72     Initializer {
73         attrs,
74         this,
75         path,
76         brace_token,
77         fields,
78         rest,
79         error,
80     }: Initializer,
81     default_error: Option<&'static str>,
82     pinned: bool,
83     dcx: &mut DiagCtxt,
84 ) -> Result<TokenStream, ErrorGuaranteed> {
85     let error = error.map_or_else(
86         || {
87             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                 #[expect(irrefutable_let_patterns)]
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("PinData"),
110             format_ident!("__pin_data"),
111             format_ident!("pin_init_from_closure"),
112         )
113     } else {
114         (
115             format_ident!("HasInitData"),
116             format_ident!("InitData"),
117             format_ident!("__init_data"),
118             format_ident!("init_from_closure"),
119         )
120     };
121     let init_kind = get_init_kind(rest, dcx);
122     let zeroable_check = match init_kind {
123         InitKind::Normal => quote!(),
124         InitKind::Zeroing => quote! {
125             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127             // This allows us to also remove the check that all fields are present (since we
128             // already set the memory to zero and that is a valid bit pattern).
129             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130             where T: ::pin_init::Zeroable
131             {}
132             // Ensure that the struct is indeed `Zeroable`.
133             assert_zeroable(#slot);
134             // SAFETY: The type implements `Zeroable` by the check above.
135             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136         },
137     };
138     let this = match this {
139         None => quote!(),
140         Some(This { ident, .. }) => quote! {
141             // Create the `this` so it can be referenced by the user inside of the
142             // expressions creating the individual fields.
143             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144         },
145     };
146     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147     let data = Ident::new("__data", Span::mixed_site());
148     let init_fields = init_fields(&fields, pinned, &data, &slot);
149     let field_check = make_field_check(&fields, init_kind, &path);
150     Ok(quote! {{
151         // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
152         // type and shadow it later when we insert the arbitrary user code. That way there will be
153         // no possibility of returning without `unsafe`.
154         struct __InitOk;
155 
156         // Get the data about fields from the supplied type.
157         // SAFETY: TODO
158         let #data = unsafe {
159             use ::pin_init::__internal::#has_data_trait;
160             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
161             // generics (which need to be present with that syntax).
162             #path::#get_data()
163         };
164         // Ensure that `#data` really is of type `#data` and help with type inference:
165         let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
166             #data,
167             move |slot| {
168                 {
169                     // Shadow the structure so it cannot be used to return early.
170                     struct __InitOk;
171                     #zeroable_check
172                     #this
173                     #init_fields
174                     #field_check
175                 }
176                 Ok(__InitOk)
177             }
178         );
179         let init = move |slot| -> ::core::result::Result<(), #error> {
180             init(slot).map(|__InitOk| ())
181         };
182         // SAFETY: TODO
183         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
184         init
185     }})
186 }
187 
188 enum InitKind {
189     Normal,
190     Zeroing,
191 }
192 
193 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
194     let Some((dotdot, expr)) = rest else {
195         return InitKind::Normal;
196     };
197     match &expr {
198         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
199             Expr::Path(ExprPath {
200                 attrs,
201                 qself: None,
202                 path:
203                     Path {
204                         leading_colon: None,
205                         segments,
206                     },
207             }) if attrs.is_empty()
208                 && segments.len() == 2
209                 && segments[0].ident == "Zeroable"
210                 && segments[0].arguments.is_none()
211                 && segments[1].ident == "init_zeroed"
212                 && segments[1].arguments.is_none() =>
213             {
214                 return InitKind::Zeroing;
215             }
216             _ => {}
217         },
218         _ => {}
219     }
220     dcx.error(
221         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
222         "expected nothing or `..Zeroable::init_zeroed()`.",
223     );
224     InitKind::Normal
225 }
226 
227 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
228 fn init_fields(
229     fields: &Punctuated<InitializerField, Token![,]>,
230     pinned: bool,
231     data: &Ident,
232     slot: &Ident,
233 ) -> TokenStream {
234     let mut guards = vec![];
235     let mut guard_attrs = vec![];
236     let mut res = TokenStream::new();
237     for InitializerField { attrs, kind } in fields {
238         let cfgs = {
239             let mut cfgs = attrs.clone();
240             cfgs.retain(|attr| attr.path().is_ident("cfg"));
241             cfgs
242         };
243         let init = match kind {
244             InitializerKind::Value { ident, value } => {
245                 let mut value_ident = ident.clone();
246                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
247                     // Setting the span of `value_ident` to `value`'s span improves error messages
248                     // when the type of `value` is wrong.
249                     value_ident.set_span(value.span());
250                     quote!(let #value_ident = #value;)
251                 });
252                 // Again span for better diagnostics
253                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
254                 // NOTE: the field accessor ensures that the initialized field is properly aligned.
255                 // Unaligned fields will cause the compiler to emit E0793. We do not support
256                 // unaligned fields since `Init::__init` requires an aligned pointer; the call to
257                 // `ptr::write` below has the same requirement.
258                 let accessor = if pinned {
259                     let project_ident = format_ident!("__project_{ident}");
260                     quote! {
261                         // SAFETY: TODO
262                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
263                     }
264                 } else {
265                     quote! {
266                         // SAFETY: TODO
267                         unsafe { &mut (*#slot).#ident }
268                     }
269                 };
270                 quote! {
271                     #(#attrs)*
272                     {
273                         #value_prep
274                         // SAFETY: TODO
275                         unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
276                     }
277                     #(#cfgs)*
278                     #[allow(unused_variables)]
279                     let #ident = #accessor;
280                 }
281             }
282             InitializerKind::Init { ident, value, .. } => {
283                 // Again span for better diagnostics
284                 let init = format_ident!("init", span = value.span());
285                 // NOTE: the field accessor ensures that the initialized field is properly aligned.
286                 // Unaligned fields will cause the compiler to emit E0793. We do not support
287                 // unaligned fields since `Init::__init` requires an aligned pointer; the call to
288                 // `ptr::write` below has the same requirement.
289                 let (value_init, accessor) = if pinned {
290                     let project_ident = format_ident!("__project_{ident}");
291                     (
292                         quote! {
293                             // SAFETY:
294                             // - `slot` is valid, because we are inside of an initializer closure, we
295                             //   return when an error/panic occurs.
296                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
297                             //   for `#ident`.
298                             unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
299                         },
300                         quote! {
301                             // SAFETY: TODO
302                             unsafe { #data.#project_ident(&mut (*#slot).#ident) }
303                         },
304                     )
305                 } else {
306                     (
307                         quote! {
308                             // SAFETY: `slot` is valid, because we are inside of an initializer
309                             // closure, we return when an error/panic occurs.
310                             unsafe {
311                                 ::pin_init::Init::__init(
312                                     #init,
313                                     ::core::ptr::addr_of_mut!((*#slot).#ident),
314                                 )?
315                             };
316                         },
317                         quote! {
318                             // SAFETY: TODO
319                             unsafe { &mut (*#slot).#ident }
320                         },
321                     )
322                 };
323                 quote! {
324                     #(#attrs)*
325                     {
326                         let #init = #value;
327                         #value_init
328                     }
329                     #(#cfgs)*
330                     #[allow(unused_variables)]
331                     let #ident = #accessor;
332                 }
333             }
334             InitializerKind::Code { block: value, .. } => quote! {
335                 #(#attrs)*
336                 #[allow(unused_braces)]
337                 #value
338             },
339         };
340         res.extend(init);
341         if let Some(ident) = kind.ident() {
342             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
343             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
344             res.extend(quote! {
345                 #(#cfgs)*
346                 // Create the drop guard:
347                 //
348                 // We rely on macro hygiene to make it impossible for users to access this local
349                 // variable.
350                 // SAFETY: We forget the guard later when initialization has succeeded.
351                 let #guard = unsafe {
352                     ::pin_init::__internal::DropGuard::new(
353                         ::core::ptr::addr_of_mut!((*slot).#ident)
354                     )
355                 };
356             });
357             guards.push(guard);
358             guard_attrs.push(cfgs);
359         }
360     }
361     quote! {
362         #res
363         // If execution reaches this point, all fields have been initialized. Therefore we can now
364         // dismiss the guards by forgetting them.
365         #(
366             #(#guard_attrs)*
367             ::core::mem::forget(#guards);
368         )*
369     }
370 }
371 
372 /// Generate the check for ensuring that every field has been initialized.
373 fn make_field_check(
374     fields: &Punctuated<InitializerField, Token![,]>,
375     init_kind: InitKind,
376     path: &Path,
377 ) -> TokenStream {
378     let field_attrs = fields
379         .iter()
380         .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
381     let field_name = fields.iter().filter_map(|f| f.kind.ident());
382     match init_kind {
383         InitKind::Normal => quote! {
384             // We use unreachable code to ensure that all fields have been mentioned exactly once,
385             // this struct initializer will still be type-checked and complain with a very natural
386             // error message if a field is forgotten/mentioned more than once.
387             #[allow(unreachable_code, clippy::diverging_sub_expression)]
388             // SAFETY: this code is never executed.
389             let _ = || unsafe {
390                 ::core::ptr::write(slot, #path {
391                     #(
392                         #(#field_attrs)*
393                         #field_name: ::core::panic!(),
394                     )*
395                 })
396             };
397         },
398         InitKind::Zeroing => quote! {
399             // We use unreachable code to ensure that all fields have been mentioned at most once.
400             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
401             // be zeroed. This struct initializer will still be type-checked and complain with a
402             // very natural error message if a field is mentioned more than once, or doesn't exist.
403             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
404             // SAFETY: this code is never executed.
405             let _ = || unsafe {
406                 ::core::ptr::write(slot, #path {
407                     #(
408                         #(#field_attrs)*
409                         #field_name: ::core::panic!(),
410                     )*
411                     ..::core::mem::zeroed()
412                 })
413             };
414         },
415     }
416 }
417 
418 impl Parse for Initializer {
419     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
420         let attrs = input.call(Attribute::parse_outer)?;
421         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
422         let path = input.parse()?;
423         let content;
424         let brace_token = braced!(content in input);
425         let mut fields = Punctuated::new();
426         loop {
427             let lh = content.lookahead1();
428             if lh.peek(End) || lh.peek(Token![..]) {
429                 break;
430             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
431                 fields.push_value(content.parse()?);
432                 let lh = content.lookahead1();
433                 if lh.peek(End) {
434                     break;
435                 } else if lh.peek(Token![,]) {
436                     fields.push_punct(content.parse()?);
437                 } else {
438                     return Err(lh.error());
439                 }
440             } else {
441                 return Err(lh.error());
442             }
443         }
444         let rest = content
445             .peek(Token![..])
446             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
447             .transpose()?;
448         let error = input
449             .peek(Token![?])
450             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
451             .transpose()?;
452         let attrs = attrs
453             .into_iter()
454             .map(|a| {
455                 if a.path().is_ident("default_error") {
456                     a.parse_args::<DefaultErrorAttribute>()
457                         .map(InitializerAttribute::DefaultError)
458                 } else {
459                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
460                 }
461             })
462             .collect::<Result<Vec<_>, _>>()?;
463         Ok(Self {
464             attrs,
465             this,
466             path,
467             brace_token,
468             fields,
469             rest,
470             error,
471         })
472     }
473 }
474 
475 impl Parse for DefaultErrorAttribute {
476     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
477         Ok(Self { ty: input.parse()? })
478     }
479 }
480 
481 impl Parse for This {
482     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
483         Ok(Self {
484             _and_token: input.parse()?,
485             ident: input.parse()?,
486             _in_token: input.parse()?,
487         })
488     }
489 }
490 
491 impl Parse for InitializerField {
492     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
493         let attrs = input.call(Attribute::parse_outer)?;
494         Ok(Self {
495             attrs,
496             kind: input.parse()?,
497         })
498     }
499 }
500 
501 impl Parse for InitializerKind {
502     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
503         let lh = input.lookahead1();
504         if lh.peek(Token![_]) {
505             Ok(Self::Code {
506                 _underscore_token: input.parse()?,
507                 _colon_token: input.parse()?,
508                 block: input.parse()?,
509             })
510         } else if lh.peek(Ident) {
511             let ident = input.parse()?;
512             let lh = input.lookahead1();
513             if lh.peek(Token![<-]) {
514                 Ok(Self::Init {
515                     ident,
516                     _left_arrow_token: input.parse()?,
517                     value: input.parse()?,
518                 })
519             } else if lh.peek(Token![:]) {
520                 Ok(Self::Value {
521                     ident,
522                     value: Some((input.parse()?, input.parse()?)),
523                 })
524             } else if lh.peek(Token![,]) || lh.peek(End) {
525                 Ok(Self::Value { ident, value: None })
526             } else {
527                 Err(lh.error())
528             }
529         } else {
530             Err(lh.error())
531         }
532     }
533 }
534