xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (revision e3f4a63af63bea70bc86b6c790b14aa5ee99fcd0)
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPU.h"
16 #include "SIMachineFunctionInfo.h"
17 #include "llvm/CodeGen/RegisterPressure.h"
18 
19 using namespace llvm;
20 
21 #define DEBUG_TYPE "machine-scheduler"
22 
23 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
24                    const GCNRPTracker::LiveRegSet &S2) {
25   if (S1.size() != S2.size())
26     return false;
27 
28   for (const auto &P : S1) {
29     auto I = S2.find(P.first);
30     if (I == S2.end() || I->second != P.second)
31       return false;
32   }
33   return true;
34 }
35 
36 ///////////////////////////////////////////////////////////////////////////////
37 // GCNRegPressure
38 
39 unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC,
40                                     const SIRegisterInfo *STI) {
41   return STI->isSGPRClass(RC) ? SGPR : (STI->isAGPRClass(RC) ? AGPR : VGPR);
42 }
43 
44 void GCNRegPressure::inc(unsigned Reg,
45                          LaneBitmask PrevMask,
46                          LaneBitmask NewMask,
47                          const MachineRegisterInfo &MRI) {
48   unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(NewMask);
49   unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(PrevMask);
50   if (NewNumCoveredRegs == PrevNumCoveredRegs)
51     return;
52 
53   int Sign = 1;
54   if (NewMask < PrevMask) {
55     std::swap(NewMask, PrevMask);
56     std::swap(NewNumCoveredRegs, PrevNumCoveredRegs);
57     Sign = -1;
58   }
59   assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs &&
60          "prev mask should always be lesser than new");
61 
62   const TargetRegisterClass *RC = MRI.getRegClass(Reg);
63   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
64   const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI);
65   unsigned RegKind = getRegKind(RC, STI);
66   if (TRI->getRegSizeInBits(*RC) != 32) {
67     // Reg is from a tuple register class.
68     if (PrevMask.none()) {
69       unsigned TupleIdx = TOTAL_KINDS + RegKind;
70       Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight;
71     }
72     // Pressure scales with number of new registers covered by the new mask.
73     // Note when true16 is enabled, we can no longer safely use the following
74     // approach to calculate the difference in the number of 32-bit registers
75     // between two masks:
76     //
77     // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78     //
79     // The issue is that the mask calculation `~PrevMask & NewMask` doesn't
80     // properly account for partial usage of a 32-bit register when dealing with
81     // 16-bit registers.
82     //
83     // Consider this example:
84     // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register
85     // usage difference should be 1, because even though PrevMask uses only half
86     // of a 32-bit register, it should still be counted as a full register use.
87     // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and
88     // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect
89     // calculation can lead to integer overflow when Sign = -1.
90     Sign *= NewNumCoveredRegs - PrevNumCoveredRegs;
91   }
92   Value[RegKind] += Sign;
93 }
94 
95 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
96                           unsigned MaxOccupancy) const {
97   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
98   unsigned DynamicVGPRBlockSize =
99       MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
100 
101   const auto SGPROcc = std::min(MaxOccupancy,
102                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
103   const auto VGPROcc = std::min(
104       MaxOccupancy, ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()),
105                                                 DynamicVGPRBlockSize));
106   const auto OtherSGPROcc = std::min(MaxOccupancy,
107                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
108   const auto OtherVGPROcc =
109       std::min(MaxOccupancy,
110                ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()),
111                                            DynamicVGPRBlockSize));
112 
113   const auto Occ = std::min(SGPROcc, VGPROcc);
114   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
115 
116   // Give first precedence to the better occupancy.
117   if (Occ != OtherOcc)
118     return Occ > OtherOcc;
119 
120   unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
121   unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
122 
123   // SGPR excess pressure conditions
124   unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0);
125   unsigned OtherExcessSGPR =
126       std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
127 
128   auto WaveSize = ST.getWavefrontSize();
129   // The number of virtual VGPRs required to handle excess SGPR
130   unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
131   unsigned OtherVGPRForSGPRSpills =
132       (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
133 
134   unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
135 
136   // Unified excess pressure conditions, accounting for VGPRs used for SGPR
137   // spills
138   unsigned ExcessVGPR =
139       std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) +
140                                 VGPRForSGPRSpills - MaxVGPRs),
141                0);
142   unsigned OtherExcessVGPR =
143       std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
144                                 OtherVGPRForSGPRSpills - MaxVGPRs),
145                0);
146   // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
147   // spills
148   unsigned ExcessArchVGPR = std::max(
149       static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs),
150       0);
151   unsigned OtherExcessArchVGPR =
152       std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills -
153                                 MaxArchVGPRs),
154                0);
155   // AGPR excess pressure conditions
156   unsigned ExcessAGPR = std::max(
157       static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs)
158                                            : (getAGPRNum() - MaxVGPRs)),
159       0);
160   unsigned OtherExcessAGPR = std::max(
161       static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
162                                            : (O.getAGPRNum() - MaxVGPRs)),
163       0);
164 
165   bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
166   bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
167                        OtherExcessArchVGPR || OtherExcessAGPR;
168 
169   // Give second precedence to the reduced number of spills to hold the register
170   // pressure.
171   if (ExcessRP || OtherExcessRP) {
172     // The difference in excess VGPR pressure, after including VGPRs used for
173     // SGPR spills
174     int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
175                     (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
176 
177     int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
178 
179     if (VGPRDiff != 0)
180       return VGPRDiff > 0;
181     if (SGPRDiff != 0) {
182       unsigned PureExcessVGPR =
183           std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
184                    0) +
185           std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0);
186       unsigned OtherPureExcessVGPR =
187           std::max(
188               static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
189               0) +
190           std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0);
191 
192       // If we have a special case where there is a tie in excess VGPR, but one
193       // of the pressures has VGPR usage from SGPR spills, prefer the pressure
194       // with SGPR spills.
195       if (PureExcessVGPR != OtherPureExcessVGPR)
196         return SGPRDiff < 0;
197       // If both pressures have the same excess pressure before and after
198       // accounting for SGPR spills, prefer fewer SGPR spills.
199       return SGPRDiff > 0;
200     }
201   }
202 
203   bool SGPRImportant = SGPROcc < VGPROcc;
204   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
205 
206   // If both pressures disagree on what is more important compare vgprs.
207   if (SGPRImportant != OtherSGPRImportant) {
208     SGPRImportant = false;
209   }
210 
211   // Give third precedence to lower register tuple pressure.
212   bool SGPRFirst = SGPRImportant;
213   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
214     if (SGPRFirst) {
215       auto SW = getSGPRTuplesWeight();
216       auto OtherSW = O.getSGPRTuplesWeight();
217       if (SW != OtherSW)
218         return SW < OtherSW;
219     } else {
220       auto VW = getVGPRTuplesWeight();
221       auto OtherVW = O.getVGPRTuplesWeight();
222       if (VW != OtherVW)
223         return VW < OtherVW;
224     }
225   }
226 
227   // Give final precedence to lower general RP.
228   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
229                          (getVGPRNum(ST.hasGFX90AInsts()) <
230                           O.getVGPRNum(ST.hasGFX90AInsts()));
231 }
232 
233 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST,
234                       unsigned DynamicVGPRBlockSize) {
235   return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) {
236     OS << "VGPRs: " << RP.getArchVGPRNum() << ' '
237        << "AGPRs: " << RP.getAGPRNum();
238     if (ST)
239       OS << "(O"
240          << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()),
241                                          DynamicVGPRBlockSize)
242          << ')';
243     OS << ", SGPRs: " << RP.getSGPRNum();
244     if (ST)
245       OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
246     OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
247        << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
248     if (ST)
249       OS << " -> Occ: " << RP.getOccupancy(*ST, DynamicVGPRBlockSize);
250     OS << '\n';
251   });
252 }
253 
254 static LaneBitmask getDefRegMask(const MachineOperand &MO,
255                                  const MachineRegisterInfo &MRI) {
256   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
257 
258   // We don't rely on read-undef flag because in case of tentative schedule
259   // tracking it isn't set correctly yet. This works correctly however since
260   // use mask has been tracked before using LIS.
261   return MO.getSubReg() == 0 ?
262     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
263     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
264 }
265 
266 static void
267 collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
268                       const MachineInstr &MI, const LiveIntervals &LIS,
269                       const MachineRegisterInfo &MRI) {
270 
271   auto &TRI = *MRI.getTargetRegisterInfo();
272   for (const auto &MO : MI.operands()) {
273     if (!MO.isReg() || !MO.getReg().isVirtual())
274       continue;
275     if (!MO.isUse() || !MO.readsReg())
276       continue;
277 
278     Register Reg = MO.getReg();
279     auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) {
280       return RM.RegUnit == Reg;
281     });
282 
283     auto &P = I == VRegMaskOrUnits.end()
284                   ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone())
285                   : *I;
286 
287     P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg())
288                                  : MRI.getMaxLaneMaskForVReg(Reg);
289   }
290 
291   SlotIndex InstrSI;
292   for (auto &P : VRegMaskOrUnits) {
293     auto &LI = LIS.getInterval(P.RegUnit);
294     if (!LI.hasSubRanges())
295       continue;
296 
297     // For a tentative schedule LIS isn't updated yet but livemask should
298     // remain the same on any schedule. Subreg defs can be reordered but they
299     // all must dominate uses anyway.
300     if (!InstrSI)
301       InstrSI = LIS.getInstructionIndex(MI).getBaseIndex();
302 
303     P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask);
304   }
305 }
306 
307 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
308 static LaneBitmask getLanesWithProperty(
309     const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
310     bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
311     LaneBitmask SafeDefault,
312     function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
313   if (RegUnit.isVirtual()) {
314     const LiveInterval &LI = LIS.getInterval(RegUnit);
315     LaneBitmask Result;
316     if (TrackLaneMasks && LI.hasSubRanges()) {
317       for (const LiveInterval::SubRange &SR : LI.subranges()) {
318         if (Property(SR, Pos))
319           Result |= SR.LaneMask;
320       }
321     } else if (Property(LI, Pos)) {
322       Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
323                               : LaneBitmask::getAll();
324     }
325 
326     return Result;
327   }
328 
329   const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
330   if (LR == nullptr)
331     return SafeDefault;
332   return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
333 }
334 
335 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
336 /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
337 /// The query starts with a lane bitmask which gets lanes/bits removed for every
338 /// use we find.
339 static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
340                                   SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
341                                   const MachineRegisterInfo &MRI,
342                                   const SIRegisterInfo *TRI,
343                                   const LiveIntervals *LIS,
344                                   bool Upward = false) {
345   for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
346     if (MO.isUndef())
347       continue;
348     const MachineInstr *MI = MO.getParent();
349     SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
350     bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
351                           : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
352     if (!InRange)
353       continue;
354 
355     unsigned SubRegIdx = MO.getSubReg();
356     LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
357     LastUseMask &= ~UseMask;
358     if (LastUseMask.none())
359       return LaneBitmask::getNone();
360   }
361   return LastUseMask;
362 }
363 
364 ////////////////////////////////////////////////////////////////////////////////
365 // GCNRPTarget
366 
367 GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP,
368                          bool CombineVGPRSavings)
369     : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
370   const Function &F = MF.getFunction();
371   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
372   setRegLimits(ST.getMaxNumSGPRs(F), ST.getMaxNumVGPRs(F), MF);
373 }
374 
375 GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs,
376                          const MachineFunction &MF, const GCNRegPressure &RP,
377                          bool CombineVGPRSavings)
378     : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
379   setRegLimits(NumSGPRs, NumVGPRs, MF);
380 }
381 
382 GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF,
383                          const GCNRegPressure &RP, bool CombineVGPRSavings)
384     : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
385   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
386   unsigned DynamicVGPRBlockSize =
387       MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
388   setRegLimits(ST.getMaxNumSGPRs(Occupancy, /*Addressable=*/false),
389                ST.getMaxNumVGPRs(Occupancy, DynamicVGPRBlockSize), MF);
390 }
391 
392 void GCNRPTarget::setRegLimits(unsigned NumSGPRs, unsigned NumVGPRs,
393                                const MachineFunction &MF) {
394   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
395   unsigned DynamicVGPRBlockSize =
396       MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
397   MaxSGPRs = std::min(ST.getAddressableNumSGPRs(), NumSGPRs);
398   MaxVGPRs = std::min(ST.getAddressableNumArchVGPRs(), NumVGPRs);
399   MaxUnifiedVGPRs =
400       ST.hasGFX90AInsts()
401           ? std::min(ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), NumVGPRs)
402           : 0;
403 }
404 
405 bool GCNRPTarget::isSaveBeneficial(Register Reg,
406                                    const MachineRegisterInfo &MRI) const {
407   const TargetRegisterClass *RC = MRI.getRegClass(Reg);
408   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
409   const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI);
410 
411   if (SRI->isSGPRClass(RC))
412     return RP.getSGPRNum() > MaxSGPRs;
413   unsigned NumVGPRs =
414       SRI->isAGPRClass(RC) ? RP.getAGPRNum() : RP.getArchVGPRNum();
415   return isVGPRBankSaveBeneficial(NumVGPRs);
416 }
417 
418 bool GCNRPTarget::satisfied() const {
419   if (RP.getSGPRNum() > MaxSGPRs)
420     return false;
421   if (RP.getVGPRNum(false) > MaxVGPRs &&
422       (!CombineVGPRSavings || !satisifiesVGPRBanksTarget()))
423     return false;
424   return satisfiesUnifiedTarget();
425 }
426 
427 ///////////////////////////////////////////////////////////////////////////////
428 // GCNRPTracker
429 
430 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI,
431                                   const LiveIntervals &LIS,
432                                   const MachineRegisterInfo &MRI,
433                                   LaneBitmask LaneMaskFilter) {
434   return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter);
435 }
436 
437 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
438                                   const MachineRegisterInfo &MRI,
439                                   LaneBitmask LaneMaskFilter) {
440   LaneBitmask LiveMask;
441   if (LI.hasSubRanges()) {
442     for (const auto &S : LI.subranges())
443       if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) {
444         LiveMask |= S.LaneMask;
445         assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
446       }
447   } else if (LI.liveAt(SI)) {
448     LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg());
449   }
450   LiveMask &= LaneMaskFilter;
451   return LiveMask;
452 }
453 
454 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
455                                            const LiveIntervals &LIS,
456                                            const MachineRegisterInfo &MRI) {
457   GCNRPTracker::LiveRegSet LiveRegs;
458   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
459     auto Reg = Register::index2VirtReg(I);
460     if (!LIS.hasInterval(Reg))
461       continue;
462     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
463     if (LiveMask.any())
464       LiveRegs[Reg] = LiveMask;
465   }
466   return LiveRegs;
467 }
468 
469 void GCNRPTracker::reset(const MachineInstr &MI,
470                          const LiveRegSet *LiveRegsCopy,
471                          bool After) {
472   const MachineFunction &MF = *MI.getMF();
473   MRI = &MF.getRegInfo();
474   if (LiveRegsCopy) {
475     if (&LiveRegs != LiveRegsCopy)
476       LiveRegs = *LiveRegsCopy;
477   } else {
478     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
479                      : getLiveRegsBefore(MI, LIS);
480   }
481 
482   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
483 }
484 
485 void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
486                          const LiveRegSet &LiveRegs_) {
487   MRI = &MRI_;
488   LiveRegs = LiveRegs_;
489   LastTrackedMI = nullptr;
490   MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
491 }
492 
493 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
494 LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
495                                            SlotIndex Pos) const {
496   return getLanesWithProperty(
497       LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
498       [](const LiveRange &LR, SlotIndex Pos) {
499         const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
500         return S != nullptr && S->end == Pos.getRegSlot();
501       });
502 }
503 
504 ////////////////////////////////////////////////////////////////////////////////
505 // GCNUpwardRPTracker
506 
507 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
508   assert(MRI && "call reset first");
509 
510   LastTrackedMI = &MI;
511 
512   if (MI.isDebugInstr())
513     return;
514 
515   // Kill all defs.
516   GCNRegPressure DefPressure, ECDefPressure;
517   bool HasECDefs = false;
518   for (const MachineOperand &MO : MI.all_defs()) {
519     if (!MO.getReg().isVirtual())
520       continue;
521 
522     Register Reg = MO.getReg();
523     LaneBitmask DefMask = getDefRegMask(MO, *MRI);
524 
525     // Treat a def as fully live at the moment of definition: keep a record.
526     if (MO.isEarlyClobber()) {
527       ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
528       HasECDefs = true;
529     } else
530       DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
531 
532     auto I = LiveRegs.find(Reg);
533     if (I == LiveRegs.end())
534       continue;
535 
536     LaneBitmask &LiveMask = I->second;
537     LaneBitmask PrevMask = LiveMask;
538     LiveMask &= ~DefMask;
539     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
540     if (LiveMask.none())
541       LiveRegs.erase(I);
542   }
543 
544   // Update MaxPressure with defs pressure.
545   DefPressure += CurPressure;
546   if (HasECDefs)
547     DefPressure += ECDefPressure;
548   MaxPressure = max(DefPressure, MaxPressure);
549 
550   // Make uses alive.
551   SmallVector<VRegMaskOrUnit, 8> RegUses;
552   collectVirtualRegUses(RegUses, MI, LIS, *MRI);
553   for (const VRegMaskOrUnit &U : RegUses) {
554     LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
555     LaneBitmask PrevMask = LiveMask;
556     LiveMask |= U.LaneMask;
557     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
558   }
559 
560   // Update MaxPressure with uses plus early-clobber defs pressure.
561   MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure)
562                           : max(CurPressure, MaxPressure);
563 
564   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
565 }
566 
567 ////////////////////////////////////////////////////////////////////////////////
568 // GCNDownwardRPTracker
569 
570 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
571                                  const LiveRegSet *LiveRegsCopy) {
572   MRI = &MI.getParent()->getParent()->getRegInfo();
573   LastTrackedMI = nullptr;
574   MBBEnd = MI.getParent()->end();
575   NextMI = &MI;
576   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
577   if (NextMI == MBBEnd)
578     return false;
579   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
580   return true;
581 }
582 
583 bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
584                                              bool UseInternalIterator) {
585   assert(MRI && "call reset first");
586   SlotIndex SI;
587   const MachineInstr *CurrMI;
588   if (UseInternalIterator) {
589     if (!LastTrackedMI)
590       return NextMI == MBBEnd;
591 
592     assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
593     CurrMI = LastTrackedMI;
594 
595     SI = NextMI == MBBEnd
596              ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
597              : LIS.getInstructionIndex(*NextMI).getBaseIndex();
598   } else { //! UseInternalIterator
599     SI = LIS.getInstructionIndex(*MI).getBaseIndex();
600     CurrMI = MI;
601   }
602 
603   assert(SI.isValid());
604 
605   // Remove dead registers or mask bits.
606   SmallSet<Register, 8> SeenRegs;
607   for (auto &MO : CurrMI->operands()) {
608     if (!MO.isReg() || !MO.getReg().isVirtual())
609       continue;
610     if (MO.isUse() && !MO.readsReg())
611       continue;
612     if (!UseInternalIterator && MO.isDef())
613       continue;
614     if (!SeenRegs.insert(MO.getReg()).second)
615       continue;
616     const LiveInterval &LI = LIS.getInterval(MO.getReg());
617     if (LI.hasSubRanges()) {
618       auto It = LiveRegs.end();
619       for (const auto &S : LI.subranges()) {
620         if (!S.liveAt(SI)) {
621           if (It == LiveRegs.end()) {
622             It = LiveRegs.find(MO.getReg());
623             if (It == LiveRegs.end())
624               llvm_unreachable("register isn't live");
625           }
626           auto PrevMask = It->second;
627           It->second &= ~S.LaneMask;
628           CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
629         }
630       }
631       if (It != LiveRegs.end() && It->second.none())
632         LiveRegs.erase(It);
633     } else if (!LI.liveAt(SI)) {
634       auto It = LiveRegs.find(MO.getReg());
635       if (It == LiveRegs.end())
636         llvm_unreachable("register isn't live");
637       CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
638       LiveRegs.erase(It);
639     }
640   }
641 
642   MaxPressure = max(MaxPressure, CurPressure);
643 
644   LastTrackedMI = nullptr;
645 
646   return UseInternalIterator && (NextMI == MBBEnd);
647 }
648 
649 void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
650                                          bool UseInternalIterator) {
651   if (UseInternalIterator) {
652     LastTrackedMI = &*NextMI++;
653     NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
654   } else {
655     LastTrackedMI = MI;
656   }
657 
658   const MachineInstr *CurrMI = LastTrackedMI;
659 
660   // Add new registers or mask bits.
661   for (const auto &MO : CurrMI->all_defs()) {
662     Register Reg = MO.getReg();
663     if (!Reg.isVirtual())
664       continue;
665     auto &LiveMask = LiveRegs[Reg];
666     auto PrevMask = LiveMask;
667     LiveMask |= getDefRegMask(MO, *MRI);
668     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
669   }
670 
671   MaxPressure = max(MaxPressure, CurPressure);
672 }
673 
674 bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
675   if (UseInternalIterator && NextMI == MBBEnd)
676     return false;
677 
678   advanceBeforeNext(MI, UseInternalIterator);
679   advanceToNext(MI, UseInternalIterator);
680   if (!UseInternalIterator) {
681     // We must remove any dead def lanes from the current RP
682     advanceBeforeNext(MI, true);
683   }
684   return true;
685 }
686 
687 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
688   while (NextMI != End)
689     if (!advance()) return false;
690   return true;
691 }
692 
693 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
694                                    MachineBasicBlock::const_iterator End,
695                                    const LiveRegSet *LiveRegsCopy) {
696   reset(*Begin, LiveRegsCopy);
697   return advance(End);
698 }
699 
700 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
701                                const GCNRPTracker::LiveRegSet &TrackedLR,
702                                const TargetRegisterInfo *TRI, StringRef Pfx) {
703   return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
704     for (auto const &P : TrackedLR) {
705       auto I = LISLR.find(P.first);
706       if (I == LISLR.end()) {
707         OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
708            << " isn't found in LIS reported set\n";
709       } else if (I->second != P.second) {
710         OS << Pfx << printReg(P.first, TRI)
711            << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
712            << ", tracked " << PrintLaneMask(P.second) << '\n';
713       }
714     }
715     for (auto const &P : LISLR) {
716       auto I = TrackedLR.find(P.first);
717       if (I == TrackedLR.end()) {
718         OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
719            << " isn't found in tracked set\n";
720       }
721     }
722   });
723 }
724 
725 GCNRegPressure
726 GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
727                                            const SIRegisterInfo *TRI) const {
728   assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
729 
730   SlotIndex SlotIdx;
731   SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();
732 
733   // Account for register pressure similar to RegPressureTracker::recede().
734   RegisterOperands RegOpers;
735   RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
736   RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
737   GCNRegPressure TempPressure = CurPressure;
738 
739   for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
740     Register Reg = Use.RegUnit;
741     if (!Reg.isVirtual())
742       continue;
743     LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
744     if (LastUseMask.none())
745       continue;
746     // The LastUseMask is queried from the liveness information of instruction
747     // which may be further down the schedule. Some lanes may actually not be
748     // last uses for the current position.
749     // FIXME: allow the caller to pass in the list of vreg uses that remain
750     // to be bottom-scheduled to avoid searching uses at each query.
751     SlotIndex CurrIdx;
752     const MachineBasicBlock *MBB = MI->getParent();
753     MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
754         LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
755     if (IdxPos == MBB->end()) {
756       CurrIdx = LIS.getMBBEndIdx(MBB);
757     } else {
758       CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
759     }
760 
761     LastUseMask =
762         findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
763     if (LastUseMask.none())
764       continue;
765 
766     auto It = LiveRegs.find(Reg);
767     LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
768     LaneBitmask NewMask = LiveMask & ~LastUseMask;
769     TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
770   }
771 
772   // Generate liveness for defs.
773   for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
774     Register Reg = Def.RegUnit;
775     if (!Reg.isVirtual())
776       continue;
777     auto It = LiveRegs.find(Reg);
778     LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
779     LaneBitmask NewMask = LiveMask | Def.LaneMask;
780     TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
781   }
782 
783   return TempPressure;
784 }
785 
786 bool GCNUpwardRPTracker::isValid() const {
787   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
788   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
789   const auto &TrackedLR = LiveRegs;
790 
791   if (!isEqual(LISLR, TrackedLR)) {
792     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
793               " LIS reported livesets mismatch:\n"
794            << print(LISLR, *MRI);
795     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
796     return false;
797   }
798 
799   auto LISPressure = getRegPressure(*MRI, LISLR);
800   if (LISPressure != CurPressure) {
801     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
802            << print(CurPressure) << "LIS rpt: " << print(LISPressure);
803     return false;
804   }
805   return true;
806 }
807 
808 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
809                       const MachineRegisterInfo &MRI) {
810   return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
811     const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
812     for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
813       Register Reg = Register::index2VirtReg(I);
814       auto It = LiveRegs.find(Reg);
815       if (It != LiveRegs.end() && It->second.any())
816         OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
817            << PrintLaneMask(It->second);
818     }
819     OS << '\n';
820   });
821 }
822 
823 void GCNRegPressure::dump() const { dbgs() << print(*this); }
824 
825 static cl::opt<bool> UseDownwardTracker(
826     "amdgpu-print-rp-downward",
827     cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
828     cl::init(false), cl::Hidden);
829 
830 char llvm::GCNRegPressurePrinter::ID = 0;
831 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;
832 
833 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
834 
835 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and
836 // are fully covered by Mask.
837 static LaneBitmask
838 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS,
839                       Register Reg, SlotIndex Begin, SlotIndex End,
840                       LaneBitmask Mask = LaneBitmask::getAll()) {
841 
842   auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool {
843     auto *Segment = LR.getSegmentContaining(Begin);
844     return Segment && Segment->contains(End);
845   };
846 
847   LaneBitmask LiveThroughMask;
848   const LiveInterval &LI = LIS.getInterval(Reg);
849   if (LI.hasSubRanges()) {
850     for (auto &SR : LI.subranges()) {
851       if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
852         LiveThroughMask |= SR.LaneMask;
853     }
854   } else {
855     LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg);
856     if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
857       LiveThroughMask = RegMask;
858   }
859 
860   return LiveThroughMask;
861 }
862 
863 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
864   const MachineRegisterInfo &MRI = MF.getRegInfo();
865   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
866   const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
867 
868   auto &OS = dbgs();
869 
870 // Leading spaces are important for YAML syntax.
871 #define PFX "  "
872 
873   OS << "---\nname: " << MF.getName() << "\nbody:             |\n";
874 
875   auto printRP = [](const GCNRegPressure &RP) {
876     return Printable([&RP](raw_ostream &OS) {
877       OS << format(PFX "  %-5d", RP.getSGPRNum())
878          << format(" %-5d", RP.getVGPRNum(false));
879     });
880   };
881 
882   auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
883                                     const GCNRPTracker::LiveRegSet &LISLR) {
884     if (LISLR != TrackedLR) {
885       OS << PFX "  mis LIS: " << llvm::print(LISLR, MRI)
886          << reportMismatch(LISLR, TrackedLR, TRI, PFX "    ");
887     }
888   };
889 
890   // Register pressure before and at an instruction (in program order).
891   SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;
892 
893   for (auto &MBB : MF) {
894     RP.clear();
895     RP.reserve(MBB.size());
896 
897     OS << PFX;
898     MBB.printName(OS);
899     OS << ":\n";
900 
901     SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
902     SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);
903 
904     GCNRPTracker::LiveRegSet LiveIn, LiveOut;
905     GCNRegPressure RPAtMBBEnd;
906 
907     if (UseDownwardTracker) {
908       if (MBB.empty()) {
909         LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI);
910         RPAtMBBEnd = getRegPressure(MRI, LiveIn);
911       } else {
912         GCNDownwardRPTracker RPT(LIS);
913         RPT.reset(MBB.front());
914 
915         LiveIn = RPT.getLiveRegs();
916 
917         while (!RPT.advanceBeforeNext()) {
918           GCNRegPressure RPBeforeMI = RPT.getPressure();
919           RPT.advanceToNext();
920           RP.emplace_back(RPBeforeMI, RPT.getPressure());
921         }
922 
923         LiveOut = RPT.getLiveRegs();
924         RPAtMBBEnd = RPT.getPressure();
925       }
926     } else {
927       GCNUpwardRPTracker RPT(LIS);
928       RPT.reset(MRI, MBBEndSlot);
929 
930       LiveOut = RPT.getLiveRegs();
931       RPAtMBBEnd = RPT.getPressure();
932 
933       for (auto &MI : reverse(MBB)) {
934         RPT.resetMaxPressure();
935         RPT.recede(MI);
936         if (!MI.isDebugInstr())
937           RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure());
938       }
939 
940       LiveIn = RPT.getLiveRegs();
941     }
942 
943     OS << PFX "  Live-in: " << llvm::print(LiveIn, MRI);
944     if (!UseDownwardTracker)
945       ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI));
946 
947     OS << PFX "  SGPR  VGPR\n";
948     int I = 0;
949     for (auto &MI : MBB) {
950       if (!MI.isDebugInstr()) {
951         auto &[RPBeforeInstr, RPAtInstr] =
952             RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
953         ++I;
954         OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << "  ";
955       } else
956         OS << PFX "               ";
957       MI.print(OS);
958     }
959     OS << printRP(RPAtMBBEnd) << '\n';
960 
961     OS << PFX "  Live-out:" << llvm::print(LiveOut, MRI);
962     if (UseDownwardTracker)
963       ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI));
964 
965     GCNRPTracker::LiveRegSet LiveThrough;
966     for (auto [Reg, Mask] : LiveIn) {
967       LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg);
968       if (MaskIntersection.any()) {
969         LaneBitmask LTMask = getRegLiveThroughMask(
970             MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
971         if (LTMask.any())
972           LiveThrough[Reg] = LTMask;
973       }
974     }
975     OS << PFX "  Live-thr:" << llvm::print(LiveThrough, MRI);
976     OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n';
977   }
978   OS << "...\n";
979   return false;
980 
981 #undef PFX
982 }
983