1 //===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis determines the convergence region for each basic block of 10 // the module, and provides a tree-like structure describing the region 11 // hierarchy. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H 16 #define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H 17 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/Analysis/CFG.h" 20 #include "llvm/Analysis/LoopInfo.h" 21 #include "llvm/IR/Dominators.h" 22 #include <optional> 23 #include <unordered_set> 24 25 namespace llvm { 26 class IntrinsicInst; 27 class SPIRVSubtarget; 28 class MachineFunction; 29 class MachineModuleInfo; 30 31 namespace SPIRV { 32 33 // Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise. 34 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB); 35 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB); 36 37 // Describes a hierarchy of convergence regions. 38 // A convergence region defines a CFG for which the execution flow can diverge 39 // starting from the entry block, but should reconverge back before the end of 40 // the exit blocks. 41 class ConvergenceRegion { 42 DominatorTree &DT; 43 LoopInfo &LI; 44 45 public: 46 // The parent region of this region, if any. 47 ConvergenceRegion *Parent = nullptr; 48 // The sub-regions contained in this region, if any. 49 SmallVector<ConvergenceRegion *> Children = {}; 50 // The convergence instruction linked to this region, if any. 51 std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt; 52 // The only block with a predecessor outside of this region. 53 BasicBlock *Entry = nullptr; 54 // All the blocks with an edge leaving this convergence region. 55 SmallPtrSet<BasicBlock *, 2> Exits = {}; 56 // All the blocks that belongs to this region, including its subregions'. 57 SmallPtrSet<BasicBlock *, 8> Blocks = {}; 58 59 // Creates a single convergence region encapsulating the whole function |F|. 60 ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F); 61 62 // Creates a single convergence region defined by entry and exits nodes, a 63 // list of blocks, and possibly a convergence token. 64 ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, 65 std::optional<IntrinsicInst *> ConvergenceToken, 66 BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks, 67 SmallPtrSet<BasicBlock *, 2> &&Exits); 68 ConvergenceRegion(ConvergenceRegion && CR)69 ConvergenceRegion(ConvergenceRegion &&CR) 70 : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)), 71 Children(std::move(CR.Children)), 72 ConvergenceToken(std::move(CR.ConvergenceToken)), 73 Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)), 74 Blocks(std::move(CR.Blocks)) {} 75 76 ConvergenceRegion(const ConvergenceRegion &other) = delete; 77 78 // Returns true if the given basic block belongs to this region, or to one of 79 // its subregion. contains(const BasicBlock * BB)80 bool contains(const BasicBlock *BB) const { return Blocks.count(BB) != 0; } 81 82 void releaseMemory(); 83 84 // Write to the debug output this region's hierarchy. 85 // |IndentSize| defines the number of tabs to print before any new line. 86 void dump(const unsigned IndentSize = 0) const; 87 }; 88 89 // Holds a ConvergenceRegion hierarchy. 90 class ConvergenceRegionInfo { 91 // The convergence region this structure holds. 92 ConvergenceRegion *TopLevelRegion; 93 94 public: ConvergenceRegionInfo()95 ConvergenceRegionInfo() : TopLevelRegion(nullptr) {} 96 97 // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is 98 // passed to this object. ConvergenceRegionInfo(ConvergenceRegion * TopLevelRegion)99 ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion) 100 : TopLevelRegion(TopLevelRegion) {} 101 ~ConvergenceRegionInfo()102 ~ConvergenceRegionInfo() { releaseMemory(); } 103 ConvergenceRegionInfo(ConvergenceRegionInfo && LHS)104 ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS) 105 : TopLevelRegion(LHS.TopLevelRegion) { 106 if (TopLevelRegion != LHS.TopLevelRegion) { 107 releaseMemory(); 108 TopLevelRegion = LHS.TopLevelRegion; 109 } 110 LHS.TopLevelRegion = nullptr; 111 } 112 113 ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) { 114 if (TopLevelRegion != LHS.TopLevelRegion) { 115 releaseMemory(); 116 TopLevelRegion = LHS.TopLevelRegion; 117 } 118 LHS.TopLevelRegion = nullptr; 119 return *this; 120 } 121 releaseMemory()122 void releaseMemory() { 123 if (TopLevelRegion == nullptr) 124 return; 125 126 TopLevelRegion->releaseMemory(); 127 delete TopLevelRegion; 128 TopLevelRegion = nullptr; 129 } 130 getTopLevelRegion()131 const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; } getWritableTopLevelRegion()132 ConvergenceRegion *getWritableTopLevelRegion() const { 133 return TopLevelRegion; 134 } 135 }; 136 137 } // namespace SPIRV 138 139 // Wrapper around the function above to use it with the legacy pass manager. 140 class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass { 141 SPIRV::ConvergenceRegionInfo CRI; 142 143 public: 144 static char ID; 145 146 SPIRVConvergenceRegionAnalysisWrapperPass(); 147 getAnalysisUsage(AnalysisUsage & AU)148 void getAnalysisUsage(AnalysisUsage &AU) const override { 149 AU.setPreservesAll(); 150 AU.addRequired<LoopInfoWrapperPass>(); 151 AU.addRequired<DominatorTreeWrapperPass>(); 152 }; 153 154 bool runOnFunction(Function &F) override; 155 getRegionInfo()156 SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; } getRegionInfo()157 const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; } 158 }; 159 160 // Wrapper around the function above to use it with the new pass manager. 161 class SPIRVConvergenceRegionAnalysis 162 : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> { 163 friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>; 164 static AnalysisKey Key; 165 166 public: 167 using Result = SPIRV::ConvergenceRegionInfo; 168 169 Result run(Function &F, FunctionAnalysisManager &AM); 170 }; 171 172 namespace SPIRV { 173 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT, 174 LoopInfo &LI); 175 } // namespace SPIRV 176 177 } // namespace llvm 178 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H 179