1 use proc_macro2::{Span, TokenStream}; 2 use syn::{ 3 parse_quote, Data, DataEnum, DataStruct, DataUnion, Error, Expr, ExprLit, ExprUnary, Lit, UnOp, 4 WherePredicate, 5 }; 6 7 use crate::{ 8 derive::try_from_bytes::derive_try_from_bytes, 9 repr::{CompoundRepr, EnumRepr, Repr, Spanned}, 10 util::{enum_size_from_repr, Ctx, FieldBounds, ImplBlockBuilder, Trait, TraitBound}, 11 }; 12 /// Returns `Ok(index)` if variant `index` of the enum has a discriminant of 13 /// zero. If `Err(bool)` is returned, the boolean is true if the enum has 14 /// unknown discriminants (e.g. discriminants set to const expressions which we 15 /// can't evaluate in a proc macro). If the enum has unknown discriminants, then 16 /// it might have a zero variant that we just can't detect. 17 pub(crate) fn find_zero_variant(enm: &DataEnum) -> Result<usize, bool> { 18 // Discriminants can be anywhere in the range [i128::MIN, u128::MAX] because 19 // the discriminant type may be signed or unsigned. Since we only care about 20 // tracking the discriminant when it's less than or equal to zero, we can 21 // avoid u128 -> i128 conversions and bounds checking by making the "next 22 // discriminant" value implicitly negative. 23 // Technically 64 bits is enough, but 128 is better for future compatibility 24 // with https://github.com/rust-lang/rust/issues/56071 25 let mut next_negative_discriminant = Some(0); 26 27 // Sometimes we encounter explicit discriminants that we can't know the 28 // value of (e.g. a constant expression that requires evaluation). These 29 // could evaluate to zero or a negative number, but we can't assume that 30 // they do (no false positives allowed!). So we treat them like strictly- 31 // positive values that can't result in any zero variants, and track whether 32 // we've encountered any unknown discriminants. 33 let mut has_unknown_discriminants = false; 34 35 for (i, v) in enm.variants.iter().enumerate() { 36 match v.discriminant.as_ref() { 37 // Implicit discriminant 38 None => { 39 match next_negative_discriminant.as_mut() { 40 Some(0) => return Ok(i), 41 // n is nonzero so subtraction is always safe 42 Some(n) => *n -= 1, 43 None => (), 44 } 45 } 46 // Explicit positive discriminant 47 Some((_, Expr::Lit(ExprLit { lit: Lit::Int(int), .. }))) => { 48 match int.base10_parse::<u128>().ok() { 49 Some(0) => return Ok(i), 50 Some(_) => next_negative_discriminant = None, 51 None => { 52 // Numbers should never fail to parse, but just in case: 53 has_unknown_discriminants = true; 54 next_negative_discriminant = None; 55 } 56 } 57 } 58 // Explicit negative discriminant 59 Some((_, Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }))) => match &**expr { 60 Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => { 61 match int.base10_parse::<u128>().ok() { 62 Some(0) => return Ok(i), 63 // x is nonzero so subtraction is always safe 64 Some(x) => next_negative_discriminant = Some(x - 1), 65 None => { 66 // Numbers should never fail to parse, but just in 67 // case: 68 has_unknown_discriminants = true; 69 next_negative_discriminant = None; 70 } 71 } 72 } 73 // Unknown negative discriminant (e.g. const repr) 74 _ => { 75 has_unknown_discriminants = true; 76 next_negative_discriminant = None; 77 } 78 }, 79 // Unknown discriminant (e.g. const expr) 80 _ => { 81 has_unknown_discriminants = true; 82 next_negative_discriminant = None; 83 } 84 } 85 } 86 87 Err(has_unknown_discriminants) 88 } 89 pub(crate) fn derive_from_zeros(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> { 90 let try_from_bytes = derive_try_from_bytes(ctx, top_level)?; 91 let from_zeros = match &ctx.ast.data { 92 Data::Struct(strct) => derive_from_zeros_struct(ctx, strct), 93 Data::Enum(enm) => derive_from_zeros_enum(ctx, enm)?, 94 Data::Union(unn) => derive_from_zeros_union(ctx, unn), 95 }; 96 Ok(IntoIterator::into_iter([try_from_bytes, from_zeros]).collect()) 97 } 98 pub(crate) fn derive_from_bytes(ctx: &Ctx, top_level: Trait) -> Result<TokenStream, Error> { 99 let from_zeros = derive_from_zeros(ctx, top_level)?; 100 let from_bytes = match &ctx.ast.data { 101 Data::Struct(strct) => derive_from_bytes_struct(ctx, strct), 102 Data::Enum(enm) => derive_from_bytes_enum(ctx, enm)?, 103 Data::Union(unn) => derive_from_bytes_union(ctx, unn), 104 }; 105 106 Ok(IntoIterator::into_iter([from_zeros, from_bytes]).collect()) 107 } 108 fn derive_from_zeros_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream { 109 ImplBlockBuilder::new(ctx, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build() 110 } 111 fn derive_from_zeros_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> { 112 let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?; 113 114 // We don't actually care what the repr is; we just care that it's one of 115 // the allowed ones. 116 match repr { 117 Repr::Compound(Spanned { t: CompoundRepr::C | CompoundRepr::Primitive(_), span: _ }, _) => { 118 } 119 Repr::Transparent(_) | Repr::Compound(Spanned { t: CompoundRepr::Rust, span: _ }, _) => { 120 return ctx.error_or_skip( 121 Error::new( 122 Span::call_site(), 123 "must have #[repr(C)] or #[repr(Int)] attribute in order to guarantee this type's memory layout", 124 ), 125 ); 126 } 127 } 128 129 let zero_variant = match find_zero_variant(enm) { 130 Ok(index) => enm.variants.iter().nth(index).unwrap(), 131 // Has unknown variants 132 Err(true) => { 133 return ctx.error_or_skip(Error::new_spanned( 134 &ctx.ast, 135 "FromZeros only supported on enums with a variant that has a discriminant of `0`\n\ 136 help: This enum has discriminants which are not literal integers. One of those may \ 137 define or imply which variant has a discriminant of zero. Use a literal integer to \ 138 define or imply the variant with a discriminant of zero.", 139 )); 140 } 141 // Does not have unknown variants 142 Err(false) => { 143 return ctx.error_or_skip(Error::new_spanned( 144 &ctx.ast, 145 "FromZeros only supported on enums with a variant that has a discriminant of `0`", 146 )); 147 } 148 }; 149 150 let zerocopy_crate = &ctx.zerocopy_crate; 151 let explicit_bounds = zero_variant 152 .fields 153 .iter() 154 .map(|field| { 155 let ty = &field.ty; 156 parse_quote! { #ty: #zerocopy_crate::FromZeros } 157 }) 158 .collect::<Vec<WherePredicate>>(); 159 160 Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds)) 161 .build()) 162 } 163 fn derive_from_zeros_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream { 164 let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]); 165 ImplBlockBuilder::new(ctx, unn, Trait::FromZeros, field_type_trait_bounds).build() 166 } 167 fn derive_from_bytes_struct(ctx: &Ctx, strct: &DataStruct) -> TokenStream { 168 ImplBlockBuilder::new(ctx, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build() 169 } 170 fn derive_from_bytes_enum(ctx: &Ctx, enm: &DataEnum) -> Result<TokenStream, Error> { 171 let repr = EnumRepr::from_attrs(&ctx.ast.attrs)?; 172 173 let variants_required = 1usize << enum_size_from_repr(&repr)?; 174 if enm.variants.len() != variants_required { 175 return ctx.error_or_skip(Error::new_spanned( 176 &ctx.ast, 177 format!( 178 "FromBytes only supported on {} enum with {} variants", 179 repr.repr_type_name(), 180 variants_required 181 ), 182 )); 183 } 184 185 Ok(ImplBlockBuilder::new(ctx, enm, Trait::FromBytes, FieldBounds::ALL_SELF).build()) 186 } 187 fn derive_from_bytes_union(ctx: &Ctx, unn: &DataUnion) -> TokenStream { 188 let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf]); 189 ImplBlockBuilder::new(ctx, unn, Trait::FromBytes, field_type_trait_bounds).build() 190 } 191