xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (revision 1165fc9a526630487a1feb63daef65c5aee1a583)
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "llvm/CodeGen/RegisterPressure.h"
16 
17 using namespace llvm;
18 
19 #define DEBUG_TYPE "machine-scheduler"
20 
21 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
22 LLVM_DUMP_METHOD
23 void llvm::printLivesAt(SlotIndex SI,
24                         const LiveIntervals &LIS,
25                         const MachineRegisterInfo &MRI) {
26   dbgs() << "Live regs at " << SI << ": "
27          << *LIS.getInstructionFromIndex(SI);
28   unsigned Num = 0;
29   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
30     const Register Reg = Register::index2VirtReg(I);
31     if (!LIS.hasInterval(Reg))
32       continue;
33     const auto &LI = LIS.getInterval(Reg);
34     if (LI.hasSubRanges()) {
35       bool firstTime = true;
36       for (const auto &S : LI.subranges()) {
37         if (!S.liveAt(SI)) continue;
38         if (firstTime) {
39           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
40                  << '\n';
41           firstTime = false;
42         }
43         dbgs() << "  " << S << '\n';
44         ++Num;
45       }
46     } else if (LI.liveAt(SI)) {
47       dbgs() << "  " << LI << '\n';
48       ++Num;
49     }
50   }
51   if (!Num) dbgs() << "  <none>\n";
52 }
53 #endif
54 
55 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
56                    const GCNRPTracker::LiveRegSet &S2) {
57   if (S1.size() != S2.size())
58     return false;
59 
60   for (const auto &P : S1) {
61     auto I = S2.find(P.first);
62     if (I == S2.end() || I->second != P.second)
63       return false;
64   }
65   return true;
66 }
67 
68 
69 ///////////////////////////////////////////////////////////////////////////////
70 // GCNRegPressure
71 
72 unsigned GCNRegPressure::getRegKind(Register Reg,
73                                     const MachineRegisterInfo &MRI) {
74   assert(Reg.isVirtual());
75   const auto RC = MRI.getRegClass(Reg);
76   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
77   return STI->isSGPRClass(RC)
78              ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
79          : STI->isAGPRClass(RC)
80              ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
81              : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
82 }
83 
84 void GCNRegPressure::inc(unsigned Reg,
85                          LaneBitmask PrevMask,
86                          LaneBitmask NewMask,
87                          const MachineRegisterInfo &MRI) {
88   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
89       SIRegisterInfo::getNumCoveredRegs(PrevMask))
90     return;
91 
92   int Sign = 1;
93   if (NewMask < PrevMask) {
94     std::swap(NewMask, PrevMask);
95     Sign = -1;
96   }
97 
98   switch (auto Kind = getRegKind(Reg, MRI)) {
99   case SGPR32:
100   case VGPR32:
101   case AGPR32:
102     Value[Kind] += Sign;
103     break;
104 
105   case SGPR_TUPLE:
106   case VGPR_TUPLE:
107   case AGPR_TUPLE:
108     assert(PrevMask < NewMask);
109 
110     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
111       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
112 
113     if (PrevMask.none()) {
114       assert(NewMask.any());
115       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
116     }
117     break;
118 
119   default: llvm_unreachable("Unknown register kind");
120   }
121 }
122 
123 bool GCNRegPressure::less(const GCNSubtarget &ST,
124                           const GCNRegPressure& O,
125                           unsigned MaxOccupancy) const {
126   const auto SGPROcc = std::min(MaxOccupancy,
127                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
128   const auto VGPROcc =
129     std::min(MaxOccupancy,
130              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
131   const auto OtherSGPROcc = std::min(MaxOccupancy,
132                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
133   const auto OtherVGPROcc =
134     std::min(MaxOccupancy,
135              ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
136 
137   const auto Occ = std::min(SGPROcc, VGPROcc);
138   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
139   if (Occ != OtherOcc)
140     return Occ > OtherOcc;
141 
142   bool SGPRImportant = SGPROcc < VGPROcc;
143   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
144 
145   // if both pressures disagree on what is more important compare vgprs
146   if (SGPRImportant != OtherSGPRImportant) {
147     SGPRImportant = false;
148   }
149 
150   // compare large regs pressure
151   bool SGPRFirst = SGPRImportant;
152   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
153     if (SGPRFirst) {
154       auto SW = getSGPRTuplesWeight();
155       auto OtherSW = O.getSGPRTuplesWeight();
156       if (SW != OtherSW)
157         return SW < OtherSW;
158     } else {
159       auto VW = getVGPRTuplesWeight();
160       auto OtherVW = O.getVGPRTuplesWeight();
161       if (VW != OtherVW)
162         return VW < OtherVW;
163     }
164   }
165   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
166                          (getVGPRNum(ST.hasGFX90AInsts()) <
167                           O.getVGPRNum(ST.hasGFX90AInsts()));
168 }
169 
170 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
171 LLVM_DUMP_METHOD
172 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
173   OS << "VGPRs: " << Value[VGPR32] << ' ';
174   OS << "AGPRs: " << Value[AGPR32];
175   if (ST) OS << "(O"
176              << ST->getOccupancyWithNumVGPRs(getVGPRNum(ST->hasGFX90AInsts()))
177              << ')';
178   OS << ", SGPRs: " << getSGPRNum();
179   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
180   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
181      << ", LSGPR WT: " << getSGPRTuplesWeight();
182   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
183   OS << '\n';
184 }
185 #endif
186 
187 static LaneBitmask getDefRegMask(const MachineOperand &MO,
188                                  const MachineRegisterInfo &MRI) {
189   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
190 
191   // We don't rely on read-undef flag because in case of tentative schedule
192   // tracking it isn't set correctly yet. This works correctly however since
193   // use mask has been tracked before using LIS.
194   return MO.getSubReg() == 0 ?
195     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
196     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
197 }
198 
199 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
200                                   const MachineRegisterInfo &MRI,
201                                   const LiveIntervals &LIS) {
202   assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
203 
204   if (auto SubReg = MO.getSubReg())
205     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
206 
207   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
208   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
209     return MaxMask;
210 
211   // For a tentative schedule LIS isn't updated yet but livemask should remain
212   // the same on any schedule. Subreg defs can be reordered but they all must
213   // dominate uses anyway.
214   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
215   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
216 }
217 
218 static SmallVector<RegisterMaskPair, 8>
219 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
220                       const MachineRegisterInfo &MRI) {
221   SmallVector<RegisterMaskPair, 8> Res;
222   for (const auto &MO : MI.operands()) {
223     if (!MO.isReg() || !MO.getReg().isVirtual())
224       continue;
225     if (!MO.isUse() || !MO.readsReg())
226       continue;
227 
228     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
229 
230     auto Reg = MO.getReg();
231     auto I = llvm::find_if(
232         Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
233     if (I != Res.end())
234       I->LaneMask |= UsedMask;
235     else
236       Res.push_back(RegisterMaskPair(Reg, UsedMask));
237   }
238   return Res;
239 }
240 
241 ///////////////////////////////////////////////////////////////////////////////
242 // GCNRPTracker
243 
244 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
245                                   SlotIndex SI,
246                                   const LiveIntervals &LIS,
247                                   const MachineRegisterInfo &MRI) {
248   LaneBitmask LiveMask;
249   const auto &LI = LIS.getInterval(Reg);
250   if (LI.hasSubRanges()) {
251     for (const auto &S : LI.subranges())
252       if (S.liveAt(SI)) {
253         LiveMask |= S.LaneMask;
254         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
255                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
256       }
257   } else if (LI.liveAt(SI)) {
258     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
259   }
260   return LiveMask;
261 }
262 
263 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
264                                            const LiveIntervals &LIS,
265                                            const MachineRegisterInfo &MRI) {
266   GCNRPTracker::LiveRegSet LiveRegs;
267   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
268     auto Reg = Register::index2VirtReg(I);
269     if (!LIS.hasInterval(Reg))
270       continue;
271     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
272     if (LiveMask.any())
273       LiveRegs[Reg] = LiveMask;
274   }
275   return LiveRegs;
276 }
277 
278 void GCNRPTracker::reset(const MachineInstr &MI,
279                          const LiveRegSet *LiveRegsCopy,
280                          bool After) {
281   const MachineFunction &MF = *MI.getMF();
282   MRI = &MF.getRegInfo();
283   if (LiveRegsCopy) {
284     if (&LiveRegs != LiveRegsCopy)
285       LiveRegs = *LiveRegsCopy;
286   } else {
287     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
288                      : getLiveRegsBefore(MI, LIS);
289   }
290 
291   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
292 }
293 
294 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
295                                const LiveRegSet *LiveRegsCopy) {
296   GCNRPTracker::reset(MI, LiveRegsCopy, true);
297 }
298 
299 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
300   assert(MRI && "call reset first");
301 
302   LastTrackedMI = &MI;
303 
304   if (MI.isDebugInstr())
305     return;
306 
307   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
308 
309   // calc pressure at the MI (defs + uses)
310   auto AtMIPressure = CurPressure;
311   for (const auto &U : RegUses) {
312     auto LiveMask = LiveRegs[U.RegUnit];
313     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
314   }
315   // update max pressure
316   MaxPressure = max(AtMIPressure, MaxPressure);
317 
318   for (const auto &MO : MI.operands()) {
319     if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
320       continue;
321 
322     auto Reg = MO.getReg();
323     auto I = LiveRegs.find(Reg);
324     if (I == LiveRegs.end())
325       continue;
326     auto &LiveMask = I->second;
327     auto PrevMask = LiveMask;
328     LiveMask &= ~getDefRegMask(MO, *MRI);
329     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
330     if (LiveMask.none())
331       LiveRegs.erase(I);
332   }
333   for (const auto &U : RegUses) {
334     auto &LiveMask = LiveRegs[U.RegUnit];
335     auto PrevMask = LiveMask;
336     LiveMask |= U.LaneMask;
337     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
338   }
339   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
340 }
341 
342 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
343                                  const LiveRegSet *LiveRegsCopy) {
344   MRI = &MI.getParent()->getParent()->getRegInfo();
345   LastTrackedMI = nullptr;
346   MBBEnd = MI.getParent()->end();
347   NextMI = &MI;
348   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
349   if (NextMI == MBBEnd)
350     return false;
351   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
352   return true;
353 }
354 
355 bool GCNDownwardRPTracker::advanceBeforeNext() {
356   assert(MRI && "call reset first");
357 
358   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
359   if (NextMI == MBBEnd)
360     return false;
361 
362   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
363   assert(SI.isValid());
364 
365   // Remove dead registers or mask bits.
366   for (auto &It : LiveRegs) {
367     const LiveInterval &LI = LIS.getInterval(It.first);
368     if (LI.hasSubRanges()) {
369       for (const auto &S : LI.subranges()) {
370         if (!S.liveAt(SI)) {
371           auto PrevMask = It.second;
372           It.second &= ~S.LaneMask;
373           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
374         }
375       }
376     } else if (!LI.liveAt(SI)) {
377       auto PrevMask = It.second;
378       It.second = LaneBitmask::getNone();
379       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
380     }
381     if (It.second.none())
382       LiveRegs.erase(It.first);
383   }
384 
385   MaxPressure = max(MaxPressure, CurPressure);
386 
387   return true;
388 }
389 
390 void GCNDownwardRPTracker::advanceToNext() {
391   LastTrackedMI = &*NextMI++;
392   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
393 
394   // Add new registers or mask bits.
395   for (const auto &MO : LastTrackedMI->operands()) {
396     if (!MO.isReg() || !MO.isDef())
397       continue;
398     Register Reg = MO.getReg();
399     if (!Reg.isVirtual())
400       continue;
401     auto &LiveMask = LiveRegs[Reg];
402     auto PrevMask = LiveMask;
403     LiveMask |= getDefRegMask(MO, *MRI);
404     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
405   }
406 
407   MaxPressure = max(MaxPressure, CurPressure);
408 }
409 
410 bool GCNDownwardRPTracker::advance() {
411   // If we have just called reset live set is actual.
412   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
413     return false;
414   advanceToNext();
415   return true;
416 }
417 
418 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
419   while (NextMI != End)
420     if (!advance()) return false;
421   return true;
422 }
423 
424 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
425                                    MachineBasicBlock::const_iterator End,
426                                    const LiveRegSet *LiveRegsCopy) {
427   reset(*Begin, LiveRegsCopy);
428   return advance(End);
429 }
430 
431 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
432 LLVM_DUMP_METHOD
433 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
434                            const GCNRPTracker::LiveRegSet &TrackedLR,
435                            const TargetRegisterInfo *TRI) {
436   for (auto const &P : TrackedLR) {
437     auto I = LISLR.find(P.first);
438     if (I == LISLR.end()) {
439       dbgs() << "  " << printReg(P.first, TRI)
440              << ":L" << PrintLaneMask(P.second)
441              << " isn't found in LIS reported set\n";
442     }
443     else if (I->second != P.second) {
444       dbgs() << "  " << printReg(P.first, TRI)
445         << " masks doesn't match: LIS reported "
446         << PrintLaneMask(I->second)
447         << ", tracked "
448         << PrintLaneMask(P.second)
449         << '\n';
450     }
451   }
452   for (auto const &P : LISLR) {
453     auto I = TrackedLR.find(P.first);
454     if (I == TrackedLR.end()) {
455       dbgs() << "  " << printReg(P.first, TRI)
456              << ":L" << PrintLaneMask(P.second)
457              << " isn't found in tracked set\n";
458     }
459   }
460 }
461 
462 bool GCNUpwardRPTracker::isValid() const {
463   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
464   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
465   const auto &TrackedLR = LiveRegs;
466 
467   if (!isEqual(LISLR, TrackedLR)) {
468     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
469               " LIS reported livesets mismatch:\n";
470     printLivesAt(SI, LIS, *MRI);
471     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
472     return false;
473   }
474 
475   auto LISPressure = getRegPressure(*MRI, LISLR);
476   if (LISPressure != CurPressure) {
477     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
478     CurPressure.print(dbgs());
479     dbgs() << "LIS rpt: ";
480     LISPressure.print(dbgs());
481     return false;
482   }
483   return true;
484 }
485 
486 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
487                                  const MachineRegisterInfo &MRI) {
488   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
489   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
490     Register Reg = Register::index2VirtReg(I);
491     auto It = LiveRegs.find(Reg);
492     if (It != LiveRegs.end() && It->second.any())
493       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
494          << PrintLaneMask(It->second);
495   }
496   OS << '\n';
497 }
498 #endif
499