1 //===- SPIRVConvergenceRegionAnalysis.h ------------------------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis determines the convergence region for each basic block of 10 // the module, and provides a tree-like structure describing the region 11 // hierarchy. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H 16 #define LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H 17 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/Analysis/CFG.h" 20 #include "llvm/Analysis/LoopInfo.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include <iostream> 24 #include <optional> 25 #include <unordered_set> 26 27 namespace llvm { 28 class SPIRVSubtarget; 29 class MachineFunction; 30 class MachineModuleInfo; 31 32 namespace SPIRV { 33 34 // Returns the first convergence intrinsic found in |BB|, |nullopt| otherwise. 35 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB); 36 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB); 37 38 // Describes a hierarchy of convergence regions. 39 // A convergence region defines a CFG for which the execution flow can diverge 40 // starting from the entry block, but should reconverge back before the end of 41 // the exit blocks. 42 class ConvergenceRegion { 43 DominatorTree &DT; 44 LoopInfo &LI; 45 46 public: 47 // The parent region of this region, if any. 48 ConvergenceRegion *Parent = nullptr; 49 // The sub-regions contained in this region, if any. 50 SmallVector<ConvergenceRegion *> Children = {}; 51 // The convergence instruction linked to this region, if any. 52 std::optional<IntrinsicInst *> ConvergenceToken = std::nullopt; 53 // The only block with a predecessor outside of this region. 54 BasicBlock *Entry = nullptr; 55 // All the blocks with an edge leaving this convergence region. 56 SmallPtrSet<BasicBlock *, 2> Exits = {}; 57 // All the blocks that belongs to this region, including its subregions'. 58 SmallPtrSet<BasicBlock *, 8> Blocks = {}; 59 60 // Creates a single convergence region encapsulating the whole function |F|. 61 ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, Function &F); 62 63 // Creates a single convergence region defined by entry and exits nodes, a 64 // list of blocks, and possibly a convergence token. 65 ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, 66 std::optional<IntrinsicInst *> ConvergenceToken, 67 BasicBlock *Entry, SmallPtrSet<BasicBlock *, 8> &&Blocks, 68 SmallPtrSet<BasicBlock *, 2> &&Exits); 69 ConvergenceRegion(ConvergenceRegion && CR)70 ConvergenceRegion(ConvergenceRegion &&CR) 71 : DT(CR.DT), LI(CR.LI), Parent(std::move(CR.Parent)), 72 Children(std::move(CR.Children)), 73 ConvergenceToken(std::move(CR.ConvergenceToken)), 74 Entry(std::move(CR.Entry)), Exits(std::move(CR.Exits)), 75 Blocks(std::move(CR.Blocks)) {} 76 77 ConvergenceRegion(const ConvergenceRegion &other) = delete; 78 79 // Returns true if the given basic block belongs to this region, or to one of 80 // its subregion. contains(const BasicBlock * BB)81 bool contains(const BasicBlock *BB) const { return Blocks.count(BB) != 0; } 82 83 void releaseMemory(); 84 85 // Write to the debug output this region's hierarchy. 86 // |IndentSize| defines the number of tabs to print before any new line. 87 void dump(const unsigned IndentSize = 0) const; 88 }; 89 90 // Holds a ConvergenceRegion hierarchy. 91 class ConvergenceRegionInfo { 92 // The convergence region this structure holds. 93 ConvergenceRegion *TopLevelRegion; 94 95 public: ConvergenceRegionInfo()96 ConvergenceRegionInfo() : TopLevelRegion(nullptr) {} 97 98 // Creates a new ConvergenceRegionInfo. Ownership of the TopLevelRegion is 99 // passed to this object. ConvergenceRegionInfo(ConvergenceRegion * TopLevelRegion)100 ConvergenceRegionInfo(ConvergenceRegion *TopLevelRegion) 101 : TopLevelRegion(TopLevelRegion) {} 102 ~ConvergenceRegionInfo()103 ~ConvergenceRegionInfo() { releaseMemory(); } 104 ConvergenceRegionInfo(ConvergenceRegionInfo && LHS)105 ConvergenceRegionInfo(ConvergenceRegionInfo &&LHS) 106 : TopLevelRegion(LHS.TopLevelRegion) { 107 if (TopLevelRegion != LHS.TopLevelRegion) { 108 releaseMemory(); 109 TopLevelRegion = LHS.TopLevelRegion; 110 } 111 LHS.TopLevelRegion = nullptr; 112 } 113 114 ConvergenceRegionInfo &operator=(ConvergenceRegionInfo &&LHS) { 115 if (TopLevelRegion != LHS.TopLevelRegion) { 116 releaseMemory(); 117 TopLevelRegion = LHS.TopLevelRegion; 118 } 119 LHS.TopLevelRegion = nullptr; 120 return *this; 121 } 122 releaseMemory()123 void releaseMemory() { 124 if (TopLevelRegion == nullptr) 125 return; 126 127 TopLevelRegion->releaseMemory(); 128 delete TopLevelRegion; 129 TopLevelRegion = nullptr; 130 } 131 getTopLevelRegion()132 const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; } 133 }; 134 135 } // namespace SPIRV 136 137 // Wrapper around the function above to use it with the legacy pass manager. 138 class SPIRVConvergenceRegionAnalysisWrapperPass : public FunctionPass { 139 SPIRV::ConvergenceRegionInfo CRI; 140 141 public: 142 static char ID; 143 144 SPIRVConvergenceRegionAnalysisWrapperPass(); 145 getAnalysisUsage(AnalysisUsage & AU)146 void getAnalysisUsage(AnalysisUsage &AU) const override { 147 AU.setPreservesAll(); 148 AU.addRequired<LoopInfoWrapperPass>(); 149 AU.addRequired<DominatorTreeWrapperPass>(); 150 }; 151 152 bool runOnFunction(Function &F) override; 153 getRegionInfo()154 SPIRV::ConvergenceRegionInfo &getRegionInfo() { return CRI; } getRegionInfo()155 const SPIRV::ConvergenceRegionInfo &getRegionInfo() const { return CRI; } 156 }; 157 158 // Wrapper around the function above to use it with the new pass manager. 159 class SPIRVConvergenceRegionAnalysis 160 : public AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis> { 161 friend AnalysisInfoMixin<SPIRVConvergenceRegionAnalysis>; 162 static AnalysisKey Key; 163 164 public: 165 using Result = SPIRV::ConvergenceRegionInfo; 166 167 Result run(Function &F, FunctionAnalysisManager &AM); 168 }; 169 170 namespace SPIRV { 171 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT, 172 LoopInfo &LI); 173 } // namespace SPIRV 174 175 } // namespace llvm 176 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVCONVERGENCEREGIONANALYSIS_H 177