1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPU.h" 16 #include "llvm/CodeGen/RegisterPressure.h" 17 18 using namespace llvm; 19 20 #define DEBUG_TYPE "machine-scheduler" 21 22 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 23 const GCNRPTracker::LiveRegSet &S2) { 24 if (S1.size() != S2.size()) 25 return false; 26 27 for (const auto &P : S1) { 28 auto I = S2.find(P.first); 29 if (I == S2.end() || I->second != P.second) 30 return false; 31 } 32 return true; 33 } 34 35 /////////////////////////////////////////////////////////////////////////////// 36 // GCNRegPressure 37 38 unsigned GCNRegPressure::getRegKind(Register Reg, 39 const MachineRegisterInfo &MRI) { 40 assert(Reg.isVirtual()); 41 const auto RC = MRI.getRegClass(Reg); 42 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 43 return STI->isSGPRClass(RC) 44 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) 45 : STI->isAGPRClass(RC) 46 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) 47 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 48 } 49 50 void GCNRegPressure::inc(unsigned Reg, 51 LaneBitmask PrevMask, 52 LaneBitmask NewMask, 53 const MachineRegisterInfo &MRI) { 54 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 55 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 56 return; 57 58 int Sign = 1; 59 if (NewMask < PrevMask) { 60 std::swap(NewMask, PrevMask); 61 Sign = -1; 62 } 63 64 switch (auto Kind = getRegKind(Reg, MRI)) { 65 case SGPR32: 66 case VGPR32: 67 case AGPR32: 68 Value[Kind] += Sign; 69 break; 70 71 case SGPR_TUPLE: 72 case VGPR_TUPLE: 73 case AGPR_TUPLE: 74 assert(PrevMask < NewMask); 75 76 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 77 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 78 79 if (PrevMask.none()) { 80 assert(NewMask.any()); 81 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 82 Value[Kind] += 83 Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight; 84 } 85 break; 86 87 default: llvm_unreachable("Unknown register kind"); 88 } 89 } 90 91 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, 92 unsigned MaxOccupancy) const { 93 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 94 95 const auto SGPROcc = std::min(MaxOccupancy, 96 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 97 const auto VGPROcc = 98 std::min(MaxOccupancy, 99 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()))); 100 const auto OtherSGPROcc = std::min(MaxOccupancy, 101 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 102 const auto OtherVGPROcc = 103 std::min(MaxOccupancy, 104 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()))); 105 106 const auto Occ = std::min(SGPROcc, VGPROcc); 107 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 108 109 // Give first precedence to the better occupancy. 110 if (Occ != OtherOcc) 111 return Occ > OtherOcc; 112 113 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); 114 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF); 115 116 // SGPR excess pressure conditions 117 unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0); 118 unsigned OtherExcessSGPR = 119 std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0); 120 121 auto WaveSize = ST.getWavefrontSize(); 122 // The number of virtual VGPRs required to handle excess SGPR 123 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize; 124 unsigned OtherVGPRForSGPRSpills = 125 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize; 126 127 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); 128 129 // Unified excess pressure conditions, accounting for VGPRs used for SGPR 130 // spills 131 unsigned ExcessVGPR = 132 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) + 133 VGPRForSGPRSpills - MaxVGPRs), 134 0); 135 unsigned OtherExcessVGPR = 136 std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) + 137 OtherVGPRForSGPRSpills - MaxVGPRs), 138 0); 139 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR 140 // spills 141 unsigned ExcessArchVGPR = std::max( 142 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs), 143 0); 144 unsigned OtherExcessArchVGPR = 145 std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills - 146 MaxArchVGPRs), 147 0); 148 // AGPR excess pressure conditions 149 unsigned ExcessAGPR = std::max( 150 static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs) 151 : (getAGPRNum() - MaxVGPRs)), 152 0); 153 unsigned OtherExcessAGPR = std::max( 154 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs) 155 : (O.getAGPRNum() - MaxVGPRs)), 156 0); 157 158 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR; 159 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR || 160 OtherExcessArchVGPR || OtherExcessAGPR; 161 162 // Give second precedence to the reduced number of spills to hold the register 163 // pressure. 164 if (ExcessRP || OtherExcessRP) { 165 // The difference in excess VGPR pressure, after including VGPRs used for 166 // SGPR spills 167 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) - 168 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR)); 169 170 int SGPRDiff = OtherExcessSGPR - ExcessSGPR; 171 172 if (VGPRDiff != 0) 173 return VGPRDiff > 0; 174 if (SGPRDiff != 0) { 175 unsigned PureExcessVGPR = 176 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 177 0) + 178 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0); 179 unsigned OtherPureExcessVGPR = 180 std::max( 181 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 182 0) + 183 std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0); 184 185 // If we have a special case where there is a tie in excess VGPR, but one 186 // of the pressures has VGPR usage from SGPR spills, prefer the pressure 187 // with SGPR spills. 188 if (PureExcessVGPR != OtherPureExcessVGPR) 189 return SGPRDiff < 0; 190 // If both pressures have the same excess pressure before and after 191 // accounting for SGPR spills, prefer fewer SGPR spills. 192 return SGPRDiff > 0; 193 } 194 } 195 196 bool SGPRImportant = SGPROcc < VGPROcc; 197 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 198 199 // If both pressures disagree on what is more important compare vgprs. 200 if (SGPRImportant != OtherSGPRImportant) { 201 SGPRImportant = false; 202 } 203 204 // Give third precedence to lower register tuple pressure. 205 bool SGPRFirst = SGPRImportant; 206 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 207 if (SGPRFirst) { 208 auto SW = getSGPRTuplesWeight(); 209 auto OtherSW = O.getSGPRTuplesWeight(); 210 if (SW != OtherSW) 211 return SW < OtherSW; 212 } else { 213 auto VW = getVGPRTuplesWeight(); 214 auto OtherVW = O.getVGPRTuplesWeight(); 215 if (VW != OtherVW) 216 return VW < OtherVW; 217 } 218 } 219 220 // Give final precedence to lower general RP. 221 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 222 (getVGPRNum(ST.hasGFX90AInsts()) < 223 O.getVGPRNum(ST.hasGFX90AInsts())); 224 } 225 226 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) { 227 return Printable([&RP, ST](raw_ostream &OS) { 228 OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' ' 229 << "AGPRs: " << RP.getAGPRNum(); 230 if (ST) 231 OS << "(O" 232 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts())) 233 << ')'; 234 OS << ", SGPRs: " << RP.getSGPRNum(); 235 if (ST) 236 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')'; 237 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() 238 << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); 239 if (ST) 240 OS << " -> Occ: " << RP.getOccupancy(*ST); 241 OS << '\n'; 242 }); 243 } 244 245 static LaneBitmask getDefRegMask(const MachineOperand &MO, 246 const MachineRegisterInfo &MRI) { 247 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 248 249 // We don't rely on read-undef flag because in case of tentative schedule 250 // tracking it isn't set correctly yet. This works correctly however since 251 // use mask has been tracked before using LIS. 252 return MO.getSubReg() == 0 ? 253 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 254 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 255 } 256 257 static void 258 collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs, 259 const MachineInstr &MI, const LiveIntervals &LIS, 260 const MachineRegisterInfo &MRI) { 261 SlotIndex InstrSI; 262 for (const auto &MO : MI.operands()) { 263 if (!MO.isReg() || !MO.getReg().isVirtual()) 264 continue; 265 if (!MO.isUse() || !MO.readsReg()) 266 continue; 267 268 Register Reg = MO.getReg(); 269 if (llvm::any_of(RegMaskPairs, [Reg](const RegisterMaskPair &RM) { 270 return RM.RegUnit == Reg; 271 })) 272 continue; 273 274 LaneBitmask UseMask; 275 auto &LI = LIS.getInterval(Reg); 276 if (!LI.hasSubRanges()) 277 UseMask = MRI.getMaxLaneMaskForVReg(Reg); 278 else { 279 // For a tentative schedule LIS isn't updated yet but livemask should 280 // remain the same on any schedule. Subreg defs can be reordered but they 281 // all must dominate uses anyway. 282 if (!InstrSI) 283 InstrSI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 284 UseMask = getLiveLaneMask(LI, InstrSI, MRI); 285 } 286 287 RegMaskPairs.emplace_back(Reg, UseMask); 288 } 289 } 290 291 /////////////////////////////////////////////////////////////////////////////// 292 // GCNRPTracker 293 294 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, 295 const LiveIntervals &LIS, 296 const MachineRegisterInfo &MRI) { 297 return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI); 298 } 299 300 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, 301 const MachineRegisterInfo &MRI) { 302 LaneBitmask LiveMask; 303 if (LI.hasSubRanges()) { 304 for (const auto &S : LI.subranges()) 305 if (S.liveAt(SI)) { 306 LiveMask |= S.LaneMask; 307 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); 308 } 309 } else if (LI.liveAt(SI)) { 310 LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg()); 311 } 312 return LiveMask; 313 } 314 315 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 316 const LiveIntervals &LIS, 317 const MachineRegisterInfo &MRI) { 318 GCNRPTracker::LiveRegSet LiveRegs; 319 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 320 auto Reg = Register::index2VirtReg(I); 321 if (!LIS.hasInterval(Reg)) 322 continue; 323 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 324 if (LiveMask.any()) 325 LiveRegs[Reg] = LiveMask; 326 } 327 return LiveRegs; 328 } 329 330 void GCNRPTracker::reset(const MachineInstr &MI, 331 const LiveRegSet *LiveRegsCopy, 332 bool After) { 333 const MachineFunction &MF = *MI.getMF(); 334 MRI = &MF.getRegInfo(); 335 if (LiveRegsCopy) { 336 if (&LiveRegs != LiveRegsCopy) 337 LiveRegs = *LiveRegsCopy; 338 } else { 339 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 340 : getLiveRegsBefore(MI, LIS); 341 } 342 343 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 344 } 345 346 //////////////////////////////////////////////////////////////////////////////// 347 // GCNUpwardRPTracker 348 349 void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_, 350 const LiveRegSet &LiveRegs_) { 351 MRI = &MRI_; 352 LiveRegs = LiveRegs_; 353 LastTrackedMI = nullptr; 354 MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_); 355 } 356 357 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 358 assert(MRI && "call reset first"); 359 360 LastTrackedMI = &MI; 361 362 if (MI.isDebugInstr()) 363 return; 364 365 // Kill all defs. 366 GCNRegPressure DefPressure, ECDefPressure; 367 bool HasECDefs = false; 368 for (const MachineOperand &MO : MI.all_defs()) { 369 if (!MO.getReg().isVirtual()) 370 continue; 371 372 Register Reg = MO.getReg(); 373 LaneBitmask DefMask = getDefRegMask(MO, *MRI); 374 375 // Treat a def as fully live at the moment of definition: keep a record. 376 if (MO.isEarlyClobber()) { 377 ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 378 HasECDefs = true; 379 } else 380 DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 381 382 auto I = LiveRegs.find(Reg); 383 if (I == LiveRegs.end()) 384 continue; 385 386 LaneBitmask &LiveMask = I->second; 387 LaneBitmask PrevMask = LiveMask; 388 LiveMask &= ~DefMask; 389 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 390 if (LiveMask.none()) 391 LiveRegs.erase(I); 392 } 393 394 // Update MaxPressure with defs pressure. 395 DefPressure += CurPressure; 396 if (HasECDefs) 397 DefPressure += ECDefPressure; 398 MaxPressure = max(DefPressure, MaxPressure); 399 400 // Make uses alive. 401 SmallVector<RegisterMaskPair, 8> RegUses; 402 collectVirtualRegUses(RegUses, MI, LIS, *MRI); 403 for (const RegisterMaskPair &U : RegUses) { 404 LaneBitmask &LiveMask = LiveRegs[U.RegUnit]; 405 LaneBitmask PrevMask = LiveMask; 406 LiveMask |= U.LaneMask; 407 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 408 } 409 410 // Update MaxPressure with uses plus early-clobber defs pressure. 411 MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure) 412 : max(CurPressure, MaxPressure); 413 414 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 415 } 416 417 //////////////////////////////////////////////////////////////////////////////// 418 // GCNDownwardRPTracker 419 420 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 421 const LiveRegSet *LiveRegsCopy) { 422 MRI = &MI.getParent()->getParent()->getRegInfo(); 423 LastTrackedMI = nullptr; 424 MBBEnd = MI.getParent()->end(); 425 NextMI = &MI; 426 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 427 if (NextMI == MBBEnd) 428 return false; 429 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 430 return true; 431 } 432 433 bool GCNDownwardRPTracker::advanceBeforeNext() { 434 assert(MRI && "call reset first"); 435 if (!LastTrackedMI) 436 return NextMI == MBBEnd; 437 438 assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); 439 440 SlotIndex SI = NextMI == MBBEnd 441 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot() 442 : LIS.getInstructionIndex(*NextMI).getBaseIndex(); 443 assert(SI.isValid()); 444 445 // Remove dead registers or mask bits. 446 SmallSet<Register, 8> SeenRegs; 447 for (auto &MO : LastTrackedMI->operands()) { 448 if (!MO.isReg() || !MO.getReg().isVirtual()) 449 continue; 450 if (MO.isUse() && !MO.readsReg()) 451 continue; 452 if (!SeenRegs.insert(MO.getReg()).second) 453 continue; 454 const LiveInterval &LI = LIS.getInterval(MO.getReg()); 455 if (LI.hasSubRanges()) { 456 auto It = LiveRegs.end(); 457 for (const auto &S : LI.subranges()) { 458 if (!S.liveAt(SI)) { 459 if (It == LiveRegs.end()) { 460 It = LiveRegs.find(MO.getReg()); 461 if (It == LiveRegs.end()) 462 llvm_unreachable("register isn't live"); 463 } 464 auto PrevMask = It->second; 465 It->second &= ~S.LaneMask; 466 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI); 467 } 468 } 469 if (It != LiveRegs.end() && It->second.none()) 470 LiveRegs.erase(It); 471 } else if (!LI.liveAt(SI)) { 472 auto It = LiveRegs.find(MO.getReg()); 473 if (It == LiveRegs.end()) 474 llvm_unreachable("register isn't live"); 475 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI); 476 LiveRegs.erase(It); 477 } 478 } 479 480 MaxPressure = max(MaxPressure, CurPressure); 481 482 LastTrackedMI = nullptr; 483 484 return NextMI == MBBEnd; 485 } 486 487 void GCNDownwardRPTracker::advanceToNext() { 488 LastTrackedMI = &*NextMI++; 489 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 490 491 // Add new registers or mask bits. 492 for (const auto &MO : LastTrackedMI->all_defs()) { 493 Register Reg = MO.getReg(); 494 if (!Reg.isVirtual()) 495 continue; 496 auto &LiveMask = LiveRegs[Reg]; 497 auto PrevMask = LiveMask; 498 LiveMask |= getDefRegMask(MO, *MRI); 499 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 500 } 501 502 MaxPressure = max(MaxPressure, CurPressure); 503 } 504 505 bool GCNDownwardRPTracker::advance() { 506 if (NextMI == MBBEnd) 507 return false; 508 advanceBeforeNext(); 509 advanceToNext(); 510 return true; 511 } 512 513 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 514 while (NextMI != End) 515 if (!advance()) return false; 516 return true; 517 } 518 519 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 520 MachineBasicBlock::const_iterator End, 521 const LiveRegSet *LiveRegsCopy) { 522 reset(*Begin, LiveRegsCopy); 523 return advance(End); 524 } 525 526 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 527 const GCNRPTracker::LiveRegSet &TrackedLR, 528 const TargetRegisterInfo *TRI, StringRef Pfx) { 529 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { 530 for (auto const &P : TrackedLR) { 531 auto I = LISLR.find(P.first); 532 if (I == LISLR.end()) { 533 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 534 << " isn't found in LIS reported set\n"; 535 } else if (I->second != P.second) { 536 OS << Pfx << printReg(P.first, TRI) 537 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second) 538 << ", tracked " << PrintLaneMask(P.second) << '\n'; 539 } 540 } 541 for (auto const &P : LISLR) { 542 auto I = TrackedLR.find(P.first); 543 if (I == TrackedLR.end()) { 544 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 545 << " isn't found in tracked set\n"; 546 } 547 } 548 }); 549 } 550 551 bool GCNUpwardRPTracker::isValid() const { 552 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 553 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 554 const auto &TrackedLR = LiveRegs; 555 556 if (!isEqual(LISLR, TrackedLR)) { 557 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 558 " LIS reported livesets mismatch:\n" 559 << print(LISLR, *MRI); 560 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 561 return false; 562 } 563 564 auto LISPressure = getRegPressure(*MRI, LISLR); 565 if (LISPressure != CurPressure) { 566 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " 567 << print(CurPressure) << "LIS rpt: " << print(LISPressure); 568 return false; 569 } 570 return true; 571 } 572 573 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, 574 const MachineRegisterInfo &MRI) { 575 return Printable([&LiveRegs, &MRI](raw_ostream &OS) { 576 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 577 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 578 Register Reg = Register::index2VirtReg(I); 579 auto It = LiveRegs.find(Reg); 580 if (It != LiveRegs.end() && It->second.any()) 581 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 582 << PrintLaneMask(It->second); 583 } 584 OS << '\n'; 585 }); 586 } 587 588 void GCNRegPressure::dump() const { dbgs() << print(*this); } 589 590 static cl::opt<bool> UseDownwardTracker( 591 "amdgpu-print-rp-downward", 592 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), 593 cl::init(false), cl::Hidden); 594 595 char llvm::GCNRegPressurePrinter::ID = 0; 596 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; 597 598 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true) 599 600 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and 601 // are fully covered by Mask. 602 static LaneBitmask 603 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, 604 Register Reg, SlotIndex Begin, SlotIndex End, 605 LaneBitmask Mask = LaneBitmask::getAll()) { 606 607 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { 608 auto *Segment = LR.getSegmentContaining(Begin); 609 return Segment && Segment->contains(End); 610 }; 611 612 LaneBitmask LiveThroughMask; 613 const LiveInterval &LI = LIS.getInterval(Reg); 614 if (LI.hasSubRanges()) { 615 for (auto &SR : LI.subranges()) { 616 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) 617 LiveThroughMask |= SR.LaneMask; 618 } 619 } else { 620 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); 621 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) 622 LiveThroughMask = RegMask; 623 } 624 625 return LiveThroughMask; 626 } 627 628 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { 629 const MachineRegisterInfo &MRI = MF.getRegInfo(); 630 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 631 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); 632 633 auto &OS = dbgs(); 634 635 // Leading spaces are important for YAML syntax. 636 #define PFX " " 637 638 OS << "---\nname: " << MF.getName() << "\nbody: |\n"; 639 640 auto printRP = [](const GCNRegPressure &RP) { 641 return Printable([&RP](raw_ostream &OS) { 642 OS << format(PFX " %-5d", RP.getSGPRNum()) 643 << format(" %-5d", RP.getVGPRNum(false)); 644 }); 645 }; 646 647 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, 648 const GCNRPTracker::LiveRegSet &LISLR) { 649 if (LISLR != TrackedLR) { 650 OS << PFX " mis LIS: " << llvm::print(LISLR, MRI) 651 << reportMismatch(LISLR, TrackedLR, TRI, PFX " "); 652 } 653 }; 654 655 // Register pressure before and at an instruction (in program order). 656 SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; 657 658 for (auto &MBB : MF) { 659 RP.clear(); 660 RP.reserve(MBB.size()); 661 662 OS << PFX; 663 MBB.printName(OS); 664 OS << ":\n"; 665 666 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB); 667 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB); 668 669 GCNRPTracker::LiveRegSet LiveIn, LiveOut; 670 GCNRegPressure RPAtMBBEnd; 671 672 if (UseDownwardTracker) { 673 if (MBB.empty()) { 674 LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI); 675 RPAtMBBEnd = getRegPressure(MRI, LiveIn); 676 } else { 677 GCNDownwardRPTracker RPT(LIS); 678 RPT.reset(MBB.front()); 679 680 LiveIn = RPT.getLiveRegs(); 681 682 while (!RPT.advanceBeforeNext()) { 683 GCNRegPressure RPBeforeMI = RPT.getPressure(); 684 RPT.advanceToNext(); 685 RP.emplace_back(RPBeforeMI, RPT.getPressure()); 686 } 687 688 LiveOut = RPT.getLiveRegs(); 689 RPAtMBBEnd = RPT.getPressure(); 690 } 691 } else { 692 GCNUpwardRPTracker RPT(LIS); 693 RPT.reset(MRI, MBBEndSlot); 694 695 LiveOut = RPT.getLiveRegs(); 696 RPAtMBBEnd = RPT.getPressure(); 697 698 for (auto &MI : reverse(MBB)) { 699 RPT.resetMaxPressure(); 700 RPT.recede(MI); 701 if (!MI.isDebugInstr()) 702 RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure()); 703 } 704 705 LiveIn = RPT.getLiveRegs(); 706 } 707 708 OS << PFX " Live-in: " << llvm::print(LiveIn, MRI); 709 if (!UseDownwardTracker) 710 ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI)); 711 712 OS << PFX " SGPR VGPR\n"; 713 int I = 0; 714 for (auto &MI : MBB) { 715 if (!MI.isDebugInstr()) { 716 auto &[RPBeforeInstr, RPAtInstr] = 717 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; 718 ++I; 719 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " "; 720 } else 721 OS << PFX " "; 722 MI.print(OS); 723 } 724 OS << printRP(RPAtMBBEnd) << '\n'; 725 726 OS << PFX " Live-out:" << llvm::print(LiveOut, MRI); 727 if (UseDownwardTracker) 728 ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI)); 729 730 GCNRPTracker::LiveRegSet LiveThrough; 731 for (auto [Reg, Mask] : LiveIn) { 732 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg); 733 if (MaskIntersection.any()) { 734 LaneBitmask LTMask = getRegLiveThroughMask( 735 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection); 736 if (LTMask.any()) 737 LiveThrough[Reg] = LTMask; 738 } 739 } 740 OS << PFX " Live-thr:" << llvm::print(LiveThrough, MRI); 741 OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n'; 742 } 743 OS << "...\n"; 744 return false; 745 746 #undef PFX 747 }