xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Merge the multiple exit targets of a convergence region into a single block.
10 // Each exit target will be assigned a constant value, and a phi node + switch
11 // will allow the new exit target to re-route to the correct basic block.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16 #include "SPIRV.h"
17 #include "SPIRVSubtarget.h"
18 #include "SPIRVUtils.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/CodeGen/IntrinsicLowering.h"
23 #include "llvm/IR/Dominators.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Transforms/Utils/Cloning.h"
28 #include "llvm/Transforms/Utils/LoopSimplify.h"
29 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 class SPIRVMergeRegionExitTargets : public FunctionPass {
36 public:
37   static char ID;
38 
SPIRVMergeRegionExitTargets()39   SPIRVMergeRegionExitTargets() : FunctionPass(ID) {}
40 
41   // Gather all the successors of |BB|.
42   // This function asserts if the terminator neither a branch, switch or return.
gatherSuccessors(BasicBlock * BB)43   std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
44     std::unordered_set<BasicBlock *> output;
45     auto *T = BB->getTerminator();
46 
47     if (auto *BI = dyn_cast<BranchInst>(T)) {
48       output.insert(BI->getSuccessor(0));
49       if (BI->isConditional())
50         output.insert(BI->getSuccessor(1));
51       return output;
52     }
53 
54     if (auto *SI = dyn_cast<SwitchInst>(T)) {
55       output.insert(SI->getDefaultDest());
56       for (auto &Case : SI->cases())
57         output.insert(Case.getCaseSuccessor());
58       return output;
59     }
60 
61     assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
62     return output;
63   }
64 
65   /// Create a value in BB set to the value associated with the branch the block
66   /// terminator will take.
createExitVariable(BasicBlock * BB,const DenseMap<BasicBlock *,ConstantInt * > & TargetToValue)67   llvm::Value *createExitVariable(
68       BasicBlock *BB,
69       const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
70     auto *T = BB->getTerminator();
71     if (isa<ReturnInst>(T))
72       return nullptr;
73 
74     IRBuilder<> Builder(BB);
75     Builder.SetInsertPoint(T);
76 
77     if (auto *BI = dyn_cast<BranchInst>(T)) {
78 
79       BasicBlock *LHSTarget = BI->getSuccessor(0);
80       BasicBlock *RHSTarget =
81           BI->isConditional() ? BI->getSuccessor(1) : nullptr;
82 
83       Value *LHS = TargetToValue.lookup(LHSTarget);
84       Value *RHS = TargetToValue.lookup(RHSTarget);
85 
86       if (LHS == nullptr || RHS == nullptr)
87         return LHS == nullptr ? RHS : LHS;
88       return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
89     }
90 
91     // TODO: add support for switch cases.
92     llvm_unreachable("Unhandled terminator type.");
93   }
94 
95   /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
replaceBranchTargets(BasicBlock * BB,const SmallPtrSet<BasicBlock *,4> & ToReplace,BasicBlock * NewTarget)96   void replaceBranchTargets(BasicBlock *BB,
97                             const SmallPtrSet<BasicBlock *, 4> &ToReplace,
98                             BasicBlock *NewTarget) {
99     auto *T = BB->getTerminator();
100     if (isa<ReturnInst>(T))
101       return;
102 
103     if (auto *BI = dyn_cast<BranchInst>(T)) {
104       for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
105         if (ToReplace.count(BI->getSuccessor(i)) != 0)
106           BI->setSuccessor(i, NewTarget);
107       }
108       return;
109     }
110 
111     if (auto *SI = dyn_cast<SwitchInst>(T)) {
112       for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
113         if (ToReplace.count(SI->getSuccessor(i)) != 0)
114           SI->setSuccessor(i, NewTarget);
115       }
116       return;
117     }
118 
119     assert(false && "Unhandled terminator type.");
120   }
121 
CreateVariable(Function & F,Type * Type,BasicBlock::iterator Position)122   AllocaInst *CreateVariable(Function &F, Type *Type,
123                              BasicBlock::iterator Position) {
124     const DataLayout &DL = F.getDataLayout();
125     return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
126                           Position);
127   }
128 
129   // Run the pass on the given convergence region, ignoring the sub-regions.
130   // Returns true if the CFG changed, false otherwise.
runOnConvergenceRegionNoRecurse(LoopInfo & LI,SPIRV::ConvergenceRegion * CR)131   bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
132                                        SPIRV::ConvergenceRegion *CR) {
133     // Gather all the exit targets for this region.
134     SmallPtrSet<BasicBlock *, 4> ExitTargets;
135     for (BasicBlock *Exit : CR->Exits) {
136       for (BasicBlock *Target : gatherSuccessors(Exit)) {
137         if (CR->Blocks.count(Target) == 0)
138           ExitTargets.insert(Target);
139       }
140     }
141 
142     // If we have zero or one exit target, nothing do to.
143     if (ExitTargets.size() <= 1)
144       return false;
145 
146     // Create the new single exit target.
147     auto F = CR->Entry->getParent();
148     auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
149     IRBuilder<> Builder(NewExitTarget);
150 
151     AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
152                                           F->begin()->getFirstInsertionPt());
153 
154     // CodeGen output needs to be stable. Using the set as-is would order
155     // the targets differently depending on the allocation pattern.
156     // Sorting per basic-block ordering in the function.
157     std::vector<BasicBlock *> SortedExitTargets;
158     std::vector<BasicBlock *> SortedExits;
159     for (BasicBlock &BB : *F) {
160       if (ExitTargets.count(&BB) != 0)
161         SortedExitTargets.push_back(&BB);
162       if (CR->Exits.count(&BB) != 0)
163         SortedExits.push_back(&BB);
164     }
165 
166     // Creating one constant per distinct exit target. This will be route to the
167     // correct target.
168     DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
169     for (BasicBlock *Target : SortedExitTargets)
170       TargetToValue.insert(
171           std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
172 
173     // Creating one variable per exit node, set to the constant matching the
174     // targeted external block.
175     std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
176     for (auto Exit : SortedExits) {
177       llvm::Value *Value = createExitVariable(Exit, TargetToValue);
178       IRBuilder<> B2(Exit);
179       B2.SetInsertPoint(Exit->getFirstInsertionPt());
180       B2.CreateStore(Value, Variable);
181       ExitToVariable.emplace_back(std::make_pair(Exit, Value));
182     }
183 
184     llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
185 
186     // Creating the switch to jump to the correct exit target.
187     llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
188                                                 SortedExitTargets.size() - 1);
189     for (size_t i = 1; i < SortedExitTargets.size(); i++) {
190       BasicBlock *BB = SortedExitTargets[i];
191       Sw->addCase(TargetToValue[BB], BB);
192     }
193 
194     // Fix exit branches to redirect to the new exit.
195     for (auto Exit : CR->Exits)
196       replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
197 
198     CR = CR->Parent;
199     while (CR) {
200       CR->Blocks.insert(NewExitTarget);
201       CR = CR->Parent;
202     }
203 
204     return true;
205   }
206 
207   /// Run the pass on the given convergence region and sub-regions (DFS).
208   /// Returns true if a region/sub-region was modified, false otherwise.
209   /// This returns as soon as one region/sub-region has been modified.
runOnConvergenceRegion(LoopInfo & LI,SPIRV::ConvergenceRegion * CR)210   bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
211     for (auto *Child : CR->Children)
212       if (runOnConvergenceRegion(LI, Child))
213         return true;
214 
215     return runOnConvergenceRegionNoRecurse(LI, CR);
216   }
217 
218 #if !NDEBUG
219   /// Validates each edge exiting the region has the same destination basic
220   /// block.
validateRegionExits(const SPIRV::ConvergenceRegion * CR)221   void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
222     for (auto *Child : CR->Children)
223       validateRegionExits(Child);
224 
225     std::unordered_set<BasicBlock *> ExitTargets;
226     for (auto *Exit : CR->Exits) {
227       auto Set = gatherSuccessors(Exit);
228       for (auto *BB : Set) {
229         if (CR->Blocks.count(BB) == 0)
230           ExitTargets.insert(BB);
231       }
232     }
233 
234     assert(ExitTargets.size() <= 1);
235   }
236 #endif
237 
runOnFunction(Function & F)238   virtual bool runOnFunction(Function &F) override {
239     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
240     auto *TopLevelRegion =
241         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
242             .getRegionInfo()
243             .getWritableTopLevelRegion();
244 
245     // FIXME: very inefficient method: each time a region is modified, we bubble
246     // back up, and recompute the whole convergence region tree. Once the
247     // algorithm is completed and test coverage good enough, rewrite this pass
248     // to be efficient instead of simple.
249     bool modified = false;
250     while (runOnConvergenceRegion(LI, TopLevelRegion)) {
251       modified = true;
252     }
253 
254 #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
255     validateRegionExits(TopLevelRegion);
256 #endif
257     return modified;
258   }
259 
getAnalysisUsage(AnalysisUsage & AU) const260   void getAnalysisUsage(AnalysisUsage &AU) const override {
261     AU.addRequired<DominatorTreeWrapperPass>();
262     AU.addRequired<LoopInfoWrapperPass>();
263     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
264 
265     AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
266     FunctionPass::getAnalysisUsage(AU);
267   }
268 };
269 } // namespace
270 
271 char SPIRVMergeRegionExitTargets::ID = 0;
272 
273 INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
274                       "SPIRV split region exit blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)275 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
276 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
277 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
278 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
279 
280 INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
281                     "SPIRV split region exit blocks", false, false)
282 
283 FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
284   return new SPIRVMergeRegionExitTargets();
285 }
286