xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFMIPeephole.cpp (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs peephole optimizations to cleanup ugly code sequences at
10 // MachineInstruction layer.
11 //
12 // Currently, there are two optimizations implemented:
13 //  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14 //    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15 //    could prove the subregisters is defined by 32-bit operations in which
16 //    case the upper half of the underlying 64-bit registers were zeroed
17 //    implicitly.
18 //
19 //  - One post-RA PreEmit pass to do final cleanup on some redundant
20 //    instructions generated due to bad RA on subregister.
21 //===----------------------------------------------------------------------===//
22 
23 #include "BPF.h"
24 #include "BPFInstrInfo.h"
25 #include "BPFTargetMachine.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/CodeGen/LivePhysRegs.h"
29 #include "llvm/CodeGen/MachineFrameInfo.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstrBuilder.h"
32 #include "llvm/CodeGen/MachineRegisterInfo.h"
33 #include "llvm/Support/Debug.h"
34 #include <set>
35 
36 using namespace llvm;
37 
38 #define DEBUG_TYPE "bpf-mi-zext-elim"
39 
40 static cl::opt<int> GotolAbsLowBound("gotol-abs-low-bound", cl::Hidden,
41   cl::init(INT16_MAX >> 1), cl::desc("Specify gotol lower bound"));
42 
43 STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
44 
45 namespace {
46 
47 struct BPFMIPeephole : public MachineFunctionPass {
48 
49   static char ID;
50   const BPFInstrInfo *TII;
51   MachineFunction *MF;
52   MachineRegisterInfo *MRI;
53 
54   BPFMIPeephole() : MachineFunctionPass(ID) {}
55 
56 private:
57   // Initialize class variables.
58   void initialize(MachineFunction &MFParm);
59 
60   bool isCopyFrom32Def(MachineInstr *CopyMI);
61   bool isInsnFrom32Def(MachineInstr *DefInsn);
62   bool isPhiFrom32Def(MachineInstr *MovMI);
63   bool isMovFrom32Def(MachineInstr *MovMI);
64   bool eliminateZExtSeq();
65   bool eliminateZExt();
66 
67   std::set<MachineInstr *> PhiInsns;
68 
69 public:
70 
71   // Main entry point for this pass.
72   bool runOnMachineFunction(MachineFunction &MF) override {
73     if (skipFunction(MF.getFunction()))
74       return false;
75 
76     initialize(MF);
77 
78     // First try to eliminate (zext, lshift, rshift) and then
79     // try to eliminate zext.
80     bool ZExtSeqExist, ZExtExist;
81     ZExtSeqExist = eliminateZExtSeq();
82     ZExtExist = eliminateZExt();
83     return ZExtSeqExist || ZExtExist;
84   }
85 };
86 
87 // Initialize class variables.
88 void BPFMIPeephole::initialize(MachineFunction &MFParm) {
89   MF = &MFParm;
90   MRI = &MF->getRegInfo();
91   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
92   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
93 }
94 
95 bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
96 {
97   MachineOperand &opnd = CopyMI->getOperand(1);
98 
99   if (!opnd.isReg())
100     return false;
101 
102   // Return false if getting value from a 32bit physical register.
103   // Most likely, this physical register is aliased to
104   // function call return value or current function parameters.
105   Register Reg = opnd.getReg();
106   if (!Reg.isVirtual())
107     return false;
108 
109   if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
110     return false;
111 
112   MachineInstr *DefInsn = MRI->getVRegDef(Reg);
113   if (!isInsnFrom32Def(DefInsn))
114     return false;
115 
116   return true;
117 }
118 
119 bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
120 {
121   for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
122     MachineOperand &opnd = PhiMI->getOperand(i);
123 
124     if (!opnd.isReg())
125       return false;
126 
127     MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
128     if (!PhiDef)
129       return false;
130     if (PhiDef->isPHI()) {
131       if (!PhiInsns.insert(PhiDef).second)
132         return false;
133       if (!isPhiFrom32Def(PhiDef))
134         return false;
135     }
136     if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
137       return false;
138   }
139 
140   return true;
141 }
142 
143 // The \p DefInsn instruction defines a virtual register.
144 bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
145 {
146   if (!DefInsn)
147     return false;
148 
149   if (DefInsn->isPHI()) {
150     if (!PhiInsns.insert(DefInsn).second)
151       return false;
152     if (!isPhiFrom32Def(DefInsn))
153       return false;
154   } else if (DefInsn->getOpcode() == BPF::COPY) {
155     if (!isCopyFrom32Def(DefInsn))
156       return false;
157   }
158 
159   return true;
160 }
161 
162 bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
163 {
164   MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
165 
166   LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
167   LLVM_DEBUG(DefInsn->dump());
168 
169   PhiInsns.clear();
170   if (!isInsnFrom32Def(DefInsn))
171     return false;
172 
173   LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
174 
175   return true;
176 }
177 
178 bool BPFMIPeephole::eliminateZExtSeq() {
179   MachineInstr* ToErase = nullptr;
180   bool Eliminated = false;
181 
182   for (MachineBasicBlock &MBB : *MF) {
183     for (MachineInstr &MI : MBB) {
184       // If the previous instruction was marked for elimination, remove it now.
185       if (ToErase) {
186         ToErase->eraseFromParent();
187         ToErase = nullptr;
188       }
189 
190       // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
191       //
192       //   MOV_32_64 rB, wA
193       //   SLL_ri    rB, rB, 32
194       //   SRL_ri    rB, rB, 32
195       if (MI.getOpcode() == BPF::SRL_ri &&
196           MI.getOperand(2).getImm() == 32) {
197         Register DstReg = MI.getOperand(0).getReg();
198         Register ShfReg = MI.getOperand(1).getReg();
199         MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
200 
201         LLVM_DEBUG(dbgs() << "Starting SRL found:");
202         LLVM_DEBUG(MI.dump());
203 
204         if (!SllMI ||
205             SllMI->isPHI() ||
206             SllMI->getOpcode() != BPF::SLL_ri ||
207             SllMI->getOperand(2).getImm() != 32)
208           continue;
209 
210         LLVM_DEBUG(dbgs() << "  SLL found:");
211         LLVM_DEBUG(SllMI->dump());
212 
213         MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
214         if (!MovMI ||
215             MovMI->isPHI() ||
216             MovMI->getOpcode() != BPF::MOV_32_64)
217           continue;
218 
219         LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
220         LLVM_DEBUG(MovMI->dump());
221 
222         Register SubReg = MovMI->getOperand(1).getReg();
223         if (!isMovFrom32Def(MovMI)) {
224           LLVM_DEBUG(dbgs()
225                      << "  One ZExt elim sequence failed qualifying elim.\n");
226           continue;
227         }
228 
229         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
230           .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
231 
232         SllMI->eraseFromParent();
233         MovMI->eraseFromParent();
234         // MI is the right shift, we can't erase it in it's own iteration.
235         // Mark it to ToErase, and erase in the next iteration.
236         ToErase = &MI;
237         ZExtElemNum++;
238         Eliminated = true;
239       }
240     }
241   }
242 
243   return Eliminated;
244 }
245 
246 bool BPFMIPeephole::eliminateZExt() {
247   MachineInstr* ToErase = nullptr;
248   bool Eliminated = false;
249 
250   for (MachineBasicBlock &MBB : *MF) {
251     for (MachineInstr &MI : MBB) {
252       // If the previous instruction was marked for elimination, remove it now.
253       if (ToErase) {
254         ToErase->eraseFromParent();
255         ToErase = nullptr;
256       }
257 
258       if (MI.getOpcode() != BPF::MOV_32_64)
259         continue;
260 
261       // Eliminate MOV_32_64 if possible.
262       //   MOV_32_64 rA, wB
263       //
264       // If wB has been zero extended, replace it with a SUBREG_TO_REG.
265       // This is to workaround BPF programs where pkt->{data, data_end}
266       // is encoded as u32, but actually the verifier populates them
267       // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
268       LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
269       LLVM_DEBUG(MI.dump());
270 
271       if (!isMovFrom32Def(&MI))
272         continue;
273 
274       LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
275 
276       Register dst = MI.getOperand(0).getReg();
277       Register src = MI.getOperand(1).getReg();
278 
279       // Build a SUBREG_TO_REG instruction.
280       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
281         .addImm(0).addReg(src).addImm(BPF::sub_32);
282 
283       ToErase = &MI;
284       Eliminated = true;
285     }
286   }
287 
288   return Eliminated;
289 }
290 
291 } // end default namespace
292 
293 INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
294                 "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
295                 false, false)
296 
297 char BPFMIPeephole::ID = 0;
298 FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
299 
300 STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
301 
302 namespace {
303 
304 struct BPFMIPreEmitPeephole : public MachineFunctionPass {
305 
306   static char ID;
307   MachineFunction *MF;
308   const TargetRegisterInfo *TRI;
309   const BPFInstrInfo *TII;
310   bool SupportGotol;
311 
312   BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {}
313 
314 private:
315   // Initialize class variables.
316   void initialize(MachineFunction &MFParm);
317 
318   bool in16BitRange(int Num);
319   bool eliminateRedundantMov();
320   bool adjustBranch();
321   bool insertMissingCallerSavedSpills();
322   bool removeMayGotoZero();
323   bool addExitAfterUnreachable();
324 
325 public:
326 
327   // Main entry point for this pass.
328   bool runOnMachineFunction(MachineFunction &MF) override {
329     if (skipFunction(MF.getFunction()))
330       return false;
331 
332     initialize(MF);
333 
334     bool Changed;
335     Changed = eliminateRedundantMov();
336     if (SupportGotol)
337       Changed = adjustBranch() || Changed;
338     Changed |= insertMissingCallerSavedSpills();
339     Changed |= removeMayGotoZero();
340     Changed |= addExitAfterUnreachable();
341     return Changed;
342   }
343 };
344 
345 // Initialize class variables.
346 void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
347   MF = &MFParm;
348   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
349   TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
350   SupportGotol = MF->getSubtarget<BPFSubtarget>().hasGotol();
351   LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
352 }
353 
354 bool BPFMIPreEmitPeephole::eliminateRedundantMov() {
355   MachineInstr* ToErase = nullptr;
356   bool Eliminated = false;
357 
358   for (MachineBasicBlock &MBB : *MF) {
359     for (MachineInstr &MI : MBB) {
360       // If the previous instruction was marked for elimination, remove it now.
361       if (ToErase) {
362         LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
363         LLVM_DEBUG(ToErase->dump());
364         ToErase->eraseFromParent();
365         ToErase = nullptr;
366       }
367 
368       // Eliminate identical move:
369       //
370       //   MOV rA, rA
371       //
372       // Note that we cannot remove
373       //   MOV_32_64  rA, wA
374       //   MOV_rr_32  wA, wA
375       // as these two instructions having side effects, zeroing out
376       // top 32 bits of rA.
377       unsigned Opcode = MI.getOpcode();
378       if (Opcode == BPF::MOV_rr) {
379         Register dst = MI.getOperand(0).getReg();
380         Register src = MI.getOperand(1).getReg();
381 
382         if (dst != src)
383           continue;
384 
385         ToErase = &MI;
386         RedundantMovElemNum++;
387         Eliminated = true;
388       }
389     }
390   }
391 
392   return Eliminated;
393 }
394 
395 bool BPFMIPreEmitPeephole::in16BitRange(int Num) {
396   // Well, the cut-off is not precisely at 16bit range since
397   // new codes are added during the transformation. So let us
398   // a little bit conservative.
399   return Num >= -GotolAbsLowBound && Num <= GotolAbsLowBound;
400 }
401 
402 // Before cpu=v4, only 16bit branch target offset (-0x8000 to 0x7fff)
403 // is supported for both unconditional (JMP) and condition (JEQ, JSGT,
404 // etc.) branches. In certain cases, e.g., full unrolling, the branch
405 // target offset might exceed 16bit range. If this happens, the llvm
406 // will generate incorrect code as the offset is truncated to 16bit.
407 //
408 // To fix this rare case, a new insn JMPL is introduced. This new
409 // insn supports supports 32bit branch target offset. The compiler
410 // does not use this insn during insn selection. Rather, BPF backend
411 // will estimate the branch target offset and do JMP -> JMPL and
412 // JEQ -> JEQ + JMPL conversion if the estimated branch target offset
413 // is beyond 16bit.
414 bool BPFMIPreEmitPeephole::adjustBranch() {
415   bool Changed = false;
416   int CurrNumInsns = 0;
417   DenseMap<MachineBasicBlock *, int> SoFarNumInsns;
418   DenseMap<MachineBasicBlock *, MachineBasicBlock *> FollowThroughBB;
419   std::vector<MachineBasicBlock *> MBBs;
420 
421   MachineBasicBlock *PrevBB = nullptr;
422   for (MachineBasicBlock &MBB : *MF) {
423     // MBB.size() is the number of insns in this basic block, including some
424     // debug info, e.g., DEBUG_VALUE, so we may over-count a little bit.
425     // Typically we have way more normal insns than DEBUG_VALUE insns.
426     // Also, if we indeed need to convert conditional branch like JEQ to
427     // JEQ + JMPL, we actually introduced some new insns like below.
428     CurrNumInsns += (int)MBB.size();
429     SoFarNumInsns[&MBB] = CurrNumInsns;
430     if (PrevBB != nullptr)
431       FollowThroughBB[PrevBB] = &MBB;
432     PrevBB = &MBB;
433     // A list of original BBs to make later traveral easier.
434     MBBs.push_back(&MBB);
435   }
436   FollowThroughBB[PrevBB] = nullptr;
437 
438   for (unsigned i = 0; i < MBBs.size(); i++) {
439     // We have four cases here:
440     //  (1). no terminator, simple follow through.
441     //  (2). jmp to another bb.
442     //  (3). conditional jmp to another bb or follow through.
443     //  (4). conditional jmp followed by an unconditional jmp.
444     MachineInstr *CondJmp = nullptr, *UncondJmp = nullptr;
445 
446     MachineBasicBlock *MBB = MBBs[i];
447     for (MachineInstr &Term : MBB->terminators()) {
448       if (Term.isConditionalBranch()) {
449         assert(CondJmp == nullptr);
450         CondJmp = &Term;
451       } else if (Term.isUnconditionalBranch()) {
452         assert(UncondJmp == nullptr);
453         UncondJmp = &Term;
454       }
455     }
456 
457     // (1). no terminator, simple follow through.
458     if (!CondJmp && !UncondJmp)
459       continue;
460 
461     MachineBasicBlock *CondTargetBB, *JmpBB;
462     CurrNumInsns = SoFarNumInsns[MBB];
463 
464     // (2). jmp to another bb.
465     if (!CondJmp && UncondJmp) {
466       JmpBB = UncondJmp->getOperand(0).getMBB();
467       if (in16BitRange(SoFarNumInsns[JmpBB] - JmpBB->size() - CurrNumInsns))
468         continue;
469 
470       // replace this insn as a JMPL.
471       BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
472       UncondJmp->eraseFromParent();
473       Changed = true;
474       continue;
475     }
476 
477     const BasicBlock *TermBB = MBB->getBasicBlock();
478     int Dist;
479 
480     // (3). conditional jmp to another bb or follow through.
481     if (!UncondJmp) {
482       CondTargetBB = CondJmp->getOperand(2).getMBB();
483       MachineBasicBlock *FollowBB = FollowThroughBB[MBB];
484       Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
485       if (in16BitRange(Dist))
486         continue;
487 
488       // We have
489       //   B2: ...
490       //       if (cond) goto B5
491       //   B3: ...
492       // where B2 -> B5 is beyond 16bit range.
493       //
494       // We do not have 32bit cond jmp insn. So we try to do
495       // the following.
496       //   B2:     ...
497       //           if (cond) goto New_B1
498       //   New_B0  goto B3
499       //   New_B1: gotol B5
500       //   B3: ...
501       // Basically two new basic blocks are created.
502       MachineBasicBlock *New_B0 = MF->CreateMachineBasicBlock(TermBB);
503       MachineBasicBlock *New_B1 = MF->CreateMachineBasicBlock(TermBB);
504 
505       // Insert New_B0 and New_B1 into function block list.
506       MachineFunction::iterator MBB_I  = ++MBB->getIterator();
507       MF->insert(MBB_I, New_B0);
508       MF->insert(MBB_I, New_B1);
509 
510       // replace B2 cond jump
511       if (CondJmp->getOperand(1).isReg())
512         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
513             .addReg(CondJmp->getOperand(0).getReg())
514             .addReg(CondJmp->getOperand(1).getReg())
515             .addMBB(New_B1);
516       else
517         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
518             .addReg(CondJmp->getOperand(0).getReg())
519             .addImm(CondJmp->getOperand(1).getImm())
520             .addMBB(New_B1);
521 
522       // it is possible that CondTargetBB and FollowBB are the same. But the
523       // above Dist checking should already filtered this case.
524       MBB->removeSuccessor(CondTargetBB);
525       MBB->removeSuccessor(FollowBB);
526       MBB->addSuccessor(New_B0);
527       MBB->addSuccessor(New_B1);
528 
529       // Populate insns in New_B0 and New_B1.
530       BuildMI(New_B0, CondJmp->getDebugLoc(), TII->get(BPF::JMP)).addMBB(FollowBB);
531       BuildMI(New_B1, CondJmp->getDebugLoc(), TII->get(BPF::JMPL))
532           .addMBB(CondTargetBB);
533 
534       New_B0->addSuccessor(FollowBB);
535       New_B1->addSuccessor(CondTargetBB);
536       CondJmp->eraseFromParent();
537       Changed = true;
538       continue;
539     }
540 
541     //  (4). conditional jmp followed by an unconditional jmp.
542     CondTargetBB = CondJmp->getOperand(2).getMBB();
543     JmpBB = UncondJmp->getOperand(0).getMBB();
544 
545     // We have
546     //   B2: ...
547     //       if (cond) goto B5
548     //       JMP B7
549     //   B3: ...
550     //
551     // If only B2->B5 is out of 16bit range, we can do
552     //   B2: ...
553     //       if (cond) goto new_B
554     //       JMP B7
555     //   New_B: gotol B5
556     //   B3: ...
557     //
558     // If only 'JMP B7' is out of 16bit range, we can replace
559     // 'JMP B7' with 'JMPL B7'.
560     //
561     // If both B2->B5 and 'JMP B7' is out of range, just do
562     // both the above transformations.
563     Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
564     if (!in16BitRange(Dist)) {
565       MachineBasicBlock *New_B = MF->CreateMachineBasicBlock(TermBB);
566 
567       // Insert New_B0 into function block list.
568       MF->insert(++MBB->getIterator(), New_B);
569 
570       // replace B2 cond jump
571       if (CondJmp->getOperand(1).isReg())
572         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
573             .addReg(CondJmp->getOperand(0).getReg())
574             .addReg(CondJmp->getOperand(1).getReg())
575             .addMBB(New_B);
576       else
577         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
578             .addReg(CondJmp->getOperand(0).getReg())
579             .addImm(CondJmp->getOperand(1).getImm())
580             .addMBB(New_B);
581 
582       if (CondTargetBB != JmpBB)
583         MBB->removeSuccessor(CondTargetBB);
584       MBB->addSuccessor(New_B);
585 
586       // Populate insn in New_B.
587       BuildMI(New_B, CondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(CondTargetBB);
588 
589       New_B->addSuccessor(CondTargetBB);
590       CondJmp->eraseFromParent();
591       Changed = true;
592     }
593 
594     if (!in16BitRange(SoFarNumInsns[JmpBB] - CurrNumInsns)) {
595       BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
596       UncondJmp->eraseFromParent();
597       Changed = true;
598     }
599   }
600 
601   return Changed;
602 }
603 
604 static const unsigned CallerSavedRegs[] = {BPF::R0, BPF::R1, BPF::R2,
605                                            BPF::R3, BPF::R4, BPF::R5};
606 
607 struct BPFFastCall {
608   MachineInstr *MI;
609   unsigned LiveCallerSavedRegs;
610 };
611 
612 static void collectBPFFastCalls(const TargetRegisterInfo *TRI,
613                                 LivePhysRegs &LiveRegs, MachineBasicBlock &BB,
614                                 SmallVectorImpl<BPFFastCall> &Calls) {
615   LiveRegs.init(*TRI);
616   LiveRegs.addLiveOuts(BB);
617   Calls.clear();
618   for (MachineInstr &MI : llvm::reverse(BB)) {
619     if (MI.isCall()) {
620       unsigned LiveCallerSavedRegs = 0;
621       for (MCRegister R : CallerSavedRegs) {
622         bool DoSpillFill = false;
623         for (MCPhysReg SR : TRI->subregs(R))
624           DoSpillFill |= !MI.definesRegister(SR, TRI) && LiveRegs.contains(SR);
625         if (!DoSpillFill)
626           continue;
627         LiveCallerSavedRegs |= 1 << R;
628       }
629       if (LiveCallerSavedRegs)
630         Calls.push_back({&MI, LiveCallerSavedRegs});
631     }
632     LiveRegs.stepBackward(MI);
633   }
634 }
635 
636 static int64_t computeMinFixedObjOffset(MachineFrameInfo &MFI,
637                                         unsigned SlotSize) {
638   int64_t MinFixedObjOffset = 0;
639   // Same logic as in X86FrameLowering::adjustFrameForMsvcCxxEh()
640   for (int I = MFI.getObjectIndexBegin(); I < MFI.getObjectIndexEnd(); ++I) {
641     if (MFI.isDeadObjectIndex(I))
642       continue;
643     MinFixedObjOffset = std::min(MinFixedObjOffset, MFI.getObjectOffset(I));
644   }
645   MinFixedObjOffset -=
646       (SlotSize + MinFixedObjOffset % SlotSize) & (SlotSize - 1);
647   return MinFixedObjOffset;
648 }
649 
650 bool BPFMIPreEmitPeephole::insertMissingCallerSavedSpills() {
651   MachineFrameInfo &MFI = MF->getFrameInfo();
652   SmallVector<BPFFastCall, 8> Calls;
653   LivePhysRegs LiveRegs;
654   const unsigned SlotSize = 8;
655   int64_t MinFixedObjOffset = computeMinFixedObjOffset(MFI, SlotSize);
656   bool Changed = false;
657   for (MachineBasicBlock &BB : *MF) {
658     collectBPFFastCalls(TRI, LiveRegs, BB, Calls);
659     Changed |= !Calls.empty();
660     for (BPFFastCall &Call : Calls) {
661       int64_t CurOffset = MinFixedObjOffset;
662       for (MCRegister Reg : CallerSavedRegs) {
663         if (((1 << Reg) & Call.LiveCallerSavedRegs) == 0)
664           continue;
665         // Allocate stack object
666         CurOffset -= SlotSize;
667         MFI.CreateFixedSpillStackObject(SlotSize, CurOffset);
668         // Generate spill
669         BuildMI(BB, Call.MI->getIterator(), Call.MI->getDebugLoc(),
670                 TII->get(BPF::STD))
671             .addReg(Reg, RegState::Kill)
672             .addReg(BPF::R10)
673             .addImm(CurOffset);
674         // Generate fill
675         BuildMI(BB, ++Call.MI->getIterator(), Call.MI->getDebugLoc(),
676                 TII->get(BPF::LDD))
677             .addReg(Reg, RegState::Define)
678             .addReg(BPF::R10)
679             .addImm(CurOffset);
680       }
681     }
682   }
683   return Changed;
684 }
685 
686 bool BPFMIPreEmitPeephole::removeMayGotoZero() {
687   bool Changed = false;
688   MachineBasicBlock *Prev_MBB, *Curr_MBB = nullptr;
689 
690   for (MachineBasicBlock &MBB : make_early_inc_range(reverse(*MF))) {
691     Prev_MBB = Curr_MBB;
692     Curr_MBB = &MBB;
693     if (Prev_MBB == nullptr || Curr_MBB->empty())
694       continue;
695 
696     MachineInstr &MI = Curr_MBB->back();
697     if (MI.getOpcode() != TargetOpcode::INLINEASM_BR)
698       continue;
699 
700     const char *AsmStr = MI.getOperand(0).getSymbolName();
701     SmallVector<StringRef, 4> AsmPieces;
702     SplitString(AsmStr, AsmPieces, ";\n");
703 
704     // Do not support multiple insns in one inline asm.
705     if (AsmPieces.size() != 1)
706       continue;
707 
708     // The asm insn must be a may_goto insn.
709     SmallVector<StringRef, 4> AsmOpPieces;
710     SplitString(AsmPieces[0], AsmOpPieces, " ");
711     if (AsmOpPieces.size() != 2 || AsmOpPieces[0] != "may_goto")
712       continue;
713     // Enforce the format of 'may_goto <label>'.
714     if (AsmOpPieces[1] != "${0:l}" && AsmOpPieces[1] != "$0")
715       continue;
716 
717     // Get the may_goto branch target.
718     MachineOperand &MO = MI.getOperand(InlineAsm::MIOp_FirstOperand + 1);
719     if (!MO.isMBB() || MO.getMBB() != Prev_MBB)
720       continue;
721 
722     Changed = true;
723     if (Curr_MBB->begin() == MI) {
724       // Single 'may_goto' insn in the same basic block.
725       Curr_MBB->removeSuccessor(Prev_MBB);
726       for (MachineBasicBlock *Pred : Curr_MBB->predecessors())
727         Pred->replaceSuccessor(Curr_MBB, Prev_MBB);
728       Curr_MBB->eraseFromParent();
729       Curr_MBB = Prev_MBB;
730     } else {
731       // Remove 'may_goto' insn.
732       MI.eraseFromParent();
733     }
734   }
735 
736   return Changed;
737 }
738 
739 // If the last insn in a funciton is 'JAL &bpf_unreachable', let us add an
740 // 'exit' insn after that insn. This will ensure no fallthrough at the last
741 // insn, making kernel verification easier.
742 bool BPFMIPreEmitPeephole::addExitAfterUnreachable() {
743   MachineBasicBlock &MBB = MF->back();
744   MachineInstr &MI = MBB.back();
745   if (MI.getOpcode() != BPF::JAL || !MI.getOperand(0).isGlobal() ||
746       MI.getOperand(0).getGlobal()->getName() != BPF_TRAP)
747     return false;
748 
749   BuildMI(&MBB, MI.getDebugLoc(), TII->get(BPF::RET));
750   return true;
751 }
752 
753 } // end default namespace
754 
755 INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
756                 "BPF PreEmit Peephole Optimization", false, false)
757 
758 char BPFMIPreEmitPeephole::ID = 0;
759 FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
760 {
761   return new BPFMIPreEmitPeephole();
762 }
763