1 // SPDX-License-Identifier: (BSD-2-Clause OR Apache-2.0) OR MIT 2 3 use proc_macro2::{Span, TokenStream}; 4 use syn::{ 5 parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp, 6 WherePredicate, 7 }; 8 9 use crate::{ 10 derive::try_from_bytes::derive_try_from_bytes, 11 repr::{CompoundRepr, EnumRepr, Repr, Spanned}, 12 util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound}, 13 }; 14 /// Returns `Ok(index)` if variant `index` of the enum has a discriminant of 15 /// zero. If `Err(bool)` is returned, the boolean is true if the enum has 16 /// unknown discriminants (e.g. discriminants set to const expressions which we 17 /// can't evaluate in a proc macro). If the enum has unknown discriminants, then 18 /// it might have a zero variant that we just can't detect. 19 pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> { 20 // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because 21 // the discriminant type may be signed or unsigned. Since we only care about 22 // tracking the discriminant when it's less than or equal to zero, we can 23 // avoid u128 -> i128 conversions and bounds checking by making the "next 24 // discriminant" value implicitly negative. 25 // Technically 64 bits is enough, but 128 is better for future compatibility 26 // with https://github.com/rust-lang/rust/issues/56071 27 let mut next_negative_discriminant = Some(0); 28 29 // Sometimes we encounter explicit discriminants that we can't know the 30 // value of (e.g. a constant expression that requires evaluation). These 31 // could evaluate to zero or a negative number, but we can't assume that 32 // they do (no false positives allowed!). So we treat them like strictly- 33 // positive values that can't result in any zero variants, and track whether 34 // we've encountered any unknown discriminants. 35 let mut has_unknown_discriminants = false; 36 37 for (i, v) in enm.variants.iter().enumerate() { 38 match v.discriminant.as_ref() { 39 // Implicit discriminant 40 None => { 41 match next_negative_discriminant.as_mut() { 42 Some(0) => return Ok(i), 43 // n is nonzero so subtraction is always safe 44 Some(n) => *n -= 1, 45 None => (), 46 } 47 } 48 // Explicit positive discriminant 49 Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => { 50 match int.base10_parse::<u128>().ok() { 51 Some(0) => return Ok(i), 52 Some(_) => next_negative_discriminant = None, 53 None => { 54 // Numbers should never fail to parse, but just in case: 55 has_unknown_discriminants = true; 56 next_negative_discriminant = None; 57 } 58 } 59 } 60 // Explicit negative discriminant 61 Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr { 62 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => { 63 match int.base10_parse::<u128>().ok() { 64 Some(0) => return Ok(i), 65 // x is nonzero so subtraction is always safe 66 Some(x) => next_negative_discriminant = Some(x - 1), 67 None => { 68 // Numbers should never fail to parse, but just in 69 // case: 70 has_unknown_discriminants = true; 71 next_negative_discriminant = None; 72 } 73 } 74 } 75 // Unknown negative discriminant (e.g. const repr) 76 _ => { 77 has_unknown_discriminants = true; 78 next_negative_discriminant = None; 79 } 80 }, 81 // Unknown discriminant (e.g. const expr) 82 _ => { 83 has_unknown_discriminants = true; 84 next_negative_discriminant = None; 85 } 86 } 87 } 88 89 Err(has_unknown_discriminants) 90 } 91 pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> { 92 let try_from_bytes = derive_try_from_bytes(ctx, top_level)?; 93 let from_zeros = match &ctx.ast.data { 94 Data::Struct(strct) => derive_from_zeros_struct(ctx, strct), 95 Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?, 96 Data::Union(unn) => derive_from_zeros_union(ctx, unn), 97 }; 98 Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect()) 99 } 100 pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> { 101 let from_zeros = derive_from_zeros(ctx, top_level)?; 102 let from_bytes = match &ctx.ast.data { 103 Data::Struct(strct) => derive_from_bytes_struct(ctx, strct), 104 Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?, 105 Data::Union(unn) => derive_from_bytes_union(ctx, unn), 106 }; 107 108 Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect()) 109 } 110 fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream { 111 ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build() 112 } 113 fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> { 114 let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?; 115 116 // We don't actually care what the repr is; we just care that it's one of 117 // the allowed ones. 118 match repr { 119 Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => { 120 } 121 Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => { 122 return ctx.error_or_skip( 123 Error::new( 124 Span::call_site(), 125 "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout", 126 ), 127 ); 128 } 129 } 130 131 let zero_variant = match find_zero_variant(enm) { 132 Ok(index) => enm.variants.iter().nth(index).unwrap(), 133 // Has unknown variants 134 Err(true) => { 135 return ctx.error_or_skip(Error::new_spanned( 136 &ctx.ast, 137 "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\ 138 help: This enum has discriminants which are not literal integers. One of those may \ 139 define or imply which variant has a discriminant of zero. Use a literal integer to \ 140 define or imply the variant with a discriminant of zero.", 141 )); 142 } 143 // Does not have unknown variants 144 Err(false) => { 145 return ctx.error_or_skip(Error::new_spanned( 146 &ctx.ast, 147 "FromZeros only supported on enums with a variant that has a discriminant of `0`", 148 )); 149 } 150 }; 151 152 let zerocopy_crate = &ctx.zerocopy_crate; 153 let explicit_bounds = zero_variant 154 .fields 155 .iter() 156 .map(|field| { 157 let ty = &field.ty; 158 parse_quote! { #ty: #zerocopy_crate::FromZeros } 159 }) 160 .collect::<Vec<WherePredicate>>(); 161 162 Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds)) 163 .build()) 164 } 165 fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream { 166 let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]); 167 ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build() 168 } 169 fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream { 170 ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build() 171 } 172 fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> { 173 let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?; 174 175 let variants_required = 1usize << enum_size_from_repr(&repr)?; 176 if enm.variants.len() != variants_required { 177 return ctx.error_or_skip(Error::new_spanned( 178 &ctx.ast, 179 format!( 180 "FromBytes only supported on {} enum with {} variants", 181 repr.repr_type_name(), 182 variants_required 183 ), 184 )); 185 } 186 187 Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build()) 188 } 189 fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream { 190 let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]); 191 ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build() 192 } 193