1 //===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis determines the convergence region for each basic block of
10 // the module, and provides a tree-like structure describing the region
11 // hierarchy.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "SPIRVConvergenceRegionAnalysis.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Utils/LoopSimplify.h"
21 #include <optional>
22 #include <queue>
23
24 #define DEBUG_TYPE "spirv-convergence-region-analysis"
25
26 using namespace llvm;
27
28 namespace llvm {
29 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
30 } // namespace llvm
31
32 INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
33 "convergence-region",
34 "SPIRV convergence regions analysis", true, true)
35 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
36 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
37 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
38 INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
39 "convergence-region", "SPIRV convergence regions analysis",
40 true, true)
41
42 namespace llvm {
43 namespace SPIRV {
44 namespace {
45
46 template <typename BasicBlockType, typename IntrinsicInstType>
47 std::optional<IntrinsicInstType *>
getConvergenceTokenInternal(BasicBlockType * BB)48 getConvergenceTokenInternal(BasicBlockType *BB) {
49 static_assert(std::is_const_v<IntrinsicInstType> ==
50 std::is_const_v<BasicBlockType>,
51 "Constness must match between input and output.");
52 static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
53 "Input must be a basic block.");
54 static_assert(
55 std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
56 "Output type must be an intrinsic instruction.");
57
58 for (auto &I : *BB) {
59 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
60 switch (II->getIntrinsicID()) {
61 case Intrinsic::experimental_convergence_entry:
62 case Intrinsic::experimental_convergence_loop:
63 return II;
64 case Intrinsic::experimental_convergence_anchor: {
65 auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);
66 assert(Bundle->Inputs.size() == 1 &&
67 Bundle->Inputs[0]->getType()->isTokenTy());
68 auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());
69 assert(TII != nullptr);
70 return TII;
71 }
72 }
73 }
74
75 if (auto *CI = dyn_cast<CallInst>(&I)) {
76 auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
77 if (!OB.has_value())
78 continue;
79 return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
80 }
81 }
82
83 return std::nullopt;
84 }
85
86 // Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
87 // region |Entry| belongs to. If |Entry| does not belong to the region defined
88 // by |Start|, this function returns |nullptr|.
findParentRegion(ConvergenceRegion * Start,BasicBlock * Entry)89 ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
90 BasicBlock *Entry) {
91 ConvergenceRegion *Candidate = nullptr;
92 ConvergenceRegion *NextCandidate = Start;
93
94 while (Candidate != NextCandidate && NextCandidate != nullptr) {
95 Candidate = NextCandidate;
96 NextCandidate = nullptr;
97
98 // End of the search, we can return.
99 if (Candidate->Children.size() == 0)
100 return Candidate;
101
102 for (auto *Child : Candidate->Children) {
103 if (Child->Blocks.count(Entry) != 0) {
104 NextCandidate = Child;
105 break;
106 }
107 }
108 }
109
110 return Candidate;
111 }
112
113 } // anonymous namespace
114
getConvergenceToken(BasicBlock * BB)115 std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
116 return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
117 }
118
getConvergenceToken(const BasicBlock * BB)119 std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {
120 return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
121 }
122
ConvergenceRegion(DominatorTree & DT,LoopInfo & LI,Function & F)123 ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
124 Function &F)
125 : DT(DT), LI(LI), Parent(nullptr) {
126 Entry = &F.getEntryBlock();
127 ConvergenceToken = getConvergenceToken(Entry);
128 for (auto &B : F) {
129 Blocks.insert(&B);
130 if (isa<ReturnInst>(B.getTerminator()))
131 Exits.insert(&B);
132 }
133 }
134
ConvergenceRegion(DominatorTree & DT,LoopInfo & LI,std::optional<IntrinsicInst * > ConvergenceToken,BasicBlock * Entry,SmallPtrSet<BasicBlock *,8> && Blocks,SmallPtrSet<BasicBlock *,2> && Exits)135 ConvergenceRegion::ConvergenceRegion(
136 DominatorTree &DT, LoopInfo &LI,
137 std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
138 SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
139 : DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
140 Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
141 for ([[maybe_unused]] auto *BB : this->Exits)
142 assert(this->Blocks.count(BB) != 0);
143 assert(this->Blocks.count(this->Entry) != 0);
144 }
145
releaseMemory()146 void ConvergenceRegion::releaseMemory() {
147 // Parent memory is owned by the parent.
148 Parent = nullptr;
149 for (auto *Child : Children) {
150 Child->releaseMemory();
151 delete Child;
152 }
153 Children.resize(0);
154 }
155
dump(const unsigned IndentSize) const156 void ConvergenceRegion::dump(const unsigned IndentSize) const {
157 const std::string Indent(IndentSize, '\t');
158 dbgs() << Indent << this << ": {\n";
159 dbgs() << Indent << " Parent: " << Parent << "\n";
160
161 if (ConvergenceToken.value_or(nullptr)) {
162 dbgs() << Indent
163 << " ConvergenceToken: " << ConvergenceToken.value()->getName()
164 << "\n";
165 }
166
167 if (Entry->getName() != "")
168 dbgs() << Indent << " Entry: " << Entry->getName() << "\n";
169 else
170 dbgs() << Indent << " Entry: " << Entry << "\n";
171
172 dbgs() << Indent << " Exits: { ";
173 for (const auto &Exit : Exits) {
174 if (Exit->getName() != "")
175 dbgs() << Exit->getName() << ", ";
176 else
177 dbgs() << Exit << ", ";
178 }
179 dbgs() << " }\n";
180
181 dbgs() << Indent << " Blocks: { ";
182 for (const auto &Block : Blocks) {
183 if (Block->getName() != "")
184 dbgs() << Block->getName() << ", ";
185 else
186 dbgs() << Block << ", ";
187 }
188 dbgs() << " }\n";
189
190 dbgs() << Indent << " Children: {\n";
191 for (const auto Child : Children)
192 Child->dump(IndentSize + 2);
193 dbgs() << Indent << " }\n";
194
195 dbgs() << Indent << "}\n";
196 }
197
198 class ConvergenceRegionAnalyzer {
199
200 public:
ConvergenceRegionAnalyzer(Function & F,DominatorTree & DT,LoopInfo & LI)201 ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
202 : DT(DT), LI(LI), F(F) {}
203
204 private:
isBackEdge(const BasicBlock * From,const BasicBlock * To) const205 bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
206 assert(From != To && "From == To. This is awkward.");
207
208 // We only handle loop in the simplified form. This means:
209 // - a single back-edge, a single latch.
210 // - meaning the back-edge target can only be the loop header.
211 // - meaning the From can only be the loop latch.
212 if (!LI.isLoopHeader(To))
213 return false;
214
215 auto *L = LI.getLoopFor(To);
216 if (L->contains(From) && L->isLoopLatch(From))
217 return true;
218
219 return false;
220 }
221
222 std::unordered_set<BasicBlock *>
findPathsToMatch(LoopInfo & LI,BasicBlock * From,std::function<bool (const BasicBlock *)> isMatch) const223 findPathsToMatch(LoopInfo &LI, BasicBlock *From,
224 std::function<bool(const BasicBlock *)> isMatch) const {
225 std::unordered_set<BasicBlock *> Output;
226
227 if (isMatch(From))
228 Output.insert(From);
229
230 auto *Terminator = From->getTerminator();
231 for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
232 auto *To = Terminator->getSuccessor(i);
233 if (isBackEdge(From, To))
234 continue;
235
236 auto ChildSet = findPathsToMatch(LI, To, isMatch);
237 if (ChildSet.size() == 0)
238 continue;
239
240 Output.insert(ChildSet.begin(), ChildSet.end());
241 Output.insert(From);
242 if (LI.isLoopHeader(From)) {
243 auto *L = LI.getLoopFor(From);
244 for (auto *BB : L->getBlocks()) {
245 Output.insert(BB);
246 }
247 }
248 }
249
250 return Output;
251 }
252
253 SmallPtrSet<BasicBlock *, 2>
findExitNodes(const SmallPtrSetImpl<BasicBlock * > & RegionBlocks)254 findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
255 SmallPtrSet<BasicBlock *, 2> Exits;
256
257 for (auto *B : RegionBlocks) {
258 auto *Terminator = B->getTerminator();
259 for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
260 auto *Child = Terminator->getSuccessor(i);
261 if (RegionBlocks.count(Child) == 0)
262 Exits.insert(B);
263 }
264 }
265
266 return Exits;
267 }
268
269 public:
analyze()270 ConvergenceRegionInfo analyze() {
271 ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
272 std::queue<Loop *> ToProcess;
273 for (auto *L : LI.getLoopsInPreorder())
274 ToProcess.push(L);
275
276 while (ToProcess.size() != 0) {
277 auto *L = ToProcess.front();
278 ToProcess.pop();
279 assert(L->isLoopSimplifyForm());
280
281 auto CT = getConvergenceToken(L->getHeader());
282 SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
283 L->block_end());
284 SmallVector<BasicBlock *> LoopExits;
285 L->getExitingBlocks(LoopExits);
286 if (CT.has_value()) {
287 for (auto *Exit : LoopExits) {
288 auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
289 auto Token = getConvergenceToken(block);
290 if (Token == std::nullopt)
291 return false;
292 return Token.value() == CT.value();
293 });
294 RegionBlocks.insert(N.begin(), N.end());
295 }
296 }
297
298 auto RegionExits = findExitNodes(RegionBlocks);
299 ConvergenceRegion *Region = new ConvergenceRegion(
300 DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
301 std::move(RegionExits));
302 Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
303 assert(Region->Parent != nullptr && "This is impossible.");
304 Region->Parent->Children.push_back(Region);
305 }
306
307 return ConvergenceRegionInfo(TopLevelRegion);
308 }
309
310 private:
311 DominatorTree &DT;
312 LoopInfo &LI;
313 Function &F;
314 };
315
getConvergenceRegions(Function & F,DominatorTree & DT,LoopInfo & LI)316 ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
317 LoopInfo &LI) {
318 ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
319 return Analyzer.analyze();
320 }
321
322 } // namespace SPIRV
323
324 char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
325
326 SPIRVConvergenceRegionAnalysisWrapperPass::
SPIRVConvergenceRegionAnalysisWrapperPass()327 SPIRVConvergenceRegionAnalysisWrapperPass()
328 : FunctionPass(ID) {}
329
runOnFunction(Function & F)330 bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
331 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
332 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
333
334 CRI = SPIRV::getConvergenceRegions(F, DT, LI);
335 // Nothing was modified.
336 return false;
337 }
338
339 SPIRVConvergenceRegionAnalysis::Result
run(Function & F,FunctionAnalysisManager & AM)340 SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
341 Result CRI;
342 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
343 auto &LI = AM.getResult<LoopAnalysis>(F);
344 CRI = SPIRV::getConvergenceRegions(F, DT, LI);
345 return CRI;
346 }
347
348 AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
349
350 } // namespace llvm
351