1 // SPDX-License-Identifier: GPL-2.0 2 3 use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; 4 use std::collections::HashSet; 5 use std::fmt::Write; 6 7 pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { 8 let mut tokens: Vec<_> = ts.into_iter().collect(); 9 10 // Scan for the `trait` or `impl` keyword. 11 let is_trait = tokens 12 .iter() 13 .find_map(|token| match token { 14 TokenTree::Ident(ident) => match ident.to_string().as_str() { 15 "trait" => Some(true), 16 "impl" => Some(false), 17 _ => None, 18 }, 19 _ => None, 20 }) 21 .expect("#[vtable] attribute should only be applied to trait or impl block"); 22 23 // Retrieve the main body. The main body should be the last token tree. 24 let body = match tokens.pop() { 25 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, 26 _ => panic!("cannot locate main body of trait or impl block"), 27 }; 28 29 let mut body_it = body.stream().into_iter(); 30 let mut functions = Vec::new(); 31 let mut consts = HashSet::new(); 32 while let Some(token) = body_it.next() { 33 match token { 34 TokenTree::Ident(ident) if ident.to_string() == "fn" => { 35 let fn_name = match body_it.next() { 36 Some(TokenTree::Ident(ident)) => ident.to_string(), 37 // Possibly we've encountered a fn pointer type instead. 38 _ => continue, 39 }; 40 functions.push(fn_name); 41 } 42 TokenTree::Ident(ident) if ident.to_string() == "const" => { 43 let const_name = match body_it.next() { 44 Some(TokenTree::Ident(ident)) => ident.to_string(), 45 // Possibly we've encountered an inline const block instead. 46 _ => continue, 47 }; 48 consts.insert(const_name); 49 } 50 _ => (), 51 } 52 } 53 54 let mut const_items; 55 if is_trait { 56 const_items = " 57 /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) 58 /// attribute when implementing this trait. 59 const USE_VTABLE_ATTR: (); 60 " 61 .to_owned(); 62 63 for f in functions { 64 let gen_const_name = format!("HAS_{}", f.to_uppercase()); 65 // Skip if it's declared already -- this allows user override. 66 if consts.contains(&gen_const_name) { 67 continue; 68 } 69 // We don't know on the implementation-site whether a method is required or provided 70 // so we have to generate a const for all methods. 71 write!( 72 const_items, 73 "/// Indicates if the `{f}` method is overridden by the implementor. 74 const {gen_const_name}: bool = false;", 75 ) 76 .unwrap(); 77 consts.insert(gen_const_name); 78 } 79 } else { 80 const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); 81 82 for f in functions { 83 let gen_const_name = format!("HAS_{}", f.to_uppercase()); 84 if consts.contains(&gen_const_name) { 85 continue; 86 } 87 write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); 88 } 89 } 90 91 let new_body = vec![const_items.parse().unwrap(), body.stream()] 92 .into_iter() 93 .collect(); 94 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); 95 tokens.into_iter().collect() 96 } 97