1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Check IR and adjust IR for verifier friendly codes. 10 // The following are done for IR checking: 11 // - no relocation globals in PHI node. 12 // The following are done for IR adjustment: 13 // - remove __builtin_bpf_passthrough builtins. Target independent IR 14 // optimizations are done and those builtins can be removed. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "BPF.h" 19 #include "BPFCORE.h" 20 #include "BPFTargetMachine.h" 21 #include "llvm/Analysis/LoopInfo.h" 22 #include "llvm/IR/DebugInfoMetadata.h" 23 #include "llvm/IR/GlobalVariable.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/Instruction.h" 26 #include "llvm/IR/Instructions.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/IR/Type.h" 29 #include "llvm/IR/User.h" 30 #include "llvm/IR/Value.h" 31 #include "llvm/Pass.h" 32 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 33 34 #define DEBUG_TYPE "bpf-check-and-opt-ir" 35 36 using namespace llvm; 37 38 namespace { 39 40 class BPFCheckAndAdjustIR final : public ModulePass { 41 bool runOnModule(Module &F) override; 42 43 public: 44 static char ID; 45 BPFCheckAndAdjustIR() : ModulePass(ID) {} 46 virtual void getAnalysisUsage(AnalysisUsage &AU) const override; 47 48 private: 49 void checkIR(Module &M); 50 bool adjustIR(Module &M); 51 bool removePassThroughBuiltin(Module &M); 52 bool removeCompareBuiltin(Module &M); 53 bool sinkMinMax(Module &M); 54 }; 55 } // End anonymous namespace 56 57 char BPFCheckAndAdjustIR::ID = 0; 58 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR", 59 false, false) 60 61 ModulePass *llvm::createBPFCheckAndAdjustIR() { 62 return new BPFCheckAndAdjustIR(); 63 } 64 65 void BPFCheckAndAdjustIR::checkIR(Module &M) { 66 // Ensure relocation global won't appear in PHI node 67 // This may happen if the compiler generated the following code: 68 // B1: 69 // g1 = @llvm.skb_buff:0:1... 70 // ... 71 // goto B_COMMON 72 // B2: 73 // g2 = @llvm.skb_buff:0:2... 74 // ... 75 // goto B_COMMON 76 // B_COMMON: 77 // g = PHI(g1, g2) 78 // x = load g 79 // ... 80 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error. 81 for (Function &F : M) 82 for (auto &BB : F) 83 for (auto &I : BB) { 84 PHINode *PN = dyn_cast<PHINode>(&I); 85 if (!PN || PN->use_empty()) 86 continue; 87 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) { 88 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i)); 89 if (!GV) 90 continue; 91 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) || 92 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr)) 93 report_fatal_error("relocation global in PHI node"); 94 } 95 } 96 } 97 98 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) { 99 // Remove __builtin_bpf_passthrough()'s which are used to prevent 100 // certain IR optimizations. Now major IR optimizations are done, 101 // remove them. 102 bool Changed = false; 103 CallInst *ToBeDeleted = nullptr; 104 for (Function &F : M) 105 for (auto &BB : F) 106 for (auto &I : BB) { 107 if (ToBeDeleted) { 108 ToBeDeleted->eraseFromParent(); 109 ToBeDeleted = nullptr; 110 } 111 112 auto *Call = dyn_cast<CallInst>(&I); 113 if (!Call) 114 continue; 115 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 116 if (!GV) 117 continue; 118 if (!GV->getName().startswith("llvm.bpf.passthrough")) 119 continue; 120 Changed = true; 121 Value *Arg = Call->getArgOperand(1); 122 Call->replaceAllUsesWith(Arg); 123 ToBeDeleted = Call; 124 } 125 return Changed; 126 } 127 128 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) { 129 // Remove __builtin_bpf_compare()'s which are used to prevent 130 // certain IR optimizations. Now major IR optimizations are done, 131 // remove them. 132 bool Changed = false; 133 CallInst *ToBeDeleted = nullptr; 134 for (Function &F : M) 135 for (auto &BB : F) 136 for (auto &I : BB) { 137 if (ToBeDeleted) { 138 ToBeDeleted->eraseFromParent(); 139 ToBeDeleted = nullptr; 140 } 141 142 auto *Call = dyn_cast<CallInst>(&I); 143 if (!Call) 144 continue; 145 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 146 if (!GV) 147 continue; 148 if (!GV->getName().startswith("llvm.bpf.compare")) 149 continue; 150 151 Changed = true; 152 Value *Arg0 = Call->getArgOperand(0); 153 Value *Arg1 = Call->getArgOperand(1); 154 Value *Arg2 = Call->getArgOperand(2); 155 156 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue(); 157 CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal; 158 159 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2); 160 ICmp->insertBefore(Call); 161 162 Call->replaceAllUsesWith(ICmp); 163 ToBeDeleted = Call; 164 } 165 return Changed; 166 } 167 168 struct MinMaxSinkInfo { 169 ICmpInst *ICmp; 170 Value *Other; 171 ICmpInst::Predicate Predicate; 172 CallInst *MinMax; 173 ZExtInst *ZExt; 174 SExtInst *SExt; 175 176 MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) 177 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), 178 ZExt(nullptr), SExt(nullptr) {} 179 }; 180 181 static bool sinkMinMaxInBB(BasicBlock &BB, 182 const std::function<bool(Instruction *)> &Filter) { 183 // Check if V is: 184 // (fn %a %b) or (ext (fn %a %b)) 185 // Where: 186 // ext := sext | zext 187 // fn := smin | umin | smax | umax 188 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) { 189 if (auto *ZExt = dyn_cast<ZExtInst>(V)) { 190 V = ZExt->getOperand(0); 191 Info.ZExt = ZExt; 192 } else if (auto *SExt = dyn_cast<SExtInst>(V)) { 193 V = SExt->getOperand(0); 194 Info.SExt = SExt; 195 } 196 197 auto *Call = dyn_cast<CallInst>(V); 198 if (!Call) 199 return false; 200 201 auto *Called = dyn_cast<Function>(Call->getCalledOperand()); 202 if (!Called) 203 return false; 204 205 switch (Called->getIntrinsicID()) { 206 case Intrinsic::smin: 207 case Intrinsic::umin: 208 case Intrinsic::smax: 209 case Intrinsic::umax: 210 break; 211 default: 212 return false; 213 } 214 215 if (!Filter(Call)) 216 return false; 217 218 Info.MinMax = Call; 219 220 return true; 221 }; 222 223 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, 224 MinMaxSinkInfo &Info) { 225 if (Info.SExt) { 226 if (Info.SExt->getType() == V->getType()) 227 return V; 228 return Builder.CreateSExt(V, Info.SExt->getType()); 229 } 230 if (Info.ZExt) { 231 if (Info.ZExt->getType() == V->getType()) 232 return V; 233 return Builder.CreateZExt(V, Info.ZExt->getType()); 234 } 235 return V; 236 }; 237 238 bool Changed = false; 239 SmallVector<MinMaxSinkInfo, 2> SinkList; 240 241 // Check BB for instructions like: 242 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a) 243 // 244 // Where: 245 // fn := min | max | (sext (min ...)) | (sext (max ...)) 246 // 247 // Put such instructions to SinkList. 248 for (Instruction &I : BB) { 249 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I); 250 if (!ICmp) 251 continue; 252 if (!ICmp->isRelational()) 253 continue; 254 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), 255 ICmpInst::getSwappedPredicate(ICmp->getPredicate())); 256 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); 257 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); 258 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); 259 if (!(FirstMinMax ^ SecondMinMax)) 260 continue; 261 SinkList.push_back(FirstMinMax ? First : Second); 262 } 263 264 // Iterate SinkList and replace each (icmp ...) with corresponding 265 // `x < a && x < b` or similar expression. 266 for (auto &Info : SinkList) { 267 ICmpInst *ICmp = Info.ICmp; 268 CallInst *MinMax = Info.MinMax; 269 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); 270 ICmpInst::Predicate P = Info.Predicate; 271 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && 272 IID != Intrinsic::smax) 273 continue; 274 275 IRBuilder<> Builder(ICmp); 276 Value *X = Info.Other; 277 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); 278 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); 279 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; 280 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; 281 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); 282 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); 283 assert(IsMin ^ IsMax); 284 assert(IsLess ^ IsGreater); 285 286 Value *Replacement; 287 Value *LHS = Builder.CreateICmp(P, X, A); 288 Value *RHS = Builder.CreateICmp(P, X, B); 289 if ((IsLess && IsMin) || (IsGreater && IsMax)) 290 // x < min(a, b) -> x < a && x < b 291 // x > max(a, b) -> x > a && x > b 292 Replacement = Builder.CreateLogicalAnd(LHS, RHS); 293 else 294 // x > min(a, b) -> x > a || x > b 295 // x < max(a, b) -> x < a || x < b 296 Replacement = Builder.CreateLogicalOr(LHS, RHS); 297 298 ICmp->replaceAllUsesWith(Replacement); 299 300 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax}; 301 for (Instruction *I : ToRemove) 302 if (I && I->use_empty()) 303 I->eraseFromParent(); 304 305 Changed = true; 306 } 307 308 return Changed; 309 } 310 311 // Do the following transformation: 312 // 313 // x < min(a, b) -> x < a && x < b 314 // x > min(a, b) -> x > a || x > b 315 // x < max(a, b) -> x < a || x < b 316 // x > max(a, b) -> x > a && x > b 317 // 318 // Such patterns are introduced by LICM.cpp:hoistMinMax() 319 // transformation and might lead to BPF verification failures for 320 // older kernels. 321 // 322 // To minimize "collateral" changes only do it for icmp + min/max 323 // calls when icmp is inside a loop and min/max is outside of that 324 // loop. 325 // 326 // Verification failure happens when: 327 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1; 328 // - verifier can recognize RHS as a constant scalar in some context; 329 // - verifier can't recognize RHS1 as a constant scalar in the same 330 // context; 331 // 332 // The "constant scalar" is not a compile time constant, but a register 333 // that holds a scalar value known to verifier at some point in time 334 // during abstract interpretation. 335 // 336 // See also: 337 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/ 338 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { 339 bool Changed = false; 340 341 for (Function &F : M) { 342 if (F.isDeclaration()) 343 continue; 344 345 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo(); 346 for (Loop *L : LI) 347 for (BasicBlock *BB : L->blocks()) { 348 // Filter out instructions coming from the same loop 349 Loop *BBLoop = LI.getLoopFor(BB); 350 auto OtherLoopFilter = [&](Instruction *I) { 351 return LI.getLoopFor(I->getParent()) != BBLoop; 352 }; 353 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter); 354 } 355 } 356 357 return Changed; 358 } 359 360 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const { 361 AU.addRequired<LoopInfoWrapperPass>(); 362 } 363 364 bool BPFCheckAndAdjustIR::adjustIR(Module &M) { 365 bool Changed = removePassThroughBuiltin(M); 366 Changed = removeCompareBuiltin(M) || Changed; 367 Changed = sinkMinMax(M) || Changed; 368 return Changed; 369 } 370 371 bool BPFCheckAndAdjustIR::runOnModule(Module &M) { 372 checkIR(M); 373 return adjustIR(M); 374 } 375