xref: /freebsd/contrib/llvm-project/clang/lib/AST/ParentMapContext.cpp (revision e64bea71c21eb42e97aa615188ba91f6cce0d36d)
1 //===- ParentMapContext.cpp - Map of parents using DynTypedNode -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Similar to ParentMap.cpp, but generalizes to non-Stmt nodes, which can have
10 // multiple parents.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/AST/ParentMapContext.h"
15 #include "clang/AST/Decl.h"
16 #include "clang/AST/Expr.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/TemplateBase.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 
21 using namespace clang;
22 
23 ParentMapContext::ParentMapContext(ASTContext &Ctx) : ASTCtx(Ctx) {}
24 
25 ParentMapContext::~ParentMapContext() = default;
26 
27 void ParentMapContext::clear() { Parents.reset(); }
28 
29 const Expr *ParentMapContext::traverseIgnored(const Expr *E) const {
30   return traverseIgnored(const_cast<Expr *>(E));
31 }
32 
33 Expr *ParentMapContext::traverseIgnored(Expr *E) const {
34   if (!E)
35     return nullptr;
36 
37   switch (Traversal) {
38   case TK_AsIs:
39     return E;
40   case TK_IgnoreUnlessSpelledInSource:
41     return E->IgnoreUnlessSpelledInSource();
42   }
43   llvm_unreachable("Invalid Traversal type!");
44 }
45 
46 DynTypedNode ParentMapContext::traverseIgnored(const DynTypedNode &N) const {
47   if (const auto *E = N.get<Expr>()) {
48     return DynTypedNode::create(*traverseIgnored(E));
49   }
50   return N;
51 }
52 
53 template <typename T, typename... U>
54 static std::tuple<bool, DynTypedNodeList, const T *, const U *...>
55 matchParents(const DynTypedNodeList &NodeList,
56              ParentMapContext::ParentMap *ParentMap);
57 
58 template <typename, typename...> struct MatchParents;
59 
60 class ParentMapContext::ParentMap {
61 
62   template <typename, typename...> friend struct ::MatchParents;
63 
64   /// Contains parents of a node.
65   class ParentVector {
66   public:
67     ParentVector() = default;
68     explicit ParentVector(size_t N, const DynTypedNode &Value) {
69       Items.reserve(N);
70       for (; N > 0; --N)
71         push_back(Value);
72     }
73     bool contains(const DynTypedNode &Value) const {
74       const void *Identity = Value.getMemoizationData();
75       assert(Identity);
76       return Dedup.contains(Identity);
77     }
78     void push_back(const DynTypedNode &Value) {
79       const void *Identity = Value.getMemoizationData();
80       if (!Identity || Dedup.insert(Identity).second) {
81         Items.push_back(Value);
82       }
83     }
84     ArrayRef<DynTypedNode> view() const { return Items; }
85 
86   private:
87     llvm::SmallVector<DynTypedNode, 1> Items;
88     llvm::SmallPtrSet<const void *, 2> Dedup;
89   };
90 
91   /// Maps from a node to its parents. This is used for nodes that have
92   /// pointer identity only, which are more common and we can save space by
93   /// only storing a unique pointer to them.
94   using ParentMapPointers =
95       llvm::DenseMap<const void *,
96                      llvm::PointerUnion<const Decl *, const Stmt *,
97                                         DynTypedNode *, ParentVector *>>;
98 
99   /// Parent map for nodes without pointer identity. We store a full
100   /// DynTypedNode for all keys.
101   using ParentMapOtherNodes =
102       llvm::DenseMap<DynTypedNode,
103                      llvm::PointerUnion<const Decl *, const Stmt *,
104                                         DynTypedNode *, ParentVector *>>;
105 
106   ParentMapPointers PointerParents;
107   ParentMapOtherNodes OtherParents;
108   class ASTVisitor;
109 
110   static DynTypedNode
111   getSingleDynTypedNodeFromParentMap(ParentMapPointers::mapped_type U) {
112     if (const auto *D = dyn_cast<const Decl *>(U))
113       return DynTypedNode::create(*D);
114     if (const auto *S = dyn_cast<const Stmt *>(U))
115       return DynTypedNode::create(*S);
116     return *cast<DynTypedNode *>(U);
117   }
118 
119   template <typename NodeTy, typename MapTy>
120   static DynTypedNodeList getDynNodeFromMap(const NodeTy &Node,
121                                                         const MapTy &Map) {
122     auto I = Map.find(Node);
123     if (I == Map.end()) {
124       return ArrayRef<DynTypedNode>();
125     }
126     if (const auto *V = dyn_cast<ParentVector *>(I->second)) {
127       return V->view();
128     }
129     return getSingleDynTypedNodeFromParentMap(I->second);
130   }
131 
132 public:
133   ParentMap(ASTContext &Ctx);
134   ~ParentMap() {
135     for (const auto &Entry : PointerParents) {
136       if (auto *DTN = dyn_cast<DynTypedNode *>(Entry.second)) {
137         delete DTN;
138       } else if (auto *PV = dyn_cast<ParentVector *>(Entry.second)) {
139         delete PV;
140       }
141     }
142     for (const auto &Entry : OtherParents) {
143       if (auto *DTN = dyn_cast<DynTypedNode *>(Entry.second)) {
144         delete DTN;
145       } else if (auto *PV = dyn_cast<ParentVector *>(Entry.second)) {
146         delete PV;
147       }
148     }
149   }
150 
151   DynTypedNodeList getParents(TraversalKind TK, const DynTypedNode &Node) {
152     if (Node.getNodeKind().hasPointerIdentity()) {
153       auto ParentList =
154           getDynNodeFromMap(Node.getMemoizationData(), PointerParents);
155       if (ParentList.size() > 0 && TK == TK_IgnoreUnlessSpelledInSource) {
156 
157         const auto *ChildExpr = Node.get<Expr>();
158 
159         {
160           // Don't match explicit node types because different stdlib
161           // implementations implement this in different ways and have
162           // different intermediate nodes.
163           // Look up 4 levels for a cxxRewrittenBinaryOperator as that is
164           // enough for the major stdlib implementations.
165           auto RewrittenBinOpParentsList = ParentList;
166           int I = 0;
167           while (ChildExpr && RewrittenBinOpParentsList.size() == 1 &&
168                  I++ < 4) {
169             const auto *S = RewrittenBinOpParentsList[0].get<Stmt>();
170             if (!S)
171               break;
172 
173             const auto *RWBO = dyn_cast<CXXRewrittenBinaryOperator>(S);
174             if (!RWBO) {
175               RewrittenBinOpParentsList = getDynNodeFromMap(S, PointerParents);
176               continue;
177             }
178             if (RWBO->getLHS()->IgnoreUnlessSpelledInSource() != ChildExpr &&
179                 RWBO->getRHS()->IgnoreUnlessSpelledInSource() != ChildExpr)
180               break;
181             return DynTypedNode::create(*RWBO);
182           }
183         }
184 
185         const auto *ParentExpr = ParentList[0].get<Expr>();
186         if (ParentExpr && ChildExpr)
187           return AscendIgnoreUnlessSpelledInSource(ParentExpr, ChildExpr);
188 
189         {
190           auto AncestorNodes =
191               matchParents<DeclStmt, CXXForRangeStmt>(ParentList, this);
192           if (std::get<bool>(AncestorNodes) &&
193               std::get<const CXXForRangeStmt *>(AncestorNodes)
194                       ->getLoopVarStmt() ==
195                   std::get<const DeclStmt *>(AncestorNodes))
196             return std::get<DynTypedNodeList>(AncestorNodes);
197         }
198         {
199           auto AncestorNodes = matchParents<VarDecl, DeclStmt, CXXForRangeStmt>(
200               ParentList, this);
201           if (std::get<bool>(AncestorNodes) &&
202               std::get<const CXXForRangeStmt *>(AncestorNodes)
203                       ->getRangeStmt() ==
204                   std::get<const DeclStmt *>(AncestorNodes))
205             return std::get<DynTypedNodeList>(AncestorNodes);
206         }
207         {
208           auto AncestorNodes =
209               matchParents<CXXMethodDecl, CXXRecordDecl, LambdaExpr>(ParentList,
210                                                                      this);
211           if (std::get<bool>(AncestorNodes))
212             return std::get<DynTypedNodeList>(AncestorNodes);
213         }
214         {
215           auto AncestorNodes =
216               matchParents<FunctionTemplateDecl, CXXRecordDecl, LambdaExpr>(
217                   ParentList, this);
218           if (std::get<bool>(AncestorNodes))
219             return std::get<DynTypedNodeList>(AncestorNodes);
220         }
221       }
222       return ParentList;
223     }
224     return getDynNodeFromMap(Node, OtherParents);
225   }
226 
227   DynTypedNodeList AscendIgnoreUnlessSpelledInSource(const Expr *E,
228                                                      const Expr *Child) {
229 
230     auto ShouldSkip = [](const Expr *E, const Expr *Child) {
231       if (isa<ImplicitCastExpr>(E))
232         return true;
233 
234       if (isa<FullExpr>(E))
235         return true;
236 
237       if (isa<MaterializeTemporaryExpr>(E))
238         return true;
239 
240       if (isa<CXXBindTemporaryExpr>(E))
241         return true;
242 
243       if (isa<ParenExpr>(E))
244         return true;
245 
246       if (isa<ExprWithCleanups>(E))
247         return true;
248 
249       auto SR = Child->getSourceRange();
250 
251       if (const auto *C = dyn_cast<CXXFunctionalCastExpr>(E)) {
252         if (C->getSourceRange() == SR)
253           return true;
254       }
255 
256       if (const auto *C = dyn_cast<CXXConstructExpr>(E)) {
257         if (C->getSourceRange() == SR || C->isElidable())
258           return true;
259       }
260 
261       if (const auto *C = dyn_cast<CXXMemberCallExpr>(E)) {
262         if (C->getSourceRange() == SR)
263           return true;
264       }
265 
266       if (const auto *C = dyn_cast<MemberExpr>(E)) {
267         if (C->getSourceRange() == SR)
268           return true;
269       }
270       return false;
271     };
272 
273     while (ShouldSkip(E, Child)) {
274       auto It = PointerParents.find(E);
275       if (It == PointerParents.end())
276         break;
277       const auto *S = dyn_cast<const Stmt *>(It->second);
278       if (!S) {
279         if (auto *Vec = dyn_cast<ParentVector *>(It->second))
280           return Vec->view();
281         return getSingleDynTypedNodeFromParentMap(It->second);
282       }
283       const auto *P = dyn_cast<Expr>(S);
284       if (!P)
285         return DynTypedNode::create(*S);
286       Child = E;
287       E = P;
288     }
289     return DynTypedNode::create(*E);
290   }
291 };
292 
293 template <typename T, typename... U> struct MatchParents {
294   static std::tuple<bool, DynTypedNodeList, const T *, const U *...>
295   match(const DynTypedNodeList &NodeList,
296         ParentMapContext::ParentMap *ParentMap) {
297     if (const auto *TypedNode = NodeList[0].get<T>()) {
298       auto NextParentList =
299           ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
300       if (NextParentList.size() == 1) {
301         auto TailTuple = MatchParents<U...>::match(NextParentList, ParentMap);
302         if (std::get<bool>(TailTuple)) {
303           return std::apply(
304               [TypedNode](bool, DynTypedNodeList NodeList, auto... TupleTail) {
305                 return std::make_tuple(true, NodeList, TypedNode, TupleTail...);
306               },
307               TailTuple);
308         }
309       }
310     }
311     return std::tuple_cat(std::make_tuple(false, NodeList),
312                           std::tuple<const T *, const U *...>());
313   }
314 };
315 
316 template <typename T> struct MatchParents<T> {
317   static std::tuple<bool, DynTypedNodeList, const T *>
318   match(const DynTypedNodeList &NodeList,
319         ParentMapContext::ParentMap *ParentMap) {
320     if (const auto *TypedNode = NodeList[0].get<T>()) {
321       auto NextParentList =
322           ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
323       if (NextParentList.size() == 1)
324         return std::make_tuple(true, NodeList, TypedNode);
325     }
326     return std::make_tuple(false, NodeList, nullptr);
327   }
328 };
329 
330 template <typename T, typename... U>
331 std::tuple<bool, DynTypedNodeList, const T *, const U *...>
332 matchParents(const DynTypedNodeList &NodeList,
333              ParentMapContext::ParentMap *ParentMap) {
334   return MatchParents<T, U...>::match(NodeList, ParentMap);
335 }
336 
337 /// Template specializations to abstract away from pointers and TypeLocs.
338 /// @{
339 template <typename T> static DynTypedNode createDynTypedNode(const T &Node) {
340   return DynTypedNode::create(*Node);
341 }
342 template <> DynTypedNode createDynTypedNode(const TypeLoc &Node) {
343   return DynTypedNode::create(Node);
344 }
345 template <>
346 DynTypedNode createDynTypedNode(const NestedNameSpecifierLoc &Node) {
347   return DynTypedNode::create(Node);
348 }
349 template <> DynTypedNode createDynTypedNode(const ObjCProtocolLoc &Node) {
350   return DynTypedNode::create(Node);
351 }
352 /// @}
353 
354 /// A \c RecursiveASTVisitor that builds a map from nodes to their
355 /// parents as defined by the \c RecursiveASTVisitor.
356 ///
357 /// Note that the relationship described here is purely in terms of AST
358 /// traversal - there are other relationships (for example declaration context)
359 /// in the AST that are better modeled by special matchers.
360 class ParentMapContext::ParentMap::ASTVisitor
361     : public RecursiveASTVisitor<ASTVisitor> {
362 public:
363   ASTVisitor(ParentMap &Map) : Map(Map) {}
364 
365 private:
366   friend class RecursiveASTVisitor<ASTVisitor>;
367 
368   using VisitorBase = RecursiveASTVisitor<ASTVisitor>;
369 
370   bool shouldVisitTemplateInstantiations() const { return true; }
371 
372   bool shouldVisitImplicitCode() const { return true; }
373 
374   /// Record the parent of the node we're visiting.
375   /// MapNode is the child, the parent is on top of ParentStack.
376   /// Parents is the parent storage (either PointerParents or OtherParents).
377   template <typename MapNodeTy, typename MapTy>
378   void addParent(MapNodeTy MapNode, MapTy *Parents) {
379     if (ParentStack.empty())
380       return;
381 
382     // FIXME: Currently we add the same parent multiple times, but only
383     // when no memoization data is available for the type.
384     // For example when we visit all subexpressions of template
385     // instantiations; this is suboptimal, but benign: the only way to
386     // visit those is with hasAncestor / hasParent, and those do not create
387     // new matches.
388     // The plan is to enable DynTypedNode to be storable in a map or hash
389     // map. The main problem there is to implement hash functions /
390     // comparison operators for all types that DynTypedNode supports that
391     // do not have pointer identity.
392     auto &NodeOrVector = (*Parents)[MapNode];
393     if (NodeOrVector.isNull()) {
394       if (const auto *D = ParentStack.back().get<Decl>())
395         NodeOrVector = D;
396       else if (const auto *S = ParentStack.back().get<Stmt>())
397         NodeOrVector = S;
398       else
399         NodeOrVector = new DynTypedNode(ParentStack.back());
400     } else {
401       if (!isa<ParentVector *>(NodeOrVector)) {
402         auto *Vector = new ParentVector(
403             1, getSingleDynTypedNodeFromParentMap(NodeOrVector));
404         delete dyn_cast<DynTypedNode *>(NodeOrVector);
405         NodeOrVector = Vector;
406       }
407 
408       auto *Vector = cast<ParentVector *>(NodeOrVector);
409       // Skip duplicates for types that have memoization data.
410       // We must check that the type has memoization data before calling
411       // llvm::is_contained() because DynTypedNode::operator== can't compare all
412       // types.
413       bool Found = ParentStack.back().getMemoizationData() &&
414                    llvm::is_contained(*Vector, ParentStack.back());
415       if (!Found)
416         Vector->push_back(ParentStack.back());
417     }
418   }
419 
420   template <typename T> static bool isNull(T Node) { return !Node; }
421   static bool isNull(ObjCProtocolLoc Node) { return false; }
422 
423   template <typename T, typename MapNodeTy, typename BaseTraverseFn,
424             typename MapTy>
425   bool TraverseNode(T Node, MapNodeTy MapNode, BaseTraverseFn BaseTraverse,
426                     MapTy *Parents) {
427     if (isNull(Node))
428       return true;
429     addParent(MapNode, Parents);
430     ParentStack.push_back(createDynTypedNode(Node));
431     bool Result = BaseTraverse();
432     ParentStack.pop_back();
433     return Result;
434   }
435 
436   bool TraverseDecl(Decl *DeclNode) {
437     return TraverseNode(
438         DeclNode, DeclNode, [&] { return VisitorBase::TraverseDecl(DeclNode); },
439         &Map.PointerParents);
440   }
441   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
442     return TraverseNode(
443         TypeLocNode, DynTypedNode::create(TypeLocNode),
444         [&] { return VisitorBase::TraverseTypeLoc(TypeLocNode); },
445         &Map.OtherParents);
446   }
447   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNSLocNode) {
448     return TraverseNode(
449         NNSLocNode, DynTypedNode::create(NNSLocNode),
450         [&] { return VisitorBase::TraverseNestedNameSpecifierLoc(NNSLocNode); },
451         &Map.OtherParents);
452   }
453   bool TraverseAttr(Attr *AttrNode) {
454     return TraverseNode(
455         AttrNode, AttrNode, [&] { return VisitorBase::TraverseAttr(AttrNode); },
456         &Map.PointerParents);
457   }
458   bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLocNode) {
459     return TraverseNode(
460         ProtocolLocNode, DynTypedNode::create(ProtocolLocNode),
461         [&] { return VisitorBase::TraverseObjCProtocolLoc(ProtocolLocNode); },
462         &Map.OtherParents);
463   }
464 
465   // Using generic TraverseNode for Stmt would prevent data-recursion.
466   bool dataTraverseStmtPre(Stmt *StmtNode) {
467     addParent(StmtNode, &Map.PointerParents);
468     ParentStack.push_back(DynTypedNode::create(*StmtNode));
469     return true;
470   }
471   bool dataTraverseStmtPost(Stmt *StmtNode) {
472     ParentStack.pop_back();
473     return true;
474   }
475 
476   ParentMap &Map;
477   llvm::SmallVector<DynTypedNode, 16> ParentStack;
478 };
479 
480 ParentMapContext::ParentMap::ParentMap(ASTContext &Ctx) {
481   ASTVisitor(*this).TraverseAST(Ctx);
482 }
483 
484 DynTypedNodeList ParentMapContext::getParents(const DynTypedNode &Node) {
485   if (!Parents)
486     // We build the parent map for the traversal scope (usually whole TU), as
487     // hasAncestor can escape any subtree.
488     Parents = std::make_unique<ParentMap>(ASTCtx);
489   return Parents->getParents(getTraversalKind(), Node);
490 }
491