xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp (revision 3a56015a2f5d630910177fa79a522bb95511ccf7)
1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Check IR and adjust IR for verifier friendly codes.
10 // The following are done for IR checking:
11 //   - no relocation globals in PHI node.
12 // The following are done for IR adjustment:
13 //   - remove __builtin_bpf_passthrough builtins. Target independent IR
14 //     optimizations are done and those builtins can be removed.
15 //   - remove llvm.bpf.getelementptr.and.load builtins.
16 //   - remove llvm.bpf.getelementptr.and.store builtins.
17 //   - for loads and stores with base addresses from non-zero address space
18 //     cast base address to zero address space (support for BPF address spaces).
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "BPF.h"
23 #include "BPFCORE.h"
24 #include "BPFTargetMachine.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicsBPF.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/User.h"
35 #include "llvm/IR/Value.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 
39 #define DEBUG_TYPE "bpf-check-and-opt-ir"
40 
41 using namespace llvm;
42 
43 namespace {
44 
45 class BPFCheckAndAdjustIR final : public ModulePass {
46   bool runOnModule(Module &F) override;
47 
48 public:
49   static char ID;
50   BPFCheckAndAdjustIR() : ModulePass(ID) {}
51   virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
52 
53 private:
54   void checkIR(Module &M);
55   bool adjustIR(Module &M);
56   bool removePassThroughBuiltin(Module &M);
57   bool removeCompareBuiltin(Module &M);
58   bool sinkMinMax(Module &M);
59   bool removeGEPBuiltins(Module &M);
60   bool insertASpaceCasts(Module &M);
61 };
62 } // End anonymous namespace
63 
64 char BPFCheckAndAdjustIR::ID = 0;
65 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
66                 false, false)
67 
68 ModulePass *llvm::createBPFCheckAndAdjustIR() {
69   return new BPFCheckAndAdjustIR();
70 }
71 
72 void BPFCheckAndAdjustIR::checkIR(Module &M) {
73   // Ensure relocation global won't appear in PHI node
74   // This may happen if the compiler generated the following code:
75   //   B1:
76   //      g1 = @llvm.skb_buff:0:1...
77   //      ...
78   //      goto B_COMMON
79   //   B2:
80   //      g2 = @llvm.skb_buff:0:2...
81   //      ...
82   //      goto B_COMMON
83   //   B_COMMON:
84   //      g = PHI(g1, g2)
85   //      x = load g
86   //      ...
87   // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
88   for (Function &F : M)
89     for (auto &BB : F)
90       for (auto &I : BB) {
91         PHINode *PN = dyn_cast<PHINode>(&I);
92         if (!PN || PN->use_empty())
93           continue;
94         for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
95           auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
96           if (!GV)
97             continue;
98           if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
99               GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
100             report_fatal_error("relocation global in PHI node");
101         }
102       }
103 }
104 
105 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
106   // Remove __builtin_bpf_passthrough()'s which are used to prevent
107   // certain IR optimizations. Now major IR optimizations are done,
108   // remove them.
109   bool Changed = false;
110   CallInst *ToBeDeleted = nullptr;
111   for (Function &F : M)
112     for (auto &BB : F)
113       for (auto &I : BB) {
114         if (ToBeDeleted) {
115           ToBeDeleted->eraseFromParent();
116           ToBeDeleted = nullptr;
117         }
118 
119         auto *Call = dyn_cast<CallInst>(&I);
120         if (!Call)
121           continue;
122         auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
123         if (!GV)
124           continue;
125         if (!GV->getName().starts_with("llvm.bpf.passthrough"))
126           continue;
127         Changed = true;
128         Value *Arg = Call->getArgOperand(1);
129         Call->replaceAllUsesWith(Arg);
130         ToBeDeleted = Call;
131       }
132   return Changed;
133 }
134 
135 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
136   // Remove __builtin_bpf_compare()'s which are used to prevent
137   // certain IR optimizations. Now major IR optimizations are done,
138   // remove them.
139   bool Changed = false;
140   CallInst *ToBeDeleted = nullptr;
141   for (Function &F : M)
142     for (auto &BB : F)
143       for (auto &I : BB) {
144         if (ToBeDeleted) {
145           ToBeDeleted->eraseFromParent();
146           ToBeDeleted = nullptr;
147         }
148 
149         auto *Call = dyn_cast<CallInst>(&I);
150         if (!Call)
151           continue;
152         auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
153         if (!GV)
154           continue;
155         if (!GV->getName().starts_with("llvm.bpf.compare"))
156           continue;
157 
158         Changed = true;
159         Value *Arg0 = Call->getArgOperand(0);
160         Value *Arg1 = Call->getArgOperand(1);
161         Value *Arg2 = Call->getArgOperand(2);
162 
163         auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
164         CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal;
165 
166         auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
167         ICmp->insertBefore(Call);
168 
169         Call->replaceAllUsesWith(ICmp);
170         ToBeDeleted = Call;
171       }
172   return Changed;
173 }
174 
175 struct MinMaxSinkInfo {
176   ICmpInst *ICmp;
177   Value *Other;
178   ICmpInst::Predicate Predicate;
179   CallInst *MinMax;
180   ZExtInst *ZExt;
181   SExtInst *SExt;
182 
183   MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
184       : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
185         ZExt(nullptr), SExt(nullptr) {}
186 };
187 
188 static bool sinkMinMaxInBB(BasicBlock &BB,
189                            const std::function<bool(Instruction *)> &Filter) {
190   // Check if V is:
191   //   (fn %a %b) or (ext (fn %a %b))
192   // Where:
193   //   ext := sext | zext
194   //   fn  := smin | umin | smax | umax
195   auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
196     if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
197       V = ZExt->getOperand(0);
198       Info.ZExt = ZExt;
199     } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
200       V = SExt->getOperand(0);
201       Info.SExt = SExt;
202     }
203 
204     auto *Call = dyn_cast<CallInst>(V);
205     if (!Call)
206       return false;
207 
208     auto *Called = dyn_cast<Function>(Call->getCalledOperand());
209     if (!Called)
210       return false;
211 
212     switch (Called->getIntrinsicID()) {
213     case Intrinsic::smin:
214     case Intrinsic::umin:
215     case Intrinsic::smax:
216     case Intrinsic::umax:
217       break;
218     default:
219       return false;
220     }
221 
222     if (!Filter(Call))
223       return false;
224 
225     Info.MinMax = Call;
226 
227     return true;
228   };
229 
230   auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
231                              MinMaxSinkInfo &Info) {
232     if (Info.SExt) {
233       if (Info.SExt->getType() == V->getType())
234         return V;
235       return Builder.CreateSExt(V, Info.SExt->getType());
236     }
237     if (Info.ZExt) {
238       if (Info.ZExt->getType() == V->getType())
239         return V;
240       return Builder.CreateZExt(V, Info.ZExt->getType());
241     }
242     return V;
243   };
244 
245   bool Changed = false;
246   SmallVector<MinMaxSinkInfo, 2> SinkList;
247 
248   // Check BB for instructions like:
249   //   insn := (icmp %a (fn ...)) | (icmp (fn ...)  %a)
250   //
251   // Where:
252   //   fn := min | max | (sext (min ...)) | (sext (max ...))
253   //
254   // Put such instructions to SinkList.
255   for (Instruction &I : BB) {
256     ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
257     if (!ICmp)
258       continue;
259     if (!ICmp->isRelational())
260       continue;
261     MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
262                          ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
263     MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
264     bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
265     bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
266     if (!(FirstMinMax ^ SecondMinMax))
267       continue;
268     SinkList.push_back(FirstMinMax ? First : Second);
269   }
270 
271   // Iterate SinkList and replace each (icmp ...) with corresponding
272   // `x < a && x < b` or similar expression.
273   for (auto &Info : SinkList) {
274     ICmpInst *ICmp = Info.ICmp;
275     CallInst *MinMax = Info.MinMax;
276     Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
277     ICmpInst::Predicate P = Info.Predicate;
278     if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
279         IID != Intrinsic::smax)
280       continue;
281 
282     IRBuilder<> Builder(ICmp);
283     Value *X = Info.Other;
284     Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
285     Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
286     bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
287     bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
288     bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
289     bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
290     assert(IsMin ^ IsMax);
291     assert(IsLess ^ IsGreater);
292 
293     Value *Replacement;
294     Value *LHS = Builder.CreateICmp(P, X, A);
295     Value *RHS = Builder.CreateICmp(P, X, B);
296     if ((IsLess && IsMin) || (IsGreater && IsMax))
297       // x < min(a, b) -> x < a && x < b
298       // x > max(a, b) -> x > a && x > b
299       Replacement = Builder.CreateLogicalAnd(LHS, RHS);
300     else
301       // x > min(a, b) -> x > a || x > b
302       // x < max(a, b) -> x < a || x < b
303       Replacement = Builder.CreateLogicalOr(LHS, RHS);
304 
305     ICmp->replaceAllUsesWith(Replacement);
306 
307     Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
308     for (Instruction *I : ToRemove)
309       if (I && I->use_empty())
310         I->eraseFromParent();
311 
312     Changed = true;
313   }
314 
315   return Changed;
316 }
317 
318 // Do the following transformation:
319 //
320 //   x < min(a, b) -> x < a && x < b
321 //   x > min(a, b) -> x > a || x > b
322 //   x < max(a, b) -> x < a || x < b
323 //   x > max(a, b) -> x > a && x > b
324 //
325 // Such patterns are introduced by LICM.cpp:hoistMinMax()
326 // transformation and might lead to BPF verification failures for
327 // older kernels.
328 //
329 // To minimize "collateral" changes only do it for icmp + min/max
330 // calls when icmp is inside a loop and min/max is outside of that
331 // loop.
332 //
333 // Verification failure happens when:
334 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
335 // - verifier can recognize RHS as a constant scalar in some context;
336 // - verifier can't recognize RHS1 as a constant scalar in the same
337 //   context;
338 //
339 // The "constant scalar" is not a compile time constant, but a register
340 // that holds a scalar value known to verifier at some point in time
341 // during abstract interpretation.
342 //
343 // See also:
344 //   https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
345 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
346   bool Changed = false;
347 
348   for (Function &F : M) {
349     if (F.isDeclaration())
350       continue;
351 
352     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
353     for (Loop *L : LI)
354       for (BasicBlock *BB : L->blocks()) {
355         // Filter out instructions coming from the same loop
356         Loop *BBLoop = LI.getLoopFor(BB);
357         auto OtherLoopFilter = [&](Instruction *I) {
358           return LI.getLoopFor(I->getParent()) != BBLoop;
359         };
360         Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
361       }
362   }
363 
364   return Changed;
365 }
366 
367 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
368   AU.addRequired<LoopInfoWrapperPass>();
369 }
370 
371 static void unrollGEPLoad(CallInst *Call) {
372   auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call);
373   GEP->insertBefore(Call);
374   Load->insertBefore(Call);
375   Call->replaceAllUsesWith(Load);
376   Call->eraseFromParent();
377 }
378 
379 static void unrollGEPStore(CallInst *Call) {
380   auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call);
381   GEP->insertBefore(Call);
382   Store->insertBefore(Call);
383   Call->eraseFromParent();
384 }
385 
386 static bool removeGEPBuiltinsInFunc(Function &F) {
387   SmallVector<CallInst *> GEPLoads;
388   SmallVector<CallInst *> GEPStores;
389   for (auto &BB : F)
390     for (auto &Insn : BB)
391       if (auto *Call = dyn_cast<CallInst>(&Insn))
392         if (auto *Called = Call->getCalledFunction())
393           switch (Called->getIntrinsicID()) {
394           case Intrinsic::bpf_getelementptr_and_load:
395             GEPLoads.push_back(Call);
396             break;
397           case Intrinsic::bpf_getelementptr_and_store:
398             GEPStores.push_back(Call);
399             break;
400           }
401 
402   if (GEPLoads.empty() && GEPStores.empty())
403     return false;
404 
405   for_each(GEPLoads, unrollGEPLoad);
406   for_each(GEPStores, unrollGEPStore);
407 
408   return true;
409 }
410 
411 // Rewrites the following builtins:
412 // - llvm.bpf.getelementptr.and.load
413 // - llvm.bpf.getelementptr.and.store
414 // As (load (getelementptr ...)) or (store (getelementptr ...)).
415 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) {
416   bool Changed = false;
417   for (auto &F : M)
418     Changed = removeGEPBuiltinsInFunc(F) || Changed;
419   return Changed;
420 }
421 
422 // Wrap ToWrap with cast to address space zero:
423 // - if ToWrap is a getelementptr,
424 //   wrap it's base pointer instead and return a copy;
425 // - if ToWrap is Instruction, insert address space cast
426 //   immediately after ToWrap;
427 // - if ToWrap is not an Instruction (function parameter
428 //   or a global value), insert address space cast at the
429 //   beginning of the Function F;
430 // - use Cache to avoid inserting too many casts;
431 static Value *aspaceWrapValue(DenseMap<Value *, Value *> &Cache, Function *F,
432                               Value *ToWrap) {
433   auto It = Cache.find(ToWrap);
434   if (It != Cache.end())
435     return It->getSecond();
436 
437   if (auto *GEP = dyn_cast<GetElementPtrInst>(ToWrap)) {
438     Value *Ptr = GEP->getPointerOperand();
439     Value *WrappedPtr = aspaceWrapValue(Cache, F, Ptr);
440     auto *GEPTy = cast<PointerType>(GEP->getType());
441     auto *NewGEP = GEP->clone();
442     NewGEP->insertAfter(GEP);
443     NewGEP->mutateType(GEPTy->getPointerTo(0));
444     NewGEP->setOperand(GEP->getPointerOperandIndex(), WrappedPtr);
445     NewGEP->setName(GEP->getName());
446     Cache[ToWrap] = NewGEP;
447     return NewGEP;
448   }
449 
450   IRBuilder IB(F->getContext());
451   if (Instruction *InsnPtr = dyn_cast<Instruction>(ToWrap))
452     IB.SetInsertPoint(*InsnPtr->getInsertionPointAfterDef());
453   else
454     IB.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
455   auto *PtrTy = cast<PointerType>(ToWrap->getType());
456   auto *ASZeroPtrTy = PtrTy->getPointerTo(0);
457   auto *ACast = IB.CreateAddrSpaceCast(ToWrap, ASZeroPtrTy, ToWrap->getName());
458   Cache[ToWrap] = ACast;
459   return ACast;
460 }
461 
462 // Wrap a pointer operand OpNum of instruction I
463 // with cast to address space zero
464 static void aspaceWrapOperand(DenseMap<Value *, Value *> &Cache, Instruction *I,
465                               unsigned OpNum) {
466   Value *OldOp = I->getOperand(OpNum);
467   if (OldOp->getType()->getPointerAddressSpace() == 0)
468     return;
469 
470   Value *NewOp = aspaceWrapValue(Cache, I->getFunction(), OldOp);
471   I->setOperand(OpNum, NewOp);
472   // Check if there are any remaining users of old GEP,
473   // delete those w/o users
474   for (;;) {
475     auto *OldGEP = dyn_cast<GetElementPtrInst>(OldOp);
476     if (!OldGEP)
477       break;
478     if (!OldGEP->use_empty())
479       break;
480     OldOp = OldGEP->getPointerOperand();
481     OldGEP->eraseFromParent();
482   }
483 }
484 
485 // Support for BPF address spaces:
486 // - for each function in the module M, update pointer operand of
487 //   each memory access instruction (load/store/cmpxchg/atomicrmw)
488 //   by casting it from non-zero address space to zero address space, e.g:
489 //
490 //   (load (ptr addrspace (N) %p) ...)
491 //     -> (load (addrspacecast ptr addrspace (N) %p to ptr))
492 //
493 // - assign section with name .addr_space.N for globals defined in
494 //   non-zero address space N
495 bool BPFCheckAndAdjustIR::insertASpaceCasts(Module &M) {
496   bool Changed = false;
497   for (Function &F : M) {
498     DenseMap<Value *, Value *> CastsCache;
499     for (BasicBlock &BB : F) {
500       for (Instruction &I : BB) {
501         unsigned PtrOpNum;
502 
503         if (auto *LD = dyn_cast<LoadInst>(&I))
504           PtrOpNum = LD->getPointerOperandIndex();
505         else if (auto *ST = dyn_cast<StoreInst>(&I))
506           PtrOpNum = ST->getPointerOperandIndex();
507         else if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(&I))
508           PtrOpNum = CmpXchg->getPointerOperandIndex();
509         else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I))
510           PtrOpNum = RMW->getPointerOperandIndex();
511         else
512           continue;
513 
514         aspaceWrapOperand(CastsCache, &I, PtrOpNum);
515       }
516     }
517     Changed |= !CastsCache.empty();
518   }
519   // Merge all globals within same address space into single
520   // .addr_space.<addr space no> section
521   for (GlobalVariable &G : M.globals()) {
522     if (G.getAddressSpace() == 0 || G.hasSection())
523       continue;
524     SmallString<16> SecName;
525     raw_svector_ostream OS(SecName);
526     OS << ".addr_space." << G.getAddressSpace();
527     G.setSection(SecName);
528     // Prevent having separate section for constants
529     G.setConstant(false);
530   }
531   return Changed;
532 }
533 
534 bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
535   bool Changed = removePassThroughBuiltin(M);
536   Changed = removeCompareBuiltin(M) || Changed;
537   Changed = sinkMinMax(M) || Changed;
538   Changed = removeGEPBuiltins(M) || Changed;
539   Changed = insertASpaceCasts(M) || Changed;
540   return Changed;
541 }
542 
543 bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
544   checkIR(M);
545   return adjustIR(M);
546 }
547