xref: /linux/rust/pin-init/internal/src/pin_data.rs (revision fea304ec875454360a3be106e0baad96032bf9fe)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::TokenStream;
4 use quote::{format_ident, quote};
5 use syn::{
6     parse::{End, Nothing, Parse},
7     parse_quote, parse_quote_spanned,
8     spanned::Spanned,
9     visit_mut::VisitMut,
10     Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
11 };
12 
13 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
14 
15 pub(crate) mod kw {
16     syn::custom_keyword!(PinnedDrop);
17 }
18 
19 pub(crate) enum Args {
20     Nothing(Nothing),
21     #[allow(dead_code)]
22     PinnedDrop(kw::PinnedDrop),
23 }
24 
25 impl Parse for Args {
26     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
27         let lh = input.lookahead1();
28         if lh.peek(End) {
29             input.parse().map(Self::Nothing)
30         } else if lh.peek(kw::PinnedDrop) {
31             input.parse().map(Self::PinnedDrop)
32         } else {
33             Err(lh.error())
34         }
35     }
36 }
37 
38 struct FieldInfo<'a> {
39     field: &'a Field,
40     pinned: bool,
41 }
42 
43 pub(crate) fn pin_data(
44     args: Args,
45     input: Item,
46     dcx: &mut DiagCtxt,
47 ) -> Result<TokenStream, ErrorGuaranteed> {
48     let mut struct_ = match input {
49         Item::Struct(struct_) => struct_,
50         Item::Enum(enum_) => {
51             return Err(dcx.error(
52                 enum_.enum_token,
53                 "`#[pin_data]` only supports structs for now",
54             ));
55         }
56         Item::Union(union) => {
57             return Err(dcx.error(
58                 union.union_token,
59                 "`#[pin_data]` only supports structs for now",
60             ));
61         }
62         rest => {
63             return Err(dcx.error(
64                 rest,
65                 "`#[pin_data]` can only be applied to struct, enum and union definitions",
66             ));
67         }
68     };
69 
70     // The generics might contain the `Self` type. Since this macro will define a new type with the
71     // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
72     // to this struct definition. Therefore we have to replace `Self` with the concrete name.
73     let mut replacer = {
74         let name = &struct_.ident;
75         let (_, ty_generics, _) = struct_.generics.split_for_impl();
76         SelfReplacer(parse_quote!(#name #ty_generics))
77     };
78     replacer.visit_generics_mut(&mut struct_.generics);
79     replacer.visit_fields_mut(&mut struct_.fields);
80 
81     let fields: Vec<FieldInfo<'_>> = struct_
82         .fields
83         .iter_mut()
84         .map(|field| {
85             let len = field.attrs.len();
86             field.attrs.retain(|a| !a.path().is_ident("pin"));
87             let pinned = len != field.attrs.len();
88 
89             FieldInfo {
90                 field: &*field,
91                 pinned,
92             }
93         })
94         .collect();
95 
96     for field in &fields {
97         let ident = field.field.ident.as_ref().unwrap();
98 
99         if !field.pinned && is_phantom_pinned(&field.field.ty) {
100             dcx.warn(
101                 field.field,
102                 format!(
103                     "The field `{ident}` of type `PhantomPinned` only has an effect \
104                     if it has the `#[pin]` attribute",
105                 ),
106             );
107         }
108     }
109 
110     let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
111     let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
112     let projections =
113         generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
114     let the_pin_data =
115         generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
116 
117     Ok(quote! {
118         #struct_
119         #projections
120         // We put the rest into this const item, because it then will not be accessible to anything
121         // outside.
122         const _: () = {
123             #the_pin_data
124             #unpin_impl
125             #drop_impl
126         };
127     })
128 }
129 
130 fn is_phantom_pinned(ty: &Type) -> bool {
131     match ty {
132         Type::Path(TypePath { qself: None, path }) => {
133             // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
134             if path.segments.len() > 3 {
135                 return false;
136             }
137             // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
138             // `::std::marker::PhantomPinned`.
139             if path.leading_colon.is_some() && path.segments.len() != 3 {
140                 return false;
141             }
142             let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
143             for (actual, expected) in path.segments.iter().rev().zip(expected) {
144                 if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
145                     return false;
146                 }
147             }
148             true
149         }
150         _ => false,
151     }
152 }
153 
154 fn generate_unpin_impl(
155     ident: &Ident,
156     generics: &Generics,
157     fields: &[FieldInfo<'_>],
158 ) -> TokenStream {
159     let (_, ty_generics, _) = generics.split_for_impl();
160     let mut generics_with_pin_lt = generics.clone();
161     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
162     generics_with_pin_lt.make_where_clause();
163     let (
164         impl_generics_with_pin_lt,
165         ty_generics_with_pin_lt,
166         Some(WhereClause {
167             where_token,
168             predicates,
169         }),
170     ) = generics_with_pin_lt.split_for_impl()
171     else {
172         unreachable!()
173     };
174     let pinned_fields = fields.iter().filter(|f| f.pinned).map(|f| f.field);
175     quote! {
176         // This struct will be used for the unpin analysis. It is needed, because only structurally
177         // pinned fields are relevant whether the struct should implement `Unpin`.
178         #[allow(dead_code)] // The fields below are never used.
179         struct __Unpin #generics_with_pin_lt
180         #where_token
181             #predicates
182         {
183             __phantom_pin: ::pin_init::__internal::PhantomInvariantLifetime<'__pin>,
184             __phantom: ::pin_init::__internal::PhantomInvariant<#ident #ty_generics>,
185             #(#pinned_fields),*
186         }
187 
188         #[doc(hidden)]
189         impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
190         #where_token
191             __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
192             #predicates
193         {}
194     }
195 }
196 
197 fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
198     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
199     let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
200     // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
201     // `PinnedDrop` was specified in `args`.
202     if has_pinned_drop {
203         // When `PinnedDrop` was specified we just implement `Drop` and delegate.
204         quote! {
205             impl #impl_generics ::core::ops::Drop for #ident #ty_generics
206                 #whr
207             {
208                 fn drop(&mut self) {
209                     // SAFETY: Since this is a destructor, `self` will not move after this function
210                     // terminates, since it is inaccessible.
211                     let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
212                     // SAFETY: Since this is a drop function, we can create this token to call the
213                     // pinned destructor of this type.
214                     let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
215                     ::pin_init::PinnedDrop::drop(pinned, token);
216                 }
217             }
218         }
219     } else {
220         // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
221         quote! {
222             // We prevent this by creating a trait that will be implemented for all types implementing
223             // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
224             // if it also implements `Drop`
225             trait MustNotImplDrop {}
226             #[expect(drop_bounds)]
227             impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
228             impl #impl_generics MustNotImplDrop for #ident #ty_generics
229                 #whr
230             {}
231             // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
232             // They might implement `PinnedDrop` correctly for the struct, but forget to give
233             // `PinnedDrop` as the parameter to `#[pin_data]`.
234             #[expect(non_camel_case_types)]
235             trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
236             impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
237                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
238             impl #impl_generics
239                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
240                 #whr
241             {}
242         }
243     }
244 }
245 
246 fn generate_projections(
247     vis: &Visibility,
248     ident: &Ident,
249     generics: &Generics,
250     fields: &[FieldInfo<'_>],
251 ) -> TokenStream {
252     let (impl_generics, ty_generics, _) = generics.split_for_impl();
253     let mut generics_with_pin_lt = generics.clone();
254     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
255     let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
256     let projection = format_ident!("{ident}Projection");
257     let this = format_ident!("this");
258 
259     let (fields_decl, fields_proj): (Vec<_>, Vec<_>) = fields
260         .iter()
261         .map(|field| {
262             let Field {
263                 vis,
264                 ident,
265                 ty,
266                 attrs,
267                 ..
268             } = &field.field;
269 
270             let mut no_doc_attrs = attrs.clone();
271             no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
272             let ident = ident
273                 .as_ref()
274                 .expect("only structs with named fields are supported");
275             if field.pinned {
276                 (
277                     quote!(
278                         #(#attrs)*
279                         #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
280                     ),
281                     quote!(
282                         #(#no_doc_attrs)*
283                         // SAFETY: this field is structurally pinned.
284                         #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
285                     ),
286                 )
287             } else {
288                 (
289                     quote!(
290                         #(#attrs)*
291                         #vis #ident: &'__pin mut #ty,
292                     ),
293                     quote!(
294                         #(#no_doc_attrs)*
295                         #ident: &mut #this.#ident,
296                     ),
297                 )
298             }
299         })
300         .collect();
301     let structurally_pinned_fields_docs = fields
302         .iter()
303         .filter(|f| f.pinned)
304         .map(|f| format!(" - `{}`", f.field.ident.as_ref().unwrap()));
305     let not_structurally_pinned_fields_docs = fields
306         .iter()
307         .filter(|f| !f.pinned)
308         .map(|f| format!(" - `{}`", f.field.ident.as_ref().unwrap()));
309     let docs = format!(" Pin-projections of [`{ident}`]");
310     quote! {
311         #[doc = #docs]
312         #[allow(dead_code)]
313         #[doc(hidden)]
314         #vis struct #projection #generics_with_pin_lt
315             #whr
316         {
317             #(#fields_decl)*
318             ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
319         }
320 
321         impl #impl_generics #ident #ty_generics
322             #whr
323         {
324             /// Pin-projects all fields of `Self`.
325             ///
326             /// These fields are structurally pinned:
327             #(#[doc = #structurally_pinned_fields_docs])*
328             ///
329             /// These fields are **not** structurally pinned:
330             #(#[doc = #not_structurally_pinned_fields_docs])*
331             #[inline]
332             #vis fn project<'__pin>(
333                 self: ::core::pin::Pin<&'__pin mut Self>,
334             ) -> #projection #ty_generics_with_pin_lt {
335                 // SAFETY: we only give access to `&mut` for fields not structurally pinned.
336                 let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
337                 #projection {
338                     #(#fields_proj)*
339                     ___pin_phantom_data: ::core::marker::PhantomData,
340                 }
341             }
342         }
343     }
344 }
345 
346 fn generate_the_pin_data(
347     vis: &Visibility,
348     struct_name: &Ident,
349     generics: &Generics,
350     fields: &[FieldInfo<'_>],
351 ) -> TokenStream {
352     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
353 
354     // For every field, we create an initializing projection function according to its projection
355     // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
356     // not structurally pinned, then it can be initialized via `Init`.
357     //
358     // The functions are `unsafe` to prevent accidentally calling them.
359     let field_accessors = fields
360         .iter()
361         .map(|f| {
362             let Field {
363                 vis,
364                 ident,
365                 ty,
366                 attrs,
367                 ..
368             } = f.field;
369 
370             let field_name = ident
371                 .as_ref()
372                 .expect("only structs with named fields are supported");
373             let project_ident = format_ident!("__project_{field_name}");
374             let (init_ty, init_fn, project_ty, project_body, pin_safety) = if f.pinned {
375                 (
376                     quote!(PinInit),
377                     quote!(__pinned_init),
378                     quote!(::core::pin::Pin<&'__slot mut #ty>),
379                     // SAFETY: this field is structurally pinned.
380                     quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }),
381                     quote!(
382                         /// - `slot` will not move until it is dropped, i.e. it will be pinned.
383                     ),
384                 )
385             } else {
386                 (
387                     quote!(Init),
388                     quote!(__init),
389                     quote!(&'__slot mut #ty),
390                     quote!(slot),
391                     quote!(),
392                 )
393             };
394             let slot_safety = format!(
395                 " `slot` points at the field `{field_name}` inside of `{struct_name}`, which is pinned.",
396             );
397             quote! {
398                 /// # Safety
399                 ///
400                 /// - `slot` is a valid pointer to uninitialized memory.
401                 /// - the caller does not touch `slot` when `Err` is returned, they are only
402                 ///   permitted to deallocate.
403                 #pin_safety
404                 #(#attrs)*
405                 #vis unsafe fn #field_name<E>(
406                     self,
407                     slot: *mut #ty,
408                     init: impl ::pin_init::#init_ty<#ty, E>,
409                 ) -> ::core::result::Result<(), E> {
410                     // SAFETY: this function has the same safety requirements as the __init function
411                     // called below.
412                     unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
413                 }
414 
415                 /// # Safety
416                 ///
417                 #[doc = #slot_safety]
418                 #(#attrs)*
419                 #vis unsafe fn #project_ident<'__slot>(
420                     self,
421                     slot: &'__slot mut #ty,
422                 ) -> #project_ty {
423                     #project_body
424                 }
425             }
426         })
427         .collect::<TokenStream>();
428     quote! {
429         // We declare this struct which will host all of the projection function for our type. It
430         // will be invariant over all generic parameters which are inherited from the struct.
431         #[doc(hidden)]
432         #vis struct __ThePinData #generics
433             #whr
434         {
435             __phantom: ::pin_init::__internal::PhantomInvariant<#struct_name #ty_generics>,
436         }
437 
438         impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
439             #whr
440         {
441             fn clone(&self) -> Self { *self }
442         }
443 
444         impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
445             #whr
446         {}
447 
448         #[allow(dead_code)] // Some functions might never be used and private.
449         #[expect(clippy::missing_safety_doc)]
450         impl #impl_generics __ThePinData #ty_generics
451             #whr
452         {
453             #field_accessors
454         }
455 
456         // SAFETY: We have added the correct projection functions above to `__ThePinData` and
457         // we also use the least restrictive generics possible.
458         unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #struct_name #ty_generics
459             #whr
460         {
461             type PinData = __ThePinData #ty_generics;
462 
463             unsafe fn __pin_data() -> Self::PinData {
464                 __ThePinData { __phantom: ::pin_init::__internal::PhantomInvariant::new() }
465             }
466         }
467 
468         // SAFETY: TODO
469         unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics
470             #whr
471         {
472             type Datee = #struct_name #ty_generics;
473         }
474     }
475 }
476 
477 struct SelfReplacer(PathSegment);
478 
479 impl VisitMut for SelfReplacer {
480     fn visit_path_mut(&mut self, i: &mut syn::Path) {
481         if i.is_ident("Self") {
482             let span = i.span();
483             let seg = &self.0;
484             *i = parse_quote_spanned!(span=> #seg);
485         } else {
486             syn::visit_mut::visit_path_mut(self, i);
487         }
488     }
489 
490     fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
491         if seg.ident == "Self" {
492             let span = seg.span();
493             let this = &self.0;
494             *seg = parse_quote_spanned!(span=> #this);
495         } else {
496             syn::visit_mut::visit_path_segment_mut(self, seg);
497         }
498     }
499 
500     fn visit_item_mut(&mut self, _: &mut Item) {
501         // Do not descend into items, since items reset/change what `Self` refers to.
502     }
503 }
504