1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "GCNRegPressure.h"
15 #include "AMDGPU.h"
16 #include "SIMachineFunctionInfo.h"
17 #include "llvm/CodeGen/RegisterPressure.h"
18
19 using namespace llvm;
20
21 #define DEBUG_TYPE "machine-scheduler"
22
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)23 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
24 const GCNRPTracker::LiveRegSet &S2) {
25 if (S1.size() != S2.size())
26 return false;
27
28 for (const auto &P : S1) {
29 auto I = S2.find(P.first);
30 if (I == S2.end() || I->second != P.second)
31 return false;
32 }
33 return true;
34 }
35
36 ///////////////////////////////////////////////////////////////////////////////
37 // GCNRegPressure
38
getRegKind(const TargetRegisterClass * RC,const SIRegisterInfo * STI)39 unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC,
40 const SIRegisterInfo *STI) {
41 return STI->isSGPRClass(RC) ? SGPR : (STI->isAGPRClass(RC) ? AGPR : VGPR);
42 }
43
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)44 void GCNRegPressure::inc(unsigned Reg,
45 LaneBitmask PrevMask,
46 LaneBitmask NewMask,
47 const MachineRegisterInfo &MRI) {
48 unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(NewMask);
49 unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(PrevMask);
50 if (NewNumCoveredRegs == PrevNumCoveredRegs)
51 return;
52
53 int Sign = 1;
54 if (NewMask < PrevMask) {
55 std::swap(NewMask, PrevMask);
56 std::swap(NewNumCoveredRegs, PrevNumCoveredRegs);
57 Sign = -1;
58 }
59 assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs &&
60 "prev mask should always be lesser than new");
61
62 const TargetRegisterClass *RC = MRI.getRegClass(Reg);
63 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
64 const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI);
65 unsigned RegKind = getRegKind(RC, STI);
66 if (TRI->getRegSizeInBits(*RC) != 32) {
67 // Reg is from a tuple register class.
68 if (PrevMask.none()) {
69 unsigned TupleIdx = TOTAL_KINDS + RegKind;
70 Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight;
71 }
72 // Pressure scales with number of new registers covered by the new mask.
73 // Note when true16 is enabled, we can no longer safely use the following
74 // approach to calculate the difference in the number of 32-bit registers
75 // between two masks:
76 //
77 // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78 //
79 // The issue is that the mask calculation `~PrevMask & NewMask` doesn't
80 // properly account for partial usage of a 32-bit register when dealing with
81 // 16-bit registers.
82 //
83 // Consider this example:
84 // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register
85 // usage difference should be 1, because even though PrevMask uses only half
86 // of a 32-bit register, it should still be counted as a full register use.
87 // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and
88 // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect
89 // calculation can lead to integer overflow when Sign = -1.
90 Sign *= NewNumCoveredRegs - PrevNumCoveredRegs;
91 }
92 Value[RegKind] += Sign;
93 }
94
less(const MachineFunction & MF,const GCNRegPressure & O,unsigned MaxOccupancy) const95 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
96 unsigned MaxOccupancy) const {
97 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
98 unsigned DynamicVGPRBlockSize =
99 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
100
101 const auto SGPROcc = std::min(MaxOccupancy,
102 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
103 const auto VGPROcc = std::min(
104 MaxOccupancy, ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()),
105 DynamicVGPRBlockSize));
106 const auto OtherSGPROcc = std::min(MaxOccupancy,
107 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
108 const auto OtherVGPROcc =
109 std::min(MaxOccupancy,
110 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()),
111 DynamicVGPRBlockSize));
112
113 const auto Occ = std::min(SGPROcc, VGPROcc);
114 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
115
116 // Give first precedence to the better occupancy.
117 if (Occ != OtherOcc)
118 return Occ > OtherOcc;
119
120 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
121 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
122
123 // SGPR excess pressure conditions
124 unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0);
125 unsigned OtherExcessSGPR =
126 std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
127
128 auto WaveSize = ST.getWavefrontSize();
129 // The number of virtual VGPRs required to handle excess SGPR
130 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
131 unsigned OtherVGPRForSGPRSpills =
132 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
133
134 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
135
136 // Unified excess pressure conditions, accounting for VGPRs used for SGPR
137 // spills
138 unsigned ExcessVGPR =
139 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) +
140 VGPRForSGPRSpills - MaxVGPRs),
141 0);
142 unsigned OtherExcessVGPR =
143 std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
144 OtherVGPRForSGPRSpills - MaxVGPRs),
145 0);
146 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
147 // spills
148 unsigned ExcessArchVGPR = std::max(
149 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs),
150 0);
151 unsigned OtherExcessArchVGPR =
152 std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills -
153 MaxArchVGPRs),
154 0);
155 // AGPR excess pressure conditions
156 unsigned ExcessAGPR = std::max(
157 static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs)
158 : (getAGPRNum() - MaxVGPRs)),
159 0);
160 unsigned OtherExcessAGPR = std::max(
161 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
162 : (O.getAGPRNum() - MaxVGPRs)),
163 0);
164
165 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
166 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
167 OtherExcessArchVGPR || OtherExcessAGPR;
168
169 // Give second precedence to the reduced number of spills to hold the register
170 // pressure.
171 if (ExcessRP || OtherExcessRP) {
172 // The difference in excess VGPR pressure, after including VGPRs used for
173 // SGPR spills
174 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
175 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
176
177 int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
178
179 if (VGPRDiff != 0)
180 return VGPRDiff > 0;
181 if (SGPRDiff != 0) {
182 unsigned PureExcessVGPR =
183 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
184 0) +
185 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0);
186 unsigned OtherPureExcessVGPR =
187 std::max(
188 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
189 0) +
190 std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0);
191
192 // If we have a special case where there is a tie in excess VGPR, but one
193 // of the pressures has VGPR usage from SGPR spills, prefer the pressure
194 // with SGPR spills.
195 if (PureExcessVGPR != OtherPureExcessVGPR)
196 return SGPRDiff < 0;
197 // If both pressures have the same excess pressure before and after
198 // accounting for SGPR spills, prefer fewer SGPR spills.
199 return SGPRDiff > 0;
200 }
201 }
202
203 bool SGPRImportant = SGPROcc < VGPROcc;
204 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
205
206 // If both pressures disagree on what is more important compare vgprs.
207 if (SGPRImportant != OtherSGPRImportant) {
208 SGPRImportant = false;
209 }
210
211 // Give third precedence to lower register tuple pressure.
212 bool SGPRFirst = SGPRImportant;
213 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
214 if (SGPRFirst) {
215 auto SW = getSGPRTuplesWeight();
216 auto OtherSW = O.getSGPRTuplesWeight();
217 if (SW != OtherSW)
218 return SW < OtherSW;
219 } else {
220 auto VW = getVGPRTuplesWeight();
221 auto OtherVW = O.getVGPRTuplesWeight();
222 if (VW != OtherVW)
223 return VW < OtherVW;
224 }
225 }
226
227 // Give final precedence to lower general RP.
228 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
229 (getVGPRNum(ST.hasGFX90AInsts()) <
230 O.getVGPRNum(ST.hasGFX90AInsts()));
231 }
232
print(const GCNRegPressure & RP,const GCNSubtarget * ST,unsigned DynamicVGPRBlockSize)233 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST,
234 unsigned DynamicVGPRBlockSize) {
235 return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) {
236 OS << "VGPRs: " << RP.getArchVGPRNum() << ' '
237 << "AGPRs: " << RP.getAGPRNum();
238 if (ST)
239 OS << "(O"
240 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()),
241 DynamicVGPRBlockSize)
242 << ')';
243 OS << ", SGPRs: " << RP.getSGPRNum();
244 if (ST)
245 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
246 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
247 << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
248 if (ST)
249 OS << " -> Occ: " << RP.getOccupancy(*ST, DynamicVGPRBlockSize);
250 OS << '\n';
251 });
252 }
253
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)254 static LaneBitmask getDefRegMask(const MachineOperand &MO,
255 const MachineRegisterInfo &MRI) {
256 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
257
258 // We don't rely on read-undef flag because in case of tentative schedule
259 // tracking it isn't set correctly yet. This works correctly however since
260 // use mask has been tracked before using LIS.
261 return MO.getSubReg() == 0 ?
262 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
263 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
264 }
265
266 static void
collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> & VRegMaskOrUnits,const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)267 collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
268 const MachineInstr &MI, const LiveIntervals &LIS,
269 const MachineRegisterInfo &MRI) {
270
271 auto &TRI = *MRI.getTargetRegisterInfo();
272 for (const auto &MO : MI.operands()) {
273 if (!MO.isReg() || !MO.getReg().isVirtual())
274 continue;
275 if (!MO.isUse() || !MO.readsReg())
276 continue;
277
278 Register Reg = MO.getReg();
279 auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) {
280 return RM.RegUnit == Reg;
281 });
282
283 auto &P = I == VRegMaskOrUnits.end()
284 ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone())
285 : *I;
286
287 P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg())
288 : MRI.getMaxLaneMaskForVReg(Reg);
289 }
290
291 SlotIndex InstrSI;
292 for (auto &P : VRegMaskOrUnits) {
293 auto &LI = LIS.getInterval(P.RegUnit);
294 if (!LI.hasSubRanges())
295 continue;
296
297 // For a tentative schedule LIS isn't updated yet but livemask should
298 // remain the same on any schedule. Subreg defs can be reordered but they
299 // all must dominate uses anyway.
300 if (!InstrSI)
301 InstrSI = LIS.getInstructionIndex(MI).getBaseIndex();
302
303 P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask);
304 }
305 }
306
307 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
getLanesWithProperty(const LiveIntervals & LIS,const MachineRegisterInfo & MRI,bool TrackLaneMasks,Register RegUnit,SlotIndex Pos,LaneBitmask SafeDefault,function_ref<bool (const LiveRange & LR,SlotIndex Pos)> Property)308 static LaneBitmask getLanesWithProperty(
309 const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
310 bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
311 LaneBitmask SafeDefault,
312 function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
313 if (RegUnit.isVirtual()) {
314 const LiveInterval &LI = LIS.getInterval(RegUnit);
315 LaneBitmask Result;
316 if (TrackLaneMasks && LI.hasSubRanges()) {
317 for (const LiveInterval::SubRange &SR : LI.subranges()) {
318 if (Property(SR, Pos))
319 Result |= SR.LaneMask;
320 }
321 } else if (Property(LI, Pos)) {
322 Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
323 : LaneBitmask::getAll();
324 }
325
326 return Result;
327 }
328
329 const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
330 if (LR == nullptr)
331 return SafeDefault;
332 return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
333 }
334
335 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
336 /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
337 /// The query starts with a lane bitmask which gets lanes/bits removed for every
338 /// use we find.
findUseBetween(unsigned Reg,LaneBitmask LastUseMask,SlotIndex PriorUseIdx,SlotIndex NextUseIdx,const MachineRegisterInfo & MRI,const SIRegisterInfo * TRI,const LiveIntervals * LIS,bool Upward=false)339 static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
340 SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
341 const MachineRegisterInfo &MRI,
342 const SIRegisterInfo *TRI,
343 const LiveIntervals *LIS,
344 bool Upward = false) {
345 for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
346 if (MO.isUndef())
347 continue;
348 const MachineInstr *MI = MO.getParent();
349 SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
350 bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
351 : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
352 if (!InRange)
353 continue;
354
355 unsigned SubRegIdx = MO.getSubReg();
356 LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
357 LastUseMask &= ~UseMask;
358 if (LastUseMask.none())
359 return LaneBitmask::getNone();
360 }
361 return LastUseMask;
362 }
363
364 ////////////////////////////////////////////////////////////////////////////////
365 // GCNRPTarget
366
GCNRPTarget(const MachineFunction & MF,const GCNRegPressure & RP,bool CombineVGPRSavings)367 GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP,
368 bool CombineVGPRSavings)
369 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
370 const Function &F = MF.getFunction();
371 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
372 setRegLimits(ST.getMaxNumSGPRs(F), ST.getMaxNumVGPRs(F), MF);
373 }
374
GCNRPTarget(unsigned NumSGPRs,unsigned NumVGPRs,const MachineFunction & MF,const GCNRegPressure & RP,bool CombineVGPRSavings)375 GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs,
376 const MachineFunction &MF, const GCNRegPressure &RP,
377 bool CombineVGPRSavings)
378 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
379 setRegLimits(NumSGPRs, NumVGPRs, MF);
380 }
381
GCNRPTarget(unsigned Occupancy,const MachineFunction & MF,const GCNRegPressure & RP,bool CombineVGPRSavings)382 GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF,
383 const GCNRegPressure &RP, bool CombineVGPRSavings)
384 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
385 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
386 unsigned DynamicVGPRBlockSize =
387 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
388 setRegLimits(ST.getMaxNumSGPRs(Occupancy, /*Addressable=*/false),
389 ST.getMaxNumVGPRs(Occupancy, DynamicVGPRBlockSize), MF);
390 }
391
setRegLimits(unsigned NumSGPRs,unsigned NumVGPRs,const MachineFunction & MF)392 void GCNRPTarget::setRegLimits(unsigned NumSGPRs, unsigned NumVGPRs,
393 const MachineFunction &MF) {
394 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
395 unsigned DynamicVGPRBlockSize =
396 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
397 MaxSGPRs = std::min(ST.getAddressableNumSGPRs(), NumSGPRs);
398 MaxVGPRs = std::min(ST.getAddressableNumArchVGPRs(), NumVGPRs);
399 MaxUnifiedVGPRs =
400 ST.hasGFX90AInsts()
401 ? std::min(ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), NumVGPRs)
402 : 0;
403 }
404
isSaveBeneficial(Register Reg,const MachineRegisterInfo & MRI) const405 bool GCNRPTarget::isSaveBeneficial(Register Reg,
406 const MachineRegisterInfo &MRI) const {
407 const TargetRegisterClass *RC = MRI.getRegClass(Reg);
408 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
409 const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI);
410
411 if (SRI->isSGPRClass(RC))
412 return RP.getSGPRNum() > MaxSGPRs;
413 unsigned NumVGPRs =
414 SRI->isAGPRClass(RC) ? RP.getAGPRNum() : RP.getArchVGPRNum();
415 return isVGPRBankSaveBeneficial(NumVGPRs);
416 }
417
satisfied() const418 bool GCNRPTarget::satisfied() const {
419 if (RP.getSGPRNum() > MaxSGPRs)
420 return false;
421 if (RP.getVGPRNum(false) > MaxVGPRs &&
422 (!CombineVGPRSavings || !satisifiesVGPRBanksTarget()))
423 return false;
424 return satisfiesUnifiedTarget();
425 }
426
427 ///////////////////////////////////////////////////////////////////////////////
428 // GCNRPTracker
429
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI,LaneBitmask LaneMaskFilter)430 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI,
431 const LiveIntervals &LIS,
432 const MachineRegisterInfo &MRI,
433 LaneBitmask LaneMaskFilter) {
434 return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter);
435 }
436
getLiveLaneMask(const LiveInterval & LI,SlotIndex SI,const MachineRegisterInfo & MRI,LaneBitmask LaneMaskFilter)437 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
438 const MachineRegisterInfo &MRI,
439 LaneBitmask LaneMaskFilter) {
440 LaneBitmask LiveMask;
441 if (LI.hasSubRanges()) {
442 for (const auto &S : LI.subranges())
443 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) {
444 LiveMask |= S.LaneMask;
445 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
446 }
447 } else if (LI.liveAt(SI)) {
448 LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg());
449 }
450 LiveMask &= LaneMaskFilter;
451 return LiveMask;
452 }
453
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)454 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
455 const LiveIntervals &LIS,
456 const MachineRegisterInfo &MRI) {
457 GCNRPTracker::LiveRegSet LiveRegs;
458 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
459 auto Reg = Register::index2VirtReg(I);
460 if (!LIS.hasInterval(Reg))
461 continue;
462 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
463 if (LiveMask.any())
464 LiveRegs[Reg] = LiveMask;
465 }
466 return LiveRegs;
467 }
468
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)469 void GCNRPTracker::reset(const MachineInstr &MI,
470 const LiveRegSet *LiveRegsCopy,
471 bool After) {
472 const MachineFunction &MF = *MI.getMF();
473 MRI = &MF.getRegInfo();
474 if (LiveRegsCopy) {
475 if (&LiveRegs != LiveRegsCopy)
476 LiveRegs = *LiveRegsCopy;
477 } else {
478 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
479 : getLiveRegsBefore(MI, LIS);
480 }
481
482 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
483 }
484
reset(const MachineRegisterInfo & MRI_,const LiveRegSet & LiveRegs_)485 void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
486 const LiveRegSet &LiveRegs_) {
487 MRI = &MRI_;
488 LiveRegs = LiveRegs_;
489 LastTrackedMI = nullptr;
490 MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
491 }
492
493 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
getLastUsedLanes(Register RegUnit,SlotIndex Pos) const494 LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
495 SlotIndex Pos) const {
496 return getLanesWithProperty(
497 LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
498 [](const LiveRange &LR, SlotIndex Pos) {
499 const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
500 return S != nullptr && S->end == Pos.getRegSlot();
501 });
502 }
503
504 ////////////////////////////////////////////////////////////////////////////////
505 // GCNUpwardRPTracker
506
recede(const MachineInstr & MI)507 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
508 assert(MRI && "call reset first");
509
510 LastTrackedMI = &MI;
511
512 if (MI.isDebugInstr())
513 return;
514
515 // Kill all defs.
516 GCNRegPressure DefPressure, ECDefPressure;
517 bool HasECDefs = false;
518 for (const MachineOperand &MO : MI.all_defs()) {
519 if (!MO.getReg().isVirtual())
520 continue;
521
522 Register Reg = MO.getReg();
523 LaneBitmask DefMask = getDefRegMask(MO, *MRI);
524
525 // Treat a def as fully live at the moment of definition: keep a record.
526 if (MO.isEarlyClobber()) {
527 ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
528 HasECDefs = true;
529 } else
530 DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
531
532 auto I = LiveRegs.find(Reg);
533 if (I == LiveRegs.end())
534 continue;
535
536 LaneBitmask &LiveMask = I->second;
537 LaneBitmask PrevMask = LiveMask;
538 LiveMask &= ~DefMask;
539 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
540 if (LiveMask.none())
541 LiveRegs.erase(I);
542 }
543
544 // Update MaxPressure with defs pressure.
545 DefPressure += CurPressure;
546 if (HasECDefs)
547 DefPressure += ECDefPressure;
548 MaxPressure = max(DefPressure, MaxPressure);
549
550 // Make uses alive.
551 SmallVector<VRegMaskOrUnit, 8> RegUses;
552 collectVirtualRegUses(RegUses, MI, LIS, *MRI);
553 for (const VRegMaskOrUnit &U : RegUses) {
554 LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
555 LaneBitmask PrevMask = LiveMask;
556 LiveMask |= U.LaneMask;
557 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
558 }
559
560 // Update MaxPressure with uses plus early-clobber defs pressure.
561 MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure)
562 : max(CurPressure, MaxPressure);
563
564 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
565 }
566
567 ////////////////////////////////////////////////////////////////////////////////
568 // GCNDownwardRPTracker
569
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)570 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
571 const LiveRegSet *LiveRegsCopy) {
572 MRI = &MI.getParent()->getParent()->getRegInfo();
573 LastTrackedMI = nullptr;
574 MBBEnd = MI.getParent()->end();
575 NextMI = &MI;
576 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
577 if (NextMI == MBBEnd)
578 return false;
579 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
580 return true;
581 }
582
advanceBeforeNext(MachineInstr * MI,bool UseInternalIterator)583 bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
584 bool UseInternalIterator) {
585 assert(MRI && "call reset first");
586 SlotIndex SI;
587 const MachineInstr *CurrMI;
588 if (UseInternalIterator) {
589 if (!LastTrackedMI)
590 return NextMI == MBBEnd;
591
592 assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
593 CurrMI = LastTrackedMI;
594
595 SI = NextMI == MBBEnd
596 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
597 : LIS.getInstructionIndex(*NextMI).getBaseIndex();
598 } else { //! UseInternalIterator
599 SI = LIS.getInstructionIndex(*MI).getBaseIndex();
600 CurrMI = MI;
601 }
602
603 assert(SI.isValid());
604
605 // Remove dead registers or mask bits.
606 SmallSet<Register, 8> SeenRegs;
607 for (auto &MO : CurrMI->operands()) {
608 if (!MO.isReg() || !MO.getReg().isVirtual())
609 continue;
610 if (MO.isUse() && !MO.readsReg())
611 continue;
612 if (!UseInternalIterator && MO.isDef())
613 continue;
614 if (!SeenRegs.insert(MO.getReg()).second)
615 continue;
616 const LiveInterval &LI = LIS.getInterval(MO.getReg());
617 if (LI.hasSubRanges()) {
618 auto It = LiveRegs.end();
619 for (const auto &S : LI.subranges()) {
620 if (!S.liveAt(SI)) {
621 if (It == LiveRegs.end()) {
622 It = LiveRegs.find(MO.getReg());
623 if (It == LiveRegs.end())
624 llvm_unreachable("register isn't live");
625 }
626 auto PrevMask = It->second;
627 It->second &= ~S.LaneMask;
628 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
629 }
630 }
631 if (It != LiveRegs.end() && It->second.none())
632 LiveRegs.erase(It);
633 } else if (!LI.liveAt(SI)) {
634 auto It = LiveRegs.find(MO.getReg());
635 if (It == LiveRegs.end())
636 llvm_unreachable("register isn't live");
637 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
638 LiveRegs.erase(It);
639 }
640 }
641
642 MaxPressure = max(MaxPressure, CurPressure);
643
644 LastTrackedMI = nullptr;
645
646 return UseInternalIterator && (NextMI == MBBEnd);
647 }
648
advanceToNext(MachineInstr * MI,bool UseInternalIterator)649 void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
650 bool UseInternalIterator) {
651 if (UseInternalIterator) {
652 LastTrackedMI = &*NextMI++;
653 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
654 } else {
655 LastTrackedMI = MI;
656 }
657
658 const MachineInstr *CurrMI = LastTrackedMI;
659
660 // Add new registers or mask bits.
661 for (const auto &MO : CurrMI->all_defs()) {
662 Register Reg = MO.getReg();
663 if (!Reg.isVirtual())
664 continue;
665 auto &LiveMask = LiveRegs[Reg];
666 auto PrevMask = LiveMask;
667 LiveMask |= getDefRegMask(MO, *MRI);
668 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
669 }
670
671 MaxPressure = max(MaxPressure, CurPressure);
672 }
673
advance(MachineInstr * MI,bool UseInternalIterator)674 bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
675 if (UseInternalIterator && NextMI == MBBEnd)
676 return false;
677
678 advanceBeforeNext(MI, UseInternalIterator);
679 advanceToNext(MI, UseInternalIterator);
680 if (!UseInternalIterator) {
681 // We must remove any dead def lanes from the current RP
682 advanceBeforeNext(MI, true);
683 }
684 return true;
685 }
686
advance(MachineBasicBlock::const_iterator End)687 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
688 while (NextMI != End)
689 if (!advance()) return false;
690 return true;
691 }
692
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)693 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
694 MachineBasicBlock::const_iterator End,
695 const LiveRegSet *LiveRegsCopy) {
696 reset(*Begin, LiveRegsCopy);
697 return advance(End);
698 }
699
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI,StringRef Pfx)700 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
701 const GCNRPTracker::LiveRegSet &TrackedLR,
702 const TargetRegisterInfo *TRI, StringRef Pfx) {
703 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
704 for (auto const &P : TrackedLR) {
705 auto I = LISLR.find(P.first);
706 if (I == LISLR.end()) {
707 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
708 << " isn't found in LIS reported set\n";
709 } else if (I->second != P.second) {
710 OS << Pfx << printReg(P.first, TRI)
711 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
712 << ", tracked " << PrintLaneMask(P.second) << '\n';
713 }
714 }
715 for (auto const &P : LISLR) {
716 auto I = TrackedLR.find(P.first);
717 if (I == TrackedLR.end()) {
718 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
719 << " isn't found in tracked set\n";
720 }
721 }
722 });
723 }
724
725 GCNRegPressure
bumpDownwardPressure(const MachineInstr * MI,const SIRegisterInfo * TRI) const726 GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
727 const SIRegisterInfo *TRI) const {
728 assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
729
730 SlotIndex SlotIdx;
731 SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();
732
733 // Account for register pressure similar to RegPressureTracker::recede().
734 RegisterOperands RegOpers;
735 RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
736 RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
737 GCNRegPressure TempPressure = CurPressure;
738
739 for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
740 Register Reg = Use.RegUnit;
741 if (!Reg.isVirtual())
742 continue;
743 LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
744 if (LastUseMask.none())
745 continue;
746 // The LastUseMask is queried from the liveness information of instruction
747 // which may be further down the schedule. Some lanes may actually not be
748 // last uses for the current position.
749 // FIXME: allow the caller to pass in the list of vreg uses that remain
750 // to be bottom-scheduled to avoid searching uses at each query.
751 SlotIndex CurrIdx;
752 const MachineBasicBlock *MBB = MI->getParent();
753 MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
754 LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
755 if (IdxPos == MBB->end()) {
756 CurrIdx = LIS.getMBBEndIdx(MBB);
757 } else {
758 CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
759 }
760
761 LastUseMask =
762 findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
763 if (LastUseMask.none())
764 continue;
765
766 auto It = LiveRegs.find(Reg);
767 LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
768 LaneBitmask NewMask = LiveMask & ~LastUseMask;
769 TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
770 }
771
772 // Generate liveness for defs.
773 for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
774 Register Reg = Def.RegUnit;
775 if (!Reg.isVirtual())
776 continue;
777 auto It = LiveRegs.find(Reg);
778 LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
779 LaneBitmask NewMask = LiveMask | Def.LaneMask;
780 TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
781 }
782
783 return TempPressure;
784 }
785
isValid() const786 bool GCNUpwardRPTracker::isValid() const {
787 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
788 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
789 const auto &TrackedLR = LiveRegs;
790
791 if (!isEqual(LISLR, TrackedLR)) {
792 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
793 " LIS reported livesets mismatch:\n"
794 << print(LISLR, *MRI);
795 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
796 return false;
797 }
798
799 auto LISPressure = getRegPressure(*MRI, LISLR);
800 if (LISPressure != CurPressure) {
801 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
802 << print(CurPressure) << "LIS rpt: " << print(LISPressure);
803 return false;
804 }
805 return true;
806 }
807
print(const GCNRPTracker::LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)808 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
809 const MachineRegisterInfo &MRI) {
810 return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
811 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
812 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
813 Register Reg = Register::index2VirtReg(I);
814 auto It = LiveRegs.find(Reg);
815 if (It != LiveRegs.end() && It->second.any())
816 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
817 << PrintLaneMask(It->second);
818 }
819 OS << '\n';
820 });
821 }
822
dump() const823 void GCNRegPressure::dump() const { dbgs() << print(*this); }
824
825 static cl::opt<bool> UseDownwardTracker(
826 "amdgpu-print-rp-downward",
827 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
828 cl::init(false), cl::Hidden);
829
830 char llvm::GCNRegPressurePrinter::ID = 0;
831 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;
832
833 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
834
835 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and
836 // are fully covered by Mask.
837 static LaneBitmask
getRegLiveThroughMask(const MachineRegisterInfo & MRI,const LiveIntervals & LIS,Register Reg,SlotIndex Begin,SlotIndex End,LaneBitmask Mask=LaneBitmask::getAll ())838 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS,
839 Register Reg, SlotIndex Begin, SlotIndex End,
840 LaneBitmask Mask = LaneBitmask::getAll()) {
841
842 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool {
843 auto *Segment = LR.getSegmentContaining(Begin);
844 return Segment && Segment->contains(End);
845 };
846
847 LaneBitmask LiveThroughMask;
848 const LiveInterval &LI = LIS.getInterval(Reg);
849 if (LI.hasSubRanges()) {
850 for (auto &SR : LI.subranges()) {
851 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
852 LiveThroughMask |= SR.LaneMask;
853 }
854 } else {
855 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg);
856 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
857 LiveThroughMask = RegMask;
858 }
859
860 return LiveThroughMask;
861 }
862
runOnMachineFunction(MachineFunction & MF)863 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
864 const MachineRegisterInfo &MRI = MF.getRegInfo();
865 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
866 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
867
868 auto &OS = dbgs();
869
870 // Leading spaces are important for YAML syntax.
871 #define PFX " "
872
873 OS << "---\nname: " << MF.getName() << "\nbody: |\n";
874
875 auto printRP = [](const GCNRegPressure &RP) {
876 return Printable([&RP](raw_ostream &OS) {
877 OS << format(PFX " %-5d", RP.getSGPRNum())
878 << format(" %-5d", RP.getVGPRNum(false));
879 });
880 };
881
882 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
883 const GCNRPTracker::LiveRegSet &LISLR) {
884 if (LISLR != TrackedLR) {
885 OS << PFX " mis LIS: " << llvm::print(LISLR, MRI)
886 << reportMismatch(LISLR, TrackedLR, TRI, PFX " ");
887 }
888 };
889
890 // Register pressure before and at an instruction (in program order).
891 SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;
892
893 for (auto &MBB : MF) {
894 RP.clear();
895 RP.reserve(MBB.size());
896
897 OS << PFX;
898 MBB.printName(OS);
899 OS << ":\n";
900
901 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
902 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);
903
904 GCNRPTracker::LiveRegSet LiveIn, LiveOut;
905 GCNRegPressure RPAtMBBEnd;
906
907 if (UseDownwardTracker) {
908 if (MBB.empty()) {
909 LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI);
910 RPAtMBBEnd = getRegPressure(MRI, LiveIn);
911 } else {
912 GCNDownwardRPTracker RPT(LIS);
913 RPT.reset(MBB.front());
914
915 LiveIn = RPT.getLiveRegs();
916
917 while (!RPT.advanceBeforeNext()) {
918 GCNRegPressure RPBeforeMI = RPT.getPressure();
919 RPT.advanceToNext();
920 RP.emplace_back(RPBeforeMI, RPT.getPressure());
921 }
922
923 LiveOut = RPT.getLiveRegs();
924 RPAtMBBEnd = RPT.getPressure();
925 }
926 } else {
927 GCNUpwardRPTracker RPT(LIS);
928 RPT.reset(MRI, MBBEndSlot);
929
930 LiveOut = RPT.getLiveRegs();
931 RPAtMBBEnd = RPT.getPressure();
932
933 for (auto &MI : reverse(MBB)) {
934 RPT.resetMaxPressure();
935 RPT.recede(MI);
936 if (!MI.isDebugInstr())
937 RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure());
938 }
939
940 LiveIn = RPT.getLiveRegs();
941 }
942
943 OS << PFX " Live-in: " << llvm::print(LiveIn, MRI);
944 if (!UseDownwardTracker)
945 ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI));
946
947 OS << PFX " SGPR VGPR\n";
948 int I = 0;
949 for (auto &MI : MBB) {
950 if (!MI.isDebugInstr()) {
951 auto &[RPBeforeInstr, RPAtInstr] =
952 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
953 ++I;
954 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " ";
955 } else
956 OS << PFX " ";
957 MI.print(OS);
958 }
959 OS << printRP(RPAtMBBEnd) << '\n';
960
961 OS << PFX " Live-out:" << llvm::print(LiveOut, MRI);
962 if (UseDownwardTracker)
963 ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI));
964
965 GCNRPTracker::LiveRegSet LiveThrough;
966 for (auto [Reg, Mask] : LiveIn) {
967 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg);
968 if (MaskIntersection.any()) {
969 LaneBitmask LTMask = getRegLiveThroughMask(
970 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
971 if (LTMask.any())
972 LiveThrough[Reg] = LTMask;
973 }
974 }
975 OS << PFX " Live-thr:" << llvm::print(LiveThrough, MRI);
976 OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n';
977 }
978 OS << "...\n";
979 return false;
980
981 #undef PFX
982 }
983