xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.h (revision a4e5e0106ac7145f56eb39a691e302cabb4635be)
1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the GCNRegPressure class, which tracks registry pressure
11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12 /// also implements a compare function, which compares different register
13 /// pressures, and declares one with max occupancy as winner.
14 ///
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19 
20 #include "GCNSubtarget.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include <algorithm>
23 
24 namespace llvm {
25 
26 class MachineRegisterInfo;
27 class raw_ostream;
28 class SlotIndex;
29 
30 struct GCNRegPressure {
31   enum RegKind {
32     SGPR32,
33     SGPR_TUPLE,
34     VGPR32,
35     VGPR_TUPLE,
36     AGPR32,
37     AGPR_TUPLE,
38     TOTAL_KINDS
39   };
40 
41   GCNRegPressure() {
42     clear();
43   }
44 
45   bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
46 
47   void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
48 
49   unsigned getSGPRNum() const { return Value[SGPR32]; }
50   unsigned getVGPRNum(bool UnifiedVGPRFile) const {
51     if (UnifiedVGPRFile) {
52       return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32]
53                            : Value[VGPR32] + Value[AGPR32];
54     }
55     return std::max(Value[VGPR32], Value[AGPR32]);
56   }
57   unsigned getAGPRNum() const { return Value[AGPR32]; }
58 
59   unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
60                                                          Value[AGPR_TUPLE]); }
61   unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
62 
63   unsigned getOccupancy(const GCNSubtarget &ST) const {
64     return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
65              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
66   }
67 
68   void inc(unsigned Reg,
69            LaneBitmask PrevMask,
70            LaneBitmask NewMask,
71            const MachineRegisterInfo &MRI);
72 
73   bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
74     return getOccupancy(ST) > O.getOccupancy(ST);
75   }
76 
77   bool less(const GCNSubtarget &ST, const GCNRegPressure& O,
78     unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
79 
80   bool operator==(const GCNRegPressure &O) const {
81     return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
82   }
83 
84   bool operator!=(const GCNRegPressure &O) const {
85     return !(*this == O);
86   }
87 
88   void dump() const;
89 
90 private:
91   unsigned Value[TOTAL_KINDS];
92 
93   static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
94 
95   friend GCNRegPressure max(const GCNRegPressure &P1,
96                             const GCNRegPressure &P2);
97 
98   friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
99 };
100 
101 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
102   GCNRegPressure Res;
103   for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
104     Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
105   return Res;
106 }
107 
108 class GCNRPTracker {
109 public:
110   using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
111 
112 protected:
113   const LiveIntervals &LIS;
114   LiveRegSet LiveRegs;
115   GCNRegPressure CurPressure, MaxPressure;
116   const MachineInstr *LastTrackedMI = nullptr;
117   mutable const MachineRegisterInfo *MRI = nullptr;
118 
119   GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
120 
121   void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
122              bool After);
123 
124 public:
125   // live regs for the current state
126   const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
127   const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
128 
129   void clearMaxPressure() { MaxPressure.clear(); }
130 
131   // returns MaxPressure, resetting it
132   decltype(MaxPressure) moveMaxPressure() {
133     auto Res = MaxPressure;
134     MaxPressure.clear();
135     return Res;
136   }
137 
138   decltype(LiveRegs) moveLiveRegs() {
139     return std::move(LiveRegs);
140   }
141 };
142 
143 class GCNUpwardRPTracker : public GCNRPTracker {
144 public:
145   GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
146 
147   // reset tracker to the point just below MI
148   // filling live regs upon this point using LIS
149   void reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
150 
151   // move to the state just above the MI
152   void recede(const MachineInstr &MI);
153 
154   // checks whether the tracker's state after receding MI corresponds
155   // to reported by LIS
156   bool isValid() const;
157 };
158 
159 class GCNDownwardRPTracker : public GCNRPTracker {
160   // Last position of reset or advanceBeforeNext
161   MachineBasicBlock::const_iterator NextMI;
162 
163   MachineBasicBlock::const_iterator MBBEnd;
164 
165 public:
166   GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
167 
168   MachineBasicBlock::const_iterator getNext() const { return NextMI; }
169 
170   // Reset tracker to the point before the MI
171   // filling live regs upon this point using LIS.
172   // Returns false if block is empty except debug values.
173   bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
174 
175   // Move to the state right before the next MI or after the end of MBB.
176   // Returns false if reached end of the block.
177   bool advanceBeforeNext();
178 
179   // Move to the state at the MI, advanceBeforeNext has to be called first.
180   void advanceToNext();
181 
182   // Move to the state at the next MI. Returns false if reached end of block.
183   bool advance();
184 
185   // Advance instructions until before End.
186   bool advance(MachineBasicBlock::const_iterator End);
187 
188   // Reset to Begin and advance to End.
189   bool advance(MachineBasicBlock::const_iterator Begin,
190                MachineBasicBlock::const_iterator End,
191                const LiveRegSet *LiveRegsCopy = nullptr);
192 };
193 
194 LaneBitmask getLiveLaneMask(unsigned Reg,
195                             SlotIndex SI,
196                             const LiveIntervals &LIS,
197                             const MachineRegisterInfo &MRI);
198 
199 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI,
200                                      const LiveIntervals &LIS,
201                                      const MachineRegisterInfo &MRI);
202 
203 /// creates a map MachineInstr -> LiveRegSet
204 /// R - range of iterators on instructions
205 /// After - upon entry or exit of every instruction
206 /// Note: there is no entry in the map for instructions with empty live reg set
207 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
208 template <typename Range>
209 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
210 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
211   std::vector<SlotIndex> Indexes;
212   Indexes.reserve(std::distance(R.begin(), R.end()));
213   auto &SII = *LIS.getSlotIndexes();
214   for (MachineInstr *I : R) {
215     auto SI = SII.getInstructionIndex(*I);
216     Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
217   }
218   llvm::sort(Indexes);
219 
220   auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
221   DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
222   SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
223   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
224     auto Reg = Register::index2VirtReg(I);
225     if (!LIS.hasInterval(Reg))
226       continue;
227     auto &LI = LIS.getInterval(Reg);
228     LiveIdxs.clear();
229     if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
230       continue;
231     if (!LI.hasSubRanges()) {
232       for (auto SI : LiveIdxs)
233         LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
234           MRI.getMaxLaneMaskForVReg(Reg);
235     } else
236       for (const auto &S : LI.subranges()) {
237         // constrain search for subranges by indexes live at main range
238         SRLiveIdxs.clear();
239         S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
240         for (auto SI : SRLiveIdxs)
241           LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
242       }
243   }
244   return LiveRegMap;
245 }
246 
247 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
248                                                  const LiveIntervals &LIS) {
249   return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS,
250                      MI.getParent()->getParent()->getRegInfo());
251 }
252 
253 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
254                                                   const LiveIntervals &LIS) {
255   return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS,
256                      MI.getParent()->getParent()->getRegInfo());
257 }
258 
259 template <typename Range>
260 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
261                               Range &&LiveRegs) {
262   GCNRegPressure Res;
263   for (const auto &RM : LiveRegs)
264     Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI);
265   return Res;
266 }
267 
268 bool isEqual(const GCNRPTracker::LiveRegSet &S1,
269              const GCNRPTracker::LiveRegSet &S2);
270 
271 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
272 
273 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
274                 const MachineRegisterInfo &MRI);
275 
276 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
277                          const GCNRPTracker::LiveRegSet &TrackedL,
278                          const TargetRegisterInfo *TRI);
279 
280 } // end namespace llvm
281 
282 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
283