xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFAdjustOpt.cpp (revision 53120fbb68952b7d620c2c0e1cf05c5017fc1b27)
1 //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adjust optimization to make the code more kernel verifier friendly.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "BPF.h"
14 #include "BPFCORE.h"
15 #include "BPFTargetMachine.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/IntrinsicsBPF.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/IR/Type.h"
22 #include "llvm/IR/User.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
26 
27 #define DEBUG_TYPE "bpf-adjust-opt"
28 
29 using namespace llvm;
30 using namespace llvm::PatternMatch;
31 
32 static cl::opt<bool>
33     DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
34                             cl::desc("BPF: Disable Serializing ICMP insns."),
35                             cl::init(false));
36 
37 static cl::opt<bool> DisableBPFavoidSpeculation(
38     "bpf-disable-avoid-speculation", cl::Hidden,
39     cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
40     cl::init(false));
41 
42 namespace {
43 class BPFAdjustOptImpl {
44   struct PassThroughInfo {
45     Instruction *Input;
46     Instruction *UsedInst;
47     uint32_t OpIdx;
48     PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
49         : Input(I), UsedInst(U), OpIdx(Idx) {}
50   };
51 
52 public:
53   BPFAdjustOptImpl(Module *M) : M(M) {}
54 
55   bool run();
56 
57 private:
58   Module *M;
59   SmallVector<PassThroughInfo, 16> PassThroughs;
60 
61   bool adjustICmpToBuiltin();
62   void adjustBasicBlock(BasicBlock &BB);
63   bool serializeICMPCrossBB(BasicBlock &BB);
64   void adjustInst(Instruction &I);
65   bool serializeICMPInBB(Instruction &I);
66   bool avoidSpeculation(Instruction &I);
67   bool insertPassThrough();
68 };
69 
70 } // End anonymous namespace
71 
72 bool BPFAdjustOptImpl::run() {
73   bool Changed = adjustICmpToBuiltin();
74 
75   for (Function &F : *M)
76     for (auto &BB : F) {
77       adjustBasicBlock(BB);
78       for (auto &I : BB)
79         adjustInst(I);
80     }
81   return insertPassThrough() || Changed;
82 }
83 
84 // Commit acabad9ff6bf ("[InstCombine] try to canonicalize icmp with
85 // trunc op into mask and cmp") added a transformation to
86 // convert "(conv)a < power_2_const" to "a & <const>" in certain
87 // cases and bpf kernel verifier has to handle the resulted code
88 // conservatively and this may reject otherwise legitimate program.
89 // Here, we change related icmp code to a builtin which will
90 // be restored to original icmp code later to prevent that
91 // InstCombine transformatin.
92 bool BPFAdjustOptImpl::adjustICmpToBuiltin() {
93   bool Changed = false;
94   ICmpInst *ToBeDeleted = nullptr;
95   for (Function &F : *M)
96     for (auto &BB : F)
97       for (auto &I : BB) {
98         if (ToBeDeleted) {
99           ToBeDeleted->eraseFromParent();
100           ToBeDeleted = nullptr;
101         }
102 
103         auto *Icmp = dyn_cast<ICmpInst>(&I);
104         if (!Icmp)
105           continue;
106 
107         Value *Op0 = Icmp->getOperand(0);
108         if (!isa<TruncInst>(Op0))
109           continue;
110 
111         auto ConstOp1 = dyn_cast<ConstantInt>(Icmp->getOperand(1));
112         if (!ConstOp1)
113           continue;
114 
115         auto ConstOp1Val = ConstOp1->getValue().getZExtValue();
116         auto Op = Icmp->getPredicate();
117         if (Op == ICmpInst::ICMP_ULT || Op == ICmpInst::ICMP_UGE) {
118           if ((ConstOp1Val - 1) & ConstOp1Val)
119             continue;
120         } else if (Op == ICmpInst::ICMP_ULE || Op == ICmpInst::ICMP_UGT) {
121           if (ConstOp1Val & (ConstOp1Val + 1))
122             continue;
123         } else {
124           continue;
125         }
126 
127         Constant *Opcode =
128             ConstantInt::get(Type::getInt32Ty(BB.getContext()), Op);
129         Function *Fn = Intrinsic::getDeclaration(
130             M, Intrinsic::bpf_compare, {Op0->getType(), ConstOp1->getType()});
131         auto *NewInst = CallInst::Create(Fn, {Opcode, Op0, ConstOp1});
132         NewInst->insertBefore(&I);
133         Icmp->replaceAllUsesWith(NewInst);
134         Changed = true;
135         ToBeDeleted = Icmp;
136       }
137 
138   return Changed;
139 }
140 
141 bool BPFAdjustOptImpl::insertPassThrough() {
142   for (auto &Info : PassThroughs) {
143     auto *CI = BPFCoreSharedInfo::insertPassThrough(
144         M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
145     Info.UsedInst->setOperand(Info.OpIdx, CI);
146   }
147 
148   return !PassThroughs.empty();
149 }
150 
151 // To avoid combining conditionals in the same basic block by
152 // instrcombine optimization.
153 bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
154   // For:
155   //   comp1 = icmp <opcode> ...;
156   //   comp2 = icmp <opcode> ...;
157   //   ... or comp1 comp2 ...
158   // changed to:
159   //   comp1 = icmp <opcode> ...;
160   //   comp2 = icmp <opcode> ...;
161   //   new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
162   //   ... or new_comp1 comp2 ...
163   Value *Op0, *Op1;
164   // Use LogicalOr (accept `or i1` as well as `select i1 Op0, true, Op1`)
165   if (!match(&I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
166     return false;
167   auto *Icmp1 = dyn_cast<ICmpInst>(Op0);
168   if (!Icmp1)
169     return false;
170   auto *Icmp2 = dyn_cast<ICmpInst>(Op1);
171   if (!Icmp2)
172     return false;
173 
174   Value *Icmp1Op0 = Icmp1->getOperand(0);
175   Value *Icmp2Op0 = Icmp2->getOperand(0);
176   if (Icmp1Op0 != Icmp2Op0)
177     return false;
178 
179   // Now we got two icmp instructions which feed into
180   // an "or" instruction.
181   PassThroughInfo Info(Icmp1, &I, 0);
182   PassThroughs.push_back(Info);
183   return true;
184 }
185 
186 // To avoid combining conditionals in the same basic block by
187 // instrcombine optimization.
188 bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
189   // For:
190   //   B1:
191   //     comp1 = icmp <opcode> ...;
192   //     if (comp1) goto B2 else B3;
193   //   B2:
194   //     comp2 = icmp <opcode> ...;
195   //     if (comp2) goto B4 else B5;
196   //   B4:
197   //     ...
198   // changed to:
199   //   B1:
200   //     comp1 = icmp <opcode> ...;
201   //     comp1 = __builtin_bpf_passthrough(seq_num, comp1);
202   //     if (comp1) goto B2 else B3;
203   //   B2:
204   //     comp2 = icmp <opcode> ...;
205   //     if (comp2) goto B4 else B5;
206   //   B4:
207   //     ...
208 
209   // Check basic predecessors, if two of them (say B1, B2) are using
210   // icmp instructions to generate conditions and one is the predesessor
211   // of another (e.g., B1 is the predecessor of B2). Add a passthrough
212   // barrier after icmp inst of block B1.
213   BasicBlock *B2 = BB.getSinglePredecessor();
214   if (!B2)
215     return false;
216 
217   BasicBlock *B1 = B2->getSinglePredecessor();
218   if (!B1)
219     return false;
220 
221   Instruction *TI = B2->getTerminator();
222   auto *BI = dyn_cast<BranchInst>(TI);
223   if (!BI || !BI->isConditional())
224     return false;
225   auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
226   if (!Cond || B2->getFirstNonPHI() != Cond)
227     return false;
228   Value *B2Op0 = Cond->getOperand(0);
229   auto Cond2Op = Cond->getPredicate();
230 
231   TI = B1->getTerminator();
232   BI = dyn_cast<BranchInst>(TI);
233   if (!BI || !BI->isConditional())
234     return false;
235   Cond = dyn_cast<ICmpInst>(BI->getCondition());
236   if (!Cond)
237     return false;
238   Value *B1Op0 = Cond->getOperand(0);
239   auto Cond1Op = Cond->getPredicate();
240 
241   if (B1Op0 != B2Op0)
242     return false;
243 
244   if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
245     if (Cond2Op != ICmpInst::ICMP_SLT && Cond2Op != ICmpInst::ICMP_SLE)
246       return false;
247   } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
248     if (Cond2Op != ICmpInst::ICMP_SGT && Cond2Op != ICmpInst::ICMP_SGE)
249       return false;
250   } else if (Cond1Op == ICmpInst::ICMP_ULT || Cond1Op == ICmpInst::ICMP_ULE) {
251     if (Cond2Op != ICmpInst::ICMP_UGT && Cond2Op != ICmpInst::ICMP_UGE)
252       return false;
253   } else if (Cond1Op == ICmpInst::ICMP_UGT || Cond1Op == ICmpInst::ICMP_UGE) {
254     if (Cond2Op != ICmpInst::ICMP_ULT && Cond2Op != ICmpInst::ICMP_ULE)
255       return false;
256   } else {
257     return false;
258   }
259 
260   PassThroughInfo Info(Cond, BI, 0);
261   PassThroughs.push_back(Info);
262 
263   return true;
264 }
265 
266 // To avoid speculative hoisting certain computations out of
267 // a basic block.
268 bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
269   if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
270     if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
271       if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
272           GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
273         return false;
274     }
275   }
276 
277   if (!isa<LoadInst>(&I) && !isa<CallInst>(&I))
278     return false;
279 
280   // For:
281   //   B1:
282   //     var = ...
283   //     ...
284   //     /* icmp may not be in the same block as var = ... */
285   //     comp1 = icmp <opcode> var, <const>;
286   //     if (comp1) goto B2 else B3;
287   //   B2:
288   //     ... var ...
289   // change to:
290   //   B1:
291   //     var = ...
292   //     ...
293   //     /* icmp may not be in the same block as var = ... */
294   //     comp1 = icmp <opcode> var, <const>;
295   //     if (comp1) goto B2 else B3;
296   //   B2:
297   //     var = __builtin_bpf_passthrough(seq_num, var);
298   //     ... var ...
299   bool isCandidate = false;
300   SmallVector<PassThroughInfo, 4> Candidates;
301   for (User *U : I.users()) {
302     Instruction *Inst = dyn_cast<Instruction>(U);
303     if (!Inst)
304       continue;
305 
306     // May cover a little bit more than the
307     // above pattern.
308     if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
309       Value *Icmp1Op1 = Icmp1->getOperand(1);
310       if (!isa<Constant>(Icmp1Op1))
311         return false;
312       isCandidate = true;
313       continue;
314     }
315 
316     // Ignore the use in the same basic block as the definition.
317     if (Inst->getParent() == I.getParent())
318       continue;
319 
320     // use in a different basic block, If there is a call or
321     // load/store insn before this instruction in this basic
322     // block. Most likely it cannot be hoisted out. Skip it.
323     for (auto &I2 : *Inst->getParent()) {
324       if (isa<CallInst>(&I2))
325         return false;
326       if (isa<LoadInst>(&I2) || isa<StoreInst>(&I2))
327         return false;
328       if (&I2 == Inst)
329         break;
330     }
331 
332     // It should be used in a GEP or a simple arithmetic like
333     // ZEXT/SEXT which is used for GEP.
334     if (Inst->getOpcode() == Instruction::ZExt ||
335         Inst->getOpcode() == Instruction::SExt) {
336       PassThroughInfo Info(&I, Inst, 0);
337       Candidates.push_back(Info);
338     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
339       // traverse GEP inst to find Use operand index
340       unsigned i, e;
341       for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
342         Value *V = GI->getOperand(i);
343         if (V == &I)
344           break;
345       }
346       if (i == e)
347         continue;
348 
349       PassThroughInfo Info(&I, GI, i);
350       Candidates.push_back(Info);
351     }
352   }
353 
354   if (!isCandidate || Candidates.empty())
355     return false;
356 
357   llvm::append_range(PassThroughs, Candidates);
358   return true;
359 }
360 
361 void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
362   if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
363     return;
364 }
365 
366 void BPFAdjustOptImpl::adjustInst(Instruction &I) {
367   if (!DisableBPFserializeICMP && serializeICMPInBB(I))
368     return;
369   if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
370     return;
371 }
372 
373 PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
374   return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
375                                     : PreservedAnalyses::all();
376 }
377