xref: /linux/rust/zerocopy-derive/derive/from_bytes.rs (revision b079329b8691768962aa514b8f8c9077ca352459)
1*5f85604cSMiguel Ojeda // SPDX-License-Identifier: (BSD-2-Clause OR Apache-2.0) OR MIT
2*5f85604cSMiguel Ojeda 
3b437b383SMiguel Ojeda use proc_macro2::{Span, TokenStream};
4b437b383SMiguel Ojeda use syn::{
5b437b383SMiguel Ojeda     parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp,
6b437b383SMiguel Ojeda     WherePredicate,
7b437b383SMiguel Ojeda };
8b437b383SMiguel Ojeda 
9b437b383SMiguel Ojeda use crate::{
10b437b383SMiguel Ojeda     derive::try_from_bytes::derive_try_from_bytes,
11b437b383SMiguel Ojeda     repr::{CompoundRepr, EnumRepr, Repr, Spanned},
12b437b383SMiguel Ojeda     util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound},
13b437b383SMiguel Ojeda };
14b437b383SMiguel Ojeda /// Returns `Ok(index)` if variant `index` of the enum has a discriminant of
15b437b383SMiguel Ojeda /// zero. If `Err(bool)` is returned, the boolean is true if the enum has
16b437b383SMiguel Ojeda /// unknown discriminants (e.g. discriminants set to const expressions which we
17b437b383SMiguel Ojeda /// can't evaluate in a proc macro). If the enum has unknown discriminants, then
18b437b383SMiguel Ojeda /// it might have a zero variant that we just can't detect.
19b437b383SMiguel Ojeda pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> {
20b437b383SMiguel Ojeda     // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because
21b437b383SMiguel Ojeda     // the discriminant type may be signed or unsigned. Since we only care about
22b437b383SMiguel Ojeda     // tracking the discriminant when it's less than or equal to zero, we can
23b437b383SMiguel Ojeda     // avoid u128 -> i128 conversions and bounds checking by making the "next
24b437b383SMiguel Ojeda     // discriminant" value implicitly negative.
25b437b383SMiguel Ojeda     // Technically 64 bits is enough, but 128 is better for future compatibility
26b437b383SMiguel Ojeda     // with https://github.com/rust-lang/rust/issues/56071
27b437b383SMiguel Ojeda     let mut next_negative_discriminant = Some(0);
28b437b383SMiguel Ojeda 
29b437b383SMiguel Ojeda     // Sometimes we encounter explicit discriminants that we can't know the
30b437b383SMiguel Ojeda     // value of (e.g. a constant expression that requires evaluation). These
31b437b383SMiguel Ojeda     // could evaluate to zero or a negative number, but we can't assume that
32b437b383SMiguel Ojeda     // they do (no false positives allowed!). So we treat them like strictly-
33b437b383SMiguel Ojeda     // positive values that can't result in any zero variants, and track whether
34b437b383SMiguel Ojeda     // we've encountered any unknown discriminants.
35b437b383SMiguel Ojeda     let mut has_unknown_discriminants = false;
36b437b383SMiguel Ojeda 
37b437b383SMiguel Ojeda     for (i, v) in enm.variants.iter().enumerate() {
38b437b383SMiguel Ojeda         match v.discriminant.as_ref() {
39b437b383SMiguel Ojeda             // Implicit discriminant
40b437b383SMiguel Ojeda             None => {
41b437b383SMiguel Ojeda                 match next_negative_discriminant.as_mut() {
42b437b383SMiguel Ojeda                     Some(0) => return Ok(i),
43b437b383SMiguel Ojeda                     // n is nonzero so subtraction is always safe
44b437b383SMiguel Ojeda                     Some(n) => *n -= 1,
45b437b383SMiguel Ojeda                     None => (),
46b437b383SMiguel Ojeda                 }
47b437b383SMiguel Ojeda             }
48b437b383SMiguel Ojeda             // Explicit positive discriminant
49b437b383SMiguel Ojeda             Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => {
50b437b383SMiguel Ojeda                 match int.base10_parse::<u128>().ok() {
51b437b383SMiguel Ojeda                     Some(0) => return Ok(i),
52b437b383SMiguel Ojeda                     Some(_) => next_negative_discriminant = None,
53b437b383SMiguel Ojeda                     None => {
54b437b383SMiguel Ojeda                         // Numbers should never fail to parse, but just in case:
55b437b383SMiguel Ojeda                         has_unknown_discriminants = true;
56b437b383SMiguel Ojeda                         next_negative_discriminant = None;
57b437b383SMiguel Ojeda                     }
58b437b383SMiguel Ojeda                 }
59b437b383SMiguel Ojeda             }
60b437b383SMiguel Ojeda             // Explicit negative discriminant
61b437b383SMiguel Ojeda             Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr {
62b437b383SMiguel Ojeda                 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => {
63b437b383SMiguel Ojeda                     match int.base10_parse::<u128>().ok() {
64b437b383SMiguel Ojeda                         Some(0) => return Ok(i),
65b437b383SMiguel Ojeda                         // x is nonzero so subtraction is always safe
66b437b383SMiguel Ojeda                         Some(x) => next_negative_discriminant = Some(x - 1),
67b437b383SMiguel Ojeda                         None => {
68b437b383SMiguel Ojeda                             // Numbers should never fail to parse, but just in
69b437b383SMiguel Ojeda                             // case:
70b437b383SMiguel Ojeda                             has_unknown_discriminants = true;
71b437b383SMiguel Ojeda                             next_negative_discriminant = None;
72b437b383SMiguel Ojeda                         }
73b437b383SMiguel Ojeda                     }
74b437b383SMiguel Ojeda                 }
75b437b383SMiguel Ojeda                 // Unknown negative discriminant (e.g. const repr)
76b437b383SMiguel Ojeda                 _ => {
77b437b383SMiguel Ojeda                     has_unknown_discriminants = true;
78b437b383SMiguel Ojeda                     next_negative_discriminant = None;
79b437b383SMiguel Ojeda                 }
80b437b383SMiguel Ojeda             },
81b437b383SMiguel Ojeda             // Unknown discriminant (e.g. const expr)
82b437b383SMiguel Ojeda             _ => {
83b437b383SMiguel Ojeda                 has_unknown_discriminants = true;
84b437b383SMiguel Ojeda                 next_negative_discriminant = None;
85b437b383SMiguel Ojeda             }
86b437b383SMiguel Ojeda         }
87b437b383SMiguel Ojeda     }
88b437b383SMiguel Ojeda 
89b437b383SMiguel Ojeda     Err(has_unknown_discriminants)
90b437b383SMiguel Ojeda }
91b437b383SMiguel Ojeda pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
92b437b383SMiguel Ojeda     let try_from_bytes = derive_try_from_bytes(ctx, top_level)?;
93b437b383SMiguel Ojeda     let from_zeros = match &ctx.ast.data {
94b437b383SMiguel Ojeda         Data::Struct(strct) => derive_from_zeros_struct(ctx, strct),
95b437b383SMiguel Ojeda         Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?,
96b437b383SMiguel Ojeda         Data::Union(unn) => derive_from_zeros_union(ctx, unn),
97b437b383SMiguel Ojeda     };
98b437b383SMiguel Ojeda     Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect())
99b437b383SMiguel Ojeda }
100b437b383SMiguel Ojeda pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> {
101b437b383SMiguel Ojeda     let from_zeros = derive_from_zeros(ctx, top_level)?;
102b437b383SMiguel Ojeda     let from_bytes = match &ctx.ast.data {
103b437b383SMiguel Ojeda         Data::Struct(strct) => derive_from_bytes_struct(ctx, strct),
104b437b383SMiguel Ojeda         Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?,
105b437b383SMiguel Ojeda         Data::Union(unn) => derive_from_bytes_union(ctx, unn),
106b437b383SMiguel Ojeda     };
107b437b383SMiguel Ojeda 
108b437b383SMiguel Ojeda     Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect())
109b437b383SMiguel Ojeda }
110b437b383SMiguel Ojeda fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
111b437b383SMiguel Ojeda     ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build()
112b437b383SMiguel Ojeda }
113b437b383SMiguel Ojeda fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
114b437b383SMiguel Ojeda     let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
115b437b383SMiguel Ojeda 
116b437b383SMiguel Ojeda     // We don't actually care what the repr is; we just care that it's one of
117b437b383SMiguel Ojeda     // the allowed ones.
118b437b383SMiguel Ojeda     match repr {
119b437b383SMiguel Ojeda         Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => {
120b437b383SMiguel Ojeda         }
121b437b383SMiguel Ojeda         Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => {
122b437b383SMiguel Ojeda             return ctx.error_or_skip(
123b437b383SMiguel Ojeda                 Error::new(
124b437b383SMiguel Ojeda                     Span::call_site(),
125b437b383SMiguel Ojeda                     "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout",
126b437b383SMiguel Ojeda                 ),
127b437b383SMiguel Ojeda             );
128b437b383SMiguel Ojeda         }
129b437b383SMiguel Ojeda     }
130b437b383SMiguel Ojeda 
131b437b383SMiguel Ojeda     let zero_variant = match find_zero_variant(enm) {
132b437b383SMiguel Ojeda         Ok(index) => enm.variants.iter().nth(index).unwrap(),
133b437b383SMiguel Ojeda         // Has unknown variants
134b437b383SMiguel Ojeda         Err(true) => {
135b437b383SMiguel Ojeda             return ctx.error_or_skip(Error::new_spanned(
136b437b383SMiguel Ojeda                 &ctx.ast,
137b437b383SMiguel Ojeda                 "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\
138b437b383SMiguel Ojeda                 help: This enum has discriminants which are not literal integers. One of those may \
139b437b383SMiguel Ojeda                 define or imply which variant has a discriminant of zero. Use a literal integer to \
140b437b383SMiguel Ojeda                 define or imply the variant with a discriminant of zero.",
141b437b383SMiguel Ojeda             ));
142b437b383SMiguel Ojeda         }
143b437b383SMiguel Ojeda         // Does not have unknown variants
144b437b383SMiguel Ojeda         Err(false) => {
145b437b383SMiguel Ojeda             return ctx.error_or_skip(Error::new_spanned(
146b437b383SMiguel Ojeda                 &ctx.ast,
147b437b383SMiguel Ojeda                 "FromZeros only supported on enums with a variant that has a discriminant of `0`",
148b437b383SMiguel Ojeda             ));
149b437b383SMiguel Ojeda         }
150b437b383SMiguel Ojeda     };
151b437b383SMiguel Ojeda 
152b437b383SMiguel Ojeda     let zerocopy_crate = &ctx.zerocopy_crate;
153b437b383SMiguel Ojeda     let explicit_bounds = zero_variant
154b437b383SMiguel Ojeda         .fields
155b437b383SMiguel Ojeda         .iter()
156b437b383SMiguel Ojeda         .map(|field| {
157b437b383SMiguel Ojeda             let ty = &field.ty;
158b437b383SMiguel Ojeda             parse_quote! { #ty: #zerocopy_crate::FromZeros }
159b437b383SMiguel Ojeda         })
160b437b383SMiguel Ojeda         .collect::<Vec<WherePredicate>>();
161b437b383SMiguel Ojeda 
162b437b383SMiguel Ojeda     Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds))
163b437b383SMiguel Ojeda         .build())
164b437b383SMiguel Ojeda }
165b437b383SMiguel Ojeda fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
166b437b383SMiguel Ojeda     let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
167b437b383SMiguel Ojeda     ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build()
168b437b383SMiguel Ojeda }
169b437b383SMiguel Ojeda fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream {
170b437b383SMiguel Ojeda     ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build()
171b437b383SMiguel Ojeda }
172b437b383SMiguel Ojeda fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> {
173b437b383SMiguel Ojeda     let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?;
174b437b383SMiguel Ojeda 
175b437b383SMiguel Ojeda     let variants_required = 1usize << enum_size_from_repr(&repr)?;
176b437b383SMiguel Ojeda     if enm.variants.len() != variants_required {
177b437b383SMiguel Ojeda         return ctx.error_or_skip(Error::new_spanned(
178b437b383SMiguel Ojeda             &ctx.ast,
179b437b383SMiguel Ojeda             format!(
180b437b383SMiguel Ojeda                 "FromBytes only supported on {} enum with {} variants",
181b437b383SMiguel Ojeda                 repr.repr_type_name(),
182b437b383SMiguel Ojeda                 variants_required
183b437b383SMiguel Ojeda             ),
184b437b383SMiguel Ojeda         ));
185b437b383SMiguel Ojeda     }
186b437b383SMiguel Ojeda 
187b437b383SMiguel Ojeda     Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build())
188b437b383SMiguel Ojeda }
189b437b383SMiguel Ojeda fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream {
190b437b383SMiguel Ojeda     let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]);
191b437b383SMiguel Ojeda     ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build()
192b437b383SMiguel Ojeda }
193