xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFMIPeephole.cpp (revision 66fd12cf4896eb08ad8e7a2627537f84ead84dd3)
1 //===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs peephole optimizations to cleanup ugly code sequences at
10 // MachineInstruction layer.
11 //
12 // Currently, there are two optimizations implemented:
13 //  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14 //    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15 //    could prove the subregisters is defined by 32-bit operations in which
16 //    case the upper half of the underlying 64-bit registers were zeroed
17 //    implicitly.
18 //
19 //  - One post-RA PreEmit pass to do final cleanup on some redundant
20 //    instructions generated due to bad RA on subregister.
21 //===----------------------------------------------------------------------===//
22 
23 #include "BPF.h"
24 #include "BPFInstrInfo.h"
25 #include "BPFTargetMachine.h"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineInstrBuilder.h"
29 #include "llvm/CodeGen/MachineRegisterInfo.h"
30 #include "llvm/Support/Debug.h"
31 #include <set>
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "bpf-mi-zext-elim"
36 
37 STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
38 
39 namespace {
40 
41 struct BPFMIPeephole : public MachineFunctionPass {
42 
43   static char ID;
44   const BPFInstrInfo *TII;
45   MachineFunction *MF;
46   MachineRegisterInfo *MRI;
47 
48   BPFMIPeephole() : MachineFunctionPass(ID) {
49     initializeBPFMIPeepholePass(*PassRegistry::getPassRegistry());
50   }
51 
52 private:
53   // Initialize class variables.
54   void initialize(MachineFunction &MFParm);
55 
56   bool isCopyFrom32Def(MachineInstr *CopyMI);
57   bool isInsnFrom32Def(MachineInstr *DefInsn);
58   bool isPhiFrom32Def(MachineInstr *MovMI);
59   bool isMovFrom32Def(MachineInstr *MovMI);
60   bool eliminateZExtSeq();
61   bool eliminateZExt();
62 
63   std::set<MachineInstr *> PhiInsns;
64 
65 public:
66 
67   // Main entry point for this pass.
68   bool runOnMachineFunction(MachineFunction &MF) override {
69     if (skipFunction(MF.getFunction()))
70       return false;
71 
72     initialize(MF);
73 
74     // First try to eliminate (zext, lshift, rshift) and then
75     // try to eliminate zext.
76     bool ZExtSeqExist, ZExtExist;
77     ZExtSeqExist = eliminateZExtSeq();
78     ZExtExist = eliminateZExt();
79     return ZExtSeqExist || ZExtExist;
80   }
81 };
82 
83 // Initialize class variables.
84 void BPFMIPeephole::initialize(MachineFunction &MFParm) {
85   MF = &MFParm;
86   MRI = &MF->getRegInfo();
87   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
88   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
89 }
90 
91 bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
92 {
93   MachineOperand &opnd = CopyMI->getOperand(1);
94 
95   if (!opnd.isReg())
96     return false;
97 
98   // Return false if getting value from a 32bit physical register.
99   // Most likely, this physical register is aliased to
100   // function call return value or current function parameters.
101   Register Reg = opnd.getReg();
102   if (!Reg.isVirtual())
103     return false;
104 
105   if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
106     return false;
107 
108   MachineInstr *DefInsn = MRI->getVRegDef(Reg);
109   if (!isInsnFrom32Def(DefInsn))
110     return false;
111 
112   return true;
113 }
114 
115 bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
116 {
117   for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
118     MachineOperand &opnd = PhiMI->getOperand(i);
119 
120     if (!opnd.isReg())
121       return false;
122 
123     MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
124     if (!PhiDef)
125       return false;
126     if (PhiDef->isPHI()) {
127       if (!PhiInsns.insert(PhiDef).second)
128         return false;
129       if (!isPhiFrom32Def(PhiDef))
130         return false;
131     }
132     if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
133       return false;
134   }
135 
136   return true;
137 }
138 
139 // The \p DefInsn instruction defines a virtual register.
140 bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
141 {
142   if (!DefInsn)
143     return false;
144 
145   if (DefInsn->isPHI()) {
146     if (!PhiInsns.insert(DefInsn).second)
147       return false;
148     if (!isPhiFrom32Def(DefInsn))
149       return false;
150   } else if (DefInsn->getOpcode() == BPF::COPY) {
151     if (!isCopyFrom32Def(DefInsn))
152       return false;
153   }
154 
155   return true;
156 }
157 
158 bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
159 {
160   MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
161 
162   LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
163   LLVM_DEBUG(DefInsn->dump());
164 
165   PhiInsns.clear();
166   if (!isInsnFrom32Def(DefInsn))
167     return false;
168 
169   LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
170 
171   return true;
172 }
173 
174 bool BPFMIPeephole::eliminateZExtSeq() {
175   MachineInstr* ToErase = nullptr;
176   bool Eliminated = false;
177 
178   for (MachineBasicBlock &MBB : *MF) {
179     for (MachineInstr &MI : MBB) {
180       // If the previous instruction was marked for elimination, remove it now.
181       if (ToErase) {
182         ToErase->eraseFromParent();
183         ToErase = nullptr;
184       }
185 
186       // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
187       //
188       //   MOV_32_64 rB, wA
189       //   SLL_ri    rB, rB, 32
190       //   SRL_ri    rB, rB, 32
191       if (MI.getOpcode() == BPF::SRL_ri &&
192           MI.getOperand(2).getImm() == 32) {
193         Register DstReg = MI.getOperand(0).getReg();
194         Register ShfReg = MI.getOperand(1).getReg();
195         MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
196 
197         LLVM_DEBUG(dbgs() << "Starting SRL found:");
198         LLVM_DEBUG(MI.dump());
199 
200         if (!SllMI ||
201             SllMI->isPHI() ||
202             SllMI->getOpcode() != BPF::SLL_ri ||
203             SllMI->getOperand(2).getImm() != 32)
204           continue;
205 
206         LLVM_DEBUG(dbgs() << "  SLL found:");
207         LLVM_DEBUG(SllMI->dump());
208 
209         MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
210         if (!MovMI ||
211             MovMI->isPHI() ||
212             MovMI->getOpcode() != BPF::MOV_32_64)
213           continue;
214 
215         LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
216         LLVM_DEBUG(MovMI->dump());
217 
218         Register SubReg = MovMI->getOperand(1).getReg();
219         if (!isMovFrom32Def(MovMI)) {
220           LLVM_DEBUG(dbgs()
221                      << "  One ZExt elim sequence failed qualifying elim.\n");
222           continue;
223         }
224 
225         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
226           .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
227 
228         SllMI->eraseFromParent();
229         MovMI->eraseFromParent();
230         // MI is the right shift, we can't erase it in it's own iteration.
231         // Mark it to ToErase, and erase in the next iteration.
232         ToErase = &MI;
233         ZExtElemNum++;
234         Eliminated = true;
235       }
236     }
237   }
238 
239   return Eliminated;
240 }
241 
242 bool BPFMIPeephole::eliminateZExt() {
243   MachineInstr* ToErase = nullptr;
244   bool Eliminated = false;
245 
246   for (MachineBasicBlock &MBB : *MF) {
247     for (MachineInstr &MI : MBB) {
248       // If the previous instruction was marked for elimination, remove it now.
249       if (ToErase) {
250         ToErase->eraseFromParent();
251         ToErase = nullptr;
252       }
253 
254       if (MI.getOpcode() != BPF::MOV_32_64)
255         continue;
256 
257       // Eliminate MOV_32_64 if possible.
258       //   MOV_32_64 rA, wB
259       //
260       // If wB has been zero extended, replace it with a SUBREG_TO_REG.
261       // This is to workaround BPF programs where pkt->{data, data_end}
262       // is encoded as u32, but actually the verifier populates them
263       // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
264       LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
265       LLVM_DEBUG(MI.dump());
266 
267       if (!isMovFrom32Def(&MI))
268         continue;
269 
270       LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
271 
272       Register dst = MI.getOperand(0).getReg();
273       Register src = MI.getOperand(1).getReg();
274 
275       // Build a SUBREG_TO_REG instruction.
276       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
277         .addImm(0).addReg(src).addImm(BPF::sub_32);
278 
279       ToErase = &MI;
280       Eliminated = true;
281     }
282   }
283 
284   return Eliminated;
285 }
286 
287 } // end default namespace
288 
289 INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
290                 "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
291                 false, false)
292 
293 char BPFMIPeephole::ID = 0;
294 FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
295 
296 STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
297 
298 namespace {
299 
300 struct BPFMIPreEmitPeephole : public MachineFunctionPass {
301 
302   static char ID;
303   MachineFunction *MF;
304   const TargetRegisterInfo *TRI;
305 
306   BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {
307     initializeBPFMIPreEmitPeepholePass(*PassRegistry::getPassRegistry());
308   }
309 
310 private:
311   // Initialize class variables.
312   void initialize(MachineFunction &MFParm);
313 
314   bool eliminateRedundantMov();
315 
316 public:
317 
318   // Main entry point for this pass.
319   bool runOnMachineFunction(MachineFunction &MF) override {
320     if (skipFunction(MF.getFunction()))
321       return false;
322 
323     initialize(MF);
324 
325     return eliminateRedundantMov();
326   }
327 };
328 
329 // Initialize class variables.
330 void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
331   MF = &MFParm;
332   TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
333   LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
334 }
335 
336 bool BPFMIPreEmitPeephole::eliminateRedundantMov() {
337   MachineInstr* ToErase = nullptr;
338   bool Eliminated = false;
339 
340   for (MachineBasicBlock &MBB : *MF) {
341     for (MachineInstr &MI : MBB) {
342       // If the previous instruction was marked for elimination, remove it now.
343       if (ToErase) {
344         LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
345         LLVM_DEBUG(ToErase->dump());
346         ToErase->eraseFromParent();
347         ToErase = nullptr;
348       }
349 
350       // Eliminate identical move:
351       //
352       //   MOV rA, rA
353       //
354       // Note that we cannot remove
355       //   MOV_32_64  rA, wA
356       //   MOV_rr_32  wA, wA
357       // as these two instructions having side effects, zeroing out
358       // top 32 bits of rA.
359       unsigned Opcode = MI.getOpcode();
360       if (Opcode == BPF::MOV_rr) {
361         Register dst = MI.getOperand(0).getReg();
362         Register src = MI.getOperand(1).getReg();
363 
364         if (dst != src)
365           continue;
366 
367         ToErase = &MI;
368         RedundantMovElemNum++;
369         Eliminated = true;
370       }
371     }
372   }
373 
374   return Eliminated;
375 }
376 
377 } // end default namespace
378 
379 INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
380                 "BPF PreEmit Peephole Optimization", false, false)
381 
382 char BPFMIPreEmitPeephole::ID = 0;
383 FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
384 {
385   return new BPFMIPreEmitPeephole();
386 }
387 
388 STATISTIC(TruncElemNum, "Number of truncation eliminated");
389 
390 namespace {
391 
392 struct BPFMIPeepholeTruncElim : public MachineFunctionPass {
393 
394   static char ID;
395   const BPFInstrInfo *TII;
396   MachineFunction *MF;
397   MachineRegisterInfo *MRI;
398 
399   BPFMIPeepholeTruncElim() : MachineFunctionPass(ID) {
400     initializeBPFMIPeepholeTruncElimPass(*PassRegistry::getPassRegistry());
401   }
402 
403 private:
404   // Initialize class variables.
405   void initialize(MachineFunction &MFParm);
406 
407   bool eliminateTruncSeq();
408 
409 public:
410 
411   // Main entry point for this pass.
412   bool runOnMachineFunction(MachineFunction &MF) override {
413     if (skipFunction(MF.getFunction()))
414       return false;
415 
416     initialize(MF);
417 
418     return eliminateTruncSeq();
419   }
420 };
421 
422 static bool TruncSizeCompatible(int TruncSize, unsigned opcode)
423 {
424   if (TruncSize == 1)
425     return opcode == BPF::LDB || opcode == BPF::LDB32;
426 
427   if (TruncSize == 2)
428     return opcode == BPF::LDH || opcode == BPF::LDH32;
429 
430   if (TruncSize == 4)
431     return opcode == BPF::LDW || opcode == BPF::LDW32;
432 
433   return false;
434 }
435 
436 // Initialize class variables.
437 void BPFMIPeepholeTruncElim::initialize(MachineFunction &MFParm) {
438   MF = &MFParm;
439   MRI = &MF->getRegInfo();
440   TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
441   LLVM_DEBUG(dbgs() << "*** BPF MachineSSA TRUNC Elim peephole pass ***\n\n");
442 }
443 
444 // Reg truncating is often the result of 8/16/32bit->64bit or
445 // 8/16bit->32bit conversion. If the reg value is loaded with
446 // masked byte width, the AND operation can be removed since
447 // BPF LOAD already has zero extension.
448 //
449 // This also solved a correctness issue.
450 // In BPF socket-related program, e.g., __sk_buff->{data, data_end}
451 // are 32-bit registers, but later on, kernel verifier will rewrite
452 // it with 64-bit value. Therefore, truncating the value after the
453 // load will result in incorrect code.
454 bool BPFMIPeepholeTruncElim::eliminateTruncSeq() {
455   MachineInstr* ToErase = nullptr;
456   bool Eliminated = false;
457 
458   for (MachineBasicBlock &MBB : *MF) {
459     for (MachineInstr &MI : MBB) {
460       // The second insn to remove if the eliminate candidate is a pair.
461       MachineInstr *MI2 = nullptr;
462       Register DstReg, SrcReg;
463       MachineInstr *DefMI;
464       int TruncSize = -1;
465 
466       // If the previous instruction was marked for elimination, remove it now.
467       if (ToErase) {
468         ToErase->eraseFromParent();
469         ToErase = nullptr;
470       }
471 
472       // AND A, 0xFFFFFFFF will be turned into SLL/SRL pair due to immediate
473       // for BPF ANDI is i32, and this case only happens on ALU64.
474       if (MI.getOpcode() == BPF::SRL_ri &&
475           MI.getOperand(2).getImm() == 32) {
476         SrcReg = MI.getOperand(1).getReg();
477         if (!MRI->hasOneNonDBGUse(SrcReg))
478           continue;
479 
480         MI2 = MRI->getVRegDef(SrcReg);
481         DstReg = MI.getOperand(0).getReg();
482 
483         if (!MI2 ||
484             MI2->getOpcode() != BPF::SLL_ri ||
485             MI2->getOperand(2).getImm() != 32)
486           continue;
487 
488         // Update SrcReg.
489         SrcReg = MI2->getOperand(1).getReg();
490         DefMI = MRI->getVRegDef(SrcReg);
491         if (DefMI)
492           TruncSize = 4;
493       } else if (MI.getOpcode() == BPF::AND_ri ||
494                  MI.getOpcode() == BPF::AND_ri_32) {
495         SrcReg = MI.getOperand(1).getReg();
496         DstReg = MI.getOperand(0).getReg();
497         DefMI = MRI->getVRegDef(SrcReg);
498 
499         if (!DefMI)
500           continue;
501 
502         int64_t imm = MI.getOperand(2).getImm();
503         if (imm == 0xff)
504           TruncSize = 1;
505         else if (imm == 0xffff)
506           TruncSize = 2;
507       }
508 
509       if (TruncSize == -1)
510         continue;
511 
512       // The definition is PHI node, check all inputs.
513       if (DefMI->isPHI()) {
514         bool CheckFail = false;
515 
516         for (unsigned i = 1, e = DefMI->getNumOperands(); i < e; i += 2) {
517           MachineOperand &opnd = DefMI->getOperand(i);
518           if (!opnd.isReg()) {
519             CheckFail = true;
520             break;
521           }
522 
523           MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
524           if (!PhiDef || PhiDef->isPHI() ||
525               !TruncSizeCompatible(TruncSize, PhiDef->getOpcode())) {
526             CheckFail = true;
527             break;
528           }
529         }
530 
531         if (CheckFail)
532           continue;
533       } else if (!TruncSizeCompatible(TruncSize, DefMI->getOpcode())) {
534         continue;
535       }
536 
537       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::MOV_rr), DstReg)
538               .addReg(SrcReg);
539 
540       if (MI2)
541         MI2->eraseFromParent();
542 
543       // Mark it to ToErase, and erase in the next iteration.
544       ToErase = &MI;
545       TruncElemNum++;
546       Eliminated = true;
547     }
548   }
549 
550   return Eliminated;
551 }
552 
553 } // end default namespace
554 
555 INITIALIZE_PASS(BPFMIPeepholeTruncElim, "bpf-mi-trunc-elim",
556                 "BPF MachineSSA Peephole Optimization For TRUNC Eliminate",
557                 false, false)
558 
559 char BPFMIPeepholeTruncElim::ID = 0;
560 FunctionPass* llvm::createBPFMIPeepholeTruncElimPass()
561 {
562   return new BPFMIPeepholeTruncElim();
563 }
564