1 //===--- RefactoringCallbacks.cpp - Structural query framework ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // 10 //===----------------------------------------------------------------------===// 11 #include "clang/Tooling/RefactoringCallbacks.h" 12 #include "clang/ASTMatchers/ASTMatchFinder.h" 13 #include "clang/Basic/SourceLocation.h" 14 #include "clang/Lex/Lexer.h" 15 16 using llvm::StringError; 17 using llvm::make_error; 18 19 namespace clang { 20 namespace tooling { 21 22 RefactoringCallback::RefactoringCallback() {} 23 tooling::Replacements &RefactoringCallback::getReplacements() { 24 return Replace; 25 } 26 27 ASTMatchRefactorer::ASTMatchRefactorer( 28 std::map<std::string, Replacements> &FileToReplaces) 29 : FileToReplaces(FileToReplaces) {} 30 31 void ASTMatchRefactorer::addDynamicMatcher( 32 const ast_matchers::internal::DynTypedMatcher &Matcher, 33 RefactoringCallback *Callback) { 34 MatchFinder.addDynamicMatcher(Matcher, Callback); 35 Callbacks.push_back(Callback); 36 } 37 38 class RefactoringASTConsumer : public ASTConsumer { 39 public: 40 explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) 41 : Refactoring(Refactoring) {} 42 43 void HandleTranslationUnit(ASTContext &Context) override { 44 // The ASTMatchRefactorer is re-used between translation units. 45 // Clear the matchers so that each Replacement is only emitted once. 46 for (const auto &Callback : Refactoring.Callbacks) { 47 Callback->getReplacements().clear(); 48 } 49 Refactoring.MatchFinder.matchAST(Context); 50 for (const auto &Callback : Refactoring.Callbacks) { 51 for (const auto &Replacement : Callback->getReplacements()) { 52 llvm::Error Err = 53 Refactoring.FileToReplaces[Replacement.getFilePath()].add( 54 Replacement); 55 if (Err) { 56 llvm::errs() << "Skipping replacement " << Replacement.toString() 57 << " due to this error:\n" 58 << toString(std::move(Err)) << "\n"; 59 } 60 } 61 } 62 } 63 64 private: 65 ASTMatchRefactorer &Refactoring; 66 }; 67 68 std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() { 69 return std::make_unique<RefactoringASTConsumer>(*this); 70 } 71 72 static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, 73 StringRef Text) { 74 return tooling::Replacement( 75 Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text); 76 } 77 static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, 78 const Stmt &To) { 79 return replaceStmtWithText( 80 Sources, From, 81 Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()), 82 Sources, LangOptions())); 83 } 84 85 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) 86 : FromId(FromId), ToText(ToText) {} 87 88 void ReplaceStmtWithText::run( 89 const ast_matchers::MatchFinder::MatchResult &Result) { 90 if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) { 91 auto Err = Replace.add(tooling::Replacement( 92 *Result.SourceManager, 93 CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText)); 94 // FIXME: better error handling. For now, just print error message in the 95 // release version. 96 if (Err) { 97 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 98 assert(false); 99 } 100 } 101 } 102 103 ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) 104 : FromId(FromId), ToId(ToId) {} 105 106 void ReplaceStmtWithStmt::run( 107 const ast_matchers::MatchFinder::MatchResult &Result) { 108 const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId); 109 const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId); 110 if (FromMatch && ToMatch) { 111 auto Err = Replace.add( 112 replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch)); 113 // FIXME: better error handling. For now, just print error message in the 114 // release version. 115 if (Err) { 116 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 117 assert(false); 118 } 119 } 120 } 121 122 ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, 123 bool PickTrueBranch) 124 : Id(Id), PickTrueBranch(PickTrueBranch) {} 125 126 void ReplaceIfStmtWithItsBody::run( 127 const ast_matchers::MatchFinder::MatchResult &Result) { 128 if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) { 129 const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); 130 if (Body) { 131 auto Err = 132 Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body)); 133 // FIXME: better error handling. For now, just print error message in the 134 // release version. 135 if (Err) { 136 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 137 assert(false); 138 } 139 } else if (!PickTrueBranch) { 140 // If we want to use the 'else'-branch, but it doesn't exist, delete 141 // the whole 'if'. 142 auto Err = 143 Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, "")); 144 // FIXME: better error handling. For now, just print error message in the 145 // release version. 146 if (Err) { 147 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 148 assert(false); 149 } 150 } 151 } 152 } 153 154 ReplaceNodeWithTemplate::ReplaceNodeWithTemplate( 155 llvm::StringRef FromId, std::vector<TemplateElement> Template) 156 : FromId(FromId), Template(std::move(Template)) {} 157 158 llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>> 159 ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) { 160 std::vector<TemplateElement> ParsedTemplate; 161 for (size_t Index = 0; Index < ToTemplate.size();) { 162 if (ToTemplate[Index] == '$') { 163 if (ToTemplate.substr(Index, 2) == "$$") { 164 Index += 2; 165 ParsedTemplate.push_back( 166 TemplateElement{TemplateElement::Literal, "$"}); 167 } else if (ToTemplate.substr(Index, 2) == "${") { 168 size_t EndOfIdentifier = ToTemplate.find("}", Index); 169 if (EndOfIdentifier == std::string::npos) { 170 return make_error<StringError>( 171 "Unterminated ${...} in replacement template near " + 172 ToTemplate.substr(Index), 173 llvm::inconvertibleErrorCode()); 174 } 175 std::string SourceNodeName = 176 ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2); 177 ParsedTemplate.push_back( 178 TemplateElement{TemplateElement::Identifier, SourceNodeName}); 179 Index = EndOfIdentifier + 1; 180 } else { 181 return make_error<StringError>( 182 "Invalid $ in replacement template near " + 183 ToTemplate.substr(Index), 184 llvm::inconvertibleErrorCode()); 185 } 186 } else { 187 size_t NextIndex = ToTemplate.find('$', Index + 1); 188 ParsedTemplate.push_back( 189 TemplateElement{TemplateElement::Literal, 190 ToTemplate.substr(Index, NextIndex - Index)}); 191 Index = NextIndex; 192 } 193 } 194 return std::unique_ptr<ReplaceNodeWithTemplate>( 195 new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate))); 196 } 197 198 void ReplaceNodeWithTemplate::run( 199 const ast_matchers::MatchFinder::MatchResult &Result) { 200 const auto &NodeMap = Result.Nodes.getMap(); 201 202 std::string ToText; 203 for (const auto &Element : Template) { 204 switch (Element.Type) { 205 case TemplateElement::Literal: 206 ToText += Element.Value; 207 break; 208 case TemplateElement::Identifier: { 209 auto NodeIter = NodeMap.find(Element.Value); 210 if (NodeIter == NodeMap.end()) { 211 llvm::errs() << "Node " << Element.Value 212 << " used in replacement template not bound in Matcher \n"; 213 llvm::report_fatal_error("Unbound node in replacement template."); 214 } 215 CharSourceRange Source = 216 CharSourceRange::getTokenRange(NodeIter->second.getSourceRange()); 217 ToText += Lexer::getSourceText(Source, *Result.SourceManager, 218 Result.Context->getLangOpts()); 219 break; 220 } 221 } 222 } 223 if (NodeMap.count(FromId) == 0) { 224 llvm::errs() << "Node to be replaced " << FromId 225 << " not bound in query.\n"; 226 llvm::report_fatal_error("FromId node not bound in MatchResult"); 227 } 228 auto Replacement = 229 tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, 230 Result.Context->getLangOpts()); 231 llvm::Error Err = Replace.add(Replacement); 232 if (Err) { 233 llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() 234 << "! " << llvm::toString(std::move(Err)) << "\n"; 235 llvm::report_fatal_error("Replacement failed"); 236 } 237 } 238 239 } // end namespace tooling 240 } // end namespace clang 241