1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Provides an action to find USR for the symbol at <offset>, as well as 11 /// all additional USRs. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h" 16 #include "clang/AST/AST.h" 17 #include "clang/AST/ASTConsumer.h" 18 #include "clang/AST/ASTContext.h" 19 #include "clang/AST/Decl.h" 20 #include "clang/AST/RecursiveASTVisitor.h" 21 #include "clang/Basic/FileManager.h" 22 #include "clang/Frontend/CompilerInstance.h" 23 #include "clang/Frontend/FrontendAction.h" 24 #include "clang/Lex/Lexer.h" 25 #include "clang/Lex/Preprocessor.h" 26 #include "clang/Tooling/CommonOptionsParser.h" 27 #include "clang/Tooling/Refactoring.h" 28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h" 29 #include "clang/Tooling/Tooling.h" 30 31 #include <algorithm> 32 #include <set> 33 #include <string> 34 #include <vector> 35 36 using namespace llvm; 37 38 namespace clang { 39 namespace tooling { 40 41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) { 42 if (!FoundDecl) 43 return nullptr; 44 // If FoundDecl is a constructor or destructor, we want to instead take 45 // the Decl of the corresponding class. 46 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl)) 47 FoundDecl = CtorDecl->getParent(); 48 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl)) 49 FoundDecl = DtorDecl->getParent(); 50 // FIXME: (Alex L): Canonicalize implicit template instantions, just like 51 // the indexer does it. 52 53 // Note: please update the declaration's doc comment every time the 54 // canonicalization rules are changed. 55 return FoundDecl; 56 } 57 58 namespace { 59 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to 60 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given 61 // Decl refers to class and adds USRs of all overridden methods if Decl refers 62 // to virtual method. 63 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> { 64 public: 65 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context) 66 : FoundDecl(FoundDecl), Context(Context) {} 67 68 std::vector<std::string> Find() { 69 // Fill OverriddenMethods and PartialSpecs storages. 70 TraverseAST(Context); 71 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) { 72 addUSRsOfOverridenFunctions(MethodDecl); 73 for (const auto &OverriddenMethod : OverriddenMethods) { 74 if (checkIfOverriddenFunctionAscends(OverriddenMethod)) 75 USRSet.insert(getUSRForDecl(OverriddenMethod)); 76 } 77 addUSRsOfInstantiatedMethods(MethodDecl); 78 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) { 79 handleCXXRecordDecl(RecordDecl); 80 } else if (const auto *TemplateDecl = 81 dyn_cast<ClassTemplateDecl>(FoundDecl)) { 82 handleClassTemplateDecl(TemplateDecl); 83 } else if (const auto *FD = dyn_cast<FunctionDecl>(FoundDecl)) { 84 USRSet.insert(getUSRForDecl(FD)); 85 if (const auto *FTD = FD->getPrimaryTemplate()) 86 handleFunctionTemplateDecl(FTD); 87 } else if (const auto *FD = dyn_cast<FunctionTemplateDecl>(FoundDecl)) { 88 handleFunctionTemplateDecl(FD); 89 } else if (const auto *VTD = dyn_cast<VarTemplateDecl>(FoundDecl)) { 90 handleVarTemplateDecl(VTD); 91 } else if (const auto *VD = 92 dyn_cast<VarTemplateSpecializationDecl>(FoundDecl)) { 93 // FIXME: figure out why FoundDecl can be a VarTemplateSpecializationDecl. 94 handleVarTemplateDecl(VD->getSpecializedTemplate()); 95 } else if (const auto *VD = dyn_cast<VarDecl>(FoundDecl)) { 96 USRSet.insert(getUSRForDecl(VD)); 97 if (const auto *VTD = VD->getDescribedVarTemplate()) 98 handleVarTemplateDecl(VTD); 99 } else { 100 USRSet.insert(getUSRForDecl(FoundDecl)); 101 } 102 return std::vector<std::string>(USRSet.begin(), USRSet.end()); 103 } 104 105 bool shouldVisitTemplateInstantiations() const { return true; } 106 107 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) { 108 if (MethodDecl->isVirtual()) 109 OverriddenMethods.push_back(MethodDecl); 110 if (MethodDecl->getInstantiatedFromMemberFunction()) 111 InstantiatedMethods.push_back(MethodDecl); 112 return true; 113 } 114 115 private: 116 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) { 117 if (!RecordDecl->getDefinition()) { 118 USRSet.insert(getUSRForDecl(RecordDecl)); 119 return; 120 } 121 RecordDecl = RecordDecl->getDefinition(); 122 if (const auto *ClassTemplateSpecDecl = 123 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl)) 124 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate()); 125 addUSRsOfCtorDtors(RecordDecl); 126 } 127 128 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) { 129 for (const auto *Specialization : TemplateDecl->specializations()) 130 addUSRsOfCtorDtors(Specialization); 131 SmallVector<ClassTemplatePartialSpecializationDecl *, 4> PartialSpecs; 132 TemplateDecl->getPartialSpecializations(PartialSpecs); 133 for (const auto *Spec : PartialSpecs) 134 addUSRsOfCtorDtors(Spec); 135 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl()); 136 } 137 138 void handleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) { 139 USRSet.insert(getUSRForDecl(FTD)); 140 USRSet.insert(getUSRForDecl(FTD->getTemplatedDecl())); 141 for (const auto *S : FTD->specializations()) 142 USRSet.insert(getUSRForDecl(S)); 143 } 144 145 void handleVarTemplateDecl(const VarTemplateDecl *VTD) { 146 USRSet.insert(getUSRForDecl(VTD)); 147 USRSet.insert(getUSRForDecl(VTD->getTemplatedDecl())); 148 for (const auto *Spec : VTD->specializations()) 149 USRSet.insert(getUSRForDecl(Spec)); 150 SmallVector<VarTemplatePartialSpecializationDecl *, 4> PartialSpecs; 151 VTD->getPartialSpecializations(PartialSpecs); 152 for (const auto *Spec : PartialSpecs) 153 USRSet.insert(getUSRForDecl(Spec)); 154 } 155 156 void addUSRsOfCtorDtors(const CXXRecordDecl *RD) { 157 const auto* RecordDecl = RD->getDefinition(); 158 159 // Skip if the CXXRecordDecl doesn't have definition. 160 if (!RecordDecl) { 161 USRSet.insert(getUSRForDecl(RD)); 162 return; 163 } 164 165 for (const auto *CtorDecl : RecordDecl->ctors()) 166 USRSet.insert(getUSRForDecl(CtorDecl)); 167 // Add template constructor decls, they are not in ctors() unfortunately. 168 if (RecordDecl->hasUserDeclaredConstructor()) 169 for (const auto *D : RecordDecl->decls()) 170 if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D)) 171 if (const auto *Ctor = 172 dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl())) 173 USRSet.insert(getUSRForDecl(Ctor)); 174 175 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor())); 176 USRSet.insert(getUSRForDecl(RecordDecl)); 177 } 178 179 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) { 180 USRSet.insert(getUSRForDecl(MethodDecl)); 181 // Recursively visit each OverridenMethod. 182 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) 183 addUSRsOfOverridenFunctions(OverriddenMethod); 184 } 185 186 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) { 187 // For renaming a class template method, all references of the instantiated 188 // member methods should be renamed too, so add USRs of the instantiated 189 // methods to the USR set. 190 USRSet.insert(getUSRForDecl(MethodDecl)); 191 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction()) 192 USRSet.insert(getUSRForDecl(FT)); 193 for (const auto *Method : InstantiatedMethods) { 194 if (USRSet.find(getUSRForDecl( 195 Method->getInstantiatedFromMemberFunction())) != USRSet.end()) 196 USRSet.insert(getUSRForDecl(Method)); 197 } 198 } 199 200 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) { 201 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) { 202 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end()) 203 return true; 204 return checkIfOverriddenFunctionAscends(OverriddenMethod); 205 } 206 return false; 207 } 208 209 const Decl *FoundDecl; 210 ASTContext &Context; 211 std::set<std::string> USRSet; 212 std::vector<const CXXMethodDecl *> OverriddenMethods; 213 std::vector<const CXXMethodDecl *> InstantiatedMethods; 214 }; 215 } // namespace 216 217 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND, 218 ASTContext &Context) { 219 AdditionalUSRFinder Finder(ND, Context); 220 return Finder.Find(); 221 } 222 223 class NamedDeclFindingConsumer : public ASTConsumer { 224 public: 225 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets, 226 ArrayRef<std::string> QualifiedNames, 227 std::vector<std::string> &SpellingNames, 228 std::vector<std::vector<std::string>> &USRList, 229 bool Force, bool &ErrorOccurred) 230 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames), 231 SpellingNames(SpellingNames), USRList(USRList), Force(Force), 232 ErrorOccurred(ErrorOccurred) {} 233 234 private: 235 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr, 236 unsigned SymbolOffset, const std::string &QualifiedName) { 237 DiagnosticsEngine &Engine = Context.getDiagnostics(); 238 const FileID MainFileID = SourceMgr.getMainFileID(); 239 240 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) { 241 ErrorOccurred = true; 242 unsigned InvalidOffset = Engine.getCustomDiagID( 243 DiagnosticsEngine::Error, 244 "SourceLocation in file %0 at offset %1 is invalid"); 245 Engine.Report(SourceLocation(), InvalidOffset) 246 << SourceMgr.getFileEntryRefForID(MainFileID)->getName() 247 << SymbolOffset; 248 return false; 249 } 250 251 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID) 252 .getLocWithOffset(SymbolOffset); 253 const NamedDecl *FoundDecl = QualifiedName.empty() 254 ? getNamedDeclAt(Context, Point) 255 : getNamedDeclFor(Context, QualifiedName); 256 257 if (FoundDecl == nullptr) { 258 if (QualifiedName.empty()) { 259 FullSourceLoc FullLoc(Point, SourceMgr); 260 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID( 261 DiagnosticsEngine::Error, 262 "clang-rename could not find symbol (offset %0)"); 263 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset; 264 ErrorOccurred = true; 265 return false; 266 } 267 268 if (Force) { 269 SpellingNames.push_back(std::string()); 270 USRList.push_back(std::vector<std::string>()); 271 return true; 272 } 273 274 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID( 275 DiagnosticsEngine::Error, "clang-rename could not find symbol %0"); 276 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName; 277 ErrorOccurred = true; 278 return false; 279 } 280 281 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl); 282 SpellingNames.push_back(FoundDecl->getNameAsString()); 283 AdditionalUSRFinder Finder(FoundDecl, Context); 284 USRList.push_back(Finder.Find()); 285 return true; 286 } 287 288 void HandleTranslationUnit(ASTContext &Context) override { 289 const SourceManager &SourceMgr = Context.getSourceManager(); 290 for (unsigned Offset : SymbolOffsets) { 291 if (!FindSymbol(Context, SourceMgr, Offset, "")) 292 return; 293 } 294 for (const std::string &QualifiedName : QualifiedNames) { 295 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName)) 296 return; 297 } 298 } 299 300 ArrayRef<unsigned> SymbolOffsets; 301 ArrayRef<std::string> QualifiedNames; 302 std::vector<std::string> &SpellingNames; 303 std::vector<std::vector<std::string>> &USRList; 304 bool Force; 305 bool &ErrorOccurred; 306 }; 307 308 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() { 309 return std::make_unique<NamedDeclFindingConsumer>( 310 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force, 311 ErrorOccurred); 312 } 313 314 } // end namespace tooling 315 } // end namespace clang 316