1 // SPDX-License-Identifier: GPL-2.0 2 3 use std::collections::HashSet; 4 use std::fmt::Write; 5 6 use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; 7 8 pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { 9 let mut tokens: Vec<_> = ts.into_iter().collect(); 10 11 // Scan for the `trait` or `impl` keyword. 12 let is_trait = tokens 13 .iter() 14 .find_map(|token| match token { 15 TokenTree::Ident(ident) => match ident.to_string().as_str() { 16 "trait" => Some(true), 17 "impl" => Some(false), 18 _ => None, 19 }, 20 _ => None, 21 }) 22 .expect("#[vtable] attribute should only be applied to trait or impl block"); 23 24 // Retrieve the main body. The main body should be the last token tree. 25 let body = match tokens.pop() { 26 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, 27 _ => panic!("cannot locate main body of trait or impl block"), 28 }; 29 30 let mut body_it = body.stream().into_iter(); 31 let mut functions = Vec::new(); 32 let mut consts = HashSet::new(); 33 while let Some(token) = body_it.next() { 34 match token { 35 TokenTree::Ident(ident) if ident == "fn" => { 36 let fn_name = match body_it.next() { 37 Some(TokenTree::Ident(ident)) => ident.to_string(), 38 // Possibly we've encountered a fn pointer type instead. 39 _ => continue, 40 }; 41 functions.push(fn_name); 42 } 43 TokenTree::Ident(ident) if ident == "const" => { 44 let const_name = match body_it.next() { 45 Some(TokenTree::Ident(ident)) => ident.to_string(), 46 // Possibly we've encountered an inline const block instead. 47 _ => continue, 48 }; 49 consts.insert(const_name); 50 } 51 _ => (), 52 } 53 } 54 55 let mut const_items; 56 if is_trait { 57 const_items = " 58 /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) 59 /// attribute when implementing this trait. 60 const USE_VTABLE_ATTR: (); 61 " 62 .to_owned(); 63 64 for f in functions { 65 let gen_const_name = format!("HAS_{}", f.to_uppercase()); 66 // Skip if it's declared already -- this allows user override. 67 if consts.contains(&gen_const_name) { 68 continue; 69 } 70 // We don't know on the implementation-site whether a method is required or provided 71 // so we have to generate a const for all methods. 72 write!( 73 const_items, 74 "/// Indicates if the `{f}` method is overridden by the implementor. 75 const {gen_const_name}: bool = false;", 76 ) 77 .unwrap(); 78 consts.insert(gen_const_name); 79 } 80 } else { 81 const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); 82 83 for f in functions { 84 let gen_const_name = format!("HAS_{}", f.to_uppercase()); 85 if consts.contains(&gen_const_name) { 86 continue; 87 } 88 write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); 89 } 90 } 91 92 let new_body = vec![const_items.parse().unwrap(), body.stream()] 93 .into_iter() 94 .collect(); 95 tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); 96 tokens.into_iter().collect() 97 } 98