xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp (revision 8311bc5f17dec348749f763b82dfe2737bc53cd7)
1 //===-- FixupStatepointCallerSaved.cpp - Fixup caller saved registers  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Statepoint instruction in deopt parameters contains values which are
11 /// meaningful to the runtime and should be able to be read at the moment the
12 /// call returns. So we can say that we need to encode the fact that these
13 /// values are "late read" by runtime. If we could express this notion for
14 /// register allocator it would produce the right form for us.
15 /// The need to fixup (i.e this pass) is specifically handling the fact that
16 /// we cannot describe such a late read for the register allocator.
17 /// Register allocator may put the value on a register clobbered by the call.
18 /// This pass forces the spill of such registers and replaces corresponding
19 /// statepoint operands to added spill slots.
20 ///
21 //===----------------------------------------------------------------------===//
22 
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/StackMaps.h"
28 #include "llvm/CodeGen/TargetInstrInfo.h"
29 #include "llvm/IR/Statepoint.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Support/Debug.h"
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "fixup-statepoint-caller-saved"
36 STATISTIC(NumSpilledRegisters, "Number of spilled register");
37 STATISTIC(NumSpillSlotsAllocated, "Number of spill slots allocated");
38 STATISTIC(NumSpillSlotsExtended, "Number of spill slots extended");
39 
40 static cl::opt<bool> FixupSCSExtendSlotSize(
41     "fixup-scs-extend-slot-size", cl::Hidden, cl::init(false),
42     cl::desc("Allow spill in spill slot of greater size than register size"),
43     cl::Hidden);
44 
45 static cl::opt<bool> PassGCPtrInCSR(
46     "fixup-allow-gcptr-in-csr", cl::Hidden, cl::init(false),
47     cl::desc("Allow passing GC Pointer arguments in callee saved registers"));
48 
49 static cl::opt<bool> EnableCopyProp(
50     "fixup-scs-enable-copy-propagation", cl::Hidden, cl::init(true),
51     cl::desc("Enable simple copy propagation during register reloading"));
52 
53 // This is purely debugging option.
54 // It may be handy for investigating statepoint spilling issues.
55 static cl::opt<unsigned> MaxStatepointsWithRegs(
56     "fixup-max-csr-statepoints", cl::Hidden,
57     cl::desc("Max number of statepoints allowed to pass GC Ptrs in registers"));
58 
59 namespace {
60 
61 class FixupStatepointCallerSaved : public MachineFunctionPass {
62 public:
63   static char ID;
64 
65   FixupStatepointCallerSaved() : MachineFunctionPass(ID) {
66     initializeFixupStatepointCallerSavedPass(*PassRegistry::getPassRegistry());
67   }
68 
69   void getAnalysisUsage(AnalysisUsage &AU) const override {
70     AU.setPreservesCFG();
71     MachineFunctionPass::getAnalysisUsage(AU);
72   }
73 
74   StringRef getPassName() const override {
75     return "Fixup Statepoint Caller Saved";
76   }
77 
78   bool runOnMachineFunction(MachineFunction &MF) override;
79 };
80 
81 } // End anonymous namespace.
82 
83 char FixupStatepointCallerSaved::ID = 0;
84 char &llvm::FixupStatepointCallerSavedID = FixupStatepointCallerSaved::ID;
85 
86 INITIALIZE_PASS_BEGIN(FixupStatepointCallerSaved, DEBUG_TYPE,
87                       "Fixup Statepoint Caller Saved", false, false)
88 INITIALIZE_PASS_END(FixupStatepointCallerSaved, DEBUG_TYPE,
89                     "Fixup Statepoint Caller Saved", false, false)
90 
91 // Utility function to get size of the register.
92 static unsigned getRegisterSize(const TargetRegisterInfo &TRI, Register Reg) {
93   const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
94   return TRI.getSpillSize(*RC);
95 }
96 
97 // Try to eliminate redundant copy to register which we're going to
98 // spill, i.e. try to change:
99 //    X = COPY Y
100 //    SPILL X
101 //  to
102 //    SPILL Y
103 //  If there are no uses of X between copy and STATEPOINT, that COPY
104 //  may be eliminated.
105 //  Reg - register we're about to spill
106 //  RI - On entry points to statepoint.
107 //       On successful copy propagation set to new spill point.
108 //  IsKill - set to true if COPY is Kill (there are no uses of Y)
109 //  Returns either found source copy register or original one.
110 static Register performCopyPropagation(Register Reg,
111                                        MachineBasicBlock::iterator &RI,
112                                        bool &IsKill, const TargetInstrInfo &TII,
113                                        const TargetRegisterInfo &TRI) {
114   // First check if statepoint itself uses Reg in non-meta operands.
115   int Idx = RI->findRegisterUseOperandIdx(Reg, false, &TRI);
116   if (Idx >= 0 && (unsigned)Idx < StatepointOpers(&*RI).getNumDeoptArgsIdx()) {
117     IsKill = false;
118     return Reg;
119   }
120 
121   if (!EnableCopyProp)
122     return Reg;
123 
124   MachineBasicBlock *MBB = RI->getParent();
125   MachineBasicBlock::reverse_iterator E = MBB->rend();
126   MachineInstr *Def = nullptr, *Use = nullptr;
127   for (auto It = ++(RI.getReverse()); It != E; ++It) {
128     if (It->readsRegister(Reg, &TRI) && !Use)
129       Use = &*It;
130     if (It->modifiesRegister(Reg, &TRI)) {
131       Def = &*It;
132       break;
133     }
134   }
135 
136   if (!Def)
137     return Reg;
138 
139   auto DestSrc = TII.isCopyInstr(*Def);
140   if (!DestSrc || DestSrc->Destination->getReg() != Reg)
141     return Reg;
142 
143   Register SrcReg = DestSrc->Source->getReg();
144 
145   if (getRegisterSize(TRI, Reg) != getRegisterSize(TRI, SrcReg))
146     return Reg;
147 
148   LLVM_DEBUG(dbgs() << "spillRegisters: perform copy propagation "
149                     << printReg(Reg, &TRI) << " -> " << printReg(SrcReg, &TRI)
150                     << "\n");
151 
152   // Insert spill immediately after Def
153   RI = ++MachineBasicBlock::iterator(Def);
154   IsKill = DestSrc->Source->isKill();
155 
156   if (!Use) {
157     // There are no uses of original register between COPY and STATEPOINT.
158     // There can't be any after STATEPOINT, so we can eliminate Def.
159     LLVM_DEBUG(dbgs() << "spillRegisters: removing dead copy " << *Def);
160     Def->eraseFromParent();
161   } else if (IsKill) {
162     // COPY will remain in place, spill will be inserted *after* it, so it is
163     // not a kill of source anymore.
164     const_cast<MachineOperand *>(DestSrc->Source)->setIsKill(false);
165   }
166 
167   return SrcReg;
168 }
169 
170 namespace {
171 // Pair {Register, FrameIndex}
172 using RegSlotPair = std::pair<Register, int>;
173 
174 // Keeps track of what reloads were inserted in MBB.
175 class RegReloadCache {
176   using ReloadSet = SmallSet<RegSlotPair, 8>;
177   DenseMap<const MachineBasicBlock *, ReloadSet> Reloads;
178 
179 public:
180   RegReloadCache() = default;
181 
182   // Record reload of Reg from FI in block MBB
183   void recordReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
184     RegSlotPair RSP(Reg, FI);
185     auto Res = Reloads[MBB].insert(RSP);
186     (void)Res;
187     assert(Res.second && "reload already exists");
188   }
189 
190   // Does basic block MBB contains reload of Reg from FI?
191   bool hasReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
192     RegSlotPair RSP(Reg, FI);
193     return Reloads.count(MBB) && Reloads[MBB].count(RSP);
194   }
195 };
196 
197 // Cache used frame indexes during statepoint re-write to re-use them in
198 // processing next statepoint instruction.
199 // Two strategies. One is to preserve the size of spill slot while another one
200 // extends the size of spill slots to reduce the number of them, causing
201 // the less total frame size. But unspill will have "implicit" any extend.
202 class FrameIndexesCache {
203 private:
204   struct FrameIndexesPerSize {
205     // List of used frame indexes during processing previous statepoints.
206     SmallVector<int, 8> Slots;
207     // Current index of un-used yet frame index.
208     unsigned Index = 0;
209   };
210   MachineFrameInfo &MFI;
211   const TargetRegisterInfo &TRI;
212   // Map size to list of frame indexes of this size. If the mode is
213   // FixupSCSExtendSlotSize then the key 0 is used to keep all frame indexes.
214   // If the size of required spill slot is greater than in a cache then the
215   // size will be increased.
216   DenseMap<unsigned, FrameIndexesPerSize> Cache;
217 
218   // Keeps track of slots reserved for the shared landing pad processing.
219   // Initialized from GlobalIndices for the current EHPad.
220   SmallSet<int, 8> ReservedSlots;
221 
222   // Landing pad can be destination of several statepoints. Every register
223   // defined by such statepoints must be spilled to the same stack slot.
224   // This map keeps that information.
225   DenseMap<const MachineBasicBlock *, SmallVector<RegSlotPair, 8>>
226       GlobalIndices;
227 
228   FrameIndexesPerSize &getCacheBucket(unsigned Size) {
229     // In FixupSCSExtendSlotSize mode the bucket with 0 index is used
230     // for all sizes.
231     return Cache[FixupSCSExtendSlotSize ? 0 : Size];
232   }
233 
234 public:
235   FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI)
236       : MFI(MFI), TRI(TRI) {}
237   // Reset the current state of used frame indexes. After invocation of
238   // this function all frame indexes are available for allocation with
239   // the exception of slots reserved for landing pad processing (if any).
240   void reset(const MachineBasicBlock *EHPad) {
241     for (auto &It : Cache)
242       It.second.Index = 0;
243 
244     ReservedSlots.clear();
245     if (EHPad && GlobalIndices.count(EHPad))
246       for (auto &RSP : GlobalIndices[EHPad])
247         ReservedSlots.insert(RSP.second);
248   }
249 
250   // Get frame index to spill the register.
251   int getFrameIndex(Register Reg, MachineBasicBlock *EHPad) {
252     // Check if slot for Reg is already reserved at EHPad.
253     auto It = GlobalIndices.find(EHPad);
254     if (It != GlobalIndices.end()) {
255       auto &Vec = It->second;
256       auto Idx = llvm::find_if(
257           Vec, [Reg](RegSlotPair &RSP) { return Reg == RSP.first; });
258       if (Idx != Vec.end()) {
259         int FI = Idx->second;
260         LLVM_DEBUG(dbgs() << "Found global FI " << FI << " for register "
261                           << printReg(Reg, &TRI) << " at "
262                           << printMBBReference(*EHPad) << "\n");
263         assert(ReservedSlots.count(FI) && "using unreserved slot");
264         return FI;
265       }
266     }
267 
268     unsigned Size = getRegisterSize(TRI, Reg);
269     FrameIndexesPerSize &Line = getCacheBucket(Size);
270     while (Line.Index < Line.Slots.size()) {
271       int FI = Line.Slots[Line.Index++];
272       if (ReservedSlots.count(FI))
273         continue;
274       // If all sizes are kept together we probably need to extend the
275       // spill slot size.
276       if (MFI.getObjectSize(FI) < Size) {
277         MFI.setObjectSize(FI, Size);
278         MFI.setObjectAlignment(FI, Align(Size));
279         NumSpillSlotsExtended++;
280       }
281       return FI;
282     }
283     int FI = MFI.CreateSpillStackObject(Size, Align(Size));
284     NumSpillSlotsAllocated++;
285     Line.Slots.push_back(FI);
286     ++Line.Index;
287 
288     // Remember assignment {Reg, FI} for EHPad
289     if (EHPad) {
290       GlobalIndices[EHPad].push_back(std::make_pair(Reg, FI));
291       LLVM_DEBUG(dbgs() << "Reserved FI " << FI << " for spilling reg "
292                         << printReg(Reg, &TRI) << " at landing pad "
293                         << printMBBReference(*EHPad) << "\n");
294     }
295 
296     return FI;
297   }
298 
299   // Sort all registers to spill in descendent order. In the
300   // FixupSCSExtendSlotSize mode it will minimize the total frame size.
301   // In non FixupSCSExtendSlotSize mode we can skip this step.
302   void sortRegisters(SmallVectorImpl<Register> &Regs) {
303     if (!FixupSCSExtendSlotSize)
304       return;
305     llvm::sort(Regs, [&](Register &A, Register &B) {
306       return getRegisterSize(TRI, A) > getRegisterSize(TRI, B);
307     });
308   }
309 };
310 
311 // Describes the state of the current processing statepoint instruction.
312 class StatepointState {
313 private:
314   // statepoint instruction.
315   MachineInstr &MI;
316   MachineFunction &MF;
317   // If non-null then statepoint is invoke, and this points to the landing pad.
318   MachineBasicBlock *EHPad;
319   const TargetRegisterInfo &TRI;
320   const TargetInstrInfo &TII;
321   MachineFrameInfo &MFI;
322   // Mask with callee saved registers.
323   const uint32_t *Mask;
324   // Cache of frame indexes used on previous instruction processing.
325   FrameIndexesCache &CacheFI;
326   bool AllowGCPtrInCSR;
327   // Operands with physical registers requiring spilling.
328   SmallVector<unsigned, 8> OpsToSpill;
329   // Set of register to spill.
330   SmallVector<Register, 8> RegsToSpill;
331   // Set of registers to reload after statepoint.
332   SmallVector<Register, 8> RegsToReload;
333   // Map Register to Frame Slot index.
334   DenseMap<Register, int> RegToSlotIdx;
335 
336 public:
337   StatepointState(MachineInstr &MI, const uint32_t *Mask,
338                   FrameIndexesCache &CacheFI, bool AllowGCPtrInCSR)
339       : MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()),
340         TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()),
341         Mask(Mask), CacheFI(CacheFI), AllowGCPtrInCSR(AllowGCPtrInCSR) {
342 
343     // Find statepoint's landing pad, if any.
344     EHPad = nullptr;
345     MachineBasicBlock *MBB = MI.getParent();
346     // Invoke statepoint must be last one in block.
347     bool Last = std::none_of(++MI.getIterator(), MBB->end().getInstrIterator(),
348                              [](MachineInstr &I) {
349                                return I.getOpcode() == TargetOpcode::STATEPOINT;
350                              });
351 
352     if (!Last)
353       return;
354 
355     auto IsEHPad = [](MachineBasicBlock *B) { return B->isEHPad(); };
356 
357     assert(llvm::count_if(MBB->successors(), IsEHPad) < 2 && "multiple EHPads");
358 
359     auto It = llvm::find_if(MBB->successors(), IsEHPad);
360     if (It != MBB->succ_end())
361       EHPad = *It;
362   }
363 
364   MachineBasicBlock *getEHPad() const { return EHPad; }
365 
366   // Return true if register is callee saved.
367   bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; }
368 
369   // Iterates over statepoint meta args to find caller saver registers.
370   // Also cache the size of found registers.
371   // Returns true if caller save registers found.
372   bool findRegistersToSpill() {
373     SmallSet<Register, 8> GCRegs;
374     // All GC pointer operands assigned to registers produce new value.
375     // Since they're tied to their defs, it is enough to collect def registers.
376     for (const auto &Def : MI.defs())
377       GCRegs.insert(Def.getReg());
378 
379     SmallSet<Register, 8> VisitedRegs;
380     for (unsigned Idx = StatepointOpers(&MI).getVarIdx(),
381                   EndIdx = MI.getNumOperands();
382          Idx < EndIdx; ++Idx) {
383       MachineOperand &MO = MI.getOperand(Idx);
384       // Leave `undef` operands as is, StackMaps will rewrite them
385       // into a constant.
386       if (!MO.isReg() || MO.isImplicit() || MO.isUndef())
387         continue;
388       Register Reg = MO.getReg();
389       assert(Reg.isPhysical() && "Only physical regs are expected");
390 
391       if (isCalleeSaved(Reg) && (AllowGCPtrInCSR || !GCRegs.contains(Reg)))
392         continue;
393 
394       LLVM_DEBUG(dbgs() << "Will spill " << printReg(Reg, &TRI) << " at index "
395                         << Idx << "\n");
396 
397       if (VisitedRegs.insert(Reg).second)
398         RegsToSpill.push_back(Reg);
399       OpsToSpill.push_back(Idx);
400     }
401     CacheFI.sortRegisters(RegsToSpill);
402     return !RegsToSpill.empty();
403   }
404 
405   // Spill all caller saved registers right before statepoint instruction.
406   // Remember frame index where register is spilled.
407   void spillRegisters() {
408     for (Register Reg : RegsToSpill) {
409       int FI = CacheFI.getFrameIndex(Reg, EHPad);
410 
411       NumSpilledRegisters++;
412       RegToSlotIdx[Reg] = FI;
413 
414       LLVM_DEBUG(dbgs() << "Spilling " << printReg(Reg, &TRI) << " to FI " << FI
415                         << "\n");
416 
417       // Perform trivial copy propagation
418       bool IsKill = true;
419       MachineBasicBlock::iterator InsertBefore(MI);
420       Reg = performCopyPropagation(Reg, InsertBefore, IsKill, TII, TRI);
421       const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
422 
423       LLVM_DEBUG(dbgs() << "Insert spill before " << *InsertBefore);
424       TII.storeRegToStackSlot(*MI.getParent(), InsertBefore, Reg, IsKill, FI,
425                               RC, &TRI, Register());
426     }
427   }
428 
429   void insertReloadBefore(unsigned Reg, MachineBasicBlock::iterator It,
430                           MachineBasicBlock *MBB) {
431     const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
432     int FI = RegToSlotIdx[Reg];
433     if (It != MBB->end()) {
434       TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI, Register());
435       return;
436     }
437 
438     // To insert reload at the end of MBB, insert it before last instruction
439     // and then swap them.
440     assert(!MBB->empty() && "Empty block");
441     --It;
442     TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI, Register());
443     MachineInstr *Reload = It->getPrevNode();
444     int Dummy = 0;
445     (void)Dummy;
446     assert(TII.isLoadFromStackSlot(*Reload, Dummy) == Reg);
447     assert(Dummy == FI);
448     MBB->remove(Reload);
449     MBB->insertAfter(It, Reload);
450   }
451 
452   // Insert reloads of (relocated) registers spilled in statepoint.
453   void insertReloads(MachineInstr *NewStatepoint, RegReloadCache &RC) {
454     MachineBasicBlock *MBB = NewStatepoint->getParent();
455     auto InsertPoint = std::next(NewStatepoint->getIterator());
456 
457     for (auto Reg : RegsToReload) {
458       insertReloadBefore(Reg, InsertPoint, MBB);
459       LLVM_DEBUG(dbgs() << "Reloading " << printReg(Reg, &TRI) << " from FI "
460                         << RegToSlotIdx[Reg] << " after statepoint\n");
461 
462       if (EHPad && !RC.hasReload(Reg, RegToSlotIdx[Reg], EHPad)) {
463         RC.recordReload(Reg, RegToSlotIdx[Reg], EHPad);
464         auto EHPadInsertPoint = EHPad->SkipPHIsLabelsAndDebug(EHPad->begin());
465         insertReloadBefore(Reg, EHPadInsertPoint, EHPad);
466         LLVM_DEBUG(dbgs() << "...also reload at EHPad "
467                           << printMBBReference(*EHPad) << "\n");
468       }
469     }
470   }
471 
472   // Re-write statepoint machine instruction to replace caller saved operands
473   // with indirect memory location (frame index).
474   MachineInstr *rewriteStatepoint() {
475     MachineInstr *NewMI =
476         MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true);
477     MachineInstrBuilder MIB(MF, NewMI);
478 
479     unsigned NumOps = MI.getNumOperands();
480 
481     // New indices for the remaining defs.
482     SmallVector<unsigned, 8> NewIndices;
483     unsigned NumDefs = MI.getNumDefs();
484     for (unsigned I = 0; I < NumDefs; ++I) {
485       MachineOperand &DefMO = MI.getOperand(I);
486       assert(DefMO.isReg() && DefMO.isDef() && "Expected Reg Def operand");
487       Register Reg = DefMO.getReg();
488       assert(DefMO.isTied() && "Def is expected to be tied");
489       // We skipped undef uses and did not spill them, so we should not
490       // proceed with defs here.
491       if (MI.getOperand(MI.findTiedOperandIdx(I)).isUndef()) {
492         if (AllowGCPtrInCSR) {
493           NewIndices.push_back(NewMI->getNumOperands());
494           MIB.addReg(Reg, RegState::Define);
495         }
496         continue;
497       }
498       if (!AllowGCPtrInCSR) {
499         assert(is_contained(RegsToSpill, Reg));
500         RegsToReload.push_back(Reg);
501       } else {
502         if (isCalleeSaved(Reg)) {
503           NewIndices.push_back(NewMI->getNumOperands());
504           MIB.addReg(Reg, RegState::Define);
505         } else {
506           NewIndices.push_back(NumOps);
507           RegsToReload.push_back(Reg);
508         }
509       }
510     }
511 
512     // Add End marker.
513     OpsToSpill.push_back(MI.getNumOperands());
514     unsigned CurOpIdx = 0;
515 
516     for (unsigned I = NumDefs; I < MI.getNumOperands(); ++I) {
517       MachineOperand &MO = MI.getOperand(I);
518       if (I == OpsToSpill[CurOpIdx]) {
519         int FI = RegToSlotIdx[MO.getReg()];
520         MIB.addImm(StackMaps::IndirectMemRefOp);
521         MIB.addImm(getRegisterSize(TRI, MO.getReg()));
522         assert(MO.isReg() && "Should be register");
523         assert(MO.getReg().isPhysical() && "Should be physical register");
524         MIB.addFrameIndex(FI);
525         MIB.addImm(0);
526         ++CurOpIdx;
527       } else {
528         MIB.add(MO);
529         unsigned OldDef;
530         if (AllowGCPtrInCSR && MI.isRegTiedToDefOperand(I, &OldDef)) {
531           assert(OldDef < NumDefs);
532           assert(NewIndices[OldDef] < NumOps);
533           MIB->tieOperands(NewIndices[OldDef], MIB->getNumOperands() - 1);
534         }
535       }
536     }
537     assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed");
538     // Add mem operands.
539     NewMI->setMemRefs(MF, MI.memoperands());
540     for (auto It : RegToSlotIdx) {
541       Register R = It.first;
542       int FrameIndex = It.second;
543       auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
544       MachineMemOperand::Flags Flags = MachineMemOperand::MOLoad;
545       if (is_contained(RegsToReload, R))
546         Flags |= MachineMemOperand::MOStore;
547       auto *MMO =
548           MF.getMachineMemOperand(PtrInfo, Flags, getRegisterSize(TRI, R),
549                                   MFI.getObjectAlign(FrameIndex));
550       NewMI->addMemOperand(MF, MMO);
551     }
552 
553     // Insert new statepoint and erase old one.
554     MI.getParent()->insert(MI, NewMI);
555 
556     LLVM_DEBUG(dbgs() << "rewritten statepoint to : " << *NewMI << "\n");
557     MI.eraseFromParent();
558     return NewMI;
559   }
560 };
561 
562 class StatepointProcessor {
563 private:
564   MachineFunction &MF;
565   const TargetRegisterInfo &TRI;
566   FrameIndexesCache CacheFI;
567   RegReloadCache ReloadCache;
568 
569 public:
570   StatepointProcessor(MachineFunction &MF)
571       : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()),
572         CacheFI(MF.getFrameInfo(), TRI) {}
573 
574   bool process(MachineInstr &MI, bool AllowGCPtrInCSR) {
575     StatepointOpers SO(&MI);
576     uint64_t Flags = SO.getFlags();
577     // Do nothing for LiveIn, it supports all registers.
578     if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn)
579       return false;
580     LLVM_DEBUG(dbgs() << "\nMBB " << MI.getParent()->getNumber() << " "
581                       << MI.getParent()->getName() << " : process statepoint "
582                       << MI);
583     CallingConv::ID CC = SO.getCallingConv();
584     const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC);
585     StatepointState SS(MI, Mask, CacheFI, AllowGCPtrInCSR);
586     CacheFI.reset(SS.getEHPad());
587 
588     if (!SS.findRegistersToSpill())
589       return false;
590 
591     SS.spillRegisters();
592     auto *NewStatepoint = SS.rewriteStatepoint();
593     SS.insertReloads(NewStatepoint, ReloadCache);
594     return true;
595   }
596 };
597 } // namespace
598 
599 bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) {
600   if (skipFunction(MF.getFunction()))
601     return false;
602 
603   const Function &F = MF.getFunction();
604   if (!F.hasGC())
605     return false;
606 
607   SmallVector<MachineInstr *, 16> Statepoints;
608   for (MachineBasicBlock &BB : MF)
609     for (MachineInstr &I : BB)
610       if (I.getOpcode() == TargetOpcode::STATEPOINT)
611         Statepoints.push_back(&I);
612 
613   if (Statepoints.empty())
614     return false;
615 
616   bool Changed = false;
617   StatepointProcessor SPP(MF);
618   unsigned NumStatepoints = 0;
619   bool AllowGCPtrInCSR = PassGCPtrInCSR;
620   for (MachineInstr *I : Statepoints) {
621     ++NumStatepoints;
622     if (MaxStatepointsWithRegs.getNumOccurrences() &&
623         NumStatepoints >= MaxStatepointsWithRegs)
624       AllowGCPtrInCSR = false;
625     Changed |= SPP.process(*I, AllowGCPtrInCSR);
626   }
627   return Changed;
628 }
629