xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis determines the convergence region for each basic block of
10 // the module, and provides a tree-like structure describing the region
11 // hierarchy.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
16 #define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
17 
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/CFG.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/IR/Dominators.h"
22 #include <optional>
23 #include <unordered_set>
24 
25 namespace llvm {
26 class IntrinsicInst;
27 class SPIRVSubtarget;
28 class MachineFunction;
29 class MachineModuleInfo;
30 
31 namespace SPIRV {
32 
33 // Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise.
34 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB);
35 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB);
36 
37 // Describes a hierarchy of convergence regions.
38 // A convergence region defines a CFG for which the execution flow can diverge
39 // starting from the entry block, but should reconverge back before the end of
40 // the exit blocks.
41 class ConvergenceRegion {
42   DominatorTree &DT;
43   LoopInfo &LI;
44 
45 public:
46   // The parent region of this region, if any.
47   ConvergenceRegion *Parent = nullptr;
48   // The sub-regions contained in this region, if any.
49   SmallVector<ConvergenceRegion *> Children = {};
50   // The convergence instruction linked to this region, if any.
51   std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt;
52   // The only block with a predecessor outside of this region.
53   BasicBlock *Entry = nullptr;
54   // All the blocks with an edge leaving this convergence region.
55   SmallPtrSet<BasicBlock *, 2> Exits = {};
56   // All the blocks that belongs to this region, including its subregions'.
57   SmallPtrSet<BasicBlock *, 8> Blocks = {};
58 
59   // Creates a single convergence region encapsulating the whole function |F|.
60   ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F);
61 
62   // Creates a single convergence region defined by entry and exits nodes, a
63   // list of blocks, and possibly a convergence token.
64   ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
65                     std::optional<IntrinsicInst *> ConvergenceToken,
66                     BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks,
67                     SmallPtrSet<BasicBlock *, 2> &&Exits);
68 
ConvergenceRegion(ConvergenceRegion && CR)69   ConvergenceRegion(ConvergenceRegion &&CR)
70       : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)),
71         Children(std::move(CR.Children)),
72         ConvergenceToken(std::move(CR.ConvergenceToken)),
73         Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)),
74         Blocks(std::move(CR.Blocks)) {}
75 
76   ConvergenceRegion(const ConvergenceRegion &other) = delete;
77 
78   // Returns true if the given basic block belongs to this region, or to one of
79   // its subregion.
contains(const BasicBlock * BB)80   bool contains(const BasicBlock *BB) const { return Blocks.count(BB) != 0; }
81 
82   void releaseMemory();
83 
84   // Write to the debug output this region's hierarchy.
85   // |IndentSize| defines the number of tabs to print before any new line.
86   void dump(const unsigned IndentSize = 0) const;
87 };
88 
89 // Holds a ConvergenceRegion hierarchy.
90 class ConvergenceRegionInfo {
91   // The convergence region this structure holds.
92   ConvergenceRegion *TopLevelRegion;
93 
94 public:
ConvergenceRegionInfo()95   ConvergenceRegionInfo() : TopLevelRegion(nullptr) {}
96 
97   // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is
98   // passed to this object.
ConvergenceRegionInfo(ConvergenceRegion * TopLevelRegion)99   ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion)
100       : TopLevelRegion(TopLevelRegion) {}
101 
~ConvergenceRegionInfo()102   ~ConvergenceRegionInfo() { releaseMemory(); }
103 
ConvergenceRegionInfo(ConvergenceRegionInfo && LHS)104   ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS)
105       : TopLevelRegion(LHS.TopLevelRegion) {
106     if (TopLevelRegion != LHS.TopLevelRegion) {
107       releaseMemory();
108       TopLevelRegion = LHS.TopLevelRegion;
109     }
110     LHS.TopLevelRegion = nullptr;
111   }
112 
113   ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) {
114     if (TopLevelRegion != LHS.TopLevelRegion) {
115       releaseMemory();
116       TopLevelRegion = LHS.TopLevelRegion;
117     }
118     LHS.TopLevelRegion = nullptr;
119     return *this;
120   }
121 
releaseMemory()122   void releaseMemory() {
123     if (TopLevelRegion == nullptr)
124       return;
125 
126     TopLevelRegion->releaseMemory();
127     delete TopLevelRegion;
128     TopLevelRegion = nullptr;
129   }
130 
getTopLevelRegion()131   const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
getWritableTopLevelRegion()132   ConvergenceRegion *getWritableTopLevelRegion() const {
133     return TopLevelRegion;
134   }
135 };
136 
137 } // namespace SPIRV
138 
139 // Wrapper around the function above to use it with the legacy pass manager.
140 class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass {
141   SPIRV::ConvergenceRegionInfo CRI;
142 
143 public:
144   static char ID;
145 
146   SPIRVConvergenceRegionAnalysisWrapperPass();
147 
getAnalysisUsage(AnalysisUsage & AU)148   void getAnalysisUsage(AnalysisUsage &AU) const override {
149     AU.setPreservesAll();
150     AU.addRequired<LoopInfoWrapperPass>();
151     AU.addRequired<DominatorTreeWrapperPass>();
152   };
153 
154   bool runOnFunction(Function &F) override;
155 
getRegionInfo()156   SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; }
getRegionInfo()157   const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
158 };
159 
160 // Wrapper around the function above to use it with the new pass manager.
161 class SPIRVConvergenceRegionAnalysis
162     : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> {
163   friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>;
164   static AnalysisKey Key;
165 
166 public:
167   using Result = SPIRV::ConvergenceRegionInfo;
168 
169   Result run(Function &F, FunctionAnalysisManager &AM);
170 };
171 
172 namespace SPIRV {
173 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
174                                             LoopInfo &LI);
175 } // namespace SPIRV
176 
177 } // namespace llvm
178 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
179