1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "llvm/CodeGen/RegisterPressure.h" 16 17 using namespace llvm; 18 19 #define DEBUG_TYPE "machine-scheduler" 20 21 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 22 LLVM_DUMP_METHOD 23 void llvm::printLivesAt(SlotIndex SI, 24 const LiveIntervals &LIS, 25 const MachineRegisterInfo &MRI) { 26 dbgs() << "Live regs at " << SI << ": " 27 << *LIS.getInstructionFromIndex(SI); 28 unsigned Num = 0; 29 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 30 const Register Reg = Register::index2VirtReg(I); 31 if (!LIS.hasInterval(Reg)) 32 continue; 33 const auto &LI = LIS.getInterval(Reg); 34 if (LI.hasSubRanges()) { 35 bool firstTime = true; 36 for (const auto &S : LI.subranges()) { 37 if (!S.liveAt(SI)) continue; 38 if (firstTime) { 39 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 40 << '\n'; 41 firstTime = false; 42 } 43 dbgs() << " " << S << '\n'; 44 ++Num; 45 } 46 } else if (LI.liveAt(SI)) { 47 dbgs() << " " << LI << '\n'; 48 ++Num; 49 } 50 } 51 if (!Num) dbgs() << " <none>\n"; 52 } 53 #endif 54 55 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 56 const GCNRPTracker::LiveRegSet &S2) { 57 if (S1.size() != S2.size()) 58 return false; 59 60 for (const auto &P : S1) { 61 auto I = S2.find(P.first); 62 if (I == S2.end() || I->second != P.second) 63 return false; 64 } 65 return true; 66 } 67 68 69 /////////////////////////////////////////////////////////////////////////////// 70 // GCNRegPressure 71 72 unsigned GCNRegPressure::getRegKind(Register Reg, 73 const MachineRegisterInfo &MRI) { 74 assert(Reg.isVirtual()); 75 const auto RC = MRI.getRegClass(Reg); 76 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 77 return STI->isSGPRClass(RC) 78 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) 79 : STI->isAGPRClass(RC) 80 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) 81 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 82 } 83 84 void GCNRegPressure::inc(unsigned Reg, 85 LaneBitmask PrevMask, 86 LaneBitmask NewMask, 87 const MachineRegisterInfo &MRI) { 88 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 89 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 90 return; 91 92 int Sign = 1; 93 if (NewMask < PrevMask) { 94 std::swap(NewMask, PrevMask); 95 Sign = -1; 96 } 97 98 switch (auto Kind = getRegKind(Reg, MRI)) { 99 case SGPR32: 100 case VGPR32: 101 case AGPR32: 102 Value[Kind] += Sign; 103 break; 104 105 case SGPR_TUPLE: 106 case VGPR_TUPLE: 107 case AGPR_TUPLE: 108 assert(PrevMask < NewMask); 109 110 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 111 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 112 113 if (PrevMask.none()) { 114 assert(NewMask.any()); 115 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 116 } 117 break; 118 119 default: llvm_unreachable("Unknown register kind"); 120 } 121 } 122 123 bool GCNRegPressure::less(const GCNSubtarget &ST, 124 const GCNRegPressure& O, 125 unsigned MaxOccupancy) const { 126 const auto SGPROcc = std::min(MaxOccupancy, 127 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 128 const auto VGPROcc = 129 std::min(MaxOccupancy, 130 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()))); 131 const auto OtherSGPROcc = std::min(MaxOccupancy, 132 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 133 const auto OtherVGPROcc = 134 std::min(MaxOccupancy, 135 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()))); 136 137 const auto Occ = std::min(SGPROcc, VGPROcc); 138 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 139 if (Occ != OtherOcc) 140 return Occ > OtherOcc; 141 142 bool SGPRImportant = SGPROcc < VGPROcc; 143 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 144 145 // if both pressures disagree on what is more important compare vgprs 146 if (SGPRImportant != OtherSGPRImportant) { 147 SGPRImportant = false; 148 } 149 150 // compare large regs pressure 151 bool SGPRFirst = SGPRImportant; 152 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 153 if (SGPRFirst) { 154 auto SW = getSGPRTuplesWeight(); 155 auto OtherSW = O.getSGPRTuplesWeight(); 156 if (SW != OtherSW) 157 return SW < OtherSW; 158 } else { 159 auto VW = getVGPRTuplesWeight(); 160 auto OtherVW = O.getVGPRTuplesWeight(); 161 if (VW != OtherVW) 162 return VW < OtherVW; 163 } 164 } 165 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 166 (getVGPRNum(ST.hasGFX90AInsts()) < 167 O.getVGPRNum(ST.hasGFX90AInsts())); 168 } 169 170 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 171 LLVM_DUMP_METHOD 172 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 173 OS << "VGPRs: " << Value[VGPR32] << ' '; 174 OS << "AGPRs: " << Value[AGPR32]; 175 if (ST) OS << "(O" 176 << ST->getOccupancyWithNumVGPRs(getVGPRNum(ST->hasGFX90AInsts())) 177 << ')'; 178 OS << ", SGPRs: " << getSGPRNum(); 179 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 180 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 181 << ", LSGPR WT: " << getSGPRTuplesWeight(); 182 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 183 OS << '\n'; 184 } 185 #endif 186 187 static LaneBitmask getDefRegMask(const MachineOperand &MO, 188 const MachineRegisterInfo &MRI) { 189 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 190 191 // We don't rely on read-undef flag because in case of tentative schedule 192 // tracking it isn't set correctly yet. This works correctly however since 193 // use mask has been tracked before using LIS. 194 return MO.getSubReg() == 0 ? 195 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 196 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 197 } 198 199 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 200 const MachineRegisterInfo &MRI, 201 const LiveIntervals &LIS) { 202 assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual()); 203 204 if (auto SubReg = MO.getSubReg()) 205 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 206 207 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 208 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs 209 return MaxMask; 210 211 // For a tentative schedule LIS isn't updated yet but livemask should remain 212 // the same on any schedule. Subreg defs can be reordered but they all must 213 // dominate uses anyway. 214 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 215 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 216 } 217 218 static SmallVector<RegisterMaskPair, 8> 219 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 220 const MachineRegisterInfo &MRI) { 221 SmallVector<RegisterMaskPair, 8> Res; 222 for (const auto &MO : MI.operands()) { 223 if (!MO.isReg() || !MO.getReg().isVirtual()) 224 continue; 225 if (!MO.isUse() || !MO.readsReg()) 226 continue; 227 228 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 229 230 auto Reg = MO.getReg(); 231 auto I = llvm::find_if( 232 Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; }); 233 if (I != Res.end()) 234 I->LaneMask |= UsedMask; 235 else 236 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 237 } 238 return Res; 239 } 240 241 /////////////////////////////////////////////////////////////////////////////// 242 // GCNRPTracker 243 244 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 245 SlotIndex SI, 246 const LiveIntervals &LIS, 247 const MachineRegisterInfo &MRI) { 248 LaneBitmask LiveMask; 249 const auto &LI = LIS.getInterval(Reg); 250 if (LI.hasSubRanges()) { 251 for (const auto &S : LI.subranges()) 252 if (S.liveAt(SI)) { 253 LiveMask |= S.LaneMask; 254 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 255 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 256 } 257 } else if (LI.liveAt(SI)) { 258 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 259 } 260 return LiveMask; 261 } 262 263 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 264 const LiveIntervals &LIS, 265 const MachineRegisterInfo &MRI) { 266 GCNRPTracker::LiveRegSet LiveRegs; 267 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 268 auto Reg = Register::index2VirtReg(I); 269 if (!LIS.hasInterval(Reg)) 270 continue; 271 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 272 if (LiveMask.any()) 273 LiveRegs[Reg] = LiveMask; 274 } 275 return LiveRegs; 276 } 277 278 void GCNRPTracker::reset(const MachineInstr &MI, 279 const LiveRegSet *LiveRegsCopy, 280 bool After) { 281 const MachineFunction &MF = *MI.getMF(); 282 MRI = &MF.getRegInfo(); 283 if (LiveRegsCopy) { 284 if (&LiveRegs != LiveRegsCopy) 285 LiveRegs = *LiveRegsCopy; 286 } else { 287 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 288 : getLiveRegsBefore(MI, LIS); 289 } 290 291 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 292 } 293 294 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 295 const LiveRegSet *LiveRegsCopy) { 296 GCNRPTracker::reset(MI, LiveRegsCopy, true); 297 } 298 299 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 300 assert(MRI && "call reset first"); 301 302 LastTrackedMI = &MI; 303 304 if (MI.isDebugInstr()) 305 return; 306 307 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 308 309 // calc pressure at the MI (defs + uses) 310 auto AtMIPressure = CurPressure; 311 for (const auto &U : RegUses) { 312 auto LiveMask = LiveRegs[U.RegUnit]; 313 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 314 } 315 // update max pressure 316 MaxPressure = max(AtMIPressure, MaxPressure); 317 318 for (const auto &MO : MI.operands()) { 319 if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead()) 320 continue; 321 322 auto Reg = MO.getReg(); 323 auto I = LiveRegs.find(Reg); 324 if (I == LiveRegs.end()) 325 continue; 326 auto &LiveMask = I->second; 327 auto PrevMask = LiveMask; 328 LiveMask &= ~getDefRegMask(MO, *MRI); 329 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 330 if (LiveMask.none()) 331 LiveRegs.erase(I); 332 } 333 for (const auto &U : RegUses) { 334 auto &LiveMask = LiveRegs[U.RegUnit]; 335 auto PrevMask = LiveMask; 336 LiveMask |= U.LaneMask; 337 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 338 } 339 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 340 } 341 342 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 343 const LiveRegSet *LiveRegsCopy) { 344 MRI = &MI.getParent()->getParent()->getRegInfo(); 345 LastTrackedMI = nullptr; 346 MBBEnd = MI.getParent()->end(); 347 NextMI = &MI; 348 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 349 if (NextMI == MBBEnd) 350 return false; 351 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 352 return true; 353 } 354 355 bool GCNDownwardRPTracker::advanceBeforeNext() { 356 assert(MRI && "call reset first"); 357 358 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 359 if (NextMI == MBBEnd) 360 return false; 361 362 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 363 assert(SI.isValid()); 364 365 // Remove dead registers or mask bits. 366 for (auto &It : LiveRegs) { 367 const LiveInterval &LI = LIS.getInterval(It.first); 368 if (LI.hasSubRanges()) { 369 for (const auto &S : LI.subranges()) { 370 if (!S.liveAt(SI)) { 371 auto PrevMask = It.second; 372 It.second &= ~S.LaneMask; 373 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 374 } 375 } 376 } else if (!LI.liveAt(SI)) { 377 auto PrevMask = It.second; 378 It.second = LaneBitmask::getNone(); 379 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 380 } 381 if (It.second.none()) 382 LiveRegs.erase(It.first); 383 } 384 385 MaxPressure = max(MaxPressure, CurPressure); 386 387 return true; 388 } 389 390 void GCNDownwardRPTracker::advanceToNext() { 391 LastTrackedMI = &*NextMI++; 392 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 393 394 // Add new registers or mask bits. 395 for (const auto &MO : LastTrackedMI->operands()) { 396 if (!MO.isReg() || !MO.isDef()) 397 continue; 398 Register Reg = MO.getReg(); 399 if (!Reg.isVirtual()) 400 continue; 401 auto &LiveMask = LiveRegs[Reg]; 402 auto PrevMask = LiveMask; 403 LiveMask |= getDefRegMask(MO, *MRI); 404 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 405 } 406 407 MaxPressure = max(MaxPressure, CurPressure); 408 } 409 410 bool GCNDownwardRPTracker::advance() { 411 // If we have just called reset live set is actual. 412 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 413 return false; 414 advanceToNext(); 415 return true; 416 } 417 418 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 419 while (NextMI != End) 420 if (!advance()) return false; 421 return true; 422 } 423 424 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 425 MachineBasicBlock::const_iterator End, 426 const LiveRegSet *LiveRegsCopy) { 427 reset(*Begin, LiveRegsCopy); 428 return advance(End); 429 } 430 431 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 432 LLVM_DUMP_METHOD 433 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 434 const GCNRPTracker::LiveRegSet &TrackedLR, 435 const TargetRegisterInfo *TRI) { 436 for (auto const &P : TrackedLR) { 437 auto I = LISLR.find(P.first); 438 if (I == LISLR.end()) { 439 dbgs() << " " << printReg(P.first, TRI) 440 << ":L" << PrintLaneMask(P.second) 441 << " isn't found in LIS reported set\n"; 442 } 443 else if (I->second != P.second) { 444 dbgs() << " " << printReg(P.first, TRI) 445 << " masks doesn't match: LIS reported " 446 << PrintLaneMask(I->second) 447 << ", tracked " 448 << PrintLaneMask(P.second) 449 << '\n'; 450 } 451 } 452 for (auto const &P : LISLR) { 453 auto I = TrackedLR.find(P.first); 454 if (I == TrackedLR.end()) { 455 dbgs() << " " << printReg(P.first, TRI) 456 << ":L" << PrintLaneMask(P.second) 457 << " isn't found in tracked set\n"; 458 } 459 } 460 } 461 462 bool GCNUpwardRPTracker::isValid() const { 463 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 464 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 465 const auto &TrackedLR = LiveRegs; 466 467 if (!isEqual(LISLR, TrackedLR)) { 468 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 469 " LIS reported livesets mismatch:\n"; 470 printLivesAt(SI, LIS, *MRI); 471 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 472 return false; 473 } 474 475 auto LISPressure = getRegPressure(*MRI, LISLR); 476 if (LISPressure != CurPressure) { 477 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 478 CurPressure.print(dbgs()); 479 dbgs() << "LIS rpt: "; 480 LISPressure.print(dbgs()); 481 return false; 482 } 483 return true; 484 } 485 486 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 487 const MachineRegisterInfo &MRI) { 488 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 489 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 490 Register Reg = Register::index2VirtReg(I); 491 auto It = LiveRegs.find(Reg); 492 if (It != LiveRegs.end() && It->second.any()) 493 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 494 << PrintLaneMask(It->second); 495 } 496 OS << '\n'; 497 } 498 #endif 499