1*0b57cec5SDimitry Andric //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===// 2*0b57cec5SDimitry Andric // 3*0b57cec5SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0b57cec5SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*0b57cec5SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0b57cec5SDimitry Andric // 7*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 8*0b57cec5SDimitry Andric // 9*0b57cec5SDimitry Andric // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring 10*0b57cec5SDimitry Andric // there is at most one ret and one unreachable instruction, it ensures there is 11*0b57cec5SDimitry Andric // at most one divergent exiting block. 12*0b57cec5SDimitry Andric // 13*0b57cec5SDimitry Andric // StructurizeCFG can't deal with multi-exit regions formed by branches to 14*0b57cec5SDimitry Andric // multiple return nodes. It is not desirable to structurize regions with 15*0b57cec5SDimitry Andric // uniform branches, so unifying those to the same return block as divergent 16*0b57cec5SDimitry Andric // branches inhibits use of scalar branching. It still can't deal with the case 17*0b57cec5SDimitry Andric // where one branch goes to return, and one unreachable. Replace unreachable in 18*0b57cec5SDimitry Andric // this case with a return. 19*0b57cec5SDimitry Andric // 20*0b57cec5SDimitry Andric //===----------------------------------------------------------------------===// 21*0b57cec5SDimitry Andric 22*0b57cec5SDimitry Andric #include "AMDGPU.h" 23*0b57cec5SDimitry Andric #include "llvm/ADT/ArrayRef.h" 24*0b57cec5SDimitry Andric #include "llvm/ADT/SmallPtrSet.h" 25*0b57cec5SDimitry Andric #include "llvm/ADT/SmallVector.h" 26*0b57cec5SDimitry Andric #include "llvm/ADT/StringRef.h" 27*0b57cec5SDimitry Andric #include "llvm/Analysis/LegacyDivergenceAnalysis.h" 28*0b57cec5SDimitry Andric #include "llvm/Analysis/PostDominators.h" 29*0b57cec5SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 30*0b57cec5SDimitry Andric #include "llvm/Transforms/Utils/Local.h" 31*0b57cec5SDimitry Andric #include "llvm/IR/BasicBlock.h" 32*0b57cec5SDimitry Andric #include "llvm/IR/CFG.h" 33*0b57cec5SDimitry Andric #include "llvm/IR/Constants.h" 34*0b57cec5SDimitry Andric #include "llvm/IR/Function.h" 35*0b57cec5SDimitry Andric #include "llvm/IR/InstrTypes.h" 36*0b57cec5SDimitry Andric #include "llvm/IR/Instructions.h" 37*0b57cec5SDimitry Andric #include "llvm/IR/Intrinsics.h" 38*0b57cec5SDimitry Andric #include "llvm/IR/Type.h" 39*0b57cec5SDimitry Andric #include "llvm/Pass.h" 40*0b57cec5SDimitry Andric #include "llvm/Support/Casting.h" 41*0b57cec5SDimitry Andric #include "llvm/Transforms/Scalar.h" 42*0b57cec5SDimitry Andric #include "llvm/Transforms/Utils.h" 43*0b57cec5SDimitry Andric 44*0b57cec5SDimitry Andric using namespace llvm; 45*0b57cec5SDimitry Andric 46*0b57cec5SDimitry Andric #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes" 47*0b57cec5SDimitry Andric 48*0b57cec5SDimitry Andric namespace { 49*0b57cec5SDimitry Andric 50*0b57cec5SDimitry Andric class AMDGPUUnifyDivergentExitNodes : public FunctionPass { 51*0b57cec5SDimitry Andric public: 52*0b57cec5SDimitry Andric static char ID; // Pass identification, replacement for typeid 53*0b57cec5SDimitry Andric 54*0b57cec5SDimitry Andric AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) { 55*0b57cec5SDimitry Andric initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry()); 56*0b57cec5SDimitry Andric } 57*0b57cec5SDimitry Andric 58*0b57cec5SDimitry Andric // We can preserve non-critical-edgeness when we unify function exit nodes 59*0b57cec5SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override; 60*0b57cec5SDimitry Andric bool runOnFunction(Function &F) override; 61*0b57cec5SDimitry Andric }; 62*0b57cec5SDimitry Andric 63*0b57cec5SDimitry Andric } // end anonymous namespace 64*0b57cec5SDimitry Andric 65*0b57cec5SDimitry Andric char AMDGPUUnifyDivergentExitNodes::ID = 0; 66*0b57cec5SDimitry Andric 67*0b57cec5SDimitry Andric char &llvm::AMDGPUUnifyDivergentExitNodesID = AMDGPUUnifyDivergentExitNodes::ID; 68*0b57cec5SDimitry Andric 69*0b57cec5SDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, 70*0b57cec5SDimitry Andric "Unify divergent function exit nodes", false, false) 71*0b57cec5SDimitry Andric INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) 72*0b57cec5SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis) 73*0b57cec5SDimitry Andric INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, 74*0b57cec5SDimitry Andric "Unify divergent function exit nodes", false, false) 75*0b57cec5SDimitry Andric 76*0b57cec5SDimitry Andric void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{ 77*0b57cec5SDimitry Andric // TODO: Preserve dominator tree. 78*0b57cec5SDimitry Andric AU.addRequired<PostDominatorTreeWrapperPass>(); 79*0b57cec5SDimitry Andric 80*0b57cec5SDimitry Andric AU.addRequired<LegacyDivergenceAnalysis>(); 81*0b57cec5SDimitry Andric 82*0b57cec5SDimitry Andric // No divergent values are changed, only blocks and branch edges. 83*0b57cec5SDimitry Andric AU.addPreserved<LegacyDivergenceAnalysis>(); 84*0b57cec5SDimitry Andric 85*0b57cec5SDimitry Andric // We preserve the non-critical-edgeness property 86*0b57cec5SDimitry Andric AU.addPreservedID(BreakCriticalEdgesID); 87*0b57cec5SDimitry Andric 88*0b57cec5SDimitry Andric // This is a cluster of orthogonal Transforms 89*0b57cec5SDimitry Andric AU.addPreservedID(LowerSwitchID); 90*0b57cec5SDimitry Andric FunctionPass::getAnalysisUsage(AU); 91*0b57cec5SDimitry Andric 92*0b57cec5SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 93*0b57cec5SDimitry Andric } 94*0b57cec5SDimitry Andric 95*0b57cec5SDimitry Andric /// \returns true if \p BB is reachable through only uniform branches. 96*0b57cec5SDimitry Andric /// XXX - Is there a more efficient way to find this? 97*0b57cec5SDimitry Andric static bool isUniformlyReached(const LegacyDivergenceAnalysis &DA, 98*0b57cec5SDimitry Andric BasicBlock &BB) { 99*0b57cec5SDimitry Andric SmallVector<BasicBlock *, 8> Stack; 100*0b57cec5SDimitry Andric SmallPtrSet<BasicBlock *, 8> Visited; 101*0b57cec5SDimitry Andric 102*0b57cec5SDimitry Andric for (BasicBlock *Pred : predecessors(&BB)) 103*0b57cec5SDimitry Andric Stack.push_back(Pred); 104*0b57cec5SDimitry Andric 105*0b57cec5SDimitry Andric while (!Stack.empty()) { 106*0b57cec5SDimitry Andric BasicBlock *Top = Stack.pop_back_val(); 107*0b57cec5SDimitry Andric if (!DA.isUniform(Top->getTerminator())) 108*0b57cec5SDimitry Andric return false; 109*0b57cec5SDimitry Andric 110*0b57cec5SDimitry Andric for (BasicBlock *Pred : predecessors(Top)) { 111*0b57cec5SDimitry Andric if (Visited.insert(Pred).second) 112*0b57cec5SDimitry Andric Stack.push_back(Pred); 113*0b57cec5SDimitry Andric } 114*0b57cec5SDimitry Andric } 115*0b57cec5SDimitry Andric 116*0b57cec5SDimitry Andric return true; 117*0b57cec5SDimitry Andric } 118*0b57cec5SDimitry Andric 119*0b57cec5SDimitry Andric static BasicBlock *unifyReturnBlockSet(Function &F, 120*0b57cec5SDimitry Andric ArrayRef<BasicBlock *> ReturningBlocks, 121*0b57cec5SDimitry Andric const TargetTransformInfo &TTI, 122*0b57cec5SDimitry Andric StringRef Name) { 123*0b57cec5SDimitry Andric // Otherwise, we need to insert a new basic block into the function, add a PHI 124*0b57cec5SDimitry Andric // nodes (if the function returns values), and convert all of the return 125*0b57cec5SDimitry Andric // instructions into unconditional branches. 126*0b57cec5SDimitry Andric BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F); 127*0b57cec5SDimitry Andric 128*0b57cec5SDimitry Andric PHINode *PN = nullptr; 129*0b57cec5SDimitry Andric if (F.getReturnType()->isVoidTy()) { 130*0b57cec5SDimitry Andric ReturnInst::Create(F.getContext(), nullptr, NewRetBlock); 131*0b57cec5SDimitry Andric } else { 132*0b57cec5SDimitry Andric // If the function doesn't return void... add a PHI node to the block... 133*0b57cec5SDimitry Andric PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(), 134*0b57cec5SDimitry Andric "UnifiedRetVal"); 135*0b57cec5SDimitry Andric NewRetBlock->getInstList().push_back(PN); 136*0b57cec5SDimitry Andric ReturnInst::Create(F.getContext(), PN, NewRetBlock); 137*0b57cec5SDimitry Andric } 138*0b57cec5SDimitry Andric 139*0b57cec5SDimitry Andric // Loop over all of the blocks, replacing the return instruction with an 140*0b57cec5SDimitry Andric // unconditional branch. 141*0b57cec5SDimitry Andric for (BasicBlock *BB : ReturningBlocks) { 142*0b57cec5SDimitry Andric // Add an incoming element to the PHI node for every return instruction that 143*0b57cec5SDimitry Andric // is merging into this new block... 144*0b57cec5SDimitry Andric if (PN) 145*0b57cec5SDimitry Andric PN->addIncoming(BB->getTerminator()->getOperand(0), BB); 146*0b57cec5SDimitry Andric 147*0b57cec5SDimitry Andric // Remove and delete the return inst. 148*0b57cec5SDimitry Andric BB->getTerminator()->eraseFromParent(); 149*0b57cec5SDimitry Andric BranchInst::Create(NewRetBlock, BB); 150*0b57cec5SDimitry Andric } 151*0b57cec5SDimitry Andric 152*0b57cec5SDimitry Andric for (BasicBlock *BB : ReturningBlocks) { 153*0b57cec5SDimitry Andric // Cleanup possible branch to unconditional branch to the return. 154*0b57cec5SDimitry Andric simplifyCFG(BB, TTI, {2}); 155*0b57cec5SDimitry Andric } 156*0b57cec5SDimitry Andric 157*0b57cec5SDimitry Andric return NewRetBlock; 158*0b57cec5SDimitry Andric } 159*0b57cec5SDimitry Andric 160*0b57cec5SDimitry Andric bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function &F) { 161*0b57cec5SDimitry Andric auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); 162*0b57cec5SDimitry Andric if (PDT.getRoots().size() <= 1) 163*0b57cec5SDimitry Andric return false; 164*0b57cec5SDimitry Andric 165*0b57cec5SDimitry Andric LegacyDivergenceAnalysis &DA = getAnalysis<LegacyDivergenceAnalysis>(); 166*0b57cec5SDimitry Andric 167*0b57cec5SDimitry Andric // Loop over all of the blocks in a function, tracking all of the blocks that 168*0b57cec5SDimitry Andric // return. 169*0b57cec5SDimitry Andric SmallVector<BasicBlock *, 4> ReturningBlocks; 170*0b57cec5SDimitry Andric SmallVector<BasicBlock *, 4> UnreachableBlocks; 171*0b57cec5SDimitry Andric 172*0b57cec5SDimitry Andric // Dummy return block for infinite loop. 173*0b57cec5SDimitry Andric BasicBlock *DummyReturnBB = nullptr; 174*0b57cec5SDimitry Andric 175*0b57cec5SDimitry Andric for (BasicBlock *BB : PDT.getRoots()) { 176*0b57cec5SDimitry Andric if (isa<ReturnInst>(BB->getTerminator())) { 177*0b57cec5SDimitry Andric if (!isUniformlyReached(DA, *BB)) 178*0b57cec5SDimitry Andric ReturningBlocks.push_back(BB); 179*0b57cec5SDimitry Andric } else if (isa<UnreachableInst>(BB->getTerminator())) { 180*0b57cec5SDimitry Andric if (!isUniformlyReached(DA, *BB)) 181*0b57cec5SDimitry Andric UnreachableBlocks.push_back(BB); 182*0b57cec5SDimitry Andric } else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { 183*0b57cec5SDimitry Andric 184*0b57cec5SDimitry Andric ConstantInt *BoolTrue = ConstantInt::getTrue(F.getContext()); 185*0b57cec5SDimitry Andric if (DummyReturnBB == nullptr) { 186*0b57cec5SDimitry Andric DummyReturnBB = BasicBlock::Create(F.getContext(), 187*0b57cec5SDimitry Andric "DummyReturnBlock", &F); 188*0b57cec5SDimitry Andric Type *RetTy = F.getReturnType(); 189*0b57cec5SDimitry Andric Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy); 190*0b57cec5SDimitry Andric ReturnInst::Create(F.getContext(), RetVal, DummyReturnBB); 191*0b57cec5SDimitry Andric ReturningBlocks.push_back(DummyReturnBB); 192*0b57cec5SDimitry Andric } 193*0b57cec5SDimitry Andric 194*0b57cec5SDimitry Andric if (BI->isUnconditional()) { 195*0b57cec5SDimitry Andric BasicBlock *LoopHeaderBB = BI->getSuccessor(0); 196*0b57cec5SDimitry Andric BI->eraseFromParent(); // Delete the unconditional branch. 197*0b57cec5SDimitry Andric // Add a new conditional branch with a dummy edge to the return block. 198*0b57cec5SDimitry Andric BranchInst::Create(LoopHeaderBB, DummyReturnBB, BoolTrue, BB); 199*0b57cec5SDimitry Andric } else { // Conditional branch. 200*0b57cec5SDimitry Andric // Create a new transition block to hold the conditional branch. 201*0b57cec5SDimitry Andric BasicBlock *TransitionBB = BB->splitBasicBlock(BI, "TransitionBlock"); 202*0b57cec5SDimitry Andric 203*0b57cec5SDimitry Andric // Create a branch that will always branch to the transition block and 204*0b57cec5SDimitry Andric // references DummyReturnBB. 205*0b57cec5SDimitry Andric BB->getTerminator()->eraseFromParent(); 206*0b57cec5SDimitry Andric BranchInst::Create(TransitionBB, DummyReturnBB, BoolTrue, BB); 207*0b57cec5SDimitry Andric } 208*0b57cec5SDimitry Andric } 209*0b57cec5SDimitry Andric } 210*0b57cec5SDimitry Andric 211*0b57cec5SDimitry Andric if (!UnreachableBlocks.empty()) { 212*0b57cec5SDimitry Andric BasicBlock *UnreachableBlock = nullptr; 213*0b57cec5SDimitry Andric 214*0b57cec5SDimitry Andric if (UnreachableBlocks.size() == 1) { 215*0b57cec5SDimitry Andric UnreachableBlock = UnreachableBlocks.front(); 216*0b57cec5SDimitry Andric } else { 217*0b57cec5SDimitry Andric UnreachableBlock = BasicBlock::Create(F.getContext(), 218*0b57cec5SDimitry Andric "UnifiedUnreachableBlock", &F); 219*0b57cec5SDimitry Andric new UnreachableInst(F.getContext(), UnreachableBlock); 220*0b57cec5SDimitry Andric 221*0b57cec5SDimitry Andric for (BasicBlock *BB : UnreachableBlocks) { 222*0b57cec5SDimitry Andric // Remove and delete the unreachable inst. 223*0b57cec5SDimitry Andric BB->getTerminator()->eraseFromParent(); 224*0b57cec5SDimitry Andric BranchInst::Create(UnreachableBlock, BB); 225*0b57cec5SDimitry Andric } 226*0b57cec5SDimitry Andric } 227*0b57cec5SDimitry Andric 228*0b57cec5SDimitry Andric if (!ReturningBlocks.empty()) { 229*0b57cec5SDimitry Andric // Don't create a new unreachable inst if we have a return. The 230*0b57cec5SDimitry Andric // structurizer/annotator can't handle the multiple exits 231*0b57cec5SDimitry Andric 232*0b57cec5SDimitry Andric Type *RetTy = F.getReturnType(); 233*0b57cec5SDimitry Andric Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy); 234*0b57cec5SDimitry Andric // Remove and delete the unreachable inst. 235*0b57cec5SDimitry Andric UnreachableBlock->getTerminator()->eraseFromParent(); 236*0b57cec5SDimitry Andric 237*0b57cec5SDimitry Andric Function *UnreachableIntrin = 238*0b57cec5SDimitry Andric Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable); 239*0b57cec5SDimitry Andric 240*0b57cec5SDimitry Andric // Insert a call to an intrinsic tracking that this is an unreachable 241*0b57cec5SDimitry Andric // point, in case we want to kill the active lanes or something later. 242*0b57cec5SDimitry Andric CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock); 243*0b57cec5SDimitry Andric 244*0b57cec5SDimitry Andric // Don't create a scalar trap. We would only want to trap if this code was 245*0b57cec5SDimitry Andric // really reached, but a scalar trap would happen even if no lanes 246*0b57cec5SDimitry Andric // actually reached here. 247*0b57cec5SDimitry Andric ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock); 248*0b57cec5SDimitry Andric ReturningBlocks.push_back(UnreachableBlock); 249*0b57cec5SDimitry Andric } 250*0b57cec5SDimitry Andric } 251*0b57cec5SDimitry Andric 252*0b57cec5SDimitry Andric // Now handle return blocks. 253*0b57cec5SDimitry Andric if (ReturningBlocks.empty()) 254*0b57cec5SDimitry Andric return false; // No blocks return 255*0b57cec5SDimitry Andric 256*0b57cec5SDimitry Andric if (ReturningBlocks.size() == 1) 257*0b57cec5SDimitry Andric return false; // Already has a single return block 258*0b57cec5SDimitry Andric 259*0b57cec5SDimitry Andric const TargetTransformInfo &TTI 260*0b57cec5SDimitry Andric = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 261*0b57cec5SDimitry Andric 262*0b57cec5SDimitry Andric unifyReturnBlockSet(F, ReturningBlocks, TTI, "UnifiedReturnBlock"); 263*0b57cec5SDimitry Andric return true; 264*0b57cec5SDimitry Andric } 265