xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis determines the convergence region for each basic block of
10 // the module, and provides a tree-like structure describing the region
11 // hierarchy.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
16 #define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
17 
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/CFG.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/IR/Dominators.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include <iostream>
24 #include <optional>
25 #include <unordered_set>
26 
27 namespace llvm {
28 class SPIRVSubtarget;
29 class MachineFunction;
30 class MachineModuleInfo;
31 
32 namespace SPIRV {
33 
34 // Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise.
35 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB);
36 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB);
37 
38 // Describes a hierarchy of convergence regions.
39 // A convergence region defines a CFG for which the execution flow can diverge
40 // starting from the entry block, but should reconverge back before the end of
41 // the exit blocks.
42 class ConvergenceRegion {
43   DominatorTree &DT;
44   LoopInfo &LI;
45 
46 public:
47   // The parent region of this region, if any.
48   ConvergenceRegion *Parent = nullptr;
49   // The sub-regions contained in this region, if any.
50   SmallVector<ConvergenceRegion *> Children = {};
51   // The convergence instruction linked to this region, if any.
52   std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt;
53   // The only block with a predecessor outside of this region.
54   BasicBlock *Entry = nullptr;
55   // All the blocks with an edge leaving this convergence region.
56   SmallPtrSet<BasicBlock *, 2> Exits = {};
57   // All the blocks that belongs to this region, including its subregions'.
58   SmallPtrSet<BasicBlock *, 8> Blocks = {};
59 
60   // Creates a single convergence region encapsulating the whole function |F|.
61   ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F);
62 
63   // Creates a single convergence region defined by entry and exits nodes, a
64   // list of blocks, and possibly a convergence token.
65   ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
66                     std::optional<IntrinsicInst *> ConvergenceToken,
67                     BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks,
68                     SmallPtrSet<BasicBlock *, 2> &&Exits);
69 
70   ConvergenceRegion(ConvergenceRegion &&CR)
71       : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)),
72         Children(std::move(CR.Children)),
73         ConvergenceToken(std::move(CR.ConvergenceToken)),
74         Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)),
75         Blocks(std::move(CR.Blocks)) {}
76 
77   ConvergenceRegion(const ConvergenceRegion &other) = delete;
78 
79   // Returns true if the given basic block belongs to this region, or to one of
80   // its subregion.
81   bool contains(const BasicBlock *BB) const { return Blocks.count(BB) != 0; }
82 
83   void releaseMemory();
84 
85   // Write to the debug output this region's hierarchy.
86   // |IndentSize| defines the number of tabs to print before any new line.
87   void dump(const unsigned IndentSize = 0) const;
88 };
89 
90 // Holds a ConvergenceRegion hierarchy.
91 class ConvergenceRegionInfo {
92   // The convergence region this structure holds.
93   ConvergenceRegion *TopLevelRegion;
94 
95 public:
96   ConvergenceRegionInfo() : TopLevelRegion(nullptr) {}
97 
98   // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is
99   // passed to this object.
100   ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion)
101       : TopLevelRegion(TopLevelRegion) {}
102 
103   ~ConvergenceRegionInfo() { releaseMemory(); }
104 
105   ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS)
106       : TopLevelRegion(LHS.TopLevelRegion) {
107     if (TopLevelRegion != LHS.TopLevelRegion) {
108       releaseMemory();
109       TopLevelRegion = LHS.TopLevelRegion;
110     }
111     LHS.TopLevelRegion = nullptr;
112   }
113 
114   ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) {
115     if (TopLevelRegion != LHS.TopLevelRegion) {
116       releaseMemory();
117       TopLevelRegion = LHS.TopLevelRegion;
118     }
119     LHS.TopLevelRegion = nullptr;
120     return *this;
121   }
122 
123   void releaseMemory() {
124     if (TopLevelRegion == nullptr)
125       return;
126 
127     TopLevelRegion->releaseMemory();
128     delete TopLevelRegion;
129     TopLevelRegion = nullptr;
130   }
131 
132   const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
133 };
134 
135 } // namespace SPIRV
136 
137 // Wrapper around the function above to use it with the legacy pass manager.
138 class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass {
139   SPIRV::ConvergenceRegionInfo CRI;
140 
141 public:
142   static char ID;
143 
144   SPIRVConvergenceRegionAnalysisWrapperPass();
145 
146   void getAnalysisUsage(AnalysisUsage &AU) const override {
147     AU.setPreservesAll();
148     AU.addRequired<LoopInfoWrapperPass>();
149     AU.addRequired<DominatorTreeWrapperPass>();
150   };
151 
152   bool runOnFunction(Function &F) override;
153 
154   SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; }
155   const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; }
156 };
157 
158 // Wrapper around the function above to use it with the new pass manager.
159 class SPIRVConvergenceRegionAnalysis
160     : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> {
161   friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>;
162   static AnalysisKey Key;
163 
164 public:
165   using Result = SPIRV::ConvergenceRegionInfo;
166 
167   Result run(Function &F, FunctionAnalysisManager &AM);
168 };
169 
170 namespace SPIRV {
171 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
172                                             LoopInfo &LI);
173 } // namespace SPIRV
174 
175 } // namespace llvm
176 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H
177