xref: /linux/rust/pin-init/internal/src/pin_data.rs (revision 37a93dd5c49b5fda807fd204edf2547c3493319c)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::TokenStream;
4 use quote::{format_ident, quote};
5 use syn::{
6     parse::{End, Nothing, Parse},
7     parse_quote, parse_quote_spanned,
8     spanned::Spanned,
9     visit_mut::VisitMut,
10     Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
11 };
12 
13 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
14 
15 pub(crate) mod kw {
16     syn::custom_keyword!(PinnedDrop);
17 }
18 
19 pub(crate) enum Args {
20     Nothing(Nothing),
21     #[allow(dead_code)]
22     PinnedDrop(kw::PinnedDrop),
23 }
24 
25 impl Parse for Args {
26     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
27         let lh = input.lookahead1();
28         if lh.peek(End) {
29             input.parse().map(Self::Nothing)
30         } else if lh.peek(kw::PinnedDrop) {
31             input.parse().map(Self::PinnedDrop)
32         } else {
33             Err(lh.error())
34         }
35     }
36 }
37 
38 pub(crate) fn pin_data(
39     args: Args,
40     input: Item,
41     dcx: &mut DiagCtxt,
42 ) -> Result<TokenStream, ErrorGuaranteed> {
43     let mut struct_ = match input {
44         Item::Struct(struct_) => struct_,
45         Item::Enum(enum_) => {
46             return Err(dcx.error(
47                 enum_.enum_token,
48                 "`#[pin_data]` only supports structs for now",
49             ));
50         }
51         Item::Union(union) => {
52             return Err(dcx.error(
53                 union.union_token,
54                 "`#[pin_data]` only supports structs for now",
55             ));
56         }
57         rest => {
58             return Err(dcx.error(
59                 rest,
60                 "`#[pin_data]` can only be applied to struct, enum and union definitions",
61             ));
62         }
63     };
64 
65     // The generics might contain the `Self` type. Since this macro will define a new type with the
66     // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
67     // to this struct definition. Therefore we have to replace `Self` with the concrete name.
68     let mut replacer = {
69         let name = &struct_.ident;
70         let (_, ty_generics, _) = struct_.generics.split_for_impl();
71         SelfReplacer(parse_quote!(#name #ty_generics))
72     };
73     replacer.visit_generics_mut(&mut struct_.generics);
74     replacer.visit_fields_mut(&mut struct_.fields);
75 
76     let fields: Vec<(bool, &Field)> = struct_
77         .fields
78         .iter_mut()
79         .map(|field| {
80             let len = field.attrs.len();
81             field.attrs.retain(|a| !a.path().is_ident("pin"));
82             (len != field.attrs.len(), &*field)
83         })
84         .collect();
85 
86     for (pinned, field) in &fields {
87         if !pinned && is_phantom_pinned(&field.ty) {
88             dcx.error(
89                 field,
90                 format!(
91                     "The field `{}` of type `PhantomPinned` only has an effect \
92                     if it has the `#[pin]` attribute",
93                     field.ident.as_ref().unwrap(),
94                 ),
95             );
96         }
97     }
98 
99     let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
100     let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
101     let projections =
102         generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
103     let the_pin_data =
104         generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
105 
106     Ok(quote! {
107         #struct_
108         #projections
109         // We put the rest into this const item, because it then will not be accessible to anything
110         // outside.
111         const _: () = {
112             #the_pin_data
113             #unpin_impl
114             #drop_impl
115         };
116     })
117 }
118 
119 fn is_phantom_pinned(ty: &Type) -> bool {
120     match ty {
121         Type::Path(TypePath { qself: None, path }) => {
122             // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
123             if path.segments.len() > 3 {
124                 return false;
125             }
126             // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
127             // `::std::marker::PhantomPinned`.
128             if path.leading_colon.is_some() && path.segments.len() != 3 {
129                 return false;
130             }
131             let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
132             for (actual, expected) in path.segments.iter().rev().zip(expected) {
133                 if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
134                     return false;
135                 }
136             }
137             true
138         }
139         _ => false,
140     }
141 }
142 
143 fn generate_unpin_impl(
144     ident: &Ident,
145     generics: &Generics,
146     fields: &[(bool, &Field)],
147 ) -> TokenStream {
148     let (_, ty_generics, _) = generics.split_for_impl();
149     let mut generics_with_pin_lt = generics.clone();
150     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
151     generics_with_pin_lt.make_where_clause();
152     let (
153         impl_generics_with_pin_lt,
154         ty_generics_with_pin_lt,
155         Some(WhereClause {
156             where_token,
157             predicates,
158         }),
159     ) = generics_with_pin_lt.split_for_impl()
160     else {
161         unreachable!()
162     };
163     let pinned_fields = fields.iter().filter_map(|(b, f)| b.then_some(f));
164     quote! {
165         // This struct will be used for the unpin analysis. It is needed, because only structurally
166         // pinned fields are relevant whether the struct should implement `Unpin`.
167         #[allow(dead_code)] // The fields below are never used.
168         struct __Unpin #generics_with_pin_lt
169         #where_token
170             #predicates
171         {
172             __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>,
173             __phantom: ::core::marker::PhantomData<
174                 fn(#ident #ty_generics) -> #ident #ty_generics
175             >,
176             #(#pinned_fields),*
177         }
178 
179         #[doc(hidden)]
180         impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
181         #where_token
182             __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
183             #predicates
184         {}
185     }
186 }
187 
188 fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
189     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
190     let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
191     // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
192     // `PinnedDrop` was specified in `args`.
193     if has_pinned_drop {
194         // When `PinnedDrop` was specified we just implement `Drop` and delegate.
195         quote! {
196             impl #impl_generics ::core::ops::Drop for #ident #ty_generics
197                 #whr
198             {
199                 fn drop(&mut self) {
200                     // SAFETY: Since this is a destructor, `self` will not move after this function
201                     // terminates, since it is inaccessible.
202                     let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
203                     // SAFETY: Since this is a drop function, we can create this token to call the
204                     // pinned destructor of this type.
205                     let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
206                     ::pin_init::PinnedDrop::drop(pinned, token);
207                 }
208             }
209         }
210     } else {
211         // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
212         quote! {
213             // We prevent this by creating a trait that will be implemented for all types implementing
214             // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
215             // if it also implements `Drop`
216             trait MustNotImplDrop {}
217             #[expect(drop_bounds)]
218             impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
219             impl #impl_generics MustNotImplDrop for #ident #ty_generics
220                 #whr
221             {}
222             // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
223             // They might implement `PinnedDrop` correctly for the struct, but forget to give
224             // `PinnedDrop` as the parameter to `#[pin_data]`.
225             #[expect(non_camel_case_types)]
226             trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
227             impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
228                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
229             impl #impl_generics
230                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
231                 #whr
232             {}
233         }
234     }
235 }
236 
237 fn generate_projections(
238     vis: &Visibility,
239     ident: &Ident,
240     generics: &Generics,
241     fields: &[(bool, &Field)],
242 ) -> TokenStream {
243     let (impl_generics, ty_generics, _) = generics.split_for_impl();
244     let mut generics_with_pin_lt = generics.clone();
245     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
246     let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
247     let projection = format_ident!("{ident}Projection");
248     let this = format_ident!("this");
249 
250     let (fields_decl, fields_proj) = collect_tuple(fields.iter().map(
251         |(
252             pinned,
253             Field {
254                 vis,
255                 ident,
256                 ty,
257                 attrs,
258                 ..
259             },
260         )| {
261             let mut attrs = attrs.clone();
262             attrs.retain(|a| !a.path().is_ident("pin"));
263             let mut no_doc_attrs = attrs.clone();
264             no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
265             let ident = ident
266                 .as_ref()
267                 .expect("only structs with named fields are supported");
268             if *pinned {
269                 (
270                     quote!(
271                         #(#attrs)*
272                         #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
273                     ),
274                     quote!(
275                         #(#no_doc_attrs)*
276                         // SAFETY: this field is structurally pinned.
277                         #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
278                     ),
279                 )
280             } else {
281                 (
282                     quote!(
283                         #(#attrs)*
284                         #vis #ident: &'__pin mut #ty,
285                     ),
286                     quote!(
287                         #(#no_doc_attrs)*
288                         #ident: &mut #this.#ident,
289                     ),
290                 )
291             }
292         },
293     ));
294     let structurally_pinned_fields_docs = fields
295         .iter()
296         .filter_map(|(pinned, field)| pinned.then_some(field))
297         .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
298     let not_structurally_pinned_fields_docs = fields
299         .iter()
300         .filter_map(|(pinned, field)| (!pinned).then_some(field))
301         .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
302     let docs = format!(" Pin-projections of [`{ident}`]");
303     quote! {
304         #[doc = #docs]
305         #[allow(dead_code)]
306         #[doc(hidden)]
307         #vis struct #projection #generics_with_pin_lt {
308             #(#fields_decl)*
309             ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
310         }
311 
312         impl #impl_generics #ident #ty_generics
313             #whr
314         {
315             /// Pin-projects all fields of `Self`.
316             ///
317             /// These fields are structurally pinned:
318             #(#[doc = #structurally_pinned_fields_docs])*
319             ///
320             /// These fields are **not** structurally pinned:
321             #(#[doc = #not_structurally_pinned_fields_docs])*
322             #[inline]
323             #vis fn project<'__pin>(
324                 self: ::core::pin::Pin<&'__pin mut Self>,
325             ) -> #projection #ty_generics_with_pin_lt {
326                 // SAFETY: we only give access to `&mut` for fields not structurally pinned.
327                 let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
328                 #projection {
329                     #(#fields_proj)*
330                     ___pin_phantom_data: ::core::marker::PhantomData,
331                 }
332             }
333         }
334     }
335 }
336 
337 fn generate_the_pin_data(
338     vis: &Visibility,
339     ident: &Ident,
340     generics: &Generics,
341     fields: &[(bool, &Field)],
342 ) -> TokenStream {
343     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
344 
345     // For every field, we create an initializing projection function according to its projection
346     // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
347     // not structurally pinned, then it can be initialized via `Init`.
348     //
349     // The functions are `unsafe` to prevent accidentally calling them.
350     fn handle_field(
351         Field {
352             vis,
353             ident,
354             ty,
355             attrs,
356             ..
357         }: &Field,
358         struct_ident: &Ident,
359         pinned: bool,
360     ) -> TokenStream {
361         let mut attrs = attrs.clone();
362         attrs.retain(|a| !a.path().is_ident("pin"));
363         let ident = ident
364             .as_ref()
365             .expect("only structs with named fields are supported");
366         let project_ident = format_ident!("__project_{ident}");
367         let (init_ty, init_fn, project_ty, project_body, pin_safety) = if pinned {
368             (
369                 quote!(PinInit),
370                 quote!(__pinned_init),
371                 quote!(::core::pin::Pin<&'__slot mut #ty>),
372                 // SAFETY: this field is structurally pinned.
373                 quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }),
374                 quote!(
375                     /// - `slot` will not move until it is dropped, i.e. it will be pinned.
376                 ),
377             )
378         } else {
379             (
380                 quote!(Init),
381                 quote!(__init),
382                 quote!(&'__slot mut #ty),
383                 quote!(slot),
384                 quote!(),
385             )
386         };
387         let slot_safety = format!(
388             " `slot` points at the field `{ident}` inside of `{struct_ident}`, which is pinned.",
389         );
390         quote! {
391             /// # Safety
392             ///
393             /// - `slot` is a valid pointer to uninitialized memory.
394             /// - the caller does not touch `slot` when `Err` is returned, they are only permitted
395             ///   to deallocate.
396             #pin_safety
397             #(#attrs)*
398             #vis unsafe fn #ident<E>(
399                 self,
400                 slot: *mut #ty,
401                 init: impl ::pin_init::#init_ty<#ty, E>,
402             ) -> ::core::result::Result<(), E> {
403                 // SAFETY: this function has the same safety requirements as the __init function
404                 // called below.
405                 unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
406             }
407 
408             /// # Safety
409             ///
410             #[doc = #slot_safety]
411             #(#attrs)*
412             #vis unsafe fn #project_ident<'__slot>(
413                 self,
414                 slot: &'__slot mut #ty,
415             ) -> #project_ty {
416                 #project_body
417             }
418         }
419     }
420 
421     let field_accessors = fields
422         .iter()
423         .map(|(pinned, field)| handle_field(field, ident, *pinned))
424         .collect::<TokenStream>();
425     quote! {
426         // We declare this struct which will host all of the projection function for our type. It
427         // will be invariant over all generic parameters which are inherited from the struct.
428         #[doc(hidden)]
429         #vis struct __ThePinData #generics
430             #whr
431         {
432             __phantom: ::core::marker::PhantomData<
433                 fn(#ident #ty_generics) -> #ident #ty_generics
434             >,
435         }
436 
437         impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
438             #whr
439         {
440             fn clone(&self) -> Self { *self }
441         }
442 
443         impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
444             #whr
445         {}
446 
447         #[allow(dead_code)] // Some functions might never be used and private.
448         #[expect(clippy::missing_safety_doc)]
449         impl #impl_generics __ThePinData #ty_generics
450             #whr
451         {
452             #field_accessors
453         }
454 
455         // SAFETY: We have added the correct projection functions above to `__ThePinData` and
456         // we also use the least restrictive generics possible.
457         unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #ident #ty_generics
458             #whr
459         {
460             type PinData = __ThePinData #ty_generics;
461 
462             unsafe fn __pin_data() -> Self::PinData {
463                 __ThePinData { __phantom: ::core::marker::PhantomData }
464             }
465         }
466 
467         // SAFETY: TODO
468         unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics
469             #whr
470         {
471             type Datee = #ident #ty_generics;
472         }
473     }
474 }
475 
476 struct SelfReplacer(PathSegment);
477 
478 impl VisitMut for SelfReplacer {
479     fn visit_path_mut(&mut self, i: &mut syn::Path) {
480         if i.is_ident("Self") {
481             let span = i.span();
482             let seg = &self.0;
483             *i = parse_quote_spanned!(span=> #seg);
484         } else {
485             syn::visit_mut::visit_path_mut(self, i);
486         }
487     }
488 
489     fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
490         if seg.ident == "Self" {
491             let span = seg.span();
492             let this = &self.0;
493             *seg = parse_quote_spanned!(span=> #this);
494         } else {
495             syn::visit_mut::visit_path_segment_mut(self, seg);
496         }
497     }
498 
499     fn visit_item_mut(&mut self, _: &mut Item) {
500         // Do not descend into items, since items reset/change what `Self` refers to.
501     }
502 }
503 
504 // replace with `.collect()` once MSRV is above 1.79
505 fn collect_tuple<A, B>(iter: impl Iterator<Item = (A, B)>) -> (Vec<A>, Vec<B>) {
506     let mut res_a = vec![];
507     let mut res_b = vec![];
508     for (a, b) in iter {
509         res_a.push(a);
510         res_b.push(b);
511     }
512     (res_a, res_b)
513 }
514