xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1*700637cbSDimitry Andric //===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
8*700637cbSDimitry Andric //
9*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
10*700637cbSDimitry Andric 
11*700637cbSDimitry Andric #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
12*700637cbSDimitry Andric #include "SPIRV.h"
13*700637cbSDimitry Andric #include "SPIRVStructurizerWrapper.h"
14*700637cbSDimitry Andric #include "SPIRVSubtarget.h"
15*700637cbSDimitry Andric #include "SPIRVUtils.h"
16*700637cbSDimitry Andric #include "llvm/ADT/DenseMap.h"
17*700637cbSDimitry Andric #include "llvm/ADT/SmallPtrSet.h"
18*700637cbSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
19*700637cbSDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
20*700637cbSDimitry Andric #include "llvm/IR/CFG.h"
21*700637cbSDimitry Andric #include "llvm/IR/Dominators.h"
22*700637cbSDimitry Andric #include "llvm/IR/IRBuilder.h"
23*700637cbSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
24*700637cbSDimitry Andric #include "llvm/IR/Intrinsics.h"
25*700637cbSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
26*700637cbSDimitry Andric #include "llvm/IR/LegacyPassManager.h"
27*700637cbSDimitry Andric #include "llvm/InitializePasses.h"
28*700637cbSDimitry Andric #include "llvm/Transforms/Utils.h"
29*700637cbSDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
30*700637cbSDimitry Andric #include "llvm/Transforms/Utils/LoopSimplify.h"
31*700637cbSDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32*700637cbSDimitry Andric #include <stack>
33*700637cbSDimitry Andric #include <unordered_set>
34*700637cbSDimitry Andric 
35*700637cbSDimitry Andric using namespace llvm;
36*700637cbSDimitry Andric using namespace SPIRV;
37*700637cbSDimitry Andric 
38*700637cbSDimitry Andric using BlockSet = std::unordered_set<BasicBlock *>;
39*700637cbSDimitry Andric using Edge = std::pair<BasicBlock *, BasicBlock *>;
40*700637cbSDimitry Andric 
41*700637cbSDimitry Andric // Helper function to do a partial order visit from the block |Start|, calling
42*700637cbSDimitry Andric // |Op| on each visited node.
partialOrderVisit(BasicBlock & Start,std::function<bool (BasicBlock *)> Op)43*700637cbSDimitry Andric static void partialOrderVisit(BasicBlock &Start,
44*700637cbSDimitry Andric                               std::function<bool(BasicBlock *)> Op) {
45*700637cbSDimitry Andric   PartialOrderingVisitor V(*Start.getParent());
46*700637cbSDimitry Andric   V.partialOrderVisit(Start, Op);
47*700637cbSDimitry Andric }
48*700637cbSDimitry Andric 
49*700637cbSDimitry Andric // Returns the exact convergence region in the tree defined by `Node` for which
50*700637cbSDimitry Andric // `BB` is the header, nullptr otherwise.
51*700637cbSDimitry Andric static const ConvergenceRegion *
getRegionForHeader(const ConvergenceRegion * Node,BasicBlock * BB)52*700637cbSDimitry Andric getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB) {
53*700637cbSDimitry Andric   if (Node->Entry == BB)
54*700637cbSDimitry Andric     return Node;
55*700637cbSDimitry Andric 
56*700637cbSDimitry Andric   for (auto *Child : Node->Children) {
57*700637cbSDimitry Andric     const auto *CR = getRegionForHeader(Child, BB);
58*700637cbSDimitry Andric     if (CR != nullptr)
59*700637cbSDimitry Andric       return CR;
60*700637cbSDimitry Andric   }
61*700637cbSDimitry Andric   return nullptr;
62*700637cbSDimitry Andric }
63*700637cbSDimitry Andric 
64*700637cbSDimitry Andric // Returns the single BasicBlock exiting the convergence region `CR`,
65*700637cbSDimitry Andric // nullptr if no such exit exists.
getExitFor(const ConvergenceRegion * CR)66*700637cbSDimitry Andric static BasicBlock *getExitFor(const ConvergenceRegion *CR) {
67*700637cbSDimitry Andric   std::unordered_set<BasicBlock *> ExitTargets;
68*700637cbSDimitry Andric   for (BasicBlock *Exit : CR->Exits) {
69*700637cbSDimitry Andric     for (BasicBlock *Successor : successors(Exit)) {
70*700637cbSDimitry Andric       if (CR->Blocks.count(Successor) == 0)
71*700637cbSDimitry Andric         ExitTargets.insert(Successor);
72*700637cbSDimitry Andric     }
73*700637cbSDimitry Andric   }
74*700637cbSDimitry Andric 
75*700637cbSDimitry Andric   assert(ExitTargets.size() <= 1);
76*700637cbSDimitry Andric   if (ExitTargets.size() == 0)
77*700637cbSDimitry Andric     return nullptr;
78*700637cbSDimitry Andric 
79*700637cbSDimitry Andric   return *ExitTargets.begin();
80*700637cbSDimitry Andric }
81*700637cbSDimitry Andric 
82*700637cbSDimitry Andric // Returns the merge block designated by I if I is a merge instruction, nullptr
83*700637cbSDimitry Andric // otherwise.
getDesignatedMergeBlock(Instruction * I)84*700637cbSDimitry Andric static BasicBlock *getDesignatedMergeBlock(Instruction *I) {
85*700637cbSDimitry Andric   IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I);
86*700637cbSDimitry Andric   if (II == nullptr)
87*700637cbSDimitry Andric     return nullptr;
88*700637cbSDimitry Andric 
89*700637cbSDimitry Andric   if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
90*700637cbSDimitry Andric       II->getIntrinsicID() != Intrinsic::spv_selection_merge)
91*700637cbSDimitry Andric     return nullptr;
92*700637cbSDimitry Andric 
93*700637cbSDimitry Andric   BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
94*700637cbSDimitry Andric   return BA->getBasicBlock();
95*700637cbSDimitry Andric }
96*700637cbSDimitry Andric 
97*700637cbSDimitry Andric // Returns the continue block designated by I if I is an OpLoopMerge, nullptr
98*700637cbSDimitry Andric // otherwise.
getDesignatedContinueBlock(Instruction * I)99*700637cbSDimitry Andric static BasicBlock *getDesignatedContinueBlock(Instruction *I) {
100*700637cbSDimitry Andric   IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I);
101*700637cbSDimitry Andric   if (II == nullptr)
102*700637cbSDimitry Andric     return nullptr;
103*700637cbSDimitry Andric 
104*700637cbSDimitry Andric   if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
105*700637cbSDimitry Andric     return nullptr;
106*700637cbSDimitry Andric 
107*700637cbSDimitry Andric   BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
108*700637cbSDimitry Andric   return BA->getBasicBlock();
109*700637cbSDimitry Andric }
110*700637cbSDimitry Andric 
111*700637cbSDimitry Andric // Returns true if Header has one merge instruction which designated Merge as
112*700637cbSDimitry Andric // merge block.
isDefinedAsSelectionMergeBy(BasicBlock & Header,BasicBlock & Merge)113*700637cbSDimitry Andric static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
114*700637cbSDimitry Andric   for (auto &I : Header) {
115*700637cbSDimitry Andric     BasicBlock *MB = getDesignatedMergeBlock(&I);
116*700637cbSDimitry Andric     if (MB == &Merge)
117*700637cbSDimitry Andric       return true;
118*700637cbSDimitry Andric   }
119*700637cbSDimitry Andric   return false;
120*700637cbSDimitry Andric }
121*700637cbSDimitry Andric 
122*700637cbSDimitry Andric // Returns true if the BB has one OpLoopMerge instruction.
hasLoopMergeInstruction(BasicBlock & BB)123*700637cbSDimitry Andric static bool hasLoopMergeInstruction(BasicBlock &BB) {
124*700637cbSDimitry Andric   for (auto &I : BB)
125*700637cbSDimitry Andric     if (getDesignatedContinueBlock(&I))
126*700637cbSDimitry Andric       return true;
127*700637cbSDimitry Andric   return false;
128*700637cbSDimitry Andric }
129*700637cbSDimitry Andric 
130*700637cbSDimitry Andric // Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false
131*700637cbSDimitry Andric // otherwise.
isMergeInstruction(Instruction * I)132*700637cbSDimitry Andric static bool isMergeInstruction(Instruction *I) {
133*700637cbSDimitry Andric   return getDesignatedMergeBlock(I) != nullptr;
134*700637cbSDimitry Andric }
135*700637cbSDimitry Andric 
136*700637cbSDimitry Andric // Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
137*700637cbSDimitry Andric // instruction.
getHeaderBlocks(Function & F)138*700637cbSDimitry Andric static SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
139*700637cbSDimitry Andric   SmallPtrSet<BasicBlock *, 2> Output;
140*700637cbSDimitry Andric   for (BasicBlock &BB : F) {
141*700637cbSDimitry Andric     for (Instruction &I : BB) {
142*700637cbSDimitry Andric       if (getDesignatedMergeBlock(&I) != nullptr)
143*700637cbSDimitry Andric         Output.insert(&BB);
144*700637cbSDimitry Andric     }
145*700637cbSDimitry Andric   }
146*700637cbSDimitry Andric   return Output;
147*700637cbSDimitry Andric }
148*700637cbSDimitry Andric 
149*700637cbSDimitry Andric // Returns all basic blocks in |F| referenced by at least 1
150*700637cbSDimitry Andric // OpSelectionMerge/OpLoopMerge instruction.
getMergeBlocks(Function & F)151*700637cbSDimitry Andric static SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
152*700637cbSDimitry Andric   SmallPtrSet<BasicBlock *, 2> Output;
153*700637cbSDimitry Andric   for (BasicBlock &BB : F) {
154*700637cbSDimitry Andric     for (Instruction &I : BB) {
155*700637cbSDimitry Andric       BasicBlock *MB = getDesignatedMergeBlock(&I);
156*700637cbSDimitry Andric       if (MB != nullptr)
157*700637cbSDimitry Andric         Output.insert(MB);
158*700637cbSDimitry Andric     }
159*700637cbSDimitry Andric   }
160*700637cbSDimitry Andric   return Output;
161*700637cbSDimitry Andric }
162*700637cbSDimitry Andric 
163*700637cbSDimitry Andric // Return all the merge instructions contained in BB.
164*700637cbSDimitry Andric // Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
165*700637cbSDimitry Andric // instruction, but this can happen while we structurize the CFG.
getMergeInstructions(BasicBlock & BB)166*700637cbSDimitry Andric static std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
167*700637cbSDimitry Andric   std::vector<Instruction *> Output;
168*700637cbSDimitry Andric   for (Instruction &I : BB)
169*700637cbSDimitry Andric     if (isMergeInstruction(&I))
170*700637cbSDimitry Andric       Output.push_back(&I);
171*700637cbSDimitry Andric   return Output;
172*700637cbSDimitry Andric }
173*700637cbSDimitry Andric 
174*700637cbSDimitry Andric // Returns all basic blocks in |F| referenced as continue target by at least 1
175*700637cbSDimitry Andric // OpLoopMerge instruction.
getContinueBlocks(Function & F)176*700637cbSDimitry Andric static SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
177*700637cbSDimitry Andric   SmallPtrSet<BasicBlock *, 2> Output;
178*700637cbSDimitry Andric   for (BasicBlock &BB : F) {
179*700637cbSDimitry Andric     for (Instruction &I : BB) {
180*700637cbSDimitry Andric       BasicBlock *MB = getDesignatedContinueBlock(&I);
181*700637cbSDimitry Andric       if (MB != nullptr)
182*700637cbSDimitry Andric         Output.insert(MB);
183*700637cbSDimitry Andric     }
184*700637cbSDimitry Andric   }
185*700637cbSDimitry Andric   return Output;
186*700637cbSDimitry Andric }
187*700637cbSDimitry Andric 
188*700637cbSDimitry Andric // Do a preorder traversal of the CFG starting from the BB |Start|.
189*700637cbSDimitry Andric // point. Calls |op| on each basic block encountered during the traversal.
visit(BasicBlock & Start,std::function<bool (BasicBlock *)> op)190*700637cbSDimitry Andric static void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
191*700637cbSDimitry Andric   std::stack<BasicBlock *> ToVisit;
192*700637cbSDimitry Andric   SmallPtrSet<BasicBlock *, 8> Seen;
193*700637cbSDimitry Andric 
194*700637cbSDimitry Andric   ToVisit.push(&Start);
195*700637cbSDimitry Andric   Seen.insert(ToVisit.top());
196*700637cbSDimitry Andric   while (ToVisit.size() != 0) {
197*700637cbSDimitry Andric     BasicBlock *BB = ToVisit.top();
198*700637cbSDimitry Andric     ToVisit.pop();
199*700637cbSDimitry Andric 
200*700637cbSDimitry Andric     if (!op(BB))
201*700637cbSDimitry Andric       continue;
202*700637cbSDimitry Andric 
203*700637cbSDimitry Andric     for (auto Succ : successors(BB)) {
204*700637cbSDimitry Andric       if (Seen.contains(Succ))
205*700637cbSDimitry Andric         continue;
206*700637cbSDimitry Andric       ToVisit.push(Succ);
207*700637cbSDimitry Andric       Seen.insert(Succ);
208*700637cbSDimitry Andric     }
209*700637cbSDimitry Andric   }
210*700637cbSDimitry Andric }
211*700637cbSDimitry Andric 
212*700637cbSDimitry Andric // Replaces the conditional and unconditional branch targets of |BB| by
213*700637cbSDimitry Andric // |NewTarget| if the target was |OldTarget|. This function also makes sure the
214*700637cbSDimitry Andric // associated merge instruction gets updated accordingly.
replaceIfBranchTargets(BasicBlock * BB,BasicBlock * OldTarget,BasicBlock * NewTarget)215*700637cbSDimitry Andric static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
216*700637cbSDimitry Andric                                    BasicBlock *NewTarget) {
217*700637cbSDimitry Andric   auto *BI = cast<BranchInst>(BB->getTerminator());
218*700637cbSDimitry Andric 
219*700637cbSDimitry Andric   // 1. Replace all matching successors.
220*700637cbSDimitry Andric   for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
221*700637cbSDimitry Andric     if (BI->getSuccessor(i) == OldTarget)
222*700637cbSDimitry Andric       BI->setSuccessor(i, NewTarget);
223*700637cbSDimitry Andric   }
224*700637cbSDimitry Andric 
225*700637cbSDimitry Andric   // Branch was unconditional, no fixup required.
226*700637cbSDimitry Andric   if (BI->isUnconditional())
227*700637cbSDimitry Andric     return;
228*700637cbSDimitry Andric 
229*700637cbSDimitry Andric   // Branch had 2 successors, maybe now both are the same?
230*700637cbSDimitry Andric   if (BI->getSuccessor(0) != BI->getSuccessor(1))
231*700637cbSDimitry Andric     return;
232*700637cbSDimitry Andric 
233*700637cbSDimitry Andric   // Note: we may end up here because the original IR had such branches.
234*700637cbSDimitry Andric   // This means Target is not necessarily equal to NewTarget.
235*700637cbSDimitry Andric   IRBuilder<> Builder(BB);
236*700637cbSDimitry Andric   Builder.SetInsertPoint(BI);
237*700637cbSDimitry Andric   Builder.CreateBr(BI->getSuccessor(0));
238*700637cbSDimitry Andric   BI->eraseFromParent();
239*700637cbSDimitry Andric 
240*700637cbSDimitry Andric   // The branch was the only instruction, nothing else to do.
241*700637cbSDimitry Andric   if (BB->size() == 1)
242*700637cbSDimitry Andric     return;
243*700637cbSDimitry Andric 
244*700637cbSDimitry Andric   // Otherwise, we need to check: was there an OpSelectionMerge before this
245*700637cbSDimitry Andric   // branch? If we removed the OpBranchConditional, we must also remove the
246*700637cbSDimitry Andric   // OpSelectionMerge. This is not valid for OpLoopMerge:
247*700637cbSDimitry Andric   IntrinsicInst *II =
248*700637cbSDimitry Andric       dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode());
249*700637cbSDimitry Andric   if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
250*700637cbSDimitry Andric     return;
251*700637cbSDimitry Andric 
252*700637cbSDimitry Andric   Constant *C = cast<Constant>(II->getOperand(0));
253*700637cbSDimitry Andric   II->eraseFromParent();
254*700637cbSDimitry Andric   if (!C->isConstantUsed())
255*700637cbSDimitry Andric     C->destroyConstant();
256*700637cbSDimitry Andric }
257*700637cbSDimitry Andric 
258*700637cbSDimitry Andric // Replaces the target of branch instruction in |BB| with |NewTarget| if it
259*700637cbSDimitry Andric // was |OldTarget|. This function also fixes the associated merge instruction.
260*700637cbSDimitry Andric // Note: this function does not simplify branching instructions, it only updates
261*700637cbSDimitry Andric // targets. See also: simplifyBranches.
replaceBranchTargets(BasicBlock * BB,BasicBlock * OldTarget,BasicBlock * NewTarget)262*700637cbSDimitry Andric static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
263*700637cbSDimitry Andric                                  BasicBlock *NewTarget) {
264*700637cbSDimitry Andric   auto *T = BB->getTerminator();
265*700637cbSDimitry Andric   if (isa<ReturnInst>(T))
266*700637cbSDimitry Andric     return;
267*700637cbSDimitry Andric 
268*700637cbSDimitry Andric   if (isa<BranchInst>(T))
269*700637cbSDimitry Andric     return replaceIfBranchTargets(BB, OldTarget, NewTarget);
270*700637cbSDimitry Andric 
271*700637cbSDimitry Andric   if (auto *SI = dyn_cast<SwitchInst>(T)) {
272*700637cbSDimitry Andric     for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
273*700637cbSDimitry Andric       if (SI->getSuccessor(i) == OldTarget)
274*700637cbSDimitry Andric         SI->setSuccessor(i, NewTarget);
275*700637cbSDimitry Andric     }
276*700637cbSDimitry Andric     return;
277*700637cbSDimitry Andric   }
278*700637cbSDimitry Andric 
279*700637cbSDimitry Andric   assert(false && "Unhandled terminator type.");
280*700637cbSDimitry Andric }
281*700637cbSDimitry Andric 
282*700637cbSDimitry Andric namespace {
283*700637cbSDimitry Andric // Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
284*700637cbSDimitry Andric // adding merge instructions when required.
285*700637cbSDimitry Andric class SPIRVStructurizer : public FunctionPass {
286*700637cbSDimitry Andric   struct DivergentConstruct;
287*700637cbSDimitry Andric   // Represents a list of condition/loops/switch constructs.
288*700637cbSDimitry Andric   // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
289*700637cbSDimitry Andric   // constructs.
290*700637cbSDimitry Andric   using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
291*700637cbSDimitry Andric 
292*700637cbSDimitry Andric   // Represents a divergent construct in the SPIR-V sense.
293*700637cbSDimitry Andric   // Such constructs are represented by a header (entry), a merge block (exit),
294*700637cbSDimitry Andric   // and possibly a continue block (back-edge). A construct can contain other
295*700637cbSDimitry Andric   // constructs, but their boundaries do not cross.
296*700637cbSDimitry Andric   struct DivergentConstruct {
297*700637cbSDimitry Andric     BasicBlock *Header = nullptr;
298*700637cbSDimitry Andric     BasicBlock *Merge = nullptr;
299*700637cbSDimitry Andric     BasicBlock *Continue = nullptr;
300*700637cbSDimitry Andric 
301*700637cbSDimitry Andric     DivergentConstruct *Parent = nullptr;
302*700637cbSDimitry Andric     ConstructList Children;
303*700637cbSDimitry Andric   };
304*700637cbSDimitry Andric 
305*700637cbSDimitry Andric   // An helper class to clean the construct boundaries.
306*700637cbSDimitry Andric   // It is used to gather the list of blocks that should belong to each
307*700637cbSDimitry Andric   // divergent construct, and possibly modify CFG edges when exits would cross
308*700637cbSDimitry Andric   // the boundary of multiple constructs.
309*700637cbSDimitry Andric   struct Splitter {
310*700637cbSDimitry Andric     Function &F;
311*700637cbSDimitry Andric     LoopInfo &LI;
312*700637cbSDimitry Andric     DomTreeBuilder::BBDomTree DT;
313*700637cbSDimitry Andric     DomTreeBuilder::BBPostDomTree PDT;
314*700637cbSDimitry Andric 
Splitter__anon1b48a8830111::SPIRVStructurizer::Splitter315*700637cbSDimitry Andric     Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
316*700637cbSDimitry Andric 
invalidate__anon1b48a8830111::SPIRVStructurizer::Splitter317*700637cbSDimitry Andric     void invalidate() {
318*700637cbSDimitry Andric       PDT.recalculate(F);
319*700637cbSDimitry Andric       DT.recalculate(F);
320*700637cbSDimitry Andric     }
321*700637cbSDimitry Andric 
322*700637cbSDimitry Andric     // Returns the list of blocks that belong to a SPIR-V loop construct,
323*700637cbSDimitry Andric     // including the continue construct.
getLoopConstructBlocks__anon1b48a8830111::SPIRVStructurizer::Splitter324*700637cbSDimitry Andric     std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
325*700637cbSDimitry Andric                                                      BasicBlock *Merge) {
326*700637cbSDimitry Andric       assert(DT.dominates(Header, Merge));
327*700637cbSDimitry Andric       std::vector<BasicBlock *> Output;
328*700637cbSDimitry Andric       partialOrderVisit(*Header, [&](BasicBlock *BB) {
329*700637cbSDimitry Andric         if (BB == Merge)
330*700637cbSDimitry Andric           return false;
331*700637cbSDimitry Andric         if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB))
332*700637cbSDimitry Andric           return false;
333*700637cbSDimitry Andric         Output.push_back(BB);
334*700637cbSDimitry Andric         return true;
335*700637cbSDimitry Andric       });
336*700637cbSDimitry Andric       return Output;
337*700637cbSDimitry Andric     }
338*700637cbSDimitry Andric 
339*700637cbSDimitry Andric     // Returns the list of blocks that belong to a SPIR-V selection construct.
340*700637cbSDimitry Andric     std::vector<BasicBlock *>
getSelectionConstructBlocks__anon1b48a8830111::SPIRVStructurizer::Splitter341*700637cbSDimitry Andric     getSelectionConstructBlocks(DivergentConstruct *Node) {
342*700637cbSDimitry Andric       assert(DT.dominates(Node->Header, Node->Merge));
343*700637cbSDimitry Andric       BlockSet OutsideBlocks;
344*700637cbSDimitry Andric       OutsideBlocks.insert(Node->Merge);
345*700637cbSDimitry Andric 
346*700637cbSDimitry Andric       for (DivergentConstruct *It = Node->Parent; It != nullptr;
347*700637cbSDimitry Andric            It = It->Parent) {
348*700637cbSDimitry Andric         OutsideBlocks.insert(It->Merge);
349*700637cbSDimitry Andric         if (It->Continue)
350*700637cbSDimitry Andric           OutsideBlocks.insert(It->Continue);
351*700637cbSDimitry Andric       }
352*700637cbSDimitry Andric 
353*700637cbSDimitry Andric       std::vector<BasicBlock *> Output;
354*700637cbSDimitry Andric       partialOrderVisit(*Node->Header, [&](BasicBlock *BB) {
355*700637cbSDimitry Andric         if (OutsideBlocks.count(BB) != 0)
356*700637cbSDimitry Andric           return false;
357*700637cbSDimitry Andric         if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
358*700637cbSDimitry Andric           return false;
359*700637cbSDimitry Andric         Output.push_back(BB);
360*700637cbSDimitry Andric         return true;
361*700637cbSDimitry Andric       });
362*700637cbSDimitry Andric       return Output;
363*700637cbSDimitry Andric     }
364*700637cbSDimitry Andric 
365*700637cbSDimitry Andric     // Returns the list of blocks that belong to a SPIR-V switch construct.
getSwitchConstructBlocks__anon1b48a8830111::SPIRVStructurizer::Splitter366*700637cbSDimitry Andric     std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
367*700637cbSDimitry Andric                                                        BasicBlock *Merge) {
368*700637cbSDimitry Andric       assert(DT.dominates(Header, Merge));
369*700637cbSDimitry Andric 
370*700637cbSDimitry Andric       std::vector<BasicBlock *> Output;
371*700637cbSDimitry Andric       partialOrderVisit(*Header, [&](BasicBlock *BB) {
372*700637cbSDimitry Andric         // the blocks structurally dominated by a switch header,
373*700637cbSDimitry Andric         if (!DT.dominates(Header, BB))
374*700637cbSDimitry Andric           return false;
375*700637cbSDimitry Andric         // excluding blocks structurally dominated by the switch header’s merge
376*700637cbSDimitry Andric         // block.
377*700637cbSDimitry Andric         if (DT.dominates(Merge, BB) || BB == Merge)
378*700637cbSDimitry Andric           return false;
379*700637cbSDimitry Andric         Output.push_back(BB);
380*700637cbSDimitry Andric         return true;
381*700637cbSDimitry Andric       });
382*700637cbSDimitry Andric       return Output;
383*700637cbSDimitry Andric     }
384*700637cbSDimitry Andric 
385*700637cbSDimitry Andric     // Returns the list of blocks that belong to a SPIR-V case construct.
getCaseConstructBlocks__anon1b48a8830111::SPIRVStructurizer::Splitter386*700637cbSDimitry Andric     std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
387*700637cbSDimitry Andric                                                      BasicBlock *Merge) {
388*700637cbSDimitry Andric       assert(DT.dominates(Target, Merge));
389*700637cbSDimitry Andric 
390*700637cbSDimitry Andric       std::vector<BasicBlock *> Output;
391*700637cbSDimitry Andric       partialOrderVisit(*Target, [&](BasicBlock *BB) {
392*700637cbSDimitry Andric         // the blocks structurally dominated by an OpSwitch Target or Default
393*700637cbSDimitry Andric         // block
394*700637cbSDimitry Andric         if (!DT.dominates(Target, BB))
395*700637cbSDimitry Andric           return false;
396*700637cbSDimitry Andric         // excluding the blocks structurally dominated by the OpSwitch
397*700637cbSDimitry Andric         // construct’s corresponding merge block.
398*700637cbSDimitry Andric         if (DT.dominates(Merge, BB) || BB == Merge)
399*700637cbSDimitry Andric           return false;
400*700637cbSDimitry Andric         Output.push_back(BB);
401*700637cbSDimitry Andric         return true;
402*700637cbSDimitry Andric       });
403*700637cbSDimitry Andric       return Output;
404*700637cbSDimitry Andric     }
405*700637cbSDimitry Andric 
406*700637cbSDimitry Andric     // Splits the given edges by recreating proxy nodes so that the destination
407*700637cbSDimitry Andric     // has unique incoming edges from this region.
408*700637cbSDimitry Andric     //
409*700637cbSDimitry Andric     // clang-format off
410*700637cbSDimitry Andric     //
411*700637cbSDimitry Andric     // In SPIR-V, constructs must have a single exit/merge.
412*700637cbSDimitry Andric     // Given nodes A and B in the construct, a node C outside, and the following edges.
413*700637cbSDimitry Andric     //  A -> C
414*700637cbSDimitry Andric     //  B -> C
415*700637cbSDimitry Andric     //
416*700637cbSDimitry Andric     // In such cases, we must create a new exit node D, that belong to the construct to make is viable:
417*700637cbSDimitry Andric     // A -> D -> C
418*700637cbSDimitry Andric     // B -> D -> C
419*700637cbSDimitry Andric     //
420*700637cbSDimitry Andric     // This is fine (assuming C has no PHI nodes), but requires handling the merge instruction here.
421*700637cbSDimitry Andric     // By adding a proxy node, we create a regular divergent shape which can easily be regularized later on.
422*700637cbSDimitry Andric     // A -> D -> D1 -> C
423*700637cbSDimitry Andric     // B -> D -> D2 -> C
424*700637cbSDimitry Andric     //
425*700637cbSDimitry Andric     // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty.
426*700637cbSDimitry Andric     //
427*700637cbSDimitry Andric     // clang-format on
428*700637cbSDimitry Andric     std::vector<Edge>
createAliasBlocksForComplexEdges__anon1b48a8830111::SPIRVStructurizer::Splitter429*700637cbSDimitry Andric     createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
430*700637cbSDimitry Andric       std::unordered_set<BasicBlock *> Seen;
431*700637cbSDimitry Andric       std::vector<Edge> Output;
432*700637cbSDimitry Andric       Output.reserve(Edges.size());
433*700637cbSDimitry Andric 
434*700637cbSDimitry Andric       for (auto &[Src, Dst] : Edges) {
435*700637cbSDimitry Andric         auto [Iterator, Inserted] = Seen.insert(Src);
436*700637cbSDimitry Andric         if (!Inserted) {
437*700637cbSDimitry Andric           // Src already a source node. Cannot have 2 edges from A to B.
438*700637cbSDimitry Andric           // Creating alias source block.
439*700637cbSDimitry Andric           BasicBlock *NewSrc = BasicBlock::Create(
440*700637cbSDimitry Andric               F.getContext(), Src->getName() + ".new.src", &F);
441*700637cbSDimitry Andric           replaceBranchTargets(Src, Dst, NewSrc);
442*700637cbSDimitry Andric           IRBuilder<> Builder(NewSrc);
443*700637cbSDimitry Andric           Builder.CreateBr(Dst);
444*700637cbSDimitry Andric           Src = NewSrc;
445*700637cbSDimitry Andric         }
446*700637cbSDimitry Andric 
447*700637cbSDimitry Andric         Output.emplace_back(Src, Dst);
448*700637cbSDimitry Andric       }
449*700637cbSDimitry Andric 
450*700637cbSDimitry Andric       return Output;
451*700637cbSDimitry Andric     }
452*700637cbSDimitry Andric 
CreateVariable__anon1b48a8830111::SPIRVStructurizer::Splitter453*700637cbSDimitry Andric     AllocaInst *CreateVariable(Function &F, Type *Type,
454*700637cbSDimitry Andric                                BasicBlock::iterator Position) {
455*700637cbSDimitry Andric       const DataLayout &DL = F.getDataLayout();
456*700637cbSDimitry Andric       return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
457*700637cbSDimitry Andric                             Position);
458*700637cbSDimitry Andric     }
459*700637cbSDimitry Andric 
460*700637cbSDimitry Andric     // Given a construct defined by |Header|, and a list of exiting edges
461*700637cbSDimitry Andric     // |Edges|, creates a new single exit node, fixing up those edges.
createSingleExitNode__anon1b48a8830111::SPIRVStructurizer::Splitter462*700637cbSDimitry Andric     BasicBlock *createSingleExitNode(BasicBlock *Header,
463*700637cbSDimitry Andric                                      std::vector<Edge> &Edges) {
464*700637cbSDimitry Andric 
465*700637cbSDimitry Andric       std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
466*700637cbSDimitry Andric 
467*700637cbSDimitry Andric       std::vector<BasicBlock *> Dsts;
468*700637cbSDimitry Andric       std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
469*700637cbSDimitry Andric       auto NewExit = BasicBlock::Create(F.getContext(),
470*700637cbSDimitry Andric                                         Header->getName() + ".new.exit", &F);
471*700637cbSDimitry Andric       IRBuilder<> ExitBuilder(NewExit);
472*700637cbSDimitry Andric       for (auto &[Src, Dst] : FixedEdges) {
473*700637cbSDimitry Andric         if (DstToIndex.count(Dst) != 0)
474*700637cbSDimitry Andric           continue;
475*700637cbSDimitry Andric         DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
476*700637cbSDimitry Andric         Dsts.push_back(Dst);
477*700637cbSDimitry Andric       }
478*700637cbSDimitry Andric 
479*700637cbSDimitry Andric       if (Dsts.size() == 1) {
480*700637cbSDimitry Andric         for (auto &[Src, Dst] : FixedEdges) {
481*700637cbSDimitry Andric           replaceBranchTargets(Src, Dst, NewExit);
482*700637cbSDimitry Andric         }
483*700637cbSDimitry Andric         ExitBuilder.CreateBr(Dsts[0]);
484*700637cbSDimitry Andric         return NewExit;
485*700637cbSDimitry Andric       }
486*700637cbSDimitry Andric 
487*700637cbSDimitry Andric       AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
488*700637cbSDimitry Andric                                             F.begin()->getFirstInsertionPt());
489*700637cbSDimitry Andric       for (auto &[Src, Dst] : FixedEdges) {
490*700637cbSDimitry Andric         IRBuilder<> B2(Src);
491*700637cbSDimitry Andric         B2.SetInsertPoint(Src->getFirstInsertionPt());
492*700637cbSDimitry Andric         B2.CreateStore(DstToIndex[Dst], Variable);
493*700637cbSDimitry Andric         replaceBranchTargets(Src, Dst, NewExit);
494*700637cbSDimitry Andric       }
495*700637cbSDimitry Andric 
496*700637cbSDimitry Andric       Value *Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
497*700637cbSDimitry Andric 
498*700637cbSDimitry Andric       // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
499*700637cbSDimitry Andric       // OpBranch are allowed to exist without a new OpSelectionMerge if one of
500*700637cbSDimitry Andric       // the branch is the parent's merge node, while OpSwitches are not.
501*700637cbSDimitry Andric       if (Dsts.size() == 2) {
502*700637cbSDimitry Andric         Value *Condition =
503*700637cbSDimitry Andric             ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load);
504*700637cbSDimitry Andric         ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
505*700637cbSDimitry Andric         return NewExit;
506*700637cbSDimitry Andric       }
507*700637cbSDimitry Andric 
508*700637cbSDimitry Andric       SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
509*700637cbSDimitry Andric       for (BasicBlock *BB : drop_begin(Dsts))
510*700637cbSDimitry Andric         Sw->addCase(DstToIndex[BB], BB);
511*700637cbSDimitry Andric       return NewExit;
512*700637cbSDimitry Andric     }
513*700637cbSDimitry Andric   };
514*700637cbSDimitry Andric 
515*700637cbSDimitry Andric   /// Create a value in BB set to the value associated with the branch the block
516*700637cbSDimitry Andric   /// terminator will take.
createExitVariable(BasicBlock * BB,const DenseMap<BasicBlock *,ConstantInt * > & TargetToValue)517*700637cbSDimitry Andric   Value *createExitVariable(
518*700637cbSDimitry Andric       BasicBlock *BB,
519*700637cbSDimitry Andric       const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
520*700637cbSDimitry Andric     auto *T = BB->getTerminator();
521*700637cbSDimitry Andric     if (isa<ReturnInst>(T))
522*700637cbSDimitry Andric       return nullptr;
523*700637cbSDimitry Andric 
524*700637cbSDimitry Andric     IRBuilder<> Builder(BB);
525*700637cbSDimitry Andric     Builder.SetInsertPoint(T);
526*700637cbSDimitry Andric 
527*700637cbSDimitry Andric     if (auto *BI = dyn_cast<BranchInst>(T)) {
528*700637cbSDimitry Andric 
529*700637cbSDimitry Andric       BasicBlock *LHSTarget = BI->getSuccessor(0);
530*700637cbSDimitry Andric       BasicBlock *RHSTarget =
531*700637cbSDimitry Andric           BI->isConditional() ? BI->getSuccessor(1) : nullptr;
532*700637cbSDimitry Andric 
533*700637cbSDimitry Andric       Value *LHS = TargetToValue.lookup(LHSTarget);
534*700637cbSDimitry Andric       Value *RHS = TargetToValue.lookup(RHSTarget);
535*700637cbSDimitry Andric 
536*700637cbSDimitry Andric       if (LHS == nullptr || RHS == nullptr)
537*700637cbSDimitry Andric         return LHS == nullptr ? RHS : LHS;
538*700637cbSDimitry Andric       return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
539*700637cbSDimitry Andric     }
540*700637cbSDimitry Andric 
541*700637cbSDimitry Andric     // TODO: add support for switch cases.
542*700637cbSDimitry Andric     llvm_unreachable("Unhandled terminator type.");
543*700637cbSDimitry Andric   }
544*700637cbSDimitry Andric 
545*700637cbSDimitry Andric   // Creates a new basic block in F with a single OpUnreachable instruction.
CreateUnreachable(Function & F)546*700637cbSDimitry Andric   BasicBlock *CreateUnreachable(Function &F) {
547*700637cbSDimitry Andric     BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F);
548*700637cbSDimitry Andric     IRBuilder<> Builder(BB);
549*700637cbSDimitry Andric     Builder.CreateUnreachable();
550*700637cbSDimitry Andric     return BB;
551*700637cbSDimitry Andric   }
552*700637cbSDimitry Andric 
553*700637cbSDimitry Andric   // Add OpLoopMerge instruction on cycles.
addMergeForLoops(Function & F)554*700637cbSDimitry Andric   bool addMergeForLoops(Function &F) {
555*700637cbSDimitry Andric     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
556*700637cbSDimitry Andric     auto *TopLevelRegion =
557*700637cbSDimitry Andric         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
558*700637cbSDimitry Andric             .getRegionInfo()
559*700637cbSDimitry Andric             .getTopLevelRegion();
560*700637cbSDimitry Andric 
561*700637cbSDimitry Andric     bool Modified = false;
562*700637cbSDimitry Andric     for (auto &BB : F) {
563*700637cbSDimitry Andric       // Not a loop header. Ignoring for now.
564*700637cbSDimitry Andric       if (!LI.isLoopHeader(&BB))
565*700637cbSDimitry Andric         continue;
566*700637cbSDimitry Andric       auto *L = LI.getLoopFor(&BB);
567*700637cbSDimitry Andric 
568*700637cbSDimitry Andric       // This loop header is not the entrance of a convergence region. Ignoring
569*700637cbSDimitry Andric       // this block.
570*700637cbSDimitry Andric       auto *CR = getRegionForHeader(TopLevelRegion, &BB);
571*700637cbSDimitry Andric       if (CR == nullptr)
572*700637cbSDimitry Andric         continue;
573*700637cbSDimitry Andric 
574*700637cbSDimitry Andric       IRBuilder<> Builder(&BB);
575*700637cbSDimitry Andric 
576*700637cbSDimitry Andric       auto *Merge = getExitFor(CR);
577*700637cbSDimitry Andric       // We are indeed in a loop, but there are no exits (infinite loop).
578*700637cbSDimitry Andric       // This could be caused by a bad shader, but also could be an artifact
579*700637cbSDimitry Andric       // from an earlier optimization. It is not always clear if structurally
580*700637cbSDimitry Andric       // reachable means runtime reachable, so we cannot error-out. What we must
581*700637cbSDimitry Andric       // do however is to make is legal on the SPIR-V point of view, hence
582*700637cbSDimitry Andric       // adding an unreachable merge block.
583*700637cbSDimitry Andric       if (Merge == nullptr) {
584*700637cbSDimitry Andric         BranchInst *Br = cast<BranchInst>(BB.getTerminator());
585*700637cbSDimitry Andric         assert(Br->isUnconditional());
586*700637cbSDimitry Andric 
587*700637cbSDimitry Andric         Merge = CreateUnreachable(F);
588*700637cbSDimitry Andric         Builder.SetInsertPoint(Br);
589*700637cbSDimitry Andric         Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0));
590*700637cbSDimitry Andric         Br->eraseFromParent();
591*700637cbSDimitry Andric       }
592*700637cbSDimitry Andric 
593*700637cbSDimitry Andric       auto *Continue = L->getLoopLatch();
594*700637cbSDimitry Andric 
595*700637cbSDimitry Andric       Builder.SetInsertPoint(BB.getTerminator());
596*700637cbSDimitry Andric       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
597*700637cbSDimitry Andric       auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
598*700637cbSDimitry Andric       SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
599*700637cbSDimitry Andric       SmallVector<unsigned, 1> LoopControlImms =
600*700637cbSDimitry Andric           getSpirvLoopControlOperandsFromLoopMetadata(L);
601*700637cbSDimitry Andric       for (unsigned Imm : LoopControlImms)
602*700637cbSDimitry Andric         Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
603*700637cbSDimitry Andric       Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {Args});
604*700637cbSDimitry Andric       Modified = true;
605*700637cbSDimitry Andric     }
606*700637cbSDimitry Andric 
607*700637cbSDimitry Andric     return Modified;
608*700637cbSDimitry Andric   }
609*700637cbSDimitry Andric 
610*700637cbSDimitry Andric   // Adds an OpSelectionMerge to the immediate dominator or each node with an
611*700637cbSDimitry Andric   // in-degree of 2 or more which is not already the merge target of an
612*700637cbSDimitry Andric   // OpLoopMerge/OpSelectionMerge.
addMergeForNodesWithMultiplePredecessors(Function & F)613*700637cbSDimitry Andric   bool addMergeForNodesWithMultiplePredecessors(Function &F) {
614*700637cbSDimitry Andric     DomTreeBuilder::BBDomTree DT;
615*700637cbSDimitry Andric     DT.recalculate(F);
616*700637cbSDimitry Andric 
617*700637cbSDimitry Andric     bool Modified = false;
618*700637cbSDimitry Andric     for (auto &BB : F) {
619*700637cbSDimitry Andric       if (pred_size(&BB) <= 1)
620*700637cbSDimitry Andric         continue;
621*700637cbSDimitry Andric 
622*700637cbSDimitry Andric       if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2)
623*700637cbSDimitry Andric         continue;
624*700637cbSDimitry Andric 
625*700637cbSDimitry Andric       assert(DT.getNode(&BB)->getIDom());
626*700637cbSDimitry Andric       BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock();
627*700637cbSDimitry Andric 
628*700637cbSDimitry Andric       if (isDefinedAsSelectionMergeBy(*Header, BB))
629*700637cbSDimitry Andric         continue;
630*700637cbSDimitry Andric 
631*700637cbSDimitry Andric       IRBuilder<> Builder(Header);
632*700637cbSDimitry Andric       Builder.SetInsertPoint(Header->getTerminator());
633*700637cbSDimitry Andric 
634*700637cbSDimitry Andric       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
635*700637cbSDimitry Andric       createOpSelectMerge(&Builder, MergeAddress);
636*700637cbSDimitry Andric 
637*700637cbSDimitry Andric       Modified = true;
638*700637cbSDimitry Andric     }
639*700637cbSDimitry Andric 
640*700637cbSDimitry Andric     return Modified;
641*700637cbSDimitry Andric   }
642*700637cbSDimitry Andric 
643*700637cbSDimitry Andric   // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts
644*700637cbSDimitry Andric   // them to put the "largest" first. A merge instruction is defined as larger
645*700637cbSDimitry Andric   // than another when its target merge block post-dominates the other target's
646*700637cbSDimitry Andric   // merge block. (This ordering should match the nesting ordering of the source
647*700637cbSDimitry Andric   // HLSL).
sortSelectionMerge(Function & F,BasicBlock & Block)648*700637cbSDimitry Andric   bool sortSelectionMerge(Function &F, BasicBlock &Block) {
649*700637cbSDimitry Andric     std::vector<Instruction *> MergeInstructions;
650*700637cbSDimitry Andric     for (Instruction &I : Block)
651*700637cbSDimitry Andric       if (isMergeInstruction(&I))
652*700637cbSDimitry Andric         MergeInstructions.push_back(&I);
653*700637cbSDimitry Andric 
654*700637cbSDimitry Andric     if (MergeInstructions.size() <= 1)
655*700637cbSDimitry Andric       return false;
656*700637cbSDimitry Andric 
657*700637cbSDimitry Andric     Instruction *InsertionPoint = *MergeInstructions.begin();
658*700637cbSDimitry Andric 
659*700637cbSDimitry Andric     PartialOrderingVisitor Visitor(F);
660*700637cbSDimitry Andric     std::sort(MergeInstructions.begin(), MergeInstructions.end(),
661*700637cbSDimitry Andric               [&Visitor](Instruction *Left, Instruction *Right) {
662*700637cbSDimitry Andric                 if (Left == Right)
663*700637cbSDimitry Andric                   return false;
664*700637cbSDimitry Andric                 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
665*700637cbSDimitry Andric                 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
666*700637cbSDimitry Andric                 return !Visitor.compare(RightMerge, LeftMerge);
667*700637cbSDimitry Andric               });
668*700637cbSDimitry Andric 
669*700637cbSDimitry Andric     for (Instruction *I : MergeInstructions) {
670*700637cbSDimitry Andric       I->moveBefore(InsertionPoint->getIterator());
671*700637cbSDimitry Andric       InsertionPoint = I;
672*700637cbSDimitry Andric     }
673*700637cbSDimitry Andric 
674*700637cbSDimitry Andric     return true;
675*700637cbSDimitry Andric   }
676*700637cbSDimitry Andric 
677*700637cbSDimitry Andric   // Sorts selection merge headers in |F|.
678*700637cbSDimitry Andric   // A is sorted before B if the merge block designated by B is an ancestor of
679*700637cbSDimitry Andric   // the one designated by A.
sortSelectionMergeHeaders(Function & F)680*700637cbSDimitry Andric   bool sortSelectionMergeHeaders(Function &F) {
681*700637cbSDimitry Andric     bool Modified = false;
682*700637cbSDimitry Andric     for (BasicBlock &BB : F) {
683*700637cbSDimitry Andric       Modified |= sortSelectionMerge(F, BB);
684*700637cbSDimitry Andric     }
685*700637cbSDimitry Andric     return Modified;
686*700637cbSDimitry Andric   }
687*700637cbSDimitry Andric 
688*700637cbSDimitry Andric   // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
689*700637cbSDimitry Andric   // instructions so each basic block contains only a single merge instruction.
splitBlocksWithMultipleHeaders(Function & F)690*700637cbSDimitry Andric   bool splitBlocksWithMultipleHeaders(Function &F) {
691*700637cbSDimitry Andric     std::stack<BasicBlock *> Work;
692*700637cbSDimitry Andric     for (auto &BB : F) {
693*700637cbSDimitry Andric       std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
694*700637cbSDimitry Andric       if (MergeInstructions.size() <= 1)
695*700637cbSDimitry Andric         continue;
696*700637cbSDimitry Andric       Work.push(&BB);
697*700637cbSDimitry Andric     }
698*700637cbSDimitry Andric 
699*700637cbSDimitry Andric     const bool Modified = Work.size() > 0;
700*700637cbSDimitry Andric     while (Work.size() > 0) {
701*700637cbSDimitry Andric       BasicBlock *Header = Work.top();
702*700637cbSDimitry Andric       Work.pop();
703*700637cbSDimitry Andric 
704*700637cbSDimitry Andric       std::vector<Instruction *> MergeInstructions =
705*700637cbSDimitry Andric           getMergeInstructions(*Header);
706*700637cbSDimitry Andric       for (unsigned i = 1; i < MergeInstructions.size(); i++) {
707*700637cbSDimitry Andric         BasicBlock *NewBlock =
708*700637cbSDimitry Andric             Header->splitBasicBlock(MergeInstructions[i], "new.header");
709*700637cbSDimitry Andric 
710*700637cbSDimitry Andric         if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) {
711*700637cbSDimitry Andric           BasicBlock *Unreachable = CreateUnreachable(F);
712*700637cbSDimitry Andric 
713*700637cbSDimitry Andric           BranchInst *BI = cast<BranchInst>(Header->getTerminator());
714*700637cbSDimitry Andric           IRBuilder<> Builder(Header);
715*700637cbSDimitry Andric           Builder.SetInsertPoint(BI);
716*700637cbSDimitry Andric           Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
717*700637cbSDimitry Andric           BI->eraseFromParent();
718*700637cbSDimitry Andric         }
719*700637cbSDimitry Andric 
720*700637cbSDimitry Andric         Header = NewBlock;
721*700637cbSDimitry Andric       }
722*700637cbSDimitry Andric     }
723*700637cbSDimitry Andric 
724*700637cbSDimitry Andric     return Modified;
725*700637cbSDimitry Andric   }
726*700637cbSDimitry Andric 
727*700637cbSDimitry Andric   // Adds an OpSelectionMerge to each block with an out-degree >= 2 which
728*700637cbSDimitry Andric   // doesn't already have an OpSelectionMerge.
addMergeForDivergentBlocks(Function & F)729*700637cbSDimitry Andric   bool addMergeForDivergentBlocks(Function &F) {
730*700637cbSDimitry Andric     DomTreeBuilder::BBPostDomTree PDT;
731*700637cbSDimitry Andric     PDT.recalculate(F);
732*700637cbSDimitry Andric     bool Modified = false;
733*700637cbSDimitry Andric 
734*700637cbSDimitry Andric     auto MergeBlocks = getMergeBlocks(F);
735*700637cbSDimitry Andric     auto ContinueBlocks = getContinueBlocks(F);
736*700637cbSDimitry Andric 
737*700637cbSDimitry Andric     for (auto &BB : F) {
738*700637cbSDimitry Andric       if (getMergeInstructions(BB).size() != 0)
739*700637cbSDimitry Andric         continue;
740*700637cbSDimitry Andric 
741*700637cbSDimitry Andric       std::vector<BasicBlock *> Candidates;
742*700637cbSDimitry Andric       for (BasicBlock *Successor : successors(&BB)) {
743*700637cbSDimitry Andric         if (MergeBlocks.contains(Successor))
744*700637cbSDimitry Andric           continue;
745*700637cbSDimitry Andric         if (ContinueBlocks.contains(Successor))
746*700637cbSDimitry Andric           continue;
747*700637cbSDimitry Andric         Candidates.push_back(Successor);
748*700637cbSDimitry Andric       }
749*700637cbSDimitry Andric 
750*700637cbSDimitry Andric       if (Candidates.size() <= 1)
751*700637cbSDimitry Andric         continue;
752*700637cbSDimitry Andric 
753*700637cbSDimitry Andric       Modified = true;
754*700637cbSDimitry Andric       BasicBlock *Merge = Candidates[0];
755*700637cbSDimitry Andric 
756*700637cbSDimitry Andric       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
757*700637cbSDimitry Andric       IRBuilder<> Builder(&BB);
758*700637cbSDimitry Andric       Builder.SetInsertPoint(BB.getTerminator());
759*700637cbSDimitry Andric       createOpSelectMerge(&Builder, MergeAddress);
760*700637cbSDimitry Andric     }
761*700637cbSDimitry Andric 
762*700637cbSDimitry Andric     return Modified;
763*700637cbSDimitry Andric   }
764*700637cbSDimitry Andric 
765*700637cbSDimitry Andric   // Gather all the exit nodes for the construct header by |Header| and
766*700637cbSDimitry Andric   // containing the blocks |Construct|.
getExitsFrom(const BlockSet & Construct,BasicBlock & Header)767*700637cbSDimitry Andric   std::vector<Edge> getExitsFrom(const BlockSet &Construct,
768*700637cbSDimitry Andric                                  BasicBlock &Header) {
769*700637cbSDimitry Andric     std::vector<Edge> Output;
770*700637cbSDimitry Andric     visit(Header, [&](BasicBlock *Item) {
771*700637cbSDimitry Andric       if (Construct.count(Item) == 0)
772*700637cbSDimitry Andric         return false;
773*700637cbSDimitry Andric 
774*700637cbSDimitry Andric       for (BasicBlock *Successor : successors(Item)) {
775*700637cbSDimitry Andric         if (Construct.count(Successor) == 0)
776*700637cbSDimitry Andric           Output.emplace_back(Item, Successor);
777*700637cbSDimitry Andric       }
778*700637cbSDimitry Andric       return true;
779*700637cbSDimitry Andric     });
780*700637cbSDimitry Andric 
781*700637cbSDimitry Andric     return Output;
782*700637cbSDimitry Andric   }
783*700637cbSDimitry Andric 
784*700637cbSDimitry Andric   // Build a divergent construct tree searching from |BB|.
785*700637cbSDimitry Andric   // If |Parent| is not null, this tree is attached to the parent's tree.
constructDivergentConstruct(BlockSet & Visited,Splitter & S,BasicBlock * BB,DivergentConstruct * Parent)786*700637cbSDimitry Andric   void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
787*700637cbSDimitry Andric                                    BasicBlock *BB, DivergentConstruct *Parent) {
788*700637cbSDimitry Andric     if (Visited.count(BB) != 0)
789*700637cbSDimitry Andric       return;
790*700637cbSDimitry Andric     Visited.insert(BB);
791*700637cbSDimitry Andric 
792*700637cbSDimitry Andric     auto MIS = getMergeInstructions(*BB);
793*700637cbSDimitry Andric     if (MIS.size() == 0) {
794*700637cbSDimitry Andric       for (BasicBlock *Successor : successors(BB))
795*700637cbSDimitry Andric         constructDivergentConstruct(Visited, S, Successor, Parent);
796*700637cbSDimitry Andric       return;
797*700637cbSDimitry Andric     }
798*700637cbSDimitry Andric 
799*700637cbSDimitry Andric     assert(MIS.size() == 1);
800*700637cbSDimitry Andric     Instruction *MI = MIS[0];
801*700637cbSDimitry Andric 
802*700637cbSDimitry Andric     BasicBlock *Merge = getDesignatedMergeBlock(MI);
803*700637cbSDimitry Andric     BasicBlock *Continue = getDesignatedContinueBlock(MI);
804*700637cbSDimitry Andric 
805*700637cbSDimitry Andric     auto Output = std::make_unique<DivergentConstruct>();
806*700637cbSDimitry Andric     Output->Header = BB;
807*700637cbSDimitry Andric     Output->Merge = Merge;
808*700637cbSDimitry Andric     Output->Continue = Continue;
809*700637cbSDimitry Andric     Output->Parent = Parent;
810*700637cbSDimitry Andric 
811*700637cbSDimitry Andric     constructDivergentConstruct(Visited, S, Merge, Parent);
812*700637cbSDimitry Andric     if (Continue)
813*700637cbSDimitry Andric       constructDivergentConstruct(Visited, S, Continue, Output.get());
814*700637cbSDimitry Andric 
815*700637cbSDimitry Andric     for (BasicBlock *Successor : successors(BB))
816*700637cbSDimitry Andric       constructDivergentConstruct(Visited, S, Successor, Output.get());
817*700637cbSDimitry Andric 
818*700637cbSDimitry Andric     if (Parent)
819*700637cbSDimitry Andric       Parent->Children.emplace_back(std::move(Output));
820*700637cbSDimitry Andric   }
821*700637cbSDimitry Andric 
822*700637cbSDimitry Andric   // Returns the blocks belonging to the divergent construct |Node|.
getConstructBlocks(Splitter & S,DivergentConstruct * Node)823*700637cbSDimitry Andric   BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
824*700637cbSDimitry Andric     assert(Node->Header && Node->Merge);
825*700637cbSDimitry Andric 
826*700637cbSDimitry Andric     if (Node->Continue) {
827*700637cbSDimitry Andric       auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
828*700637cbSDimitry Andric       return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
829*700637cbSDimitry Andric     }
830*700637cbSDimitry Andric 
831*700637cbSDimitry Andric     auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
832*700637cbSDimitry Andric     return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
833*700637cbSDimitry Andric   }
834*700637cbSDimitry Andric 
835*700637cbSDimitry Andric   // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V
836*700637cbSDimitry Andric   // spec.
fixupConstruct(Splitter & S,DivergentConstruct * Node)837*700637cbSDimitry Andric   bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
838*700637cbSDimitry Andric     bool Modified = false;
839*700637cbSDimitry Andric     for (auto &Child : Node->Children)
840*700637cbSDimitry Andric       Modified |= fixupConstruct(S, Child.get());
841*700637cbSDimitry Andric 
842*700637cbSDimitry Andric     // This construct is the root construct. Does not represent any real
843*700637cbSDimitry Andric     // construct, just a way to access the first level of the forest.
844*700637cbSDimitry Andric     if (Node->Parent == nullptr)
845*700637cbSDimitry Andric       return Modified;
846*700637cbSDimitry Andric 
847*700637cbSDimitry Andric     // This node's parent is the root. Meaning this is a top-level construct.
848*700637cbSDimitry Andric     // There can be multiple exists, but all are guaranteed to exit at most 1
849*700637cbSDimitry Andric     // construct since we are at first level.
850*700637cbSDimitry Andric     if (Node->Parent->Header == nullptr)
851*700637cbSDimitry Andric       return Modified;
852*700637cbSDimitry Andric 
853*700637cbSDimitry Andric     // Health check for the structure.
854*700637cbSDimitry Andric     assert(Node->Header && Node->Merge);
855*700637cbSDimitry Andric     assert(Node->Parent->Header && Node->Parent->Merge);
856*700637cbSDimitry Andric 
857*700637cbSDimitry Andric     BlockSet ConstructBlocks = getConstructBlocks(S, Node);
858*700637cbSDimitry Andric     auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
859*700637cbSDimitry Andric 
860*700637cbSDimitry Andric     //  No edges exiting the construct.
861*700637cbSDimitry Andric     if (Edges.size() < 1)
862*700637cbSDimitry Andric       return Modified;
863*700637cbSDimitry Andric 
864*700637cbSDimitry Andric     bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
865*700637cbSDimitry Andric                       Node->Merge == Node->Parent->Continue;
866*700637cbSDimitry Andric     // BasicBlock *Target = Edges[0].second;
867*700637cbSDimitry Andric     for (auto &[Src, Dst] : Edges) {
868*700637cbSDimitry Andric       // - Breaking from a selection construct: S is a selection construct, S is
869*700637cbSDimitry Andric       // the innermost structured
870*700637cbSDimitry Andric       //   control-flow construct containing A, and B is the merge block for S
871*700637cbSDimitry Andric       // - Breaking from the innermost loop: S is the innermost loop construct
872*700637cbSDimitry Andric       // containing A,
873*700637cbSDimitry Andric       //   and B is the merge block for S
874*700637cbSDimitry Andric       if (Node->Merge == Dst)
875*700637cbSDimitry Andric         continue;
876*700637cbSDimitry Andric 
877*700637cbSDimitry Andric       // Entering the innermost loop’s continue construct: S is the innermost
878*700637cbSDimitry Andric       // loop construct containing A, and B is the continue target for S
879*700637cbSDimitry Andric       if (Node->Continue == Dst)
880*700637cbSDimitry Andric         continue;
881*700637cbSDimitry Andric 
882*700637cbSDimitry Andric       // TODO: what about cases branching to another case in the switch? Seems
883*700637cbSDimitry Andric       // to work, but need to double check.
884*700637cbSDimitry Andric       HasBadEdge = true;
885*700637cbSDimitry Andric     }
886*700637cbSDimitry Andric 
887*700637cbSDimitry Andric     if (!HasBadEdge)
888*700637cbSDimitry Andric       return Modified;
889*700637cbSDimitry Andric 
890*700637cbSDimitry Andric     // Create a single exit node gathering all exit edges.
891*700637cbSDimitry Andric     BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);
892*700637cbSDimitry Andric 
893*700637cbSDimitry Andric     // Fixup this construct's merge node to point to the new exit.
894*700637cbSDimitry Andric     // Note: this algorithm fixes inner-most divergence construct first. So
895*700637cbSDimitry Andric     // recursive structures sharing a single merge node are fixed from the
896*700637cbSDimitry Andric     // inside toward the outside.
897*700637cbSDimitry Andric     auto MergeInstructions = getMergeInstructions(*Node->Header);
898*700637cbSDimitry Andric     assert(MergeInstructions.size() == 1);
899*700637cbSDimitry Andric     Instruction *I = MergeInstructions[0];
900*700637cbSDimitry Andric     BlockAddress *BA = cast<BlockAddress>(I->getOperand(0));
901*700637cbSDimitry Andric     if (BA->getBasicBlock() == Node->Merge) {
902*700637cbSDimitry Andric       auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit);
903*700637cbSDimitry Andric       I->setOperand(0, MergeAddress);
904*700637cbSDimitry Andric     }
905*700637cbSDimitry Andric 
906*700637cbSDimitry Andric     // Clean up of the possible dangling BockAddr operands to prevent MIR
907*700637cbSDimitry Andric     // comments about "address of removed block taken".
908*700637cbSDimitry Andric     if (!BA->isConstantUsed())
909*700637cbSDimitry Andric       BA->destroyConstant();
910*700637cbSDimitry Andric 
911*700637cbSDimitry Andric     Node->Merge = NewExit;
912*700637cbSDimitry Andric     // Regenerate the dom trees.
913*700637cbSDimitry Andric     S.invalidate();
914*700637cbSDimitry Andric     return true;
915*700637cbSDimitry Andric   }
916*700637cbSDimitry Andric 
splitCriticalEdges(Function & F)917*700637cbSDimitry Andric   bool splitCriticalEdges(Function &F) {
918*700637cbSDimitry Andric     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
919*700637cbSDimitry Andric     Splitter S(F, LI);
920*700637cbSDimitry Andric 
921*700637cbSDimitry Andric     DivergentConstruct Root;
922*700637cbSDimitry Andric     BlockSet Visited;
923*700637cbSDimitry Andric     constructDivergentConstruct(Visited, S, &*F.begin(), &Root);
924*700637cbSDimitry Andric     return fixupConstruct(S, &Root);
925*700637cbSDimitry Andric   }
926*700637cbSDimitry Andric 
927*700637cbSDimitry Andric   // Simplify branches when possible:
928*700637cbSDimitry Andric   //  - if the 2 sides of a conditional branch are the same, transforms it to an
929*700637cbSDimitry Andric   //  unconditional branch.
930*700637cbSDimitry Andric   //  - if a switch has only 2 distinct successors, converts it to a conditional
931*700637cbSDimitry Andric   //  branch.
simplifyBranches(Function & F)932*700637cbSDimitry Andric   bool simplifyBranches(Function &F) {
933*700637cbSDimitry Andric     bool Modified = false;
934*700637cbSDimitry Andric 
935*700637cbSDimitry Andric     for (BasicBlock &BB : F) {
936*700637cbSDimitry Andric       SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
937*700637cbSDimitry Andric       if (!SI)
938*700637cbSDimitry Andric         continue;
939*700637cbSDimitry Andric       if (SI->getNumCases() > 1)
940*700637cbSDimitry Andric         continue;
941*700637cbSDimitry Andric 
942*700637cbSDimitry Andric       Modified = true;
943*700637cbSDimitry Andric       IRBuilder<> Builder(&BB);
944*700637cbSDimitry Andric       Builder.SetInsertPoint(SI);
945*700637cbSDimitry Andric 
946*700637cbSDimitry Andric       if (SI->getNumCases() == 0) {
947*700637cbSDimitry Andric         Builder.CreateBr(SI->getDefaultDest());
948*700637cbSDimitry Andric       } else {
949*700637cbSDimitry Andric         Value *Condition =
950*700637cbSDimitry Andric             Builder.CreateCmp(CmpInst::ICMP_EQ, SI->getCondition(),
951*700637cbSDimitry Andric                               SI->case_begin()->getCaseValue());
952*700637cbSDimitry Andric         Builder.CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(),
953*700637cbSDimitry Andric                              SI->getDefaultDest());
954*700637cbSDimitry Andric       }
955*700637cbSDimitry Andric       SI->eraseFromParent();
956*700637cbSDimitry Andric     }
957*700637cbSDimitry Andric 
958*700637cbSDimitry Andric     return Modified;
959*700637cbSDimitry Andric   }
960*700637cbSDimitry Andric 
961*700637cbSDimitry Andric   // Makes sure every case target in |F| is unique. If 2 cases branch to the
962*700637cbSDimitry Andric   // same basic block, one of the targets is updated so it jumps to a new basic
963*700637cbSDimitry Andric   // block ending with a single unconditional branch to the original target.
splitSwitchCases(Function & F)964*700637cbSDimitry Andric   bool splitSwitchCases(Function &F) {
965*700637cbSDimitry Andric     bool Modified = false;
966*700637cbSDimitry Andric 
967*700637cbSDimitry Andric     for (BasicBlock &BB : F) {
968*700637cbSDimitry Andric       SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
969*700637cbSDimitry Andric       if (!SI)
970*700637cbSDimitry Andric         continue;
971*700637cbSDimitry Andric 
972*700637cbSDimitry Andric       BlockSet Seen;
973*700637cbSDimitry Andric       Seen.insert(SI->getDefaultDest());
974*700637cbSDimitry Andric 
975*700637cbSDimitry Andric       auto It = SI->case_begin();
976*700637cbSDimitry Andric       while (It != SI->case_end()) {
977*700637cbSDimitry Andric         BasicBlock *Target = It->getCaseSuccessor();
978*700637cbSDimitry Andric         if (Seen.count(Target) == 0) {
979*700637cbSDimitry Andric           Seen.insert(Target);
980*700637cbSDimitry Andric           ++It;
981*700637cbSDimitry Andric           continue;
982*700637cbSDimitry Andric         }
983*700637cbSDimitry Andric 
984*700637cbSDimitry Andric         Modified = true;
985*700637cbSDimitry Andric         BasicBlock *NewTarget =
986*700637cbSDimitry Andric             BasicBlock::Create(F.getContext(), "new.sw.case", &F);
987*700637cbSDimitry Andric         IRBuilder<> Builder(NewTarget);
988*700637cbSDimitry Andric         Builder.CreateBr(Target);
989*700637cbSDimitry Andric         SI->addCase(It->getCaseValue(), NewTarget);
990*700637cbSDimitry Andric         It = SI->removeCase(It);
991*700637cbSDimitry Andric       }
992*700637cbSDimitry Andric     }
993*700637cbSDimitry Andric 
994*700637cbSDimitry Andric     return Modified;
995*700637cbSDimitry Andric   }
996*700637cbSDimitry Andric 
997*700637cbSDimitry Andric   // Removes blocks not contributing to any structured CFG. This assumes there
998*700637cbSDimitry Andric   // is no PHI nodes.
removeUselessBlocks(Function & F)999*700637cbSDimitry Andric   bool removeUselessBlocks(Function &F) {
1000*700637cbSDimitry Andric     std::vector<BasicBlock *> ToRemove;
1001*700637cbSDimitry Andric 
1002*700637cbSDimitry Andric     auto MergeBlocks = getMergeBlocks(F);
1003*700637cbSDimitry Andric     auto ContinueBlocks = getContinueBlocks(F);
1004*700637cbSDimitry Andric 
1005*700637cbSDimitry Andric     for (BasicBlock &BB : F) {
1006*700637cbSDimitry Andric       if (BB.size() != 1)
1007*700637cbSDimitry Andric         continue;
1008*700637cbSDimitry Andric 
1009*700637cbSDimitry Andric       if (isa<ReturnInst>(BB.getTerminator()))
1010*700637cbSDimitry Andric         continue;
1011*700637cbSDimitry Andric 
1012*700637cbSDimitry Andric       if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1013*700637cbSDimitry Andric         continue;
1014*700637cbSDimitry Andric 
1015*700637cbSDimitry Andric       if (BB.getUniqueSuccessor() == nullptr)
1016*700637cbSDimitry Andric         continue;
1017*700637cbSDimitry Andric 
1018*700637cbSDimitry Andric       BasicBlock *Successor = BB.getUniqueSuccessor();
1019*700637cbSDimitry Andric       std::vector<BasicBlock *> Predecessors(predecessors(&BB).begin(),
1020*700637cbSDimitry Andric                                              predecessors(&BB).end());
1021*700637cbSDimitry Andric       for (BasicBlock *Predecessor : Predecessors)
1022*700637cbSDimitry Andric         replaceBranchTargets(Predecessor, &BB, Successor);
1023*700637cbSDimitry Andric       ToRemove.push_back(&BB);
1024*700637cbSDimitry Andric     }
1025*700637cbSDimitry Andric 
1026*700637cbSDimitry Andric     for (BasicBlock *BB : ToRemove)
1027*700637cbSDimitry Andric       BB->eraseFromParent();
1028*700637cbSDimitry Andric 
1029*700637cbSDimitry Andric     return ToRemove.size() != 0;
1030*700637cbSDimitry Andric   }
1031*700637cbSDimitry Andric 
addHeaderToRemainingDivergentDAG(Function & F)1032*700637cbSDimitry Andric   bool addHeaderToRemainingDivergentDAG(Function &F) {
1033*700637cbSDimitry Andric     bool Modified = false;
1034*700637cbSDimitry Andric 
1035*700637cbSDimitry Andric     auto MergeBlocks = getMergeBlocks(F);
1036*700637cbSDimitry Andric     auto ContinueBlocks = getContinueBlocks(F);
1037*700637cbSDimitry Andric     auto HeaderBlocks = getHeaderBlocks(F);
1038*700637cbSDimitry Andric 
1039*700637cbSDimitry Andric     DomTreeBuilder::BBDomTree DT;
1040*700637cbSDimitry Andric     DomTreeBuilder::BBPostDomTree PDT;
1041*700637cbSDimitry Andric     PDT.recalculate(F);
1042*700637cbSDimitry Andric     DT.recalculate(F);
1043*700637cbSDimitry Andric 
1044*700637cbSDimitry Andric     for (BasicBlock &BB : F) {
1045*700637cbSDimitry Andric       if (HeaderBlocks.count(&BB) != 0)
1046*700637cbSDimitry Andric         continue;
1047*700637cbSDimitry Andric       if (succ_size(&BB) < 2)
1048*700637cbSDimitry Andric         continue;
1049*700637cbSDimitry Andric 
1050*700637cbSDimitry Andric       size_t CandidateEdges = 0;
1051*700637cbSDimitry Andric       for (BasicBlock *Successor : successors(&BB)) {
1052*700637cbSDimitry Andric         if (MergeBlocks.count(Successor) != 0 ||
1053*700637cbSDimitry Andric             ContinueBlocks.count(Successor) != 0)
1054*700637cbSDimitry Andric           continue;
1055*700637cbSDimitry Andric         if (HeaderBlocks.count(Successor) != 0)
1056*700637cbSDimitry Andric           continue;
1057*700637cbSDimitry Andric         CandidateEdges += 1;
1058*700637cbSDimitry Andric       }
1059*700637cbSDimitry Andric 
1060*700637cbSDimitry Andric       if (CandidateEdges <= 1)
1061*700637cbSDimitry Andric         continue;
1062*700637cbSDimitry Andric 
1063*700637cbSDimitry Andric       BasicBlock *Header = &BB;
1064*700637cbSDimitry Andric       BasicBlock *Merge = PDT.getNode(&BB)->getIDom()->getBlock();
1065*700637cbSDimitry Andric 
1066*700637cbSDimitry Andric       bool HasBadBlock = false;
1067*700637cbSDimitry Andric       visit(*Header, [&](const BasicBlock *Node) {
1068*700637cbSDimitry Andric         if (DT.dominates(Header, Node))
1069*700637cbSDimitry Andric           return false;
1070*700637cbSDimitry Andric         if (PDT.dominates(Merge, Node))
1071*700637cbSDimitry Andric           return false;
1072*700637cbSDimitry Andric         if (Node == Header || Node == Merge)
1073*700637cbSDimitry Andric           return true;
1074*700637cbSDimitry Andric 
1075*700637cbSDimitry Andric         HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1076*700637cbSDimitry Andric                        ContinueBlocks.count(Node) != 0 ||
1077*700637cbSDimitry Andric                        HeaderBlocks.count(Node) != 0;
1078*700637cbSDimitry Andric         return !HasBadBlock;
1079*700637cbSDimitry Andric       });
1080*700637cbSDimitry Andric 
1081*700637cbSDimitry Andric       if (HasBadBlock)
1082*700637cbSDimitry Andric         continue;
1083*700637cbSDimitry Andric 
1084*700637cbSDimitry Andric       Modified = true;
1085*700637cbSDimitry Andric 
1086*700637cbSDimitry Andric       if (Merge == nullptr) {
1087*700637cbSDimitry Andric         Merge = *successors(Header).begin();
1088*700637cbSDimitry Andric         IRBuilder<> Builder(Header);
1089*700637cbSDimitry Andric         Builder.SetInsertPoint(Header->getTerminator());
1090*700637cbSDimitry Andric 
1091*700637cbSDimitry Andric         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
1092*700637cbSDimitry Andric         createOpSelectMerge(&Builder, MergeAddress);
1093*700637cbSDimitry Andric         continue;
1094*700637cbSDimitry Andric       }
1095*700637cbSDimitry Andric 
1096*700637cbSDimitry Andric       Instruction *SplitInstruction = Merge->getTerminator();
1097*700637cbSDimitry Andric       if (isMergeInstruction(SplitInstruction->getPrevNode()))
1098*700637cbSDimitry Andric         SplitInstruction = SplitInstruction->getPrevNode();
1099*700637cbSDimitry Andric       BasicBlock *NewMerge =
1100*700637cbSDimitry Andric           Merge->splitBasicBlockBefore(SplitInstruction, "new.merge");
1101*700637cbSDimitry Andric 
1102*700637cbSDimitry Andric       IRBuilder<> Builder(Header);
1103*700637cbSDimitry Andric       Builder.SetInsertPoint(Header->getTerminator());
1104*700637cbSDimitry Andric 
1105*700637cbSDimitry Andric       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
1106*700637cbSDimitry Andric       createOpSelectMerge(&Builder, MergeAddress);
1107*700637cbSDimitry Andric     }
1108*700637cbSDimitry Andric 
1109*700637cbSDimitry Andric     return Modified;
1110*700637cbSDimitry Andric   }
1111*700637cbSDimitry Andric 
1112*700637cbSDimitry Andric public:
1113*700637cbSDimitry Andric   static char ID;
1114*700637cbSDimitry Andric 
SPIRVStructurizer()1115*700637cbSDimitry Andric   SPIRVStructurizer() : FunctionPass(ID) {}
1116*700637cbSDimitry Andric 
runOnFunction(Function & F)1117*700637cbSDimitry Andric   virtual bool runOnFunction(Function &F) override {
1118*700637cbSDimitry Andric     bool Modified = false;
1119*700637cbSDimitry Andric 
1120*700637cbSDimitry Andric     // In LLVM, Switches are allowed to have several cases branching to the same
1121*700637cbSDimitry Andric     // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V
1122*700637cbSDimitry Andric     // harder, so first remove edge cases.
1123*700637cbSDimitry Andric     Modified |= splitSwitchCases(F);
1124*700637cbSDimitry Andric 
1125*700637cbSDimitry Andric     // LLVM allows conditional branches to have both side jumping to the same
1126*700637cbSDimitry Andric     // block. It also allows switched to have a single default, or just one
1127*700637cbSDimitry Andric     // case. Cleaning this up now.
1128*700637cbSDimitry Andric     Modified |= simplifyBranches(F);
1129*700637cbSDimitry Andric 
1130*700637cbSDimitry Andric     // At this state, we should have a reducible CFG with cycles.
1131*700637cbSDimitry Andric     // STEP 1: Adding OpLoopMerge instructions to loop headers.
1132*700637cbSDimitry Andric     Modified |= addMergeForLoops(F);
1133*700637cbSDimitry Andric 
1134*700637cbSDimitry Andric     // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2.
1135*700637cbSDimitry Andric     Modified |= addMergeForNodesWithMultiplePredecessors(F);
1136*700637cbSDimitry Andric 
1137*700637cbSDimitry Andric     // STEP 3:
1138*700637cbSDimitry Andric     // Sort selection merge, the largest construct goes first.
1139*700637cbSDimitry Andric     // This simplifies the next step.
1140*700637cbSDimitry Andric     Modified |= sortSelectionMergeHeaders(F);
1141*700637cbSDimitry Andric 
1142*700637cbSDimitry Andric     // STEP 4: As this stage, we can have a single basic block with multiple
1143*700637cbSDimitry Andric     // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each
1144*700637cbSDimitry Andric     // BB has a single merge instruction.
1145*700637cbSDimitry Andric     Modified |= splitBlocksWithMultipleHeaders(F);
1146*700637cbSDimitry Andric 
1147*700637cbSDimitry Andric     // STEP 5: In the previous steps, we added merge blocks the loops and
1148*700637cbSDimitry Andric     // natural merge blocks (in-degree >= 2). What remains are conditions with
1149*700637cbSDimitry Andric     // an exiting branch (return, unreachable). In such case, we must start from
1150*700637cbSDimitry Andric     // the header, and add headers to divergent construct with no headers.
1151*700637cbSDimitry Andric     Modified |= addMergeForDivergentBlocks(F);
1152*700637cbSDimitry Andric 
1153*700637cbSDimitry Andric     // STEP 6: At this stage, we have several divergent construct defines by a
1154*700637cbSDimitry Andric     // header and a merge block. But their boundaries have no constraints: a
1155*700637cbSDimitry Andric     // construct exit could be outside of the parents' construct exit. Such
1156*700637cbSDimitry Andric     // edges are called critical edges. What we need is to split those edges
1157*700637cbSDimitry Andric     // into several parts. Each part exiting the parent's construct by its merge
1158*700637cbSDimitry Andric     // block.
1159*700637cbSDimitry Andric     Modified |= splitCriticalEdges(F);
1160*700637cbSDimitry Andric 
1161*700637cbSDimitry Andric     // STEP 7: The previous steps possibly created a lot of "proxy" blocks.
1162*700637cbSDimitry Andric     // Blocks with a single unconditional branch, used to create a valid
1163*700637cbSDimitry Andric     // divergent construct tree. Some nodes are still requires (e.g: nodes
1164*700637cbSDimitry Andric     // allowing a valid exit through the parent's merge block). But some are
1165*700637cbSDimitry Andric     // left-overs of past transformations, and could cause actual validation
1166*700637cbSDimitry Andric     // issues. E.g: the SPIR-V spec allows a construct to break to the parents
1167*700637cbSDimitry Andric     // loop construct without an OpSelectionMerge, but this requires a straight
1168*700637cbSDimitry Andric     // jump. If a proxy block lies between the conditional branch and the
1169*700637cbSDimitry Andric     // parent's merge, the CFG is not valid.
1170*700637cbSDimitry Andric     Modified |= removeUselessBlocks(F);
1171*700637cbSDimitry Andric 
1172*700637cbSDimitry Andric     // STEP 8: Final fix-up steps: our tree boundaries are correct, but some
1173*700637cbSDimitry Andric     // blocks are branching with no header. Those are often simple conditional
1174*700637cbSDimitry Andric     // branches with 1 or 2 returning edges. Adding a header for those.
1175*700637cbSDimitry Andric     Modified |= addHeaderToRemainingDivergentDAG(F);
1176*700637cbSDimitry Andric 
1177*700637cbSDimitry Andric     // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements.
1178*700637cbSDimitry Andric     Modified |= sortBlocks(F);
1179*700637cbSDimitry Andric 
1180*700637cbSDimitry Andric     return Modified;
1181*700637cbSDimitry Andric   }
1182*700637cbSDimitry Andric 
getAnalysisUsage(AnalysisUsage & AU) const1183*700637cbSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
1184*700637cbSDimitry Andric     AU.addRequired<DominatorTreeWrapperPass>();
1185*700637cbSDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
1186*700637cbSDimitry Andric     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1187*700637cbSDimitry Andric 
1188*700637cbSDimitry Andric     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1189*700637cbSDimitry Andric     FunctionPass::getAnalysisUsage(AU);
1190*700637cbSDimitry Andric   }
1191*700637cbSDimitry Andric 
createOpSelectMerge(IRBuilder<> * Builder,BlockAddress * MergeAddress)1192*700637cbSDimitry Andric   void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1193*700637cbSDimitry Andric     Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
1194*700637cbSDimitry Andric 
1195*700637cbSDimitry Andric     MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
1196*700637cbSDimitry Andric 
1197*700637cbSDimitry Andric     ConstantInt *BranchHint = ConstantInt::get(Builder->getInt32Ty(), 0);
1198*700637cbSDimitry Andric 
1199*700637cbSDimitry Andric     if (MDNode) {
1200*700637cbSDimitry Andric       assert(MDNode->getNumOperands() == 2 &&
1201*700637cbSDimitry Andric              "invalid metadata hlsl.controlflow.hint");
1202*700637cbSDimitry Andric       BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
1203*700637cbSDimitry Andric     }
1204*700637cbSDimitry Andric 
1205*700637cbSDimitry Andric     SmallVector<Value *, 2> Args = {MergeAddress, BranchHint};
1206*700637cbSDimitry Andric 
1207*700637cbSDimitry Andric     Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
1208*700637cbSDimitry Andric                              {MergeAddress->getType()}, Args);
1209*700637cbSDimitry Andric   }
1210*700637cbSDimitry Andric };
1211*700637cbSDimitry Andric } // anonymous namespace
1212*700637cbSDimitry Andric 
1213*700637cbSDimitry Andric char SPIRVStructurizer::ID = 0;
1214*700637cbSDimitry Andric 
1215*700637cbSDimitry Andric INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "spirv-structurizer",
1216*700637cbSDimitry Andric                       "structurize SPIRV", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)1217*700637cbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1218*700637cbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1219*700637cbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
1220*700637cbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
1221*700637cbSDimitry Andric 
1222*700637cbSDimitry Andric INITIALIZE_PASS_END(SPIRVStructurizer, "spirv-structurizer",
1223*700637cbSDimitry Andric                     "structurize SPIRV", false, false)
1224*700637cbSDimitry Andric 
1225*700637cbSDimitry Andric FunctionPass *llvm::createSPIRVStructurizerPass() {
1226*700637cbSDimitry Andric   return new SPIRVStructurizer();
1227*700637cbSDimitry Andric }
1228*700637cbSDimitry Andric 
run(Function & F,FunctionAnalysisManager & AF)1229*700637cbSDimitry Andric PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
1230*700637cbSDimitry Andric                                                 FunctionAnalysisManager &AF) {
1231*700637cbSDimitry Andric 
1232*700637cbSDimitry Andric   auto FPM = legacy::FunctionPassManager(F.getParent());
1233*700637cbSDimitry Andric   FPM.add(createSPIRVStructurizerPass());
1234*700637cbSDimitry Andric 
1235*700637cbSDimitry Andric   if (!FPM.run(F))
1236*700637cbSDimitry Andric     return PreservedAnalyses::all();
1237*700637cbSDimitry Andric   PreservedAnalyses PA;
1238*700637cbSDimitry Andric   PA.preserveSet<CFGAnalyses>();
1239*700637cbSDimitry Andric   return PA;
1240*700637cbSDimitry Andric }
1241