xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
12 #include "SPIRV.h"
13 #include "SPIRVStructurizerWrapper.h"
14 #include "SPIRVSubtarget.h"
15 #include "SPIRVUtils.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/CodeGen/IntrinsicLowering.h"
20 #include "llvm/IR/CFG.h"
21 #include "llvm/IR/Dominators.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/IntrinsicsSPIRV.h"
26 #include "llvm/IR/LegacyPassManager.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Transforms/Utils.h"
29 #include "llvm/Transforms/Utils/Cloning.h"
30 #include "llvm/Transforms/Utils/LoopSimplify.h"
31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32 #include <stack>
33 #include <unordered_set>
34 
35 using namespace llvm;
36 using namespace SPIRV;
37 
38 using BlockSet = std::unordered_set<BasicBlock *>;
39 using Edge = std::pair<BasicBlock *, BasicBlock *>;
40 
41 // Helper function to do a partial order visit from the block |Start|, calling
42 // |Op| on each visited node.
43 static void partialOrderVisit(BasicBlock &Start,
44                               std::function<bool(BasicBlock *)> Op) {
45   PartialOrderingVisitor V(*Start.getParent());
46   V.partialOrderVisit(Start, Op);
47 }
48 
49 // Returns the exact convergence region in the tree defined by `Node` for which
50 // `BB` is the header, nullptr otherwise.
51 static const ConvergenceRegion *
52 getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB) {
53   if (Node->Entry == BB)
54     return Node;
55 
56   for (auto *Child : Node->Children) {
57     const auto *CR = getRegionForHeader(Child, BB);
58     if (CR != nullptr)
59       return CR;
60   }
61   return nullptr;
62 }
63 
64 // Returns the single BasicBlock exiting the convergence region `CR`,
65 // nullptr if no such exit exists.
66 static BasicBlock *getExitFor(const ConvergenceRegion *CR) {
67   std::unordered_set<BasicBlock *> ExitTargets;
68   for (BasicBlock *Exit : CR->Exits) {
69     for (BasicBlock *Successor : successors(Exit)) {
70       if (CR->Blocks.count(Successor) == 0)
71         ExitTargets.insert(Successor);
72     }
73   }
74 
75   assert(ExitTargets.size() <= 1);
76   if (ExitTargets.size() == 0)
77     return nullptr;
78 
79   return *ExitTargets.begin();
80 }
81 
82 // Returns the merge block designated by I if I is a merge instruction, nullptr
83 // otherwise.
84 static BasicBlock *getDesignatedMergeBlock(Instruction *I) {
85   IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I);
86   if (II == nullptr)
87     return nullptr;
88 
89   if (II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
90       II->getIntrinsicID() != Intrinsic::spv_selection_merge)
91     return nullptr;
92 
93   BlockAddress *BA = cast<BlockAddress>(II->getOperand(0));
94   return BA->getBasicBlock();
95 }
96 
97 // Returns the continue block designated by I if I is an OpLoopMerge, nullptr
98 // otherwise.
99 static BasicBlock *getDesignatedContinueBlock(Instruction *I) {
100   IntrinsicInst *II = dyn_cast_or_null<IntrinsicInst>(I);
101   if (II == nullptr)
102     return nullptr;
103 
104   if (II->getIntrinsicID() != Intrinsic::spv_loop_merge)
105     return nullptr;
106 
107   BlockAddress *BA = cast<BlockAddress>(II->getOperand(1));
108   return BA->getBasicBlock();
109 }
110 
111 // Returns true if Header has one merge instruction which designated Merge as
112 // merge block.
113 static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge) {
114   for (auto &I : Header) {
115     BasicBlock *MB = getDesignatedMergeBlock(&I);
116     if (MB == &Merge)
117       return true;
118   }
119   return false;
120 }
121 
122 // Returns true if the BB has one OpLoopMerge instruction.
123 static bool hasLoopMergeInstruction(BasicBlock &BB) {
124   for (auto &I : BB)
125     if (getDesignatedContinueBlock(&I))
126       return true;
127   return false;
128 }
129 
130 // Returns true is I is an OpSelectionMerge or OpLoopMerge instruction, false
131 // otherwise.
132 static bool isMergeInstruction(Instruction *I) {
133   return getDesignatedMergeBlock(I) != nullptr;
134 }
135 
136 // Returns all blocks in F having at least one OpLoopMerge or OpSelectionMerge
137 // instruction.
138 static SmallPtrSet<BasicBlock *, 2> getHeaderBlocks(Function &F) {
139   SmallPtrSet<BasicBlock *, 2> Output;
140   for (BasicBlock &BB : F) {
141     for (Instruction &I : BB) {
142       if (getDesignatedMergeBlock(&I) != nullptr)
143         Output.insert(&BB);
144     }
145   }
146   return Output;
147 }
148 
149 // Returns all basic blocks in |F| referenced by at least 1
150 // OpSelectionMerge/OpLoopMerge instruction.
151 static SmallPtrSet<BasicBlock *, 2> getMergeBlocks(Function &F) {
152   SmallPtrSet<BasicBlock *, 2> Output;
153   for (BasicBlock &BB : F) {
154     for (Instruction &I : BB) {
155       BasicBlock *MB = getDesignatedMergeBlock(&I);
156       if (MB != nullptr)
157         Output.insert(MB);
158     }
159   }
160   return Output;
161 }
162 
163 // Return all the merge instructions contained in BB.
164 // Note: the SPIR-V spec doesn't allow a single BB to contain more than 1 merge
165 // instruction, but this can happen while we structurize the CFG.
166 static std::vector<Instruction *> getMergeInstructions(BasicBlock &BB) {
167   std::vector<Instruction *> Output;
168   for (Instruction &I : BB)
169     if (isMergeInstruction(&I))
170       Output.push_back(&I);
171   return Output;
172 }
173 
174 // Returns all basic blocks in |F| referenced as continue target by at least 1
175 // OpLoopMerge instruction.
176 static SmallPtrSet<BasicBlock *, 2> getContinueBlocks(Function &F) {
177   SmallPtrSet<BasicBlock *, 2> Output;
178   for (BasicBlock &BB : F) {
179     for (Instruction &I : BB) {
180       BasicBlock *MB = getDesignatedContinueBlock(&I);
181       if (MB != nullptr)
182         Output.insert(MB);
183     }
184   }
185   return Output;
186 }
187 
188 // Do a preorder traversal of the CFG starting from the BB |Start|.
189 // point. Calls |op| on each basic block encountered during the traversal.
190 static void visit(BasicBlock &Start, std::function<bool(BasicBlock *)> op) {
191   std::stack<BasicBlock *> ToVisit;
192   SmallPtrSet<BasicBlock *, 8> Seen;
193 
194   ToVisit.push(&Start);
195   Seen.insert(ToVisit.top());
196   while (ToVisit.size() != 0) {
197     BasicBlock *BB = ToVisit.top();
198     ToVisit.pop();
199 
200     if (!op(BB))
201       continue;
202 
203     for (auto Succ : successors(BB)) {
204       if (Seen.contains(Succ))
205         continue;
206       ToVisit.push(Succ);
207       Seen.insert(Succ);
208     }
209   }
210 }
211 
212 // Replaces the conditional and unconditional branch targets of |BB| by
213 // |NewTarget| if the target was |OldTarget|. This function also makes sure the
214 // associated merge instruction gets updated accordingly.
215 static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
216                                    BasicBlock *NewTarget) {
217   auto *BI = cast<BranchInst>(BB->getTerminator());
218 
219   // 1. Replace all matching successors.
220   for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
221     if (BI->getSuccessor(i) == OldTarget)
222       BI->setSuccessor(i, NewTarget);
223   }
224 
225   // Branch was unconditional, no fixup required.
226   if (BI->isUnconditional())
227     return;
228 
229   // Branch had 2 successors, maybe now both are the same?
230   if (BI->getSuccessor(0) != BI->getSuccessor(1))
231     return;
232 
233   // Note: we may end up here because the original IR had such branches.
234   // This means Target is not necessarily equal to NewTarget.
235   IRBuilder<> Builder(BB);
236   Builder.SetInsertPoint(BI);
237   Builder.CreateBr(BI->getSuccessor(0));
238   BI->eraseFromParent();
239 
240   // The branch was the only instruction, nothing else to do.
241   if (BB->size() == 1)
242     return;
243 
244   // Otherwise, we need to check: was there an OpSelectionMerge before this
245   // branch? If we removed the OpBranchConditional, we must also remove the
246   // OpSelectionMerge. This is not valid for OpLoopMerge:
247   IntrinsicInst *II =
248       dyn_cast<IntrinsicInst>(BB->getTerminator()->getPrevNode());
249   if (!II || II->getIntrinsicID() != Intrinsic::spv_selection_merge)
250     return;
251 
252   Constant *C = cast<Constant>(II->getOperand(0));
253   II->eraseFromParent();
254   if (!C->isConstantUsed())
255     C->destroyConstant();
256 }
257 
258 // Replaces the target of branch instruction in |BB| with |NewTarget| if it
259 // was |OldTarget|. This function also fixes the associated merge instruction.
260 // Note: this function does not simplify branching instructions, it only updates
261 // targets. See also: simplifyBranches.
262 static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
263                                  BasicBlock *NewTarget) {
264   auto *T = BB->getTerminator();
265   if (isa<ReturnInst>(T))
266     return;
267 
268   if (isa<BranchInst>(T))
269     return replaceIfBranchTargets(BB, OldTarget, NewTarget);
270 
271   if (auto *SI = dyn_cast<SwitchInst>(T)) {
272     for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
273       if (SI->getSuccessor(i) == OldTarget)
274         SI->setSuccessor(i, NewTarget);
275     }
276     return;
277   }
278 
279   assert(false && "Unhandled terminator type.");
280 }
281 
282 namespace {
283 // Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
284 // adding merge instructions when required.
285 class SPIRVStructurizer : public FunctionPass {
286   struct DivergentConstruct;
287   // Represents a list of condition/loops/switch constructs.
288   // See SPIR-V 2.11.2. Structured Control-flow Constructs for the list of
289   // constructs.
290   using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
291 
292   // Represents a divergent construct in the SPIR-V sense.
293   // Such constructs are represented by a header (entry), a merge block (exit),
294   // and possibly a continue block (back-edge). A construct can contain other
295   // constructs, but their boundaries do not cross.
296   struct DivergentConstruct {
297     BasicBlock *Header = nullptr;
298     BasicBlock *Merge = nullptr;
299     BasicBlock *Continue = nullptr;
300 
301     DivergentConstruct *Parent = nullptr;
302     ConstructList Children;
303   };
304 
305   // An helper class to clean the construct boundaries.
306   // It is used to gather the list of blocks that should belong to each
307   // divergent construct, and possibly modify CFG edges when exits would cross
308   // the boundary of multiple constructs.
309   struct Splitter {
310     Function &F;
311     LoopInfo &LI;
312     DomTreeBuilder::BBDomTree DT;
313     DomTreeBuilder::BBPostDomTree PDT;
314 
315     Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
316 
317     void invalidate() {
318       PDT.recalculate(F);
319       DT.recalculate(F);
320     }
321 
322     // Returns the list of blocks that belong to a SPIR-V loop construct,
323     // including the continue construct.
324     std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
325                                                      BasicBlock *Merge) {
326       assert(DT.dominates(Header, Merge));
327       std::vector<BasicBlock *> Output;
328       partialOrderVisit(*Header, [&](BasicBlock *BB) {
329         if (BB == Merge)
330           return false;
331         if (DT.dominates(Merge, BB) || !DT.dominates(Header, BB))
332           return false;
333         Output.push_back(BB);
334         return true;
335       });
336       return Output;
337     }
338 
339     // Returns the list of blocks that belong to a SPIR-V selection construct.
340     std::vector<BasicBlock *>
341     getSelectionConstructBlocks(DivergentConstruct *Node) {
342       assert(DT.dominates(Node->Header, Node->Merge));
343       BlockSet OutsideBlocks;
344       OutsideBlocks.insert(Node->Merge);
345 
346       for (DivergentConstruct *It = Node->Parent; It != nullptr;
347            It = It->Parent) {
348         OutsideBlocks.insert(It->Merge);
349         if (It->Continue)
350           OutsideBlocks.insert(It->Continue);
351       }
352 
353       std::vector<BasicBlock *> Output;
354       partialOrderVisit(*Node->Header, [&](BasicBlock *BB) {
355         if (OutsideBlocks.count(BB) != 0)
356           return false;
357         if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
358           return false;
359         Output.push_back(BB);
360         return true;
361       });
362       return Output;
363     }
364 
365     // Returns the list of blocks that belong to a SPIR-V switch construct.
366     std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
367                                                        BasicBlock *Merge) {
368       assert(DT.dominates(Header, Merge));
369 
370       std::vector<BasicBlock *> Output;
371       partialOrderVisit(*Header, [&](BasicBlock *BB) {
372         // the blocks structurally dominated by a switch header,
373         if (!DT.dominates(Header, BB))
374           return false;
375         // excluding blocks structurally dominated by the switch header’s merge
376         // block.
377         if (DT.dominates(Merge, BB) || BB == Merge)
378           return false;
379         Output.push_back(BB);
380         return true;
381       });
382       return Output;
383     }
384 
385     // Returns the list of blocks that belong to a SPIR-V case construct.
386     std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
387                                                      BasicBlock *Merge) {
388       assert(DT.dominates(Target, Merge));
389 
390       std::vector<BasicBlock *> Output;
391       partialOrderVisit(*Target, [&](BasicBlock *BB) {
392         // the blocks structurally dominated by an OpSwitch Target or Default
393         // block
394         if (!DT.dominates(Target, BB))
395           return false;
396         // excluding the blocks structurally dominated by the OpSwitch
397         // construct’s corresponding merge block.
398         if (DT.dominates(Merge, BB) || BB == Merge)
399           return false;
400         Output.push_back(BB);
401         return true;
402       });
403       return Output;
404     }
405 
406     // Splits the given edges by recreating proxy nodes so that the destination
407     // has unique incoming edges from this region.
408     //
409     // clang-format off
410     //
411     // In SPIR-V, constructs must have a single exit/merge.
412     // Given nodes A and B in the construct, a node C outside, and the following edges.
413     //  A -> C
414     //  B -> C
415     //
416     // In such cases, we must create a new exit node D, that belong to the construct to make is viable:
417     // A -> D -> C
418     // B -> D -> C
419     //
420     // This is fine (assuming C has no PHI nodes), but requires handling the merge instruction here.
421     // By adding a proxy node, we create a regular divergent shape which can easily be regularized later on.
422     // A -> D -> D1 -> C
423     // B -> D -> D2 -> C
424     //
425     // A, B, D belongs to the construct. D is the exit. D1 and D2 are empty.
426     //
427     // clang-format on
428     std::vector<Edge>
429     createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
430       std::unordered_set<BasicBlock *> Seen;
431       std::vector<Edge> Output;
432       Output.reserve(Edges.size());
433 
434       for (auto &[Src, Dst] : Edges) {
435         auto [Iterator, Inserted] = Seen.insert(Src);
436         if (!Inserted) {
437           // Src already a source node. Cannot have 2 edges from A to B.
438           // Creating alias source block.
439           BasicBlock *NewSrc = BasicBlock::Create(
440               F.getContext(), Src->getName() + ".new.src", &F);
441           replaceBranchTargets(Src, Dst, NewSrc);
442           IRBuilder<> Builder(NewSrc);
443           Builder.CreateBr(Dst);
444           Src = NewSrc;
445         }
446 
447         Output.emplace_back(Src, Dst);
448       }
449 
450       return Output;
451     }
452 
453     AllocaInst *CreateVariable(Function &F, Type *Type,
454                                BasicBlock::iterator Position) {
455       const DataLayout &DL = F.getDataLayout();
456       return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
457                             Position);
458     }
459 
460     // Given a construct defined by |Header|, and a list of exiting edges
461     // |Edges|, creates a new single exit node, fixing up those edges.
462     BasicBlock *createSingleExitNode(BasicBlock *Header,
463                                      std::vector<Edge> &Edges) {
464 
465       std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
466 
467       std::vector<BasicBlock *> Dsts;
468       std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
469       auto NewExit = BasicBlock::Create(F.getContext(),
470                                         Header->getName() + ".new.exit", &F);
471       IRBuilder<> ExitBuilder(NewExit);
472       for (auto &[Src, Dst] : FixedEdges) {
473         if (DstToIndex.count(Dst) != 0)
474           continue;
475         DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
476         Dsts.push_back(Dst);
477       }
478 
479       if (Dsts.size() == 1) {
480         for (auto &[Src, Dst] : FixedEdges) {
481           replaceBranchTargets(Src, Dst, NewExit);
482         }
483         ExitBuilder.CreateBr(Dsts[0]);
484         return NewExit;
485       }
486 
487       AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
488                                             F.begin()->getFirstInsertionPt());
489       for (auto &[Src, Dst] : FixedEdges) {
490         IRBuilder<> B2(Src);
491         B2.SetInsertPoint(Src->getFirstInsertionPt());
492         B2.CreateStore(DstToIndex[Dst], Variable);
493         replaceBranchTargets(Src, Dst, NewExit);
494       }
495 
496       Value *Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
497 
498       // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
499       // OpBranch are allowed to exist without a new OpSelectionMerge if one of
500       // the branch is the parent's merge node, while OpSwitches are not.
501       if (Dsts.size() == 2) {
502         Value *Condition =
503             ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load);
504         ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
505         return NewExit;
506       }
507 
508       SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
509       for (BasicBlock *BB : drop_begin(Dsts))
510         Sw->addCase(DstToIndex[BB], BB);
511       return NewExit;
512     }
513   };
514 
515   /// Create a value in BB set to the value associated with the branch the block
516   /// terminator will take.
517   Value *createExitVariable(
518       BasicBlock *BB,
519       const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
520     auto *T = BB->getTerminator();
521     if (isa<ReturnInst>(T))
522       return nullptr;
523 
524     IRBuilder<> Builder(BB);
525     Builder.SetInsertPoint(T);
526 
527     if (auto *BI = dyn_cast<BranchInst>(T)) {
528 
529       BasicBlock *LHSTarget = BI->getSuccessor(0);
530       BasicBlock *RHSTarget =
531           BI->isConditional() ? BI->getSuccessor(1) : nullptr;
532 
533       Value *LHS = TargetToValue.lookup(LHSTarget);
534       Value *RHS = TargetToValue.lookup(RHSTarget);
535 
536       if (LHS == nullptr || RHS == nullptr)
537         return LHS == nullptr ? RHS : LHS;
538       return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
539     }
540 
541     // TODO: add support for switch cases.
542     llvm_unreachable("Unhandled terminator type.");
543   }
544 
545   // Creates a new basic block in F with a single OpUnreachable instruction.
546   BasicBlock *CreateUnreachable(Function &F) {
547     BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F);
548     IRBuilder<> Builder(BB);
549     Builder.CreateUnreachable();
550     return BB;
551   }
552 
553   // Add OpLoopMerge instruction on cycles.
554   bool addMergeForLoops(Function &F) {
555     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
556     auto *TopLevelRegion =
557         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
558             .getRegionInfo()
559             .getTopLevelRegion();
560 
561     bool Modified = false;
562     for (auto &BB : F) {
563       // Not a loop header. Ignoring for now.
564       if (!LI.isLoopHeader(&BB))
565         continue;
566       auto *L = LI.getLoopFor(&BB);
567 
568       // This loop header is not the entrance of a convergence region. Ignoring
569       // this block.
570       auto *CR = getRegionForHeader(TopLevelRegion, &BB);
571       if (CR == nullptr)
572         continue;
573 
574       IRBuilder<> Builder(&BB);
575 
576       auto *Merge = getExitFor(CR);
577       // We are indeed in a loop, but there are no exits (infinite loop).
578       // This could be caused by a bad shader, but also could be an artifact
579       // from an earlier optimization. It is not always clear if structurally
580       // reachable means runtime reachable, so we cannot error-out. What we must
581       // do however is to make is legal on the SPIR-V point of view, hence
582       // adding an unreachable merge block.
583       if (Merge == nullptr) {
584         BranchInst *Br = cast<BranchInst>(BB.getTerminator());
585         assert(Br->isUnconditional());
586 
587         Merge = CreateUnreachable(F);
588         Builder.SetInsertPoint(Br);
589         Builder.CreateCondBr(Builder.getFalse(), Merge, Br->getSuccessor(0));
590         Br->eraseFromParent();
591       }
592 
593       auto *Continue = L->getLoopLatch();
594 
595       Builder.SetInsertPoint(BB.getTerminator());
596       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
597       auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
598       SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
599       SmallVector<unsigned, 1> LoopControlImms =
600           getSpirvLoopControlOperandsFromLoopMetadata(L);
601       for (unsigned Imm : LoopControlImms)
602         Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
603       Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {Args});
604       Modified = true;
605     }
606 
607     return Modified;
608   }
609 
610   // Adds an OpSelectionMerge to the immediate dominator or each node with an
611   // in-degree of 2 or more which is not already the merge target of an
612   // OpLoopMerge/OpSelectionMerge.
613   bool addMergeForNodesWithMultiplePredecessors(Function &F) {
614     DomTreeBuilder::BBDomTree DT;
615     DT.recalculate(F);
616 
617     bool Modified = false;
618     for (auto &BB : F) {
619       if (pred_size(&BB) <= 1)
620         continue;
621 
622       if (hasLoopMergeInstruction(BB) && pred_size(&BB) <= 2)
623         continue;
624 
625       assert(DT.getNode(&BB)->getIDom());
626       BasicBlock *Header = DT.getNode(&BB)->getIDom()->getBlock();
627 
628       if (isDefinedAsSelectionMergeBy(*Header, BB))
629         continue;
630 
631       IRBuilder<> Builder(Header);
632       Builder.SetInsertPoint(Header->getTerminator());
633 
634       auto MergeAddress = BlockAddress::get(BB.getParent(), &BB);
635       createOpSelectMerge(&Builder, MergeAddress);
636 
637       Modified = true;
638     }
639 
640     return Modified;
641   }
642 
643   // When a block has multiple OpSelectionMerge/OpLoopMerge instructions, sorts
644   // them to put the "largest" first. A merge instruction is defined as larger
645   // than another when its target merge block post-dominates the other target's
646   // merge block. (This ordering should match the nesting ordering of the source
647   // HLSL).
648   bool sortSelectionMerge(Function &F, BasicBlock &Block) {
649     std::vector<Instruction *> MergeInstructions;
650     for (Instruction &I : Block)
651       if (isMergeInstruction(&I))
652         MergeInstructions.push_back(&I);
653 
654     if (MergeInstructions.size() <= 1)
655       return false;
656 
657     Instruction *InsertionPoint = *MergeInstructions.begin();
658 
659     PartialOrderingVisitor Visitor(F);
660     std::sort(MergeInstructions.begin(), MergeInstructions.end(),
661               [&Visitor](Instruction *Left, Instruction *Right) {
662                 if (Left == Right)
663                   return false;
664                 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
665                 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
666                 return !Visitor.compare(RightMerge, LeftMerge);
667               });
668 
669     for (Instruction *I : MergeInstructions) {
670       I->moveBefore(InsertionPoint->getIterator());
671       InsertionPoint = I;
672     }
673 
674     return true;
675   }
676 
677   // Sorts selection merge headers in |F|.
678   // A is sorted before B if the merge block designated by B is an ancestor of
679   // the one designated by A.
680   bool sortSelectionMergeHeaders(Function &F) {
681     bool Modified = false;
682     for (BasicBlock &BB : F) {
683       Modified |= sortSelectionMerge(F, BB);
684     }
685     return Modified;
686   }
687 
688   // Split basic blocks containing multiple OpLoopMerge/OpSelectionMerge
689   // instructions so each basic block contains only a single merge instruction.
690   bool splitBlocksWithMultipleHeaders(Function &F) {
691     std::stack<BasicBlock *> Work;
692     for (auto &BB : F) {
693       std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
694       if (MergeInstructions.size() <= 1)
695         continue;
696       Work.push(&BB);
697     }
698 
699     const bool Modified = Work.size() > 0;
700     while (Work.size() > 0) {
701       BasicBlock *Header = Work.top();
702       Work.pop();
703 
704       std::vector<Instruction *> MergeInstructions =
705           getMergeInstructions(*Header);
706       for (unsigned i = 1; i < MergeInstructions.size(); i++) {
707         BasicBlock *NewBlock =
708             Header->splitBasicBlock(MergeInstructions[i], "new.header");
709 
710         if (getDesignatedContinueBlock(MergeInstructions[0]) == nullptr) {
711           BasicBlock *Unreachable = CreateUnreachable(F);
712 
713           BranchInst *BI = cast<BranchInst>(Header->getTerminator());
714           IRBuilder<> Builder(Header);
715           Builder.SetInsertPoint(BI);
716           Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
717           BI->eraseFromParent();
718         }
719 
720         Header = NewBlock;
721       }
722     }
723 
724     return Modified;
725   }
726 
727   // Adds an OpSelectionMerge to each block with an out-degree >= 2 which
728   // doesn't already have an OpSelectionMerge.
729   bool addMergeForDivergentBlocks(Function &F) {
730     DomTreeBuilder::BBPostDomTree PDT;
731     PDT.recalculate(F);
732     bool Modified = false;
733 
734     auto MergeBlocks = getMergeBlocks(F);
735     auto ContinueBlocks = getContinueBlocks(F);
736 
737     for (auto &BB : F) {
738       if (getMergeInstructions(BB).size() != 0)
739         continue;
740 
741       std::vector<BasicBlock *> Candidates;
742       for (BasicBlock *Successor : successors(&BB)) {
743         if (MergeBlocks.contains(Successor))
744           continue;
745         if (ContinueBlocks.contains(Successor))
746           continue;
747         Candidates.push_back(Successor);
748       }
749 
750       if (Candidates.size() <= 1)
751         continue;
752 
753       Modified = true;
754       BasicBlock *Merge = Candidates[0];
755 
756       auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
757       IRBuilder<> Builder(&BB);
758       Builder.SetInsertPoint(BB.getTerminator());
759       createOpSelectMerge(&Builder, MergeAddress);
760     }
761 
762     return Modified;
763   }
764 
765   // Gather all the exit nodes for the construct header by |Header| and
766   // containing the blocks |Construct|.
767   std::vector<Edge> getExitsFrom(const BlockSet &Construct,
768                                  BasicBlock &Header) {
769     std::vector<Edge> Output;
770     visit(Header, [&](BasicBlock *Item) {
771       if (Construct.count(Item) == 0)
772         return false;
773 
774       for (BasicBlock *Successor : successors(Item)) {
775         if (Construct.count(Successor) == 0)
776           Output.emplace_back(Item, Successor);
777       }
778       return true;
779     });
780 
781     return Output;
782   }
783 
784   // Build a divergent construct tree searching from |BB|.
785   // If |Parent| is not null, this tree is attached to the parent's tree.
786   void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
787                                    BasicBlock *BB, DivergentConstruct *Parent) {
788     if (Visited.count(BB) != 0)
789       return;
790     Visited.insert(BB);
791 
792     auto MIS = getMergeInstructions(*BB);
793     if (MIS.size() == 0) {
794       for (BasicBlock *Successor : successors(BB))
795         constructDivergentConstruct(Visited, S, Successor, Parent);
796       return;
797     }
798 
799     assert(MIS.size() == 1);
800     Instruction *MI = MIS[0];
801 
802     BasicBlock *Merge = getDesignatedMergeBlock(MI);
803     BasicBlock *Continue = getDesignatedContinueBlock(MI);
804 
805     auto Output = std::make_unique<DivergentConstruct>();
806     Output->Header = BB;
807     Output->Merge = Merge;
808     Output->Continue = Continue;
809     Output->Parent = Parent;
810 
811     constructDivergentConstruct(Visited, S, Merge, Parent);
812     if (Continue)
813       constructDivergentConstruct(Visited, S, Continue, Output.get());
814 
815     for (BasicBlock *Successor : successors(BB))
816       constructDivergentConstruct(Visited, S, Successor, Output.get());
817 
818     if (Parent)
819       Parent->Children.emplace_back(std::move(Output));
820   }
821 
822   // Returns the blocks belonging to the divergent construct |Node|.
823   BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
824     assert(Node->Header && Node->Merge);
825 
826     if (Node->Continue) {
827       auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
828       return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
829     }
830 
831     auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
832     return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
833   }
834 
835   // Fixup the construct |Node| to respect a set of rules defined by the SPIR-V
836   // spec.
837   bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
838     bool Modified = false;
839     for (auto &Child : Node->Children)
840       Modified |= fixupConstruct(S, Child.get());
841 
842     // This construct is the root construct. Does not represent any real
843     // construct, just a way to access the first level of the forest.
844     if (Node->Parent == nullptr)
845       return Modified;
846 
847     // This node's parent is the root. Meaning this is a top-level construct.
848     // There can be multiple exists, but all are guaranteed to exit at most 1
849     // construct since we are at first level.
850     if (Node->Parent->Header == nullptr)
851       return Modified;
852 
853     // Health check for the structure.
854     assert(Node->Header && Node->Merge);
855     assert(Node->Parent->Header && Node->Parent->Merge);
856 
857     BlockSet ConstructBlocks = getConstructBlocks(S, Node);
858     auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
859 
860     //  No edges exiting the construct.
861     if (Edges.size() < 1)
862       return Modified;
863 
864     bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
865                       Node->Merge == Node->Parent->Continue;
866     // BasicBlock *Target = Edges[0].second;
867     for (auto &[Src, Dst] : Edges) {
868       // - Breaking from a selection construct: S is a selection construct, S is
869       // the innermost structured
870       //   control-flow construct containing A, and B is the merge block for S
871       // - Breaking from the innermost loop: S is the innermost loop construct
872       // containing A,
873       //   and B is the merge block for S
874       if (Node->Merge == Dst)
875         continue;
876 
877       // Entering the innermost loop’s continue construct: S is the innermost
878       // loop construct containing A, and B is the continue target for S
879       if (Node->Continue == Dst)
880         continue;
881 
882       // TODO: what about cases branching to another case in the switch? Seems
883       // to work, but need to double check.
884       HasBadEdge = true;
885     }
886 
887     if (!HasBadEdge)
888       return Modified;
889 
890     // Create a single exit node gathering all exit edges.
891     BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);
892 
893     // Fixup this construct's merge node to point to the new exit.
894     // Note: this algorithm fixes inner-most divergence construct first. So
895     // recursive structures sharing a single merge node are fixed from the
896     // inside toward the outside.
897     auto MergeInstructions = getMergeInstructions(*Node->Header);
898     assert(MergeInstructions.size() == 1);
899     Instruction *I = MergeInstructions[0];
900     BlockAddress *BA = cast<BlockAddress>(I->getOperand(0));
901     if (BA->getBasicBlock() == Node->Merge) {
902       auto MergeAddress = BlockAddress::get(NewExit->getParent(), NewExit);
903       I->setOperand(0, MergeAddress);
904     }
905 
906     // Clean up of the possible dangling BockAddr operands to prevent MIR
907     // comments about "address of removed block taken".
908     if (!BA->isConstantUsed())
909       BA->destroyConstant();
910 
911     Node->Merge = NewExit;
912     // Regenerate the dom trees.
913     S.invalidate();
914     return true;
915   }
916 
917   bool splitCriticalEdges(Function &F) {
918     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
919     Splitter S(F, LI);
920 
921     DivergentConstruct Root;
922     BlockSet Visited;
923     constructDivergentConstruct(Visited, S, &*F.begin(), &Root);
924     return fixupConstruct(S, &Root);
925   }
926 
927   // Simplify branches when possible:
928   //  - if the 2 sides of a conditional branch are the same, transforms it to an
929   //  unconditional branch.
930   //  - if a switch has only 2 distinct successors, converts it to a conditional
931   //  branch.
932   bool simplifyBranches(Function &F) {
933     bool Modified = false;
934 
935     for (BasicBlock &BB : F) {
936       SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
937       if (!SI)
938         continue;
939       if (SI->getNumCases() > 1)
940         continue;
941 
942       Modified = true;
943       IRBuilder<> Builder(&BB);
944       Builder.SetInsertPoint(SI);
945 
946       if (SI->getNumCases() == 0) {
947         Builder.CreateBr(SI->getDefaultDest());
948       } else {
949         Value *Condition =
950             Builder.CreateCmp(CmpInst::ICMP_EQ, SI->getCondition(),
951                               SI->case_begin()->getCaseValue());
952         Builder.CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(),
953                              SI->getDefaultDest());
954       }
955       SI->eraseFromParent();
956     }
957 
958     return Modified;
959   }
960 
961   // Makes sure every case target in |F| is unique. If 2 cases branch to the
962   // same basic block, one of the targets is updated so it jumps to a new basic
963   // block ending with a single unconditional branch to the original target.
964   bool splitSwitchCases(Function &F) {
965     bool Modified = false;
966 
967     for (BasicBlock &BB : F) {
968       SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator());
969       if (!SI)
970         continue;
971 
972       BlockSet Seen;
973       Seen.insert(SI->getDefaultDest());
974 
975       auto It = SI->case_begin();
976       while (It != SI->case_end()) {
977         BasicBlock *Target = It->getCaseSuccessor();
978         if (Seen.count(Target) == 0) {
979           Seen.insert(Target);
980           ++It;
981           continue;
982         }
983 
984         Modified = true;
985         BasicBlock *NewTarget =
986             BasicBlock::Create(F.getContext(), "new.sw.case", &F);
987         IRBuilder<> Builder(NewTarget);
988         Builder.CreateBr(Target);
989         SI->addCase(It->getCaseValue(), NewTarget);
990         It = SI->removeCase(It);
991       }
992     }
993 
994     return Modified;
995   }
996 
997   // Removes blocks not contributing to any structured CFG. This assumes there
998   // is no PHI nodes.
999   bool removeUselessBlocks(Function &F) {
1000     std::vector<BasicBlock *> ToRemove;
1001 
1002     auto MergeBlocks = getMergeBlocks(F);
1003     auto ContinueBlocks = getContinueBlocks(F);
1004 
1005     for (BasicBlock &BB : F) {
1006       if (BB.size() != 1)
1007         continue;
1008 
1009       if (isa<ReturnInst>(BB.getTerminator()))
1010         continue;
1011 
1012       if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1013         continue;
1014 
1015       if (BB.getUniqueSuccessor() == nullptr)
1016         continue;
1017 
1018       BasicBlock *Successor = BB.getUniqueSuccessor();
1019       std::vector<BasicBlock *> Predecessors(predecessors(&BB).begin(),
1020                                              predecessors(&BB).end());
1021       for (BasicBlock *Predecessor : Predecessors)
1022         replaceBranchTargets(Predecessor, &BB, Successor);
1023       ToRemove.push_back(&BB);
1024     }
1025 
1026     for (BasicBlock *BB : ToRemove)
1027       BB->eraseFromParent();
1028 
1029     return ToRemove.size() != 0;
1030   }
1031 
1032   bool addHeaderToRemainingDivergentDAG(Function &F) {
1033     bool Modified = false;
1034 
1035     auto MergeBlocks = getMergeBlocks(F);
1036     auto ContinueBlocks = getContinueBlocks(F);
1037     auto HeaderBlocks = getHeaderBlocks(F);
1038 
1039     DomTreeBuilder::BBDomTree DT;
1040     DomTreeBuilder::BBPostDomTree PDT;
1041     PDT.recalculate(F);
1042     DT.recalculate(F);
1043 
1044     for (BasicBlock &BB : F) {
1045       if (HeaderBlocks.count(&BB) != 0)
1046         continue;
1047       if (succ_size(&BB) < 2)
1048         continue;
1049 
1050       size_t CandidateEdges = 0;
1051       for (BasicBlock *Successor : successors(&BB)) {
1052         if (MergeBlocks.count(Successor) != 0 ||
1053             ContinueBlocks.count(Successor) != 0)
1054           continue;
1055         if (HeaderBlocks.count(Successor) != 0)
1056           continue;
1057         CandidateEdges += 1;
1058       }
1059 
1060       if (CandidateEdges <= 1)
1061         continue;
1062 
1063       BasicBlock *Header = &BB;
1064       BasicBlock *Merge = PDT.getNode(&BB)->getIDom()->getBlock();
1065 
1066       bool HasBadBlock = false;
1067       visit(*Header, [&](const BasicBlock *Node) {
1068         if (DT.dominates(Header, Node))
1069           return false;
1070         if (PDT.dominates(Merge, Node))
1071           return false;
1072         if (Node == Header || Node == Merge)
1073           return true;
1074 
1075         HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1076                        ContinueBlocks.count(Node) != 0 ||
1077                        HeaderBlocks.count(Node) != 0;
1078         return !HasBadBlock;
1079       });
1080 
1081       if (HasBadBlock)
1082         continue;
1083 
1084       Modified = true;
1085 
1086       if (Merge == nullptr) {
1087         Merge = *successors(Header).begin();
1088         IRBuilder<> Builder(Header);
1089         Builder.SetInsertPoint(Header->getTerminator());
1090 
1091         auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
1092         createOpSelectMerge(&Builder, MergeAddress);
1093         continue;
1094       }
1095 
1096       Instruction *SplitInstruction = Merge->getTerminator();
1097       if (isMergeInstruction(SplitInstruction->getPrevNode()))
1098         SplitInstruction = SplitInstruction->getPrevNode();
1099       BasicBlock *NewMerge =
1100           Merge->splitBasicBlockBefore(SplitInstruction, "new.merge");
1101 
1102       IRBuilder<> Builder(Header);
1103       Builder.SetInsertPoint(Header->getTerminator());
1104 
1105       auto MergeAddress = BlockAddress::get(NewMerge->getParent(), NewMerge);
1106       createOpSelectMerge(&Builder, MergeAddress);
1107     }
1108 
1109     return Modified;
1110   }
1111 
1112 public:
1113   static char ID;
1114 
1115   SPIRVStructurizer() : FunctionPass(ID) {}
1116 
1117   virtual bool runOnFunction(Function &F) override {
1118     bool Modified = false;
1119 
1120     // In LLVM, Switches are allowed to have several cases branching to the same
1121     // basic block. This is allowed in SPIR-V, but can make structurizing SPIR-V
1122     // harder, so first remove edge cases.
1123     Modified |= splitSwitchCases(F);
1124 
1125     // LLVM allows conditional branches to have both side jumping to the same
1126     // block. It also allows switched to have a single default, or just one
1127     // case. Cleaning this up now.
1128     Modified |= simplifyBranches(F);
1129 
1130     // At this state, we should have a reducible CFG with cycles.
1131     // STEP 1: Adding OpLoopMerge instructions to loop headers.
1132     Modified |= addMergeForLoops(F);
1133 
1134     // STEP 2: adding OpSelectionMerge to each node with an in-degree >= 2.
1135     Modified |= addMergeForNodesWithMultiplePredecessors(F);
1136 
1137     // STEP 3:
1138     // Sort selection merge, the largest construct goes first.
1139     // This simplifies the next step.
1140     Modified |= sortSelectionMergeHeaders(F);
1141 
1142     // STEP 4: As this stage, we can have a single basic block with multiple
1143     // OpLoopMerge/OpSelectionMerge instructions. Splitting this block so each
1144     // BB has a single merge instruction.
1145     Modified |= splitBlocksWithMultipleHeaders(F);
1146 
1147     // STEP 5: In the previous steps, we added merge blocks the loops and
1148     // natural merge blocks (in-degree >= 2). What remains are conditions with
1149     // an exiting branch (return, unreachable). In such case, we must start from
1150     // the header, and add headers to divergent construct with no headers.
1151     Modified |= addMergeForDivergentBlocks(F);
1152 
1153     // STEP 6: At this stage, we have several divergent construct defines by a
1154     // header and a merge block. But their boundaries have no constraints: a
1155     // construct exit could be outside of the parents' construct exit. Such
1156     // edges are called critical edges. What we need is to split those edges
1157     // into several parts. Each part exiting the parent's construct by its merge
1158     // block.
1159     Modified |= splitCriticalEdges(F);
1160 
1161     // STEP 7: The previous steps possibly created a lot of "proxy" blocks.
1162     // Blocks with a single unconditional branch, used to create a valid
1163     // divergent construct tree. Some nodes are still requires (e.g: nodes
1164     // allowing a valid exit through the parent's merge block). But some are
1165     // left-overs of past transformations, and could cause actual validation
1166     // issues. E.g: the SPIR-V spec allows a construct to break to the parents
1167     // loop construct without an OpSelectionMerge, but this requires a straight
1168     // jump. If a proxy block lies between the conditional branch and the
1169     // parent's merge, the CFG is not valid.
1170     Modified |= removeUselessBlocks(F);
1171 
1172     // STEP 8: Final fix-up steps: our tree boundaries are correct, but some
1173     // blocks are branching with no header. Those are often simple conditional
1174     // branches with 1 or 2 returning edges. Adding a header for those.
1175     Modified |= addHeaderToRemainingDivergentDAG(F);
1176 
1177     // STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements.
1178     Modified |= sortBlocks(F);
1179 
1180     return Modified;
1181   }
1182 
1183   void getAnalysisUsage(AnalysisUsage &AU) const override {
1184     AU.addRequired<DominatorTreeWrapperPass>();
1185     AU.addRequired<LoopInfoWrapperPass>();
1186     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1187 
1188     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1189     FunctionPass::getAnalysisUsage(AU);
1190   }
1191 
1192   void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1193     Instruction *BBTerminatorInst = Builder->GetInsertBlock()->getTerminator();
1194 
1195     MDNode *MDNode = BBTerminatorInst->getMetadata("hlsl.controlflow.hint");
1196 
1197     ConstantInt *BranchHint = ConstantInt::get(Builder->getInt32Ty(), 0);
1198 
1199     if (MDNode) {
1200       assert(MDNode->getNumOperands() == 2 &&
1201              "invalid metadata hlsl.controlflow.hint");
1202       BranchHint = mdconst::extract<ConstantInt>(MDNode->getOperand(1));
1203     }
1204 
1205     SmallVector<Value *, 2> Args = {MergeAddress, BranchHint};
1206 
1207     Builder->CreateIntrinsic(Intrinsic::spv_selection_merge,
1208                              {MergeAddress->getType()}, Args);
1209   }
1210 };
1211 } // anonymous namespace
1212 
1213 char SPIRVStructurizer::ID = 0;
1214 
1215 INITIALIZE_PASS_BEGIN(SPIRVStructurizer, "spirv-structurizer",
1216                       "structurize SPIRV", false, false)
1217 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1218 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1219 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
1220 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
1221 
1222 INITIALIZE_PASS_END(SPIRVStructurizer, "spirv-structurizer",
1223                     "structurize SPIRV", false, false)
1224 
1225 FunctionPass *llvm::createSPIRVStructurizerPass() {
1226   return new SPIRVStructurizer();
1227 }
1228 
1229 PreservedAnalyses SPIRVStructurizerWrapper::run(Function &F,
1230                                                 FunctionAnalysisManager &AF) {
1231 
1232   auto FPM = legacy::FunctionPassManager(F.getParent());
1233   FPM.add(createSPIRVStructurizerPass());
1234 
1235   if (!FPM.run(F))
1236     return PreservedAnalyses::all();
1237   PreservedAnalyses PA;
1238   PA.preserveSet<CFGAnalyses>();
1239   return PA;
1240 }
1241