xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
20 #include "SIInstrInfo.h"
21 #include "SIMachineFunctionInfo.h"
22 #include "llvm/ADT/BitmaskEnum.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/CodeGen/MachineScheduler.h"
25 #include "llvm/CodeGen/TargetOpcodes.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "igrouplp"
30 
31 namespace {
32 
33 static cl::opt<bool> EnableExactSolver(
34     "amdgpu-igrouplp-exact-solver", cl::Hidden,
35     cl::desc("Whether to use the exponential time solver to fit "
36              "the instructions to the pipeline as closely as "
37              "possible."),
38     cl::init(false));
39 
40 static cl::opt<unsigned> CutoffForExact(
41     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
42     cl::desc("The maximum number of scheduling group conflicts "
43              "which we attempt to solve with the exponential time "
44              "exact solver. Problem sizes greater than this will"
45              "be solved by the less accurate greedy algorithm. Selecting "
46              "solver by size is superseded by manually selecting "
47              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
48 
49 static cl::opt<uint64_t> MaxBranchesExplored(
50     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
51     cl::desc("The amount of branches that we are willing to explore with"
52              "the exact algorithm before giving up."));
53 
54 static cl::opt<bool> UseCostHeur(
55     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
56     cl::desc("Whether to use the cost heuristic to make choices as we "
57              "traverse the search space using the exact solver. Defaulted "
58              "to on, and if turned off, we will use the node order -- "
59              "attempting to put the later nodes in the later sched groups. "
60              "Experimentally, results are mixed, so this should be set on a "
61              "case-by-case basis."));
62 
63 // Components of the mask that determines which instruction types may be may be
64 // classified into a SchedGroup.
65 enum class SchedGroupMask {
66   NONE = 0u,
67   ALU = 1u << 0,
68   VALU = 1u << 1,
69   SALU = 1u << 2,
70   MFMA = 1u << 3,
71   VMEM = 1u << 4,
72   VMEM_READ = 1u << 5,
73   VMEM_WRITE = 1u << 6,
74   DS = 1u << 7,
75   DS_READ = 1u << 8,
76   DS_WRITE = 1u << 9,
77   TRANS = 1u << 10,
78   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79         DS_READ | DS_WRITE | TRANS,
80   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
81 };
82 
83 class SchedGroup;
84 
85 // InstructionRule class is used to enact a filter which determines whether or
86 // not an SU maps to a given SchedGroup. It contains complementary data
87 // structures (e.g Cache) to help those filters.
88 class InstructionRule {
89 protected:
90   const SIInstrInfo *TII;
91   unsigned SGID;
92   // A cache made available to the Filter to store SUnits for subsequent
93   // invocations of the Filter
94   std::optional<SmallVector<SUnit *, 4>> Cache;
95 
96 public:
97   virtual bool
98   apply(const SUnit *, const ArrayRef<SUnit *>,
99         SmallVectorImpl<SchedGroup> &) {
100     return true;
101   };
102 
103   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
104                   bool NeedsCache = false)
105       : TII(TII), SGID(SGID) {
106     if (NeedsCache) {
107       Cache = SmallVector<SUnit *, 4>();
108     }
109   }
110 
111   virtual ~InstructionRule() = default;
112 };
113 
114 using SUnitsToCandidateSGsMap = DenseMap<SUnit *, SmallVector<int, 4>>;
115 
116 // Classify instructions into groups to enable fine tuned control over the
117 // scheduler. These groups may be more specific than current SchedModel
118 // instruction classes.
119 class SchedGroup {
120 private:
121   // Mask that defines which instruction types can be classified into this
122   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
123   // and SCHED_GROUP_BARRIER.
124   SchedGroupMask SGMask;
125 
126   // Maximum number of SUnits that can be added to this group.
127   std::optional<unsigned> MaxSize;
128 
129   // SchedGroups will only synchronize with other SchedGroups that have the same
130   // SyncID.
131   int SyncID = 0;
132 
133   // SGID is used to map instructions to candidate SchedGroups
134   unsigned SGID;
135 
136   // The different rules each instruction in this SchedGroup must conform to
137   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
138 
139   // Count of the number of created SchedGroups, used to initialize SGID.
140   static unsigned NumSchedGroups;
141 
142   // Try to add and edge from SU A to SU B.
143   bool tryAddEdge(SUnit *A, SUnit *B);
144 
145   // Use SGMask to determine whether we can classify MI as a member of this
146   // SchedGroup object.
147   bool canAddMI(const MachineInstr &MI) const;
148 
149 public:
150   // Collection of SUnits that are classified as members of this group.
151   SmallVector<SUnit *, 32> Collection;
152 
153   ScheduleDAGInstrs *DAG;
154   const SIInstrInfo *TII;
155 
156   // Returns true if SU can be added to this SchedGroup.
157   bool canAddSU(SUnit &SU) const;
158 
159   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
160   // MakePred is true, SU will be a predecessor of the SUnits in this
161   // SchedGroup, otherwise SU will be a successor.
162   void link(SUnit &SU, bool MakePred = false);
163 
164   // Add DAG dependencies and track which edges are added, and the count of
165   // missed edges
166   int link(SUnit &SU, bool MakePred,
167            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
168 
169   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
170   // Use the predicate to determine whether SU should be a predecessor (P =
171   // true) or a successor (P = false) of this SchedGroup.
172   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
173 
174   // Add DAG dependencies such that SUnits in this group shall be ordered
175   // before SUnits in OtherGroup.
176   void link(SchedGroup &OtherGroup);
177 
178   // Returns true if no more instructions may be added to this group.
179   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
180 
181   // Append a constraint that SUs must meet in order to fit into this
182   // SchedGroup. Since many rules involve the relationship between a SchedGroup
183   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
184   // time (rather than SchedGroup init time.)
185   void addRule(std::shared_ptr<InstructionRule> NewRule) {
186     Rules.push_back(NewRule);
187   }
188 
189   // Returns true if the SU matches all rules
190   bool allowedByRules(const SUnit *SU,
191                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
192     for (auto &Rule : Rules) {
193       if (!Rule->apply(SU, Collection, SyncPipe))
194         return false;
195     }
196     return true;
197   }
198 
199   // Add SU to the SchedGroup.
200   void add(SUnit &SU) {
201     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
202                       << format_hex((int)SGMask, 10, true) << " adding "
203                       << *SU.getInstr());
204     Collection.push_back(&SU);
205   }
206 
207   // Remove last element in the SchedGroup
208   void pop() { Collection.pop_back(); }
209 
210   // Identify and add all relevant SUs from the DAG to this SchedGroup.
211   void initSchedGroup();
212 
213   // Add instructions to the SchedGroup bottom up starting from RIter.
214   // PipelineInstrs is a set of instructions that should not be added to the
215   // SchedGroup even when the other conditions for adding it are satisfied.
216   // RIter will be added to the SchedGroup as well, and dependencies will be
217   // added so that RIter will always be scheduled at the end of the group.
218   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
219                       SUnitsToCandidateSGsMap &SyncedInstrs);
220 
221   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
222 
223   int getSyncID() { return SyncID; }
224 
225   int getSGID() { return SGID; }
226 
227   SchedGroupMask getMask() { return SGMask; }
228 
229   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
230              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
231       : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
232     SGID = NumSchedGroups++;
233   }
234 
235   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
236              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
237       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
238     SGID = NumSchedGroups++;
239   }
240 };
241 
242 using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;
243 using SUsToCandSGsVec = SmallVector<SUToCandSGsPair, 4>;
244 
245 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
246 // in non-trivial cases. For example, if the requested pipeline is
247 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
248 // in the DAG, then we will have an instruction that can not be trivially
249 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
250 // to find a good solution to the pipeline -- a greedy algorithm and an exact
251 // algorithm. The exact algorithm has an exponential time complexity and should
252 // only be used for small sized problems or medium sized problems where an exact
253 // solution is highly desired.
254 class PipelineSolver {
255   [[maybe_unused]] ScheduleDAGMI *DAG;
256 
257   // Instructions that can be assigned to multiple SchedGroups
258   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
259   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
260   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
261   // The current working pipeline
262   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
263   // The pipeline that has the best solution found so far
264   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
265 
266   // Whether or not we actually have any SyncedInstrs to try to solve.
267   bool NeedsSolver = false;
268 
269   // Compute an estimate of the size of search tree -- the true size is
270   // the product of each conflictedInst.Matches.size() across all SyncPipelines
271   unsigned computeProblemSize();
272 
273   // The cost penalty of not assigning a SU to a SchedGroup
274   int MissPenalty = 0;
275 
276   // Costs in terms of the number of edges we are unable to add
277   int BestCost = -1;
278   int CurrCost = 0;
279 
280   // Index pointing to the conflicting instruction that is currently being
281   // fitted
282   int CurrConflInstNo = 0;
283   // Index to the pipeline that is currently being fitted
284   int CurrSyncGroupIdx = 0;
285   // The first non trivial pipeline
286   int BeginSyncGroupIdx = 0;
287 
288   // How many branches we have explored
289   uint64_t BranchesExplored = 0;
290 
291   // The direction in which we process the candidate SchedGroups per SU
292   bool IsBottomUp = true;
293 
294   // Update indices to fit next conflicting instruction
295   void advancePosition();
296   // Recede indices to attempt to find better fit for previous conflicting
297   // instruction
298   void retreatPosition();
299 
300   // The exponential time algorithm which finds the provably best fit
301   bool solveExact();
302   // The polynomial time algorithm which attempts to find a good fit
303   bool solveGreedy();
304   // Find the best SchedGroup for the current SU using the heuristic given all
305   // current information. One step in the greedy algorithm. Templated against
306   // the SchedGroup iterator (either reverse or forward).
307   template <typename T>
308   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
309                   T E);
310   // Whether or not the current solution is optimal
311   bool checkOptimal();
312   // Populate the ready list, prioiritizing fewest missed edges first
313   // Templated against the SchedGroup iterator (either reverse or forward).
314   template <typename T>
315   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
316                          T E);
317   // Add edges corresponding to the SchedGroups as assigned by solver
318   void makePipeline();
319   // Link the SchedGroups in the best found pipeline.
320   // Tmplated against the SchedGroup iterator (either reverse or forward).
321   template <typename T> void linkSchedGroups(T I, T E);
322   // Add the edges from the SU to the other SchedGroups in pipeline, and
323   // return the number of edges missed.
324   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
325                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
326   /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
327   /// returns the cost (in terms of missed pipeline edges), and tracks the edges
328   /// added in \p AddedEdges
329   template <typename T>
330   int linkSUnit(SUnit *SU, int SGID,
331                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
332   /// Remove the edges passed via \p AddedEdges
333   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
334   // Convert the passed in maps to arrays for bidirectional iterators
335   void convertSyncMapsToArrays();
336 
337   void reset();
338 
339 public:
340   // Invoke the solver to map instructions to instruction groups. Heuristic &&
341   // command-line-option determines to use exact or greedy algorithm.
342   void solve();
343 
344   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
345                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
346                  ScheduleDAGMI *DAG, bool IsBottomUp = true)
347       : DAG(DAG), SyncedInstrs(SyncedInstrs),
348         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
349 
350     for (auto &PipelineInstrs : SyncedInstrs) {
351       if (PipelineInstrs.second.size() > 0) {
352         NeedsSolver = true;
353         break;
354       }
355     }
356 
357     if (!NeedsSolver)
358       return;
359 
360     convertSyncMapsToArrays();
361 
362     CurrPipeline = BestPipeline;
363 
364     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
365            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
366       ++BeginSyncGroupIdx;
367 
368     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
369       return;
370   }
371 };
372 
373 void PipelineSolver::reset() {
374 
375   for (auto &SyncPipeline : CurrPipeline) {
376     for (auto &SG : SyncPipeline) {
377       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
378       SG.Collection.clear();
379       auto *SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
380         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
381       });
382       if (SchedBarr != TempCollection.end())
383         SG.Collection.push_back(*SchedBarr);
384     }
385   }
386 
387   CurrSyncGroupIdx = BeginSyncGroupIdx;
388   CurrConflInstNo = 0;
389   CurrCost = 0;
390 }
391 
392 void PipelineSolver::convertSyncMapsToArrays() {
393   for (auto &SyncPipe : SyncedSchedGroups) {
394     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
395   }
396 
397   int PipelineIDx = SyncedInstrs.size() - 1;
398   PipelineInstrs.resize(SyncedInstrs.size());
399   for (auto &SyncInstrMap : SyncedInstrs) {
400     for (auto &SUsToCandSGs : SyncInstrMap.second) {
401       if (PipelineInstrs[PipelineIDx].size() == 0) {
402         PipelineInstrs[PipelineIDx].push_back(
403             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
404         continue;
405       }
406       auto *SortPosition = PipelineInstrs[PipelineIDx].begin();
407       // Insert them in sorted order -- this allows for good parsing order in
408       // the greedy algorithm
409       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
410              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
411         ++SortPosition;
412       PipelineInstrs[PipelineIDx].insert(
413           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
414     }
415     --PipelineIDx;
416   }
417 }
418 
419 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
420   for (; I != E; ++I) {
421     auto &GroupA = *I;
422     for (auto J = std::next(I); J != E; ++J) {
423       auto &GroupB = *J;
424       GroupA.link(GroupB);
425     }
426   }
427 }
428 
429 void PipelineSolver::makePipeline() {
430   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
431   for (auto &SyncPipeline : BestPipeline) {
432     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
433     for (auto &SG : SyncPipeline) {
434       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
435                         << " has: \n");
436       SUnit *SGBarr = nullptr;
437       for (auto &SU : SG.Collection) {
438         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
439           SGBarr = SU;
440         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
441       }
442       // Command line requested IGroupLP doesn't have SGBarr
443       if (!SGBarr)
444         continue;
445       SG.link(*SGBarr, false);
446     }
447   }
448 
449   for (auto &SyncPipeline : BestPipeline) {
450     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
451                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
452   }
453 }
454 
455 template <typename T>
456 int PipelineSolver::linkSUnit(
457     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
458     T I, T E) {
459   bool MakePred = false;
460   int AddedCost = 0;
461   for (; I < E; ++I) {
462     if (I->getSGID() == SGID) {
463       MakePred = true;
464       continue;
465     }
466     auto Group = *I;
467     AddedCost += Group.link(*SU, MakePred, AddedEdges);
468     assert(AddedCost >= 0);
469   }
470   return AddedCost;
471 }
472 
473 int PipelineSolver::addEdges(
474     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
475     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
476 
477   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
478   // instructions that are the ultimate successors in the resultant mutation.
479   // Therefore, in such a configuration, the SchedGroups occurring before the
480   // candidate SGID are successors of the candidate SchedGroup, thus the current
481   // SU should be linked as a predecessor to SUs in those SchedGroups. The
482   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
483   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
484   // IsBottomUp (in reverse).
485   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
486                                 SyncPipeline.rend())
487                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
488                                 SyncPipeline.end());
489 }
490 
491 void PipelineSolver::removeEdges(
492     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
493   // Only remove the edges that we have added when testing
494   // the fit.
495   for (auto &PredSuccPair : EdgesToRemove) {
496     SUnit *Pred = PredSuccPair.first;
497     SUnit *Succ = PredSuccPair.second;
498 
499     auto *Match = llvm::find_if(
500         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
501     if (Match != Succ->Preds.end()) {
502       assert(Match->isArtificial());
503       Succ->removePred(*Match);
504     }
505   }
506 }
507 
508 void PipelineSolver::advancePosition() {
509   ++CurrConflInstNo;
510 
511   if (static_cast<size_t>(CurrConflInstNo) >=
512       PipelineInstrs[CurrSyncGroupIdx].size()) {
513     CurrConflInstNo = 0;
514     ++CurrSyncGroupIdx;
515     // Advance to next non-trivial pipeline
516     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
517            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
518       ++CurrSyncGroupIdx;
519   }
520 }
521 
522 void PipelineSolver::retreatPosition() {
523   assert(CurrConflInstNo >= 0);
524   assert(CurrSyncGroupIdx >= 0);
525 
526   if (CurrConflInstNo > 0) {
527     --CurrConflInstNo;
528     return;
529   }
530 
531   if (CurrConflInstNo == 0) {
532     // If we return to the starting position, we have explored
533     // the entire tree
534     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
535       return;
536 
537     --CurrSyncGroupIdx;
538     // Go to previous non-trivial pipeline
539     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
540       --CurrSyncGroupIdx;
541 
542     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
543   }
544 }
545 
546 bool PipelineSolver::checkOptimal() {
547   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
548     if (BestCost == -1 || CurrCost < BestCost) {
549       BestPipeline = CurrPipeline;
550       BestCost = CurrCost;
551       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
552     }
553     assert(BestCost >= 0);
554   }
555 
556   bool DoneExploring = false;
557   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
558     DoneExploring = true;
559 
560   return (DoneExploring || BestCost == 0);
561 }
562 
563 template <typename T>
564 void PipelineSolver::populateReadyList(
565     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
566   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
567   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
568   assert(CurrSU.second.size() >= 1);
569 
570   for (; I != E; ++I) {
571     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
572     int CandSGID = *I;
573     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
574       return SG.getSGID() == CandSGID;
575     });
576     assert(Match);
577 
578     if (UseCostHeur) {
579       if (Match->isFull()) {
580         ReadyList.push_back(std::pair(*I, MissPenalty));
581         continue;
582       }
583 
584       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
585       ReadyList.push_back(std::pair(*I, TempCost));
586       removeEdges(AddedEdges);
587     } else
588       ReadyList.push_back(std::pair(*I, -1));
589   }
590 
591   if (UseCostHeur)
592     std::sort(ReadyList.begin(), ReadyList.end(), llvm::less_second());
593 
594   assert(ReadyList.size() == CurrSU.second.size());
595 }
596 
597 bool PipelineSolver::solveExact() {
598   if (checkOptimal())
599     return true;
600 
601   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
602     return false;
603 
604   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
605   assert(static_cast<size_t>(CurrConflInstNo) <
606          PipelineInstrs[CurrSyncGroupIdx].size());
607   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
608   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
609                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
610 
611   // SchedGroup -> Cost pairs
612   SmallVector<std::pair<int, int>, 4> ReadyList;
613   // Prioritize the candidate sched groups in terms of lowest cost first
614   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
615                                  CurrSU.second.rend())
616              : populateReadyList(ReadyList, CurrSU.second.begin(),
617                                  CurrSU.second.end());
618 
619   auto *I = ReadyList.begin();
620   auto *E = ReadyList.end();
621   for (; I != E; ++I) {
622     // If we are trying SGs in least cost order, and the current SG is cost
623     // infeasible, then all subsequent SGs will also be cost infeasible, so we
624     // can prune.
625     if (BestCost != -1 && (CurrCost + I->second > BestCost))
626       return false;
627 
628     int CandSGID = I->first;
629     int AddedCost = 0;
630     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
631     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
632     SchedGroup *Match;
633     for (auto &SG : SyncPipeline) {
634       if (SG.getSGID() == CandSGID)
635         Match = &SG;
636     }
637 
638     if (Match->isFull())
639       continue;
640 
641     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
642       continue;
643 
644     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
645                       << (int)Match->getMask() << "and ID " << CandSGID
646                       << "\n");
647     Match->add(*CurrSU.first);
648     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
649     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
650     CurrCost += AddedCost;
651     advancePosition();
652     ++BranchesExplored;
653     bool FinishedExploring = false;
654     // If the Cost after adding edges is greater than a known solution,
655     // backtrack
656     if (CurrCost < BestCost || BestCost == -1) {
657       if (solveExact()) {
658         FinishedExploring = BestCost != 0;
659         if (!FinishedExploring)
660           return true;
661       }
662     }
663 
664     retreatPosition();
665     CurrCost -= AddedCost;
666     removeEdges(AddedEdges);
667     Match->pop();
668     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
669     if (FinishedExploring)
670       return true;
671   }
672 
673   // Try the pipeline where the current instruction is omitted
674   // Potentially if we omit a problematic instruction from the pipeline,
675   // all the other instructions can nicely fit.
676   CurrCost += MissPenalty;
677   advancePosition();
678 
679   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
680 
681   bool FinishedExploring = false;
682   if (CurrCost < BestCost || BestCost == -1) {
683     if (solveExact()) {
684       bool FinishedExploring = BestCost != 0;
685       if (!FinishedExploring)
686         return true;
687     }
688   }
689 
690   retreatPosition();
691   CurrCost -= MissPenalty;
692   return FinishedExploring;
693 }
694 
695 template <typename T>
696 void PipelineSolver::greedyFind(
697     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
698   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
699   int BestNodeCost = -1;
700   int TempCost;
701   SchedGroup *BestGroup = nullptr;
702   int BestGroupID = -1;
703   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
704   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
705                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
706 
707   // Since we have added the potential SchedGroups from bottom up, but
708   // traversed the DAG from top down, parse over the groups from last to
709   // first. If we fail to do this for the greedy algorithm, the solution will
710   // likely not be good in more complex cases.
711   for (; I != E; ++I) {
712     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
713     int CandSGID = *I;
714     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
715       return SG.getSGID() == CandSGID;
716     });
717     assert(Match);
718 
719     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
720                       << (int)Match->getMask() << "\n");
721 
722     if (Match->isFull()) {
723       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
724       continue;
725     }
726     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
727       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
728       continue;
729     }
730     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
731     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
732     if (TempCost < BestNodeCost || BestNodeCost == -1) {
733       BestGroup = Match;
734       BestNodeCost = TempCost;
735       BestGroupID = CandSGID;
736     }
737     removeEdges(AddedEdges);
738     if (BestNodeCost == 0)
739       break;
740   }
741 
742   if (BestGroupID != -1) {
743     BestGroup->add(*CurrSU.first);
744     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
745     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
746                       << (int)BestGroup->getMask() << "\n");
747     BestCost += TempCost;
748   } else
749     BestCost += MissPenalty;
750 
751   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
752 }
753 
754 bool PipelineSolver::solveGreedy() {
755   BestCost = 0;
756   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
757 
758   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
759     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
760     IsBottomUp
761         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
762         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
763     advancePosition();
764   }
765   BestPipeline = CurrPipeline;
766   removeEdges(AddedEdges);
767   return false;
768 }
769 
770 unsigned PipelineSolver::computeProblemSize() {
771   unsigned ProblemSize = 0;
772   for (auto &PipeConflicts : PipelineInstrs) {
773     ProblemSize += PipeConflicts.size();
774   }
775 
776   return ProblemSize;
777 }
778 
779 void PipelineSolver::solve() {
780   if (!NeedsSolver)
781     return;
782 
783   unsigned ProblemSize = computeProblemSize();
784   assert(ProblemSize > 0);
785 
786   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
787   MissPenalty = (ProblemSize / 2) + 1;
788 
789   LLVM_DEBUG(DAG->dump());
790   if (EnableExactSolver || BelowCutoff) {
791     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
792     solveGreedy();
793     reset();
794     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
795     if (BestCost > 0) {
796       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
797       solveExact();
798       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
799     }
800   } else { // Use the Greedy Algorithm by default
801     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
802     solveGreedy();
803   }
804 
805   makePipeline();
806   LLVM_DEBUG(dbgs() << "After applying mutation\n");
807   LLVM_DEBUG(DAG->dump());
808 }
809 
810 enum IGLPStrategyID : int {
811   MFMASmallGemmOptID = 0,
812   MFMASmallGemmSingleWaveOptID = 1,
813   MFMAExpInterleaveID = 2,
814   MFMAExpSimpleInterleaveID = 3
815 };
816 
817 // Implement a IGLP scheduling strategy.
818 class IGLPStrategy {
819 protected:
820   ScheduleDAGInstrs *DAG;
821 
822   const SIInstrInfo *TII;
823 
824 public:
825   /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
826   virtual bool applyIGLPStrategy(
827       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
828       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
829       AMDGPU::SchedulingPhase Phase) = 0;
830 
831   // Returns true if this strategy should be applied to a ScheduleDAG.
832   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
833                                    AMDGPU::SchedulingPhase Phase) = 0;
834 
835   bool IsBottomUp = true;
836 
837   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
838       : DAG(DAG), TII(TII) {}
839 
840   virtual ~IGLPStrategy() = default;
841 };
842 
843 class MFMASmallGemmOpt final : public IGLPStrategy {
844 private:
845 public:
846   bool applyIGLPStrategy(
847       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
848       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
849       AMDGPU::SchedulingPhase Phase) override;
850 
851   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
852                            AMDGPU::SchedulingPhase Phase) override {
853     return true;
854   }
855 
856   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
857       : IGLPStrategy(DAG, TII) {
858     IsBottomUp = true;
859   }
860 };
861 
862 bool MFMASmallGemmOpt::applyIGLPStrategy(
863     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
864     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
865     AMDGPU::SchedulingPhase Phase) {
866   // Count the number of MFMA instructions.
867   unsigned MFMACount = 0;
868   for (const MachineInstr &I : *DAG)
869     if (TII->isMFMAorWMMA(I))
870       ++MFMACount;
871 
872   const unsigned PipelineSyncID = 0;
873   SchedGroup *SG = nullptr;
874   for (unsigned I = 0; I < MFMACount * 3; ++I) {
875     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
876         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
877     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
878 
879     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
880         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
881     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
882   }
883 
884   return true;
885 }
886 
887 class MFMAExpInterleaveOpt final : public IGLPStrategy {
888 private:
889   // The count of TRANS SUs involved in the interleaved pipeline
890   static unsigned TransPipeCount;
891   // The count of MFMA SUs involved in the interleaved pipeline
892   static unsigned MFMAPipeCount;
893   // The count of Add SUs involved in the interleaved pipeline
894   static unsigned AddPipeCount;
895   // The number of transitive MFMA successors for each TRANS SU
896   static unsigned MFMAEnablement;
897   // The number of transitive TRANS predecessors for each MFMA SU
898   static unsigned ExpRequirement;
899   // The count of independent "chains" of MFMA instructions in the pipeline
900   static unsigned MFMAChains;
901   // The length of each independent "chain" of MFMA instructions
902   static unsigned MFMAChainLength;
903   // Whether or not the pipeline has V_CVT instructions
904   static bool HasCvt;
905   // Whether or not there are instructions between the TRANS instruction and
906   // V_CVT
907   static bool HasChainBetweenCvt;
908   // The first occuring DS_READ which feeds an MFMA chain
909   static std::optional<unsigned> FirstPipeDSR;
910   // The MFMAPipe SUs with no MFMA predecessors
911   SmallVector<SUnit *, 4> MFMAChainSeeds;
912   // Compute the heuristics for the pipeline, returning whether or not the DAG
913   // is well formatted for the mutation
914   bool analyzeDAG(const SIInstrInfo *TII);
915 
916   /// Whether or not the instruction is a transitive predecessor of an MFMA
917   /// instruction
918   class IsPipeExp final : public InstructionRule {
919   public:
920     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
921                SmallVectorImpl<SchedGroup> &SyncPipe) override {
922 
923       auto *DAG = SyncPipe[0].DAG;
924 
925       if (Cache->empty()) {
926         auto I = DAG->SUnits.rbegin();
927         auto E = DAG->SUnits.rend();
928         for (; I != E; I++) {
929           if (TII->isMFMAorWMMA(*I->getInstr()))
930             Cache->push_back(&*I);
931         }
932         if (Cache->empty())
933           return false;
934       }
935 
936       auto Reaches = any_of(*Cache, [&SU, &DAG](SUnit *TargetSU) {
937         return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
938       });
939 
940       return Reaches;
941     }
942     IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
943         : InstructionRule(TII, SGID, NeedsCache) {}
944   };
945 
946   /// Whether or not the instruction is a transitive predecessor of the
947   /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
948   class EnablesNthMFMA final : public InstructionRule {
949   private:
950     unsigned Number = 1;
951 
952   public:
953     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
954                SmallVectorImpl<SchedGroup> &SyncPipe) override {
955       bool FoundTrans = false;
956       unsigned Counter = 1;
957       auto *DAG = SyncPipe[0].DAG;
958 
959       if (Cache->empty()) {
960         auto I = DAG->SUnits.begin();
961         auto E = DAG->SUnits.end();
962         for (; I != E; I++) {
963           if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
964             if (Counter == Number) {
965               Cache->push_back(&*I);
966               break;
967             }
968             ++Counter;
969           }
970           if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
971             FoundTrans = true;
972         }
973         if (Cache->empty())
974           return false;
975       }
976 
977       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
978     }
979 
980     EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
981                    bool NeedsCache = false)
982         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
983   };
984 
985   /// Whether or not the instruction enables the exact MFMA that is the \p
986   /// Number th MFMA in the chain starting with \p ChainSeed
987   class EnablesNthMFMAInChain final : public InstructionRule {
988   private:
989     unsigned Number = 1;
990     SUnit *ChainSeed;
991 
992   public:
993     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
994                SmallVectorImpl<SchedGroup> &SyncPipe) override {
995       auto *DAG = SyncPipe[0].DAG;
996 
997       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
998         return false;
999 
1000       if (Cache->empty()) {
1001         auto *TempSU = ChainSeed;
1002         auto Depth = Number;
1003         while (Depth > 0) {
1004           --Depth;
1005           bool Found = false;
1006           for (auto &Succ : TempSU->Succs) {
1007             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1008               TempSU = Succ.getSUnit();
1009               Found = true;
1010               break;
1011             }
1012           }
1013           if (!Found)
1014             return false;
1015         }
1016 
1017         Cache->push_back(TempSU);
1018       }
1019       // If we failed to find the instruction to be placed into the cache, we
1020       // would have already exited.
1021       assert(!Cache->empty());
1022 
1023       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1024     }
1025 
1026     EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1027                           const SIInstrInfo *TII, unsigned SGID,
1028                           bool NeedsCache = false)
1029         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1030           ChainSeed(ChainSeed) {}
1031   };
1032 
1033   /// Whether or not the instruction has less than \p Size immediate successors.
1034   /// If \p HasIntermediary is true, this tests also whether all successors of
1035   /// the SUnit have less than \p Size successors.
1036   class LessThanNSuccs final : public InstructionRule {
1037   private:
1038     unsigned Size = 1;
1039     bool HasIntermediary = false;
1040 
1041   public:
1042     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1043                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1044       if (!SyncPipe.size())
1045         return false;
1046 
1047       auto SuccSize = llvm::count_if(SU->Succs, [](const SDep &Succ) {
1048         return Succ.getKind() == SDep::Data;
1049       });
1050       if (SuccSize >= Size)
1051         return false;
1052 
1053       if (HasIntermediary) {
1054         for (auto Succ : SU->Succs) {
1055           auto SuccSize =
1056               llvm::count_if(Succ.getSUnit()->Succs, [](const SDep &SuccSucc) {
1057                 return SuccSucc.getKind() == SDep::Data;
1058               });
1059           if (SuccSize >= Size)
1060             return false;
1061         }
1062       }
1063 
1064       return true;
1065     }
1066     LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1067                    bool HasIntermediary = false, bool NeedsCache = false)
1068         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1069           HasIntermediary(HasIntermediary) {}
1070   };
1071 
1072   /// Whether or not the instruction has greater than or equal to \p Size
1073   /// immediate successors. If \p HasIntermediary is true, this tests also
1074   /// whether all successors of the SUnit have greater than or equal to \p Size
1075   /// successors.
1076   class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1077   private:
1078     unsigned Size = 1;
1079     bool HasIntermediary = false;
1080 
1081   public:
1082     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1083                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1084       if (!SyncPipe.size())
1085         return false;
1086 
1087       auto SuccSize = llvm::count_if(SU->Succs, [](const SDep &Succ) {
1088         return Succ.getKind() == SDep::Data;
1089       });
1090       if (SuccSize >= Size)
1091         return true;
1092 
1093       if (HasIntermediary) {
1094         for (auto Succ : SU->Succs) {
1095           auto SuccSize =
1096               llvm::count_if(Succ.getSUnit()->Succs, [](const SDep &SuccSucc) {
1097                 return SuccSucc.getKind() == SDep::Data;
1098               });
1099           if (SuccSize >= Size)
1100             return true;
1101         }
1102       }
1103 
1104       return false;
1105     }
1106     GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1107                                unsigned SGID, bool HasIntermediary = false,
1108                                bool NeedsCache = false)
1109         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1110           HasIntermediary(HasIntermediary) {}
1111   };
1112 
1113   // Whether or not the instruction is a relevant V_CVT instruction.
1114   class IsCvt final : public InstructionRule {
1115   public:
1116     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1117                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1118       auto Opc = SU->getInstr()->getOpcode();
1119       return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1120              Opc == AMDGPU::V_CVT_I32_F32_e32;
1121     }
1122     IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1123         : InstructionRule(TII, SGID, NeedsCache) {}
1124   };
1125 
1126   // Whether or not the instruction is FMA_F32.
1127   class IsFMA final : public InstructionRule {
1128   public:
1129     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1130                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1131       return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1132              SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1133     }
1134     IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1135         : InstructionRule(TII, SGID, NeedsCache) {}
1136   };
1137 
1138   // Whether or not the instruction is a V_ADD_F32 instruction.
1139   class IsPipeAdd final : public InstructionRule {
1140   public:
1141     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1142                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1143       return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1144     }
1145     IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1146         : InstructionRule(TII, SGID, NeedsCache) {}
1147   };
1148 
1149   /// Whether or not the instruction is an immediate RAW successor
1150   /// of the SchedGroup \p Distance steps before.
1151   class IsSuccOfPrevNthGroup final : public InstructionRule {
1152   private:
1153     unsigned Distance = 1;
1154 
1155   public:
1156     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1157                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1158       SchedGroup *OtherGroup = nullptr;
1159       if (!SyncPipe.size())
1160         return false;
1161 
1162       for (auto &PipeSG : SyncPipe) {
1163         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1164           OtherGroup = &PipeSG;
1165       }
1166 
1167       if (!OtherGroup)
1168         return false;
1169       if (!OtherGroup->Collection.size())
1170         return true;
1171 
1172       for (auto &OtherEle : OtherGroup->Collection) {
1173         for (auto &Succ : OtherEle->Succs) {
1174           if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1175             return true;
1176         }
1177       }
1178 
1179       return false;
1180     }
1181     IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1182                          unsigned SGID, bool NeedsCache = false)
1183         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1184   };
1185 
1186   /// Whether or not the instruction is a transitive successor of any
1187   /// instruction the the SchedGroup \p Distance steps before.
1188   class IsReachableFromPrevNthGroup final : public InstructionRule {
1189   private:
1190     unsigned Distance = 1;
1191 
1192   public:
1193     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1194                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1195       SchedGroup *OtherGroup = nullptr;
1196       if (!SyncPipe.size())
1197         return false;
1198 
1199       for (auto &PipeSG : SyncPipe) {
1200         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1201           OtherGroup = &PipeSG;
1202       }
1203 
1204       if (!OtherGroup)
1205         return false;
1206       if (!OtherGroup->Collection.size())
1207         return true;
1208 
1209       auto *DAG = SyncPipe[0].DAG;
1210 
1211       for (auto &OtherEle : OtherGroup->Collection)
1212         if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
1213           return true;
1214 
1215       return false;
1216     }
1217     IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1218                                 unsigned SGID, bool NeedsCache = false)
1219         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1220   };
1221 
1222   /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1223   class OccursAtOrAfterNode final : public InstructionRule {
1224   private:
1225     unsigned Number = 1;
1226 
1227   public:
1228     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1229                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1230 
1231       return SU->NodeNum >= Number;
1232     }
1233     OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1234                         bool NeedsCache = false)
1235         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1236   };
1237 
1238   /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1239   /// starting with \p ChainSeed
1240   class IsExactMFMA final : public InstructionRule {
1241   private:
1242     unsigned Number = 1;
1243     SUnit *ChainSeed;
1244 
1245   public:
1246     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1247                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1248       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1249         return false;
1250 
1251       if (Cache->empty()) {
1252         auto *TempSU = ChainSeed;
1253         auto Depth = Number;
1254         while (Depth > 0) {
1255           --Depth;
1256           bool Found = false;
1257           for (auto &Succ : TempSU->Succs) {
1258             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1259               TempSU = Succ.getSUnit();
1260               Found = true;
1261               break;
1262             }
1263           }
1264           if (!Found) {
1265             return false;
1266           }
1267         }
1268         Cache->push_back(TempSU);
1269       }
1270       // If we failed to find the instruction to be placed into the cache, we
1271       // would have already exited.
1272       assert(!Cache->empty());
1273 
1274       return (*Cache)[0] == SU;
1275     }
1276 
1277     IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1278                 unsigned SGID, bool NeedsCache = false)
1279         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1280           ChainSeed(ChainSeed) {}
1281   };
1282 
1283   // Whether the instruction occurs after the first TRANS instruction. This
1284   // implies the instruction can not be a predecessor of the first TRANS
1285   // insruction
1286   class OccursAfterExp final : public InstructionRule {
1287   public:
1288     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1289                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1290 
1291       auto *DAG = SyncPipe[0].DAG;
1292       if (Cache->empty()) {
1293         for (auto &SU : DAG->SUnits)
1294           if (TII->isTRANS(SU.getInstr()->getOpcode())) {
1295             Cache->push_back(&SU);
1296             break;
1297           }
1298         if (Cache->empty())
1299           return false;
1300       }
1301 
1302       return SU->NodeNum > (*Cache)[0]->NodeNum;
1303     }
1304 
1305     OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1306                    bool NeedsCache = false)
1307         : InstructionRule(TII, SGID, NeedsCache) {}
1308   };
1309 
1310 public:
1311   bool applyIGLPStrategy(
1312       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1313       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1314       AMDGPU::SchedulingPhase Phase) override;
1315 
1316   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1317                            AMDGPU::SchedulingPhase Phase) override;
1318 
1319   MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1320       : IGLPStrategy(DAG, TII) {
1321     IsBottomUp = false;
1322   }
1323 };
1324 
1325 unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1326 unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1327 unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1328 unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1329 unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1330 unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1331 unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
1332 bool MFMAExpInterleaveOpt::HasCvt = false;
1333 bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1334 std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1335 
1336 bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1337   SmallVector<SUnit *, 10> ExpPipeCands;
1338   SmallVector<SUnit *, 10> MFMAPipeCands;
1339   SmallVector<SUnit *, 10> MFMAPipeSUs;
1340   SmallVector<SUnit *, 10> PackSUs;
1341   SmallVector<SUnit *, 10> CvtSUs;
1342 
1343   auto isBitPack = [](unsigned Opc) {
1344     return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1345   };
1346 
1347   auto isCvt = [](unsigned Opc) {
1348     return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1349   };
1350 
1351   auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1352 
1353   AddPipeCount = 0;
1354   for (SUnit &SU : DAG->SUnits) {
1355     auto Opc = SU.getInstr()->getOpcode();
1356     if (TII->isTRANS(Opc)) {
1357       // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1358       if (SU.Succs.size() >= 7)
1359         continue;
1360       for (auto &Succ : SU.Succs) {
1361         if (Succ.getSUnit()->Succs.size() >= 7)
1362           continue;
1363       }
1364       ExpPipeCands.push_back(&SU);
1365     }
1366 
1367     if (TII->isMFMAorWMMA(*SU.getInstr()))
1368       MFMAPipeCands.push_back(&SU);
1369 
1370     if (isBitPack(Opc))
1371       PackSUs.push_back(&SU);
1372 
1373     if (isCvt(Opc))
1374       CvtSUs.push_back(&SU);
1375 
1376     if (isAdd(Opc))
1377       ++AddPipeCount;
1378   }
1379 
1380   if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1381     return false;
1382 
1383   TransPipeCount = 0;
1384 
1385   std::optional<SUnit *> TempMFMA;
1386   std::optional<SUnit *> TempExp;
1387   // Count the number of EXPs that reach an MFMA
1388   for (auto &PredSU : ExpPipeCands) {
1389     for (auto &SuccSU : MFMAPipeCands) {
1390       if (DAG->IsReachable(SuccSU, PredSU)) {
1391         if (!TempExp) {
1392           TempExp = PredSU;
1393           TempMFMA = SuccSU;
1394         }
1395         MFMAPipeSUs.push_back(SuccSU);
1396         ++TransPipeCount;
1397         break;
1398       }
1399     }
1400   }
1401 
1402   if (!(TempExp && TempMFMA))
1403     return false;
1404 
1405   HasChainBetweenCvt = none_of((*TempExp)->Succs, [&isCvt](SDep &Succ) {
1406     return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1407   });
1408 
1409   // Count the number of MFMAs that are reached by an EXP
1410   for (auto &SuccSU : MFMAPipeCands) {
1411     if (MFMAPipeSUs.size() &&
1412         any_of(MFMAPipeSUs, [&SuccSU](SUnit *PotentialMatch) {
1413           return PotentialMatch->NodeNum == SuccSU->NodeNum;
1414         }))
1415       continue;
1416 
1417     for (auto &PredSU : ExpPipeCands) {
1418       if (DAG->IsReachable(SuccSU, PredSU)) {
1419         MFMAPipeSUs.push_back(SuccSU);
1420         break;
1421       }
1422     }
1423   }
1424 
1425   MFMAPipeCount = MFMAPipeSUs.size();
1426 
1427   assert(TempExp && TempMFMA);
1428   assert(MFMAPipeCount > 0);
1429 
1430   std::optional<SUnit *> TempCvt;
1431   for (auto &SuccSU : CvtSUs) {
1432     if (DAG->IsReachable(SuccSU, *TempExp)) {
1433       TempCvt = SuccSU;
1434       break;
1435     }
1436   }
1437 
1438   HasCvt = false;
1439   if (TempCvt.has_value()) {
1440     for (auto &SuccSU : MFMAPipeSUs) {
1441       if (DAG->IsReachable(SuccSU, *TempCvt)) {
1442         HasCvt = true;
1443         break;
1444       }
1445     }
1446   }
1447 
1448   MFMAChains = 0;
1449   for (auto &MFMAPipeSU : MFMAPipeSUs) {
1450     if (is_contained(MFMAChainSeeds, MFMAPipeSU))
1451       continue;
1452     if (none_of(MFMAPipeSU->Preds, [&TII](SDep &Succ) {
1453           return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1454         })) {
1455       MFMAChainSeeds.push_back(MFMAPipeSU);
1456       ++MFMAChains;
1457     }
1458   }
1459 
1460   if (!MFMAChains)
1461     return false;
1462 
1463   for (auto Pred : MFMAChainSeeds[0]->Preds) {
1464     if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&
1465         Pred.getSUnit()->getInstr()->mayLoad())
1466       FirstPipeDSR = Pred.getSUnit()->NodeNum;
1467   }
1468 
1469   MFMAChainLength = MFMAPipeCount / MFMAChains;
1470 
1471   // The number of bit pack operations that depend on a single V_EXP
1472   unsigned PackSuccCount =
1473       llvm::count_if(PackSUs, [this, &TempExp](SUnit *VPack) {
1474         return DAG->IsReachable(VPack, *TempExp);
1475       });
1476 
1477   // The number of bit pack operations an MFMA depends on
1478   unsigned PackPredCount =
1479       llvm::count_if((*TempMFMA)->Preds, [&isBitPack](SDep &Pred) {
1480         auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1481         return isBitPack(Opc);
1482       });
1483 
1484   auto *PackPred = llvm::find_if((*TempMFMA)->Preds, [&isBitPack](SDep &Pred) {
1485     auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1486     return isBitPack(Opc);
1487   });
1488 
1489   if (PackPred == (*TempMFMA)->Preds.end())
1490     return false;
1491 
1492   MFMAEnablement = 0;
1493   ExpRequirement = 0;
1494   // How many MFMAs depend on a single bit pack operation
1495   MFMAEnablement =
1496       llvm::count_if(PackPred->getSUnit()->Succs, [&TII](SDep &Succ) {
1497         return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1498       });
1499 
1500   // The number of MFMAs that depend on a single V_EXP
1501   MFMAEnablement *= PackSuccCount;
1502 
1503   // The number of V_EXPs required to resolve all dependencies for an MFMA
1504   ExpRequirement =
1505       llvm::count_if(ExpPipeCands, [this, &PackPred](SUnit *ExpBase) {
1506         return DAG->IsReachable(PackPred->getSUnit(), ExpBase);
1507       });
1508 
1509   ExpRequirement *= PackPredCount;
1510   return true;
1511 }
1512 
1513 bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1514                                                AMDGPU::SchedulingPhase Phase) {
1515   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1516   const SIInstrInfo *TII = ST.getInstrInfo();
1517 
1518   if (Phase != AMDGPU::SchedulingPhase::PostRA)
1519     MFMAChainSeeds.clear();
1520   if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1521     return false;
1522 
1523   return true;
1524 }
1525 
1526 bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1527     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1528     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1529     AMDGPU::SchedulingPhase Phase) {
1530 
1531   bool IsSmallKernelType =
1532       MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1533   bool IsLargeKernelType =
1534       MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1535 
1536   if (!(IsSmallKernelType || IsLargeKernelType))
1537     return false;
1538 
1539   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1540   const SIInstrInfo *TII = ST.getInstrInfo();
1541 
1542   unsigned PipelineSyncID = 0;
1543   SchedGroup *SG = nullptr;
1544 
1545   unsigned MFMAChain = 0;
1546   unsigned PositionInChain = 0;
1547   unsigned CurrMFMAForTransPosition = 0;
1548 
1549   auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1550                                  &CurrMFMAForTransPosition]() {
1551     CurrMFMAForTransPosition += MFMAEnablement;
1552     PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1553     MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1554   };
1555 
1556   auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1557     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1558     return (TempMFMAForTrans / MFMAChains);
1559   };
1560 
1561   auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1562     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1563     return TempMFMAForTrans % MFMAChains;
1564   };
1565 
1566   unsigned CurrMFMAPosition = 0;
1567   unsigned MFMAChainForMFMA = 0;
1568   unsigned PositionInChainForMFMA = 0;
1569 
1570   auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1571                                 &PositionInChainForMFMA]() {
1572     ++CurrMFMAPosition;
1573     MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1574     PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1575   };
1576 
1577   bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1578   assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1579 
1580   bool UsesFMA = IsSmallKernelType || !IsPostRA;
1581   bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1582   bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1583   bool UsesVALU = IsSmallKernelType;
1584 
1585   // PHASE 1: "Prefetch"
1586   if (UsesFMA) {
1587     // First Round FMA
1588     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1589         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1590     if (!IsPostRA && MFMAChains) {
1591       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1592           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1593           true));
1594     } else
1595       SG->addRule(
1596           std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1597     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1598     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1599 
1600     // Second Round FMA
1601     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1602         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1603     if (!IsPostRA && MFMAChains) {
1604       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1605           getNextTransPositionInChain(),
1606           MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1607     } else
1608       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1609                                                    SG->getSGID(), true));
1610     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1611     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1612   }
1613 
1614   if (UsesDSRead) {
1615     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1616         SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1617     SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1618                                                       SG->getSGID()));
1619     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1620   }
1621 
1622   // First Round EXP
1623   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1624       SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);
1625   if (!IsPostRA && MFMAChains)
1626     SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1627         PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));
1628   else
1629     SG->addRule(std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1630   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1631   SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1632                                                HasChainBetweenCvt));
1633   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1634 
1635   incrementTransPosition();
1636 
1637   // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1638   for (unsigned I = 0; I < ExpRequirement; I++) {
1639     // First Round CVT
1640     if (UsesCvt) {
1641       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1642           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1643       SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1644       if (HasChainBetweenCvt)
1645         SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1646             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1647       else
1648         SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(
1649             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1650       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1651     }
1652 
1653     // Third Round FMA
1654     if (UsesFMA) {
1655       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1656           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1657       if (!IsPostRA && MFMAChains) {
1658         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1659             getNextTransPositionInChain(),
1660             MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1661       } else
1662         SG->addRule(std::make_shared<EnablesNthMFMA>(2 * MFMAEnablement + 1,
1663                                                      TII, SG->getSGID(), true));
1664       SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1665       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1666     }
1667 
1668     // Second Round EXP
1669     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1670         SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1671     if (!IsPostRA && MFMAChains)
1672       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1673           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1674           true));
1675     else
1676       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1677                                                    SG->getSGID(), true));
1678     SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1679     SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1680                                                  HasChainBetweenCvt));
1681     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1682   }
1683 
1684   // The "extra" EXP which enables all MFMA
1685   // TODO: UsesExtraExp
1686   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1687       SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1688   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1689   SG->addRule(std::make_shared<GreaterThanOrEqualToNSuccs>(
1690       8, TII, SG->getSGID(), HasChainBetweenCvt));
1691   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1692 
1693   // PHASE 2: Main Interleave Loop
1694 
1695   // The number of MFMAs per iteration
1696   unsigned MFMARatio =
1697       MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1698   // The number of Exps per iteration
1699   unsigned ExpRatio =
1700       MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1701   // The reamaining Exps
1702   unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1703                               ? TransPipeCount - (2 * ExpRequirement)
1704                               : 0;
1705   unsigned ExpLoopCount = RemainingExp / ExpRatio;
1706   // In loop MFMAs
1707   unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1708                             ? MFMAPipeCount - (MFMAEnablement * 2)
1709                             : 0;
1710   unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1711   unsigned VALUOps =
1712       AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1713   unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);
1714 
1715   for (unsigned I = 0; I < LoopSize; I++) {
1716     if (!(I * ExpRatio % ExpRequirement))
1717       incrementTransPosition();
1718 
1719     // Round N MFMA
1720     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1721         SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);
1722     if (!IsPostRA && MFMAChains)
1723       SG->addRule(std::make_shared<IsExactMFMA>(
1724           PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,
1725           SG->getSGID(), true));
1726     else
1727       SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1728     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1729     incrementMFMAPosition();
1730 
1731     if (UsesVALU) {
1732       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1733           SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);
1734       SG->addRule(std::make_shared<IsPipeAdd>(TII, SG->getSGID()));
1735       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1736     }
1737 
1738     if (UsesDSRead && !(I % 4)) {
1739       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1740           SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1741       SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1742                                                         SG->getSGID()));
1743       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1744     }
1745 
1746     // CVT, EXP, FMA Interleaving
1747     for (unsigned J = 0; J < ExpRatio; J++) {
1748       auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1749       auto MaxMFMAOffset =
1750           (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1751 
1752       // Round N + 1 CVT
1753       if (UsesCvt) {
1754         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1755             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1756         SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1757         auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1758         auto DSROffset = I / 4 + 1;
1759         auto MaxDSROffset = MaxMFMAOffset / 4;
1760         // TODO: UsesExtraExp
1761         auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1762         auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +
1763                              std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +
1764                              ExpOffset;
1765         if (HasChainBetweenCvt)
1766           SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1767               CurrentOffset, TII, SG->getSGID()));
1768         else
1769           SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(CurrentOffset, TII,
1770                                                              SG->getSGID()));
1771         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1772       }
1773 
1774       // Round N + 3 FMA
1775       if (UsesFMA) {
1776         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1777             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1778         if (!IsPostRA && MFMAChains)
1779           SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1780               getNextTransPositionInChain(),
1781               MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),
1782               true));
1783         else
1784           SG->addRule(std::make_shared<EnablesNthMFMA>(
1785               (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1786               TII, SG->getSGID(), true));
1787         SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1788         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1789       }
1790 
1791       // Round N + 2 Exp
1792       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1793           SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1794       if (!IsPostRA && MFMAChains)
1795         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1796             PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1797             true));
1798       else
1799         SG->addRule(std::make_shared<EnablesNthMFMA>(
1800             (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1801             TII, SG->getSGID(), true));
1802       SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1803       SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1804                                                    HasChainBetweenCvt));
1805       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1806     }
1807   }
1808 
1809   // PHASE 3: Remaining MFMAs
1810   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1811       SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);
1812   SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1813   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1814   return true;
1815 }
1816 
1817 class MFMAExpSimpleInterleaveOpt final : public IGLPStrategy {
1818 public:
1819   bool applyIGLPStrategy(
1820       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1821       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1822       AMDGPU::SchedulingPhase Phase) override;
1823 
1824   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1825                            AMDGPU::SchedulingPhase Phase) override {
1826     return true;
1827   }
1828 
1829   MFMAExpSimpleInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1830       : IGLPStrategy(DAG, TII) {
1831     IsBottomUp = true;
1832   }
1833 };
1834 
1835 bool MFMAExpSimpleInterleaveOpt::applyIGLPStrategy(
1836     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1837     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1838     AMDGPU::SchedulingPhase Phase) {
1839   // Count the number of MFMA instructions.
1840   unsigned MFMACount = 0;
1841   for (const MachineInstr &I : *DAG)
1842     if (TII->isMFMAorWMMA(I))
1843       ++MFMACount;
1844 
1845   const unsigned PipelineSyncID = 0;
1846   for (unsigned I = 0; I < MFMACount * 3; ++I) {
1847     SchedGroup *SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1848         SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1849     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1850 
1851     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1852         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1853     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1854   }
1855 
1856   return true;
1857 }
1858 
1859 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1860 private:
1861   // Whether the DS_READ is a predecessor of first four MFMA in region
1862   class EnablesInitialMFMA final : public InstructionRule {
1863   public:
1864     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1865                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1866       if (!SyncPipe.size())
1867         return false;
1868       int MFMAsFound = 0;
1869       if (!Cache->size()) {
1870         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
1871           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
1872             ++MFMAsFound;
1873             if (MFMAsFound > 4)
1874               break;
1875             Cache->push_back(&Elt);
1876           }
1877         }
1878       }
1879 
1880       auto *DAG = SyncPipe[0].DAG;
1881       for (auto &Elt : *Cache) {
1882         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
1883           return true;
1884       }
1885       return false;
1886     }
1887 
1888     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
1889                        bool NeedsCache = false)
1890         : InstructionRule(TII, SGID, NeedsCache) {}
1891   };
1892 
1893   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
1894   class IsPermForDSW final : public InstructionRule {
1895   public:
1896     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1897                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1898       auto *MI = SU->getInstr();
1899       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
1900         return false;
1901 
1902       bool FitsInGroup = false;
1903       // Does the VALU have a DS_WRITE successor
1904       if (!Collection.size()) {
1905         for (auto &Succ : SU->Succs) {
1906           SUnit *SuccUnit = Succ.getSUnit();
1907           if (TII->isDS(*SuccUnit->getInstr()) &&
1908               SuccUnit->getInstr()->mayStore()) {
1909             Cache->push_back(SuccUnit);
1910             FitsInGroup = true;
1911           }
1912         }
1913         return FitsInGroup;
1914       }
1915 
1916       // Does the VALU have a DS_WRITE successor that is the same as other
1917       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
1918       return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
1919         return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
1920           return ThisSucc.getSUnit() == Elt;
1921         });
1922       });
1923     }
1924 
1925     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1926         : InstructionRule(TII, SGID, NeedsCache) {}
1927   };
1928 
1929   // Whether the SU is a successor of any element in previous SchedGroup
1930   class IsSuccOfPrevGroup final : public InstructionRule {
1931   public:
1932     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1933                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1934       SchedGroup *OtherGroup = nullptr;
1935       for (auto &PipeSG : SyncPipe) {
1936         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
1937           OtherGroup = &PipeSG;
1938         }
1939       }
1940 
1941       if (!OtherGroup)
1942         return false;
1943       if (!OtherGroup->Collection.size())
1944         return true;
1945 
1946       // Does the previous VALU have this DS_Write as a successor
1947       return any_of(OtherGroup->Collection, [&SU](SUnit *Elt) {
1948         return any_of(Elt->Succs,
1949                       [&SU](SDep &Succ) { return Succ.getSUnit() == SU; });
1950       });
1951     }
1952     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1953                       bool NeedsCache = false)
1954         : InstructionRule(TII, SGID, NeedsCache) {}
1955   };
1956 
1957   // Whether the combined load width of group is 128 bits
1958   class VMEMSize final : public InstructionRule {
1959   public:
1960     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1961                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1962       auto *MI = SU->getInstr();
1963       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1964         return false;
1965       if (!Collection.size())
1966         return true;
1967 
1968       int NumBits = 0;
1969 
1970       auto TRI = TII->getRegisterInfo();
1971       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1972       for (auto &Elt : Collection) {
1973         auto Op = Elt->getInstr()->getOperand(0);
1974         auto Size =
1975             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1976         NumBits += Size;
1977       }
1978 
1979       if (NumBits < 128) {
1980         assert(TII->isVMEM(*MI) && MI->mayLoad());
1981         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1982                           MRI, MI->getOperand(0))) <=
1983             128)
1984           return true;
1985       }
1986 
1987       return false;
1988     }
1989 
1990     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1991         : InstructionRule(TII, SGID, NeedsCache) {}
1992   };
1993 
1994   /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1995   /// that is \p Distance steps away
1996   class SharesPredWithPrevNthGroup final : public InstructionRule {
1997   private:
1998     unsigned Distance = 1;
1999 
2000   public:
2001     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2002                SmallVectorImpl<SchedGroup> &SyncPipe) override {
2003       SchedGroup *OtherGroup = nullptr;
2004       if (!SyncPipe.size())
2005         return false;
2006 
2007       if (!Cache->size()) {
2008 
2009         for (auto &PipeSG : SyncPipe) {
2010           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2011             OtherGroup = &PipeSG;
2012           }
2013         }
2014 
2015         if (!OtherGroup)
2016           return false;
2017         if (!OtherGroup->Collection.size())
2018           return true;
2019 
2020         for (auto &OtherEle : OtherGroup->Collection) {
2021           for (auto &Pred : OtherEle->Preds) {
2022             if (Pred.getSUnit()->getInstr()->getOpcode() ==
2023                 AMDGPU::V_PERM_B32_e64)
2024               Cache->push_back(Pred.getSUnit());
2025           }
2026         }
2027 
2028         // If the other group has no PERM preds, then this group won't share any
2029         if (!Cache->size())
2030           return false;
2031       }
2032 
2033       auto *DAG = SyncPipe[0].DAG;
2034       // Does the previous DS_WRITE share a V_PERM predecessor with this
2035       // VMEM_READ
2036       return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
2037         return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
2038       });
2039     }
2040     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2041                                unsigned SGID, bool NeedsCache = false)
2042         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2043   };
2044 
2045 public:
2046   bool applyIGLPStrategy(
2047       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2048       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2049       AMDGPU::SchedulingPhase Phase) override;
2050 
2051   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2052                            AMDGPU::SchedulingPhase Phase) override {
2053     return true;
2054   }
2055 
2056   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2057       : IGLPStrategy(DAG, TII) {
2058     IsBottomUp = false;
2059   }
2060 };
2061 
2062 static unsigned DSWCount = 0;
2063 static unsigned DSWWithPermCount = 0;
2064 static unsigned DSWWithSharedVMEMCount = 0;
2065 
2066 bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2067     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2068     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2069     AMDGPU::SchedulingPhase Phase) {
2070   unsigned MFMACount = 0;
2071   unsigned DSRCount = 0;
2072 
2073   bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2074 
2075   assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2076                          DSWWithSharedVMEMCount == 0)) &&
2077          "DSWCounters should be zero in pre-RA scheduling!");
2078   SmallVector<SUnit *, 6> DSWithPerms;
2079   for (auto &SU : DAG->SUnits) {
2080     auto *I = SU.getInstr();
2081     if (TII->isMFMAorWMMA(*I))
2082       ++MFMACount;
2083     else if (TII->isDS(*I)) {
2084       if (I->mayLoad())
2085         ++DSRCount;
2086       else if (I->mayStore() && IsInitial) {
2087         ++DSWCount;
2088         for (auto Pred : SU.Preds) {
2089           if (Pred.getSUnit()->getInstr()->getOpcode() ==
2090               AMDGPU::V_PERM_B32_e64) {
2091             DSWithPerms.push_back(&SU);
2092             break;
2093           }
2094         }
2095       }
2096     }
2097   }
2098 
2099   if (IsInitial) {
2100     DSWWithPermCount = DSWithPerms.size();
2101     auto *I = DSWithPerms.begin();
2102     auto *E = DSWithPerms.end();
2103 
2104     // Get the count of DS_WRITES with V_PERM predecessors which
2105     // have loop carried dependencies (WAR) on the same VMEM_READs.
2106     // We consider partial overlap as a miss -- in other words,
2107     // for a given DS_W, we only consider another DS_W as matching
2108     // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2109     // for every V_PERM pred of this DS_W.
2110     DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2111     SmallVector<SUnit *, 6> Counted;
2112     for (; I != E; I++) {
2113       SUnit *Cand = nullptr;
2114       bool MissedAny = false;
2115       for (auto &Pred : (*I)->Preds) {
2116         if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2117           continue;
2118 
2119         if (Cand && llvm::is_contained(Counted, Cand))
2120           break;
2121 
2122         for (auto &Succ : Pred.getSUnit()->Succs) {
2123           auto *MI = Succ.getSUnit()->getInstr();
2124           if (!TII->isVMEM(*MI) || !MI->mayLoad())
2125             continue;
2126 
2127           if (MissedAny || !VMEMLookup.size()) {
2128             MissedAny = true;
2129             VMEMLookup[MI] = *I;
2130             continue;
2131           }
2132 
2133           auto [It, Inserted] = VMEMLookup.try_emplace(MI, *I);
2134           if (Inserted) {
2135             MissedAny = true;
2136             continue;
2137           }
2138 
2139           Cand = It->second;
2140           if (llvm::is_contained(Counted, Cand)) {
2141             MissedAny = true;
2142             break;
2143           }
2144         }
2145       }
2146       if (!MissedAny && Cand) {
2147         DSWWithSharedVMEMCount += 2;
2148         Counted.push_back(Cand);
2149         Counted.push_back(*I);
2150       }
2151     }
2152   }
2153 
2154   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2155   SchedGroup *SG;
2156   unsigned PipelineSyncID = 0;
2157   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2158   if (DSWWithPermCount) {
2159     for (unsigned I = 0; I < MFMACount; I++) {
2160       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2161           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2162       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2163 
2164       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2165           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
2166       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2167     }
2168   }
2169 
2170   PipelineSyncID = 1;
2171   // Phase 1: Break up DS_READ and MFMA clusters.
2172   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2173   // prefetch
2174 
2175   // Make ready initial MFMA
2176   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2177       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
2178   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
2179   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2180 
2181   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2182       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2183   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2184 
2185   // Interleave MFMA with DS_READ prefetch
2186   for (unsigned I = 0; I < DSRCount - 4; ++I) {
2187     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2188         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
2189     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2190 
2191     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2192         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2193     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2194   }
2195 
2196   // Phase 2a: Loop carried dependency with V_PERM
2197   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2198   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2199   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
2200     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2201         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2202     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2203     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2204 
2205     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2206         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2207     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2208     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2209 
2210     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2211         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2212     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2213         1, TII, SG->getSGID(), true));
2214     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2215     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2216 
2217     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2218         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2219     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2220 
2221     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2222         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2223     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2224         3, TII, SG->getSGID(), true));
2225     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2226     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2227 
2228     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2229         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2230     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2231   }
2232 
2233   // Phase 2b: Loop carried dependency without V_PERM
2234   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2235   // Interleave MFMA to keep XDL unit busy throughout.
2236   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
2237     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2238         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2239     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2240 
2241     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2242         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2243     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2244     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2245 
2246     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2247         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2248     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2249   }
2250 
2251   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2252   // ultimately used by two DS_WRITE
2253   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2254   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2255 
2256   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2257     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2258         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2259     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2260     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2261 
2262     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2263         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2264     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2265     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2266 
2267     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2268         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2269     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2270 
2271     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2272         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2273     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2274     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2275 
2276     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2277         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2278     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2279     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2280 
2281     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2282         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2283     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2284 
2285     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2286         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2287     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2288         2, TII, SG->getSGID(), true));
2289     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2290     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2291 
2292     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2293         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2294     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2295 
2296     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2297         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2298     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2299         4, TII, SG->getSGID(), true));
2300     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2301     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2302 
2303     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2304         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2305     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2306   }
2307 
2308   return true;
2309 }
2310 
2311 static std::unique_ptr<IGLPStrategy>
2312 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2313                    const SIInstrInfo *TII) {
2314   switch (ID) {
2315   case MFMASmallGemmOptID:
2316     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
2317   case MFMASmallGemmSingleWaveOptID:
2318     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
2319   case MFMAExpInterleaveID:
2320     return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
2321   case MFMAExpSimpleInterleaveID:
2322     return std::make_unique<MFMAExpSimpleInterleaveOpt>(DAG, TII);
2323   }
2324 
2325   llvm_unreachable("Unknown IGLPStrategyID");
2326 }
2327 
2328 class IGroupLPDAGMutation : public ScheduleDAGMutation {
2329 private:
2330   const SIInstrInfo *TII;
2331 
2332   ScheduleDAGMI *DAG;
2333 
2334   // Organize lists of SchedGroups by their SyncID. SchedGroups /
2335   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2336   // between then.
2337   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2338 
2339   // Used to track instructions that can be mapped to multiple sched groups
2340   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2341 
2342   // Add DAG edges that enforce SCHED_BARRIER ordering.
2343   void addSchedBarrierEdges(SUnit &SU);
2344 
2345   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2346   // not be reordered accross the SCHED_BARRIER. This is used for the base
2347   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2348   // SCHED_BARRIER will always block all instructions that can be classified
2349   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2350   // and may only synchronize with some SchedGroups. Returns the inverse of
2351   // Mask. SCHED_BARRIER's mask describes which instruction types should be
2352   // allowed to be scheduled across it. Invert the mask to get the
2353   // SchedGroupMask of instructions that should be barred.
2354   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2355 
2356   // Create SchedGroups for a SCHED_GROUP_BARRIER.
2357   void initSchedGroupBarrierPipelineStage(
2358       std::vector<SUnit>::reverse_iterator RIter);
2359 
2360   bool initIGLPOpt(SUnit &SU);
2361 
2362 public:
2363   void apply(ScheduleDAGInstrs *DAGInstrs) override;
2364 
2365   // The order in which the PipelineSolver should process the candidate
2366   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2367   // created SchedGroup first, and will consider that as the ultimate
2368   // predecessor group when linking. TOP_DOWN instead links and processes the
2369   // first created SchedGroup first.
2370   bool IsBottomUp = true;
2371 
2372   // The scheduling phase this application of IGLP corresponds with.
2373   AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2374 
2375   IGroupLPDAGMutation() = default;
2376   IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2377 };
2378 
2379 unsigned SchedGroup::NumSchedGroups = 0;
2380 
2381 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2382   if (A != B && DAG->canAddEdge(B, A)) {
2383     DAG->addEdge(B, SDep(A, SDep::Artificial));
2384     return true;
2385   }
2386   return false;
2387 }
2388 
2389 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2390   bool Result = false;
2391   if (MI.isMetaInstruction())
2392     Result = false;
2393 
2394   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2395            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
2396             TII->isTRANS(MI)))
2397     Result = true;
2398 
2399   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2400            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
2401     Result = true;
2402 
2403   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2404            TII->isSALU(MI))
2405     Result = true;
2406 
2407   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2408            TII->isMFMAorWMMA(MI))
2409     Result = true;
2410 
2411   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2412            TII->isVMEM(MI))
2413     Result = true;
2414 
2415   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2416            MI.mayLoad() && TII->isVMEM(MI))
2417     Result = true;
2418 
2419   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2420            MI.mayStore() && TII->isVMEM(MI))
2421     Result = true;
2422 
2423   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2424            TII->isDS(MI))
2425     Result = true;
2426 
2427   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2428            MI.mayLoad() && TII->isDS(MI))
2429     Result = true;
2430 
2431   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2432            MI.mayStore() && TII->isDS(MI))
2433     Result = true;
2434 
2435   else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2436            TII->isTRANS(MI))
2437     Result = true;
2438 
2439   LLVM_DEBUG(
2440       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2441              << (Result ? " could classify " : " unable to classify ") << MI);
2442 
2443   return Result;
2444 }
2445 
2446 int SchedGroup::link(SUnit &SU, bool MakePred,
2447                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2448   int MissedEdges = 0;
2449   for (auto *A : Collection) {
2450     SUnit *B = &SU;
2451     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2452       continue;
2453     if (MakePred)
2454       std::swap(A, B);
2455 
2456     if (DAG->IsReachable(B, A))
2457       continue;
2458 
2459     // tryAddEdge returns false if there is a dependency that makes adding
2460     // the A->B edge impossible, otherwise it returns true;
2461     bool Added = tryAddEdge(A, B);
2462     if (Added)
2463       AddedEdges.emplace_back(A, B);
2464     else
2465       ++MissedEdges;
2466   }
2467 
2468   return MissedEdges;
2469 }
2470 
2471 void SchedGroup::link(SUnit &SU, bool MakePred) {
2472   for (auto *A : Collection) {
2473     SUnit *B = &SU;
2474     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2475       continue;
2476     if (MakePred)
2477       std::swap(A, B);
2478 
2479     tryAddEdge(A, B);
2480   }
2481 }
2482 
2483 void SchedGroup::link(SUnit &SU,
2484                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2485   for (auto *A : Collection) {
2486     SUnit *B = &SU;
2487     if (P(A, B))
2488       std::swap(A, B);
2489 
2490     tryAddEdge(A, B);
2491   }
2492 }
2493 
2494 void SchedGroup::link(SchedGroup &OtherGroup) {
2495   for (auto *B : OtherGroup.Collection)
2496     link(*B);
2497 }
2498 
2499 bool SchedGroup::canAddSU(SUnit &SU) const {
2500   MachineInstr &MI = *SU.getInstr();
2501   if (MI.getOpcode() != TargetOpcode::BUNDLE)
2502     return canAddMI(MI);
2503 
2504   // Special case for bundled MIs.
2505   const MachineBasicBlock *MBB = MI.getParent();
2506   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2507   while (E != MBB->end() && E->isBundledWithPred())
2508     ++E;
2509 
2510   // Return true if all of the bundled MIs can be added to this group.
2511   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
2512 }
2513 
2514 void SchedGroup::initSchedGroup() {
2515   for (auto &SU : DAG->SUnits) {
2516     if (isFull())
2517       break;
2518 
2519     if (canAddSU(SU))
2520       add(SU);
2521   }
2522 }
2523 
2524 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
2525                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
2526   SUnit &InitSU = *RIter;
2527   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
2528     auto &SU = *RIter;
2529     if (isFull())
2530       break;
2531 
2532     if (canAddSU(SU))
2533       SyncedInstrs[&SU].push_back(SGID);
2534   }
2535 
2536   add(InitSU);
2537   assert(MaxSize);
2538   (*MaxSize)++;
2539 }
2540 
2541 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
2542   auto I = DAG->SUnits.rbegin();
2543   auto E = DAG->SUnits.rend();
2544   for (; I != E; ++I) {
2545     auto &SU = *I;
2546     if (isFull())
2547       break;
2548     if (canAddSU(SU))
2549       SyncedInstrs[&SU].push_back(SGID);
2550   }
2551 }
2552 
2553 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2554   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2555   if (!TSchedModel || DAGInstrs->SUnits.empty())
2556     return;
2557 
2558   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2559   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2560   TII = ST.getInstrInfo();
2561   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2562   SyncedSchedGroups.clear();
2563   SyncedInstrs.clear();
2564   bool FoundSB = false;
2565   bool FoundIGLP = false;
2566   bool ShouldApplyIGLP = false;
2567   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2568     unsigned Opc = R->getInstr()->getOpcode();
2569     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2570     if (Opc == AMDGPU::SCHED_BARRIER) {
2571       addSchedBarrierEdges(*R);
2572       FoundSB = true;
2573     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2574       initSchedGroupBarrierPipelineStage(R);
2575       FoundSB = true;
2576     } else if (Opc == AMDGPU::IGLP_OPT) {
2577       if (!FoundSB && !FoundIGLP) {
2578         FoundIGLP = true;
2579         ShouldApplyIGLP = initIGLPOpt(*R);
2580       }
2581     }
2582   }
2583 
2584   if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2585     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2586     // PipelineSolver performs the mutation by adding the edges it
2587     // determined as the best
2588     PS.solve();
2589     return;
2590   }
2591 }
2592 
2593 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2594   MachineInstr &MI = *SchedBarrier.getInstr();
2595   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2596   // Remove all existing edges from the SCHED_BARRIER that were added due to the
2597   // instruction having side effects.
2598   LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2599                     << MI.getOperand(0).getImm() << "\n");
2600   auto InvertedMask =
2601       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
2602   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2603   SG.initSchedGroup();
2604 
2605   // Preserve original instruction ordering relative to the SCHED_BARRIER.
2606   SG.link(
2607       SchedBarrier,
2608       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2609           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2610 }
2611 
2612 SchedGroupMask
2613 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2614   // Invert mask and erase bits for types of instructions that are implied to be
2615   // allowed past the SCHED_BARRIER.
2616   SchedGroupMask InvertedMask = ~Mask;
2617 
2618   // ALU implies VALU, SALU, MFMA, TRANS.
2619   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2620     InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2621                     ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2622   // VALU, SALU, MFMA, TRANS implies ALU.
2623   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2624            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2625            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2626            (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2627     InvertedMask &= ~SchedGroupMask::ALU;
2628 
2629   // VMEM implies VMEM_READ, VMEM_WRITE.
2630   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2631     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
2632   // VMEM_READ, VMEM_WRITE implies VMEM.
2633   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2634            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
2635     InvertedMask &= ~SchedGroupMask::VMEM;
2636 
2637   // DS implies DS_READ, DS_WRITE.
2638   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2639     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
2640   // DS_READ, DS_WRITE implies DS.
2641   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2642            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2643     InvertedMask &= ~SchedGroupMask::DS;
2644 
2645   LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2646                     << "\n");
2647 
2648   return InvertedMask;
2649 }
2650 
2651 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2652     std::vector<SUnit>::reverse_iterator RIter) {
2653   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
2654   // to the instruction having side effects.
2655   MachineInstr &SGB = *RIter->getInstr();
2656   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2657   int32_t SGMask = SGB.getOperand(0).getImm();
2658   int32_t Size = SGB.getOperand(1).getImm();
2659   int32_t SyncID = SGB.getOperand(2).getImm();
2660 
2661   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
2662                                                     Size, SyncID, DAG, TII);
2663 
2664   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
2665 }
2666 
2667 bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2668   IGLPStrategyID StrategyID =
2669       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
2670   auto S = createIGLPStrategy(StrategyID, DAG, TII);
2671   if (!S->shouldApplyStrategy(DAG, Phase))
2672     return false;
2673 
2674   IsBottomUp = S->IsBottomUp;
2675   return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2676 }
2677 
2678 } // namespace
2679 
2680 /// \p Phase specifes whether or not this is a reentry into the
2681 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2682 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
2683 /// scheduling "phases"), we can reenter this mutation framework more than once
2684 /// for a given region.
2685 std::unique_ptr<ScheduleDAGMutation>
2686 llvm::createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) {
2687   return std::make_unique<IGroupLPDAGMutation>(Phase);
2688 }
2689