1 //===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // The SuspendCrossingInfo maintains data that allows to answer a question
9 // whether given two BasicBlocks A and B there is a path from A to B that
10 // passes through a suspend point. Note, SuspendCrossingInfo is invalidated
11 // by changes to the CFG including adding/removing BBs due to its use of BB
12 // ptrs in the BlockToIndexMapping.
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h"
16 #include "llvm/IR/ModuleSlotTracker.h"
17
18 // The "coro-suspend-crossing" flag is very noisy. There is another debug type,
19 // "coro-frame", which results in leaner debug spew.
20 #define DEBUG_TYPE "coro-suspend-crossing"
21
22 namespace llvm {
23 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dumpBasicBlockLabel(const BasicBlock * BB,ModuleSlotTracker & MST)24 static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) {
25 if (BB->hasName()) {
26 dbgs() << BB->getName();
27 return;
28 }
29
30 dbgs() << MST.getLocalSlot(BB);
31 }
32
33 LLVM_DUMP_METHOD void
dump(StringRef Label,BitVector const & BV,const ReversePostOrderTraversal<Function * > & RPOT,ModuleSlotTracker & MST) const34 SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV,
35 const ReversePostOrderTraversal<Function *> &RPOT,
36 ModuleSlotTracker &MST) const {
37 dbgs() << Label << ":";
38 for (const BasicBlock *BB : RPOT) {
39 auto BBNo = Mapping.blockToIndex(BB);
40 if (BV[BBNo]) {
41 dbgs() << " ";
42 dumpBasicBlockLabel(BB, MST);
43 }
44 }
45 dbgs() << "\n";
46 }
47
dump() const48 LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
49 if (Block.empty())
50 return;
51
52 BasicBlock *const B = Mapping.indexToBlock(0);
53 Function *F = B->getParent();
54
55 ModuleSlotTracker MST(F->getParent());
56 MST.incorporateFunction(*F);
57
58 ReversePostOrderTraversal<Function *> RPOT(F);
59 for (const BasicBlock *BB : RPOT) {
60 auto BBNo = Mapping.blockToIndex(BB);
61 dumpBasicBlockLabel(BB, MST);
62 dbgs() << ":\n";
63 dump(" Consumes", Block[BBNo].Consumes, RPOT, MST);
64 dump(" Kills", Block[BBNo].Kills, RPOT, MST);
65 }
66 dbgs() << "\n";
67 }
68 #endif
69
hasPathCrossingSuspendPoint(BasicBlock * From,BasicBlock * To) const70 bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From,
71 BasicBlock *To) const {
72 size_t const FromIndex = Mapping.blockToIndex(From);
73 size_t const ToIndex = Mapping.blockToIndex(To);
74 bool const Result = Block[ToIndex].Kills[FromIndex];
75 LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
76 << " crosses suspend point\n");
77 return Result;
78 }
79
hasPathOrLoopCrossingSuspendPoint(BasicBlock * From,BasicBlock * To) const80 bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint(
81 BasicBlock *From, BasicBlock *To) const {
82 size_t const FromIndex = Mapping.blockToIndex(From);
83 size_t const ToIndex = Mapping.blockToIndex(To);
84 bool Result = Block[ToIndex].Kills[FromIndex] ||
85 (From == To && Block[ToIndex].KillLoop);
86 LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
87 << " crosses suspend point (path or loop)\n");
88 return Result;
89 }
90
91 template <bool Initialize>
computeBlockData(const ReversePostOrderTraversal<Function * > & RPOT)92 bool SuspendCrossingInfo::computeBlockData(
93 const ReversePostOrderTraversal<Function *> &RPOT) {
94 bool Changed = false;
95
96 for (const BasicBlock *BB : RPOT) {
97 auto BBNo = Mapping.blockToIndex(BB);
98 auto &B = Block[BBNo];
99
100 // We don't need to count the predecessors when initialization.
101 if constexpr (!Initialize)
102 // If all the predecessors of the current Block don't change,
103 // the BlockData for the current block must not change too.
104 if (all_of(predecessors(B), [this](BasicBlock *BB) {
105 return !Block[Mapping.blockToIndex(BB)].Changed;
106 })) {
107 B.Changed = false;
108 continue;
109 }
110
111 // Saved Consumes and Kills bitsets so that it is easy to see
112 // if anything changed after propagation.
113 auto SavedConsumes = B.Consumes;
114 auto SavedKills = B.Kills;
115
116 for (BasicBlock *PI : predecessors(B)) {
117 auto PrevNo = Mapping.blockToIndex(PI);
118 auto &P = Block[PrevNo];
119
120 // Propagate Kills and Consumes from predecessors into B.
121 B.Consumes |= P.Consumes;
122 B.Kills |= P.Kills;
123
124 // If block P is a suspend block, it should propagate kills into block
125 // B for every block P consumes.
126 if (P.Suspend)
127 B.Kills |= P.Consumes;
128 }
129
130 if (B.Suspend) {
131 // If block B is a suspend block, it should kill all of the blocks it
132 // consumes.
133 B.Kills |= B.Consumes;
134 } else if (B.End) {
135 // If block B is an end block, it should not propagate kills as the
136 // blocks following coro.end() are reached during initial invocation
137 // of the coroutine while all the data are still available on the
138 // stack or in the registers.
139 B.Kills.reset();
140 } else {
141 // This is reached when B block it not Suspend nor coro.end and it
142 // need to make sure that it is not in the kill set.
143 B.KillLoop |= B.Kills[BBNo];
144 B.Kills.reset(BBNo);
145 }
146
147 if constexpr (!Initialize) {
148 B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
149 Changed |= B.Changed;
150 }
151 }
152
153 return Changed;
154 }
155
SuspendCrossingInfo(Function & F,const SmallVectorImpl<AnyCoroSuspendInst * > & CoroSuspends,const SmallVectorImpl<AnyCoroEndInst * > & CoroEnds)156 SuspendCrossingInfo::SuspendCrossingInfo(
157 Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
158 const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds)
159 : Mapping(F) {
160 const size_t N = Mapping.size();
161 Block.resize(N);
162
163 // Initialize every block so that it consumes itself
164 for (size_t I = 0; I < N; ++I) {
165 auto &B = Block[I];
166 B.Consumes.resize(N);
167 B.Kills.resize(N);
168 B.Consumes.set(I);
169 B.Changed = true;
170 }
171
172 // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
173 // the code beyond coro.end is reachable during initial invocation of the
174 // coroutine.
175 for (auto *CE : CoroEnds) {
176 // Verify CoroEnd was normalized
177 assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() &&
178 CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB");
179
180 getBlockData(CE->getParent()).End = true;
181 }
182
183 // Mark all suspend blocks and indicate that they kill everything they
184 // consume. Note, that crossing coro.save also requires a spill, as any code
185 // between coro.save and coro.suspend may resume the coroutine and all of the
186 // state needs to be saved by that time.
187 auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
188 BasicBlock *SuspendBlock = BarrierInst->getParent();
189 auto &B = getBlockData(SuspendBlock);
190 B.Suspend = true;
191 B.Kills |= B.Consumes;
192 };
193 for (auto *CSI : CoroSuspends) {
194 // Verify CoroSuspend was normalized
195 assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() &&
196 CSI->getParent()->size() <= 2 &&
197 "CoroSuspend must be in its own BB");
198
199 markSuspendBlock(CSI);
200 if (auto *Save = CSI->getCoroSave())
201 markSuspendBlock(Save);
202 }
203
204 // It is considered to be faster to use RPO traversal for forward-edges
205 // dataflow analysis.
206 ReversePostOrderTraversal<Function *> RPOT(&F);
207 computeBlockData</*Initialize=*/true>(RPOT);
208 while (computeBlockData</*Initialize*/ false>(RPOT))
209 ;
210
211 LLVM_DEBUG(dump());
212 }
213
214 } // namespace llvm
215