1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "llvm/CodeGen/RegisterPressure.h" 16 17 using namespace llvm; 18 19 #define DEBUG_TYPE "machine-scheduler" 20 21 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 22 const GCNRPTracker::LiveRegSet &S2) { 23 if (S1.size() != S2.size()) 24 return false; 25 26 for (const auto &P : S1) { 27 auto I = S2.find(P.first); 28 if (I == S2.end() || I->second != P.second) 29 return false; 30 } 31 return true; 32 } 33 34 35 /////////////////////////////////////////////////////////////////////////////// 36 // GCNRegPressure 37 38 unsigned GCNRegPressure::getRegKind(Register Reg, 39 const MachineRegisterInfo &MRI) { 40 assert(Reg.isVirtual()); 41 const auto RC = MRI.getRegClass(Reg); 42 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 43 return STI->isSGPRClass(RC) 44 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) 45 : STI->isAGPRClass(RC) 46 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) 47 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 48 } 49 50 void GCNRegPressure::inc(unsigned Reg, 51 LaneBitmask PrevMask, 52 LaneBitmask NewMask, 53 const MachineRegisterInfo &MRI) { 54 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 55 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 56 return; 57 58 int Sign = 1; 59 if (NewMask < PrevMask) { 60 std::swap(NewMask, PrevMask); 61 Sign = -1; 62 } 63 64 switch (auto Kind = getRegKind(Reg, MRI)) { 65 case SGPR32: 66 case VGPR32: 67 case AGPR32: 68 Value[Kind] += Sign; 69 break; 70 71 case SGPR_TUPLE: 72 case VGPR_TUPLE: 73 case AGPR_TUPLE: 74 assert(PrevMask < NewMask); 75 76 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 77 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 78 79 if (PrevMask.none()) { 80 assert(NewMask.any()); 81 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 82 } 83 break; 84 85 default: llvm_unreachable("Unknown register kind"); 86 } 87 } 88 89 bool GCNRegPressure::less(const GCNSubtarget &ST, 90 const GCNRegPressure& O, 91 unsigned MaxOccupancy) const { 92 const auto SGPROcc = std::min(MaxOccupancy, 93 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 94 const auto VGPROcc = 95 std::min(MaxOccupancy, 96 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()))); 97 const auto OtherSGPROcc = std::min(MaxOccupancy, 98 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 99 const auto OtherVGPROcc = 100 std::min(MaxOccupancy, 101 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()))); 102 103 const auto Occ = std::min(SGPROcc, VGPROcc); 104 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 105 if (Occ != OtherOcc) 106 return Occ > OtherOcc; 107 108 bool SGPRImportant = SGPROcc < VGPROcc; 109 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 110 111 // if both pressures disagree on what is more important compare vgprs 112 if (SGPRImportant != OtherSGPRImportant) { 113 SGPRImportant = false; 114 } 115 116 // compare large regs pressure 117 bool SGPRFirst = SGPRImportant; 118 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 119 if (SGPRFirst) { 120 auto SW = getSGPRTuplesWeight(); 121 auto OtherSW = O.getSGPRTuplesWeight(); 122 if (SW != OtherSW) 123 return SW < OtherSW; 124 } else { 125 auto VW = getVGPRTuplesWeight(); 126 auto OtherVW = O.getVGPRTuplesWeight(); 127 if (VW != OtherVW) 128 return VW < OtherVW; 129 } 130 } 131 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 132 (getVGPRNum(ST.hasGFX90AInsts()) < 133 O.getVGPRNum(ST.hasGFX90AInsts())); 134 } 135 136 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 137 LLVM_DUMP_METHOD 138 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) { 139 return Printable([&RP, ST](raw_ostream &OS) { 140 OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' ' 141 << "AGPRs: " << RP.getAGPRNum(); 142 if (ST) 143 OS << "(O" 144 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts())) 145 << ')'; 146 OS << ", SGPRs: " << RP.getSGPRNum(); 147 if (ST) 148 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')'; 149 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() 150 << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); 151 if (ST) 152 OS << " -> Occ: " << RP.getOccupancy(*ST); 153 OS << '\n'; 154 }); 155 } 156 #endif 157 158 static LaneBitmask getDefRegMask(const MachineOperand &MO, 159 const MachineRegisterInfo &MRI) { 160 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 161 162 // We don't rely on read-undef flag because in case of tentative schedule 163 // tracking it isn't set correctly yet. This works correctly however since 164 // use mask has been tracked before using LIS. 165 return MO.getSubReg() == 0 ? 166 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 167 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 168 } 169 170 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 171 const MachineRegisterInfo &MRI, 172 const LiveIntervals &LIS) { 173 assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual()); 174 175 if (auto SubReg = MO.getSubReg()) 176 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 177 178 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 179 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs 180 return MaxMask; 181 182 // For a tentative schedule LIS isn't updated yet but livemask should remain 183 // the same on any schedule. Subreg defs can be reordered but they all must 184 // dominate uses anyway. 185 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 186 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 187 } 188 189 static SmallVector<RegisterMaskPair, 8> 190 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 191 const MachineRegisterInfo &MRI) { 192 SmallVector<RegisterMaskPair, 8> Res; 193 for (const auto &MO : MI.operands()) { 194 if (!MO.isReg() || !MO.getReg().isVirtual()) 195 continue; 196 if (!MO.isUse() || !MO.readsReg()) 197 continue; 198 199 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 200 201 auto Reg = MO.getReg(); 202 auto I = llvm::find_if( 203 Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; }); 204 if (I != Res.end()) 205 I->LaneMask |= UsedMask; 206 else 207 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 208 } 209 return Res; 210 } 211 212 /////////////////////////////////////////////////////////////////////////////// 213 // GCNRPTracker 214 215 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 216 SlotIndex SI, 217 const LiveIntervals &LIS, 218 const MachineRegisterInfo &MRI) { 219 LaneBitmask LiveMask; 220 const auto &LI = LIS.getInterval(Reg); 221 if (LI.hasSubRanges()) { 222 for (const auto &S : LI.subranges()) 223 if (S.liveAt(SI)) { 224 LiveMask |= S.LaneMask; 225 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 226 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 227 } 228 } else if (LI.liveAt(SI)) { 229 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 230 } 231 return LiveMask; 232 } 233 234 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 235 const LiveIntervals &LIS, 236 const MachineRegisterInfo &MRI) { 237 GCNRPTracker::LiveRegSet LiveRegs; 238 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 239 auto Reg = Register::index2VirtReg(I); 240 if (!LIS.hasInterval(Reg)) 241 continue; 242 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 243 if (LiveMask.any()) 244 LiveRegs[Reg] = LiveMask; 245 } 246 return LiveRegs; 247 } 248 249 void GCNRPTracker::reset(const MachineInstr &MI, 250 const LiveRegSet *LiveRegsCopy, 251 bool After) { 252 const MachineFunction &MF = *MI.getMF(); 253 MRI = &MF.getRegInfo(); 254 if (LiveRegsCopy) { 255 if (&LiveRegs != LiveRegsCopy) 256 LiveRegs = *LiveRegsCopy; 257 } else { 258 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 259 : getLiveRegsBefore(MI, LIS); 260 } 261 262 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 263 } 264 265 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 266 const LiveRegSet *LiveRegsCopy) { 267 GCNRPTracker::reset(MI, LiveRegsCopy, true); 268 } 269 270 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 271 assert(MRI && "call reset first"); 272 273 LastTrackedMI = &MI; 274 275 if (MI.isDebugInstr()) 276 return; 277 278 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 279 280 // calc pressure at the MI (defs + uses) 281 auto AtMIPressure = CurPressure; 282 for (const auto &U : RegUses) { 283 auto LiveMask = LiveRegs[U.RegUnit]; 284 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 285 } 286 // update max pressure 287 MaxPressure = max(AtMIPressure, MaxPressure); 288 289 for (const auto &MO : MI.all_defs()) { 290 if (!MO.getReg().isVirtual() || MO.isDead()) 291 continue; 292 293 auto Reg = MO.getReg(); 294 auto I = LiveRegs.find(Reg); 295 if (I == LiveRegs.end()) 296 continue; 297 auto &LiveMask = I->second; 298 auto PrevMask = LiveMask; 299 LiveMask &= ~getDefRegMask(MO, *MRI); 300 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 301 if (LiveMask.none()) 302 LiveRegs.erase(I); 303 } 304 for (const auto &U : RegUses) { 305 auto &LiveMask = LiveRegs[U.RegUnit]; 306 auto PrevMask = LiveMask; 307 LiveMask |= U.LaneMask; 308 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 309 } 310 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 311 } 312 313 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 314 const LiveRegSet *LiveRegsCopy) { 315 MRI = &MI.getParent()->getParent()->getRegInfo(); 316 LastTrackedMI = nullptr; 317 MBBEnd = MI.getParent()->end(); 318 NextMI = &MI; 319 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 320 if (NextMI == MBBEnd) 321 return false; 322 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 323 return true; 324 } 325 326 bool GCNDownwardRPTracker::advanceBeforeNext() { 327 assert(MRI && "call reset first"); 328 if (!LastTrackedMI) 329 return NextMI == MBBEnd; 330 331 assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); 332 333 SlotIndex SI = NextMI == MBBEnd 334 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot() 335 : LIS.getInstructionIndex(*NextMI).getBaseIndex(); 336 assert(SI.isValid()); 337 338 // Remove dead registers or mask bits. 339 SmallSet<Register, 8> SeenRegs; 340 for (auto &MO : LastTrackedMI->operands()) { 341 if (!MO.isReg() || !MO.getReg().isVirtual()) 342 continue; 343 if (MO.isUse() && !MO.readsReg()) 344 continue; 345 if (!SeenRegs.insert(MO.getReg()).second) 346 continue; 347 const LiveInterval &LI = LIS.getInterval(MO.getReg()); 348 if (LI.hasSubRanges()) { 349 auto It = LiveRegs.end(); 350 for (const auto &S : LI.subranges()) { 351 if (!S.liveAt(SI)) { 352 if (It == LiveRegs.end()) { 353 It = LiveRegs.find(MO.getReg()); 354 if (It == LiveRegs.end()) 355 llvm_unreachable("register isn't live"); 356 } 357 auto PrevMask = It->second; 358 It->second &= ~S.LaneMask; 359 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI); 360 } 361 } 362 if (It != LiveRegs.end() && It->second.none()) 363 LiveRegs.erase(It); 364 } else if (!LI.liveAt(SI)) { 365 auto It = LiveRegs.find(MO.getReg()); 366 if (It == LiveRegs.end()) 367 llvm_unreachable("register isn't live"); 368 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI); 369 LiveRegs.erase(It); 370 } 371 } 372 373 MaxPressure = max(MaxPressure, CurPressure); 374 375 LastTrackedMI = nullptr; 376 377 return NextMI == MBBEnd; 378 } 379 380 void GCNDownwardRPTracker::advanceToNext() { 381 LastTrackedMI = &*NextMI++; 382 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 383 384 // Add new registers or mask bits. 385 for (const auto &MO : LastTrackedMI->all_defs()) { 386 Register Reg = MO.getReg(); 387 if (!Reg.isVirtual()) 388 continue; 389 auto &LiveMask = LiveRegs[Reg]; 390 auto PrevMask = LiveMask; 391 LiveMask |= getDefRegMask(MO, *MRI); 392 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 393 } 394 395 MaxPressure = max(MaxPressure, CurPressure); 396 } 397 398 bool GCNDownwardRPTracker::advance() { 399 if (NextMI == MBBEnd) 400 return false; 401 advanceBeforeNext(); 402 advanceToNext(); 403 return true; 404 } 405 406 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 407 while (NextMI != End) 408 if (!advance()) return false; 409 return true; 410 } 411 412 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 413 MachineBasicBlock::const_iterator End, 414 const LiveRegSet *LiveRegsCopy) { 415 reset(*Begin, LiveRegsCopy); 416 return advance(End); 417 } 418 419 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 420 LLVM_DUMP_METHOD 421 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 422 const GCNRPTracker::LiveRegSet &TrackedLR, 423 const TargetRegisterInfo *TRI) { 424 return Printable([&LISLR, &TrackedLR, TRI](raw_ostream &OS) { 425 for (auto const &P : TrackedLR) { 426 auto I = LISLR.find(P.first); 427 if (I == LISLR.end()) { 428 OS << " " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 429 << " isn't found in LIS reported set\n"; 430 } else if (I->second != P.second) { 431 OS << " " << printReg(P.first, TRI) 432 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second) 433 << ", tracked " << PrintLaneMask(P.second) << '\n'; 434 } 435 } 436 for (auto const &P : LISLR) { 437 auto I = TrackedLR.find(P.first); 438 if (I == TrackedLR.end()) { 439 OS << " " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 440 << " isn't found in tracked set\n"; 441 } 442 } 443 }); 444 } 445 446 bool GCNUpwardRPTracker::isValid() const { 447 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 448 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 449 const auto &TrackedLR = LiveRegs; 450 451 if (!isEqual(LISLR, TrackedLR)) { 452 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 453 " LIS reported livesets mismatch:\n" 454 << print(LISLR, *MRI); 455 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 456 return false; 457 } 458 459 auto LISPressure = getRegPressure(*MRI, LISLR); 460 if (LISPressure != CurPressure) { 461 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " 462 << print(CurPressure) << "LIS rpt: " << print(LISPressure); 463 return false; 464 } 465 return true; 466 } 467 468 LLVM_DUMP_METHOD 469 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, 470 const MachineRegisterInfo &MRI) { 471 return Printable([&LiveRegs, &MRI](raw_ostream &OS) { 472 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 473 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 474 Register Reg = Register::index2VirtReg(I); 475 auto It = LiveRegs.find(Reg); 476 if (It != LiveRegs.end() && It->second.any()) 477 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 478 << PrintLaneMask(It->second); 479 } 480 OS << '\n'; 481 }); 482 } 483 484 LLVM_DUMP_METHOD 485 void GCNRegPressure::dump() const { dbgs() << print(*this); } 486 487 #endif 488