1 //===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 //===----------------------------------------------------------------------===// 10 11 #include "Analysis/SPIRVConvergenceRegionAnalysis.h" 12 #include "SPIRV.h" 13 #include "SPIRVStructurizerWrapper.h" 14 #include "SPIRVSubtarget.h" 15 #include "SPIRVUtils.h" 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/CodeGen/IntrinsicLowering.h" 20 #include "llvm/IR/CFG.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/IntrinsicInst.h" 24 #include "llvm/IR/Intrinsics.h" 25 #include "llvm/IR/IntrinsicsSPIRV.h" 26 #include "llvm/IR/LegacyPassManager.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Transforms/Utils.h" 29 #include "llvm/Transforms/Utils/Cloning.h" 30 #include "llvm/Transforms/Utils/LoopSimplify.h" 31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 32 #include <stack> 33 #include <unordered_set> 34 35 using namespace llvm; 36 using namespace SPIRV; 37 38 using BlockSet = std::unordered_set<BasicBlock *>; 39 using Edge = std::pair<BasicBlock *, BasicBlock *>; 40 41 // Helper function to do a partial order visit from the block |Start|, calling 42 // |Op| on each visited node. 43 static void partialOrderVisit(BasicBlock &Start, 44 std::function<bool(BasicBlock *)> Op) { 45 PartialOrderingVisitor V(*Start.getParent()); 46 V.partialOrderVisit(Start, Op); 47 } 48 49 // Returns the exact convergence region in the tree defined by `Node` for which 50 // `BB` is the header, nullptr otherwise. 51 static const ConvergenceRegion * 52 getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB) { 53 if (Node->Entry == BB) 54 return Node; 55 56 for (auto *Child : Node->Children) { 57 const auto *CR = getRegionForHeader(Child, BB); 58 if (CR != nullptr) 59 return CR; 60 } 61 return nullptr; 62 } 63 64 // Returns the single BasicBlock exiting the convergence region `CR`, 65 // nullptr if no such exit exists. 66 static BasicBlock *getExitFor(const ConvergenceRegion *CR) { 67 std::unordered_set<BasicBlock *> ExitTargets; 68 for (BasicBlock *Exit : CR->Exits) { 69 for (BasicBlock *Successor : successors(Exit)) { 70 if (CR->Blocks.count(Successor) == 0) 71 ExitTargets.insert(Successor); 72 } 73 } 74 75 assert(ExitTargets.size() <= 1); 76 if (ExitTargets.size() == 0) 77 return nullptr; 78 79 return *ExitTargets.begin(); 80 } 81 82 // Returns the merge block designated by I if I is a merge instruction, nullptr 83 // otherwise. 84 static BasicBlock *getDesignatedMergeBlock(Instruction *I) { 85 IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I); 86 if (II == nullptr) 87 return nullptr; 88 89 if (II->getIntrinsicID() != Intrinsic::spv_loop_merge && 90 II->getIntrinsicID() != Intrinsic::spv_selection_merge) 91 return nullptr; 92 93 BlockAddress *BA = cast<BlockAddress>(II->getOperand(0)); 94 return BA->getBasicBlock(); 95 } 96 97 // Returns the continue block designated by I if I is an OpLoopMerge, nullptr 98 // otherwise. 99 static BasicBlock *getDesignatedContinueBlock(Instruction *I) { 100 IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I); 101 if (II == nullptr) 102 return nullptr; 103 104 if (II->getIntrinsicID() != Intrinsic::spv_loop_merge) 105 return nullptr; 106 107 BlockAddress *BA = cast<BlockAddress>(II->getOperand(1)); 108 return BA->getBasicBlock(); 109 } 110 111 // Returns true if Header has one merge instruction which designated Merge as 112 // merge block. 113 static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) { 114 for (auto &I : Header) { 115 BasicBlock *MB = getDesignatedMergeBlock(&I); 116 if (MB == &Merge) 117 return true; 118 } 119 return false; 120 } 121 122 // Returns true if the BB has one OpLoopMerge instruction. 123 static bool hasLoopMergeInstruction(BasicBlock &BB) { 124 for (auto &I : BB) 125 if (getDesignatedContinueBlock(&I)) 126 return true; 127 return false; 128 } 129 130 // Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false 131 // otherwise. 132 static bool isMergeInstruction(Instruction *I) { 133 return getDesignatedMergeBlock(I) != nullptr; 134 } 135 136 // Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge 137 // instruction. 138 static SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) { 139 SmallPtrSet<BasicBlock *, 2> Output; 140 for (BasicBlock &BB : F) { 141 for (Instruction &I : BB) { 142 if (getDesignatedMergeBlock(&I) != nullptr) 143 Output.insert(&BB); 144 } 145 } 146 return Output; 147 } 148 149 // Returns all basic blocks in |F| referenced by at least 1 150 // OpSelectionMerge/OpLoopMerge instruction. 151 static SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) { 152 SmallPtrSet<BasicBlock *, 2> Output; 153 for (BasicBlock &BB : F) { 154 for (Instruction &I : BB) { 155 BasicBlock *MB = getDesignatedMergeBlock(&I); 156 if (MB != nullptr) 157 Output.insert(MB); 158 } 159 } 160 return Output; 161 } 162 163 // Return all the merge instructions contained in BB. 164 // Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge 165 // instruction, but this can happen while we structurize the CFG. 166 static std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) { 167 std::vector<Instruction *> Output; 168 for (Instruction &I : BB) 169 if (isMergeInstruction(&I)) 170 Output.push_back(&I); 171 return Output; 172 } 173 174 // Returns all basic blocks in |F| referenced as continue target by at least 1 175 // OpLoopMerge instruction. 176 static SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) { 177 SmallPtrSet<BasicBlock *, 2> Output; 178 for (BasicBlock &BB : F) { 179 for (Instruction &I : BB) { 180 BasicBlock *MB = getDesignatedContinueBlock(&I); 181 if (MB != nullptr) 182 Output.insert(MB); 183 } 184 } 185 return Output; 186 } 187 188 // Do a preorder traversal of the CFG starting from the BB |Start|. 189 // point. Calls |op| on each basic block encountered during the traversal. 190 static void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) { 191 std::stack<BasicBlock *> ToVisit; 192 SmallPtrSet<BasicBlock *, 8> Seen; 193 194 ToVisit.push(&Start); 195 Seen.insert(ToVisit.top()); 196 while (ToVisit.size() != 0) { 197 BasicBlock *BB = ToVisit.top(); 198 ToVisit.pop(); 199 200 if (!op(BB)) 201 continue; 202 203 for (auto Succ : successors(BB)) { 204 if (Seen.contains(Succ)) 205 continue; 206 ToVisit.push(Succ); 207 Seen.insert(Succ); 208 } 209 } 210 } 211 212 // Replaces the conditional and unconditional branch targets of |BB| by 213 // |NewTarget| if the target was |OldTarget|. This function also makes sure the 214 // associated merge instruction gets updated accordingly. 215 static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, 216 BasicBlock *NewTarget) { 217 auto *BI = cast<BranchInst>(BB->getTerminator()); 218 219 // 1. Replace all matching successors. 220 for (size_t i = 0; i < BI->getNumSuccessors(); i++) { 221 if (BI->getSuccessor(i) == OldTarget) 222 BI->setSuccessor(i, NewTarget); 223 } 224 225 // Branch was unconditional, no fixup required. 226 if (BI->isUnconditional()) 227 return; 228 229 // Branch had 2 successors, maybe now both are the same? 230 if (BI->getSuccessor(0) != BI->getSuccessor(1)) 231 return; 232 233 // Note: we may end up here because the original IR had such branches. 234 // This means Target is not necessarily equal to NewTarget. 235 IRBuilder<> Builder(BB); 236 Builder.SetInsertPoint(BI); 237 Builder.CreateBr(BI->getSuccessor(0)); 238 BI->eraseFromParent(); 239 240 // The branch was the only instruction, nothing else to do. 241 if (BB->size() == 1) 242 return; 243 244 // Otherwise, we need to check: was there an OpSelectionMerge before this 245 // branch? If we removed the OpBranchConditional, we must also remove the 246 // OpSelectionMerge. This is not valid for OpLoopMerge: 247 IntrinsicInst *II = 248 dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode()); 249 if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge) 250 return; 251 252 Constant *C = cast<Constant>(II->getOperand(0)); 253 II->eraseFromParent(); 254 if (!C->isConstantUsed()) 255 C->destroyConstant(); 256 } 257 258 // Replaces the target of branch instruction in |BB| with |NewTarget| if it 259 // was |OldTarget|. This function also fixes the associated merge instruction. 260 // Note: this function does not simplify branching instructions, it only updates 261 // targets. See also: simplifyBranches. 262 static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, 263 BasicBlock *NewTarget) { 264 auto *T = BB->getTerminator(); 265 if (isa<ReturnInst>(T)) 266 return; 267 268 if (isa<BranchInst>(T)) 269 return replaceIfBranchTargets(BB, OldTarget, NewTarget); 270 271 if (auto *SI = dyn_cast<SwitchInst>(T)) { 272 for (size_t i = 0; i < SI->getNumSuccessors(); i++) { 273 if (SI->getSuccessor(i) == OldTarget) 274 SI->setSuccessor(i, NewTarget); 275 } 276 return; 277 } 278 279 assert(false && "Unhandled terminator type."); 280 } 281 282 namespace { 283 // Given a reducible CFG, produces a structurized CFG in the SPIR-V sense, 284 // adding merge instructions when required. 285 class SPIRVStructurizer : public FunctionPass { 286 struct DivergentConstruct; 287 // Represents a list of condition/loops/switch constructs. 288 // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of 289 // constructs. 290 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>; 291 292 // Represents a divergent construct in the SPIR-V sense. 293 // Such constructs are represented by a header (entry), a merge block (exit), 294 // and possibly a continue block (back-edge). A construct can contain other 295 // constructs, but their boundaries do not cross. 296 struct DivergentConstruct { 297 BasicBlock *Header = nullptr; 298 BasicBlock *Merge = nullptr; 299 BasicBlock *Continue = nullptr; 300 301 DivergentConstruct *Parent = nullptr; 302 ConstructList Children; 303 }; 304 305 // An helper class to clean the construct boundaries. 306 // It is used to gather the list of blocks that should belong to each 307 // divergent construct, and possibly modify CFG edges when exits would cross 308 // the boundary of multiple constructs. 309 struct Splitter { 310 Function &F; 311 LoopInfo &LI; 312 DomTreeBuilder::BBDomTree DT; 313 DomTreeBuilder::BBPostDomTree PDT; 314 315 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); } 316 317 void invalidate() { 318 PDT.recalculate(F); 319 DT.recalculate(F); 320 } 321 322 // Returns the list of blocks that belong to a SPIR-V loop construct, 323 // including the continue construct. 324 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header, 325 BasicBlock *Merge) { 326 assert(DT.dominates(Header, Merge)); 327 std::vector<BasicBlock *> Output; 328 partialOrderVisit(*Header, [&](BasicBlock *BB) { 329 if (BB == Merge) 330 return false; 331 if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB)) 332 return false; 333 Output.push_back(BB); 334 return true; 335 }); 336 return Output; 337 } 338 339 // Returns the list of blocks that belong to a SPIR-V selection construct. 340 std::vector<BasicBlock *> 341 getSelectionConstructBlocks(DivergentConstruct *Node) { 342 assert(DT.dominates(Node->Header, Node->Merge)); 343 BlockSet OutsideBlocks; 344 OutsideBlocks.insert(Node->Merge); 345 346 for (DivergentConstruct *It = Node->Parent; It != nullptr; 347 It = It->Parent) { 348 OutsideBlocks.insert(It->Merge); 349 if (It->Continue) 350 OutsideBlocks.insert(It->Continue); 351 } 352 353 std::vector<BasicBlock *> Output; 354 partialOrderVisit(*Node->Header, [&](BasicBlock *BB) { 355 if (OutsideBlocks.count(BB) != 0) 356 return false; 357 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB)) 358 return false; 359 Output.push_back(BB); 360 return true; 361 }); 362 return Output; 363 } 364 365 // Returns the list of blocks that belong to a SPIR-V switch construct. 366 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header, 367 BasicBlock *Merge) { 368 assert(DT.dominates(Header, Merge)); 369 370 std::vector<BasicBlock *> Output; 371 partialOrderVisit(*Header, [&](BasicBlock *BB) { 372 // the blocks structurally dominated by a switch header, 373 if (!DT.dominates(Header, BB)) 374 return false; 375 // excluding blocks structurally dominated by the switch header’s merge 376 // block. 377 if (DT.dominates(Merge, BB) || BB == Merge) 378 return false; 379 Output.push_back(BB); 380 return true; 381 }); 382 return Output; 383 } 384 385 // Returns the list of blocks that belong to a SPIR-V case construct. 386 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target, 387 BasicBlock *Merge) { 388 assert(DT.dominates(Target, Merge)); 389 390 std::vector<BasicBlock *> Output; 391 partialOrderVisit(*Target, [&](BasicBlock *BB) { 392 // the blocks structurally dominated by an OpSwitch Target or Default 393 // block 394 if (!DT.dominates(Target, BB)) 395 return false; 396 // excluding the blocks structurally dominated by the OpSwitch 397 // construct’s corresponding merge block. 398 if (DT.dominates(Merge, BB) || BB == Merge) 399 return false; 400 Output.push_back(BB); 401 return true; 402 }); 403 return Output; 404 } 405 406 // Splits the given edges by recreating proxy nodes so that the destination 407 // has unique incoming edges from this region. 408 // 409 // clang-format off 410 // 411 // In SPIR-V, constructs must have a single exit/merge. 412 // Given nodes A and B in the construct, a node C outside, and the following edges. 413 // A -> C 414 // B -> C 415 // 416 // In such cases, we must create a new exit node D, that belong to the construct to make is viable: 417 // A -> D -> C 418 // B -> D -> C 419 // 420 // This is fine (assuming C has no PHI nodes), but requires handling the merge instruction here. 421 // By adding a proxy node, we create a regular divergent shape which can easily be regularized later on. 422 // A -> D -> D1 -> C 423 // B -> D -> D2 -> C 424 // 425 // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty. 426 // 427 // clang-format on 428 std::vector<Edge> 429 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) { 430 std::unordered_set<BasicBlock *> Seen; 431 std::vector<Edge> Output; 432 Output.reserve(Edges.size()); 433 434 for (auto &[Src, Dst] : Edges) { 435 auto [Iterator, Inserted] = Seen.insert(Src); 436 if (!Inserted) { 437 // Src already a source node. Cannot have 2 edges from A to B. 438 // Creating alias source block. 439 BasicBlock *NewSrc = BasicBlock::Create( 440 F.getContext(), Src->getName() + ".new.src", &F); 441 replaceBranchTargets(Src, Dst, NewSrc); 442 IRBuilder<> Builder(NewSrc); 443 Builder.CreateBr(Dst); 444 Src = NewSrc; 445 } 446 447 Output.emplace_back(Src, Dst); 448 } 449 450 return Output; 451 } 452 453 AllocaInst *CreateVariable(Function &F, Type *Type, 454 BasicBlock::iterator Position) { 455 const DataLayout &DL = F.getDataLayout(); 456 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg", 457 Position); 458 } 459 460 // Given a construct defined by |Header|, and a list of exiting edges 461 // |Edges|, creates a new single exit node, fixing up those edges. 462 BasicBlock *createSingleExitNode(BasicBlock *Header, 463 std::vector<Edge> &Edges) { 464 465 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges); 466 467 std::vector<BasicBlock *> Dsts; 468 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex; 469 auto NewExit = BasicBlock::Create(F.getContext(), 470 Header->getName() + ".new.exit", &F); 471 IRBuilder<> ExitBuilder(NewExit); 472 for (auto &[Src, Dst] : FixedEdges) { 473 if (DstToIndex.count(Dst) != 0) 474 continue; 475 DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size())); 476 Dsts.push_back(Dst); 477 } 478 479 if (Dsts.size() == 1) { 480 for (auto &[Src, Dst] : FixedEdges) { 481 replaceBranchTargets(Src, Dst, NewExit); 482 } 483 ExitBuilder.CreateBr(Dsts[0]); 484 return NewExit; 485 } 486 487 AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(), 488 F.begin()->getFirstInsertionPt()); 489 for (auto &[Src, Dst] : FixedEdges) { 490 IRBuilder<> B2(Src); 491 B2.SetInsertPoint(Src->getFirstInsertionPt()); 492 B2.CreateStore(DstToIndex[Dst], Variable); 493 replaceBranchTargets(Src, Dst, NewExit); 494 } 495 496 Value *Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable); 497 498 // If we can avoid an OpSwitch, generate an OpBranch. Reason is some 499 // OpBranch are allowed to exist without a new OpSelectionMerge if one of 500 // the branch is the parent's merge node, while OpSwitches are not. 501 if (Dsts.size() == 2) { 502 Value *Condition = 503 ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load); 504 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]); 505 return NewExit; 506 } 507 508 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1); 509 for (BasicBlock *BB : drop_begin(Dsts)) 510 Sw->addCase(DstToIndex[BB], BB); 511 return NewExit; 512 } 513 }; 514 515 /// Create a value in BB set to the value associated with the branch the block 516 /// terminator will take. 517 Value *createExitVariable( 518 BasicBlock *BB, 519 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { 520 auto *T = BB->getTerminator(); 521 if (isa<ReturnInst>(T)) 522 return nullptr; 523 524 IRBuilder<> Builder(BB); 525 Builder.SetInsertPoint(T); 526 527 if (auto *BI = dyn_cast<BranchInst>(T)) { 528 529 BasicBlock *LHSTarget = BI->getSuccessor(0); 530 BasicBlock *RHSTarget = 531 BI->isConditional() ? BI->getSuccessor(1) : nullptr; 532 533 Value *LHS = TargetToValue.lookup(LHSTarget); 534 Value *RHS = TargetToValue.lookup(RHSTarget); 535 536 if (LHS == nullptr || RHS == nullptr) 537 return LHS == nullptr ? RHS : LHS; 538 return Builder.CreateSelect(BI->getCondition(), LHS, RHS); 539 } 540 541 // TODO: add support for switch cases. 542 llvm_unreachable("Unhandled terminator type."); 543 } 544 545 // Creates a new basic block in F with a single OpUnreachable instruction. 546 BasicBlock *CreateUnreachable(Function &F) { 547 BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F); 548 IRBuilder<> Builder(BB); 549 Builder.CreateUnreachable(); 550 return BB; 551 } 552 553 // Add OpLoopMerge instruction on cycles. 554 bool addMergeForLoops(Function &F) { 555 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 556 auto *TopLevelRegion = 557 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 558 .getRegionInfo() 559 .getTopLevelRegion(); 560 561 bool Modified = false; 562 for (auto &BB : F) { 563 // Not a loop header. Ignoring for now. 564 if (!LI.isLoopHeader(&BB)) 565 continue; 566 auto *L = LI.getLoopFor(&BB); 567 568 // This loop header is not the entrance of a convergence region. Ignoring 569 // this block. 570 auto *CR = getRegionForHeader(TopLevelRegion, &BB); 571 if (CR == nullptr) 572 continue; 573 574 IRBuilder<> Builder(&BB); 575 576 auto *Merge = getExitFor(CR); 577 // We are indeed in a loop, but there are no exits (infinite loop). 578 // This could be caused by a bad shader, but also could be an artifact 579 // from an earlier optimization. It is not always clear if structurally 580 // reachable means runtime reachable, so we cannot error-out. What we must 581 // do however is to make is legal on the SPIR-V point of view, hence 582 // adding an unreachable merge block. 583 if (Merge == nullptr) { 584 BranchInst *Br = cast<BranchInst>(BB.getTerminator()); 585 assert(Br->isUnconditional()); 586 587 Merge = CreateUnreachable(F); 588 Builder.SetInsertPoint(Br); 589 Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0)); 590 Br->eraseFromParent(); 591 } 592 593 auto *Continue = L->getLoopLatch(); 594 595 Builder.SetInsertPoint(BB.getTerminator()); 596 auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); 597 auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue); 598 SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress}; 599 SmallVector<unsigned, 1> LoopControlImms = 600 getSpirvLoopControlOperandsFromLoopMetadata(L); 601 for (unsigned Imm : LoopControlImms) 602 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm)); 603 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {Args}); 604 Modified = true; 605 } 606 607 return Modified; 608 } 609 610 // Adds an OpSelectionMerge to the immediate dominator or each node with an 611 // in-degree of 2 or more which is not already the merge target of an 612 // OpLoopMerge/OpSelectionMerge. 613 bool addMergeForNodesWithMultiplePredecessors(Function &F) { 614 DomTreeBuilder::BBDomTree DT; 615 DT.recalculate(F); 616 617 bool Modified = false; 618 for (auto &BB : F) { 619 if (pred_size(&BB) <= 1) 620 continue; 621 622 if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2) 623 continue; 624 625 assert(DT.getNode(&BB)->getIDom()); 626 BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock(); 627 628 if (isDefinedAsSelectionMergeBy(*Header, BB)) 629 continue; 630 631 IRBuilder<> Builder(Header); 632 Builder.SetInsertPoint(Header->getTerminator()); 633 634 auto MergeAddress = BlockAddress::get(BB.getParent(), &BB); 635 createOpSelectMerge(&Builder, MergeAddress); 636 637 Modified = true; 638 } 639 640 return Modified; 641 } 642 643 // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts 644 // them to put the "largest" first. A merge instruction is defined as larger 645 // than another when its target merge block post-dominates the other target's 646 // merge block. (This ordering should match the nesting ordering of the source 647 // HLSL). 648 bool sortSelectionMerge(Function &F, BasicBlock &Block) { 649 std::vector<Instruction *> MergeInstructions; 650 for (Instruction &I : Block) 651 if (isMergeInstruction(&I)) 652 MergeInstructions.push_back(&I); 653 654 if (MergeInstructions.size() <= 1) 655 return false; 656 657 Instruction *InsertionPoint = *MergeInstructions.begin(); 658 659 PartialOrderingVisitor Visitor(F); 660 std::sort(MergeInstructions.begin(), MergeInstructions.end(), 661 [&Visitor](Instruction *Left, Instruction *Right) { 662 if (Left == Right) 663 return false; 664 BasicBlock *RightMerge = getDesignatedMergeBlock(Right); 665 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left); 666 return !Visitor.compare(RightMerge, LeftMerge); 667 }); 668 669 for (Instruction *I : MergeInstructions) { 670 I->moveBefore(InsertionPoint->getIterator()); 671 InsertionPoint = I; 672 } 673 674 return true; 675 } 676 677 // Sorts selection merge headers in |F|. 678 // A is sorted before B if the merge block designated by B is an ancestor of 679 // the one designated by A. 680 bool sortSelectionMergeHeaders(Function &F) { 681 bool Modified = false; 682 for (BasicBlock &BB : F) { 683 Modified |= sortSelectionMerge(F, BB); 684 } 685 return Modified; 686 } 687 688 // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge 689 // instructions so each basic block contains only a single merge instruction. 690 bool splitBlocksWithMultipleHeaders(Function &F) { 691 std::stack<BasicBlock *> Work; 692 for (auto &BB : F) { 693 std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB); 694 if (MergeInstructions.size() <= 1) 695 continue; 696 Work.push(&BB); 697 } 698 699 const bool Modified = Work.size() > 0; 700 while (Work.size() > 0) { 701 BasicBlock *Header = Work.top(); 702 Work.pop(); 703 704 std::vector<Instruction *> MergeInstructions = 705 getMergeInstructions(*Header); 706 for (unsigned i = 1; i < MergeInstructions.size(); i++) { 707 BasicBlock *NewBlock = 708 Header->splitBasicBlock(MergeInstructions[i], "new.header"); 709 710 if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) { 711 BasicBlock *Unreachable = CreateUnreachable(F); 712 713 BranchInst *BI = cast<BranchInst>(Header->getTerminator()); 714 IRBuilder<> Builder(Header); 715 Builder.SetInsertPoint(BI); 716 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable); 717 BI->eraseFromParent(); 718 } 719 720 Header = NewBlock; 721 } 722 } 723 724 return Modified; 725 } 726 727 // Adds an OpSelectionMerge to each block with an out-degree >= 2 which 728 // doesn't already have an OpSelectionMerge. 729 bool addMergeForDivergentBlocks(Function &F) { 730 DomTreeBuilder::BBPostDomTree PDT; 731 PDT.recalculate(F); 732 bool Modified = false; 733 734 auto MergeBlocks = getMergeBlocks(F); 735 auto ContinueBlocks = getContinueBlocks(F); 736 737 for (auto &BB : F) { 738 if (getMergeInstructions(BB).size() != 0) 739 continue; 740 741 std::vector<BasicBlock *> Candidates; 742 for (BasicBlock *Successor : successors(&BB)) { 743 if (MergeBlocks.contains(Successor)) 744 continue; 745 if (ContinueBlocks.contains(Successor)) 746 continue; 747 Candidates.push_back(Successor); 748 } 749 750 if (Candidates.size() <= 1) 751 continue; 752 753 Modified = true; 754 BasicBlock *Merge = Candidates[0]; 755 756 auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); 757 IRBuilder<> Builder(&BB); 758 Builder.SetInsertPoint(BB.getTerminator()); 759 createOpSelectMerge(&Builder, MergeAddress); 760 } 761 762 return Modified; 763 } 764 765 // Gather all the exit nodes for the construct header by |Header| and 766 // containing the blocks |Construct|. 767 std::vector<Edge> getExitsFrom(const BlockSet &Construct, 768 BasicBlock &Header) { 769 std::vector<Edge> Output; 770 visit(Header, [&](BasicBlock *Item) { 771 if (Construct.count(Item) == 0) 772 return false; 773 774 for (BasicBlock *Successor : successors(Item)) { 775 if (Construct.count(Successor) == 0) 776 Output.emplace_back(Item, Successor); 777 } 778 return true; 779 }); 780 781 return Output; 782 } 783 784 // Build a divergent construct tree searching from |BB|. 785 // If |Parent| is not null, this tree is attached to the parent's tree. 786 void constructDivergentConstruct(BlockSet &Visited, Splitter &S, 787 BasicBlock *BB, DivergentConstruct *Parent) { 788 if (Visited.count(BB) != 0) 789 return; 790 Visited.insert(BB); 791 792 auto MIS = getMergeInstructions(*BB); 793 if (MIS.size() == 0) { 794 for (BasicBlock *Successor : successors(BB)) 795 constructDivergentConstruct(Visited, S, Successor, Parent); 796 return; 797 } 798 799 assert(MIS.size() == 1); 800 Instruction *MI = MIS[0]; 801 802 BasicBlock *Merge = getDesignatedMergeBlock(MI); 803 BasicBlock *Continue = getDesignatedContinueBlock(MI); 804 805 auto Output = std::make_unique<DivergentConstruct>(); 806 Output->Header = BB; 807 Output->Merge = Merge; 808 Output->Continue = Continue; 809 Output->Parent = Parent; 810 811 constructDivergentConstruct(Visited, S, Merge, Parent); 812 if (Continue) 813 constructDivergentConstruct(Visited, S, Continue, Output.get()); 814 815 for (BasicBlock *Successor : successors(BB)) 816 constructDivergentConstruct(Visited, S, Successor, Output.get()); 817 818 if (Parent) 819 Parent->Children.emplace_back(std::move(Output)); 820 } 821 822 // Returns the blocks belonging to the divergent construct |Node|. 823 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) { 824 assert(Node->Header && Node->Merge); 825 826 if (Node->Continue) { 827 auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge); 828 return BlockSet(LoopBlocks.begin(), LoopBlocks.end()); 829 } 830 831 auto SelectionBlocks = S.getSelectionConstructBlocks(Node); 832 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end()); 833 } 834 835 // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V 836 // spec. 837 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) { 838 bool Modified = false; 839 for (auto &Child : Node->Children) 840 Modified |= fixupConstruct(S, Child.get()); 841 842 // This construct is the root construct. Does not represent any real 843 // construct, just a way to access the first level of the forest. 844 if (Node->Parent == nullptr) 845 return Modified; 846 847 // This node's parent is the root. Meaning this is a top-level construct. 848 // There can be multiple exists, but all are guaranteed to exit at most 1 849 // construct since we are at first level. 850 if (Node->Parent->Header == nullptr) 851 return Modified; 852 853 // Health check for the structure. 854 assert(Node->Header && Node->Merge); 855 assert(Node->Parent->Header && Node->Parent->Merge); 856 857 BlockSet ConstructBlocks = getConstructBlocks(S, Node); 858 auto Edges = getExitsFrom(ConstructBlocks, *Node->Header); 859 860 // No edges exiting the construct. 861 if (Edges.size() < 1) 862 return Modified; 863 864 bool HasBadEdge = Node->Merge == Node->Parent->Merge || 865 Node->Merge == Node->Parent->Continue; 866 // BasicBlock *Target = Edges[0].second; 867 for (auto &[Src, Dst] : Edges) { 868 // - Breaking from a selection construct: S is a selection construct, S is 869 // the innermost structured 870 // control-flow construct containing A, and B is the merge block for S 871 // - Breaking from the innermost loop: S is the innermost loop construct 872 // containing A, 873 // and B is the merge block for S 874 if (Node->Merge == Dst) 875 continue; 876 877 // Entering the innermost loop’s continue construct: S is the innermost 878 // loop construct containing A, and B is the continue target for S 879 if (Node->Continue == Dst) 880 continue; 881 882 // TODO: what about cases branching to another case in the switch? Seems 883 // to work, but need to double check. 884 HasBadEdge = true; 885 } 886 887 if (!HasBadEdge) 888 return Modified; 889 890 // Create a single exit node gathering all exit edges. 891 BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges); 892 893 // Fixup this construct's merge node to point to the new exit. 894 // Note: this algorithm fixes inner-most divergence construct first. So 895 // recursive structures sharing a single merge node are fixed from the 896 // inside toward the outside. 897 auto MergeInstructions = getMergeInstructions(*Node->Header); 898 assert(MergeInstructions.size() == 1); 899 Instruction *I = MergeInstructions[0]; 900 BlockAddress *BA = cast<BlockAddress>(I->getOperand(0)); 901 if (BA->getBasicBlock() == Node->Merge) { 902 auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit); 903 I->setOperand(0, MergeAddress); 904 } 905 906 // Clean up of the possible dangling BockAddr operands to prevent MIR 907 // comments about "address of removed block taken". 908 if (!BA->isConstantUsed()) 909 BA->destroyConstant(); 910 911 Node->Merge = NewExit; 912 // Regenerate the dom trees. 913 S.invalidate(); 914 return true; 915 } 916 917 bool splitCriticalEdges(Function &F) { 918 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 919 Splitter S(F, LI); 920 921 DivergentConstruct Root; 922 BlockSet Visited; 923 constructDivergentConstruct(Visited, S, &*F.begin(), &Root); 924 return fixupConstruct(S, &Root); 925 } 926 927 // Simplify branches when possible: 928 // - if the 2 sides of a conditional branch are the same, transforms it to an 929 // unconditional branch. 930 // - if a switch has only 2 distinct successors, converts it to a conditional 931 // branch. 932 bool simplifyBranches(Function &F) { 933 bool Modified = false; 934 935 for (BasicBlock &BB : F) { 936 SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator()); 937 if (!SI) 938 continue; 939 if (SI->getNumCases() > 1) 940 continue; 941 942 Modified = true; 943 IRBuilder<> Builder(&BB); 944 Builder.SetInsertPoint(SI); 945 946 if (SI->getNumCases() == 0) { 947 Builder.CreateBr(SI->getDefaultDest()); 948 } else { 949 Value *Condition = 950 Builder.CreateCmp(CmpInst::ICMP_EQ, SI->getCondition(), 951 SI->case_begin()->getCaseValue()); 952 Builder.CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(), 953 SI->getDefaultDest()); 954 } 955 SI->eraseFromParent(); 956 } 957 958 return Modified; 959 } 960 961 // Makes sure every case target in |F| is unique. If 2 cases branch to the 962 // same basic block, one of the targets is updated so it jumps to a new basic 963 // block ending with a single unconditional branch to the original target. 964 bool splitSwitchCases(Function &F) { 965 bool Modified = false; 966 967 for (BasicBlock &BB : F) { 968 SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator()); 969 if (!SI) 970 continue; 971 972 BlockSet Seen; 973 Seen.insert(SI->getDefaultDest()); 974 975 auto It = SI->case_begin(); 976 while (It != SI->case_end()) { 977 BasicBlock *Target = It->getCaseSuccessor(); 978 if (Seen.count(Target) == 0) { 979 Seen.insert(Target); 980 ++It; 981 continue; 982 } 983 984 Modified = true; 985 BasicBlock *NewTarget = 986 BasicBlock::Create(F.getContext(), "new.sw.case", &F); 987 IRBuilder<> Builder(NewTarget); 988 Builder.CreateBr(Target); 989 SI->addCase(It->getCaseValue(), NewTarget); 990 It = SI->removeCase(It); 991 } 992 } 993 994 return Modified; 995 } 996 997 // Removes blocks not contributing to any structured CFG. This assumes there 998 // is no PHI nodes. 999 bool removeUselessBlocks(Function &F) { 1000 std::vector<BasicBlock *> ToRemove; 1001 1002 auto MergeBlocks = getMergeBlocks(F); 1003 auto ContinueBlocks = getContinueBlocks(F); 1004 1005 for (BasicBlock &BB : F) { 1006 if (BB.size() != 1) 1007 continue; 1008 1009 if (isa<ReturnInst>(BB.getTerminator())) 1010 continue; 1011 1012 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0) 1013 continue; 1014 1015 if (BB.getUniqueSuccessor() == nullptr) 1016 continue; 1017 1018 BasicBlock *Successor = BB.getUniqueSuccessor(); 1019 std::vector<BasicBlock *> Predecessors(predecessors(&BB).begin(), 1020 predecessors(&BB).end()); 1021 for (BasicBlock *Predecessor : Predecessors) 1022 replaceBranchTargets(Predecessor, &BB, Successor); 1023 ToRemove.push_back(&BB); 1024 } 1025 1026 for (BasicBlock *BB : ToRemove) 1027 BB->eraseFromParent(); 1028 1029 return ToRemove.size() != 0; 1030 } 1031 1032 bool addHeaderToRemainingDivergentDAG(Function &F) { 1033 bool Modified = false; 1034 1035 auto MergeBlocks = getMergeBlocks(F); 1036 auto ContinueBlocks = getContinueBlocks(F); 1037 auto HeaderBlocks = getHeaderBlocks(F); 1038 1039 DomTreeBuilder::BBDomTree DT; 1040 DomTreeBuilder::BBPostDomTree PDT; 1041 PDT.recalculate(F); 1042 DT.recalculate(F); 1043 1044 for (BasicBlock &BB : F) { 1045 if (HeaderBlocks.count(&BB) != 0) 1046 continue; 1047 if (succ_size(&BB) < 2) 1048 continue; 1049 1050 size_t CandidateEdges = 0; 1051 for (BasicBlock *Successor : successors(&BB)) { 1052 if (MergeBlocks.count(Successor) != 0 || 1053 ContinueBlocks.count(Successor) != 0) 1054 continue; 1055 if (HeaderBlocks.count(Successor) != 0) 1056 continue; 1057 CandidateEdges += 1; 1058 } 1059 1060 if (CandidateEdges <= 1) 1061 continue; 1062 1063 BasicBlock *Header = &BB; 1064 BasicBlock *Merge = PDT.getNode(&BB)->getIDom()->getBlock(); 1065 1066 bool HasBadBlock = false; 1067 visit(*Header, [&](const BasicBlock *Node) { 1068 if (DT.dominates(Header, Node)) 1069 return false; 1070 if (PDT.dominates(Merge, Node)) 1071 return false; 1072 if (Node == Header || Node == Merge) 1073 return true; 1074 1075 HasBadBlock |= MergeBlocks.count(Node) != 0 || 1076 ContinueBlocks.count(Node) != 0 || 1077 HeaderBlocks.count(Node) != 0; 1078 return !HasBadBlock; 1079 }); 1080 1081 if (HasBadBlock) 1082 continue; 1083 1084 Modified = true; 1085 1086 if (Merge == nullptr) { 1087 Merge = *successors(Header).begin(); 1088 IRBuilder<> Builder(Header); 1089 Builder.SetInsertPoint(Header->getTerminator()); 1090 1091 auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge); 1092 createOpSelectMerge(&Builder, MergeAddress); 1093 continue; 1094 } 1095 1096 Instruction *SplitInstruction = Merge->getTerminator(); 1097 if (isMergeInstruction(SplitInstruction->getPrevNode())) 1098 SplitInstruction = SplitInstruction->getPrevNode(); 1099 BasicBlock *NewMerge = 1100 Merge->splitBasicBlockBefore(SplitInstruction, "new.merge"); 1101 1102 IRBuilder<> Builder(Header); 1103 Builder.SetInsertPoint(Header->getTerminator()); 1104 1105 auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge); 1106 createOpSelectMerge(&Builder, MergeAddress); 1107 } 1108 1109 return Modified; 1110 } 1111 1112 public: 1113 static char ID; 1114 1115 SPIRVStructurizer() : FunctionPass(ID) {} 1116 1117 virtual bool runOnFunction(Function &F) override { 1118 bool Modified = false; 1119 1120 // In LLVM, Switches are allowed to have several cases branching to the same 1121 // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V 1122 // harder, so first remove edge cases. 1123 Modified |= splitSwitchCases(F); 1124 1125 // LLVM allows conditional branches to have both side jumping to the same 1126 // block. It also allows switched to have a single default, or just one 1127 // case. Cleaning this up now. 1128 Modified |= simplifyBranches(F); 1129 1130 // At this state, we should have a reducible CFG with cycles. 1131 // STEP 1: Adding OpLoopMerge instructions to loop headers. 1132 Modified |= addMergeForLoops(F); 1133 1134 // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2. 1135 Modified |= addMergeForNodesWithMultiplePredecessors(F); 1136 1137 // STEP 3: 1138 // Sort selection merge, the largest construct goes first. 1139 // This simplifies the next step. 1140 Modified |= sortSelectionMergeHeaders(F); 1141 1142 // STEP 4: As this stage, we can have a single basic block with multiple 1143 // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each 1144 // BB has a single merge instruction. 1145 Modified |= splitBlocksWithMultipleHeaders(F); 1146 1147 // STEP 5: In the previous steps, we added merge blocks the loops and 1148 // natural merge blocks (in-degree >= 2). What remains are conditions with 1149 // an exiting branch (return, unreachable). In such case, we must start from 1150 // the header, and add headers to divergent construct with no headers. 1151 Modified |= addMergeForDivergentBlocks(F); 1152 1153 // STEP 6: At this stage, we have several divergent construct defines by a 1154 // header and a merge block. But their boundaries have no constraints: a 1155 // construct exit could be outside of the parents' construct exit. Such 1156 // edges are called critical edges. What we need is to split those edges 1157 // into several parts. Each part exiting the parent's construct by its merge 1158 // block. 1159 Modified |= splitCriticalEdges(F); 1160 1161 // STEP 7: The previous steps possibly created a lot of "proxy" blocks. 1162 // Blocks with a single unconditional branch, used to create a valid 1163 // divergent construct tree. Some nodes are still requires (e.g: nodes 1164 // allowing a valid exit through the parent's merge block). But some are 1165 // left-overs of past transformations, and could cause actual validation 1166 // issues. E.g: the SPIR-V spec allows a construct to break to the parents 1167 // loop construct without an OpSelectionMerge, but this requires a straight 1168 // jump. If a proxy block lies between the conditional branch and the 1169 // parent's merge, the CFG is not valid. 1170 Modified |= removeUselessBlocks(F); 1171 1172 // STEP 8: Final fix-up steps: our tree boundaries are correct, but some 1173 // blocks are branching with no header. Those are often simple conditional 1174 // branches with 1 or 2 returning edges. Adding a header for those. 1175 Modified |= addHeaderToRemainingDivergentDAG(F); 1176 1177 // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements. 1178 Modified |= sortBlocks(F); 1179 1180 return Modified; 1181 } 1182 1183 void getAnalysisUsage(AnalysisUsage &AU) const override { 1184 AU.addRequired<DominatorTreeWrapperPass>(); 1185 AU.addRequired<LoopInfoWrapperPass>(); 1186 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); 1187 1188 AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>(); 1189 FunctionPass::getAnalysisUsage(AU); 1190 } 1191 1192 void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) { 1193 Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator(); 1194 1195 MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); 1196 1197 ConstantInt *BranchHint = ConstantInt::get(Builder->getInt32Ty(), 0); 1198 1199 if (MDNode) { 1200 assert(MDNode->getNumOperands() == 2 && 1201 "invalid metadata hlsl.controlflow.hint"); 1202 BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1)); 1203 } 1204 1205 SmallVector<Value *, 2> Args = {MergeAddress, BranchHint}; 1206 1207 Builder->CreateIntrinsic(Intrinsic::spv_selection_merge, 1208 {MergeAddress->getType()}, Args); 1209 } 1210 }; 1211 } // anonymous namespace 1212 1213 char SPIRVStructurizer::ID = 0; 1214 1215 INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "spirv-structurizer", 1216 "structurize SPIRV", false, false) 1217 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 1218 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1219 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 1220 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) 1221 1222 INITIALIZE_PASS_END(SPIRVStructurizer, "spirv-structurizer", 1223 "structurize SPIRV", false, false) 1224 1225 FunctionPass *llvm::createSPIRVStructurizerPass() { 1226 return new SPIRVStructurizer(); 1227 } 1228 1229 PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F, 1230 FunctionAnalysisManager &AF) { 1231 1232 auto FPM = legacy::FunctionPassManager(F.getParent()); 1233 FPM.add(createSPIRVStructurizerPass()); 1234 1235 if (!FPM.run(F)) 1236 return PreservedAnalyses::all(); 1237 PreservedAnalyses PA; 1238 PA.preserveSet<CFGAnalyses>(); 1239 return PA; 1240 } 1241