xref: /linux/rust/pin-init/internal/src/pin_data.rs (revision 430654211d566f86e8ee533ff1b01a42be6b602c)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::TokenStream;
4 use quote::{format_ident, quote};
5 use syn::{
6     parse::{End, Nothing, Parse},
7     parse_quote, parse_quote_spanned,
8     spanned::Spanned,
9     visit_mut::VisitMut,
10     Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
11 };
12 
13 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
14 
15 pub(crate) mod kw {
16     syn::custom_keyword!(PinnedDrop);
17 }
18 
19 pub(crate) enum Args {
20     Nothing(Nothing),
21     #[allow(dead_code)]
22     PinnedDrop(kw::PinnedDrop),
23 }
24 
25 impl Parse for Args {
26     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
27         let lh = input.lookahead1();
28         if lh.peek(End) {
29             input.parse().map(Self::Nothing)
30         } else if lh.peek(kw::PinnedDrop) {
31             input.parse().map(Self::PinnedDrop)
32         } else {
33             Err(lh.error())
34         }
35     }
36 }
37 
38 pub(crate) fn pin_data(
39     args: Args,
40     input: Item,
41     dcx: &mut DiagCtxt,
42 ) -> Result<TokenStream, ErrorGuaranteed> {
43     let mut struct_ = match input {
44         Item::Struct(struct_) => struct_,
45         Item::Enum(enum_) => {
46             return Err(dcx.error(
47                 enum_.enum_token,
48                 "`#[pin_data]` only supports structs for now",
49             ));
50         }
51         Item::Union(union) => {
52             return Err(dcx.error(
53                 union.union_token,
54                 "`#[pin_data]` only supports structs for now",
55             ));
56         }
57         rest => {
58             return Err(dcx.error(
59                 rest,
60                 "`#[pin_data]` can only be applied to struct, enum and union definitions",
61             ));
62         }
63     };
64 
65     // The generics might contain the `Self` type. Since this macro will define a new type with the
66     // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
67     // to this struct definition. Therefore we have to replace `Self` with the concrete name.
68     let mut replacer = {
69         let name = &struct_.ident;
70         let (_, ty_generics, _) = struct_.generics.split_for_impl();
71         SelfReplacer(parse_quote!(#name #ty_generics))
72     };
73     replacer.visit_generics_mut(&mut struct_.generics);
74     replacer.visit_fields_mut(&mut struct_.fields);
75 
76     let fields: Vec<(bool, &Field)> = struct_
77         .fields
78         .iter_mut()
79         .map(|field| {
80             let len = field.attrs.len();
81             field.attrs.retain(|a| !a.path().is_ident("pin"));
82             (len != field.attrs.len(), &*field)
83         })
84         .collect();
85 
86     for (pinned, field) in &fields {
87         if !pinned && is_phantom_pinned(&field.ty) {
88             dcx.warn(
89                 field,
90                 format!(
91                     "The field `{}` of type `PhantomPinned` only has an effect \
92                     if it has the `#[pin]` attribute",
93                     field.ident.as_ref().unwrap(),
94                 ),
95             );
96         }
97     }
98 
99     let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
100     let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
101     let projections =
102         generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
103     let the_pin_data =
104         generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
105 
106     Ok(quote! {
107         #struct_
108         #projections
109         // We put the rest into this const item, because it then will not be accessible to anything
110         // outside.
111         const _: () = {
112             #the_pin_data
113             #unpin_impl
114             #drop_impl
115         };
116     })
117 }
118 
119 fn is_phantom_pinned(ty: &Type) -> bool {
120     match ty {
121         Type::Path(TypePath { qself: None, path }) => {
122             // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
123             if path.segments.len() > 3 {
124                 return false;
125             }
126             // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
127             // `::std::marker::PhantomPinned`.
128             if path.leading_colon.is_some() && path.segments.len() != 3 {
129                 return false;
130             }
131             let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
132             for (actual, expected) in path.segments.iter().rev().zip(expected) {
133                 if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
134                     return false;
135                 }
136             }
137             true
138         }
139         _ => false,
140     }
141 }
142 
143 fn generate_unpin_impl(
144     ident: &Ident,
145     generics: &Generics,
146     fields: &[(bool, &Field)],
147 ) -> TokenStream {
148     let (_, ty_generics, _) = generics.split_for_impl();
149     let mut generics_with_pin_lt = generics.clone();
150     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
151     generics_with_pin_lt.make_where_clause();
152     let (
153         impl_generics_with_pin_lt,
154         ty_generics_with_pin_lt,
155         Some(WhereClause {
156             where_token,
157             predicates,
158         }),
159     ) = generics_with_pin_lt.split_for_impl()
160     else {
161         unreachable!()
162     };
163     let pinned_fields = fields.iter().filter_map(|(b, f)| b.then_some(f));
164     quote! {
165         // This struct will be used for the unpin analysis. It is needed, because only structurally
166         // pinned fields are relevant whether the struct should implement `Unpin`.
167         #[allow(dead_code)] // The fields below are never used.
168         struct __Unpin #generics_with_pin_lt
169         #where_token
170             #predicates
171         {
172             __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>,
173             __phantom: ::core::marker::PhantomData<
174                 fn(#ident #ty_generics) -> #ident #ty_generics
175             >,
176             #(#pinned_fields),*
177         }
178 
179         #[doc(hidden)]
180         impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
181         #where_token
182             __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
183             #predicates
184         {}
185     }
186 }
187 
188 fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
189     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
190     let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
191     // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
192     // `PinnedDrop` was specified in `args`.
193     if has_pinned_drop {
194         // When `PinnedDrop` was specified we just implement `Drop` and delegate.
195         quote! {
196             impl #impl_generics ::core::ops::Drop for #ident #ty_generics
197                 #whr
198             {
199                 fn drop(&mut self) {
200                     // SAFETY: Since this is a destructor, `self` will not move after this function
201                     // terminates, since it is inaccessible.
202                     let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
203                     // SAFETY: Since this is a drop function, we can create this token to call the
204                     // pinned destructor of this type.
205                     let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
206                     ::pin_init::PinnedDrop::drop(pinned, token);
207                 }
208             }
209         }
210     } else {
211         // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
212         quote! {
213             // We prevent this by creating a trait that will be implemented for all types implementing
214             // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
215             // if it also implements `Drop`
216             trait MustNotImplDrop {}
217             #[expect(drop_bounds)]
218             impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
219             impl #impl_generics MustNotImplDrop for #ident #ty_generics
220                 #whr
221             {}
222             // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
223             // They might implement `PinnedDrop` correctly for the struct, but forget to give
224             // `PinnedDrop` as the parameter to `#[pin_data]`.
225             #[expect(non_camel_case_types)]
226             trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
227             impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
228                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
229             impl #impl_generics
230                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
231                 #whr
232             {}
233         }
234     }
235 }
236 
237 fn generate_projections(
238     vis: &Visibility,
239     ident: &Ident,
240     generics: &Generics,
241     fields: &[(bool, &Field)],
242 ) -> TokenStream {
243     let (impl_generics, ty_generics, _) = generics.split_for_impl();
244     let mut generics_with_pin_lt = generics.clone();
245     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
246     let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
247     let projection = format_ident!("{ident}Projection");
248     let this = format_ident!("this");
249 
250     let (fields_decl, fields_proj): (Vec<_>, Vec<_>) = fields
251         .iter()
252         .map(|(pinned, field)| {
253             let Field {
254                 vis,
255                 ident,
256                 ty,
257                 attrs,
258                 ..
259             } = field;
260 
261             let mut no_doc_attrs = attrs.clone();
262             no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
263             let ident = ident
264                 .as_ref()
265                 .expect("only structs with named fields are supported");
266             if *pinned {
267                 (
268                     quote!(
269                         #(#attrs)*
270                         #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
271                     ),
272                     quote!(
273                         #(#no_doc_attrs)*
274                         // SAFETY: this field is structurally pinned.
275                         #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
276                     ),
277                 )
278             } else {
279                 (
280                     quote!(
281                         #(#attrs)*
282                         #vis #ident: &'__pin mut #ty,
283                     ),
284                     quote!(
285                         #(#no_doc_attrs)*
286                         #ident: &mut #this.#ident,
287                     ),
288                 )
289             }
290         })
291         .collect();
292     let structurally_pinned_fields_docs = fields
293         .iter()
294         .filter_map(|(pinned, field)| pinned.then_some(field))
295         .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
296     let not_structurally_pinned_fields_docs = fields
297         .iter()
298         .filter_map(|(pinned, field)| (!pinned).then_some(field))
299         .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
300     let docs = format!(" Pin-projections of [`{ident}`]");
301     quote! {
302         #[doc = #docs]
303         #[allow(dead_code)]
304         #[doc(hidden)]
305         #vis struct #projection #generics_with_pin_lt
306             #whr
307         {
308             #(#fields_decl)*
309             ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
310         }
311 
312         impl #impl_generics #ident #ty_generics
313             #whr
314         {
315             /// Pin-projects all fields of `Self`.
316             ///
317             /// These fields are structurally pinned:
318             #(#[doc = #structurally_pinned_fields_docs])*
319             ///
320             /// These fields are **not** structurally pinned:
321             #(#[doc = #not_structurally_pinned_fields_docs])*
322             #[inline]
323             #vis fn project<'__pin>(
324                 self: ::core::pin::Pin<&'__pin mut Self>,
325             ) -> #projection #ty_generics_with_pin_lt {
326                 // SAFETY: we only give access to `&mut` for fields not structurally pinned.
327                 let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
328                 #projection {
329                     #(#fields_proj)*
330                     ___pin_phantom_data: ::core::marker::PhantomData,
331                 }
332             }
333         }
334     }
335 }
336 
337 fn generate_the_pin_data(
338     vis: &Visibility,
339     ident: &Ident,
340     generics: &Generics,
341     fields: &[(bool, &Field)],
342 ) -> TokenStream {
343     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
344 
345     // For every field, we create an initializing projection function according to its projection
346     // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
347     // not structurally pinned, then it can be initialized via `Init`.
348     //
349     // The functions are `unsafe` to prevent accidentally calling them.
350     fn handle_field(
351         Field {
352             vis,
353             ident,
354             ty,
355             attrs,
356             ..
357         }: &Field,
358         struct_ident: &Ident,
359         pinned: bool,
360     ) -> TokenStream {
361         let ident = ident
362             .as_ref()
363             .expect("only structs with named fields are supported");
364         let project_ident = format_ident!("__project_{ident}");
365         let (init_ty, init_fn, project_ty, project_body, pin_safety) = if pinned {
366             (
367                 quote!(PinInit),
368                 quote!(__pinned_init),
369                 quote!(::core::pin::Pin<&'__slot mut #ty>),
370                 // SAFETY: this field is structurally pinned.
371                 quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }),
372                 quote!(
373                     /// - `slot` will not move until it is dropped, i.e. it will be pinned.
374                 ),
375             )
376         } else {
377             (
378                 quote!(Init),
379                 quote!(__init),
380                 quote!(&'__slot mut #ty),
381                 quote!(slot),
382                 quote!(),
383             )
384         };
385         let slot_safety = format!(
386             " `slot` points at the field `{ident}` inside of `{struct_ident}`, which is pinned.",
387         );
388         quote! {
389             /// # Safety
390             ///
391             /// - `slot` is a valid pointer to uninitialized memory.
392             /// - the caller does not touch `slot` when `Err` is returned, they are only permitted
393             ///   to deallocate.
394             #pin_safety
395             #(#attrs)*
396             #vis unsafe fn #ident<E>(
397                 self,
398                 slot: *mut #ty,
399                 init: impl ::pin_init::#init_ty<#ty, E>,
400             ) -> ::core::result::Result<(), E> {
401                 // SAFETY: this function has the same safety requirements as the __init function
402                 // called below.
403                 unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
404             }
405 
406             /// # Safety
407             ///
408             #[doc = #slot_safety]
409             #(#attrs)*
410             #vis unsafe fn #project_ident<'__slot>(
411                 self,
412                 slot: &'__slot mut #ty,
413             ) -> #project_ty {
414                 #project_body
415             }
416         }
417     }
418 
419     let field_accessors = fields
420         .iter()
421         .map(|(pinned, field)| handle_field(field, ident, *pinned))
422         .collect::<TokenStream>();
423     quote! {
424         // We declare this struct which will host all of the projection function for our type. It
425         // will be invariant over all generic parameters which are inherited from the struct.
426         #[doc(hidden)]
427         #vis struct __ThePinData #generics
428             #whr
429         {
430             __phantom: ::core::marker::PhantomData<
431                 fn(#ident #ty_generics) -> #ident #ty_generics
432             >,
433         }
434 
435         impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
436             #whr
437         {
438             fn clone(&self) -> Self { *self }
439         }
440 
441         impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
442             #whr
443         {}
444 
445         #[allow(dead_code)] // Some functions might never be used and private.
446         #[expect(clippy::missing_safety_doc)]
447         impl #impl_generics __ThePinData #ty_generics
448             #whr
449         {
450             #field_accessors
451         }
452 
453         // SAFETY: We have added the correct projection functions above to `__ThePinData` and
454         // we also use the least restrictive generics possible.
455         unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #ident #ty_generics
456             #whr
457         {
458             type PinData = __ThePinData #ty_generics;
459 
460             unsafe fn __pin_data() -> Self::PinData {
461                 __ThePinData { __phantom: ::core::marker::PhantomData }
462             }
463         }
464 
465         // SAFETY: TODO
466         unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics
467             #whr
468         {
469             type Datee = #ident #ty_generics;
470         }
471     }
472 }
473 
474 struct SelfReplacer(PathSegment);
475 
476 impl VisitMut for SelfReplacer {
477     fn visit_path_mut(&mut self, i: &mut syn::Path) {
478         if i.is_ident("Self") {
479             let span = i.span();
480             let seg = &self.0;
481             *i = parse_quote_spanned!(span=> #seg);
482         } else {
483             syn::visit_mut::visit_path_mut(self, i);
484         }
485     }
486 
487     fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
488         if seg.ident == "Self" {
489             let span = seg.span();
490             let this = &self.0;
491             *seg = parse_quote_spanned!(span=> #this);
492         } else {
493             syn::visit_mut::visit_path_segment_mut(self, seg);
494         }
495     }
496 
497     fn visit_item_mut(&mut self, _: &mut Item) {
498         // Do not descend into items, since items reset/change what `Self` refers to.
499     }
500 }
501