xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (revision 35c0a8c449fd2b7f75029ebed5e10852240f0865)
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPU.h"
16 #include "llvm/CodeGen/RegisterPressure.h"
17 
18 using namespace llvm;
19 
20 #define DEBUG_TYPE "machine-scheduler"
21 
22 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
23                    const GCNRPTracker::LiveRegSet &S2) {
24   if (S1.size() != S2.size())
25     return false;
26 
27   for (const auto &P : S1) {
28     auto I = S2.find(P.first);
29     if (I == S2.end() || I->second != P.second)
30       return false;
31   }
32   return true;
33 }
34 
35 ///////////////////////////////////////////////////////////////////////////////
36 // GCNRegPressure
37 
38 unsigned GCNRegPressure::getRegKind(Register Reg,
39                                     const MachineRegisterInfo &MRI) {
40   assert(Reg.isVirtual());
41   const auto RC = MRI.getRegClass(Reg);
42   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
43   return STI->isSGPRClass(RC)
44              ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
45          : STI->isAGPRClass(RC)
46              ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
47              : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
48 }
49 
50 void GCNRegPressure::inc(unsigned Reg,
51                          LaneBitmask PrevMask,
52                          LaneBitmask NewMask,
53                          const MachineRegisterInfo &MRI) {
54   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
55       SIRegisterInfo::getNumCoveredRegs(PrevMask))
56     return;
57 
58   int Sign = 1;
59   if (NewMask < PrevMask) {
60     std::swap(NewMask, PrevMask);
61     Sign = -1;
62   }
63 
64   switch (auto Kind = getRegKind(Reg, MRI)) {
65   case SGPR32:
66   case VGPR32:
67   case AGPR32:
68     Value[Kind] += Sign;
69     break;
70 
71   case SGPR_TUPLE:
72   case VGPR_TUPLE:
73   case AGPR_TUPLE:
74     assert(PrevMask < NewMask);
75 
76     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
77       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78 
79     if (PrevMask.none()) {
80       assert(NewMask.any());
81       const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
82       Value[Kind] +=
83           Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight;
84     }
85     break;
86 
87   default: llvm_unreachable("Unknown register kind");
88   }
89 }
90 
91 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
92                           unsigned MaxOccupancy) const {
93   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
94 
95   const auto SGPROcc = std::min(MaxOccupancy,
96                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
97   const auto VGPROcc =
98     std::min(MaxOccupancy,
99              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
100   const auto OtherSGPROcc = std::min(MaxOccupancy,
101                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
102   const auto OtherVGPROcc =
103     std::min(MaxOccupancy,
104              ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
105 
106   const auto Occ = std::min(SGPROcc, VGPROcc);
107   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
108 
109   // Give first precedence to the better occupancy.
110   if (Occ != OtherOcc)
111     return Occ > OtherOcc;
112 
113   unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
114   unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
115 
116   // SGPR excess pressure conditions
117   unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0);
118   unsigned OtherExcessSGPR =
119       std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
120 
121   auto WaveSize = ST.getWavefrontSize();
122   // The number of virtual VGPRs required to handle excess SGPR
123   unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
124   unsigned OtherVGPRForSGPRSpills =
125       (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
126 
127   unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
128 
129   // Unified excess pressure conditions, accounting for VGPRs used for SGPR
130   // spills
131   unsigned ExcessVGPR =
132       std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) +
133                                 VGPRForSGPRSpills - MaxVGPRs),
134                0);
135   unsigned OtherExcessVGPR =
136       std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
137                                 OtherVGPRForSGPRSpills - MaxVGPRs),
138                0);
139   // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
140   // spills
141   unsigned ExcessArchVGPR = std::max(
142       static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs),
143       0);
144   unsigned OtherExcessArchVGPR =
145       std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills -
146                                 MaxArchVGPRs),
147                0);
148   // AGPR excess pressure conditions
149   unsigned ExcessAGPR = std::max(
150       static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs)
151                                            : (getAGPRNum() - MaxVGPRs)),
152       0);
153   unsigned OtherExcessAGPR = std::max(
154       static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
155                                            : (O.getAGPRNum() - MaxVGPRs)),
156       0);
157 
158   bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
159   bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
160                        OtherExcessArchVGPR || OtherExcessAGPR;
161 
162   // Give second precedence to the reduced number of spills to hold the register
163   // pressure.
164   if (ExcessRP || OtherExcessRP) {
165     // The difference in excess VGPR pressure, after including VGPRs used for
166     // SGPR spills
167     int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
168                     (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
169 
170     int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
171 
172     if (VGPRDiff != 0)
173       return VGPRDiff > 0;
174     if (SGPRDiff != 0) {
175       unsigned PureExcessVGPR =
176           std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
177                    0) +
178           std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0);
179       unsigned OtherPureExcessVGPR =
180           std::max(
181               static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
182               0) +
183           std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0);
184 
185       // If we have a special case where there is a tie in excess VGPR, but one
186       // of the pressures has VGPR usage from SGPR spills, prefer the pressure
187       // with SGPR spills.
188       if (PureExcessVGPR != OtherPureExcessVGPR)
189         return SGPRDiff < 0;
190       // If both pressures have the same excess pressure before and after
191       // accounting for SGPR spills, prefer fewer SGPR spills.
192       return SGPRDiff > 0;
193     }
194   }
195 
196   bool SGPRImportant = SGPROcc < VGPROcc;
197   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
198 
199   // If both pressures disagree on what is more important compare vgprs.
200   if (SGPRImportant != OtherSGPRImportant) {
201     SGPRImportant = false;
202   }
203 
204   // Give third precedence to lower register tuple pressure.
205   bool SGPRFirst = SGPRImportant;
206   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
207     if (SGPRFirst) {
208       auto SW = getSGPRTuplesWeight();
209       auto OtherSW = O.getSGPRTuplesWeight();
210       if (SW != OtherSW)
211         return SW < OtherSW;
212     } else {
213       auto VW = getVGPRTuplesWeight();
214       auto OtherVW = O.getVGPRTuplesWeight();
215       if (VW != OtherVW)
216         return VW < OtherVW;
217     }
218   }
219 
220   // Give final precedence to lower general RP.
221   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
222                          (getVGPRNum(ST.hasGFX90AInsts()) <
223                           O.getVGPRNum(ST.hasGFX90AInsts()));
224 }
225 
226 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
227   return Printable([&RP, ST](raw_ostream &OS) {
228     OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
229        << "AGPRs: " << RP.getAGPRNum();
230     if (ST)
231       OS << "(O"
232          << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
233          << ')';
234     OS << ", SGPRs: " << RP.getSGPRNum();
235     if (ST)
236       OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
237     OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
238        << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
239     if (ST)
240       OS << " -> Occ: " << RP.getOccupancy(*ST);
241     OS << '\n';
242   });
243 }
244 
245 static LaneBitmask getDefRegMask(const MachineOperand &MO,
246                                  const MachineRegisterInfo &MRI) {
247   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
248 
249   // We don't rely on read-undef flag because in case of tentative schedule
250   // tracking it isn't set correctly yet. This works correctly however since
251   // use mask has been tracked before using LIS.
252   return MO.getSubReg() == 0 ?
253     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
254     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
255 }
256 
257 static void
258 collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
259                       const MachineInstr &MI, const LiveIntervals &LIS,
260                       const MachineRegisterInfo &MRI) {
261   SlotIndex InstrSI;
262   for (const auto &MO : MI.operands()) {
263     if (!MO.isReg() || !MO.getReg().isVirtual())
264       continue;
265     if (!MO.isUse() || !MO.readsReg())
266       continue;
267 
268     Register Reg = MO.getReg();
269     if (llvm::any_of(RegMaskPairs, [Reg](const RegisterMaskPair &RM) {
270           return RM.RegUnit == Reg;
271         }))
272       continue;
273 
274     LaneBitmask UseMask;
275     auto &LI = LIS.getInterval(Reg);
276     if (!LI.hasSubRanges())
277       UseMask = MRI.getMaxLaneMaskForVReg(Reg);
278     else {
279       // For a tentative schedule LIS isn't updated yet but livemask should
280       // remain the same on any schedule. Subreg defs can be reordered but they
281       // all must dominate uses anyway.
282       if (!InstrSI)
283         InstrSI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
284       UseMask = getLiveLaneMask(LI, InstrSI, MRI);
285     }
286 
287     RegMaskPairs.emplace_back(Reg, UseMask);
288   }
289 }
290 
291 ///////////////////////////////////////////////////////////////////////////////
292 // GCNRPTracker
293 
294 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI,
295                                   const LiveIntervals &LIS,
296                                   const MachineRegisterInfo &MRI) {
297   return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI);
298 }
299 
300 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
301                                   const MachineRegisterInfo &MRI) {
302   LaneBitmask LiveMask;
303   if (LI.hasSubRanges()) {
304     for (const auto &S : LI.subranges())
305       if (S.liveAt(SI)) {
306         LiveMask |= S.LaneMask;
307         assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
308       }
309   } else if (LI.liveAt(SI)) {
310     LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg());
311   }
312   return LiveMask;
313 }
314 
315 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
316                                            const LiveIntervals &LIS,
317                                            const MachineRegisterInfo &MRI) {
318   GCNRPTracker::LiveRegSet LiveRegs;
319   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
320     auto Reg = Register::index2VirtReg(I);
321     if (!LIS.hasInterval(Reg))
322       continue;
323     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
324     if (LiveMask.any())
325       LiveRegs[Reg] = LiveMask;
326   }
327   return LiveRegs;
328 }
329 
330 void GCNRPTracker::reset(const MachineInstr &MI,
331                          const LiveRegSet *LiveRegsCopy,
332                          bool After) {
333   const MachineFunction &MF = *MI.getMF();
334   MRI = &MF.getRegInfo();
335   if (LiveRegsCopy) {
336     if (&LiveRegs != LiveRegsCopy)
337       LiveRegs = *LiveRegsCopy;
338   } else {
339     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
340                      : getLiveRegsBefore(MI, LIS);
341   }
342 
343   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
344 }
345 
346 ////////////////////////////////////////////////////////////////////////////////
347 // GCNUpwardRPTracker
348 
349 void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
350                                const LiveRegSet &LiveRegs_) {
351   MRI = &MRI_;
352   LiveRegs = LiveRegs_;
353   LastTrackedMI = nullptr;
354   MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
355 }
356 
357 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
358   assert(MRI && "call reset first");
359 
360   LastTrackedMI = &MI;
361 
362   if (MI.isDebugInstr())
363     return;
364 
365   // Kill all defs.
366   GCNRegPressure DefPressure, ECDefPressure;
367   bool HasECDefs = false;
368   for (const MachineOperand &MO : MI.all_defs()) {
369     if (!MO.getReg().isVirtual())
370       continue;
371 
372     Register Reg = MO.getReg();
373     LaneBitmask DefMask = getDefRegMask(MO, *MRI);
374 
375     // Treat a def as fully live at the moment of definition: keep a record.
376     if (MO.isEarlyClobber()) {
377       ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
378       HasECDefs = true;
379     } else
380       DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
381 
382     auto I = LiveRegs.find(Reg);
383     if (I == LiveRegs.end())
384       continue;
385 
386     LaneBitmask &LiveMask = I->second;
387     LaneBitmask PrevMask = LiveMask;
388     LiveMask &= ~DefMask;
389     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
390     if (LiveMask.none())
391       LiveRegs.erase(I);
392   }
393 
394   // Update MaxPressure with defs pressure.
395   DefPressure += CurPressure;
396   if (HasECDefs)
397     DefPressure += ECDefPressure;
398   MaxPressure = max(DefPressure, MaxPressure);
399 
400   // Make uses alive.
401   SmallVector<RegisterMaskPair, 8> RegUses;
402   collectVirtualRegUses(RegUses, MI, LIS, *MRI);
403   for (const RegisterMaskPair &U : RegUses) {
404     LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
405     LaneBitmask PrevMask = LiveMask;
406     LiveMask |= U.LaneMask;
407     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
408   }
409 
410   // Update MaxPressure with uses plus early-clobber defs pressure.
411   MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure)
412                           : max(CurPressure, MaxPressure);
413 
414   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
415 }
416 
417 ////////////////////////////////////////////////////////////////////////////////
418 // GCNDownwardRPTracker
419 
420 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
421                                  const LiveRegSet *LiveRegsCopy) {
422   MRI = &MI.getParent()->getParent()->getRegInfo();
423   LastTrackedMI = nullptr;
424   MBBEnd = MI.getParent()->end();
425   NextMI = &MI;
426   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
427   if (NextMI == MBBEnd)
428     return false;
429   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
430   return true;
431 }
432 
433 bool GCNDownwardRPTracker::advanceBeforeNext() {
434   assert(MRI && "call reset first");
435   if (!LastTrackedMI)
436     return NextMI == MBBEnd;
437 
438   assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
439 
440   SlotIndex SI = NextMI == MBBEnd
441                      ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
442                      : LIS.getInstructionIndex(*NextMI).getBaseIndex();
443   assert(SI.isValid());
444 
445   // Remove dead registers or mask bits.
446   SmallSet<Register, 8> SeenRegs;
447   for (auto &MO : LastTrackedMI->operands()) {
448     if (!MO.isReg() || !MO.getReg().isVirtual())
449       continue;
450     if (MO.isUse() && !MO.readsReg())
451       continue;
452     if (!SeenRegs.insert(MO.getReg()).second)
453       continue;
454     const LiveInterval &LI = LIS.getInterval(MO.getReg());
455     if (LI.hasSubRanges()) {
456       auto It = LiveRegs.end();
457       for (const auto &S : LI.subranges()) {
458         if (!S.liveAt(SI)) {
459           if (It == LiveRegs.end()) {
460             It = LiveRegs.find(MO.getReg());
461             if (It == LiveRegs.end())
462               llvm_unreachable("register isn't live");
463           }
464           auto PrevMask = It->second;
465           It->second &= ~S.LaneMask;
466           CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
467         }
468       }
469       if (It != LiveRegs.end() && It->second.none())
470         LiveRegs.erase(It);
471     } else if (!LI.liveAt(SI)) {
472       auto It = LiveRegs.find(MO.getReg());
473       if (It == LiveRegs.end())
474         llvm_unreachable("register isn't live");
475       CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
476       LiveRegs.erase(It);
477     }
478   }
479 
480   MaxPressure = max(MaxPressure, CurPressure);
481 
482   LastTrackedMI = nullptr;
483 
484   return NextMI == MBBEnd;
485 }
486 
487 void GCNDownwardRPTracker::advanceToNext() {
488   LastTrackedMI = &*NextMI++;
489   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
490 
491   // Add new registers or mask bits.
492   for (const auto &MO : LastTrackedMI->all_defs()) {
493     Register Reg = MO.getReg();
494     if (!Reg.isVirtual())
495       continue;
496     auto &LiveMask = LiveRegs[Reg];
497     auto PrevMask = LiveMask;
498     LiveMask |= getDefRegMask(MO, *MRI);
499     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
500   }
501 
502   MaxPressure = max(MaxPressure, CurPressure);
503 }
504 
505 bool GCNDownwardRPTracker::advance() {
506   if (NextMI == MBBEnd)
507     return false;
508   advanceBeforeNext();
509   advanceToNext();
510   return true;
511 }
512 
513 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
514   while (NextMI != End)
515     if (!advance()) return false;
516   return true;
517 }
518 
519 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
520                                    MachineBasicBlock::const_iterator End,
521                                    const LiveRegSet *LiveRegsCopy) {
522   reset(*Begin, LiveRegsCopy);
523   return advance(End);
524 }
525 
526 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
527                                const GCNRPTracker::LiveRegSet &TrackedLR,
528                                const TargetRegisterInfo *TRI, StringRef Pfx) {
529   return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
530     for (auto const &P : TrackedLR) {
531       auto I = LISLR.find(P.first);
532       if (I == LISLR.end()) {
533         OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
534            << " isn't found in LIS reported set\n";
535       } else if (I->second != P.second) {
536         OS << Pfx << printReg(P.first, TRI)
537            << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
538            << ", tracked " << PrintLaneMask(P.second) << '\n';
539       }
540     }
541     for (auto const &P : LISLR) {
542       auto I = TrackedLR.find(P.first);
543       if (I == TrackedLR.end()) {
544         OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
545            << " isn't found in tracked set\n";
546       }
547     }
548   });
549 }
550 
551 bool GCNUpwardRPTracker::isValid() const {
552   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
553   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
554   const auto &TrackedLR = LiveRegs;
555 
556   if (!isEqual(LISLR, TrackedLR)) {
557     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
558               " LIS reported livesets mismatch:\n"
559            << print(LISLR, *MRI);
560     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
561     return false;
562   }
563 
564   auto LISPressure = getRegPressure(*MRI, LISLR);
565   if (LISPressure != CurPressure) {
566     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
567            << print(CurPressure) << "LIS rpt: " << print(LISPressure);
568     return false;
569   }
570   return true;
571 }
572 
573 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
574                       const MachineRegisterInfo &MRI) {
575   return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
576     const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
577     for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
578       Register Reg = Register::index2VirtReg(I);
579       auto It = LiveRegs.find(Reg);
580       if (It != LiveRegs.end() && It->second.any())
581         OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
582            << PrintLaneMask(It->second);
583     }
584     OS << '\n';
585   });
586 }
587 
588 void GCNRegPressure::dump() const { dbgs() << print(*this); }
589 
590 static cl::opt<bool> UseDownwardTracker(
591     "amdgpu-print-rp-downward",
592     cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
593     cl::init(false), cl::Hidden);
594 
595 char llvm::GCNRegPressurePrinter::ID = 0;
596 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;
597 
598 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
599 
600 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and
601 // are fully covered by Mask.
602 static LaneBitmask
603 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS,
604                       Register Reg, SlotIndex Begin, SlotIndex End,
605                       LaneBitmask Mask = LaneBitmask::getAll()) {
606 
607   auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool {
608     auto *Segment = LR.getSegmentContaining(Begin);
609     return Segment && Segment->contains(End);
610   };
611 
612   LaneBitmask LiveThroughMask;
613   const LiveInterval &LI = LIS.getInterval(Reg);
614   if (LI.hasSubRanges()) {
615     for (auto &SR : LI.subranges()) {
616       if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
617         LiveThroughMask |= SR.LaneMask;
618     }
619   } else {
620     LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg);
621     if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
622       LiveThroughMask = RegMask;
623   }
624 
625   return LiveThroughMask;
626 }
627 
628 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
629   const MachineRegisterInfo &MRI = MF.getRegInfo();
630   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
631   const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
632 
633   auto &OS = dbgs();
634 
635 // Leading spaces are important for YAML syntax.
636 #define PFX "  "
637 
638   OS << "---\nname: " << MF.getName() << "\nbody:             |\n";
639 
640   auto printRP = [](const GCNRegPressure &RP) {
641     return Printable([&RP](raw_ostream &OS) {
642       OS << format(PFX "  %-5d", RP.getSGPRNum())
643          << format(" %-5d", RP.getVGPRNum(false));
644     });
645   };
646 
647   auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
648                                     const GCNRPTracker::LiveRegSet &LISLR) {
649     if (LISLR != TrackedLR) {
650       OS << PFX "  mis LIS: " << llvm::print(LISLR, MRI)
651          << reportMismatch(LISLR, TrackedLR, TRI, PFX "    ");
652     }
653   };
654 
655   // Register pressure before and at an instruction (in program order).
656   SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;
657 
658   for (auto &MBB : MF) {
659     RP.clear();
660     RP.reserve(MBB.size());
661 
662     OS << PFX;
663     MBB.printName(OS);
664     OS << ":\n";
665 
666     SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
667     SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);
668 
669     GCNRPTracker::LiveRegSet LiveIn, LiveOut;
670     GCNRegPressure RPAtMBBEnd;
671 
672     if (UseDownwardTracker) {
673       if (MBB.empty()) {
674         LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI);
675         RPAtMBBEnd = getRegPressure(MRI, LiveIn);
676       } else {
677         GCNDownwardRPTracker RPT(LIS);
678         RPT.reset(MBB.front());
679 
680         LiveIn = RPT.getLiveRegs();
681 
682         while (!RPT.advanceBeforeNext()) {
683           GCNRegPressure RPBeforeMI = RPT.getPressure();
684           RPT.advanceToNext();
685           RP.emplace_back(RPBeforeMI, RPT.getPressure());
686         }
687 
688         LiveOut = RPT.getLiveRegs();
689         RPAtMBBEnd = RPT.getPressure();
690       }
691     } else {
692       GCNUpwardRPTracker RPT(LIS);
693       RPT.reset(MRI, MBBEndSlot);
694 
695       LiveOut = RPT.getLiveRegs();
696       RPAtMBBEnd = RPT.getPressure();
697 
698       for (auto &MI : reverse(MBB)) {
699         RPT.resetMaxPressure();
700         RPT.recede(MI);
701         if (!MI.isDebugInstr())
702           RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure());
703       }
704 
705       LiveIn = RPT.getLiveRegs();
706     }
707 
708     OS << PFX "  Live-in: " << llvm::print(LiveIn, MRI);
709     if (!UseDownwardTracker)
710       ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI));
711 
712     OS << PFX "  SGPR  VGPR\n";
713     int I = 0;
714     for (auto &MI : MBB) {
715       if (!MI.isDebugInstr()) {
716         auto &[RPBeforeInstr, RPAtInstr] =
717             RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
718         ++I;
719         OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << "  ";
720       } else
721         OS << PFX "               ";
722       MI.print(OS);
723     }
724     OS << printRP(RPAtMBBEnd) << '\n';
725 
726     OS << PFX "  Live-out:" << llvm::print(LiveOut, MRI);
727     if (UseDownwardTracker)
728       ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI));
729 
730     GCNRPTracker::LiveRegSet LiveThrough;
731     for (auto [Reg, Mask] : LiveIn) {
732       LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg);
733       if (MaskIntersection.any()) {
734         LaneBitmask LTMask = getRegLiveThroughMask(
735             MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
736         if (LTMask.any())
737           LiveThrough[Reg] = LTMask;
738       }
739     }
740     OS << PFX "  Live-thr:" << llvm::print(LiveThrough, MRI);
741     OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n';
742   }
743   OS << "...\n";
744   return false;
745 
746 #undef PFX
747 }