xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFMIPeephole.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1 //===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs peephole optimizations to cleanup ugly code sequences at
10 // MachineInstruction layer.
11 //
12 // Currently, there are two optimizations implemented:
13 //  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14 //    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15 //    could prove the subregisters is defined by 32-bit operations in which
16 //    case the upper half of the underlying 64-bit registers were zeroed
17 //    implicitly.
18 //
19 //  - One post-RA PreEmit pass to do final cleanup on some redundant
20 //    instructions generated due to bad RA on subregister.
21 //===----------------------------------------------------------------------===//
22 
23 #include "BPF.h"
24 #include "BPFInstrInfo.h"
25 #include "BPFTargetMachine.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineInstrBuilder.h"
29 #include "llvm/CodeGen/MachineRegisterInfo.h"
30 #include "llvm/Support/Debug.h"
31 #include <set>
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "bpf-mi-zext-elim"
36 
37 static cl::opt<int> GotolAbsLowBound("gotol-abs-low-bound", cl::Hidden,
38   cl::init(INT16_MAX >> 1), cl::desc("Specify gotol lower bound"));
39 
40 STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
41 
42 namespace {
43 
44 struct BPFMIPeephole : public MachineFunctionPass {
45 
46   static char ID;
47   const BPFInstrInfo *TII;
48   MachineFunction *MF;
49   MachineRegisterInfo *MRI;
50 
51   BPFMIPeephole() : MachineFunctionPass(ID) {
52     initializeBPFMIPeepholePass(*PassRegistry::getPassRegistry());
53   }
54 
55 private:
56   // Initialize class variables.
57   void initialize(MachineFunction &MFParm);
58 
59   bool isCopyFrom32Def(MachineInstr *CopyMI);
60   bool isInsnFrom32Def(MachineInstr *DefInsn);
61   bool isPhiFrom32Def(MachineInstr *MovMI);
62   bool isMovFrom32Def(MachineInstr *MovMI);
63   bool eliminateZExtSeq();
64   bool eliminateZExt();
65 
66   std::set<MachineInstr *> PhiInsns;
67 
68 public:
69 
70   // Main entry point for this pass.
71   bool runOnMachineFunction(MachineFunction &MF) override {
72     if (skipFunction(MF.getFunction()))
73       return false;
74 
75     initialize(MF);
76 
77     // First try to eliminate (zext, lshift, rshift) and then
78     // try to eliminate zext.
79     bool ZExtSeqExist, ZExtExist;
80     ZExtSeqExist = eliminateZExtSeq();
81     ZExtExist = eliminateZExt();
82     return ZExtSeqExist || ZExtExist;
83   }
84 };
85 
86 // Initialize class variables.
87 void BPFMIPeephole::initialize(MachineFunction &MFParm) {
88   MF = &MFParm;
89   MRI = &MF->getRegInfo();
90   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
91   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
92 }
93 
94 bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
95 {
96   MachineOperand &opnd = CopyMI->getOperand(1);
97 
98   if (!opnd.isReg())
99     return false;
100 
101   // Return false if getting value from a 32bit physical register.
102   // Most likely, this physical register is aliased to
103   // function call return value or current function parameters.
104   Register Reg = opnd.getReg();
105   if (!Reg.isVirtual())
106     return false;
107 
108   if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
109     return false;
110 
111   MachineInstr *DefInsn = MRI->getVRegDef(Reg);
112   if (!isInsnFrom32Def(DefInsn))
113     return false;
114 
115   return true;
116 }
117 
118 bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
119 {
120   for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
121     MachineOperand &opnd = PhiMI->getOperand(i);
122 
123     if (!opnd.isReg())
124       return false;
125 
126     MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
127     if (!PhiDef)
128       return false;
129     if (PhiDef->isPHI()) {
130       if (!PhiInsns.insert(PhiDef).second)
131         return false;
132       if (!isPhiFrom32Def(PhiDef))
133         return false;
134     }
135     if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
136       return false;
137   }
138 
139   return true;
140 }
141 
142 // The \p DefInsn instruction defines a virtual register.
143 bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
144 {
145   if (!DefInsn)
146     return false;
147 
148   if (DefInsn->isPHI()) {
149     if (!PhiInsns.insert(DefInsn).second)
150       return false;
151     if (!isPhiFrom32Def(DefInsn))
152       return false;
153   } else if (DefInsn->getOpcode() == BPF::COPY) {
154     if (!isCopyFrom32Def(DefInsn))
155       return false;
156   }
157 
158   return true;
159 }
160 
161 bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
162 {
163   MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
164 
165   LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
166   LLVM_DEBUG(DefInsn->dump());
167 
168   PhiInsns.clear();
169   if (!isInsnFrom32Def(DefInsn))
170     return false;
171 
172   LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
173 
174   return true;
175 }
176 
177 bool BPFMIPeephole::eliminateZExtSeq() {
178   MachineInstr* ToErase = nullptr;
179   bool Eliminated = false;
180 
181   for (MachineBasicBlock &MBB : *MF) {
182     for (MachineInstr &MI : MBB) {
183       // If the previous instruction was marked for elimination, remove it now.
184       if (ToErase) {
185         ToErase->eraseFromParent();
186         ToErase = nullptr;
187       }
188 
189       // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
190       //
191       //   MOV_32_64 rB, wA
192       //   SLL_ri    rB, rB, 32
193       //   SRL_ri    rB, rB, 32
194       if (MI.getOpcode() == BPF::SRL_ri &&
195           MI.getOperand(2).getImm() == 32) {
196         Register DstReg = MI.getOperand(0).getReg();
197         Register ShfReg = MI.getOperand(1).getReg();
198         MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
199 
200         LLVM_DEBUG(dbgs() << "Starting SRL found:");
201         LLVM_DEBUG(MI.dump());
202 
203         if (!SllMI ||
204             SllMI->isPHI() ||
205             SllMI->getOpcode() != BPF::SLL_ri ||
206             SllMI->getOperand(2).getImm() != 32)
207           continue;
208 
209         LLVM_DEBUG(dbgs() << "  SLL found:");
210         LLVM_DEBUG(SllMI->dump());
211 
212         MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
213         if (!MovMI ||
214             MovMI->isPHI() ||
215             MovMI->getOpcode() != BPF::MOV_32_64)
216           continue;
217 
218         LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
219         LLVM_DEBUG(MovMI->dump());
220 
221         Register SubReg = MovMI->getOperand(1).getReg();
222         if (!isMovFrom32Def(MovMI)) {
223           LLVM_DEBUG(dbgs()
224                      << "  One ZExt elim sequence failed qualifying elim.\n");
225           continue;
226         }
227 
228         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
229           .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
230 
231         SllMI->eraseFromParent();
232         MovMI->eraseFromParent();
233         // MI is the right shift, we can't erase it in it's own iteration.
234         // Mark it to ToErase, and erase in the next iteration.
235         ToErase = &MI;
236         ZExtElemNum++;
237         Eliminated = true;
238       }
239     }
240   }
241 
242   return Eliminated;
243 }
244 
245 bool BPFMIPeephole::eliminateZExt() {
246   MachineInstr* ToErase = nullptr;
247   bool Eliminated = false;
248 
249   for (MachineBasicBlock &MBB : *MF) {
250     for (MachineInstr &MI : MBB) {
251       // If the previous instruction was marked for elimination, remove it now.
252       if (ToErase) {
253         ToErase->eraseFromParent();
254         ToErase = nullptr;
255       }
256 
257       if (MI.getOpcode() != BPF::MOV_32_64)
258         continue;
259 
260       // Eliminate MOV_32_64 if possible.
261       //   MOV_32_64 rA, wB
262       //
263       // If wB has been zero extended, replace it with a SUBREG_TO_REG.
264       // This is to workaround BPF programs where pkt->{data, data_end}
265       // is encoded as u32, but actually the verifier populates them
266       // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
267       LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
268       LLVM_DEBUG(MI.dump());
269 
270       if (!isMovFrom32Def(&MI))
271         continue;
272 
273       LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
274 
275       Register dst = MI.getOperand(0).getReg();
276       Register src = MI.getOperand(1).getReg();
277 
278       // Build a SUBREG_TO_REG instruction.
279       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
280         .addImm(0).addReg(src).addImm(BPF::sub_32);
281 
282       ToErase = &MI;
283       Eliminated = true;
284     }
285   }
286 
287   return Eliminated;
288 }
289 
290 } // end default namespace
291 
292 INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
293                 "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
294                 false, false)
295 
296 char BPFMIPeephole::ID = 0;
297 FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
298 
299 STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
300 
301 namespace {
302 
303 struct BPFMIPreEmitPeephole : public MachineFunctionPass {
304 
305   static char ID;
306   MachineFunction *MF;
307   const TargetRegisterInfo *TRI;
308   const BPFInstrInfo *TII;
309   bool SupportGotol;
310 
311   BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {
312     initializeBPFMIPreEmitPeepholePass(*PassRegistry::getPassRegistry());
313   }
314 
315 private:
316   // Initialize class variables.
317   void initialize(MachineFunction &MFParm);
318 
319   bool in16BitRange(int Num);
320   bool eliminateRedundantMov();
321   bool adjustBranch();
322 
323 public:
324 
325   // Main entry point for this pass.
326   bool runOnMachineFunction(MachineFunction &MF) override {
327     if (skipFunction(MF.getFunction()))
328       return false;
329 
330     initialize(MF);
331 
332     bool Changed;
333     Changed = eliminateRedundantMov();
334     if (SupportGotol)
335       Changed = adjustBranch() || Changed;
336     return Changed;
337   }
338 };
339 
340 // Initialize class variables.
341 void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
342   MF = &MFParm;
343   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
344   TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
345   SupportGotol = MF->getSubtarget<BPFSubtarget>().hasGotol();
346   LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
347 }
348 
349 bool BPFMIPreEmitPeephole::eliminateRedundantMov() {
350   MachineInstr* ToErase = nullptr;
351   bool Eliminated = false;
352 
353   for (MachineBasicBlock &MBB : *MF) {
354     for (MachineInstr &MI : MBB) {
355       // If the previous instruction was marked for elimination, remove it now.
356       if (ToErase) {
357         LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
358         LLVM_DEBUG(ToErase->dump());
359         ToErase->eraseFromParent();
360         ToErase = nullptr;
361       }
362 
363       // Eliminate identical move:
364       //
365       //   MOV rA, rA
366       //
367       // Note that we cannot remove
368       //   MOV_32_64  rA, wA
369       //   MOV_rr_32  wA, wA
370       // as these two instructions having side effects, zeroing out
371       // top 32 bits of rA.
372       unsigned Opcode = MI.getOpcode();
373       if (Opcode == BPF::MOV_rr) {
374         Register dst = MI.getOperand(0).getReg();
375         Register src = MI.getOperand(1).getReg();
376 
377         if (dst != src)
378           continue;
379 
380         ToErase = &MI;
381         RedundantMovElemNum++;
382         Eliminated = true;
383       }
384     }
385   }
386 
387   return Eliminated;
388 }
389 
390 bool BPFMIPreEmitPeephole::in16BitRange(int Num) {
391   // Well, the cut-off is not precisely at 16bit range since
392   // new codes are added during the transformation. So let us
393   // a little bit conservative.
394   return Num >= -GotolAbsLowBound && Num <= GotolAbsLowBound;
395 }
396 
397 // Before cpu=v4, only 16bit branch target offset (-0x8000 to 0x7fff)
398 // is supported for both unconditional (JMP) and condition (JEQ, JSGT,
399 // etc.) branches. In certain cases, e.g., full unrolling, the branch
400 // target offset might exceed 16bit range. If this happens, the llvm
401 // will generate incorrect code as the offset is truncated to 16bit.
402 //
403 // To fix this rare case, a new insn JMPL is introduced. This new
404 // insn supports supports 32bit branch target offset. The compiler
405 // does not use this insn during insn selection. Rather, BPF backend
406 // will estimate the branch target offset and do JMP -> JMPL and
407 // JEQ -> JEQ + JMPL conversion if the estimated branch target offset
408 // is beyond 16bit.
409 bool BPFMIPreEmitPeephole::adjustBranch() {
410   bool Changed = false;
411   int CurrNumInsns = 0;
412   DenseMap<MachineBasicBlock *, int> SoFarNumInsns;
413   DenseMap<MachineBasicBlock *, MachineBasicBlock *> FollowThroughBB;
414   std::vector<MachineBasicBlock *> MBBs;
415 
416   MachineBasicBlock *PrevBB = nullptr;
417   for (MachineBasicBlock &MBB : *MF) {
418     // MBB.size() is the number of insns in this basic block, including some
419     // debug info, e.g., DEBUG_VALUE, so we may over-count a little bit.
420     // Typically we have way more normal insns than DEBUG_VALUE insns.
421     // Also, if we indeed need to convert conditional branch like JEQ to
422     // JEQ + JMPL, we actually introduced some new insns like below.
423     CurrNumInsns += (int)MBB.size();
424     SoFarNumInsns[&MBB] = CurrNumInsns;
425     if (PrevBB != nullptr)
426       FollowThroughBB[PrevBB] = &MBB;
427     PrevBB = &MBB;
428     // A list of original BBs to make later traveral easier.
429     MBBs.push_back(&MBB);
430   }
431   FollowThroughBB[PrevBB] = nullptr;
432 
433   for (unsigned i = 0; i < MBBs.size(); i++) {
434     // We have four cases here:
435     //  (1). no terminator, simple follow through.
436     //  (2). jmp to another bb.
437     //  (3). conditional jmp to another bb or follow through.
438     //  (4). conditional jmp followed by an unconditional jmp.
439     MachineInstr *CondJmp = nullptr, *UncondJmp = nullptr;
440 
441     MachineBasicBlock *MBB = MBBs[i];
442     for (MachineInstr &Term : MBB->terminators()) {
443       if (Term.isConditionalBranch()) {
444         assert(CondJmp == nullptr);
445         CondJmp = &Term;
446       } else if (Term.isUnconditionalBranch()) {
447         assert(UncondJmp == nullptr);
448         UncondJmp = &Term;
449       }
450     }
451 
452     // (1). no terminator, simple follow through.
453     if (!CondJmp && !UncondJmp)
454       continue;
455 
456     MachineBasicBlock *CondTargetBB, *JmpBB;
457     CurrNumInsns = SoFarNumInsns[MBB];
458 
459     // (2). jmp to another bb.
460     if (!CondJmp && UncondJmp) {
461       JmpBB = UncondJmp->getOperand(0).getMBB();
462       if (in16BitRange(SoFarNumInsns[JmpBB] - JmpBB->size() - CurrNumInsns))
463         continue;
464 
465       // replace this insn as a JMPL.
466       BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
467       UncondJmp->eraseFromParent();
468       Changed = true;
469       continue;
470     }
471 
472     const BasicBlock *TermBB = MBB->getBasicBlock();
473     int Dist;
474 
475     // (3). conditional jmp to another bb or follow through.
476     if (!UncondJmp) {
477       CondTargetBB = CondJmp->getOperand(2).getMBB();
478       MachineBasicBlock *FollowBB = FollowThroughBB[MBB];
479       Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
480       if (in16BitRange(Dist))
481         continue;
482 
483       // We have
484       //   B2: ...
485       //       if (cond) goto B5
486       //   B3: ...
487       // where B2 -> B5 is beyond 16bit range.
488       //
489       // We do not have 32bit cond jmp insn. So we try to do
490       // the following.
491       //   B2:     ...
492       //           if (cond) goto New_B1
493       //   New_B0  goto B3
494       //   New_B1: gotol B5
495       //   B3: ...
496       // Basically two new basic blocks are created.
497       MachineBasicBlock *New_B0 = MF->CreateMachineBasicBlock(TermBB);
498       MachineBasicBlock *New_B1 = MF->CreateMachineBasicBlock(TermBB);
499 
500       // Insert New_B0 and New_B1 into function block list.
501       MachineFunction::iterator MBB_I  = ++MBB->getIterator();
502       MF->insert(MBB_I, New_B0);
503       MF->insert(MBB_I, New_B1);
504 
505       // replace B2 cond jump
506       if (CondJmp->getOperand(1).isReg())
507         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
508             .addReg(CondJmp->getOperand(0).getReg())
509             .addReg(CondJmp->getOperand(1).getReg())
510             .addMBB(New_B1);
511       else
512         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
513             .addReg(CondJmp->getOperand(0).getReg())
514             .addImm(CondJmp->getOperand(1).getImm())
515             .addMBB(New_B1);
516 
517       // it is possible that CondTargetBB and FollowBB are the same. But the
518       // above Dist checking should already filtered this case.
519       MBB->removeSuccessor(CondTargetBB);
520       MBB->removeSuccessor(FollowBB);
521       MBB->addSuccessor(New_B0);
522       MBB->addSuccessor(New_B1);
523 
524       // Populate insns in New_B0 and New_B1.
525       BuildMI(New_B0, CondJmp->getDebugLoc(), TII->get(BPF::JMP)).addMBB(FollowBB);
526       BuildMI(New_B1, CondJmp->getDebugLoc(), TII->get(BPF::JMPL))
527           .addMBB(CondTargetBB);
528 
529       New_B0->addSuccessor(FollowBB);
530       New_B1->addSuccessor(CondTargetBB);
531       CondJmp->eraseFromParent();
532       Changed = true;
533       continue;
534     }
535 
536     //  (4). conditional jmp followed by an unconditional jmp.
537     CondTargetBB = CondJmp->getOperand(2).getMBB();
538     JmpBB = UncondJmp->getOperand(0).getMBB();
539 
540     // We have
541     //   B2: ...
542     //       if (cond) goto B5
543     //       JMP B7
544     //   B3: ...
545     //
546     // If only B2->B5 is out of 16bit range, we can do
547     //   B2: ...
548     //       if (cond) goto new_B
549     //       JMP B7
550     //   New_B: gotol B5
551     //   B3: ...
552     //
553     // If only 'JMP B7' is out of 16bit range, we can replace
554     // 'JMP B7' with 'JMPL B7'.
555     //
556     // If both B2->B5 and 'JMP B7' is out of range, just do
557     // both the above transformations.
558     Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
559     if (!in16BitRange(Dist)) {
560       MachineBasicBlock *New_B = MF->CreateMachineBasicBlock(TermBB);
561 
562       // Insert New_B0 into function block list.
563       MF->insert(++MBB->getIterator(), New_B);
564 
565       // replace B2 cond jump
566       if (CondJmp->getOperand(1).isReg())
567         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
568             .addReg(CondJmp->getOperand(0).getReg())
569             .addReg(CondJmp->getOperand(1).getReg())
570             .addMBB(New_B);
571       else
572         BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
573             .addReg(CondJmp->getOperand(0).getReg())
574             .addImm(CondJmp->getOperand(1).getImm())
575             .addMBB(New_B);
576 
577       if (CondTargetBB != JmpBB)
578         MBB->removeSuccessor(CondTargetBB);
579       MBB->addSuccessor(New_B);
580 
581       // Populate insn in New_B.
582       BuildMI(New_B, CondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(CondTargetBB);
583 
584       New_B->addSuccessor(CondTargetBB);
585       CondJmp->eraseFromParent();
586       Changed = true;
587     }
588 
589     if (!in16BitRange(SoFarNumInsns[JmpBB] - CurrNumInsns)) {
590       BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
591       UncondJmp->eraseFromParent();
592       Changed = true;
593     }
594   }
595 
596   return Changed;
597 }
598 
599 } // end default namespace
600 
601 INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
602                 "BPF PreEmit Peephole Optimization", false, false)
603 
604 char BPFMIPreEmitPeephole::ID = 0;
605 FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
606 {
607   return new BPFMIPreEmitPeephole();
608 }
609