1 // SPDX-License-Identifier: Apache-2.0 OR MIT 2 3 use crate::helpers::{parse_generics, Generics}; 4 use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree}; 5 6 pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { 7 // This proc-macro only does some pre-parsing and then delegates the actual parsing to 8 // `kernel::__pin_data!`. 9 10 let ( 11 Generics { 12 impl_generics, 13 decl_generics, 14 ty_generics, 15 }, 16 rest, 17 ) = parse_generics(input); 18 // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new 19 // type with the same generics and bounds, this poses a problem, since `Self` will refer to the 20 // new type as opposed to this struct definition. Therefore we have to replace `Self` with the 21 // concrete name. 22 23 // Errors that occur when replacing `Self` with `struct_name`. 24 let mut errs = TokenStream::new(); 25 // The name of the struct with ty_generics. 26 let struct_name = rest 27 .iter() 28 .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct")) 29 .nth(1) 30 .and_then(|tt| match tt { 31 TokenTree::Ident(_) => { 32 let tt = tt.clone(); 33 let mut res = vec![tt]; 34 if !ty_generics.is_empty() { 35 // We add this, so it is maximally compatible with e.g. `Self::CONST` which 36 // will be replaced by `StructName::<$generics>::CONST`. 37 res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint))); 38 res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone))); 39 res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone))); 40 res.extend(ty_generics.iter().cloned()); 41 res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))); 42 } 43 Some(res) 44 } 45 _ => None, 46 }) 47 .unwrap_or_else(|| { 48 // If we did not find the name of the struct then we will use `Self` as the replacement 49 // and add a compile error to ensure it does not compile. 50 errs.extend( 51 "::core::compile_error!(\"Could not locate type name.\");" 52 .parse::<TokenStream>() 53 .unwrap(), 54 ); 55 "Self".parse::<TokenStream>().unwrap().into_iter().collect() 56 }); 57 let impl_generics = impl_generics 58 .into_iter() 59 .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)) 60 .collect::<Vec<_>>(); 61 let mut rest = rest 62 .into_iter() 63 .flat_map(|tt| { 64 // We ignore top level `struct` tokens, since they would emit a compile error. 65 if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") { 66 vec![tt] 67 } else { 68 replace_self_and_deny_type_defs(&struct_name, tt, &mut errs) 69 } 70 }) 71 .collect::<Vec<_>>(); 72 // This should be the body of the struct `{...}`. 73 let last = rest.pop(); 74 let mut quoted = quote!(::kernel::__pin_data! { 75 parse_input: 76 @args(#args), 77 @sig(#(#rest)*), 78 @impl_generics(#(#impl_generics)*), 79 @ty_generics(#(#ty_generics)*), 80 @decl_generics(#(#decl_generics)*), 81 @body(#last), 82 }); 83 quoted.extend(errs); 84 quoted 85 } 86 87 /// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl` 88 /// keywords. 89 /// 90 /// The error is appended to `errs` to allow normal parsing to continue. 91 fn replace_self_and_deny_type_defs( 92 struct_name: &Vec<TokenTree>, 93 tt: TokenTree, 94 errs: &mut TokenStream, 95 ) -> Vec<TokenTree> { 96 match tt { 97 TokenTree::Ident(ref i) 98 if i.to_string() == "enum" 99 || i.to_string() == "trait" 100 || i.to_string() == "struct" 101 || i.to_string() == "union" 102 || i.to_string() == "impl" => 103 { 104 errs.extend( 105 format!( 106 "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \ 107 `#[pin_data]`.\");" 108 ) 109 .parse::<TokenStream>() 110 .unwrap() 111 .into_iter() 112 .map(|mut tok| { 113 tok.set_span(tt.span()); 114 tok 115 }), 116 ); 117 vec![tt] 118 } 119 TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(), 120 TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt], 121 TokenTree::Group(g) => vec![TokenTree::Group(Group::new( 122 g.delimiter(), 123 g.stream() 124 .into_iter() 125 .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs)) 126 .collect(), 127 ))], 128 } 129 } 130