1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPUSubtarget.h" 16 #include "SIRegisterInfo.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/CodeGen/LiveInterval.h" 19 #include "llvm/CodeGen/LiveIntervals.h" 20 #include "llvm/CodeGen/MachineInstr.h" 21 #include "llvm/CodeGen/MachineOperand.h" 22 #include "llvm/CodeGen/MachineRegisterInfo.h" 23 #include "llvm/CodeGen/RegisterPressure.h" 24 #include "llvm/CodeGen/SlotIndexes.h" 25 #include "llvm/CodeGen/TargetRegisterInfo.h" 26 #include "llvm/Config/llvm-config.h" 27 #include "llvm/MC/LaneBitmask.h" 28 #include "llvm/Support/Compiler.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/ErrorHandling.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include <algorithm> 33 #include <cassert> 34 35 using namespace llvm; 36 37 #define DEBUG_TYPE "machine-scheduler" 38 39 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 40 LLVM_DUMP_METHOD 41 void llvm::printLivesAt(SlotIndex SI, 42 const LiveIntervals &LIS, 43 const MachineRegisterInfo &MRI) { 44 dbgs() << "Live regs at " << SI << ": " 45 << *LIS.getInstructionFromIndex(SI); 46 unsigned Num = 0; 47 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 48 const unsigned Reg = Register::index2VirtReg(I); 49 if (!LIS.hasInterval(Reg)) 50 continue; 51 const auto &LI = LIS.getInterval(Reg); 52 if (LI.hasSubRanges()) { 53 bool firstTime = true; 54 for (const auto &S : LI.subranges()) { 55 if (!S.liveAt(SI)) continue; 56 if (firstTime) { 57 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 58 << '\n'; 59 firstTime = false; 60 } 61 dbgs() << " " << S << '\n'; 62 ++Num; 63 } 64 } else if (LI.liveAt(SI)) { 65 dbgs() << " " << LI << '\n'; 66 ++Num; 67 } 68 } 69 if (!Num) dbgs() << " <none>\n"; 70 } 71 #endif 72 73 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 74 const GCNRPTracker::LiveRegSet &S2) { 75 if (S1.size() != S2.size()) 76 return false; 77 78 for (const auto &P : S1) { 79 auto I = S2.find(P.first); 80 if (I == S2.end() || I->second != P.second) 81 return false; 82 } 83 return true; 84 } 85 86 87 /////////////////////////////////////////////////////////////////////////////// 88 // GCNRegPressure 89 90 unsigned GCNRegPressure::getRegKind(unsigned Reg, 91 const MachineRegisterInfo &MRI) { 92 assert(Register::isVirtualRegister(Reg)); 93 const auto RC = MRI.getRegClass(Reg); 94 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 95 return STI->isSGPRClass(RC) ? 96 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 97 STI->hasAGPRs(RC) ? 98 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) : 99 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 100 } 101 102 void GCNRegPressure::inc(unsigned Reg, 103 LaneBitmask PrevMask, 104 LaneBitmask NewMask, 105 const MachineRegisterInfo &MRI) { 106 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 107 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 108 return; 109 110 int Sign = 1; 111 if (NewMask < PrevMask) { 112 std::swap(NewMask, PrevMask); 113 Sign = -1; 114 } 115 116 switch (auto Kind = getRegKind(Reg, MRI)) { 117 case SGPR32: 118 case VGPR32: 119 case AGPR32: 120 Value[Kind] += Sign; 121 break; 122 123 case SGPR_TUPLE: 124 case VGPR_TUPLE: 125 case AGPR_TUPLE: 126 assert(PrevMask < NewMask); 127 128 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 129 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 130 131 if (PrevMask.none()) { 132 assert(NewMask.any()); 133 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 134 } 135 break; 136 137 default: llvm_unreachable("Unknown register kind"); 138 } 139 } 140 141 bool GCNRegPressure::less(const GCNSubtarget &ST, 142 const GCNRegPressure& O, 143 unsigned MaxOccupancy) const { 144 const auto SGPROcc = std::min(MaxOccupancy, 145 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 146 const auto VGPROcc = std::min(MaxOccupancy, 147 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 148 const auto OtherSGPROcc = std::min(MaxOccupancy, 149 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 150 const auto OtherVGPROcc = std::min(MaxOccupancy, 151 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 152 153 const auto Occ = std::min(SGPROcc, VGPROcc); 154 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 155 if (Occ != OtherOcc) 156 return Occ > OtherOcc; 157 158 bool SGPRImportant = SGPROcc < VGPROcc; 159 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 160 161 // if both pressures disagree on what is more important compare vgprs 162 if (SGPRImportant != OtherSGPRImportant) { 163 SGPRImportant = false; 164 } 165 166 // compare large regs pressure 167 bool SGPRFirst = SGPRImportant; 168 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 169 if (SGPRFirst) { 170 auto SW = getSGPRTuplesWeight(); 171 auto OtherSW = O.getSGPRTuplesWeight(); 172 if (SW != OtherSW) 173 return SW < OtherSW; 174 } else { 175 auto VW = getVGPRTuplesWeight(); 176 auto OtherVW = O.getVGPRTuplesWeight(); 177 if (VW != OtherVW) 178 return VW < OtherVW; 179 } 180 } 181 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 182 (getVGPRNum() < O.getVGPRNum()); 183 } 184 185 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 186 LLVM_DUMP_METHOD 187 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 188 OS << "VGPRs: " << Value[VGPR32] << ' '; 189 OS << "AGPRs: " << Value[AGPR32]; 190 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 191 OS << ", SGPRs: " << getSGPRNum(); 192 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 193 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 194 << ", LSGPR WT: " << getSGPRTuplesWeight(); 195 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 196 OS << '\n'; 197 } 198 #endif 199 200 static LaneBitmask getDefRegMask(const MachineOperand &MO, 201 const MachineRegisterInfo &MRI) { 202 assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg())); 203 204 // We don't rely on read-undef flag because in case of tentative schedule 205 // tracking it isn't set correctly yet. This works correctly however since 206 // use mask has been tracked before using LIS. 207 return MO.getSubReg() == 0 ? 208 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 209 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 210 } 211 212 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 213 const MachineRegisterInfo &MRI, 214 const LiveIntervals &LIS) { 215 assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg())); 216 217 if (auto SubReg = MO.getSubReg()) 218 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 219 220 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 221 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs 222 return MaxMask; 223 224 // For a tentative schedule LIS isn't updated yet but livemask should remain 225 // the same on any schedule. Subreg defs can be reordered but they all must 226 // dominate uses anyway. 227 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 228 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 229 } 230 231 static SmallVector<RegisterMaskPair, 8> 232 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 233 const MachineRegisterInfo &MRI) { 234 SmallVector<RegisterMaskPair, 8> Res; 235 for (const auto &MO : MI.operands()) { 236 if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg())) 237 continue; 238 if (!MO.isUse() || !MO.readsReg()) 239 continue; 240 241 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 242 243 auto Reg = MO.getReg(); 244 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) { 245 return RM.RegUnit == Reg; 246 }); 247 if (I != Res.end()) 248 I->LaneMask |= UsedMask; 249 else 250 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 251 } 252 return Res; 253 } 254 255 /////////////////////////////////////////////////////////////////////////////// 256 // GCNRPTracker 257 258 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 259 SlotIndex SI, 260 const LiveIntervals &LIS, 261 const MachineRegisterInfo &MRI) { 262 LaneBitmask LiveMask; 263 const auto &LI = LIS.getInterval(Reg); 264 if (LI.hasSubRanges()) { 265 for (const auto &S : LI.subranges()) 266 if (S.liveAt(SI)) { 267 LiveMask |= S.LaneMask; 268 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 269 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 270 } 271 } else if (LI.liveAt(SI)) { 272 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 273 } 274 return LiveMask; 275 } 276 277 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 278 const LiveIntervals &LIS, 279 const MachineRegisterInfo &MRI) { 280 GCNRPTracker::LiveRegSet LiveRegs; 281 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 282 auto Reg = Register::index2VirtReg(I); 283 if (!LIS.hasInterval(Reg)) 284 continue; 285 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 286 if (LiveMask.any()) 287 LiveRegs[Reg] = LiveMask; 288 } 289 return LiveRegs; 290 } 291 292 void GCNRPTracker::reset(const MachineInstr &MI, 293 const LiveRegSet *LiveRegsCopy, 294 bool After) { 295 const MachineFunction &MF = *MI.getMF(); 296 MRI = &MF.getRegInfo(); 297 if (LiveRegsCopy) { 298 if (&LiveRegs != LiveRegsCopy) 299 LiveRegs = *LiveRegsCopy; 300 } else { 301 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 302 : getLiveRegsBefore(MI, LIS); 303 } 304 305 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 306 } 307 308 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 309 const LiveRegSet *LiveRegsCopy) { 310 GCNRPTracker::reset(MI, LiveRegsCopy, true); 311 } 312 313 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 314 assert(MRI && "call reset first"); 315 316 LastTrackedMI = &MI; 317 318 if (MI.isDebugInstr()) 319 return; 320 321 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 322 323 // calc pressure at the MI (defs + uses) 324 auto AtMIPressure = CurPressure; 325 for (const auto &U : RegUses) { 326 auto LiveMask = LiveRegs[U.RegUnit]; 327 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 328 } 329 // update max pressure 330 MaxPressure = max(AtMIPressure, MaxPressure); 331 332 for (const auto &MO : MI.operands()) { 333 if (!MO.isReg() || !MO.isDef() || 334 !Register::isVirtualRegister(MO.getReg()) || MO.isDead()) 335 continue; 336 337 auto Reg = MO.getReg(); 338 auto I = LiveRegs.find(Reg); 339 if (I == LiveRegs.end()) 340 continue; 341 auto &LiveMask = I->second; 342 auto PrevMask = LiveMask; 343 LiveMask &= ~getDefRegMask(MO, *MRI); 344 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 345 if (LiveMask.none()) 346 LiveRegs.erase(I); 347 } 348 for (const auto &U : RegUses) { 349 auto &LiveMask = LiveRegs[U.RegUnit]; 350 auto PrevMask = LiveMask; 351 LiveMask |= U.LaneMask; 352 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 353 } 354 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 355 } 356 357 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 358 const LiveRegSet *LiveRegsCopy) { 359 MRI = &MI.getParent()->getParent()->getRegInfo(); 360 LastTrackedMI = nullptr; 361 MBBEnd = MI.getParent()->end(); 362 NextMI = &MI; 363 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 364 if (NextMI == MBBEnd) 365 return false; 366 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 367 return true; 368 } 369 370 bool GCNDownwardRPTracker::advanceBeforeNext() { 371 assert(MRI && "call reset first"); 372 373 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 374 if (NextMI == MBBEnd) 375 return false; 376 377 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 378 assert(SI.isValid()); 379 380 // Remove dead registers or mask bits. 381 for (auto &It : LiveRegs) { 382 const LiveInterval &LI = LIS.getInterval(It.first); 383 if (LI.hasSubRanges()) { 384 for (const auto &S : LI.subranges()) { 385 if (!S.liveAt(SI)) { 386 auto PrevMask = It.second; 387 It.second &= ~S.LaneMask; 388 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 389 } 390 } 391 } else if (!LI.liveAt(SI)) { 392 auto PrevMask = It.second; 393 It.second = LaneBitmask::getNone(); 394 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 395 } 396 if (It.second.none()) 397 LiveRegs.erase(It.first); 398 } 399 400 MaxPressure = max(MaxPressure, CurPressure); 401 402 return true; 403 } 404 405 void GCNDownwardRPTracker::advanceToNext() { 406 LastTrackedMI = &*NextMI++; 407 408 // Add new registers or mask bits. 409 for (const auto &MO : LastTrackedMI->operands()) { 410 if (!MO.isReg() || !MO.isDef()) 411 continue; 412 Register Reg = MO.getReg(); 413 if (!Register::isVirtualRegister(Reg)) 414 continue; 415 auto &LiveMask = LiveRegs[Reg]; 416 auto PrevMask = LiveMask; 417 LiveMask |= getDefRegMask(MO, *MRI); 418 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 419 } 420 421 MaxPressure = max(MaxPressure, CurPressure); 422 } 423 424 bool GCNDownwardRPTracker::advance() { 425 // If we have just called reset live set is actual. 426 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 427 return false; 428 advanceToNext(); 429 return true; 430 } 431 432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 433 while (NextMI != End) 434 if (!advance()) return false; 435 return true; 436 } 437 438 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 439 MachineBasicBlock::const_iterator End, 440 const LiveRegSet *LiveRegsCopy) { 441 reset(*Begin, LiveRegsCopy); 442 return advance(End); 443 } 444 445 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 446 LLVM_DUMP_METHOD 447 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 448 const GCNRPTracker::LiveRegSet &TrackedLR, 449 const TargetRegisterInfo *TRI) { 450 for (auto const &P : TrackedLR) { 451 auto I = LISLR.find(P.first); 452 if (I == LISLR.end()) { 453 dbgs() << " " << printReg(P.first, TRI) 454 << ":L" << PrintLaneMask(P.second) 455 << " isn't found in LIS reported set\n"; 456 } 457 else if (I->second != P.second) { 458 dbgs() << " " << printReg(P.first, TRI) 459 << " masks doesn't match: LIS reported " 460 << PrintLaneMask(I->second) 461 << ", tracked " 462 << PrintLaneMask(P.second) 463 << '\n'; 464 } 465 } 466 for (auto const &P : LISLR) { 467 auto I = TrackedLR.find(P.first); 468 if (I == TrackedLR.end()) { 469 dbgs() << " " << printReg(P.first, TRI) 470 << ":L" << PrintLaneMask(P.second) 471 << " isn't found in tracked set\n"; 472 } 473 } 474 } 475 476 bool GCNUpwardRPTracker::isValid() const { 477 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 478 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 479 const auto &TrackedLR = LiveRegs; 480 481 if (!isEqual(LISLR, TrackedLR)) { 482 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 483 " LIS reported livesets mismatch:\n"; 484 printLivesAt(SI, LIS, *MRI); 485 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 486 return false; 487 } 488 489 auto LISPressure = getRegPressure(*MRI, LISLR); 490 if (LISPressure != CurPressure) { 491 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 492 CurPressure.print(dbgs()); 493 dbgs() << "LIS rpt: "; 494 LISPressure.print(dbgs()); 495 return false; 496 } 497 return true; 498 } 499 500 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 501 const MachineRegisterInfo &MRI) { 502 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 503 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 504 unsigned Reg = Register::index2VirtReg(I); 505 auto It = LiveRegs.find(Reg); 506 if (It != LiveRegs.end() && It->second.any()) 507 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 508 << PrintLaneMask(It->second); 509 } 510 OS << '\n'; 511 } 512 #endif 513