xref: /linux/rust/pin-init/internal/src/pin_data.rs (revision 57b0a0d7e5a063edceb50bffa648b49591112896)
1 // SPDX-License-Identifier: Apache-2.0 OR MIT
2 
3 use proc_macro2::TokenStream;
4 use quote::{format_ident, quote};
5 use syn::{
6     parse::{End, Nothing, Parse},
7     parse_quote, parse_quote_spanned,
8     spanned::Spanned,
9     visit_mut::VisitMut,
10     Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
11 };
12 
13 use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
14 
15 pub(crate) mod kw {
16     syn::custom_keyword!(PinnedDrop);
17 }
18 
19 pub(crate) enum Args {
20     Nothing(Nothing),
21     #[allow(dead_code)]
22     PinnedDrop(kw::PinnedDrop),
23 }
24 
25 impl Parse for Args {
26     fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
27         let lh = input.lookahead1();
28         if lh.peek(End) {
29             input.parse().map(Self::Nothing)
30         } else if lh.peek(kw::PinnedDrop) {
31             input.parse().map(Self::PinnedDrop)
32         } else {
33             Err(lh.error())
34         }
35     }
36 }
37 
38 struct FieldInfo<'a> {
39     field: &'a Field,
40     pinned: bool,
41 }
42 
43 pub(crate) fn pin_data(
44     args: Args,
45     input: Item,
46     dcx: &mut DiagCtxt,
47 ) -> Result<TokenStream, ErrorGuaranteed> {
48     let mut struct_ = match input {
49         Item::Struct(struct_) => struct_,
50         Item::Enum(enum_) => {
51             return Err(dcx.error(
52                 enum_.enum_token,
53                 "`#[pin_data]` only supports structs for now",
54             ));
55         }
56         Item::Union(union) => {
57             return Err(dcx.error(
58                 union.union_token,
59                 "`#[pin_data]` only supports structs for now",
60             ));
61         }
62         rest => {
63             return Err(dcx.error(
64                 rest,
65                 "`#[pin_data]` can only be applied to struct, enum and union definitions",
66             ));
67         }
68     };
69 
70     // The generics might contain the `Self` type. Since this macro will define a new type with the
71     // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
72     // to this struct definition. Therefore we have to replace `Self` with the concrete name.
73     let mut replacer = {
74         let name = &struct_.ident;
75         let (_, ty_generics, _) = struct_.generics.split_for_impl();
76         SelfReplacer(parse_quote!(#name #ty_generics))
77     };
78     replacer.visit_generics_mut(&mut struct_.generics);
79     replacer.visit_fields_mut(&mut struct_.fields);
80 
81     let fields: Vec<FieldInfo<'_>> = struct_
82         .fields
83         .iter_mut()
84         .map(|field| {
85             let len = field.attrs.len();
86             field.attrs.retain(|a| !a.path().is_ident("pin"));
87             let pinned = len != field.attrs.len();
88 
89             FieldInfo {
90                 field: &*field,
91                 pinned,
92             }
93         })
94         .collect();
95 
96     for field in &fields {
97         let ident = field.field.ident.as_ref().unwrap();
98 
99         if !field.pinned && is_phantom_pinned(&field.field.ty) {
100             dcx.warn(
101                 field.field,
102                 format!(
103                     "The field `{ident}` of type `PhantomPinned` only has an effect \
104                     if it has the `#[pin]` attribute",
105                 ),
106             );
107         }
108     }
109 
110     let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
111     let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
112     let projections =
113         generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
114     let the_pin_data =
115         generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
116 
117     Ok(quote! {
118         #struct_
119         #projections
120         // We put the rest into this const item, because it then will not be accessible to anything
121         // outside.
122         const _: () = {
123             #the_pin_data
124             #unpin_impl
125             #drop_impl
126         };
127     })
128 }
129 
130 fn is_phantom_pinned(ty: &Type) -> bool {
131     match ty {
132         Type::Path(TypePath { qself: None, path }) => {
133             // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
134             if path.segments.len() > 3 {
135                 return false;
136             }
137             // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
138             // `::std::marker::PhantomPinned`.
139             if path.leading_colon.is_some() && path.segments.len() != 3 {
140                 return false;
141             }
142             let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
143             for (actual, expected) in path.segments.iter().rev().zip(expected) {
144                 if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
145                     return false;
146                 }
147             }
148             true
149         }
150         _ => false,
151     }
152 }
153 
154 fn generate_unpin_impl(
155     ident: &Ident,
156     generics: &Generics,
157     fields: &[FieldInfo<'_>],
158 ) -> TokenStream {
159     let (_, ty_generics, _) = generics.split_for_impl();
160     let mut generics_with_pin_lt = generics.clone();
161     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
162     generics_with_pin_lt.make_where_clause();
163     let (
164         impl_generics_with_pin_lt,
165         ty_generics_with_pin_lt,
166         Some(WhereClause {
167             where_token,
168             predicates,
169         }),
170     ) = generics_with_pin_lt.split_for_impl()
171     else {
172         unreachable!()
173     };
174     let pinned_fields = fields.iter().filter(|f| f.pinned).map(|f| f.field);
175     quote! {
176         // This struct will be used for the unpin analysis. It is needed, because only structurally
177         // pinned fields are relevant whether the struct should implement `Unpin`.
178         #[allow(dead_code)] // The fields below are never used.
179         struct __Unpin #generics_with_pin_lt
180         #where_token
181             #predicates
182         {
183             __phantom_pin: ::pin_init::__internal::PhantomInvariantLifetime<'__pin>,
184             __phantom: ::pin_init::__internal::PhantomInvariant<#ident #ty_generics>,
185             #(#pinned_fields),*
186         }
187 
188         #[doc(hidden)]
189         impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
190         #where_token
191             __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
192             #predicates
193         {}
194     }
195 }
196 
197 fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
198     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
199     let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
200     // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
201     // `PinnedDrop` was specified in `args`.
202     if has_pinned_drop {
203         // When `PinnedDrop` was specified we just implement `Drop` and delegate.
204         quote! {
205             impl #impl_generics ::core::ops::Drop for #ident #ty_generics
206                 #whr
207             {
208                 fn drop(&mut self) {
209                     // SAFETY: Since this is a destructor, `self` will not move after this function
210                     // terminates, since it is inaccessible.
211                     let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
212                     // SAFETY: Since this is a drop function, we can create this token to call the
213                     // pinned destructor of this type.
214                     let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
215                     ::pin_init::PinnedDrop::drop(pinned, token);
216                 }
217             }
218         }
219     } else {
220         // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
221         quote! {
222             // We prevent this by creating a trait that will be implemented for all types implementing
223             // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
224             // if it also implements `Drop`
225             trait MustNotImplDrop {}
226             #[expect(drop_bounds)]
227             impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
228             impl #impl_generics MustNotImplDrop for #ident #ty_generics
229                 #whr
230             {}
231             // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
232             // They might implement `PinnedDrop` correctly for the struct, but forget to give
233             // `PinnedDrop` as the parameter to `#[pin_data]`.
234             #[expect(non_camel_case_types)]
235             trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
236             impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
237                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
238             impl #impl_generics
239                 UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
240                 #whr
241             {}
242         }
243     }
244 }
245 
246 fn generate_projections(
247     vis: &Visibility,
248     ident: &Ident,
249     generics: &Generics,
250     fields: &[FieldInfo<'_>],
251 ) -> TokenStream {
252     let (impl_generics, ty_generics, _) = generics.split_for_impl();
253     let mut generics_with_pin_lt = generics.clone();
254     generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
255     let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
256     let projection = format_ident!("{ident}Projection");
257     let this = format_ident!("this");
258 
259     let (fields_decl, fields_proj): (Vec<_>, Vec<_>) = fields
260         .iter()
261         .map(|field| {
262             let Field {
263                 vis,
264                 ident,
265                 ty,
266                 attrs,
267                 ..
268             } = &field.field;
269 
270             let mut no_doc_attrs = attrs.clone();
271             no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
272             let ident = ident
273                 .as_ref()
274                 .expect("only structs with named fields are supported");
275             if field.pinned {
276                 (
277                     quote!(
278                         #(#attrs)*
279                         #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
280                     ),
281                     quote!(
282                         #(#no_doc_attrs)*
283                         // SAFETY: this field is structurally pinned.
284                         #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
285                     ),
286                 )
287             } else {
288                 (
289                     quote!(
290                         #(#attrs)*
291                         #vis #ident: &'__pin mut #ty,
292                     ),
293                     quote!(
294                         #(#no_doc_attrs)*
295                         #ident: &mut #this.#ident,
296                     ),
297                 )
298             }
299         })
300         .collect();
301     let structurally_pinned_fields_docs = fields
302         .iter()
303         .filter(|f| f.pinned)
304         .map(|f| format!(" - `{}`", f.field.ident.as_ref().unwrap()));
305     let not_structurally_pinned_fields_docs = fields
306         .iter()
307         .filter(|f| !f.pinned)
308         .map(|f| format!(" - `{}`", f.field.ident.as_ref().unwrap()));
309     let docs = format!(" Pin-projections of [`{ident}`]");
310     quote! {
311         #[doc = #docs]
312         #[allow(dead_code)]
313         #[doc(hidden)]
314         #vis struct #projection #generics_with_pin_lt
315             #whr
316         {
317             #(#fields_decl)*
318             ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
319         }
320 
321         impl #impl_generics #ident #ty_generics
322             #whr
323         {
324             /// Pin-projects all fields of `Self`.
325             ///
326             /// These fields are structurally pinned:
327             #(#[doc = #structurally_pinned_fields_docs])*
328             ///
329             /// These fields are **not** structurally pinned:
330             #(#[doc = #not_structurally_pinned_fields_docs])*
331             #[inline]
332             #vis fn project<'__pin>(
333                 self: ::core::pin::Pin<&'__pin mut Self>,
334             ) -> #projection #ty_generics_with_pin_lt {
335                 // SAFETY: we only give access to `&mut` for fields not structurally pinned.
336                 let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
337                 #projection {
338                     #(#fields_proj)*
339                     ___pin_phantom_data: ::core::marker::PhantomData,
340                 }
341             }
342         }
343     }
344 }
345 
346 fn generate_the_pin_data(
347     vis: &Visibility,
348     struct_name: &Ident,
349     generics: &Generics,
350     fields: &[FieldInfo<'_>],
351 ) -> TokenStream {
352     let (impl_generics, ty_generics, whr) = generics.split_for_impl();
353 
354     // For every field, we create an initializing projection function according to its projection
355     // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
356     // not structurally pinned, then it can be initialized via `Init`.
357     //
358     // The functions are `unsafe` to prevent accidentally calling them.
359     let field_accessors = fields
360         .iter()
361         .map(|f| {
362             let Field {
363                 vis,
364                 ident,
365                 ty,
366                 attrs,
367                 ..
368             } = f.field;
369 
370             let field_name = ident
371                 .as_ref()
372                 .expect("only structs with named fields are supported");
373             let project_ident = format_ident!("__project_{field_name}");
374             let (init_ty, init_fn, pin_marker, pin_safety) = if f.pinned {
375                 (
376                     quote!(PinInit),
377                     quote!(__pinned_init),
378                     quote!(Pinned),
379                     quote!(
380                         /// - `slot` will not move until it is dropped, i.e. it will be pinned.
381                     ),
382                 )
383             } else {
384                 (quote!(Init), quote!(__init), quote!(Unpinned), quote!())
385             };
386             quote! {
387                 /// # Safety
388                 ///
389                 /// - `slot` is a valid pointer to uninitialized memory.
390                 /// - the caller does not touch `slot` when `Err` is returned, they are only
391                 ///   permitted to deallocate.
392                 #pin_safety
393                 #(#attrs)*
394                 #vis unsafe fn #field_name<E>(
395                     self,
396                     slot: *mut #ty,
397                     init: impl ::pin_init::#init_ty<#ty, E>,
398                 ) -> ::core::result::Result<(), E> {
399                     // SAFETY: this function has the same safety requirements as the __init function
400                     // called below.
401                     unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
402                 }
403 
404                 /// # Safety
405                 ///
406                 /// - `slot` points to a `#ident` field of a pinned struct that this `__ThePinData`
407                 ///    describes.
408                 /// - `slot` is valid and properly aligned.
409                 /// - `*slot` is initialized, and the ownership is transferred to the returned
410                 ///    guard.
411                 #(#attrs)*
412                 #vis unsafe fn #project_ident(
413                     self,
414                     slot: *mut #ty,
415                 ) -> ::pin_init::__internal::DropGuard<::pin_init::__internal::#pin_marker, #ty> {
416                     // SAFETY:
417                     // - If `#pin_marker` is `Pinned`, the corresponding field is structurally
418                     //   pinned.
419                     // - Other safety requirements follows the safety requirement.
420                     unsafe { ::pin_init::__internal::DropGuard::new(slot) }
421                 }
422             }
423         })
424         .collect::<TokenStream>();
425     quote! {
426         // We declare this struct which will host all of the projection function for our type. It
427         // will be invariant over all generic parameters which are inherited from the struct.
428         #[doc(hidden)]
429         #vis struct __ThePinData #generics
430             #whr
431         {
432             __phantom: ::pin_init::__internal::PhantomInvariant<#struct_name #ty_generics>,
433         }
434 
435         impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
436             #whr
437         {
438             fn clone(&self) -> Self { *self }
439         }
440 
441         impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
442             #whr
443         {}
444 
445         #[allow(dead_code)] // Some functions might never be used and private.
446         #[expect(clippy::missing_safety_doc)]
447         impl #impl_generics __ThePinData #ty_generics
448             #whr
449         {
450             /// Type inference helper function.
451             #[inline(always)]
452             #vis fn __make_closure<__F, __E>(self, f: __F) -> __F
453             where
454                 __F: FnOnce(*mut #struct_name #ty_generics) ->
455                     ::core::result::Result<::pin_init::__internal::InitOk, __E>,
456             {
457                 f
458             }
459 
460             #field_accessors
461         }
462 
463         // SAFETY: We have added the correct projection functions above to `__ThePinData` and
464         // we also use the least restrictive generics possible.
465         unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #struct_name #ty_generics
466             #whr
467         {
468             type PinData = __ThePinData #ty_generics;
469 
470             unsafe fn __pin_data() -> Self::PinData {
471                 __ThePinData { __phantom: ::pin_init::__internal::PhantomInvariant::new() }
472             }
473         }
474     }
475 }
476 
477 struct SelfReplacer(PathSegment);
478 
479 impl VisitMut for SelfReplacer {
480     fn visit_path_mut(&mut self, i: &mut syn::Path) {
481         if i.is_ident("Self") {
482             let span = i.span();
483             let seg = &self.0;
484             *i = parse_quote_spanned!(span=> #seg);
485         } else {
486             syn::visit_mut::visit_path_mut(self, i);
487         }
488     }
489 
490     fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
491         if seg.ident == "Self" {
492             let span = seg.span();
493             let this = &self.0;
494             *seg = parse_quote_spanned!(span=> #this);
495         } else {
496             syn::visit_mut::visit_path_segment_mut(self, seg);
497         }
498     }
499 
500     fn visit_item_mut(&mut self, _: &mut Item) {
501         // Do not descend into items, since items reset/change what `Self` refers to.
502     }
503 }
504