xref: /linux/rust/pin-init/internal/src/init.rs (revision aeabc92eb2d8c27578274a7ec3d0d00558fedfc2)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 enum InitializerField {
33     Value {
34         ident: Ident,
35         value: Option<(Token![:], Expr)>,
36     },
37     Init {
38         ident: Ident,
39         _left_arrow_token: Token![<-],
40         value: Expr,
41     },
42     Code {
43         _underscore_token: Token![_],
44         _colon_token: Token![:],
45         block: Block,
46     },
47 }
48 
49 impl InitializerField {
50     fn ident(&self) -> Option<&Ident> {
51         match self {
52             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
53             Self::Code { .. } => None,
54         }
55     }
56 }
57 
58 enum InitializerAttribute {
59     DefaultError(DefaultErrorAttribute),
60 }
61 
62 struct DefaultErrorAttribute {
63     ty: Box<Type>,
64 }
65 
66 pub(crate) fn expand(
67     Initializer {
68         attrs,
69         this,
70         path,
71         brace_token,
72         fields,
73         rest,
74         error,
75     }: Initializer,
76     default_error: Option<&'static str>,
77     pinned: bool,
78     dcx: &mut DiagCtxt,
79 ) -> Result<TokenStream, ErrorGuaranteed> {
80     let error = error.map_or_else(
81         || {
82             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
83                 #[expect(irrefutable_let_patterns)]
84                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
85                     Some(ty.clone())
86                 } else {
87                     acc
88                 }
89             }) {
90                 default_error
91             } else if let Some(default_error) = default_error {
92                 syn::parse_str(default_error).unwrap()
93             } else {
94                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
95                 parse_quote!(::core::convert::Infallible)
96             }
97         },
98         |(_, err)| Box::new(err),
99     );
100     let slot = format_ident!("slot");
101     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
102         (
103             format_ident!("HasPinData"),
104             format_ident!("PinData"),
105             format_ident!("__pin_data"),
106             format_ident!("pin_init_from_closure"),
107         )
108     } else {
109         (
110             format_ident!("HasInitData"),
111             format_ident!("InitData"),
112             format_ident!("__init_data"),
113             format_ident!("init_from_closure"),
114         )
115     };
116     let init_kind = get_init_kind(rest, dcx);
117     let zeroable_check = match init_kind {
118         InitKind::Normal => quote!(),
119         InitKind::Zeroing => quote! {
120             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
121             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
122             // This allows us to also remove the check that all fields are present (since we
123             // already set the memory to zero and that is a valid bit pattern).
124             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
125             where T: ::pin_init::Zeroable
126             {}
127             // Ensure that the struct is indeed `Zeroable`.
128             assert_zeroable(#slot);
129             // SAFETY: The type implements `Zeroable` by the check above.
130             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
131         },
132     };
133     let this = match this {
134         None => quote!(),
135         Some(This { ident, .. }) => quote! {
136             // Create the `this` so it can be referenced by the user inside of the
137             // expressions creating the individual fields.
138             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
139         },
140     };
141     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
142     let data = Ident::new("__data", Span::mixed_site());
143     let init_fields = init_fields(&fields, pinned, &data, &slot);
144     let field_check = make_field_check(&fields, init_kind, &path);
145     Ok(quote! {{
146         // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
147         // type and shadow it later when we insert the arbitrary user code. That way there will be
148         // no possibility of returning without `unsafe`.
149         struct __InitOk;
150 
151         // Get the data about fields from the supplied type.
152         // SAFETY: TODO
153         let #data = unsafe {
154             use ::pin_init::__internal::#has_data_trait;
155             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
156             // generics (which need to be present with that syntax).
157             #path::#get_data()
158         };
159         // Ensure that `#data` really is of type `#data` and help with type inference:
160         let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
161             #data,
162             move |slot| {
163                 {
164                     // Shadow the structure so it cannot be used to return early.
165                     struct __InitOk;
166                     #zeroable_check
167                     #this
168                     #init_fields
169                     #field_check
170                 }
171                 Ok(__InitOk)
172             }
173         );
174         let init = move |slot| -> ::core::result::Result<(), #error> {
175             init(slot).map(|__InitOk| ())
176         };
177         // SAFETY: TODO
178         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
179         init
180     }})
181 }
182 
183 enum InitKind {
184     Normal,
185     Zeroing,
186 }
187 
188 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
189     let Some((dotdot, expr)) = rest else {
190         return InitKind::Normal;
191     };
192     match &expr {
193         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
194             Expr::Path(ExprPath {
195                 attrs,
196                 qself: None,
197                 path:
198                     Path {
199                         leading_colon: None,
200                         segments,
201                     },
202             }) if attrs.is_empty()
203                 && segments.len() == 2
204                 && segments[0].ident == "Zeroable"
205                 && segments[0].arguments.is_none()
206                 && segments[1].ident == "init_zeroed"
207                 && segments[1].arguments.is_none() =>
208             {
209                 return InitKind::Zeroing;
210             }
211             _ => {}
212         },
213         _ => {}
214     }
215     dcx.error(
216         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
217         "expected nothing or `..Zeroable::init_zeroed()`.",
218     );
219     InitKind::Normal
220 }
221 
222 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
223 fn init_fields(
224     fields: &Punctuated<InitializerField, Token![,]>,
225     pinned: bool,
226     data: &Ident,
227     slot: &Ident,
228 ) -> TokenStream {
229     let mut guards = vec![];
230     let mut res = TokenStream::new();
231     for field in fields {
232         let init = match field {
233             InitializerField::Value { ident, value } => {
234                 let mut value_ident = ident.clone();
235                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
236                     // Setting the span of `value_ident` to `value`'s span improves error messages
237                     // when the type of `value` is wrong.
238                     value_ident.set_span(value.span());
239                     quote!(let #value_ident = #value;)
240                 });
241                 // Again span for better diagnostics
242                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
243                 let accessor = if pinned {
244                     let project_ident = format_ident!("__project_{ident}");
245                     quote! {
246                         // SAFETY: TODO
247                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
248                     }
249                 } else {
250                     quote! {
251                         // SAFETY: TODO
252                         unsafe { &mut (*#slot).#ident }
253                     }
254                 };
255                 quote! {
256                     {
257                         #value_prep
258                         // SAFETY: TODO
259                         unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
260                     }
261                     #[allow(unused_variables)]
262                     let #ident = #accessor;
263                 }
264             }
265             InitializerField::Init { ident, value, .. } => {
266                 // Again span for better diagnostics
267                 let init = format_ident!("init", span = value.span());
268                 if pinned {
269                     let project_ident = format_ident!("__project_{ident}");
270                     quote! {
271                         {
272                             let #init = #value;
273                             // SAFETY:
274                             // - `slot` is valid, because we are inside of an initializer closure, we
275                             //   return when an error/panic occurs.
276                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
277                             //   for `#ident`.
278                             unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
279                         }
280                         // SAFETY: TODO
281                         #[allow(unused_variables)]
282                         let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) };
283                     }
284                 } else {
285                     quote! {
286                         {
287                             let #init = #value;
288                             // SAFETY: `slot` is valid, because we are inside of an initializer
289                             // closure, we return when an error/panic occurs.
290                             unsafe {
291                                 ::pin_init::Init::__init(
292                                     #init,
293                                     ::core::ptr::addr_of_mut!((*#slot).#ident),
294                                 )?
295                             };
296                         }
297                         // SAFETY: TODO
298                         #[allow(unused_variables)]
299                         let #ident = unsafe { &mut (*#slot).#ident };
300                     }
301                 }
302             }
303             InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value),
304         };
305         res.extend(init);
306         if let Some(ident) = field.ident() {
307             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
308             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
309             guards.push(guard.clone());
310             res.extend(quote! {
311                 // Create the drop guard:
312                 //
313                 // We rely on macro hygiene to make it impossible for users to access this local
314                 // variable.
315                 // SAFETY: We forget the guard later when initialization has succeeded.
316                 let #guard = unsafe {
317                     ::pin_init::__internal::DropGuard::new(
318                         ::core::ptr::addr_of_mut!((*slot).#ident)
319                     )
320                 };
321             });
322         }
323     }
324     quote! {
325         #res
326         // If execution reaches this point, all fields have been initialized. Therefore we can now
327         // dismiss the guards by forgetting them.
328         #(::core::mem::forget(#guards);)*
329     }
330 }
331 
332 /// Generate the check for ensuring that every field has been initialized.
333 fn make_field_check(
334     fields: &Punctuated<InitializerField, Token![,]>,
335     init_kind: InitKind,
336     path: &Path,
337 ) -> TokenStream {
338     let fields = fields.iter().filter_map(|f| f.ident());
339     match init_kind {
340         InitKind::Normal => quote! {
341             // We use unreachable code to ensure that all fields have been mentioned exactly once,
342             // this struct initializer will still be type-checked and complain with a very natural
343             // error message if a field is forgotten/mentioned more than once.
344             #[allow(unreachable_code, clippy::diverging_sub_expression)]
345             // SAFETY: this code is never executed.
346             let _ = || unsafe {
347                 ::core::ptr::write(slot, #path {
348                     #(
349                         #fields: ::core::panic!(),
350                     )*
351                 })
352             };
353         },
354         InitKind::Zeroing => quote! {
355             // We use unreachable code to ensure that all fields have been mentioned at most once.
356             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
357             // be zeroed. This struct initializer will still be type-checked and complain with a
358             // very natural error message if a field is mentioned more than once, or doesn't exist.
359             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
360             // SAFETY: this code is never executed.
361             let _ = || unsafe {
362                 let mut zeroed = ::core::mem::zeroed();
363                 // We have to use type inference here to make zeroed have the correct type. This
364                 // does not get executed, so it has no effect.
365                 ::core::ptr::write(slot, zeroed);
366                 zeroed = ::core::mem::zeroed();
367                 ::core::ptr::write(slot, #path {
368                     #(
369                         #fields: ::core::panic!(),
370                     )*
371                     ..zeroed
372                 })
373             };
374         },
375     }
376 }
377 
378 impl Parse for Initializer {
379     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
380         let attrs = input.call(Attribute::parse_outer)?;
381         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
382         let path = input.parse()?;
383         let content;
384         let brace_token = braced!(content in input);
385         let mut fields = Punctuated::new();
386         loop {
387             let lh = content.lookahead1();
388             if lh.peek(End) || lh.peek(Token![..]) {
389                 break;
390             } else if lh.peek(Ident) || lh.peek(Token![_]) {
391                 fields.push_value(content.parse()?);
392                 let lh = content.lookahead1();
393                 if lh.peek(End) {
394                     break;
395                 } else if lh.peek(Token![,]) {
396                     fields.push_punct(content.parse()?);
397                 } else {
398                     return Err(lh.error());
399                 }
400             } else {
401                 return Err(lh.error());
402             }
403         }
404         let rest = content
405             .peek(Token![..])
406             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
407             .transpose()?;
408         let error = input
409             .peek(Token![?])
410             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
411             .transpose()?;
412         let attrs = attrs
413             .into_iter()
414             .map(|a| {
415                 if a.path().is_ident("default_error") {
416                     a.parse_args::<DefaultErrorAttribute>()
417                         .map(InitializerAttribute::DefaultError)
418                 } else {
419                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
420                 }
421             })
422             .collect::<Result<Vec<_>, _>>()?;
423         Ok(Self {
424             attrs,
425             this,
426             path,
427             brace_token,
428             fields,
429             rest,
430             error,
431         })
432     }
433 }
434 
435 impl Parse for DefaultErrorAttribute {
436     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
437         Ok(Self { ty: input.parse()? })
438     }
439 }
440 
441 impl Parse for This {
442     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
443         Ok(Self {
444             _and_token: input.parse()?,
445             ident: input.parse()?,
446             _in_token: input.parse()?,
447         })
448     }
449 }
450 
451 impl Parse for InitializerField {
452     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
453         let lh = input.lookahead1();
454         if lh.peek(Token![_]) {
455             Ok(Self::Code {
456                 _underscore_token: input.parse()?,
457                 _colon_token: input.parse()?,
458                 block: input.parse()?,
459             })
460         } else if lh.peek(Ident) {
461             let ident = input.parse()?;
462             let lh = input.lookahead1();
463             if lh.peek(Token![<-]) {
464                 Ok(Self::Init {
465                     ident,
466                     _left_arrow_token: input.parse()?,
467                     value: input.parse()?,
468                 })
469             } else if lh.peek(Token![:]) {
470                 Ok(Self::Value {
471                     ident,
472                     value: Some((input.parse()?, input.parse()?)),
473                 })
474             } else if lh.peek(Token![,]) || lh.peek(End) {
475                 Ok(Self::Value { ident, value: None })
476             } else {
477                 Err(lh.error())
478             }
479         } else {
480             Err(lh.error())
481         }
482     }
483 }
484