xref: /freebsd/contrib/llvm-project/clang/lib/Tooling/Refactoring/Rename/USRFindingAction.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Provides an action to find USR for the symbol at <offset>, as well as
11 /// all additional USRs.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16 #include "clang/AST/AST.h"
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/Decl.h"
20 #include "clang/AST/RecursiveASTVisitor.h"
21 #include "clang/Basic/FileManager.h"
22 #include "clang/Frontend/CompilerInstance.h"
23 #include "clang/Frontend/FrontendAction.h"
24 #include "clang/Lex/Lexer.h"
25 #include "clang/Lex/Preprocessor.h"
26 #include "clang/Tooling/CommonOptionsParser.h"
27 #include "clang/Tooling/Refactoring.h"
28 #include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29 #include "clang/Tooling/Tooling.h"
30 
31 #include <algorithm>
32 #include <set>
33 #include <string>
34 #include <vector>
35 
36 using namespace llvm;
37 
38 namespace clang {
39 namespace tooling {
40 
41 const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42   if (!FoundDecl)
43     return nullptr;
44   // If FoundDecl is a constructor or destructor, we want to instead take
45   // the Decl of the corresponding class.
46   if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
47     FoundDecl = CtorDecl->getParent();
48   else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
49     FoundDecl = DtorDecl->getParent();
50   // FIXME: (Alex L): Canonicalize implicit template instantions, just like
51   // the indexer does it.
52 
53   // Note: please update the declaration's doc comment every time the
54   // canonicalization rules are changed.
55   return FoundDecl;
56 }
57 
58 namespace {
59 // NamedDeclFindingConsumer should delegate finding USRs of given Decl to
60 // AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
61 // Decl refers to class and adds USRs of all overridden methods if Decl refers
62 // to virtual method.
63 class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64 public:
65   AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
66       : FoundDecl(FoundDecl), Context(Context) {}
67 
68   std::vector<std::string> Find() {
69     // Fill OverriddenMethods and PartialSpecs storages.
70     TraverseAST(Context);
71     if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
72       addUSRsOfOverridenFunctions(MethodDecl);
73       for (const auto &OverriddenMethod : OverriddenMethods) {
74         if (checkIfOverriddenFunctionAscends(OverriddenMethod))
75           USRSet.insert(getUSRForDecl(OverriddenMethod));
76       }
77       addUSRsOfInstantiatedMethods(MethodDecl);
78     } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
79       handleCXXRecordDecl(RecordDecl);
80     } else if (const auto *TemplateDecl =
81                    dyn_cast<ClassTemplateDecl>(FoundDecl)) {
82       handleClassTemplateDecl(TemplateDecl);
83     } else if (const auto *FD = dyn_cast<FunctionDecl>(FoundDecl)) {
84       USRSet.insert(getUSRForDecl(FD));
85       if (const auto *FTD = FD->getPrimaryTemplate())
86         handleFunctionTemplateDecl(FTD);
87     } else if (const auto *FD = dyn_cast<FunctionTemplateDecl>(FoundDecl)) {
88       handleFunctionTemplateDecl(FD);
89     } else if (const auto *VTD = dyn_cast<VarTemplateDecl>(FoundDecl)) {
90       handleVarTemplateDecl(VTD);
91     } else if (const auto *VD =
92                    dyn_cast<VarTemplateSpecializationDecl>(FoundDecl)) {
93       // FIXME: figure out why FoundDecl can be a VarTemplateSpecializationDecl.
94       handleVarTemplateDecl(VD->getSpecializedTemplate());
95     } else if (const auto *VD = dyn_cast<VarDecl>(FoundDecl)) {
96       USRSet.insert(getUSRForDecl(VD));
97       if (const auto *VTD = VD->getDescribedVarTemplate())
98         handleVarTemplateDecl(VTD);
99     } else {
100       USRSet.insert(getUSRForDecl(FoundDecl));
101     }
102     return std::vector<std::string>(USRSet.begin(), USRSet.end());
103   }
104 
105   bool shouldVisitTemplateInstantiations() const { return true; }
106 
107   bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
108     if (MethodDecl->isVirtual())
109       OverriddenMethods.push_back(MethodDecl);
110     if (MethodDecl->getInstantiatedFromMemberFunction())
111       InstantiatedMethods.push_back(MethodDecl);
112     return true;
113   }
114 
115 private:
116   void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
117     if (!RecordDecl->getDefinition()) {
118       USRSet.insert(getUSRForDecl(RecordDecl));
119       return;
120     }
121     RecordDecl = RecordDecl->getDefinition();
122     if (const auto *ClassTemplateSpecDecl =
123             dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
124       handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
125     addUSRsOfCtorDtors(RecordDecl);
126   }
127 
128   void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
129     for (const auto *Specialization : TemplateDecl->specializations())
130       addUSRsOfCtorDtors(Specialization);
131     SmallVector<ClassTemplatePartialSpecializationDecl *, 4> PartialSpecs;
132     TemplateDecl->getPartialSpecializations(PartialSpecs);
133     for (const auto *Spec : PartialSpecs)
134       addUSRsOfCtorDtors(Spec);
135     addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
136   }
137 
138   void handleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) {
139     USRSet.insert(getUSRForDecl(FTD));
140     USRSet.insert(getUSRForDecl(FTD->getTemplatedDecl()));
141     for (const auto *S : FTD->specializations())
142       USRSet.insert(getUSRForDecl(S));
143   }
144 
145   void handleVarTemplateDecl(const VarTemplateDecl *VTD) {
146     USRSet.insert(getUSRForDecl(VTD));
147     USRSet.insert(getUSRForDecl(VTD->getTemplatedDecl()));
148     for (const auto *Spec : VTD->specializations())
149       USRSet.insert(getUSRForDecl(Spec));
150     SmallVector<VarTemplatePartialSpecializationDecl *, 4> PartialSpecs;
151     VTD->getPartialSpecializations(PartialSpecs);
152     for (const auto *Spec : PartialSpecs)
153       USRSet.insert(getUSRForDecl(Spec));
154   }
155 
156   void addUSRsOfCtorDtors(const CXXRecordDecl *RD) {
157     const auto* RecordDecl = RD->getDefinition();
158 
159     // Skip if the CXXRecordDecl doesn't have definition.
160     if (!RecordDecl) {
161       USRSet.insert(getUSRForDecl(RD));
162       return;
163     }
164 
165     for (const auto *CtorDecl : RecordDecl->ctors())
166       USRSet.insert(getUSRForDecl(CtorDecl));
167     // Add template constructor decls, they are not in ctors() unfortunately.
168     if (RecordDecl->hasUserDeclaredConstructor())
169       for (const auto *D : RecordDecl->decls())
170         if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
171           if (const auto *Ctor =
172                   dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl()))
173             USRSet.insert(getUSRForDecl(Ctor));
174 
175     USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
176     USRSet.insert(getUSRForDecl(RecordDecl));
177   }
178 
179   void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
180     USRSet.insert(getUSRForDecl(MethodDecl));
181     // Recursively visit each OverridenMethod.
182     for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
183       addUSRsOfOverridenFunctions(OverriddenMethod);
184   }
185 
186   void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
187     // For renaming a class template method, all references of the instantiated
188     // member methods should be renamed too, so add USRs of the instantiated
189     // methods to the USR set.
190     USRSet.insert(getUSRForDecl(MethodDecl));
191     if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
192       USRSet.insert(getUSRForDecl(FT));
193     for (const auto *Method : InstantiatedMethods) {
194       if (USRSet.find(getUSRForDecl(
195               Method->getInstantiatedFromMemberFunction())) != USRSet.end())
196         USRSet.insert(getUSRForDecl(Method));
197     }
198   }
199 
200   bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
201     for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
202       if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
203         return true;
204       return checkIfOverriddenFunctionAscends(OverriddenMethod);
205     }
206     return false;
207   }
208 
209   const Decl *FoundDecl;
210   ASTContext &Context;
211   std::set<std::string> USRSet;
212   std::vector<const CXXMethodDecl *> OverriddenMethods;
213   std::vector<const CXXMethodDecl *> InstantiatedMethods;
214 };
215 } // namespace
216 
217 std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
218                                                ASTContext &Context) {
219   AdditionalUSRFinder Finder(ND, Context);
220   return Finder.Find();
221 }
222 
223 class NamedDeclFindingConsumer : public ASTConsumer {
224 public:
225   NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
226                            ArrayRef<std::string> QualifiedNames,
227                            std::vector<std::string> &SpellingNames,
228                            std::vector<std::vector<std::string>> &USRList,
229                            bool Force, bool &ErrorOccurred)
230       : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
231         SpellingNames(SpellingNames), USRList(USRList), Force(Force),
232         ErrorOccurred(ErrorOccurred) {}
233 
234 private:
235   bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
236                   unsigned SymbolOffset, const std::string &QualifiedName) {
237     DiagnosticsEngine &Engine = Context.getDiagnostics();
238     const FileID MainFileID = SourceMgr.getMainFileID();
239 
240     if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
241       ErrorOccurred = true;
242       unsigned InvalidOffset = Engine.getCustomDiagID(
243           DiagnosticsEngine::Error,
244           "SourceLocation in file %0 at offset %1 is invalid");
245       Engine.Report(SourceLocation(), InvalidOffset)
246           << SourceMgr.getFileEntryRefForID(MainFileID)->getName()
247           << SymbolOffset;
248       return false;
249     }
250 
251     const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
252                                      .getLocWithOffset(SymbolOffset);
253     const NamedDecl *FoundDecl = QualifiedName.empty()
254                                      ? getNamedDeclAt(Context, Point)
255                                      : getNamedDeclFor(Context, QualifiedName);
256 
257     if (FoundDecl == nullptr) {
258       if (QualifiedName.empty()) {
259         FullSourceLoc FullLoc(Point, SourceMgr);
260         unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
261             DiagnosticsEngine::Error,
262             "clang-rename could not find symbol (offset %0)");
263         Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
264         ErrorOccurred = true;
265         return false;
266       }
267 
268       if (Force) {
269         SpellingNames.push_back(std::string());
270         USRList.push_back(std::vector<std::string>());
271         return true;
272       }
273 
274       unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
275           DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
276       Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
277       ErrorOccurred = true;
278       return false;
279     }
280 
281     FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
282     SpellingNames.push_back(FoundDecl->getNameAsString());
283     AdditionalUSRFinder Finder(FoundDecl, Context);
284     USRList.push_back(Finder.Find());
285     return true;
286   }
287 
288   void HandleTranslationUnit(ASTContext &Context) override {
289     const SourceManager &SourceMgr = Context.getSourceManager();
290     for (unsigned Offset : SymbolOffsets) {
291       if (!FindSymbol(Context, SourceMgr, Offset, ""))
292         return;
293     }
294     for (const std::string &QualifiedName : QualifiedNames) {
295       if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
296         return;
297     }
298   }
299 
300   ArrayRef<unsigned> SymbolOffsets;
301   ArrayRef<std::string> QualifiedNames;
302   std::vector<std::string> &SpellingNames;
303   std::vector<std::vector<std::string>> &USRList;
304   bool Force;
305   bool &ErrorOccurred;
306 };
307 
308 std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
309   return std::make_unique<NamedDeclFindingConsumer>(
310       SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
311       ErrorOccurred);
312 }
313 
314 } // end namespace tooling
315 } // end namespace clang
316