xref: /linux/rust/pin-init/internal/src/zeroable.rs (revision 37a93dd5c49b5fda807fd204edf2547c3493319c)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use proc_macro2::TokenStream;
4 use quote::quote;
5 use syn::{parse_quote, Data, DeriveInput, Field, Fields};
6 
7 use crate::{diagnostics::ErrorGuaranteed, DiagCtxt};
8 
9 pub(crate) fn derive(
10     input: DeriveInput,
11     dcx: &mut DiagCtxt,
12 ) -> Result<TokenStream, ErrorGuaranteed> {
13     let fields = match input.data {
14         Data::Struct(data_struct) => data_struct.fields,
15         Data::Union(data_union) => Fields::Named(data_union.fields),
16         Data::Enum(data_enum) => {
17             return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum"));
18         }
19     };
20     let name = input.ident;
21     let mut generics = input.generics;
22     for param in generics.type_params_mut() {
23         param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
24     }
25     let (impl_gen, ty_gen, whr) = generics.split_for_impl();
26     let field_type = fields.iter().map(|field| &field.ty);
27     Ok(quote! {
28         // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
29         #[automatically_derived]
30         unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen
31             #whr
32         {}
33         const _: () = {
34             fn assert_zeroable<T: ?::core::marker::Sized + ::pin_init::Zeroable>() {}
35             fn ensure_zeroable #impl_gen ()
36                 #whr
37             {
38                 #(
39                     assert_zeroable::<#field_type>();
40                 )*
41             }
42         };
43     })
44 }
45 
46 pub(crate) fn maybe_derive(
47     input: DeriveInput,
48     dcx: &mut DiagCtxt,
49 ) -> Result<TokenStream, ErrorGuaranteed> {
50     let fields = match input.data {
51         Data::Struct(data_struct) => data_struct.fields,
52         Data::Union(data_union) => Fields::Named(data_union.fields),
53         Data::Enum(data_enum) => {
54             return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum"));
55         }
56     };
57     let name = input.ident;
58     let mut generics = input.generics;
59     for param in generics.type_params_mut() {
60         param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
61     }
62     for Field { ty, .. } in fields {
63         generics
64             .make_where_clause()
65             .predicates
66             // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds`
67             // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>.
68             .push(parse_quote!(#ty: for<'__dummy> ::pin_init::Zeroable));
69     }
70     let (impl_gen, ty_gen, whr) = generics.split_for_impl();
71     Ok(quote! {
72         // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
73         #[automatically_derived]
74         unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen
75             #whr
76         {}
77     })
78 }
79