xref: /linux/rust/macros/vtable.rs (revision 5f160950a5cdc36f222299905e09a72f67ebfcd4)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use std::{
4     collections::HashSet,
5     iter::Extend, //
6 };
7 
8 use proc_macro2::{
9     Ident,
10     TokenStream, //
11 };
12 use quote::ToTokens;
13 use syn::{
14     parse_quote,
15     Error,
16     ImplItem,
17     Item,
18     ItemImpl,
19     ItemTrait,
20     Result,
21     TraitItem, //
22 };
23 
24 fn handle_trait(mut item: ItemTrait) -> Result<ItemTrait> {
25     let mut gen_items = Vec::new();
26     let mut gen_consts = HashSet::new();
27 
28     gen_items.push(parse_quote! {
29          /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable)
30          /// attribute when implementing this trait.
31          const USE_VTABLE_ATTR: ();
32     });
33 
34     for item in &item.items {
35         if let TraitItem::Fn(fn_item) = item {
36             let name = &fn_item.sig.ident;
37             let gen_const_name = Ident::new(
38                 &format!("HAS_{}", name.to_string().to_uppercase()),
39                 name.span(),
40             );
41             // Skip if it's declared already -- this can happen if `#[cfg]` is used to selectively
42             // define functions.
43             // FIXME: `#[cfg]` should be copied and propagated to the generated consts.
44             if gen_consts.contains(&gen_const_name) {
45                 continue;
46             }
47 
48             // We don't know on the implementation-site whether a method is required or provided
49             // so we have to generate a const for all methods.
50             let comment =
51                 format!("Indicates if the `{name}` method is overridden by the implementor.");
52             gen_items.push(parse_quote! {
53                 #[doc = #comment]
54                 const #gen_const_name: bool = false;
55             });
56             gen_consts.insert(gen_const_name);
57         }
58     }
59 
60     item.items.extend(gen_items);
61     Ok(item)
62 }
63 
64 fn handle_impl(mut item: ItemImpl) -> Result<ItemImpl> {
65     let mut gen_items = Vec::new();
66     let mut defined_consts = HashSet::new();
67 
68     // Iterate over all user-defined constants to gather any possible explicit overrides.
69     for item in &item.items {
70         if let ImplItem::Const(const_item) = item {
71             defined_consts.insert(const_item.ident.clone());
72         }
73     }
74 
75     gen_items.push(parse_quote! {
76         const USE_VTABLE_ATTR: () = ();
77     });
78 
79     for item in &item.items {
80         if let ImplItem::Fn(fn_item) = item {
81             let name = &fn_item.sig.ident;
82             let gen_const_name = Ident::new(
83                 &format!("HAS_{}", name.to_string().to_uppercase()),
84                 name.span(),
85             );
86             // Skip if it's declared already -- this allows user override.
87             if defined_consts.contains(&gen_const_name) {
88                 continue;
89             }
90             gen_items.push(parse_quote! {
91                 const #gen_const_name: bool = true;
92             });
93             defined_consts.insert(gen_const_name);
94         }
95     }
96 
97     item.items.extend(gen_items);
98     Ok(item)
99 }
100 
101 pub(crate) fn vtable(input: Item) -> Result<TokenStream> {
102     match input {
103         Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()),
104         Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()),
105         _ => Err(Error::new_spanned(
106             input,
107             "`#[vtable]` attribute should only be applied to trait or impl block",
108         ))?,
109     }
110 }
111