1 //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The analysis determines the convergence region for each basic block of 10 // the module, and provides a tree-like structure describing the region 11 // hierarchy. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "SPIRVConvergenceRegionAnalysis.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/IR/Dominators.h" 18 #include "llvm/IR/IntrinsicInst.h" 19 #include "llvm/InitializePasses.h" 20 #include "llvm/Transforms/Utils/LoopSimplify.h" 21 #include <optional> 22 #include <queue> 23 24 #define DEBUG_TYPE "spirv-convergence-region-analysis" 25 26 using namespace llvm; 27 28 namespace llvm { 29 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &); 30 } // namespace llvm 31 32 INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass, 33 "convergence-region", 34 "SPIRV convergence regions analysis", true, true) 35 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 36 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 37 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 38 INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass, 39 "convergence-region", "SPIRV convergence regions analysis", 40 true, true) 41 42 namespace llvm { 43 namespace SPIRV { 44 namespace { 45 46 template <typename BasicBlockType, typename IntrinsicInstType> 47 std::optional<IntrinsicInstType *> 48 getConvergenceTokenInternal(BasicBlockType *BB) { 49 static_assert(std::is_const_v<IntrinsicInstType> == 50 std::is_const_v<BasicBlockType>, 51 "Constness must match between input and output."); 52 static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>, 53 "Input must be a basic block."); 54 static_assert( 55 std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>, 56 "Output type must be an intrinsic instruction."); 57 58 for (auto &I : *BB) { 59 if (auto *II = dyn_cast<IntrinsicInst>(&I)) { 60 switch (II->getIntrinsicID()) { 61 case Intrinsic::experimental_convergence_entry: 62 case Intrinsic::experimental_convergence_loop: 63 return II; 64 case Intrinsic::experimental_convergence_anchor: { 65 auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl); 66 assert(Bundle->Inputs.size() == 1 && 67 Bundle->Inputs[0]->getType()->isTokenTy()); 68 auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get()); 69 assert(TII != nullptr); 70 return TII; 71 } 72 } 73 } 74 75 if (auto *CI = dyn_cast<CallInst>(&I)) { 76 auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl); 77 if (!OB.has_value()) 78 continue; 79 return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]); 80 } 81 } 82 83 return std::nullopt; 84 } 85 86 // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest 87 // region |Entry| belongs to. If |Entry| does not belong to the region defined 88 // by |Start|, this function returns |nullptr|. 89 ConvergenceRegion *findParentRegion(ConvergenceRegion *Start, 90 BasicBlock *Entry) { 91 ConvergenceRegion *Candidate = nullptr; 92 ConvergenceRegion *NextCandidate = Start; 93 94 while (Candidate != NextCandidate && NextCandidate != nullptr) { 95 Candidate = NextCandidate; 96 NextCandidate = nullptr; 97 98 // End of the search, we can return. 99 if (Candidate->Children.size() == 0) 100 return Candidate; 101 102 for (auto *Child : Candidate->Children) { 103 if (Child->Blocks.count(Entry) != 0) { 104 NextCandidate = Child; 105 break; 106 } 107 } 108 } 109 110 return Candidate; 111 } 112 113 } // anonymous namespace 114 115 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) { 116 return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB); 117 } 118 119 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) { 120 return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB); 121 } 122 123 ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI, 124 Function &F) 125 : DT(DT), LI(LI), Parent(nullptr) { 126 Entry = &F.getEntryBlock(); 127 ConvergenceToken = getConvergenceToken(Entry); 128 for (auto &B : F) { 129 Blocks.insert(&B); 130 if (isa<ReturnInst>(B.getTerminator())) 131 Exits.insert(&B); 132 } 133 } 134 135 ConvergenceRegion::ConvergenceRegion( 136 DominatorTree &DT, LoopInfo &LI, 137 std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry, 138 SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits) 139 : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry), 140 Exits(std::move(Exits)), Blocks(std::move(Blocks)) { 141 for ([[maybe_unused]] auto *BB : this->Exits) 142 assert(this->Blocks.count(BB) != 0); 143 assert(this->Blocks.count(this->Entry) != 0); 144 } 145 146 void ConvergenceRegion::releaseMemory() { 147 // Parent memory is owned by the parent. 148 Parent = nullptr; 149 for (auto *Child : Children) { 150 Child->releaseMemory(); 151 delete Child; 152 } 153 Children.resize(0); 154 } 155 156 void ConvergenceRegion::dump(const unsigned IndentSize) const { 157 const std::string Indent(IndentSize, '\t'); 158 dbgs() << Indent << this << ": {\n"; 159 dbgs() << Indent << " Parent: " << Parent << "\n"; 160 161 if (ConvergenceToken.value_or(nullptr)) { 162 dbgs() << Indent 163 << " ConvergenceToken: " << ConvergenceToken.value()->getName() 164 << "\n"; 165 } 166 167 if (Entry->getName() != "") 168 dbgs() << Indent << " Entry: " << Entry->getName() << "\n"; 169 else 170 dbgs() << Indent << " Entry: " << Entry << "\n"; 171 172 dbgs() << Indent << " Exits: { "; 173 for (const auto &Exit : Exits) { 174 if (Exit->getName() != "") 175 dbgs() << Exit->getName() << ", "; 176 else 177 dbgs() << Exit << ", "; 178 } 179 dbgs() << " }\n"; 180 181 dbgs() << Indent << " Blocks: { "; 182 for (const auto &Block : Blocks) { 183 if (Block->getName() != "") 184 dbgs() << Block->getName() << ", "; 185 else 186 dbgs() << Block << ", "; 187 } 188 dbgs() << " }\n"; 189 190 dbgs() << Indent << " Children: {\n"; 191 for (const auto Child : Children) 192 Child->dump(IndentSize + 2); 193 dbgs() << Indent << " }\n"; 194 195 dbgs() << Indent << "}\n"; 196 } 197 198 class ConvergenceRegionAnalyzer { 199 200 public: 201 ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI) 202 : DT(DT), LI(LI), F(F) {} 203 204 private: 205 bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const { 206 assert(From != To && "From == To. This is awkward."); 207 208 // We only handle loop in the simplified form. This means: 209 // - a single back-edge, a single latch. 210 // - meaning the back-edge target can only be the loop header. 211 // - meaning the From can only be the loop latch. 212 if (!LI.isLoopHeader(To)) 213 return false; 214 215 auto *L = LI.getLoopFor(To); 216 if (L->contains(From) && L->isLoopLatch(From)) 217 return true; 218 219 return false; 220 } 221 222 std::unordered_set<BasicBlock *> 223 findPathsToMatch(LoopInfo &LI, BasicBlock *From, 224 std::function<bool(const BasicBlock *)> isMatch) const { 225 std::unordered_set<BasicBlock *> Output; 226 227 if (isMatch(From)) 228 Output.insert(From); 229 230 auto *Terminator = From->getTerminator(); 231 for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { 232 auto *To = Terminator->getSuccessor(i); 233 if (isBackEdge(From, To)) 234 continue; 235 236 auto ChildSet = findPathsToMatch(LI, To, isMatch); 237 if (ChildSet.size() == 0) 238 continue; 239 240 Output.insert(ChildSet.begin(), ChildSet.end()); 241 Output.insert(From); 242 if (LI.isLoopHeader(From)) { 243 auto *L = LI.getLoopFor(From); 244 for (auto *BB : L->getBlocks()) { 245 Output.insert(BB); 246 } 247 } 248 } 249 250 return Output; 251 } 252 253 SmallPtrSet<BasicBlock *, 2> 254 findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) { 255 SmallPtrSet<BasicBlock *, 2> Exits; 256 257 for (auto *B : RegionBlocks) { 258 auto *Terminator = B->getTerminator(); 259 for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) { 260 auto *Child = Terminator->getSuccessor(i); 261 if (RegionBlocks.count(Child) == 0) 262 Exits.insert(B); 263 } 264 } 265 266 return Exits; 267 } 268 269 public: 270 ConvergenceRegionInfo analyze() { 271 ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F); 272 std::queue<Loop *> ToProcess; 273 for (auto *L : LI.getLoopsInPreorder()) 274 ToProcess.push(L); 275 276 while (ToProcess.size() != 0) { 277 auto *L = ToProcess.front(); 278 ToProcess.pop(); 279 assert(L->isLoopSimplifyForm()); 280 281 auto CT = getConvergenceToken(L->getHeader()); 282 SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(), 283 L->block_end()); 284 SmallVector<BasicBlock *> LoopExits; 285 L->getExitingBlocks(LoopExits); 286 if (CT.has_value()) { 287 for (auto *Exit : LoopExits) { 288 auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) { 289 auto Token = getConvergenceToken(block); 290 if (Token == std::nullopt) 291 return false; 292 return Token.value() == CT.value(); 293 }); 294 RegionBlocks.insert(N.begin(), N.end()); 295 } 296 } 297 298 auto RegionExits = findExitNodes(RegionBlocks); 299 ConvergenceRegion *Region = new ConvergenceRegion( 300 DT, LI, CT, L->getHeader(), std::move(RegionBlocks), 301 std::move(RegionExits)); 302 Region->Parent = findParentRegion(TopLevelRegion, Region->Entry); 303 assert(Region->Parent != nullptr && "This is impossible."); 304 Region->Parent->Children.push_back(Region); 305 } 306 307 return ConvergenceRegionInfo(TopLevelRegion); 308 } 309 310 private: 311 DominatorTree &DT; 312 LoopInfo &LI; 313 Function &F; 314 }; 315 316 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT, 317 LoopInfo &LI) { 318 ConvergenceRegionAnalyzer Analyzer(F, DT, LI); 319 return Analyzer.analyze(); 320 } 321 322 } // namespace SPIRV 323 324 char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0; 325 326 SPIRVConvergenceRegionAnalysisWrapperPass:: 327 SPIRVConvergenceRegionAnalysisWrapperPass() 328 : FunctionPass(ID) {} 329 330 bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) { 331 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 332 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 333 334 CRI = SPIRV::getConvergenceRegions(F, DT, LI); 335 // Nothing was modified. 336 return false; 337 } 338 339 SPIRVConvergenceRegionAnalysis::Result 340 SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) { 341 Result CRI; 342 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 343 auto &LI = AM.getResult<LoopAnalysis>(F); 344 CRI = SPIRV::getConvergenceRegions(F, DT, LI); 345 return CRI; 346 } 347 348 AnalysisKey SPIRVConvergenceRegionAnalysis::Key; 349 350 } // namespace llvm 351