xref: /freebsd/contrib/llvm-project/clang/lib/ASTMatchers/ASTMatchFinder.cpp (revision 8ddb146abcdf061be9f2c0db7e391697dafad85c)
1 //===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  Implements an algorithm to efficiently search for matches on AST nodes.
10 //  Uses memoization to support recursive matches like HasDescendant.
11 //
12 //  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
13 //  calling the Matches(...) method of each matcher we are running on each
14 //  AST node. The matcher can recurse via the ASTMatchFinder interface.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "clang/ASTMatchers/ASTMatchFinder.h"
19 #include "clang/AST/ASTConsumer.h"
20 #include "clang/AST/ASTContext.h"
21 #include "clang/AST/RecursiveASTVisitor.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/Support/Timer.h"
25 #include <deque>
26 #include <memory>
27 #include <set>
28 
29 namespace clang {
30 namespace ast_matchers {
31 namespace internal {
32 namespace {
33 
34 typedef MatchFinder::MatchCallback MatchCallback;
35 
36 // The maximum number of memoization entries to store.
37 // 10k has been experimentally found to give a good trade-off
38 // of performance vs. memory consumption by running matcher
39 // that match on every statement over a very large codebase.
40 //
41 // FIXME: Do some performance optimization in general and
42 // revisit this number; also, put up micro-benchmarks that we can
43 // optimize this on.
44 static const unsigned MaxMemoizationEntries = 10000;
45 
46 enum class MatchType {
47   Ancestors,
48 
49   Descendants,
50   Child,
51 };
52 
53 // We use memoization to avoid running the same matcher on the same
54 // AST node twice.  This struct is the key for looking up match
55 // result.  It consists of an ID of the MatcherInterface (for
56 // identifying the matcher), a pointer to the AST node and the
57 // bound nodes before the matcher was executed.
58 //
59 // We currently only memoize on nodes whose pointers identify the
60 // nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc).
61 // For \c QualType and \c TypeLoc it is possible to implement
62 // generation of keys for each type.
63 // FIXME: Benchmark whether memoization of non-pointer typed nodes
64 // provides enough benefit for the additional amount of code.
65 struct MatchKey {
66   DynTypedMatcher::MatcherIDType MatcherID;
67   DynTypedNode Node;
68   BoundNodesTreeBuilder BoundNodes;
69   TraversalKind Traversal = TK_AsIs;
70   MatchType Type;
71 
72   bool operator<(const MatchKey &Other) const {
73     return std::tie(Traversal, Type, MatcherID, Node, BoundNodes) <
74            std::tie(Other.Traversal, Other.Type, Other.MatcherID, Other.Node,
75                     Other.BoundNodes);
76   }
77 };
78 
79 // Used to store the result of a match and possibly bound nodes.
80 struct MemoizedMatchResult {
81   bool ResultOfMatch;
82   BoundNodesTreeBuilder Nodes;
83 };
84 
85 // A RecursiveASTVisitor that traverses all children or all descendants of
86 // a node.
87 class MatchChildASTVisitor
88     : public RecursiveASTVisitor<MatchChildASTVisitor> {
89 public:
90   typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
91 
92   // Creates an AST visitor that matches 'matcher' on all children or
93   // descendants of a traversed node. max_depth is the maximum depth
94   // to traverse: use 1 for matching the children and INT_MAX for
95   // matching the descendants.
96   MatchChildASTVisitor(const DynTypedMatcher *Matcher, ASTMatchFinder *Finder,
97                        BoundNodesTreeBuilder *Builder, int MaxDepth,
98                        bool IgnoreImplicitChildren,
99                        ASTMatchFinder::BindKind Bind)
100       : Matcher(Matcher), Finder(Finder), Builder(Builder), CurrentDepth(0),
101         MaxDepth(MaxDepth), IgnoreImplicitChildren(IgnoreImplicitChildren),
102         Bind(Bind), Matches(false) {}
103 
104   // Returns true if a match is found in the subtree rooted at the
105   // given AST node. This is done via a set of mutually recursive
106   // functions. Here's how the recursion is done (the  *wildcard can
107   // actually be Decl, Stmt, or Type):
108   //
109   //   - Traverse(node) calls BaseTraverse(node) when it needs
110   //     to visit the descendants of node.
111   //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
112   //     Traverse*(c) for each child c of 'node'.
113   //   - Traverse*(c) in turn calls Traverse(c), completing the
114   //     recursion.
115   bool findMatch(const DynTypedNode &DynNode) {
116     reset();
117     if (const Decl *D = DynNode.get<Decl>())
118       traverse(*D);
119     else if (const Stmt *S = DynNode.get<Stmt>())
120       traverse(*S);
121     else if (const NestedNameSpecifier *NNS =
122              DynNode.get<NestedNameSpecifier>())
123       traverse(*NNS);
124     else if (const NestedNameSpecifierLoc *NNSLoc =
125              DynNode.get<NestedNameSpecifierLoc>())
126       traverse(*NNSLoc);
127     else if (const QualType *Q = DynNode.get<QualType>())
128       traverse(*Q);
129     else if (const TypeLoc *T = DynNode.get<TypeLoc>())
130       traverse(*T);
131     else if (const auto *C = DynNode.get<CXXCtorInitializer>())
132       traverse(*C);
133     else if (const TemplateArgumentLoc *TALoc =
134                  DynNode.get<TemplateArgumentLoc>())
135       traverse(*TALoc);
136     else if (const Attr *A = DynNode.get<Attr>())
137       traverse(*A);
138     // FIXME: Add other base types after adding tests.
139 
140     // It's OK to always overwrite the bound nodes, as if there was
141     // no match in this recursive branch, the result set is empty
142     // anyway.
143     *Builder = ResultBindings;
144 
145     return Matches;
146   }
147 
148   // The following are overriding methods from the base visitor class.
149   // They are public only to allow CRTP to work. They are *not *part
150   // of the public API of this class.
151   bool TraverseDecl(Decl *DeclNode) {
152 
153     if (DeclNode && DeclNode->isImplicit() &&
154         Finder->isTraversalIgnoringImplicitNodes())
155       return baseTraverse(*DeclNode);
156 
157     ScopedIncrement ScopedDepth(&CurrentDepth);
158     return (DeclNode == nullptr) || traverse(*DeclNode);
159   }
160 
161   Stmt *getStmtToTraverse(Stmt *StmtNode) {
162     Stmt *StmtToTraverse = StmtNode;
163     if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode)) {
164       auto *LambdaNode = dyn_cast_or_null<LambdaExpr>(StmtNode);
165       if (LambdaNode && Finder->isTraversalIgnoringImplicitNodes())
166         StmtToTraverse = LambdaNode;
167       else
168         StmtToTraverse =
169             Finder->getASTContext().getParentMapContext().traverseIgnored(
170                 ExprNode);
171     }
172     return StmtToTraverse;
173   }
174 
175   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr) {
176     // If we need to keep track of the depth, we can't perform data recursion.
177     if (CurrentDepth == 0 || (CurrentDepth <= MaxDepth && MaxDepth < INT_MAX))
178       Queue = nullptr;
179 
180     ScopedIncrement ScopedDepth(&CurrentDepth);
181     Stmt *StmtToTraverse = getStmtToTraverse(StmtNode);
182     if (!StmtToTraverse)
183       return true;
184 
185     if (IgnoreImplicitChildren && isa<CXXDefaultArgExpr>(StmtNode))
186       return true;
187 
188     if (!match(*StmtToTraverse))
189       return false;
190     return VisitorBase::TraverseStmt(StmtToTraverse, Queue);
191   }
192   // We assume that the QualType and the contained type are on the same
193   // hierarchy level. Thus, we try to match either of them.
194   bool TraverseType(QualType TypeNode) {
195     if (TypeNode.isNull())
196       return true;
197     ScopedIncrement ScopedDepth(&CurrentDepth);
198     // Match the Type.
199     if (!match(*TypeNode))
200       return false;
201     // The QualType is matched inside traverse.
202     return traverse(TypeNode);
203   }
204   // We assume that the TypeLoc, contained QualType and contained Type all are
205   // on the same hierarchy level. Thus, we try to match all of them.
206   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
207     if (TypeLocNode.isNull())
208       return true;
209     ScopedIncrement ScopedDepth(&CurrentDepth);
210     // Match the Type.
211     if (!match(*TypeLocNode.getType()))
212       return false;
213     // Match the QualType.
214     if (!match(TypeLocNode.getType()))
215       return false;
216     // The TypeLoc is matched inside traverse.
217     return traverse(TypeLocNode);
218   }
219   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
220     ScopedIncrement ScopedDepth(&CurrentDepth);
221     return (NNS == nullptr) || traverse(*NNS);
222   }
223   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
224     if (!NNS)
225       return true;
226     ScopedIncrement ScopedDepth(&CurrentDepth);
227     if (!match(*NNS.getNestedNameSpecifier()))
228       return false;
229     return traverse(NNS);
230   }
231   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit) {
232     if (!CtorInit)
233       return true;
234     ScopedIncrement ScopedDepth(&CurrentDepth);
235     return traverse(*CtorInit);
236   }
237   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL) {
238     ScopedIncrement ScopedDepth(&CurrentDepth);
239     return traverse(TAL);
240   }
241   bool TraverseCXXForRangeStmt(CXXForRangeStmt *Node) {
242     if (!Finder->isTraversalIgnoringImplicitNodes())
243       return VisitorBase::TraverseCXXForRangeStmt(Node);
244     if (!Node)
245       return true;
246     ScopedIncrement ScopedDepth(&CurrentDepth);
247     if (auto *Init = Node->getInit())
248       if (!traverse(*Init))
249         return false;
250     if (!match(*Node->getLoopVariable()))
251       return false;
252     if (match(*Node->getRangeInit()))
253       if (!VisitorBase::TraverseStmt(Node->getRangeInit()))
254         return false;
255     if (!match(*Node->getBody()))
256       return false;
257     return VisitorBase::TraverseStmt(Node->getBody());
258   }
259   bool TraverseCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *Node) {
260     if (!Finder->isTraversalIgnoringImplicitNodes())
261       return VisitorBase::TraverseCXXRewrittenBinaryOperator(Node);
262     if (!Node)
263       return true;
264     ScopedIncrement ScopedDepth(&CurrentDepth);
265 
266     return match(*Node->getLHS()) && match(*Node->getRHS());
267   }
268   bool TraverseAttr(Attr *A) {
269     if (A == nullptr ||
270         (A->isImplicit() &&
271          Finder->getASTContext().getParentMapContext().getTraversalKind() ==
272              TK_IgnoreUnlessSpelledInSource))
273       return true;
274     ScopedIncrement ScopedDepth(&CurrentDepth);
275     return traverse(*A);
276   }
277   bool TraverseLambdaExpr(LambdaExpr *Node) {
278     if (!Finder->isTraversalIgnoringImplicitNodes())
279       return VisitorBase::TraverseLambdaExpr(Node);
280     if (!Node)
281       return true;
282     ScopedIncrement ScopedDepth(&CurrentDepth);
283 
284     for (unsigned I = 0, N = Node->capture_size(); I != N; ++I) {
285       const auto *C = Node->capture_begin() + I;
286       if (!C->isExplicit())
287         continue;
288       if (Node->isInitCapture(C) && !match(*C->getCapturedVar()))
289         return false;
290       if (!match(*Node->capture_init_begin()[I]))
291         return false;
292     }
293 
294     if (const auto *TPL = Node->getTemplateParameterList()) {
295       for (const auto *TP : *TPL) {
296         if (!match(*TP))
297           return false;
298       }
299     }
300 
301     for (const auto *P : Node->getCallOperator()->parameters()) {
302       if (!match(*P))
303         return false;
304     }
305 
306     if (!match(*Node->getBody()))
307       return false;
308 
309     return VisitorBase::TraverseStmt(Node->getBody());
310   }
311 
312   bool shouldVisitTemplateInstantiations() const { return true; }
313   bool shouldVisitImplicitCode() const { return !IgnoreImplicitChildren; }
314 
315 private:
316   // Used for updating the depth during traversal.
317   struct ScopedIncrement {
318     explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
319     ~ScopedIncrement() { --(*Depth); }
320 
321    private:
322     int *Depth;
323   };
324 
325   // Resets the state of this object.
326   void reset() {
327     Matches = false;
328     CurrentDepth = 0;
329   }
330 
331   // Forwards the call to the corresponding Traverse*() method in the
332   // base visitor class.
333   bool baseTraverse(const Decl &DeclNode) {
334     return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
335   }
336   bool baseTraverse(const Stmt &StmtNode) {
337     return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
338   }
339   bool baseTraverse(QualType TypeNode) {
340     return VisitorBase::TraverseType(TypeNode);
341   }
342   bool baseTraverse(TypeLoc TypeLocNode) {
343     return VisitorBase::TraverseTypeLoc(TypeLocNode);
344   }
345   bool baseTraverse(const NestedNameSpecifier &NNS) {
346     return VisitorBase::TraverseNestedNameSpecifier(
347         const_cast<NestedNameSpecifier*>(&NNS));
348   }
349   bool baseTraverse(NestedNameSpecifierLoc NNS) {
350     return VisitorBase::TraverseNestedNameSpecifierLoc(NNS);
351   }
352   bool baseTraverse(const CXXCtorInitializer &CtorInit) {
353     return VisitorBase::TraverseConstructorInitializer(
354         const_cast<CXXCtorInitializer *>(&CtorInit));
355   }
356   bool baseTraverse(TemplateArgumentLoc TAL) {
357     return VisitorBase::TraverseTemplateArgumentLoc(TAL);
358   }
359   bool baseTraverse(const Attr &AttrNode) {
360     return VisitorBase::TraverseAttr(const_cast<Attr *>(&AttrNode));
361   }
362 
363   // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
364   //   0 < CurrentDepth <= MaxDepth.
365   //
366   // Returns 'true' if traversal should continue after this function
367   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
368   template <typename T>
369   bool match(const T &Node) {
370     if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
371       return true;
372     }
373     if (Bind != ASTMatchFinder::BK_All) {
374       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
375       if (Matcher->matches(DynTypedNode::create(Node), Finder,
376                            &RecursiveBuilder)) {
377         Matches = true;
378         ResultBindings.addMatch(RecursiveBuilder);
379         return false; // Abort as soon as a match is found.
380       }
381     } else {
382       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
383       if (Matcher->matches(DynTypedNode::create(Node), Finder,
384                            &RecursiveBuilder)) {
385         // After the first match the matcher succeeds.
386         Matches = true;
387         ResultBindings.addMatch(RecursiveBuilder);
388       }
389     }
390     return true;
391   }
392 
393   // Traverses the subtree rooted at 'Node'; returns true if the
394   // traversal should continue after this function returns.
395   template <typename T>
396   bool traverse(const T &Node) {
397     static_assert(IsBaseType<T>::value,
398                   "traverse can only be instantiated with base type");
399     if (!match(Node))
400       return false;
401     return baseTraverse(Node);
402   }
403 
404   const DynTypedMatcher *const Matcher;
405   ASTMatchFinder *const Finder;
406   BoundNodesTreeBuilder *const Builder;
407   BoundNodesTreeBuilder ResultBindings;
408   int CurrentDepth;
409   const int MaxDepth;
410   const bool IgnoreImplicitChildren;
411   const ASTMatchFinder::BindKind Bind;
412   bool Matches;
413 };
414 
415 // Controls the outermost traversal of the AST and allows to match multiple
416 // matchers.
417 class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
418                         public ASTMatchFinder {
419 public:
420   MatchASTVisitor(const MatchFinder::MatchersByType *Matchers,
421                   const MatchFinder::MatchFinderOptions &Options)
422       : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {}
423 
424   ~MatchASTVisitor() override {
425     if (Options.CheckProfiling) {
426       Options.CheckProfiling->Records = std::move(TimeByBucket);
427     }
428   }
429 
430   void onStartOfTranslationUnit() {
431     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
432     TimeBucketRegion Timer;
433     for (MatchCallback *MC : Matchers->AllCallbacks) {
434       if (EnableCheckProfiling)
435         Timer.setBucket(&TimeByBucket[MC->getID()]);
436       MC->onStartOfTranslationUnit();
437     }
438   }
439 
440   void onEndOfTranslationUnit() {
441     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
442     TimeBucketRegion Timer;
443     for (MatchCallback *MC : Matchers->AllCallbacks) {
444       if (EnableCheckProfiling)
445         Timer.setBucket(&TimeByBucket[MC->getID()]);
446       MC->onEndOfTranslationUnit();
447     }
448   }
449 
450   void set_active_ast_context(ASTContext *NewActiveASTContext) {
451     ActiveASTContext = NewActiveASTContext;
452   }
453 
454   // The following Visit*() and Traverse*() functions "override"
455   // methods in RecursiveASTVisitor.
456 
457   bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) {
458     // When we see 'typedef A B', we add name 'B' to the set of names
459     // A's canonical type maps to.  This is necessary for implementing
460     // isDerivedFrom(x) properly, where x can be the name of the base
461     // class or any of its aliases.
462     //
463     // In general, the is-alias-of (as defined by typedefs) relation
464     // is tree-shaped, as you can typedef a type more than once.  For
465     // example,
466     //
467     //   typedef A B;
468     //   typedef A C;
469     //   typedef C D;
470     //   typedef C E;
471     //
472     // gives you
473     //
474     //   A
475     //   |- B
476     //   `- C
477     //      |- D
478     //      `- E
479     //
480     // It is wrong to assume that the relation is a chain.  A correct
481     // implementation of isDerivedFrom() needs to recognize that B and
482     // E are aliases, even though neither is a typedef of the other.
483     // Therefore, we cannot simply walk through one typedef chain to
484     // find out whether the type name matches.
485     const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
486     const Type *CanonicalType =  // root of the typedef tree
487         ActiveASTContext->getCanonicalType(TypeNode);
488     TypeAliases[CanonicalType].insert(DeclNode);
489     return true;
490   }
491 
492   bool VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl *CAD) {
493     const ObjCInterfaceDecl *InterfaceDecl = CAD->getClassInterface();
494     CompatibleAliases[InterfaceDecl].insert(CAD);
495     return true;
496   }
497 
498   bool TraverseDecl(Decl *DeclNode);
499   bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr);
500   bool TraverseType(QualType TypeNode);
501   bool TraverseTypeLoc(TypeLoc TypeNode);
502   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS);
503   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS);
504   bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit);
505   bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL);
506   bool TraverseAttr(Attr *AttrNode);
507 
508   bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
509     if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) {
510       {
511         ASTNodeNotAsIsSourceScope RAII(this, true);
512         TraverseStmt(RF->getInit());
513         // Don't traverse under the loop variable
514         match(*RF->getLoopVariable());
515         TraverseStmt(RF->getRangeInit());
516       }
517       {
518         ASTNodeNotSpelledInSourceScope RAII(this, true);
519         for (auto *SubStmt : RF->children()) {
520           if (SubStmt != RF->getBody())
521             TraverseStmt(SubStmt);
522         }
523       }
524       TraverseStmt(RF->getBody());
525       return true;
526     } else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
527       {
528         ASTNodeNotAsIsSourceScope RAII(this, true);
529         TraverseStmt(const_cast<Expr *>(RBO->getLHS()));
530         TraverseStmt(const_cast<Expr *>(RBO->getRHS()));
531       }
532       {
533         ASTNodeNotSpelledInSourceScope RAII(this, true);
534         for (auto *SubStmt : RBO->children()) {
535           TraverseStmt(SubStmt);
536         }
537       }
538       return true;
539     } else if (auto *LE = dyn_cast<LambdaExpr>(S)) {
540       for (auto I : llvm::zip(LE->captures(), LE->capture_inits())) {
541         auto C = std::get<0>(I);
542         ASTNodeNotSpelledInSourceScope RAII(
543             this, TraversingASTNodeNotSpelledInSource || !C.isExplicit());
544         TraverseLambdaCapture(LE, &C, std::get<1>(I));
545       }
546 
547       {
548         ASTNodeNotSpelledInSourceScope RAII(this, true);
549         TraverseDecl(LE->getLambdaClass());
550       }
551       {
552         ASTNodeNotAsIsSourceScope RAII(this, true);
553 
554         // We need to poke around to find the bits that might be explicitly
555         // written.
556         TypeLoc TL = LE->getCallOperator()->getTypeSourceInfo()->getTypeLoc();
557         FunctionProtoTypeLoc Proto = TL.getAsAdjusted<FunctionProtoTypeLoc>();
558 
559         if (auto *TPL = LE->getTemplateParameterList()) {
560           for (NamedDecl *D : *TPL) {
561             TraverseDecl(D);
562           }
563           if (Expr *RequiresClause = TPL->getRequiresClause()) {
564             TraverseStmt(RequiresClause);
565           }
566         }
567 
568         if (LE->hasExplicitParameters()) {
569           // Visit parameters.
570           for (ParmVarDecl *Param : Proto.getParams())
571             TraverseDecl(Param);
572         }
573 
574         const auto *T = Proto.getTypePtr();
575         for (const auto &E : T->exceptions())
576           TraverseType(E);
577 
578         if (Expr *NE = T->getNoexceptExpr())
579           TraverseStmt(NE, Queue);
580 
581         if (LE->hasExplicitResultType())
582           TraverseTypeLoc(Proto.getReturnLoc());
583         TraverseStmt(LE->getTrailingRequiresClause());
584       }
585 
586       TraverseStmt(LE->getBody());
587       return true;
588     }
589     return RecursiveASTVisitor<MatchASTVisitor>::dataTraverseNode(S, Queue);
590   }
591 
592   // Matches children or descendants of 'Node' with 'BaseMatcher'.
593   bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx,
594                                   const DynTypedMatcher &Matcher,
595                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
596                                   BindKind Bind) {
597     // For AST-nodes that don't have an identity, we can't memoize.
598     if (!Node.getMemoizationData() || !Builder->isComparable())
599       return matchesRecursively(Node, Matcher, Builder, MaxDepth, Bind);
600 
601     MatchKey Key;
602     Key.MatcherID = Matcher.getID();
603     Key.Node = Node;
604     // Note that we key on the bindings *before* the match.
605     Key.BoundNodes = *Builder;
606     Key.Traversal = Ctx.getParentMapContext().getTraversalKind();
607     // Memoize result even doing a single-level match, it might be expensive.
608     Key.Type = MaxDepth == 1 ? MatchType::Child : MatchType::Descendants;
609     MemoizationMap::iterator I = ResultCache.find(Key);
610     if (I != ResultCache.end()) {
611       *Builder = I->second.Nodes;
612       return I->second.ResultOfMatch;
613     }
614 
615     MemoizedMatchResult Result;
616     Result.Nodes = *Builder;
617     Result.ResultOfMatch =
618         matchesRecursively(Node, Matcher, &Result.Nodes, MaxDepth, Bind);
619 
620     MemoizedMatchResult &CachedResult = ResultCache[Key];
621     CachedResult = std::move(Result);
622 
623     *Builder = CachedResult.Nodes;
624     return CachedResult.ResultOfMatch;
625   }
626 
627   // Matches children or descendants of 'Node' with 'BaseMatcher'.
628   bool matchesRecursively(const DynTypedNode &Node,
629                           const DynTypedMatcher &Matcher,
630                           BoundNodesTreeBuilder *Builder, int MaxDepth,
631                           BindKind Bind) {
632     bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
633                            TraversingASTChildrenNotSpelledInSource;
634 
635     bool IgnoreImplicitChildren = false;
636 
637     if (isTraversalIgnoringImplicitNodes()) {
638       IgnoreImplicitChildren = true;
639     }
640 
641     ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
642 
643     MatchChildASTVisitor Visitor(&Matcher, this, Builder, MaxDepth,
644                                  IgnoreImplicitChildren, Bind);
645     return Visitor.findMatch(Node);
646   }
647 
648   bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
649                           const Matcher<NamedDecl> &Base,
650                           BoundNodesTreeBuilder *Builder,
651                           bool Directly) override;
652 
653   bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration,
654                               const Matcher<NamedDecl> &Base,
655                               BoundNodesTreeBuilder *Builder,
656                               bool Directly) override;
657 
658   // Implements ASTMatchFinder::matchesChildOf.
659   bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
660                       const DynTypedMatcher &Matcher,
661                       BoundNodesTreeBuilder *Builder, BindKind Bind) override {
662     if (ResultCache.size() > MaxMemoizationEntries)
663       ResultCache.clear();
664     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Bind);
665   }
666   // Implements ASTMatchFinder::matchesDescendantOf.
667   bool matchesDescendantOf(const DynTypedNode &Node, ASTContext &Ctx,
668                            const DynTypedMatcher &Matcher,
669                            BoundNodesTreeBuilder *Builder,
670                            BindKind Bind) override {
671     if (ResultCache.size() > MaxMemoizationEntries)
672       ResultCache.clear();
673     return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
674                                       Bind);
675   }
676   // Implements ASTMatchFinder::matchesAncestorOf.
677   bool matchesAncestorOf(const DynTypedNode &Node, ASTContext &Ctx,
678                          const DynTypedMatcher &Matcher,
679                          BoundNodesTreeBuilder *Builder,
680                          AncestorMatchMode MatchMode) override {
681     // Reset the cache outside of the recursive call to make sure we
682     // don't invalidate any iterators.
683     if (ResultCache.size() > MaxMemoizationEntries)
684       ResultCache.clear();
685     if (MatchMode == AncestorMatchMode::AMM_ParentOnly)
686       return matchesParentOf(Node, Matcher, Builder);
687     return matchesAnyAncestorOf(Node, Ctx, Matcher, Builder);
688   }
689 
690   // Matches all registered matchers on the given node and calls the
691   // result callback for every node that matches.
692   void match(const DynTypedNode &Node) {
693     // FIXME: Improve this with a switch or a visitor pattern.
694     if (auto *N = Node.get<Decl>()) {
695       match(*N);
696     } else if (auto *N = Node.get<Stmt>()) {
697       match(*N);
698     } else if (auto *N = Node.get<Type>()) {
699       match(*N);
700     } else if (auto *N = Node.get<QualType>()) {
701       match(*N);
702     } else if (auto *N = Node.get<NestedNameSpecifier>()) {
703       match(*N);
704     } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) {
705       match(*N);
706     } else if (auto *N = Node.get<TypeLoc>()) {
707       match(*N);
708     } else if (auto *N = Node.get<CXXCtorInitializer>()) {
709       match(*N);
710     } else if (auto *N = Node.get<TemplateArgumentLoc>()) {
711       match(*N);
712     } else if (auto *N = Node.get<Attr>()) {
713       match(*N);
714     }
715   }
716 
717   template <typename T> void match(const T &Node) {
718     matchDispatch(&Node);
719   }
720 
721   // Implements ASTMatchFinder::getASTContext.
722   ASTContext &getASTContext() const override { return *ActiveASTContext; }
723 
724   bool shouldVisitTemplateInstantiations() const { return true; }
725   bool shouldVisitImplicitCode() const { return true; }
726 
727   // We visit the lambda body explicitly, so instruct the RAV
728   // to not visit it on our behalf too.
729   bool shouldVisitLambdaBody() const { return false; }
730 
731   bool IsMatchingInASTNodeNotSpelledInSource() const override {
732     return TraversingASTNodeNotSpelledInSource;
733   }
734   bool isMatchingChildrenNotSpelledInSource() const override {
735     return TraversingASTChildrenNotSpelledInSource;
736   }
737   void setMatchingChildrenNotSpelledInSource(bool Set) override {
738     TraversingASTChildrenNotSpelledInSource = Set;
739   }
740 
741   bool IsMatchingInASTNodeNotAsIs() const override {
742     return TraversingASTNodeNotAsIs;
743   }
744 
745   bool TraverseTemplateInstantiations(ClassTemplateDecl *D) {
746     ASTNodeNotSpelledInSourceScope RAII(this, true);
747     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
748         D);
749   }
750 
751   bool TraverseTemplateInstantiations(VarTemplateDecl *D) {
752     ASTNodeNotSpelledInSourceScope RAII(this, true);
753     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
754         D);
755   }
756 
757   bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) {
758     ASTNodeNotSpelledInSourceScope RAII(this, true);
759     return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
760         D);
761   }
762 
763 private:
764   bool TraversingASTNodeNotSpelledInSource = false;
765   bool TraversingASTNodeNotAsIs = false;
766   bool TraversingASTChildrenNotSpelledInSource = false;
767 
768   struct ASTNodeNotSpelledInSourceScope {
769     ASTNodeNotSpelledInSourceScope(MatchASTVisitor *V, bool B)
770         : MV(V), MB(V->TraversingASTNodeNotSpelledInSource) {
771       V->TraversingASTNodeNotSpelledInSource = B;
772     }
773     ~ASTNodeNotSpelledInSourceScope() {
774       MV->TraversingASTNodeNotSpelledInSource = MB;
775     }
776 
777   private:
778     MatchASTVisitor *MV;
779     bool MB;
780   };
781 
782   struct ASTNodeNotAsIsSourceScope {
783     ASTNodeNotAsIsSourceScope(MatchASTVisitor *V, bool B)
784         : MV(V), MB(V->TraversingASTNodeNotAsIs) {
785       V->TraversingASTNodeNotAsIs = B;
786     }
787     ~ASTNodeNotAsIsSourceScope() { MV->TraversingASTNodeNotAsIs = MB; }
788 
789   private:
790     MatchASTVisitor *MV;
791     bool MB;
792   };
793 
794   class TimeBucketRegion {
795   public:
796     TimeBucketRegion() : Bucket(nullptr) {}
797     ~TimeBucketRegion() { setBucket(nullptr); }
798 
799     /// Start timing for \p NewBucket.
800     ///
801     /// If there was a bucket already set, it will finish the timing for that
802     /// other bucket.
803     /// \p NewBucket will be timed until the next call to \c setBucket() or
804     /// until the \c TimeBucketRegion is destroyed.
805     /// If \p NewBucket is the same as the currently timed bucket, this call
806     /// does nothing.
807     void setBucket(llvm::TimeRecord *NewBucket) {
808       if (Bucket != NewBucket) {
809         auto Now = llvm::TimeRecord::getCurrentTime(true);
810         if (Bucket)
811           *Bucket += Now;
812         if (NewBucket)
813           *NewBucket -= Now;
814         Bucket = NewBucket;
815       }
816     }
817 
818   private:
819     llvm::TimeRecord *Bucket;
820   };
821 
822   /// Runs all the \p Matchers on \p Node.
823   ///
824   /// Used by \c matchDispatch() below.
825   template <typename T, typename MC>
826   void matchWithoutFilter(const T &Node, const MC &Matchers) {
827     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
828     TimeBucketRegion Timer;
829     for (const auto &MP : Matchers) {
830       if (EnableCheckProfiling)
831         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
832       BoundNodesTreeBuilder Builder;
833       if (MP.first.matches(Node, this, &Builder)) {
834         MatchVisitor Visitor(ActiveASTContext, MP.second);
835         Builder.visitMatches(&Visitor);
836       }
837     }
838   }
839 
840   void matchWithFilter(const DynTypedNode &DynNode) {
841     auto Kind = DynNode.getNodeKind();
842     auto it = MatcherFiltersMap.find(Kind);
843     const auto &Filter =
844         it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind);
845 
846     if (Filter.empty())
847       return;
848 
849     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
850     TimeBucketRegion Timer;
851     auto &Matchers = this->Matchers->DeclOrStmt;
852     for (unsigned short I : Filter) {
853       auto &MP = Matchers[I];
854       if (EnableCheckProfiling)
855         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
856       BoundNodesTreeBuilder Builder;
857 
858       {
859         TraversalKindScope RAII(getASTContext(), MP.first.getTraversalKind());
860         if (getASTContext().getParentMapContext().traverseIgnored(DynNode) !=
861             DynNode)
862           continue;
863       }
864 
865       if (MP.first.matches(DynNode, this, &Builder)) {
866         MatchVisitor Visitor(ActiveASTContext, MP.second);
867         Builder.visitMatches(&Visitor);
868       }
869     }
870   }
871 
872   const std::vector<unsigned short> &getFilterForKind(ASTNodeKind Kind) {
873     auto &Filter = MatcherFiltersMap[Kind];
874     auto &Matchers = this->Matchers->DeclOrStmt;
875     assert((Matchers.size() < USHRT_MAX) && "Too many matchers.");
876     for (unsigned I = 0, E = Matchers.size(); I != E; ++I) {
877       if (Matchers[I].first.canMatchNodesOfKind(Kind)) {
878         Filter.push_back(I);
879       }
880     }
881     return Filter;
882   }
883 
884   /// @{
885   /// Overloads to pair the different node types to their matchers.
886   void matchDispatch(const Decl *Node) {
887     return matchWithFilter(DynTypedNode::create(*Node));
888   }
889   void matchDispatch(const Stmt *Node) {
890     return matchWithFilter(DynTypedNode::create(*Node));
891   }
892 
893   void matchDispatch(const Type *Node) {
894     matchWithoutFilter(QualType(Node, 0), Matchers->Type);
895   }
896   void matchDispatch(const TypeLoc *Node) {
897     matchWithoutFilter(*Node, Matchers->TypeLoc);
898   }
899   void matchDispatch(const QualType *Node) {
900     matchWithoutFilter(*Node, Matchers->Type);
901   }
902   void matchDispatch(const NestedNameSpecifier *Node) {
903     matchWithoutFilter(*Node, Matchers->NestedNameSpecifier);
904   }
905   void matchDispatch(const NestedNameSpecifierLoc *Node) {
906     matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc);
907   }
908   void matchDispatch(const CXXCtorInitializer *Node) {
909     matchWithoutFilter(*Node, Matchers->CtorInit);
910   }
911   void matchDispatch(const TemplateArgumentLoc *Node) {
912     matchWithoutFilter(*Node, Matchers->TemplateArgumentLoc);
913   }
914   void matchDispatch(const Attr *Node) {
915     matchWithoutFilter(*Node, Matchers->Attr);
916   }
917   void matchDispatch(const void *) { /* Do nothing. */ }
918   /// @}
919 
920   // Returns whether a direct parent of \p Node matches \p Matcher.
921   // Unlike matchesAnyAncestorOf there's no memoization: it doesn't save much.
922   bool matchesParentOf(const DynTypedNode &Node, const DynTypedMatcher &Matcher,
923                        BoundNodesTreeBuilder *Builder) {
924     for (const auto &Parent : ActiveASTContext->getParents(Node)) {
925       BoundNodesTreeBuilder BuilderCopy = *Builder;
926       if (Matcher.matches(Parent, this, &BuilderCopy)) {
927         *Builder = std::move(BuilderCopy);
928         return true;
929       }
930     }
931     return false;
932   }
933 
934   // Returns whether an ancestor of \p Node matches \p Matcher.
935   //
936   // The order of matching (which can lead to different nodes being bound in
937   // case there are multiple matches) is breadth first search.
938   //
939   // To allow memoization in the very common case of having deeply nested
940   // expressions inside a template function, we first walk up the AST, memoizing
941   // the result of the match along the way, as long as there is only a single
942   // parent.
943   //
944   // Once there are multiple parents, the breadth first search order does not
945   // allow simple memoization on the ancestors. Thus, we only memoize as long
946   // as there is a single parent.
947   //
948   // We avoid a recursive implementation to prevent excessive stack use on
949   // very deep ASTs (similarly to RecursiveASTVisitor's data recursion).
950   bool matchesAnyAncestorOf(DynTypedNode Node, ASTContext &Ctx,
951                             const DynTypedMatcher &Matcher,
952                             BoundNodesTreeBuilder *Builder) {
953 
954     // Memoization keys that can be updated with the result.
955     // These are the memoizable nodes in the chain of unique parents, which
956     // terminates when a node has multiple parents, or matches, or is the root.
957     std::vector<MatchKey> Keys;
958     // When returning, update the memoization cache.
959     auto Finish = [&](bool Matched) {
960       for (const auto &Key : Keys) {
961         MemoizedMatchResult &CachedResult = ResultCache[Key];
962         CachedResult.ResultOfMatch = Matched;
963         CachedResult.Nodes = *Builder;
964       }
965       return Matched;
966     };
967 
968     // Loop while there's a single parent and we want to attempt memoization.
969     DynTypedNodeList Parents{ArrayRef<DynTypedNode>()}; // after loop: size != 1
970     for (;;) {
971       // A cache key only makes sense if memoization is possible.
972       if (Builder->isComparable()) {
973         Keys.emplace_back();
974         Keys.back().MatcherID = Matcher.getID();
975         Keys.back().Node = Node;
976         Keys.back().BoundNodes = *Builder;
977         Keys.back().Traversal = Ctx.getParentMapContext().getTraversalKind();
978         Keys.back().Type = MatchType::Ancestors;
979 
980         // Check the cache.
981         MemoizationMap::iterator I = ResultCache.find(Keys.back());
982         if (I != ResultCache.end()) {
983           Keys.pop_back(); // Don't populate the cache for the matching node!
984           *Builder = I->second.Nodes;
985           return Finish(I->second.ResultOfMatch);
986         }
987       }
988 
989       Parents = ActiveASTContext->getParents(Node);
990       // Either no parents or multiple parents: leave chain+memoize mode and
991       // enter bfs+forgetful mode.
992       if (Parents.size() != 1)
993         break;
994 
995       // Check the next parent.
996       Node = *Parents.begin();
997       BoundNodesTreeBuilder BuilderCopy = *Builder;
998       if (Matcher.matches(Node, this, &BuilderCopy)) {
999         *Builder = std::move(BuilderCopy);
1000         return Finish(true);
1001       }
1002     }
1003     // We reached the end of the chain.
1004 
1005     if (Parents.empty()) {
1006       // Nodes may have no parents if:
1007       //  a) the node is the TranslationUnitDecl
1008       //  b) we have a limited traversal scope that excludes the parent edges
1009       //  c) there is a bug in the AST, and the node is not reachable
1010       // Usually the traversal scope is the whole AST, which precludes b.
1011       // Bugs are common enough that it's worthwhile asserting when we can.
1012 #ifndef NDEBUG
1013       if (!Node.get<TranslationUnitDecl>() &&
1014           /* Traversal scope is full AST if any of the bounds are the TU */
1015           llvm::any_of(ActiveASTContext->getTraversalScope(), [](Decl *D) {
1016             return D->getKind() == Decl::TranslationUnit;
1017           })) {
1018         llvm::errs() << "Tried to match orphan node:\n";
1019         Node.dump(llvm::errs(), *ActiveASTContext);
1020         llvm_unreachable("Parent map should be complete!");
1021       }
1022 #endif
1023     } else {
1024       assert(Parents.size() > 1);
1025       // BFS starting from the parents not yet considered.
1026       // Memoization of newly visited nodes is not possible (but we still update
1027       // results for the elements in the chain we found above).
1028       std::deque<DynTypedNode> Queue(Parents.begin(), Parents.end());
1029       llvm::DenseSet<const void *> Visited;
1030       while (!Queue.empty()) {
1031         BoundNodesTreeBuilder BuilderCopy = *Builder;
1032         if (Matcher.matches(Queue.front(), this, &BuilderCopy)) {
1033           *Builder = std::move(BuilderCopy);
1034           return Finish(true);
1035         }
1036         for (const auto &Parent : ActiveASTContext->getParents(Queue.front())) {
1037           // Make sure we do not visit the same node twice.
1038           // Otherwise, we'll visit the common ancestors as often as there
1039           // are splits on the way down.
1040           if (Visited.insert(Parent.getMemoizationData()).second)
1041             Queue.push_back(Parent);
1042         }
1043         Queue.pop_front();
1044       }
1045     }
1046     return Finish(false);
1047   }
1048 
1049   // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
1050   // the aggregated bound nodes for each match.
1051   class MatchVisitor : public BoundNodesTreeBuilder::Visitor {
1052   public:
1053     MatchVisitor(ASTContext* Context,
1054                  MatchFinder::MatchCallback* Callback)
1055       : Context(Context),
1056         Callback(Callback) {}
1057 
1058     void visitMatch(const BoundNodes& BoundNodesView) override {
1059       TraversalKindScope RAII(*Context, Callback->getCheckTraversalKind());
1060       Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
1061     }
1062 
1063   private:
1064     ASTContext* Context;
1065     MatchFinder::MatchCallback* Callback;
1066   };
1067 
1068   // Returns true if 'TypeNode' has an alias that matches the given matcher.
1069   bool typeHasMatchingAlias(const Type *TypeNode,
1070                             const Matcher<NamedDecl> &Matcher,
1071                             BoundNodesTreeBuilder *Builder) {
1072     const Type *const CanonicalType =
1073       ActiveASTContext->getCanonicalType(TypeNode);
1074     auto Aliases = TypeAliases.find(CanonicalType);
1075     if (Aliases == TypeAliases.end())
1076       return false;
1077     for (const TypedefNameDecl *Alias : Aliases->second) {
1078       BoundNodesTreeBuilder Result(*Builder);
1079       if (Matcher.matches(*Alias, this, &Result)) {
1080         *Builder = std::move(Result);
1081         return true;
1082       }
1083     }
1084     return false;
1085   }
1086 
1087   bool
1088   objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl *InterfaceDecl,
1089                                          const Matcher<NamedDecl> &Matcher,
1090                                          BoundNodesTreeBuilder *Builder) {
1091     auto Aliases = CompatibleAliases.find(InterfaceDecl);
1092     if (Aliases == CompatibleAliases.end())
1093       return false;
1094     for (const ObjCCompatibleAliasDecl *Alias : Aliases->second) {
1095       BoundNodesTreeBuilder Result(*Builder);
1096       if (Matcher.matches(*Alias, this, &Result)) {
1097         *Builder = std::move(Result);
1098         return true;
1099       }
1100     }
1101     return false;
1102   }
1103 
1104   /// Bucket to record map.
1105   ///
1106   /// Used to get the appropriate bucket for each matcher.
1107   llvm::StringMap<llvm::TimeRecord> TimeByBucket;
1108 
1109   const MatchFinder::MatchersByType *Matchers;
1110 
1111   /// Filtered list of matcher indices for each matcher kind.
1112   ///
1113   /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node
1114   /// kind (and derived kinds) so it is a waste to try every matcher on every
1115   /// node.
1116   /// We precalculate a list of matchers that pass the toplevel restrict check.
1117   llvm::DenseMap<ASTNodeKind, std::vector<unsigned short>> MatcherFiltersMap;
1118 
1119   const MatchFinder::MatchFinderOptions &Options;
1120   ASTContext *ActiveASTContext;
1121 
1122   // Maps a canonical type to its TypedefDecls.
1123   llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases;
1124 
1125   // Maps an Objective-C interface to its ObjCCompatibleAliasDecls.
1126   llvm::DenseMap<const ObjCInterfaceDecl *,
1127                  llvm::SmallPtrSet<const ObjCCompatibleAliasDecl *, 2>>
1128       CompatibleAliases;
1129 
1130   // Maps (matcher, node) -> the match result for memoization.
1131   typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap;
1132   MemoizationMap ResultCache;
1133 };
1134 
1135 static CXXRecordDecl *
1136 getAsCXXRecordDeclOrPrimaryTemplate(const Type *TypeNode) {
1137   if (auto *RD = TypeNode->getAsCXXRecordDecl())
1138     return RD;
1139 
1140   // Find the innermost TemplateSpecializationType that isn't an alias template.
1141   auto *TemplateType = TypeNode->getAs<TemplateSpecializationType>();
1142   while (TemplateType && TemplateType->isTypeAlias())
1143     TemplateType =
1144         TemplateType->getAliasedType()->getAs<TemplateSpecializationType>();
1145 
1146   // If this is the name of a (dependent) template specialization, use the
1147   // definition of the template, even though it might be specialized later.
1148   if (TemplateType)
1149     if (auto *ClassTemplate = dyn_cast_or_null<ClassTemplateDecl>(
1150           TemplateType->getTemplateName().getAsTemplateDecl()))
1151       return ClassTemplate->getTemplatedDecl();
1152 
1153   return nullptr;
1154 }
1155 
1156 // Returns true if the given C++ class is directly or indirectly derived
1157 // from a base type with the given name.  A class is not considered to be
1158 // derived from itself.
1159 bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
1160                                          const Matcher<NamedDecl> &Base,
1161                                          BoundNodesTreeBuilder *Builder,
1162                                          bool Directly) {
1163   if (!Declaration->hasDefinition())
1164     return false;
1165   for (const auto &It : Declaration->bases()) {
1166     const Type *TypeNode = It.getType().getTypePtr();
1167 
1168     if (typeHasMatchingAlias(TypeNode, Base, Builder))
1169       return true;
1170 
1171     // FIXME: Going to the primary template here isn't really correct, but
1172     // unfortunately we accept a Decl matcher for the base class not a Type
1173     // matcher, so it's the best thing we can do with our current interface.
1174     CXXRecordDecl *ClassDecl = getAsCXXRecordDeclOrPrimaryTemplate(TypeNode);
1175     if (!ClassDecl)
1176       continue;
1177     if (ClassDecl == Declaration) {
1178       // This can happen for recursive template definitions.
1179       continue;
1180     }
1181     BoundNodesTreeBuilder Result(*Builder);
1182     if (Base.matches(*ClassDecl, this, &Result)) {
1183       *Builder = std::move(Result);
1184       return true;
1185     }
1186     if (!Directly && classIsDerivedFrom(ClassDecl, Base, Builder, Directly))
1187       return true;
1188   }
1189   return false;
1190 }
1191 
1192 // Returns true if the given Objective-C class is directly or indirectly
1193 // derived from a matching base class. A class is not considered to be derived
1194 // from itself.
1195 bool MatchASTVisitor::objcClassIsDerivedFrom(
1196     const ObjCInterfaceDecl *Declaration, const Matcher<NamedDecl> &Base,
1197     BoundNodesTreeBuilder *Builder, bool Directly) {
1198   // Check if any of the superclasses of the class match.
1199   for (const ObjCInterfaceDecl *ClassDecl = Declaration->getSuperClass();
1200        ClassDecl != nullptr; ClassDecl = ClassDecl->getSuperClass()) {
1201     // Check if there are any matching compatibility aliases.
1202     if (objcClassHasMatchingCompatibilityAlias(ClassDecl, Base, Builder))
1203       return true;
1204 
1205     // Check if there are any matching type aliases.
1206     const Type *TypeNode = ClassDecl->getTypeForDecl();
1207     if (typeHasMatchingAlias(TypeNode, Base, Builder))
1208       return true;
1209 
1210     if (Base.matches(*ClassDecl, this, Builder))
1211       return true;
1212 
1213     // Not `return false` as a temporary workaround for PR43879.
1214     if (Directly)
1215       break;
1216   }
1217 
1218   return false;
1219 }
1220 
1221 bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
1222   if (!DeclNode) {
1223     return true;
1224   }
1225 
1226   bool ScopedTraversal =
1227       TraversingASTNodeNotSpelledInSource || DeclNode->isImplicit();
1228   bool ScopedChildren = TraversingASTChildrenNotSpelledInSource;
1229 
1230   if (const auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(DeclNode)) {
1231     auto SK = CTSD->getSpecializationKind();
1232     if (SK == TSK_ExplicitInstantiationDeclaration ||
1233         SK == TSK_ExplicitInstantiationDefinition)
1234       ScopedChildren = true;
1235   } else if (const auto *FD = dyn_cast<FunctionDecl>(DeclNode)) {
1236     if (FD->isDefaulted())
1237       ScopedChildren = true;
1238     if (FD->isTemplateInstantiation())
1239       ScopedTraversal = true;
1240   } else if (isa<BindingDecl>(DeclNode)) {
1241     ScopedChildren = true;
1242   }
1243 
1244   ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
1245   ASTChildrenNotSpelledInSourceScope RAII2(this, ScopedChildren);
1246 
1247   match(*DeclNode);
1248   return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
1249 }
1250 
1251 bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue) {
1252   if (!StmtNode) {
1253     return true;
1254   }
1255   bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
1256                          TraversingASTChildrenNotSpelledInSource;
1257 
1258   ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
1259   match(*StmtNode);
1260   return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode, Queue);
1261 }
1262 
1263 bool MatchASTVisitor::TraverseType(QualType TypeNode) {
1264   match(TypeNode);
1265   return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
1266 }
1267 
1268 bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) {
1269   // The RecursiveASTVisitor only visits types if they're not within TypeLocs.
1270   // We still want to find those types via matchers, so we match them here. Note
1271   // that the TypeLocs are structurally a shadow-hierarchy to the expressed
1272   // type, so we visit all involved parts of a compound type when matching on
1273   // each TypeLoc.
1274   match(TypeLocNode);
1275   match(TypeLocNode.getType());
1276   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode);
1277 }
1278 
1279 bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
1280   match(*NNS);
1281   return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS);
1282 }
1283 
1284 bool MatchASTVisitor::TraverseNestedNameSpecifierLoc(
1285     NestedNameSpecifierLoc NNS) {
1286   if (!NNS)
1287     return true;
1288 
1289   match(NNS);
1290 
1291   // We only match the nested name specifier here (as opposed to traversing it)
1292   // because the traversal is already done in the parallel "Loc"-hierarchy.
1293   if (NNS.hasQualifier())
1294     match(*NNS.getNestedNameSpecifier());
1295   return
1296       RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS);
1297 }
1298 
1299 bool MatchASTVisitor::TraverseConstructorInitializer(
1300     CXXCtorInitializer *CtorInit) {
1301   if (!CtorInit)
1302     return true;
1303 
1304   bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
1305                          TraversingASTChildrenNotSpelledInSource;
1306 
1307   if (!CtorInit->isWritten())
1308     ScopedTraversal = true;
1309 
1310   ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
1311 
1312   match(*CtorInit);
1313 
1314   return RecursiveASTVisitor<MatchASTVisitor>::TraverseConstructorInitializer(
1315       CtorInit);
1316 }
1317 
1318 bool MatchASTVisitor::TraverseTemplateArgumentLoc(TemplateArgumentLoc Loc) {
1319   match(Loc);
1320   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateArgumentLoc(Loc);
1321 }
1322 
1323 bool MatchASTVisitor::TraverseAttr(Attr *AttrNode) {
1324   match(*AttrNode);
1325   return RecursiveASTVisitor<MatchASTVisitor>::TraverseAttr(AttrNode);
1326 }
1327 
1328 class MatchASTConsumer : public ASTConsumer {
1329 public:
1330   MatchASTConsumer(MatchFinder *Finder,
1331                    MatchFinder::ParsingDoneTestCallback *ParsingDone)
1332       : Finder(Finder), ParsingDone(ParsingDone) {}
1333 
1334 private:
1335   void HandleTranslationUnit(ASTContext &Context) override {
1336     if (ParsingDone != nullptr) {
1337       ParsingDone->run();
1338     }
1339     Finder->matchAST(Context);
1340   }
1341 
1342   MatchFinder *Finder;
1343   MatchFinder::ParsingDoneTestCallback *ParsingDone;
1344 };
1345 
1346 } // end namespace
1347 } // end namespace internal
1348 
1349 MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
1350                                       ASTContext *Context)
1351   : Nodes(Nodes), Context(Context),
1352     SourceManager(&Context->getSourceManager()) {}
1353 
1354 MatchFinder::MatchCallback::~MatchCallback() {}
1355 MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
1356 
1357 MatchFinder::MatchFinder(MatchFinderOptions Options)
1358     : Options(std::move(Options)), ParsingDone(nullptr) {}
1359 
1360 MatchFinder::~MatchFinder() {}
1361 
1362 void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
1363                              MatchCallback *Action) {
1364   llvm::Optional<TraversalKind> TK;
1365   if (Action)
1366     TK = Action->getCheckTraversalKind();
1367   if (TK)
1368     Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
1369   else
1370     Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1371   Matchers.AllCallbacks.insert(Action);
1372 }
1373 
1374 void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
1375                              MatchCallback *Action) {
1376   Matchers.Type.emplace_back(NodeMatch, Action);
1377   Matchers.AllCallbacks.insert(Action);
1378 }
1379 
1380 void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
1381                              MatchCallback *Action) {
1382   llvm::Optional<TraversalKind> TK;
1383   if (Action)
1384     TK = Action->getCheckTraversalKind();
1385   if (TK)
1386     Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
1387   else
1388     Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1389   Matchers.AllCallbacks.insert(Action);
1390 }
1391 
1392 void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch,
1393                              MatchCallback *Action) {
1394   Matchers.NestedNameSpecifier.emplace_back(NodeMatch, Action);
1395   Matchers.AllCallbacks.insert(Action);
1396 }
1397 
1398 void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch,
1399                              MatchCallback *Action) {
1400   Matchers.NestedNameSpecifierLoc.emplace_back(NodeMatch, Action);
1401   Matchers.AllCallbacks.insert(Action);
1402 }
1403 
1404 void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch,
1405                              MatchCallback *Action) {
1406   Matchers.TypeLoc.emplace_back(NodeMatch, Action);
1407   Matchers.AllCallbacks.insert(Action);
1408 }
1409 
1410 void MatchFinder::addMatcher(const CXXCtorInitializerMatcher &NodeMatch,
1411                              MatchCallback *Action) {
1412   Matchers.CtorInit.emplace_back(NodeMatch, Action);
1413   Matchers.AllCallbacks.insert(Action);
1414 }
1415 
1416 void MatchFinder::addMatcher(const TemplateArgumentLocMatcher &NodeMatch,
1417                              MatchCallback *Action) {
1418   Matchers.TemplateArgumentLoc.emplace_back(NodeMatch, Action);
1419   Matchers.AllCallbacks.insert(Action);
1420 }
1421 
1422 void MatchFinder::addMatcher(const AttrMatcher &AttrMatch,
1423                              MatchCallback *Action) {
1424   Matchers.Attr.emplace_back(AttrMatch, Action);
1425   Matchers.AllCallbacks.insert(Action);
1426 }
1427 
1428 bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
1429                                     MatchCallback *Action) {
1430   if (NodeMatch.canConvertTo<Decl>()) {
1431     addMatcher(NodeMatch.convertTo<Decl>(), Action);
1432     return true;
1433   } else if (NodeMatch.canConvertTo<QualType>()) {
1434     addMatcher(NodeMatch.convertTo<QualType>(), Action);
1435     return true;
1436   } else if (NodeMatch.canConvertTo<Stmt>()) {
1437     addMatcher(NodeMatch.convertTo<Stmt>(), Action);
1438     return true;
1439   } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) {
1440     addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action);
1441     return true;
1442   } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) {
1443     addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action);
1444     return true;
1445   } else if (NodeMatch.canConvertTo<TypeLoc>()) {
1446     addMatcher(NodeMatch.convertTo<TypeLoc>(), Action);
1447     return true;
1448   } else if (NodeMatch.canConvertTo<CXXCtorInitializer>()) {
1449     addMatcher(NodeMatch.convertTo<CXXCtorInitializer>(), Action);
1450     return true;
1451   } else if (NodeMatch.canConvertTo<TemplateArgumentLoc>()) {
1452     addMatcher(NodeMatch.convertTo<TemplateArgumentLoc>(), Action);
1453     return true;
1454   } else if (NodeMatch.canConvertTo<Attr>()) {
1455     addMatcher(NodeMatch.convertTo<Attr>(), Action);
1456     return true;
1457   }
1458   return false;
1459 }
1460 
1461 std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() {
1462   return std::make_unique<internal::MatchASTConsumer>(this, ParsingDone);
1463 }
1464 
1465 void MatchFinder::match(const clang::DynTypedNode &Node, ASTContext &Context) {
1466   internal::MatchASTVisitor Visitor(&Matchers, Options);
1467   Visitor.set_active_ast_context(&Context);
1468   Visitor.match(Node);
1469 }
1470 
1471 void MatchFinder::matchAST(ASTContext &Context) {
1472   internal::MatchASTVisitor Visitor(&Matchers, Options);
1473   Visitor.set_active_ast_context(&Context);
1474   Visitor.onStartOfTranslationUnit();
1475   Visitor.TraverseAST(Context);
1476   Visitor.onEndOfTranslationUnit();
1477 }
1478 
1479 void MatchFinder::registerTestCallbackAfterParsing(
1480     MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
1481   ParsingDone = NewParsingDone;
1482 }
1483 
1484 StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; }
1485 
1486 llvm::Optional<TraversalKind>
1487 MatchFinder::MatchCallback::getCheckTraversalKind() const {
1488   return llvm::None;
1489 }
1490 
1491 } // end namespace ast_matchers
1492 } // end namespace clang
1493