xref: /linux/rust/zerocopy-derive/util.rs (revision 5f85604cf0877b0369dfd68cd50cf61c0f134819)
1*5f85604cSMiguel Ojeda // SPDX-License-Identifier: (BSD-2-Clause OR Apache-2.0) OR MIT
2*5f85604cSMiguel Ojeda 
3b437b383SMiguel Ojeda // Copyright 2019 The Fuchsia Authors
4b437b383SMiguel Ojeda //
5b437b383SMiguel Ojeda // Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
6b437b383SMiguel Ojeda // <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
7b437b383SMiguel Ojeda // license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
8b437b383SMiguel Ojeda // This file may not be copied, modified, or distributed except according to
9b437b383SMiguel Ojeda // those terms.
10b437b383SMiguel Ojeda 
11b437b383SMiguel Ojeda use std::num::NonZeroU32;
12b437b383SMiguel Ojeda 
13b437b383SMiguel Ojeda use proc_macro2::{Span, TokenStream};
14b437b383SMiguel Ojeda use quote::{quote, quote_spanned, ToTokens};
15b437b383SMiguel Ojeda use syn::{
16b437b383SMiguel Ojeda     parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error,
17b437b383SMiguel Ojeda     Expr, ExprLit, Field, GenericParam, Ident, Index, Lit, LitStr, Meta, Path, Type, Variant,
18b437b383SMiguel Ojeda     Visibility, WherePredicate,
19b437b383SMiguel Ojeda };
20b437b383SMiguel Ojeda 
21b437b383SMiguel Ojeda use crate::repr::{CompoundRepr, EnumRepr, PrimitiveRepr, Repr, Spanned};
22b437b383SMiguel Ojeda 
23b437b383SMiguel Ojeda pub(crate) struct Ctx {
24b437b383SMiguel Ojeda     pub(crate) ast: DeriveInput,
25b437b383SMiguel Ojeda     pub(crate) zerocopy_crate: Path,
26b437b383SMiguel Ojeda 
27b437b383SMiguel Ojeda     // The value of the last `#[zerocopy(on_error = ...)]` attribute, or `false`
28b437b383SMiguel Ojeda     // if none is provided.
29b437b383SMiguel Ojeda     pub(crate) skip_on_error: bool,
30b437b383SMiguel Ojeda 
31b437b383SMiguel Ojeda     // The span of the last `#[zerocopy(on_error = ...)]` attribute, if any.
32b437b383SMiguel Ojeda     pub(crate) on_error_span: Option<proc_macro2::Span>,
33b437b383SMiguel Ojeda }
34b437b383SMiguel Ojeda 
35b437b383SMiguel Ojeda impl Ctx {
36b437b383SMiguel Ojeda     /// Attempt to extract a crate path from the provided attributes. Defaults to
37b437b383SMiguel Ojeda     /// `::zerocopy` if not found.
38b437b383SMiguel Ojeda     pub(crate) fn try_from_derive_input(ast: DeriveInput) -> Result<Self, Error> {
39b437b383SMiguel Ojeda         let mut path = parse_quote!(::zerocopy);
40b437b383SMiguel Ojeda         let mut skip_on_error = false;
41b437b383SMiguel Ojeda         let mut on_error_span = None;
42b437b383SMiguel Ojeda 
43b437b383SMiguel Ojeda         for attr in &ast.attrs {
44b437b383SMiguel Ojeda             if let Meta::List(ref meta_list) = attr.meta {
45b437b383SMiguel Ojeda                 if meta_list.path.is_ident("zerocopy") {
46b437b383SMiguel Ojeda                     attr.parse_nested_meta(|meta| {
47b437b383SMiguel Ojeda                         if meta.path.is_ident("crate") {
48b437b383SMiguel Ojeda                             let expr = meta.value().and_then(|value| value.parse());
49b437b383SMiguel Ojeda                             if let Ok(Expr::Lit(ExprLit { lit: Lit::Str(lit), .. })) = expr {
50b437b383SMiguel Ojeda                                 if let Ok(path_lit) = lit.parse::<Ident>() {
51b437b383SMiguel Ojeda                                     path = parse_quote!(::#path_lit);
52b437b383SMiguel Ojeda                                     return Ok(());
53b437b383SMiguel Ojeda                                 }
54b437b383SMiguel Ojeda                             }
55b437b383SMiguel Ojeda 
56b437b383SMiguel Ojeda                             return Err(Error::new(
57b437b383SMiguel Ojeda                                 Span::call_site(),
58b437b383SMiguel Ojeda                                 "`crate` attribute requires a path as the value",
59b437b383SMiguel Ojeda                             ));
60b437b383SMiguel Ojeda                         }
61b437b383SMiguel Ojeda 
62b437b383SMiguel Ojeda                         if meta.path.is_ident("on_error") {
63b437b383SMiguel Ojeda                             on_error_span = Some(meta.path.span());
64b437b383SMiguel Ojeda                             let value = meta.value()?;
65b437b383SMiguel Ojeda                             let s: LitStr = value.parse()?;
66b437b383SMiguel Ojeda                             match s.value().as_str() {
67b437b383SMiguel Ojeda                                 "skip" => skip_on_error = true,
68b437b383SMiguel Ojeda                                 "fail" => skip_on_error = false,
69b437b383SMiguel Ojeda                                 _ => return Err(Error::new(
70b437b383SMiguel Ojeda                                     s.span(),
71b437b383SMiguel Ojeda                                     "unrecognized value for `on_error` attribute from `zerocopy`; expected `skip` or `fail`",
72b437b383SMiguel Ojeda                                 )),
73b437b383SMiguel Ojeda                             }
74b437b383SMiguel Ojeda                             return Ok(());
75b437b383SMiguel Ojeda                         }
76b437b383SMiguel Ojeda 
77b437b383SMiguel Ojeda                         Err(Error::new(
78b437b383SMiguel Ojeda                             Span::call_site(),
79b437b383SMiguel Ojeda                             format!(
80b437b383SMiguel Ojeda                                 "unknown attribute encountered: {}",
81b437b383SMiguel Ojeda                                 meta.path.into_token_stream()
82b437b383SMiguel Ojeda                             ),
83b437b383SMiguel Ojeda                         ))
84b437b383SMiguel Ojeda                     })?;
85b437b383SMiguel Ojeda                 }
86b437b383SMiguel Ojeda             }
87b437b383SMiguel Ojeda         }
88b437b383SMiguel Ojeda 
89b437b383SMiguel Ojeda         Ok(Self { ast, zerocopy_crate: path, skip_on_error, on_error_span })
90b437b383SMiguel Ojeda     }
91b437b383SMiguel Ojeda 
92b437b383SMiguel Ojeda     pub(crate) fn with_input(&self, input: &DeriveInput) -> Self {
93b437b383SMiguel Ojeda         Self {
94b437b383SMiguel Ojeda             ast: input.clone(),
95b437b383SMiguel Ojeda             zerocopy_crate: self.zerocopy_crate.clone(),
96b437b383SMiguel Ojeda             skip_on_error: self.skip_on_error,
97b437b383SMiguel Ojeda             on_error_span: self.on_error_span,
98b437b383SMiguel Ojeda         }
99b437b383SMiguel Ojeda     }
100b437b383SMiguel Ojeda 
101b437b383SMiguel Ojeda     pub(crate) fn core_path(&self) -> TokenStream {
102b437b383SMiguel Ojeda         let zerocopy_crate = &self.zerocopy_crate;
103b437b383SMiguel Ojeda         quote!(#zerocopy_crate::util::macro_util::core_reexport)
104b437b383SMiguel Ojeda     }
105b437b383SMiguel Ojeda 
106b437b383SMiguel Ojeda     pub(crate) fn cfg_compile_error(&self) -> TokenStream {
107b437b383SMiguel Ojeda         // By checking both during the compilation of the proc macro *and* in
108b437b383SMiguel Ojeda         // the generated code, we ensure that `--cfg
109b437b383SMiguel Ojeda         // zerocopy_unstable_derive_on_error` need only be passed *either* when
110b437b383SMiguel Ojeda         // compiling this crate *or* when compiling the user's crate. The former
111b437b383SMiguel Ojeda         // is preferable, but in some situations (such as when cross-compiling
112b437b383SMiguel Ojeda         // using `cargo build --target`), it doesn't get propagated to this
113b437b383SMiguel Ojeda         // crate's build by default.
114b437b383SMiguel Ojeda         if cfg!(zerocopy_unstable_derive_on_error) {
115b437b383SMiguel Ojeda             quote!()
116b437b383SMiguel Ojeda         } else if let Some(span) = self.on_error_span {
117b437b383SMiguel Ojeda             let core = self.core_path();
118b437b383SMiguel Ojeda             let error_message = "`on_error` is experimental; pass '--cfg zerocopy_unstable_derive_on_error' to enable";
119b437b383SMiguel Ojeda             quote::quote_spanned! {span=>
120b437b383SMiguel Ojeda                 #[allow(unused_attributes, unexpected_cfgs)]
121b437b383SMiguel Ojeda                 const _: () = {
122b437b383SMiguel Ojeda                     #[cfg(not(zerocopy_unstable_derive_on_error))]
123b437b383SMiguel Ojeda                     #core::compile_error!(#error_message);
124b437b383SMiguel Ojeda                 };
125b437b383SMiguel Ojeda             }
126b437b383SMiguel Ojeda         } else {
127b437b383SMiguel Ojeda             quote!()
128b437b383SMiguel Ojeda         }
129b437b383SMiguel Ojeda     }
130b437b383SMiguel Ojeda 
131b437b383SMiguel Ojeda     pub(crate) fn error_or_skip<E>(&self, error: E) -> Result<TokenStream, E> {
132b437b383SMiguel Ojeda         if self.skip_on_error {
133b437b383SMiguel Ojeda             Ok(self.cfg_compile_error())
134b437b383SMiguel Ojeda         } else {
135b437b383SMiguel Ojeda             Err(error)
136b437b383SMiguel Ojeda         }
137b437b383SMiguel Ojeda     }
138b437b383SMiguel Ojeda }
139b437b383SMiguel Ojeda 
140b437b383SMiguel Ojeda pub(crate) trait DataExt {
141b437b383SMiguel Ojeda     /// Extracts the names and types of all fields. For enums, extracts the
142b437b383SMiguel Ojeda     /// names and types of fields from each variant. For tuple structs, the
143b437b383SMiguel Ojeda     /// names are the indices used to index into the struct (ie, `0`, `1`, etc).
144b437b383SMiguel Ojeda     ///
145b437b383SMiguel Ojeda     /// FIXME: Extracting field names for enums doesn't really make sense. Types
146b437b383SMiguel Ojeda     /// makes sense because we don't care about where they live - we just care
147b437b383SMiguel Ojeda     /// about transitive ownership. But for field names, we'd only use them when
148b437b383SMiguel Ojeda     /// generating is_bit_valid, which cares about where they live.
149b437b383SMiguel Ojeda     fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)>;
150b437b383SMiguel Ojeda 
151b437b383SMiguel Ojeda     fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)>;
152b437b383SMiguel Ojeda 
153b437b383SMiguel Ojeda     fn tag(&self) -> Option<Ident>;
154b437b383SMiguel Ojeda }
155b437b383SMiguel Ojeda 
156b437b383SMiguel Ojeda impl DataExt for Data {
157b437b383SMiguel Ojeda     fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
158b437b383SMiguel Ojeda         match self {
159b437b383SMiguel Ojeda             Data::Struct(strc) => strc.fields(),
160b437b383SMiguel Ojeda             Data::Enum(enm) => enm.fields(),
161b437b383SMiguel Ojeda             Data::Union(un) => un.fields(),
162b437b383SMiguel Ojeda         }
163b437b383SMiguel Ojeda     }
164b437b383SMiguel Ojeda 
165b437b383SMiguel Ojeda     fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
166b437b383SMiguel Ojeda         match self {
167b437b383SMiguel Ojeda             Data::Struct(strc) => strc.variants(),
168b437b383SMiguel Ojeda             Data::Enum(enm) => enm.variants(),
169b437b383SMiguel Ojeda             Data::Union(un) => un.variants(),
170b437b383SMiguel Ojeda         }
171b437b383SMiguel Ojeda     }
172b437b383SMiguel Ojeda 
173b437b383SMiguel Ojeda     fn tag(&self) -> Option<Ident> {
174b437b383SMiguel Ojeda         match self {
175b437b383SMiguel Ojeda             Data::Struct(strc) => strc.tag(),
176b437b383SMiguel Ojeda             Data::Enum(enm) => enm.tag(),
177b437b383SMiguel Ojeda             Data::Union(un) => un.tag(),
178b437b383SMiguel Ojeda         }
179b437b383SMiguel Ojeda     }
180b437b383SMiguel Ojeda }
181b437b383SMiguel Ojeda 
182b437b383SMiguel Ojeda impl DataExt for DataStruct {
183b437b383SMiguel Ojeda     fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
184b437b383SMiguel Ojeda         map_fields(&self.fields)
185b437b383SMiguel Ojeda     }
186b437b383SMiguel Ojeda 
187b437b383SMiguel Ojeda     fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
188b437b383SMiguel Ojeda         vec![(None, self.fields())]
189b437b383SMiguel Ojeda     }
190b437b383SMiguel Ojeda 
191b437b383SMiguel Ojeda     fn tag(&self) -> Option<Ident> {
192b437b383SMiguel Ojeda         None
193b437b383SMiguel Ojeda     }
194b437b383SMiguel Ojeda }
195b437b383SMiguel Ojeda 
196b437b383SMiguel Ojeda impl DataExt for DataEnum {
197b437b383SMiguel Ojeda     fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
198b437b383SMiguel Ojeda         map_fields(self.variants.iter().flat_map(|var| &var.fields))
199b437b383SMiguel Ojeda     }
200b437b383SMiguel Ojeda 
201b437b383SMiguel Ojeda     fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
202b437b383SMiguel Ojeda         self.variants.iter().map(|var| (Some(var), map_fields(&var.fields))).collect()
203b437b383SMiguel Ojeda     }
204b437b383SMiguel Ojeda 
205b437b383SMiguel Ojeda     fn tag(&self) -> Option<Ident> {
206b437b383SMiguel Ojeda         Some(Ident::new("___ZerocopyTag", Span::call_site()))
207b437b383SMiguel Ojeda     }
208b437b383SMiguel Ojeda }
209b437b383SMiguel Ojeda 
210b437b383SMiguel Ojeda impl DataExt for DataUnion {
211b437b383SMiguel Ojeda     fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> {
212b437b383SMiguel Ojeda         map_fields(&self.fields.named)
213b437b383SMiguel Ojeda     }
214b437b383SMiguel Ojeda 
215b437b383SMiguel Ojeda     fn variants(&self) -> Vec<(Option<&Variant>, Vec<(&Visibility, TokenStream, &Type)>)> {
216b437b383SMiguel Ojeda         vec![(None, self.fields())]
217b437b383SMiguel Ojeda     }
218b437b383SMiguel Ojeda 
219b437b383SMiguel Ojeda     fn tag(&self) -> Option<Ident> {
220b437b383SMiguel Ojeda         None
221b437b383SMiguel Ojeda     }
222b437b383SMiguel Ojeda }
223b437b383SMiguel Ojeda 
224b437b383SMiguel Ojeda fn map_fields<'a>(
225b437b383SMiguel Ojeda     fields: impl 'a + IntoIterator<Item = &'a Field>,
226b437b383SMiguel Ojeda ) -> Vec<(&'a Visibility, TokenStream, &'a Type)> {
227b437b383SMiguel Ojeda     fields
228b437b383SMiguel Ojeda         .into_iter()
229b437b383SMiguel Ojeda         .enumerate()
230b437b383SMiguel Ojeda         .map(|(idx, f)| {
231b437b383SMiguel Ojeda             (
232b437b383SMiguel Ojeda                 &f.vis,
233b437b383SMiguel Ojeda                 f.ident
234b437b383SMiguel Ojeda                     .as_ref()
235b437b383SMiguel Ojeda                     .map(ToTokens::to_token_stream)
236b437b383SMiguel Ojeda                     .unwrap_or_else(|| Index::from(idx).to_token_stream()),
237b437b383SMiguel Ojeda                 &f.ty,
238b437b383SMiguel Ojeda             )
239b437b383SMiguel Ojeda         })
240b437b383SMiguel Ojeda         .collect()
241b437b383SMiguel Ojeda }
242b437b383SMiguel Ojeda 
243b437b383SMiguel Ojeda pub(crate) fn to_ident_str(t: &impl ToString) -> String {
244b437b383SMiguel Ojeda     let s = t.to_string();
245b437b383SMiguel Ojeda     if let Some(stripped) = s.strip_prefix("r#") {
246b437b383SMiguel Ojeda         stripped.to_string()
247b437b383SMiguel Ojeda     } else {
248b437b383SMiguel Ojeda         s
249b437b383SMiguel Ojeda     }
250b437b383SMiguel Ojeda }
251b437b383SMiguel Ojeda 
252b437b383SMiguel Ojeda /// This enum describes what kind of padding check needs to be generated for the
253b437b383SMiguel Ojeda /// associated impl.
254b437b383SMiguel Ojeda pub(crate) enum PaddingCheck {
255b437b383SMiguel Ojeda     /// Check that the sum of the fields' sizes exactly equals the struct's
256b437b383SMiguel Ojeda     /// size.
257b437b383SMiguel Ojeda     Struct,
258b437b383SMiguel Ojeda     /// Check that a `repr(C)` struct has no padding.
259b437b383SMiguel Ojeda     ReprCStruct,
260b437b383SMiguel Ojeda     /// Check that the size of each field exactly equals the union's size.
261b437b383SMiguel Ojeda     Union,
262b437b383SMiguel Ojeda     /// Check that every variant of the enum contains no padding.
263b437b383SMiguel Ojeda     ///
264b437b383SMiguel Ojeda     /// Because doing so requires a tag enum, this padding check requires an
265b437b383SMiguel Ojeda     /// additional `TokenStream` which defines the tag enum as `___ZerocopyTag`.
266b437b383SMiguel Ojeda     Enum { tag_type_definition: TokenStream },
267b437b383SMiguel Ojeda }
268b437b383SMiguel Ojeda 
269b437b383SMiguel Ojeda impl PaddingCheck {
270b437b383SMiguel Ojeda     /// Returns the idents of the trait to use and the macro to call in order to
271b437b383SMiguel Ojeda     /// validate that a type passes the relevant padding check.
272b437b383SMiguel Ojeda     pub(crate) fn validator_trait_and_macro_idents(&self) -> (Ident, Ident) {
273b437b383SMiguel Ojeda         let (trt, mcro) = match self {
274b437b383SMiguel Ojeda             PaddingCheck::Struct => ("PaddingFree", "struct_padding"),
275b437b383SMiguel Ojeda             PaddingCheck::ReprCStruct => ("DynamicPaddingFree", "repr_c_struct_has_padding"),
276b437b383SMiguel Ojeda             PaddingCheck::Union => ("PaddingFree", "union_padding"),
277b437b383SMiguel Ojeda             PaddingCheck::Enum { .. } => ("PaddingFree", "enum_padding"),
278b437b383SMiguel Ojeda         };
279b437b383SMiguel Ojeda 
280b437b383SMiguel Ojeda         let trt = Ident::new(trt, Span::call_site());
281b437b383SMiguel Ojeda         let mcro = Ident::new(mcro, Span::call_site());
282b437b383SMiguel Ojeda         (trt, mcro)
283b437b383SMiguel Ojeda     }
284b437b383SMiguel Ojeda 
285b437b383SMiguel Ojeda     /// Sometimes performing the padding check requires some additional
286b437b383SMiguel Ojeda     /// "context" code. For enums, this is the definition of the tag enum.
287b437b383SMiguel Ojeda     pub(crate) fn validator_macro_context(&self) -> Option<&TokenStream> {
288b437b383SMiguel Ojeda         match self {
289b437b383SMiguel Ojeda             PaddingCheck::Struct | PaddingCheck::ReprCStruct | PaddingCheck::Union => None,
290b437b383SMiguel Ojeda             PaddingCheck::Enum { tag_type_definition } => Some(tag_type_definition),
291b437b383SMiguel Ojeda         }
292b437b383SMiguel Ojeda     }
293b437b383SMiguel Ojeda }
294b437b383SMiguel Ojeda 
295b437b383SMiguel Ojeda #[derive(Clone)]
296b437b383SMiguel Ojeda pub(crate) enum Trait {
297b437b383SMiguel Ojeda     KnownLayout,
298b437b383SMiguel Ojeda     HasTag,
299b437b383SMiguel Ojeda     HasField {
300b437b383SMiguel Ojeda         variant_id: Box<Expr>,
301b437b383SMiguel Ojeda         field: Box<Type>,
302b437b383SMiguel Ojeda         field_id: Box<Expr>,
303b437b383SMiguel Ojeda     },
304b437b383SMiguel Ojeda     ProjectField {
305b437b383SMiguel Ojeda         variant_id: Box<Expr>,
306b437b383SMiguel Ojeda         field: Box<Type>,
307b437b383SMiguel Ojeda         field_id: Box<Expr>,
308b437b383SMiguel Ojeda         invariants: Box<Type>,
309b437b383SMiguel Ojeda     },
310b437b383SMiguel Ojeda     Immutable,
311b437b383SMiguel Ojeda     TryFromBytes,
312b437b383SMiguel Ojeda     FromZeros,
313b437b383SMiguel Ojeda     FromBytes,
314b437b383SMiguel Ojeda     IntoBytes,
315b437b383SMiguel Ojeda     Unaligned,
316b437b383SMiguel Ojeda     Sized,
317b437b383SMiguel Ojeda     ByteHash,
318b437b383SMiguel Ojeda     ByteEq,
319b437b383SMiguel Ojeda     SplitAt,
320b437b383SMiguel Ojeda }
321b437b383SMiguel Ojeda 
322b437b383SMiguel Ojeda impl ToTokens for Trait {
323b437b383SMiguel Ojeda     fn to_tokens(&self, tokens: &mut TokenStream) {
324b437b383SMiguel Ojeda         // According to [1], the format of the derived `Debug`` output is not
325b437b383SMiguel Ojeda         // stable and therefore not guaranteed to represent the variant names.
326b437b383SMiguel Ojeda         // Indeed with the (unstable) `fmt-debug` compiler flag [2], it can
327b437b383SMiguel Ojeda         // return only a minimalized output or empty string. To make sure this
328b437b383SMiguel Ojeda         // code will work in the future and independent of the compiler flag, we
329b437b383SMiguel Ojeda         // translate the variants to their names manually here.
330b437b383SMiguel Ojeda         //
331b437b383SMiguel Ojeda         // [1] https://doc.rust-lang.org/1.81.0/std/fmt/trait.Debug.html#stability
332b437b383SMiguel Ojeda         // [2] https://doc.rust-lang.org/beta/unstable-book/compiler-flags/fmt-debug.html
333b437b383SMiguel Ojeda         let s = match self {
334b437b383SMiguel Ojeda             Trait::HasField { .. } => "HasField",
335b437b383SMiguel Ojeda             Trait::ProjectField { .. } => "ProjectField",
336b437b383SMiguel Ojeda             Trait::KnownLayout => "KnownLayout",
337b437b383SMiguel Ojeda             Trait::HasTag => "HasTag",
338b437b383SMiguel Ojeda             Trait::Immutable => "Immutable",
339b437b383SMiguel Ojeda             Trait::TryFromBytes => "TryFromBytes",
340b437b383SMiguel Ojeda             Trait::FromZeros => "FromZeros",
341b437b383SMiguel Ojeda             Trait::FromBytes => "FromBytes",
342b437b383SMiguel Ojeda             Trait::IntoBytes => "IntoBytes",
343b437b383SMiguel Ojeda             Trait::Unaligned => "Unaligned",
344b437b383SMiguel Ojeda             Trait::Sized => "Sized",
345b437b383SMiguel Ojeda             Trait::ByteHash => "ByteHash",
346b437b383SMiguel Ojeda             Trait::ByteEq => "ByteEq",
347b437b383SMiguel Ojeda             Trait::SplitAt => "SplitAt",
348b437b383SMiguel Ojeda         };
349b437b383SMiguel Ojeda         let ident = Ident::new(s, Span::call_site());
350b437b383SMiguel Ojeda         let arguments: Option<syn::AngleBracketedGenericArguments> = match self {
351b437b383SMiguel Ojeda             Trait::HasField { variant_id, field, field_id } => {
352b437b383SMiguel Ojeda                 Some(parse_quote!(<#field, #variant_id, #field_id>))
353b437b383SMiguel Ojeda             }
354b437b383SMiguel Ojeda             Trait::ProjectField { variant_id, field, field_id, invariants } => {
355b437b383SMiguel Ojeda                 Some(parse_quote!(<#field, #invariants, #variant_id, #field_id>))
356b437b383SMiguel Ojeda             }
357b437b383SMiguel Ojeda             Trait::KnownLayout
358b437b383SMiguel Ojeda             | Trait::HasTag
359b437b383SMiguel Ojeda             | Trait::Immutable
360b437b383SMiguel Ojeda             | Trait::TryFromBytes
361b437b383SMiguel Ojeda             | Trait::FromZeros
362b437b383SMiguel Ojeda             | Trait::FromBytes
363b437b383SMiguel Ojeda             | Trait::IntoBytes
364b437b383SMiguel Ojeda             | Trait::Unaligned
365b437b383SMiguel Ojeda             | Trait::Sized
366b437b383SMiguel Ojeda             | Trait::ByteHash
367b437b383SMiguel Ojeda             | Trait::ByteEq
368b437b383SMiguel Ojeda             | Trait::SplitAt => None,
369b437b383SMiguel Ojeda         };
370b437b383SMiguel Ojeda         tokens.extend(quote!(#ident #arguments));
371b437b383SMiguel Ojeda     }
372b437b383SMiguel Ojeda }
373b437b383SMiguel Ojeda 
374b437b383SMiguel Ojeda impl Trait {
375b437b383SMiguel Ojeda     pub(crate) fn crate_path(&self, ctx: &Ctx) -> Path {
376b437b383SMiguel Ojeda         let zerocopy_crate = &ctx.zerocopy_crate;
377b437b383SMiguel Ojeda         let core = ctx.core_path();
378b437b383SMiguel Ojeda         match self {
379b437b383SMiguel Ojeda             Self::Sized => parse_quote!(#core::marker::#self),
380b437b383SMiguel Ojeda             _ => parse_quote!(#zerocopy_crate::#self),
381b437b383SMiguel Ojeda         }
382b437b383SMiguel Ojeda     }
383b437b383SMiguel Ojeda }
384b437b383SMiguel Ojeda 
385b437b383SMiguel Ojeda pub(crate) enum TraitBound {
386b437b383SMiguel Ojeda     Slf,
387b437b383SMiguel Ojeda     Other(Trait),
388b437b383SMiguel Ojeda }
389b437b383SMiguel Ojeda 
390b437b383SMiguel Ojeda pub(crate) enum FieldBounds<'a> {
391b437b383SMiguel Ojeda     None,
392b437b383SMiguel Ojeda     All(&'a [TraitBound]),
393b437b383SMiguel Ojeda     Trailing(&'a [TraitBound]),
394b437b383SMiguel Ojeda     Explicit(Vec<WherePredicate>),
395b437b383SMiguel Ojeda }
396b437b383SMiguel Ojeda 
397b437b383SMiguel Ojeda impl<'a> FieldBounds<'a> {
398b437b383SMiguel Ojeda     pub(crate) const ALL_SELF: FieldBounds<'a> = FieldBounds::All(&[TraitBound::Slf]);
399b437b383SMiguel Ojeda     pub(crate) const TRAILING_SELF: FieldBounds<'a> = FieldBounds::Trailing(&[TraitBound::Slf]);
400b437b383SMiguel Ojeda }
401b437b383SMiguel Ojeda 
402b437b383SMiguel Ojeda pub(crate) enum SelfBounds<'a> {
403b437b383SMiguel Ojeda     None,
404b437b383SMiguel Ojeda     All(&'a [Trait]),
405b437b383SMiguel Ojeda }
406b437b383SMiguel Ojeda 
407b437b383SMiguel Ojeda // FIXME(https://github.com/rust-lang/rust-clippy/issues/12908): This is a false
408b437b383SMiguel Ojeda // positive. Explicit lifetimes are actually necessary here.
409b437b383SMiguel Ojeda #[allow(clippy::needless_lifetimes)]
410b437b383SMiguel Ojeda impl<'a> SelfBounds<'a> {
411b437b383SMiguel Ojeda     pub(crate) const SIZED: Self = Self::All(&[Trait::Sized]);
412b437b383SMiguel Ojeda }
413b437b383SMiguel Ojeda 
414b437b383SMiguel Ojeda /// Normalizes a slice of bounds by replacing [`TraitBound::Slf`] with `slf`.
415b437b383SMiguel Ojeda pub(crate) fn normalize_bounds<'a>(
416b437b383SMiguel Ojeda     slf: &'a Trait,
417b437b383SMiguel Ojeda     bounds: &'a [TraitBound],
418b437b383SMiguel Ojeda ) -> impl 'a + Iterator<Item = Trait> {
419b437b383SMiguel Ojeda     bounds.iter().map(move |bound| match bound {
420b437b383SMiguel Ojeda         TraitBound::Slf => slf.clone(),
421b437b383SMiguel Ojeda         TraitBound::Other(trt) => trt.clone(),
422b437b383SMiguel Ojeda     })
423b437b383SMiguel Ojeda }
424b437b383SMiguel Ojeda 
425b437b383SMiguel Ojeda pub(crate) struct ImplBlockBuilder<'a> {
426b437b383SMiguel Ojeda     ctx: &'a Ctx,
427b437b383SMiguel Ojeda     data: &'a dyn DataExt,
428b437b383SMiguel Ojeda     trt: Trait,
429b437b383SMiguel Ojeda     field_type_trait_bounds: FieldBounds<'a>,
430b437b383SMiguel Ojeda     self_type_trait_bounds: SelfBounds<'a>,
431b437b383SMiguel Ojeda     padding_check: Option<PaddingCheck>,
432b437b383SMiguel Ojeda     param_extras: Vec<GenericParam>,
433b437b383SMiguel Ojeda     inner_extras: Option<TokenStream>,
434b437b383SMiguel Ojeda     outer_extras: Option<TokenStream>,
435b437b383SMiguel Ojeda }
436b437b383SMiguel Ojeda 
437b437b383SMiguel Ojeda impl<'a> ImplBlockBuilder<'a> {
438b437b383SMiguel Ojeda     pub(crate) fn new(
439b437b383SMiguel Ojeda         ctx: &'a Ctx,
440b437b383SMiguel Ojeda         data: &'a dyn DataExt,
441b437b383SMiguel Ojeda         trt: Trait,
442b437b383SMiguel Ojeda         field_type_trait_bounds: FieldBounds<'a>,
443b437b383SMiguel Ojeda     ) -> Self {
444b437b383SMiguel Ojeda         Self {
445b437b383SMiguel Ojeda             ctx,
446b437b383SMiguel Ojeda             data,
447b437b383SMiguel Ojeda             trt,
448b437b383SMiguel Ojeda             field_type_trait_bounds,
449b437b383SMiguel Ojeda             self_type_trait_bounds: SelfBounds::None,
450b437b383SMiguel Ojeda             padding_check: None,
451b437b383SMiguel Ojeda             param_extras: Vec::new(),
452b437b383SMiguel Ojeda             inner_extras: None,
453b437b383SMiguel Ojeda             outer_extras: None,
454b437b383SMiguel Ojeda         }
455b437b383SMiguel Ojeda     }
456b437b383SMiguel Ojeda 
457b437b383SMiguel Ojeda     pub(crate) fn self_type_trait_bounds(mut self, self_type_trait_bounds: SelfBounds<'a>) -> Self {
458b437b383SMiguel Ojeda         self.self_type_trait_bounds = self_type_trait_bounds;
459b437b383SMiguel Ojeda         self
460b437b383SMiguel Ojeda     }
461b437b383SMiguel Ojeda 
462b437b383SMiguel Ojeda     pub(crate) fn padding_check<P: Into<Option<PaddingCheck>>>(mut self, padding_check: P) -> Self {
463b437b383SMiguel Ojeda         self.padding_check = padding_check.into();
464b437b383SMiguel Ojeda         self
465b437b383SMiguel Ojeda     }
466b437b383SMiguel Ojeda 
467b437b383SMiguel Ojeda     pub(crate) fn param_extras(mut self, param_extras: Vec<GenericParam>) -> Self {
468b437b383SMiguel Ojeda         self.param_extras.extend(param_extras);
469b437b383SMiguel Ojeda         self
470b437b383SMiguel Ojeda     }
471b437b383SMiguel Ojeda 
472b437b383SMiguel Ojeda     pub(crate) fn inner_extras(mut self, inner_extras: TokenStream) -> Self {
473b437b383SMiguel Ojeda         self.inner_extras = Some(inner_extras);
474b437b383SMiguel Ojeda         self
475b437b383SMiguel Ojeda     }
476b437b383SMiguel Ojeda 
477b437b383SMiguel Ojeda     pub(crate) fn outer_extras<T: Into<Option<TokenStream>>>(mut self, outer_extras: T) -> Self {
478b437b383SMiguel Ojeda         self.outer_extras = outer_extras.into();
479b437b383SMiguel Ojeda         self
480b437b383SMiguel Ojeda     }
481b437b383SMiguel Ojeda 
482b437b383SMiguel Ojeda     pub(crate) fn build(self) -> TokenStream {
483b437b383SMiguel Ojeda         // In this documentation, we will refer to this hypothetical struct:
484b437b383SMiguel Ojeda         //
485b437b383SMiguel Ojeda         //   #[derive(FromBytes)]
486b437b383SMiguel Ojeda         //   struct Foo<T, I: Iterator>
487b437b383SMiguel Ojeda         //   where
488b437b383SMiguel Ojeda         //       T: Copy,
489b437b383SMiguel Ojeda         //       I: Clone,
490b437b383SMiguel Ojeda         //       I::Item: Clone,
491b437b383SMiguel Ojeda         //   {
492b437b383SMiguel Ojeda         //       a: u8,
493b437b383SMiguel Ojeda         //       b: T,
494b437b383SMiguel Ojeda         //       c: I::Item,
495b437b383SMiguel Ojeda         //   }
496b437b383SMiguel Ojeda         //
497b437b383SMiguel Ojeda         // We extract the field types, which in this case are `u8`, `T`, and
498b437b383SMiguel Ojeda         // `I::Item`. We re-use the existing parameters and where clauses. If
499b437b383SMiguel Ojeda         // `require_trait_bound == true` (as it is for `FromBytes), we add where
500b437b383SMiguel Ojeda         // bounds for each field's type:
501b437b383SMiguel Ojeda         //
502b437b383SMiguel Ojeda         //   impl<T, I: Iterator> FromBytes for Foo<T, I>
503b437b383SMiguel Ojeda         //   where
504b437b383SMiguel Ojeda         //       T: Copy,
505b437b383SMiguel Ojeda         //       I: Clone,
506b437b383SMiguel Ojeda         //       I::Item: Clone,
507b437b383SMiguel Ojeda         //       T: FromBytes,
508b437b383SMiguel Ojeda         //       I::Item: FromBytes,
509b437b383SMiguel Ojeda         //   {
510b437b383SMiguel Ojeda         //   }
511b437b383SMiguel Ojeda         //
512b437b383SMiguel Ojeda         // NOTE: It is standard practice to only emit bounds for the type
513b437b383SMiguel Ojeda         // parameters themselves, not for field types based on those parameters
514b437b383SMiguel Ojeda         // (e.g., `T` vs `T::Foo`). For a discussion of why this is standard
515b437b383SMiguel Ojeda         // practice, see https://github.com/rust-lang/rust/issues/26925.
516b437b383SMiguel Ojeda         //
517b437b383SMiguel Ojeda         // The reason we diverge from this standard is that doing it that way
518b437b383SMiguel Ojeda         // for us would be unsound. E.g., consider a type, `T` where `T:
519b437b383SMiguel Ojeda         // FromBytes` but `T::Foo: !FromBytes`. It would not be sound for us to
520b437b383SMiguel Ojeda         // accept a type with a `T::Foo` field as `FromBytes` simply because `T:
521b437b383SMiguel Ojeda         // FromBytes`.
522b437b383SMiguel Ojeda         //
523b437b383SMiguel Ojeda         // While there's no getting around this requirement for us, it does have
524b437b383SMiguel Ojeda         // the pretty serious downside that, when lifetimes are involved, the
525b437b383SMiguel Ojeda         // trait solver ties itself in knots:
526b437b383SMiguel Ojeda         //
527b437b383SMiguel Ojeda         //     #[derive(Unaligned)]
528b437b383SMiguel Ojeda         //     #[repr(C)]
529b437b383SMiguel Ojeda         //     struct Dup<'a, 'b> {
530b437b383SMiguel Ojeda         //         a: PhantomData<&'a u8>,
531b437b383SMiguel Ojeda         //         b: PhantomData<&'b u8>,
532b437b383SMiguel Ojeda         //     }
533b437b383SMiguel Ojeda         //
534b437b383SMiguel Ojeda         //     error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned`
535b437b383SMiguel Ojeda         //      --> src/main.rs:6:10
536b437b383SMiguel Ojeda         //       |
537b437b383SMiguel Ojeda         //     6 | #[derive(Unaligned)]
538b437b383SMiguel Ojeda         //       |          ^^^^^^^^^
539b437b383SMiguel Ojeda         //       |
540b437b383SMiguel Ojeda         //       = note: required by `zerocopy::Unaligned`
541b437b383SMiguel Ojeda 
542b437b383SMiguel Ojeda         let type_ident = &self.ctx.ast.ident;
543b437b383SMiguel Ojeda         let trait_path = self.trt.crate_path(self.ctx);
544b437b383SMiguel Ojeda         let fields = self.data.fields();
545b437b383SMiguel Ojeda         let variants = self.data.variants();
546b437b383SMiguel Ojeda         let tag = self.data.tag();
547b437b383SMiguel Ojeda         let zerocopy_crate = &self.ctx.zerocopy_crate;
548b437b383SMiguel Ojeda 
549b437b383SMiguel Ojeda         fn bound_tt(ty: &Type, traits: impl Iterator<Item = Trait>, ctx: &Ctx) -> WherePredicate {
550b437b383SMiguel Ojeda             let traits = traits.map(|t| t.crate_path(ctx));
551b437b383SMiguel Ojeda             parse_quote!(#ty: #(#traits)+*)
552b437b383SMiguel Ojeda         }
553b437b383SMiguel Ojeda         let field_type_bounds: Vec<_> = match (self.field_type_trait_bounds, &fields[..]) {
554b437b383SMiguel Ojeda             (FieldBounds::All(traits), _) => fields
555b437b383SMiguel Ojeda                 .iter()
556b437b383SMiguel Ojeda                 .map(|(_vis, _name, ty)| {
557b437b383SMiguel Ojeda                     bound_tt(ty, normalize_bounds(&self.trt, traits), self.ctx)
558b437b383SMiguel Ojeda                 })
559b437b383SMiguel Ojeda                 .collect(),
560b437b383SMiguel Ojeda             (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![],
561b437b383SMiguel Ojeda             (FieldBounds::Trailing(traits), [.., last]) => {
562b437b383SMiguel Ojeda                 vec![bound_tt(last.2, normalize_bounds(&self.trt, traits), self.ctx)]
563b437b383SMiguel Ojeda             }
564b437b383SMiguel Ojeda             (FieldBounds::Explicit(bounds), _) => bounds,
565b437b383SMiguel Ojeda         };
566b437b383SMiguel Ojeda 
567b437b383SMiguel Ojeda         let padding_check_bound = self
568b437b383SMiguel Ojeda             .padding_check
569b437b383SMiguel Ojeda             .map(|check| {
570b437b383SMiguel Ojeda                 // Parse the repr for `align` and `packed` modifiers. Note that
571b437b383SMiguel Ojeda                 // `Repr::<PrimitiveRepr, NonZeroU32>` is more permissive than
572b437b383SMiguel Ojeda                 // what Rust supports for structs, enums, or unions, and thus
573b437b383SMiguel Ojeda                 // reliably extracts these modifiers for any kind of type.
574b437b383SMiguel Ojeda                 let repr =
575b437b383SMiguel Ojeda                     Repr::<PrimitiveRepr, NonZeroU32>::from_attrs(&self.ctx.ast.attrs).unwrap();
576b437b383SMiguel Ojeda                 let core = self.ctx.core_path();
577b437b383SMiguel Ojeda                 let option = quote! { #core::option::Option };
578b437b383SMiguel Ojeda                 let nonzero = quote! { #core::num::NonZeroUsize };
579b437b383SMiguel Ojeda                 let none = quote! { #option::None::<#nonzero> };
580b437b383SMiguel Ojeda                 let repr_align =
581b437b383SMiguel Ojeda                     repr.get_align().map(|spanned| {
582b437b383SMiguel Ojeda                         let n = spanned.t.get();
583b437b383SMiguel Ojeda                         quote_spanned! { spanned.span => (#nonzero::new(#n as usize)) }
584b437b383SMiguel Ojeda                     }).unwrap_or(quote! { (#none) });
585b437b383SMiguel Ojeda                 let repr_packed =
586b437b383SMiguel Ojeda                     repr.get_packed().map(|packed| {
587b437b383SMiguel Ojeda                         let n = packed.get();
588b437b383SMiguel Ojeda                         quote! { (#nonzero::new(#n as usize)) }
589b437b383SMiguel Ojeda                     }).unwrap_or(quote! { (#none) });
590b437b383SMiguel Ojeda                 let variant_types = variants.iter().map(|(_, fields)| {
591b437b383SMiguel Ojeda                     let types = fields.iter().map(|(_vis, _name, ty)| ty);
592b437b383SMiguel Ojeda                     quote!([#((#types)),*])
593b437b383SMiguel Ojeda                 });
594b437b383SMiguel Ojeda                 let validator_context = check.validator_macro_context();
595b437b383SMiguel Ojeda                 let (trt, validator_macro) = check.validator_trait_and_macro_idents();
596b437b383SMiguel Ojeda                 let t = tag.iter();
597b437b383SMiguel Ojeda                 parse_quote! {
598b437b383SMiguel Ojeda                     (): #zerocopy_crate::util::macro_util::#trt<
599b437b383SMiguel Ojeda                         Self,
600b437b383SMiguel Ojeda                         {
601b437b383SMiguel Ojeda                             #validator_context
602b437b383SMiguel Ojeda                             #zerocopy_crate::#validator_macro!(Self, #repr_align, #repr_packed, #(#t,)* #(#variant_types),*)
603b437b383SMiguel Ojeda                         }
604b437b383SMiguel Ojeda                     >
605b437b383SMiguel Ojeda                 }
606b437b383SMiguel Ojeda             });
607b437b383SMiguel Ojeda 
608b437b383SMiguel Ojeda         let self_bounds: Option<WherePredicate> = match self.self_type_trait_bounds {
609b437b383SMiguel Ojeda             SelfBounds::None => None,
610b437b383SMiguel Ojeda             SelfBounds::All(traits) => {
611b437b383SMiguel Ojeda                 Some(bound_tt(&parse_quote!(Self), traits.iter().cloned(), self.ctx))
612b437b383SMiguel Ojeda             }
613b437b383SMiguel Ojeda         };
614b437b383SMiguel Ojeda 
615b437b383SMiguel Ojeda         let bounds = self
616b437b383SMiguel Ojeda             .ctx
617b437b383SMiguel Ojeda             .ast
618b437b383SMiguel Ojeda             .generics
619b437b383SMiguel Ojeda             .where_clause
620b437b383SMiguel Ojeda             .as_ref()
621b437b383SMiguel Ojeda             .map(|where_clause| where_clause.predicates.iter())
622b437b383SMiguel Ojeda             .into_iter()
623b437b383SMiguel Ojeda             .flatten()
624b437b383SMiguel Ojeda             .chain(field_type_bounds.iter())
625b437b383SMiguel Ojeda             .chain(padding_check_bound.iter())
626b437b383SMiguel Ojeda             .chain(self_bounds.iter());
627b437b383SMiguel Ojeda 
628b437b383SMiguel Ojeda         // The parameters with trait bounds, but without type defaults.
629b437b383SMiguel Ojeda         let mut params: Vec<_> = self
630b437b383SMiguel Ojeda             .ctx
631b437b383SMiguel Ojeda             .ast
632b437b383SMiguel Ojeda             .generics
633b437b383SMiguel Ojeda             .params
634b437b383SMiguel Ojeda             .clone()
635b437b383SMiguel Ojeda             .into_iter()
636b437b383SMiguel Ojeda             .map(|mut param| {
637b437b383SMiguel Ojeda                 match &mut param {
638b437b383SMiguel Ojeda                     GenericParam::Type(ty) => ty.default = None,
639b437b383SMiguel Ojeda                     GenericParam::Const(cnst) => cnst.default = None,
640b437b383SMiguel Ojeda                     GenericParam::Lifetime(_) => {}
641b437b383SMiguel Ojeda                 }
642b437b383SMiguel Ojeda                 parse_quote!(#param)
643b437b383SMiguel Ojeda             })
644b437b383SMiguel Ojeda             .chain(self.param_extras)
645b437b383SMiguel Ojeda             .collect();
646b437b383SMiguel Ojeda 
647b437b383SMiguel Ojeda         // For MSRV purposes, ensure that lifetimes precede types precede const
648b437b383SMiguel Ojeda         // generics.
649b437b383SMiguel Ojeda         params.sort_by_cached_key(|param| match param {
650b437b383SMiguel Ojeda             GenericParam::Lifetime(_) => 0,
651b437b383SMiguel Ojeda             GenericParam::Type(_) => 1,
652b437b383SMiguel Ojeda             GenericParam::Const(_) => 2,
653b437b383SMiguel Ojeda         });
654b437b383SMiguel Ojeda 
655b437b383SMiguel Ojeda         // The identifiers of the parameters without trait bounds or type
656b437b383SMiguel Ojeda         // defaults.
657b437b383SMiguel Ojeda         let param_idents = self.ctx.ast.generics.params.iter().map(|param| match param {
658b437b383SMiguel Ojeda             GenericParam::Type(ty) => {
659b437b383SMiguel Ojeda                 let ident = &ty.ident;
660b437b383SMiguel Ojeda                 quote!(#ident)
661b437b383SMiguel Ojeda             }
662b437b383SMiguel Ojeda             GenericParam::Lifetime(l) => {
663b437b383SMiguel Ojeda                 let ident = &l.lifetime;
664b437b383SMiguel Ojeda                 quote!(#ident)
665b437b383SMiguel Ojeda             }
666b437b383SMiguel Ojeda             GenericParam::Const(cnst) => {
667b437b383SMiguel Ojeda                 let ident = &cnst.ident;
668b437b383SMiguel Ojeda                 quote!({#ident})
669b437b383SMiguel Ojeda             }
670b437b383SMiguel Ojeda         });
671b437b383SMiguel Ojeda 
672b437b383SMiguel Ojeda         let inner_extras = self.inner_extras;
673b437b383SMiguel Ojeda         let allow_trivial_bounds =
674b437b383SMiguel Ojeda             if self.ctx.skip_on_error { quote!(#[allow(trivial_bounds)]) } else { quote!() };
675b437b383SMiguel Ojeda         let impl_tokens = quote! {
676b437b383SMiguel Ojeda             #allow_trivial_bounds
677b437b383SMiguel Ojeda             unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* >
678b437b383SMiguel Ojeda             where
679b437b383SMiguel Ojeda                 #(#bounds,)*
680b437b383SMiguel Ojeda             {
681b437b383SMiguel Ojeda                 fn only_derive_is_allowed_to_implement_this_trait() {}
682b437b383SMiguel Ojeda 
683b437b383SMiguel Ojeda                 #inner_extras
684b437b383SMiguel Ojeda             }
685b437b383SMiguel Ojeda         };
686b437b383SMiguel Ojeda 
687b437b383SMiguel Ojeda         let outer_extras = self.outer_extras.filter(|e| !e.is_empty());
688b437b383SMiguel Ojeda         let cfg_compile_error = self.ctx.cfg_compile_error();
689b437b383SMiguel Ojeda         const_block([Some(cfg_compile_error), Some(impl_tokens), outer_extras])
690b437b383SMiguel Ojeda     }
691b437b383SMiguel Ojeda }
692b437b383SMiguel Ojeda 
693b437b383SMiguel Ojeda // A polyfill for `Option::then_some`, which was added after our MSRV.
694b437b383SMiguel Ojeda //
695b437b383SMiguel Ojeda // The `#[allow(unused)]` is necessary because, on sufficiently recent toolchain
696b437b383SMiguel Ojeda // versions, `b.then_some(...)` resolves to the inherent method rather than to
697b437b383SMiguel Ojeda // this trait, and so this trait is considered unused.
698b437b383SMiguel Ojeda //
699b437b383SMiguel Ojeda // FIXME(#67): Remove this once our MSRV is >= 1.62.
700b437b383SMiguel Ojeda #[allow(unused)]
701b437b383SMiguel Ojeda trait BoolExt {
702b437b383SMiguel Ojeda     fn then_some<T>(self, t: T) -> Option<T>;
703b437b383SMiguel Ojeda }
704b437b383SMiguel Ojeda 
705b437b383SMiguel Ojeda impl BoolExt for bool {
706b437b383SMiguel Ojeda     fn then_some<T>(self, t: T) -> Option<T> {
707b437b383SMiguel Ojeda         if self {
708b437b383SMiguel Ojeda             Some(t)
709b437b383SMiguel Ojeda         } else {
710b437b383SMiguel Ojeda             None
711b437b383SMiguel Ojeda         }
712b437b383SMiguel Ojeda     }
713b437b383SMiguel Ojeda }
714b437b383SMiguel Ojeda 
715b437b383SMiguel Ojeda pub(crate) fn const_block(items: impl IntoIterator<Item = Option<TokenStream>>) -> TokenStream {
716b437b383SMiguel Ojeda     let items = items.into_iter().flatten();
717b437b383SMiguel Ojeda     quote! {
718b437b383SMiguel Ojeda         #[allow(
719b437b383SMiguel Ojeda             // FIXME(#553): Add a test that generates a warning when
720b437b383SMiguel Ojeda             // `#[allow(deprecated)]` isn't present.
721b437b383SMiguel Ojeda             deprecated,
722b437b383SMiguel Ojeda             // Required on some rustc versions due to a lint that is only
723b437b383SMiguel Ojeda             // triggered when `derive(KnownLayout)` is applied to `repr(C)`
724b437b383SMiguel Ojeda             // structs that are generated by macros. See #2177 for details.
725b437b383SMiguel Ojeda             private_bounds,
726b437b383SMiguel Ojeda             non_local_definitions,
727b437b383SMiguel Ojeda             non_camel_case_types,
728b437b383SMiguel Ojeda             non_upper_case_globals,
729b437b383SMiguel Ojeda             non_snake_case,
730b437b383SMiguel Ojeda             non_ascii_idents,
731b437b383SMiguel Ojeda             clippy::missing_inline_in_public_items,
732b437b383SMiguel Ojeda         )]
733b437b383SMiguel Ojeda         #[deny(ambiguous_associated_items)]
734b437b383SMiguel Ojeda         // While there are not currently any warnings that this suppresses
735b437b383SMiguel Ojeda         // (that we're aware of), it's good future-proofing hygiene.
736b437b383SMiguel Ojeda         #[automatically_derived]
737b437b383SMiguel Ojeda         const _: () = {
738b437b383SMiguel Ojeda             #(#items)*
739b437b383SMiguel Ojeda         };
740b437b383SMiguel Ojeda     }
741b437b383SMiguel Ojeda }
742b437b383SMiguel Ojeda pub(crate) fn generate_tag_enum(ctx: &Ctx, repr: &EnumRepr, data: &DataEnum) -> TokenStream {
743b437b383SMiguel Ojeda     let zerocopy_crate = &ctx.zerocopy_crate;
744b437b383SMiguel Ojeda     let variants = data.variants.iter().map(|v| {
745b437b383SMiguel Ojeda         let ident = &v.ident;
746b437b383SMiguel Ojeda         if let Some((eq, discriminant)) = &v.discriminant {
747b437b383SMiguel Ojeda             quote! { #ident #eq #discriminant }
748b437b383SMiguel Ojeda         } else {
749b437b383SMiguel Ojeda             quote! { #ident }
750b437b383SMiguel Ojeda         }
751b437b383SMiguel Ojeda     });
752b437b383SMiguel Ojeda 
753b437b383SMiguel Ojeda     // Don't include any `repr(align)` when generating the tag enum, as that
754b437b383SMiguel Ojeda     // could add padding after the tag but before any variants, which is not the
755b437b383SMiguel Ojeda     // correct behavior.
756b437b383SMiguel Ojeda     let repr = match repr {
757b437b383SMiguel Ojeda         EnumRepr::Transparent(span) => quote::quote_spanned! { *span => #[repr(transparent)] },
758b437b383SMiguel Ojeda         EnumRepr::Compound(c, _) => quote! { #c },
759b437b383SMiguel Ojeda     };
760b437b383SMiguel Ojeda 
761b437b383SMiguel Ojeda     quote! {
762b437b383SMiguel Ojeda         #repr
763b437b383SMiguel Ojeda         #[allow(dead_code)]
764b437b383SMiguel Ojeda         pub enum ___ZerocopyTag {
765b437b383SMiguel Ojeda             #(#variants,)*
766b437b383SMiguel Ojeda         }
767b437b383SMiguel Ojeda 
768b437b383SMiguel Ojeda         // SAFETY: `___ZerocopyTag` has no fields, and so it does not permit
769b437b383SMiguel Ojeda         // interior mutation.
770b437b383SMiguel Ojeda         unsafe impl #zerocopy_crate::Immutable for ___ZerocopyTag {
771b437b383SMiguel Ojeda             fn only_derive_is_allowed_to_implement_this_trait() {}
772b437b383SMiguel Ojeda         }
773b437b383SMiguel Ojeda     }
774b437b383SMiguel Ojeda }
775b437b383SMiguel Ojeda pub(crate) fn enum_size_from_repr(repr: &EnumRepr) -> Result<usize, Error> {
776b437b383SMiguel Ojeda     use CompoundRepr::*;
777b437b383SMiguel Ojeda     use PrimitiveRepr::*;
778b437b383SMiguel Ojeda     use Repr::*;
779b437b383SMiguel Ojeda     match repr {
780b437b383SMiguel Ojeda         Transparent(span)
781b437b383SMiguel Ojeda         | Compound(
782b437b383SMiguel Ojeda             Spanned {
783b437b383SMiguel Ojeda                 t: C | Rust | Primitive(U32 | I32 | U64 | I64 | U128 | I128 | Usize | Isize),
784b437b383SMiguel Ojeda                 span,
785b437b383SMiguel Ojeda             },
786b437b383SMiguel Ojeda             _,
787b437b383SMiguel Ojeda         ) => Err(Error::new(
788b437b383SMiguel Ojeda             *span,
789b437b383SMiguel Ojeda             "`FromBytes` only supported on enums with `#[repr(...)]` attributes `u8`, `i8`, `u16`, or `i16`",
790b437b383SMiguel Ojeda         )),
791b437b383SMiguel Ojeda         Compound(Spanned { t: Primitive(U8 | I8), span: _ }, _align) => Ok(8),
792b437b383SMiguel Ojeda         Compound(Spanned { t: Primitive(U16 | I16), span: _ }, _align) => Ok(16),
793b437b383SMiguel Ojeda     }
794b437b383SMiguel Ojeda }
795b437b383SMiguel Ojeda 
796b437b383SMiguel Ojeda #[cfg(test)]
797b437b383SMiguel Ojeda pub(crate) mod testutil {
798b437b383SMiguel Ojeda     use proc_macro2::TokenStream;
799b437b383SMiguel Ojeda     use syn::visit::{self, Visit};
800b437b383SMiguel Ojeda 
801b437b383SMiguel Ojeda     /// Checks for hygiene violations in the generated code.
802b437b383SMiguel Ojeda     ///
803b437b383SMiguel Ojeda     /// # Panics
804b437b383SMiguel Ojeda     ///
805b437b383SMiguel Ojeda     /// Panics if a hygiene violation is found.
806b437b383SMiguel Ojeda     pub(crate) fn check_hygiene(ts: TokenStream) {
807b437b383SMiguel Ojeda         struct AmbiguousItemVisitor;
808b437b383SMiguel Ojeda 
809b437b383SMiguel Ojeda         impl<'ast> Visit<'ast> for AmbiguousItemVisitor {
810b437b383SMiguel Ojeda             fn visit_path(&mut self, i: &'ast syn::Path) {
811b437b383SMiguel Ojeda                 if i.segments.len() > 1 && i.segments.first().unwrap().ident == "Self" {
812b437b383SMiguel Ojeda                     panic!(
813b437b383SMiguel Ojeda                     "Found ambiguous path `{}` in generated output. \
814b437b383SMiguel Ojeda                      All associated item access must be fully qualified (e.g., `<Self as Trait>::Item`) \
815b437b383SMiguel Ojeda                      to prevent hygiene issues.",
816b437b383SMiguel Ojeda                     quote::quote!(#i)
817b437b383SMiguel Ojeda                 );
818b437b383SMiguel Ojeda                 }
819b437b383SMiguel Ojeda                 visit::visit_path(self, i);
820b437b383SMiguel Ojeda             }
821b437b383SMiguel Ojeda         }
822b437b383SMiguel Ojeda 
823b437b383SMiguel Ojeda         let file = syn::parse2::<syn::File>(ts).expect("failed to parse generated output as File");
824b437b383SMiguel Ojeda         AmbiguousItemVisitor.visit_file(&file);
825b437b383SMiguel Ojeda     }
826b437b383SMiguel Ojeda 
827b437b383SMiguel Ojeda     #[test]
828b437b383SMiguel Ojeda     fn test_check_hygiene_success() {
829b437b383SMiguel Ojeda         check_hygiene(quote::quote! {
830b437b383SMiguel Ojeda             fn foo() {
831b437b383SMiguel Ojeda                 let _ = <Self as Trait>::Item;
832b437b383SMiguel Ojeda             }
833b437b383SMiguel Ojeda         });
834b437b383SMiguel Ojeda     }
835b437b383SMiguel Ojeda 
836b437b383SMiguel Ojeda     #[test]
837b437b383SMiguel Ojeda     #[should_panic(expected = "Found ambiguous path `Self :: Ambiguous`")]
838b437b383SMiguel Ojeda     fn test_check_hygiene_failure() {
839b437b383SMiguel Ojeda         check_hygiene(quote::quote! {
840b437b383SMiguel Ojeda             fn foo() {
841b437b383SMiguel Ojeda                 let _ = Self::Ambiguous;
842b437b383SMiguel Ojeda             }
843b437b383SMiguel Ojeda         });
844b437b383SMiguel Ojeda     }
845b437b383SMiguel Ojeda }
846