xref: /linux/rust/macros/paste.rs (revision 7f71507851fc7764b36a3221839607d3a45c2025)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 use proc_macro::{Delimiter, Group, Ident, Spacing, Span, TokenTree};
4 
5 fn concat_helper(tokens: &[TokenTree]) -> Vec<(String, Span)> {
6     let mut tokens = tokens.iter();
7     let mut segments = Vec::new();
8     let mut span = None;
9     loop {
10         match tokens.next() {
11             None => break,
12             Some(TokenTree::Literal(lit)) => {
13                 // Allow us to concat string literals by stripping quotes
14                 let mut value = lit.to_string();
15                 if value.starts_with('"') && value.ends_with('"') {
16                     value.remove(0);
17                     value.pop();
18                 }
19                 segments.push((value, lit.span()));
20             }
21             Some(TokenTree::Ident(ident)) => {
22                 let mut value = ident.to_string();
23                 if value.starts_with("r#") {
24                     value.replace_range(0..2, "");
25                 }
26                 segments.push((value, ident.span()));
27             }
28             Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
29                 let Some(TokenTree::Ident(ident)) = tokens.next() else {
30                     panic!("expected identifier as modifier");
31                 };
32 
33                 let (mut value, sp) = segments.pop().expect("expected identifier before modifier");
34                 match ident.to_string().as_str() {
35                     // Set the overall span of concatenated token as current span
36                     "span" => {
37                         assert!(
38                             span.is_none(),
39                             "span modifier should only appear at most once"
40                         );
41                         span = Some(sp);
42                     }
43                     "lower" => value = value.to_lowercase(),
44                     "upper" => value = value.to_uppercase(),
45                     v => panic!("unknown modifier `{v}`"),
46                 };
47                 segments.push((value, sp));
48             }
49             Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::None => {
50                 let tokens = group.stream().into_iter().collect::<Vec<TokenTree>>();
51                 segments.append(&mut concat_helper(tokens.as_slice()));
52             }
53             token => panic!("unexpected token in paste segments: {:?}", token),
54         };
55     }
56 
57     segments
58 }
59 
60 fn concat(tokens: &[TokenTree], group_span: Span) -> TokenTree {
61     let segments = concat_helper(tokens);
62     let pasted: String = segments.into_iter().map(|x| x.0).collect();
63     TokenTree::Ident(Ident::new(&pasted, group_span))
64 }
65 
66 pub(crate) fn expand(tokens: &mut Vec<TokenTree>) {
67     for token in tokens.iter_mut() {
68         if let TokenTree::Group(group) = token {
69             let delimiter = group.delimiter();
70             let span = group.span();
71             let mut stream: Vec<_> = group.stream().into_iter().collect();
72             // Find groups that looks like `[< A B C D >]`
73             if delimiter == Delimiter::Bracket
74                 && stream.len() >= 3
75                 && matches!(&stream[0], TokenTree::Punct(p) if p.as_char() == '<')
76                 && matches!(&stream[stream.len() - 1], TokenTree::Punct(p) if p.as_char() == '>')
77             {
78                 // Replace the group with concatenated token
79                 *token = concat(&stream[1..stream.len() - 1], span);
80             } else {
81                 // Recursively expand tokens inside the group
82                 expand(&mut stream);
83                 let mut group = Group::new(delimiter, stream.into_iter().collect());
84                 group.set_span(span);
85                 *token = TokenTree::Group(group);
86             }
87         }
88     }
89 
90     // Path segments cannot contain invisible delimiter group, so remove them if any.
91     for i in (0..tokens.len().saturating_sub(3)).rev() {
92         // Looking for a double colon
93         if matches!(
94             (&tokens[i + 1], &tokens[i + 2]),
95             (TokenTree::Punct(a), TokenTree::Punct(b))
96                 if a.as_char() == ':' && a.spacing() == Spacing::Joint && b.as_char() == ':'
97         ) {
98             match &tokens[i + 3] {
99                 TokenTree::Group(group) if group.delimiter() == Delimiter::None => {
100                     tokens.splice(i + 3..i + 4, group.stream());
101                 }
102                 _ => (),
103             }
104 
105             match &tokens[i] {
106                 TokenTree::Group(group) if group.delimiter() == Delimiter::None => {
107                     tokens.splice(i..i + 1, group.stream());
108                 }
109                 _ => (),
110             }
111         }
112     }
113 }
114