xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp (revision d5e3895ea4fe4ef9db8823774e07b4368180a23e)
1 //===-- FixupStatepointCallerSaved.cpp - Fixup caller saved registers  ----===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Statepoint instruction in deopt parameters contains values which are
12 /// meaningful to the runtime and should be able to be read at the moment the
13 /// call returns. So we can say that we need to encode the fact that these
14 /// values are "late read" by runtime. If we could express this notion for
15 /// register allocator it would produce the right form for us.
16 /// The need to fixup (i.e this pass) is specifically handling the fact that
17 /// we cannot describe such a late read for the register allocator.
18 /// Register allocator may put the value on a register clobbered by the call.
19 /// This pass forces the spill of such registers and replaces corresponding
20 /// statepoint operands to added spill slots.
21 ///
22 //===----------------------------------------------------------------------===//
23 
24 #include "llvm/ADT/SmallSet.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/StackMaps.h"
31 #include "llvm/CodeGen/TargetFrameLowering.h"
32 #include "llvm/CodeGen/TargetInstrInfo.h"
33 #include "llvm/IR/Statepoint.h"
34 #include "llvm/InitializePasses.h"
35 #include "llvm/Support/Debug.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "fixup-statepoint-caller-saved"
40 STATISTIC(NumSpilledRegisters, "Number of spilled register");
41 STATISTIC(NumSpillSlotsAllocated, "Number of spill slots allocated");
42 STATISTIC(NumSpillSlotsExtended, "Number of spill slots extended");
43 
44 static cl::opt<bool> FixupSCSExtendSlotSize(
45     "fixup-scs-extend-slot-size", cl::Hidden, cl::init(false),
46     cl::desc("Allow spill in spill slot of greater size than register size"),
47     cl::Hidden);
48 
49 namespace {
50 
51 class FixupStatepointCallerSaved : public MachineFunctionPass {
52 public:
53   static char ID;
54 
55   FixupStatepointCallerSaved() : MachineFunctionPass(ID) {
56     initializeFixupStatepointCallerSavedPass(*PassRegistry::getPassRegistry());
57   }
58 
59   void getAnalysisUsage(AnalysisUsage &AU) const override {
60     AU.setPreservesCFG();
61     MachineFunctionPass::getAnalysisUsage(AU);
62   }
63 
64   StringRef getPassName() const override {
65     return "Fixup Statepoint Caller Saved";
66   }
67 
68   bool runOnMachineFunction(MachineFunction &MF) override;
69 };
70 } // End anonymous namespace.
71 
72 char FixupStatepointCallerSaved::ID = 0;
73 char &llvm::FixupStatepointCallerSavedID = FixupStatepointCallerSaved::ID;
74 
75 INITIALIZE_PASS_BEGIN(FixupStatepointCallerSaved, DEBUG_TYPE,
76                       "Fixup Statepoint Caller Saved", false, false)
77 INITIALIZE_PASS_END(FixupStatepointCallerSaved, DEBUG_TYPE,
78                     "Fixup Statepoint Caller Saved", false, false)
79 
80 // Utility function to get size of the register.
81 static unsigned getRegisterSize(const TargetRegisterInfo &TRI, Register Reg) {
82   const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
83   return TRI.getSpillSize(*RC);
84 }
85 
86 namespace {
87 // Cache used frame indexes during statepoint re-write to re-use them in
88 // processing next statepoint instruction.
89 // Two strategies. One is to preserve the size of spill slot while another one
90 // extends the size of spill slots to reduce the number of them, causing
91 // the less total frame size. But unspill will have "implicit" any extend.
92 class FrameIndexesCache {
93 private:
94   struct FrameIndexesPerSize {
95     // List of used frame indexes during processing previous statepoints.
96     SmallVector<int, 8> Slots;
97     // Current index of un-used yet frame index.
98     unsigned Index = 0;
99   };
100   MachineFrameInfo &MFI;
101   const TargetRegisterInfo &TRI;
102   // Map size to list of frame indexes of this size. If the mode is
103   // FixupSCSExtendSlotSize then the key 0 is used to keep all frame indexes.
104   // If the size of required spill slot is greater than in a cache then the
105   // size will be increased.
106   DenseMap<unsigned, FrameIndexesPerSize> Cache;
107 
108 public:
109   FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI)
110       : MFI(MFI), TRI(TRI) {}
111   // Reset the current state of used frame indexes. After invocation of
112   // this function all frame indexes are available for allocation.
113   void reset() {
114     for (auto &It : Cache)
115       It.second.Index = 0;
116   }
117   // Get frame index to spill the register.
118   int getFrameIndex(Register Reg) {
119     unsigned Size = getRegisterSize(TRI, Reg);
120     // In FixupSCSExtendSlotSize mode the bucket with 0 index is used
121     // for all sizes.
122     unsigned Bucket = FixupSCSExtendSlotSize ? 0 : Size;
123     FrameIndexesPerSize &Line = Cache[Bucket];
124     if (Line.Index < Line.Slots.size()) {
125       int FI = Line.Slots[Line.Index++];
126       // If all sizes are kept together we probably need to extend the
127       // spill slot size.
128       if (MFI.getObjectSize(FI) < Size) {
129         MFI.setObjectSize(FI, Size);
130         MFI.setObjectAlignment(FI, Align(Size));
131         NumSpillSlotsExtended++;
132       }
133       return FI;
134     }
135     int FI = MFI.CreateSpillStackObject(Size, Align(Size));
136     NumSpillSlotsAllocated++;
137     Line.Slots.push_back(FI);
138     ++Line.Index;
139     return FI;
140   }
141   // Sort all registers to spill in descendent order. In the
142   // FixupSCSExtendSlotSize mode it will minimize the total frame size.
143   // In non FixupSCSExtendSlotSize mode we can skip this step.
144   void sortRegisters(SmallVectorImpl<Register> &Regs) {
145     if (!FixupSCSExtendSlotSize)
146       return;
147     llvm::sort(Regs.begin(), Regs.end(), [&](Register &A, Register &B) {
148       return getRegisterSize(TRI, A) > getRegisterSize(TRI, B);
149     });
150   }
151 };
152 
153 // Describes the state of the current processing statepoint instruction.
154 class StatepointState {
155 private:
156   // statepoint instruction.
157   MachineInstr &MI;
158   MachineFunction &MF;
159   const TargetRegisterInfo &TRI;
160   const TargetInstrInfo &TII;
161   MachineFrameInfo &MFI;
162   // Mask with callee saved registers.
163   const uint32_t *Mask;
164   // Cache of frame indexes used on previous instruction processing.
165   FrameIndexesCache &CacheFI;
166   // Operands with physical registers requiring spilling.
167   SmallVector<unsigned, 8> OpsToSpill;
168   // Set of register to spill.
169   SmallVector<Register, 8> RegsToSpill;
170   // Map Register to Frame Slot index.
171   DenseMap<Register, int> RegToSlotIdx;
172 
173 public:
174   StatepointState(MachineInstr &MI, const uint32_t *Mask,
175                   FrameIndexesCache &CacheFI)
176       : MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()),
177         TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()),
178         Mask(Mask), CacheFI(CacheFI) {}
179   // Return true if register is callee saved.
180   bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; }
181   // Iterates over statepoint meta args to find caller saver registers.
182   // Also cache the size of found registers.
183   // Returns true if caller save registers found.
184   bool findRegistersToSpill() {
185     SmallSet<Register, 8> VisitedRegs;
186     for (unsigned Idx = StatepointOpers(&MI).getVarIdx(),
187                   EndIdx = MI.getNumOperands();
188          Idx < EndIdx; ++Idx) {
189       MachineOperand &MO = MI.getOperand(Idx);
190       if (!MO.isReg() || MO.isImplicit())
191         continue;
192       Register Reg = MO.getReg();
193       assert(Reg.isPhysical() && "Only physical regs are expected");
194       if (isCalleeSaved(Reg))
195         continue;
196       if (VisitedRegs.insert(Reg).second)
197         RegsToSpill.push_back(Reg);
198       OpsToSpill.push_back(Idx);
199     }
200     CacheFI.sortRegisters(RegsToSpill);
201     return !RegsToSpill.empty();
202   }
203   // Spill all caller saved registers right before statepoint instruction.
204   // Remember frame index where register is spilled.
205   void spillRegisters() {
206     for (Register Reg : RegsToSpill) {
207       int FI = CacheFI.getFrameIndex(Reg);
208       const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
209       TII.storeRegToStackSlot(*MI.getParent(), MI, Reg, true /*is_Kill*/, FI,
210                               RC, &TRI);
211       NumSpilledRegisters++;
212       RegToSlotIdx[Reg] = FI;
213     }
214   }
215   // Re-write statepoint machine instruction to replace caller saved operands
216   // with indirect memory location (frame index).
217   void rewriteStatepoint() {
218     MachineInstr *NewMI =
219         MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true);
220     MachineInstrBuilder MIB(MF, NewMI);
221 
222     // Add End marker.
223     OpsToSpill.push_back(MI.getNumOperands());
224     unsigned CurOpIdx = 0;
225 
226     for (unsigned I = 0; I < MI.getNumOperands(); ++I) {
227       MachineOperand &MO = MI.getOperand(I);
228       if (I == OpsToSpill[CurOpIdx]) {
229         int FI = RegToSlotIdx[MO.getReg()];
230         MIB.addImm(StackMaps::IndirectMemRefOp);
231         MIB.addImm(getRegisterSize(TRI, MO.getReg()));
232         assert(MO.isReg() && "Should be register");
233         assert(MO.getReg().isPhysical() && "Should be physical register");
234         MIB.addFrameIndex(FI);
235         MIB.addImm(0);
236         ++CurOpIdx;
237       } else
238         MIB.add(MO);
239     }
240     assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed");
241     // Add mem operands.
242     NewMI->setMemRefs(MF, MI.memoperands());
243     for (auto It : RegToSlotIdx) {
244       int FrameIndex = It.second;
245       auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
246       auto *MMO = MF.getMachineMemOperand(PtrInfo, MachineMemOperand::MOLoad,
247                                           getRegisterSize(TRI, It.first),
248                                           MFI.getObjectAlign(FrameIndex));
249       NewMI->addMemOperand(MF, MMO);
250     }
251     // Insert new statepoint and erase old one.
252     MI.getParent()->insert(MI, NewMI);
253     MI.eraseFromParent();
254   }
255 };
256 
257 class StatepointProcessor {
258 private:
259   MachineFunction &MF;
260   const TargetRegisterInfo &TRI;
261   FrameIndexesCache CacheFI;
262 
263 public:
264   StatepointProcessor(MachineFunction &MF)
265       : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()),
266         CacheFI(MF.getFrameInfo(), TRI) {}
267 
268   bool process(MachineInstr &MI) {
269     StatepointOpers SO(&MI);
270     uint64_t Flags = SO.getFlags();
271     // Do nothing for LiveIn, it supports all registers.
272     if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn)
273       return false;
274     CallingConv::ID CC = SO.getCallingConv();
275     const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC);
276     CacheFI.reset();
277     StatepointState SS(MI, Mask, CacheFI);
278 
279     if (!SS.findRegistersToSpill())
280       return false;
281 
282     SS.spillRegisters();
283     SS.rewriteStatepoint();
284     return true;
285   }
286 };
287 } // namespace
288 
289 bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) {
290   if (skipFunction(MF.getFunction()))
291     return false;
292 
293   const Function &F = MF.getFunction();
294   if (!F.hasGC())
295     return false;
296 
297   SmallVector<MachineInstr *, 16> Statepoints;
298   for (MachineBasicBlock &BB : MF)
299     for (MachineInstr &I : BB)
300       if (I.getOpcode() == TargetOpcode::STATEPOINT)
301         Statepoints.push_back(&I);
302 
303   if (Statepoints.empty())
304     return false;
305 
306   bool Changed = false;
307   StatepointProcessor SPP(MF);
308   for (MachineInstr *I : Statepoints)
309     Changed |= SPP.process(*I);
310   return Changed;
311 }
312