1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Check IR and adjust IR for verifier friendly codes. 10 // The following are done for IR checking: 11 // - no relocation globals in PHI node. 12 // The following are done for IR adjustment: 13 // - remove __builtin_bpf_passthrough builtins. Target independent IR 14 // optimizations are done and those builtins can be removed. 15 // - remove llvm.bpf.getelementptr.and.load builtins. 16 // - remove llvm.bpf.getelementptr.and.store builtins. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "BPF.h" 21 #include "BPFCORE.h" 22 #include "BPFTargetMachine.h" 23 #include "llvm/Analysis/LoopInfo.h" 24 #include "llvm/IR/DebugInfoMetadata.h" 25 #include "llvm/IR/GlobalVariable.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instruction.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/IntrinsicsBPF.h" 30 #include "llvm/IR/Module.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/IR/User.h" 33 #include "llvm/IR/Value.h" 34 #include "llvm/Pass.h" 35 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 36 37 #define DEBUG_TYPE "bpf-check-and-opt-ir" 38 39 using namespace llvm; 40 41 namespace { 42 43 class BPFCheckAndAdjustIR final : public ModulePass { 44 bool runOnModule(Module &F) override; 45 46 public: 47 static char ID; 48 BPFCheckAndAdjustIR() : ModulePass(ID) {} 49 virtual void getAnalysisUsage(AnalysisUsage &AU) const override; 50 51 private: 52 void checkIR(Module &M); 53 bool adjustIR(Module &M); 54 bool removePassThroughBuiltin(Module &M); 55 bool removeCompareBuiltin(Module &M); 56 bool sinkMinMax(Module &M); 57 bool removeGEPBuiltins(Module &M); 58 }; 59 } // End anonymous namespace 60 61 char BPFCheckAndAdjustIR::ID = 0; 62 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR", 63 false, false) 64 65 ModulePass *llvm::createBPFCheckAndAdjustIR() { 66 return new BPFCheckAndAdjustIR(); 67 } 68 69 void BPFCheckAndAdjustIR::checkIR(Module &M) { 70 // Ensure relocation global won't appear in PHI node 71 // This may happen if the compiler generated the following code: 72 // B1: 73 // g1 = @llvm.skb_buff:0:1... 74 // ... 75 // goto B_COMMON 76 // B2: 77 // g2 = @llvm.skb_buff:0:2... 78 // ... 79 // goto B_COMMON 80 // B_COMMON: 81 // g = PHI(g1, g2) 82 // x = load g 83 // ... 84 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error. 85 for (Function &F : M) 86 for (auto &BB : F) 87 for (auto &I : BB) { 88 PHINode *PN = dyn_cast<PHINode>(&I); 89 if (!PN || PN->use_empty()) 90 continue; 91 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) { 92 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i)); 93 if (!GV) 94 continue; 95 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) || 96 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr)) 97 report_fatal_error("relocation global in PHI node"); 98 } 99 } 100 } 101 102 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) { 103 // Remove __builtin_bpf_passthrough()'s which are used to prevent 104 // certain IR optimizations. Now major IR optimizations are done, 105 // remove them. 106 bool Changed = false; 107 CallInst *ToBeDeleted = nullptr; 108 for (Function &F : M) 109 for (auto &BB : F) 110 for (auto &I : BB) { 111 if (ToBeDeleted) { 112 ToBeDeleted->eraseFromParent(); 113 ToBeDeleted = nullptr; 114 } 115 116 auto *Call = dyn_cast<CallInst>(&I); 117 if (!Call) 118 continue; 119 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 120 if (!GV) 121 continue; 122 if (!GV->getName().starts_with("llvm.bpf.passthrough")) 123 continue; 124 Changed = true; 125 Value *Arg = Call->getArgOperand(1); 126 Call->replaceAllUsesWith(Arg); 127 ToBeDeleted = Call; 128 } 129 return Changed; 130 } 131 132 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) { 133 // Remove __builtin_bpf_compare()'s which are used to prevent 134 // certain IR optimizations. Now major IR optimizations are done, 135 // remove them. 136 bool Changed = false; 137 CallInst *ToBeDeleted = nullptr; 138 for (Function &F : M) 139 for (auto &BB : F) 140 for (auto &I : BB) { 141 if (ToBeDeleted) { 142 ToBeDeleted->eraseFromParent(); 143 ToBeDeleted = nullptr; 144 } 145 146 auto *Call = dyn_cast<CallInst>(&I); 147 if (!Call) 148 continue; 149 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 150 if (!GV) 151 continue; 152 if (!GV->getName().starts_with("llvm.bpf.compare")) 153 continue; 154 155 Changed = true; 156 Value *Arg0 = Call->getArgOperand(0); 157 Value *Arg1 = Call->getArgOperand(1); 158 Value *Arg2 = Call->getArgOperand(2); 159 160 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue(); 161 CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal; 162 163 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2); 164 ICmp->insertBefore(Call); 165 166 Call->replaceAllUsesWith(ICmp); 167 ToBeDeleted = Call; 168 } 169 return Changed; 170 } 171 172 struct MinMaxSinkInfo { 173 ICmpInst *ICmp; 174 Value *Other; 175 ICmpInst::Predicate Predicate; 176 CallInst *MinMax; 177 ZExtInst *ZExt; 178 SExtInst *SExt; 179 180 MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) 181 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), 182 ZExt(nullptr), SExt(nullptr) {} 183 }; 184 185 static bool sinkMinMaxInBB(BasicBlock &BB, 186 const std::function<bool(Instruction *)> &Filter) { 187 // Check if V is: 188 // (fn %a %b) or (ext (fn %a %b)) 189 // Where: 190 // ext := sext | zext 191 // fn := smin | umin | smax | umax 192 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) { 193 if (auto *ZExt = dyn_cast<ZExtInst>(V)) { 194 V = ZExt->getOperand(0); 195 Info.ZExt = ZExt; 196 } else if (auto *SExt = dyn_cast<SExtInst>(V)) { 197 V = SExt->getOperand(0); 198 Info.SExt = SExt; 199 } 200 201 auto *Call = dyn_cast<CallInst>(V); 202 if (!Call) 203 return false; 204 205 auto *Called = dyn_cast<Function>(Call->getCalledOperand()); 206 if (!Called) 207 return false; 208 209 switch (Called->getIntrinsicID()) { 210 case Intrinsic::smin: 211 case Intrinsic::umin: 212 case Intrinsic::smax: 213 case Intrinsic::umax: 214 break; 215 default: 216 return false; 217 } 218 219 if (!Filter(Call)) 220 return false; 221 222 Info.MinMax = Call; 223 224 return true; 225 }; 226 227 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, 228 MinMaxSinkInfo &Info) { 229 if (Info.SExt) { 230 if (Info.SExt->getType() == V->getType()) 231 return V; 232 return Builder.CreateSExt(V, Info.SExt->getType()); 233 } 234 if (Info.ZExt) { 235 if (Info.ZExt->getType() == V->getType()) 236 return V; 237 return Builder.CreateZExt(V, Info.ZExt->getType()); 238 } 239 return V; 240 }; 241 242 bool Changed = false; 243 SmallVector<MinMaxSinkInfo, 2> SinkList; 244 245 // Check BB for instructions like: 246 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a) 247 // 248 // Where: 249 // fn := min | max | (sext (min ...)) | (sext (max ...)) 250 // 251 // Put such instructions to SinkList. 252 for (Instruction &I : BB) { 253 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I); 254 if (!ICmp) 255 continue; 256 if (!ICmp->isRelational()) 257 continue; 258 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), 259 ICmpInst::getSwappedPredicate(ICmp->getPredicate())); 260 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); 261 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); 262 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); 263 if (!(FirstMinMax ^ SecondMinMax)) 264 continue; 265 SinkList.push_back(FirstMinMax ? First : Second); 266 } 267 268 // Iterate SinkList and replace each (icmp ...) with corresponding 269 // `x < a && x < b` or similar expression. 270 for (auto &Info : SinkList) { 271 ICmpInst *ICmp = Info.ICmp; 272 CallInst *MinMax = Info.MinMax; 273 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); 274 ICmpInst::Predicate P = Info.Predicate; 275 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && 276 IID != Intrinsic::smax) 277 continue; 278 279 IRBuilder<> Builder(ICmp); 280 Value *X = Info.Other; 281 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); 282 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); 283 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; 284 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; 285 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); 286 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); 287 assert(IsMin ^ IsMax); 288 assert(IsLess ^ IsGreater); 289 290 Value *Replacement; 291 Value *LHS = Builder.CreateICmp(P, X, A); 292 Value *RHS = Builder.CreateICmp(P, X, B); 293 if ((IsLess && IsMin) || (IsGreater && IsMax)) 294 // x < min(a, b) -> x < a && x < b 295 // x > max(a, b) -> x > a && x > b 296 Replacement = Builder.CreateLogicalAnd(LHS, RHS); 297 else 298 // x > min(a, b) -> x > a || x > b 299 // x < max(a, b) -> x < a || x < b 300 Replacement = Builder.CreateLogicalOr(LHS, RHS); 301 302 ICmp->replaceAllUsesWith(Replacement); 303 304 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax}; 305 for (Instruction *I : ToRemove) 306 if (I && I->use_empty()) 307 I->eraseFromParent(); 308 309 Changed = true; 310 } 311 312 return Changed; 313 } 314 315 // Do the following transformation: 316 // 317 // x < min(a, b) -> x < a && x < b 318 // x > min(a, b) -> x > a || x > b 319 // x < max(a, b) -> x < a || x < b 320 // x > max(a, b) -> x > a && x > b 321 // 322 // Such patterns are introduced by LICM.cpp:hoistMinMax() 323 // transformation and might lead to BPF verification failures for 324 // older kernels. 325 // 326 // To minimize "collateral" changes only do it for icmp + min/max 327 // calls when icmp is inside a loop and min/max is outside of that 328 // loop. 329 // 330 // Verification failure happens when: 331 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1; 332 // - verifier can recognize RHS as a constant scalar in some context; 333 // - verifier can't recognize RHS1 as a constant scalar in the same 334 // context; 335 // 336 // The "constant scalar" is not a compile time constant, but a register 337 // that holds a scalar value known to verifier at some point in time 338 // during abstract interpretation. 339 // 340 // See also: 341 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/ 342 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { 343 bool Changed = false; 344 345 for (Function &F : M) { 346 if (F.isDeclaration()) 347 continue; 348 349 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo(); 350 for (Loop *L : LI) 351 for (BasicBlock *BB : L->blocks()) { 352 // Filter out instructions coming from the same loop 353 Loop *BBLoop = LI.getLoopFor(BB); 354 auto OtherLoopFilter = [&](Instruction *I) { 355 return LI.getLoopFor(I->getParent()) != BBLoop; 356 }; 357 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter); 358 } 359 } 360 361 return Changed; 362 } 363 364 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const { 365 AU.addRequired<LoopInfoWrapperPass>(); 366 } 367 368 static void unrollGEPLoad(CallInst *Call) { 369 auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call); 370 GEP->insertBefore(Call); 371 Load->insertBefore(Call); 372 Call->replaceAllUsesWith(Load); 373 Call->eraseFromParent(); 374 } 375 376 static void unrollGEPStore(CallInst *Call) { 377 auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call); 378 GEP->insertBefore(Call); 379 Store->insertBefore(Call); 380 Call->eraseFromParent(); 381 } 382 383 static bool removeGEPBuiltinsInFunc(Function &F) { 384 SmallVector<CallInst *> GEPLoads; 385 SmallVector<CallInst *> GEPStores; 386 for (auto &BB : F) 387 for (auto &Insn : BB) 388 if (auto *Call = dyn_cast<CallInst>(&Insn)) 389 if (auto *Called = Call->getCalledFunction()) 390 switch (Called->getIntrinsicID()) { 391 case Intrinsic::bpf_getelementptr_and_load: 392 GEPLoads.push_back(Call); 393 break; 394 case Intrinsic::bpf_getelementptr_and_store: 395 GEPStores.push_back(Call); 396 break; 397 } 398 399 if (GEPLoads.empty() && GEPStores.empty()) 400 return false; 401 402 for_each(GEPLoads, unrollGEPLoad); 403 for_each(GEPStores, unrollGEPStore); 404 405 return true; 406 } 407 408 // Rewrites the following builtins: 409 // - llvm.bpf.getelementptr.and.load 410 // - llvm.bpf.getelementptr.and.store 411 // As (load (getelementptr ...)) or (store (getelementptr ...)). 412 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) { 413 bool Changed = false; 414 for (auto &F : M) 415 Changed = removeGEPBuiltinsInFunc(F) || Changed; 416 return Changed; 417 } 418 419 bool BPFCheckAndAdjustIR::adjustIR(Module &M) { 420 bool Changed = removePassThroughBuiltin(M); 421 Changed = removeCompareBuiltin(M) || Changed; 422 Changed = sinkMinMax(M) || Changed; 423 Changed = removeGEPBuiltins(M) || Changed; 424 return Changed; 425 } 426 427 bool BPFCheckAndAdjustIR::runOnModule(Module &M) { 428 checkIR(M); 429 return adjustIR(M); 430 } 431