1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Check IR and adjust IR for verifier friendly codes.
10 // The following are done for IR checking:
11 // - no relocation globals in PHI node.
12 // The following are done for IR adjustment:
13 // - remove __builtin_bpf_passthrough builtins. Target independent IR
14 // optimizations are done and those builtins can be removed.
15 // - remove llvm.bpf.getelementptr.and.load builtins.
16 // - remove llvm.bpf.getelementptr.and.store builtins.
17 // - for loads and stores with base addresses from non-zero address space
18 // cast base address to zero address space (support for BPF address spaces).
19 //
20 //===----------------------------------------------------------------------===//
21
22 #include "BPF.h"
23 #include "BPFCORE.h"
24 #include "BPFTargetMachine.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicsBPF.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/User.h"
35 #include "llvm/IR/Value.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38
39 #define DEBUG_TYPE "bpf-check-and-opt-ir"
40
41 using namespace llvm;
42
43 namespace {
44
45 class BPFCheckAndAdjustIR final : public ModulePass {
46 bool runOnModule(Module &F) override;
47
48 public:
49 static char ID;
BPFCheckAndAdjustIR()50 BPFCheckAndAdjustIR() : ModulePass(ID) {}
51 virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
52
53 private:
54 void checkIR(Module &M);
55 bool adjustIR(Module &M);
56 bool removePassThroughBuiltin(Module &M);
57 bool removeCompareBuiltin(Module &M);
58 bool sinkMinMax(Module &M);
59 bool removeGEPBuiltins(Module &M);
60 bool insertASpaceCasts(Module &M);
61 };
62 } // End anonymous namespace
63
64 char BPFCheckAndAdjustIR::ID = 0;
65 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
66 false, false)
67
createBPFCheckAndAdjustIR()68 ModulePass *llvm::createBPFCheckAndAdjustIR() {
69 return new BPFCheckAndAdjustIR();
70 }
71
checkIR(Module & M)72 void BPFCheckAndAdjustIR::checkIR(Module &M) {
73 // Ensure relocation global won't appear in PHI node
74 // This may happen if the compiler generated the following code:
75 // B1:
76 // g1 = @llvm.skb_buff:0:1...
77 // ...
78 // goto B_COMMON
79 // B2:
80 // g2 = @llvm.skb_buff:0:2...
81 // ...
82 // goto B_COMMON
83 // B_COMMON:
84 // g = PHI(g1, g2)
85 // x = load g
86 // ...
87 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
88 for (Function &F : M)
89 for (auto &BB : F)
90 for (auto &I : BB) {
91 PHINode *PN = dyn_cast<PHINode>(&I);
92 if (!PN || PN->use_empty())
93 continue;
94 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
95 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
96 if (!GV)
97 continue;
98 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
99 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
100 report_fatal_error("relocation global in PHI node");
101 }
102 }
103 }
104
removePassThroughBuiltin(Module & M)105 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
106 // Remove __builtin_bpf_passthrough()'s which are used to prevent
107 // certain IR optimizations. Now major IR optimizations are done,
108 // remove them.
109 bool Changed = false;
110 CallInst *ToBeDeleted = nullptr;
111 for (Function &F : M)
112 for (auto &BB : F)
113 for (auto &I : BB) {
114 if (ToBeDeleted) {
115 ToBeDeleted->eraseFromParent();
116 ToBeDeleted = nullptr;
117 }
118
119 auto *Call = dyn_cast<CallInst>(&I);
120 if (!Call)
121 continue;
122 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
123 if (!GV)
124 continue;
125 if (!GV->getName().starts_with("llvm.bpf.passthrough"))
126 continue;
127 Changed = true;
128 Value *Arg = Call->getArgOperand(1);
129 Call->replaceAllUsesWith(Arg);
130 ToBeDeleted = Call;
131 }
132 return Changed;
133 }
134
removeCompareBuiltin(Module & M)135 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
136 // Remove __builtin_bpf_compare()'s which are used to prevent
137 // certain IR optimizations. Now major IR optimizations are done,
138 // remove them.
139 bool Changed = false;
140 CallInst *ToBeDeleted = nullptr;
141 for (Function &F : M)
142 for (auto &BB : F)
143 for (auto &I : BB) {
144 if (ToBeDeleted) {
145 ToBeDeleted->eraseFromParent();
146 ToBeDeleted = nullptr;
147 }
148
149 auto *Call = dyn_cast<CallInst>(&I);
150 if (!Call)
151 continue;
152 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
153 if (!GV)
154 continue;
155 if (!GV->getName().starts_with("llvm.bpf.compare"))
156 continue;
157
158 Changed = true;
159 Value *Arg0 = Call->getArgOperand(0);
160 Value *Arg1 = Call->getArgOperand(1);
161 Value *Arg2 = Call->getArgOperand(2);
162
163 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
164 CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal;
165
166 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
167 ICmp->insertBefore(Call);
168
169 Call->replaceAllUsesWith(ICmp);
170 ToBeDeleted = Call;
171 }
172 return Changed;
173 }
174
175 struct MinMaxSinkInfo {
176 ICmpInst *ICmp;
177 Value *Other;
178 ICmpInst::Predicate Predicate;
179 CallInst *MinMax;
180 ZExtInst *ZExt;
181 SExtInst *SExt;
182
MinMaxSinkInfoMinMaxSinkInfo183 MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
184 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
185 ZExt(nullptr), SExt(nullptr) {}
186 };
187
sinkMinMaxInBB(BasicBlock & BB,const std::function<bool (Instruction *)> & Filter)188 static bool sinkMinMaxInBB(BasicBlock &BB,
189 const std::function<bool(Instruction *)> &Filter) {
190 // Check if V is:
191 // (fn %a %b) or (ext (fn %a %b))
192 // Where:
193 // ext := sext | zext
194 // fn := smin | umin | smax | umax
195 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
196 if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
197 V = ZExt->getOperand(0);
198 Info.ZExt = ZExt;
199 } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
200 V = SExt->getOperand(0);
201 Info.SExt = SExt;
202 }
203
204 auto *Call = dyn_cast<CallInst>(V);
205 if (!Call)
206 return false;
207
208 auto *Called = dyn_cast<Function>(Call->getCalledOperand());
209 if (!Called)
210 return false;
211
212 switch (Called->getIntrinsicID()) {
213 case Intrinsic::smin:
214 case Intrinsic::umin:
215 case Intrinsic::smax:
216 case Intrinsic::umax:
217 break;
218 default:
219 return false;
220 }
221
222 if (!Filter(Call))
223 return false;
224
225 Info.MinMax = Call;
226
227 return true;
228 };
229
230 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
231 MinMaxSinkInfo &Info) {
232 if (Info.SExt) {
233 if (Info.SExt->getType() == V->getType())
234 return V;
235 return Builder.CreateSExt(V, Info.SExt->getType());
236 }
237 if (Info.ZExt) {
238 if (Info.ZExt->getType() == V->getType())
239 return V;
240 return Builder.CreateZExt(V, Info.ZExt->getType());
241 }
242 return V;
243 };
244
245 bool Changed = false;
246 SmallVector<MinMaxSinkInfo, 2> SinkList;
247
248 // Check BB for instructions like:
249 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
250 //
251 // Where:
252 // fn := min | max | (sext (min ...)) | (sext (max ...))
253 //
254 // Put such instructions to SinkList.
255 for (Instruction &I : BB) {
256 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
257 if (!ICmp)
258 continue;
259 if (!ICmp->isRelational())
260 continue;
261 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
262 ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
263 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
264 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
265 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
266 if (!(FirstMinMax ^ SecondMinMax))
267 continue;
268 SinkList.push_back(FirstMinMax ? First : Second);
269 }
270
271 // Iterate SinkList and replace each (icmp ...) with corresponding
272 // `x < a && x < b` or similar expression.
273 for (auto &Info : SinkList) {
274 ICmpInst *ICmp = Info.ICmp;
275 CallInst *MinMax = Info.MinMax;
276 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
277 ICmpInst::Predicate P = Info.Predicate;
278 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
279 IID != Intrinsic::smax)
280 continue;
281
282 IRBuilder<> Builder(ICmp);
283 Value *X = Info.Other;
284 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
285 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
286 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
287 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
288 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
289 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
290 assert(IsMin ^ IsMax);
291 assert(IsLess ^ IsGreater);
292
293 Value *Replacement;
294 Value *LHS = Builder.CreateICmp(P, X, A);
295 Value *RHS = Builder.CreateICmp(P, X, B);
296 if ((IsLess && IsMin) || (IsGreater && IsMax))
297 // x < min(a, b) -> x < a && x < b
298 // x > max(a, b) -> x > a && x > b
299 Replacement = Builder.CreateLogicalAnd(LHS, RHS);
300 else
301 // x > min(a, b) -> x > a || x > b
302 // x < max(a, b) -> x < a || x < b
303 Replacement = Builder.CreateLogicalOr(LHS, RHS);
304
305 ICmp->replaceAllUsesWith(Replacement);
306
307 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
308 for (Instruction *I : ToRemove)
309 if (I && I->use_empty())
310 I->eraseFromParent();
311
312 Changed = true;
313 }
314
315 return Changed;
316 }
317
318 // Do the following transformation:
319 //
320 // x < min(a, b) -> x < a && x < b
321 // x > min(a, b) -> x > a || x > b
322 // x < max(a, b) -> x < a || x < b
323 // x > max(a, b) -> x > a && x > b
324 //
325 // Such patterns are introduced by LICM.cpp:hoistMinMax()
326 // transformation and might lead to BPF verification failures for
327 // older kernels.
328 //
329 // To minimize "collateral" changes only do it for icmp + min/max
330 // calls when icmp is inside a loop and min/max is outside of that
331 // loop.
332 //
333 // Verification failure happens when:
334 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
335 // - verifier can recognize RHS as a constant scalar in some context;
336 // - verifier can't recognize RHS1 as a constant scalar in the same
337 // context;
338 //
339 // The "constant scalar" is not a compile time constant, but a register
340 // that holds a scalar value known to verifier at some point in time
341 // during abstract interpretation.
342 //
343 // See also:
344 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
sinkMinMax(Module & M)345 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
346 bool Changed = false;
347
348 for (Function &F : M) {
349 if (F.isDeclaration())
350 continue;
351
352 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
353 for (Loop *L : LI)
354 for (BasicBlock *BB : L->blocks()) {
355 // Filter out instructions coming from the same loop
356 Loop *BBLoop = LI.getLoopFor(BB);
357 auto OtherLoopFilter = [&](Instruction *I) {
358 return LI.getLoopFor(I->getParent()) != BBLoop;
359 };
360 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
361 }
362 }
363
364 return Changed;
365 }
366
getAnalysisUsage(AnalysisUsage & AU) const367 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
368 AU.addRequired<LoopInfoWrapperPass>();
369 }
370
unrollGEPLoad(CallInst * Call)371 static void unrollGEPLoad(CallInst *Call) {
372 auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call);
373 GEP->insertBefore(Call);
374 Load->insertBefore(Call);
375 Call->replaceAllUsesWith(Load);
376 Call->eraseFromParent();
377 }
378
unrollGEPStore(CallInst * Call)379 static void unrollGEPStore(CallInst *Call) {
380 auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call);
381 GEP->insertBefore(Call);
382 Store->insertBefore(Call);
383 Call->eraseFromParent();
384 }
385
removeGEPBuiltinsInFunc(Function & F)386 static bool removeGEPBuiltinsInFunc(Function &F) {
387 SmallVector<CallInst *> GEPLoads;
388 SmallVector<CallInst *> GEPStores;
389 for (auto &BB : F)
390 for (auto &Insn : BB)
391 if (auto *Call = dyn_cast<CallInst>(&Insn))
392 if (auto *Called = Call->getCalledFunction())
393 switch (Called->getIntrinsicID()) {
394 case Intrinsic::bpf_getelementptr_and_load:
395 GEPLoads.push_back(Call);
396 break;
397 case Intrinsic::bpf_getelementptr_and_store:
398 GEPStores.push_back(Call);
399 break;
400 }
401
402 if (GEPLoads.empty() && GEPStores.empty())
403 return false;
404
405 for_each(GEPLoads, unrollGEPLoad);
406 for_each(GEPStores, unrollGEPStore);
407
408 return true;
409 }
410
411 // Rewrites the following builtins:
412 // - llvm.bpf.getelementptr.and.load
413 // - llvm.bpf.getelementptr.and.store
414 // As (load (getelementptr ...)) or (store (getelementptr ...)).
removeGEPBuiltins(Module & M)415 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) {
416 bool Changed = false;
417 for (auto &F : M)
418 Changed = removeGEPBuiltinsInFunc(F) || Changed;
419 return Changed;
420 }
421
422 // Wrap ToWrap with cast to address space zero:
423 // - if ToWrap is a getelementptr,
424 // wrap it's base pointer instead and return a copy;
425 // - if ToWrap is Instruction, insert address space cast
426 // immediately after ToWrap;
427 // - if ToWrap is not an Instruction (function parameter
428 // or a global value), insert address space cast at the
429 // beginning of the Function F;
430 // - use Cache to avoid inserting too many casts;
aspaceWrapValue(DenseMap<Value *,Value * > & Cache,Function * F,Value * ToWrap)431 static Value *aspaceWrapValue(DenseMap<Value *, Value *> &Cache, Function *F,
432 Value *ToWrap) {
433 auto It = Cache.find(ToWrap);
434 if (It != Cache.end())
435 return It->getSecond();
436
437 if (auto *GEP = dyn_cast<GetElementPtrInst>(ToWrap)) {
438 Value *Ptr = GEP->getPointerOperand();
439 Value *WrappedPtr = aspaceWrapValue(Cache, F, Ptr);
440 auto *GEPTy = cast<PointerType>(GEP->getType());
441 auto *NewGEP = GEP->clone();
442 NewGEP->insertAfter(GEP);
443 NewGEP->mutateType(GEPTy->getPointerTo(0));
444 NewGEP->setOperand(GEP->getPointerOperandIndex(), WrappedPtr);
445 NewGEP->setName(GEP->getName());
446 Cache[ToWrap] = NewGEP;
447 return NewGEP;
448 }
449
450 IRBuilder IB(F->getContext());
451 if (Instruction *InsnPtr = dyn_cast<Instruction>(ToWrap))
452 IB.SetInsertPoint(*InsnPtr->getInsertionPointAfterDef());
453 else
454 IB.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
455 auto *PtrTy = cast<PointerType>(ToWrap->getType());
456 auto *ASZeroPtrTy = PtrTy->getPointerTo(0);
457 auto *ACast = IB.CreateAddrSpaceCast(ToWrap, ASZeroPtrTy, ToWrap->getName());
458 Cache[ToWrap] = ACast;
459 return ACast;
460 }
461
462 // Wrap a pointer operand OpNum of instruction I
463 // with cast to address space zero
aspaceWrapOperand(DenseMap<Value *,Value * > & Cache,Instruction * I,unsigned OpNum)464 static void aspaceWrapOperand(DenseMap<Value *, Value *> &Cache, Instruction *I,
465 unsigned OpNum) {
466 Value *OldOp = I->getOperand(OpNum);
467 if (OldOp->getType()->getPointerAddressSpace() == 0)
468 return;
469
470 Value *NewOp = aspaceWrapValue(Cache, I->getFunction(), OldOp);
471 I->setOperand(OpNum, NewOp);
472 // Check if there are any remaining users of old GEP,
473 // delete those w/o users
474 for (;;) {
475 auto *OldGEP = dyn_cast<GetElementPtrInst>(OldOp);
476 if (!OldGEP)
477 break;
478 if (!OldGEP->use_empty())
479 break;
480 OldOp = OldGEP->getPointerOperand();
481 OldGEP->eraseFromParent();
482 }
483 }
484
485 // Support for BPF address spaces:
486 // - for each function in the module M, update pointer operand of
487 // each memory access instruction (load/store/cmpxchg/atomicrmw)
488 // by casting it from non-zero address space to zero address space, e.g:
489 //
490 // (load (ptr addrspace (N) %p) ...)
491 // -> (load (addrspacecast ptr addrspace (N) %p to ptr))
492 //
493 // - assign section with name .addr_space.N for globals defined in
494 // non-zero address space N
insertASpaceCasts(Module & M)495 bool BPFCheckAndAdjustIR::insertASpaceCasts(Module &M) {
496 bool Changed = false;
497 for (Function &F : M) {
498 DenseMap<Value *, Value *> CastsCache;
499 for (BasicBlock &BB : F) {
500 for (Instruction &I : BB) {
501 unsigned PtrOpNum;
502
503 if (auto *LD = dyn_cast<LoadInst>(&I))
504 PtrOpNum = LD->getPointerOperandIndex();
505 else if (auto *ST = dyn_cast<StoreInst>(&I))
506 PtrOpNum = ST->getPointerOperandIndex();
507 else if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(&I))
508 PtrOpNum = CmpXchg->getPointerOperandIndex();
509 else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I))
510 PtrOpNum = RMW->getPointerOperandIndex();
511 else
512 continue;
513
514 aspaceWrapOperand(CastsCache, &I, PtrOpNum);
515 }
516 }
517 Changed |= !CastsCache.empty();
518 }
519 // Merge all globals within same address space into single
520 // .addr_space.<addr space no> section
521 for (GlobalVariable &G : M.globals()) {
522 if (G.getAddressSpace() == 0 || G.hasSection())
523 continue;
524 SmallString<16> SecName;
525 raw_svector_ostream OS(SecName);
526 OS << ".addr_space." << G.getAddressSpace();
527 G.setSection(SecName);
528 // Prevent having separate section for constants
529 G.setConstant(false);
530 }
531 return Changed;
532 }
533
adjustIR(Module & M)534 bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
535 bool Changed = removePassThroughBuiltin(M);
536 Changed = removeCompareBuiltin(M) || Changed;
537 Changed = sinkMinMax(M) || Changed;
538 Changed = removeGEPBuiltins(M) || Changed;
539 Changed = insertASpaceCasts(M) || Changed;
540 return Changed;
541 }
542
runOnModule(Module & M)543 bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
544 checkIR(M);
545 return adjustIR(M);
546 }
547