xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision 63f537551380d2dab29fa402ad1269feae17e594)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "igrouplp"
31 
32 namespace {
33 
34 static cl::opt<bool> EnableExactSolver(
35     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36     cl::desc("Whether to use the exponential time solver to fit "
37              "the instructions to the pipeline as closely as "
38              "possible."),
39     cl::init(false));
40 
41 static cl::opt<unsigned> CutoffForExact(
42     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43     cl::desc("The maximum number of scheduling group conflicts "
44              "which we attempt to solve with the exponential time "
45              "exact solver. Problem sizes greater than this will"
46              "be solved by the less accurate greedy algorithm. Selecting "
47              "solver by size is superseded by manually selecting "
48              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49 
50 static cl::opt<uint64_t> MaxBranchesExplored(
51     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52     cl::desc("The amount of branches that we are willing to explore with"
53              "the exact algorithm before giving up."));
54 
55 static cl::opt<bool> UseCostHeur(
56     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57     cl::desc("Whether to use the cost heuristic to make choices as we "
58              "traverse the search space using the exact solver. Defaulted "
59              "to on, and if turned off, we will use the node order -- "
60              "attempting to put the later nodes in the later sched groups. "
61              "Experimentally, results are mixed, so this should be set on a "
62              "case-by-case basis."));
63 
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67   NONE = 0u,
68   ALU = 1u << 0,
69   VALU = 1u << 1,
70   SALU = 1u << 2,
71   MFMA = 1u << 3,
72   VMEM = 1u << 4,
73   VMEM_READ = 1u << 5,
74   VMEM_WRITE = 1u << 6,
75   DS = 1u << 7,
76   DS_READ = 1u << 8,
77   DS_WRITE = 1u << 9,
78   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79         DS_READ | DS_WRITE,
80   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
81 };
82 
83 typedef DenseMap<SUnit *, SmallVector<int, 4>> SUnitsToCandidateSGsMap;
84 
85 // Classify instructions into groups to enable fine tuned control over the
86 // scheduler. These groups may be more specific than current SchedModel
87 // instruction classes.
88 class SchedGroup {
89 private:
90   // Mask that defines which instruction types can be classified into this
91   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
92   // and SCHED_GROUP_BARRIER.
93   SchedGroupMask SGMask;
94 
95   // Maximum number of SUnits that can be added to this group.
96   std::optional<unsigned> MaxSize;
97 
98   // SchedGroups will only synchronize with other SchedGroups that have the same
99   // SyncID.
100   int SyncID = 0;
101 
102   // SGID is used to map instructions to candidate SchedGroups
103   unsigned SGID;
104 
105   // Count of the number of created SchedGroups, used to initialize SGID.
106   static unsigned NumSchedGroups;
107 
108   ScheduleDAGInstrs *DAG;
109 
110   const SIInstrInfo *TII;
111 
112   // Try to add and edge from SU A to SU B.
113   bool tryAddEdge(SUnit *A, SUnit *B);
114 
115   // Use SGMask to determine whether we can classify MI as a member of this
116   // SchedGroup object.
117   bool canAddMI(const MachineInstr &MI) const;
118 
119 public:
120   // Collection of SUnits that are classified as members of this group.
121   SmallVector<SUnit *, 32> Collection;
122 
123   // Returns true if SU can be added to this SchedGroup.
124   bool canAddSU(SUnit &SU) const;
125 
126   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
127   // MakePred is true, SU will be a predecessor of the SUnits in this
128   // SchedGroup, otherwise SU will be a successor.
129   void link(SUnit &SU, bool MakePred = false);
130 
131   // Add DAG dependencies and track which edges are added, and the count of
132   // missed edges
133   int link(SUnit &SU, bool MakePred,
134            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
135 
136   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
137   // Use the predicate to determine whether SU should be a predecessor (P =
138   // true) or a successor (P = false) of this SchedGroup.
139   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
140 
141   // Add DAG dependencies such that SUnits in this group shall be ordered
142   // before SUnits in OtherGroup.
143   void link(SchedGroup &OtherGroup);
144 
145   // Returns true if no more instructions may be added to this group.
146   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
147 
148   // Add SU to the SchedGroup.
149   void add(SUnit &SU) {
150     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
151                       << format_hex((int)SGMask, 10, true) << " adding "
152                       << *SU.getInstr());
153     Collection.push_back(&SU);
154   }
155 
156   // Remove last element in the SchedGroup
157   void pop() { Collection.pop_back(); }
158 
159   // Identify and add all relevant SUs from the DAG to this SchedGroup.
160   void initSchedGroup();
161 
162   // Add instructions to the SchedGroup bottom up starting from RIter.
163   // PipelineInstrs is a set of instructions that should not be added to the
164   // SchedGroup even when the other conditions for adding it are satisfied.
165   // RIter will be added to the SchedGroup as well, and dependencies will be
166   // added so that RIter will always be scheduled at the end of the group.
167   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
168                       SUnitsToCandidateSGsMap &SyncedInstrs);
169 
170   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
171 
172   int getSyncID() { return SyncID; }
173 
174   int getSGID() { return SGID; }
175 
176   SchedGroupMask getMask() { return SGMask; }
177 
178   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
179              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
180       : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
181     SGID = NumSchedGroups++;
182   }
183 
184   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
185              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
186       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
187     SGID = NumSchedGroups++;
188   }
189 };
190 
191 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
192 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
193   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
194          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
195          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
196 
197   while (!SU.Preds.empty())
198     for (auto &P : SU.Preds)
199       SU.removePred(P);
200 
201   while (!SU.Succs.empty())
202     for (auto &S : SU.Succs)
203       for (auto &SP : S.getSUnit()->Preds)
204         if (SP.getSUnit() == &SU)
205           S.getSUnit()->removePred(SP);
206 }
207 
208 typedef std::pair<SUnit *, SmallVector<int, 4>> SUToCandSGsPair;
209 typedef SmallVector<SUToCandSGsPair, 4> SUsToCandSGsVec;
210 
211 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
212 // in non-trivial cases. For example, if the requested pipeline is
213 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
214 // in the DAG, then we will have an instruction that can not be trivially
215 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
216 // to find a good solution to the pipeline -- a greedy algorithm and an exact
217 // algorithm. The exact algorithm has an exponential time complexity and should
218 // only be used for small sized problems or medium sized problems where an exact
219 // solution is highly desired.
220 class PipelineSolver {
221   ScheduleDAGMI *DAG;
222 
223   // Instructions that can be assigned to multiple SchedGroups
224   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
225   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
226   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
227   // The current working pipeline
228   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
229   // The pipeline that has the best solution found so far
230   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
231 
232   // Whether or not we actually have any SyncedInstrs to try to solve.
233   bool NeedsSolver = false;
234 
235   // Compute an estimate of the size of search tree -- the true size is
236   // the product of each conflictedInst.Matches.size() across all SyncPipelines
237   unsigned computeProblemSize();
238 
239   // The cost penalty of not assigning a SU to a SchedGroup
240   int MissPenalty = 0;
241 
242   // Costs in terms of the number of edges we are unable to add
243   int BestCost = -1;
244   int CurrCost = 0;
245 
246   // Index pointing to the conflicting instruction that is currently being
247   // fitted
248   int CurrConflInstNo = 0;
249   // Index to the pipeline that is currently being fitted
250   int CurrSyncGroupIdx = 0;
251   // The first non trivial pipeline
252   int BeginSyncGroupIdx = 0;
253 
254   // How many branches we have explored
255   uint64_t BranchesExplored = 0;
256 
257   // Update indices to fit next conflicting instruction
258   void advancePosition();
259   // Recede indices to attempt to find better fit for previous conflicting
260   // instruction
261   void retreatPosition();
262 
263   // The exponential time algorithm which finds the provably best fit
264   bool solveExact();
265   // The polynomial time algorithm which attempts to find a good fit
266   bool solveGreedy();
267   // Whether or not the current solution is optimal
268   bool checkOptimal();
269   // Populate the ready list, prioiritizing fewest missed edges first
270   void populateReadyList(SUToCandSGsPair &CurrSU,
271                          SmallVectorImpl<std::pair<int, int>> &ReadyList,
272                          SmallVectorImpl<SchedGroup> &SyncPipeline);
273   // Add edges corresponding to the SchedGroups as assigned by solver
274   void makePipeline();
275   // Add the edges from the SU to the other SchedGroups in pipeline, and
276   // return the number of edges missed.
277   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
278                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
279   // Remove the edges passed via AddedEdges
280   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
281   // Convert the passed in maps to arrays for bidirectional iterators
282   void convertSyncMapsToArrays();
283 
284   void reset();
285 
286 public:
287   // Invoke the solver to map instructions to instruction groups. Heuristic &&
288   // command-line-option determines to use exact or greedy algorithm.
289   void solve();
290 
291   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
292                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
293                  ScheduleDAGMI *DAG)
294       : DAG(DAG), SyncedInstrs(SyncedInstrs),
295         SyncedSchedGroups(SyncedSchedGroups) {
296 
297     for (auto &PipelineInstrs : SyncedInstrs) {
298       if (PipelineInstrs.second.size() > 0) {
299         NeedsSolver = true;
300         break;
301       }
302     }
303 
304     if (!NeedsSolver)
305       return;
306 
307     convertSyncMapsToArrays();
308 
309     CurrPipeline = BestPipeline;
310 
311     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
312            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
313       ++BeginSyncGroupIdx;
314 
315     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
316       return;
317   }
318 };
319 
320 void PipelineSolver::reset() {
321 
322   for (auto &SyncPipeline : CurrPipeline) {
323     for (auto &SG : SyncPipeline) {
324       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
325       SG.Collection.clear();
326       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
327         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
328       });
329       if (SchedBarr != TempCollection.end())
330         SG.Collection.push_back(*SchedBarr);
331     }
332   }
333 
334   CurrSyncGroupIdx = BeginSyncGroupIdx;
335   CurrConflInstNo = 0;
336   CurrCost = 0;
337 }
338 
339 void PipelineSolver::convertSyncMapsToArrays() {
340   for (auto &SyncPipe : SyncedSchedGroups) {
341     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
342   }
343 
344   int PipelineIDx = SyncedInstrs.size() - 1;
345   PipelineInstrs.resize(SyncedInstrs.size());
346   for (auto &SyncInstrMap : SyncedInstrs) {
347     for (auto &SUsToCandSGs : SyncInstrMap.second) {
348       if (PipelineInstrs[PipelineIDx].size() == 0) {
349         PipelineInstrs[PipelineIDx].push_back(
350             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
351         continue;
352       }
353       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
354       // Insert them in sorted order -- this allows for good parsing order in
355       // the greedy algorithm
356       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
357              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
358         ++SortPosition;
359       PipelineInstrs[PipelineIDx].insert(
360           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
361     }
362     --PipelineIDx;
363   }
364 }
365 
366 void PipelineSolver::makePipeline() {
367   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
368   for (auto &SyncPipeline : BestPipeline) {
369     for (auto &SG : SyncPipeline) {
370       SUnit *SGBarr = nullptr;
371       for (auto &SU : SG.Collection) {
372         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
373           SGBarr = SU;
374       }
375       // Command line requested IGroupLP doesn't have SGBarr
376       if (!SGBarr)
377         continue;
378       resetEdges(*SGBarr, DAG);
379       SG.link(*SGBarr, false);
380     }
381   }
382 
383   for (auto &SyncPipeline : BestPipeline) {
384     auto I = SyncPipeline.rbegin();
385     auto E = SyncPipeline.rend();
386     for (; I != E; ++I) {
387       auto &GroupA = *I;
388       for (auto J = std::next(I); J != E; ++J) {
389         auto &GroupB = *J;
390         GroupA.link(GroupB);
391       }
392     }
393   }
394 }
395 
396 int PipelineSolver::addEdges(
397     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
398     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
399   int AddedCost = 0;
400   bool MakePred = false;
401 
402   // The groups in the pipeline are in reverse order. Thus,
403   // by traversing them from last to first, we are traversing
404   // them in the order as they were introduced in the code. After we
405   // pass the group the SU is being assigned to, it should be
406   // linked as a predecessor of the subsequent SchedGroups
407   auto GroupNo = (int)SyncPipeline.size() - 1;
408   for (; GroupNo >= 0; GroupNo--) {
409     if (SyncPipeline[GroupNo].getSGID() == SGID) {
410       MakePred = true;
411       continue;
412     }
413     auto Group = &SyncPipeline[GroupNo];
414     AddedCost += Group->link(*SU, MakePred, AddedEdges);
415     assert(AddedCost >= 0);
416   }
417 
418   return AddedCost;
419 }
420 
421 void PipelineSolver::removeEdges(
422     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
423   // Only remove the edges that we have added when testing
424   // the fit.
425   for (auto &PredSuccPair : EdgesToRemove) {
426     SUnit *Pred = PredSuccPair.first;
427     SUnit *Succ = PredSuccPair.second;
428 
429     auto Match = llvm::find_if(
430         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
431     if (Match != Succ->Preds.end()) {
432       assert(Match->isArtificial());
433       Succ->removePred(*Match);
434     }
435   }
436 }
437 
438 void PipelineSolver::advancePosition() {
439   ++CurrConflInstNo;
440 
441   if (static_cast<size_t>(CurrConflInstNo) >=
442       PipelineInstrs[CurrSyncGroupIdx].size()) {
443     CurrConflInstNo = 0;
444     ++CurrSyncGroupIdx;
445     // Advance to next non-trivial pipeline
446     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
447            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
448       ++CurrSyncGroupIdx;
449   }
450 }
451 
452 void PipelineSolver::retreatPosition() {
453   assert(CurrConflInstNo >= 0);
454   assert(CurrSyncGroupIdx >= 0);
455 
456   if (CurrConflInstNo > 0) {
457     --CurrConflInstNo;
458     return;
459   }
460 
461   if (CurrConflInstNo == 0) {
462     // If we return to the starting position, we have explored
463     // the entire tree
464     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
465       return;
466 
467     --CurrSyncGroupIdx;
468     // Go to previous non-trivial pipeline
469     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
470       --CurrSyncGroupIdx;
471 
472     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
473   }
474 }
475 
476 bool PipelineSolver::checkOptimal() {
477   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
478     if (BestCost == -1 || CurrCost < BestCost) {
479       BestPipeline = CurrPipeline;
480       BestCost = CurrCost;
481       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
482     }
483     assert(BestCost >= 0);
484   }
485 
486   bool DoneExploring = false;
487   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
488     DoneExploring = true;
489 
490   return (DoneExploring || BestCost == 0);
491 }
492 
493 void PipelineSolver::populateReadyList(
494     SUToCandSGsPair &CurrSU, SmallVectorImpl<std::pair<int, int>> &ReadyList,
495     SmallVectorImpl<SchedGroup> &SyncPipeline) {
496   assert(CurrSU.second.size() >= 1);
497   auto I = CurrSU.second.rbegin();
498   auto E = CurrSU.second.rend();
499   for (; I != E; ++I) {
500     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
501     int CandSGID = *I;
502     SchedGroup *Match;
503     for (auto &SG : SyncPipeline) {
504       if (SG.getSGID() == CandSGID)
505         Match = &SG;
506     }
507 
508     if (UseCostHeur) {
509       if (Match->isFull()) {
510         ReadyList.push_back(std::pair(*I, MissPenalty));
511         continue;
512       }
513 
514       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
515       ReadyList.push_back(std::pair(*I, TempCost));
516       removeEdges(AddedEdges);
517     } else
518       ReadyList.push_back(std::pair(*I, -1));
519   }
520 
521   if (UseCostHeur) {
522     std::sort(ReadyList.begin(), ReadyList.end(),
523               [](std::pair<int, int> A, std::pair<int, int> B) {
524                 return A.second < B.second;
525               });
526   }
527 
528   assert(ReadyList.size() == CurrSU.second.size());
529 }
530 
531 bool PipelineSolver::solveExact() {
532   if (checkOptimal())
533     return true;
534 
535   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
536     return false;
537 
538   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
539   assert(static_cast<size_t>(CurrConflInstNo) <
540          PipelineInstrs[CurrSyncGroupIdx].size());
541   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
542   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
543                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
544 
545   // SchedGroup -> Cost pairs
546   SmallVector<std::pair<int, int>, 4> ReadyList;
547   // Prioritize the candidate sched groups in terms of lowest cost first
548   populateReadyList(CurrSU, ReadyList, CurrPipeline[CurrSyncGroupIdx]);
549 
550   auto I = ReadyList.begin();
551   auto E = ReadyList.end();
552   for (; I != E; ++I) {
553     // If we are trying SGs in least cost order, and the current SG is cost
554     // infeasible, then all subsequent SGs will also be cost infeasible, so we
555     // can prune.
556     if (BestCost != -1 && (CurrCost + I->second > BestCost))
557       return false;
558 
559     int CandSGID = I->first;
560     int AddedCost = 0;
561     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
562     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
563     SchedGroup *Match;
564     for (auto &SG : SyncPipeline) {
565       if (SG.getSGID() == CandSGID)
566         Match = &SG;
567     }
568 
569     if (Match->isFull())
570       continue;
571 
572     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
573                       << (int)Match->getMask() << "and ID " << CandSGID
574                       << "\n");
575     Match->add(*CurrSU.first);
576     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
577     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
578     CurrCost += AddedCost;
579     advancePosition();
580     ++BranchesExplored;
581     bool FinishedExploring = false;
582     // If the Cost after adding edges is greater than a known solution,
583     // backtrack
584     if (CurrCost < BestCost || BestCost == -1) {
585       if (solveExact()) {
586         FinishedExploring = BestCost != 0;
587         if (!FinishedExploring)
588           return true;
589       }
590     }
591 
592     retreatPosition();
593     CurrCost -= AddedCost;
594     removeEdges(AddedEdges);
595     Match->pop();
596     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
597     if (FinishedExploring)
598       return true;
599   }
600 
601   // Try the pipeline where the current instruction is omitted
602   // Potentially if we omit a problematic instruction from the pipeline,
603   // all the other instructions can nicely fit.
604   CurrCost += MissPenalty;
605   advancePosition();
606 
607   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
608 
609   bool FinishedExploring = false;
610   if (CurrCost < BestCost || BestCost == -1) {
611     if (solveExact()) {
612       bool FinishedExploring = BestCost != 0;
613       if (!FinishedExploring)
614         return true;
615     }
616   }
617 
618   retreatPosition();
619   CurrCost -= MissPenalty;
620   return FinishedExploring;
621 }
622 
623 bool PipelineSolver::solveGreedy() {
624   BestCost = 0;
625   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
626 
627   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
628     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
629     int BestNodeCost = -1;
630     int TempCost;
631     SchedGroup *BestGroup = nullptr;
632     int BestGroupID = -1;
633     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
634     LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
635                       << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
636 
637     // Since we have added the potential SchedGroups from bottom up, but
638     // traversed the DAG from top down, parse over the groups from last to
639     // first. If we fail to do this for the greedy algorithm, the solution will
640     // likely not be good in more complex cases.
641     auto I = CurrSU.second.rbegin();
642     auto E = CurrSU.second.rend();
643     for (; I != E; ++I) {
644       std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
645       int CandSGID = *I;
646       SchedGroup *Match;
647       for (auto &SG : SyncPipeline) {
648         if (SG.getSGID() == CandSGID)
649           Match = &SG;
650       }
651 
652       LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
653                         << (int)Match->getMask() << "\n");
654 
655       if (Match->isFull()) {
656         LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
657         continue;
658       }
659       TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
660       LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
661       if (TempCost < BestNodeCost || BestNodeCost == -1) {
662         BestGroup = Match;
663         BestNodeCost = TempCost;
664         BestGroupID = CandSGID;
665       }
666       removeEdges(AddedEdges);
667       if (BestNodeCost == 0)
668         break;
669     }
670 
671     if (BestGroupID != -1) {
672       BestGroup->add(*CurrSU.first);
673       addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
674       LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
675                         << (int)BestGroup->getMask() << "\n");
676       BestCost += TempCost;
677     } else
678       BestCost += MissPenalty;
679 
680     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
681     advancePosition();
682   }
683   BestPipeline = CurrPipeline;
684   removeEdges(AddedEdges);
685   return false;
686 }
687 
688 unsigned PipelineSolver::computeProblemSize() {
689   unsigned ProblemSize = 0;
690   for (auto &PipeConflicts : PipelineInstrs) {
691     ProblemSize += PipeConflicts.size();
692   }
693 
694   return ProblemSize;
695 }
696 
697 void PipelineSolver::solve() {
698   if (!NeedsSolver)
699     return;
700 
701   unsigned ProblemSize = computeProblemSize();
702   assert(ProblemSize > 0);
703 
704   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
705   MissPenalty = (ProblemSize / 2) + 1;
706 
707   LLVM_DEBUG(DAG->dump());
708   if (EnableExactSolver || BelowCutoff) {
709     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
710     solveGreedy();
711     reset();
712     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
713     if (BestCost > 0) {
714       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
715       solveExact();
716       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
717     }
718   } else { // Use the Greedy Algorithm by default
719     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
720     solveGreedy();
721   }
722 
723   makePipeline();
724 }
725 
726 enum IGLPStrategyID : int { MFMASmallGemmOptID = 0 };
727 
728 // Implement a IGLP scheduling strategy.
729 class IGLPStrategy {
730 protected:
731   ScheduleDAGInstrs *DAG;
732 
733   const SIInstrInfo *TII;
734 
735 public:
736   // Add SchedGroups to \p Pipeline to implement this Strategy.
737   virtual void applyIGLPStrategy(
738       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
739       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) = 0;
740 
741   // Returns true if this strategy should be applied to a ScheduleDAG.
742   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) = 0;
743 
744   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
745       : DAG(DAG), TII(TII) {}
746 
747   virtual ~IGLPStrategy() = default;
748 };
749 
750 class MFMASmallGemmOpt final : public IGLPStrategy {
751 public:
752   void applyIGLPStrategy(
753       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
754       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) override;
755 
756   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
757 
758   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
759       : IGLPStrategy(DAG, TII) {}
760 };
761 
762 void MFMASmallGemmOpt::applyIGLPStrategy(
763     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
764     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) {
765   // Count the number of MFMA instructions.
766   unsigned MFMACount = 0;
767   for (const MachineInstr &I : *DAG)
768     if (TII->isMFMAorWMMA(I))
769       ++MFMACount;
770 
771   const unsigned PipelineSyncID = 0;
772   SchedGroup *SG = nullptr;
773   for (unsigned I = 0; I < MFMACount * 3; ++I) {
774     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
775         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
776     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
777 
778     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
779         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
780     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
781   }
782 }
783 
784 static std::unique_ptr<IGLPStrategy>
785 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
786                    const SIInstrInfo *TII) {
787   switch (ID) {
788   case MFMASmallGemmOptID:
789     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
790   }
791 
792   llvm_unreachable("Unknown IGLPStrategyID");
793 }
794 
795 class IGroupLPDAGMutation : public ScheduleDAGMutation {
796 private:
797   const SIInstrInfo *TII;
798 
799   ScheduleDAGMI *DAG;
800 
801   // Organize lists of SchedGroups by their SyncID. SchedGroups /
802   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
803   // between then.
804   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
805 
806   // Used to track instructions that can be mapped to multiple sched groups
807   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
808 
809   // Add DAG edges that enforce SCHED_BARRIER ordering.
810   void addSchedBarrierEdges(SUnit &SU);
811 
812   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
813   // not be reordered accross the SCHED_BARRIER. This is used for the base
814   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
815   // SCHED_BARRIER will always block all instructions that can be classified
816   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
817   // and may only synchronize with some SchedGroups. Returns the inverse of
818   // Mask. SCHED_BARRIER's mask describes which instruction types should be
819   // allowed to be scheduled across it. Invert the mask to get the
820   // SchedGroupMask of instructions that should be barred.
821   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
822 
823   // Create SchedGroups for a SCHED_GROUP_BARRIER.
824   void initSchedGroupBarrierPipelineStage(
825       std::vector<SUnit>::reverse_iterator RIter);
826 
827   void initIGLPOpt(SUnit &SU);
828 
829 public:
830   void apply(ScheduleDAGInstrs *DAGInstrs) override;
831 
832   IGroupLPDAGMutation() = default;
833 };
834 
835 unsigned SchedGroup::NumSchedGroups = 0;
836 
837 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
838   if (A != B && DAG->canAddEdge(B, A)) {
839     DAG->addEdge(B, SDep(A, SDep::Artificial));
840     return true;
841   }
842   return false;
843 }
844 
845 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
846   bool Result = false;
847   if (MI.isMetaInstruction())
848     Result = false;
849 
850   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
851            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI)))
852     Result = true;
853 
854   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
855            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI))
856     Result = true;
857 
858   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
859            TII->isSALU(MI))
860     Result = true;
861 
862   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
863            TII->isMFMAorWMMA(MI))
864     Result = true;
865 
866   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
867            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
868     Result = true;
869 
870   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
871            MI.mayLoad() &&
872            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
873     Result = true;
874 
875   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
876            MI.mayStore() &&
877            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
878     Result = true;
879 
880   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
881            TII->isDS(MI))
882     Result = true;
883 
884   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
885            MI.mayLoad() && TII->isDS(MI))
886     Result = true;
887 
888   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
889            MI.mayStore() && TII->isDS(MI))
890     Result = true;
891 
892   LLVM_DEBUG(
893       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
894              << (Result ? " could classify " : " unable to classify ") << MI);
895 
896   return Result;
897 }
898 
899 int SchedGroup::link(SUnit &SU, bool MakePred,
900                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
901   int MissedEdges = 0;
902   for (auto *A : Collection) {
903     SUnit *B = &SU;
904     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
905       continue;
906     if (MakePred)
907       std::swap(A, B);
908 
909     if (DAG->IsReachable(B, A))
910       continue;
911     // tryAddEdge returns false if there is a dependency that makes adding
912     // the A->B edge impossible, otherwise it returns true;
913     bool Added = tryAddEdge(A, B);
914     if (Added)
915       AddedEdges.push_back(std::pair(A, B));
916     else
917       ++MissedEdges;
918   }
919 
920   return MissedEdges;
921 }
922 
923 void SchedGroup::link(SUnit &SU, bool MakePred) {
924   for (auto *A : Collection) {
925     SUnit *B = &SU;
926     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
927       continue;
928     if (MakePred)
929       std::swap(A, B);
930 
931     tryAddEdge(A, B);
932   }
933 }
934 
935 void SchedGroup::link(SUnit &SU,
936                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
937   for (auto *A : Collection) {
938     SUnit *B = &SU;
939     if (P(A, B))
940       std::swap(A, B);
941 
942     tryAddEdge(A, B);
943   }
944 }
945 
946 void SchedGroup::link(SchedGroup &OtherGroup) {
947   for (auto *B : OtherGroup.Collection)
948     link(*B);
949 }
950 
951 bool SchedGroup::canAddSU(SUnit &SU) const {
952   MachineInstr &MI = *SU.getInstr();
953   if (MI.getOpcode() != TargetOpcode::BUNDLE)
954     return canAddMI(MI);
955 
956   // Special case for bundled MIs.
957   const MachineBasicBlock *MBB = MI.getParent();
958   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
959   while (E != MBB->end() && E->isBundledWithPred())
960     ++E;
961 
962   // Return true if all of the bundled MIs can be added to this group.
963   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
964 }
965 
966 void SchedGroup::initSchedGroup() {
967   for (auto &SU : DAG->SUnits) {
968     if (isFull())
969       break;
970 
971     if (canAddSU(SU))
972       add(SU);
973   }
974 }
975 
976 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
977                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
978   SUnit &InitSU = *RIter;
979   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
980     auto &SU = *RIter;
981     if (isFull())
982       break;
983 
984     if (canAddSU(SU))
985       SyncedInstrs[&SU].push_back(SGID);
986   }
987 
988   add(InitSU);
989   assert(MaxSize);
990   (*MaxSize)++;
991 }
992 
993 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
994   auto I = DAG->SUnits.rbegin();
995   auto E = DAG->SUnits.rend();
996   for (; I != E; ++I) {
997     auto &SU = *I;
998     if (isFull())
999       break;
1000 
1001     if (canAddSU(SU))
1002       SyncedInstrs[&SU].push_back(SGID);
1003   }
1004 }
1005 
1006 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
1007   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
1008   if (!TSchedModel || DAGInstrs->SUnits.empty())
1009     return;
1010 
1011   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
1012   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
1013   TII = ST.getInstrInfo();
1014   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
1015   SyncedSchedGroups.clear();
1016   SyncedInstrs.clear();
1017   bool foundSB = false;
1018   bool foundIGLP = false;
1019   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
1020     unsigned Opc = R->getInstr()->getOpcode();
1021     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
1022     if (Opc == AMDGPU::SCHED_BARRIER) {
1023       addSchedBarrierEdges(*R);
1024       foundSB = true;
1025     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
1026       initSchedGroupBarrierPipelineStage(R);
1027       foundSB = true;
1028     } else if (Opc == AMDGPU::IGLP_OPT) {
1029       resetEdges(*R, DAG);
1030       if (!foundSB && !foundIGLP)
1031         initIGLPOpt(*R);
1032       foundIGLP = true;
1033     }
1034   }
1035 
1036   if (foundSB || foundIGLP) {
1037     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG);
1038     // PipelineSolver performs the mutation by adding the edges it
1039     // determined as the best
1040     PS.solve();
1041   }
1042 }
1043 
1044 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
1045   MachineInstr &MI = *SchedBarrier.getInstr();
1046   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
1047   // Remove all existing edges from the SCHED_BARRIER that were added due to the
1048   // instruction having side effects.
1049   resetEdges(SchedBarrier, DAG);
1050   auto InvertedMask =
1051       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
1052   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
1053   SG.initSchedGroup();
1054   // Preserve original instruction ordering relative to the SCHED_BARRIER.
1055   SG.link(
1056       SchedBarrier,
1057       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
1058           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
1059 }
1060 
1061 SchedGroupMask
1062 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1063   // Invert mask and erase bits for types of instructions that are implied to be
1064   // allowed past the SCHED_BARRIER.
1065   SchedGroupMask InvertedMask = ~Mask;
1066 
1067   // ALU implies VALU, SALU, MFMA.
1068   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1069     InvertedMask &=
1070         ~SchedGroupMask::VALU & ~SchedGroupMask::SALU & ~SchedGroupMask::MFMA;
1071   // VALU, SALU, MFMA implies ALU.
1072   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1073            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1074            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE)
1075     InvertedMask &= ~SchedGroupMask::ALU;
1076 
1077   // VMEM implies VMEM_READ, VMEM_WRITE.
1078   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
1079     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
1080   // VMEM_READ, VMEM_WRITE implies VMEM.
1081   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
1082            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
1083     InvertedMask &= ~SchedGroupMask::VMEM;
1084 
1085   // DS implies DS_READ, DS_WRITE.
1086   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
1087     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
1088   // DS_READ, DS_WRITE implies DS.
1089   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
1090            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1091     InvertedMask &= ~SchedGroupMask::DS;
1092 
1093   return InvertedMask;
1094 }
1095 
1096 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
1097     std::vector<SUnit>::reverse_iterator RIter) {
1098   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
1099   // to the instruction having side effects.
1100   resetEdges(*RIter, DAG);
1101   MachineInstr &SGB = *RIter->getInstr();
1102   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
1103   int32_t SGMask = SGB.getOperand(0).getImm();
1104   int32_t Size = SGB.getOperand(1).getImm();
1105   int32_t SyncID = SGB.getOperand(2).getImm();
1106 
1107   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
1108                                                     Size, SyncID, DAG, TII);
1109 
1110   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
1111 }
1112 
1113 void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1114   IGLPStrategyID StrategyID =
1115       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
1116   auto S = createIGLPStrategy(StrategyID, DAG, TII);
1117   if (S->shouldApplyStrategy(DAG))
1118     S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups);
1119 }
1120 
1121 } // namespace
1122 
1123 namespace llvm {
1124 
1125 std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation() {
1126   return std::make_unique<IGroupLPDAGMutation>();
1127 }
1128 
1129 } // end namespace llvm
1130