1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Check IR and adjust IR for verifier friendly codes. 10 // The following are done for IR checking: 11 // - no relocation globals in PHI node. 12 // The following are done for IR adjustment: 13 // - remove __builtin_bpf_passthrough builtins. Target independent IR 14 // optimizations are done and those builtins can be removed. 15 // - remove llvm.bpf.getelementptr.and.load builtins. 16 // - remove llvm.bpf.getelementptr.and.store builtins. 17 // - for loads and stores with base addresses from non-zero address space 18 // cast base address to zero address space (support for BPF address spaces). 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "BPF.h" 23 #include "BPFCORE.h" 24 #include "BPFTargetMachine.h" 25 #include "llvm/Analysis/LoopInfo.h" 26 #include "llvm/IR/DebugInfoMetadata.h" 27 #include "llvm/IR/GlobalVariable.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Instruction.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicsBPF.h" 32 #include "llvm/IR/Module.h" 33 #include "llvm/IR/Type.h" 34 #include "llvm/IR/User.h" 35 #include "llvm/IR/Value.h" 36 #include "llvm/Pass.h" 37 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 38 39 #define DEBUG_TYPE "bpf-check-and-opt-ir" 40 41 using namespace llvm; 42 43 namespace { 44 45 class BPFCheckAndAdjustIR final : public ModulePass { 46 bool runOnModule(Module &F) override; 47 48 public: 49 static char ID; 50 BPFCheckAndAdjustIR() : ModulePass(ID) {} 51 virtual void getAnalysisUsage(AnalysisUsage &AU) const override; 52 53 private: 54 void checkIR(Module &M); 55 bool adjustIR(Module &M); 56 bool removePassThroughBuiltin(Module &M); 57 bool removeCompareBuiltin(Module &M); 58 bool sinkMinMax(Module &M); 59 bool removeGEPBuiltins(Module &M); 60 bool insertASpaceCasts(Module &M); 61 }; 62 } // End anonymous namespace 63 64 char BPFCheckAndAdjustIR::ID = 0; 65 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR", 66 false, false) 67 68 ModulePass *llvm::createBPFCheckAndAdjustIR() { 69 return new BPFCheckAndAdjustIR(); 70 } 71 72 void BPFCheckAndAdjustIR::checkIR(Module &M) { 73 // Ensure relocation global won't appear in PHI node 74 // This may happen if the compiler generated the following code: 75 // B1: 76 // g1 = @llvm.skb_buff:0:1... 77 // ... 78 // goto B_COMMON 79 // B2: 80 // g2 = @llvm.skb_buff:0:2... 81 // ... 82 // goto B_COMMON 83 // B_COMMON: 84 // g = PHI(g1, g2) 85 // x = load g 86 // ... 87 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error. 88 for (Function &F : M) 89 for (auto &BB : F) 90 for (auto &I : BB) { 91 PHINode *PN = dyn_cast<PHINode>(&I); 92 if (!PN || PN->use_empty()) 93 continue; 94 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) { 95 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i)); 96 if (!GV) 97 continue; 98 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) || 99 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr)) 100 report_fatal_error("relocation global in PHI node"); 101 } 102 } 103 } 104 105 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) { 106 // Remove __builtin_bpf_passthrough()'s which are used to prevent 107 // certain IR optimizations. Now major IR optimizations are done, 108 // remove them. 109 bool Changed = false; 110 CallInst *ToBeDeleted = nullptr; 111 for (Function &F : M) 112 for (auto &BB : F) 113 for (auto &I : BB) { 114 if (ToBeDeleted) { 115 ToBeDeleted->eraseFromParent(); 116 ToBeDeleted = nullptr; 117 } 118 119 auto *Call = dyn_cast<CallInst>(&I); 120 if (!Call) 121 continue; 122 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 123 if (!GV) 124 continue; 125 if (!GV->getName().starts_with("llvm.bpf.passthrough")) 126 continue; 127 Changed = true; 128 Value *Arg = Call->getArgOperand(1); 129 Call->replaceAllUsesWith(Arg); 130 ToBeDeleted = Call; 131 } 132 return Changed; 133 } 134 135 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) { 136 // Remove __builtin_bpf_compare()'s which are used to prevent 137 // certain IR optimizations. Now major IR optimizations are done, 138 // remove them. 139 bool Changed = false; 140 CallInst *ToBeDeleted = nullptr; 141 for (Function &F : M) 142 for (auto &BB : F) 143 for (auto &I : BB) { 144 if (ToBeDeleted) { 145 ToBeDeleted->eraseFromParent(); 146 ToBeDeleted = nullptr; 147 } 148 149 auto *Call = dyn_cast<CallInst>(&I); 150 if (!Call) 151 continue; 152 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 153 if (!GV) 154 continue; 155 if (!GV->getName().starts_with("llvm.bpf.compare")) 156 continue; 157 158 Changed = true; 159 Value *Arg0 = Call->getArgOperand(0); 160 Value *Arg1 = Call->getArgOperand(1); 161 Value *Arg2 = Call->getArgOperand(2); 162 163 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue(); 164 CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal; 165 166 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2); 167 ICmp->insertBefore(Call); 168 169 Call->replaceAllUsesWith(ICmp); 170 ToBeDeleted = Call; 171 } 172 return Changed; 173 } 174 175 struct MinMaxSinkInfo { 176 ICmpInst *ICmp; 177 Value *Other; 178 ICmpInst::Predicate Predicate; 179 CallInst *MinMax; 180 ZExtInst *ZExt; 181 SExtInst *SExt; 182 183 MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) 184 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), 185 ZExt(nullptr), SExt(nullptr) {} 186 }; 187 188 static bool sinkMinMaxInBB(BasicBlock &BB, 189 const std::function<bool(Instruction *)> &Filter) { 190 // Check if V is: 191 // (fn %a %b) or (ext (fn %a %b)) 192 // Where: 193 // ext := sext | zext 194 // fn := smin | umin | smax | umax 195 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) { 196 if (auto *ZExt = dyn_cast<ZExtInst>(V)) { 197 V = ZExt->getOperand(0); 198 Info.ZExt = ZExt; 199 } else if (auto *SExt = dyn_cast<SExtInst>(V)) { 200 V = SExt->getOperand(0); 201 Info.SExt = SExt; 202 } 203 204 auto *Call = dyn_cast<CallInst>(V); 205 if (!Call) 206 return false; 207 208 auto *Called = dyn_cast<Function>(Call->getCalledOperand()); 209 if (!Called) 210 return false; 211 212 switch (Called->getIntrinsicID()) { 213 case Intrinsic::smin: 214 case Intrinsic::umin: 215 case Intrinsic::smax: 216 case Intrinsic::umax: 217 break; 218 default: 219 return false; 220 } 221 222 if (!Filter(Call)) 223 return false; 224 225 Info.MinMax = Call; 226 227 return true; 228 }; 229 230 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, 231 MinMaxSinkInfo &Info) { 232 if (Info.SExt) { 233 if (Info.SExt->getType() == V->getType()) 234 return V; 235 return Builder.CreateSExt(V, Info.SExt->getType()); 236 } 237 if (Info.ZExt) { 238 if (Info.ZExt->getType() == V->getType()) 239 return V; 240 return Builder.CreateZExt(V, Info.ZExt->getType()); 241 } 242 return V; 243 }; 244 245 bool Changed = false; 246 SmallVector<MinMaxSinkInfo, 2> SinkList; 247 248 // Check BB for instructions like: 249 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a) 250 // 251 // Where: 252 // fn := min | max | (sext (min ...)) | (sext (max ...)) 253 // 254 // Put such instructions to SinkList. 255 for (Instruction &I : BB) { 256 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I); 257 if (!ICmp) 258 continue; 259 if (!ICmp->isRelational()) 260 continue; 261 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), 262 ICmpInst::getSwappedPredicate(ICmp->getPredicate())); 263 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); 264 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); 265 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); 266 if (!(FirstMinMax ^ SecondMinMax)) 267 continue; 268 SinkList.push_back(FirstMinMax ? First : Second); 269 } 270 271 // Iterate SinkList and replace each (icmp ...) with corresponding 272 // `x < a && x < b` or similar expression. 273 for (auto &Info : SinkList) { 274 ICmpInst *ICmp = Info.ICmp; 275 CallInst *MinMax = Info.MinMax; 276 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); 277 ICmpInst::Predicate P = Info.Predicate; 278 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && 279 IID != Intrinsic::smax) 280 continue; 281 282 IRBuilder<> Builder(ICmp); 283 Value *X = Info.Other; 284 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); 285 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); 286 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; 287 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; 288 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); 289 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); 290 assert(IsMin ^ IsMax); 291 assert(IsLess ^ IsGreater); 292 293 Value *Replacement; 294 Value *LHS = Builder.CreateICmp(P, X, A); 295 Value *RHS = Builder.CreateICmp(P, X, B); 296 if ((IsLess && IsMin) || (IsGreater && IsMax)) 297 // x < min(a, b) -> x < a && x < b 298 // x > max(a, b) -> x > a && x > b 299 Replacement = Builder.CreateLogicalAnd(LHS, RHS); 300 else 301 // x > min(a, b) -> x > a || x > b 302 // x < max(a, b) -> x < a || x < b 303 Replacement = Builder.CreateLogicalOr(LHS, RHS); 304 305 ICmp->replaceAllUsesWith(Replacement); 306 307 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax}; 308 for (Instruction *I : ToRemove) 309 if (I && I->use_empty()) 310 I->eraseFromParent(); 311 312 Changed = true; 313 } 314 315 return Changed; 316 } 317 318 // Do the following transformation: 319 // 320 // x < min(a, b) -> x < a && x < b 321 // x > min(a, b) -> x > a || x > b 322 // x < max(a, b) -> x < a || x < b 323 // x > max(a, b) -> x > a && x > b 324 // 325 // Such patterns are introduced by LICM.cpp:hoistMinMax() 326 // transformation and might lead to BPF verification failures for 327 // older kernels. 328 // 329 // To minimize "collateral" changes only do it for icmp + min/max 330 // calls when icmp is inside a loop and min/max is outside of that 331 // loop. 332 // 333 // Verification failure happens when: 334 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1; 335 // - verifier can recognize RHS as a constant scalar in some context; 336 // - verifier can't recognize RHS1 as a constant scalar in the same 337 // context; 338 // 339 // The "constant scalar" is not a compile time constant, but a register 340 // that holds a scalar value known to verifier at some point in time 341 // during abstract interpretation. 342 // 343 // See also: 344 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/ 345 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { 346 bool Changed = false; 347 348 for (Function &F : M) { 349 if (F.isDeclaration()) 350 continue; 351 352 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo(); 353 for (Loop *L : LI) 354 for (BasicBlock *BB : L->blocks()) { 355 // Filter out instructions coming from the same loop 356 Loop *BBLoop = LI.getLoopFor(BB); 357 auto OtherLoopFilter = [&](Instruction *I) { 358 return LI.getLoopFor(I->getParent()) != BBLoop; 359 }; 360 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter); 361 } 362 } 363 364 return Changed; 365 } 366 367 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const { 368 AU.addRequired<LoopInfoWrapperPass>(); 369 } 370 371 static void unrollGEPLoad(CallInst *Call) { 372 auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call); 373 GEP->insertBefore(Call); 374 Load->insertBefore(Call); 375 Call->replaceAllUsesWith(Load); 376 Call->eraseFromParent(); 377 } 378 379 static void unrollGEPStore(CallInst *Call) { 380 auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call); 381 GEP->insertBefore(Call); 382 Store->insertBefore(Call); 383 Call->eraseFromParent(); 384 } 385 386 static bool removeGEPBuiltinsInFunc(Function &F) { 387 SmallVector<CallInst *> GEPLoads; 388 SmallVector<CallInst *> GEPStores; 389 for (auto &BB : F) 390 for (auto &Insn : BB) 391 if (auto *Call = dyn_cast<CallInst>(&Insn)) 392 if (auto *Called = Call->getCalledFunction()) 393 switch (Called->getIntrinsicID()) { 394 case Intrinsic::bpf_getelementptr_and_load: 395 GEPLoads.push_back(Call); 396 break; 397 case Intrinsic::bpf_getelementptr_and_store: 398 GEPStores.push_back(Call); 399 break; 400 } 401 402 if (GEPLoads.empty() && GEPStores.empty()) 403 return false; 404 405 for_each(GEPLoads, unrollGEPLoad); 406 for_each(GEPStores, unrollGEPStore); 407 408 return true; 409 } 410 411 // Rewrites the following builtins: 412 // - llvm.bpf.getelementptr.and.load 413 // - llvm.bpf.getelementptr.and.store 414 // As (load (getelementptr ...)) or (store (getelementptr ...)). 415 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) { 416 bool Changed = false; 417 for (auto &F : M) 418 Changed = removeGEPBuiltinsInFunc(F) || Changed; 419 return Changed; 420 } 421 422 // Wrap ToWrap with cast to address space zero: 423 // - if ToWrap is a getelementptr, 424 // wrap it's base pointer instead and return a copy; 425 // - if ToWrap is Instruction, insert address space cast 426 // immediately after ToWrap; 427 // - if ToWrap is not an Instruction (function parameter 428 // or a global value), insert address space cast at the 429 // beginning of the Function F; 430 // - use Cache to avoid inserting too many casts; 431 static Value *aspaceWrapValue(DenseMap<Value *, Value *> &Cache, Function *F, 432 Value *ToWrap) { 433 auto It = Cache.find(ToWrap); 434 if (It != Cache.end()) 435 return It->getSecond(); 436 437 if (auto *GEP = dyn_cast<GetElementPtrInst>(ToWrap)) { 438 Value *Ptr = GEP->getPointerOperand(); 439 Value *WrappedPtr = aspaceWrapValue(Cache, F, Ptr); 440 auto *GEPTy = cast<PointerType>(GEP->getType()); 441 auto *NewGEP = GEP->clone(); 442 NewGEP->insertAfter(GEP); 443 NewGEP->mutateType(GEPTy->getPointerTo(0)); 444 NewGEP->setOperand(GEP->getPointerOperandIndex(), WrappedPtr); 445 NewGEP->setName(GEP->getName()); 446 Cache[ToWrap] = NewGEP; 447 return NewGEP; 448 } 449 450 IRBuilder IB(F->getContext()); 451 if (Instruction *InsnPtr = dyn_cast<Instruction>(ToWrap)) 452 IB.SetInsertPoint(*InsnPtr->getInsertionPointAfterDef()); 453 else 454 IB.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt()); 455 auto *PtrTy = cast<PointerType>(ToWrap->getType()); 456 auto *ASZeroPtrTy = PtrTy->getPointerTo(0); 457 auto *ACast = IB.CreateAddrSpaceCast(ToWrap, ASZeroPtrTy, ToWrap->getName()); 458 Cache[ToWrap] = ACast; 459 return ACast; 460 } 461 462 // Wrap a pointer operand OpNum of instruction I 463 // with cast to address space zero 464 static void aspaceWrapOperand(DenseMap<Value *, Value *> &Cache, Instruction *I, 465 unsigned OpNum) { 466 Value *OldOp = I->getOperand(OpNum); 467 if (OldOp->getType()->getPointerAddressSpace() == 0) 468 return; 469 470 Value *NewOp = aspaceWrapValue(Cache, I->getFunction(), OldOp); 471 I->setOperand(OpNum, NewOp); 472 // Check if there are any remaining users of old GEP, 473 // delete those w/o users 474 for (;;) { 475 auto *OldGEP = dyn_cast<GetElementPtrInst>(OldOp); 476 if (!OldGEP) 477 break; 478 if (!OldGEP->use_empty()) 479 break; 480 OldOp = OldGEP->getPointerOperand(); 481 OldGEP->eraseFromParent(); 482 } 483 } 484 485 // Support for BPF address spaces: 486 // - for each function in the module M, update pointer operand of 487 // each memory access instruction (load/store/cmpxchg/atomicrmw) 488 // by casting it from non-zero address space to zero address space, e.g: 489 // 490 // (load (ptr addrspace (N) %p) ...) 491 // -> (load (addrspacecast ptr addrspace (N) %p to ptr)) 492 // 493 // - assign section with name .addr_space.N for globals defined in 494 // non-zero address space N 495 bool BPFCheckAndAdjustIR::insertASpaceCasts(Module &M) { 496 bool Changed = false; 497 for (Function &F : M) { 498 DenseMap<Value *, Value *> CastsCache; 499 for (BasicBlock &BB : F) { 500 for (Instruction &I : BB) { 501 unsigned PtrOpNum; 502 503 if (auto *LD = dyn_cast<LoadInst>(&I)) 504 PtrOpNum = LD->getPointerOperandIndex(); 505 else if (auto *ST = dyn_cast<StoreInst>(&I)) 506 PtrOpNum = ST->getPointerOperandIndex(); 507 else if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(&I)) 508 PtrOpNum = CmpXchg->getPointerOperandIndex(); 509 else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I)) 510 PtrOpNum = RMW->getPointerOperandIndex(); 511 else 512 continue; 513 514 aspaceWrapOperand(CastsCache, &I, PtrOpNum); 515 } 516 } 517 Changed |= !CastsCache.empty(); 518 } 519 // Merge all globals within same address space into single 520 // .addr_space.<addr space no> section 521 for (GlobalVariable &G : M.globals()) { 522 if (G.getAddressSpace() == 0 || G.hasSection()) 523 continue; 524 SmallString<16> SecName; 525 raw_svector_ostream OS(SecName); 526 OS << ".addr_space." << G.getAddressSpace(); 527 G.setSection(SecName); 528 // Prevent having separate section for constants 529 G.setConstant(false); 530 } 531 return Changed; 532 } 533 534 bool BPFCheckAndAdjustIR::adjustIR(Module &M) { 535 bool Changed = removePassThroughBuiltin(M); 536 Changed = removeCompareBuiltin(M) || Changed; 537 Changed = sinkMinMax(M) || Changed; 538 Changed = removeGEPBuiltins(M) || Changed; 539 Changed = insertASpaceCasts(M) || Changed; 540 return Changed; 541 } 542 543 bool BPFCheckAndAdjustIR::runOnModule(Module &M) { 544 checkIR(M); 545 return adjustIR(M); 546 } 547