xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (revision 53071ed1c96db7f89defc99c95b0ad1031d48f45)
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <algorithm>
28 #include <cassert>
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "machine-scheduler"
33 
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
35 LLVM_DUMP_METHOD
36 void llvm::printLivesAt(SlotIndex SI,
37                         const LiveIntervals &LIS,
38                         const MachineRegisterInfo &MRI) {
39   dbgs() << "Live regs at " << SI << ": "
40          << *LIS.getInstructionFromIndex(SI);
41   unsigned Num = 0;
42   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
43     const unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
44     if (!LIS.hasInterval(Reg))
45       continue;
46     const auto &LI = LIS.getInterval(Reg);
47     if (LI.hasSubRanges()) {
48       bool firstTime = true;
49       for (const auto &S : LI.subranges()) {
50         if (!S.liveAt(SI)) continue;
51         if (firstTime) {
52           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
53                  << '\n';
54           firstTime = false;
55         }
56         dbgs() << "  " << S << '\n';
57         ++Num;
58       }
59     } else if (LI.liveAt(SI)) {
60       dbgs() << "  " << LI << '\n';
61       ++Num;
62     }
63   }
64   if (!Num) dbgs() << "  <none>\n";
65 }
66 #endif
67 
68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
69                    const GCNRPTracker::LiveRegSet &S2) {
70   if (S1.size() != S2.size())
71     return false;
72 
73   for (const auto &P : S1) {
74     auto I = S2.find(P.first);
75     if (I == S2.end() || I->second != P.second)
76       return false;
77   }
78   return true;
79 }
80 
81 
82 ///////////////////////////////////////////////////////////////////////////////
83 // GCNRegPressure
84 
85 unsigned GCNRegPressure::getRegKind(unsigned Reg,
86                                     const MachineRegisterInfo &MRI) {
87   assert(TargetRegisterInfo::isVirtualRegister(Reg));
88   const auto RC = MRI.getRegClass(Reg);
89   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90   return STI->isSGPRClass(RC) ?
91     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92     STI->hasAGPRs(RC) ?
93       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
94       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
95 }
96 
97 void GCNRegPressure::inc(unsigned Reg,
98                          LaneBitmask PrevMask,
99                          LaneBitmask NewMask,
100                          const MachineRegisterInfo &MRI) {
101   if (NewMask == PrevMask)
102     return;
103 
104   int Sign = 1;
105   if (NewMask < PrevMask) {
106     std::swap(NewMask, PrevMask);
107     Sign = -1;
108   }
109 #ifndef NDEBUG
110   const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
111 #endif
112   switch (auto Kind = getRegKind(Reg, MRI)) {
113   case SGPR32:
114   case VGPR32:
115   case AGPR32:
116     assert(PrevMask.none() && NewMask == MaxMask);
117     Value[Kind] += Sign;
118     break;
119 
120   case SGPR_TUPLE:
121   case VGPR_TUPLE:
122   case AGPR_TUPLE:
123     assert(NewMask < MaxMask || NewMask == MaxMask);
124     assert(PrevMask < NewMask);
125 
126     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
127       Sign * (~PrevMask & NewMask).getNumLanes();
128 
129     if (PrevMask.none()) {
130       assert(NewMask.any());
131       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
132     }
133     break;
134 
135   default: llvm_unreachable("Unknown register kind");
136   }
137 }
138 
139 bool GCNRegPressure::less(const GCNSubtarget &ST,
140                           const GCNRegPressure& O,
141                           unsigned MaxOccupancy) const {
142   const auto SGPROcc = std::min(MaxOccupancy,
143                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
144   const auto VGPROcc = std::min(MaxOccupancy,
145                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
146   const auto OtherSGPROcc = std::min(MaxOccupancy,
147                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
148   const auto OtherVGPROcc = std::min(MaxOccupancy,
149                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
150 
151   const auto Occ = std::min(SGPROcc, VGPROcc);
152   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
153   if (Occ != OtherOcc)
154     return Occ > OtherOcc;
155 
156   bool SGPRImportant = SGPROcc < VGPROcc;
157   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
158 
159   // if both pressures disagree on what is more important compare vgprs
160   if (SGPRImportant != OtherSGPRImportant) {
161     SGPRImportant = false;
162   }
163 
164   // compare large regs pressure
165   bool SGPRFirst = SGPRImportant;
166   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
167     if (SGPRFirst) {
168       auto SW = getSGPRTuplesWeight();
169       auto OtherSW = O.getSGPRTuplesWeight();
170       if (SW != OtherSW)
171         return SW < OtherSW;
172     } else {
173       auto VW = getVGPRTuplesWeight();
174       auto OtherVW = O.getVGPRTuplesWeight();
175       if (VW != OtherVW)
176         return VW < OtherVW;
177     }
178   }
179   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
180                          (getVGPRNum() < O.getVGPRNum());
181 }
182 
183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
184 LLVM_DUMP_METHOD
185 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
186   OS << "VGPRs: " << getVGPRNum();
187   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
188   OS << ", SGPRs: " << getSGPRNum();
189   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
190   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
191      << ", LSGPR WT: " << getSGPRTuplesWeight();
192   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
193   OS << '\n';
194 }
195 #endif
196 
197 static LaneBitmask getDefRegMask(const MachineOperand &MO,
198                                  const MachineRegisterInfo &MRI) {
199   assert(MO.isDef() && MO.isReg() &&
200     TargetRegisterInfo::isVirtualRegister(MO.getReg()));
201 
202   // We don't rely on read-undef flag because in case of tentative schedule
203   // tracking it isn't set correctly yet. This works correctly however since
204   // use mask has been tracked before using LIS.
205   return MO.getSubReg() == 0 ?
206     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
207     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
208 }
209 
210 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
211                                   const MachineRegisterInfo &MRI,
212                                   const LiveIntervals &LIS) {
213   assert(MO.isUse() && MO.isReg() &&
214          TargetRegisterInfo::isVirtualRegister(MO.getReg()));
215 
216   if (auto SubReg = MO.getSubReg())
217     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
218 
219   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
220   if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
221     return MaxMask;
222 
223   // For a tentative schedule LIS isn't updated yet but livemask should remain
224   // the same on any schedule. Subreg defs can be reordered but they all must
225   // dominate uses anyway.
226   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
227   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
228 }
229 
230 static SmallVector<RegisterMaskPair, 8>
231 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
232                       const MachineRegisterInfo &MRI) {
233   SmallVector<RegisterMaskPair, 8> Res;
234   for (const auto &MO : MI.operands()) {
235     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()))
236       continue;
237     if (!MO.isUse() || !MO.readsReg())
238       continue;
239 
240     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
241 
242     auto Reg = MO.getReg();
243     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
244       return RM.RegUnit == Reg;
245     });
246     if (I != Res.end())
247       I->LaneMask |= UsedMask;
248     else
249       Res.push_back(RegisterMaskPair(Reg, UsedMask));
250   }
251   return Res;
252 }
253 
254 ///////////////////////////////////////////////////////////////////////////////
255 // GCNRPTracker
256 
257 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
258                                   SlotIndex SI,
259                                   const LiveIntervals &LIS,
260                                   const MachineRegisterInfo &MRI) {
261   LaneBitmask LiveMask;
262   const auto &LI = LIS.getInterval(Reg);
263   if (LI.hasSubRanges()) {
264     for (const auto &S : LI.subranges())
265       if (S.liveAt(SI)) {
266         LiveMask |= S.LaneMask;
267         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
268                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
269       }
270   } else if (LI.liveAt(SI)) {
271     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
272   }
273   return LiveMask;
274 }
275 
276 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
277                                            const LiveIntervals &LIS,
278                                            const MachineRegisterInfo &MRI) {
279   GCNRPTracker::LiveRegSet LiveRegs;
280   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
281     auto Reg = TargetRegisterInfo::index2VirtReg(I);
282     if (!LIS.hasInterval(Reg))
283       continue;
284     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
285     if (LiveMask.any())
286       LiveRegs[Reg] = LiveMask;
287   }
288   return LiveRegs;
289 }
290 
291 void GCNRPTracker::reset(const MachineInstr &MI,
292                          const LiveRegSet *LiveRegsCopy,
293                          bool After) {
294   const MachineFunction &MF = *MI.getMF();
295   MRI = &MF.getRegInfo();
296   if (LiveRegsCopy) {
297     if (&LiveRegs != LiveRegsCopy)
298       LiveRegs = *LiveRegsCopy;
299   } else {
300     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
301                      : getLiveRegsBefore(MI, LIS);
302   }
303 
304   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
305 }
306 
307 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
308                                const LiveRegSet *LiveRegsCopy) {
309   GCNRPTracker::reset(MI, LiveRegsCopy, true);
310 }
311 
312 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
313   assert(MRI && "call reset first");
314 
315   LastTrackedMI = &MI;
316 
317   if (MI.isDebugInstr())
318     return;
319 
320   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
321 
322   // calc pressure at the MI (defs + uses)
323   auto AtMIPressure = CurPressure;
324   for (const auto &U : RegUses) {
325     auto LiveMask = LiveRegs[U.RegUnit];
326     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
327   }
328   // update max pressure
329   MaxPressure = max(AtMIPressure, MaxPressure);
330 
331   for (const auto &MO : MI.defs()) {
332     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) ||
333          MO.isDead())
334       continue;
335 
336     auto Reg = MO.getReg();
337     auto I = LiveRegs.find(Reg);
338     if (I == LiveRegs.end())
339       continue;
340     auto &LiveMask = I->second;
341     auto PrevMask = LiveMask;
342     LiveMask &= ~getDefRegMask(MO, *MRI);
343     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
344     if (LiveMask.none())
345       LiveRegs.erase(I);
346   }
347   for (const auto &U : RegUses) {
348     auto &LiveMask = LiveRegs[U.RegUnit];
349     auto PrevMask = LiveMask;
350     LiveMask |= U.LaneMask;
351     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
352   }
353   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
354 }
355 
356 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
357                                  const LiveRegSet *LiveRegsCopy) {
358   MRI = &MI.getParent()->getParent()->getRegInfo();
359   LastTrackedMI = nullptr;
360   MBBEnd = MI.getParent()->end();
361   NextMI = &MI;
362   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
363   if (NextMI == MBBEnd)
364     return false;
365   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
366   return true;
367 }
368 
369 bool GCNDownwardRPTracker::advanceBeforeNext() {
370   assert(MRI && "call reset first");
371 
372   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
373   if (NextMI == MBBEnd)
374     return false;
375 
376   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
377   assert(SI.isValid());
378 
379   // Remove dead registers or mask bits.
380   for (auto &It : LiveRegs) {
381     const LiveInterval &LI = LIS.getInterval(It.first);
382     if (LI.hasSubRanges()) {
383       for (const auto &S : LI.subranges()) {
384         if (!S.liveAt(SI)) {
385           auto PrevMask = It.second;
386           It.second &= ~S.LaneMask;
387           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
388         }
389       }
390     } else if (!LI.liveAt(SI)) {
391       auto PrevMask = It.second;
392       It.second = LaneBitmask::getNone();
393       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
394     }
395     if (It.second.none())
396       LiveRegs.erase(It.first);
397   }
398 
399   MaxPressure = max(MaxPressure, CurPressure);
400 
401   return true;
402 }
403 
404 void GCNDownwardRPTracker::advanceToNext() {
405   LastTrackedMI = &*NextMI++;
406 
407   // Add new registers or mask bits.
408   for (const auto &MO : LastTrackedMI->defs()) {
409     if (!MO.isReg())
410       continue;
411     unsigned Reg = MO.getReg();
412     if (!TargetRegisterInfo::isVirtualRegister(Reg))
413       continue;
414     auto &LiveMask = LiveRegs[Reg];
415     auto PrevMask = LiveMask;
416     LiveMask |= getDefRegMask(MO, *MRI);
417     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
418   }
419 
420   MaxPressure = max(MaxPressure, CurPressure);
421 }
422 
423 bool GCNDownwardRPTracker::advance() {
424   // If we have just called reset live set is actual.
425   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
426     return false;
427   advanceToNext();
428   return true;
429 }
430 
431 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
432   while (NextMI != End)
433     if (!advance()) return false;
434   return true;
435 }
436 
437 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
438                                    MachineBasicBlock::const_iterator End,
439                                    const LiveRegSet *LiveRegsCopy) {
440   reset(*Begin, LiveRegsCopy);
441   return advance(End);
442 }
443 
444 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
445 LLVM_DUMP_METHOD
446 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
447                            const GCNRPTracker::LiveRegSet &TrackedLR,
448                            const TargetRegisterInfo *TRI) {
449   for (auto const &P : TrackedLR) {
450     auto I = LISLR.find(P.first);
451     if (I == LISLR.end()) {
452       dbgs() << "  " << printReg(P.first, TRI)
453              << ":L" << PrintLaneMask(P.second)
454              << " isn't found in LIS reported set\n";
455     }
456     else if (I->second != P.second) {
457       dbgs() << "  " << printReg(P.first, TRI)
458         << " masks doesn't match: LIS reported "
459         << PrintLaneMask(I->second)
460         << ", tracked "
461         << PrintLaneMask(P.second)
462         << '\n';
463     }
464   }
465   for (auto const &P : LISLR) {
466     auto I = TrackedLR.find(P.first);
467     if (I == TrackedLR.end()) {
468       dbgs() << "  " << printReg(P.first, TRI)
469              << ":L" << PrintLaneMask(P.second)
470              << " isn't found in tracked set\n";
471     }
472   }
473 }
474 
475 bool GCNUpwardRPTracker::isValid() const {
476   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
477   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
478   const auto &TrackedLR = LiveRegs;
479 
480   if (!isEqual(LISLR, TrackedLR)) {
481     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
482               " LIS reported livesets mismatch:\n";
483     printLivesAt(SI, LIS, *MRI);
484     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
485     return false;
486   }
487 
488   auto LISPressure = getRegPressure(*MRI, LISLR);
489   if (LISPressure != CurPressure) {
490     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
491     CurPressure.print(dbgs());
492     dbgs() << "LIS rpt: ";
493     LISPressure.print(dbgs());
494     return false;
495   }
496   return true;
497 }
498 
499 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
500                                  const MachineRegisterInfo &MRI) {
501   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
502   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
503     unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
504     auto It = LiveRegs.find(Reg);
505     if (It != LiveRegs.end() && It->second.any())
506       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
507          << PrintLaneMask(It->second);
508   }
509   OS << '\n';
510 }
511 #endif
512