xref: /linux/rust/pin-init/internal/src/init.rs (revision 5423ef9d4db852835746001d0840231227bb0e39)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     attrs: Vec<InitializerAttribute>,
18     this: Option<This>,
19     path: Path,
20     brace_token: token::Brace,
21     fields: Punctuated<InitializerField, Token![,]>,
22     rest: Option<(Token![..], Expr)>,
23     error: Option<(Token![?], Type)>,
24 }
25 
26 struct This {
27     _and_token: Token![&],
28     ident: Ident,
29     _in_token: Token![in],
30 }
31 
32 struct InitializerField {
33     attrs: Vec<Attribute>,
34     kind: InitializerKind,
35 }
36 
37 enum InitializerKind {
38     Value {
39         ident: Ident,
40         value: Option<(Token![:], Expr)>,
41     },
42     Init {
43         ident: Ident,
44         _left_arrow_token: Token![<-],
45         value: Expr,
46     },
47     Code {
48         _underscore_token: Token![_],
49         _colon_token: Token![:],
50         block: Block,
51     },
52 }
53 
54 impl InitializerKind {
55     fn ident(&self) -> Option<&Ident> {
56         match self {
57             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
58             Self::Code { .. } => None,
59         }
60     }
61 }
62 
63 enum InitializerAttribute {
64     DefaultError(DefaultErrorAttribute),
65 }
66 
67 struct DefaultErrorAttribute {
68     ty: Box<Type>,
69 }
70 
71 pub(crate) fn expand(
72     Initializer {
73         attrs,
74         this,
75         path,
76         brace_token,
77         fields,
78         rest,
79         error,
80     }: Initializer,
81     default_error: Option<&'static str>,
82     pinned: bool,
83     dcx: &mut DiagCtxt,
84 ) -> Result<TokenStream, ErrorGuaranteed> {
85     let error = error.map_or_else(
86         || {
87             if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
88                 #[expect(irrefutable_let_patterns)]
89                 if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
90                     Some(ty.clone())
91                 } else {
92                     acc
93                 }
94             }) {
95                 default_error
96             } else if let Some(default_error) = default_error {
97                 syn::parse_str(default_error).unwrap()
98             } else {
99                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
100                 parse_quote!(::core::convert::Infallible)
101             }
102         },
103         |(_, err)| Box::new(err),
104     );
105     let slot = format_ident!("slot");
106     let (has_data_trait, get_data, init_from_closure) = if pinned {
107         (
108             format_ident!("HasPinData"),
109             format_ident!("__pin_data"),
110             format_ident!("pin_init_from_closure"),
111         )
112     } else {
113         (
114             format_ident!("HasInitData"),
115             format_ident!("__init_data"),
116             format_ident!("init_from_closure"),
117         )
118     };
119     let init_kind = get_init_kind(rest, dcx);
120     let zeroable_check = match init_kind {
121         InitKind::Normal => quote!(),
122         InitKind::Zeroing => quote! {
123             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
124             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
125             // This allows us to also remove the check that all fields are present (since we
126             // already set the memory to zero and that is a valid bit pattern).
127             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
128             where T: ::pin_init::Zeroable
129             {}
130             // Ensure that the struct is indeed `Zeroable`.
131             assert_zeroable(#slot);
132             // SAFETY: The type implements `Zeroable` by the check above.
133             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
134         },
135     };
136     let this = match this {
137         None => quote!(),
138         Some(This { ident, .. }) => quote! {
139             // Create the `this` so it can be referenced by the user inside of the
140             // expressions creating the individual fields.
141             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
142         },
143     };
144     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
145     let data = Ident::new("__data", Span::mixed_site());
146     let init_fields = init_fields(&fields, pinned, &data, &slot);
147     let field_check = make_field_check(&fields, init_kind, &path);
148     Ok(quote! {{
149         // Get the data about fields from the supplied type.
150         // SAFETY: TODO
151         let #data = unsafe {
152             use ::pin_init::__internal::#has_data_trait;
153             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
154             // generics (which need to be present with that syntax).
155             #path::#get_data()
156         };
157         // Ensure that `#data` really is of type `#data` and help with type inference:
158         let init = #data.__make_closure::<_, #error>(
159             move |slot| {
160                 #zeroable_check
161                 #this
162                 #init_fields
163                 #field_check
164                 // SAFETY: we are the `init!` macro that is allowed to call this.
165                 Ok(unsafe { ::pin_init::__internal::InitOk::new() })
166             }
167         );
168         let init = move |slot| -> ::core::result::Result<(), #error> {
169             init(slot).map(|__InitOk| ())
170         };
171         // SAFETY: TODO
172         unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }
173     }})
174 }
175 
176 enum InitKind {
177     Normal,
178     Zeroing,
179 }
180 
181 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
182     let Some((dotdot, expr)) = rest else {
183         return InitKind::Normal;
184     };
185     match &expr {
186         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
187             Expr::Path(ExprPath {
188                 attrs,
189                 qself: None,
190                 path:
191                     Path {
192                         leading_colon: None,
193                         segments,
194                     },
195             }) if attrs.is_empty()
196                 && segments.len() == 2
197                 && segments[0].ident == "Zeroable"
198                 && segments[0].arguments.is_none()
199                 && segments[1].ident == "init_zeroed"
200                 && segments[1].arguments.is_none() =>
201             {
202                 return InitKind::Zeroing;
203             }
204             _ => {}
205         },
206         _ => {}
207     }
208     dcx.error(
209         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
210         "expected nothing or `..Zeroable::init_zeroed()`.",
211     );
212     InitKind::Normal
213 }
214 
215 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
216 fn init_fields(
217     fields: &Punctuated<InitializerField, Token![,]>,
218     pinned: bool,
219     data: &Ident,
220     slot: &Ident,
221 ) -> TokenStream {
222     let mut guards = vec![];
223     let mut guard_attrs = vec![];
224     let mut res = TokenStream::new();
225     for InitializerField { attrs, kind } in fields {
226         let cfgs = {
227             let mut cfgs = attrs.clone();
228             cfgs.retain(|attr| attr.path().is_ident("cfg"));
229             cfgs
230         };
231 
232         let ident = match kind {
233             InitializerKind::Value { ident, .. } => ident,
234             InitializerKind::Init { ident, .. } => ident,
235             InitializerKind::Code { block, .. } => {
236                 res.extend(quote! {
237                     #(#attrs)*
238                     #[allow(unused_braces)]
239                     #block
240                 });
241                 continue;
242             }
243         };
244 
245         let slot = if pinned {
246             quote! {
247                 // SAFETY:
248                 // - `slot` is valid and properly aligned.
249                 // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned.
250                 // - `make_field_check` prevents `#ident` from being used twice, therefore
251                 //   `(*slot).#ident` is exclusively accessed and has not been initialized.
252                 (unsafe { #data.#ident(#slot) })
253             }
254         } else {
255             quote! {
256                 // For `init!()` macro, everything is unpinned.
257                 // SAFETY:
258                 // - `&raw mut (*slot).#ident` is valid.
259                 // - `make_field_check` checks that `&raw mut (*slot).#ident` is properly aligned.
260                 // - `make_field_check` prevents `#ident` from being used twice, therefore
261                 //   `(*slot).#ident` is exclusively accessed and has not been initialized.
262                 (unsafe {
263                     ::pin_init::__internal::Slot::<::pin_init::__internal::Unpinned, _>::new(
264                         &raw mut (*#slot).#ident
265                     )
266                 })
267             }
268         };
269 
270         // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
271         let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
272 
273         let init = match kind {
274             InitializerKind::Value { ident, value } => {
275                 let value = value
276                     .as_ref()
277                     .map(|(_, value)| quote!(#value))
278                     .unwrap_or_else(|| quote!(#ident));
279 
280                 quote! {
281                     #(#attrs)*
282                     let mut #guard = #slot.write(#value);
283 
284                 }
285             }
286             InitializerKind::Init { value, .. } => {
287                 quote! {
288                     #(#attrs)*
289                     let mut #guard = #slot.init(#value)?;
290                 }
291             }
292             InitializerKind::Code { .. } => unreachable!(),
293         };
294 
295         res.extend(quote! {
296             #init
297 
298             #(#cfgs)*
299             // Allow `non_snake_case` since the same warning is going to be reported for the struct
300             // field.
301             #[allow(unused_variables, non_snake_case)]
302             let #ident = #guard.let_binding();
303         });
304 
305         guards.push(guard);
306         guard_attrs.push(cfgs);
307     }
308     quote! {
309         #res
310         // If execution reaches this point, all fields have been initialized. Therefore we can now
311         // dismiss the guards by forgetting them.
312         #(
313             #(#guard_attrs)*
314             ::core::mem::forget(#guards);
315         )*
316     }
317 }
318 
319 /// Generate the check for ensuring that every field has been initialized and aligned.
320 fn make_field_check(
321     fields: &Punctuated<InitializerField, Token![,]>,
322     init_kind: InitKind,
323     path: &Path,
324 ) -> TokenStream {
325     let field_attrs: Vec<_> = fields
326         .iter()
327         .filter_map(|f| f.kind.ident().map(|_| &f.attrs))
328         .collect();
329     let field_name: Vec<_> = fields.iter().filter_map(|f| f.kind.ident()).collect();
330     let zeroing_trailer = match init_kind {
331         InitKind::Normal => None,
332         InitKind::Zeroing => Some(quote! {
333             ..::core::mem::zeroed()
334         }),
335     };
336     quote! {
337         #[allow(unreachable_code, clippy::diverging_sub_expression)]
338         // We use unreachable code to perform field checks. They're still checked by the compiler.
339         // SAFETY: this code is never executed.
340         let _ = || unsafe {
341             // Create references to ensure that the initialized field is properly aligned.
342             // Unaligned fields will cause the compiler to emit E0793. We do not support
343             // unaligned fields since `Init::__init` requires an aligned pointer; the call to
344             // `ptr::write` for value-initialization case has the same requirement.
345             #(
346                 #(#field_attrs)*
347                 let _ = &(*slot).#field_name;
348             )*
349 
350             // If the zeroing trailer is not present, this checks that all fields have been
351             // mentioned exactly once. If the zeroing trailer is present, all missing fields will be
352             // zeroed, so this checks that all fields have been mentioned at most once. The use of
353             // struct initializer will still generate very natural error messages for any misuse.
354             ::core::ptr::write(slot, #path {
355                 #(
356                     #(#field_attrs)*
357                     #field_name: ::core::panic!(),
358                 )*
359                 #zeroing_trailer
360             })
361         };
362     }
363 }
364 
365 impl Parse for Initializer {
366     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
367         let attrs = input.call(Attribute::parse_outer)?;
368         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
369         let path = input.parse()?;
370         let content;
371         let brace_token = braced!(content in input);
372         let mut fields = Punctuated::new();
373         loop {
374             let lh = content.lookahead1();
375             if lh.peek(End) || lh.peek(Token![..]) {
376                 break;
377             } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
378                 fields.push_value(content.parse()?);
379                 let lh = content.lookahead1();
380                 if lh.peek(End) {
381                     break;
382                 } else if lh.peek(Token![,]) {
383                     fields.push_punct(content.parse()?);
384                 } else {
385                     return Err(lh.error());
386                 }
387             } else {
388                 return Err(lh.error());
389             }
390         }
391         let rest = content
392             .peek(Token![..])
393             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
394             .transpose()?;
395         let error = input
396             .peek(Token![?])
397             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
398             .transpose()?;
399         let attrs = attrs
400             .into_iter()
401             .map(|a| {
402                 if a.path().is_ident("default_error") {
403                     a.parse_args::<DefaultErrorAttribute>()
404                         .map(InitializerAttribute::DefaultError)
405                 } else {
406                     Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
407                 }
408             })
409             .collect::<Result<Vec<_>, _>>()?;
410         Ok(Self {
411             attrs,
412             this,
413             path,
414             brace_token,
415             fields,
416             rest,
417             error,
418         })
419     }
420 }
421 
422 impl Parse for DefaultErrorAttribute {
423     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
424         Ok(Self { ty: input.parse()? })
425     }
426 }
427 
428 impl Parse for This {
429     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
430         Ok(Self {
431             _and_token: input.parse()?,
432             ident: input.parse()?,
433             _in_token: input.parse()?,
434         })
435     }
436 }
437 
438 impl Parse for InitializerField {
439     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
440         let attrs = input.call(Attribute::parse_outer)?;
441         Ok(Self {
442             attrs,
443             kind: input.parse()?,
444         })
445     }
446 }
447 
448 impl Parse for InitializerKind {
449     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
450         let lh = input.lookahead1();
451         if lh.peek(Token![_]) {
452             Ok(Self::Code {
453                 _underscore_token: input.parse()?,
454                 _colon_token: input.parse()?,
455                 block: input.parse()?,
456             })
457         } else if lh.peek(Ident) {
458             let ident = input.parse()?;
459             let lh = input.lookahead1();
460             if lh.peek(Token![<-]) {
461                 Ok(Self::Init {
462                     ident,
463                     _left_arrow_token: input.parse()?,
464                     value: input.parse()?,
465                 })
466             } else if lh.peek(Token![:]) {
467                 Ok(Self::Value {
468                     ident,
469                     value: Some((input.parse()?, input.parse()?)),
470                 })
471             } else if lh.peek(Token![,]) || lh.peek(End) {
472                 Ok(Self::Value { ident, value: None })
473             } else {
474                 Err(lh.error())
475             }
476         } else {
477             Err(lh.error())
478         }
479     }
480 }
481