xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.h (revision f5f40dd63bc7acbb5312b26ac1ea1103c12352a6)
1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the GCNRegPressure class, which tracks registry pressure
11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12 /// also implements a compare function, which compares different register
13 /// pressures, and declares one with max occupancy as winner.
14 ///
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19 
20 #include "GCNSubtarget.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include <algorithm>
23 
24 namespace llvm {
25 
26 class MachineRegisterInfo;
27 class raw_ostream;
28 class SlotIndex;
29 
30 struct GCNRegPressure {
31   enum RegKind {
32     SGPR32,
33     SGPR_TUPLE,
34     VGPR32,
35     VGPR_TUPLE,
36     AGPR32,
37     AGPR_TUPLE,
38     TOTAL_KINDS
39   };
40 
41   GCNRegPressure() {
42     clear();
43   }
44 
45   bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
46 
47   void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
48 
49   unsigned getSGPRNum() const { return Value[SGPR32]; }
50   unsigned getVGPRNum(bool UnifiedVGPRFile) const {
51     if (UnifiedVGPRFile) {
52       return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32]
53                            : Value[VGPR32] + Value[AGPR32];
54     }
55     return std::max(Value[VGPR32], Value[AGPR32]);
56   }
57   unsigned getAGPRNum() const { return Value[AGPR32]; }
58 
59   unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
60                                                          Value[AGPR_TUPLE]); }
61   unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
62 
63   unsigned getOccupancy(const GCNSubtarget &ST) const {
64     return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
65              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
66   }
67 
68   void inc(unsigned Reg,
69            LaneBitmask PrevMask,
70            LaneBitmask NewMask,
71            const MachineRegisterInfo &MRI);
72 
73   bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
74     return getOccupancy(ST) > O.getOccupancy(ST);
75   }
76 
77   bool less(const GCNSubtarget &ST, const GCNRegPressure& O,
78     unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
79 
80   bool operator==(const GCNRegPressure &O) const {
81     return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
82   }
83 
84   bool operator!=(const GCNRegPressure &O) const {
85     return !(*this == O);
86   }
87 
88   GCNRegPressure &operator+=(const GCNRegPressure &RHS) {
89     for (unsigned I = 0; I < TOTAL_KINDS; ++I)
90       Value[I] += RHS.Value[I];
91     return *this;
92   }
93 
94   GCNRegPressure &operator-=(const GCNRegPressure &RHS) {
95     for (unsigned I = 0; I < TOTAL_KINDS; ++I)
96       Value[I] -= RHS.Value[I];
97     return *this;
98   }
99 
100   void dump() const;
101 
102 private:
103   unsigned Value[TOTAL_KINDS];
104 
105   static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
106 
107   friend GCNRegPressure max(const GCNRegPressure &P1,
108                             const GCNRegPressure &P2);
109 
110   friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
111 };
112 
113 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
114   GCNRegPressure Res;
115   for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
116     Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
117   return Res;
118 }
119 
120 inline GCNRegPressure operator+(const GCNRegPressure &P1,
121                                 const GCNRegPressure &P2) {
122   GCNRegPressure Sum = P1;
123   Sum += P2;
124   return Sum;
125 }
126 
127 inline GCNRegPressure operator-(const GCNRegPressure &P1,
128                                 const GCNRegPressure &P2) {
129   GCNRegPressure Diff = P1;
130   Diff -= P2;
131   return Diff;
132 }
133 
134 class GCNRPTracker {
135 public:
136   using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
137 
138 protected:
139   const LiveIntervals &LIS;
140   LiveRegSet LiveRegs;
141   GCNRegPressure CurPressure, MaxPressure;
142   const MachineInstr *LastTrackedMI = nullptr;
143   mutable const MachineRegisterInfo *MRI = nullptr;
144 
145   GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
146 
147   void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
148              bool After);
149 
150 public:
151   // live regs for the current state
152   const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
153   const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
154 
155   void clearMaxPressure() { MaxPressure.clear(); }
156 
157   GCNRegPressure getPressure() const { return CurPressure; }
158 
159   decltype(LiveRegs) moveLiveRegs() {
160     return std::move(LiveRegs);
161   }
162 };
163 
164 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
165                                      const MachineRegisterInfo &MRI);
166 
167 class GCNUpwardRPTracker : public GCNRPTracker {
168 public:
169   GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
170 
171   // reset tracker and set live register set to the specified value.
172   void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
173 
174   // reset tracker at the specified slot index.
175   void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
176     reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
177   }
178 
179   // reset tracker to the end of the MBB.
180   void reset(const MachineBasicBlock &MBB) {
181     reset(MBB.getParent()->getRegInfo(),
182           LIS.getSlotIndexes()->getMBBEndIdx(&MBB));
183   }
184 
185   // reset tracker to the point just after MI (in program order).
186   void reset(const MachineInstr &MI) {
187     reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot());
188   }
189 
190   // move to the state just before the MI (in program order).
191   void recede(const MachineInstr &MI);
192 
193   // checks whether the tracker's state after receding MI corresponds
194   // to reported by LIS.
195   bool isValid() const;
196 
197   const GCNRegPressure &getMaxPressure() const { return MaxPressure; }
198 
199   void resetMaxPressure() { MaxPressure = CurPressure; }
200 
201   GCNRegPressure getMaxPressureAndReset() {
202     GCNRegPressure RP = MaxPressure;
203     resetMaxPressure();
204     return RP;
205   }
206 };
207 
208 class GCNDownwardRPTracker : public GCNRPTracker {
209   // Last position of reset or advanceBeforeNext
210   MachineBasicBlock::const_iterator NextMI;
211 
212   MachineBasicBlock::const_iterator MBBEnd;
213 
214 public:
215   GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
216 
217   MachineBasicBlock::const_iterator getNext() const { return NextMI; }
218 
219   // Return MaxPressure and clear it.
220   GCNRegPressure moveMaxPressure() {
221     auto Res = MaxPressure;
222     MaxPressure.clear();
223     return Res;
224   }
225 
226   // Reset tracker to the point before the MI
227   // filling live regs upon this point using LIS.
228   // Returns false if block is empty except debug values.
229   bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
230 
231   // Move to the state right before the next MI or after the end of MBB.
232   // Returns false if reached end of the block.
233   bool advanceBeforeNext();
234 
235   // Move to the state at the MI, advanceBeforeNext has to be called first.
236   void advanceToNext();
237 
238   // Move to the state at the next MI. Returns false if reached end of block.
239   bool advance();
240 
241   // Advance instructions until before End.
242   bool advance(MachineBasicBlock::const_iterator End);
243 
244   // Reset to Begin and advance to End.
245   bool advance(MachineBasicBlock::const_iterator Begin,
246                MachineBasicBlock::const_iterator End,
247                const LiveRegSet *LiveRegsCopy = nullptr);
248 };
249 
250 LaneBitmask getLiveLaneMask(unsigned Reg,
251                             SlotIndex SI,
252                             const LiveIntervals &LIS,
253                             const MachineRegisterInfo &MRI);
254 
255 LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
256                             const MachineRegisterInfo &MRI);
257 
258 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
259                                      const MachineRegisterInfo &MRI);
260 
261 /// creates a map MachineInstr -> LiveRegSet
262 /// R - range of iterators on instructions
263 /// After - upon entry or exit of every instruction
264 /// Note: there is no entry in the map for instructions with empty live reg set
265 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
266 template <typename Range>
267 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
268 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
269   std::vector<SlotIndex> Indexes;
270   Indexes.reserve(std::distance(R.begin(), R.end()));
271   auto &SII = *LIS.getSlotIndexes();
272   for (MachineInstr *I : R) {
273     auto SI = SII.getInstructionIndex(*I);
274     Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
275   }
276   llvm::sort(Indexes);
277 
278   auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
279   DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
280   SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
281   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282     auto Reg = Register::index2VirtReg(I);
283     if (!LIS.hasInterval(Reg))
284       continue;
285     auto &LI = LIS.getInterval(Reg);
286     LiveIdxs.clear();
287     if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
288       continue;
289     if (!LI.hasSubRanges()) {
290       for (auto SI : LiveIdxs)
291         LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
292           MRI.getMaxLaneMaskForVReg(Reg);
293     } else
294       for (const auto &S : LI.subranges()) {
295         // constrain search for subranges by indexes live at main range
296         SRLiveIdxs.clear();
297         S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
298         for (auto SI : SRLiveIdxs)
299           LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
300       }
301   }
302   return LiveRegMap;
303 }
304 
305 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
306                                                  const LiveIntervals &LIS) {
307   return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS,
308                      MI.getParent()->getParent()->getRegInfo());
309 }
310 
311 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
312                                                   const LiveIntervals &LIS) {
313   return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS,
314                      MI.getParent()->getParent()->getRegInfo());
315 }
316 
317 template <typename Range>
318 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
319                               Range &&LiveRegs) {
320   GCNRegPressure Res;
321   for (const auto &RM : LiveRegs)
322     Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI);
323   return Res;
324 }
325 
326 bool isEqual(const GCNRPTracker::LiveRegSet &S1,
327              const GCNRPTracker::LiveRegSet &S2);
328 
329 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
330 
331 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
332                 const MachineRegisterInfo &MRI);
333 
334 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
335                          const GCNRPTracker::LiveRegSet &TrackedL,
336                          const TargetRegisterInfo *TRI, StringRef Pfx = "  ");
337 
338 struct GCNRegPressurePrinter : public MachineFunctionPass {
339   static char ID;
340 
341 public:
342   GCNRegPressurePrinter() : MachineFunctionPass(ID) {}
343 
344   bool runOnMachineFunction(MachineFunction &MF) override;
345 
346   void getAnalysisUsage(AnalysisUsage &AU) const override {
347     AU.addRequired<LiveIntervals>();
348     AU.setPreservesAll();
349     MachineFunctionPass::getAnalysisUsage(AU);
350   }
351 };
352 
353 } // end namespace llvm
354 
355 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
356