1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "llvm/CodeGen/RegisterPressure.h" 16 17 using namespace llvm; 18 19 #define DEBUG_TYPE "machine-scheduler" 20 21 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 22 LLVM_DUMP_METHOD 23 void llvm::printLivesAt(SlotIndex SI, 24 const LiveIntervals &LIS, 25 const MachineRegisterInfo &MRI) { 26 dbgs() << "Live regs at " << SI << ": " 27 << *LIS.getInstructionFromIndex(SI); 28 unsigned Num = 0; 29 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 30 const unsigned Reg = Register::index2VirtReg(I); 31 if (!LIS.hasInterval(Reg)) 32 continue; 33 const auto &LI = LIS.getInterval(Reg); 34 if (LI.hasSubRanges()) { 35 bool firstTime = true; 36 for (const auto &S : LI.subranges()) { 37 if (!S.liveAt(SI)) continue; 38 if (firstTime) { 39 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 40 << '\n'; 41 firstTime = false; 42 } 43 dbgs() << " " << S << '\n'; 44 ++Num; 45 } 46 } else if (LI.liveAt(SI)) { 47 dbgs() << " " << LI << '\n'; 48 ++Num; 49 } 50 } 51 if (!Num) dbgs() << " <none>\n"; 52 } 53 #endif 54 55 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 56 const GCNRPTracker::LiveRegSet &S2) { 57 if (S1.size() != S2.size()) 58 return false; 59 60 for (const auto &P : S1) { 61 auto I = S2.find(P.first); 62 if (I == S2.end() || I->second != P.second) 63 return false; 64 } 65 return true; 66 } 67 68 69 /////////////////////////////////////////////////////////////////////////////// 70 // GCNRegPressure 71 72 unsigned GCNRegPressure::getRegKind(Register Reg, 73 const MachineRegisterInfo &MRI) { 74 assert(Reg.isVirtual()); 75 const auto RC = MRI.getRegClass(Reg); 76 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 77 return STI->isSGPRClass(RC) ? 78 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 79 STI->hasAGPRs(RC) ? 80 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) : 81 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 82 } 83 84 void GCNRegPressure::inc(unsigned Reg, 85 LaneBitmask PrevMask, 86 LaneBitmask NewMask, 87 const MachineRegisterInfo &MRI) { 88 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 89 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 90 return; 91 92 int Sign = 1; 93 if (NewMask < PrevMask) { 94 std::swap(NewMask, PrevMask); 95 Sign = -1; 96 } 97 98 switch (auto Kind = getRegKind(Reg, MRI)) { 99 case SGPR32: 100 case VGPR32: 101 case AGPR32: 102 Value[Kind] += Sign; 103 break; 104 105 case SGPR_TUPLE: 106 case VGPR_TUPLE: 107 case AGPR_TUPLE: 108 assert(PrevMask < NewMask); 109 110 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 111 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 112 113 if (PrevMask.none()) { 114 assert(NewMask.any()); 115 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 116 } 117 break; 118 119 default: llvm_unreachable("Unknown register kind"); 120 } 121 } 122 123 bool GCNRegPressure::less(const GCNSubtarget &ST, 124 const GCNRegPressure& O, 125 unsigned MaxOccupancy) const { 126 const auto SGPROcc = std::min(MaxOccupancy, 127 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 128 const auto VGPROcc = std::min(MaxOccupancy, 129 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 130 const auto OtherSGPROcc = std::min(MaxOccupancy, 131 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 132 const auto OtherVGPROcc = std::min(MaxOccupancy, 133 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 134 135 const auto Occ = std::min(SGPROcc, VGPROcc); 136 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 137 if (Occ != OtherOcc) 138 return Occ > OtherOcc; 139 140 bool SGPRImportant = SGPROcc < VGPROcc; 141 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 142 143 // if both pressures disagree on what is more important compare vgprs 144 if (SGPRImportant != OtherSGPRImportant) { 145 SGPRImportant = false; 146 } 147 148 // compare large regs pressure 149 bool SGPRFirst = SGPRImportant; 150 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 151 if (SGPRFirst) { 152 auto SW = getSGPRTuplesWeight(); 153 auto OtherSW = O.getSGPRTuplesWeight(); 154 if (SW != OtherSW) 155 return SW < OtherSW; 156 } else { 157 auto VW = getVGPRTuplesWeight(); 158 auto OtherVW = O.getVGPRTuplesWeight(); 159 if (VW != OtherVW) 160 return VW < OtherVW; 161 } 162 } 163 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 164 (getVGPRNum() < O.getVGPRNum()); 165 } 166 167 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 168 LLVM_DUMP_METHOD 169 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 170 OS << "VGPRs: " << Value[VGPR32] << ' '; 171 OS << "AGPRs: " << Value[AGPR32]; 172 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 173 OS << ", SGPRs: " << getSGPRNum(); 174 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 175 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 176 << ", LSGPR WT: " << getSGPRTuplesWeight(); 177 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 178 OS << '\n'; 179 } 180 #endif 181 182 static LaneBitmask getDefRegMask(const MachineOperand &MO, 183 const MachineRegisterInfo &MRI) { 184 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 185 186 // We don't rely on read-undef flag because in case of tentative schedule 187 // tracking it isn't set correctly yet. This works correctly however since 188 // use mask has been tracked before using LIS. 189 return MO.getSubReg() == 0 ? 190 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 191 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 192 } 193 194 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 195 const MachineRegisterInfo &MRI, 196 const LiveIntervals &LIS) { 197 assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual()); 198 199 if (auto SubReg = MO.getSubReg()) 200 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 201 202 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 203 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs 204 return MaxMask; 205 206 // For a tentative schedule LIS isn't updated yet but livemask should remain 207 // the same on any schedule. Subreg defs can be reordered but they all must 208 // dominate uses anyway. 209 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 210 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 211 } 212 213 static SmallVector<RegisterMaskPair, 8> 214 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 215 const MachineRegisterInfo &MRI) { 216 SmallVector<RegisterMaskPair, 8> Res; 217 for (const auto &MO : MI.operands()) { 218 if (!MO.isReg() || !MO.getReg().isVirtual()) 219 continue; 220 if (!MO.isUse() || !MO.readsReg()) 221 continue; 222 223 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 224 225 auto Reg = MO.getReg(); 226 auto I = llvm::find_if( 227 Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; }); 228 if (I != Res.end()) 229 I->LaneMask |= UsedMask; 230 else 231 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 232 } 233 return Res; 234 } 235 236 /////////////////////////////////////////////////////////////////////////////// 237 // GCNRPTracker 238 239 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 240 SlotIndex SI, 241 const LiveIntervals &LIS, 242 const MachineRegisterInfo &MRI) { 243 LaneBitmask LiveMask; 244 const auto &LI = LIS.getInterval(Reg); 245 if (LI.hasSubRanges()) { 246 for (const auto &S : LI.subranges()) 247 if (S.liveAt(SI)) { 248 LiveMask |= S.LaneMask; 249 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 250 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 251 } 252 } else if (LI.liveAt(SI)) { 253 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 254 } 255 return LiveMask; 256 } 257 258 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 259 const LiveIntervals &LIS, 260 const MachineRegisterInfo &MRI) { 261 GCNRPTracker::LiveRegSet LiveRegs; 262 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 263 auto Reg = Register::index2VirtReg(I); 264 if (!LIS.hasInterval(Reg)) 265 continue; 266 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 267 if (LiveMask.any()) 268 LiveRegs[Reg] = LiveMask; 269 } 270 return LiveRegs; 271 } 272 273 void GCNRPTracker::reset(const MachineInstr &MI, 274 const LiveRegSet *LiveRegsCopy, 275 bool After) { 276 const MachineFunction &MF = *MI.getMF(); 277 MRI = &MF.getRegInfo(); 278 if (LiveRegsCopy) { 279 if (&LiveRegs != LiveRegsCopy) 280 LiveRegs = *LiveRegsCopy; 281 } else { 282 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 283 : getLiveRegsBefore(MI, LIS); 284 } 285 286 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 287 } 288 289 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 290 const LiveRegSet *LiveRegsCopy) { 291 GCNRPTracker::reset(MI, LiveRegsCopy, true); 292 } 293 294 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 295 assert(MRI && "call reset first"); 296 297 LastTrackedMI = &MI; 298 299 if (MI.isDebugInstr()) 300 return; 301 302 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 303 304 // calc pressure at the MI (defs + uses) 305 auto AtMIPressure = CurPressure; 306 for (const auto &U : RegUses) { 307 auto LiveMask = LiveRegs[U.RegUnit]; 308 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 309 } 310 // update max pressure 311 MaxPressure = max(AtMIPressure, MaxPressure); 312 313 for (const auto &MO : MI.operands()) { 314 if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead()) 315 continue; 316 317 auto Reg = MO.getReg(); 318 auto I = LiveRegs.find(Reg); 319 if (I == LiveRegs.end()) 320 continue; 321 auto &LiveMask = I->second; 322 auto PrevMask = LiveMask; 323 LiveMask &= ~getDefRegMask(MO, *MRI); 324 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 325 if (LiveMask.none()) 326 LiveRegs.erase(I); 327 } 328 for (const auto &U : RegUses) { 329 auto &LiveMask = LiveRegs[U.RegUnit]; 330 auto PrevMask = LiveMask; 331 LiveMask |= U.LaneMask; 332 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 333 } 334 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 335 } 336 337 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 338 const LiveRegSet *LiveRegsCopy) { 339 MRI = &MI.getParent()->getParent()->getRegInfo(); 340 LastTrackedMI = nullptr; 341 MBBEnd = MI.getParent()->end(); 342 NextMI = &MI; 343 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 344 if (NextMI == MBBEnd) 345 return false; 346 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 347 return true; 348 } 349 350 bool GCNDownwardRPTracker::advanceBeforeNext() { 351 assert(MRI && "call reset first"); 352 353 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 354 if (NextMI == MBBEnd) 355 return false; 356 357 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 358 assert(SI.isValid()); 359 360 // Remove dead registers or mask bits. 361 for (auto &It : LiveRegs) { 362 const LiveInterval &LI = LIS.getInterval(It.first); 363 if (LI.hasSubRanges()) { 364 for (const auto &S : LI.subranges()) { 365 if (!S.liveAt(SI)) { 366 auto PrevMask = It.second; 367 It.second &= ~S.LaneMask; 368 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 369 } 370 } 371 } else if (!LI.liveAt(SI)) { 372 auto PrevMask = It.second; 373 It.second = LaneBitmask::getNone(); 374 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 375 } 376 if (It.second.none()) 377 LiveRegs.erase(It.first); 378 } 379 380 MaxPressure = max(MaxPressure, CurPressure); 381 382 return true; 383 } 384 385 void GCNDownwardRPTracker::advanceToNext() { 386 LastTrackedMI = &*NextMI++; 387 388 // Add new registers or mask bits. 389 for (const auto &MO : LastTrackedMI->operands()) { 390 if (!MO.isReg() || !MO.isDef()) 391 continue; 392 Register Reg = MO.getReg(); 393 if (!Reg.isVirtual()) 394 continue; 395 auto &LiveMask = LiveRegs[Reg]; 396 auto PrevMask = LiveMask; 397 LiveMask |= getDefRegMask(MO, *MRI); 398 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 399 } 400 401 MaxPressure = max(MaxPressure, CurPressure); 402 } 403 404 bool GCNDownwardRPTracker::advance() { 405 // If we have just called reset live set is actual. 406 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 407 return false; 408 advanceToNext(); 409 return true; 410 } 411 412 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 413 while (NextMI != End) 414 if (!advance()) return false; 415 return true; 416 } 417 418 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 419 MachineBasicBlock::const_iterator End, 420 const LiveRegSet *LiveRegsCopy) { 421 reset(*Begin, LiveRegsCopy); 422 return advance(End); 423 } 424 425 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 426 LLVM_DUMP_METHOD 427 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 428 const GCNRPTracker::LiveRegSet &TrackedLR, 429 const TargetRegisterInfo *TRI) { 430 for (auto const &P : TrackedLR) { 431 auto I = LISLR.find(P.first); 432 if (I == LISLR.end()) { 433 dbgs() << " " << printReg(P.first, TRI) 434 << ":L" << PrintLaneMask(P.second) 435 << " isn't found in LIS reported set\n"; 436 } 437 else if (I->second != P.second) { 438 dbgs() << " " << printReg(P.first, TRI) 439 << " masks doesn't match: LIS reported " 440 << PrintLaneMask(I->second) 441 << ", tracked " 442 << PrintLaneMask(P.second) 443 << '\n'; 444 } 445 } 446 for (auto const &P : LISLR) { 447 auto I = TrackedLR.find(P.first); 448 if (I == TrackedLR.end()) { 449 dbgs() << " " << printReg(P.first, TRI) 450 << ":L" << PrintLaneMask(P.second) 451 << " isn't found in tracked set\n"; 452 } 453 } 454 } 455 456 bool GCNUpwardRPTracker::isValid() const { 457 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 458 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 459 const auto &TrackedLR = LiveRegs; 460 461 if (!isEqual(LISLR, TrackedLR)) { 462 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 463 " LIS reported livesets mismatch:\n"; 464 printLivesAt(SI, LIS, *MRI); 465 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 466 return false; 467 } 468 469 auto LISPressure = getRegPressure(*MRI, LISLR); 470 if (LISPressure != CurPressure) { 471 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 472 CurPressure.print(dbgs()); 473 dbgs() << "LIS rpt: "; 474 LISPressure.print(dbgs()); 475 return false; 476 } 477 return true; 478 } 479 480 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 481 const MachineRegisterInfo &MRI) { 482 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 483 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 484 unsigned Reg = Register::index2VirtReg(I); 485 auto It = LiveRegs.find(Reg); 486 if (It != LiveRegs.end() && It->second.any()) 487 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 488 << PrintLaneMask(It->second); 489 } 490 OS << '\n'; 491 } 492 #endif 493