xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/RegisterPressure.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- RegisterPressure.cpp - Dynamic Register Pressure -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the RegisterPressure class which can be used to track
10 // MachineInstr level register pressure.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/RegisterPressure.h"
15 #include "llvm/ADT/ArrayRef.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/LiveInterval.h"
19 #include "llvm/CodeGen/LiveIntervals.h"
20 #include "llvm/CodeGen/MachineBasicBlock.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineInstrBundle.h"
24 #include "llvm/CodeGen/MachineOperand.h"
25 #include "llvm/CodeGen/MachineRegisterInfo.h"
26 #include "llvm/CodeGen/RegisterClassInfo.h"
27 #include "llvm/CodeGen/SlotIndexes.h"
28 #include "llvm/CodeGen/TargetRegisterInfo.h"
29 #include "llvm/CodeGen/TargetSubtargetInfo.h"
30 #include "llvm/Config/llvm-config.h"
31 #include "llvm/MC/LaneBitmask.h"
32 #include "llvm/MC/MCRegisterInfo.h"
33 #include "llvm/Support/Compiler.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <algorithm>
38 #include <cassert>
39 #include <cstdint>
40 #include <cstdlib>
41 #include <cstring>
42 #include <iterator>
43 #include <limits>
44 #include <utility>
45 #include <vector>
46 
47 using namespace llvm;
48 
49 /// Increase pressure for each pressure set provided by TargetRegisterInfo.
increaseSetPressure(std::vector<unsigned> & CurrSetPressure,const MachineRegisterInfo & MRI,unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask)50 static void increaseSetPressure(std::vector<unsigned> &CurrSetPressure,
51                                 const MachineRegisterInfo &MRI, unsigned Reg,
52                                 LaneBitmask PrevMask, LaneBitmask NewMask) {
53   assert((PrevMask & ~NewMask).none() && "Must not remove bits");
54   if (PrevMask.any() || NewMask.none())
55     return;
56 
57   PSetIterator PSetI = MRI.getPressureSets(Reg);
58   unsigned Weight = PSetI.getWeight();
59   for (; PSetI.isValid(); ++PSetI)
60     CurrSetPressure[*PSetI] += Weight;
61 }
62 
63 /// Decrease pressure for each pressure set provided by TargetRegisterInfo.
decreaseSetPressure(std::vector<unsigned> & CurrSetPressure,const MachineRegisterInfo & MRI,Register Reg,LaneBitmask PrevMask,LaneBitmask NewMask)64 static void decreaseSetPressure(std::vector<unsigned> &CurrSetPressure,
65                                 const MachineRegisterInfo &MRI, Register Reg,
66                                 LaneBitmask PrevMask, LaneBitmask NewMask) {
67   assert((NewMask & ~PrevMask).none() && "Must not add bits");
68   if (NewMask.any() || PrevMask.none())
69     return;
70 
71   PSetIterator PSetI = MRI.getPressureSets(Reg);
72   unsigned Weight = PSetI.getWeight();
73   for (; PSetI.isValid(); ++PSetI) {
74     assert(CurrSetPressure[*PSetI] >= Weight && "register pressure underflow");
75     CurrSetPressure[*PSetI] -= Weight;
76   }
77 }
78 
79 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
80 LLVM_DUMP_METHOD
dumpRegSetPressure(ArrayRef<unsigned> SetPressure,const TargetRegisterInfo * TRI)81 void llvm::dumpRegSetPressure(ArrayRef<unsigned> SetPressure,
82                               const TargetRegisterInfo *TRI) {
83   bool Empty = true;
84   for (unsigned i = 0, e = SetPressure.size(); i < e; ++i) {
85     if (SetPressure[i] != 0) {
86       dbgs() << TRI->getRegPressureSetName(i) << "=" << SetPressure[i] << '\n';
87       Empty = false;
88     }
89   }
90   if (Empty)
91     dbgs() << "\n";
92 }
93 
94 LLVM_DUMP_METHOD
dump(const TargetRegisterInfo * TRI) const95 void RegisterPressure::dump(const TargetRegisterInfo *TRI) const {
96   dbgs() << "Max Pressure: ";
97   dumpRegSetPressure(MaxSetPressure, TRI);
98   dbgs() << "Live In: ";
99   for (const RegisterMaskPair &P : LiveInRegs) {
100     dbgs() << printVRegOrUnit(P.RegUnit, TRI);
101     if (!P.LaneMask.all())
102       dbgs() << ':' << PrintLaneMask(P.LaneMask);
103     dbgs() << ' ';
104   }
105   dbgs() << '\n';
106   dbgs() << "Live Out: ";
107   for (const RegisterMaskPair &P : LiveOutRegs) {
108     dbgs() << printVRegOrUnit(P.RegUnit, TRI);
109     if (!P.LaneMask.all())
110       dbgs() << ':' << PrintLaneMask(P.LaneMask);
111     dbgs() << ' ';
112   }
113   dbgs() << '\n';
114 }
115 
116 LLVM_DUMP_METHOD
dump() const117 void RegPressureTracker::dump() const {
118   if (!isTopClosed() || !isBottomClosed()) {
119     dbgs() << "Curr Pressure: ";
120     dumpRegSetPressure(CurrSetPressure, TRI);
121   }
122   P.dump(TRI);
123 }
124 
125 LLVM_DUMP_METHOD
dump(const TargetRegisterInfo & TRI) const126 void PressureDiff::dump(const TargetRegisterInfo &TRI) const {
127   const char *sep = "";
128   for (const PressureChange &Change : *this) {
129     if (!Change.isValid())
130       break;
131     dbgs() << sep << TRI.getRegPressureSetName(Change.getPSet())
132            << " " << Change.getUnitInc();
133     sep = "    ";
134   }
135   dbgs() << '\n';
136 }
137 
138 LLVM_DUMP_METHOD
dump() const139 void PressureChange::dump() const {
140   dbgs() << "[" << getPSetOrMax() << ", " << getUnitInc() << "]\n";
141 }
142 
dump() const143 void RegPressureDelta::dump() const {
144   dbgs() << "[Excess=";
145   Excess.dump();
146   dbgs() << ", CriticalMax=";
147   CriticalMax.dump();
148   dbgs() << ", CurrentMax=";
149   CurrentMax.dump();
150   dbgs() << "]\n";
151 }
152 
153 #endif
154 
increaseRegPressure(Register RegUnit,LaneBitmask PreviousMask,LaneBitmask NewMask)155 void RegPressureTracker::increaseRegPressure(Register RegUnit,
156                                              LaneBitmask PreviousMask,
157                                              LaneBitmask NewMask) {
158   if (PreviousMask.any() || NewMask.none())
159     return;
160 
161   PSetIterator PSetI = MRI->getPressureSets(RegUnit);
162   unsigned Weight = PSetI.getWeight();
163   for (; PSetI.isValid(); ++PSetI) {
164     CurrSetPressure[*PSetI] += Weight;
165     P.MaxSetPressure[*PSetI] =
166         std::max(P.MaxSetPressure[*PSetI], CurrSetPressure[*PSetI]);
167   }
168 }
169 
decreaseRegPressure(Register RegUnit,LaneBitmask PreviousMask,LaneBitmask NewMask)170 void RegPressureTracker::decreaseRegPressure(Register RegUnit,
171                                              LaneBitmask PreviousMask,
172                                              LaneBitmask NewMask) {
173   decreaseSetPressure(CurrSetPressure, *MRI, RegUnit, PreviousMask, NewMask);
174 }
175 
176 /// Clear the result so it can be used for another round of pressure tracking.
reset()177 void IntervalPressure::reset() {
178   TopIdx = BottomIdx = SlotIndex();
179   MaxSetPressure.clear();
180   LiveInRegs.clear();
181   LiveOutRegs.clear();
182 }
183 
184 /// Clear the result so it can be used for another round of pressure tracking.
reset()185 void RegionPressure::reset() {
186   TopPos = BottomPos = MachineBasicBlock::const_iterator();
187   MaxSetPressure.clear();
188   LiveInRegs.clear();
189   LiveOutRegs.clear();
190 }
191 
192 /// If the current top is not less than or equal to the next index, open it.
193 /// We happen to need the SlotIndex for the next top for pressure update.
openTop(SlotIndex NextTop)194 void IntervalPressure::openTop(SlotIndex NextTop) {
195   if (TopIdx <= NextTop)
196     return;
197   TopIdx = SlotIndex();
198   LiveInRegs.clear();
199 }
200 
201 /// If the current top is the previous instruction (before receding), open it.
openTop(MachineBasicBlock::const_iterator PrevTop)202 void RegionPressure::openTop(MachineBasicBlock::const_iterator PrevTop) {
203   if (TopPos != PrevTop)
204     return;
205   TopPos = MachineBasicBlock::const_iterator();
206   LiveInRegs.clear();
207 }
208 
209 /// If the current bottom is not greater than the previous index, open it.
openBottom(SlotIndex PrevBottom)210 void IntervalPressure::openBottom(SlotIndex PrevBottom) {
211   if (BottomIdx > PrevBottom)
212     return;
213   BottomIdx = SlotIndex();
214   LiveInRegs.clear();
215 }
216 
217 /// If the current bottom is the previous instr (before advancing), open it.
openBottom(MachineBasicBlock::const_iterator PrevBottom)218 void RegionPressure::openBottom(MachineBasicBlock::const_iterator PrevBottom) {
219   if (BottomPos != PrevBottom)
220     return;
221   BottomPos = MachineBasicBlock::const_iterator();
222   LiveInRegs.clear();
223 }
224 
init(const MachineRegisterInfo & MRI)225 void LiveRegSet::init(const MachineRegisterInfo &MRI) {
226   const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
227   unsigned NumRegUnits = TRI.getNumRegs();
228   unsigned NumVirtRegs = MRI.getNumVirtRegs();
229   Regs.setUniverse(NumRegUnits + NumVirtRegs);
230   this->NumRegUnits = NumRegUnits;
231 }
232 
clear()233 void LiveRegSet::clear() {
234   Regs.clear();
235 }
236 
getLiveRange(const LiveIntervals & LIS,unsigned Reg)237 static const LiveRange *getLiveRange(const LiveIntervals &LIS, unsigned Reg) {
238   if (Register::isVirtualRegister(Reg))
239     return &LIS.getInterval(Reg);
240   return LIS.getCachedRegUnit(Reg);
241 }
242 
reset()243 void RegPressureTracker::reset() {
244   MBB = nullptr;
245   LIS = nullptr;
246 
247   CurrSetPressure.clear();
248   LiveThruPressure.clear();
249   P.MaxSetPressure.clear();
250 
251   if (RequireIntervals)
252     static_cast<IntervalPressure&>(P).reset();
253   else
254     static_cast<RegionPressure&>(P).reset();
255 
256   LiveRegs.clear();
257   UntiedDefs.clear();
258 }
259 
260 /// Setup the RegPressureTracker.
261 ///
262 /// TODO: Add support for pressure without LiveIntervals.
init(const MachineFunction * mf,const RegisterClassInfo * rci,const LiveIntervals * lis,const MachineBasicBlock * mbb,MachineBasicBlock::const_iterator pos,bool TrackLaneMasks,bool TrackUntiedDefs)263 void RegPressureTracker::init(const MachineFunction *mf,
264                               const RegisterClassInfo *rci,
265                               const LiveIntervals *lis,
266                               const MachineBasicBlock *mbb,
267                               MachineBasicBlock::const_iterator pos,
268                               bool TrackLaneMasks, bool TrackUntiedDefs) {
269   reset();
270 
271   MF = mf;
272   TRI = MF->getSubtarget().getRegisterInfo();
273   RCI = rci;
274   MRI = &MF->getRegInfo();
275   MBB = mbb;
276   this->TrackUntiedDefs = TrackUntiedDefs;
277   this->TrackLaneMasks = TrackLaneMasks;
278 
279   if (RequireIntervals) {
280     assert(lis && "IntervalPressure requires LiveIntervals");
281     LIS = lis;
282   }
283 
284   CurrPos = pos;
285   CurrSetPressure.assign(TRI->getNumRegPressureSets(), 0);
286 
287   P.MaxSetPressure = CurrSetPressure;
288 
289   LiveRegs.init(*MRI);
290   if (TrackUntiedDefs)
291     UntiedDefs.setUniverse(MRI->getNumVirtRegs());
292 }
293 
294 /// Does this pressure result have a valid top position and live ins.
isTopClosed() const295 bool RegPressureTracker::isTopClosed() const {
296   if (RequireIntervals)
297     return static_cast<IntervalPressure&>(P).TopIdx.isValid();
298   return (static_cast<RegionPressure&>(P).TopPos ==
299           MachineBasicBlock::const_iterator());
300 }
301 
302 /// Does this pressure result have a valid bottom position and live outs.
isBottomClosed() const303 bool RegPressureTracker::isBottomClosed() const {
304   if (RequireIntervals)
305     return static_cast<IntervalPressure&>(P).BottomIdx.isValid();
306   return (static_cast<RegionPressure&>(P).BottomPos ==
307           MachineBasicBlock::const_iterator());
308 }
309 
getCurrSlot() const310 SlotIndex RegPressureTracker::getCurrSlot() const {
311   MachineBasicBlock::const_iterator IdxPos =
312     skipDebugInstructionsForward(CurrPos, MBB->end());
313   if (IdxPos == MBB->end())
314     return LIS->getMBBEndIdx(MBB);
315   return LIS->getInstructionIndex(*IdxPos).getRegSlot();
316 }
317 
318 /// Set the boundary for the top of the region and summarize live ins.
closeTop()319 void RegPressureTracker::closeTop() {
320   if (RequireIntervals)
321     static_cast<IntervalPressure&>(P).TopIdx = getCurrSlot();
322   else
323     static_cast<RegionPressure&>(P).TopPos = CurrPos;
324 
325   assert(P.LiveInRegs.empty() && "inconsistent max pressure result");
326   P.LiveInRegs.reserve(LiveRegs.size());
327   LiveRegs.appendTo(P.LiveInRegs);
328 }
329 
330 /// Set the boundary for the bottom of the region and summarize live outs.
closeBottom()331 void RegPressureTracker::closeBottom() {
332   if (RequireIntervals)
333     static_cast<IntervalPressure&>(P).BottomIdx = getCurrSlot();
334   else
335     static_cast<RegionPressure&>(P).BottomPos = CurrPos;
336 
337   assert(P.LiveOutRegs.empty() && "inconsistent max pressure result");
338   P.LiveOutRegs.reserve(LiveRegs.size());
339   LiveRegs.appendTo(P.LiveOutRegs);
340 }
341 
342 /// Finalize the region boundaries and record live ins and live outs.
closeRegion()343 void RegPressureTracker::closeRegion() {
344   if (!isTopClosed() && !isBottomClosed()) {
345     assert(LiveRegs.size() == 0 && "no region boundary");
346     return;
347   }
348   if (!isBottomClosed())
349     closeBottom();
350   else if (!isTopClosed())
351     closeTop();
352   // If both top and bottom are closed, do nothing.
353 }
354 
355 /// The register tracker is unaware of global liveness so ignores normal
356 /// live-thru ranges. However, two-address or coalesced chains can also lead
357 /// to live ranges with no holes. Count these to inform heuristics that we
358 /// can never drop below this pressure.
initLiveThru(const RegPressureTracker & RPTracker)359 void RegPressureTracker::initLiveThru(const RegPressureTracker &RPTracker) {
360   LiveThruPressure.assign(TRI->getNumRegPressureSets(), 0);
361   assert(isBottomClosed() && "need bottom-up tracking to intialize.");
362   for (const RegisterMaskPair &Pair : P.LiveOutRegs) {
363     Register RegUnit = Pair.RegUnit;
364     if (RegUnit.isVirtual() && !RPTracker.hasUntiedDef(RegUnit))
365       increaseSetPressure(LiveThruPressure, *MRI, RegUnit,
366                           LaneBitmask::getNone(), Pair.LaneMask);
367   }
368 }
369 
getRegLanes(ArrayRef<RegisterMaskPair> RegUnits,Register RegUnit)370 static LaneBitmask getRegLanes(ArrayRef<RegisterMaskPair> RegUnits,
371                                Register RegUnit) {
372   auto I = llvm::find_if(RegUnits, [RegUnit](const RegisterMaskPair Other) {
373     return Other.RegUnit == RegUnit;
374   });
375   if (I == RegUnits.end())
376     return LaneBitmask::getNone();
377   return I->LaneMask;
378 }
379 
addRegLanes(SmallVectorImpl<RegisterMaskPair> & RegUnits,RegisterMaskPair Pair)380 static void addRegLanes(SmallVectorImpl<RegisterMaskPair> &RegUnits,
381                         RegisterMaskPair Pair) {
382   Register RegUnit = Pair.RegUnit;
383   assert(Pair.LaneMask.any());
384   auto I = llvm::find_if(RegUnits, [RegUnit](const RegisterMaskPair Other) {
385     return Other.RegUnit == RegUnit;
386   });
387   if (I == RegUnits.end()) {
388     RegUnits.push_back(Pair);
389   } else {
390     I->LaneMask |= Pair.LaneMask;
391   }
392 }
393 
setRegZero(SmallVectorImpl<RegisterMaskPair> & RegUnits,Register RegUnit)394 static void setRegZero(SmallVectorImpl<RegisterMaskPair> &RegUnits,
395                        Register RegUnit) {
396   auto I = llvm::find_if(RegUnits, [RegUnit](const RegisterMaskPair Other) {
397     return Other.RegUnit == RegUnit;
398   });
399   if (I == RegUnits.end()) {
400     RegUnits.push_back(RegisterMaskPair(RegUnit, LaneBitmask::getNone()));
401   } else {
402     I->LaneMask = LaneBitmask::getNone();
403   }
404 }
405 
removeRegLanes(SmallVectorImpl<RegisterMaskPair> & RegUnits,RegisterMaskPair Pair)406 static void removeRegLanes(SmallVectorImpl<RegisterMaskPair> &RegUnits,
407                            RegisterMaskPair Pair) {
408   Register RegUnit = Pair.RegUnit;
409   assert(Pair.LaneMask.any());
410   auto I = llvm::find_if(RegUnits, [RegUnit](const RegisterMaskPair Other) {
411     return Other.RegUnit == RegUnit;
412   });
413   if (I != RegUnits.end()) {
414     I->LaneMask &= ~Pair.LaneMask;
415     if (I->LaneMask.none())
416       RegUnits.erase(I);
417   }
418 }
419 
420 static LaneBitmask
getLanesWithProperty(const LiveIntervals & LIS,const MachineRegisterInfo & MRI,bool TrackLaneMasks,Register RegUnit,SlotIndex Pos,LaneBitmask SafeDefault,bool (* Property)(const LiveRange & LR,SlotIndex Pos))421 getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
422                      bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
423                      LaneBitmask SafeDefault,
424                      bool (*Property)(const LiveRange &LR, SlotIndex Pos)) {
425   if (RegUnit.isVirtual()) {
426     const LiveInterval &LI = LIS.getInterval(RegUnit);
427     LaneBitmask Result;
428     if (TrackLaneMasks && LI.hasSubRanges()) {
429         for (const LiveInterval::SubRange &SR : LI.subranges()) {
430           if (Property(SR, Pos))
431             Result |= SR.LaneMask;
432         }
433     } else if (Property(LI, Pos)) {
434       Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
435                               : LaneBitmask::getAll();
436     }
437 
438     return Result;
439   } else {
440     const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
441     // Be prepared for missing liveranges: We usually do not compute liveranges
442     // for physical registers on targets with many registers (GPUs).
443     if (LR == nullptr)
444       return SafeDefault;
445     return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
446   }
447 }
448 
getLiveLanesAt(const LiveIntervals & LIS,const MachineRegisterInfo & MRI,bool TrackLaneMasks,Register RegUnit,SlotIndex Pos)449 static LaneBitmask getLiveLanesAt(const LiveIntervals &LIS,
450                                   const MachineRegisterInfo &MRI,
451                                   bool TrackLaneMasks, Register RegUnit,
452                                   SlotIndex Pos) {
453   return getLanesWithProperty(LIS, MRI, TrackLaneMasks, RegUnit, Pos,
454                               LaneBitmask::getAll(),
455                               [](const LiveRange &LR, SlotIndex Pos) {
456                                 return LR.liveAt(Pos);
457                               });
458 }
459 
460 namespace {
461 
462 /// Collect this instruction's unique uses and defs into SmallVectors for
463 /// processing defs and uses in order.
464 ///
465 /// FIXME: always ignore tied opers
466 class RegisterOperandsCollector {
467   friend class llvm::RegisterOperands;
468 
469   RegisterOperands &RegOpers;
470   const TargetRegisterInfo &TRI;
471   const MachineRegisterInfo &MRI;
472   bool IgnoreDead;
473 
RegisterOperandsCollector(RegisterOperands & RegOpers,const TargetRegisterInfo & TRI,const MachineRegisterInfo & MRI,bool IgnoreDead)474   RegisterOperandsCollector(RegisterOperands &RegOpers,
475                             const TargetRegisterInfo &TRI,
476                             const MachineRegisterInfo &MRI, bool IgnoreDead)
477     : RegOpers(RegOpers), TRI(TRI), MRI(MRI), IgnoreDead(IgnoreDead) {}
478 
collectInstr(const MachineInstr & MI) const479   void collectInstr(const MachineInstr &MI) const {
480     for (ConstMIBundleOperands OperI(MI); OperI.isValid(); ++OperI)
481       collectOperand(*OperI);
482 
483     // Remove redundant physreg dead defs.
484     for (const RegisterMaskPair &P : RegOpers.Defs)
485       removeRegLanes(RegOpers.DeadDefs, P);
486   }
487 
collectInstrLanes(const MachineInstr & MI) const488   void collectInstrLanes(const MachineInstr &MI) const {
489     for (ConstMIBundleOperands OperI(MI); OperI.isValid(); ++OperI)
490       collectOperandLanes(*OperI);
491 
492     // Remove redundant physreg dead defs.
493     for (const RegisterMaskPair &P : RegOpers.Defs)
494       removeRegLanes(RegOpers.DeadDefs, P);
495   }
496 
497   /// Push this operand's register onto the correct vectors.
collectOperand(const MachineOperand & MO) const498   void collectOperand(const MachineOperand &MO) const {
499     if (!MO.isReg() || !MO.getReg())
500       return;
501     Register Reg = MO.getReg();
502     if (MO.isUse()) {
503       if (!MO.isUndef() && !MO.isInternalRead())
504         pushReg(Reg, RegOpers.Uses);
505     } else {
506       assert(MO.isDef());
507       // Subregister definitions may imply a register read.
508       if (MO.readsReg())
509         pushReg(Reg, RegOpers.Uses);
510 
511       if (MO.isDead()) {
512         if (!IgnoreDead)
513           pushReg(Reg, RegOpers.DeadDefs);
514       } else
515         pushReg(Reg, RegOpers.Defs);
516     }
517   }
518 
pushReg(Register Reg,SmallVectorImpl<RegisterMaskPair> & RegUnits) const519   void pushReg(Register Reg,
520                SmallVectorImpl<RegisterMaskPair> &RegUnits) const {
521     if (Reg.isVirtual()) {
522       addRegLanes(RegUnits, RegisterMaskPair(Reg, LaneBitmask::getAll()));
523     } else if (MRI.isAllocatable(Reg)) {
524       for (MCRegUnit Unit : TRI.regunits(Reg.asMCReg()))
525         addRegLanes(RegUnits, RegisterMaskPair(Unit, LaneBitmask::getAll()));
526     }
527   }
528 
collectOperandLanes(const MachineOperand & MO) const529   void collectOperandLanes(const MachineOperand &MO) const {
530     if (!MO.isReg() || !MO.getReg())
531       return;
532     Register Reg = MO.getReg();
533     unsigned SubRegIdx = MO.getSubReg();
534     if (MO.isUse()) {
535       if (!MO.isUndef() && !MO.isInternalRead())
536         pushRegLanes(Reg, SubRegIdx, RegOpers.Uses);
537     } else {
538       assert(MO.isDef());
539       // Treat read-undef subreg defs as definitions of the whole register.
540       if (MO.isUndef())
541         SubRegIdx = 0;
542 
543       if (MO.isDead()) {
544         if (!IgnoreDead)
545           pushRegLanes(Reg, SubRegIdx, RegOpers.DeadDefs);
546       } else
547         pushRegLanes(Reg, SubRegIdx, RegOpers.Defs);
548     }
549   }
550 
pushRegLanes(Register Reg,unsigned SubRegIdx,SmallVectorImpl<RegisterMaskPair> & RegUnits) const551   void pushRegLanes(Register Reg, unsigned SubRegIdx,
552                     SmallVectorImpl<RegisterMaskPair> &RegUnits) const {
553     if (Reg.isVirtual()) {
554       LaneBitmask LaneMask = SubRegIdx != 0
555                              ? TRI.getSubRegIndexLaneMask(SubRegIdx)
556                              : MRI.getMaxLaneMaskForVReg(Reg);
557       addRegLanes(RegUnits, RegisterMaskPair(Reg, LaneMask));
558     } else if (MRI.isAllocatable(Reg)) {
559       for (MCRegUnit Unit : TRI.regunits(Reg.asMCReg()))
560         addRegLanes(RegUnits, RegisterMaskPair(Unit, LaneBitmask::getAll()));
561     }
562   }
563 };
564 
565 } // end anonymous namespace
566 
collect(const MachineInstr & MI,const TargetRegisterInfo & TRI,const MachineRegisterInfo & MRI,bool TrackLaneMasks,bool IgnoreDead)567 void RegisterOperands::collect(const MachineInstr &MI,
568                                const TargetRegisterInfo &TRI,
569                                const MachineRegisterInfo &MRI,
570                                bool TrackLaneMasks, bool IgnoreDead) {
571   RegisterOperandsCollector Collector(*this, TRI, MRI, IgnoreDead);
572   if (TrackLaneMasks)
573     Collector.collectInstrLanes(MI);
574   else
575     Collector.collectInstr(MI);
576 }
577 
detectDeadDefs(const MachineInstr & MI,const LiveIntervals & LIS)578 void RegisterOperands::detectDeadDefs(const MachineInstr &MI,
579                                       const LiveIntervals &LIS) {
580   SlotIndex SlotIdx = LIS.getInstructionIndex(MI);
581   for (auto *RI = Defs.begin(); RI != Defs.end(); /*empty*/) {
582     Register Reg = RI->RegUnit;
583     const LiveRange *LR = getLiveRange(LIS, Reg);
584     if (LR != nullptr) {
585       LiveQueryResult LRQ = LR->Query(SlotIdx);
586       if (LRQ.isDeadDef()) {
587         // LiveIntervals knows this is a dead even though it's MachineOperand is
588         // not flagged as such.
589         DeadDefs.push_back(*RI);
590         RI = Defs.erase(RI);
591         continue;
592       }
593     }
594     ++RI;
595   }
596 }
597 
adjustLaneLiveness(const LiveIntervals & LIS,const MachineRegisterInfo & MRI,SlotIndex Pos,MachineInstr * AddFlagsMI)598 void RegisterOperands::adjustLaneLiveness(const LiveIntervals &LIS,
599                                           const MachineRegisterInfo &MRI,
600                                           SlotIndex Pos,
601                                           MachineInstr *AddFlagsMI) {
602   for (auto *I = Defs.begin(); I != Defs.end();) {
603     LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, I->RegUnit,
604                                            Pos.getDeadSlot());
605     // If the def is all that is live after the instruction, then in case
606     // of a subregister def we need a read-undef flag.
607     Register RegUnit = I->RegUnit;
608     if (RegUnit.isVirtual() && AddFlagsMI != nullptr &&
609         (LiveAfter & ~I->LaneMask).none())
610       AddFlagsMI->setRegisterDefReadUndef(RegUnit);
611 
612     LaneBitmask ActualDef = I->LaneMask & LiveAfter;
613     if (ActualDef.none()) {
614       I = Defs.erase(I);
615     } else {
616       I->LaneMask = ActualDef;
617       ++I;
618     }
619   }
620 
621   // For uses just copy the information from LIS.
622   for (auto &[RegUnit, LaneMask] : Uses)
623     LaneMask = getLiveLanesAt(LIS, MRI, true, RegUnit, Pos.getBaseIndex());
624 
625   if (AddFlagsMI != nullptr) {
626     for (const RegisterMaskPair &P : DeadDefs) {
627       Register RegUnit = P.RegUnit;
628       if (!RegUnit.isVirtual())
629         continue;
630       LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, RegUnit,
631                                              Pos.getDeadSlot());
632       if (LiveAfter.none())
633         AddFlagsMI->setRegisterDefReadUndef(RegUnit);
634     }
635   }
636 }
637 
638 /// Initialize an array of N PressureDiffs.
init(unsigned N)639 void PressureDiffs::init(unsigned N) {
640   Size = N;
641   if (N <= Max) {
642     memset(PDiffArray, 0, N * sizeof(PressureDiff));
643     return;
644   }
645   Max = Size;
646   free(PDiffArray);
647   PDiffArray = static_cast<PressureDiff*>(safe_calloc(N, sizeof(PressureDiff)));
648 }
649 
addInstruction(unsigned Idx,const RegisterOperands & RegOpers,const MachineRegisterInfo & MRI)650 void PressureDiffs::addInstruction(unsigned Idx,
651                                    const RegisterOperands &RegOpers,
652                                    const MachineRegisterInfo &MRI) {
653   PressureDiff &PDiff = (*this)[Idx];
654   assert(!PDiff.begin()->isValid() && "stale PDiff");
655   for (const RegisterMaskPair &P : RegOpers.Defs)
656     PDiff.addPressureChange(P.RegUnit, true, &MRI);
657 
658   for (const RegisterMaskPair &P : RegOpers.Uses)
659     PDiff.addPressureChange(P.RegUnit, false, &MRI);
660 }
661 
662 /// Add a change in pressure to the pressure diff of a given instruction.
addPressureChange(Register RegUnit,bool IsDec,const MachineRegisterInfo * MRI)663 void PressureDiff::addPressureChange(Register RegUnit, bool IsDec,
664                                      const MachineRegisterInfo *MRI) {
665   PSetIterator PSetI = MRI->getPressureSets(RegUnit);
666   int Weight = IsDec ? -PSetI.getWeight() : PSetI.getWeight();
667   for (; PSetI.isValid(); ++PSetI) {
668     // Find an existing entry in the pressure diff for this PSet.
669     PressureDiff::iterator I = nonconst_begin(), E = nonconst_end();
670     for (; I != E && I->isValid(); ++I) {
671       if (I->getPSet() >= *PSetI)
672         break;
673     }
674     // If all pressure sets are more constrained, skip the remaining PSets.
675     if (I == E)
676       break;
677     // Insert this PressureChange.
678     if (!I->isValid() || I->getPSet() != *PSetI) {
679       PressureChange PTmp = PressureChange(*PSetI);
680       for (PressureDiff::iterator J = I; J != E && PTmp.isValid(); ++J)
681         std::swap(*J, PTmp);
682     }
683     // Update the units for this pressure set.
684     unsigned NewUnitInc = I->getUnitInc() + Weight;
685     if (NewUnitInc != 0) {
686       I->setUnitInc(NewUnitInc);
687     } else {
688       // Remove entry
689       PressureDiff::iterator J;
690       for (J = std::next(I); J != E && J->isValid(); ++J, ++I)
691         *I = *J;
692       *I = PressureChange();
693     }
694   }
695 }
696 
697 /// Force liveness of registers.
addLiveRegs(ArrayRef<RegisterMaskPair> Regs)698 void RegPressureTracker::addLiveRegs(ArrayRef<RegisterMaskPair> Regs) {
699   for (const RegisterMaskPair &P : Regs) {
700     LaneBitmask PrevMask = LiveRegs.insert(P);
701     LaneBitmask NewMask = PrevMask | P.LaneMask;
702     increaseRegPressure(P.RegUnit, PrevMask, NewMask);
703   }
704 }
705 
discoverLiveInOrOut(RegisterMaskPair Pair,SmallVectorImpl<RegisterMaskPair> & LiveInOrOut)706 void RegPressureTracker::discoverLiveInOrOut(RegisterMaskPair Pair,
707     SmallVectorImpl<RegisterMaskPair> &LiveInOrOut) {
708   assert(Pair.LaneMask.any());
709 
710   Register RegUnit = Pair.RegUnit;
711   auto I = llvm::find_if(LiveInOrOut, [RegUnit](const RegisterMaskPair &Other) {
712     return Other.RegUnit == RegUnit;
713   });
714   LaneBitmask PrevMask;
715   LaneBitmask NewMask;
716   if (I == LiveInOrOut.end()) {
717     PrevMask = LaneBitmask::getNone();
718     NewMask = Pair.LaneMask;
719     LiveInOrOut.push_back(Pair);
720   } else {
721     PrevMask = I->LaneMask;
722     NewMask = PrevMask | Pair.LaneMask;
723     I->LaneMask = NewMask;
724   }
725   increaseSetPressure(P.MaxSetPressure, *MRI, RegUnit, PrevMask, NewMask);
726 }
727 
discoverLiveIn(RegisterMaskPair Pair)728 void RegPressureTracker::discoverLiveIn(RegisterMaskPair Pair) {
729   discoverLiveInOrOut(Pair, P.LiveInRegs);
730 }
731 
discoverLiveOut(RegisterMaskPair Pair)732 void RegPressureTracker::discoverLiveOut(RegisterMaskPair Pair) {
733   discoverLiveInOrOut(Pair, P.LiveOutRegs);
734 }
735 
bumpDeadDefs(ArrayRef<RegisterMaskPair> DeadDefs)736 void RegPressureTracker::bumpDeadDefs(ArrayRef<RegisterMaskPair> DeadDefs) {
737   for (const RegisterMaskPair &P : DeadDefs) {
738     Register Reg = P.RegUnit;
739     LaneBitmask LiveMask = LiveRegs.contains(Reg);
740     LaneBitmask BumpedMask = LiveMask | P.LaneMask;
741     increaseRegPressure(Reg, LiveMask, BumpedMask);
742   }
743   for (const RegisterMaskPair &P : DeadDefs) {
744     Register Reg = P.RegUnit;
745     LaneBitmask LiveMask = LiveRegs.contains(Reg);
746     LaneBitmask BumpedMask = LiveMask | P.LaneMask;
747     decreaseRegPressure(Reg, BumpedMask, LiveMask);
748   }
749 }
750 
751 /// Recede across the previous instruction. If LiveUses is provided, record any
752 /// RegUnits that are made live by the current instruction's uses. This includes
753 /// registers that are both defined and used by the instruction.  If a pressure
754 /// difference pointer is provided record the changes is pressure caused by this
755 /// instruction independent of liveness.
recede(const RegisterOperands & RegOpers,SmallVectorImpl<RegisterMaskPair> * LiveUses)756 void RegPressureTracker::recede(const RegisterOperands &RegOpers,
757                                 SmallVectorImpl<RegisterMaskPair> *LiveUses) {
758   assert(!CurrPos->isDebugOrPseudoInstr());
759 
760   // Boost pressure for all dead defs together.
761   bumpDeadDefs(RegOpers.DeadDefs);
762 
763   // Kill liveness at live defs.
764   // TODO: consider earlyclobbers?
765   for (const RegisterMaskPair &Def : RegOpers.Defs) {
766     Register Reg = Def.RegUnit;
767 
768     LaneBitmask PreviousMask = LiveRegs.erase(Def);
769     LaneBitmask NewMask = PreviousMask & ~Def.LaneMask;
770 
771     LaneBitmask LiveOut = Def.LaneMask & ~PreviousMask;
772     if (LiveOut.any()) {
773       discoverLiveOut(RegisterMaskPair(Reg, LiveOut));
774       // Retroactively model effects on pressure of the live out lanes.
775       increaseSetPressure(CurrSetPressure, *MRI, Reg, LaneBitmask::getNone(),
776                           LiveOut);
777       PreviousMask = LiveOut;
778     }
779 
780     if (NewMask.none()) {
781       // Add a 0 entry to LiveUses as a marker that the complete vreg has become
782       // dead.
783       if (TrackLaneMasks && LiveUses != nullptr)
784         setRegZero(*LiveUses, Reg);
785     }
786 
787     decreaseRegPressure(Reg, PreviousMask, NewMask);
788   }
789 
790   SlotIndex SlotIdx;
791   if (RequireIntervals)
792     SlotIdx = LIS->getInstructionIndex(*CurrPos).getRegSlot();
793 
794   // Generate liveness for uses.
795   for (const RegisterMaskPair &Use : RegOpers.Uses) {
796     Register Reg = Use.RegUnit;
797     assert(Use.LaneMask.any());
798     LaneBitmask PreviousMask = LiveRegs.insert(Use);
799     LaneBitmask NewMask = PreviousMask | Use.LaneMask;
800     if (NewMask == PreviousMask)
801       continue;
802 
803     // Did the register just become live?
804     if (PreviousMask.none()) {
805       if (LiveUses != nullptr) {
806         if (!TrackLaneMasks) {
807           addRegLanes(*LiveUses, RegisterMaskPair(Reg, NewMask));
808         } else {
809           auto I =
810               llvm::find_if(*LiveUses, [Reg](const RegisterMaskPair Other) {
811                 return Other.RegUnit == Reg;
812               });
813           bool IsRedef = I != LiveUses->end();
814           if (IsRedef) {
815             // ignore re-defs here...
816             assert(I->LaneMask.none());
817             removeRegLanes(*LiveUses, RegisterMaskPair(Reg, NewMask));
818           } else {
819             addRegLanes(*LiveUses, RegisterMaskPair(Reg, NewMask));
820           }
821         }
822       }
823 
824       // Discover live outs if this may be the first occurance of this register.
825       if (RequireIntervals) {
826         LaneBitmask LiveOut = getLiveThroughAt(Reg, SlotIdx);
827         if (LiveOut.any())
828           discoverLiveOut(RegisterMaskPair(Reg, LiveOut));
829       }
830     }
831 
832     increaseRegPressure(Reg, PreviousMask, NewMask);
833   }
834   if (TrackUntiedDefs) {
835     for (const RegisterMaskPair &Def : RegOpers.Defs) {
836       Register RegUnit = Def.RegUnit;
837       if (RegUnit.isVirtual() &&
838           (LiveRegs.contains(RegUnit) & Def.LaneMask).none())
839         UntiedDefs.insert(RegUnit);
840     }
841   }
842 }
843 
recedeSkipDebugValues()844 void RegPressureTracker::recedeSkipDebugValues() {
845   assert(CurrPos != MBB->begin());
846   if (!isBottomClosed())
847     closeBottom();
848 
849   // Open the top of the region using block iterators.
850   if (!RequireIntervals && isTopClosed())
851     static_cast<RegionPressure&>(P).openTop(CurrPos);
852 
853   // Find the previous instruction.
854   CurrPos = prev_nodbg(CurrPos, MBB->begin());
855 
856   SlotIndex SlotIdx;
857   if (RequireIntervals && !CurrPos->isDebugOrPseudoInstr())
858     SlotIdx = LIS->getInstructionIndex(*CurrPos).getRegSlot();
859 
860   // Open the top of the region using slot indexes.
861   if (RequireIntervals && isTopClosed())
862     static_cast<IntervalPressure&>(P).openTop(SlotIdx);
863 }
864 
recede(SmallVectorImpl<RegisterMaskPair> * LiveUses)865 void RegPressureTracker::recede(SmallVectorImpl<RegisterMaskPair> *LiveUses) {
866   recedeSkipDebugValues();
867   if (CurrPos->isDebugInstr() || CurrPos->isPseudoProbe()) {
868     // It's possible to only have debug_value and pseudo probe instructions and
869     // hit the start of the block.
870     assert(CurrPos == MBB->begin());
871     return;
872   }
873 
874   const MachineInstr &MI = *CurrPos;
875   RegisterOperands RegOpers;
876   RegOpers.collect(MI, *TRI, *MRI, TrackLaneMasks, /*IgnoreDead=*/false);
877   if (TrackLaneMasks) {
878     SlotIndex SlotIdx = LIS->getInstructionIndex(*CurrPos).getRegSlot();
879     RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx);
880   } else if (RequireIntervals) {
881     RegOpers.detectDeadDefs(MI, *LIS);
882   }
883 
884   recede(RegOpers, LiveUses);
885 }
886 
887 /// Advance across the current instruction.
advance(const RegisterOperands & RegOpers)888 void RegPressureTracker::advance(const RegisterOperands &RegOpers) {
889   assert(!TrackUntiedDefs && "unsupported mode");
890   assert(CurrPos != MBB->end());
891   if (!isTopClosed())
892     closeTop();
893 
894   SlotIndex SlotIdx;
895   if (RequireIntervals)
896     SlotIdx = getCurrSlot();
897 
898   // Open the bottom of the region using slot indexes.
899   if (isBottomClosed()) {
900     if (RequireIntervals)
901       static_cast<IntervalPressure&>(P).openBottom(SlotIdx);
902     else
903       static_cast<RegionPressure&>(P).openBottom(CurrPos);
904   }
905 
906   for (const RegisterMaskPair &Use : RegOpers.Uses) {
907     Register Reg = Use.RegUnit;
908     LaneBitmask LiveMask = LiveRegs.contains(Reg);
909     LaneBitmask LiveIn = Use.LaneMask & ~LiveMask;
910     if (LiveIn.any()) {
911       discoverLiveIn(RegisterMaskPair(Reg, LiveIn));
912       increaseRegPressure(Reg, LiveMask, LiveMask | LiveIn);
913       LiveRegs.insert(RegisterMaskPair(Reg, LiveIn));
914     }
915     // Kill liveness at last uses.
916     if (RequireIntervals) {
917       LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
918       if (LastUseMask.any()) {
919         LiveRegs.erase(RegisterMaskPair(Reg, LastUseMask));
920         decreaseRegPressure(Reg, LiveMask, LiveMask & ~LastUseMask);
921       }
922     }
923   }
924 
925   // Generate liveness for defs.
926   for (const RegisterMaskPair &Def : RegOpers.Defs) {
927     LaneBitmask PreviousMask = LiveRegs.insert(Def);
928     LaneBitmask NewMask = PreviousMask | Def.LaneMask;
929     increaseRegPressure(Def.RegUnit, PreviousMask, NewMask);
930   }
931 
932   // Boost pressure for all dead defs together.
933   bumpDeadDefs(RegOpers.DeadDefs);
934 
935   // Find the next instruction.
936   CurrPos = next_nodbg(CurrPos, MBB->end());
937 }
938 
advance()939 void RegPressureTracker::advance() {
940   const MachineInstr &MI = *CurrPos;
941   RegisterOperands RegOpers;
942   RegOpers.collect(MI, *TRI, *MRI, TrackLaneMasks, false);
943   if (TrackLaneMasks) {
944     SlotIndex SlotIdx = getCurrSlot();
945     RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx);
946   }
947   advance(RegOpers);
948 }
949 
950 /// Find the max change in excess pressure across all sets.
computeExcessPressureDelta(ArrayRef<unsigned> OldPressureVec,ArrayRef<unsigned> NewPressureVec,RegPressureDelta & Delta,const RegisterClassInfo * RCI,ArrayRef<unsigned> LiveThruPressureVec)951 static void computeExcessPressureDelta(ArrayRef<unsigned> OldPressureVec,
952                                        ArrayRef<unsigned> NewPressureVec,
953                                        RegPressureDelta &Delta,
954                                        const RegisterClassInfo *RCI,
955                                        ArrayRef<unsigned> LiveThruPressureVec) {
956   Delta.Excess = PressureChange();
957   for (unsigned i = 0, e = OldPressureVec.size(); i < e; ++i) {
958     unsigned POld = OldPressureVec[i];
959     unsigned PNew = NewPressureVec[i];
960     int PDiff = (int)PNew - (int)POld;
961     if (!PDiff) // No change in this set in the common case.
962       continue;
963     // Only consider change beyond the limit.
964     unsigned Limit = RCI->getRegPressureSetLimit(i);
965     if (!LiveThruPressureVec.empty())
966       Limit += LiveThruPressureVec[i];
967 
968     if (Limit > POld) {
969       if (Limit > PNew)
970         PDiff = 0;            // Under the limit
971       else
972         PDiff = PNew - Limit; // Just exceeded limit.
973     } else if (Limit > PNew)
974       PDiff = Limit - POld;   // Just obeyed limit.
975 
976     if (PDiff) {
977       Delta.Excess = PressureChange(i);
978       Delta.Excess.setUnitInc(PDiff);
979       break;
980     }
981   }
982 }
983 
984 /// Find the max change in max pressure that either surpasses a critical PSet
985 /// limit or exceeds the current MaxPressureLimit.
986 ///
987 /// FIXME: comparing each element of the old and new MaxPressure vectors here is
988 /// silly. It's done now to demonstrate the concept but will go away with a
989 /// RegPressureTracker API change to work with pressure differences.
computeMaxPressureDelta(ArrayRef<unsigned> OldMaxPressureVec,ArrayRef<unsigned> NewMaxPressureVec,ArrayRef<PressureChange> CriticalPSets,ArrayRef<unsigned> MaxPressureLimit,RegPressureDelta & Delta)990 static void computeMaxPressureDelta(ArrayRef<unsigned> OldMaxPressureVec,
991                                     ArrayRef<unsigned> NewMaxPressureVec,
992                                     ArrayRef<PressureChange> CriticalPSets,
993                                     ArrayRef<unsigned> MaxPressureLimit,
994                                     RegPressureDelta &Delta) {
995   Delta.CriticalMax = PressureChange();
996   Delta.CurrentMax = PressureChange();
997 
998   unsigned CritIdx = 0, CritEnd = CriticalPSets.size();
999   for (unsigned i = 0, e = OldMaxPressureVec.size(); i < e; ++i) {
1000     unsigned POld = OldMaxPressureVec[i];
1001     unsigned PNew = NewMaxPressureVec[i];
1002     if (PNew == POld) // No change in this set in the common case.
1003       continue;
1004 
1005     if (!Delta.CriticalMax.isValid()) {
1006       while (CritIdx != CritEnd && CriticalPSets[CritIdx].getPSet() < i)
1007         ++CritIdx;
1008 
1009       if (CritIdx != CritEnd && CriticalPSets[CritIdx].getPSet() == i) {
1010         int PDiff = (int)PNew - (int)CriticalPSets[CritIdx].getUnitInc();
1011         if (PDiff > 0) {
1012           Delta.CriticalMax = PressureChange(i);
1013           Delta.CriticalMax.setUnitInc(PDiff);
1014         }
1015       }
1016     }
1017     // Find the first increase above MaxPressureLimit.
1018     // (Ignores negative MDiff).
1019     if (!Delta.CurrentMax.isValid() && PNew > MaxPressureLimit[i]) {
1020       Delta.CurrentMax = PressureChange(i);
1021       Delta.CurrentMax.setUnitInc(PNew - POld);
1022       if (CritIdx == CritEnd || Delta.CriticalMax.isValid())
1023         break;
1024     }
1025   }
1026 }
1027 
1028 /// Record the upward impact of a single instruction on current register
1029 /// pressure. Unlike the advance/recede pressure tracking interface, this does
1030 /// not discover live in/outs.
1031 ///
1032 /// This is intended for speculative queries. It leaves pressure inconsistent
1033 /// with the current position, so must be restored by the caller.
bumpUpwardPressure(const MachineInstr * MI)1034 void RegPressureTracker::bumpUpwardPressure(const MachineInstr *MI) {
1035   assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
1036 
1037   SlotIndex SlotIdx;
1038   if (RequireIntervals)
1039     SlotIdx = LIS->getInstructionIndex(*MI).getRegSlot();
1040 
1041   // Account for register pressure similar to RegPressureTracker::recede().
1042   RegisterOperands RegOpers;
1043   RegOpers.collect(*MI, *TRI, *MRI, TrackLaneMasks, /*IgnoreDead=*/true);
1044   assert(RegOpers.DeadDefs.empty());
1045   if (TrackLaneMasks)
1046     RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx);
1047   else if (RequireIntervals)
1048     RegOpers.detectDeadDefs(*MI, *LIS);
1049 
1050   // Boost max pressure for all dead defs together.
1051   // Since CurrSetPressure and MaxSetPressure
1052   bumpDeadDefs(RegOpers.DeadDefs);
1053 
1054   // Kill liveness at live defs.
1055   for (const RegisterMaskPair &P : RegOpers.Defs) {
1056     Register Reg = P.RegUnit;
1057     LaneBitmask LiveAfter = LiveRegs.contains(Reg);
1058     LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, Reg);
1059     LaneBitmask DefLanes = P.LaneMask;
1060     LaneBitmask LiveBefore = (LiveAfter & ~DefLanes) | UseLanes;
1061 
1062     // There may be parts of the register that were dead before the
1063     // instruction, but became live afterwards. Similarly, some parts
1064     // may have been killed in this instruction.
1065     decreaseRegPressure(Reg, LiveAfter, LiveAfter & LiveBefore);
1066     increaseRegPressure(Reg, LiveAfter, ~LiveAfter & LiveBefore);
1067   }
1068   // Generate liveness for uses.
1069   for (const RegisterMaskPair &P : RegOpers.Uses) {
1070     Register Reg = P.RegUnit;
1071     // If this register was also in a def operand, we've handled it
1072     // with defs.
1073     if (getRegLanes(RegOpers.Defs, Reg).any())
1074       continue;
1075     LaneBitmask LiveAfter = LiveRegs.contains(Reg);
1076     LaneBitmask LiveBefore = LiveAfter | P.LaneMask;
1077     increaseRegPressure(Reg, LiveAfter, LiveBefore);
1078   }
1079 }
1080 
1081 /// Consider the pressure increase caused by traversing this instruction
1082 /// bottom-up. Find the pressure set with the most change beyond its pressure
1083 /// limit based on the tracker's current pressure, and return the change in
1084 /// number of register units of that pressure set introduced by this
1085 /// instruction.
1086 ///
1087 /// This assumes that the current LiveOut set is sufficient.
1088 ///
1089 /// This is expensive for an on-the-fly query because it calls
1090 /// bumpUpwardPressure to recompute the pressure sets based on current
1091 /// liveness. This mainly exists to verify correctness, e.g. with
1092 /// -verify-misched. getUpwardPressureDelta is the fast version of this query
1093 /// that uses the per-SUnit cache of the PressureDiff.
1094 void RegPressureTracker::
getMaxUpwardPressureDelta(const MachineInstr * MI,PressureDiff * PDiff,RegPressureDelta & Delta,ArrayRef<PressureChange> CriticalPSets,ArrayRef<unsigned> MaxPressureLimit)1095 getMaxUpwardPressureDelta(const MachineInstr *MI, PressureDiff *PDiff,
1096                           RegPressureDelta &Delta,
1097                           ArrayRef<PressureChange> CriticalPSets,
1098                           ArrayRef<unsigned> MaxPressureLimit) {
1099   // Snapshot Pressure.
1100   // FIXME: The snapshot heap space should persist. But I'm planning to
1101   // summarize the pressure effect so we don't need to snapshot at all.
1102   std::vector<unsigned> SavedPressure = CurrSetPressure;
1103   std::vector<unsigned> SavedMaxPressure = P.MaxSetPressure;
1104 
1105   bumpUpwardPressure(MI);
1106 
1107   computeExcessPressureDelta(SavedPressure, CurrSetPressure, Delta, RCI,
1108                              LiveThruPressure);
1109   computeMaxPressureDelta(SavedMaxPressure, P.MaxSetPressure, CriticalPSets,
1110                           MaxPressureLimit, Delta);
1111   assert(Delta.CriticalMax.getUnitInc() >= 0 &&
1112          Delta.CurrentMax.getUnitInc() >= 0 && "cannot decrease max pressure");
1113 
1114   // Restore the tracker's state.
1115   P.MaxSetPressure.swap(SavedMaxPressure);
1116   CurrSetPressure.swap(SavedPressure);
1117 
1118 #ifndef NDEBUG
1119   if (!PDiff)
1120     return;
1121 
1122   // Check if the alternate algorithm yields the same result.
1123   RegPressureDelta Delta2;
1124   getUpwardPressureDelta(MI, *PDiff, Delta2, CriticalPSets, MaxPressureLimit);
1125   if (Delta != Delta2) {
1126     dbgs() << "PDiff: ";
1127     PDiff->dump(*TRI);
1128     dbgs() << "DELTA: " << *MI;
1129     if (Delta.Excess.isValid())
1130       dbgs() << "Excess1 " << TRI->getRegPressureSetName(Delta.Excess.getPSet())
1131              << " " << Delta.Excess.getUnitInc() << "\n";
1132     if (Delta.CriticalMax.isValid())
1133       dbgs() << "Critic1 " << TRI->getRegPressureSetName(Delta.CriticalMax.getPSet())
1134              << " " << Delta.CriticalMax.getUnitInc() << "\n";
1135     if (Delta.CurrentMax.isValid())
1136       dbgs() << "CurrMx1 " << TRI->getRegPressureSetName(Delta.CurrentMax.getPSet())
1137              << " " << Delta.CurrentMax.getUnitInc() << "\n";
1138     if (Delta2.Excess.isValid())
1139       dbgs() << "Excess2 " << TRI->getRegPressureSetName(Delta2.Excess.getPSet())
1140              << " " << Delta2.Excess.getUnitInc() << "\n";
1141     if (Delta2.CriticalMax.isValid())
1142       dbgs() << "Critic2 " << TRI->getRegPressureSetName(Delta2.CriticalMax.getPSet())
1143              << " " << Delta2.CriticalMax.getUnitInc() << "\n";
1144     if (Delta2.CurrentMax.isValid())
1145       dbgs() << "CurrMx2 " << TRI->getRegPressureSetName(Delta2.CurrentMax.getPSet())
1146              << " " << Delta2.CurrentMax.getUnitInc() << "\n";
1147     llvm_unreachable("RegP Delta Mismatch");
1148   }
1149 #endif
1150 }
1151 
1152 /// This is the fast version of querying register pressure that does not
1153 /// directly depend on current liveness.
1154 ///
1155 /// @param Delta captures information needed for heuristics.
1156 ///
1157 /// @param CriticalPSets Are the pressure sets that are known to exceed some
1158 /// limit within the region, not necessarily at the current position.
1159 ///
1160 /// @param MaxPressureLimit Is the max pressure within the region, not
1161 /// necessarily at the current position.
1162 void RegPressureTracker::
getUpwardPressureDelta(const MachineInstr * MI,PressureDiff & PDiff,RegPressureDelta & Delta,ArrayRef<PressureChange> CriticalPSets,ArrayRef<unsigned> MaxPressureLimit) const1163 getUpwardPressureDelta(const MachineInstr *MI, /*const*/ PressureDiff &PDiff,
1164                        RegPressureDelta &Delta,
1165                        ArrayRef<PressureChange> CriticalPSets,
1166                        ArrayRef<unsigned> MaxPressureLimit) const {
1167   unsigned CritIdx = 0, CritEnd = CriticalPSets.size();
1168   for (PressureDiff::const_iterator
1169          PDiffI = PDiff.begin(), PDiffE = PDiff.end();
1170        PDiffI != PDiffE && PDiffI->isValid(); ++PDiffI) {
1171 
1172     unsigned PSetID = PDiffI->getPSet();
1173     unsigned Limit = RCI->getRegPressureSetLimit(PSetID);
1174     if (!LiveThruPressure.empty())
1175       Limit += LiveThruPressure[PSetID];
1176 
1177     unsigned POld = CurrSetPressure[PSetID];
1178     unsigned MOld = P.MaxSetPressure[PSetID];
1179     unsigned MNew = MOld;
1180     // Ignore DeadDefs here because they aren't captured by PressureChange.
1181     unsigned PNew = POld + PDiffI->getUnitInc();
1182     assert((PDiffI->getUnitInc() >= 0) == (PNew >= POld)
1183            && "PSet overflow/underflow");
1184     if (PNew > MOld)
1185       MNew = PNew;
1186     // Check if current pressure has exceeded the limit.
1187     if (!Delta.Excess.isValid()) {
1188       unsigned ExcessInc = 0;
1189       if (PNew > Limit)
1190         ExcessInc = POld > Limit ? PNew - POld : PNew - Limit;
1191       else if (POld > Limit)
1192         ExcessInc = Limit - POld;
1193       if (ExcessInc) {
1194         Delta.Excess = PressureChange(PSetID);
1195         Delta.Excess.setUnitInc(ExcessInc);
1196       }
1197     }
1198     // Check if max pressure has exceeded a critical pressure set max.
1199     if (MNew == MOld)
1200       continue;
1201     if (!Delta.CriticalMax.isValid()) {
1202       while (CritIdx != CritEnd && CriticalPSets[CritIdx].getPSet() < PSetID)
1203         ++CritIdx;
1204 
1205       if (CritIdx != CritEnd && CriticalPSets[CritIdx].getPSet() == PSetID) {
1206         int CritInc = (int)MNew - (int)CriticalPSets[CritIdx].getUnitInc();
1207         if (CritInc > 0 && CritInc <= std::numeric_limits<int16_t>::max()) {
1208           Delta.CriticalMax = PressureChange(PSetID);
1209           Delta.CriticalMax.setUnitInc(CritInc);
1210         }
1211       }
1212     }
1213     // Check if max pressure has exceeded the current max.
1214     if (!Delta.CurrentMax.isValid() && MNew > MaxPressureLimit[PSetID]) {
1215       Delta.CurrentMax = PressureChange(PSetID);
1216       Delta.CurrentMax.setUnitInc(MNew - MOld);
1217     }
1218   }
1219 }
1220 
1221 /// Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx).
1222 /// The query starts with a lane bitmask which gets lanes/bits removed for every
1223 /// use we find.
findUseBetween(unsigned Reg,LaneBitmask LastUseMask,SlotIndex PriorUseIdx,SlotIndex NextUseIdx,const MachineRegisterInfo & MRI,const LiveIntervals * LIS)1224 static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
1225                                   SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
1226                                   const MachineRegisterInfo &MRI,
1227                                   const LiveIntervals *LIS) {
1228   const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
1229   for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
1230     if (MO.isUndef())
1231       continue;
1232     const MachineInstr *MI = MO.getParent();
1233     SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
1234     if (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx) {
1235       unsigned SubRegIdx = MO.getSubReg();
1236       LaneBitmask UseMask = TRI.getSubRegIndexLaneMask(SubRegIdx);
1237       LastUseMask &= ~UseMask;
1238       if (LastUseMask.none())
1239         return LaneBitmask::getNone();
1240     }
1241   }
1242   return LastUseMask;
1243 }
1244 
getLiveLanesAt(Register RegUnit,SlotIndex Pos) const1245 LaneBitmask RegPressureTracker::getLiveLanesAt(Register RegUnit,
1246                                                SlotIndex Pos) const {
1247   assert(RequireIntervals);
1248   return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos,
1249                               LaneBitmask::getAll(),
1250       [](const LiveRange &LR, SlotIndex Pos) {
1251         return LR.liveAt(Pos);
1252       });
1253 }
1254 
getLastUsedLanes(Register RegUnit,SlotIndex Pos) const1255 LaneBitmask RegPressureTracker::getLastUsedLanes(Register RegUnit,
1256                                                  SlotIndex Pos) const {
1257   assert(RequireIntervals);
1258   return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit,
1259                               Pos.getBaseIndex(), LaneBitmask::getNone(),
1260       [](const LiveRange &LR, SlotIndex Pos) {
1261         const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
1262         return S != nullptr && S->end == Pos.getRegSlot();
1263       });
1264 }
1265 
getLiveThroughAt(Register RegUnit,SlotIndex Pos) const1266 LaneBitmask RegPressureTracker::getLiveThroughAt(Register RegUnit,
1267                                                  SlotIndex Pos) const {
1268   assert(RequireIntervals);
1269   return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos,
1270                               LaneBitmask::getNone(),
1271       [](const LiveRange &LR, SlotIndex Pos) {
1272         const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
1273         return S != nullptr && S->start < Pos.getRegSlot(true) &&
1274                S->end != Pos.getDeadSlot();
1275       });
1276 }
1277 
1278 /// Record the downward impact of a single instruction on current register
1279 /// pressure. Unlike the advance/recede pressure tracking interface, this does
1280 /// not discover live in/outs.
1281 ///
1282 /// This is intended for speculative queries. It leaves pressure inconsistent
1283 /// with the current position, so must be restored by the caller.
bumpDownwardPressure(const MachineInstr * MI)1284 void RegPressureTracker::bumpDownwardPressure(const MachineInstr *MI) {
1285   assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
1286 
1287   SlotIndex SlotIdx;
1288   if (RequireIntervals)
1289     SlotIdx = LIS->getInstructionIndex(*MI).getRegSlot();
1290 
1291   // Account for register pressure similar to RegPressureTracker::advance().
1292   RegisterOperands RegOpers;
1293   RegOpers.collect(*MI, *TRI, *MRI, TrackLaneMasks, /*IgnoreDead=*/false);
1294   if (TrackLaneMasks)
1295     RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx);
1296 
1297   if (RequireIntervals) {
1298     for (const RegisterMaskPair &Use : RegOpers.Uses) {
1299       Register Reg = Use.RegUnit;
1300       LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
1301       if (LastUseMask.none())
1302         continue;
1303       // The LastUseMask is queried from the liveness information of instruction
1304       // which may be further down the schedule. Some lanes may actually not be
1305       // last uses for the current position.
1306       // FIXME: allow the caller to pass in the list of vreg uses that remain
1307       // to be bottom-scheduled to avoid searching uses at each query.
1308       SlotIndex CurrIdx = getCurrSlot();
1309       LastUseMask
1310         = findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, LIS);
1311       if (LastUseMask.none())
1312         continue;
1313 
1314       LaneBitmask LiveMask = LiveRegs.contains(Reg);
1315       LaneBitmask NewMask = LiveMask & ~LastUseMask;
1316       decreaseRegPressure(Reg, LiveMask, NewMask);
1317     }
1318   }
1319 
1320   // Generate liveness for defs.
1321   for (const RegisterMaskPair &Def : RegOpers.Defs) {
1322     Register Reg = Def.RegUnit;
1323     LaneBitmask LiveMask = LiveRegs.contains(Reg);
1324     LaneBitmask NewMask = LiveMask | Def.LaneMask;
1325     increaseRegPressure(Reg, LiveMask, NewMask);
1326   }
1327 
1328   // Boost pressure for all dead defs together.
1329   bumpDeadDefs(RegOpers.DeadDefs);
1330 }
1331 
1332 /// Consider the pressure increase caused by traversing this instruction
1333 /// top-down. Find the register class with the most change in its pressure limit
1334 /// based on the tracker's current pressure, and return the number of excess
1335 /// register units of that pressure set introduced by this instruction.
1336 ///
1337 /// This assumes that the current LiveIn set is sufficient.
1338 ///
1339 /// This is expensive for an on-the-fly query because it calls
1340 /// bumpDownwardPressure to recompute the pressure sets based on current
1341 /// liveness. We don't yet have a fast version of downward pressure tracking
1342 /// analogous to getUpwardPressureDelta.
1343 void RegPressureTracker::
getMaxDownwardPressureDelta(const MachineInstr * MI,RegPressureDelta & Delta,ArrayRef<PressureChange> CriticalPSets,ArrayRef<unsigned> MaxPressureLimit)1344 getMaxDownwardPressureDelta(const MachineInstr *MI, RegPressureDelta &Delta,
1345                             ArrayRef<PressureChange> CriticalPSets,
1346                             ArrayRef<unsigned> MaxPressureLimit) {
1347   // Snapshot Pressure.
1348   std::vector<unsigned> SavedPressure = CurrSetPressure;
1349   std::vector<unsigned> SavedMaxPressure = P.MaxSetPressure;
1350 
1351   bumpDownwardPressure(MI);
1352 
1353   computeExcessPressureDelta(SavedPressure, CurrSetPressure, Delta, RCI,
1354                              LiveThruPressure);
1355   computeMaxPressureDelta(SavedMaxPressure, P.MaxSetPressure, CriticalPSets,
1356                           MaxPressureLimit, Delta);
1357   assert(Delta.CriticalMax.getUnitInc() >= 0 &&
1358          Delta.CurrentMax.getUnitInc() >= 0 && "cannot decrease max pressure");
1359 
1360   // Restore the tracker's state.
1361   P.MaxSetPressure.swap(SavedMaxPressure);
1362   CurrSetPressure.swap(SavedPressure);
1363 }
1364 
1365 /// Get the pressure of each PSet after traversing this instruction bottom-up.
1366 void RegPressureTracker::
getUpwardPressure(const MachineInstr * MI,std::vector<unsigned> & PressureResult,std::vector<unsigned> & MaxPressureResult)1367 getUpwardPressure(const MachineInstr *MI,
1368                   std::vector<unsigned> &PressureResult,
1369                   std::vector<unsigned> &MaxPressureResult) {
1370   // Snapshot pressure.
1371   PressureResult = CurrSetPressure;
1372   MaxPressureResult = P.MaxSetPressure;
1373 
1374   bumpUpwardPressure(MI);
1375 
1376   // Current pressure becomes the result. Restore current pressure.
1377   P.MaxSetPressure.swap(MaxPressureResult);
1378   CurrSetPressure.swap(PressureResult);
1379 }
1380 
1381 /// Get the pressure of each PSet after traversing this instruction top-down.
1382 void RegPressureTracker::
getDownwardPressure(const MachineInstr * MI,std::vector<unsigned> & PressureResult,std::vector<unsigned> & MaxPressureResult)1383 getDownwardPressure(const MachineInstr *MI,
1384                     std::vector<unsigned> &PressureResult,
1385                     std::vector<unsigned> &MaxPressureResult) {
1386   // Snapshot pressure.
1387   PressureResult = CurrSetPressure;
1388   MaxPressureResult = P.MaxSetPressure;
1389 
1390   bumpDownwardPressure(MI);
1391 
1392   // Current pressure becomes the result. Restore current pressure.
1393   P.MaxSetPressure.swap(MaxPressureResult);
1394   CurrSetPressure.swap(PressureResult);
1395 }
1396