1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27
28 using namespace llvm;
29
30 #define DEBUG_TYPE "igrouplp"
31
32 namespace {
33
34 static cl::opt<bool> EnableExactSolver(
35 "amdgpu-igrouplp-exact-solver", cl::Hidden,
36 cl::desc("Whether to use the exponential time solver to fit "
37 "the instructions to the pipeline as closely as "
38 "possible."),
39 cl::init(false));
40
41 static cl::opt<unsigned> CutoffForExact(
42 "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43 cl::desc("The maximum number of scheduling group conflicts "
44 "which we attempt to solve with the exponential time "
45 "exact solver. Problem sizes greater than this will"
46 "be solved by the less accurate greedy algorithm. Selecting "
47 "solver by size is superseded by manually selecting "
48 "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49
50 static cl::opt<uint64_t> MaxBranchesExplored(
51 "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52 cl::desc("The amount of branches that we are willing to explore with"
53 "the exact algorithm before giving up."));
54
55 static cl::opt<bool> UseCostHeur(
56 "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57 cl::desc("Whether to use the cost heuristic to make choices as we "
58 "traverse the search space using the exact solver. Defaulted "
59 "to on, and if turned off, we will use the node order -- "
60 "attempting to put the later nodes in the later sched groups. "
61 "Experimentally, results are mixed, so this should be set on a "
62 "case-by-case basis."));
63
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67 NONE = 0u,
68 ALU = 1u << 0,
69 VALU = 1u << 1,
70 SALU = 1u << 2,
71 MFMA = 1u << 3,
72 VMEM = 1u << 4,
73 VMEM_READ = 1u << 5,
74 VMEM_WRITE = 1u << 6,
75 DS = 1u << 7,
76 DS_READ = 1u << 8,
77 DS_WRITE = 1u << 9,
78 TRANS = 1u << 10,
79 ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
80 DS_READ | DS_WRITE | TRANS,
81 LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
82 };
83
84 class SchedGroup;
85
86 // InstructionRule class is used to enact a filter which determines whether or
87 // not an SU maps to a given SchedGroup. It contains complementary data
88 // structures (e.g Cache) to help those filters.
89 class InstructionRule {
90 protected:
91 const SIInstrInfo *TII;
92 unsigned SGID;
93 // A cache made available to the Filter to store SUnits for subsequent
94 // invocations of the Filter
95 std::optional<SmallVector<SUnit *, 4>> Cache;
96
97 public:
98 virtual bool
apply(const SUnit *,const ArrayRef<SUnit * >,SmallVectorImpl<SchedGroup> &)99 apply(const SUnit *, const ArrayRef<SUnit *>,
100 SmallVectorImpl<SchedGroup> &) {
101 return true;
102 };
103
InstructionRule(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)104 InstructionRule(const SIInstrInfo *TII, unsigned SGID,
105 bool NeedsCache = false)
106 : TII(TII), SGID(SGID) {
107 if (NeedsCache) {
108 Cache = SmallVector<SUnit *, 4>();
109 }
110 }
111
112 virtual ~InstructionRule() = default;
113 };
114
115 using SUnitsToCandidateSGsMap = DenseMap<SUnit *, SmallVector<int, 4>>;
116
117 // Classify instructions into groups to enable fine tuned control over the
118 // scheduler. These groups may be more specific than current SchedModel
119 // instruction classes.
120 class SchedGroup {
121 private:
122 // Mask that defines which instruction types can be classified into this
123 // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
124 // and SCHED_GROUP_BARRIER.
125 SchedGroupMask SGMask;
126
127 // Maximum number of SUnits that can be added to this group.
128 std::optional<unsigned> MaxSize;
129
130 // SchedGroups will only synchronize with other SchedGroups that have the same
131 // SyncID.
132 int SyncID = 0;
133
134 // SGID is used to map instructions to candidate SchedGroups
135 unsigned SGID;
136
137 // The different rules each instruction in this SchedGroup must conform to
138 SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
139
140 // Count of the number of created SchedGroups, used to initialize SGID.
141 static unsigned NumSchedGroups;
142
143 // Try to add and edge from SU A to SU B.
144 bool tryAddEdge(SUnit *A, SUnit *B);
145
146 // Use SGMask to determine whether we can classify MI as a member of this
147 // SchedGroup object.
148 bool canAddMI(const MachineInstr &MI) const;
149
150 public:
151 // Collection of SUnits that are classified as members of this group.
152 SmallVector<SUnit *, 32> Collection;
153
154 ScheduleDAGInstrs *DAG;
155 const SIInstrInfo *TII;
156
157 // Returns true if SU can be added to this SchedGroup.
158 bool canAddSU(SUnit &SU) const;
159
160 // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
161 // MakePred is true, SU will be a predecessor of the SUnits in this
162 // SchedGroup, otherwise SU will be a successor.
163 void link(SUnit &SU, bool MakePred = false);
164
165 // Add DAG dependencies and track which edges are added, and the count of
166 // missed edges
167 int link(SUnit &SU, bool MakePred,
168 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
169
170 // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
171 // Use the predicate to determine whether SU should be a predecessor (P =
172 // true) or a successor (P = false) of this SchedGroup.
173 void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
174
175 // Add DAG dependencies such that SUnits in this group shall be ordered
176 // before SUnits in OtherGroup.
177 void link(SchedGroup &OtherGroup);
178
179 // Returns true if no more instructions may be added to this group.
isFull() const180 bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
181
182 // Append a constraint that SUs must meet in order to fit into this
183 // SchedGroup. Since many rules involve the relationship between a SchedGroup
184 // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
185 // time (rather than SchedGroup init time.)
addRule(std::shared_ptr<InstructionRule> NewRule)186 void addRule(std::shared_ptr<InstructionRule> NewRule) {
187 Rules.push_back(NewRule);
188 }
189
190 // Returns true if the SU matches all rules
allowedByRules(const SUnit * SU,SmallVectorImpl<SchedGroup> & SyncPipe) const191 bool allowedByRules(const SUnit *SU,
192 SmallVectorImpl<SchedGroup> &SyncPipe) const {
193 for (auto &Rule : Rules) {
194 if (!Rule.get()->apply(SU, Collection, SyncPipe))
195 return false;
196 }
197 return true;
198 }
199
200 // Add SU to the SchedGroup.
add(SUnit & SU)201 void add(SUnit &SU) {
202 LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
203 << format_hex((int)SGMask, 10, true) << " adding "
204 << *SU.getInstr());
205 Collection.push_back(&SU);
206 }
207
208 // Remove last element in the SchedGroup
pop()209 void pop() { Collection.pop_back(); }
210
211 // Identify and add all relevant SUs from the DAG to this SchedGroup.
212 void initSchedGroup();
213
214 // Add instructions to the SchedGroup bottom up starting from RIter.
215 // PipelineInstrs is a set of instructions that should not be added to the
216 // SchedGroup even when the other conditions for adding it are satisfied.
217 // RIter will be added to the SchedGroup as well, and dependencies will be
218 // added so that RIter will always be scheduled at the end of the group.
219 void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
220 SUnitsToCandidateSGsMap &SyncedInstrs);
221
222 void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
223
getSyncID()224 int getSyncID() { return SyncID; }
225
getSGID()226 int getSGID() { return SGID; }
227
getMask()228 SchedGroupMask getMask() { return SGMask; }
229
SchedGroup(SchedGroupMask SGMask,std::optional<unsigned> MaxSize,ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)230 SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
231 ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
232 : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
233 SGID = NumSchedGroups++;
234 }
235
SchedGroup(SchedGroupMask SGMask,std::optional<unsigned> MaxSize,int SyncID,ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)236 SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
237 ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
238 : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
239 SGID = NumSchedGroups++;
240 }
241 };
242
243 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
resetEdges(SUnit & SU,ScheduleDAGInstrs * DAG)244 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
245 assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
246 SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
247 SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
248
249 while (!SU.Preds.empty())
250 for (auto &P : SU.Preds)
251 SU.removePred(P);
252
253 while (!SU.Succs.empty())
254 for (auto &S : SU.Succs)
255 for (auto &SP : S.getSUnit()->Preds)
256 if (SP.getSUnit() == &SU)
257 S.getSUnit()->removePred(SP);
258 }
259
260 using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;
261 using SUsToCandSGsVec = SmallVector<SUToCandSGsPair, 4>;
262
263 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
264 // in non-trivial cases. For example, if the requested pipeline is
265 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
266 // in the DAG, then we will have an instruction that can not be trivially
267 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
268 // to find a good solution to the pipeline -- a greedy algorithm and an exact
269 // algorithm. The exact algorithm has an exponential time complexity and should
270 // only be used for small sized problems or medium sized problems where an exact
271 // solution is highly desired.
272 class PipelineSolver {
273 ScheduleDAGMI *DAG;
274
275 // Instructions that can be assigned to multiple SchedGroups
276 DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
277 SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
278 DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
279 // The current working pipeline
280 SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
281 // The pipeline that has the best solution found so far
282 SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
283
284 // Whether or not we actually have any SyncedInstrs to try to solve.
285 bool NeedsSolver = false;
286
287 // Compute an estimate of the size of search tree -- the true size is
288 // the product of each conflictedInst.Matches.size() across all SyncPipelines
289 unsigned computeProblemSize();
290
291 // The cost penalty of not assigning a SU to a SchedGroup
292 int MissPenalty = 0;
293
294 // Costs in terms of the number of edges we are unable to add
295 int BestCost = -1;
296 int CurrCost = 0;
297
298 // Index pointing to the conflicting instruction that is currently being
299 // fitted
300 int CurrConflInstNo = 0;
301 // Index to the pipeline that is currently being fitted
302 int CurrSyncGroupIdx = 0;
303 // The first non trivial pipeline
304 int BeginSyncGroupIdx = 0;
305
306 // How many branches we have explored
307 uint64_t BranchesExplored = 0;
308
309 // The direction in which we process the candidate SchedGroups per SU
310 bool IsBottomUp = true;
311
312 // Update indices to fit next conflicting instruction
313 void advancePosition();
314 // Recede indices to attempt to find better fit for previous conflicting
315 // instruction
316 void retreatPosition();
317
318 // The exponential time algorithm which finds the provably best fit
319 bool solveExact();
320 // The polynomial time algorithm which attempts to find a good fit
321 bool solveGreedy();
322 // Find the best SchedGroup for the current SU using the heuristic given all
323 // current information. One step in the greedy algorithm. Templated against
324 // the SchedGroup iterator (either reverse or forward).
325 template <typename T>
326 void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
327 T E);
328 // Whether or not the current solution is optimal
329 bool checkOptimal();
330 // Populate the ready list, prioiritizing fewest missed edges first
331 // Templated against the SchedGroup iterator (either reverse or forward).
332 template <typename T>
333 void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
334 T E);
335 // Add edges corresponding to the SchedGroups as assigned by solver
336 void makePipeline();
337 // Link the SchedGroups in the best found pipeline.
338 // Tmplated against the SchedGroup iterator (either reverse or forward).
339 template <typename T> void linkSchedGroups(T I, T E);
340 // Add the edges from the SU to the other SchedGroups in pipeline, and
341 // return the number of edges missed.
342 int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
343 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
344 /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
345 /// returns the cost (in terms of missed pipeline edges), and tracks the edges
346 /// added in \p AddedEdges
347 template <typename T>
348 int linkSUnit(SUnit *SU, int SGID,
349 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
350 /// Remove the edges passed via \p AddedEdges
351 void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
352 // Convert the passed in maps to arrays for bidirectional iterators
353 void convertSyncMapsToArrays();
354
355 void reset();
356
357 public:
358 // Invoke the solver to map instructions to instruction groups. Heuristic &&
359 // command-line-option determines to use exact or greedy algorithm.
360 void solve();
361
PipelineSolver(DenseMap<int,SmallVector<SchedGroup,4>> & SyncedSchedGroups,DenseMap<int,SUnitsToCandidateSGsMap> & SyncedInstrs,ScheduleDAGMI * DAG,bool IsBottomUp=true)362 PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
363 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
364 ScheduleDAGMI *DAG, bool IsBottomUp = true)
365 : DAG(DAG), SyncedInstrs(SyncedInstrs),
366 SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
367
368 for (auto &PipelineInstrs : SyncedInstrs) {
369 if (PipelineInstrs.second.size() > 0) {
370 NeedsSolver = true;
371 break;
372 }
373 }
374
375 if (!NeedsSolver)
376 return;
377
378 convertSyncMapsToArrays();
379
380 CurrPipeline = BestPipeline;
381
382 while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
383 PipelineInstrs[BeginSyncGroupIdx].size() == 0)
384 ++BeginSyncGroupIdx;
385
386 if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
387 return;
388 }
389 };
390
reset()391 void PipelineSolver::reset() {
392
393 for (auto &SyncPipeline : CurrPipeline) {
394 for (auto &SG : SyncPipeline) {
395 SmallVector<SUnit *, 32> TempCollection = SG.Collection;
396 SG.Collection.clear();
397 auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
398 return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
399 });
400 if (SchedBarr != TempCollection.end())
401 SG.Collection.push_back(*SchedBarr);
402 }
403 }
404
405 CurrSyncGroupIdx = BeginSyncGroupIdx;
406 CurrConflInstNo = 0;
407 CurrCost = 0;
408 }
409
convertSyncMapsToArrays()410 void PipelineSolver::convertSyncMapsToArrays() {
411 for (auto &SyncPipe : SyncedSchedGroups) {
412 BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
413 }
414
415 int PipelineIDx = SyncedInstrs.size() - 1;
416 PipelineInstrs.resize(SyncedInstrs.size());
417 for (auto &SyncInstrMap : SyncedInstrs) {
418 for (auto &SUsToCandSGs : SyncInstrMap.second) {
419 if (PipelineInstrs[PipelineIDx].size() == 0) {
420 PipelineInstrs[PipelineIDx].push_back(
421 std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
422 continue;
423 }
424 auto SortPosition = PipelineInstrs[PipelineIDx].begin();
425 // Insert them in sorted order -- this allows for good parsing order in
426 // the greedy algorithm
427 while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
428 SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
429 ++SortPosition;
430 PipelineInstrs[PipelineIDx].insert(
431 SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
432 }
433 --PipelineIDx;
434 }
435 }
436
linkSchedGroups(T I,T E)437 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
438 for (; I != E; ++I) {
439 auto &GroupA = *I;
440 for (auto J = std::next(I); J != E; ++J) {
441 auto &GroupB = *J;
442 GroupA.link(GroupB);
443 }
444 }
445 }
446
makePipeline()447 void PipelineSolver::makePipeline() {
448 // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
449 for (auto &SyncPipeline : BestPipeline) {
450 LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
451 for (auto &SG : SyncPipeline) {
452 LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
453 << " has: \n");
454 SUnit *SGBarr = nullptr;
455 for (auto &SU : SG.Collection) {
456 if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
457 SGBarr = SU;
458 LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
459 }
460 // Command line requested IGroupLP doesn't have SGBarr
461 if (!SGBarr)
462 continue;
463 resetEdges(*SGBarr, DAG);
464 SG.link(*SGBarr, false);
465 }
466 }
467
468 for (auto &SyncPipeline : BestPipeline) {
469 IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
470 : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
471 }
472 }
473
474 template <typename T>
linkSUnit(SUnit * SU,int SGID,std::vector<std::pair<SUnit *,SUnit * >> & AddedEdges,T I,T E)475 int PipelineSolver::linkSUnit(
476 SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
477 T I, T E) {
478 bool MakePred = false;
479 int AddedCost = 0;
480 for (; I < E; ++I) {
481 if (I->getSGID() == SGID) {
482 MakePred = true;
483 continue;
484 }
485 auto Group = *I;
486 AddedCost += Group.link(*SU, MakePred, AddedEdges);
487 assert(AddedCost >= 0);
488 }
489 return AddedCost;
490 }
491
addEdges(SmallVectorImpl<SchedGroup> & SyncPipeline,SUnit * SU,int SGID,std::vector<std::pair<SUnit *,SUnit * >> & AddedEdges)492 int PipelineSolver::addEdges(
493 SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
494 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
495
496 // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
497 // instructions that are the ultimate successors in the resultant mutation.
498 // Therefore, in such a configuration, the SchedGroups occurring before the
499 // candidate SGID are successors of the candidate SchedGroup, thus the current
500 // SU should be linked as a predecessor to SUs in those SchedGroups. The
501 // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
502 // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
503 // IsBottomUp (in reverse).
504 return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
505 SyncPipeline.rend())
506 : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
507 SyncPipeline.end());
508 }
509
removeEdges(const std::vector<std::pair<SUnit *,SUnit * >> & EdgesToRemove)510 void PipelineSolver::removeEdges(
511 const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
512 // Only remove the edges that we have added when testing
513 // the fit.
514 for (auto &PredSuccPair : EdgesToRemove) {
515 SUnit *Pred = PredSuccPair.first;
516 SUnit *Succ = PredSuccPair.second;
517
518 auto Match = llvm::find_if(
519 Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
520 if (Match != Succ->Preds.end()) {
521 assert(Match->isArtificial());
522 Succ->removePred(*Match);
523 }
524 }
525 }
526
advancePosition()527 void PipelineSolver::advancePosition() {
528 ++CurrConflInstNo;
529
530 if (static_cast<size_t>(CurrConflInstNo) >=
531 PipelineInstrs[CurrSyncGroupIdx].size()) {
532 CurrConflInstNo = 0;
533 ++CurrSyncGroupIdx;
534 // Advance to next non-trivial pipeline
535 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
536 PipelineInstrs[CurrSyncGroupIdx].size() == 0)
537 ++CurrSyncGroupIdx;
538 }
539 }
540
retreatPosition()541 void PipelineSolver::retreatPosition() {
542 assert(CurrConflInstNo >= 0);
543 assert(CurrSyncGroupIdx >= 0);
544
545 if (CurrConflInstNo > 0) {
546 --CurrConflInstNo;
547 return;
548 }
549
550 if (CurrConflInstNo == 0) {
551 // If we return to the starting position, we have explored
552 // the entire tree
553 if (CurrSyncGroupIdx == BeginSyncGroupIdx)
554 return;
555
556 --CurrSyncGroupIdx;
557 // Go to previous non-trivial pipeline
558 while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
559 --CurrSyncGroupIdx;
560
561 CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
562 }
563 }
564
checkOptimal()565 bool PipelineSolver::checkOptimal() {
566 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
567 if (BestCost == -1 || CurrCost < BestCost) {
568 BestPipeline = CurrPipeline;
569 BestCost = CurrCost;
570 LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
571 }
572 assert(BestCost >= 0);
573 }
574
575 bool DoneExploring = false;
576 if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
577 DoneExploring = true;
578
579 return (DoneExploring || BestCost == 0);
580 }
581
582 template <typename T>
populateReadyList(SmallVectorImpl<std::pair<int,int>> & ReadyList,T I,T E)583 void PipelineSolver::populateReadyList(
584 SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
585 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
586 auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
587 assert(CurrSU.second.size() >= 1);
588
589 for (; I != E; ++I) {
590 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
591 int CandSGID = *I;
592 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
593 return SG.getSGID() == CandSGID;
594 });
595 assert(Match);
596
597 if (UseCostHeur) {
598 if (Match->isFull()) {
599 ReadyList.push_back(std::pair(*I, MissPenalty));
600 continue;
601 }
602
603 int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
604 ReadyList.push_back(std::pair(*I, TempCost));
605 removeEdges(AddedEdges);
606 } else
607 ReadyList.push_back(std::pair(*I, -1));
608 }
609
610 if (UseCostHeur) {
611 std::sort(ReadyList.begin(), ReadyList.end(),
612 [](std::pair<int, int> A, std::pair<int, int> B) {
613 return A.second < B.second;
614 });
615 }
616
617 assert(ReadyList.size() == CurrSU.second.size());
618 }
619
solveExact()620 bool PipelineSolver::solveExact() {
621 if (checkOptimal())
622 return true;
623
624 if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
625 return false;
626
627 assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
628 assert(static_cast<size_t>(CurrConflInstNo) <
629 PipelineInstrs[CurrSyncGroupIdx].size());
630 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
631 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
632 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
633
634 // SchedGroup -> Cost pairs
635 SmallVector<std::pair<int, int>, 4> ReadyList;
636 // Prioritize the candidate sched groups in terms of lowest cost first
637 IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
638 CurrSU.second.rend())
639 : populateReadyList(ReadyList, CurrSU.second.begin(),
640 CurrSU.second.end());
641
642 auto I = ReadyList.begin();
643 auto E = ReadyList.end();
644 for (; I != E; ++I) {
645 // If we are trying SGs in least cost order, and the current SG is cost
646 // infeasible, then all subsequent SGs will also be cost infeasible, so we
647 // can prune.
648 if (BestCost != -1 && (CurrCost + I->second > BestCost))
649 return false;
650
651 int CandSGID = I->first;
652 int AddedCost = 0;
653 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
654 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
655 SchedGroup *Match;
656 for (auto &SG : SyncPipeline) {
657 if (SG.getSGID() == CandSGID)
658 Match = &SG;
659 }
660
661 if (Match->isFull())
662 continue;
663
664 if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
665 continue;
666
667 LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
668 << (int)Match->getMask() << "and ID " << CandSGID
669 << "\n");
670 Match->add(*CurrSU.first);
671 AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
672 LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
673 CurrCost += AddedCost;
674 advancePosition();
675 ++BranchesExplored;
676 bool FinishedExploring = false;
677 // If the Cost after adding edges is greater than a known solution,
678 // backtrack
679 if (CurrCost < BestCost || BestCost == -1) {
680 if (solveExact()) {
681 FinishedExploring = BestCost != 0;
682 if (!FinishedExploring)
683 return true;
684 }
685 }
686
687 retreatPosition();
688 CurrCost -= AddedCost;
689 removeEdges(AddedEdges);
690 Match->pop();
691 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
692 if (FinishedExploring)
693 return true;
694 }
695
696 // Try the pipeline where the current instruction is omitted
697 // Potentially if we omit a problematic instruction from the pipeline,
698 // all the other instructions can nicely fit.
699 CurrCost += MissPenalty;
700 advancePosition();
701
702 LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
703
704 bool FinishedExploring = false;
705 if (CurrCost < BestCost || BestCost == -1) {
706 if (solveExact()) {
707 bool FinishedExploring = BestCost != 0;
708 if (!FinishedExploring)
709 return true;
710 }
711 }
712
713 retreatPosition();
714 CurrCost -= MissPenalty;
715 return FinishedExploring;
716 }
717
718 template <typename T>
greedyFind(std::vector<std::pair<SUnit *,SUnit * >> & AddedEdges,T I,T E)719 void PipelineSolver::greedyFind(
720 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
721 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
722 int BestNodeCost = -1;
723 int TempCost;
724 SchedGroup *BestGroup = nullptr;
725 int BestGroupID = -1;
726 auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
727 LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
728 << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
729
730 // Since we have added the potential SchedGroups from bottom up, but
731 // traversed the DAG from top down, parse over the groups from last to
732 // first. If we fail to do this for the greedy algorithm, the solution will
733 // likely not be good in more complex cases.
734 for (; I != E; ++I) {
735 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
736 int CandSGID = *I;
737 SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
738 return SG.getSGID() == CandSGID;
739 });
740 assert(Match);
741
742 LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
743 << (int)Match->getMask() << "\n");
744
745 if (Match->isFull()) {
746 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
747 continue;
748 }
749 if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
750 LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
751 continue;
752 }
753 TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
754 LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
755 if (TempCost < BestNodeCost || BestNodeCost == -1) {
756 BestGroup = Match;
757 BestNodeCost = TempCost;
758 BestGroupID = CandSGID;
759 }
760 removeEdges(AddedEdges);
761 if (BestNodeCost == 0)
762 break;
763 }
764
765 if (BestGroupID != -1) {
766 BestGroup->add(*CurrSU.first);
767 addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
768 LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
769 << (int)BestGroup->getMask() << "\n");
770 BestCost += TempCost;
771 } else
772 BestCost += MissPenalty;
773
774 CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
775 }
776
solveGreedy()777 bool PipelineSolver::solveGreedy() {
778 BestCost = 0;
779 std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
780
781 while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
782 SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
783 IsBottomUp
784 ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
785 : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
786 advancePosition();
787 }
788 BestPipeline = CurrPipeline;
789 removeEdges(AddedEdges);
790 return false;
791 }
792
computeProblemSize()793 unsigned PipelineSolver::computeProblemSize() {
794 unsigned ProblemSize = 0;
795 for (auto &PipeConflicts : PipelineInstrs) {
796 ProblemSize += PipeConflicts.size();
797 }
798
799 return ProblemSize;
800 }
801
solve()802 void PipelineSolver::solve() {
803 if (!NeedsSolver)
804 return;
805
806 unsigned ProblemSize = computeProblemSize();
807 assert(ProblemSize > 0);
808
809 bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
810 MissPenalty = (ProblemSize / 2) + 1;
811
812 LLVM_DEBUG(DAG->dump());
813 if (EnableExactSolver || BelowCutoff) {
814 LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
815 solveGreedy();
816 reset();
817 LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
818 if (BestCost > 0) {
819 LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
820 solveExact();
821 LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
822 }
823 } else { // Use the Greedy Algorithm by default
824 LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
825 solveGreedy();
826 }
827
828 makePipeline();
829 LLVM_DEBUG(dbgs() << "After applying mutation\n");
830 LLVM_DEBUG(DAG->dump());
831 }
832
833 enum IGLPStrategyID : int {
834 MFMASmallGemmOptID = 0,
835 MFMASmallGemmSingleWaveOptID = 1,
836 MFMAExpInterleave = 2
837 };
838
839 // Implement a IGLP scheduling strategy.
840 class IGLPStrategy {
841 protected:
842 ScheduleDAGInstrs *DAG;
843
844 const SIInstrInfo *TII;
845
846 public:
847 /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
848 virtual bool applyIGLPStrategy(
849 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
850 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
851 AMDGPU::SchedulingPhase Phase) = 0;
852
853 // Returns true if this strategy should be applied to a ScheduleDAG.
854 virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
855 AMDGPU::SchedulingPhase Phase) = 0;
856
857 bool IsBottomUp = true;
858
IGLPStrategy(ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)859 IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
860 : DAG(DAG), TII(TII) {}
861
862 virtual ~IGLPStrategy() = default;
863 };
864
865 class MFMASmallGemmOpt final : public IGLPStrategy {
866 private:
867 public:
868 bool applyIGLPStrategy(
869 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
870 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
871 AMDGPU::SchedulingPhase Phase) override;
872
shouldApplyStrategy(ScheduleDAGInstrs * DAG,AMDGPU::SchedulingPhase Phase)873 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
874 AMDGPU::SchedulingPhase Phase) override {
875 return true;
876 }
877
MFMASmallGemmOpt(ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)878 MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
879 : IGLPStrategy(DAG, TII) {
880 IsBottomUp = true;
881 }
882 };
883
applyIGLPStrategy(DenseMap<int,SUnitsToCandidateSGsMap> & SyncedInstrs,DenseMap<int,SmallVector<SchedGroup,4>> & SyncedSchedGroups,AMDGPU::SchedulingPhase Phase)884 bool MFMASmallGemmOpt::applyIGLPStrategy(
885 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
886 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
887 AMDGPU::SchedulingPhase Phase) {
888 // Count the number of MFMA instructions.
889 unsigned MFMACount = 0;
890 for (const MachineInstr &I : *DAG)
891 if (TII->isMFMAorWMMA(I))
892 ++MFMACount;
893
894 const unsigned PipelineSyncID = 0;
895 SchedGroup *SG = nullptr;
896 for (unsigned I = 0; I < MFMACount * 3; ++I) {
897 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
898 SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
899 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
900
901 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
902 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
903 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
904 }
905
906 return true;
907 }
908
909 class MFMAExpInterleaveOpt final : public IGLPStrategy {
910 private:
911 // The count of TRANS SUs involved in the interleaved pipeline
912 static unsigned TransPipeCount;
913 // The count of MFMA SUs involved in the interleaved pipeline
914 static unsigned MFMAPipeCount;
915 // The count of Add SUs involved in the interleaved pipeline
916 static unsigned AddPipeCount;
917 // The number of transitive MFMA successors for each TRANS SU
918 static unsigned MFMAEnablement;
919 // The number of transitive TRANS predecessors for each MFMA SU
920 static unsigned ExpRequirement;
921 // The count of independent "chains" of MFMA instructions in the pipeline
922 static unsigned MFMAChains;
923 // The length of each independent "chain" of MFMA instructions
924 static unsigned MFMAChainLength;
925 // Whether or not the pipeline has V_CVT instructions
926 static bool HasCvt;
927 // Whether or not there are instructions between the TRANS instruction and
928 // V_CVT
929 static bool HasChainBetweenCvt;
930 // The first occuring DS_READ which feeds an MFMA chain
931 static std::optional<unsigned> FirstPipeDSR;
932 // The MFMAPipe SUs with no MFMA predecessors
933 SmallVector<SUnit *, 4> MFMAChainSeeds;
934 // Compute the heuristics for the pipeline, returning whether or not the DAG
935 // is well formatted for the mutation
936 bool analyzeDAG(const SIInstrInfo *TII);
937
938 /// Whether or not the instruction is a transitive predecessor of an MFMA
939 /// instruction
940 class IsPipeExp final : public InstructionRule {
941 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)942 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
943 SmallVectorImpl<SchedGroup> &SyncPipe) override {
944
945 auto DAG = SyncPipe[0].DAG;
946
947 if (Cache->empty()) {
948 auto I = DAG->SUnits.rbegin();
949 auto E = DAG->SUnits.rend();
950 for (; I != E; I++) {
951 if (TII->isMFMAorWMMA(*I->getInstr()))
952 Cache->push_back(&*I);
953 }
954 if (Cache->empty())
955 return false;
956 }
957
958 auto Reaches = (std::any_of(
959 Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *TargetSU) {
960 return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
961 }));
962
963 return Reaches;
964 }
IsPipeExp(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)965 IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
966 : InstructionRule(TII, SGID, NeedsCache) {}
967 };
968
969 /// Whether or not the instruction is a transitive predecessor of the
970 /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
971 class EnablesNthMFMA final : public InstructionRule {
972 private:
973 unsigned Number = 1;
974
975 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)976 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
977 SmallVectorImpl<SchedGroup> &SyncPipe) override {
978 bool FoundTrans = false;
979 unsigned Counter = 1;
980 auto DAG = SyncPipe[0].DAG;
981
982 if (Cache->empty()) {
983 SmallVector<SUnit *, 8> Worklist;
984
985 auto I = DAG->SUnits.begin();
986 auto E = DAG->SUnits.end();
987 for (; I != E; I++) {
988 if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
989 if (Counter == Number) {
990 Cache->push_back(&*I);
991 break;
992 }
993 ++Counter;
994 }
995 if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
996 FoundTrans = true;
997 }
998 if (Cache->empty())
999 return false;
1000 }
1001
1002 return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1003 }
1004
EnablesNthMFMA(unsigned Number,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1005 EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1006 bool NeedsCache = false)
1007 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1008 };
1009
1010 /// Whether or not the instruction enables the exact MFMA that is the \p
1011 /// Number th MFMA in the chain starting with \p ChainSeed
1012 class EnablesNthMFMAInChain final : public InstructionRule {
1013 private:
1014 unsigned Number = 1;
1015 SUnit *ChainSeed;
1016
1017 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1018 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1019 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1020 auto DAG = SyncPipe[0].DAG;
1021
1022 if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1023 return false;
1024
1025 if (Cache->empty()) {
1026 auto TempSU = ChainSeed;
1027 auto Depth = Number;
1028 while (Depth > 0) {
1029 --Depth;
1030 bool Found = false;
1031 for (auto &Succ : TempSU->Succs) {
1032 if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1033 TempSU = Succ.getSUnit();
1034 Found = true;
1035 break;
1036 }
1037 }
1038 if (!Found)
1039 return false;
1040 }
1041
1042 Cache->push_back(TempSU);
1043 }
1044 // If we failed to find the instruction to be placed into the cache, we
1045 // would have already exited.
1046 assert(!Cache->empty());
1047
1048 return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1049 }
1050
EnablesNthMFMAInChain(unsigned Number,SUnit * ChainSeed,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1051 EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1052 const SIInstrInfo *TII, unsigned SGID,
1053 bool NeedsCache = false)
1054 : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1055 ChainSeed(ChainSeed) {}
1056 };
1057
1058 /// Whether or not the instruction has less than \p Size immediate successors.
1059 /// If \p HasIntermediary is true, this tests also whether all successors of
1060 /// the SUnit have less than \p Size successors.
1061 class LessThanNSuccs final : public InstructionRule {
1062 private:
1063 unsigned Size = 1;
1064 bool HasIntermediary = false;
1065
1066 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1067 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1068 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1069 if (!SyncPipe.size())
1070 return false;
1071
1072 auto SuccSize = std::count_if(
1073 SU->Succs.begin(), SU->Succs.end(),
1074 [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1075 if (SuccSize >= Size)
1076 return false;
1077
1078 if (HasIntermediary) {
1079 for (auto Succ : SU->Succs) {
1080 auto SuccSize = std::count_if(
1081 Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1082 [](const SDep &SuccSucc) {
1083 return SuccSucc.getKind() == SDep::Data;
1084 });
1085 if (SuccSize >= Size)
1086 return false;
1087 }
1088 }
1089
1090 return true;
1091 }
LessThanNSuccs(unsigned Size,const SIInstrInfo * TII,unsigned SGID,bool HasIntermediary=false,bool NeedsCache=false)1092 LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1093 bool HasIntermediary = false, bool NeedsCache = false)
1094 : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1095 HasIntermediary(HasIntermediary) {}
1096 };
1097
1098 /// Whether or not the instruction has greater than or equal to \p Size
1099 /// immediate successors. If \p HasIntermediary is true, this tests also
1100 /// whether all successors of the SUnit have greater than or equal to \p Size
1101 /// successors.
1102 class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1103 private:
1104 unsigned Size = 1;
1105 bool HasIntermediary = false;
1106
1107 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1108 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1109 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1110 if (!SyncPipe.size())
1111 return false;
1112
1113 auto SuccSize = std::count_if(
1114 SU->Succs.begin(), SU->Succs.end(),
1115 [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1116 if (SuccSize >= Size)
1117 return true;
1118
1119 if (HasIntermediary) {
1120 for (auto Succ : SU->Succs) {
1121 auto SuccSize = std::count_if(
1122 Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1123 [](const SDep &SuccSucc) {
1124 return SuccSucc.getKind() == SDep::Data;
1125 });
1126 if (SuccSize >= Size)
1127 return true;
1128 }
1129 }
1130
1131 return false;
1132 }
GreaterThanOrEqualToNSuccs(unsigned Size,const SIInstrInfo * TII,unsigned SGID,bool HasIntermediary=false,bool NeedsCache=false)1133 GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1134 unsigned SGID, bool HasIntermediary = false,
1135 bool NeedsCache = false)
1136 : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1137 HasIntermediary(HasIntermediary) {}
1138 };
1139
1140 // Whether or not the instruction is a relevant V_CVT instruction.
1141 class IsCvt final : public InstructionRule {
1142 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1143 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1144 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1145 auto Opc = SU->getInstr()->getOpcode();
1146 return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1147 Opc == AMDGPU::V_CVT_I32_F32_e32;
1148 }
IsCvt(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1149 IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1150 : InstructionRule(TII, SGID, NeedsCache) {}
1151 };
1152
1153 // Whether or not the instruction is FMA_F32.
1154 class IsFMA final : public InstructionRule {
1155 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1156 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1157 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1158 return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1159 SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1160 }
IsFMA(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1161 IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1162 : InstructionRule(TII, SGID, NeedsCache) {}
1163 };
1164
1165 // Whether or not the instruction is a V_ADD_F32 instruction.
1166 class IsPipeAdd final : public InstructionRule {
1167 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1168 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1169 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1170 return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1171 }
IsPipeAdd(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1172 IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1173 : InstructionRule(TII, SGID, NeedsCache) {}
1174 };
1175
1176 /// Whether or not the instruction is an immediate RAW successor
1177 /// of the SchedGroup \p Distance steps before.
1178 class IsSuccOfPrevNthGroup final : public InstructionRule {
1179 private:
1180 unsigned Distance = 1;
1181
1182 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1183 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1184 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1185 SchedGroup *OtherGroup = nullptr;
1186 if (!SyncPipe.size())
1187 return false;
1188
1189 for (auto &PipeSG : SyncPipe) {
1190 if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1191 OtherGroup = &PipeSG;
1192 }
1193
1194 if (!OtherGroup)
1195 return false;
1196 if (!OtherGroup->Collection.size())
1197 return true;
1198
1199 for (auto &OtherEle : OtherGroup->Collection) {
1200 for (auto &Succ : OtherEle->Succs) {
1201 if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1202 return true;
1203 }
1204 }
1205
1206 return false;
1207 }
IsSuccOfPrevNthGroup(unsigned Distance,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1208 IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1209 unsigned SGID, bool NeedsCache = false)
1210 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1211 };
1212
1213 /// Whether or not the instruction is a transitive successor of any
1214 /// instruction the the SchedGroup \p Distance steps before.
1215 class IsReachableFromPrevNthGroup final : public InstructionRule {
1216 private:
1217 unsigned Distance = 1;
1218
1219 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1220 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1221 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1222 SchedGroup *OtherGroup = nullptr;
1223 if (!SyncPipe.size())
1224 return false;
1225
1226 for (auto &PipeSG : SyncPipe) {
1227 if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1228 OtherGroup = &PipeSG;
1229 }
1230
1231 if (!OtherGroup)
1232 return false;
1233 if (!OtherGroup->Collection.size())
1234 return true;
1235
1236 auto DAG = SyncPipe[0].DAG;
1237
1238 for (auto &OtherEle : OtherGroup->Collection)
1239 if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
1240 return true;
1241
1242 return false;
1243 }
IsReachableFromPrevNthGroup(unsigned Distance,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1244 IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1245 unsigned SGID, bool NeedsCache = false)
1246 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1247 };
1248
1249 /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1250 class OccursAtOrAfterNode final : public InstructionRule {
1251 private:
1252 unsigned Number = 1;
1253
1254 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1255 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1256 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1257
1258 return SU->NodeNum >= Number;
1259 }
OccursAtOrAfterNode(unsigned Number,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1260 OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1261 bool NeedsCache = false)
1262 : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1263 };
1264
1265 /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1266 /// starting with \p ChainSeed
1267 class IsExactMFMA final : public InstructionRule {
1268 private:
1269 unsigned Number = 1;
1270 SUnit *ChainSeed;
1271
1272 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1273 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1274 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1275 if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1276 return false;
1277
1278 if (Cache->empty()) {
1279 auto TempSU = ChainSeed;
1280 auto Depth = Number;
1281 while (Depth > 0) {
1282 --Depth;
1283 bool Found = false;
1284 for (auto &Succ : TempSU->Succs) {
1285 if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1286 TempSU = Succ.getSUnit();
1287 Found = true;
1288 break;
1289 }
1290 }
1291 if (!Found) {
1292 return false;
1293 }
1294 }
1295 Cache->push_back(TempSU);
1296 }
1297 // If we failed to find the instruction to be placed into the cache, we
1298 // would have already exited.
1299 assert(!Cache->empty());
1300
1301 return (*Cache)[0] == SU;
1302 }
1303
IsExactMFMA(unsigned Number,SUnit * ChainSeed,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1304 IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1305 unsigned SGID, bool NeedsCache = false)
1306 : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1307 ChainSeed(ChainSeed) {}
1308 };
1309
1310 // Whether the instruction occurs after the first TRANS instruction. This
1311 // implies the instruction can not be a predecessor of the first TRANS
1312 // insruction
1313 class OccursAfterExp final : public InstructionRule {
1314 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1315 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1316 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1317
1318 SmallVector<SUnit *, 12> Worklist;
1319 auto DAG = SyncPipe[0].DAG;
1320 if (Cache->empty()) {
1321 for (auto &SU : DAG->SUnits)
1322 if (TII->isTRANS(SU.getInstr()->getOpcode())) {
1323 Cache->push_back(&SU);
1324 break;
1325 }
1326 if (Cache->empty())
1327 return false;
1328 }
1329
1330 return SU->NodeNum > (*Cache)[0]->NodeNum;
1331 }
1332
OccursAfterExp(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1333 OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1334 bool NeedsCache = false)
1335 : InstructionRule(TII, SGID, NeedsCache) {}
1336 };
1337
1338 public:
1339 bool applyIGLPStrategy(
1340 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1341 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1342 AMDGPU::SchedulingPhase Phase) override;
1343
1344 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1345 AMDGPU::SchedulingPhase Phase) override;
1346
MFMAExpInterleaveOpt(ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)1347 MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1348 : IGLPStrategy(DAG, TII) {
1349 IsBottomUp = false;
1350 }
1351 };
1352
1353 unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1354 unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1355 unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1356 unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1357 unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1358 unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1359 unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
1360 bool MFMAExpInterleaveOpt::HasCvt = false;
1361 bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1362 std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1363
analyzeDAG(const SIInstrInfo * TII)1364 bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1365 SmallVector<SUnit *, 10> ExpPipeCands;
1366 SmallVector<SUnit *, 10> MFMAPipeCands;
1367 SmallVector<SUnit *, 10> MFMAPipeSUs;
1368 SmallVector<SUnit *, 10> PackSUs;
1369 SmallVector<SUnit *, 10> CvtSUs;
1370
1371 auto isBitPack = [](unsigned Opc) {
1372 return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1373 };
1374
1375 auto isCvt = [](unsigned Opc) {
1376 return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1377 };
1378
1379 auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1380
1381 AddPipeCount = 0;
1382 for (SUnit &SU : DAG->SUnits) {
1383 auto Opc = SU.getInstr()->getOpcode();
1384 if (TII->isTRANS(Opc)) {
1385 // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1386 if (SU.Succs.size() >= 7)
1387 continue;
1388 for (auto &Succ : SU.Succs) {
1389 if (Succ.getSUnit()->Succs.size() >= 7)
1390 continue;
1391 }
1392 ExpPipeCands.push_back(&SU);
1393 }
1394
1395 if (TII->isMFMAorWMMA(*SU.getInstr()))
1396 MFMAPipeCands.push_back(&SU);
1397
1398 if (isBitPack(Opc))
1399 PackSUs.push_back(&SU);
1400
1401 if (isCvt(Opc))
1402 CvtSUs.push_back(&SU);
1403
1404 if (isAdd(Opc))
1405 ++AddPipeCount;
1406 }
1407
1408 if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1409 return false;
1410
1411 TransPipeCount = 0;
1412
1413 std::optional<SUnit *> TempMFMA;
1414 std::optional<SUnit *> TempExp;
1415 // Count the number of EXPs that reach an MFMA
1416 for (auto &PredSU : ExpPipeCands) {
1417 for (auto &SuccSU : MFMAPipeCands) {
1418 if (DAG->IsReachable(SuccSU, PredSU)) {
1419 if (!TempExp) {
1420 TempExp = PredSU;
1421 TempMFMA = SuccSU;
1422 }
1423 MFMAPipeSUs.push_back(SuccSU);
1424 ++TransPipeCount;
1425 break;
1426 }
1427 }
1428 }
1429
1430 if (!(TempExp && TempMFMA))
1431 return false;
1432
1433 HasChainBetweenCvt =
1434 std::find_if((*TempExp)->Succs.begin(), (*TempExp)->Succs.end(),
1435 [&isCvt](SDep &Succ) {
1436 return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1437 }) == (*TempExp)->Succs.end();
1438
1439 // Count the number of MFMAs that are reached by an EXP
1440 for (auto &SuccSU : MFMAPipeCands) {
1441 if (MFMAPipeSUs.size() &&
1442 std::find_if(MFMAPipeSUs.begin(), MFMAPipeSUs.end(),
1443 [&SuccSU](SUnit *PotentialMatch) {
1444 return PotentialMatch->NodeNum == SuccSU->NodeNum;
1445 }) != MFMAPipeSUs.end())
1446 continue;
1447
1448 for (auto &PredSU : ExpPipeCands) {
1449 if (DAG->IsReachable(SuccSU, PredSU)) {
1450 MFMAPipeSUs.push_back(SuccSU);
1451 break;
1452 }
1453 }
1454 }
1455
1456 MFMAPipeCount = MFMAPipeSUs.size();
1457
1458 assert(TempExp && TempMFMA);
1459 assert(MFMAPipeCount > 0);
1460
1461 std::optional<SUnit *> TempCvt;
1462 for (auto &SuccSU : CvtSUs) {
1463 if (DAG->IsReachable(SuccSU, *TempExp)) {
1464 TempCvt = SuccSU;
1465 break;
1466 }
1467 }
1468
1469 HasCvt = false;
1470 if (TempCvt.has_value()) {
1471 for (auto &SuccSU : MFMAPipeSUs) {
1472 if (DAG->IsReachable(SuccSU, *TempCvt)) {
1473 HasCvt = true;
1474 break;
1475 }
1476 }
1477 }
1478
1479 MFMAChains = 0;
1480 for (auto &MFMAPipeSU : MFMAPipeSUs) {
1481 if (is_contained(MFMAChainSeeds, MFMAPipeSU))
1482 continue;
1483 if (!std::any_of(MFMAPipeSU->Preds.begin(), MFMAPipeSU->Preds.end(),
1484 [&TII](SDep &Succ) {
1485 return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1486 })) {
1487 MFMAChainSeeds.push_back(MFMAPipeSU);
1488 ++MFMAChains;
1489 }
1490 }
1491
1492 if (!MFMAChains)
1493 return false;
1494
1495 for (auto Pred : MFMAChainSeeds[0]->Preds) {
1496 if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&
1497 Pred.getSUnit()->getInstr()->mayLoad())
1498 FirstPipeDSR = Pred.getSUnit()->NodeNum;
1499 }
1500
1501 MFMAChainLength = MFMAPipeCount / MFMAChains;
1502
1503 // The number of bit pack operations that depend on a single V_EXP
1504 unsigned PackSuccCount = std::count_if(
1505 PackSUs.begin(), PackSUs.end(), [this, &TempExp](SUnit *VPack) {
1506 return DAG->IsReachable(VPack, *TempExp);
1507 });
1508
1509 // The number of bit pack operations an MFMA depends on
1510 unsigned PackPredCount =
1511 std::count_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1512 [&isBitPack](SDep &Pred) {
1513 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1514 return isBitPack(Opc);
1515 });
1516
1517 auto PackPred =
1518 std::find_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1519 [&isBitPack](SDep &Pred) {
1520 auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1521 return isBitPack(Opc);
1522 });
1523
1524 if (PackPred == (*TempMFMA)->Preds.end())
1525 return false;
1526
1527 MFMAEnablement = 0;
1528 ExpRequirement = 0;
1529 // How many MFMAs depend on a single bit pack operation
1530 MFMAEnablement =
1531 std::count_if(PackPred->getSUnit()->Succs.begin(),
1532 PackPred->getSUnit()->Succs.end(), [&TII](SDep &Succ) {
1533 return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1534 });
1535
1536 // The number of MFMAs that depend on a single V_EXP
1537 MFMAEnablement *= PackSuccCount;
1538
1539 // The number of V_EXPs required to resolve all dependencies for an MFMA
1540 ExpRequirement =
1541 std::count_if(ExpPipeCands.begin(), ExpPipeCands.end(),
1542 [this, &PackPred](SUnit *ExpBase) {
1543 return DAG->IsReachable(PackPred->getSUnit(), ExpBase);
1544 });
1545
1546 ExpRequirement *= PackPredCount;
1547 return true;
1548 }
1549
shouldApplyStrategy(ScheduleDAGInstrs * DAG,AMDGPU::SchedulingPhase Phase)1550 bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1551 AMDGPU::SchedulingPhase Phase) {
1552 const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1553 const SIInstrInfo *TII = ST.getInstrInfo();
1554
1555 if (Phase != AMDGPU::SchedulingPhase::PostRA)
1556 MFMAChainSeeds.clear();
1557 if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1558 return false;
1559
1560 return true;
1561 }
1562
applyIGLPStrategy(DenseMap<int,SUnitsToCandidateSGsMap> & SyncedInstrs,DenseMap<int,SmallVector<SchedGroup,4>> & SyncedSchedGroups,AMDGPU::SchedulingPhase Phase)1563 bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1564 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1565 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1566 AMDGPU::SchedulingPhase Phase) {
1567
1568 bool IsSmallKernelType =
1569 MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1570 bool IsLargeKernelType =
1571 MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1572
1573 if (!(IsSmallKernelType || IsLargeKernelType))
1574 return false;
1575
1576 const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1577 const SIInstrInfo *TII = ST.getInstrInfo();
1578
1579 unsigned PipelineSyncID = 0;
1580 SchedGroup *SG = nullptr;
1581
1582 unsigned MFMAChain = 0;
1583 unsigned PositionInChain = 0;
1584 unsigned CurrMFMAForTransPosition = 0;
1585
1586 auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1587 &CurrMFMAForTransPosition]() {
1588 CurrMFMAForTransPosition += MFMAEnablement;
1589 PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1590 MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1591 };
1592
1593 auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1594 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1595 return (TempMFMAForTrans / MFMAChains);
1596 };
1597
1598 auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1599 auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1600 return TempMFMAForTrans % MFMAChains;
1601 };
1602
1603 unsigned CurrMFMAPosition = 0;
1604 unsigned MFMAChainForMFMA = 0;
1605 unsigned PositionInChainForMFMA = 0;
1606
1607 auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1608 &PositionInChainForMFMA]() {
1609 ++CurrMFMAPosition;
1610 MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1611 PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1612 };
1613
1614 bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1615 assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1616
1617 bool UsesFMA = IsSmallKernelType || !IsPostRA;
1618 bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1619 bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1620 bool UsesVALU = IsSmallKernelType;
1621
1622 // PHASE 1: "Prefetch"
1623 if (UsesFMA) {
1624 // First Round FMA
1625 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1626 SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1627 if (!IsPostRA && MFMAChains) {
1628 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1629 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1630 true));
1631 } else
1632 SG->addRule(
1633 std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1634 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1635 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1636
1637 // Second Round FMA
1638 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1639 SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1640 if (!IsPostRA && MFMAChains) {
1641 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1642 getNextTransPositionInChain(),
1643 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1644 } else
1645 SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1646 SG->getSGID(), true));
1647 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1648 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1649 }
1650
1651 if (UsesDSRead) {
1652 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1653 SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1654 SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1655 SG->getSGID()));
1656 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1657 }
1658
1659 // First Round EXP
1660 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1661 SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);
1662 if (!IsPostRA && MFMAChains)
1663 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1664 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));
1665 else
1666 SG->addRule(std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1667 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1668 SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1669 HasChainBetweenCvt));
1670 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1671
1672 incrementTransPosition();
1673
1674 // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1675 for (unsigned I = 0; I < ExpRequirement; I++) {
1676 // First Round CVT
1677 if (UsesCvt) {
1678 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1679 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1680 SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1681 if (HasChainBetweenCvt)
1682 SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1683 1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1684 else
1685 SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(
1686 1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1687 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1688 }
1689
1690 // Third Round FMA
1691 if (UsesFMA) {
1692 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1693 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1694 if (!IsPostRA && MFMAChains) {
1695 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1696 getNextTransPositionInChain(),
1697 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1698 } else
1699 SG->addRule(std::make_shared<EnablesNthMFMA>(2 * MFMAEnablement + 1,
1700 TII, SG->getSGID(), true));
1701 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1702 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1703 }
1704
1705 // Second Round EXP
1706 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1707 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1708 if (!IsPostRA && MFMAChains)
1709 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1710 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1711 true));
1712 else
1713 SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1714 SG->getSGID(), true));
1715 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1716 SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1717 HasChainBetweenCvt));
1718 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1719 }
1720
1721 // The "extra" EXP which enables all MFMA
1722 // TODO: UsesExtraExp
1723 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1724 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1725 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1726 SG->addRule(std::make_shared<GreaterThanOrEqualToNSuccs>(
1727 8, TII, SG->getSGID(), HasChainBetweenCvt));
1728 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1729
1730 // PHASE 2: Main Interleave Loop
1731
1732 // The number of MFMAs per iteration
1733 unsigned MFMARatio =
1734 MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1735 // The number of Exps per iteration
1736 unsigned ExpRatio =
1737 MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1738 // The reamaining Exps
1739 unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1740 ? TransPipeCount - (2 * ExpRequirement)
1741 : 0;
1742 unsigned ExpLoopCount = RemainingExp / ExpRatio;
1743 // In loop MFMAs
1744 unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1745 ? MFMAPipeCount - (MFMAEnablement * 2)
1746 : 0;
1747 unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1748 unsigned VALUOps =
1749 AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1750 unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);
1751
1752 for (unsigned I = 0; I < LoopSize; I++) {
1753 if (!(I * ExpRatio % ExpRequirement))
1754 incrementTransPosition();
1755
1756 // Round N MFMA
1757 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1758 SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);
1759 if (!IsPostRA && MFMAChains)
1760 SG->addRule(std::make_shared<IsExactMFMA>(
1761 PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,
1762 SG->getSGID(), true));
1763 else
1764 SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1765 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1766 incrementMFMAPosition();
1767
1768 if (UsesVALU) {
1769 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1770 SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);
1771 SG->addRule(std::make_shared<IsPipeAdd>(TII, SG->getSGID()));
1772 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1773 }
1774
1775 if (UsesDSRead && !(I % 4)) {
1776 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1777 SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1778 SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1779 SG->getSGID()));
1780 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1781 }
1782
1783 // CVT, EXP, FMA Interleaving
1784 for (unsigned J = 0; J < ExpRatio; J++) {
1785 auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1786 auto MaxMFMAOffset =
1787 (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1788
1789 // Round N + 1 CVT
1790 if (UsesCvt) {
1791 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1792 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1793 SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1794 auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1795 auto DSROffset = I / 4 + 1;
1796 auto MaxDSROffset = MaxMFMAOffset / 4;
1797 // TODO: UsesExtraExp
1798 auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1799 auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +
1800 std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +
1801 ExpOffset;
1802 if (HasChainBetweenCvt)
1803 SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1804 CurrentOffset, TII, SG->getSGID()));
1805 else
1806 SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(CurrentOffset, TII,
1807 SG->getSGID()));
1808 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1809 }
1810
1811 // Round N + 3 FMA
1812 if (UsesFMA) {
1813 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1814 SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1815 if (!IsPostRA && MFMAChains)
1816 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1817 getNextTransPositionInChain(),
1818 MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),
1819 true));
1820 else
1821 SG->addRule(std::make_shared<EnablesNthMFMA>(
1822 (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1823 TII, SG->getSGID(), true));
1824 SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1825 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1826 }
1827
1828 // Round N + 2 Exp
1829 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1830 SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1831 if (!IsPostRA && MFMAChains)
1832 SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1833 PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1834 true));
1835 else
1836 SG->addRule(std::make_shared<EnablesNthMFMA>(
1837 (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1838 TII, SG->getSGID(), true));
1839 SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1840 SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1841 HasChainBetweenCvt));
1842 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1843 }
1844 }
1845
1846 // PHASE 3: Remaining MFMAs
1847 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1848 SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);
1849 SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1850 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1851 return true;
1852 }
1853
1854 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1855 private:
1856 // Whether the DS_READ is a predecessor of first four MFMA in region
1857 class EnablesInitialMFMA final : public InstructionRule {
1858 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1859 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1860 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1861 if (!SyncPipe.size())
1862 return false;
1863 int MFMAsFound = 0;
1864 if (!Cache->size()) {
1865 for (auto &Elt : SyncPipe[0].DAG->SUnits) {
1866 if (TII->isMFMAorWMMA(*Elt.getInstr())) {
1867 ++MFMAsFound;
1868 if (MFMAsFound > 4)
1869 break;
1870 Cache->push_back(&Elt);
1871 }
1872 }
1873 }
1874
1875 assert(Cache->size());
1876 auto DAG = SyncPipe[0].DAG;
1877 for (auto &Elt : *Cache) {
1878 if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
1879 return true;
1880 }
1881 return false;
1882 }
1883
EnablesInitialMFMA(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1884 EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
1885 bool NeedsCache = false)
1886 : InstructionRule(TII, SGID, NeedsCache) {}
1887 };
1888
1889 // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
1890 class IsPermForDSW final : public InstructionRule {
1891 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1892 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1893 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1894 auto MI = SU->getInstr();
1895 if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
1896 return false;
1897
1898 bool FitsInGroup = false;
1899 // Does the VALU have a DS_WRITE successor
1900 if (!Collection.size()) {
1901 for (auto &Succ : SU->Succs) {
1902 SUnit *SuccUnit = Succ.getSUnit();
1903 if (TII->isDS(*SuccUnit->getInstr()) &&
1904 SuccUnit->getInstr()->mayStore()) {
1905 Cache->push_back(SuccUnit);
1906 FitsInGroup = true;
1907 }
1908 }
1909 return FitsInGroup;
1910 }
1911
1912 assert(Cache->size());
1913
1914 // Does the VALU have a DS_WRITE successor that is the same as other
1915 // VALU already in the group. The V_PERMs will all share 1 DS_W succ
1916 return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
1917 return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
1918 return ThisSucc.getSUnit() == Elt;
1919 });
1920 });
1921 }
1922
IsPermForDSW(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1923 IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1924 : InstructionRule(TII, SGID, NeedsCache) {}
1925 };
1926
1927 // Whether the SU is a successor of any element in previous SchedGroup
1928 class IsSuccOfPrevGroup final : public InstructionRule {
1929 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1930 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1931 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1932 SchedGroup *OtherGroup = nullptr;
1933 for (auto &PipeSG : SyncPipe) {
1934 if ((unsigned)PipeSG.getSGID() == SGID - 1) {
1935 OtherGroup = &PipeSG;
1936 }
1937 }
1938
1939 if (!OtherGroup)
1940 return false;
1941 if (!OtherGroup->Collection.size())
1942 return true;
1943
1944 // Does the previous VALU have this DS_Write as a successor
1945 return (std::any_of(OtherGroup->Collection.begin(),
1946 OtherGroup->Collection.end(), [&SU](SUnit *Elt) {
1947 return std::any_of(Elt->Succs.begin(),
1948 Elt->Succs.end(),
1949 [&SU](SDep &Succ) {
1950 return Succ.getSUnit() == SU;
1951 });
1952 }));
1953 }
IsSuccOfPrevGroup(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1954 IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1955 bool NeedsCache = false)
1956 : InstructionRule(TII, SGID, NeedsCache) {}
1957 };
1958
1959 // Whether the combined load width of group is 128 bits
1960 class VMEMSize final : public InstructionRule {
1961 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)1962 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1963 SmallVectorImpl<SchedGroup> &SyncPipe) override {
1964 auto MI = SU->getInstr();
1965 if (MI->getOpcode() == TargetOpcode::BUNDLE)
1966 return false;
1967 if (!Collection.size())
1968 return true;
1969
1970 int NumBits = 0;
1971
1972 auto TRI = TII->getRegisterInfo();
1973 auto &MRI = MI->getParent()->getParent()->getRegInfo();
1974 for (auto &Elt : Collection) {
1975 auto Op = Elt->getInstr()->getOperand(0);
1976 auto Size =
1977 TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1978 NumBits += Size;
1979 }
1980
1981 if (NumBits < 128) {
1982 assert(TII->isVMEM(*MI) && MI->mayLoad());
1983 if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1984 MRI, MI->getOperand(0))) <=
1985 128)
1986 return true;
1987 }
1988
1989 return false;
1990 }
1991
VMEMSize(const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)1992 VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1993 : InstructionRule(TII, SGID, NeedsCache) {}
1994 };
1995
1996 /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1997 /// that is \p Distance steps away
1998 class SharesPredWithPrevNthGroup final : public InstructionRule {
1999 private:
2000 unsigned Distance = 1;
2001
2002 public:
apply(const SUnit * SU,const ArrayRef<SUnit * > Collection,SmallVectorImpl<SchedGroup> & SyncPipe)2003 bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2004 SmallVectorImpl<SchedGroup> &SyncPipe) override {
2005 SchedGroup *OtherGroup = nullptr;
2006 if (!SyncPipe.size())
2007 return false;
2008
2009 if (!Cache->size()) {
2010
2011 for (auto &PipeSG : SyncPipe) {
2012 if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2013 OtherGroup = &PipeSG;
2014 }
2015 }
2016
2017 if (!OtherGroup)
2018 return false;
2019 if (!OtherGroup->Collection.size())
2020 return true;
2021
2022 for (auto &OtherEle : OtherGroup->Collection) {
2023 for (auto &Pred : OtherEle->Preds) {
2024 if (Pred.getSUnit()->getInstr()->getOpcode() ==
2025 AMDGPU::V_PERM_B32_e64)
2026 Cache->push_back(Pred.getSUnit());
2027 }
2028 }
2029
2030 // If the other group has no PERM preds, then this group won't share any
2031 if (!Cache->size())
2032 return false;
2033 }
2034
2035 auto DAG = SyncPipe[0].DAG;
2036 // Does the previous DS_WRITE share a V_PERM predecessor with this
2037 // VMEM_READ
2038 return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
2039 return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
2040 });
2041 }
SharesPredWithPrevNthGroup(unsigned Distance,const SIInstrInfo * TII,unsigned SGID,bool NeedsCache=false)2042 SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2043 unsigned SGID, bool NeedsCache = false)
2044 : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2045 };
2046
2047 public:
2048 bool applyIGLPStrategy(
2049 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2050 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2051 AMDGPU::SchedulingPhase Phase) override;
2052
shouldApplyStrategy(ScheduleDAGInstrs * DAG,AMDGPU::SchedulingPhase Phase)2053 bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2054 AMDGPU::SchedulingPhase Phase) override {
2055 return true;
2056 }
2057
MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)2058 MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2059 : IGLPStrategy(DAG, TII) {
2060 IsBottomUp = false;
2061 }
2062 };
2063
2064 static unsigned DSWCount = 0;
2065 static unsigned DSWWithPermCount = 0;
2066 static unsigned DSWWithSharedVMEMCount = 0;
2067
applyIGLPStrategy(DenseMap<int,SUnitsToCandidateSGsMap> & SyncedInstrs,DenseMap<int,SmallVector<SchedGroup,4>> & SyncedSchedGroups,AMDGPU::SchedulingPhase Phase)2068 bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2069 DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2070 DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2071 AMDGPU::SchedulingPhase Phase) {
2072 unsigned MFMACount = 0;
2073 unsigned DSRCount = 0;
2074
2075 bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2076
2077 assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2078 DSWWithSharedVMEMCount == 0)) &&
2079 "DSWCounters should be zero in pre-RA scheduling!");
2080 SmallVector<SUnit *, 6> DSWithPerms;
2081 for (auto &SU : DAG->SUnits) {
2082 auto I = SU.getInstr();
2083 if (TII->isMFMAorWMMA(*I))
2084 ++MFMACount;
2085 else if (TII->isDS(*I)) {
2086 if (I->mayLoad())
2087 ++DSRCount;
2088 else if (I->mayStore() && IsInitial) {
2089 ++DSWCount;
2090 for (auto Pred : SU.Preds) {
2091 if (Pred.getSUnit()->getInstr()->getOpcode() ==
2092 AMDGPU::V_PERM_B32_e64) {
2093 DSWithPerms.push_back(&SU);
2094 break;
2095 }
2096 }
2097 }
2098 }
2099 }
2100
2101 if (IsInitial) {
2102 DSWWithPermCount = DSWithPerms.size();
2103 auto I = DSWithPerms.begin();
2104 auto E = DSWithPerms.end();
2105
2106 // Get the count of DS_WRITES with V_PERM predecessors which
2107 // have loop carried dependencies (WAR) on the same VMEM_READs.
2108 // We consider partial overlap as a miss -- in other words,
2109 // for a given DS_W, we only consider another DS_W as matching
2110 // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2111 // for every V_PERM pred of this DS_W.
2112 DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2113 SmallVector<SUnit *, 6> Counted;
2114 for (; I != E; I++) {
2115 SUnit *Cand = nullptr;
2116 bool MissedAny = false;
2117 for (auto &Pred : (*I)->Preds) {
2118 if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2119 continue;
2120
2121 if (Cand && llvm::is_contained(Counted, Cand))
2122 break;
2123
2124 for (auto &Succ : Pred.getSUnit()->Succs) {
2125 auto MI = Succ.getSUnit()->getInstr();
2126 if (!TII->isVMEM(*MI) || !MI->mayLoad())
2127 continue;
2128
2129 if (MissedAny || !VMEMLookup.size()) {
2130 MissedAny = true;
2131 VMEMLookup[MI] = *I;
2132 continue;
2133 }
2134
2135 if (!VMEMLookup.contains(MI)) {
2136 MissedAny = true;
2137 VMEMLookup[MI] = *I;
2138 continue;
2139 }
2140
2141 Cand = VMEMLookup[MI];
2142 if (llvm::is_contained(Counted, Cand)) {
2143 MissedAny = true;
2144 break;
2145 }
2146 }
2147 }
2148 if (!MissedAny && Cand) {
2149 DSWWithSharedVMEMCount += 2;
2150 Counted.push_back(Cand);
2151 Counted.push_back(*I);
2152 }
2153 }
2154 }
2155
2156 assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2157 SchedGroup *SG;
2158 unsigned PipelineSyncID = 0;
2159 // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2160 if (DSWWithPermCount) {
2161 for (unsigned I = 0; I < MFMACount; I++) {
2162 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2163 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2164 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2165
2166 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2167 SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
2168 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2169 }
2170 }
2171
2172 PipelineSyncID = 1;
2173 // Phase 1: Break up DS_READ and MFMA clusters.
2174 // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2175 // prefetch
2176
2177 // Make ready initial MFMA
2178 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2179 SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
2180 SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
2181 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2182
2183 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2184 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2185 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2186
2187 // Interleave MFMA with DS_READ prefetch
2188 for (unsigned I = 0; I < DSRCount - 4; ++I) {
2189 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2190 SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
2191 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2192
2193 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2194 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2195 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2196 }
2197
2198 // Phase 2a: Loop carried dependency with V_PERM
2199 // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2200 // depend on. Interleave MFMA to keep XDL unit busy throughout.
2201 for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
2202 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2203 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2204 SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2205 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2206
2207 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2208 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2209 SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2210 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2211
2212 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2213 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2214 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2215 1, TII, SG->getSGID(), true));
2216 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2217 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2218
2219 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2220 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2221 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2222
2223 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2224 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2225 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2226 3, TII, SG->getSGID(), true));
2227 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2228 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2229
2230 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2231 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2232 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2233 }
2234
2235 // Phase 2b: Loop carried dependency without V_PERM
2236 // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2237 // Interleave MFMA to keep XDL unit busy throughout.
2238 for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
2239 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2240 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2241 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2242
2243 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2244 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2245 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2246 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2247
2248 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2249 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2250 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2251 }
2252
2253 // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2254 // ultimately used by two DS_WRITE
2255 // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2256 // depend on. Interleave MFMA to keep XDL unit busy throughout.
2257
2258 for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2259 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2260 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2261 SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2262 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2263
2264 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2265 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2266 SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2267 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2268
2269 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2270 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2271 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2272
2273 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2274 SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2275 SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2276 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2277
2278 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2279 SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2280 SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2281 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2282
2283 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2284 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2285 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2286
2287 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2288 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2289 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2290 2, TII, SG->getSGID(), true));
2291 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2292 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2293
2294 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2295 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2296 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2297
2298 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2299 SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2300 SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2301 4, TII, SG->getSGID(), true));
2302 SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2303 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2304
2305 SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2306 SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2307 SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2308 }
2309
2310 return true;
2311 }
2312
2313 static std::unique_ptr<IGLPStrategy>
createIGLPStrategy(IGLPStrategyID ID,ScheduleDAGInstrs * DAG,const SIInstrInfo * TII)2314 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2315 const SIInstrInfo *TII) {
2316 switch (ID) {
2317 case MFMASmallGemmOptID:
2318 return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
2319 case MFMASmallGemmSingleWaveOptID:
2320 return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
2321 case MFMAExpInterleave:
2322 return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
2323 }
2324
2325 llvm_unreachable("Unknown IGLPStrategyID");
2326 }
2327
2328 class IGroupLPDAGMutation : public ScheduleDAGMutation {
2329 private:
2330 const SIInstrInfo *TII;
2331
2332 ScheduleDAGMI *DAG;
2333
2334 // Organize lists of SchedGroups by their SyncID. SchedGroups /
2335 // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2336 // between then.
2337 DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2338
2339 // Used to track instructions that can be mapped to multiple sched groups
2340 DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2341
2342 // Add DAG edges that enforce SCHED_BARRIER ordering.
2343 void addSchedBarrierEdges(SUnit &SU);
2344
2345 // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2346 // not be reordered accross the SCHED_BARRIER. This is used for the base
2347 // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2348 // SCHED_BARRIER will always block all instructions that can be classified
2349 // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2350 // and may only synchronize with some SchedGroups. Returns the inverse of
2351 // Mask. SCHED_BARRIER's mask describes which instruction types should be
2352 // allowed to be scheduled across it. Invert the mask to get the
2353 // SchedGroupMask of instructions that should be barred.
2354 SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2355
2356 // Create SchedGroups for a SCHED_GROUP_BARRIER.
2357 void initSchedGroupBarrierPipelineStage(
2358 std::vector<SUnit>::reverse_iterator RIter);
2359
2360 bool initIGLPOpt(SUnit &SU);
2361
2362 public:
2363 void apply(ScheduleDAGInstrs *DAGInstrs) override;
2364
2365 // The order in which the PipelineSolver should process the candidate
2366 // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2367 // created SchedGroup first, and will consider that as the ultimate
2368 // predecessor group when linking. TOP_DOWN instead links and processes the
2369 // first created SchedGroup first.
2370 bool IsBottomUp = true;
2371
2372 // The scheduling phase this application of IGLP corresponds with.
2373 AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2374
2375 IGroupLPDAGMutation() = default;
IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase)2376 IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2377 };
2378
2379 unsigned SchedGroup::NumSchedGroups = 0;
2380
tryAddEdge(SUnit * A,SUnit * B)2381 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2382 if (A != B && DAG->canAddEdge(B, A)) {
2383 DAG->addEdge(B, SDep(A, SDep::Artificial));
2384 return true;
2385 }
2386 return false;
2387 }
2388
canAddMI(const MachineInstr & MI) const2389 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2390 bool Result = false;
2391 if (MI.isMetaInstruction())
2392 Result = false;
2393
2394 else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2395 (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
2396 TII->isTRANS(MI)))
2397 Result = true;
2398
2399 else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2400 TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
2401 Result = true;
2402
2403 else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2404 TII->isSALU(MI))
2405 Result = true;
2406
2407 else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2408 TII->isMFMAorWMMA(MI))
2409 Result = true;
2410
2411 else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2412 (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2413 Result = true;
2414
2415 else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2416 MI.mayLoad() &&
2417 (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2418 Result = true;
2419
2420 else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2421 MI.mayStore() &&
2422 (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2423 Result = true;
2424
2425 else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2426 TII->isDS(MI))
2427 Result = true;
2428
2429 else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2430 MI.mayLoad() && TII->isDS(MI))
2431 Result = true;
2432
2433 else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2434 MI.mayStore() && TII->isDS(MI))
2435 Result = true;
2436
2437 else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2438 TII->isTRANS(MI))
2439 Result = true;
2440
2441 LLVM_DEBUG(
2442 dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2443 << (Result ? " could classify " : " unable to classify ") << MI);
2444
2445 return Result;
2446 }
2447
link(SUnit & SU,bool MakePred,std::vector<std::pair<SUnit *,SUnit * >> & AddedEdges)2448 int SchedGroup::link(SUnit &SU, bool MakePred,
2449 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2450 int MissedEdges = 0;
2451 for (auto *A : Collection) {
2452 SUnit *B = &SU;
2453 if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2454 continue;
2455 if (MakePred)
2456 std::swap(A, B);
2457
2458 if (DAG->IsReachable(B, A))
2459 continue;
2460
2461 // tryAddEdge returns false if there is a dependency that makes adding
2462 // the A->B edge impossible, otherwise it returns true;
2463 bool Added = tryAddEdge(A, B);
2464 if (Added)
2465 AddedEdges.emplace_back(A, B);
2466 else
2467 ++MissedEdges;
2468 }
2469
2470 return MissedEdges;
2471 }
2472
link(SUnit & SU,bool MakePred)2473 void SchedGroup::link(SUnit &SU, bool MakePred) {
2474 for (auto *A : Collection) {
2475 SUnit *B = &SU;
2476 if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2477 continue;
2478 if (MakePred)
2479 std::swap(A, B);
2480
2481 tryAddEdge(A, B);
2482 }
2483 }
2484
link(SUnit & SU,function_ref<bool (const SUnit * A,const SUnit * B)> P)2485 void SchedGroup::link(SUnit &SU,
2486 function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2487 for (auto *A : Collection) {
2488 SUnit *B = &SU;
2489 if (P(A, B))
2490 std::swap(A, B);
2491
2492 tryAddEdge(A, B);
2493 }
2494 }
2495
link(SchedGroup & OtherGroup)2496 void SchedGroup::link(SchedGroup &OtherGroup) {
2497 for (auto *B : OtherGroup.Collection)
2498 link(*B);
2499 }
2500
canAddSU(SUnit & SU) const2501 bool SchedGroup::canAddSU(SUnit &SU) const {
2502 MachineInstr &MI = *SU.getInstr();
2503 if (MI.getOpcode() != TargetOpcode::BUNDLE)
2504 return canAddMI(MI);
2505
2506 // Special case for bundled MIs.
2507 const MachineBasicBlock *MBB = MI.getParent();
2508 MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2509 while (E != MBB->end() && E->isBundledWithPred())
2510 ++E;
2511
2512 // Return true if all of the bundled MIs can be added to this group.
2513 return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
2514 }
2515
initSchedGroup()2516 void SchedGroup::initSchedGroup() {
2517 for (auto &SU : DAG->SUnits) {
2518 if (isFull())
2519 break;
2520
2521 if (canAddSU(SU))
2522 add(SU);
2523 }
2524 }
2525
initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,SUnitsToCandidateSGsMap & SyncedInstrs)2526 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
2527 SUnitsToCandidateSGsMap &SyncedInstrs) {
2528 SUnit &InitSU = *RIter;
2529 for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
2530 auto &SU = *RIter;
2531 if (isFull())
2532 break;
2533
2534 if (canAddSU(SU))
2535 SyncedInstrs[&SU].push_back(SGID);
2536 }
2537
2538 add(InitSU);
2539 assert(MaxSize);
2540 (*MaxSize)++;
2541 }
2542
initSchedGroup(SUnitsToCandidateSGsMap & SyncedInstrs)2543 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
2544 auto I = DAG->SUnits.rbegin();
2545 auto E = DAG->SUnits.rend();
2546 for (; I != E; ++I) {
2547 auto &SU = *I;
2548 if (isFull())
2549 break;
2550 if (canAddSU(SU))
2551 SyncedInstrs[&SU].push_back(SGID);
2552 }
2553 }
2554
apply(ScheduleDAGInstrs * DAGInstrs)2555 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2556 const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2557 if (!TSchedModel || DAGInstrs->SUnits.empty())
2558 return;
2559
2560 LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2561 const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2562 TII = ST.getInstrInfo();
2563 DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2564 SyncedSchedGroups.clear();
2565 SyncedInstrs.clear();
2566 bool FoundSB = false;
2567 bool FoundIGLP = false;
2568 bool ShouldApplyIGLP = false;
2569 for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2570 unsigned Opc = R->getInstr()->getOpcode();
2571 // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2572 if (Opc == AMDGPU::SCHED_BARRIER) {
2573 addSchedBarrierEdges(*R);
2574 FoundSB = true;
2575 } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2576 initSchedGroupBarrierPipelineStage(R);
2577 FoundSB = true;
2578 } else if (Opc == AMDGPU::IGLP_OPT) {
2579 resetEdges(*R, DAG);
2580 if (!FoundSB && !FoundIGLP) {
2581 FoundIGLP = true;
2582 ShouldApplyIGLP = initIGLPOpt(*R);
2583 }
2584 }
2585 }
2586
2587 if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2588 PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2589 // PipelineSolver performs the mutation by adding the edges it
2590 // determined as the best
2591 PS.solve();
2592 return;
2593 }
2594 }
2595
addSchedBarrierEdges(SUnit & SchedBarrier)2596 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2597 MachineInstr &MI = *SchedBarrier.getInstr();
2598 assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2599 // Remove all existing edges from the SCHED_BARRIER that were added due to the
2600 // instruction having side effects.
2601 resetEdges(SchedBarrier, DAG);
2602 LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2603 << MI.getOperand(0).getImm() << "\n");
2604 auto InvertedMask =
2605 invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
2606 SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2607 SG.initSchedGroup();
2608
2609 // Preserve original instruction ordering relative to the SCHED_BARRIER.
2610 SG.link(
2611 SchedBarrier,
2612 (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2613 const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2614 }
2615
2616 SchedGroupMask
invertSchedBarrierMask(SchedGroupMask Mask) const2617 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2618 // Invert mask and erase bits for types of instructions that are implied to be
2619 // allowed past the SCHED_BARRIER.
2620 SchedGroupMask InvertedMask = ~Mask;
2621
2622 // ALU implies VALU, SALU, MFMA, TRANS.
2623 if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2624 InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2625 ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2626 // VALU, SALU, MFMA, TRANS implies ALU.
2627 else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2628 (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2629 (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2630 (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2631 InvertedMask &= ~SchedGroupMask::ALU;
2632
2633 // VMEM implies VMEM_READ, VMEM_WRITE.
2634 if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2635 InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
2636 // VMEM_READ, VMEM_WRITE implies VMEM.
2637 else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2638 (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
2639 InvertedMask &= ~SchedGroupMask::VMEM;
2640
2641 // DS implies DS_READ, DS_WRITE.
2642 if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2643 InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
2644 // DS_READ, DS_WRITE implies DS.
2645 else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2646 (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2647 InvertedMask &= ~SchedGroupMask::DS;
2648
2649 LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2650 << "\n");
2651
2652 return InvertedMask;
2653 }
2654
initSchedGroupBarrierPipelineStage(std::vector<SUnit>::reverse_iterator RIter)2655 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2656 std::vector<SUnit>::reverse_iterator RIter) {
2657 // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
2658 // to the instruction having side effects.
2659 resetEdges(*RIter, DAG);
2660 MachineInstr &SGB = *RIter->getInstr();
2661 assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2662 int32_t SGMask = SGB.getOperand(0).getImm();
2663 int32_t Size = SGB.getOperand(1).getImm();
2664 int32_t SyncID = SGB.getOperand(2).getImm();
2665
2666 auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
2667 Size, SyncID, DAG, TII);
2668
2669 SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
2670 }
2671
initIGLPOpt(SUnit & SU)2672 bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2673 IGLPStrategyID StrategyID =
2674 (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
2675 auto S = createIGLPStrategy(StrategyID, DAG, TII);
2676 if (!S->shouldApplyStrategy(DAG, Phase))
2677 return false;
2678
2679 IsBottomUp = S->IsBottomUp;
2680 return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2681 }
2682
2683 } // namespace
2684
2685 namespace llvm {
2686
2687 /// \p Phase specifes whether or not this is a reentry into the
2688 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2689 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
2690 /// scheduling "phases"), we can reenter this mutation framework more than once
2691 /// for a given region.
2692 std::unique_ptr<ScheduleDAGMutation>
createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase)2693 createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) {
2694 return std::make_unique<IGroupLPDAGMutation>(Phase);
2695 }
2696
2697 } // end namespace llvm
2698