1*0fca6ea1SDimitry Andric //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2*0fca6ea1SDimitry Andric //
3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0fca6ea1SDimitry Andric //
7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
8*0fca6ea1SDimitry Andric //
9*0fca6ea1SDimitry Andric // Merge the multiple exit targets of a convergence region into a single block.
10*0fca6ea1SDimitry Andric // Each exit target will be assigned a constant value, and a phi node + switch
11*0fca6ea1SDimitry Andric // will allow the new exit target to re-route to the correct basic block.
12*0fca6ea1SDimitry Andric //
13*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
14*0fca6ea1SDimitry Andric
15*0fca6ea1SDimitry Andric #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16*0fca6ea1SDimitry Andric #include "SPIRV.h"
17*0fca6ea1SDimitry Andric #include "SPIRVSubtarget.h"
18*0fca6ea1SDimitry Andric #include "SPIRVTargetMachine.h"
19*0fca6ea1SDimitry Andric #include "SPIRVUtils.h"
20*0fca6ea1SDimitry Andric #include "llvm/ADT/DenseMap.h"
21*0fca6ea1SDimitry Andric #include "llvm/ADT/SmallPtrSet.h"
22*0fca6ea1SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
23*0fca6ea1SDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
24*0fca6ea1SDimitry Andric #include "llvm/IR/CFG.h"
25*0fca6ea1SDimitry Andric #include "llvm/IR/Dominators.h"
26*0fca6ea1SDimitry Andric #include "llvm/IR/IRBuilder.h"
27*0fca6ea1SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
28*0fca6ea1SDimitry Andric #include "llvm/IR/Intrinsics.h"
29*0fca6ea1SDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
30*0fca6ea1SDimitry Andric #include "llvm/InitializePasses.h"
31*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
32*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/LoopSimplify.h"
33*0fca6ea1SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
34*0fca6ea1SDimitry Andric
35*0fca6ea1SDimitry Andric using namespace llvm;
36*0fca6ea1SDimitry Andric
37*0fca6ea1SDimitry Andric namespace llvm {
38*0fca6ea1SDimitry Andric void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
39*0fca6ea1SDimitry Andric
40*0fca6ea1SDimitry Andric class SPIRVMergeRegionExitTargets : public FunctionPass {
41*0fca6ea1SDimitry Andric public:
42*0fca6ea1SDimitry Andric static char ID;
43*0fca6ea1SDimitry Andric
SPIRVMergeRegionExitTargets()44*0fca6ea1SDimitry Andric SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
45*0fca6ea1SDimitry Andric initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
46*0fca6ea1SDimitry Andric };
47*0fca6ea1SDimitry Andric
48*0fca6ea1SDimitry Andric // Gather all the successors of |BB|.
49*0fca6ea1SDimitry Andric // This function asserts if the terminator neither a branch, switch or return.
gatherSuccessors(BasicBlock * BB)50*0fca6ea1SDimitry Andric std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
51*0fca6ea1SDimitry Andric std::unordered_set<BasicBlock *> output;
52*0fca6ea1SDimitry Andric auto *T = BB->getTerminator();
53*0fca6ea1SDimitry Andric
54*0fca6ea1SDimitry Andric if (auto *BI = dyn_cast<BranchInst>(T)) {
55*0fca6ea1SDimitry Andric output.insert(BI->getSuccessor(0));
56*0fca6ea1SDimitry Andric if (BI->isConditional())
57*0fca6ea1SDimitry Andric output.insert(BI->getSuccessor(1));
58*0fca6ea1SDimitry Andric return output;
59*0fca6ea1SDimitry Andric }
60*0fca6ea1SDimitry Andric
61*0fca6ea1SDimitry Andric if (auto *SI = dyn_cast<SwitchInst>(T)) {
62*0fca6ea1SDimitry Andric output.insert(SI->getDefaultDest());
63*0fca6ea1SDimitry Andric for (auto &Case : SI->cases())
64*0fca6ea1SDimitry Andric output.insert(Case.getCaseSuccessor());
65*0fca6ea1SDimitry Andric return output;
66*0fca6ea1SDimitry Andric }
67*0fca6ea1SDimitry Andric
68*0fca6ea1SDimitry Andric assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
69*0fca6ea1SDimitry Andric return output;
70*0fca6ea1SDimitry Andric }
71*0fca6ea1SDimitry Andric
72*0fca6ea1SDimitry Andric /// Create a value in BB set to the value associated with the branch the block
73*0fca6ea1SDimitry Andric /// terminator will take.
createExitVariable(BasicBlock * BB,const DenseMap<BasicBlock *,ConstantInt * > & TargetToValue)74*0fca6ea1SDimitry Andric llvm::Value *createExitVariable(
75*0fca6ea1SDimitry Andric BasicBlock *BB,
76*0fca6ea1SDimitry Andric const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
77*0fca6ea1SDimitry Andric auto *T = BB->getTerminator();
78*0fca6ea1SDimitry Andric if (isa<ReturnInst>(T))
79*0fca6ea1SDimitry Andric return nullptr;
80*0fca6ea1SDimitry Andric
81*0fca6ea1SDimitry Andric IRBuilder<> Builder(BB);
82*0fca6ea1SDimitry Andric Builder.SetInsertPoint(T);
83*0fca6ea1SDimitry Andric
84*0fca6ea1SDimitry Andric if (auto *BI = dyn_cast<BranchInst>(T)) {
85*0fca6ea1SDimitry Andric
86*0fca6ea1SDimitry Andric BasicBlock *LHSTarget = BI->getSuccessor(0);
87*0fca6ea1SDimitry Andric BasicBlock *RHSTarget =
88*0fca6ea1SDimitry Andric BI->isConditional() ? BI->getSuccessor(1) : nullptr;
89*0fca6ea1SDimitry Andric
90*0fca6ea1SDimitry Andric Value *LHS = TargetToValue.count(LHSTarget) != 0
91*0fca6ea1SDimitry Andric ? TargetToValue.at(LHSTarget)
92*0fca6ea1SDimitry Andric : nullptr;
93*0fca6ea1SDimitry Andric Value *RHS = TargetToValue.count(RHSTarget) != 0
94*0fca6ea1SDimitry Andric ? TargetToValue.at(RHSTarget)
95*0fca6ea1SDimitry Andric : nullptr;
96*0fca6ea1SDimitry Andric
97*0fca6ea1SDimitry Andric if (LHS == nullptr || RHS == nullptr)
98*0fca6ea1SDimitry Andric return LHS == nullptr ? RHS : LHS;
99*0fca6ea1SDimitry Andric return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
100*0fca6ea1SDimitry Andric }
101*0fca6ea1SDimitry Andric
102*0fca6ea1SDimitry Andric // TODO: add support for switch cases.
103*0fca6ea1SDimitry Andric llvm_unreachable("Unhandled terminator type.");
104*0fca6ea1SDimitry Andric }
105*0fca6ea1SDimitry Andric
106*0fca6ea1SDimitry Andric /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
replaceBranchTargets(BasicBlock * BB,const SmallPtrSet<BasicBlock *,4> & ToReplace,BasicBlock * NewTarget)107*0fca6ea1SDimitry Andric void replaceBranchTargets(BasicBlock *BB,
108*0fca6ea1SDimitry Andric const SmallPtrSet<BasicBlock *, 4> &ToReplace,
109*0fca6ea1SDimitry Andric BasicBlock *NewTarget) {
110*0fca6ea1SDimitry Andric auto *T = BB->getTerminator();
111*0fca6ea1SDimitry Andric if (isa<ReturnInst>(T))
112*0fca6ea1SDimitry Andric return;
113*0fca6ea1SDimitry Andric
114*0fca6ea1SDimitry Andric if (auto *BI = dyn_cast<BranchInst>(T)) {
115*0fca6ea1SDimitry Andric for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
116*0fca6ea1SDimitry Andric if (ToReplace.count(BI->getSuccessor(i)) != 0)
117*0fca6ea1SDimitry Andric BI->setSuccessor(i, NewTarget);
118*0fca6ea1SDimitry Andric }
119*0fca6ea1SDimitry Andric return;
120*0fca6ea1SDimitry Andric }
121*0fca6ea1SDimitry Andric
122*0fca6ea1SDimitry Andric if (auto *SI = dyn_cast<SwitchInst>(T)) {
123*0fca6ea1SDimitry Andric for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
124*0fca6ea1SDimitry Andric if (ToReplace.count(SI->getSuccessor(i)) != 0)
125*0fca6ea1SDimitry Andric SI->setSuccessor(i, NewTarget);
126*0fca6ea1SDimitry Andric }
127*0fca6ea1SDimitry Andric return;
128*0fca6ea1SDimitry Andric }
129*0fca6ea1SDimitry Andric
130*0fca6ea1SDimitry Andric assert(false && "Unhandled terminator type.");
131*0fca6ea1SDimitry Andric }
132*0fca6ea1SDimitry Andric
133*0fca6ea1SDimitry Andric // Run the pass on the given convergence region, ignoring the sub-regions.
134*0fca6ea1SDimitry Andric // Returns true if the CFG changed, false otherwise.
runOnConvergenceRegionNoRecurse(LoopInfo & LI,const SPIRV::ConvergenceRegion * CR)135*0fca6ea1SDimitry Andric bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
136*0fca6ea1SDimitry Andric const SPIRV::ConvergenceRegion *CR) {
137*0fca6ea1SDimitry Andric // Gather all the exit targets for this region.
138*0fca6ea1SDimitry Andric SmallPtrSet<BasicBlock *, 4> ExitTargets;
139*0fca6ea1SDimitry Andric for (BasicBlock *Exit : CR->Exits) {
140*0fca6ea1SDimitry Andric for (BasicBlock *Target : gatherSuccessors(Exit)) {
141*0fca6ea1SDimitry Andric if (CR->Blocks.count(Target) == 0)
142*0fca6ea1SDimitry Andric ExitTargets.insert(Target);
143*0fca6ea1SDimitry Andric }
144*0fca6ea1SDimitry Andric }
145*0fca6ea1SDimitry Andric
146*0fca6ea1SDimitry Andric // If we have zero or one exit target, nothing do to.
147*0fca6ea1SDimitry Andric if (ExitTargets.size() <= 1)
148*0fca6ea1SDimitry Andric return false;
149*0fca6ea1SDimitry Andric
150*0fca6ea1SDimitry Andric // Create the new single exit target.
151*0fca6ea1SDimitry Andric auto F = CR->Entry->getParent();
152*0fca6ea1SDimitry Andric auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
153*0fca6ea1SDimitry Andric IRBuilder<> Builder(NewExitTarget);
154*0fca6ea1SDimitry Andric
155*0fca6ea1SDimitry Andric // CodeGen output needs to be stable. Using the set as-is would order
156*0fca6ea1SDimitry Andric // the targets differently depending on the allocation pattern.
157*0fca6ea1SDimitry Andric // Sorting per basic-block ordering in the function.
158*0fca6ea1SDimitry Andric std::vector<BasicBlock *> SortedExitTargets;
159*0fca6ea1SDimitry Andric std::vector<BasicBlock *> SortedExits;
160*0fca6ea1SDimitry Andric for (BasicBlock &BB : *F) {
161*0fca6ea1SDimitry Andric if (ExitTargets.count(&BB) != 0)
162*0fca6ea1SDimitry Andric SortedExitTargets.push_back(&BB);
163*0fca6ea1SDimitry Andric if (CR->Exits.count(&BB) != 0)
164*0fca6ea1SDimitry Andric SortedExits.push_back(&BB);
165*0fca6ea1SDimitry Andric }
166*0fca6ea1SDimitry Andric
167*0fca6ea1SDimitry Andric // Creating one constant per distinct exit target. This will be route to the
168*0fca6ea1SDimitry Andric // correct target.
169*0fca6ea1SDimitry Andric DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
170*0fca6ea1SDimitry Andric for (BasicBlock *Target : SortedExitTargets)
171*0fca6ea1SDimitry Andric TargetToValue.insert(
172*0fca6ea1SDimitry Andric std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
173*0fca6ea1SDimitry Andric
174*0fca6ea1SDimitry Andric // Creating one variable per exit node, set to the constant matching the
175*0fca6ea1SDimitry Andric // targeted external block.
176*0fca6ea1SDimitry Andric std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
177*0fca6ea1SDimitry Andric for (auto Exit : SortedExits) {
178*0fca6ea1SDimitry Andric llvm::Value *Value = createExitVariable(Exit, TargetToValue);
179*0fca6ea1SDimitry Andric ExitToVariable.emplace_back(std::make_pair(Exit, Value));
180*0fca6ea1SDimitry Andric }
181*0fca6ea1SDimitry Andric
182*0fca6ea1SDimitry Andric // Gather the correct value depending on the exit we came from.
183*0fca6ea1SDimitry Andric llvm::PHINode *node =
184*0fca6ea1SDimitry Andric Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
185*0fca6ea1SDimitry Andric for (auto [BB, Value] : ExitToVariable) {
186*0fca6ea1SDimitry Andric node->addIncoming(Value, BB);
187*0fca6ea1SDimitry Andric }
188*0fca6ea1SDimitry Andric
189*0fca6ea1SDimitry Andric // Creating the switch to jump to the correct exit target.
190*0fca6ea1SDimitry Andric llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
191*0fca6ea1SDimitry Andric SortedExitTargets.size() - 1);
192*0fca6ea1SDimitry Andric for (size_t i = 1; i < SortedExitTargets.size(); i++) {
193*0fca6ea1SDimitry Andric BasicBlock *BB = SortedExitTargets[i];
194*0fca6ea1SDimitry Andric Sw->addCase(TargetToValue[BB], BB);
195*0fca6ea1SDimitry Andric }
196*0fca6ea1SDimitry Andric
197*0fca6ea1SDimitry Andric // Fix exit branches to redirect to the new exit.
198*0fca6ea1SDimitry Andric for (auto Exit : CR->Exits)
199*0fca6ea1SDimitry Andric replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
200*0fca6ea1SDimitry Andric
201*0fca6ea1SDimitry Andric return true;
202*0fca6ea1SDimitry Andric }
203*0fca6ea1SDimitry Andric
204*0fca6ea1SDimitry Andric /// Run the pass on the given convergence region and sub-regions (DFS).
205*0fca6ea1SDimitry Andric /// Returns true if a region/sub-region was modified, false otherwise.
206*0fca6ea1SDimitry Andric /// This returns as soon as one region/sub-region has been modified.
runOnConvergenceRegion(LoopInfo & LI,const SPIRV::ConvergenceRegion * CR)207*0fca6ea1SDimitry Andric bool runOnConvergenceRegion(LoopInfo &LI,
208*0fca6ea1SDimitry Andric const SPIRV::ConvergenceRegion *CR) {
209*0fca6ea1SDimitry Andric for (auto *Child : CR->Children)
210*0fca6ea1SDimitry Andric if (runOnConvergenceRegion(LI, Child))
211*0fca6ea1SDimitry Andric return true;
212*0fca6ea1SDimitry Andric
213*0fca6ea1SDimitry Andric return runOnConvergenceRegionNoRecurse(LI, CR);
214*0fca6ea1SDimitry Andric }
215*0fca6ea1SDimitry Andric
216*0fca6ea1SDimitry Andric #if !NDEBUG
217*0fca6ea1SDimitry Andric /// Validates each edge exiting the region has the same destination basic
218*0fca6ea1SDimitry Andric /// block.
validateRegionExits(const SPIRV::ConvergenceRegion * CR)219*0fca6ea1SDimitry Andric void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
220*0fca6ea1SDimitry Andric for (auto *Child : CR->Children)
221*0fca6ea1SDimitry Andric validateRegionExits(Child);
222*0fca6ea1SDimitry Andric
223*0fca6ea1SDimitry Andric std::unordered_set<BasicBlock *> ExitTargets;
224*0fca6ea1SDimitry Andric for (auto *Exit : CR->Exits) {
225*0fca6ea1SDimitry Andric auto Set = gatherSuccessors(Exit);
226*0fca6ea1SDimitry Andric for (auto *BB : Set) {
227*0fca6ea1SDimitry Andric if (CR->Blocks.count(BB) == 0)
228*0fca6ea1SDimitry Andric ExitTargets.insert(BB);
229*0fca6ea1SDimitry Andric }
230*0fca6ea1SDimitry Andric }
231*0fca6ea1SDimitry Andric
232*0fca6ea1SDimitry Andric assert(ExitTargets.size() <= 1);
233*0fca6ea1SDimitry Andric }
234*0fca6ea1SDimitry Andric #endif
235*0fca6ea1SDimitry Andric
runOnFunction(Function & F)236*0fca6ea1SDimitry Andric virtual bool runOnFunction(Function &F) override {
237*0fca6ea1SDimitry Andric LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
238*0fca6ea1SDimitry Andric const auto *TopLevelRegion =
239*0fca6ea1SDimitry Andric getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
240*0fca6ea1SDimitry Andric .getRegionInfo()
241*0fca6ea1SDimitry Andric .getTopLevelRegion();
242*0fca6ea1SDimitry Andric
243*0fca6ea1SDimitry Andric // FIXME: very inefficient method: each time a region is modified, we bubble
244*0fca6ea1SDimitry Andric // back up, and recompute the whole convergence region tree. Once the
245*0fca6ea1SDimitry Andric // algorithm is completed and test coverage good enough, rewrite this pass
246*0fca6ea1SDimitry Andric // to be efficient instead of simple.
247*0fca6ea1SDimitry Andric bool modified = false;
248*0fca6ea1SDimitry Andric while (runOnConvergenceRegion(LI, TopLevelRegion)) {
249*0fca6ea1SDimitry Andric TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
250*0fca6ea1SDimitry Andric .getRegionInfo()
251*0fca6ea1SDimitry Andric .getTopLevelRegion();
252*0fca6ea1SDimitry Andric modified = true;
253*0fca6ea1SDimitry Andric }
254*0fca6ea1SDimitry Andric
255*0fca6ea1SDimitry Andric #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
256*0fca6ea1SDimitry Andric validateRegionExits(TopLevelRegion);
257*0fca6ea1SDimitry Andric #endif
258*0fca6ea1SDimitry Andric return modified;
259*0fca6ea1SDimitry Andric }
260*0fca6ea1SDimitry Andric
getAnalysisUsage(AnalysisUsage & AU) const261*0fca6ea1SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override {
262*0fca6ea1SDimitry Andric AU.addRequired<DominatorTreeWrapperPass>();
263*0fca6ea1SDimitry Andric AU.addRequired<LoopInfoWrapperPass>();
264*0fca6ea1SDimitry Andric AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
265*0fca6ea1SDimitry Andric FunctionPass::getAnalysisUsage(AU);
266*0fca6ea1SDimitry Andric }
267*0fca6ea1SDimitry Andric };
268*0fca6ea1SDimitry Andric } // namespace llvm
269*0fca6ea1SDimitry Andric
270*0fca6ea1SDimitry Andric char SPIRVMergeRegionExitTargets::ID = 0;
271*0fca6ea1SDimitry Andric
272*0fca6ea1SDimitry Andric INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
273*0fca6ea1SDimitry Andric "SPIRV split region exit blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)274*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
275*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
276*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277*0fca6ea1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
278*0fca6ea1SDimitry Andric
279*0fca6ea1SDimitry Andric INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
280*0fca6ea1SDimitry Andric "SPIRV split region exit blocks", false, false)
281*0fca6ea1SDimitry Andric
282*0fca6ea1SDimitry Andric FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
283*0fca6ea1SDimitry Andric return new SPIRVMergeRegionExitTargets();
284*0fca6ea1SDimitry Andric }
285