xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (revision 2e3507c25e42292b45a5482e116d278f5515d04d)
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "llvm/CodeGen/RegisterPressure.h"
16 
17 using namespace llvm;
18 
19 #define DEBUG_TYPE "machine-scheduler"
20 
21 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
22                    const GCNRPTracker::LiveRegSet &S2) {
23   if (S1.size() != S2.size())
24     return false;
25 
26   for (const auto &P : S1) {
27     auto I = S2.find(P.first);
28     if (I == S2.end() || I->second != P.second)
29       return false;
30   }
31   return true;
32 }
33 
34 
35 ///////////////////////////////////////////////////////////////////////////////
36 // GCNRegPressure
37 
38 unsigned GCNRegPressure::getRegKind(Register Reg,
39                                     const MachineRegisterInfo &MRI) {
40   assert(Reg.isVirtual());
41   const auto RC = MRI.getRegClass(Reg);
42   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
43   return STI->isSGPRClass(RC)
44              ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
45          : STI->isAGPRClass(RC)
46              ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
47              : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
48 }
49 
50 void GCNRegPressure::inc(unsigned Reg,
51                          LaneBitmask PrevMask,
52                          LaneBitmask NewMask,
53                          const MachineRegisterInfo &MRI) {
54   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
55       SIRegisterInfo::getNumCoveredRegs(PrevMask))
56     return;
57 
58   int Sign = 1;
59   if (NewMask < PrevMask) {
60     std::swap(NewMask, PrevMask);
61     Sign = -1;
62   }
63 
64   switch (auto Kind = getRegKind(Reg, MRI)) {
65   case SGPR32:
66   case VGPR32:
67   case AGPR32:
68     Value[Kind] += Sign;
69     break;
70 
71   case SGPR_TUPLE:
72   case VGPR_TUPLE:
73   case AGPR_TUPLE:
74     assert(PrevMask < NewMask);
75 
76     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
77       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78 
79     if (PrevMask.none()) {
80       assert(NewMask.any());
81       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
82     }
83     break;
84 
85   default: llvm_unreachable("Unknown register kind");
86   }
87 }
88 
89 bool GCNRegPressure::less(const GCNSubtarget &ST,
90                           const GCNRegPressure& O,
91                           unsigned MaxOccupancy) const {
92   const auto SGPROcc = std::min(MaxOccupancy,
93                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
94   const auto VGPROcc =
95     std::min(MaxOccupancy,
96              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
97   const auto OtherSGPROcc = std::min(MaxOccupancy,
98                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
99   const auto OtherVGPROcc =
100     std::min(MaxOccupancy,
101              ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
102 
103   const auto Occ = std::min(SGPROcc, VGPROcc);
104   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
105   if (Occ != OtherOcc)
106     return Occ > OtherOcc;
107 
108   bool SGPRImportant = SGPROcc < VGPROcc;
109   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
110 
111   // if both pressures disagree on what is more important compare vgprs
112   if (SGPRImportant != OtherSGPRImportant) {
113     SGPRImportant = false;
114   }
115 
116   // compare large regs pressure
117   bool SGPRFirst = SGPRImportant;
118   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
119     if (SGPRFirst) {
120       auto SW = getSGPRTuplesWeight();
121       auto OtherSW = O.getSGPRTuplesWeight();
122       if (SW != OtherSW)
123         return SW < OtherSW;
124     } else {
125       auto VW = getVGPRTuplesWeight();
126       auto OtherVW = O.getVGPRTuplesWeight();
127       if (VW != OtherVW)
128         return VW < OtherVW;
129     }
130   }
131   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
132                          (getVGPRNum(ST.hasGFX90AInsts()) <
133                           O.getVGPRNum(ST.hasGFX90AInsts()));
134 }
135 
136 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
137 LLVM_DUMP_METHOD
138 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
139   return Printable([&RP, ST](raw_ostream &OS) {
140     OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
141        << "AGPRs: " << RP.getAGPRNum();
142     if (ST)
143       OS << "(O"
144          << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
145          << ')';
146     OS << ", SGPRs: " << RP.getSGPRNum();
147     if (ST)
148       OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
149     OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
150        << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
151     if (ST)
152       OS << " -> Occ: " << RP.getOccupancy(*ST);
153     OS << '\n';
154   });
155 }
156 #endif
157 
158 static LaneBitmask getDefRegMask(const MachineOperand &MO,
159                                  const MachineRegisterInfo &MRI) {
160   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
161 
162   // We don't rely on read-undef flag because in case of tentative schedule
163   // tracking it isn't set correctly yet. This works correctly however since
164   // use mask has been tracked before using LIS.
165   return MO.getSubReg() == 0 ?
166     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
167     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
168 }
169 
170 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
171                                   const MachineRegisterInfo &MRI,
172                                   const LiveIntervals &LIS) {
173   assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
174 
175   if (auto SubReg = MO.getSubReg())
176     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
177 
178   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
179   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
180     return MaxMask;
181 
182   // For a tentative schedule LIS isn't updated yet but livemask should remain
183   // the same on any schedule. Subreg defs can be reordered but they all must
184   // dominate uses anyway.
185   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
186   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
187 }
188 
189 static SmallVector<RegisterMaskPair, 8>
190 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
191                       const MachineRegisterInfo &MRI) {
192   SmallVector<RegisterMaskPair, 8> Res;
193   for (const auto &MO : MI.operands()) {
194     if (!MO.isReg() || !MO.getReg().isVirtual())
195       continue;
196     if (!MO.isUse() || !MO.readsReg())
197       continue;
198 
199     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
200 
201     auto Reg = MO.getReg();
202     auto I = llvm::find_if(
203         Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
204     if (I != Res.end())
205       I->LaneMask |= UsedMask;
206     else
207       Res.push_back(RegisterMaskPair(Reg, UsedMask));
208   }
209   return Res;
210 }
211 
212 ///////////////////////////////////////////////////////////////////////////////
213 // GCNRPTracker
214 
215 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
216                                   SlotIndex SI,
217                                   const LiveIntervals &LIS,
218                                   const MachineRegisterInfo &MRI) {
219   LaneBitmask LiveMask;
220   const auto &LI = LIS.getInterval(Reg);
221   if (LI.hasSubRanges()) {
222     for (const auto &S : LI.subranges())
223       if (S.liveAt(SI)) {
224         LiveMask |= S.LaneMask;
225         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
226                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
227       }
228   } else if (LI.liveAt(SI)) {
229     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
230   }
231   return LiveMask;
232 }
233 
234 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
235                                            const LiveIntervals &LIS,
236                                            const MachineRegisterInfo &MRI) {
237   GCNRPTracker::LiveRegSet LiveRegs;
238   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
239     auto Reg = Register::index2VirtReg(I);
240     if (!LIS.hasInterval(Reg))
241       continue;
242     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
243     if (LiveMask.any())
244       LiveRegs[Reg] = LiveMask;
245   }
246   return LiveRegs;
247 }
248 
249 void GCNRPTracker::reset(const MachineInstr &MI,
250                          const LiveRegSet *LiveRegsCopy,
251                          bool After) {
252   const MachineFunction &MF = *MI.getMF();
253   MRI = &MF.getRegInfo();
254   if (LiveRegsCopy) {
255     if (&LiveRegs != LiveRegsCopy)
256       LiveRegs = *LiveRegsCopy;
257   } else {
258     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
259                      : getLiveRegsBefore(MI, LIS);
260   }
261 
262   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
263 }
264 
265 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
266                                const LiveRegSet *LiveRegsCopy) {
267   GCNRPTracker::reset(MI, LiveRegsCopy, true);
268 }
269 
270 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
271   assert(MRI && "call reset first");
272 
273   LastTrackedMI = &MI;
274 
275   if (MI.isDebugInstr())
276     return;
277 
278   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
279 
280   // calc pressure at the MI (defs + uses)
281   auto AtMIPressure = CurPressure;
282   for (const auto &U : RegUses) {
283     auto LiveMask = LiveRegs[U.RegUnit];
284     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
285   }
286   // update max pressure
287   MaxPressure = max(AtMIPressure, MaxPressure);
288 
289   for (const auto &MO : MI.all_defs()) {
290     if (!MO.getReg().isVirtual() || MO.isDead())
291       continue;
292 
293     auto Reg = MO.getReg();
294     auto I = LiveRegs.find(Reg);
295     if (I == LiveRegs.end())
296       continue;
297     auto &LiveMask = I->second;
298     auto PrevMask = LiveMask;
299     LiveMask &= ~getDefRegMask(MO, *MRI);
300     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
301     if (LiveMask.none())
302       LiveRegs.erase(I);
303   }
304   for (const auto &U : RegUses) {
305     auto &LiveMask = LiveRegs[U.RegUnit];
306     auto PrevMask = LiveMask;
307     LiveMask |= U.LaneMask;
308     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
309   }
310   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
311 }
312 
313 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
314                                  const LiveRegSet *LiveRegsCopy) {
315   MRI = &MI.getParent()->getParent()->getRegInfo();
316   LastTrackedMI = nullptr;
317   MBBEnd = MI.getParent()->end();
318   NextMI = &MI;
319   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
320   if (NextMI == MBBEnd)
321     return false;
322   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
323   return true;
324 }
325 
326 bool GCNDownwardRPTracker::advanceBeforeNext() {
327   assert(MRI && "call reset first");
328   if (!LastTrackedMI)
329     return NextMI == MBBEnd;
330 
331   assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
332 
333   SlotIndex SI = NextMI == MBBEnd
334                      ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
335                      : LIS.getInstructionIndex(*NextMI).getBaseIndex();
336   assert(SI.isValid());
337 
338   // Remove dead registers or mask bits.
339   SmallSet<Register, 8> SeenRegs;
340   for (auto &MO : LastTrackedMI->operands()) {
341     if (!MO.isReg() || !MO.getReg().isVirtual())
342       continue;
343     if (MO.isUse() && !MO.readsReg())
344       continue;
345     if (!SeenRegs.insert(MO.getReg()).second)
346       continue;
347     const LiveInterval &LI = LIS.getInterval(MO.getReg());
348     if (LI.hasSubRanges()) {
349       auto It = LiveRegs.end();
350       for (const auto &S : LI.subranges()) {
351         if (!S.liveAt(SI)) {
352           if (It == LiveRegs.end()) {
353             It = LiveRegs.find(MO.getReg());
354             if (It == LiveRegs.end())
355               llvm_unreachable("register isn't live");
356           }
357           auto PrevMask = It->second;
358           It->second &= ~S.LaneMask;
359           CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
360         }
361       }
362       if (It != LiveRegs.end() && It->second.none())
363         LiveRegs.erase(It);
364     } else if (!LI.liveAt(SI)) {
365       auto It = LiveRegs.find(MO.getReg());
366       if (It == LiveRegs.end())
367         llvm_unreachable("register isn't live");
368       CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
369       LiveRegs.erase(It);
370     }
371   }
372 
373   MaxPressure = max(MaxPressure, CurPressure);
374 
375   LastTrackedMI = nullptr;
376 
377   return NextMI == MBBEnd;
378 }
379 
380 void GCNDownwardRPTracker::advanceToNext() {
381   LastTrackedMI = &*NextMI++;
382   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
383 
384   // Add new registers or mask bits.
385   for (const auto &MO : LastTrackedMI->all_defs()) {
386     Register Reg = MO.getReg();
387     if (!Reg.isVirtual())
388       continue;
389     auto &LiveMask = LiveRegs[Reg];
390     auto PrevMask = LiveMask;
391     LiveMask |= getDefRegMask(MO, *MRI);
392     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
393   }
394 
395   MaxPressure = max(MaxPressure, CurPressure);
396 }
397 
398 bool GCNDownwardRPTracker::advance() {
399   if (NextMI == MBBEnd)
400     return false;
401   advanceBeforeNext();
402   advanceToNext();
403   return true;
404 }
405 
406 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
407   while (NextMI != End)
408     if (!advance()) return false;
409   return true;
410 }
411 
412 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
413                                    MachineBasicBlock::const_iterator End,
414                                    const LiveRegSet *LiveRegsCopy) {
415   reset(*Begin, LiveRegsCopy);
416   return advance(End);
417 }
418 
419 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
420 LLVM_DUMP_METHOD
421 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
422                                const GCNRPTracker::LiveRegSet &TrackedLR,
423                                const TargetRegisterInfo *TRI) {
424   return Printable([&LISLR, &TrackedLR, TRI](raw_ostream &OS) {
425     for (auto const &P : TrackedLR) {
426       auto I = LISLR.find(P.first);
427       if (I == LISLR.end()) {
428         OS << "  " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
429            << " isn't found in LIS reported set\n";
430       } else if (I->second != P.second) {
431         OS << "  " << printReg(P.first, TRI)
432            << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
433            << ", tracked " << PrintLaneMask(P.second) << '\n';
434       }
435     }
436     for (auto const &P : LISLR) {
437       auto I = TrackedLR.find(P.first);
438       if (I == TrackedLR.end()) {
439         OS << "  " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
440            << " isn't found in tracked set\n";
441       }
442     }
443   });
444 }
445 
446 bool GCNUpwardRPTracker::isValid() const {
447   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
448   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
449   const auto &TrackedLR = LiveRegs;
450 
451   if (!isEqual(LISLR, TrackedLR)) {
452     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
453               " LIS reported livesets mismatch:\n"
454            << print(LISLR, *MRI);
455     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
456     return false;
457   }
458 
459   auto LISPressure = getRegPressure(*MRI, LISLR);
460   if (LISPressure != CurPressure) {
461     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
462            << print(CurPressure) << "LIS rpt: " << print(LISPressure);
463     return false;
464   }
465   return true;
466 }
467 
468 LLVM_DUMP_METHOD
469 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
470                       const MachineRegisterInfo &MRI) {
471   return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
472     const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
473     for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
474       Register Reg = Register::index2VirtReg(I);
475       auto It = LiveRegs.find(Reg);
476       if (It != LiveRegs.end() && It->second.any())
477         OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
478            << PrintLaneMask(It->second);
479     }
480     OS << '\n';
481   });
482 }
483 
484 LLVM_DUMP_METHOD
485 void GCNRegPressure::dump() const { dbgs() << print(*this); }
486 
487 #endif
488