xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp (revision f126d349810fdb512c0b01e101342d430b947488)
1 //===-- FixupStatepointCallerSaved.cpp - Fixup caller saved registers  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Statepoint instruction in deopt parameters contains values which are
11 /// meaningful to the runtime and should be able to be read at the moment the
12 /// call returns. So we can say that we need to encode the fact that these
13 /// values are "late read" by runtime. If we could express this notion for
14 /// register allocator it would produce the right form for us.
15 /// The need to fixup (i.e this pass) is specifically handling the fact that
16 /// we cannot describe such a late read for the register allocator.
17 /// Register allocator may put the value on a register clobbered by the call.
18 /// This pass forces the spill of such registers and replaces corresponding
19 /// statepoint operands to added spill slots.
20 ///
21 //===----------------------------------------------------------------------===//
22 
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/Passes.h"
29 #include "llvm/CodeGen/StackMaps.h"
30 #include "llvm/CodeGen/TargetFrameLowering.h"
31 #include "llvm/CodeGen/TargetInstrInfo.h"
32 #include "llvm/IR/Statepoint.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Support/Debug.h"
35 
36 using namespace llvm;
37 
38 #define DEBUG_TYPE "fixup-statepoint-caller-saved"
39 STATISTIC(NumSpilledRegisters, "Number of spilled register");
40 STATISTIC(NumSpillSlotsAllocated, "Number of spill slots allocated");
41 STATISTIC(NumSpillSlotsExtended, "Number of spill slots extended");
42 
43 static cl::opt<bool> FixupSCSExtendSlotSize(
44     "fixup-scs-extend-slot-size", cl::Hidden, cl::init(false),
45     cl::desc("Allow spill in spill slot of greater size than register size"),
46     cl::Hidden);
47 
48 static cl::opt<bool> PassGCPtrInCSR(
49     "fixup-allow-gcptr-in-csr", cl::Hidden, cl::init(false),
50     cl::desc("Allow passing GC Pointer arguments in callee saved registers"));
51 
52 static cl::opt<bool> EnableCopyProp(
53     "fixup-scs-enable-copy-propagation", cl::Hidden, cl::init(true),
54     cl::desc("Enable simple copy propagation during register reloading"));
55 
56 // This is purely debugging option.
57 // It may be handy for investigating statepoint spilling issues.
58 static cl::opt<unsigned> MaxStatepointsWithRegs(
59     "fixup-max-csr-statepoints", cl::Hidden,
60     cl::desc("Max number of statepoints allowed to pass GC Ptrs in registers"));
61 
62 namespace {
63 
64 class FixupStatepointCallerSaved : public MachineFunctionPass {
65 public:
66   static char ID;
67 
68   FixupStatepointCallerSaved() : MachineFunctionPass(ID) {
69     initializeFixupStatepointCallerSavedPass(*PassRegistry::getPassRegistry());
70   }
71 
72   void getAnalysisUsage(AnalysisUsage &AU) const override {
73     AU.setPreservesCFG();
74     MachineFunctionPass::getAnalysisUsage(AU);
75   }
76 
77   StringRef getPassName() const override {
78     return "Fixup Statepoint Caller Saved";
79   }
80 
81   bool runOnMachineFunction(MachineFunction &MF) override;
82 };
83 
84 } // End anonymous namespace.
85 
86 char FixupStatepointCallerSaved::ID = 0;
87 char &llvm::FixupStatepointCallerSavedID = FixupStatepointCallerSaved::ID;
88 
89 INITIALIZE_PASS_BEGIN(FixupStatepointCallerSaved, DEBUG_TYPE,
90                       "Fixup Statepoint Caller Saved", false, false)
91 INITIALIZE_PASS_END(FixupStatepointCallerSaved, DEBUG_TYPE,
92                     "Fixup Statepoint Caller Saved", false, false)
93 
94 // Utility function to get size of the register.
95 static unsigned getRegisterSize(const TargetRegisterInfo &TRI, Register Reg) {
96   const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
97   return TRI.getSpillSize(*RC);
98 }
99 
100 // Try to eliminate redundant copy to register which we're going to
101 // spill, i.e. try to change:
102 //    X = COPY Y
103 //    SPILL X
104 //  to
105 //    SPILL Y
106 //  If there are no uses of X between copy and STATEPOINT, that COPY
107 //  may be eliminated.
108 //  Reg - register we're about to spill
109 //  RI - On entry points to statepoint.
110 //       On successful copy propagation set to new spill point.
111 //  IsKill - set to true if COPY is Kill (there are no uses of Y)
112 //  Returns either found source copy register or original one.
113 static Register performCopyPropagation(Register Reg,
114                                        MachineBasicBlock::iterator &RI,
115                                        bool &IsKill, const TargetInstrInfo &TII,
116                                        const TargetRegisterInfo &TRI) {
117   // First check if statepoint itself uses Reg in non-meta operands.
118   int Idx = RI->findRegisterUseOperandIdx(Reg, false, &TRI);
119   if (Idx >= 0 && (unsigned)Idx < StatepointOpers(&*RI).getNumDeoptArgsIdx()) {
120     IsKill = false;
121     return Reg;
122   }
123 
124   if (!EnableCopyProp)
125     return Reg;
126 
127   MachineBasicBlock *MBB = RI->getParent();
128   MachineBasicBlock::reverse_iterator E = MBB->rend();
129   MachineInstr *Def = nullptr, *Use = nullptr;
130   for (auto It = ++(RI.getReverse()); It != E; ++It) {
131     if (It->readsRegister(Reg, &TRI) && !Use)
132       Use = &*It;
133     if (It->modifiesRegister(Reg, &TRI)) {
134       Def = &*It;
135       break;
136     }
137   }
138 
139   if (!Def)
140     return Reg;
141 
142   auto DestSrc = TII.isCopyInstr(*Def);
143   if (!DestSrc || DestSrc->Destination->getReg() != Reg)
144     return Reg;
145 
146   Register SrcReg = DestSrc->Source->getReg();
147 
148   if (getRegisterSize(TRI, Reg) != getRegisterSize(TRI, SrcReg))
149     return Reg;
150 
151   LLVM_DEBUG(dbgs() << "spillRegisters: perform copy propagation "
152                     << printReg(Reg, &TRI) << " -> " << printReg(SrcReg, &TRI)
153                     << "\n");
154 
155   // Insert spill immediately after Def
156   RI = ++MachineBasicBlock::iterator(Def);
157   IsKill = DestSrc->Source->isKill();
158 
159   // There are no uses of original register between COPY and STATEPOINT.
160   // There can't be any after STATEPOINT, so we can eliminate Def.
161   if (!Use) {
162     LLVM_DEBUG(dbgs() << "spillRegisters: removing dead copy " << *Def);
163     Def->eraseFromParent();
164   }
165   return SrcReg;
166 }
167 
168 namespace {
169 // Pair {Register, FrameIndex}
170 using RegSlotPair = std::pair<Register, int>;
171 
172 // Keeps track of what reloads were inserted in MBB.
173 class RegReloadCache {
174   using ReloadSet = SmallSet<RegSlotPair, 8>;
175   DenseMap<const MachineBasicBlock *, ReloadSet> Reloads;
176 
177 public:
178   RegReloadCache() = default;
179 
180   // Record reload of Reg from FI in block MBB
181   void recordReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
182     RegSlotPair RSP(Reg, FI);
183     auto Res = Reloads[MBB].insert(RSP);
184     (void)Res;
185     assert(Res.second && "reload already exists");
186   }
187 
188   // Does basic block MBB contains reload of Reg from FI?
189   bool hasReload(Register Reg, int FI, const MachineBasicBlock *MBB) {
190     RegSlotPair RSP(Reg, FI);
191     return Reloads.count(MBB) && Reloads[MBB].count(RSP);
192   }
193 };
194 
195 // Cache used frame indexes during statepoint re-write to re-use them in
196 // processing next statepoint instruction.
197 // Two strategies. One is to preserve the size of spill slot while another one
198 // extends the size of spill slots to reduce the number of them, causing
199 // the less total frame size. But unspill will have "implicit" any extend.
200 class FrameIndexesCache {
201 private:
202   struct FrameIndexesPerSize {
203     // List of used frame indexes during processing previous statepoints.
204     SmallVector<int, 8> Slots;
205     // Current index of un-used yet frame index.
206     unsigned Index = 0;
207   };
208   MachineFrameInfo &MFI;
209   const TargetRegisterInfo &TRI;
210   // Map size to list of frame indexes of this size. If the mode is
211   // FixupSCSExtendSlotSize then the key 0 is used to keep all frame indexes.
212   // If the size of required spill slot is greater than in a cache then the
213   // size will be increased.
214   DenseMap<unsigned, FrameIndexesPerSize> Cache;
215 
216   // Keeps track of slots reserved for the shared landing pad processing.
217   // Initialized from GlobalIndices for the current EHPad.
218   SmallSet<int, 8> ReservedSlots;
219 
220   // Landing pad can be destination of several statepoints. Every register
221   // defined by such statepoints must be spilled to the same stack slot.
222   // This map keeps that information.
223   DenseMap<const MachineBasicBlock *, SmallVector<RegSlotPair, 8>>
224       GlobalIndices;
225 
226   FrameIndexesPerSize &getCacheBucket(unsigned Size) {
227     // In FixupSCSExtendSlotSize mode the bucket with 0 index is used
228     // for all sizes.
229     return Cache[FixupSCSExtendSlotSize ? 0 : Size];
230   }
231 
232 public:
233   FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI)
234       : MFI(MFI), TRI(TRI) {}
235   // Reset the current state of used frame indexes. After invocation of
236   // this function all frame indexes are available for allocation with
237   // the exception of slots reserved for landing pad processing (if any).
238   void reset(const MachineBasicBlock *EHPad) {
239     for (auto &It : Cache)
240       It.second.Index = 0;
241 
242     ReservedSlots.clear();
243     if (EHPad && GlobalIndices.count(EHPad))
244       for (auto &RSP : GlobalIndices[EHPad])
245         ReservedSlots.insert(RSP.second);
246   }
247 
248   // Get frame index to spill the register.
249   int getFrameIndex(Register Reg, MachineBasicBlock *EHPad) {
250     // Check if slot for Reg is already reserved at EHPad.
251     auto It = GlobalIndices.find(EHPad);
252     if (It != GlobalIndices.end()) {
253       auto &Vec = It->second;
254       auto Idx = llvm::find_if(
255           Vec, [Reg](RegSlotPair &RSP) { return Reg == RSP.first; });
256       if (Idx != Vec.end()) {
257         int FI = Idx->second;
258         LLVM_DEBUG(dbgs() << "Found global FI " << FI << " for register "
259                           << printReg(Reg, &TRI) << " at "
260                           << printMBBReference(*EHPad) << "\n");
261         assert(ReservedSlots.count(FI) && "using unreserved slot");
262         return FI;
263       }
264     }
265 
266     unsigned Size = getRegisterSize(TRI, Reg);
267     FrameIndexesPerSize &Line = getCacheBucket(Size);
268     while (Line.Index < Line.Slots.size()) {
269       int FI = Line.Slots[Line.Index++];
270       if (ReservedSlots.count(FI))
271         continue;
272       // If all sizes are kept together we probably need to extend the
273       // spill slot size.
274       if (MFI.getObjectSize(FI) < Size) {
275         MFI.setObjectSize(FI, Size);
276         MFI.setObjectAlignment(FI, Align(Size));
277         NumSpillSlotsExtended++;
278       }
279       return FI;
280     }
281     int FI = MFI.CreateSpillStackObject(Size, Align(Size));
282     NumSpillSlotsAllocated++;
283     Line.Slots.push_back(FI);
284     ++Line.Index;
285 
286     // Remember assignment {Reg, FI} for EHPad
287     if (EHPad) {
288       GlobalIndices[EHPad].push_back(std::make_pair(Reg, FI));
289       LLVM_DEBUG(dbgs() << "Reserved FI " << FI << " for spilling reg "
290                         << printReg(Reg, &TRI) << " at landing pad "
291                         << printMBBReference(*EHPad) << "\n");
292     }
293 
294     return FI;
295   }
296 
297   // Sort all registers to spill in descendent order. In the
298   // FixupSCSExtendSlotSize mode it will minimize the total frame size.
299   // In non FixupSCSExtendSlotSize mode we can skip this step.
300   void sortRegisters(SmallVectorImpl<Register> &Regs) {
301     if (!FixupSCSExtendSlotSize)
302       return;
303     llvm::sort(Regs, [&](Register &A, Register &B) {
304       return getRegisterSize(TRI, A) > getRegisterSize(TRI, B);
305     });
306   }
307 };
308 
309 // Describes the state of the current processing statepoint instruction.
310 class StatepointState {
311 private:
312   // statepoint instruction.
313   MachineInstr &MI;
314   MachineFunction &MF;
315   // If non-null then statepoint is invoke, and this points to the landing pad.
316   MachineBasicBlock *EHPad;
317   const TargetRegisterInfo &TRI;
318   const TargetInstrInfo &TII;
319   MachineFrameInfo &MFI;
320   // Mask with callee saved registers.
321   const uint32_t *Mask;
322   // Cache of frame indexes used on previous instruction processing.
323   FrameIndexesCache &CacheFI;
324   bool AllowGCPtrInCSR;
325   // Operands with physical registers requiring spilling.
326   SmallVector<unsigned, 8> OpsToSpill;
327   // Set of register to spill.
328   SmallVector<Register, 8> RegsToSpill;
329   // Set of registers to reload after statepoint.
330   SmallVector<Register, 8> RegsToReload;
331   // Map Register to Frame Slot index.
332   DenseMap<Register, int> RegToSlotIdx;
333 
334 public:
335   StatepointState(MachineInstr &MI, const uint32_t *Mask,
336                   FrameIndexesCache &CacheFI, bool AllowGCPtrInCSR)
337       : MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()),
338         TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()),
339         Mask(Mask), CacheFI(CacheFI), AllowGCPtrInCSR(AllowGCPtrInCSR) {
340 
341     // Find statepoint's landing pad, if any.
342     EHPad = nullptr;
343     MachineBasicBlock *MBB = MI.getParent();
344     // Invoke statepoint must be last one in block.
345     bool Last = std::none_of(++MI.getIterator(), MBB->end().getInstrIterator(),
346                              [](MachineInstr &I) {
347                                return I.getOpcode() == TargetOpcode::STATEPOINT;
348                              });
349 
350     if (!Last)
351       return;
352 
353     auto IsEHPad = [](MachineBasicBlock *B) { return B->isEHPad(); };
354 
355     assert(llvm::count_if(MBB->successors(), IsEHPad) < 2 && "multiple EHPads");
356 
357     auto It = llvm::find_if(MBB->successors(), IsEHPad);
358     if (It != MBB->succ_end())
359       EHPad = *It;
360   }
361 
362   MachineBasicBlock *getEHPad() const { return EHPad; }
363 
364   // Return true if register is callee saved.
365   bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; }
366 
367   // Iterates over statepoint meta args to find caller saver registers.
368   // Also cache the size of found registers.
369   // Returns true if caller save registers found.
370   bool findRegistersToSpill() {
371     SmallSet<Register, 8> GCRegs;
372     // All GC pointer operands assigned to registers produce new value.
373     // Since they're tied to their defs, it is enough to collect def registers.
374     for (const auto &Def : MI.defs())
375       GCRegs.insert(Def.getReg());
376 
377     SmallSet<Register, 8> VisitedRegs;
378     for (unsigned Idx = StatepointOpers(&MI).getVarIdx(),
379                   EndIdx = MI.getNumOperands();
380          Idx < EndIdx; ++Idx) {
381       MachineOperand &MO = MI.getOperand(Idx);
382       // Leave `undef` operands as is, StackMaps will rewrite them
383       // into a constant.
384       if (!MO.isReg() || MO.isImplicit() || MO.isUndef())
385         continue;
386       Register Reg = MO.getReg();
387       assert(Reg.isPhysical() && "Only physical regs are expected");
388 
389       if (isCalleeSaved(Reg) && (AllowGCPtrInCSR || !is_contained(GCRegs, Reg)))
390         continue;
391 
392       LLVM_DEBUG(dbgs() << "Will spill " << printReg(Reg, &TRI) << " at index "
393                         << Idx << "\n");
394 
395       if (VisitedRegs.insert(Reg).second)
396         RegsToSpill.push_back(Reg);
397       OpsToSpill.push_back(Idx);
398     }
399     CacheFI.sortRegisters(RegsToSpill);
400     return !RegsToSpill.empty();
401   }
402 
403   // Spill all caller saved registers right before statepoint instruction.
404   // Remember frame index where register is spilled.
405   void spillRegisters() {
406     for (Register Reg : RegsToSpill) {
407       int FI = CacheFI.getFrameIndex(Reg, EHPad);
408       const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
409 
410       NumSpilledRegisters++;
411       RegToSlotIdx[Reg] = FI;
412 
413       LLVM_DEBUG(dbgs() << "Spilling " << printReg(Reg, &TRI) << " to FI " << FI
414                         << "\n");
415 
416       // Perform trivial copy propagation
417       bool IsKill = true;
418       MachineBasicBlock::iterator InsertBefore(MI);
419       Reg = performCopyPropagation(Reg, InsertBefore, IsKill, TII, TRI);
420 
421       LLVM_DEBUG(dbgs() << "Insert spill before " << *InsertBefore);
422       TII.storeRegToStackSlot(*MI.getParent(), InsertBefore, Reg, IsKill, FI,
423                               RC, &TRI);
424     }
425   }
426 
427   void insertReloadBefore(unsigned Reg, MachineBasicBlock::iterator It,
428                           MachineBasicBlock *MBB) {
429     const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg);
430     int FI = RegToSlotIdx[Reg];
431     if (It != MBB->end()) {
432       TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI);
433       return;
434     }
435 
436     // To insert reload at the end of MBB, insert it before last instruction
437     // and then swap them.
438     assert(!MBB->empty() && "Empty block");
439     --It;
440     TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI);
441     MachineInstr *Reload = It->getPrevNode();
442     int Dummy = 0;
443     (void)Dummy;
444     assert(TII.isLoadFromStackSlot(*Reload, Dummy) == Reg);
445     assert(Dummy == FI);
446     MBB->remove(Reload);
447     MBB->insertAfter(It, Reload);
448   }
449 
450   // Insert reloads of (relocated) registers spilled in statepoint.
451   void insertReloads(MachineInstr *NewStatepoint, RegReloadCache &RC) {
452     MachineBasicBlock *MBB = NewStatepoint->getParent();
453     auto InsertPoint = std::next(NewStatepoint->getIterator());
454 
455     for (auto Reg : RegsToReload) {
456       insertReloadBefore(Reg, InsertPoint, MBB);
457       LLVM_DEBUG(dbgs() << "Reloading " << printReg(Reg, &TRI) << " from FI "
458                         << RegToSlotIdx[Reg] << " after statepoint\n");
459 
460       if (EHPad && !RC.hasReload(Reg, RegToSlotIdx[Reg], EHPad)) {
461         RC.recordReload(Reg, RegToSlotIdx[Reg], EHPad);
462         auto EHPadInsertPoint = EHPad->SkipPHIsLabelsAndDebug(EHPad->begin());
463         insertReloadBefore(Reg, EHPadInsertPoint, EHPad);
464         LLVM_DEBUG(dbgs() << "...also reload at EHPad "
465                           << printMBBReference(*EHPad) << "\n");
466       }
467     }
468   }
469 
470   // Re-write statepoint machine instruction to replace caller saved operands
471   // with indirect memory location (frame index).
472   MachineInstr *rewriteStatepoint() {
473     MachineInstr *NewMI =
474         MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true);
475     MachineInstrBuilder MIB(MF, NewMI);
476 
477     unsigned NumOps = MI.getNumOperands();
478 
479     // New indices for the remaining defs.
480     SmallVector<unsigned, 8> NewIndices;
481     unsigned NumDefs = MI.getNumDefs();
482     for (unsigned I = 0; I < NumDefs; ++I) {
483       MachineOperand &DefMO = MI.getOperand(I);
484       assert(DefMO.isReg() && DefMO.isDef() && "Expected Reg Def operand");
485       Register Reg = DefMO.getReg();
486       assert(DefMO.isTied() && "Def is expected to be tied");
487       // We skipped undef uses and did not spill them, so we should not
488       // proceed with defs here.
489       if (MI.getOperand(MI.findTiedOperandIdx(I)).isUndef()) {
490         if (AllowGCPtrInCSR) {
491           NewIndices.push_back(NewMI->getNumOperands());
492           MIB.addReg(Reg, RegState::Define);
493         }
494         continue;
495       }
496       if (!AllowGCPtrInCSR) {
497         assert(is_contained(RegsToSpill, Reg));
498         RegsToReload.push_back(Reg);
499       } else {
500         if (isCalleeSaved(Reg)) {
501           NewIndices.push_back(NewMI->getNumOperands());
502           MIB.addReg(Reg, RegState::Define);
503         } else {
504           NewIndices.push_back(NumOps);
505           RegsToReload.push_back(Reg);
506         }
507       }
508     }
509 
510     // Add End marker.
511     OpsToSpill.push_back(MI.getNumOperands());
512     unsigned CurOpIdx = 0;
513 
514     for (unsigned I = NumDefs; I < MI.getNumOperands(); ++I) {
515       MachineOperand &MO = MI.getOperand(I);
516       if (I == OpsToSpill[CurOpIdx]) {
517         int FI = RegToSlotIdx[MO.getReg()];
518         MIB.addImm(StackMaps::IndirectMemRefOp);
519         MIB.addImm(getRegisterSize(TRI, MO.getReg()));
520         assert(MO.isReg() && "Should be register");
521         assert(MO.getReg().isPhysical() && "Should be physical register");
522         MIB.addFrameIndex(FI);
523         MIB.addImm(0);
524         ++CurOpIdx;
525       } else {
526         MIB.add(MO);
527         unsigned OldDef;
528         if (AllowGCPtrInCSR && MI.isRegTiedToDefOperand(I, &OldDef)) {
529           assert(OldDef < NumDefs);
530           assert(NewIndices[OldDef] < NumOps);
531           MIB->tieOperands(NewIndices[OldDef], MIB->getNumOperands() - 1);
532         }
533       }
534     }
535     assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed");
536     // Add mem operands.
537     NewMI->setMemRefs(MF, MI.memoperands());
538     for (auto It : RegToSlotIdx) {
539       Register R = It.first;
540       int FrameIndex = It.second;
541       auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
542       MachineMemOperand::Flags Flags = MachineMemOperand::MOLoad;
543       if (is_contained(RegsToReload, R))
544         Flags |= MachineMemOperand::MOStore;
545       auto *MMO =
546           MF.getMachineMemOperand(PtrInfo, Flags, getRegisterSize(TRI, R),
547                                   MFI.getObjectAlign(FrameIndex));
548       NewMI->addMemOperand(MF, MMO);
549     }
550 
551     // Insert new statepoint and erase old one.
552     MI.getParent()->insert(MI, NewMI);
553 
554     LLVM_DEBUG(dbgs() << "rewritten statepoint to : " << *NewMI << "\n");
555     MI.eraseFromParent();
556     return NewMI;
557   }
558 };
559 
560 class StatepointProcessor {
561 private:
562   MachineFunction &MF;
563   const TargetRegisterInfo &TRI;
564   FrameIndexesCache CacheFI;
565   RegReloadCache ReloadCache;
566 
567 public:
568   StatepointProcessor(MachineFunction &MF)
569       : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()),
570         CacheFI(MF.getFrameInfo(), TRI) {}
571 
572   bool process(MachineInstr &MI, bool AllowGCPtrInCSR) {
573     StatepointOpers SO(&MI);
574     uint64_t Flags = SO.getFlags();
575     // Do nothing for LiveIn, it supports all registers.
576     if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn)
577       return false;
578     LLVM_DEBUG(dbgs() << "\nMBB " << MI.getParent()->getNumber() << " "
579                       << MI.getParent()->getName() << " : process statepoint "
580                       << MI);
581     CallingConv::ID CC = SO.getCallingConv();
582     const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC);
583     StatepointState SS(MI, Mask, CacheFI, AllowGCPtrInCSR);
584     CacheFI.reset(SS.getEHPad());
585 
586     if (!SS.findRegistersToSpill())
587       return false;
588 
589     SS.spillRegisters();
590     auto *NewStatepoint = SS.rewriteStatepoint();
591     SS.insertReloads(NewStatepoint, ReloadCache);
592     return true;
593   }
594 };
595 } // namespace
596 
597 bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) {
598   if (skipFunction(MF.getFunction()))
599     return false;
600 
601   const Function &F = MF.getFunction();
602   if (!F.hasGC())
603     return false;
604 
605   SmallVector<MachineInstr *, 16> Statepoints;
606   for (MachineBasicBlock &BB : MF)
607     for (MachineInstr &I : BB)
608       if (I.getOpcode() == TargetOpcode::STATEPOINT)
609         Statepoints.push_back(&I);
610 
611   if (Statepoints.empty())
612     return false;
613 
614   bool Changed = false;
615   StatepointProcessor SPP(MF);
616   unsigned NumStatepoints = 0;
617   bool AllowGCPtrInCSR = PassGCPtrInCSR;
618   for (MachineInstr *I : Statepoints) {
619     ++NumStatepoints;
620     if (MaxStatepointsWithRegs.getNumOccurrences() &&
621         NumStatepoints >= MaxStatepointsWithRegs)
622       AllowGCPtrInCSR = false;
623     Changed |= SPP.process(*I, AllowGCPtrInCSR);
624   }
625   return Changed;
626 }
627