xref: /linux/rust/zerocopy-derive/derive/from_bytes.rs (revision 056a5087d87ead77dedbe9cf5bde53b7cd4b4651)
1 // SPDX-License-Identifier: (BSD-2-Clause OR Apache-2.0) OR MIT
2 
3 use proc_macro2::{Span, TokenStream};
4 use syn::{
5     parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp,
6     WherePredicate,
7 };
8 
9 use crate::{
10     derive::try_from_bytes::derive_try_from_bytes,
11     repr::{CompoundRepr, EnumRepr, Repr, Spanned},
12     util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound},
13 };
14 /// Returns `Ok(index)` if variant `index` of the enum has a discriminant of
15 /// zero. If `Err(bool)` is returned, the boolean is true if the enum has
16 /// unknown discriminants (e.g. discriminants set to const expressions which we
17 /// can't evaluate in a proc macro). If the enum has unknown discriminants, then
18 /// it might have a zero variant that we just can't detect.
19 pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> {
20     // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because
21     // the discriminant type may be signed or unsigned. Since we only care about
22     // tracking the discriminant when it's less than or equal to zero, we can
23     // avoid u128 -> i128 conversions and bounds checking by making the "next
24     // discriminant" value implicitly negative.
25     // Technically 64 bits is enough, but 128 is better for future compatibility
26     // with https://github.com/rust-lang/rust/issues/56071
27     let mut next_negative_discriminant = Some(0);
28 
29     // Sometimes we encounter explicit discriminants that we can't know the
30     // value of (e.g. a constant expression that requires evaluation). These
31     // could evaluate to zero or a negative number, but we can't assume that
32     // they do (no false positives allowed!). So we treat them like strictly-
33     // positive values that can't result in any zero variants, and track whether
34     // we've encountered any unknown discriminants.
35     let mut has_unknown_discriminants = false;
36 
37     for (i, v) in enm.variants.iter().enumerate() {
38         match v.discriminant.as_ref() {
39             // Implicit discriminant
40             None => {
41                 match next_negative_discriminant.as_mut() {
42                     Some(0) => return Ok(i),
43                     // n is nonzero so subtraction is always safe
44                     Some(n) => *n -= 1,
45                     None => (),
46                 }
47             }
48             // Explicit positive discriminant
49             Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => {
50                 match int.base10_parse::<u128>().ok() {
51                     Some(0) => return Ok(i),
52                     Some(_) => next_negative_discriminant = None,
53                     None => {
54                         // Numbers should never fail to parse, but just in case:
55                         has_unknown_discriminants = true;
56                         next_negative_discriminant = None;
57                     }
58                 }
59             }
60             // Explicit negative discriminant
61             Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr {
62                 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
63                     match int.base10_parse::<u128>().ok() {
64                         Some(0) => return Ok(i),
65                         // x is nonzero so subtraction is always safe
66                         Some(x) => next_negative_discriminant = Some(x - 1),
67                         None => {
68                             // Numbers should never fail to parse, but just in
69                             // case:
70                             has_unknown_discriminants = true;
71                             next_negative_discriminant = None;
72                         }
73                     }
74                 }
75                 // Unknown negative discriminant (e.g. const repr)
76                 _ => {
77                     has_unknown_discriminants = true;
78                     next_negative_discriminant = None;
79                 }
80             },
81             // Unknown discriminant (e.g. const expr)
82             _ => {
83                 has_unknown_discriminants = true;
84                 next_negative_discriminant = None;
85             }
86         }
87     }
88 
89     Err(has_unknown_discriminants)
90 }
91 pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
92     let try_from_bytes = derive_try_from_bytes(ctx, top_level)?;
93     let from_zeros = match &ctx.ast.data {
94         Data::Struct(strct) => derive_from_zeros_struct(ctx, strct),
95         Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?,
96         Data::Union(unn) => derive_from_zeros_union(ctx, unn),
97     };
98     Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect())
99 }
100 pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
101     let from_zeros = derive_from_zeros(ctx, top_level)?;
102     let from_bytes = match &ctx.ast.data {
103         Data::Struct(strct) => derive_from_bytes_struct(ctx, strct),
104         Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?,
105         Data::Union(unn) => derive_from_bytes_union(ctx, unn),
106     };
107 
108     Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect())
109 }
110 fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
111     ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build()
112 }
113 fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
114     let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
115 
116     // We don't actually care what the repr is; we just care that it's one of
117     // the allowed ones.
118     match repr {
119         Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => {
120         }
121         Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => {
122             return ctx.error_or_skip(
123                 Error::new(
124                     Span::call_site(),
125                     "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
126                 ),
127             );
128         }
129     }
130 
131     let zero_variant = match find_zero_variant(enm) {
132         Ok(index) => enm.variants.iter().nth(index).unwrap(),
133         // Has unknown variants
134         Err(true) => {
135             return ctx.error_or_skip(Error::new_spanned(
136                 &ctx.ast,
137                 "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\
138                 help: This enum has discriminants which are not literal integers. One of those may \
139                 define or imply which variant has a discriminant of zero. Use a literal integer to \
140                 define or imply the variant with a discriminant of zero.",
141             ));
142         }
143         // Does not have unknown variants
144         Err(false) => {
145             return ctx.error_or_skip(Error::new_spanned(
146                 &ctx.ast,
147                 "FromZeros only supported on enums with a variant that has a discriminant of `0`",
148             ));
149         }
150     };
151 
152     let zerocopy_crate = &ctx.zerocopy_crate;
153     let explicit_bounds = zero_variant
154         .fields
155         .iter()
156         .map(|field| {
157             let ty = &field.ty;
158             parse_quote! { #ty: #zerocopy_crate::FromZeros }
159         })
160         .collect::<Vec<WherePredicate>>();
161 
162     Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds))
163         .build())
164 }
165 fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
166     let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
167     ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build()
168 }
169 fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
170     ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build()
171 }
172 fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
173     let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
174 
175     let variants_required = 1usize << enum_size_from_repr(&repr)?;
176     if enm.variants.len() != variants_required {
177         return ctx.error_or_skip(Error::new_spanned(
178             &ctx.ast,
179             format!(
180                 "FromBytes only supported on {} enum with {} variants",
181                 repr.repr_type_name(),
182                 variants_required
183             ),
184         ));
185     }
186 
187     Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build())
188 }
189 fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
190     let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
191     ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build()
192 }
193