xref: /linux/rust/pin-init/internal/src/init.rs (revision 4883830e9784bdf6223fe0e5f1ea36d4a4ab4fef)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use quote::{format_ident, quote, quote_spanned};
5 use syn::{
6     braced,
7     parse::{End, Parse},
8     parse_quote,
9     punctuated::Punctuated,
10     spanned::Spanned,
11     token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
12 };
13 
14 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
15 
16 pub(crate) struct Initializer {
17     this: Option<This>,
18     path: Path,
19     brace_token: token::Brace,
20     fields: Punctuated<InitializerField, Token![,]>,
21     rest: Option<(Token![..], Expr)>,
22     error: Option<(Token![?], Type)>,
23 }
24 
25 struct This {
26     _and_token: Token![&],
27     ident: Ident,
28     _in_token: Token![in],
29 }
30 
31 enum InitializerField {
32     Value {
33         ident: Ident,
34         value: Option<(Token![:], Expr)>,
35     },
36     Init {
37         ident: Ident,
38         _left_arrow_token: Token![<-],
39         value: Expr,
40     },
41     Code {
42         _underscore_token: Token![_],
43         _colon_token: Token![:],
44         block: Block,
45     },
46 }
47 
48 impl InitializerField {
49     fn ident(&self) -> Option<&Ident> {
50         match self {
51             Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
52             Self::Code { .. } => None,
53         }
54     }
55 }
56 
57 pub(crate) fn expand(
58     Initializer {
59         this,
60         path,
61         brace_token,
62         fields,
63         rest,
64         error,
65     }: Initializer,
66     default_error: Option<&'static str>,
67     pinned: bool,
68     dcx: &mut DiagCtxt,
69 ) -> Result<TokenStream, ErrorGuaranteed> {
70     let error = error.map_or_else(
71         || {
72             if let Some(default_error) = default_error {
73                 syn::parse_str(default_error).unwrap()
74             } else {
75                 dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
76                 parse_quote!(::core::convert::Infallible)
77             }
78         },
79         |(_, err)| err,
80     );
81     let slot = format_ident!("slot");
82     let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
83         (
84             format_ident!("HasPinData"),
85             format_ident!("PinData"),
86             format_ident!("__pin_data"),
87             format_ident!("pin_init_from_closure"),
88         )
89     } else {
90         (
91             format_ident!("HasInitData"),
92             format_ident!("InitData"),
93             format_ident!("__init_data"),
94             format_ident!("init_from_closure"),
95         )
96     };
97     let init_kind = get_init_kind(rest, dcx);
98     let zeroable_check = match init_kind {
99         InitKind::Normal => quote!(),
100         InitKind::Zeroing => quote! {
101             // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
102             // Therefore we check if the struct implements `Zeroable` and then zero the memory.
103             // This allows us to also remove the check that all fields are present (since we
104             // already set the memory to zero and that is a valid bit pattern).
105             fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
106             where T: ::pin_init::Zeroable
107             {}
108             // Ensure that the struct is indeed `Zeroable`.
109             assert_zeroable(#slot);
110             // SAFETY: The type implements `Zeroable` by the check above.
111             unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
112         },
113     };
114     let this = match this {
115         None => quote!(),
116         Some(This { ident, .. }) => quote! {
117             // Create the `this` so it can be referenced by the user inside of the
118             // expressions creating the individual fields.
119             let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
120         },
121     };
122     // `mixed_site` ensures that the data is not accessible to the user-controlled code.
123     let data = Ident::new("__data", Span::mixed_site());
124     let init_fields = init_fields(&fields, pinned, &data, &slot);
125     let field_check = make_field_check(&fields, init_kind, &path);
126     Ok(quote! {{
127         // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
128         // type and shadow it later when we insert the arbitrary user code. That way there will be
129         // no possibility of returning without `unsafe`.
130         struct __InitOk;
131 
132         // Get the data about fields from the supplied type.
133         // SAFETY: TODO
134         let #data = unsafe {
135             use ::pin_init::__internal::#has_data_trait;
136             // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
137             // generics (which need to be present with that syntax).
138             #path::#get_data()
139         };
140         // Ensure that `#data` really is of type `#data` and help with type inference:
141         let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
142             #data,
143             move |slot| {
144                 {
145                     // Shadow the structure so it cannot be used to return early.
146                     struct __InitOk;
147                     #zeroable_check
148                     #this
149                     #init_fields
150                     #field_check
151                 }
152                 Ok(__InitOk)
153             }
154         );
155         let init = move |slot| -> ::core::result::Result<(), #error> {
156             init(slot).map(|__InitOk| ())
157         };
158         // SAFETY: TODO
159         let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
160         init
161     }})
162 }
163 
164 enum InitKind {
165     Normal,
166     Zeroing,
167 }
168 
169 fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
170     let Some((dotdot, expr)) = rest else {
171         return InitKind::Normal;
172     };
173     match &expr {
174         Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
175             Expr::Path(ExprPath {
176                 attrs,
177                 qself: None,
178                 path:
179                     Path {
180                         leading_colon: None,
181                         segments,
182                     },
183             }) if attrs.is_empty()
184                 && segments.len() == 2
185                 && segments[0].ident == "Zeroable"
186                 && segments[0].arguments.is_none()
187                 && segments[1].ident == "init_zeroed"
188                 && segments[1].arguments.is_none() =>
189             {
190                 return InitKind::Zeroing;
191             }
192             _ => {}
193         },
194         _ => {}
195     }
196     dcx.error(
197         dotdot.span().join(expr.span()).unwrap_or(expr.span()),
198         "expected nothing or `..Zeroable::init_zeroed()`.",
199     );
200     InitKind::Normal
201 }
202 
203 /// Generate the code that initializes the fields of the struct using the initializers in `field`.
204 fn init_fields(
205     fields: &Punctuated<InitializerField, Token![,]>,
206     pinned: bool,
207     data: &Ident,
208     slot: &Ident,
209 ) -> TokenStream {
210     let mut guards = vec![];
211     let mut res = TokenStream::new();
212     for field in fields {
213         let init = match field {
214             InitializerField::Value { ident, value } => {
215                 let mut value_ident = ident.clone();
216                 let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
217                     // Setting the span of `value_ident` to `value`'s span improves error messages
218                     // when the type of `value` is wrong.
219                     value_ident.set_span(value.span());
220                     quote!(let #value_ident = #value;)
221                 });
222                 // Again span for better diagnostics
223                 let write = quote_spanned!(ident.span()=> ::core::ptr::write);
224                 let accessor = if pinned {
225                     let project_ident = format_ident!("__project_{ident}");
226                     quote! {
227                         // SAFETY: TODO
228                         unsafe { #data.#project_ident(&mut (*#slot).#ident) }
229                     }
230                 } else {
231                     quote! {
232                         // SAFETY: TODO
233                         unsafe { &mut (*#slot).#ident }
234                     }
235                 };
236                 quote! {
237                     {
238                         #value_prep
239                         // SAFETY: TODO
240                         unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
241                     }
242                     #[allow(unused_variables)]
243                     let #ident = #accessor;
244                 }
245             }
246             InitializerField::Init { ident, value, .. } => {
247                 // Again span for better diagnostics
248                 let init = format_ident!("init", span = value.span());
249                 if pinned {
250                     let project_ident = format_ident!("__project_{ident}");
251                     quote! {
252                         {
253                             let #init = #value;
254                             // SAFETY:
255                             // - `slot` is valid, because we are inside of an initializer closure, we
256                             //   return when an error/panic occurs.
257                             // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
258                             //   for `#ident`.
259                             unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
260                         }
261                         // SAFETY: TODO
262                         #[allow(unused_variables)]
263                         let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) };
264                     }
265                 } else {
266                     quote! {
267                         {
268                             let #init = #value;
269                             // SAFETY: `slot` is valid, because we are inside of an initializer
270                             // closure, we return when an error/panic occurs.
271                             unsafe {
272                                 ::pin_init::Init::__init(
273                                     #init,
274                                     ::core::ptr::addr_of_mut!((*#slot).#ident),
275                                 )?
276                             };
277                         }
278                         // SAFETY: TODO
279                         #[allow(unused_variables)]
280                         let #ident = unsafe { &mut (*#slot).#ident };
281                     }
282                 }
283             }
284             InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value),
285         };
286         res.extend(init);
287         if let Some(ident) = field.ident() {
288             // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
289             let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
290             guards.push(guard.clone());
291             res.extend(quote! {
292                 // Create the drop guard:
293                 //
294                 // We rely on macro hygiene to make it impossible for users to access this local
295                 // variable.
296                 // SAFETY: We forget the guard later when initialization has succeeded.
297                 let #guard = unsafe {
298                     ::pin_init::__internal::DropGuard::new(
299                         ::core::ptr::addr_of_mut!((*slot).#ident)
300                     )
301                 };
302             });
303         }
304     }
305     quote! {
306         #res
307         // If execution reaches this point, all fields have been initialized. Therefore we can now
308         // dismiss the guards by forgetting them.
309         #(::core::mem::forget(#guards);)*
310     }
311 }
312 
313 /// Generate the check for ensuring that every field has been initialized.
314 fn make_field_check(
315     fields: &Punctuated<InitializerField, Token![,]>,
316     init_kind: InitKind,
317     path: &Path,
318 ) -> TokenStream {
319     let fields = fields.iter().filter_map(|f| f.ident());
320     match init_kind {
321         InitKind::Normal => quote! {
322             // We use unreachable code to ensure that all fields have been mentioned exactly once,
323             // this struct initializer will still be type-checked and complain with a very natural
324             // error message if a field is forgotten/mentioned more than once.
325             #[allow(unreachable_code, clippy::diverging_sub_expression)]
326             // SAFETY: this code is never executed.
327             let _ = || unsafe {
328                 ::core::ptr::write(slot, #path {
329                     #(
330                         #fields: ::core::panic!(),
331                     )*
332                 })
333             };
334         },
335         InitKind::Zeroing => quote! {
336             // We use unreachable code to ensure that all fields have been mentioned at most once.
337             // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
338             // be zeroed. This struct initializer will still be type-checked and complain with a
339             // very natural error message if a field is mentioned more than once, or doesn't exist.
340             #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
341             // SAFETY: this code is never executed.
342             let _ = || unsafe {
343                 let mut zeroed = ::core::mem::zeroed();
344                 // We have to use type inference here to make zeroed have the correct type. This
345                 // does not get executed, so it has no effect.
346                 ::core::ptr::write(slot, zeroed);
347                 zeroed = ::core::mem::zeroed();
348                 ::core::ptr::write(slot, #path {
349                     #(
350                         #fields: ::core::panic!(),
351                     )*
352                     ..zeroed
353                 })
354             };
355         },
356     }
357 }
358 
359 impl Parse for Initializer {
360     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
361         let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
362         let path = input.parse()?;
363         let content;
364         let brace_token = braced!(content in input);
365         let mut fields = Punctuated::new();
366         loop {
367             let lh = content.lookahead1();
368             if lh.peek(End) || lh.peek(Token![..]) {
369                 break;
370             } else if lh.peek(Ident) || lh.peek(Token![_]) {
371                 fields.push_value(content.parse()?);
372                 let lh = content.lookahead1();
373                 if lh.peek(End) {
374                     break;
375                 } else if lh.peek(Token![,]) {
376                     fields.push_punct(content.parse()?);
377                 } else {
378                     return Err(lh.error());
379                 }
380             } else {
381                 return Err(lh.error());
382             }
383         }
384         let rest = content
385             .peek(Token![..])
386             .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
387             .transpose()?;
388         let error = input
389             .peek(Token![?])
390             .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
391             .transpose()?;
392         Ok(Self {
393             this,
394             path,
395             brace_token,
396             fields,
397             rest,
398             error,
399         })
400     }
401 }
402 
403 impl Parse for This {
404     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
405         Ok(Self {
406             _and_token: input.parse()?,
407             ident: input.parse()?,
408             _in_token: input.parse()?,
409         })
410     }
411 }
412 
413 impl Parse for InitializerField {
414     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
415         let lh = input.lookahead1();
416         if lh.peek(Token![_]) {
417             Ok(Self::Code {
418                 _underscore_token: input.parse()?,
419                 _colon_token: input.parse()?,
420                 block: input.parse()?,
421             })
422         } else if lh.peek(Ident) {
423             let ident = input.parse()?;
424             let lh = input.lookahead1();
425             if lh.peek(Token![<-]) {
426                 Ok(Self::Init {
427                     ident,
428                     _left_arrow_token: input.parse()?,
429                     value: input.parse()?,
430                 })
431             } else if lh.peek(Token![:]) {
432                 Ok(Self::Value {
433                     ident,
434                     value: Some((input.parse()?, input.parse()?)),
435                 })
436             } else if lh.peek(Token![,]) || lh.peek(End) {
437                 Ok(Self::Value { ident, value: None })
438             } else {
439                 Err(lh.error())
440             }
441         } else {
442             Err(lh.error())
443         }
444     }
445 }
446