xref: /linux/rust/pin-init/internal/src/init.rs (revision 23b0f90ba871f096474e1c27c3d14f455189d2d9)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65     DisableInitializedFieldAccess,
66 }
67 
68 struct DefaultErrorAttribute {
69     ty: Box<Type>,
70 }
71 
72 pub(crate) fn expand(
73     Initializer {
74         attrs,
75         this,
76         path,
77         brace_token,
78         fields,
79         rest,
80         error,
81     }: Initializer,
82     default_error: Option<&'static str>,
83     pinned: bool,
84     dcx: &mut DiagCtxt,
85 ) -> Result<TokenStream, ErrorGuaranteed> {
86     let error = error.map_or_else(
87         || {
88             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("PinData"),
110             format_ident!("__pin_data"),
111             format_ident!("pin_init_from_closure"),
112         )
113     } else {
114         (
115             format_ident!("HasInitData"),
116             format_ident!("InitData"),
117             format_ident!("__init_data"),
118             format_ident!("init_from_closure"),
119         )
120     };
121     let init_kind = get_init_kind(rest, dcx);
122     let zeroable_check = match init_kind {
123         InitKind::Normal => quote!(),
124         InitKind::Zeroing => quote! {
125             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
126             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
127             // This allows us to also remove the check that all fields are present (since we
128             // already set the memory to zero and that is a valid bit pattern).
129             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
130             where T: ::pin_init::Zeroable
131             {}
132             // Ensure that the struct is indeed `Zeroable`.
133             assert_zeroable(#slot);
134             // SAFETY: The type implements `Zeroable` by the check above.
135             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
136         },
137     };
138     let this = match this {
139         None => quote!(),
140         Some(This { ident, .. }) => quote! {
141             // Create the `this` so it can be referenced by the user inside of the
142             // expressions creating the individual fields.
143             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
144         },
145     };
146     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
147     let data = Ident::new("__data", Span::mixed_site());
148     let init_fields = init_fields(
149         &fields,
150         pinned,
151         !attrs
152             .iter()
153             .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)),
154         &data,
155         &slot,
156     );
157     let field_check = make_field_check(&fields, init_kind, &path);
158     Ok(quote! {{
159         // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
160         // type and shadow it later when we insert the arbitrary user code. That way there will be
161         // no possibility of returning without `unsafe`.
162         struct __InitOk;
163 
164         // Get the data about fields from the supplied type.
165         // SAFETY: TODO
166         let #data = unsafe {
167             use ::pin_init::__internal::#has_data_trait;
168             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
169             // generics (which need to be present with that syntax).
170             #path::#get_data()
171         };
172         // Ensure that `#data` really is of type `#data` and help with type inference:
173         let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
174             #data,
175             move |slot| {
176                 {
177                     // Shadow the structure so it cannot be used to return early.
178                     struct __InitOk;
179                     #zeroable_check
180                     #this
181                     #init_fields
182                     #field_check
183                 }
184                 Ok(__InitOk)
185             }
186         );
187         let init = move |slot| -> ::core::result::Result<(), #error> {
188             init(slot).map(|__InitOk| ())
189         };
190         // SAFETY: TODO
191         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
192         init
193     }})
194 }
195 
196 enum InitKind {
197     Normal,
198     Zeroing,
199 }
200 
201 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
202     let Some((dotdot, expr)) = rest else {
203         return InitKind::Normal;
204     };
205     match &expr {
206         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
207             Expr::Path(ExprPath {
208                 attrs,
209                 qself: None,
210                 path:
211                     Path {
212                         leading_colon: None,
213                         segments,
214                     },
215             }) if attrs.is_empty()
216                 && segments.len() == 2
217                 && segments[0].ident == "Zeroable"
218                 && segments[0].arguments.is_none()
219                 && segments[1].ident == "init_zeroed"
220                 && segments[1].arguments.is_none() =>
221             {
222                 return InitKind::Zeroing;
223             }
224             _ => {}
225         },
226         _ => {}
227     }
228     dcx.error(
229         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
230         "expected nothing or `..Zeroable::init_zeroed()`.",
231     );
232     InitKind::Normal
233 }
234 
235 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
236 fn init_fields(
237     fields: &Punctuated<InitializerField, Token![,]>,
238     pinned: bool,
239     generate_initialized_accessors: bool,
240     data: &Ident,
241     slot: &Ident,
242 ) -> TokenStream {
243     let mut guards = vec![];
244     let mut guard_attrs = vec![];
245     let mut res = TokenStream::new();
246     for InitializerField { attrs, kind } in fields {
247         let cfgs = {
248             let mut cfgs = attrs.clone();
249             cfgs.retain(|attr| attr.path().is_ident("cfg"));
250             cfgs
251         };
252         let init = match kind {
253             InitializerKind::Value { ident, value } => {
254                 let mut value_ident = ident.clone();
255                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
256                     // Setting the span of `value_ident` to `value`'s span improves error messages
257                     // when the type of `value` is wrong.
258                     value_ident.set_span(value.span());
259                     quote!(let #value_ident = #value;)
260                 });
261                 // Again span for better diagnostics
262                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
263                 let accessor = if pinned {
264                     let project_ident = format_ident!("__project_{ident}");
265                     quote! {
266                         // SAFETY: TODO
267                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
268                     }
269                 } else {
270                     quote! {
271                         // SAFETY: TODO
272                         unsafe { &mut (*#slot).#ident }
273                     }
274                 };
275                 let accessor = generate_initialized_accessors.then(|| {
276                     quote! {
277                         #(#cfgs)*
278                         #[allow(unused_variables)]
279                         let #ident = #accessor;
280                     }
281                 });
282                 quote! {
283                     #(#attrs)*
284                     {
285                         #value_prep
286                         // SAFETY: TODO
287                         unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
288                     }
289                     #accessor
290                 }
291             }
292             InitializerKind::Init { ident, value, .. } => {
293                 // Again span for better diagnostics
294                 let init = format_ident!("init", span = value.span());
295                 let (value_init, accessor) = if pinned {
296                     let project_ident = format_ident!("__project_{ident}");
297                     (
298                         quote! {
299                             // SAFETY:
300                             // - `slot` is valid, because we are inside of an initializer closure, we
301                             //   return when an error/panic occurs.
302                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
303                             //   for `#ident`.
304                             unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
305                         },
306                         quote! {
307                             // SAFETY: TODO
308                             unsafe { #data.#project_ident(&mut (*#slot).#ident) }
309                         },
310                     )
311                 } else {
312                     (
313                         quote! {
314                             // SAFETY: `slot` is valid, because we are inside of an initializer
315                             // closure, we return when an error/panic occurs.
316                             unsafe {
317                                 ::pin_init::Init::__init(
318                                     #init,
319                                     ::core::ptr::addr_of_mut!((*#slot).#ident),
320                                 )?
321                             };
322                         },
323                         quote! {
324                             // SAFETY: TODO
325                             unsafe { &mut (*#slot).#ident }
326                         },
327                     )
328                 };
329                 let accessor = generate_initialized_accessors.then(|| {
330                     quote! {
331                         #(#cfgs)*
332                         #[allow(unused_variables)]
333                         let #ident = #accessor;
334                     }
335                 });
336                 quote! {
337                     #(#attrs)*
338                     {
339                         let #init = #value;
340                         #value_init
341                     }
342                     #accessor
343                 }
344             }
345             InitializerKind::Code { block: value, .. } => quote! {
346                 #(#attrs)*
347                 #[allow(unused_braces)]
348                 #value
349             },
350         };
351         res.extend(init);
352         if let Some(ident) = kind.ident() {
353             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
354             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
355             res.extend(quote! {
356                 #(#cfgs)*
357                 // Create the drop guard:
358                 //
359                 // We rely on macro hygiene to make it impossible for users to access this local
360                 // variable.
361                 // SAFETY: We forget the guard later when initialization has succeeded.
362                 let #guard = unsafe {
363                     ::pin_init::__internal::DropGuard::new(
364                         ::core::ptr::addr_of_mut!((*slot).#ident)
365                     )
366                 };
367             });
368             guards.push(guard);
369             guard_attrs.push(cfgs);
370         }
371     }
372     quote! {
373         #res
374         // If execution reaches this point, all fields have been initialized. Therefore we can now
375         // dismiss the guards by forgetting them.
376         #(
377             #(#guard_attrs)*
378             ::core::mem::forget(#guards);
379         )*
380     }
381 }
382 
383 /// Generate the check for ensuring that every field has been initialized.
384 fn make_field_check(
385     fields: &Punctuated<InitializerField, Token![,]>,
386     init_kind: InitKind,
387     path: &Path,
388 ) -> TokenStream {
389     let field_attrs = fields
390         .iter()
391         .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
392     let field_name = fields.iter().filter_map(|f| f.kind.ident());
393     match init_kind {
394         InitKind::Normal => quote! {
395             // We use unreachable code to ensure that all fields have been mentioned exactly once,
396             // this struct initializer will still be type-checked and complain with a very natural
397             // error message if a field is forgotten/mentioned more than once.
398             #[allow(unreachable_code, clippy::diverging_sub_expression)]
399             // SAFETY: this code is never executed.
400             let _ = || unsafe {
401                 ::core::ptr::write(slot, #path {
402                     #(
403                         #(#field_attrs)*
404                         #field_name: ::core::panic!(),
405                     )*
406                 })
407             };
408         },
409         InitKind::Zeroing => quote! {
410             // We use unreachable code to ensure that all fields have been mentioned at most once.
411             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
412             // be zeroed. This struct initializer will still be type-checked and complain with a
413             // very natural error message if a field is mentioned more than once, or doesn't exist.
414             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
415             // SAFETY: this code is never executed.
416             let _ = || unsafe {
417                 ::core::ptr::write(slot, #path {
418                     #(
419                         #(#field_attrs)*
420                         #field_name: ::core::panic!(),
421                     )*
422                     ..::core::mem::zeroed()
423                 })
424             };
425         },
426     }
427 }
428 
429 impl Parse for Initializer {
430     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
431         let attrs = input.call(Attribute::parse_outer)?;
432         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
433         let path = input.parse()?;
434         let content;
435         let brace_token = braced!(content in input);
436         let mut fields = Punctuated::new();
437         loop {
438             let lh = content.lookahead1();
439             if lh.peek(End) || lh.peek(Token![..]) {
440                 break;
441             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
442                 fields.push_value(content.parse()?);
443                 let lh = content.lookahead1();
444                 if lh.peek(End) {
445                     break;
446                 } else if lh.peek(Token![,]) {
447                     fields.push_punct(content.parse()?);
448                 } else {
449                     return Err(lh.error());
450                 }
451             } else {
452                 return Err(lh.error());
453             }
454         }
455         let rest = content
456             .peek(Token![..])
457             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
458             .transpose()?;
459         let error = input
460             .peek(Token![?])
461             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
462             .transpose()?;
463         let attrs = attrs
464             .into_iter()
465             .map(|a| {
466                 if a.path().is_ident("default_error") {
467                     a.parse_args::<DefaultErrorAttribute>()
468                         .map(InitializerAttribute::DefaultError)
469                 } else if a.path().is_ident("disable_initialized_field_access") {
470                     a.meta
471                         .require_path_only()
472                         .map(|_| InitializerAttribute::DisableInitializedFieldAccess)
473                 } else {
474                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
475                 }
476             })
477             .collect::<Result<Vec<_>, _>>()?;
478         Ok(Self {
479             attrs,
480             this,
481             path,
482             brace_token,
483             fields,
484             rest,
485             error,
486         })
487     }
488 }
489 
490 impl Parse for DefaultErrorAttribute {
491     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
492         Ok(Self { ty: input.parse()? })
493     }
494 }
495 
496 impl Parse for This {
497     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
498         Ok(Self {
499             _and_token: input.parse()?,
500             ident: input.parse()?,
501             _in_token: input.parse()?,
502         })
503     }
504 }
505 
506 impl Parse for InitializerField {
507     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
508         let attrs = input.call(Attribute::parse_outer)?;
509         Ok(Self {
510             attrs,
511             kind: input.parse()?,
512         })
513     }
514 }
515 
516 impl Parse for InitializerKind {
517     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
518         let lh = input.lookahead1();
519         if lh.peek(Token![_]) {
520             Ok(Self::Code {
521                 _underscore_token: input.parse()?,
522                 _colon_token: input.parse()?,
523                 block: input.parse()?,
524             })
525         } else if lh.peek(Ident) {
526             let ident = input.parse()?;
527             let lh = input.lookahead1();
528             if lh.peek(Token![<-]) {
529                 Ok(Self::Init {
530                     ident,
531                     _left_arrow_token: input.parse()?,
532                     value: input.parse()?,
533                 })
534             } else if lh.peek(Token![:]) {
535                 Ok(Self::Value {
536                     ident,
537                     value: Some((input.parse()?, input.parse()?)),
538                 })
539             } else if lh.peek(Token![,]) || lh.peek(End) {
540                 Ok(Self::Value { ident, value: None })
541             } else {
542                 Err(lh.error())
543             }
544         } else {
545             Err(lh.error())
546         }
547     }
548 }
549