xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the GCNRegPressure class, which tracks registry pressure
11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12 /// also implements a compare function, which compares different register
13 /// pressures, and declares one with max occupancy as winner.
14 ///
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19 
20 #include "GCNSubtarget.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include <algorithm>
23 
24 namespace llvm {
25 
26 class MachineRegisterInfo;
27 class raw_ostream;
28 class SlotIndex;
29 
30 struct GCNRegPressure {
31   enum RegKind {
32     SGPR32,
33     SGPR_TUPLE,
34     VGPR32,
35     VGPR_TUPLE,
36     AGPR32,
37     AGPR_TUPLE,
38     TOTAL_KINDS
39   };
40 
GCNRegPressureGCNRegPressure41   GCNRegPressure() {
42     clear();
43   }
44 
emptyGCNRegPressure45   bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
46 
clearGCNRegPressure47   void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
48 
getSGPRNumGCNRegPressure49   unsigned getSGPRNum() const { return Value[SGPR32]; }
getVGPRNumGCNRegPressure50   unsigned getVGPRNum(bool UnifiedVGPRFile) const {
51     if (UnifiedVGPRFile) {
52       return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32]
53                            : Value[VGPR32] + Value[AGPR32];
54     }
55     return std::max(Value[VGPR32], Value[AGPR32]);
56   }
getAGPRNumGCNRegPressure57   unsigned getAGPRNum() const { return Value[AGPR32]; }
58 
getVGPRTuplesWeightGCNRegPressure59   unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
60                                                          Value[AGPR_TUPLE]); }
getSGPRTuplesWeightGCNRegPressure61   unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
62 
getOccupancyGCNRegPressure63   unsigned getOccupancy(const GCNSubtarget &ST) const {
64     return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
65              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
66   }
67 
68   void inc(unsigned Reg,
69            LaneBitmask PrevMask,
70            LaneBitmask NewMask,
71            const MachineRegisterInfo &MRI);
72 
higherOccupancyGCNRegPressure73   bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
74     return getOccupancy(ST) > O.getOccupancy(ST);
75   }
76 
77   /// Compares \p this GCNRegpressure to \p O, returning true if \p this is
78   /// less. Since GCNRegpressure contains different types of pressures, and due
79   /// to target-specific pecularities (e.g. we care about occupancy rather than
80   /// raw register usage), we determine if \p this GCNRegPressure is less than
81   /// \p O based on the following tiered comparisons (in order order of
82   /// precedence):
83   /// 1. Better occupancy
84   /// 2. Less spilling (first preference to VGPR spills, then to SGPR spills)
85   /// 3. Less tuple register pressure (first preference to VGPR tuples if we
86   /// determine that SGPR pressure is not important)
87   /// 4. Less raw register pressure (first preference to VGPR tuples if we
88   /// determine that SGPR pressure is not important)
89   bool less(const MachineFunction &MF, const GCNRegPressure &O,
90             unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
91 
92   bool operator==(const GCNRegPressure &O) const {
93     return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
94   }
95 
96   bool operator!=(const GCNRegPressure &O) const {
97     return !(*this == O);
98   }
99 
100   GCNRegPressure &operator+=(const GCNRegPressure &RHS) {
101     for (unsigned I = 0; I < TOTAL_KINDS; ++I)
102       Value[I] += RHS.Value[I];
103     return *this;
104   }
105 
106   GCNRegPressure &operator-=(const GCNRegPressure &RHS) {
107     for (unsigned I = 0; I < TOTAL_KINDS; ++I)
108       Value[I] -= RHS.Value[I];
109     return *this;
110   }
111 
112   void dump() const;
113 
114 private:
115   unsigned Value[TOTAL_KINDS];
116 
117   static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
118 
119   friend GCNRegPressure max(const GCNRegPressure &P1,
120                             const GCNRegPressure &P2);
121 
122   friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
123 };
124 
max(const GCNRegPressure & P1,const GCNRegPressure & P2)125 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
126   GCNRegPressure Res;
127   for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
128     Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
129   return Res;
130 }
131 
132 inline GCNRegPressure operator+(const GCNRegPressure &P1,
133                                 const GCNRegPressure &P2) {
134   GCNRegPressure Sum = P1;
135   Sum += P2;
136   return Sum;
137 }
138 
139 inline GCNRegPressure operator-(const GCNRegPressure &P1,
140                                 const GCNRegPressure &P2) {
141   GCNRegPressure Diff = P1;
142   Diff -= P2;
143   return Diff;
144 }
145 
146 class GCNRPTracker {
147 public:
148   using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
149 
150 protected:
151   const LiveIntervals &LIS;
152   LiveRegSet LiveRegs;
153   GCNRegPressure CurPressure, MaxPressure;
154   const MachineInstr *LastTrackedMI = nullptr;
155   mutable const MachineRegisterInfo *MRI = nullptr;
156 
GCNRPTracker(const LiveIntervals & LIS_)157   GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
158 
159   void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
160              bool After);
161 
162 public:
163   // live regs for the current state
decltype(LiveRegs)164   const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
getLastTrackedMI()165   const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
166 
clearMaxPressure()167   void clearMaxPressure() { MaxPressure.clear(); }
168 
getPressure()169   GCNRegPressure getPressure() const { return CurPressure; }
170 
moveLiveRegs()171   decltype(LiveRegs) moveLiveRegs() {
172     return std::move(LiveRegs);
173   }
174 };
175 
176 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
177                                      const MachineRegisterInfo &MRI);
178 
179 class GCNUpwardRPTracker : public GCNRPTracker {
180 public:
GCNUpwardRPTracker(const LiveIntervals & LIS_)181   GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
182 
183   // reset tracker and set live register set to the specified value.
184   void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
185 
186   // reset tracker at the specified slot index.
reset(const MachineRegisterInfo & MRI,SlotIndex SI)187   void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
188     reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
189   }
190 
191   // reset tracker to the end of the MBB.
reset(const MachineBasicBlock & MBB)192   void reset(const MachineBasicBlock &MBB) {
193     reset(MBB.getParent()->getRegInfo(),
194           LIS.getSlotIndexes()->getMBBEndIdx(&MBB));
195   }
196 
197   // reset tracker to the point just after MI (in program order).
reset(const MachineInstr & MI)198   void reset(const MachineInstr &MI) {
199     reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot());
200   }
201 
202   // move to the state just before the MI (in program order).
203   void recede(const MachineInstr &MI);
204 
205   // checks whether the tracker's state after receding MI corresponds
206   // to reported by LIS.
207   bool isValid() const;
208 
getMaxPressure()209   const GCNRegPressure &getMaxPressure() const { return MaxPressure; }
210 
resetMaxPressure()211   void resetMaxPressure() { MaxPressure = CurPressure; }
212 
getMaxPressureAndReset()213   GCNRegPressure getMaxPressureAndReset() {
214     GCNRegPressure RP = MaxPressure;
215     resetMaxPressure();
216     return RP;
217   }
218 };
219 
220 class GCNDownwardRPTracker : public GCNRPTracker {
221   // Last position of reset or advanceBeforeNext
222   MachineBasicBlock::const_iterator NextMI;
223 
224   MachineBasicBlock::const_iterator MBBEnd;
225 
226 public:
GCNDownwardRPTracker(const LiveIntervals & LIS_)227   GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
228 
getNext()229   MachineBasicBlock::const_iterator getNext() const { return NextMI; }
230 
231   // Return MaxPressure and clear it.
moveMaxPressure()232   GCNRegPressure moveMaxPressure() {
233     auto Res = MaxPressure;
234     MaxPressure.clear();
235     return Res;
236   }
237 
238   // Reset tracker to the point before the MI
239   // filling live regs upon this point using LIS.
240   // Returns false if block is empty except debug values.
241   bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
242 
243   // Move to the state right before the next MI or after the end of MBB.
244   // Returns false if reached end of the block.
245   bool advanceBeforeNext();
246 
247   // Move to the state at the MI, advanceBeforeNext has to be called first.
248   void advanceToNext();
249 
250   // Move to the state at the next MI. Returns false if reached end of block.
251   bool advance();
252 
253   // Advance instructions until before End.
254   bool advance(MachineBasicBlock::const_iterator End);
255 
256   // Reset to Begin and advance to End.
257   bool advance(MachineBasicBlock::const_iterator Begin,
258                MachineBasicBlock::const_iterator End,
259                const LiveRegSet *LiveRegsCopy = nullptr);
260 };
261 
262 LaneBitmask getLiveLaneMask(unsigned Reg,
263                             SlotIndex SI,
264                             const LiveIntervals &LIS,
265                             const MachineRegisterInfo &MRI);
266 
267 LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
268                             const MachineRegisterInfo &MRI);
269 
270 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
271                                      const MachineRegisterInfo &MRI);
272 
273 /// creates a map MachineInstr -> LiveRegSet
274 /// R - range of iterators on instructions
275 /// After - upon entry or exit of every instruction
276 /// Note: there is no entry in the map for instructions with empty live reg set
277 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
278 template <typename Range>
279 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
getLiveRegMap(Range && R,bool After,LiveIntervals & LIS)280 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
281   std::vector<SlotIndex> Indexes;
282   Indexes.reserve(std::distance(R.begin(), R.end()));
283   auto &SII = *LIS.getSlotIndexes();
284   for (MachineInstr *I : R) {
285     auto SI = SII.getInstructionIndex(*I);
286     Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
287   }
288   llvm::sort(Indexes);
289 
290   auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
291   DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
292   SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
293   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
294     auto Reg = Register::index2VirtReg(I);
295     if (!LIS.hasInterval(Reg))
296       continue;
297     auto &LI = LIS.getInterval(Reg);
298     LiveIdxs.clear();
299     if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
300       continue;
301     if (!LI.hasSubRanges()) {
302       for (auto SI : LiveIdxs)
303         LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
304           MRI.getMaxLaneMaskForVReg(Reg);
305     } else
306       for (const auto &S : LI.subranges()) {
307         // constrain search for subranges by indexes live at main range
308         SRLiveIdxs.clear();
309         S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
310         for (auto SI : SRLiveIdxs)
311           LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
312       }
313   }
314   return LiveRegMap;
315 }
316 
getLiveRegsAfter(const MachineInstr & MI,const LiveIntervals & LIS)317 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
318                                                  const LiveIntervals &LIS) {
319   return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS,
320                      MI.getParent()->getParent()->getRegInfo());
321 }
322 
getLiveRegsBefore(const MachineInstr & MI,const LiveIntervals & LIS)323 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
324                                                   const LiveIntervals &LIS) {
325   return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS,
326                      MI.getParent()->getParent()->getRegInfo());
327 }
328 
329 template <typename Range>
getRegPressure(const MachineRegisterInfo & MRI,Range && LiveRegs)330 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
331                               Range &&LiveRegs) {
332   GCNRegPressure Res;
333   for (const auto &RM : LiveRegs)
334     Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI);
335   return Res;
336 }
337 
338 bool isEqual(const GCNRPTracker::LiveRegSet &S1,
339              const GCNRPTracker::LiveRegSet &S2);
340 
341 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
342 
343 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
344                 const MachineRegisterInfo &MRI);
345 
346 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
347                          const GCNRPTracker::LiveRegSet &TrackedL,
348                          const TargetRegisterInfo *TRI, StringRef Pfx = "  ");
349 
350 struct GCNRegPressurePrinter : public MachineFunctionPass {
351   static char ID;
352 
353 public:
GCNRegPressurePrinterGCNRegPressurePrinter354   GCNRegPressurePrinter() : MachineFunctionPass(ID) {}
355 
356   bool runOnMachineFunction(MachineFunction &MF) override;
357 
getAnalysisUsageGCNRegPressurePrinter358   void getAnalysisUsage(AnalysisUsage &AU) const override {
359     AU.addRequired<LiveIntervalsWrapperPass>();
360     AU.setPreservesAll();
361     MachineFunctionPass::getAnalysisUsage(AU);
362   }
363 };
364 
365 } // end namespace llvm
366 
367 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
368