1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Provides an action to find USR for the symbol at <offset>, as well as 11 /// all additional USRs. 12 /// 13 //===----------------------------------------------------------------------===// 14 15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h" 16 #include "clang/AST/ASTConsumer.h" 17 #include "clang/AST/ASTContext.h" 18 #include "clang/AST/Decl.h" 19 #include "clang/AST/RecursiveASTVisitor.h" 20 #include "clang/Frontend/CompilerInstance.h" 21 #include "clang/Lex/Lexer.h" 22 #include "clang/Tooling/Refactoring/Rename/USRFinder.h" 23 #include "clang/Tooling/Tooling.h" 24 25 #include <set> 26 #include <string> 27 #include <vector> 28 29 using namespace llvm; 30 31 namespace clang { 32 namespace tooling { 33 34 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) { 35 if (!FoundDecl) 36 return nullptr; 37 // If FoundDecl is a constructor or destructor, we want to instead take 38 // the Decl of the corresponding class. 39 if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl)) 40 FoundDecl = CtorDecl->getParent(); 41 else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl)) 42 FoundDecl = DtorDecl->getParent(); 43 // FIXME: (Alex L): Canonicalize implicit template instantions, just like 44 // the indexer does it. 45 46 // Note: please update the declaration's doc comment every time the 47 // canonicalization rules are changed. 48 return FoundDecl; 49 } 50 51 namespace { 52 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to 53 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given 54 // Decl refers to class and adds USRs of all overridden methods if Decl refers 55 // to virtual method. 56 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> { 57 public: 58 AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context) 59 : FoundDecl(FoundDecl), Context(Context) {} 60 61 std::vector<std::string> Find() { 62 // Fill OverriddenMethods and PartialSpecs storages. 63 TraverseAST(Context); 64 if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) { 65 addUSRsOfOverridenFunctions(MethodDecl); 66 for (const auto &OverriddenMethod : OverriddenMethods) { 67 if (checkIfOverriddenFunctionAscends(OverriddenMethod)) 68 USRSet.insert(getUSRForDecl(OverriddenMethod)); 69 } 70 addUSRsOfInstantiatedMethods(MethodDecl); 71 } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) { 72 handleCXXRecordDecl(RecordDecl); 73 } else if (const auto *TemplateDecl = 74 dyn_cast<ClassTemplateDecl>(FoundDecl)) { 75 handleClassTemplateDecl(TemplateDecl); 76 } else if (const auto *FD = dyn_cast<FunctionDecl>(FoundDecl)) { 77 USRSet.insert(getUSRForDecl(FD)); 78 if (const auto *FTD = FD->getPrimaryTemplate()) 79 handleFunctionTemplateDecl(FTD); 80 } else if (const auto *FD = dyn_cast<FunctionTemplateDecl>(FoundDecl)) { 81 handleFunctionTemplateDecl(FD); 82 } else if (const auto *VTD = dyn_cast<VarTemplateDecl>(FoundDecl)) { 83 handleVarTemplateDecl(VTD); 84 } else if (const auto *VD = 85 dyn_cast<VarTemplateSpecializationDecl>(FoundDecl)) { 86 // FIXME: figure out why FoundDecl can be a VarTemplateSpecializationDecl. 87 handleVarTemplateDecl(VD->getSpecializedTemplate()); 88 } else if (const auto *VD = dyn_cast<VarDecl>(FoundDecl)) { 89 USRSet.insert(getUSRForDecl(VD)); 90 if (const auto *VTD = VD->getDescribedVarTemplate()) 91 handleVarTemplateDecl(VTD); 92 } else { 93 USRSet.insert(getUSRForDecl(FoundDecl)); 94 } 95 return std::vector<std::string>(USRSet.begin(), USRSet.end()); 96 } 97 98 bool shouldVisitTemplateInstantiations() const { return true; } 99 100 bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) { 101 if (MethodDecl->isVirtual()) 102 OverriddenMethods.push_back(MethodDecl); 103 if (MethodDecl->getInstantiatedFromMemberFunction()) 104 InstantiatedMethods.push_back(MethodDecl); 105 return true; 106 } 107 108 private: 109 void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) { 110 if (!RecordDecl->getDefinition()) { 111 USRSet.insert(getUSRForDecl(RecordDecl)); 112 return; 113 } 114 RecordDecl = RecordDecl->getDefinition(); 115 if (const auto *ClassTemplateSpecDecl = 116 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl)) 117 handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate()); 118 addUSRsOfCtorDtors(RecordDecl); 119 } 120 121 void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) { 122 for (const auto *Specialization : TemplateDecl->specializations()) 123 addUSRsOfCtorDtors(Specialization); 124 SmallVector<ClassTemplatePartialSpecializationDecl *, 4> PartialSpecs; 125 TemplateDecl->getPartialSpecializations(PartialSpecs); 126 for (const auto *Spec : PartialSpecs) 127 addUSRsOfCtorDtors(Spec); 128 addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl()); 129 } 130 131 void handleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) { 132 USRSet.insert(getUSRForDecl(FTD)); 133 USRSet.insert(getUSRForDecl(FTD->getTemplatedDecl())); 134 for (const auto *S : FTD->specializations()) 135 USRSet.insert(getUSRForDecl(S)); 136 } 137 138 void handleVarTemplateDecl(const VarTemplateDecl *VTD) { 139 USRSet.insert(getUSRForDecl(VTD)); 140 USRSet.insert(getUSRForDecl(VTD->getTemplatedDecl())); 141 for (const auto *Spec : VTD->specializations()) 142 USRSet.insert(getUSRForDecl(Spec)); 143 SmallVector<VarTemplatePartialSpecializationDecl *, 4> PartialSpecs; 144 VTD->getPartialSpecializations(PartialSpecs); 145 for (const auto *Spec : PartialSpecs) 146 USRSet.insert(getUSRForDecl(Spec)); 147 } 148 149 void addUSRsOfCtorDtors(const CXXRecordDecl *RD) { 150 const auto* RecordDecl = RD->getDefinition(); 151 152 // Skip if the CXXRecordDecl doesn't have definition. 153 if (!RecordDecl) { 154 USRSet.insert(getUSRForDecl(RD)); 155 return; 156 } 157 158 for (const auto *CtorDecl : RecordDecl->ctors()) 159 USRSet.insert(getUSRForDecl(CtorDecl)); 160 // Add template constructor decls, they are not in ctors() unfortunately. 161 if (RecordDecl->hasUserDeclaredConstructor()) 162 for (const auto *D : RecordDecl->decls()) 163 if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D)) 164 if (const auto *Ctor = 165 dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl())) 166 USRSet.insert(getUSRForDecl(Ctor)); 167 168 USRSet.insert(getUSRForDecl(RecordDecl->getDestructor())); 169 USRSet.insert(getUSRForDecl(RecordDecl)); 170 } 171 172 void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) { 173 USRSet.insert(getUSRForDecl(MethodDecl)); 174 // Recursively visit each OverridenMethod. 175 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) 176 addUSRsOfOverridenFunctions(OverriddenMethod); 177 } 178 179 void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) { 180 // For renaming a class template method, all references of the instantiated 181 // member methods should be renamed too, so add USRs of the instantiated 182 // methods to the USR set. 183 USRSet.insert(getUSRForDecl(MethodDecl)); 184 if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction()) 185 USRSet.insert(getUSRForDecl(FT)); 186 for (const auto *Method : InstantiatedMethods) { 187 if (USRSet.find(getUSRForDecl( 188 Method->getInstantiatedFromMemberFunction())) != USRSet.end()) 189 USRSet.insert(getUSRForDecl(Method)); 190 } 191 } 192 193 bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) { 194 for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) { 195 if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end()) 196 return true; 197 return checkIfOverriddenFunctionAscends(OverriddenMethod); 198 } 199 return false; 200 } 201 202 const Decl *FoundDecl; 203 ASTContext &Context; 204 std::set<std::string> USRSet; 205 std::vector<const CXXMethodDecl *> OverriddenMethods; 206 std::vector<const CXXMethodDecl *> InstantiatedMethods; 207 }; 208 } // namespace 209 210 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND, 211 ASTContext &Context) { 212 AdditionalUSRFinder Finder(ND, Context); 213 return Finder.Find(); 214 } 215 216 class NamedDeclFindingConsumer : public ASTConsumer { 217 public: 218 NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets, 219 ArrayRef<std::string> QualifiedNames, 220 std::vector<std::string> &SpellingNames, 221 std::vector<std::vector<std::string>> &USRList, 222 bool Force, bool &ErrorOccurred) 223 : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames), 224 SpellingNames(SpellingNames), USRList(USRList), Force(Force), 225 ErrorOccurred(ErrorOccurred) {} 226 227 private: 228 bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr, 229 unsigned SymbolOffset, const std::string &QualifiedName) { 230 DiagnosticsEngine &Engine = Context.getDiagnostics(); 231 const FileID MainFileID = SourceMgr.getMainFileID(); 232 233 if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) { 234 ErrorOccurred = true; 235 unsigned InvalidOffset = Engine.getCustomDiagID( 236 DiagnosticsEngine::Error, 237 "SourceLocation in file %0 at offset %1 is invalid"); 238 Engine.Report(SourceLocation(), InvalidOffset) 239 << SourceMgr.getFileEntryRefForID(MainFileID)->getName() 240 << SymbolOffset; 241 return false; 242 } 243 244 const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID) 245 .getLocWithOffset(SymbolOffset); 246 const NamedDecl *FoundDecl = QualifiedName.empty() 247 ? getNamedDeclAt(Context, Point) 248 : getNamedDeclFor(Context, QualifiedName); 249 250 if (FoundDecl == nullptr) { 251 if (QualifiedName.empty()) { 252 FullSourceLoc FullLoc(Point, SourceMgr); 253 unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID( 254 DiagnosticsEngine::Error, 255 "clang-rename could not find symbol (offset %0)"); 256 Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset; 257 ErrorOccurred = true; 258 return false; 259 } 260 261 if (Force) { 262 SpellingNames.push_back(std::string()); 263 USRList.push_back(std::vector<std::string>()); 264 return true; 265 } 266 267 unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID( 268 DiagnosticsEngine::Error, "clang-rename could not find symbol %0"); 269 Engine.Report(CouldNotFindSymbolNamed) << QualifiedName; 270 ErrorOccurred = true; 271 return false; 272 } 273 274 FoundDecl = getCanonicalSymbolDeclaration(FoundDecl); 275 SpellingNames.push_back(FoundDecl->getNameAsString()); 276 AdditionalUSRFinder Finder(FoundDecl, Context); 277 USRList.push_back(Finder.Find()); 278 return true; 279 } 280 281 void HandleTranslationUnit(ASTContext &Context) override { 282 const SourceManager &SourceMgr = Context.getSourceManager(); 283 for (unsigned Offset : SymbolOffsets) { 284 if (!FindSymbol(Context, SourceMgr, Offset, "")) 285 return; 286 } 287 for (const std::string &QualifiedName : QualifiedNames) { 288 if (!FindSymbol(Context, SourceMgr, 0, QualifiedName)) 289 return; 290 } 291 } 292 293 ArrayRef<unsigned> SymbolOffsets; 294 ArrayRef<std::string> QualifiedNames; 295 std::vector<std::string> &SpellingNames; 296 std::vector<std::vector<std::string>> &USRList; 297 bool Force; 298 bool &ErrorOccurred; 299 }; 300 301 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() { 302 return std::make_unique<NamedDeclFindingConsumer>( 303 SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force, 304 ErrorOccurred); 305 } 306 307 } // end namespace tooling 308 } // end namespace clang 309