1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPU.h" 16 #include "SIMachineFunctionInfo.h" 17 #include "llvm/CodeGen/RegisterPressure.h" 18 19 using namespace llvm; 20 21 #define DEBUG_TYPE "machine-scheduler" 22 23 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 24 const GCNRPTracker::LiveRegSet &S2) { 25 if (S1.size() != S2.size()) 26 return false; 27 28 for (const auto &P : S1) { 29 auto I = S2.find(P.first); 30 if (I == S2.end() || I->second != P.second) 31 return false; 32 } 33 return true; 34 } 35 36 /////////////////////////////////////////////////////////////////////////////// 37 // GCNRegPressure 38 39 unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC, 40 const SIRegisterInfo *STI) { 41 return STI->isSGPRClass(RC) ? SGPR : (STI->isAGPRClass(RC) ? AGPR : VGPR); 42 } 43 44 void GCNRegPressure::inc(unsigned Reg, 45 LaneBitmask PrevMask, 46 LaneBitmask NewMask, 47 const MachineRegisterInfo &MRI) { 48 unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(NewMask); 49 unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(PrevMask); 50 if (NewNumCoveredRegs == PrevNumCoveredRegs) 51 return; 52 53 int Sign = 1; 54 if (NewMask < PrevMask) { 55 std::swap(NewMask, PrevMask); 56 std::swap(NewNumCoveredRegs, PrevNumCoveredRegs); 57 Sign = -1; 58 } 59 assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs && 60 "prev mask should always be lesser than new"); 61 62 const TargetRegisterClass *RC = MRI.getRegClass(Reg); 63 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 64 const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI); 65 unsigned RegKind = getRegKind(RC, STI); 66 if (TRI->getRegSizeInBits(*RC) != 32) { 67 // Reg is from a tuple register class. 68 if (PrevMask.none()) { 69 unsigned TupleIdx = TOTAL_KINDS + RegKind; 70 Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight; 71 } 72 // Pressure scales with number of new registers covered by the new mask. 73 // Note when true16 is enabled, we can no longer safely use the following 74 // approach to calculate the difference in the number of 32-bit registers 75 // between two masks: 76 // 77 // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 78 // 79 // The issue is that the mask calculation `~PrevMask & NewMask` doesn't 80 // properly account for partial usage of a 32-bit register when dealing with 81 // 16-bit registers. 82 // 83 // Consider this example: 84 // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register 85 // usage difference should be 1, because even though PrevMask uses only half 86 // of a 32-bit register, it should still be counted as a full register use. 87 // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and 88 // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect 89 // calculation can lead to integer overflow when Sign = -1. 90 Sign *= NewNumCoveredRegs - PrevNumCoveredRegs; 91 } 92 Value[RegKind] += Sign; 93 } 94 95 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, 96 unsigned MaxOccupancy) const { 97 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 98 unsigned DynamicVGPRBlockSize = 99 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); 100 101 const auto SGPROcc = std::min(MaxOccupancy, 102 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 103 const auto VGPROcc = std::min( 104 MaxOccupancy, ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()), 105 DynamicVGPRBlockSize)); 106 const auto OtherSGPROcc = std::min(MaxOccupancy, 107 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 108 const auto OtherVGPROcc = 109 std::min(MaxOccupancy, 110 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()), 111 DynamicVGPRBlockSize)); 112 113 const auto Occ = std::min(SGPROcc, VGPROcc); 114 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 115 116 // Give first precedence to the better occupancy. 117 if (Occ != OtherOcc) 118 return Occ > OtherOcc; 119 120 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); 121 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF); 122 123 // SGPR excess pressure conditions 124 unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0); 125 unsigned OtherExcessSGPR = 126 std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0); 127 128 auto WaveSize = ST.getWavefrontSize(); 129 // The number of virtual VGPRs required to handle excess SGPR 130 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize; 131 unsigned OtherVGPRForSGPRSpills = 132 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize; 133 134 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); 135 136 // Unified excess pressure conditions, accounting for VGPRs used for SGPR 137 // spills 138 unsigned ExcessVGPR = 139 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) + 140 VGPRForSGPRSpills - MaxVGPRs), 141 0); 142 unsigned OtherExcessVGPR = 143 std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) + 144 OtherVGPRForSGPRSpills - MaxVGPRs), 145 0); 146 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR 147 // spills 148 unsigned ExcessArchVGPR = std::max( 149 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs), 150 0); 151 unsigned OtherExcessArchVGPR = 152 std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills - 153 MaxArchVGPRs), 154 0); 155 // AGPR excess pressure conditions 156 unsigned ExcessAGPR = std::max( 157 static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs) 158 : (getAGPRNum() - MaxVGPRs)), 159 0); 160 unsigned OtherExcessAGPR = std::max( 161 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs) 162 : (O.getAGPRNum() - MaxVGPRs)), 163 0); 164 165 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR; 166 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR || 167 OtherExcessArchVGPR || OtherExcessAGPR; 168 169 // Give second precedence to the reduced number of spills to hold the register 170 // pressure. 171 if (ExcessRP || OtherExcessRP) { 172 // The difference in excess VGPR pressure, after including VGPRs used for 173 // SGPR spills 174 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) - 175 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR)); 176 177 int SGPRDiff = OtherExcessSGPR - ExcessSGPR; 178 179 if (VGPRDiff != 0) 180 return VGPRDiff > 0; 181 if (SGPRDiff != 0) { 182 unsigned PureExcessVGPR = 183 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 184 0) + 185 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0); 186 unsigned OtherPureExcessVGPR = 187 std::max( 188 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 189 0) + 190 std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0); 191 192 // If we have a special case where there is a tie in excess VGPR, but one 193 // of the pressures has VGPR usage from SGPR spills, prefer the pressure 194 // with SGPR spills. 195 if (PureExcessVGPR != OtherPureExcessVGPR) 196 return SGPRDiff < 0; 197 // If both pressures have the same excess pressure before and after 198 // accounting for SGPR spills, prefer fewer SGPR spills. 199 return SGPRDiff > 0; 200 } 201 } 202 203 bool SGPRImportant = SGPROcc < VGPROcc; 204 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 205 206 // If both pressures disagree on what is more important compare vgprs. 207 if (SGPRImportant != OtherSGPRImportant) { 208 SGPRImportant = false; 209 } 210 211 // Give third precedence to lower register tuple pressure. 212 bool SGPRFirst = SGPRImportant; 213 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 214 if (SGPRFirst) { 215 auto SW = getSGPRTuplesWeight(); 216 auto OtherSW = O.getSGPRTuplesWeight(); 217 if (SW != OtherSW) 218 return SW < OtherSW; 219 } else { 220 auto VW = getVGPRTuplesWeight(); 221 auto OtherVW = O.getVGPRTuplesWeight(); 222 if (VW != OtherVW) 223 return VW < OtherVW; 224 } 225 } 226 227 // Give final precedence to lower general RP. 228 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 229 (getVGPRNum(ST.hasGFX90AInsts()) < 230 O.getVGPRNum(ST.hasGFX90AInsts())); 231 } 232 233 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST, 234 unsigned DynamicVGPRBlockSize) { 235 return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) { 236 OS << "VGPRs: " << RP.getArchVGPRNum() << ' ' 237 << "AGPRs: " << RP.getAGPRNum(); 238 if (ST) 239 OS << "(O" 240 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()), 241 DynamicVGPRBlockSize) 242 << ')'; 243 OS << ", SGPRs: " << RP.getSGPRNum(); 244 if (ST) 245 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')'; 246 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() 247 << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); 248 if (ST) 249 OS << " -> Occ: " << RP.getOccupancy(*ST, DynamicVGPRBlockSize); 250 OS << '\n'; 251 }); 252 } 253 254 static LaneBitmask getDefRegMask(const MachineOperand &MO, 255 const MachineRegisterInfo &MRI) { 256 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 257 258 // We don't rely on read-undef flag because in case of tentative schedule 259 // tracking it isn't set correctly yet. This works correctly however since 260 // use mask has been tracked before using LIS. 261 return MO.getSubReg() == 0 ? 262 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 263 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 264 } 265 266 static void 267 collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits, 268 const MachineInstr &MI, const LiveIntervals &LIS, 269 const MachineRegisterInfo &MRI) { 270 271 auto &TRI = *MRI.getTargetRegisterInfo(); 272 for (const auto &MO : MI.operands()) { 273 if (!MO.isReg() || !MO.getReg().isVirtual()) 274 continue; 275 if (!MO.isUse() || !MO.readsReg()) 276 continue; 277 278 Register Reg = MO.getReg(); 279 auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) { 280 return RM.RegUnit == Reg; 281 }); 282 283 auto &P = I == VRegMaskOrUnits.end() 284 ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone()) 285 : *I; 286 287 P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg()) 288 : MRI.getMaxLaneMaskForVReg(Reg); 289 } 290 291 SlotIndex InstrSI; 292 for (auto &P : VRegMaskOrUnits) { 293 auto &LI = LIS.getInterval(P.RegUnit); 294 if (!LI.hasSubRanges()) 295 continue; 296 297 // For a tentative schedule LIS isn't updated yet but livemask should 298 // remain the same on any schedule. Subreg defs can be reordered but they 299 // all must dominate uses anyway. 300 if (!InstrSI) 301 InstrSI = LIS.getInstructionIndex(MI).getBaseIndex(); 302 303 P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask); 304 } 305 } 306 307 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp 308 static LaneBitmask getLanesWithProperty( 309 const LiveIntervals &LIS, const MachineRegisterInfo &MRI, 310 bool TrackLaneMasks, Register RegUnit, SlotIndex Pos, 311 LaneBitmask SafeDefault, 312 function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) { 313 if (RegUnit.isVirtual()) { 314 const LiveInterval &LI = LIS.getInterval(RegUnit); 315 LaneBitmask Result; 316 if (TrackLaneMasks && LI.hasSubRanges()) { 317 for (const LiveInterval::SubRange &SR : LI.subranges()) { 318 if (Property(SR, Pos)) 319 Result |= SR.LaneMask; 320 } 321 } else if (Property(LI, Pos)) { 322 Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit) 323 : LaneBitmask::getAll(); 324 } 325 326 return Result; 327 } 328 329 const LiveRange *LR = LIS.getCachedRegUnit(RegUnit); 330 if (LR == nullptr) 331 return SafeDefault; 332 return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone(); 333 } 334 335 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp 336 /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}. 337 /// The query starts with a lane bitmask which gets lanes/bits removed for every 338 /// use we find. 339 static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, 340 SlotIndex PriorUseIdx, SlotIndex NextUseIdx, 341 const MachineRegisterInfo &MRI, 342 const SIRegisterInfo *TRI, 343 const LiveIntervals *LIS, 344 bool Upward = false) { 345 for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { 346 if (MO.isUndef()) 347 continue; 348 const MachineInstr *MI = MO.getParent(); 349 SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot(); 350 bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx) 351 : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx); 352 if (!InRange) 353 continue; 354 355 unsigned SubRegIdx = MO.getSubReg(); 356 LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx); 357 LastUseMask &= ~UseMask; 358 if (LastUseMask.none()) 359 return LaneBitmask::getNone(); 360 } 361 return LastUseMask; 362 } 363 364 //////////////////////////////////////////////////////////////////////////////// 365 // GCNRPTarget 366 367 GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP, 368 bool CombineVGPRSavings) 369 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) { 370 const Function &F = MF.getFunction(); 371 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 372 setRegLimits(ST.getMaxNumSGPRs(F), ST.getMaxNumVGPRs(F), MF); 373 } 374 375 GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs, 376 const MachineFunction &MF, const GCNRegPressure &RP, 377 bool CombineVGPRSavings) 378 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) { 379 setRegLimits(NumSGPRs, NumVGPRs, MF); 380 } 381 382 GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF, 383 const GCNRegPressure &RP, bool CombineVGPRSavings) 384 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) { 385 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 386 unsigned DynamicVGPRBlockSize = 387 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); 388 setRegLimits(ST.getMaxNumSGPRs(Occupancy, /*Addressable=*/false), 389 ST.getMaxNumVGPRs(Occupancy, DynamicVGPRBlockSize), MF); 390 } 391 392 void GCNRPTarget::setRegLimits(unsigned NumSGPRs, unsigned NumVGPRs, 393 const MachineFunction &MF) { 394 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 395 unsigned DynamicVGPRBlockSize = 396 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); 397 MaxSGPRs = std::min(ST.getAddressableNumSGPRs(), NumSGPRs); 398 MaxVGPRs = std::min(ST.getAddressableNumArchVGPRs(), NumVGPRs); 399 MaxUnifiedVGPRs = 400 ST.hasGFX90AInsts() 401 ? std::min(ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), NumVGPRs) 402 : 0; 403 } 404 405 bool GCNRPTarget::isSaveBeneficial(Register Reg, 406 const MachineRegisterInfo &MRI) const { 407 const TargetRegisterClass *RC = MRI.getRegClass(Reg); 408 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 409 const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI); 410 411 if (SRI->isSGPRClass(RC)) 412 return RP.getSGPRNum() > MaxSGPRs; 413 unsigned NumVGPRs = 414 SRI->isAGPRClass(RC) ? RP.getAGPRNum() : RP.getArchVGPRNum(); 415 return isVGPRBankSaveBeneficial(NumVGPRs); 416 } 417 418 bool GCNRPTarget::satisfied() const { 419 if (RP.getSGPRNum() > MaxSGPRs) 420 return false; 421 if (RP.getVGPRNum(false) > MaxVGPRs && 422 (!CombineVGPRSavings || !satisifiesVGPRBanksTarget())) 423 return false; 424 return satisfiesUnifiedTarget(); 425 } 426 427 /////////////////////////////////////////////////////////////////////////////// 428 // GCNRPTracker 429 430 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, 431 const LiveIntervals &LIS, 432 const MachineRegisterInfo &MRI, 433 LaneBitmask LaneMaskFilter) { 434 return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter); 435 } 436 437 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, 438 const MachineRegisterInfo &MRI, 439 LaneBitmask LaneMaskFilter) { 440 LaneBitmask LiveMask; 441 if (LI.hasSubRanges()) { 442 for (const auto &S : LI.subranges()) 443 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) { 444 LiveMask |= S.LaneMask; 445 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); 446 } 447 } else if (LI.liveAt(SI)) { 448 LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg()); 449 } 450 LiveMask &= LaneMaskFilter; 451 return LiveMask; 452 } 453 454 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 455 const LiveIntervals &LIS, 456 const MachineRegisterInfo &MRI) { 457 GCNRPTracker::LiveRegSet LiveRegs; 458 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 459 auto Reg = Register::index2VirtReg(I); 460 if (!LIS.hasInterval(Reg)) 461 continue; 462 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 463 if (LiveMask.any()) 464 LiveRegs[Reg] = LiveMask; 465 } 466 return LiveRegs; 467 } 468 469 void GCNRPTracker::reset(const MachineInstr &MI, 470 const LiveRegSet *LiveRegsCopy, 471 bool After) { 472 const MachineFunction &MF = *MI.getMF(); 473 MRI = &MF.getRegInfo(); 474 if (LiveRegsCopy) { 475 if (&LiveRegs != LiveRegsCopy) 476 LiveRegs = *LiveRegsCopy; 477 } else { 478 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 479 : getLiveRegsBefore(MI, LIS); 480 } 481 482 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 483 } 484 485 void GCNRPTracker::reset(const MachineRegisterInfo &MRI_, 486 const LiveRegSet &LiveRegs_) { 487 MRI = &MRI_; 488 LiveRegs = LiveRegs_; 489 LastTrackedMI = nullptr; 490 MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_); 491 } 492 493 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp 494 LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit, 495 SlotIndex Pos) const { 496 return getLanesWithProperty( 497 LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(), 498 [](const LiveRange &LR, SlotIndex Pos) { 499 const LiveRange::Segment *S = LR.getSegmentContaining(Pos); 500 return S != nullptr && S->end == Pos.getRegSlot(); 501 }); 502 } 503 504 //////////////////////////////////////////////////////////////////////////////// 505 // GCNUpwardRPTracker 506 507 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 508 assert(MRI && "call reset first"); 509 510 LastTrackedMI = &MI; 511 512 if (MI.isDebugInstr()) 513 return; 514 515 // Kill all defs. 516 GCNRegPressure DefPressure, ECDefPressure; 517 bool HasECDefs = false; 518 for (const MachineOperand &MO : MI.all_defs()) { 519 if (!MO.getReg().isVirtual()) 520 continue; 521 522 Register Reg = MO.getReg(); 523 LaneBitmask DefMask = getDefRegMask(MO, *MRI); 524 525 // Treat a def as fully live at the moment of definition: keep a record. 526 if (MO.isEarlyClobber()) { 527 ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 528 HasECDefs = true; 529 } else 530 DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 531 532 auto I = LiveRegs.find(Reg); 533 if (I == LiveRegs.end()) 534 continue; 535 536 LaneBitmask &LiveMask = I->second; 537 LaneBitmask PrevMask = LiveMask; 538 LiveMask &= ~DefMask; 539 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 540 if (LiveMask.none()) 541 LiveRegs.erase(I); 542 } 543 544 // Update MaxPressure with defs pressure. 545 DefPressure += CurPressure; 546 if (HasECDefs) 547 DefPressure += ECDefPressure; 548 MaxPressure = max(DefPressure, MaxPressure); 549 550 // Make uses alive. 551 SmallVector<VRegMaskOrUnit, 8> RegUses; 552 collectVirtualRegUses(RegUses, MI, LIS, *MRI); 553 for (const VRegMaskOrUnit &U : RegUses) { 554 LaneBitmask &LiveMask = LiveRegs[U.RegUnit]; 555 LaneBitmask PrevMask = LiveMask; 556 LiveMask |= U.LaneMask; 557 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 558 } 559 560 // Update MaxPressure with uses plus early-clobber defs pressure. 561 MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure) 562 : max(CurPressure, MaxPressure); 563 564 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 565 } 566 567 //////////////////////////////////////////////////////////////////////////////// 568 // GCNDownwardRPTracker 569 570 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 571 const LiveRegSet *LiveRegsCopy) { 572 MRI = &MI.getParent()->getParent()->getRegInfo(); 573 LastTrackedMI = nullptr; 574 MBBEnd = MI.getParent()->end(); 575 NextMI = &MI; 576 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 577 if (NextMI == MBBEnd) 578 return false; 579 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 580 return true; 581 } 582 583 bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI, 584 bool UseInternalIterator) { 585 assert(MRI && "call reset first"); 586 SlotIndex SI; 587 const MachineInstr *CurrMI; 588 if (UseInternalIterator) { 589 if (!LastTrackedMI) 590 return NextMI == MBBEnd; 591 592 assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); 593 CurrMI = LastTrackedMI; 594 595 SI = NextMI == MBBEnd 596 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot() 597 : LIS.getInstructionIndex(*NextMI).getBaseIndex(); 598 } else { //! UseInternalIterator 599 SI = LIS.getInstructionIndex(*MI).getBaseIndex(); 600 CurrMI = MI; 601 } 602 603 assert(SI.isValid()); 604 605 // Remove dead registers or mask bits. 606 SmallSet<Register, 8> SeenRegs; 607 for (auto &MO : CurrMI->operands()) { 608 if (!MO.isReg() || !MO.getReg().isVirtual()) 609 continue; 610 if (MO.isUse() && !MO.readsReg()) 611 continue; 612 if (!UseInternalIterator && MO.isDef()) 613 continue; 614 if (!SeenRegs.insert(MO.getReg()).second) 615 continue; 616 const LiveInterval &LI = LIS.getInterval(MO.getReg()); 617 if (LI.hasSubRanges()) { 618 auto It = LiveRegs.end(); 619 for (const auto &S : LI.subranges()) { 620 if (!S.liveAt(SI)) { 621 if (It == LiveRegs.end()) { 622 It = LiveRegs.find(MO.getReg()); 623 if (It == LiveRegs.end()) 624 llvm_unreachable("register isn't live"); 625 } 626 auto PrevMask = It->second; 627 It->second &= ~S.LaneMask; 628 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI); 629 } 630 } 631 if (It != LiveRegs.end() && It->second.none()) 632 LiveRegs.erase(It); 633 } else if (!LI.liveAt(SI)) { 634 auto It = LiveRegs.find(MO.getReg()); 635 if (It == LiveRegs.end()) 636 llvm_unreachable("register isn't live"); 637 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI); 638 LiveRegs.erase(It); 639 } 640 } 641 642 MaxPressure = max(MaxPressure, CurPressure); 643 644 LastTrackedMI = nullptr; 645 646 return UseInternalIterator && (NextMI == MBBEnd); 647 } 648 649 void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, 650 bool UseInternalIterator) { 651 if (UseInternalIterator) { 652 LastTrackedMI = &*NextMI++; 653 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 654 } else { 655 LastTrackedMI = MI; 656 } 657 658 const MachineInstr *CurrMI = LastTrackedMI; 659 660 // Add new registers or mask bits. 661 for (const auto &MO : CurrMI->all_defs()) { 662 Register Reg = MO.getReg(); 663 if (!Reg.isVirtual()) 664 continue; 665 auto &LiveMask = LiveRegs[Reg]; 666 auto PrevMask = LiveMask; 667 LiveMask |= getDefRegMask(MO, *MRI); 668 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 669 } 670 671 MaxPressure = max(MaxPressure, CurPressure); 672 } 673 674 bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) { 675 if (UseInternalIterator && NextMI == MBBEnd) 676 return false; 677 678 advanceBeforeNext(MI, UseInternalIterator); 679 advanceToNext(MI, UseInternalIterator); 680 if (!UseInternalIterator) { 681 // We must remove any dead def lanes from the current RP 682 advanceBeforeNext(MI, true); 683 } 684 return true; 685 } 686 687 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 688 while (NextMI != End) 689 if (!advance()) return false; 690 return true; 691 } 692 693 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 694 MachineBasicBlock::const_iterator End, 695 const LiveRegSet *LiveRegsCopy) { 696 reset(*Begin, LiveRegsCopy); 697 return advance(End); 698 } 699 700 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 701 const GCNRPTracker::LiveRegSet &TrackedLR, 702 const TargetRegisterInfo *TRI, StringRef Pfx) { 703 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { 704 for (auto const &P : TrackedLR) { 705 auto I = LISLR.find(P.first); 706 if (I == LISLR.end()) { 707 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 708 << " isn't found in LIS reported set\n"; 709 } else if (I->second != P.second) { 710 OS << Pfx << printReg(P.first, TRI) 711 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second) 712 << ", tracked " << PrintLaneMask(P.second) << '\n'; 713 } 714 } 715 for (auto const &P : LISLR) { 716 auto I = TrackedLR.find(P.first); 717 if (I == TrackedLR.end()) { 718 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 719 << " isn't found in tracked set\n"; 720 } 721 } 722 }); 723 } 724 725 GCNRegPressure 726 GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI, 727 const SIRegisterInfo *TRI) const { 728 assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction."); 729 730 SlotIndex SlotIdx; 731 SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot(); 732 733 // Account for register pressure similar to RegPressureTracker::recede(). 734 RegisterOperands RegOpers; 735 RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false); 736 RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx); 737 GCNRegPressure TempPressure = CurPressure; 738 739 for (const VRegMaskOrUnit &Use : RegOpers.Uses) { 740 Register Reg = Use.RegUnit; 741 if (!Reg.isVirtual()) 742 continue; 743 LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx); 744 if (LastUseMask.none()) 745 continue; 746 // The LastUseMask is queried from the liveness information of instruction 747 // which may be further down the schedule. Some lanes may actually not be 748 // last uses for the current position. 749 // FIXME: allow the caller to pass in the list of vreg uses that remain 750 // to be bottom-scheduled to avoid searching uses at each query. 751 SlotIndex CurrIdx; 752 const MachineBasicBlock *MBB = MI->getParent(); 753 MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward( 754 LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end()); 755 if (IdxPos == MBB->end()) { 756 CurrIdx = LIS.getMBBEndIdx(MBB); 757 } else { 758 CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot(); 759 } 760 761 LastUseMask = 762 findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS); 763 if (LastUseMask.none()) 764 continue; 765 766 auto It = LiveRegs.find(Reg); 767 LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); 768 LaneBitmask NewMask = LiveMask & ~LastUseMask; 769 TempPressure.inc(Reg, LiveMask, NewMask, *MRI); 770 } 771 772 // Generate liveness for defs. 773 for (const VRegMaskOrUnit &Def : RegOpers.Defs) { 774 Register Reg = Def.RegUnit; 775 if (!Reg.isVirtual()) 776 continue; 777 auto It = LiveRegs.find(Reg); 778 LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); 779 LaneBitmask NewMask = LiveMask | Def.LaneMask; 780 TempPressure.inc(Reg, LiveMask, NewMask, *MRI); 781 } 782 783 return TempPressure; 784 } 785 786 bool GCNUpwardRPTracker::isValid() const { 787 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 788 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 789 const auto &TrackedLR = LiveRegs; 790 791 if (!isEqual(LISLR, TrackedLR)) { 792 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 793 " LIS reported livesets mismatch:\n" 794 << print(LISLR, *MRI); 795 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 796 return false; 797 } 798 799 auto LISPressure = getRegPressure(*MRI, LISLR); 800 if (LISPressure != CurPressure) { 801 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " 802 << print(CurPressure) << "LIS rpt: " << print(LISPressure); 803 return false; 804 } 805 return true; 806 } 807 808 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, 809 const MachineRegisterInfo &MRI) { 810 return Printable([&LiveRegs, &MRI](raw_ostream &OS) { 811 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 812 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 813 Register Reg = Register::index2VirtReg(I); 814 auto It = LiveRegs.find(Reg); 815 if (It != LiveRegs.end() && It->second.any()) 816 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 817 << PrintLaneMask(It->second); 818 } 819 OS << '\n'; 820 }); 821 } 822 823 void GCNRegPressure::dump() const { dbgs() << print(*this); } 824 825 static cl::opt<bool> UseDownwardTracker( 826 "amdgpu-print-rp-downward", 827 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), 828 cl::init(false), cl::Hidden); 829 830 char llvm::GCNRegPressurePrinter::ID = 0; 831 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; 832 833 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true) 834 835 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and 836 // are fully covered by Mask. 837 static LaneBitmask 838 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, 839 Register Reg, SlotIndex Begin, SlotIndex End, 840 LaneBitmask Mask = LaneBitmask::getAll()) { 841 842 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { 843 auto *Segment = LR.getSegmentContaining(Begin); 844 return Segment && Segment->contains(End); 845 }; 846 847 LaneBitmask LiveThroughMask; 848 const LiveInterval &LI = LIS.getInterval(Reg); 849 if (LI.hasSubRanges()) { 850 for (auto &SR : LI.subranges()) { 851 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) 852 LiveThroughMask |= SR.LaneMask; 853 } 854 } else { 855 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); 856 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) 857 LiveThroughMask = RegMask; 858 } 859 860 return LiveThroughMask; 861 } 862 863 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { 864 const MachineRegisterInfo &MRI = MF.getRegInfo(); 865 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 866 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); 867 868 auto &OS = dbgs(); 869 870 // Leading spaces are important for YAML syntax. 871 #define PFX " " 872 873 OS << "---\nname: " << MF.getName() << "\nbody: |\n"; 874 875 auto printRP = [](const GCNRegPressure &RP) { 876 return Printable([&RP](raw_ostream &OS) { 877 OS << format(PFX " %-5d", RP.getSGPRNum()) 878 << format(" %-5d", RP.getVGPRNum(false)); 879 }); 880 }; 881 882 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, 883 const GCNRPTracker::LiveRegSet &LISLR) { 884 if (LISLR != TrackedLR) { 885 OS << PFX " mis LIS: " << llvm::print(LISLR, MRI) 886 << reportMismatch(LISLR, TrackedLR, TRI, PFX " "); 887 } 888 }; 889 890 // Register pressure before and at an instruction (in program order). 891 SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; 892 893 for (auto &MBB : MF) { 894 RP.clear(); 895 RP.reserve(MBB.size()); 896 897 OS << PFX; 898 MBB.printName(OS); 899 OS << ":\n"; 900 901 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB); 902 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB); 903 904 GCNRPTracker::LiveRegSet LiveIn, LiveOut; 905 GCNRegPressure RPAtMBBEnd; 906 907 if (UseDownwardTracker) { 908 if (MBB.empty()) { 909 LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI); 910 RPAtMBBEnd = getRegPressure(MRI, LiveIn); 911 } else { 912 GCNDownwardRPTracker RPT(LIS); 913 RPT.reset(MBB.front()); 914 915 LiveIn = RPT.getLiveRegs(); 916 917 while (!RPT.advanceBeforeNext()) { 918 GCNRegPressure RPBeforeMI = RPT.getPressure(); 919 RPT.advanceToNext(); 920 RP.emplace_back(RPBeforeMI, RPT.getPressure()); 921 } 922 923 LiveOut = RPT.getLiveRegs(); 924 RPAtMBBEnd = RPT.getPressure(); 925 } 926 } else { 927 GCNUpwardRPTracker RPT(LIS); 928 RPT.reset(MRI, MBBEndSlot); 929 930 LiveOut = RPT.getLiveRegs(); 931 RPAtMBBEnd = RPT.getPressure(); 932 933 for (auto &MI : reverse(MBB)) { 934 RPT.resetMaxPressure(); 935 RPT.recede(MI); 936 if (!MI.isDebugInstr()) 937 RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure()); 938 } 939 940 LiveIn = RPT.getLiveRegs(); 941 } 942 943 OS << PFX " Live-in: " << llvm::print(LiveIn, MRI); 944 if (!UseDownwardTracker) 945 ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI)); 946 947 OS << PFX " SGPR VGPR\n"; 948 int I = 0; 949 for (auto &MI : MBB) { 950 if (!MI.isDebugInstr()) { 951 auto &[RPBeforeInstr, RPAtInstr] = 952 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; 953 ++I; 954 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " "; 955 } else 956 OS << PFX " "; 957 MI.print(OS); 958 } 959 OS << printRP(RPAtMBBEnd) << '\n'; 960 961 OS << PFX " Live-out:" << llvm::print(LiveOut, MRI); 962 if (UseDownwardTracker) 963 ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI)); 964 965 GCNRPTracker::LiveRegSet LiveThrough; 966 for (auto [Reg, Mask] : LiveIn) { 967 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg); 968 if (MaskIntersection.any()) { 969 LaneBitmask LTMask = getRegLiveThroughMask( 970 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection); 971 if (LTMask.any()) 972 LiveThrough[Reg] = LTMask; 973 } 974 } 975 OS << PFX " Live-thr:" << llvm::print(LiveThrough, MRI); 976 OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n'; 977 } 978 OS << "...\n"; 979 return false; 980 981 #undef PFX 982 } 983