1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file defines the GCNRegPressure class, which tracks registry pressure 11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It 12 /// also implements a compare function, which compares different register 13 /// pressures, and declares one with max occupancy as winner. 14 /// 15 //===----------------------------------------------------------------------===// 16 17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H 18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H 19 20 #include "GCNSubtarget.h" 21 #include "llvm/CodeGen/LiveIntervals.h" 22 #include <algorithm> 23 24 namespace llvm { 25 26 class MachineRegisterInfo; 27 class raw_ostream; 28 class SlotIndex; 29 30 struct GCNRegPressure { 31 enum RegKind { 32 SGPR32, 33 SGPR_TUPLE, 34 VGPR32, 35 VGPR_TUPLE, 36 AGPR32, 37 AGPR_TUPLE, 38 TOTAL_KINDS 39 }; 40 41 GCNRegPressure() { 42 clear(); 43 } 44 45 bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; } 46 47 void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); } 48 49 unsigned getSGPRNum() const { return Value[SGPR32]; } 50 unsigned getVGPRNum(bool UnifiedVGPRFile) const { 51 if (UnifiedVGPRFile) { 52 return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32] 53 : Value[VGPR32] + Value[AGPR32]; 54 } 55 return std::max(Value[VGPR32], Value[AGPR32]); 56 } 57 unsigned getAGPRNum() const { return Value[AGPR32]; } 58 59 unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE], 60 Value[AGPR_TUPLE]); } 61 unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; } 62 63 unsigned getOccupancy(const GCNSubtarget &ST) const { 64 return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()), 65 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()))); 66 } 67 68 void inc(unsigned Reg, 69 LaneBitmask PrevMask, 70 LaneBitmask NewMask, 71 const MachineRegisterInfo &MRI); 72 73 bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const { 74 return getOccupancy(ST) > O.getOccupancy(ST); 75 } 76 77 /// Compares \p this GCNRegpressure to \p O, returning true if \p this is 78 /// less. Since GCNRegpressure contains different types of pressures, and due 79 /// to target-specific pecularities (e.g. we care about occupancy rather than 80 /// raw register usage), we determine if \p this GCNRegPressure is less than 81 /// \p O based on the following tiered comparisons (in order order of 82 /// precedence): 83 /// 1. Better occupancy 84 /// 2. Less spilling (first preference to VGPR spills, then to SGPR spills) 85 /// 3. Less tuple register pressure (first preference to VGPR tuples if we 86 /// determine that SGPR pressure is not important) 87 /// 4. Less raw register pressure (first preference to VGPR tuples if we 88 /// determine that SGPR pressure is not important) 89 bool less(const MachineFunction &MF, const GCNRegPressure &O, 90 unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const; 91 92 bool operator==(const GCNRegPressure &O) const { 93 return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value); 94 } 95 96 bool operator!=(const GCNRegPressure &O) const { 97 return !(*this == O); 98 } 99 100 GCNRegPressure &operator+=(const GCNRegPressure &RHS) { 101 for (unsigned I = 0; I < TOTAL_KINDS; ++I) 102 Value[I] += RHS.Value[I]; 103 return *this; 104 } 105 106 GCNRegPressure &operator-=(const GCNRegPressure &RHS) { 107 for (unsigned I = 0; I < TOTAL_KINDS; ++I) 108 Value[I] -= RHS.Value[I]; 109 return *this; 110 } 111 112 void dump() const; 113 114 private: 115 unsigned Value[TOTAL_KINDS]; 116 117 static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI); 118 119 friend GCNRegPressure max(const GCNRegPressure &P1, 120 const GCNRegPressure &P2); 121 122 friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST); 123 }; 124 125 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) { 126 GCNRegPressure Res; 127 for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I) 128 Res.Value[I] = std::max(P1.Value[I], P2.Value[I]); 129 return Res; 130 } 131 132 inline GCNRegPressure operator+(const GCNRegPressure &P1, 133 const GCNRegPressure &P2) { 134 GCNRegPressure Sum = P1; 135 Sum += P2; 136 return Sum; 137 } 138 139 inline GCNRegPressure operator-(const GCNRegPressure &P1, 140 const GCNRegPressure &P2) { 141 GCNRegPressure Diff = P1; 142 Diff -= P2; 143 return Diff; 144 } 145 146 class GCNRPTracker { 147 public: 148 using LiveRegSet = DenseMap<unsigned, LaneBitmask>; 149 150 protected: 151 const LiveIntervals &LIS; 152 LiveRegSet LiveRegs; 153 GCNRegPressure CurPressure, MaxPressure; 154 const MachineInstr *LastTrackedMI = nullptr; 155 mutable const MachineRegisterInfo *MRI = nullptr; 156 157 GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {} 158 159 void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy, 160 bool After); 161 162 public: 163 // live regs for the current state 164 const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; } 165 const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; } 166 167 void clearMaxPressure() { MaxPressure.clear(); } 168 169 GCNRegPressure getPressure() const { return CurPressure; } 170 171 decltype(LiveRegs) moveLiveRegs() { 172 return std::move(LiveRegs); 173 } 174 }; 175 176 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS, 177 const MachineRegisterInfo &MRI); 178 179 class GCNUpwardRPTracker : public GCNRPTracker { 180 public: 181 GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {} 182 183 // reset tracker and set live register set to the specified value. 184 void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_); 185 186 // reset tracker at the specified slot index. 187 void reset(const MachineRegisterInfo &MRI, SlotIndex SI) { 188 reset(MRI, llvm::getLiveRegs(SI, LIS, MRI)); 189 } 190 191 // reset tracker to the end of the MBB. 192 void reset(const MachineBasicBlock &MBB) { 193 reset(MBB.getParent()->getRegInfo(), 194 LIS.getSlotIndexes()->getMBBEndIdx(&MBB)); 195 } 196 197 // reset tracker to the point just after MI (in program order). 198 void reset(const MachineInstr &MI) { 199 reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot()); 200 } 201 202 // move to the state just before the MI (in program order). 203 void recede(const MachineInstr &MI); 204 205 // checks whether the tracker's state after receding MI corresponds 206 // to reported by LIS. 207 bool isValid() const; 208 209 const GCNRegPressure &getMaxPressure() const { return MaxPressure; } 210 211 void resetMaxPressure() { MaxPressure = CurPressure; } 212 213 GCNRegPressure getMaxPressureAndReset() { 214 GCNRegPressure RP = MaxPressure; 215 resetMaxPressure(); 216 return RP; 217 } 218 }; 219 220 class GCNDownwardRPTracker : public GCNRPTracker { 221 // Last position of reset or advanceBeforeNext 222 MachineBasicBlock::const_iterator NextMI; 223 224 MachineBasicBlock::const_iterator MBBEnd; 225 226 public: 227 GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {} 228 229 MachineBasicBlock::const_iterator getNext() const { return NextMI; } 230 231 // Return MaxPressure and clear it. 232 GCNRegPressure moveMaxPressure() { 233 auto Res = MaxPressure; 234 MaxPressure.clear(); 235 return Res; 236 } 237 238 // Reset tracker to the point before the MI 239 // filling live regs upon this point using LIS. 240 // Returns false if block is empty except debug values. 241 bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr); 242 243 // Move to the state right before the next MI or after the end of MBB. 244 // Returns false if reached end of the block. 245 bool advanceBeforeNext(); 246 247 // Move to the state at the MI, advanceBeforeNext has to be called first. 248 void advanceToNext(); 249 250 // Move to the state at the next MI. Returns false if reached end of block. 251 bool advance(); 252 253 // Advance instructions until before End. 254 bool advance(MachineBasicBlock::const_iterator End); 255 256 // Reset to Begin and advance to End. 257 bool advance(MachineBasicBlock::const_iterator Begin, 258 MachineBasicBlock::const_iterator End, 259 const LiveRegSet *LiveRegsCopy = nullptr); 260 }; 261 262 LaneBitmask getLiveLaneMask(unsigned Reg, 263 SlotIndex SI, 264 const LiveIntervals &LIS, 265 const MachineRegisterInfo &MRI); 266 267 LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, 268 const MachineRegisterInfo &MRI); 269 270 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS, 271 const MachineRegisterInfo &MRI); 272 273 /// creates a map MachineInstr -> LiveRegSet 274 /// R - range of iterators on instructions 275 /// After - upon entry or exit of every instruction 276 /// Note: there is no entry in the map for instructions with empty live reg set 277 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R)) 278 template <typename Range> 279 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet> 280 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) { 281 std::vector<SlotIndex> Indexes; 282 Indexes.reserve(std::distance(R.begin(), R.end())); 283 auto &SII = *LIS.getSlotIndexes(); 284 for (MachineInstr *I : R) { 285 auto SI = SII.getInstructionIndex(*I); 286 Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex()); 287 } 288 llvm::sort(Indexes); 289 290 auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo(); 291 DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap; 292 SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs; 293 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 294 auto Reg = Register::index2VirtReg(I); 295 if (!LIS.hasInterval(Reg)) 296 continue; 297 auto &LI = LIS.getInterval(Reg); 298 LiveIdxs.clear(); 299 if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs))) 300 continue; 301 if (!LI.hasSubRanges()) { 302 for (auto SI : LiveIdxs) 303 LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] = 304 MRI.getMaxLaneMaskForVReg(Reg); 305 } else 306 for (const auto &S : LI.subranges()) { 307 // constrain search for subranges by indexes live at main range 308 SRLiveIdxs.clear(); 309 S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs)); 310 for (auto SI : SRLiveIdxs) 311 LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask; 312 } 313 } 314 return LiveRegMap; 315 } 316 317 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI, 318 const LiveIntervals &LIS) { 319 return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS, 320 MI.getParent()->getParent()->getRegInfo()); 321 } 322 323 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI, 324 const LiveIntervals &LIS) { 325 return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS, 326 MI.getParent()->getParent()->getRegInfo()); 327 } 328 329 template <typename Range> 330 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI, 331 Range &&LiveRegs) { 332 GCNRegPressure Res; 333 for (const auto &RM : LiveRegs) 334 Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI); 335 return Res; 336 } 337 338 bool isEqual(const GCNRPTracker::LiveRegSet &S1, 339 const GCNRPTracker::LiveRegSet &S2); 340 341 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr); 342 343 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs, 344 const MachineRegisterInfo &MRI); 345 346 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 347 const GCNRPTracker::LiveRegSet &TrackedL, 348 const TargetRegisterInfo *TRI, StringRef Pfx = " "); 349 350 struct GCNRegPressurePrinter : public MachineFunctionPass { 351 static char ID; 352 353 public: 354 GCNRegPressurePrinter() : MachineFunctionPass(ID) {} 355 356 bool runOnMachineFunction(MachineFunction &MF) override; 357 358 void getAnalysisUsage(AnalysisUsage &AU) const override { 359 AU.addRequired<LiveIntervalsWrapperPass>(); 360 AU.setPreservesAll(); 361 MachineFunctionPass::getAnalysisUsage(AU); 362 } 363 }; 364 365 } // end namespace llvm 366 367 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H 368