xref: /freebsd/contrib/llvm-project/clang/lib/Tooling/Refactoring/Rename/USRFindingAction.cpp (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Provides an action to find USR for the symbol at <offset>, as well as
11 /// all additional USRs.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16 #include "clang/AST/ASTConsumer.h"
17 #include "clang/AST/ASTContext.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/Frontend/CompilerInstance.h"
21 #include "clang/Lex/Lexer.h"
22 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
23 #include "clang/Tooling/Tooling.h"
24 
25 #include <set>
26 #include <string>
27 #include <vector>
28 
29 using namespace llvm;
30 
31 namespace clang {
32 namespace tooling {
33 
34 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
35   if (!FoundDecl)
36     return nullptr;
37   // If FoundDecl is a constructor or destructor, we want to instead take
38   // the Decl of the corresponding class.
39   if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
40     FoundDecl = CtorDecl->getParent();
41   else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
42     FoundDecl = DtorDecl->getParent();
43   // FIXME: (Alex L): Canonicalize implicit template instantions, just like
44   // the indexer does it.
45 
46   // Note: please update the declaration's doc comment every time the
47   // canonicalization rules are changed.
48   return FoundDecl;
49 }
50 
51 namespace {
52 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
53 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
54 // Decl refers to class and adds USRs of all overridden methods if Decl refers
55 // to virtual method.
56 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
57 public:
58   AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
59       : FoundDecl(FoundDecl), Context(Context) {}
60 
61   std::vector<std::string> Find() {
62     // Fill OverriddenMethods and PartialSpecs storages.
63     TraverseAST(Context);
64     if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
65       addUSRsOfOverridenFunctions(MethodDecl);
66       for (const auto &OverriddenMethod : OverriddenMethods) {
67         if (checkIfOverriddenFunctionAscends(OverriddenMethod))
68           USRSet.insert(getUSRForDecl(OverriddenMethod));
69       }
70       addUSRsOfInstantiatedMethods(MethodDecl);
71     } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
72       handleCXXRecordDecl(RecordDecl);
73     } else if (const auto *TemplateDecl =
74                    dyn_cast<ClassTemplateDecl>(FoundDecl)) {
75       handleClassTemplateDecl(TemplateDecl);
76     } else if (const auto *FD = dyn_cast<FunctionDecl>(FoundDecl)) {
77       USRSet.insert(getUSRForDecl(FD));
78       if (const auto *FTD = FD->getPrimaryTemplate())
79         handleFunctionTemplateDecl(FTD);
80     } else if (const auto *FD = dyn_cast<FunctionTemplateDecl>(FoundDecl)) {
81       handleFunctionTemplateDecl(FD);
82     } else if (const auto *VTD = dyn_cast<VarTemplateDecl>(FoundDecl)) {
83       handleVarTemplateDecl(VTD);
84     } else if (const auto *VD =
85                    dyn_cast<VarTemplateSpecializationDecl>(FoundDecl)) {
86       // FIXME: figure out why FoundDecl can be a VarTemplateSpecializationDecl.
87       handleVarTemplateDecl(VD->getSpecializedTemplate());
88     } else if (const auto *VD = dyn_cast<VarDecl>(FoundDecl)) {
89       USRSet.insert(getUSRForDecl(VD));
90       if (const auto *VTD = VD->getDescribedVarTemplate())
91         handleVarTemplateDecl(VTD);
92     } else {
93       USRSet.insert(getUSRForDecl(FoundDecl));
94     }
95     return std::vector<std::string>(USRSet.begin(), USRSet.end());
96   }
97 
98   bool shouldVisitTemplateInstantiations() const { return true; }
99 
100   bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
101     if (MethodDecl->isVirtual())
102       OverriddenMethods.push_back(MethodDecl);
103     if (MethodDecl->getInstantiatedFromMemberFunction())
104       InstantiatedMethods.push_back(MethodDecl);
105     return true;
106   }
107 
108 private:
109   void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
110     if (!RecordDecl->getDefinition()) {
111       USRSet.insert(getUSRForDecl(RecordDecl));
112       return;
113     }
114     RecordDecl = RecordDecl->getDefinition();
115     if (const auto *ClassTemplateSpecDecl =
116             dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
117       handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
118     addUSRsOfCtorDtors(RecordDecl);
119   }
120 
121   void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
122     for (const auto *Specialization : TemplateDecl->specializations())
123       addUSRsOfCtorDtors(Specialization);
124     SmallVector<ClassTemplatePartialSpecializationDecl *, 4> PartialSpecs;
125     TemplateDecl->getPartialSpecializations(PartialSpecs);
126     for (const auto *Spec : PartialSpecs)
127       addUSRsOfCtorDtors(Spec);
128     addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
129   }
130 
131   void handleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) {
132     USRSet.insert(getUSRForDecl(FTD));
133     USRSet.insert(getUSRForDecl(FTD->getTemplatedDecl()));
134     for (const auto *S : FTD->specializations())
135       USRSet.insert(getUSRForDecl(S));
136   }
137 
138   void handleVarTemplateDecl(const VarTemplateDecl *VTD) {
139     USRSet.insert(getUSRForDecl(VTD));
140     USRSet.insert(getUSRForDecl(VTD->getTemplatedDecl()));
141     for (const auto *Spec : VTD->specializations())
142       USRSet.insert(getUSRForDecl(Spec));
143     SmallVector<VarTemplatePartialSpecializationDecl *, 4> PartialSpecs;
144     VTD->getPartialSpecializations(PartialSpecs);
145     for (const auto *Spec : PartialSpecs)
146       USRSet.insert(getUSRForDecl(Spec));
147   }
148 
149   void addUSRsOfCtorDtors(const CXXRecordDecl *RD) {
150     const auto* RecordDecl = RD->getDefinition();
151 
152     // Skip if the CXXRecordDecl doesn't have definition.
153     if (!RecordDecl) {
154       USRSet.insert(getUSRForDecl(RD));
155       return;
156     }
157 
158     for (const auto *CtorDecl : RecordDecl->ctors())
159       USRSet.insert(getUSRForDecl(CtorDecl));
160     // Add template constructor decls, they are not in ctors() unfortunately.
161     if (RecordDecl->hasUserDeclaredConstructor())
162       for (const auto *D : RecordDecl->decls())
163         if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
164           if (const auto *Ctor =
165                   dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl()))
166             USRSet.insert(getUSRForDecl(Ctor));
167 
168     USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
169     USRSet.insert(getUSRForDecl(RecordDecl));
170   }
171 
172   void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
173     USRSet.insert(getUSRForDecl(MethodDecl));
174     // Recursively visit each OverridenMethod.
175     for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
176       addUSRsOfOverridenFunctions(OverriddenMethod);
177   }
178 
179   void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
180     // For renaming a class template method, all references of the instantiated
181     // member methods should be renamed too, so add USRs of the instantiated
182     // methods to the USR set.
183     USRSet.insert(getUSRForDecl(MethodDecl));
184     if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
185       USRSet.insert(getUSRForDecl(FT));
186     for (const auto *Method : InstantiatedMethods) {
187       if (USRSet.find(getUSRForDecl(
188               Method->getInstantiatedFromMemberFunction())) != USRSet.end())
189         USRSet.insert(getUSRForDecl(Method));
190     }
191   }
192 
193   bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
194     for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
195       if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
196         return true;
197       return checkIfOverriddenFunctionAscends(OverriddenMethod);
198     }
199     return false;
200   }
201 
202   const Decl *FoundDecl;
203   ASTContext &Context;
204   std::set<std::string> USRSet;
205   std::vector<const CXXMethodDecl *> OverriddenMethods;
206   std::vector<const CXXMethodDecl *> InstantiatedMethods;
207 };
208 } // namespace
209 
210 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
211                                                ASTContext &Context) {
212   AdditionalUSRFinder Finder(ND, Context);
213   return Finder.Find();
214 }
215 
216 class NamedDeclFindingConsumer : public ASTConsumer {
217 public:
218   NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
219                            ArrayRef<std::string> QualifiedNames,
220                            std::vector<std::string> &SpellingNames,
221                            std::vector<std::vector<std::string>> &USRList,
222                            bool Force, bool &ErrorOccurred)
223       : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
224         SpellingNames(SpellingNames), USRList(USRList), Force(Force),
225         ErrorOccurred(ErrorOccurred) {}
226 
227 private:
228   bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
229                   unsigned SymbolOffset, const std::string &QualifiedName) {
230     DiagnosticsEngine &Engine = Context.getDiagnostics();
231     const FileID MainFileID = SourceMgr.getMainFileID();
232 
233     if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
234       ErrorOccurred = true;
235       unsigned InvalidOffset = Engine.getCustomDiagID(
236           DiagnosticsEngine::Error,
237           "SourceLocation in file %0 at offset %1 is invalid");
238       Engine.Report(SourceLocation(), InvalidOffset)
239           << SourceMgr.getFileEntryRefForID(MainFileID)->getName()
240           << SymbolOffset;
241       return false;
242     }
243 
244     const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
245                                      .getLocWithOffset(SymbolOffset);
246     const NamedDecl *FoundDecl = QualifiedName.empty()
247                                      ? getNamedDeclAt(Context, Point)
248                                      : getNamedDeclFor(Context, QualifiedName);
249 
250     if (FoundDecl == nullptr) {
251       if (QualifiedName.empty()) {
252         FullSourceLoc FullLoc(Point, SourceMgr);
253         unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
254             DiagnosticsEngine::Error,
255             "clang-rename could not find symbol (offset %0)");
256         Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
257         ErrorOccurred = true;
258         return false;
259       }
260 
261       if (Force) {
262         SpellingNames.push_back(std::string());
263         USRList.push_back(std::vector<std::string>());
264         return true;
265       }
266 
267       unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
268           DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
269       Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
270       ErrorOccurred = true;
271       return false;
272     }
273 
274     FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
275     SpellingNames.push_back(FoundDecl->getNameAsString());
276     AdditionalUSRFinder Finder(FoundDecl, Context);
277     USRList.push_back(Finder.Find());
278     return true;
279   }
280 
281   void HandleTranslationUnit(ASTContext &Context) override {
282     const SourceManager &SourceMgr = Context.getSourceManager();
283     for (unsigned Offset : SymbolOffsets) {
284       if (!FindSymbol(Context, SourceMgr, Offset, ""))
285         return;
286     }
287     for (const std::string &QualifiedName : QualifiedNames) {
288       if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
289         return;
290     }
291   }
292 
293   ArrayRef<unsigned> SymbolOffsets;
294   ArrayRef<std::string> QualifiedNames;
295   std::vector<std::string> &SpellingNames;
296   std::vector<std::vector<std::string>> &USRList;
297   bool Force;
298   bool &ErrorOccurred;
299 };
300 
301 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
302   return std::make_unique<NamedDeclFindingConsumer>(
303       SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
304       ErrorOccurred);
305 }
306 
307 } // end namespace tooling
308 } // end namespace clang
309