xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp (revision cb14a3fe5122c879eae1fb480ed7ce82a699ddb6)
1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Check IR and adjust IR for verifier friendly codes.
10 // The following are done for IR checking:
11 //   - no relocation globals in PHI node.
12 // The following are done for IR adjustment:
13 //   - remove __builtin_bpf_passthrough builtins. Target independent IR
14 //     optimizations are done and those builtins can be removed.
15 //   - remove llvm.bpf.getelementptr.and.load builtins.
16 //   - remove llvm.bpf.getelementptr.and.store builtins.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "BPF.h"
21 #include "BPFCORE.h"
22 #include "BPFTargetMachine.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/IR/DebugInfoMetadata.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicsBPF.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/User.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
36 
37 #define DEBUG_TYPE "bpf-check-and-opt-ir"
38 
39 using namespace llvm;
40 
41 namespace {
42 
43 class BPFCheckAndAdjustIR final : public ModulePass {
44   bool runOnModule(Module &F) override;
45 
46 public:
47   static char ID;
48   BPFCheckAndAdjustIR() : ModulePass(ID) {}
49   virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
50 
51 private:
52   void checkIR(Module &M);
53   bool adjustIR(Module &M);
54   bool removePassThroughBuiltin(Module &M);
55   bool removeCompareBuiltin(Module &M);
56   bool sinkMinMax(Module &M);
57   bool removeGEPBuiltins(Module &M);
58 };
59 } // End anonymous namespace
60 
61 char BPFCheckAndAdjustIR::ID = 0;
62 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
63                 false, false)
64 
65 ModulePass *llvm::createBPFCheckAndAdjustIR() {
66   return new BPFCheckAndAdjustIR();
67 }
68 
69 void BPFCheckAndAdjustIR::checkIR(Module &M) {
70   // Ensure relocation global won't appear in PHI node
71   // This may happen if the compiler generated the following code:
72   //   B1:
73   //      g1 = @llvm.skb_buff:0:1...
74   //      ...
75   //      goto B_COMMON
76   //   B2:
77   //      g2 = @llvm.skb_buff:0:2...
78   //      ...
79   //      goto B_COMMON
80   //   B_COMMON:
81   //      g = PHI(g1, g2)
82   //      x = load g
83   //      ...
84   // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
85   for (Function &F : M)
86     for (auto &BB : F)
87       for (auto &I : BB) {
88         PHINode *PN = dyn_cast<PHINode>(&I);
89         if (!PN || PN->use_empty())
90           continue;
91         for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
92           auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
93           if (!GV)
94             continue;
95           if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
96               GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
97             report_fatal_error("relocation global in PHI node");
98         }
99       }
100 }
101 
102 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
103   // Remove __builtin_bpf_passthrough()'s which are used to prevent
104   // certain IR optimizations. Now major IR optimizations are done,
105   // remove them.
106   bool Changed = false;
107   CallInst *ToBeDeleted = nullptr;
108   for (Function &F : M)
109     for (auto &BB : F)
110       for (auto &I : BB) {
111         if (ToBeDeleted) {
112           ToBeDeleted->eraseFromParent();
113           ToBeDeleted = nullptr;
114         }
115 
116         auto *Call = dyn_cast<CallInst>(&I);
117         if (!Call)
118           continue;
119         auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
120         if (!GV)
121           continue;
122         if (!GV->getName().starts_with("llvm.bpf.passthrough"))
123           continue;
124         Changed = true;
125         Value *Arg = Call->getArgOperand(1);
126         Call->replaceAllUsesWith(Arg);
127         ToBeDeleted = Call;
128       }
129   return Changed;
130 }
131 
132 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
133   // Remove __builtin_bpf_compare()'s which are used to prevent
134   // certain IR optimizations. Now major IR optimizations are done,
135   // remove them.
136   bool Changed = false;
137   CallInst *ToBeDeleted = nullptr;
138   for (Function &F : M)
139     for (auto &BB : F)
140       for (auto &I : BB) {
141         if (ToBeDeleted) {
142           ToBeDeleted->eraseFromParent();
143           ToBeDeleted = nullptr;
144         }
145 
146         auto *Call = dyn_cast<CallInst>(&I);
147         if (!Call)
148           continue;
149         auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
150         if (!GV)
151           continue;
152         if (!GV->getName().starts_with("llvm.bpf.compare"))
153           continue;
154 
155         Changed = true;
156         Value *Arg0 = Call->getArgOperand(0);
157         Value *Arg1 = Call->getArgOperand(1);
158         Value *Arg2 = Call->getArgOperand(2);
159 
160         auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
161         CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal;
162 
163         auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
164         ICmp->insertBefore(Call);
165 
166         Call->replaceAllUsesWith(ICmp);
167         ToBeDeleted = Call;
168       }
169   return Changed;
170 }
171 
172 struct MinMaxSinkInfo {
173   ICmpInst *ICmp;
174   Value *Other;
175   ICmpInst::Predicate Predicate;
176   CallInst *MinMax;
177   ZExtInst *ZExt;
178   SExtInst *SExt;
179 
180   MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
181       : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
182         ZExt(nullptr), SExt(nullptr) {}
183 };
184 
185 static bool sinkMinMaxInBB(BasicBlock &BB,
186                            const std::function<bool(Instruction *)> &Filter) {
187   // Check if V is:
188   //   (fn %a %b) or (ext (fn %a %b))
189   // Where:
190   //   ext := sext | zext
191   //   fn  := smin | umin | smax | umax
192   auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
193     if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
194       V = ZExt->getOperand(0);
195       Info.ZExt = ZExt;
196     } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
197       V = SExt->getOperand(0);
198       Info.SExt = SExt;
199     }
200 
201     auto *Call = dyn_cast<CallInst>(V);
202     if (!Call)
203       return false;
204 
205     auto *Called = dyn_cast<Function>(Call->getCalledOperand());
206     if (!Called)
207       return false;
208 
209     switch (Called->getIntrinsicID()) {
210     case Intrinsic::smin:
211     case Intrinsic::umin:
212     case Intrinsic::smax:
213     case Intrinsic::umax:
214       break;
215     default:
216       return false;
217     }
218 
219     if (!Filter(Call))
220       return false;
221 
222     Info.MinMax = Call;
223 
224     return true;
225   };
226 
227   auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
228                              MinMaxSinkInfo &Info) {
229     if (Info.SExt) {
230       if (Info.SExt->getType() == V->getType())
231         return V;
232       return Builder.CreateSExt(V, Info.SExt->getType());
233     }
234     if (Info.ZExt) {
235       if (Info.ZExt->getType() == V->getType())
236         return V;
237       return Builder.CreateZExt(V, Info.ZExt->getType());
238     }
239     return V;
240   };
241 
242   bool Changed = false;
243   SmallVector<MinMaxSinkInfo, 2> SinkList;
244 
245   // Check BB for instructions like:
246   //   insn := (icmp %a (fn ...)) | (icmp (fn ...)  %a)
247   //
248   // Where:
249   //   fn := min | max | (sext (min ...)) | (sext (max ...))
250   //
251   // Put such instructions to SinkList.
252   for (Instruction &I : BB) {
253     ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
254     if (!ICmp)
255       continue;
256     if (!ICmp->isRelational())
257       continue;
258     MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
259                          ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
260     MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
261     bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
262     bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
263     if (!(FirstMinMax ^ SecondMinMax))
264       continue;
265     SinkList.push_back(FirstMinMax ? First : Second);
266   }
267 
268   // Iterate SinkList and replace each (icmp ...) with corresponding
269   // `x < a && x < b` or similar expression.
270   for (auto &Info : SinkList) {
271     ICmpInst *ICmp = Info.ICmp;
272     CallInst *MinMax = Info.MinMax;
273     Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
274     ICmpInst::Predicate P = Info.Predicate;
275     if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
276         IID != Intrinsic::smax)
277       continue;
278 
279     IRBuilder<> Builder(ICmp);
280     Value *X = Info.Other;
281     Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
282     Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
283     bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
284     bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
285     bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
286     bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
287     assert(IsMin ^ IsMax);
288     assert(IsLess ^ IsGreater);
289 
290     Value *Replacement;
291     Value *LHS = Builder.CreateICmp(P, X, A);
292     Value *RHS = Builder.CreateICmp(P, X, B);
293     if ((IsLess && IsMin) || (IsGreater && IsMax))
294       // x < min(a, b) -> x < a && x < b
295       // x > max(a, b) -> x > a && x > b
296       Replacement = Builder.CreateLogicalAnd(LHS, RHS);
297     else
298       // x > min(a, b) -> x > a || x > b
299       // x < max(a, b) -> x < a || x < b
300       Replacement = Builder.CreateLogicalOr(LHS, RHS);
301 
302     ICmp->replaceAllUsesWith(Replacement);
303 
304     Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
305     for (Instruction *I : ToRemove)
306       if (I && I->use_empty())
307         I->eraseFromParent();
308 
309     Changed = true;
310   }
311 
312   return Changed;
313 }
314 
315 // Do the following transformation:
316 //
317 //   x < min(a, b) -> x < a && x < b
318 //   x > min(a, b) -> x > a || x > b
319 //   x < max(a, b) -> x < a || x < b
320 //   x > max(a, b) -> x > a && x > b
321 //
322 // Such patterns are introduced by LICM.cpp:hoistMinMax()
323 // transformation and might lead to BPF verification failures for
324 // older kernels.
325 //
326 // To minimize "collateral" changes only do it for icmp + min/max
327 // calls when icmp is inside a loop and min/max is outside of that
328 // loop.
329 //
330 // Verification failure happens when:
331 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
332 // - verifier can recognize RHS as a constant scalar in some context;
333 // - verifier can't recognize RHS1 as a constant scalar in the same
334 //   context;
335 //
336 // The "constant scalar" is not a compile time constant, but a register
337 // that holds a scalar value known to verifier at some point in time
338 // during abstract interpretation.
339 //
340 // See also:
341 //   https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
342 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
343   bool Changed = false;
344 
345   for (Function &F : M) {
346     if (F.isDeclaration())
347       continue;
348 
349     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
350     for (Loop *L : LI)
351       for (BasicBlock *BB : L->blocks()) {
352         // Filter out instructions coming from the same loop
353         Loop *BBLoop = LI.getLoopFor(BB);
354         auto OtherLoopFilter = [&](Instruction *I) {
355           return LI.getLoopFor(I->getParent()) != BBLoop;
356         };
357         Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
358       }
359   }
360 
361   return Changed;
362 }
363 
364 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
365   AU.addRequired<LoopInfoWrapperPass>();
366 }
367 
368 static void unrollGEPLoad(CallInst *Call) {
369   auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call);
370   GEP->insertBefore(Call);
371   Load->insertBefore(Call);
372   Call->replaceAllUsesWith(Load);
373   Call->eraseFromParent();
374 }
375 
376 static void unrollGEPStore(CallInst *Call) {
377   auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call);
378   GEP->insertBefore(Call);
379   Store->insertBefore(Call);
380   Call->eraseFromParent();
381 }
382 
383 static bool removeGEPBuiltinsInFunc(Function &F) {
384   SmallVector<CallInst *> GEPLoads;
385   SmallVector<CallInst *> GEPStores;
386   for (auto &BB : F)
387     for (auto &Insn : BB)
388       if (auto *Call = dyn_cast<CallInst>(&Insn))
389         if (auto *Called = Call->getCalledFunction())
390           switch (Called->getIntrinsicID()) {
391           case Intrinsic::bpf_getelementptr_and_load:
392             GEPLoads.push_back(Call);
393             break;
394           case Intrinsic::bpf_getelementptr_and_store:
395             GEPStores.push_back(Call);
396             break;
397           }
398 
399   if (GEPLoads.empty() && GEPStores.empty())
400     return false;
401 
402   for_each(GEPLoads, unrollGEPLoad);
403   for_each(GEPStores, unrollGEPStore);
404 
405   return true;
406 }
407 
408 // Rewrites the following builtins:
409 // - llvm.bpf.getelementptr.and.load
410 // - llvm.bpf.getelementptr.and.store
411 // As (load (getelementptr ...)) or (store (getelementptr ...)).
412 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) {
413   bool Changed = false;
414   for (auto &F : M)
415     Changed = removeGEPBuiltinsInFunc(F) || Changed;
416   return Changed;
417 }
418 
419 bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
420   bool Changed = removePassThroughBuiltin(M);
421   Changed = removeCompareBuiltin(M) || Changed;
422   Changed = sinkMinMax(M) || Changed;
423   Changed = removeGEPBuiltins(M) || Changed;
424   return Changed;
425 }
426 
427 bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
428   checkIR(M);
429   return adjustIR(M);
430 }
431