xref: /linux/rust/zerocopy-derive/derive/from_bytes.rs (revision b437b3832874d4df88195d31b9052417674ffaed)
1 use proc_macro2::{Span, TokenStream};
2 use syn::{
3     parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp,
4     WherePredicate,
5 };
6 
7 use crate::{
8     derive::try_from_bytes::derive_try_from_bytes,
9     repr::{CompoundRepr, EnumRepr, Repr, Spanned},
10     util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound},
11 };
12 /// Returns `Ok(index)` if variant `index` of the enum has a discriminant of
13 /// zero. If `Err(bool)` is returned, the boolean is true if the enum has
14 /// unknown discriminants (e.g. discriminants set to const expressions which we
15 /// can't evaluate in a proc macro). If the enum has unknown discriminants, then
16 /// it might have a zero variant that we just can't detect.
17 pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> {
18     // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because
19     // the discriminant type may be signed or unsigned. Since we only care about
20     // tracking the discriminant when it's less than or equal to zero, we can
21     // avoid u128 -> i128 conversions and bounds checking by making the "next
22     // discriminant" value implicitly negative.
23     // Technically 64 bits is enough, but 128 is better for future compatibility
24     // with https://github.com/rust-lang/rust/issues/56071
25     let mut next_negative_discriminant = Some(0);
26 
27     // Sometimes we encounter explicit discriminants that we can't know the
28     // value of (e.g. a constant expression that requires evaluation). These
29     // could evaluate to zero or a negative number, but we can't assume that
30     // they do (no false positives allowed!). So we treat them like strictly-
31     // positive values that can't result in any zero variants, and track whether
32     // we've encountered any unknown discriminants.
33     let mut has_unknown_discriminants = false;
34 
35     for (i, v) in enm.variants.iter().enumerate() {
36         match v.discriminant.as_ref() {
37             // Implicit discriminant
38             None => {
39                 match next_negative_discriminant.as_mut() {
40                     Some(0) => return Ok(i),
41                     // n is nonzero so subtraction is always safe
42                     Some(n) => *n -= 1,
43                     None => (),
44                 }
45             }
46             // Explicit positive discriminant
47             Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => {
48                 match int.base10_parse::<u128>().ok() {
49                     Some(0) => return Ok(i),
50                     Some(_) => next_negative_discriminant = None,
51                     None => {
52                         // Numbers should never fail to parse, but just in case:
53                         has_unknown_discriminants = true;
54                         next_negative_discriminant = None;
55                     }
56                 }
57             }
58             // Explicit negative discriminant
59             Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr {
60                 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
61                     match int.base10_parse::<u128>().ok() {
62                         Some(0) => return Ok(i),
63                         // x is nonzero so subtraction is always safe
64                         Some(x) => next_negative_discriminant = Some(x - 1),
65                         None => {
66                             // Numbers should never fail to parse, but just in
67                             // case:
68                             has_unknown_discriminants = true;
69                             next_negative_discriminant = None;
70                         }
71                     }
72                 }
73                 // Unknown negative discriminant (e.g. const repr)
74                 _ => {
75                     has_unknown_discriminants = true;
76                     next_negative_discriminant = None;
77                 }
78             },
79             // Unknown discriminant (e.g. const expr)
80             _ => {
81                 has_unknown_discriminants = true;
82                 next_negative_discriminant = None;
83             }
84         }
85     }
86 
87     Err(has_unknown_discriminants)
88 }
89 pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
90     let try_from_bytes = derive_try_from_bytes(ctx, top_level)?;
91     let from_zeros = match &ctx.ast.data {
92         Data::Struct(strct) => derive_from_zeros_struct(ctx, strct),
93         Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?,
94         Data::Union(unn) => derive_from_zeros_union(ctx, unn),
95     };
96     Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect())
97 }
98 pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
99     let from_zeros = derive_from_zeros(ctx, top_level)?;
100     let from_bytes = match &ctx.ast.data {
101         Data::Struct(strct) => derive_from_bytes_struct(ctx, strct),
102         Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?,
103         Data::Union(unn) => derive_from_bytes_union(ctx, unn),
104     };
105 
106     Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect())
107 }
108 fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
109     ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build()
110 }
111 fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
112     let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
113 
114     // We don't actually care what the repr is; we just care that it's one of
115     // the allowed ones.
116     match repr {
117         Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => {
118         }
119         Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => {
120             return ctx.error_or_skip(
121                 Error::new(
122                     Span::call_site(),
123                     "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
124                 ),
125             );
126         }
127     }
128 
129     let zero_variant = match find_zero_variant(enm) {
130         Ok(index) => enm.variants.iter().nth(index).unwrap(),
131         // Has unknown variants
132         Err(true) => {
133             return ctx.error_or_skip(Error::new_spanned(
134                 &ctx.ast,
135                 "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\
136                 help: This enum has discriminants which are not literal integers. One of those may \
137                 define or imply which variant has a discriminant of zero. Use a literal integer to \
138                 define or imply the variant with a discriminant of zero.",
139             ));
140         }
141         // Does not have unknown variants
142         Err(false) => {
143             return ctx.error_or_skip(Error::new_spanned(
144                 &ctx.ast,
145                 "FromZeros only supported on enums with a variant that has a discriminant of `0`",
146             ));
147         }
148     };
149 
150     let zerocopy_crate = &ctx.zerocopy_crate;
151     let explicit_bounds = zero_variant
152         .fields
153         .iter()
154         .map(|field| {
155             let ty = &field.ty;
156             parse_quote! { #ty: #zerocopy_crate::FromZeros }
157         })
158         .collect::<Vec<WherePredicate>>();
159 
160     Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds))
161         .build())
162 }
163 fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
164     let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
165     ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build()
166 }
167 fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
168     ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build()
169 }
170 fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
171     let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
172 
173     let variants_required = 1usize << enum_size_from_repr(&repr)?;
174     if enm.variants.len() != variants_required {
175         return ctx.error_or_skip(Error::new_spanned(
176             &ctx.ast,
177             format!(
178                 "FromBytes only supported on {} enum with {} variants",
179                 repr.repr_type_name(),
180                 variants_required
181             ),
182         ));
183     }
184 
185     Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build())
186 }
187 fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
188     let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
189     ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build()
190 }
191