1 //===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass lowers the 'expect' intrinsic to LLVM metadata. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/ADT/iterator_range.h" 17 #include "llvm/IR/BasicBlock.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/Function.h" 20 #include "llvm/IR/Instructions.h" 21 #include "llvm/IR/Intrinsics.h" 22 #include "llvm/IR/LLVMContext.h" 23 #include "llvm/IR/MDBuilder.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Transforms/Scalar.h" 28 #include "llvm/Transforms/Utils/MisExpect.h" 29 30 #include <cmath> 31 32 using namespace llvm; 33 34 #define DEBUG_TYPE "lower-expect-intrinsic" 35 36 STATISTIC(ExpectIntrinsicsHandled, 37 "Number of 'expect' intrinsic instructions handled"); 38 39 // These default values are chosen to represent an extremely skewed outcome for 40 // a condition, but they leave some room for interpretation by later passes. 41 // 42 // If the documentation for __builtin_expect() was made explicit that it should 43 // only be used in extreme cases, we could make this ratio higher. As it stands, 44 // programmers may be using __builtin_expect() / llvm.expect to annotate that a 45 // branch is likely or unlikely to be taken. 46 47 // WARNING: these values are internal implementation detail of the pass. 48 // They should not be exposed to the outside of the pass, front-end codegen 49 // should emit @llvm.expect intrinsics instead of using these weights directly. 50 // Transforms should use TargetTransformInfo's getPredictableBranchThreshold(). 51 static cl::opt<uint32_t> LikelyBranchWeight( 52 "likely-branch-weight", cl::Hidden, cl::init(2000), 53 cl::desc("Weight of the branch likely to be taken (default = 2000)")); 54 static cl::opt<uint32_t> UnlikelyBranchWeight( 55 "unlikely-branch-weight", cl::Hidden, cl::init(1), 56 cl::desc("Weight of the branch unlikely to be taken (default = 1)")); 57 58 static std::tuple<uint32_t, uint32_t> 59 getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) { 60 if (IntrinsicID == Intrinsic::expect) { 61 // __builtin_expect 62 return std::make_tuple(LikelyBranchWeight.getValue(), 63 UnlikelyBranchWeight.getValue()); 64 } else { 65 // __builtin_expect_with_probability 66 assert(CI->getNumOperands() >= 3 && 67 "expect with probability must have 3 arguments"); 68 auto *Confidence = cast<ConstantFP>(CI->getArgOperand(2)); 69 double TrueProb = Confidence->getValueAPF().convertToDouble(); 70 assert((TrueProb >= 0.0 && TrueProb <= 1.0) && 71 "probability value must be in the range [0.0, 1.0]"); 72 double FalseProb = (1.0 - TrueProb) / (BranchCount - 1); 73 uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0); 74 uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0); 75 return std::make_tuple(LikelyBW, UnlikelyBW); 76 } 77 } 78 79 static bool handleSwitchExpect(SwitchInst &SI) { 80 CallInst *CI = dyn_cast<CallInst>(SI.getCondition()); 81 if (!CI) 82 return false; 83 84 Function *Fn = CI->getCalledFunction(); 85 if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect && 86 Fn->getIntrinsicID() != Intrinsic::expect_with_probability)) 87 return false; 88 89 Value *ArgValue = CI->getArgOperand(0); 90 ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1)); 91 if (!ExpectedValue) 92 return false; 93 94 SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue); 95 unsigned n = SI.getNumCases(); // +1 for default case. 96 uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; 97 std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = 98 getBranchWeight(Fn->getIntrinsicID(), CI, n + 1); 99 100 SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal); 101 102 uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1; 103 Weights[Index] = LikelyBranchWeightVal; 104 105 misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true); 106 107 SI.setCondition(ArgValue); 108 109 SI.setMetadata(LLVMContext::MD_prof, 110 MDBuilder(CI->getContext()).createBranchWeights(Weights)); 111 112 return true; 113 } 114 115 /// Handler for PHINodes that define the value argument to an 116 /// @llvm.expect call. 117 /// 118 /// If the operand of the phi has a constant value and it 'contradicts' 119 /// with the expected value of phi def, then the corresponding incoming 120 /// edge of the phi is unlikely to be taken. Using that information, 121 /// the branch probability info for the originating branch can be inferred. 122 static void handlePhiDef(CallInst *Expect) { 123 Value &Arg = *Expect->getArgOperand(0); 124 ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1)); 125 if (!ExpectedValue) 126 return; 127 const APInt &ExpectedPhiValue = ExpectedValue->getValue(); 128 bool ExpectedValueIsLikely = true; 129 Function *Fn = Expect->getCalledFunction(); 130 // If the function is expect_with_probability, then we need to take the 131 // probability into consideration. For example, in 132 // expect.with.probability.i64(i64 %a, i64 1, double 0.0), the 133 // "ExpectedValue" 1 is unlikely. This affects probability propagation later. 134 if (Fn->getIntrinsicID() == Intrinsic::expect_with_probability) { 135 auto *Confidence = cast<ConstantFP>(Expect->getArgOperand(2)); 136 double TrueProb = Confidence->getValueAPF().convertToDouble(); 137 ExpectedValueIsLikely = (TrueProb > 0.5); 138 } 139 140 // Walk up in backward a list of instructions that 141 // have 'copy' semantics by 'stripping' the copies 142 // until a PHI node or an instruction of unknown kind 143 // is reached. Negation via xor is also handled. 144 // 145 // C = PHI(...); 146 // B = C; 147 // A = B; 148 // D = __builtin_expect(A, 0); 149 // 150 Value *V = &Arg; 151 SmallVector<Instruction *, 4> Operations; 152 while (!isa<PHINode>(V)) { 153 if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) { 154 V = ZExt->getOperand(0); 155 Operations.push_back(ZExt); 156 continue; 157 } 158 159 if (SExtInst *SExt = dyn_cast<SExtInst>(V)) { 160 V = SExt->getOperand(0); 161 Operations.push_back(SExt); 162 continue; 163 } 164 165 BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V); 166 if (!BinOp || BinOp->getOpcode() != Instruction::Xor) 167 return; 168 169 ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1)); 170 if (!CInt) 171 return; 172 173 V = BinOp->getOperand(0); 174 Operations.push_back(BinOp); 175 } 176 177 // Executes the recorded operations on input 'Value'. 178 auto ApplyOperations = [&](const APInt &Value) { 179 APInt Result = Value; 180 for (auto *Op : llvm::reverse(Operations)) { 181 switch (Op->getOpcode()) { 182 case Instruction::Xor: 183 Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue(); 184 break; 185 case Instruction::ZExt: 186 Result = Result.zext(Op->getType()->getIntegerBitWidth()); 187 break; 188 case Instruction::SExt: 189 Result = Result.sext(Op->getType()->getIntegerBitWidth()); 190 break; 191 default: 192 llvm_unreachable("Unexpected operation"); 193 } 194 } 195 return Result; 196 }; 197 198 auto *PhiDef = cast<PHINode>(V); 199 200 // Get the first dominating conditional branch of the operand 201 // i's incoming block. 202 auto GetDomConditional = [&](unsigned i) -> BranchInst * { 203 BasicBlock *BB = PhiDef->getIncomingBlock(i); 204 BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); 205 if (BI && BI->isConditional()) 206 return BI; 207 BB = BB->getSinglePredecessor(); 208 if (!BB) 209 return nullptr; 210 BI = dyn_cast<BranchInst>(BB->getTerminator()); 211 if (!BI || BI->isUnconditional()) 212 return nullptr; 213 return BI; 214 }; 215 216 // Now walk through all Phi operands to find phi oprerands with values 217 // conflicting with the expected phi output value. Any such operand 218 // indicates the incoming edge to that operand is unlikely. 219 for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) { 220 221 Value *PhiOpnd = PhiDef->getIncomingValue(i); 222 ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd); 223 if (!CI) 224 continue; 225 226 // Not an interesting case when IsUnlikely is false -- we can not infer 227 // anything useful when: 228 // (1) We expect some phi output and the operand value matches it, or 229 // (2) We don't expect some phi output (i.e. the "ExpectedValue" has low 230 // probability) and the operand value doesn't match that. 231 const APInt &CurrentPhiValue = ApplyOperations(CI->getValue()); 232 if (ExpectedValueIsLikely == (ExpectedPhiValue == CurrentPhiValue)) 233 continue; 234 235 BranchInst *BI = GetDomConditional(i); 236 if (!BI) 237 continue; 238 239 MDBuilder MDB(PhiDef->getContext()); 240 241 // There are two situations in which an operand of the PhiDef comes 242 // from a given successor of a branch instruction BI. 243 // 1) When the incoming block of the operand is the successor block; 244 // 2) When the incoming block is BI's enclosing block and the 245 // successor is the PhiDef's enclosing block. 246 // 247 // Returns true if the operand which comes from OpndIncomingBB 248 // comes from outgoing edge of BI that leads to Succ block. 249 auto *OpndIncomingBB = PhiDef->getIncomingBlock(i); 250 auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) { 251 if (OpndIncomingBB == Succ) 252 // If this successor is the incoming block for this 253 // Phi operand, then this successor does lead to the Phi. 254 return true; 255 if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent()) 256 // Otherwise, if the edge is directly from the branch 257 // to the Phi, this successor is the one feeding this 258 // Phi operand. 259 return true; 260 return false; 261 }; 262 uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; 263 std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight( 264 Expect->getCalledFunction()->getIntrinsicID(), Expect, 2); 265 if (!ExpectedValueIsLikely) 266 std::swap(LikelyBranchWeightVal, UnlikelyBranchWeightVal); 267 268 if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) 269 BI->setMetadata(LLVMContext::MD_prof, 270 MDB.createBranchWeights(LikelyBranchWeightVal, 271 UnlikelyBranchWeightVal)); 272 else if (IsOpndComingFromSuccessor(BI->getSuccessor(0))) 273 BI->setMetadata(LLVMContext::MD_prof, 274 MDB.createBranchWeights(UnlikelyBranchWeightVal, 275 LikelyBranchWeightVal)); 276 } 277 } 278 279 // Handle both BranchInst and SelectInst. 280 template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { 281 282 // Handle non-optimized IR code like: 283 // %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1) 284 // %tobool = icmp ne i64 %expval, 0 285 // br i1 %tobool, label %if.then, label %if.end 286 // 287 // Or the following simpler case: 288 // %expval = call i1 @llvm.expect.i1(i1 %cmp, i1 1) 289 // br i1 %expval, label %if.then, label %if.end 290 291 CallInst *CI; 292 293 ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition()); 294 CmpInst::Predicate Predicate; 295 ConstantInt *CmpConstOperand = nullptr; 296 if (!CmpI) { 297 CI = dyn_cast<CallInst>(BSI.getCondition()); 298 Predicate = CmpInst::ICMP_NE; 299 } else { 300 Predicate = CmpI->getPredicate(); 301 if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ) 302 return false; 303 304 CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1)); 305 if (!CmpConstOperand) 306 return false; 307 CI = dyn_cast<CallInst>(CmpI->getOperand(0)); 308 } 309 310 if (!CI) 311 return false; 312 313 uint64_t ValueComparedTo = 0; 314 if (CmpConstOperand) { 315 if (CmpConstOperand->getBitWidth() > 64) 316 return false; 317 ValueComparedTo = CmpConstOperand->getZExtValue(); 318 } 319 320 Function *Fn = CI->getCalledFunction(); 321 if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect && 322 Fn->getIntrinsicID() != Intrinsic::expect_with_probability)) 323 return false; 324 325 Value *ArgValue = CI->getArgOperand(0); 326 ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1)); 327 if (!ExpectedValue) 328 return false; 329 330 MDBuilder MDB(CI->getContext()); 331 MDNode *Node; 332 333 uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; 334 std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = 335 getBranchWeight(Fn->getIntrinsicID(), CI, 2); 336 337 SmallVector<uint32_t, 4> ExpectedWeights; 338 if ((ExpectedValue->getZExtValue() == ValueComparedTo) == 339 (Predicate == CmpInst::ICMP_EQ)) { 340 Node = 341 MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal); 342 ExpectedWeights = {LikelyBranchWeightVal, UnlikelyBranchWeightVal}; 343 } else { 344 Node = 345 MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal); 346 ExpectedWeights = {UnlikelyBranchWeightVal, LikelyBranchWeightVal}; 347 } 348 349 if (CmpI) 350 CmpI->setOperand(0, ArgValue); 351 else 352 BSI.setCondition(ArgValue); 353 354 misexpect::checkFrontendInstrumentation(BSI, ExpectedWeights); 355 356 BSI.setMetadata(LLVMContext::MD_prof, Node); 357 358 return true; 359 } 360 361 static bool handleBranchExpect(BranchInst &BI) { 362 if (BI.isUnconditional()) 363 return false; 364 365 return handleBrSelExpect<BranchInst>(BI); 366 } 367 368 static bool lowerExpectIntrinsic(Function &F) { 369 bool Changed = false; 370 371 for (BasicBlock &BB : F) { 372 // Create "block_weights" metadata. 373 if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) { 374 if (handleBranchExpect(*BI)) 375 ExpectIntrinsicsHandled++; 376 } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) { 377 if (handleSwitchExpect(*SI)) 378 ExpectIntrinsicsHandled++; 379 } 380 381 // Remove llvm.expect intrinsics. Iterate backwards in order 382 // to process select instructions before the intrinsic gets 383 // removed. 384 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(BB))) { 385 CallInst *CI = dyn_cast<CallInst>(&Inst); 386 if (!CI) { 387 if (SelectInst *SI = dyn_cast<SelectInst>(&Inst)) { 388 if (handleBrSelExpect(*SI)) 389 ExpectIntrinsicsHandled++; 390 } 391 continue; 392 } 393 394 Function *Fn = CI->getCalledFunction(); 395 if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect || 396 Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) { 397 // Before erasing the llvm.expect, walk backward to find 398 // phi that define llvm.expect's first arg, and 399 // infer branch probability: 400 handlePhiDef(CI); 401 Value *Exp = CI->getArgOperand(0); 402 CI->replaceAllUsesWith(Exp); 403 CI->eraseFromParent(); 404 Changed = true; 405 } 406 } 407 } 408 409 return Changed; 410 } 411 412 PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F, 413 FunctionAnalysisManager &) { 414 if (lowerExpectIntrinsic(F)) 415 return PreservedAnalyses::none(); 416 417 return PreservedAnalyses::all(); 418 } 419 420 namespace { 421 /// Legacy pass for lowering expect intrinsics out of the IR. 422 /// 423 /// When this pass is run over a function it uses expect intrinsics which feed 424 /// branches and switches to provide branch weight metadata for those 425 /// terminators. It then removes the expect intrinsics from the IR so the rest 426 /// of the optimizer can ignore them. 427 class LowerExpectIntrinsic : public FunctionPass { 428 public: 429 static char ID; 430 LowerExpectIntrinsic() : FunctionPass(ID) { 431 initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry()); 432 } 433 434 bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); } 435 }; 436 } // namespace 437 438 char LowerExpectIntrinsic::ID = 0; 439 INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", 440 "Lower 'expect' Intrinsics", false, false) 441 442 FunctionPass *llvm::createLowerExpectIntrinsicPass() { 443 return new LowerExpectIntrinsic(); 444 } 445