xref: /linux/rust/pin-init/internal/src/init.rs (revision d26732e57b06ef32dadfc32d5de9ac39262698cb)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65 }
66 
67 struct DefaultErrorAttribute {
68     ty: Box<Type>,
69 }
70 
71 pub(crate) fn expand(
72     Initializer {
73         attrs,
74         this,
75         path,
76         brace_token,
77         fields,
78         rest,
79         error,
80     }: Initializer,
81     default_error: Option<&'static str>,
82     pinned: bool,
83     dcx: &mut DiagCtxt,
84 ) -> Result<TokenStream, ErrorGuaranteed> {
85     let error = error.map_or_else(
86         || {
87             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                 #[expect(irrefutable_let_patterns)]
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("PinData"),
110             format_ident!("__pin_data"),
111             format_ident!("pin_init_from_closure"),
112         )
113     } else {
114         (
115             format_ident!("HasInitData"),
116             format_ident!("InitData"),
117             format_ident!("__init_data"),
118             format_ident!("init_from_closure"),
119         )
120     };
121     let init_kind = get_init_kind(rest, dcx);
122     let zeroable_check = match init_kind {
123         InitKind::Normal => quote!(),
124         InitKind::Zeroing => quote! {
125             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127             // This allows us to also remove the check that all fields are present (since we
128             // already set the memory to zero and that is a valid bit pattern).
129             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130             where T: ::pin_init::Zeroable
131             {}
132             // Ensure that the struct is indeed `Zeroable`.
133             assert_zeroable(#slot);
134             // SAFETY: The type implements `Zeroable` by the check above.
135             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136         },
137     };
138     let this = match this {
139         None => quote!(),
140         Some(This { ident, .. }) => quote! {
141             // Create the `this` so it can be referenced by the user inside of the
142             // expressions creating the individual fields.
143             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144         },
145     };
146     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147     let data = Ident::new("__data", Span::mixed_site());
148     let init_fields = init_fields(&fields, pinned, &data, &slot);
149     let field_check = make_field_check(&fields, init_kind, &path);
150     Ok(quote! {{
151         // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
152         // type and shadow it later when we insert the arbitrary user code. That way there will be
153         // no possibility of returning without `unsafe`.
154         struct __InitOk;
155 
156         // Get the data about fields from the supplied type.
157         // SAFETY: TODO
158         let #data = unsafe {
159             use ::pin_init::__internal::#has_data_trait;
160             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
161             // generics (which need to be present with that syntax).
162             #path::#get_data()
163         };
164         // Ensure that `#data` really is of type `#data` and help with type inference:
165         let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
166             #data,
167             move |slot| {
168                 {
169                     // Shadow the structure so it cannot be used to return early.
170                     struct __InitOk;
171                     #zeroable_check
172                     #this
173                     #init_fields
174                     #field_check
175                 }
176                 Ok(__InitOk)
177             }
178         );
179         let init = move |slot| -> ::core::result::Result<(), #error> {
180             init(slot).map(|__InitOk| ())
181         };
182         // SAFETY: TODO
183         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
184         init
185     }})
186 }
187 
188 enum InitKind {
189     Normal,
190     Zeroing,
191 }
192 
193 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
194     let Some((dotdot, expr)) = rest else {
195         return InitKind::Normal;
196     };
197     match &expr {
198         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
199             Expr::Path(ExprPath {
200                 attrs,
201                 qself: None,
202                 path:
203                     Path {
204                         leading_colon: None,
205                         segments,
206                     },
207             }) if attrs.is_empty()
208                 && segments.len() == 2
209                 && segments[0].ident == "Zeroable"
210                 && segments[0].arguments.is_none()
211                 && segments[1].ident == "init_zeroed"
212                 && segments[1].arguments.is_none() =>
213             {
214                 return InitKind::Zeroing;
215             }
216             _ => {}
217         },
218         _ => {}
219     }
220     dcx.error(
221         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
222         "expected nothing or `..Zeroable::init_zeroed()`.",
223     );
224     InitKind::Normal
225 }
226 
227 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
228 fn init_fields(
229     fields: &Punctuated<InitializerField, Token![,]>,
230     pinned: bool,
231     data: &Ident,
232     slot: &Ident,
233 ) -> TokenStream {
234     let mut guards = vec![];
235     let mut guard_attrs = vec![];
236     let mut res = TokenStream::new();
237     for InitializerField { attrs, kind } in fields {
238         let cfgs = {
239             let mut cfgs = attrs.clone();
240             cfgs.retain(|attr| attr.path().is_ident("cfg"));
241             cfgs
242         };
243         let init = match kind {
244             InitializerKind::Value { ident, value } => {
245                 let mut value_ident = ident.clone();
246                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
247                     // Setting the span of `value_ident` to `value`'s span improves error messages
248                     // when the type of `value` is wrong.
249                     value_ident.set_span(value.span());
250                     quote!(let #value_ident = #value;)
251                 });
252                 // Again span for better diagnostics
253                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
254                 let accessor = if pinned {
255                     let project_ident = format_ident!("__project_{ident}");
256                     quote! {
257                         // SAFETY: TODO
258                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
259                     }
260                 } else {
261                     quote! {
262                         // SAFETY: TODO
263                         unsafe { &mut (*#slot).#ident }
264                     }
265                 };
266                 quote! {
267                     #(#attrs)*
268                     {
269                         #value_prep
270                         // SAFETY: TODO
271                         unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
272                     }
273                     #(#cfgs)*
274                     #[allow(unused_variables)]
275                     let #ident = #accessor;
276                 }
277             }
278             InitializerKind::Init { ident, value, .. } => {
279                 // Again span for better diagnostics
280                 let init = format_ident!("init", span = value.span());
281                 if pinned {
282                     let project_ident = format_ident!("__project_{ident}");
283                     quote! {
284                         #(#attrs)*
285                         {
286                             let #init = #value;
287                             // SAFETY:
288                             // - `slot` is valid, because we are inside of an initializer closure, we
289                             //   return when an error/panic occurs.
290                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
291                             //   for `#ident`.
292                             unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
293                         }
294                         #(#cfgs)*
295                         // SAFETY: TODO
296                         #[allow(unused_variables)]
297                         let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) };
298                     }
299                 } else {
300                     quote! {
301                         #(#attrs)*
302                         {
303                             let #init = #value;
304                             // SAFETY: `slot` is valid, because we are inside of an initializer
305                             // closure, we return when an error/panic occurs.
306                             unsafe {
307                                 ::pin_init::Init::__init(
308                                     #init,
309                                     ::core::ptr::addr_of_mut!((*#slot).#ident),
310                                 )?
311                             };
312                         }
313                         #(#cfgs)*
314                         // SAFETY: TODO
315                         #[allow(unused_variables)]
316                         let #ident = unsafe { &mut (*#slot).#ident };
317                     }
318                 }
319             }
320             InitializerKind::Code { block: value, .. } => quote! {
321                 #(#attrs)*
322                 #[allow(unused_braces)]
323                 #value
324             },
325         };
326         res.extend(init);
327         if let Some(ident) = kind.ident() {
328             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
329             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
330             res.extend(quote! {
331                 #(#cfgs)*
332                 // Create the drop guard:
333                 //
334                 // We rely on macro hygiene to make it impossible for users to access this local
335                 // variable.
336                 // SAFETY: We forget the guard later when initialization has succeeded.
337                 let #guard = unsafe {
338                     ::pin_init::__internal::DropGuard::new(
339                         ::core::ptr::addr_of_mut!((*slot).#ident)
340                     )
341                 };
342             });
343             guards.push(guard);
344             guard_attrs.push(cfgs);
345         }
346     }
347     quote! {
348         #res
349         // If execution reaches this point, all fields have been initialized. Therefore we can now
350         // dismiss the guards by forgetting them.
351         #(
352             #(#guard_attrs)*
353             ::core::mem::forget(#guards);
354         )*
355     }
356 }
357 
358 /// Generate the check for ensuring that every field has been initialized.
359 fn make_field_check(
360     fields: &Punctuated<InitializerField, Token![,]>,
361     init_kind: InitKind,
362     path: &Path,
363 ) -> TokenStream {
364     let field_attrs = fields
365         .iter()
366         .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
367     let field_name = fields.iter().filter_map(|f| f.kind.ident());
368     match init_kind {
369         InitKind::Normal => quote! {
370             // We use unreachable code to ensure that all fields have been mentioned exactly once,
371             // this struct initializer will still be type-checked and complain with a very natural
372             // error message if a field is forgotten/mentioned more than once.
373             #[allow(unreachable_code, clippy::diverging_sub_expression)]
374             // SAFETY: this code is never executed.
375             let _ = || unsafe {
376                 ::core::ptr::write(slot, #path {
377                     #(
378                         #(#field_attrs)*
379                         #field_name: ::core::panic!(),
380                     )*
381                 })
382             };
383         },
384         InitKind::Zeroing => quote! {
385             // We use unreachable code to ensure that all fields have been mentioned at most once.
386             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
387             // be zeroed. This struct initializer will still be type-checked and complain with a
388             // very natural error message if a field is mentioned more than once, or doesn't exist.
389             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
390             // SAFETY: this code is never executed.
391             let _ = || unsafe {
392                 let mut zeroed = ::core::mem::zeroed();
393                 // We have to use type inference here to make zeroed have the correct type. This
394                 // does not get executed, so it has no effect.
395                 ::core::ptr::write(slot, zeroed);
396                 zeroed = ::core::mem::zeroed();
397                 ::core::ptr::write(slot, #path {
398                     #(
399                         #(#field_attrs)*
400                         #field_name: ::core::panic!(),
401                     )*
402                     ..zeroed
403                 })
404             };
405         },
406     }
407 }
408 
409 impl Parse for Initializer {
410     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
411         let attrs = input.call(Attribute::parse_outer)?;
412         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
413         let path = input.parse()?;
414         let content;
415         let brace_token = braced!(content in input);
416         let mut fields = Punctuated::new();
417         loop {
418             let lh = content.lookahead1();
419             if lh.peek(End) || lh.peek(Token![..]) {
420                 break;
421             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
422                 fields.push_value(content.parse()?);
423                 let lh = content.lookahead1();
424                 if lh.peek(End) {
425                     break;
426                 } else if lh.peek(Token![,]) {
427                     fields.push_punct(content.parse()?);
428                 } else {
429                     return Err(lh.error());
430                 }
431             } else {
432                 return Err(lh.error());
433             }
434         }
435         let rest = content
436             .peek(Token![..])
437             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
438             .transpose()?;
439         let error = input
440             .peek(Token![?])
441             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
442             .transpose()?;
443         let attrs = attrs
444             .into_iter()
445             .map(|a| {
446                 if a.path().is_ident("default_error") {
447                     a.parse_args::<DefaultErrorAttribute>()
448                         .map(InitializerAttribute::DefaultError)
449                 } else {
450                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
451                 }
452             })
453             .collect::<Result<Vec<_>, _>>()?;
454         Ok(Self {
455             attrs,
456             this,
457             path,
458             brace_token,
459             fields,
460             rest,
461             error,
462         })
463     }
464 }
465 
466 impl Parse for DefaultErrorAttribute {
467     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
468         Ok(Self { ty: input.parse()? })
469     }
470 }
471 
472 impl Parse for This {
473     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
474         Ok(Self {
475             _and_token: input.parse()?,
476             ident: input.parse()?,
477             _in_token: input.parse()?,
478         })
479     }
480 }
481 
482 impl Parse for InitializerField {
483     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
484         let attrs = input.call(Attribute::parse_outer)?;
485         Ok(Self {
486             attrs,
487             kind: input.parse()?,
488         })
489     }
490 }
491 
492 impl Parse for InitializerKind {
493     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
494         let lh = input.lookahead1();
495         if lh.peek(Token![_]) {
496             Ok(Self::Code {
497                 _underscore_token: input.parse()?,
498                 _colon_token: input.parse()?,
499                 block: input.parse()?,
500             })
501         } else if lh.peek(Ident) {
502             let ident = input.parse()?;
503             let lh = input.lookahead1();
504             if lh.peek(Token![<-]) {
505                 Ok(Self::Init {
506                     ident,
507                     _left_arrow_token: input.parse()?,
508                     value: input.parse()?,
509                 })
510             } else if lh.peek(Token![:]) {
511                 Ok(Self::Value {
512                     ident,
513                     value: Some((input.parse()?, input.parse()?)),
514                 })
515             } else if lh.peek(Token![,]) || lh.peek(End) {
516                 Ok(Self::Value { ident, value: None })
517             } else {
518                 Err(lh.error())
519             }
520         } else {
521             Err(lh.error())
522         }
523     }
524 }
525