1 //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Merge the multiple exit targets of a convergence region into a single block.
10 // Each exit target will be assigned a constant value, and a phi node + switch
11 // will allow the new exit target to re-route to the correct basic block.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16 #include "SPIRV.h"
17 #include "SPIRVSubtarget.h"
18 #include "SPIRVTargetMachine.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/CodeGen/IntrinsicLowering.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "llvm/Transforms/Utils/LoopSimplify.h"
33 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
34
35 using namespace llvm;
36
37 namespace llvm {
38 void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
39
40 class SPIRVMergeRegionExitTargets : public FunctionPass {
41 public:
42 static char ID;
43
SPIRVMergeRegionExitTargets()44 SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
45 initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
46 };
47
48 // Gather all the successors of |BB|.
49 // This function asserts if the terminator neither a branch, switch or return.
gatherSuccessors(BasicBlock * BB)50 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
51 std::unordered_set<BasicBlock *> output;
52 auto *T = BB->getTerminator();
53
54 if (auto *BI = dyn_cast<BranchInst>(T)) {
55 output.insert(BI->getSuccessor(0));
56 if (BI->isConditional())
57 output.insert(BI->getSuccessor(1));
58 return output;
59 }
60
61 if (auto *SI = dyn_cast<SwitchInst>(T)) {
62 output.insert(SI->getDefaultDest());
63 for (auto &Case : SI->cases())
64 output.insert(Case.getCaseSuccessor());
65 return output;
66 }
67
68 assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
69 return output;
70 }
71
72 /// Create a value in BB set to the value associated with the branch the block
73 /// terminator will take.
createExitVariable(BasicBlock * BB,const DenseMap<BasicBlock *,ConstantInt * > & TargetToValue)74 llvm::Value *createExitVariable(
75 BasicBlock *BB,
76 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
77 auto *T = BB->getTerminator();
78 if (isa<ReturnInst>(T))
79 return nullptr;
80
81 IRBuilder<> Builder(BB);
82 Builder.SetInsertPoint(T);
83
84 if (auto *BI = dyn_cast<BranchInst>(T)) {
85
86 BasicBlock *LHSTarget = BI->getSuccessor(0);
87 BasicBlock *RHSTarget =
88 BI->isConditional() ? BI->getSuccessor(1) : nullptr;
89
90 Value *LHS = TargetToValue.count(LHSTarget) != 0
91 ? TargetToValue.at(LHSTarget)
92 : nullptr;
93 Value *RHS = TargetToValue.count(RHSTarget) != 0
94 ? TargetToValue.at(RHSTarget)
95 : nullptr;
96
97 if (LHS == nullptr || RHS == nullptr)
98 return LHS == nullptr ? RHS : LHS;
99 return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
100 }
101
102 // TODO: add support for switch cases.
103 llvm_unreachable("Unhandled terminator type.");
104 }
105
106 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
replaceBranchTargets(BasicBlock * BB,const SmallPtrSet<BasicBlock *,4> & ToReplace,BasicBlock * NewTarget)107 void replaceBranchTargets(BasicBlock *BB,
108 const SmallPtrSet<BasicBlock *, 4> &ToReplace,
109 BasicBlock *NewTarget) {
110 auto *T = BB->getTerminator();
111 if (isa<ReturnInst>(T))
112 return;
113
114 if (auto *BI = dyn_cast<BranchInst>(T)) {
115 for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
116 if (ToReplace.count(BI->getSuccessor(i)) != 0)
117 BI->setSuccessor(i, NewTarget);
118 }
119 return;
120 }
121
122 if (auto *SI = dyn_cast<SwitchInst>(T)) {
123 for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
124 if (ToReplace.count(SI->getSuccessor(i)) != 0)
125 SI->setSuccessor(i, NewTarget);
126 }
127 return;
128 }
129
130 assert(false && "Unhandled terminator type.");
131 }
132
133 // Run the pass on the given convergence region, ignoring the sub-regions.
134 // Returns true if the CFG changed, false otherwise.
runOnConvergenceRegionNoRecurse(LoopInfo & LI,const SPIRV::ConvergenceRegion * CR)135 bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
136 const SPIRV::ConvergenceRegion *CR) {
137 // Gather all the exit targets for this region.
138 SmallPtrSet<BasicBlock *, 4> ExitTargets;
139 for (BasicBlock *Exit : CR->Exits) {
140 for (BasicBlock *Target : gatherSuccessors(Exit)) {
141 if (CR->Blocks.count(Target) == 0)
142 ExitTargets.insert(Target);
143 }
144 }
145
146 // If we have zero or one exit target, nothing do to.
147 if (ExitTargets.size() <= 1)
148 return false;
149
150 // Create the new single exit target.
151 auto F = CR->Entry->getParent();
152 auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
153 IRBuilder<> Builder(NewExitTarget);
154
155 // CodeGen output needs to be stable. Using the set as-is would order
156 // the targets differently depending on the allocation pattern.
157 // Sorting per basic-block ordering in the function.
158 std::vector<BasicBlock *> SortedExitTargets;
159 std::vector<BasicBlock *> SortedExits;
160 for (BasicBlock &BB : *F) {
161 if (ExitTargets.count(&BB) != 0)
162 SortedExitTargets.push_back(&BB);
163 if (CR->Exits.count(&BB) != 0)
164 SortedExits.push_back(&BB);
165 }
166
167 // Creating one constant per distinct exit target. This will be route to the
168 // correct target.
169 DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
170 for (BasicBlock *Target : SortedExitTargets)
171 TargetToValue.insert(
172 std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
173
174 // Creating one variable per exit node, set to the constant matching the
175 // targeted external block.
176 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
177 for (auto Exit : SortedExits) {
178 llvm::Value *Value = createExitVariable(Exit, TargetToValue);
179 ExitToVariable.emplace_back(std::make_pair(Exit, Value));
180 }
181
182 // Gather the correct value depending on the exit we came from.
183 llvm::PHINode *node =
184 Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
185 for (auto [BB, Value] : ExitToVariable) {
186 node->addIncoming(Value, BB);
187 }
188
189 // Creating the switch to jump to the correct exit target.
190 llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
191 SortedExitTargets.size() - 1);
192 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
193 BasicBlock *BB = SortedExitTargets[i];
194 Sw->addCase(TargetToValue[BB], BB);
195 }
196
197 // Fix exit branches to redirect to the new exit.
198 for (auto Exit : CR->Exits)
199 replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
200
201 return true;
202 }
203
204 /// Run the pass on the given convergence region and sub-regions (DFS).
205 /// Returns true if a region/sub-region was modified, false otherwise.
206 /// This returns as soon as one region/sub-region has been modified.
runOnConvergenceRegion(LoopInfo & LI,const SPIRV::ConvergenceRegion * CR)207 bool runOnConvergenceRegion(LoopInfo &LI,
208 const SPIRV::ConvergenceRegion *CR) {
209 for (auto *Child : CR->Children)
210 if (runOnConvergenceRegion(LI, Child))
211 return true;
212
213 return runOnConvergenceRegionNoRecurse(LI, CR);
214 }
215
216 #if !NDEBUG
217 /// Validates each edge exiting the region has the same destination basic
218 /// block.
validateRegionExits(const SPIRV::ConvergenceRegion * CR)219 void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
220 for (auto *Child : CR->Children)
221 validateRegionExits(Child);
222
223 std::unordered_set<BasicBlock *> ExitTargets;
224 for (auto *Exit : CR->Exits) {
225 auto Set = gatherSuccessors(Exit);
226 for (auto *BB : Set) {
227 if (CR->Blocks.count(BB) == 0)
228 ExitTargets.insert(BB);
229 }
230 }
231
232 assert(ExitTargets.size() <= 1);
233 }
234 #endif
235
runOnFunction(Function & F)236 virtual bool runOnFunction(Function &F) override {
237 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
238 const auto *TopLevelRegion =
239 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
240 .getRegionInfo()
241 .getTopLevelRegion();
242
243 // FIXME: very inefficient method: each time a region is modified, we bubble
244 // back up, and recompute the whole convergence region tree. Once the
245 // algorithm is completed and test coverage good enough, rewrite this pass
246 // to be efficient instead of simple.
247 bool modified = false;
248 while (runOnConvergenceRegion(LI, TopLevelRegion)) {
249 TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
250 .getRegionInfo()
251 .getTopLevelRegion();
252 modified = true;
253 }
254
255 #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
256 validateRegionExits(TopLevelRegion);
257 #endif
258 return modified;
259 }
260
getAnalysisUsage(AnalysisUsage & AU) const261 void getAnalysisUsage(AnalysisUsage &AU) const override {
262 AU.addRequired<DominatorTreeWrapperPass>();
263 AU.addRequired<LoopInfoWrapperPass>();
264 AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
265 FunctionPass::getAnalysisUsage(AU);
266 }
267 };
268 } // namespace llvm
269
270 char SPIRVMergeRegionExitTargets::ID = 0;
271
272 INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
273 "SPIRV split region exit blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)274 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
275 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
276 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
278
279 INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
280 "SPIRV split region exit blocks", false, false)
281
282 FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
283 return new SPIRVMergeRegionExitTargets();
284 }
285