1 //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Merge the multiple exit targets of a convergence region into a single block. 10 // Each exit target will be assigned a constant value, and a phi node + switch 11 // will allow the new exit target to re-route to the correct basic block. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "Analysis/SPIRVConvergenceRegionAnalysis.h" 16 #include "SPIRV.h" 17 #include "SPIRVSubtarget.h" 18 #include "SPIRVTargetMachine.h" 19 #include "SPIRVUtils.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/Analysis/LoopInfo.h" 23 #include "llvm/CodeGen/IntrinsicLowering.h" 24 #include "llvm/IR/CFG.h" 25 #include "llvm/IR/Dominators.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/IntrinsicInst.h" 28 #include "llvm/IR/Intrinsics.h" 29 #include "llvm/IR/IntrinsicsSPIRV.h" 30 #include "llvm/InitializePasses.h" 31 #include "llvm/Transforms/Utils/Cloning.h" 32 #include "llvm/Transforms/Utils/LoopSimplify.h" 33 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 34 35 using namespace llvm; 36 37 namespace llvm { 38 void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &); 39 40 class SPIRVMergeRegionExitTargets : public FunctionPass { 41 public: 42 static char ID; 43 44 SPIRVMergeRegionExitTargets() : FunctionPass(ID) { 45 initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry()); 46 }; 47 48 // Gather all the successors of |BB|. 49 // This function asserts if the terminator neither a branch, switch or return. 50 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) { 51 std::unordered_set<BasicBlock *> output; 52 auto *T = BB->getTerminator(); 53 54 if (auto *BI = dyn_cast<BranchInst>(T)) { 55 output.insert(BI->getSuccessor(0)); 56 if (BI->isConditional()) 57 output.insert(BI->getSuccessor(1)); 58 return output; 59 } 60 61 if (auto *SI = dyn_cast<SwitchInst>(T)) { 62 output.insert(SI->getDefaultDest()); 63 for (auto &Case : SI->cases()) 64 output.insert(Case.getCaseSuccessor()); 65 return output; 66 } 67 68 assert(isa<ReturnInst>(T) && "Unhandled terminator type."); 69 return output; 70 } 71 72 /// Create a value in BB set to the value associated with the branch the block 73 /// terminator will take. 74 llvm::Value *createExitVariable( 75 BasicBlock *BB, 76 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) { 77 auto *T = BB->getTerminator(); 78 if (isa<ReturnInst>(T)) 79 return nullptr; 80 81 IRBuilder<> Builder(BB); 82 Builder.SetInsertPoint(T); 83 84 if (auto *BI = dyn_cast<BranchInst>(T)) { 85 86 BasicBlock *LHSTarget = BI->getSuccessor(0); 87 BasicBlock *RHSTarget = 88 BI->isConditional() ? BI->getSuccessor(1) : nullptr; 89 90 Value *LHS = TargetToValue.count(LHSTarget) != 0 91 ? TargetToValue.at(LHSTarget) 92 : nullptr; 93 Value *RHS = TargetToValue.count(RHSTarget) != 0 94 ? TargetToValue.at(RHSTarget) 95 : nullptr; 96 97 if (LHS == nullptr || RHS == nullptr) 98 return LHS == nullptr ? RHS : LHS; 99 return Builder.CreateSelect(BI->getCondition(), LHS, RHS); 100 } 101 102 // TODO: add support for switch cases. 103 llvm_unreachable("Unhandled terminator type."); 104 } 105 106 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. 107 void replaceBranchTargets(BasicBlock *BB, 108 const SmallPtrSet<BasicBlock *, 4> &ToReplace, 109 BasicBlock *NewTarget) { 110 auto *T = BB->getTerminator(); 111 if (isa<ReturnInst>(T)) 112 return; 113 114 if (auto *BI = dyn_cast<BranchInst>(T)) { 115 for (size_t i = 0; i < BI->getNumSuccessors(); i++) { 116 if (ToReplace.count(BI->getSuccessor(i)) != 0) 117 BI->setSuccessor(i, NewTarget); 118 } 119 return; 120 } 121 122 if (auto *SI = dyn_cast<SwitchInst>(T)) { 123 for (size_t i = 0; i < SI->getNumSuccessors(); i++) { 124 if (ToReplace.count(SI->getSuccessor(i)) != 0) 125 SI->setSuccessor(i, NewTarget); 126 } 127 return; 128 } 129 130 assert(false && "Unhandled terminator type."); 131 } 132 133 // Run the pass on the given convergence region, ignoring the sub-regions. 134 // Returns true if the CFG changed, false otherwise. 135 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, 136 const SPIRV::ConvergenceRegion *CR) { 137 // Gather all the exit targets for this region. 138 SmallPtrSet<BasicBlock *, 4> ExitTargets; 139 for (BasicBlock *Exit : CR->Exits) { 140 for (BasicBlock *Target : gatherSuccessors(Exit)) { 141 if (CR->Blocks.count(Target) == 0) 142 ExitTargets.insert(Target); 143 } 144 } 145 146 // If we have zero or one exit target, nothing do to. 147 if (ExitTargets.size() <= 1) 148 return false; 149 150 // Create the new single exit target. 151 auto F = CR->Entry->getParent(); 152 auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F); 153 IRBuilder<> Builder(NewExitTarget); 154 155 // CodeGen output needs to be stable. Using the set as-is would order 156 // the targets differently depending on the allocation pattern. 157 // Sorting per basic-block ordering in the function. 158 std::vector<BasicBlock *> SortedExitTargets; 159 std::vector<BasicBlock *> SortedExits; 160 for (BasicBlock &BB : *F) { 161 if (ExitTargets.count(&BB) != 0) 162 SortedExitTargets.push_back(&BB); 163 if (CR->Exits.count(&BB) != 0) 164 SortedExits.push_back(&BB); 165 } 166 167 // Creating one constant per distinct exit target. This will be route to the 168 // correct target. 169 DenseMap<BasicBlock *, ConstantInt *> TargetToValue; 170 for (BasicBlock *Target : SortedExitTargets) 171 TargetToValue.insert( 172 std::make_pair(Target, Builder.getInt32(TargetToValue.size()))); 173 174 // Creating one variable per exit node, set to the constant matching the 175 // targeted external block. 176 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable; 177 for (auto Exit : SortedExits) { 178 llvm::Value *Value = createExitVariable(Exit, TargetToValue); 179 ExitToVariable.emplace_back(std::make_pair(Exit, Value)); 180 } 181 182 // Gather the correct value depending on the exit we came from. 183 llvm::PHINode *node = 184 Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size()); 185 for (auto [BB, Value] : ExitToVariable) { 186 node->addIncoming(Value, BB); 187 } 188 189 // Creating the switch to jump to the correct exit target. 190 llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0], 191 SortedExitTargets.size() - 1); 192 for (size_t i = 1; i < SortedExitTargets.size(); i++) { 193 BasicBlock *BB = SortedExitTargets[i]; 194 Sw->addCase(TargetToValue[BB], BB); 195 } 196 197 // Fix exit branches to redirect to the new exit. 198 for (auto Exit : CR->Exits) 199 replaceBranchTargets(Exit, ExitTargets, NewExitTarget); 200 201 return true; 202 } 203 204 /// Run the pass on the given convergence region and sub-regions (DFS). 205 /// Returns true if a region/sub-region was modified, false otherwise. 206 /// This returns as soon as one region/sub-region has been modified. 207 bool runOnConvergenceRegion(LoopInfo &LI, 208 const SPIRV::ConvergenceRegion *CR) { 209 for (auto *Child : CR->Children) 210 if (runOnConvergenceRegion(LI, Child)) 211 return true; 212 213 return runOnConvergenceRegionNoRecurse(LI, CR); 214 } 215 216 #if !NDEBUG 217 /// Validates each edge exiting the region has the same destination basic 218 /// block. 219 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { 220 for (auto *Child : CR->Children) 221 validateRegionExits(Child); 222 223 std::unordered_set<BasicBlock *> ExitTargets; 224 for (auto *Exit : CR->Exits) { 225 auto Set = gatherSuccessors(Exit); 226 for (auto *BB : Set) { 227 if (CR->Blocks.count(BB) == 0) 228 ExitTargets.insert(BB); 229 } 230 } 231 232 assert(ExitTargets.size() <= 1); 233 } 234 #endif 235 236 virtual bool runOnFunction(Function &F) override { 237 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 238 const auto *TopLevelRegion = 239 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 240 .getRegionInfo() 241 .getTopLevelRegion(); 242 243 // FIXME: very inefficient method: each time a region is modified, we bubble 244 // back up, and recompute the whole convergence region tree. Once the 245 // algorithm is completed and test coverage good enough, rewrite this pass 246 // to be efficient instead of simple. 247 bool modified = false; 248 while (runOnConvergenceRegion(LI, TopLevelRegion)) { 249 TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>() 250 .getRegionInfo() 251 .getTopLevelRegion(); 252 modified = true; 253 } 254 255 #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) 256 validateRegionExits(TopLevelRegion); 257 #endif 258 return modified; 259 } 260 261 void getAnalysisUsage(AnalysisUsage &AU) const override { 262 AU.addRequired<DominatorTreeWrapperPass>(); 263 AU.addRequired<LoopInfoWrapperPass>(); 264 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>(); 265 FunctionPass::getAnalysisUsage(AU); 266 } 267 }; 268 } // namespace llvm 269 270 char SPIRVMergeRegionExitTargets::ID = 0; 271 272 INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", 273 "SPIRV split region exit blocks", false, false) 274 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 275 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 276 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 277 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) 278 279 INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", 280 "SPIRV split region exit blocks", false, false) 281 282 FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { 283 return new SPIRVMergeRegionExitTargets(); 284 } 285