xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "igrouplp"
31 
32 namespace {
33 
34 static cl::opt<bool> EnableExactSolver(
35     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36     cl::desc("Whether to use the exponential time solver to fit "
37              "the instructions to the pipeline as closely as "
38              "possible."),
39     cl::init(false));
40 
41 static cl::opt<unsigned> CutoffForExact(
42     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43     cl::desc("The maximum number of scheduling group conflicts "
44              "which we attempt to solve with the exponential time "
45              "exact solver. Problem sizes greater than this will"
46              "be solved by the less accurate greedy algorithm. Selecting "
47              "solver by size is superseded by manually selecting "
48              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49 
50 static cl::opt<uint64_t> MaxBranchesExplored(
51     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52     cl::desc("The amount of branches that we are willing to explore with"
53              "the exact algorithm before giving up."));
54 
55 static cl::opt<bool> UseCostHeur(
56     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57     cl::desc("Whether to use the cost heuristic to make choices as we "
58              "traverse the search space using the exact solver. Defaulted "
59              "to on, and if turned off, we will use the node order -- "
60              "attempting to put the later nodes in the later sched groups. "
61              "Experimentally, results are mixed, so this should be set on a "
62              "case-by-case basis."));
63 
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67   NONE = 0u,
68   ALU = 1u << 0,
69   VALU = 1u << 1,
70   SALU = 1u << 2,
71   MFMA = 1u << 3,
72   VMEM = 1u << 4,
73   VMEM_READ = 1u << 5,
74   VMEM_WRITE = 1u << 6,
75   DS = 1u << 7,
76   DS_READ = 1u << 8,
77   DS_WRITE = 1u << 9,
78   TRANS = 1u << 10,
79   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
80         DS_READ | DS_WRITE | TRANS,
81   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
82 };
83 
84 class SchedGroup;
85 
86 // InstructionRule class is used to enact a filter which determines whether or
87 // not an SU maps to a given SchedGroup. It contains complementary data
88 // structures (e.g Cache) to help those filters.
89 class InstructionRule {
90 protected:
91   const SIInstrInfo *TII;
92   unsigned SGID;
93   // A cache made available to the Filter to store SUnits for subsequent
94   // invocations of the Filter
95   std::optional<SmallVector<SUnit *, 4>> Cache;
96 
97 public:
98   virtual bool
99   apply(const SUnit *, const ArrayRef<SUnit *>,
100         SmallVectorImpl<SchedGroup> &) {
101     return true;
102   };
103 
104   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
105                   bool NeedsCache = false)
106       : TII(TII), SGID(SGID) {
107     if (NeedsCache) {
108       Cache = SmallVector<SUnit *, 4>();
109     }
110   }
111 
112   virtual ~InstructionRule() = default;
113 };
114 
115 using SUnitsToCandidateSGsMap = DenseMap<SUnit *, SmallVector<int, 4>>;
116 
117 // Classify instructions into groups to enable fine tuned control over the
118 // scheduler. These groups may be more specific than current SchedModel
119 // instruction classes.
120 class SchedGroup {
121 private:
122   // Mask that defines which instruction types can be classified into this
123   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
124   // and SCHED_GROUP_BARRIER.
125   SchedGroupMask SGMask;
126 
127   // Maximum number of SUnits that can be added to this group.
128   std::optional<unsigned> MaxSize;
129 
130   // SchedGroups will only synchronize with other SchedGroups that have the same
131   // SyncID.
132   int SyncID = 0;
133 
134   // SGID is used to map instructions to candidate SchedGroups
135   unsigned SGID;
136 
137   // The different rules each instruction in this SchedGroup must conform to
138   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
139 
140   // Count of the number of created SchedGroups, used to initialize SGID.
141   static unsigned NumSchedGroups;
142 
143   // Try to add and edge from SU A to SU B.
144   bool tryAddEdge(SUnit *A, SUnit *B);
145 
146   // Use SGMask to determine whether we can classify MI as a member of this
147   // SchedGroup object.
148   bool canAddMI(const MachineInstr &MI) const;
149 
150 public:
151   // Collection of SUnits that are classified as members of this group.
152   SmallVector<SUnit *, 32> Collection;
153 
154   ScheduleDAGInstrs *DAG;
155   const SIInstrInfo *TII;
156 
157   // Returns true if SU can be added to this SchedGroup.
158   bool canAddSU(SUnit &SU) const;
159 
160   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
161   // MakePred is true, SU will be a predecessor of the SUnits in this
162   // SchedGroup, otherwise SU will be a successor.
163   void link(SUnit &SU, bool MakePred = false);
164 
165   // Add DAG dependencies and track which edges are added, and the count of
166   // missed edges
167   int link(SUnit &SU, bool MakePred,
168            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
169 
170   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
171   // Use the predicate to determine whether SU should be a predecessor (P =
172   // true) or a successor (P = false) of this SchedGroup.
173   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
174 
175   // Add DAG dependencies such that SUnits in this group shall be ordered
176   // before SUnits in OtherGroup.
177   void link(SchedGroup &OtherGroup);
178 
179   // Returns true if no more instructions may be added to this group.
180   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
181 
182   // Append a constraint that SUs must meet in order to fit into this
183   // SchedGroup. Since many rules involve the relationship between a SchedGroup
184   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
185   // time (rather than SchedGroup init time.)
186   void addRule(std::shared_ptr<InstructionRule> NewRule) {
187     Rules.push_back(NewRule);
188   }
189 
190   // Returns true if the SU matches all rules
191   bool allowedByRules(const SUnit *SU,
192                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
193     for (auto &Rule : Rules) {
194       if (!Rule.get()->apply(SU, Collection, SyncPipe))
195         return false;
196     }
197     return true;
198   }
199 
200   // Add SU to the SchedGroup.
201   void add(SUnit &SU) {
202     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
203                       << format_hex((int)SGMask, 10, true) << " adding "
204                       << *SU.getInstr());
205     Collection.push_back(&SU);
206   }
207 
208   // Remove last element in the SchedGroup
209   void pop() { Collection.pop_back(); }
210 
211   // Identify and add all relevant SUs from the DAG to this SchedGroup.
212   void initSchedGroup();
213 
214   // Add instructions to the SchedGroup bottom up starting from RIter.
215   // PipelineInstrs is a set of instructions that should not be added to the
216   // SchedGroup even when the other conditions for adding it are satisfied.
217   // RIter will be added to the SchedGroup as well, and dependencies will be
218   // added so that RIter will always be scheduled at the end of the group.
219   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
220                       SUnitsToCandidateSGsMap &SyncedInstrs);
221 
222   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
223 
224   int getSyncID() { return SyncID; }
225 
226   int getSGID() { return SGID; }
227 
228   SchedGroupMask getMask() { return SGMask; }
229 
230   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
231              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
232       : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
233     SGID = NumSchedGroups++;
234   }
235 
236   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
237              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
238       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
239     SGID = NumSchedGroups++;
240   }
241 };
242 
243 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
244 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
245   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
246          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
247          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
248 
249   while (!SU.Preds.empty())
250     for (auto &P : SU.Preds)
251       SU.removePred(P);
252 
253   while (!SU.Succs.empty())
254     for (auto &S : SU.Succs)
255       for (auto &SP : S.getSUnit()->Preds)
256         if (SP.getSUnit() == &SU)
257           S.getSUnit()->removePred(SP);
258 }
259 
260 using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;
261 using SUsToCandSGsVec = SmallVector<SUToCandSGsPair, 4>;
262 
263 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
264 // in non-trivial cases. For example, if the requested pipeline is
265 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
266 // in the DAG, then we will have an instruction that can not be trivially
267 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
268 // to find a good solution to the pipeline -- a greedy algorithm and an exact
269 // algorithm. The exact algorithm has an exponential time complexity and should
270 // only be used for small sized problems or medium sized problems where an exact
271 // solution is highly desired.
272 class PipelineSolver {
273   ScheduleDAGMI *DAG;
274 
275   // Instructions that can be assigned to multiple SchedGroups
276   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
277   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
278   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
279   // The current working pipeline
280   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
281   // The pipeline that has the best solution found so far
282   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
283 
284   // Whether or not we actually have any SyncedInstrs to try to solve.
285   bool NeedsSolver = false;
286 
287   // Compute an estimate of the size of search tree -- the true size is
288   // the product of each conflictedInst.Matches.size() across all SyncPipelines
289   unsigned computeProblemSize();
290 
291   // The cost penalty of not assigning a SU to a SchedGroup
292   int MissPenalty = 0;
293 
294   // Costs in terms of the number of edges we are unable to add
295   int BestCost = -1;
296   int CurrCost = 0;
297 
298   // Index pointing to the conflicting instruction that is currently being
299   // fitted
300   int CurrConflInstNo = 0;
301   // Index to the pipeline that is currently being fitted
302   int CurrSyncGroupIdx = 0;
303   // The first non trivial pipeline
304   int BeginSyncGroupIdx = 0;
305 
306   // How many branches we have explored
307   uint64_t BranchesExplored = 0;
308 
309   // The direction in which we process the candidate SchedGroups per SU
310   bool IsBottomUp = true;
311 
312   // Update indices to fit next conflicting instruction
313   void advancePosition();
314   // Recede indices to attempt to find better fit for previous conflicting
315   // instruction
316   void retreatPosition();
317 
318   // The exponential time algorithm which finds the provably best fit
319   bool solveExact();
320   // The polynomial time algorithm which attempts to find a good fit
321   bool solveGreedy();
322   // Find the best SchedGroup for the current SU using the heuristic given all
323   // current information. One step in the greedy algorithm. Templated against
324   // the SchedGroup iterator (either reverse or forward).
325   template <typename T>
326   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
327                   T E);
328   // Whether or not the current solution is optimal
329   bool checkOptimal();
330   // Populate the ready list, prioiritizing fewest missed edges first
331   // Templated against the SchedGroup iterator (either reverse or forward).
332   template <typename T>
333   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
334                          T E);
335   // Add edges corresponding to the SchedGroups as assigned by solver
336   void makePipeline();
337   // Link the SchedGroups in the best found pipeline.
338   // Tmplated against the SchedGroup iterator (either reverse or forward).
339   template <typename T> void linkSchedGroups(T I, T E);
340   // Add the edges from the SU to the other SchedGroups in pipeline, and
341   // return the number of edges missed.
342   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
343                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
344   /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
345   /// returns the cost (in terms of missed pipeline edges), and tracks the edges
346   /// added in \p AddedEdges
347   template <typename T>
348   int linkSUnit(SUnit *SU, int SGID,
349                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
350   /// Remove the edges passed via \p AddedEdges
351   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
352   // Convert the passed in maps to arrays for bidirectional iterators
353   void convertSyncMapsToArrays();
354 
355   void reset();
356 
357 public:
358   // Invoke the solver to map instructions to instruction groups. Heuristic &&
359   // command-line-option determines to use exact or greedy algorithm.
360   void solve();
361 
362   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
363                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
364                  ScheduleDAGMI *DAG, bool IsBottomUp = true)
365       : DAG(DAG), SyncedInstrs(SyncedInstrs),
366         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
367 
368     for (auto &PipelineInstrs : SyncedInstrs) {
369       if (PipelineInstrs.second.size() > 0) {
370         NeedsSolver = true;
371         break;
372       }
373     }
374 
375     if (!NeedsSolver)
376       return;
377 
378     convertSyncMapsToArrays();
379 
380     CurrPipeline = BestPipeline;
381 
382     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
383            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
384       ++BeginSyncGroupIdx;
385 
386     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
387       return;
388   }
389 };
390 
391 void PipelineSolver::reset() {
392 
393   for (auto &SyncPipeline : CurrPipeline) {
394     for (auto &SG : SyncPipeline) {
395       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
396       SG.Collection.clear();
397       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
398         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
399       });
400       if (SchedBarr != TempCollection.end())
401         SG.Collection.push_back(*SchedBarr);
402     }
403   }
404 
405   CurrSyncGroupIdx = BeginSyncGroupIdx;
406   CurrConflInstNo = 0;
407   CurrCost = 0;
408 }
409 
410 void PipelineSolver::convertSyncMapsToArrays() {
411   for (auto &SyncPipe : SyncedSchedGroups) {
412     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
413   }
414 
415   int PipelineIDx = SyncedInstrs.size() - 1;
416   PipelineInstrs.resize(SyncedInstrs.size());
417   for (auto &SyncInstrMap : SyncedInstrs) {
418     for (auto &SUsToCandSGs : SyncInstrMap.second) {
419       if (PipelineInstrs[PipelineIDx].size() == 0) {
420         PipelineInstrs[PipelineIDx].push_back(
421             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
422         continue;
423       }
424       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
425       // Insert them in sorted order -- this allows for good parsing order in
426       // the greedy algorithm
427       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
428              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
429         ++SortPosition;
430       PipelineInstrs[PipelineIDx].insert(
431           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
432     }
433     --PipelineIDx;
434   }
435 }
436 
437 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
438   for (; I != E; ++I) {
439     auto &GroupA = *I;
440     for (auto J = std::next(I); J != E; ++J) {
441       auto &GroupB = *J;
442       GroupA.link(GroupB);
443     }
444   }
445 }
446 
447 void PipelineSolver::makePipeline() {
448   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
449   for (auto &SyncPipeline : BestPipeline) {
450     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
451     for (auto &SG : SyncPipeline) {
452       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
453                         << " has: \n");
454       SUnit *SGBarr = nullptr;
455       for (auto &SU : SG.Collection) {
456         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
457           SGBarr = SU;
458         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
459       }
460       // Command line requested IGroupLP doesn't have SGBarr
461       if (!SGBarr)
462         continue;
463       resetEdges(*SGBarr, DAG);
464       SG.link(*SGBarr, false);
465     }
466   }
467 
468   for (auto &SyncPipeline : BestPipeline) {
469     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
470                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
471   }
472 }
473 
474 template <typename T>
475 int PipelineSolver::linkSUnit(
476     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
477     T I, T E) {
478   bool MakePred = false;
479   int AddedCost = 0;
480   for (; I < E; ++I) {
481     if (I->getSGID() == SGID) {
482       MakePred = true;
483       continue;
484     }
485     auto Group = *I;
486     AddedCost += Group.link(*SU, MakePred, AddedEdges);
487     assert(AddedCost >= 0);
488   }
489   return AddedCost;
490 }
491 
492 int PipelineSolver::addEdges(
493     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
494     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
495 
496   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
497   // instructions that are the ultimate successors in the resultant mutation.
498   // Therefore, in such a configuration, the SchedGroups occurring before the
499   // candidate SGID are successors of the candidate SchedGroup, thus the current
500   // SU should be linked as a predecessor to SUs in those SchedGroups. The
501   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
502   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
503   // IsBottomUp (in reverse).
504   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
505                                 SyncPipeline.rend())
506                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
507                                 SyncPipeline.end());
508 }
509 
510 void PipelineSolver::removeEdges(
511     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
512   // Only remove the edges that we have added when testing
513   // the fit.
514   for (auto &PredSuccPair : EdgesToRemove) {
515     SUnit *Pred = PredSuccPair.first;
516     SUnit *Succ = PredSuccPair.second;
517 
518     auto Match = llvm::find_if(
519         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
520     if (Match != Succ->Preds.end()) {
521       assert(Match->isArtificial());
522       Succ->removePred(*Match);
523     }
524   }
525 }
526 
527 void PipelineSolver::advancePosition() {
528   ++CurrConflInstNo;
529 
530   if (static_cast<size_t>(CurrConflInstNo) >=
531       PipelineInstrs[CurrSyncGroupIdx].size()) {
532     CurrConflInstNo = 0;
533     ++CurrSyncGroupIdx;
534     // Advance to next non-trivial pipeline
535     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
536            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
537       ++CurrSyncGroupIdx;
538   }
539 }
540 
541 void PipelineSolver::retreatPosition() {
542   assert(CurrConflInstNo >= 0);
543   assert(CurrSyncGroupIdx >= 0);
544 
545   if (CurrConflInstNo > 0) {
546     --CurrConflInstNo;
547     return;
548   }
549 
550   if (CurrConflInstNo == 0) {
551     // If we return to the starting position, we have explored
552     // the entire tree
553     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
554       return;
555 
556     --CurrSyncGroupIdx;
557     // Go to previous non-trivial pipeline
558     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
559       --CurrSyncGroupIdx;
560 
561     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
562   }
563 }
564 
565 bool PipelineSolver::checkOptimal() {
566   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
567     if (BestCost == -1 || CurrCost < BestCost) {
568       BestPipeline = CurrPipeline;
569       BestCost = CurrCost;
570       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
571     }
572     assert(BestCost >= 0);
573   }
574 
575   bool DoneExploring = false;
576   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
577     DoneExploring = true;
578 
579   return (DoneExploring || BestCost == 0);
580 }
581 
582 template <typename T>
583 void PipelineSolver::populateReadyList(
584     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
585   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
586   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
587   assert(CurrSU.second.size() >= 1);
588 
589   for (; I != E; ++I) {
590     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
591     int CandSGID = *I;
592     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
593       return SG.getSGID() == CandSGID;
594     });
595     assert(Match);
596 
597     if (UseCostHeur) {
598       if (Match->isFull()) {
599         ReadyList.push_back(std::pair(*I, MissPenalty));
600         continue;
601       }
602 
603       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
604       ReadyList.push_back(std::pair(*I, TempCost));
605       removeEdges(AddedEdges);
606     } else
607       ReadyList.push_back(std::pair(*I, -1));
608   }
609 
610   if (UseCostHeur) {
611     std::sort(ReadyList.begin(), ReadyList.end(),
612               [](std::pair<int, int> A, std::pair<int, int> B) {
613                 return A.second < B.second;
614               });
615   }
616 
617   assert(ReadyList.size() == CurrSU.second.size());
618 }
619 
620 bool PipelineSolver::solveExact() {
621   if (checkOptimal())
622     return true;
623 
624   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
625     return false;
626 
627   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
628   assert(static_cast<size_t>(CurrConflInstNo) <
629          PipelineInstrs[CurrSyncGroupIdx].size());
630   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
631   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
632                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
633 
634   // SchedGroup -> Cost pairs
635   SmallVector<std::pair<int, int>, 4> ReadyList;
636   // Prioritize the candidate sched groups in terms of lowest cost first
637   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
638                                  CurrSU.second.rend())
639              : populateReadyList(ReadyList, CurrSU.second.begin(),
640                                  CurrSU.second.end());
641 
642   auto I = ReadyList.begin();
643   auto E = ReadyList.end();
644   for (; I != E; ++I) {
645     // If we are trying SGs in least cost order, and the current SG is cost
646     // infeasible, then all subsequent SGs will also be cost infeasible, so we
647     // can prune.
648     if (BestCost != -1 && (CurrCost + I->second > BestCost))
649       return false;
650 
651     int CandSGID = I->first;
652     int AddedCost = 0;
653     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
654     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
655     SchedGroup *Match;
656     for (auto &SG : SyncPipeline) {
657       if (SG.getSGID() == CandSGID)
658         Match = &SG;
659     }
660 
661     if (Match->isFull())
662       continue;
663 
664     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
665       continue;
666 
667     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
668                       << (int)Match->getMask() << "and ID " << CandSGID
669                       << "\n");
670     Match->add(*CurrSU.first);
671     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
672     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
673     CurrCost += AddedCost;
674     advancePosition();
675     ++BranchesExplored;
676     bool FinishedExploring = false;
677     // If the Cost after adding edges is greater than a known solution,
678     // backtrack
679     if (CurrCost < BestCost || BestCost == -1) {
680       if (solveExact()) {
681         FinishedExploring = BestCost != 0;
682         if (!FinishedExploring)
683           return true;
684       }
685     }
686 
687     retreatPosition();
688     CurrCost -= AddedCost;
689     removeEdges(AddedEdges);
690     Match->pop();
691     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
692     if (FinishedExploring)
693       return true;
694   }
695 
696   // Try the pipeline where the current instruction is omitted
697   // Potentially if we omit a problematic instruction from the pipeline,
698   // all the other instructions can nicely fit.
699   CurrCost += MissPenalty;
700   advancePosition();
701 
702   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
703 
704   bool FinishedExploring = false;
705   if (CurrCost < BestCost || BestCost == -1) {
706     if (solveExact()) {
707       bool FinishedExploring = BestCost != 0;
708       if (!FinishedExploring)
709         return true;
710     }
711   }
712 
713   retreatPosition();
714   CurrCost -= MissPenalty;
715   return FinishedExploring;
716 }
717 
718 template <typename T>
719 void PipelineSolver::greedyFind(
720     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
721   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
722   int BestNodeCost = -1;
723   int TempCost;
724   SchedGroup *BestGroup = nullptr;
725   int BestGroupID = -1;
726   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
727   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
728                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
729 
730   // Since we have added the potential SchedGroups from bottom up, but
731   // traversed the DAG from top down, parse over the groups from last to
732   // first. If we fail to do this for the greedy algorithm, the solution will
733   // likely not be good in more complex cases.
734   for (; I != E; ++I) {
735     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
736     int CandSGID = *I;
737     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
738       return SG.getSGID() == CandSGID;
739     });
740     assert(Match);
741 
742     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
743                       << (int)Match->getMask() << "\n");
744 
745     if (Match->isFull()) {
746       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
747       continue;
748     }
749     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
750       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
751       continue;
752     }
753     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
754     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
755     if (TempCost < BestNodeCost || BestNodeCost == -1) {
756       BestGroup = Match;
757       BestNodeCost = TempCost;
758       BestGroupID = CandSGID;
759     }
760     removeEdges(AddedEdges);
761     if (BestNodeCost == 0)
762       break;
763   }
764 
765   if (BestGroupID != -1) {
766     BestGroup->add(*CurrSU.first);
767     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
768     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
769                       << (int)BestGroup->getMask() << "\n");
770     BestCost += TempCost;
771   } else
772     BestCost += MissPenalty;
773 
774   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
775 }
776 
777 bool PipelineSolver::solveGreedy() {
778   BestCost = 0;
779   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
780 
781   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
782     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
783     IsBottomUp
784         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
785         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
786     advancePosition();
787   }
788   BestPipeline = CurrPipeline;
789   removeEdges(AddedEdges);
790   return false;
791 }
792 
793 unsigned PipelineSolver::computeProblemSize() {
794   unsigned ProblemSize = 0;
795   for (auto &PipeConflicts : PipelineInstrs) {
796     ProblemSize += PipeConflicts.size();
797   }
798 
799   return ProblemSize;
800 }
801 
802 void PipelineSolver::solve() {
803   if (!NeedsSolver)
804     return;
805 
806   unsigned ProblemSize = computeProblemSize();
807   assert(ProblemSize > 0);
808 
809   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
810   MissPenalty = (ProblemSize / 2) + 1;
811 
812   LLVM_DEBUG(DAG->dump());
813   if (EnableExactSolver || BelowCutoff) {
814     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
815     solveGreedy();
816     reset();
817     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
818     if (BestCost > 0) {
819       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
820       solveExact();
821       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
822     }
823   } else { // Use the Greedy Algorithm by default
824     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
825     solveGreedy();
826   }
827 
828   makePipeline();
829   LLVM_DEBUG(dbgs() << "After applying mutation\n");
830   LLVM_DEBUG(DAG->dump());
831 }
832 
833 enum IGLPStrategyID : int {
834   MFMASmallGemmOptID = 0,
835   MFMASmallGemmSingleWaveOptID = 1,
836   MFMAExpInterleave = 2
837 };
838 
839 // Implement a IGLP scheduling strategy.
840 class IGLPStrategy {
841 protected:
842   ScheduleDAGInstrs *DAG;
843 
844   const SIInstrInfo *TII;
845 
846 public:
847   /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
848   virtual bool applyIGLPStrategy(
849       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
850       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
851       AMDGPU::SchedulingPhase Phase) = 0;
852 
853   // Returns true if this strategy should be applied to a ScheduleDAG.
854   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
855                                    AMDGPU::SchedulingPhase Phase) = 0;
856 
857   bool IsBottomUp = true;
858 
859   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
860       : DAG(DAG), TII(TII) {}
861 
862   virtual ~IGLPStrategy() = default;
863 };
864 
865 class MFMASmallGemmOpt final : public IGLPStrategy {
866 private:
867 public:
868   bool applyIGLPStrategy(
869       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
870       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
871       AMDGPU::SchedulingPhase Phase) override;
872 
873   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
874                            AMDGPU::SchedulingPhase Phase) override {
875     return true;
876   }
877 
878   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
879       : IGLPStrategy(DAG, TII) {
880     IsBottomUp = true;
881   }
882 };
883 
884 bool MFMASmallGemmOpt::applyIGLPStrategy(
885     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
886     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
887     AMDGPU::SchedulingPhase Phase) {
888   // Count the number of MFMA instructions.
889   unsigned MFMACount = 0;
890   for (const MachineInstr &I : *DAG)
891     if (TII->isMFMAorWMMA(I))
892       ++MFMACount;
893 
894   const unsigned PipelineSyncID = 0;
895   SchedGroup *SG = nullptr;
896   for (unsigned I = 0; I < MFMACount * 3; ++I) {
897     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
898         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
899     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
900 
901     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
902         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
903     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
904   }
905 
906   return true;
907 }
908 
909 class MFMAExpInterleaveOpt final : public IGLPStrategy {
910 private:
911   // The count of TRANS SUs involved in the interleaved pipeline
912   static unsigned TransPipeCount;
913   // The count of MFMA SUs involved in the interleaved pipeline
914   static unsigned MFMAPipeCount;
915   // The count of Add SUs involved in the interleaved pipeline
916   static unsigned AddPipeCount;
917   // The number of transitive MFMA successors for each TRANS SU
918   static unsigned MFMAEnablement;
919   // The number of transitive TRANS predecessors for each MFMA SU
920   static unsigned ExpRequirement;
921   // The count of independent "chains" of MFMA instructions in the pipeline
922   static unsigned MFMAChains;
923   // The length of each independent "chain" of MFMA instructions
924   static unsigned MFMAChainLength;
925   // Whether or not the pipeline has V_CVT instructions
926   static bool HasCvt;
927   // Whether or not there are instructions between the TRANS instruction and
928   // V_CVT
929   static bool HasChainBetweenCvt;
930   // The first occuring DS_READ which feeds an MFMA chain
931   static std::optional<unsigned> FirstPipeDSR;
932   // The MFMAPipe SUs with no MFMA predecessors
933   SmallVector<SUnit *, 4> MFMAChainSeeds;
934   // Compute the heuristics for the pipeline, returning whether or not the DAG
935   // is well formatted for the mutation
936   bool analyzeDAG(const SIInstrInfo *TII);
937 
938   /// Whether or not the instruction is a transitive predecessor of an MFMA
939   /// instruction
940   class IsPipeExp final : public InstructionRule {
941   public:
942     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
943                SmallVectorImpl<SchedGroup> &SyncPipe) override {
944 
945       auto DAG = SyncPipe[0].DAG;
946 
947       if (Cache->empty()) {
948         auto I = DAG->SUnits.rbegin();
949         auto E = DAG->SUnits.rend();
950         for (; I != E; I++) {
951           if (TII->isMFMAorWMMA(*I->getInstr()))
952             Cache->push_back(&*I);
953         }
954         if (Cache->empty())
955           return false;
956       }
957 
958       auto Reaches = (std::any_of(
959           Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *TargetSU) {
960             return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
961           }));
962 
963       return Reaches;
964     }
965     IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
966         : InstructionRule(TII, SGID, NeedsCache) {}
967   };
968 
969   /// Whether or not the instruction is a transitive predecessor of the
970   /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
971   class EnablesNthMFMA final : public InstructionRule {
972   private:
973     unsigned Number = 1;
974 
975   public:
976     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
977                SmallVectorImpl<SchedGroup> &SyncPipe) override {
978       bool FoundTrans = false;
979       unsigned Counter = 1;
980       auto DAG = SyncPipe[0].DAG;
981 
982       if (Cache->empty()) {
983         SmallVector<SUnit *, 8> Worklist;
984 
985         auto I = DAG->SUnits.begin();
986         auto E = DAG->SUnits.end();
987         for (; I != E; I++) {
988           if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
989             if (Counter == Number) {
990               Cache->push_back(&*I);
991               break;
992             }
993             ++Counter;
994           }
995           if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
996             FoundTrans = true;
997         }
998         if (Cache->empty())
999           return false;
1000       }
1001 
1002       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1003     }
1004 
1005     EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1006                    bool NeedsCache = false)
1007         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1008   };
1009 
1010   /// Whether or not the instruction enables the exact MFMA that is the \p
1011   /// Number th MFMA in the chain starting with \p ChainSeed
1012   class EnablesNthMFMAInChain final : public InstructionRule {
1013   private:
1014     unsigned Number = 1;
1015     SUnit *ChainSeed;
1016 
1017   public:
1018     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1019                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1020       auto DAG = SyncPipe[0].DAG;
1021 
1022       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1023         return false;
1024 
1025       if (Cache->empty()) {
1026         auto TempSU = ChainSeed;
1027         auto Depth = Number;
1028         while (Depth > 0) {
1029           --Depth;
1030           bool Found = false;
1031           for (auto &Succ : TempSU->Succs) {
1032             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1033               TempSU = Succ.getSUnit();
1034               Found = true;
1035               break;
1036             }
1037           }
1038           if (!Found)
1039             return false;
1040         }
1041 
1042         Cache->push_back(TempSU);
1043       }
1044       // If we failed to find the instruction to be placed into the cache, we
1045       // would have already exited.
1046       assert(!Cache->empty());
1047 
1048       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1049     }
1050 
1051     EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1052                           const SIInstrInfo *TII, unsigned SGID,
1053                           bool NeedsCache = false)
1054         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1055           ChainSeed(ChainSeed) {}
1056   };
1057 
1058   /// Whether or not the instruction has less than \p Size immediate successors.
1059   /// If \p HasIntermediary is true, this tests also whether all successors of
1060   /// the SUnit have less than \p Size successors.
1061   class LessThanNSuccs final : public InstructionRule {
1062   private:
1063     unsigned Size = 1;
1064     bool HasIntermediary = false;
1065 
1066   public:
1067     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1068                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1069       if (!SyncPipe.size())
1070         return false;
1071 
1072       auto SuccSize = std::count_if(
1073           SU->Succs.begin(), SU->Succs.end(),
1074           [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1075       if (SuccSize >= Size)
1076         return false;
1077 
1078       if (HasIntermediary) {
1079         for (auto Succ : SU->Succs) {
1080           auto SuccSize = std::count_if(
1081               Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1082               [](const SDep &SuccSucc) {
1083                 return SuccSucc.getKind() == SDep::Data;
1084               });
1085           if (SuccSize >= Size)
1086             return false;
1087         }
1088       }
1089 
1090       return true;
1091     }
1092     LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1093                    bool HasIntermediary = false, bool NeedsCache = false)
1094         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1095           HasIntermediary(HasIntermediary) {}
1096   };
1097 
1098   /// Whether or not the instruction has greater than or equal to \p Size
1099   /// immediate successors. If \p HasIntermediary is true, this tests also
1100   /// whether all successors of the SUnit have greater than or equal to \p Size
1101   /// successors.
1102   class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1103   private:
1104     unsigned Size = 1;
1105     bool HasIntermediary = false;
1106 
1107   public:
1108     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1109                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1110       if (!SyncPipe.size())
1111         return false;
1112 
1113       auto SuccSize = std::count_if(
1114           SU->Succs.begin(), SU->Succs.end(),
1115           [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1116       if (SuccSize >= Size)
1117         return true;
1118 
1119       if (HasIntermediary) {
1120         for (auto Succ : SU->Succs) {
1121           auto SuccSize = std::count_if(
1122               Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1123               [](const SDep &SuccSucc) {
1124                 return SuccSucc.getKind() == SDep::Data;
1125               });
1126           if (SuccSize >= Size)
1127             return true;
1128         }
1129       }
1130 
1131       return false;
1132     }
1133     GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1134                                unsigned SGID, bool HasIntermediary = false,
1135                                bool NeedsCache = false)
1136         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1137           HasIntermediary(HasIntermediary) {}
1138   };
1139 
1140   // Whether or not the instruction is a relevant V_CVT instruction.
1141   class IsCvt final : public InstructionRule {
1142   public:
1143     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1144                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1145       auto Opc = SU->getInstr()->getOpcode();
1146       return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1147              Opc == AMDGPU::V_CVT_I32_F32_e32;
1148     }
1149     IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1150         : InstructionRule(TII, SGID, NeedsCache) {}
1151   };
1152 
1153   // Whether or not the instruction is FMA_F32.
1154   class IsFMA final : public InstructionRule {
1155   public:
1156     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1157                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1158       return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1159              SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1160     }
1161     IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1162         : InstructionRule(TII, SGID, NeedsCache) {}
1163   };
1164 
1165   // Whether or not the instruction is a V_ADD_F32 instruction.
1166   class IsPipeAdd final : public InstructionRule {
1167   public:
1168     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1169                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1170       return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1171     }
1172     IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1173         : InstructionRule(TII, SGID, NeedsCache) {}
1174   };
1175 
1176   /// Whether or not the instruction is an immediate RAW successor
1177   /// of the SchedGroup \p Distance steps before.
1178   class IsSuccOfPrevNthGroup final : public InstructionRule {
1179   private:
1180     unsigned Distance = 1;
1181 
1182   public:
1183     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1184                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1185       SchedGroup *OtherGroup = nullptr;
1186       if (!SyncPipe.size())
1187         return false;
1188 
1189       for (auto &PipeSG : SyncPipe) {
1190         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1191           OtherGroup = &PipeSG;
1192       }
1193 
1194       if (!OtherGroup)
1195         return false;
1196       if (!OtherGroup->Collection.size())
1197         return true;
1198 
1199       for (auto &OtherEle : OtherGroup->Collection) {
1200         for (auto &Succ : OtherEle->Succs) {
1201           if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1202             return true;
1203         }
1204       }
1205 
1206       return false;
1207     }
1208     IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1209                          unsigned SGID, bool NeedsCache = false)
1210         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1211   };
1212 
1213   /// Whether or not the instruction is a transitive successor of any
1214   /// instruction the the SchedGroup \p Distance steps before.
1215   class IsReachableFromPrevNthGroup final : public InstructionRule {
1216   private:
1217     unsigned Distance = 1;
1218 
1219   public:
1220     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1221                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1222       SchedGroup *OtherGroup = nullptr;
1223       if (!SyncPipe.size())
1224         return false;
1225 
1226       for (auto &PipeSG : SyncPipe) {
1227         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1228           OtherGroup = &PipeSG;
1229       }
1230 
1231       if (!OtherGroup)
1232         return false;
1233       if (!OtherGroup->Collection.size())
1234         return true;
1235 
1236       auto DAG = SyncPipe[0].DAG;
1237 
1238       for (auto &OtherEle : OtherGroup->Collection)
1239         if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
1240           return true;
1241 
1242       return false;
1243     }
1244     IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1245                                 unsigned SGID, bool NeedsCache = false)
1246         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1247   };
1248 
1249   /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1250   class OccursAtOrAfterNode final : public InstructionRule {
1251   private:
1252     unsigned Number = 1;
1253 
1254   public:
1255     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1256                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1257 
1258       return SU->NodeNum >= Number;
1259     }
1260     OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1261                         bool NeedsCache = false)
1262         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1263   };
1264 
1265   /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1266   /// starting with \p ChainSeed
1267   class IsExactMFMA final : public InstructionRule {
1268   private:
1269     unsigned Number = 1;
1270     SUnit *ChainSeed;
1271 
1272   public:
1273     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1274                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1275       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1276         return false;
1277 
1278       if (Cache->empty()) {
1279         auto TempSU = ChainSeed;
1280         auto Depth = Number;
1281         while (Depth > 0) {
1282           --Depth;
1283           bool Found = false;
1284           for (auto &Succ : TempSU->Succs) {
1285             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1286               TempSU = Succ.getSUnit();
1287               Found = true;
1288               break;
1289             }
1290           }
1291           if (!Found) {
1292             return false;
1293           }
1294         }
1295         Cache->push_back(TempSU);
1296       }
1297       // If we failed to find the instruction to be placed into the cache, we
1298       // would have already exited.
1299       assert(!Cache->empty());
1300 
1301       return (*Cache)[0] == SU;
1302     }
1303 
1304     IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1305                 unsigned SGID, bool NeedsCache = false)
1306         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1307           ChainSeed(ChainSeed) {}
1308   };
1309 
1310   // Whether the instruction occurs after the first TRANS instruction. This
1311   // implies the instruction can not be a predecessor of the first TRANS
1312   // insruction
1313   class OccursAfterExp final : public InstructionRule {
1314   public:
1315     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1316                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1317 
1318       SmallVector<SUnit *, 12> Worklist;
1319       auto DAG = SyncPipe[0].DAG;
1320       if (Cache->empty()) {
1321         for (auto &SU : DAG->SUnits)
1322           if (TII->isTRANS(SU.getInstr()->getOpcode())) {
1323             Cache->push_back(&SU);
1324             break;
1325           }
1326         if (Cache->empty())
1327           return false;
1328       }
1329 
1330       return SU->NodeNum > (*Cache)[0]->NodeNum;
1331     }
1332 
1333     OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1334                    bool NeedsCache = false)
1335         : InstructionRule(TII, SGID, NeedsCache) {}
1336   };
1337 
1338 public:
1339   bool applyIGLPStrategy(
1340       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1341       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1342       AMDGPU::SchedulingPhase Phase) override;
1343 
1344   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1345                            AMDGPU::SchedulingPhase Phase) override;
1346 
1347   MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1348       : IGLPStrategy(DAG, TII) {
1349     IsBottomUp = false;
1350   }
1351 };
1352 
1353 unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1354 unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1355 unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1356 unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1357 unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1358 unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1359 unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
1360 bool MFMAExpInterleaveOpt::HasCvt = false;
1361 bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1362 std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1363 
1364 bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1365   SmallVector<SUnit *, 10> ExpPipeCands;
1366   SmallVector<SUnit *, 10> MFMAPipeCands;
1367   SmallVector<SUnit *, 10> MFMAPipeSUs;
1368   SmallVector<SUnit *, 10> PackSUs;
1369   SmallVector<SUnit *, 10> CvtSUs;
1370 
1371   auto isBitPack = [](unsigned Opc) {
1372     return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1373   };
1374 
1375   auto isCvt = [](unsigned Opc) {
1376     return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1377   };
1378 
1379   auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1380 
1381   AddPipeCount = 0;
1382   for (SUnit &SU : DAG->SUnits) {
1383     auto Opc = SU.getInstr()->getOpcode();
1384     if (TII->isTRANS(Opc)) {
1385       // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1386       if (SU.Succs.size() >= 7)
1387         continue;
1388       for (auto &Succ : SU.Succs) {
1389         if (Succ.getSUnit()->Succs.size() >= 7)
1390           continue;
1391       }
1392       ExpPipeCands.push_back(&SU);
1393     }
1394 
1395     if (TII->isMFMAorWMMA(*SU.getInstr()))
1396       MFMAPipeCands.push_back(&SU);
1397 
1398     if (isBitPack(Opc))
1399       PackSUs.push_back(&SU);
1400 
1401     if (isCvt(Opc))
1402       CvtSUs.push_back(&SU);
1403 
1404     if (isAdd(Opc))
1405       ++AddPipeCount;
1406   }
1407 
1408   if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1409     return false;
1410 
1411   TransPipeCount = 0;
1412 
1413   std::optional<SUnit *> TempMFMA;
1414   std::optional<SUnit *> TempExp;
1415   // Count the number of EXPs that reach an MFMA
1416   for (auto &PredSU : ExpPipeCands) {
1417     for (auto &SuccSU : MFMAPipeCands) {
1418       if (DAG->IsReachable(SuccSU, PredSU)) {
1419         if (!TempExp) {
1420           TempExp = PredSU;
1421           TempMFMA = SuccSU;
1422         }
1423         MFMAPipeSUs.push_back(SuccSU);
1424         ++TransPipeCount;
1425         break;
1426       }
1427     }
1428   }
1429 
1430   if (!(TempExp && TempMFMA))
1431     return false;
1432 
1433   HasChainBetweenCvt =
1434       std::find_if((*TempExp)->Succs.begin(), (*TempExp)->Succs.end(),
1435                    [&isCvt](SDep &Succ) {
1436                      return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1437                    }) == (*TempExp)->Succs.end();
1438 
1439   // Count the number of MFMAs that are reached by an EXP
1440   for (auto &SuccSU : MFMAPipeCands) {
1441     if (MFMAPipeSUs.size() &&
1442         std::find_if(MFMAPipeSUs.begin(), MFMAPipeSUs.end(),
1443                      [&SuccSU](SUnit *PotentialMatch) {
1444                        return PotentialMatch->NodeNum == SuccSU->NodeNum;
1445                      }) != MFMAPipeSUs.end())
1446       continue;
1447 
1448     for (auto &PredSU : ExpPipeCands) {
1449       if (DAG->IsReachable(SuccSU, PredSU)) {
1450         MFMAPipeSUs.push_back(SuccSU);
1451         break;
1452       }
1453     }
1454   }
1455 
1456   MFMAPipeCount = MFMAPipeSUs.size();
1457 
1458   assert(TempExp && TempMFMA);
1459   assert(MFMAPipeCount > 0);
1460 
1461   std::optional<SUnit *> TempCvt;
1462   for (auto &SuccSU : CvtSUs) {
1463     if (DAG->IsReachable(SuccSU, *TempExp)) {
1464       TempCvt = SuccSU;
1465       break;
1466     }
1467   }
1468 
1469   HasCvt = false;
1470   if (TempCvt.has_value()) {
1471     for (auto &SuccSU : MFMAPipeSUs) {
1472       if (DAG->IsReachable(SuccSU, *TempCvt)) {
1473         HasCvt = true;
1474         break;
1475       }
1476     }
1477   }
1478 
1479   MFMAChains = 0;
1480   for (auto &MFMAPipeSU : MFMAPipeSUs) {
1481     if (is_contained(MFMAChainSeeds, MFMAPipeSU))
1482       continue;
1483     if (!std::any_of(MFMAPipeSU->Preds.begin(), MFMAPipeSU->Preds.end(),
1484                      [&TII](SDep &Succ) {
1485                        return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1486                      })) {
1487       MFMAChainSeeds.push_back(MFMAPipeSU);
1488       ++MFMAChains;
1489     }
1490   }
1491 
1492   if (!MFMAChains)
1493     return false;
1494 
1495   for (auto Pred : MFMAChainSeeds[0]->Preds) {
1496     if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&
1497         Pred.getSUnit()->getInstr()->mayLoad())
1498       FirstPipeDSR = Pred.getSUnit()->NodeNum;
1499   }
1500 
1501   MFMAChainLength = MFMAPipeCount / MFMAChains;
1502 
1503   // The number of bit pack operations that depend on a single V_EXP
1504   unsigned PackSuccCount = std::count_if(
1505       PackSUs.begin(), PackSUs.end(), [this, &TempExp](SUnit *VPack) {
1506         return DAG->IsReachable(VPack, *TempExp);
1507       });
1508 
1509   // The number of bit pack operations an MFMA depends on
1510   unsigned PackPredCount =
1511       std::count_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1512                     [&isBitPack](SDep &Pred) {
1513                       auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1514                       return isBitPack(Opc);
1515                     });
1516 
1517   auto PackPred =
1518       std::find_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1519                    [&isBitPack](SDep &Pred) {
1520                      auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1521                      return isBitPack(Opc);
1522                    });
1523 
1524   if (PackPred == (*TempMFMA)->Preds.end())
1525     return false;
1526 
1527   MFMAEnablement = 0;
1528   ExpRequirement = 0;
1529   // How many MFMAs depend on a single bit pack operation
1530   MFMAEnablement =
1531       std::count_if(PackPred->getSUnit()->Succs.begin(),
1532                     PackPred->getSUnit()->Succs.end(), [&TII](SDep &Succ) {
1533                       return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1534                     });
1535 
1536   // The number of MFMAs that depend on a single V_EXP
1537   MFMAEnablement *= PackSuccCount;
1538 
1539   // The number of V_EXPs required to resolve all dependencies for an MFMA
1540   ExpRequirement =
1541       std::count_if(ExpPipeCands.begin(), ExpPipeCands.end(),
1542                     [this, &PackPred](SUnit *ExpBase) {
1543                       return DAG->IsReachable(PackPred->getSUnit(), ExpBase);
1544                     });
1545 
1546   ExpRequirement *= PackPredCount;
1547   return true;
1548 }
1549 
1550 bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1551                                                AMDGPU::SchedulingPhase Phase) {
1552   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1553   const SIInstrInfo *TII = ST.getInstrInfo();
1554 
1555   if (Phase != AMDGPU::SchedulingPhase::PostRA)
1556     MFMAChainSeeds.clear();
1557   if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1558     return false;
1559 
1560   return true;
1561 }
1562 
1563 bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1564     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1565     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1566     AMDGPU::SchedulingPhase Phase) {
1567 
1568   bool IsSmallKernelType =
1569       MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1570   bool IsLargeKernelType =
1571       MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1572 
1573   if (!(IsSmallKernelType || IsLargeKernelType))
1574     return false;
1575 
1576   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1577   const SIInstrInfo *TII = ST.getInstrInfo();
1578 
1579   unsigned PipelineSyncID = 0;
1580   SchedGroup *SG = nullptr;
1581 
1582   unsigned MFMAChain = 0;
1583   unsigned PositionInChain = 0;
1584   unsigned CurrMFMAForTransPosition = 0;
1585 
1586   auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1587                                  &CurrMFMAForTransPosition]() {
1588     CurrMFMAForTransPosition += MFMAEnablement;
1589     PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1590     MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1591   };
1592 
1593   auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1594     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1595     return (TempMFMAForTrans / MFMAChains);
1596   };
1597 
1598   auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1599     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1600     return TempMFMAForTrans % MFMAChains;
1601   };
1602 
1603   unsigned CurrMFMAPosition = 0;
1604   unsigned MFMAChainForMFMA = 0;
1605   unsigned PositionInChainForMFMA = 0;
1606 
1607   auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1608                                 &PositionInChainForMFMA]() {
1609     ++CurrMFMAPosition;
1610     MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1611     PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1612   };
1613 
1614   bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1615   assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1616 
1617   bool UsesFMA = IsSmallKernelType || !IsPostRA;
1618   bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1619   bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1620   bool UsesVALU = IsSmallKernelType;
1621 
1622   // PHASE 1: "Prefetch"
1623   if (UsesFMA) {
1624     // First Round FMA
1625     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1626         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1627     if (!IsPostRA && MFMAChains) {
1628       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1629           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1630           true));
1631     } else
1632       SG->addRule(
1633           std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1634     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1635     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1636 
1637     // Second Round FMA
1638     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1639         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1640     if (!IsPostRA && MFMAChains) {
1641       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1642           getNextTransPositionInChain(),
1643           MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1644     } else
1645       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1646                                                    SG->getSGID(), true));
1647     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1648     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1649   }
1650 
1651   if (UsesDSRead) {
1652     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1653         SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1654     SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1655                                                       SG->getSGID()));
1656     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1657   }
1658 
1659   // First Round EXP
1660   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1661       SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);
1662   if (!IsPostRA && MFMAChains)
1663     SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1664         PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));
1665   else
1666     SG->addRule(std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1667   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1668   SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1669                                                HasChainBetweenCvt));
1670   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1671 
1672   incrementTransPosition();
1673 
1674   // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1675   for (unsigned I = 0; I < ExpRequirement; I++) {
1676     // First Round CVT
1677     if (UsesCvt) {
1678       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1679           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1680       SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1681       if (HasChainBetweenCvt)
1682         SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1683             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1684       else
1685         SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(
1686             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1687       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1688     }
1689 
1690     // Third Round FMA
1691     if (UsesFMA) {
1692       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1693           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1694       if (!IsPostRA && MFMAChains) {
1695         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1696             getNextTransPositionInChain(),
1697             MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1698       } else
1699         SG->addRule(std::make_shared<EnablesNthMFMA>(2 * MFMAEnablement + 1,
1700                                                      TII, SG->getSGID(), true));
1701       SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1702       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1703     }
1704 
1705     // Second Round EXP
1706     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1707         SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1708     if (!IsPostRA && MFMAChains)
1709       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1710           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1711           true));
1712     else
1713       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1714                                                    SG->getSGID(), true));
1715     SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1716     SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1717                                                  HasChainBetweenCvt));
1718     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1719   }
1720 
1721   // The "extra" EXP which enables all MFMA
1722   // TODO: UsesExtraExp
1723   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1724       SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1725   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1726   SG->addRule(std::make_shared<GreaterThanOrEqualToNSuccs>(
1727       8, TII, SG->getSGID(), HasChainBetweenCvt));
1728   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1729 
1730   // PHASE 2: Main Interleave Loop
1731 
1732   // The number of MFMAs per iteration
1733   unsigned MFMARatio =
1734       MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1735   // The number of Exps per iteration
1736   unsigned ExpRatio =
1737       MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1738   // The reamaining Exps
1739   unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1740                               ? TransPipeCount - (2 * ExpRequirement)
1741                               : 0;
1742   unsigned ExpLoopCount = RemainingExp / ExpRatio;
1743   // In loop MFMAs
1744   unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1745                             ? MFMAPipeCount - (MFMAEnablement * 2)
1746                             : 0;
1747   unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1748   unsigned VALUOps =
1749       AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1750   unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);
1751 
1752   for (unsigned I = 0; I < LoopSize; I++) {
1753     if (!(I * ExpRatio % ExpRequirement))
1754       incrementTransPosition();
1755 
1756     // Round N MFMA
1757     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1758         SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);
1759     if (!IsPostRA && MFMAChains)
1760       SG->addRule(std::make_shared<IsExactMFMA>(
1761           PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,
1762           SG->getSGID(), true));
1763     else
1764       SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1765     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1766     incrementMFMAPosition();
1767 
1768     if (UsesVALU) {
1769       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1770           SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);
1771       SG->addRule(std::make_shared<IsPipeAdd>(TII, SG->getSGID()));
1772       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1773     }
1774 
1775     if (UsesDSRead && !(I % 4)) {
1776       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1777           SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1778       SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1779                                                         SG->getSGID()));
1780       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1781     }
1782 
1783     // CVT, EXP, FMA Interleaving
1784     for (unsigned J = 0; J < ExpRatio; J++) {
1785       auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1786       auto MaxMFMAOffset =
1787           (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1788 
1789       // Round N + 1 CVT
1790       if (UsesCvt) {
1791         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1792             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1793         SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1794         auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1795         auto DSROffset = I / 4 + 1;
1796         auto MaxDSROffset = MaxMFMAOffset / 4;
1797         // TODO: UsesExtraExp
1798         auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1799         auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +
1800                              std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +
1801                              ExpOffset;
1802         if (HasChainBetweenCvt)
1803           SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1804               CurrentOffset, TII, SG->getSGID()));
1805         else
1806           SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(CurrentOffset, TII,
1807                                                              SG->getSGID()));
1808         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1809       }
1810 
1811       // Round N + 3 FMA
1812       if (UsesFMA) {
1813         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1814             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1815         if (!IsPostRA && MFMAChains)
1816           SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1817               getNextTransPositionInChain(),
1818               MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),
1819               true));
1820         else
1821           SG->addRule(std::make_shared<EnablesNthMFMA>(
1822               (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1823               TII, SG->getSGID(), true));
1824         SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1825         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1826       }
1827 
1828       // Round N + 2 Exp
1829       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1830           SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1831       if (!IsPostRA && MFMAChains)
1832         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1833             PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1834             true));
1835       else
1836         SG->addRule(std::make_shared<EnablesNthMFMA>(
1837             (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1838             TII, SG->getSGID(), true));
1839       SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1840       SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1841                                                    HasChainBetweenCvt));
1842       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1843     }
1844   }
1845 
1846   // PHASE 3: Remaining MFMAs
1847   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1848       SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);
1849   SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1850   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1851   return true;
1852 }
1853 
1854 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1855 private:
1856   // Whether the DS_READ is a predecessor of first four MFMA in region
1857   class EnablesInitialMFMA final : public InstructionRule {
1858   public:
1859     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1860                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1861       if (!SyncPipe.size())
1862         return false;
1863       int MFMAsFound = 0;
1864       if (!Cache->size()) {
1865         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
1866           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
1867             ++MFMAsFound;
1868             if (MFMAsFound > 4)
1869               break;
1870             Cache->push_back(&Elt);
1871           }
1872         }
1873       }
1874 
1875       assert(Cache->size());
1876       auto DAG = SyncPipe[0].DAG;
1877       for (auto &Elt : *Cache) {
1878         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
1879           return true;
1880       }
1881       return false;
1882     }
1883 
1884     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
1885                        bool NeedsCache = false)
1886         : InstructionRule(TII, SGID, NeedsCache) {}
1887   };
1888 
1889   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
1890   class IsPermForDSW final : public InstructionRule {
1891   public:
1892     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1893                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1894       auto MI = SU->getInstr();
1895       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
1896         return false;
1897 
1898       bool FitsInGroup = false;
1899       // Does the VALU have a DS_WRITE successor
1900       if (!Collection.size()) {
1901         for (auto &Succ : SU->Succs) {
1902           SUnit *SuccUnit = Succ.getSUnit();
1903           if (TII->isDS(*SuccUnit->getInstr()) &&
1904               SuccUnit->getInstr()->mayStore()) {
1905             Cache->push_back(SuccUnit);
1906             FitsInGroup = true;
1907           }
1908         }
1909         return FitsInGroup;
1910       }
1911 
1912       assert(Cache->size());
1913 
1914       // Does the VALU have a DS_WRITE successor that is the same as other
1915       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
1916       return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
1917         return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
1918           return ThisSucc.getSUnit() == Elt;
1919         });
1920       });
1921     }
1922 
1923     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1924         : InstructionRule(TII, SGID, NeedsCache) {}
1925   };
1926 
1927   // Whether the SU is a successor of any element in previous SchedGroup
1928   class IsSuccOfPrevGroup final : public InstructionRule {
1929   public:
1930     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1931                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1932       SchedGroup *OtherGroup = nullptr;
1933       for (auto &PipeSG : SyncPipe) {
1934         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
1935           OtherGroup = &PipeSG;
1936         }
1937       }
1938 
1939       if (!OtherGroup)
1940         return false;
1941       if (!OtherGroup->Collection.size())
1942         return true;
1943 
1944       // Does the previous VALU have this DS_Write as a successor
1945       return (std::any_of(OtherGroup->Collection.begin(),
1946                           OtherGroup->Collection.end(), [&SU](SUnit *Elt) {
1947                             return std::any_of(Elt->Succs.begin(),
1948                                                Elt->Succs.end(),
1949                                                [&SU](SDep &Succ) {
1950                                                  return Succ.getSUnit() == SU;
1951                                                });
1952                           }));
1953     }
1954     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1955                       bool NeedsCache = false)
1956         : InstructionRule(TII, SGID, NeedsCache) {}
1957   };
1958 
1959   // Whether the combined load width of group is 128 bits
1960   class VMEMSize final : public InstructionRule {
1961   public:
1962     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1963                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1964       auto MI = SU->getInstr();
1965       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1966         return false;
1967       if (!Collection.size())
1968         return true;
1969 
1970       int NumBits = 0;
1971 
1972       auto TRI = TII->getRegisterInfo();
1973       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1974       for (auto &Elt : Collection) {
1975         auto Op = Elt->getInstr()->getOperand(0);
1976         auto Size =
1977             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1978         NumBits += Size;
1979       }
1980 
1981       if (NumBits < 128) {
1982         assert(TII->isVMEM(*MI) && MI->mayLoad());
1983         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1984                           MRI, MI->getOperand(0))) <=
1985             128)
1986           return true;
1987       }
1988 
1989       return false;
1990     }
1991 
1992     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1993         : InstructionRule(TII, SGID, NeedsCache) {}
1994   };
1995 
1996   /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1997   /// that is \p Distance steps away
1998   class SharesPredWithPrevNthGroup final : public InstructionRule {
1999   private:
2000     unsigned Distance = 1;
2001 
2002   public:
2003     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2004                SmallVectorImpl<SchedGroup> &SyncPipe) override {
2005       SchedGroup *OtherGroup = nullptr;
2006       if (!SyncPipe.size())
2007         return false;
2008 
2009       if (!Cache->size()) {
2010 
2011         for (auto &PipeSG : SyncPipe) {
2012           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2013             OtherGroup = &PipeSG;
2014           }
2015         }
2016 
2017         if (!OtherGroup)
2018           return false;
2019         if (!OtherGroup->Collection.size())
2020           return true;
2021 
2022         for (auto &OtherEle : OtherGroup->Collection) {
2023           for (auto &Pred : OtherEle->Preds) {
2024             if (Pred.getSUnit()->getInstr()->getOpcode() ==
2025                 AMDGPU::V_PERM_B32_e64)
2026               Cache->push_back(Pred.getSUnit());
2027           }
2028         }
2029 
2030         // If the other group has no PERM preds, then this group won't share any
2031         if (!Cache->size())
2032           return false;
2033       }
2034 
2035       auto DAG = SyncPipe[0].DAG;
2036       // Does the previous DS_WRITE share a V_PERM predecessor with this
2037       // VMEM_READ
2038       return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
2039         return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
2040       });
2041     }
2042     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2043                                unsigned SGID, bool NeedsCache = false)
2044         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2045   };
2046 
2047 public:
2048   bool applyIGLPStrategy(
2049       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2050       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2051       AMDGPU::SchedulingPhase Phase) override;
2052 
2053   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2054                            AMDGPU::SchedulingPhase Phase) override {
2055     return true;
2056   }
2057 
2058   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2059       : IGLPStrategy(DAG, TII) {
2060     IsBottomUp = false;
2061   }
2062 };
2063 
2064 static unsigned DSWCount = 0;
2065 static unsigned DSWWithPermCount = 0;
2066 static unsigned DSWWithSharedVMEMCount = 0;
2067 
2068 bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2069     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2070     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2071     AMDGPU::SchedulingPhase Phase) {
2072   unsigned MFMACount = 0;
2073   unsigned DSRCount = 0;
2074 
2075   bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2076 
2077   assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2078                          DSWWithSharedVMEMCount == 0)) &&
2079          "DSWCounters should be zero in pre-RA scheduling!");
2080   SmallVector<SUnit *, 6> DSWithPerms;
2081   for (auto &SU : DAG->SUnits) {
2082     auto I = SU.getInstr();
2083     if (TII->isMFMAorWMMA(*I))
2084       ++MFMACount;
2085     else if (TII->isDS(*I)) {
2086       if (I->mayLoad())
2087         ++DSRCount;
2088       else if (I->mayStore() && IsInitial) {
2089         ++DSWCount;
2090         for (auto Pred : SU.Preds) {
2091           if (Pred.getSUnit()->getInstr()->getOpcode() ==
2092               AMDGPU::V_PERM_B32_e64) {
2093             DSWithPerms.push_back(&SU);
2094             break;
2095           }
2096         }
2097       }
2098     }
2099   }
2100 
2101   if (IsInitial) {
2102     DSWWithPermCount = DSWithPerms.size();
2103     auto I = DSWithPerms.begin();
2104     auto E = DSWithPerms.end();
2105 
2106     // Get the count of DS_WRITES with V_PERM predecessors which
2107     // have loop carried dependencies (WAR) on the same VMEM_READs.
2108     // We consider partial overlap as a miss -- in other words,
2109     // for a given DS_W, we only consider another DS_W as matching
2110     // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2111     // for every V_PERM pred of this DS_W.
2112     DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2113     SmallVector<SUnit *, 6> Counted;
2114     for (; I != E; I++) {
2115       SUnit *Cand = nullptr;
2116       bool MissedAny = false;
2117       for (auto &Pred : (*I)->Preds) {
2118         if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2119           continue;
2120 
2121         if (Cand && llvm::is_contained(Counted, Cand))
2122           break;
2123 
2124         for (auto &Succ : Pred.getSUnit()->Succs) {
2125           auto MI = Succ.getSUnit()->getInstr();
2126           if (!TII->isVMEM(*MI) || !MI->mayLoad())
2127             continue;
2128 
2129           if (MissedAny || !VMEMLookup.size()) {
2130             MissedAny = true;
2131             VMEMLookup[MI] = *I;
2132             continue;
2133           }
2134 
2135           if (!VMEMLookup.contains(MI)) {
2136             MissedAny = true;
2137             VMEMLookup[MI] = *I;
2138             continue;
2139           }
2140 
2141           Cand = VMEMLookup[MI];
2142           if (llvm::is_contained(Counted, Cand)) {
2143             MissedAny = true;
2144             break;
2145           }
2146         }
2147       }
2148       if (!MissedAny && Cand) {
2149         DSWWithSharedVMEMCount += 2;
2150         Counted.push_back(Cand);
2151         Counted.push_back(*I);
2152       }
2153     }
2154   }
2155 
2156   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2157   SchedGroup *SG;
2158   unsigned PipelineSyncID = 0;
2159   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2160   if (DSWWithPermCount) {
2161     for (unsigned I = 0; I < MFMACount; I++) {
2162       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2163           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2164       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2165 
2166       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2167           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
2168       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2169     }
2170   }
2171 
2172   PipelineSyncID = 1;
2173   // Phase 1: Break up DS_READ and MFMA clusters.
2174   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2175   // prefetch
2176 
2177   // Make ready initial MFMA
2178   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2179       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
2180   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
2181   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2182 
2183   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2184       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2185   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2186 
2187   // Interleave MFMA with DS_READ prefetch
2188   for (unsigned I = 0; I < DSRCount - 4; ++I) {
2189     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2190         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
2191     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2192 
2193     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2194         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2195     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2196   }
2197 
2198   // Phase 2a: Loop carried dependency with V_PERM
2199   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2200   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2201   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
2202     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2203         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2204     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2205     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2206 
2207     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2208         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2209     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2210     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2211 
2212     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2213         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2214     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2215         1, TII, SG->getSGID(), true));
2216     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2217     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2218 
2219     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2220         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2221     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2222 
2223     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2224         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2225     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2226         3, TII, SG->getSGID(), true));
2227     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2228     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2229 
2230     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2231         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2232     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2233   }
2234 
2235   // Phase 2b: Loop carried dependency without V_PERM
2236   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2237   // Interleave MFMA to keep XDL unit busy throughout.
2238   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
2239     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2240         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2241     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2242 
2243     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2244         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2245     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2246     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2247 
2248     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2249         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2250     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2251   }
2252 
2253   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2254   // ultimately used by two DS_WRITE
2255   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2256   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2257 
2258   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2259     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2260         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2261     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2262     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2263 
2264     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2265         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2266     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2267     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2268 
2269     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2270         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2271     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2272 
2273     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2274         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2275     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2276     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2277 
2278     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2279         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2280     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2281     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2282 
2283     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2284         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2285     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2286 
2287     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2288         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2289     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2290         2, TII, SG->getSGID(), true));
2291     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2292     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2293 
2294     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2295         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2296     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2297 
2298     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2299         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2300     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2301         4, TII, SG->getSGID(), true));
2302     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2303     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2304 
2305     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2306         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2307     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2308   }
2309 
2310   return true;
2311 }
2312 
2313 static std::unique_ptr<IGLPStrategy>
2314 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2315                    const SIInstrInfo *TII) {
2316   switch (ID) {
2317   case MFMASmallGemmOptID:
2318     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
2319   case MFMASmallGemmSingleWaveOptID:
2320     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
2321   case MFMAExpInterleave:
2322     return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
2323   }
2324 
2325   llvm_unreachable("Unknown IGLPStrategyID");
2326 }
2327 
2328 class IGroupLPDAGMutation : public ScheduleDAGMutation {
2329 private:
2330   const SIInstrInfo *TII;
2331 
2332   ScheduleDAGMI *DAG;
2333 
2334   // Organize lists of SchedGroups by their SyncID. SchedGroups /
2335   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2336   // between then.
2337   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2338 
2339   // Used to track instructions that can be mapped to multiple sched groups
2340   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2341 
2342   // Add DAG edges that enforce SCHED_BARRIER ordering.
2343   void addSchedBarrierEdges(SUnit &SU);
2344 
2345   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2346   // not be reordered accross the SCHED_BARRIER. This is used for the base
2347   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2348   // SCHED_BARRIER will always block all instructions that can be classified
2349   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2350   // and may only synchronize with some SchedGroups. Returns the inverse of
2351   // Mask. SCHED_BARRIER's mask describes which instruction types should be
2352   // allowed to be scheduled across it. Invert the mask to get the
2353   // SchedGroupMask of instructions that should be barred.
2354   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2355 
2356   // Create SchedGroups for a SCHED_GROUP_BARRIER.
2357   void initSchedGroupBarrierPipelineStage(
2358       std::vector<SUnit>::reverse_iterator RIter);
2359 
2360   bool initIGLPOpt(SUnit &SU);
2361 
2362 public:
2363   void apply(ScheduleDAGInstrs *DAGInstrs) override;
2364 
2365   // The order in which the PipelineSolver should process the candidate
2366   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2367   // created SchedGroup first, and will consider that as the ultimate
2368   // predecessor group when linking. TOP_DOWN instead links and processes the
2369   // first created SchedGroup first.
2370   bool IsBottomUp = true;
2371 
2372   // The scheduling phase this application of IGLP corresponds with.
2373   AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2374 
2375   IGroupLPDAGMutation() = default;
2376   IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2377 };
2378 
2379 unsigned SchedGroup::NumSchedGroups = 0;
2380 
2381 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2382   if (A != B && DAG->canAddEdge(B, A)) {
2383     DAG->addEdge(B, SDep(A, SDep::Artificial));
2384     return true;
2385   }
2386   return false;
2387 }
2388 
2389 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2390   bool Result = false;
2391   if (MI.isMetaInstruction())
2392     Result = false;
2393 
2394   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2395            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
2396             TII->isTRANS(MI)))
2397     Result = true;
2398 
2399   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2400            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
2401     Result = true;
2402 
2403   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2404            TII->isSALU(MI))
2405     Result = true;
2406 
2407   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2408            TII->isMFMAorWMMA(MI))
2409     Result = true;
2410 
2411   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2412            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2413     Result = true;
2414 
2415   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2416            MI.mayLoad() &&
2417            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2418     Result = true;
2419 
2420   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2421            MI.mayStore() &&
2422            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2423     Result = true;
2424 
2425   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2426            TII->isDS(MI))
2427     Result = true;
2428 
2429   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2430            MI.mayLoad() && TII->isDS(MI))
2431     Result = true;
2432 
2433   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2434            MI.mayStore() && TII->isDS(MI))
2435     Result = true;
2436 
2437   else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2438            TII->isTRANS(MI))
2439     Result = true;
2440 
2441   LLVM_DEBUG(
2442       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2443              << (Result ? " could classify " : " unable to classify ") << MI);
2444 
2445   return Result;
2446 }
2447 
2448 int SchedGroup::link(SUnit &SU, bool MakePred,
2449                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2450   int MissedEdges = 0;
2451   for (auto *A : Collection) {
2452     SUnit *B = &SU;
2453     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2454       continue;
2455     if (MakePred)
2456       std::swap(A, B);
2457 
2458     if (DAG->IsReachable(B, A))
2459       continue;
2460 
2461     // tryAddEdge returns false if there is a dependency that makes adding
2462     // the A->B edge impossible, otherwise it returns true;
2463     bool Added = tryAddEdge(A, B);
2464     if (Added)
2465       AddedEdges.emplace_back(A, B);
2466     else
2467       ++MissedEdges;
2468   }
2469 
2470   return MissedEdges;
2471 }
2472 
2473 void SchedGroup::link(SUnit &SU, bool MakePred) {
2474   for (auto *A : Collection) {
2475     SUnit *B = &SU;
2476     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2477       continue;
2478     if (MakePred)
2479       std::swap(A, B);
2480 
2481     tryAddEdge(A, B);
2482   }
2483 }
2484 
2485 void SchedGroup::link(SUnit &SU,
2486                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2487   for (auto *A : Collection) {
2488     SUnit *B = &SU;
2489     if (P(A, B))
2490       std::swap(A, B);
2491 
2492     tryAddEdge(A, B);
2493   }
2494 }
2495 
2496 void SchedGroup::link(SchedGroup &OtherGroup) {
2497   for (auto *B : OtherGroup.Collection)
2498     link(*B);
2499 }
2500 
2501 bool SchedGroup::canAddSU(SUnit &SU) const {
2502   MachineInstr &MI = *SU.getInstr();
2503   if (MI.getOpcode() != TargetOpcode::BUNDLE)
2504     return canAddMI(MI);
2505 
2506   // Special case for bundled MIs.
2507   const MachineBasicBlock *MBB = MI.getParent();
2508   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2509   while (E != MBB->end() && E->isBundledWithPred())
2510     ++E;
2511 
2512   // Return true if all of the bundled MIs can be added to this group.
2513   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
2514 }
2515 
2516 void SchedGroup::initSchedGroup() {
2517   for (auto &SU : DAG->SUnits) {
2518     if (isFull())
2519       break;
2520 
2521     if (canAddSU(SU))
2522       add(SU);
2523   }
2524 }
2525 
2526 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
2527                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
2528   SUnit &InitSU = *RIter;
2529   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
2530     auto &SU = *RIter;
2531     if (isFull())
2532       break;
2533 
2534     if (canAddSU(SU))
2535       SyncedInstrs[&SU].push_back(SGID);
2536   }
2537 
2538   add(InitSU);
2539   assert(MaxSize);
2540   (*MaxSize)++;
2541 }
2542 
2543 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
2544   auto I = DAG->SUnits.rbegin();
2545   auto E = DAG->SUnits.rend();
2546   for (; I != E; ++I) {
2547     auto &SU = *I;
2548     if (isFull())
2549       break;
2550     if (canAddSU(SU))
2551       SyncedInstrs[&SU].push_back(SGID);
2552   }
2553 }
2554 
2555 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2556   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2557   if (!TSchedModel || DAGInstrs->SUnits.empty())
2558     return;
2559 
2560   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2561   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2562   TII = ST.getInstrInfo();
2563   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2564   SyncedSchedGroups.clear();
2565   SyncedInstrs.clear();
2566   bool FoundSB = false;
2567   bool FoundIGLP = false;
2568   bool ShouldApplyIGLP = false;
2569   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2570     unsigned Opc = R->getInstr()->getOpcode();
2571     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2572     if (Opc == AMDGPU::SCHED_BARRIER) {
2573       addSchedBarrierEdges(*R);
2574       FoundSB = true;
2575     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2576       initSchedGroupBarrierPipelineStage(R);
2577       FoundSB = true;
2578     } else if (Opc == AMDGPU::IGLP_OPT) {
2579       resetEdges(*R, DAG);
2580       if (!FoundSB && !FoundIGLP) {
2581         FoundIGLP = true;
2582         ShouldApplyIGLP = initIGLPOpt(*R);
2583       }
2584     }
2585   }
2586 
2587   if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2588     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2589     // PipelineSolver performs the mutation by adding the edges it
2590     // determined as the best
2591     PS.solve();
2592     return;
2593   }
2594 }
2595 
2596 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2597   MachineInstr &MI = *SchedBarrier.getInstr();
2598   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2599   // Remove all existing edges from the SCHED_BARRIER that were added due to the
2600   // instruction having side effects.
2601   resetEdges(SchedBarrier, DAG);
2602   LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2603                     << MI.getOperand(0).getImm() << "\n");
2604   auto InvertedMask =
2605       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
2606   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2607   SG.initSchedGroup();
2608 
2609   // Preserve original instruction ordering relative to the SCHED_BARRIER.
2610   SG.link(
2611       SchedBarrier,
2612       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2613           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2614 }
2615 
2616 SchedGroupMask
2617 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2618   // Invert mask and erase bits for types of instructions that are implied to be
2619   // allowed past the SCHED_BARRIER.
2620   SchedGroupMask InvertedMask = ~Mask;
2621 
2622   // ALU implies VALU, SALU, MFMA, TRANS.
2623   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2624     InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2625                     ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2626   // VALU, SALU, MFMA, TRANS implies ALU.
2627   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2628            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2629            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2630            (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2631     InvertedMask &= ~SchedGroupMask::ALU;
2632 
2633   // VMEM implies VMEM_READ, VMEM_WRITE.
2634   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2635     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
2636   // VMEM_READ, VMEM_WRITE implies VMEM.
2637   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2638            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
2639     InvertedMask &= ~SchedGroupMask::VMEM;
2640 
2641   // DS implies DS_READ, DS_WRITE.
2642   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2643     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
2644   // DS_READ, DS_WRITE implies DS.
2645   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2646            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2647     InvertedMask &= ~SchedGroupMask::DS;
2648 
2649   LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2650                     << "\n");
2651 
2652   return InvertedMask;
2653 }
2654 
2655 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2656     std::vector<SUnit>::reverse_iterator RIter) {
2657   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
2658   // to the instruction having side effects.
2659   resetEdges(*RIter, DAG);
2660   MachineInstr &SGB = *RIter->getInstr();
2661   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2662   int32_t SGMask = SGB.getOperand(0).getImm();
2663   int32_t Size = SGB.getOperand(1).getImm();
2664   int32_t SyncID = SGB.getOperand(2).getImm();
2665 
2666   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
2667                                                     Size, SyncID, DAG, TII);
2668 
2669   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
2670 }
2671 
2672 bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2673   IGLPStrategyID StrategyID =
2674       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
2675   auto S = createIGLPStrategy(StrategyID, DAG, TII);
2676   if (!S->shouldApplyStrategy(DAG, Phase))
2677     return false;
2678 
2679   IsBottomUp = S->IsBottomUp;
2680   return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2681 }
2682 
2683 } // namespace
2684 
2685 namespace llvm {
2686 
2687 /// \p Phase specifes whether or not this is a reentry into the
2688 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2689 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
2690 /// scheduling "phases"), we can reenter this mutation framework more than once
2691 /// for a given region.
2692 std::unique_ptr<ScheduleDAGMutation>
2693 createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) {
2694   return std::make_unique<IGroupLPDAGMutation>(Phase);
2695 }
2696 
2697 } // end namespace llvm
2698