xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUUnifyDivergentExitNodes.cpp (revision 8311bc5f17dec348749f763b82dfe2737bc53cd7)
1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a variant of the UnifyFunctionExitNodes pass. Rather than ensuring
10 // there is at most one ret and one unreachable instruction, it ensures there is
11 // at most one divergent exiting block.
12 //
13 // StructurizeCFG can't deal with multi-exit regions formed by branches to
14 // multiple return nodes. It is not desirable to structurize regions with
15 // uniform branches, so unifying those to the same return block as divergent
16 // branches inhibits use of scalar branching. It still can't deal with the case
17 // where one branch goes to return, and one unreachable. Replace unreachable in
18 // this case with a return.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "AMDGPUUnifyDivergentExitNodes.h"
23 #include "AMDGPU.h"
24 #include "SIDefines.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Analysis/DomTreeUpdater.h"
30 #include "llvm/Analysis/PostDominators.h"
31 #include "llvm/Analysis/TargetTransformInfo.h"
32 #include "llvm/Analysis/UniformityAnalysis.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/InstrTypes.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/Intrinsics.h"
42 #include "llvm/IR/IntrinsicsAMDGPU.h"
43 #include "llvm/IR/Type.h"
44 #include "llvm/InitializePasses.h"
45 #include "llvm/Pass.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Transforms/Scalar.h"
48 #include "llvm/Transforms/Utils.h"
49 #include "llvm/Transforms/Utils/Local.h"
50 
51 using namespace llvm;
52 
53 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
54 
55 namespace {
56 
57 class AMDGPUUnifyDivergentExitNodesImpl {
58 private:
59   const TargetTransformInfo *TTI = nullptr;
60 
61 public:
62   AMDGPUUnifyDivergentExitNodesImpl() = delete;
63   AMDGPUUnifyDivergentExitNodesImpl(const TargetTransformInfo *TTI)
64       : TTI(TTI) {}
65 
66   // We can preserve non-critical-edgeness when we unify function exit nodes
67   BasicBlock *unifyReturnBlockSet(Function &F, DomTreeUpdater &DTU,
68                                   ArrayRef<BasicBlock *> ReturningBlocks,
69                                   StringRef Name);
70   bool run(Function &F, DominatorTree *DT, const PostDominatorTree &PDT,
71            const UniformityInfo &UA);
72 };
73 
74 class AMDGPUUnifyDivergentExitNodes : public FunctionPass {
75 public:
76   static char ID;
77   AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) {
78     initializeAMDGPUUnifyDivergentExitNodesPass(
79         *PassRegistry::getPassRegistry());
80   }
81   void getAnalysisUsage(AnalysisUsage &AU) const override;
82   bool runOnFunction(Function &F) override;
83 };
84 } // end anonymous namespace
85 
86 char AMDGPUUnifyDivergentExitNodes::ID = 0;
87 
88 char &llvm::AMDGPUUnifyDivergentExitNodesID = AMDGPUUnifyDivergentExitNodes::ID;
89 
90 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
91                       "Unify divergent function exit nodes", false, false)
92 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
93 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
94 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
95 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
96                     "Unify divergent function exit nodes", false, false)
97 
98 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const {
99   if (RequireAndPreserveDomTree)
100     AU.addRequired<DominatorTreeWrapperPass>();
101 
102   AU.addRequired<PostDominatorTreeWrapperPass>();
103 
104   AU.addRequired<UniformityInfoWrapperPass>();
105 
106   if (RequireAndPreserveDomTree) {
107     AU.addPreserved<DominatorTreeWrapperPass>();
108     // FIXME: preserve PostDominatorTreeWrapperPass
109   }
110 
111   // No divergent values are changed, only blocks and branch edges.
112   AU.addPreserved<UniformityInfoWrapperPass>();
113 
114   // We preserve the non-critical-edgeness property
115   AU.addPreservedID(BreakCriticalEdgesID);
116 
117   // This is a cluster of orthogonal Transforms
118   AU.addPreservedID(LowerSwitchID);
119   FunctionPass::getAnalysisUsage(AU);
120 
121   AU.addRequired<TargetTransformInfoWrapperPass>();
122 }
123 
124 /// \returns true if \p BB is reachable through only uniform branches.
125 /// XXX - Is there a more efficient way to find this?
126 static bool isUniformlyReached(const UniformityInfo &UA, BasicBlock &BB) {
127   SmallVector<BasicBlock *, 8> Stack(predecessors(&BB));
128   SmallPtrSet<BasicBlock *, 8> Visited;
129 
130   while (!Stack.empty()) {
131     BasicBlock *Top = Stack.pop_back_val();
132     if (!UA.isUniform(Top->getTerminator()))
133       return false;
134 
135     for (BasicBlock *Pred : predecessors(Top)) {
136       if (Visited.insert(Pred).second)
137         Stack.push_back(Pred);
138     }
139   }
140 
141   return true;
142 }
143 
144 BasicBlock *AMDGPUUnifyDivergentExitNodesImpl::unifyReturnBlockSet(
145     Function &F, DomTreeUpdater &DTU, ArrayRef<BasicBlock *> ReturningBlocks,
146     StringRef Name) {
147   // Otherwise, we need to insert a new basic block into the function, add a PHI
148   // nodes (if the function returns values), and convert all of the return
149   // instructions into unconditional branches.
150   BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F);
151   IRBuilder<> B(NewRetBlock);
152 
153   PHINode *PN = nullptr;
154   if (F.getReturnType()->isVoidTy()) {
155     B.CreateRetVoid();
156   } else {
157     // If the function doesn't return void... add a PHI node to the block...
158     PN = B.CreatePHI(F.getReturnType(), ReturningBlocks.size(),
159                      "UnifiedRetVal");
160     B.CreateRet(PN);
161   }
162 
163   // Loop over all of the blocks, replacing the return instruction with an
164   // unconditional branch.
165   std::vector<DominatorTree::UpdateType> Updates;
166   Updates.reserve(ReturningBlocks.size());
167   for (BasicBlock *BB : ReturningBlocks) {
168     // Add an incoming element to the PHI node for every return instruction that
169     // is merging into this new block...
170     if (PN)
171       PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
172 
173     // Remove and delete the return inst.
174     BB->getTerminator()->eraseFromParent();
175     BranchInst::Create(NewRetBlock, BB);
176     Updates.push_back({DominatorTree::Insert, BB, NewRetBlock});
177   }
178 
179   if (RequireAndPreserveDomTree)
180     DTU.applyUpdates(Updates);
181   Updates.clear();
182 
183   for (BasicBlock *BB : ReturningBlocks) {
184     // Cleanup possible branch to unconditional branch to the return.
185     simplifyCFG(BB, *TTI, RequireAndPreserveDomTree ? &DTU : nullptr,
186                 SimplifyCFGOptions().bonusInstThreshold(2));
187   }
188 
189   return NewRetBlock;
190 }
191 
192 bool AMDGPUUnifyDivergentExitNodesImpl::run(Function &F, DominatorTree *DT,
193                                             const PostDominatorTree &PDT,
194                                             const UniformityInfo &UA) {
195   if (PDT.root_size() == 0 ||
196       (PDT.root_size() == 1 &&
197        !isa<BranchInst>(PDT.getRoot()->getTerminator())))
198     return false;
199 
200   // Loop over all of the blocks in a function, tracking all of the blocks that
201   // return.
202   SmallVector<BasicBlock *, 4> ReturningBlocks;
203   SmallVector<BasicBlock *, 4> UnreachableBlocks;
204 
205   // Dummy return block for infinite loop.
206   BasicBlock *DummyReturnBB = nullptr;
207 
208   bool Changed = false;
209   std::vector<DominatorTree::UpdateType> Updates;
210 
211   // TODO: For now we unify all exit blocks, even though they are uniformly
212   // reachable, if there are any exits not uniformly reached. This is to
213   // workaround the limitation of structurizer, which can not handle multiple
214   // function exits. After structurizer is able to handle multiple function
215   // exits, we should only unify UnreachableBlocks that are not uniformly
216   // reachable.
217   bool HasDivergentExitBlock = llvm::any_of(
218       PDT.roots(), [&](auto BB) { return !isUniformlyReached(UA, *BB); });
219 
220   for (BasicBlock *BB : PDT.roots()) {
221     if (isa<ReturnInst>(BB->getTerminator())) {
222       if (HasDivergentExitBlock)
223         ReturningBlocks.push_back(BB);
224     } else if (isa<UnreachableInst>(BB->getTerminator())) {
225       if (HasDivergentExitBlock)
226         UnreachableBlocks.push_back(BB);
227     } else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
228 
229       ConstantInt *BoolTrue = ConstantInt::getTrue(F.getContext());
230       if (DummyReturnBB == nullptr) {
231         DummyReturnBB = BasicBlock::Create(F.getContext(),
232                                            "DummyReturnBlock", &F);
233         Type *RetTy = F.getReturnType();
234         Value *RetVal = RetTy->isVoidTy() ? nullptr : PoisonValue::get(RetTy);
235         ReturnInst::Create(F.getContext(), RetVal, DummyReturnBB);
236         ReturningBlocks.push_back(DummyReturnBB);
237       }
238 
239       if (BI->isUnconditional()) {
240         BasicBlock *LoopHeaderBB = BI->getSuccessor(0);
241         BI->eraseFromParent(); // Delete the unconditional branch.
242         // Add a new conditional branch with a dummy edge to the return block.
243         BranchInst::Create(LoopHeaderBB, DummyReturnBB, BoolTrue, BB);
244         Updates.push_back({DominatorTree::Insert, BB, DummyReturnBB});
245       } else { // Conditional branch.
246         SmallVector<BasicBlock *, 2> Successors(successors(BB));
247 
248         // Create a new transition block to hold the conditional branch.
249         BasicBlock *TransitionBB = BB->splitBasicBlock(BI, "TransitionBlock");
250 
251         Updates.reserve(Updates.size() + 2 * Successors.size() + 2);
252 
253         // 'Successors' become successors of TransitionBB instead of BB,
254         // and TransitionBB becomes a single successor of BB.
255         Updates.push_back({DominatorTree::Insert, BB, TransitionBB});
256         for (BasicBlock *Successor : Successors) {
257           Updates.push_back({DominatorTree::Insert, TransitionBB, Successor});
258           Updates.push_back({DominatorTree::Delete, BB, Successor});
259         }
260 
261         // Create a branch that will always branch to the transition block and
262         // references DummyReturnBB.
263         BB->getTerminator()->eraseFromParent();
264         BranchInst::Create(TransitionBB, DummyReturnBB, BoolTrue, BB);
265         Updates.push_back({DominatorTree::Insert, BB, DummyReturnBB});
266       }
267       Changed = true;
268     }
269   }
270 
271   if (!UnreachableBlocks.empty()) {
272     BasicBlock *UnreachableBlock = nullptr;
273 
274     if (UnreachableBlocks.size() == 1) {
275       UnreachableBlock = UnreachableBlocks.front();
276     } else {
277       UnreachableBlock = BasicBlock::Create(F.getContext(),
278                                             "UnifiedUnreachableBlock", &F);
279       new UnreachableInst(F.getContext(), UnreachableBlock);
280 
281       Updates.reserve(Updates.size() + UnreachableBlocks.size());
282       for (BasicBlock *BB : UnreachableBlocks) {
283         // Remove and delete the unreachable inst.
284         BB->getTerminator()->eraseFromParent();
285         BranchInst::Create(UnreachableBlock, BB);
286         Updates.push_back({DominatorTree::Insert, BB, UnreachableBlock});
287       }
288       Changed = true;
289     }
290 
291     if (!ReturningBlocks.empty()) {
292       // Don't create a new unreachable inst if we have a return. The
293       // structurizer/annotator can't handle the multiple exits
294 
295       Type *RetTy = F.getReturnType();
296       Value *RetVal = RetTy->isVoidTy() ? nullptr : PoisonValue::get(RetTy);
297       // Remove and delete the unreachable inst.
298       UnreachableBlock->getTerminator()->eraseFromParent();
299 
300       Function *UnreachableIntrin =
301         Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable);
302 
303       // Insert a call to an intrinsic tracking that this is an unreachable
304       // point, in case we want to kill the active lanes or something later.
305       CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock);
306 
307       // Don't create a scalar trap. We would only want to trap if this code was
308       // really reached, but a scalar trap would happen even if no lanes
309       // actually reached here.
310       ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock);
311       ReturningBlocks.push_back(UnreachableBlock);
312       Changed = true;
313     }
314   }
315 
316   // FIXME: add PDT here once simplifycfg is ready.
317   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
318   if (RequireAndPreserveDomTree)
319     DTU.applyUpdates(Updates);
320   Updates.clear();
321 
322   // Now handle return blocks.
323   if (ReturningBlocks.empty())
324     return Changed; // No blocks return
325 
326   if (ReturningBlocks.size() == 1)
327     return Changed; // Already has a single return block
328 
329   unifyReturnBlockSet(F, DTU, ReturningBlocks, "UnifiedReturnBlock");
330   return true;
331 }
332 
333 bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function &F) {
334   DominatorTree *DT = nullptr;
335   if (RequireAndPreserveDomTree)
336     DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
337   const auto &PDT =
338       getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
339   const auto &UA = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
340   const auto *TranformInfo =
341       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
342   return AMDGPUUnifyDivergentExitNodesImpl(TranformInfo).run(F, DT, PDT, UA);
343 }
344 
345 PreservedAnalyses
346 AMDGPUUnifyDivergentExitNodesPass::run(Function &F,
347                                        FunctionAnalysisManager &AM) {
348   DominatorTree *DT = nullptr;
349   if (RequireAndPreserveDomTree)
350     DT = &AM.getResult<DominatorTreeAnalysis>(F);
351 
352   const auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
353   const auto &UA = AM.getResult<UniformityInfoAnalysis>(F);
354   const auto *TransformInfo = &AM.getResult<TargetIRAnalysis>(F);
355   return AMDGPUUnifyDivergentExitNodesImpl(TransformInfo).run(F, DT, PDT, UA)
356              ? PreservedAnalyses::none()
357              : PreservedAnalyses::all();
358 }
359