//===--- USRFindingAction.cpp - Clang refactoring library -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file /// Provides an action to find USR for the symbol at , as well as /// all additional USRs. /// //===----------------------------------------------------------------------===// #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h" #include "clang/AST/AST.h" #include "clang/AST/ASTConsumer.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/FileManager.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Lex/Lexer.h" #include "clang/Lex/Preprocessor.h" #include "clang/Tooling/CommonOptionsParser.h" #include "clang/Tooling/Refactoring.h" #include "clang/Tooling/Refactoring/Rename/USRFinder.h" #include "clang/Tooling/Tooling.h" #include #include #include #include using namespace llvm; namespace clang { namespace tooling { const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) { if (!FoundDecl) return nullptr; // If FoundDecl is a constructor or destructor, we want to instead take // the Decl of the corresponding class. if (const auto *CtorDecl = dyn_cast(FoundDecl)) FoundDecl = CtorDecl->getParent(); else if (const auto *DtorDecl = dyn_cast(FoundDecl)) FoundDecl = DtorDecl->getParent(); // FIXME: (Alex L): Canonicalize implicit template instantions, just like // the indexer does it. // Note: please update the declaration's doc comment every time the // canonicalization rules are changed. return FoundDecl; } namespace { // NamedDeclFindingConsumer should delegate finding USRs of given Decl to // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given // Decl refers to class and adds USRs of all overridden methods if Decl refers // to virtual method. class AdditionalUSRFinder : public RecursiveASTVisitor { public: AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context) : FoundDecl(FoundDecl), Context(Context) {} std::vector Find() { // Fill OverriddenMethods and PartialSpecs storages. TraverseAST(Context); if (const auto *MethodDecl = dyn_cast(FoundDecl)) { addUSRsOfOverridenFunctions(MethodDecl); for (const auto &OverriddenMethod : OverriddenMethods) { if (checkIfOverriddenFunctionAscends(OverriddenMethod)) USRSet.insert(getUSRForDecl(OverriddenMethod)); } addUSRsOfInstantiatedMethods(MethodDecl); } else if (const auto *RecordDecl = dyn_cast(FoundDecl)) { handleCXXRecordDecl(RecordDecl); } else if (const auto *TemplateDecl = dyn_cast(FoundDecl)) { handleClassTemplateDecl(TemplateDecl); } else { USRSet.insert(getUSRForDecl(FoundDecl)); } return std::vector(USRSet.begin(), USRSet.end()); } bool shouldVisitTemplateInstantiations() const { return true; } bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) { if (MethodDecl->isVirtual()) OverriddenMethods.push_back(MethodDecl); if (MethodDecl->getInstantiatedFromMemberFunction()) InstantiatedMethods.push_back(MethodDecl); return true; } bool VisitClassTemplatePartialSpecializationDecl( const ClassTemplatePartialSpecializationDecl *PartialSpec) { PartialSpecs.push_back(PartialSpec); return true; } private: void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) { if (!RecordDecl->getDefinition()) { USRSet.insert(getUSRForDecl(RecordDecl)); return; } RecordDecl = RecordDecl->getDefinition(); if (const auto *ClassTemplateSpecDecl = dyn_cast(RecordDecl)) handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate()); addUSRsOfCtorDtors(RecordDecl); } void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) { for (const auto *Specialization : TemplateDecl->specializations()) addUSRsOfCtorDtors(Specialization); for (const auto *PartialSpec : PartialSpecs) { if (PartialSpec->getSpecializedTemplate() == TemplateDecl) addUSRsOfCtorDtors(PartialSpec); } addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl()); } void addUSRsOfCtorDtors(const CXXRecordDecl *RD) { const auto* RecordDecl = RD->getDefinition(); // Skip if the CXXRecordDecl doesn't have definition. if (!RecordDecl) { USRSet.insert(getUSRForDecl(RD)); return; } for (const auto *CtorDecl : RecordDecl->ctors()) USRSet.insert(getUSRForDecl(CtorDecl)); // Add template constructor decls, they are not in ctors() unfortunately. if (RecordDecl->hasUserDeclaredConstructor()) for (const auto *D : RecordDecl->decls()) if (const auto *FTD = dyn_cast(D)) if (const auto *Ctor = dyn_cast(FTD->getTemplatedDecl())) USRSet.insert(getUSRForDecl(Ctor)); USRSet.insert(getUSRForDecl(RecordDecl->getDestructor())); USRSet.insert(getUSRForDecl(RecordDecl)); } void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) { USRSet.insert(getUSRForDecl(MethodDecl)); // Recursively visit each OverridenMethod. for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) addUSRsOfOverridenFunctions(OverriddenMethod); } void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) { // For renaming a class template method, all references of the instantiated // member methods should be renamed too, so add USRs of the instantiated // methods to the USR set. USRSet.insert(getUSRForDecl(MethodDecl)); if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction()) USRSet.insert(getUSRForDecl(FT)); for (const auto *Method : InstantiatedMethods) { if (USRSet.find(getUSRForDecl( Method->getInstantiatedFromMemberFunction())) != USRSet.end()) USRSet.insert(getUSRForDecl(Method)); } } bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) { for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) { if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end()) return true; return checkIfOverriddenFunctionAscends(OverriddenMethod); } return false; } const Decl *FoundDecl; ASTContext &Context; std::set USRSet; std::vector OverriddenMethods; std::vector InstantiatedMethods; std::vector PartialSpecs; }; } // namespace std::vector getUSRsForDeclaration(const NamedDecl *ND, ASTContext &Context) { AdditionalUSRFinder Finder(ND, Context); return Finder.Find(); } class NamedDeclFindingConsumer : public ASTConsumer { public: NamedDeclFindingConsumer(ArrayRef SymbolOffsets, ArrayRef QualifiedNames, std::vector &SpellingNames, std::vector> &USRList, bool Force, bool &ErrorOccurred) : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames), SpellingNames(SpellingNames), USRList(USRList), Force(Force), ErrorOccurred(ErrorOccurred) {} private: bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr, unsigned SymbolOffset, const std::string &QualifiedName) { DiagnosticsEngine &Engine = Context.getDiagnostics(); const FileID MainFileID = SourceMgr.getMainFileID(); if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) { ErrorOccurred = true; unsigned InvalidOffset = Engine.getCustomDiagID( DiagnosticsEngine::Error, "SourceLocation in file %0 at offset %1 is invalid"); Engine.Report(SourceLocation(), InvalidOffset) << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset; return false; } const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID) .getLocWithOffset(SymbolOffset); const NamedDecl *FoundDecl = QualifiedName.empty() ? getNamedDeclAt(Context, Point) : getNamedDeclFor(Context, QualifiedName); if (FoundDecl == nullptr) { if (QualifiedName.empty()) { FullSourceLoc FullLoc(Point, SourceMgr); unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID( DiagnosticsEngine::Error, "clang-rename could not find symbol (offset %0)"); Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset; ErrorOccurred = true; return false; } if (Force) { SpellingNames.push_back(std::string()); USRList.push_back(std::vector()); return true; } unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID( DiagnosticsEngine::Error, "clang-rename could not find symbol %0"); Engine.Report(CouldNotFindSymbolNamed) << QualifiedName; ErrorOccurred = true; return false; } FoundDecl = getCanonicalSymbolDeclaration(FoundDecl); SpellingNames.push_back(FoundDecl->getNameAsString()); AdditionalUSRFinder Finder(FoundDecl, Context); USRList.push_back(Finder.Find()); return true; } void HandleTranslationUnit(ASTContext &Context) override { const SourceManager &SourceMgr = Context.getSourceManager(); for (unsigned Offset : SymbolOffsets) { if (!FindSymbol(Context, SourceMgr, Offset, "")) return; } for (const std::string &QualifiedName : QualifiedNames) { if (!FindSymbol(Context, SourceMgr, 0, QualifiedName)) return; } } ArrayRef SymbolOffsets; ArrayRef QualifiedNames; std::vector &SpellingNames; std::vector> &USRList; bool Force; bool &ErrorOccurred; }; std::unique_ptr USRFindingAction::newASTConsumer() { return std::make_unique( SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force, ErrorOccurred); } } // end namespace tooling } // end namespace clang