xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp (revision b59017c5cad90d0f09a59e68c00457b7faf93e7c)
1 //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis determines the convergence region for each basic block of
10 // the module, and provides a tree-like structure describing the region
11 // hierarchy.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "SPIRVConvergenceRegionAnalysis.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Utils/LoopSimplify.h"
21 #include <optional>
22 #include <queue>
23 
24 #define DEBUG_TYPE "spirv-convergence-region-analysis"
25 
26 using namespace llvm;
27 
28 namespace llvm {
29 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
30 } // namespace llvm
31 
32 INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
33                       "convergence-region",
34                       "SPIRV convergence regions analysis", true, true)
35 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
36 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
37 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
38 INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
39                     "convergence-region", "SPIRV convergence regions analysis",
40                     true, true)
41 
42 namespace llvm {
43 namespace SPIRV {
44 namespace {
45 
46 template <typename BasicBlockType, typename IntrinsicInstType>
47 std::optional<IntrinsicInstType *>
48 getConvergenceTokenInternal(BasicBlockType *BB) {
49   static_assert(std::is_const_v<IntrinsicInstType> ==
50                     std::is_const_v<BasicBlockType>,
51                 "Constness must match between input and output.");
52   static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
53                 "Input must be a basic block.");
54   static_assert(
55       std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
56       "Output type must be an intrinsic instruction.");
57 
58   for (auto &I : *BB) {
59     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
60       switch (II->getIntrinsicID()) {
61       case Intrinsic::experimental_convergence_entry:
62       case Intrinsic::experimental_convergence_loop:
63         return II;
64       case Intrinsic::experimental_convergence_anchor: {
65         auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);
66         assert(Bundle->Inputs.size() == 1 &&
67                Bundle->Inputs[0]->getType()->isTokenTy());
68         auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());
69         assert(TII != nullptr);
70         return TII;
71       }
72       }
73     }
74 
75     if (auto *CI = dyn_cast<CallInst>(&I)) {
76       auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
77       if (!OB.has_value())
78         continue;
79       return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
80     }
81   }
82 
83   return std::nullopt;
84 }
85 
86 // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
87 // region |Entry| belongs to. If |Entry| does not belong to the region defined
88 // by |Start|, this function returns |nullptr|.
89 ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
90                                     BasicBlock *Entry) {
91   ConvergenceRegion *Candidate = nullptr;
92   ConvergenceRegion *NextCandidate = Start;
93 
94   while (Candidate != NextCandidate && NextCandidate != nullptr) {
95     Candidate = NextCandidate;
96     NextCandidate = nullptr;
97 
98     // End of the search, we can return.
99     if (Candidate->Children.size() == 0)
100       return Candidate;
101 
102     for (auto *Child : Candidate->Children) {
103       if (Child->Blocks.count(Entry) != 0) {
104         NextCandidate = Child;
105         break;
106       }
107     }
108   }
109 
110   return Candidate;
111 }
112 
113 } // anonymous namespace
114 
115 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
116   return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
117 }
118 
119 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {
120   return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
121 }
122 
123 ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
124                                      Function &F)
125     : DT(DT), LI(LI), Parent(nullptr) {
126   Entry = &F.getEntryBlock();
127   ConvergenceToken = getConvergenceToken(Entry);
128   for (auto &B : F) {
129     Blocks.insert(&B);
130     if (isa<ReturnInst>(B.getTerminator()))
131       Exits.insert(&B);
132   }
133 }
134 
135 ConvergenceRegion::ConvergenceRegion(
136     DominatorTree &DT, LoopInfo &LI,
137     std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
138     SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
139     : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
140       Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
141   for ([[maybe_unused]] auto *BB : this->Exits)
142     assert(this->Blocks.count(BB) != 0);
143   assert(this->Blocks.count(this->Entry) != 0);
144 }
145 
146 void ConvergenceRegion::releaseMemory() {
147   // Parent memory is owned by the parent.
148   Parent = nullptr;
149   for (auto *Child : Children) {
150     Child->releaseMemory();
151     delete Child;
152   }
153   Children.resize(0);
154 }
155 
156 void ConvergenceRegion::dump(const unsigned IndentSize) const {
157   const std::string Indent(IndentSize, '\t');
158   dbgs() << Indent << this << ": {\n";
159   dbgs() << Indent << "	Parent: " << Parent << "\n";
160 
161   if (ConvergenceToken.value_or(nullptr)) {
162     dbgs() << Indent
163            << "	ConvergenceToken: " << ConvergenceToken.value()->getName()
164            << "\n";
165   }
166 
167   if (Entry->getName() != "")
168     dbgs() << Indent << "	Entry: " << Entry->getName() << "\n";
169   else
170     dbgs() << Indent << "	Entry: " << Entry << "\n";
171 
172   dbgs() << Indent << "	Exits: { ";
173   for (const auto &Exit : Exits) {
174     if (Exit->getName() != "")
175       dbgs() << Exit->getName() << ", ";
176     else
177       dbgs() << Exit << ", ";
178   }
179   dbgs() << "	}\n";
180 
181   dbgs() << Indent << "	Blocks: { ";
182   for (const auto &Block : Blocks) {
183     if (Block->getName() != "")
184       dbgs() << Block->getName() << ", ";
185     else
186       dbgs() << Block << ", ";
187   }
188   dbgs() << "	}\n";
189 
190   dbgs() << Indent << "	Children: {\n";
191   for (const auto Child : Children)
192     Child->dump(IndentSize + 2);
193   dbgs() << Indent << "	}\n";
194 
195   dbgs() << Indent << "}\n";
196 }
197 
198 class ConvergenceRegionAnalyzer {
199 
200 public:
201   ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
202       : DT(DT), LI(LI), F(F) {}
203 
204 private:
205   bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
206     assert(From != To && "From == To. This is awkward.");
207 
208     // We only handle loop in the simplified form. This means:
209     // - a single back-edge, a single latch.
210     // - meaning the back-edge target can only be the loop header.
211     // - meaning the From can only be the loop latch.
212     if (!LI.isLoopHeader(To))
213       return false;
214 
215     auto *L = LI.getLoopFor(To);
216     if (L->contains(From) && L->isLoopLatch(From))
217       return true;
218 
219     return false;
220   }
221 
222   std::unordered_set<BasicBlock *>
223   findPathsToMatch(LoopInfo &LI, BasicBlock *From,
224                    std::function<bool(const BasicBlock *)> isMatch) const {
225     std::unordered_set<BasicBlock *> Output;
226 
227     if (isMatch(From))
228       Output.insert(From);
229 
230     auto *Terminator = From->getTerminator();
231     for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
232       auto *To = Terminator->getSuccessor(i);
233       if (isBackEdge(From, To))
234         continue;
235 
236       auto ChildSet = findPathsToMatch(LI, To, isMatch);
237       if (ChildSet.size() == 0)
238         continue;
239 
240       Output.insert(ChildSet.begin(), ChildSet.end());
241       Output.insert(From);
242       if (LI.isLoopHeader(From)) {
243         auto *L = LI.getLoopFor(From);
244         for (auto *BB : L->getBlocks()) {
245           Output.insert(BB);
246         }
247       }
248     }
249 
250     return Output;
251   }
252 
253   SmallPtrSet<BasicBlock *, 2>
254   findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
255     SmallPtrSet<BasicBlock *, 2> Exits;
256 
257     for (auto *B : RegionBlocks) {
258       auto *Terminator = B->getTerminator();
259       for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
260         auto *Child = Terminator->getSuccessor(i);
261         if (RegionBlocks.count(Child) == 0)
262           Exits.insert(B);
263       }
264     }
265 
266     return Exits;
267   }
268 
269 public:
270   ConvergenceRegionInfo analyze() {
271     ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
272     std::queue<Loop *> ToProcess;
273     for (auto *L : LI.getLoopsInPreorder())
274       ToProcess.push(L);
275 
276     while (ToProcess.size() != 0) {
277       auto *L = ToProcess.front();
278       ToProcess.pop();
279       assert(L->isLoopSimplifyForm());
280 
281       auto CT = getConvergenceToken(L->getHeader());
282       SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
283                                                 L->block_end());
284       SmallVector<BasicBlock *> LoopExits;
285       L->getExitingBlocks(LoopExits);
286       if (CT.has_value()) {
287         for (auto *Exit : LoopExits) {
288           auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
289             auto Token = getConvergenceToken(block);
290             if (Token == std::nullopt)
291               return false;
292             return Token.value() == CT.value();
293           });
294           RegionBlocks.insert(N.begin(), N.end());
295         }
296       }
297 
298       auto RegionExits = findExitNodes(RegionBlocks);
299       ConvergenceRegion *Region = new ConvergenceRegion(
300           DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
301           std::move(RegionExits));
302       Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
303       assert(Region->Parent != nullptr && "This is impossible.");
304       Region->Parent->Children.push_back(Region);
305     }
306 
307     return ConvergenceRegionInfo(TopLevelRegion);
308   }
309 
310 private:
311   DominatorTree &DT;
312   LoopInfo &LI;
313   Function &F;
314 };
315 
316 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
317                                             LoopInfo &LI) {
318   ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
319   return Analyzer.analyze();
320 }
321 
322 } // namespace SPIRV
323 
324 char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
325 
326 SPIRVConvergenceRegionAnalysisWrapperPass::
327     SPIRVConvergenceRegionAnalysisWrapperPass()
328     : FunctionPass(ID) {}
329 
330 bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
331   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
332   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
333 
334   CRI = SPIRV::getConvergenceRegions(F, DT, LI);
335   // Nothing was modified.
336   return false;
337 }
338 
339 SPIRVConvergenceRegionAnalysis::Result
340 SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
341   Result CRI;
342   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
343   auto &LI = AM.getResult<LoopAnalysis>(F);
344   CRI = SPIRV::getConvergenceRegions(F, DT, LI);
345   return CRI;
346 }
347 
348 AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
349 
350 } // namespace llvm
351