xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "igrouplp"
31 
32 namespace {
33 
34 static cl::opt<bool> EnableExactSolver(
35     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36     cl::desc("Whether to use the exponential time solver to fit "
37              "the instructions to the pipeline as closely as "
38              "possible."),
39     cl::init(false));
40 
41 static cl::opt<unsigned> CutoffForExact(
42     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43     cl::desc("The maximum number of scheduling group conflicts "
44              "which we attempt to solve with the exponential time "
45              "exact solver. Problem sizes greater than this will"
46              "be solved by the less accurate greedy algorithm. Selecting "
47              "solver by size is superseded by manually selecting "
48              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49 
50 static cl::opt<uint64_t> MaxBranchesExplored(
51     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52     cl::desc("The amount of branches that we are willing to explore with"
53              "the exact algorithm before giving up."));
54 
55 static cl::opt<bool> UseCostHeur(
56     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57     cl::desc("Whether to use the cost heuristic to make choices as we "
58              "traverse the search space using the exact solver. Defaulted "
59              "to on, and if turned off, we will use the node order -- "
60              "attempting to put the later nodes in the later sched groups. "
61              "Experimentally, results are mixed, so this should be set on a "
62              "case-by-case basis."));
63 
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67   NONE = 0u,
68   ALU = 1u << 0,
69   VALU = 1u << 1,
70   SALU = 1u << 2,
71   MFMA = 1u << 3,
72   VMEM = 1u << 4,
73   VMEM_READ = 1u << 5,
74   VMEM_WRITE = 1u << 6,
75   DS = 1u << 7,
76   DS_READ = 1u << 8,
77   DS_WRITE = 1u << 9,
78   TRANS = 1u << 10,
79   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
80         DS_READ | DS_WRITE | TRANS,
81   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
82 };
83 
84 class SchedGroup;
85 
86 // InstructionRule class is used to enact a filter which determines whether or
87 // not an SU maps to a given SchedGroup. It contains complementary data
88 // structures (e.g Cache) to help those filters.
89 class InstructionRule {
90 protected:
91   const SIInstrInfo *TII;
92   unsigned SGID;
93   // A cache made available to the Filter to store SUnits for subsequent
94   // invocations of the Filter
95   std::optional<SmallVector<SUnit *, 4>> Cache;
96 
97 public:
98   virtual bool
99   apply(const SUnit *, const ArrayRef<SUnit *>,
100         SmallVectorImpl<SchedGroup> &) {
101     return true;
102   };
103 
104   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
105                   bool NeedsCache = false)
106       : TII(TII), SGID(SGID) {
107     if (NeedsCache) {
108       Cache = SmallVector<SUnit *, 4>();
109     }
110   }
111 
112   virtual ~InstructionRule() = default;
113 };
114 
115 typedef DenseMap<SUnit *, SmallVector<int, 4>> SUnitsToCandidateSGsMap;
116 
117 // Classify instructions into groups to enable fine tuned control over the
118 // scheduler. These groups may be more specific than current SchedModel
119 // instruction classes.
120 class SchedGroup {
121 private:
122   // Mask that defines which instruction types can be classified into this
123   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
124   // and SCHED_GROUP_BARRIER.
125   SchedGroupMask SGMask;
126 
127   // Maximum number of SUnits that can be added to this group.
128   std::optional<unsigned> MaxSize;
129 
130   // SchedGroups will only synchronize with other SchedGroups that have the same
131   // SyncID.
132   int SyncID = 0;
133 
134   // SGID is used to map instructions to candidate SchedGroups
135   unsigned SGID;
136 
137   // The different rules each instruction in this SchedGroup must conform to
138   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
139 
140   // Count of the number of created SchedGroups, used to initialize SGID.
141   static unsigned NumSchedGroups;
142 
143   const SIInstrInfo *TII;
144 
145   // Try to add and edge from SU A to SU B.
146   bool tryAddEdge(SUnit *A, SUnit *B);
147 
148   // Use SGMask to determine whether we can classify MI as a member of this
149   // SchedGroup object.
150   bool canAddMI(const MachineInstr &MI) const;
151 
152 public:
153   // Collection of SUnits that are classified as members of this group.
154   SmallVector<SUnit *, 32> Collection;
155 
156   ScheduleDAGInstrs *DAG;
157 
158   // Returns true if SU can be added to this SchedGroup.
159   bool canAddSU(SUnit &SU) const;
160 
161   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
162   // MakePred is true, SU will be a predecessor of the SUnits in this
163   // SchedGroup, otherwise SU will be a successor.
164   void link(SUnit &SU, bool MakePred = false);
165 
166   // Add DAG dependencies and track which edges are added, and the count of
167   // missed edges
168   int link(SUnit &SU, bool MakePred,
169            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
170 
171   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
172   // Use the predicate to determine whether SU should be a predecessor (P =
173   // true) or a successor (P = false) of this SchedGroup.
174   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
175 
176   // Add DAG dependencies such that SUnits in this group shall be ordered
177   // before SUnits in OtherGroup.
178   void link(SchedGroup &OtherGroup);
179 
180   // Returns true if no more instructions may be added to this group.
181   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
182 
183   // Append a constraint that SUs must meet in order to fit into this
184   // SchedGroup. Since many rules involve the relationship between a SchedGroup
185   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
186   // time (rather than SchedGroup init time.)
187   void addRule(std::shared_ptr<InstructionRule> NewRule) {
188     Rules.push_back(NewRule);
189   }
190 
191   // Returns true if the SU matches all rules
192   bool allowedByRules(const SUnit *SU,
193                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
194     if (Rules.empty())
195       return true;
196     for (size_t I = 0; I < Rules.size(); I++) {
197       auto TheRule = Rules[I].get();
198       if (!TheRule->apply(SU, Collection, SyncPipe)) {
199         return false;
200       }
201     }
202     return true;
203   }
204 
205   // Add SU to the SchedGroup.
206   void add(SUnit &SU) {
207     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
208                       << format_hex((int)SGMask, 10, true) << " adding "
209                       << *SU.getInstr());
210     Collection.push_back(&SU);
211   }
212 
213   // Remove last element in the SchedGroup
214   void pop() { Collection.pop_back(); }
215 
216   // Identify and add all relevant SUs from the DAG to this SchedGroup.
217   void initSchedGroup();
218 
219   // Add instructions to the SchedGroup bottom up starting from RIter.
220   // PipelineInstrs is a set of instructions that should not be added to the
221   // SchedGroup even when the other conditions for adding it are satisfied.
222   // RIter will be added to the SchedGroup as well, and dependencies will be
223   // added so that RIter will always be scheduled at the end of the group.
224   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
225                       SUnitsToCandidateSGsMap &SyncedInstrs);
226 
227   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
228 
229   int getSyncID() { return SyncID; }
230 
231   int getSGID() { return SGID; }
232 
233   SchedGroupMask getMask() { return SGMask; }
234 
235   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
236              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
237       : SGMask(SGMask), MaxSize(MaxSize), TII(TII), DAG(DAG) {
238     SGID = NumSchedGroups++;
239   }
240 
241   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
242              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
243       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), TII(TII), DAG(DAG) {
244     SGID = NumSchedGroups++;
245   }
246 };
247 
248 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
249 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
250   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
251          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
252          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
253 
254   while (!SU.Preds.empty())
255     for (auto &P : SU.Preds)
256       SU.removePred(P);
257 
258   while (!SU.Succs.empty())
259     for (auto &S : SU.Succs)
260       for (auto &SP : S.getSUnit()->Preds)
261         if (SP.getSUnit() == &SU)
262           S.getSUnit()->removePred(SP);
263 }
264 
265 typedef std::pair<SUnit *, SmallVector<int, 4>> SUToCandSGsPair;
266 typedef SmallVector<SUToCandSGsPair, 4> SUsToCandSGsVec;
267 
268 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
269 // in non-trivial cases. For example, if the requested pipeline is
270 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
271 // in the DAG, then we will have an instruction that can not be trivially
272 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
273 // to find a good solution to the pipeline -- a greedy algorithm and an exact
274 // algorithm. The exact algorithm has an exponential time complexity and should
275 // only be used for small sized problems or medium sized problems where an exact
276 // solution is highly desired.
277 class PipelineSolver {
278   ScheduleDAGMI *DAG;
279 
280   // Instructions that can be assigned to multiple SchedGroups
281   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
282   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
283   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
284   // The current working pipeline
285   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
286   // The pipeline that has the best solution found so far
287   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
288 
289   // Whether or not we actually have any SyncedInstrs to try to solve.
290   bool NeedsSolver = false;
291 
292   // Compute an estimate of the size of search tree -- the true size is
293   // the product of each conflictedInst.Matches.size() across all SyncPipelines
294   unsigned computeProblemSize();
295 
296   // The cost penalty of not assigning a SU to a SchedGroup
297   int MissPenalty = 0;
298 
299   // Costs in terms of the number of edges we are unable to add
300   int BestCost = -1;
301   int CurrCost = 0;
302 
303   // Index pointing to the conflicting instruction that is currently being
304   // fitted
305   int CurrConflInstNo = 0;
306   // Index to the pipeline that is currently being fitted
307   int CurrSyncGroupIdx = 0;
308   // The first non trivial pipeline
309   int BeginSyncGroupIdx = 0;
310 
311   // How many branches we have explored
312   uint64_t BranchesExplored = 0;
313 
314   // The direction in which we process the candidate SchedGroups per SU
315   bool IsBottomUp = 1;
316 
317   // Update indices to fit next conflicting instruction
318   void advancePosition();
319   // Recede indices to attempt to find better fit for previous conflicting
320   // instruction
321   void retreatPosition();
322 
323   // The exponential time algorithm which finds the provably best fit
324   bool solveExact();
325   // The polynomial time algorithm which attempts to find a good fit
326   bool solveGreedy();
327   // Find the best SchedGroup for the current SU using the heuristic given all
328   // current information. One step in the greedy algorithm. Templated against
329   // the SchedGroup iterator (either reverse or forward).
330   template <typename T>
331   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
332                   T E);
333   // Whether or not the current solution is optimal
334   bool checkOptimal();
335   // Populate the ready list, prioiritizing fewest missed edges first
336   // Templated against the SchedGroup iterator (either reverse or forward).
337   template <typename T>
338   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
339                          T E);
340   // Add edges corresponding to the SchedGroups as assigned by solver
341   void makePipeline();
342   // Link the SchedGroups in the best found pipeline.
343   // Tmplated against the SchedGroup iterator (either reverse or forward).
344   template <typename T> void linkSchedGroups(T I, T E);
345   // Add the edges from the SU to the other SchedGroups in pipeline, and
346   // return the number of edges missed.
347   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
348                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
349   /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
350   /// returns the cost (in terms of missed pipeline edges), and tracks the edges
351   /// added in \p AddedEdges
352   template <typename T>
353   int linkSUnit(SUnit *SU, int SGID,
354                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
355   /// Remove the edges passed via \p AddedEdges
356   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
357   // Convert the passed in maps to arrays for bidirectional iterators
358   void convertSyncMapsToArrays();
359 
360   void reset();
361 
362 public:
363   // Invoke the solver to map instructions to instruction groups. Heuristic &&
364   // command-line-option determines to use exact or greedy algorithm.
365   void solve();
366 
367   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
368                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
369                  ScheduleDAGMI *DAG, bool IsBottomUp = 1)
370       : DAG(DAG), SyncedInstrs(SyncedInstrs),
371         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
372 
373     for (auto &PipelineInstrs : SyncedInstrs) {
374       if (PipelineInstrs.second.size() > 0) {
375         NeedsSolver = true;
376         break;
377       }
378     }
379 
380     if (!NeedsSolver)
381       return;
382 
383     convertSyncMapsToArrays();
384 
385     CurrPipeline = BestPipeline;
386 
387     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
388            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
389       ++BeginSyncGroupIdx;
390 
391     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
392       return;
393   }
394 };
395 
396 void PipelineSolver::reset() {
397 
398   for (auto &SyncPipeline : CurrPipeline) {
399     for (auto &SG : SyncPipeline) {
400       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
401       SG.Collection.clear();
402       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
403         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
404       });
405       if (SchedBarr != TempCollection.end())
406         SG.Collection.push_back(*SchedBarr);
407     }
408   }
409 
410   CurrSyncGroupIdx = BeginSyncGroupIdx;
411   CurrConflInstNo = 0;
412   CurrCost = 0;
413 }
414 
415 void PipelineSolver::convertSyncMapsToArrays() {
416   for (auto &SyncPipe : SyncedSchedGroups) {
417     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
418   }
419 
420   int PipelineIDx = SyncedInstrs.size() - 1;
421   PipelineInstrs.resize(SyncedInstrs.size());
422   for (auto &SyncInstrMap : SyncedInstrs) {
423     for (auto &SUsToCandSGs : SyncInstrMap.second) {
424       if (PipelineInstrs[PipelineIDx].size() == 0) {
425         PipelineInstrs[PipelineIDx].push_back(
426             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
427         continue;
428       }
429       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
430       // Insert them in sorted order -- this allows for good parsing order in
431       // the greedy algorithm
432       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
433              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
434         ++SortPosition;
435       PipelineInstrs[PipelineIDx].insert(
436           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
437     }
438     --PipelineIDx;
439   }
440 }
441 
442 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
443   for (; I != E; ++I) {
444     auto &GroupA = *I;
445     for (auto J = std::next(I); J != E; ++J) {
446       auto &GroupB = *J;
447       GroupA.link(GroupB);
448     }
449   }
450 }
451 
452 void PipelineSolver::makePipeline() {
453   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
454   for (auto &SyncPipeline : BestPipeline) {
455     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
456     for (auto &SG : SyncPipeline) {
457       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
458                         << " has: \n");
459       SUnit *SGBarr = nullptr;
460       for (auto &SU : SG.Collection) {
461         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
462           SGBarr = SU;
463         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
464       }
465       // Command line requested IGroupLP doesn't have SGBarr
466       if (!SGBarr)
467         continue;
468       resetEdges(*SGBarr, DAG);
469       SG.link(*SGBarr, false);
470     }
471   }
472 
473   for (auto &SyncPipeline : BestPipeline) {
474     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
475                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
476   }
477 }
478 
479 template <typename T>
480 int PipelineSolver::linkSUnit(
481     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
482     T I, T E) {
483   bool MakePred = false;
484   int AddedCost = 0;
485   for (; I < E; ++I) {
486     if (I->getSGID() == SGID) {
487       MakePred = true;
488       continue;
489     }
490     auto Group = *I;
491     AddedCost += Group.link(*SU, MakePred, AddedEdges);
492     assert(AddedCost >= 0);
493   }
494   return AddedCost;
495 }
496 
497 int PipelineSolver::addEdges(
498     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
499     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
500 
501   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
502   // instructions that are the ultimate successors in the resultant mutation.
503   // Therefore, in such a configuration, the SchedGroups occurring before the
504   // candidate SGID are successors of the candidate SchedGroup, thus the current
505   // SU should be linked as a predecessor to SUs in those SchedGroups. The
506   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
507   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
508   // IsBottomUp (in reverse).
509   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
510                                 SyncPipeline.rend())
511                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
512                                 SyncPipeline.end());
513 }
514 
515 void PipelineSolver::removeEdges(
516     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
517   // Only remove the edges that we have added when testing
518   // the fit.
519   for (auto &PredSuccPair : EdgesToRemove) {
520     SUnit *Pred = PredSuccPair.first;
521     SUnit *Succ = PredSuccPair.second;
522 
523     auto Match = llvm::find_if(
524         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
525     if (Match != Succ->Preds.end()) {
526       assert(Match->isArtificial());
527       Succ->removePred(*Match);
528     }
529   }
530 }
531 
532 void PipelineSolver::advancePosition() {
533   ++CurrConflInstNo;
534 
535   if (static_cast<size_t>(CurrConflInstNo) >=
536       PipelineInstrs[CurrSyncGroupIdx].size()) {
537     CurrConflInstNo = 0;
538     ++CurrSyncGroupIdx;
539     // Advance to next non-trivial pipeline
540     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
541            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
542       ++CurrSyncGroupIdx;
543   }
544 }
545 
546 void PipelineSolver::retreatPosition() {
547   assert(CurrConflInstNo >= 0);
548   assert(CurrSyncGroupIdx >= 0);
549 
550   if (CurrConflInstNo > 0) {
551     --CurrConflInstNo;
552     return;
553   }
554 
555   if (CurrConflInstNo == 0) {
556     // If we return to the starting position, we have explored
557     // the entire tree
558     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
559       return;
560 
561     --CurrSyncGroupIdx;
562     // Go to previous non-trivial pipeline
563     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
564       --CurrSyncGroupIdx;
565 
566     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
567   }
568 }
569 
570 bool PipelineSolver::checkOptimal() {
571   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
572     if (BestCost == -1 || CurrCost < BestCost) {
573       BestPipeline = CurrPipeline;
574       BestCost = CurrCost;
575       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
576     }
577     assert(BestCost >= 0);
578   }
579 
580   bool DoneExploring = false;
581   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
582     DoneExploring = true;
583 
584   return (DoneExploring || BestCost == 0);
585 }
586 
587 template <typename T>
588 void PipelineSolver::populateReadyList(
589     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
590   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
591   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
592   assert(CurrSU.second.size() >= 1);
593 
594   for (; I != E; ++I) {
595     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
596     int CandSGID = *I;
597     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
598       return SG.getSGID() == CandSGID;
599     });
600     assert(Match);
601 
602     if (UseCostHeur) {
603       if (Match->isFull()) {
604         ReadyList.push_back(std::pair(*I, MissPenalty));
605         continue;
606       }
607 
608       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
609       ReadyList.push_back(std::pair(*I, TempCost));
610       removeEdges(AddedEdges);
611     } else
612       ReadyList.push_back(std::pair(*I, -1));
613   }
614 
615   if (UseCostHeur) {
616     std::sort(ReadyList.begin(), ReadyList.end(),
617               [](std::pair<int, int> A, std::pair<int, int> B) {
618                 return A.second < B.second;
619               });
620   }
621 
622   assert(ReadyList.size() == CurrSU.second.size());
623 }
624 
625 bool PipelineSolver::solveExact() {
626   if (checkOptimal())
627     return true;
628 
629   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
630     return false;
631 
632   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
633   assert(static_cast<size_t>(CurrConflInstNo) <
634          PipelineInstrs[CurrSyncGroupIdx].size());
635   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
636   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
637                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
638 
639   // SchedGroup -> Cost pairs
640   SmallVector<std::pair<int, int>, 4> ReadyList;
641   // Prioritize the candidate sched groups in terms of lowest cost first
642   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
643                                  CurrSU.second.rend())
644              : populateReadyList(ReadyList, CurrSU.second.begin(),
645                                  CurrSU.second.end());
646 
647   auto I = ReadyList.begin();
648   auto E = ReadyList.end();
649   for (; I != E; ++I) {
650     // If we are trying SGs in least cost order, and the current SG is cost
651     // infeasible, then all subsequent SGs will also be cost infeasible, so we
652     // can prune.
653     if (BestCost != -1 && (CurrCost + I->second > BestCost))
654       return false;
655 
656     int CandSGID = I->first;
657     int AddedCost = 0;
658     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
659     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
660     SchedGroup *Match;
661     for (auto &SG : SyncPipeline) {
662       if (SG.getSGID() == CandSGID)
663         Match = &SG;
664     }
665 
666     if (Match->isFull())
667       continue;
668 
669     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
670       continue;
671 
672     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
673                       << (int)Match->getMask() << "and ID " << CandSGID
674                       << "\n");
675     Match->add(*CurrSU.first);
676     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
677     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
678     CurrCost += AddedCost;
679     advancePosition();
680     ++BranchesExplored;
681     bool FinishedExploring = false;
682     // If the Cost after adding edges is greater than a known solution,
683     // backtrack
684     if (CurrCost < BestCost || BestCost == -1) {
685       if (solveExact()) {
686         FinishedExploring = BestCost != 0;
687         if (!FinishedExploring)
688           return true;
689       }
690     }
691 
692     retreatPosition();
693     CurrCost -= AddedCost;
694     removeEdges(AddedEdges);
695     Match->pop();
696     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
697     if (FinishedExploring)
698       return true;
699   }
700 
701   // Try the pipeline where the current instruction is omitted
702   // Potentially if we omit a problematic instruction from the pipeline,
703   // all the other instructions can nicely fit.
704   CurrCost += MissPenalty;
705   advancePosition();
706 
707   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
708 
709   bool FinishedExploring = false;
710   if (CurrCost < BestCost || BestCost == -1) {
711     if (solveExact()) {
712       bool FinishedExploring = BestCost != 0;
713       if (!FinishedExploring)
714         return true;
715     }
716   }
717 
718   retreatPosition();
719   CurrCost -= MissPenalty;
720   return FinishedExploring;
721 }
722 
723 template <typename T>
724 void PipelineSolver::greedyFind(
725     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
726   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
727   int BestNodeCost = -1;
728   int TempCost;
729   SchedGroup *BestGroup = nullptr;
730   int BestGroupID = -1;
731   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
732   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
733                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
734 
735   // Since we have added the potential SchedGroups from bottom up, but
736   // traversed the DAG from top down, parse over the groups from last to
737   // first. If we fail to do this for the greedy algorithm, the solution will
738   // likely not be good in more complex cases.
739   for (; I != E; ++I) {
740     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
741     int CandSGID = *I;
742     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
743       return SG.getSGID() == CandSGID;
744     });
745     assert(Match);
746 
747     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
748                       << (int)Match->getMask() << "\n");
749 
750     if (Match->isFull()) {
751       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
752       continue;
753     }
754     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
755       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
756       continue;
757     }
758     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
759     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
760     if (TempCost < BestNodeCost || BestNodeCost == -1) {
761       BestGroup = Match;
762       BestNodeCost = TempCost;
763       BestGroupID = CandSGID;
764     }
765     removeEdges(AddedEdges);
766     if (BestNodeCost == 0)
767       break;
768   }
769 
770   if (BestGroupID != -1) {
771     BestGroup->add(*CurrSU.first);
772     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
773     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
774                       << (int)BestGroup->getMask() << "\n");
775     BestCost += TempCost;
776   } else
777     BestCost += MissPenalty;
778 
779   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
780 }
781 
782 bool PipelineSolver::solveGreedy() {
783   BestCost = 0;
784   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
785 
786   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
787     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
788     IsBottomUp
789         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
790         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
791     advancePosition();
792   }
793   BestPipeline = CurrPipeline;
794   removeEdges(AddedEdges);
795   return false;
796 }
797 
798 unsigned PipelineSolver::computeProblemSize() {
799   unsigned ProblemSize = 0;
800   for (auto &PipeConflicts : PipelineInstrs) {
801     ProblemSize += PipeConflicts.size();
802   }
803 
804   return ProblemSize;
805 }
806 
807 void PipelineSolver::solve() {
808   if (!NeedsSolver)
809     return;
810 
811   unsigned ProblemSize = computeProblemSize();
812   assert(ProblemSize > 0);
813 
814   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
815   MissPenalty = (ProblemSize / 2) + 1;
816 
817   LLVM_DEBUG(DAG->dump());
818   if (EnableExactSolver || BelowCutoff) {
819     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
820     solveGreedy();
821     reset();
822     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
823     if (BestCost > 0) {
824       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
825       solveExact();
826       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
827     }
828   } else { // Use the Greedy Algorithm by default
829     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
830     solveGreedy();
831   }
832 
833   makePipeline();
834   LLVM_DEBUG(dbgs() << "After applying mutation\n");
835   LLVM_DEBUG(DAG->dump());
836 }
837 
838 enum IGLPStrategyID : int {
839   MFMASmallGemmOptID = 0,
840   MFMASmallGemmSingleWaveOptID = 1,
841 };
842 
843 // Implement a IGLP scheduling strategy.
844 class IGLPStrategy {
845 protected:
846   ScheduleDAGInstrs *DAG;
847 
848   const SIInstrInfo *TII;
849 
850 public:
851   /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
852   virtual void applyIGLPStrategy(
853       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
854       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
855       bool IsReentry) = 0;
856 
857   // Returns true if this strategy should be applied to a ScheduleDAG.
858   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) = 0;
859 
860   bool IsBottomUp = 1;
861 
862   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
863       : DAG(DAG), TII(TII) {}
864 
865   virtual ~IGLPStrategy() = default;
866 };
867 
868 class MFMASmallGemmOpt final : public IGLPStrategy {
869 private:
870 public:
871   void applyIGLPStrategy(
872       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
873       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
874       bool IsReentry) override;
875 
876   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
877 
878   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
879       : IGLPStrategy(DAG, TII) {
880     IsBottomUp = 1;
881   }
882 };
883 
884 void MFMASmallGemmOpt::applyIGLPStrategy(
885     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
886     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
887     bool IsReentry) {
888   // Count the number of MFMA instructions.
889   unsigned MFMACount = 0;
890   for (const MachineInstr &I : *DAG)
891     if (TII->isMFMAorWMMA(I))
892       ++MFMACount;
893 
894   const unsigned PipelineSyncID = 0;
895   SchedGroup *SG = nullptr;
896   for (unsigned I = 0; I < MFMACount * 3; ++I) {
897     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
898         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
899     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
900 
901     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
902         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
903     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
904   }
905 }
906 
907 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
908 private:
909   // Whether the DS_READ is a predecessor of first four MFMA in region
910   class EnablesInitialMFMA final : public InstructionRule {
911   public:
912     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
913                SmallVectorImpl<SchedGroup> &SyncPipe) override {
914       if (!SyncPipe.size())
915         return false;
916       int MFMAsFound = 0;
917       if (!Cache->size()) {
918         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
919           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
920             ++MFMAsFound;
921             if (MFMAsFound > 4)
922               break;
923             Cache->push_back(&Elt);
924           }
925         }
926       }
927 
928       assert(Cache->size());
929       auto DAG = SyncPipe[0].DAG;
930       for (auto &Elt : *Cache) {
931         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
932           return true;
933       }
934       return false;
935     }
936 
937     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
938                        bool NeedsCache = false)
939         : InstructionRule(TII, SGID, NeedsCache) {}
940   };
941 
942   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
943   class IsPermForDSW final : public InstructionRule {
944   public:
945     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
946                SmallVectorImpl<SchedGroup> &SyncPipe) override {
947       auto MI = SU->getInstr();
948       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
949         return false;
950 
951       bool FitsInGroup = false;
952       // Does the VALU have a DS_WRITE successor
953       if (!Collection.size()) {
954         for (auto &Succ : SU->Succs) {
955           SUnit *SuccUnit = Succ.getSUnit();
956           if (TII->isDS(*SuccUnit->getInstr()) &&
957               SuccUnit->getInstr()->mayStore()) {
958             Cache->push_back(SuccUnit);
959             FitsInGroup = true;
960           }
961         }
962         return FitsInGroup;
963       }
964 
965       assert(Cache->size());
966 
967       // Does the VALU have a DS_WRITE successor that is the same as other
968       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
969       return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
970         return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
971           return ThisSucc.getSUnit() == Elt;
972         });
973       });
974     }
975 
976     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
977         : InstructionRule(TII, SGID, NeedsCache) {}
978   };
979 
980   // Whether the SU is a successor of any element in previous SchedGroup
981   class IsSuccOfPrevGroup final : public InstructionRule {
982   public:
983     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
984                SmallVectorImpl<SchedGroup> &SyncPipe) override {
985       SchedGroup *OtherGroup = nullptr;
986       for (auto &PipeSG : SyncPipe) {
987         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
988           OtherGroup = &PipeSG;
989         }
990       }
991 
992       if (!OtherGroup)
993         return false;
994       if (!OtherGroup->Collection.size())
995         return true;
996 
997       // Does the previous VALU have this DS_Write as a successor
998       return (std::any_of(OtherGroup->Collection.begin(),
999                           OtherGroup->Collection.end(), [&SU](SUnit *Elt) {
1000                             return std::any_of(Elt->Succs.begin(),
1001                                                Elt->Succs.end(),
1002                                                [&SU](SDep &Succ) {
1003                                                  return Succ.getSUnit() == SU;
1004                                                });
1005                           }));
1006     }
1007     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1008                       bool NeedsCache = false)
1009         : InstructionRule(TII, SGID, NeedsCache) {}
1010   };
1011 
1012   // Whether the combined load width of group is 128 bits
1013   class VMEMSize final : public InstructionRule {
1014   public:
1015     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1016                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1017       auto MI = SU->getInstr();
1018       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1019         return false;
1020       if (!Collection.size())
1021         return true;
1022 
1023       int NumBits = 0;
1024 
1025       auto TRI = TII->getRegisterInfo();
1026       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1027       for (auto &Elt : Collection) {
1028         auto Op = Elt->getInstr()->getOperand(0);
1029         auto Size =
1030             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1031         NumBits += Size;
1032       }
1033 
1034       if (NumBits < 128) {
1035         assert(TII->isVMEM(*MI) && MI->mayLoad());
1036         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1037                           MRI, MI->getOperand(0))) <=
1038             128)
1039           return true;
1040       }
1041 
1042       return false;
1043     }
1044 
1045     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1046         : InstructionRule(TII, SGID, NeedsCache) {}
1047   };
1048 
1049   /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1050   /// that is \p Distance steps away
1051   class SharesPredWithPrevNthGroup final : public InstructionRule {
1052   private:
1053     unsigned Distance = 1;
1054 
1055   public:
1056     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1057                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1058       SchedGroup *OtherGroup = nullptr;
1059       if (!SyncPipe.size())
1060         return false;
1061 
1062       if (!Cache->size()) {
1063 
1064         for (auto &PipeSG : SyncPipe) {
1065           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
1066             OtherGroup = &PipeSG;
1067           }
1068         }
1069 
1070         if (!OtherGroup)
1071           return false;
1072         if (!OtherGroup->Collection.size())
1073           return true;
1074 
1075         for (auto &OtherEle : OtherGroup->Collection) {
1076           for (auto &Pred : OtherEle->Preds) {
1077             if (Pred.getSUnit()->getInstr()->getOpcode() ==
1078                 AMDGPU::V_PERM_B32_e64)
1079               Cache->push_back(Pred.getSUnit());
1080           }
1081         }
1082 
1083         // If the other group has no PERM preds, then this group won't share any
1084         if (!Cache->size())
1085           return false;
1086       }
1087 
1088       auto DAG = SyncPipe[0].DAG;
1089       // Does the previous DS_WRITE share a V_PERM predecessor with this
1090       // VMEM_READ
1091       return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
1092         return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
1093       });
1094     }
1095     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1096                                unsigned SGID, bool NeedsCache = false)
1097         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1098   };
1099 
1100 public:
1101   void applyIGLPStrategy(
1102       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1103       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1104       bool IsReentry) override;
1105 
1106   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
1107 
1108   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1109       : IGLPStrategy(DAG, TII) {
1110     IsBottomUp = 0;
1111   }
1112 };
1113 
1114 static unsigned DSWCount = 0;
1115 static unsigned DSWWithPermCount = 0;
1116 static unsigned DSWWithSharedVMEMCount = 0;
1117 
1118 void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1119     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1120     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1121     bool IsReentry) {
1122   unsigned MFMACount = 0;
1123   unsigned DSRCount = 0;
1124 
1125   assert((IsReentry || (DSWCount == 0 && DSWWithPermCount == 0 &&
1126                         DSWWithSharedVMEMCount == 0)) &&
1127          "DSWCounters should be zero in pre-RA scheduling!");
1128   SmallVector<SUnit *, 6> DSWithPerms;
1129   for (auto &SU : DAG->SUnits) {
1130     auto I = SU.getInstr();
1131     if (TII->isMFMAorWMMA(*I))
1132       ++MFMACount;
1133     else if (TII->isDS(*I)) {
1134       if (I->mayLoad())
1135         ++DSRCount;
1136       else if (I->mayStore() && !IsReentry) {
1137         ++DSWCount;
1138         for (auto Pred : SU.Preds) {
1139           if (Pred.getSUnit()->getInstr()->getOpcode() ==
1140               AMDGPU::V_PERM_B32_e64) {
1141             DSWithPerms.push_back(&SU);
1142             break;
1143           }
1144         }
1145       }
1146     }
1147   }
1148 
1149   if (!IsReentry) {
1150     DSWWithPermCount = DSWithPerms.size();
1151     auto I = DSWithPerms.begin();
1152     auto E = DSWithPerms.end();
1153 
1154     // Get the count of DS_WRITES with V_PERM predecessors which
1155     // have loop carried dependencies (WAR) on the same VMEM_READs.
1156     // We consider partial overlap as a miss -- in other words,
1157     // for a given DS_W, we only consider another DS_W as matching
1158     // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1159     // for every V_PERM pred of this DS_W.
1160     DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1161     SmallVector<SUnit *, 6> Counted;
1162     for (; I != E; I++) {
1163       SUnit *Cand = nullptr;
1164       bool MissedAny = false;
1165       for (auto &Pred : (*I)->Preds) {
1166         if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
1167           continue;
1168 
1169         if (Cand && llvm::is_contained(Counted, Cand))
1170           break;
1171 
1172         for (auto &Succ : Pred.getSUnit()->Succs) {
1173           auto MI = Succ.getSUnit()->getInstr();
1174           if (!TII->isVMEM(*MI) || !MI->mayLoad())
1175             continue;
1176 
1177           if (MissedAny || !VMEMLookup.size()) {
1178             MissedAny = true;
1179             VMEMLookup[MI] = *I;
1180             continue;
1181           }
1182 
1183           if (!VMEMLookup.contains(MI)) {
1184             MissedAny = true;
1185             VMEMLookup[MI] = *I;
1186             continue;
1187           }
1188 
1189           Cand = VMEMLookup[MI];
1190           if (llvm::is_contained(Counted, Cand)) {
1191             MissedAny = true;
1192             break;
1193           }
1194         }
1195       }
1196       if (!MissedAny && Cand) {
1197         DSWWithSharedVMEMCount += 2;
1198         Counted.push_back(Cand);
1199         Counted.push_back(*I);
1200       }
1201     }
1202   }
1203 
1204   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
1205   SchedGroup *SG;
1206   unsigned PipelineSyncID = 0;
1207   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
1208   if (DSWWithPermCount) {
1209     for (unsigned I = 0; I < MFMACount; I++) {
1210       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1211           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1212       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1213 
1214       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1215           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
1216       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1217     }
1218   }
1219 
1220   PipelineSyncID = 1;
1221   // Phase 1: Break up DS_READ and MFMA clusters.
1222   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
1223   // prefetch
1224 
1225   // Make ready initial MFMA
1226   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1227       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
1228   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
1229   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1230 
1231   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1232       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1233   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1234 
1235   // Interleave MFMA with DS_READ prefetch
1236   for (unsigned I = 0; I < DSRCount - 4; ++I) {
1237     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1238         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
1239     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1240 
1241     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1242         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1243     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1244   }
1245 
1246   // Phase 2a: Loop carried dependency with V_PERM
1247   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
1248   // depend on. Interleave MFMA to keep XDL unit busy throughout.
1249   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
1250     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1251         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1252     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1253     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1254 
1255     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1256         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1257     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1258     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1259 
1260     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1261         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1262     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1263         1, TII, SG->getSGID(), true));
1264     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1265     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1266 
1267     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1268         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1269     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1270 
1271     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1272         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1273     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1274         3, TII, SG->getSGID(), true));
1275     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1276     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1277 
1278     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1279         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1280     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1281   }
1282 
1283   // Phase 2b: Loop carried dependency without V_PERM
1284   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
1285   // Interleave MFMA to keep XDL unit busy throughout.
1286   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
1287     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1288         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1289     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1290 
1291     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1292         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1293     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1294     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1295 
1296     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1297         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1298     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1299   }
1300 
1301   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
1302   // ultimately used by two DS_WRITE
1303   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
1304   // depend on. Interleave MFMA to keep XDL unit busy throughout.
1305 
1306   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
1307     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1308         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1309     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1310     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1311 
1312     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1313         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1314     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1315     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1316 
1317     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1318         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1319     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1320 
1321     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1322         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1323     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1324     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1325 
1326     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1327         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1328     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1329     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1330 
1331     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1332         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1333     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1334 
1335     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1336         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1337     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1338         2, TII, SG->getSGID(), true));
1339     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1340     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1341 
1342     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1343         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1344     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1345 
1346     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1347         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1348     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1349         4, TII, SG->getSGID(), true));
1350     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1351     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1352 
1353     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1354         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1355     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1356   }
1357 }
1358 
1359 static std::unique_ptr<IGLPStrategy>
1360 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
1361                    const SIInstrInfo *TII) {
1362   switch (ID) {
1363   case MFMASmallGemmOptID:
1364     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
1365   case MFMASmallGemmSingleWaveOptID:
1366     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
1367   }
1368 
1369   llvm_unreachable("Unknown IGLPStrategyID");
1370 }
1371 
1372 class IGroupLPDAGMutation : public ScheduleDAGMutation {
1373 private:
1374   const SIInstrInfo *TII;
1375 
1376   ScheduleDAGMI *DAG;
1377 
1378   // Organize lists of SchedGroups by their SyncID. SchedGroups /
1379   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
1380   // between then.
1381   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
1382 
1383   // Used to track instructions that can be mapped to multiple sched groups
1384   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
1385 
1386   // Add DAG edges that enforce SCHED_BARRIER ordering.
1387   void addSchedBarrierEdges(SUnit &SU);
1388 
1389   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
1390   // not be reordered accross the SCHED_BARRIER. This is used for the base
1391   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
1392   // SCHED_BARRIER will always block all instructions that can be classified
1393   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
1394   // and may only synchronize with some SchedGroups. Returns the inverse of
1395   // Mask. SCHED_BARRIER's mask describes which instruction types should be
1396   // allowed to be scheduled across it. Invert the mask to get the
1397   // SchedGroupMask of instructions that should be barred.
1398   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
1399 
1400   // Create SchedGroups for a SCHED_GROUP_BARRIER.
1401   void initSchedGroupBarrierPipelineStage(
1402       std::vector<SUnit>::reverse_iterator RIter);
1403 
1404   void initIGLPOpt(SUnit &SU);
1405 
1406 public:
1407   void apply(ScheduleDAGInstrs *DAGInstrs) override;
1408 
1409   // The order in which the PipelineSolver should process the candidate
1410   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
1411   // created SchedGroup first, and will consider that as the ultimate
1412   // predecessor group when linking. TOP_DOWN instead links and processes the
1413   // first created SchedGroup first.
1414   bool IsBottomUp = 1;
1415 
1416   // Whether or not this is a reentry into the IGroupLPDAGMutation.
1417   bool IsReentry = false;
1418 
1419   IGroupLPDAGMutation() = default;
1420   IGroupLPDAGMutation(bool IsReentry) : IsReentry(IsReentry) {}
1421 };
1422 
1423 unsigned SchedGroup::NumSchedGroups = 0;
1424 
1425 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
1426   if (A != B && DAG->canAddEdge(B, A)) {
1427     DAG->addEdge(B, SDep(A, SDep::Artificial));
1428     return true;
1429   }
1430   return false;
1431 }
1432 
1433 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1434   bool Result = false;
1435   if (MI.isMetaInstruction())
1436     Result = false;
1437 
1438   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
1439            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
1440             TII->isTRANS(MI)))
1441     Result = true;
1442 
1443   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
1444            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
1445     Result = true;
1446 
1447   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
1448            TII->isSALU(MI))
1449     Result = true;
1450 
1451   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
1452            TII->isMFMAorWMMA(MI))
1453     Result = true;
1454 
1455   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
1456            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1457     Result = true;
1458 
1459   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
1460            MI.mayLoad() &&
1461            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1462     Result = true;
1463 
1464   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
1465            MI.mayStore() &&
1466            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1467     Result = true;
1468 
1469   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
1470            TII->isDS(MI))
1471     Result = true;
1472 
1473   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
1474            MI.mayLoad() && TII->isDS(MI))
1475     Result = true;
1476 
1477   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
1478            MI.mayStore() && TII->isDS(MI))
1479     Result = true;
1480 
1481   else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
1482            TII->isTRANS(MI))
1483     Result = true;
1484 
1485   LLVM_DEBUG(
1486       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
1487              << (Result ? " could classify " : " unable to classify ") << MI);
1488 
1489   return Result;
1490 }
1491 
1492 int SchedGroup::link(SUnit &SU, bool MakePred,
1493                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
1494   int MissedEdges = 0;
1495   for (auto *A : Collection) {
1496     SUnit *B = &SU;
1497     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1498       continue;
1499     if (MakePred)
1500       std::swap(A, B);
1501 
1502     if (DAG->IsReachable(B, A))
1503       continue;
1504 
1505     // tryAddEdge returns false if there is a dependency that makes adding
1506     // the A->B edge impossible, otherwise it returns true;
1507     bool Added = tryAddEdge(A, B);
1508     if (Added)
1509       AddedEdges.push_back(std::pair(A, B));
1510     else
1511       ++MissedEdges;
1512   }
1513 
1514   return MissedEdges;
1515 }
1516 
1517 void SchedGroup::link(SUnit &SU, bool MakePred) {
1518   for (auto *A : Collection) {
1519     SUnit *B = &SU;
1520     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1521       continue;
1522     if (MakePred)
1523       std::swap(A, B);
1524 
1525     tryAddEdge(A, B);
1526   }
1527 }
1528 
1529 void SchedGroup::link(SUnit &SU,
1530                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
1531   for (auto *A : Collection) {
1532     SUnit *B = &SU;
1533     if (P(A, B))
1534       std::swap(A, B);
1535 
1536     tryAddEdge(A, B);
1537   }
1538 }
1539 
1540 void SchedGroup::link(SchedGroup &OtherGroup) {
1541   for (auto *B : OtherGroup.Collection)
1542     link(*B);
1543 }
1544 
1545 bool SchedGroup::canAddSU(SUnit &SU) const {
1546   MachineInstr &MI = *SU.getInstr();
1547   if (MI.getOpcode() != TargetOpcode::BUNDLE)
1548     return canAddMI(MI);
1549 
1550   // Special case for bundled MIs.
1551   const MachineBasicBlock *MBB = MI.getParent();
1552   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
1553   while (E != MBB->end() && E->isBundledWithPred())
1554     ++E;
1555 
1556   // Return true if all of the bundled MIs can be added to this group.
1557   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
1558 }
1559 
1560 void SchedGroup::initSchedGroup() {
1561   for (auto &SU : DAG->SUnits) {
1562     if (isFull())
1563       break;
1564 
1565     if (canAddSU(SU))
1566       add(SU);
1567   }
1568 }
1569 
1570 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
1571                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
1572   SUnit &InitSU = *RIter;
1573   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
1574     auto &SU = *RIter;
1575     if (isFull())
1576       break;
1577 
1578     if (canAddSU(SU))
1579       SyncedInstrs[&SU].push_back(SGID);
1580   }
1581 
1582   add(InitSU);
1583   assert(MaxSize);
1584   (*MaxSize)++;
1585 }
1586 
1587 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
1588   auto I = DAG->SUnits.rbegin();
1589   auto E = DAG->SUnits.rend();
1590   for (; I != E; ++I) {
1591     auto &SU = *I;
1592     if (isFull())
1593       break;
1594 
1595     if (canAddSU(SU))
1596       SyncedInstrs[&SU].push_back(SGID);
1597   }
1598 }
1599 
1600 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
1601   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
1602   if (!TSchedModel || DAGInstrs->SUnits.empty())
1603     return;
1604 
1605   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
1606   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
1607   TII = ST.getInstrInfo();
1608   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
1609   SyncedSchedGroups.clear();
1610   SyncedInstrs.clear();
1611   bool foundSB = false;
1612   bool foundIGLP = false;
1613   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
1614     unsigned Opc = R->getInstr()->getOpcode();
1615     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
1616     if (Opc == AMDGPU::SCHED_BARRIER) {
1617       addSchedBarrierEdges(*R);
1618       foundSB = true;
1619     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
1620       initSchedGroupBarrierPipelineStage(R);
1621       foundSB = true;
1622     } else if (Opc == AMDGPU::IGLP_OPT) {
1623       resetEdges(*R, DAG);
1624       if (!foundSB && !foundIGLP)
1625         initIGLPOpt(*R);
1626       foundIGLP = true;
1627     }
1628   }
1629 
1630   if (foundSB || foundIGLP) {
1631     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
1632     // PipelineSolver performs the mutation by adding the edges it
1633     // determined as the best
1634     PS.solve();
1635   }
1636 }
1637 
1638 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
1639   MachineInstr &MI = *SchedBarrier.getInstr();
1640   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
1641   // Remove all existing edges from the SCHED_BARRIER that were added due to the
1642   // instruction having side effects.
1643   resetEdges(SchedBarrier, DAG);
1644   LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
1645                     << MI.getOperand(0).getImm() << "\n");
1646   auto InvertedMask =
1647       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
1648   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
1649   SG.initSchedGroup();
1650 
1651   // Preserve original instruction ordering relative to the SCHED_BARRIER.
1652   SG.link(
1653       SchedBarrier,
1654       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
1655           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
1656 }
1657 
1658 SchedGroupMask
1659 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1660   // Invert mask and erase bits for types of instructions that are implied to be
1661   // allowed past the SCHED_BARRIER.
1662   SchedGroupMask InvertedMask = ~Mask;
1663 
1664   // ALU implies VALU, SALU, MFMA, TRANS.
1665   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1666     InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
1667                     ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
1668   // VALU, SALU, MFMA, TRANS implies ALU.
1669   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1670            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1671            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
1672            (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
1673     InvertedMask &= ~SchedGroupMask::ALU;
1674 
1675   // VMEM implies VMEM_READ, VMEM_WRITE.
1676   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
1677     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
1678   // VMEM_READ, VMEM_WRITE implies VMEM.
1679   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
1680            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
1681     InvertedMask &= ~SchedGroupMask::VMEM;
1682 
1683   // DS implies DS_READ, DS_WRITE.
1684   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
1685     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
1686   // DS_READ, DS_WRITE implies DS.
1687   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
1688            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1689     InvertedMask &= ~SchedGroupMask::DS;
1690 
1691   LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
1692                     << "\n");
1693 
1694   return InvertedMask;
1695 }
1696 
1697 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
1698     std::vector<SUnit>::reverse_iterator RIter) {
1699   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
1700   // to the instruction having side effects.
1701   resetEdges(*RIter, DAG);
1702   MachineInstr &SGB = *RIter->getInstr();
1703   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
1704   int32_t SGMask = SGB.getOperand(0).getImm();
1705   int32_t Size = SGB.getOperand(1).getImm();
1706   int32_t SyncID = SGB.getOperand(2).getImm();
1707 
1708   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
1709                                                     Size, SyncID, DAG, TII);
1710 
1711   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
1712 }
1713 
1714 void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1715   IGLPStrategyID StrategyID =
1716       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
1717   auto S = createIGLPStrategy(StrategyID, DAG, TII);
1718   if (S->shouldApplyStrategy(DAG)) {
1719     IsBottomUp = S->IsBottomUp;
1720     S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, IsReentry);
1721   }
1722 }
1723 
1724 } // namespace
1725 
1726 namespace llvm {
1727 
1728 /// \p IsReentry specifes whether or not this is a reentry into the
1729 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
1730 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
1731 /// scheduling "phases"), we can reenter this mutation framework more than once
1732 /// for a given region.
1733 std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation(bool IsReentry) {
1734   return std::make_unique<IGroupLPDAGMutation>(IsReentry);
1735 }
1736 
1737 } // end namespace llvm
1738