xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/BranchProbabilityInfo.cpp (revision 4b50c451720d8b427757a6da1dd2bb4c52cd9e35)
1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/CFG.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/LLVMContext.h"
30 #include "llvm/IR/Metadata.h"
31 #include "llvm/IR/PassManager.h"
32 #include "llvm/IR/Type.h"
33 #include "llvm/IR/Value.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/BranchProbability.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <cassert>
40 #include <cstdint>
41 #include <iterator>
42 #include <utility>
43 
44 using namespace llvm;
45 
46 #define DEBUG_TYPE "branch-prob"
47 
48 static cl::opt<bool> PrintBranchProb(
49     "print-bpi", cl::init(false), cl::Hidden,
50     cl::desc("Print the branch probability info."));
51 
52 cl::opt<std::string> PrintBranchProbFuncName(
53     "print-bpi-func-name", cl::Hidden,
54     cl::desc("The option to specify the name of the function "
55              "whose branch probability info is printed."));
56 
57 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
58                       "Branch Probability Analysis", false, true)
59 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
60 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
61 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
62                     "Branch Probability Analysis", false, true)
63 
64 char BranchProbabilityInfoWrapperPass::ID = 0;
65 
66 // Weights are for internal use only. They are used by heuristics to help to
67 // estimate edges' probability. Example:
68 //
69 // Using "Loop Branch Heuristics" we predict weights of edges for the
70 // block BB2.
71 //         ...
72 //          |
73 //          V
74 //         BB1<-+
75 //          |   |
76 //          |   | (Weight = 124)
77 //          V   |
78 //         BB2--+
79 //          |
80 //          | (Weight = 4)
81 //          V
82 //         BB3
83 //
84 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
85 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
86 static const uint32_t LBH_TAKEN_WEIGHT = 124;
87 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
88 // Unlikely edges within a loop are half as likely as other edges
89 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
90 
91 /// Unreachable-terminating branch taken probability.
92 ///
93 /// This is the probability for a branch being taken to a block that terminates
94 /// (eventually) in unreachable. These are predicted as unlikely as possible.
95 /// All reachable probability will equally share the remaining part.
96 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
97 
98 /// Weight for a branch taken going into a cold block.
99 ///
100 /// This is the weight for a branch taken toward a block marked
101 /// cold.  A block is marked cold if it's postdominated by a
102 /// block containing a call to a cold function.  Cold functions
103 /// are those marked with attribute 'cold'.
104 static const uint32_t CC_TAKEN_WEIGHT = 4;
105 
106 /// Weight for a branch not-taken into a cold block.
107 ///
108 /// This is the weight for a branch not taken toward a block marked
109 /// cold.
110 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
111 
112 static const uint32_t PH_TAKEN_WEIGHT = 20;
113 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
114 
115 static const uint32_t ZH_TAKEN_WEIGHT = 20;
116 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
117 
118 static const uint32_t FPH_TAKEN_WEIGHT = 20;
119 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
120 
121 /// Invoke-terminating normal branch taken weight
122 ///
123 /// This is the weight for branching to the normal destination of an invoke
124 /// instruction. We expect this to happen most of the time. Set the weight to an
125 /// absurdly high value so that nested loops subsume it.
126 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
127 
128 /// Invoke-terminating normal branch not-taken weight.
129 ///
130 /// This is the weight for branching to the unwind destination of an invoke
131 /// instruction. This is essentially never taken.
132 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
133 
134 /// Add \p BB to PostDominatedByUnreachable set if applicable.
135 void
136 BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) {
137   const Instruction *TI = BB->getTerminator();
138   if (TI->getNumSuccessors() == 0) {
139     if (isa<UnreachableInst>(TI) ||
140         // If this block is terminated by a call to
141         // @llvm.experimental.deoptimize then treat it like an unreachable since
142         // the @llvm.experimental.deoptimize call is expected to practically
143         // never execute.
144         BB->getTerminatingDeoptimizeCall())
145       PostDominatedByUnreachable.insert(BB);
146     return;
147   }
148 
149   // If the terminator is an InvokeInst, check only the normal destination block
150   // as the unwind edge of InvokeInst is also very unlikely taken.
151   if (auto *II = dyn_cast<InvokeInst>(TI)) {
152     if (PostDominatedByUnreachable.count(II->getNormalDest()))
153       PostDominatedByUnreachable.insert(BB);
154     return;
155   }
156 
157   for (auto *I : successors(BB))
158     // If any of successor is not post dominated then BB is also not.
159     if (!PostDominatedByUnreachable.count(I))
160       return;
161 
162   PostDominatedByUnreachable.insert(BB);
163 }
164 
165 /// Add \p BB to PostDominatedByColdCall set if applicable.
166 void
167 BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) {
168   assert(!PostDominatedByColdCall.count(BB));
169   const Instruction *TI = BB->getTerminator();
170   if (TI->getNumSuccessors() == 0)
171     return;
172 
173   // If all of successor are post dominated then BB is also done.
174   if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) {
175         return PostDominatedByColdCall.count(SuccBB);
176       })) {
177     PostDominatedByColdCall.insert(BB);
178     return;
179   }
180 
181   // If the terminator is an InvokeInst, check only the normal destination
182   // block as the unwind edge of InvokeInst is also very unlikely taken.
183   if (auto *II = dyn_cast<InvokeInst>(TI))
184     if (PostDominatedByColdCall.count(II->getNormalDest())) {
185       PostDominatedByColdCall.insert(BB);
186       return;
187     }
188 
189   // Otherwise, if the block itself contains a cold function, add it to the
190   // set of blocks post-dominated by a cold call.
191   for (auto &I : *BB)
192     if (const CallInst *CI = dyn_cast<CallInst>(&I))
193       if (CI->hasFnAttr(Attribute::Cold)) {
194         PostDominatedByColdCall.insert(BB);
195         return;
196       }
197 }
198 
199 /// Calculate edge weights for successors lead to unreachable.
200 ///
201 /// Predict that a successor which leads necessarily to an
202 /// unreachable-terminated block as extremely unlikely.
203 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
204   const Instruction *TI = BB->getTerminator();
205   (void) TI;
206   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
207   assert(!isa<InvokeInst>(TI) &&
208          "Invokes should have already been handled by calcInvokeHeuristics");
209 
210   SmallVector<unsigned, 4> UnreachableEdges;
211   SmallVector<unsigned, 4> ReachableEdges;
212 
213   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
214     if (PostDominatedByUnreachable.count(*I))
215       UnreachableEdges.push_back(I.getSuccessorIndex());
216     else
217       ReachableEdges.push_back(I.getSuccessorIndex());
218 
219   // Skip probabilities if all were reachable.
220   if (UnreachableEdges.empty())
221     return false;
222 
223   if (ReachableEdges.empty()) {
224     BranchProbability Prob(1, UnreachableEdges.size());
225     for (unsigned SuccIdx : UnreachableEdges)
226       setEdgeProbability(BB, SuccIdx, Prob);
227     return true;
228   }
229 
230   auto UnreachableProb = UR_TAKEN_PROB;
231   auto ReachableProb =
232       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
233       ReachableEdges.size();
234 
235   for (unsigned SuccIdx : UnreachableEdges)
236     setEdgeProbability(BB, SuccIdx, UnreachableProb);
237   for (unsigned SuccIdx : ReachableEdges)
238     setEdgeProbability(BB, SuccIdx, ReachableProb);
239 
240   return true;
241 }
242 
243 // Propagate existing explicit probabilities from either profile data or
244 // 'expect' intrinsic processing. Examine metadata against unreachable
245 // heuristic. The probability of the edge coming to unreachable block is
246 // set to min of metadata and unreachable heuristic.
247 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
248   const Instruction *TI = BB->getTerminator();
249   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
250   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
251     return false;
252 
253   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
254   if (!WeightsNode)
255     return false;
256 
257   // Check that the number of successors is manageable.
258   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
259 
260   // Ensure there are weights for all of the successors. Note that the first
261   // operand to the metadata node is a name, not a weight.
262   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
263     return false;
264 
265   // Build up the final weights that will be used in a temporary buffer.
266   // Compute the sum of all weights to later decide whether they need to
267   // be scaled to fit in 32 bits.
268   uint64_t WeightSum = 0;
269   SmallVector<uint32_t, 2> Weights;
270   SmallVector<unsigned, 2> UnreachableIdxs;
271   SmallVector<unsigned, 2> ReachableIdxs;
272   Weights.reserve(TI->getNumSuccessors());
273   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
274     ConstantInt *Weight =
275         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
276     if (!Weight)
277       return false;
278     assert(Weight->getValue().getActiveBits() <= 32 &&
279            "Too many bits for uint32_t");
280     Weights.push_back(Weight->getZExtValue());
281     WeightSum += Weights.back();
282     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
283       UnreachableIdxs.push_back(i - 1);
284     else
285       ReachableIdxs.push_back(i - 1);
286   }
287   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
288 
289   // If the sum of weights does not fit in 32 bits, scale every weight down
290   // accordingly.
291   uint64_t ScalingFactor =
292       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
293 
294   if (ScalingFactor > 1) {
295     WeightSum = 0;
296     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
297       Weights[i] /= ScalingFactor;
298       WeightSum += Weights[i];
299     }
300   }
301   assert(WeightSum <= UINT32_MAX &&
302          "Expected weights to scale down to 32 bits");
303 
304   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
305     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
306       Weights[i] = 1;
307     WeightSum = TI->getNumSuccessors();
308   }
309 
310   // Set the probability.
311   SmallVector<BranchProbability, 2> BP;
312   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
313     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
314 
315   // Examine the metadata against unreachable heuristic.
316   // If the unreachable heuristic is more strong then we use it for this edge.
317   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
318     auto ToDistribute = BranchProbability::getZero();
319     auto UnreachableProb = UR_TAKEN_PROB;
320     for (auto i : UnreachableIdxs)
321       if (UnreachableProb < BP[i]) {
322         ToDistribute += BP[i] - UnreachableProb;
323         BP[i] = UnreachableProb;
324       }
325 
326     // If we modified the probability of some edges then we must distribute
327     // the difference between reachable blocks.
328     if (ToDistribute > BranchProbability::getZero()) {
329       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
330       for (auto i : ReachableIdxs)
331         BP[i] += PerEdge;
332     }
333   }
334 
335   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
336     setEdgeProbability(BB, i, BP[i]);
337 
338   return true;
339 }
340 
341 /// Calculate edge weights for edges leading to cold blocks.
342 ///
343 /// A cold block is one post-dominated by  a block with a call to a
344 /// cold function.  Those edges are unlikely to be taken, so we give
345 /// them relatively low weight.
346 ///
347 /// Return true if we could compute the weights for cold edges.
348 /// Return false, otherwise.
349 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
350   const Instruction *TI = BB->getTerminator();
351   (void) TI;
352   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
353   assert(!isa<InvokeInst>(TI) &&
354          "Invokes should have already been handled by calcInvokeHeuristics");
355 
356   // Determine which successors are post-dominated by a cold block.
357   SmallVector<unsigned, 4> ColdEdges;
358   SmallVector<unsigned, 4> NormalEdges;
359   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
360     if (PostDominatedByColdCall.count(*I))
361       ColdEdges.push_back(I.getSuccessorIndex());
362     else
363       NormalEdges.push_back(I.getSuccessorIndex());
364 
365   // Skip probabilities if no cold edges.
366   if (ColdEdges.empty())
367     return false;
368 
369   if (NormalEdges.empty()) {
370     BranchProbability Prob(1, ColdEdges.size());
371     for (unsigned SuccIdx : ColdEdges)
372       setEdgeProbability(BB, SuccIdx, Prob);
373     return true;
374   }
375 
376   auto ColdProb = BranchProbability::getBranchProbability(
377       CC_TAKEN_WEIGHT,
378       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
379   auto NormalProb = BranchProbability::getBranchProbability(
380       CC_NONTAKEN_WEIGHT,
381       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
382 
383   for (unsigned SuccIdx : ColdEdges)
384     setEdgeProbability(BB, SuccIdx, ColdProb);
385   for (unsigned SuccIdx : NormalEdges)
386     setEdgeProbability(BB, SuccIdx, NormalProb);
387 
388   return true;
389 }
390 
391 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
392 // between two pointer or pointer and NULL will fail.
393 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
394   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
395   if (!BI || !BI->isConditional())
396     return false;
397 
398   Value *Cond = BI->getCondition();
399   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
400   if (!CI || !CI->isEquality())
401     return false;
402 
403   Value *LHS = CI->getOperand(0);
404 
405   if (!LHS->getType()->isPointerTy())
406     return false;
407 
408   assert(CI->getOperand(1)->getType()->isPointerTy());
409 
410   // p != 0   ->   isProb = true
411   // p == 0   ->   isProb = false
412   // p != q   ->   isProb = true
413   // p == q   ->   isProb = false;
414   unsigned TakenIdx = 0, NonTakenIdx = 1;
415   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
416   if (!isProb)
417     std::swap(TakenIdx, NonTakenIdx);
418 
419   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
420                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
421   setEdgeProbability(BB, TakenIdx, TakenProb);
422   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
423   return true;
424 }
425 
426 static int getSCCNum(const BasicBlock *BB,
427                      const BranchProbabilityInfo::SccInfo &SccI) {
428   auto SccIt = SccI.SccNums.find(BB);
429   if (SccIt == SccI.SccNums.end())
430     return -1;
431   return SccIt->second;
432 }
433 
434 // Consider any block that is an entry point to the SCC as a header.
435 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
436                         BranchProbabilityInfo::SccInfo &SccI) {
437   assert(getSCCNum(BB, SccI) == SccNum);
438 
439   // Lazily compute the set of headers for a given SCC and cache the results
440   // in the SccHeaderMap.
441   if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
442     SccI.SccHeaders.resize(SccNum + 1);
443   auto &HeaderMap = SccI.SccHeaders[SccNum];
444   bool Inserted;
445   BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
446   std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
447   if (Inserted) {
448     bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
449                                  [&](const BasicBlock *Pred) {
450                                    return getSCCNum(Pred, SccI) != SccNum;
451                                  });
452     HeaderMapIt->second = IsHeader;
453     return IsHeader;
454   } else
455     return HeaderMapIt->second;
456 }
457 
458 // Compute the unlikely successors to the block BB in the loop L, specifically
459 // those that are unlikely because this is a loop, and add them to the
460 // UnlikelyBlocks set.
461 static void
462 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
463                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
464   // Sometimes in a loop we have a branch whose condition is made false by
465   // taking it. This is typically something like
466   //  int n = 0;
467   //  while (...) {
468   //    if (++n >= MAX) {
469   //      n = 0;
470   //    }
471   //  }
472   // In this sort of situation taking the branch means that at the very least it
473   // won't be taken again in the next iteration of the loop, so we should
474   // consider it less likely than a typical branch.
475   //
476   // We detect this by looking back through the graph of PHI nodes that sets the
477   // value that the condition depends on, and seeing if we can reach a successor
478   // block which can be determined to make the condition false.
479   //
480   // FIXME: We currently consider unlikely blocks to be half as likely as other
481   // blocks, but if we consider the example above the likelyhood is actually
482   // 1/MAX. We could therefore be more precise in how unlikely we consider
483   // blocks to be, but it would require more careful examination of the form
484   // of the comparison expression.
485   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
486   if (!BI || !BI->isConditional())
487     return;
488 
489   // Check if the branch is based on an instruction compared with a constant
490   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
491   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
492       !isa<Constant>(CI->getOperand(1)))
493     return;
494 
495   // Either the instruction must be a PHI, or a chain of operations involving
496   // constants that ends in a PHI which we can then collapse into a single value
497   // if the PHI value is known.
498   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
499   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
500   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
501   // Collect the instructions until we hit a PHI
502   SmallVector<BinaryOperator *, 1> InstChain;
503   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
504          isa<Constant>(CmpLHS->getOperand(1))) {
505     // Stop if the chain extends outside of the loop
506     if (!L->contains(CmpLHS))
507       return;
508     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
509     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
510     if (CmpLHS)
511       CmpPHI = dyn_cast<PHINode>(CmpLHS);
512   }
513   if (!CmpPHI || !L->contains(CmpPHI))
514     return;
515 
516   // Trace the phi node to find all values that come from successors of BB
517   SmallPtrSet<PHINode*, 8> VisitedInsts;
518   SmallVector<PHINode*, 8> WorkList;
519   WorkList.push_back(CmpPHI);
520   VisitedInsts.insert(CmpPHI);
521   while (!WorkList.empty()) {
522     PHINode *P = WorkList.back();
523     WorkList.pop_back();
524     for (BasicBlock *B : P->blocks()) {
525       // Skip blocks that aren't part of the loop
526       if (!L->contains(B))
527         continue;
528       Value *V = P->getIncomingValueForBlock(B);
529       // If the source is a PHI add it to the work list if we haven't
530       // already visited it.
531       if (PHINode *PN = dyn_cast<PHINode>(V)) {
532         if (VisitedInsts.insert(PN).second)
533           WorkList.push_back(PN);
534         continue;
535       }
536       // If this incoming value is a constant and B is a successor of BB, then
537       // we can constant-evaluate the compare to see if it makes the branch be
538       // taken or not.
539       Constant *CmpLHSConst = dyn_cast<Constant>(V);
540       if (!CmpLHSConst ||
541           std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
542         continue;
543       // First collapse InstChain
544       for (Instruction *I : llvm::reverse(InstChain)) {
545         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
546                                         cast<Constant>(I->getOperand(1)), true);
547         if (!CmpLHSConst)
548           break;
549       }
550       if (!CmpLHSConst)
551         continue;
552       // Now constant-evaluate the compare
553       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
554                                                   CmpLHSConst, CmpConst, true);
555       // If the result means we don't branch to the block then that block is
556       // unlikely.
557       if (Result &&
558           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
559            (Result->isOneValue() && B == BI->getSuccessor(1))))
560         UnlikelyBlocks.insert(B);
561     }
562   }
563 }
564 
565 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
566 // as taken, exiting edges as not-taken.
567 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
568                                                      const LoopInfo &LI,
569                                                      SccInfo &SccI) {
570   int SccNum;
571   Loop *L = LI.getLoopFor(BB);
572   if (!L) {
573     SccNum = getSCCNum(BB, SccI);
574     if (SccNum < 0)
575       return false;
576   }
577 
578   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
579   if (L)
580     computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
581 
582   SmallVector<unsigned, 8> BackEdges;
583   SmallVector<unsigned, 8> ExitingEdges;
584   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
585   SmallVector<unsigned, 8> UnlikelyEdges;
586 
587   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
588     // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
589     // irreducible loops.
590     if (L) {
591       if (UnlikelyBlocks.count(*I) != 0)
592         UnlikelyEdges.push_back(I.getSuccessorIndex());
593       else if (!L->contains(*I))
594         ExitingEdges.push_back(I.getSuccessorIndex());
595       else if (L->getHeader() == *I)
596         BackEdges.push_back(I.getSuccessorIndex());
597       else
598         InEdges.push_back(I.getSuccessorIndex());
599     } else {
600       if (getSCCNum(*I, SccI) != SccNum)
601         ExitingEdges.push_back(I.getSuccessorIndex());
602       else if (isSCCHeader(*I, SccNum, SccI))
603         BackEdges.push_back(I.getSuccessorIndex());
604       else
605         InEdges.push_back(I.getSuccessorIndex());
606     }
607   }
608 
609   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
610     return false;
611 
612   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
613   // normalize them so that they sum up to one.
614   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
615                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
616                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
617                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
618 
619   if (uint32_t numBackEdges = BackEdges.size()) {
620     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
621     auto Prob = TakenProb / numBackEdges;
622     for (unsigned SuccIdx : BackEdges)
623       setEdgeProbability(BB, SuccIdx, Prob);
624   }
625 
626   if (uint32_t numInEdges = InEdges.size()) {
627     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
628     auto Prob = TakenProb / numInEdges;
629     for (unsigned SuccIdx : InEdges)
630       setEdgeProbability(BB, SuccIdx, Prob);
631   }
632 
633   if (uint32_t numExitingEdges = ExitingEdges.size()) {
634     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
635                                                        Denom);
636     auto Prob = NotTakenProb / numExitingEdges;
637     for (unsigned SuccIdx : ExitingEdges)
638       setEdgeProbability(BB, SuccIdx, Prob);
639   }
640 
641   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
642     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
643                                                        Denom);
644     auto Prob = UnlikelyProb / numUnlikelyEdges;
645     for (unsigned SuccIdx : UnlikelyEdges)
646       setEdgeProbability(BB, SuccIdx, Prob);
647   }
648 
649   return true;
650 }
651 
652 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
653                                                const TargetLibraryInfo *TLI) {
654   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
655   if (!BI || !BI->isConditional())
656     return false;
657 
658   Value *Cond = BI->getCondition();
659   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
660   if (!CI)
661     return false;
662 
663   auto GetConstantInt = [](Value *V) {
664     if (auto *I = dyn_cast<BitCastInst>(V))
665       return dyn_cast<ConstantInt>(I->getOperand(0));
666     return dyn_cast<ConstantInt>(V);
667   };
668 
669   Value *RHS = CI->getOperand(1);
670   ConstantInt *CV = GetConstantInt(RHS);
671   if (!CV)
672     return false;
673 
674   // If the LHS is the result of AND'ing a value with a single bit bitmask,
675   // we don't have information about probabilities.
676   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
677     if (LHS->getOpcode() == Instruction::And)
678       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
679         if (AndRHS->getValue().isPowerOf2())
680           return false;
681 
682   // Check if the LHS is the return value of a library function
683   LibFunc Func = NumLibFuncs;
684   if (TLI)
685     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
686       if (Function *CalledFn = Call->getCalledFunction())
687         TLI->getLibFunc(*CalledFn, Func);
688 
689   bool isProb;
690   if (Func == LibFunc_strcasecmp ||
691       Func == LibFunc_strcmp ||
692       Func == LibFunc_strncasecmp ||
693       Func == LibFunc_strncmp ||
694       Func == LibFunc_memcmp) {
695     // strcmp and similar functions return zero, negative, or positive, if the
696     // first string is equal, less, or greater than the second. We consider it
697     // likely that the strings are not equal, so a comparison with zero is
698     // probably false, but also a comparison with any other number is also
699     // probably false given that what exactly is returned for nonzero values is
700     // not specified. Any kind of comparison other than equality we know
701     // nothing about.
702     switch (CI->getPredicate()) {
703     case CmpInst::ICMP_EQ:
704       isProb = false;
705       break;
706     case CmpInst::ICMP_NE:
707       isProb = true;
708       break;
709     default:
710       return false;
711     }
712   } else if (CV->isZero()) {
713     switch (CI->getPredicate()) {
714     case CmpInst::ICMP_EQ:
715       // X == 0   ->  Unlikely
716       isProb = false;
717       break;
718     case CmpInst::ICMP_NE:
719       // X != 0   ->  Likely
720       isProb = true;
721       break;
722     case CmpInst::ICMP_SLT:
723       // X < 0   ->  Unlikely
724       isProb = false;
725       break;
726     case CmpInst::ICMP_SGT:
727       // X > 0   ->  Likely
728       isProb = true;
729       break;
730     default:
731       return false;
732     }
733   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
734     // InstCombine canonicalizes X <= 0 into X < 1.
735     // X <= 0   ->  Unlikely
736     isProb = false;
737   } else if (CV->isMinusOne()) {
738     switch (CI->getPredicate()) {
739     case CmpInst::ICMP_EQ:
740       // X == -1  ->  Unlikely
741       isProb = false;
742       break;
743     case CmpInst::ICMP_NE:
744       // X != -1  ->  Likely
745       isProb = true;
746       break;
747     case CmpInst::ICMP_SGT:
748       // InstCombine canonicalizes X >= 0 into X > -1.
749       // X >= 0   ->  Likely
750       isProb = true;
751       break;
752     default:
753       return false;
754     }
755   } else {
756     return false;
757   }
758 
759   unsigned TakenIdx = 0, NonTakenIdx = 1;
760 
761   if (!isProb)
762     std::swap(TakenIdx, NonTakenIdx);
763 
764   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
765                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
766   setEdgeProbability(BB, TakenIdx, TakenProb);
767   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
768   return true;
769 }
770 
771 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
772   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
773   if (!BI || !BI->isConditional())
774     return false;
775 
776   Value *Cond = BI->getCondition();
777   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
778   if (!FCmp)
779     return false;
780 
781   bool isProb;
782   if (FCmp->isEquality()) {
783     // f1 == f2 -> Unlikely
784     // f1 != f2 -> Likely
785     isProb = !FCmp->isTrueWhenEqual();
786   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
787     // !isnan -> Likely
788     isProb = true;
789   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
790     // isnan -> Unlikely
791     isProb = false;
792   } else {
793     return false;
794   }
795 
796   unsigned TakenIdx = 0, NonTakenIdx = 1;
797 
798   if (!isProb)
799     std::swap(TakenIdx, NonTakenIdx);
800 
801   BranchProbability TakenProb(FPH_TAKEN_WEIGHT,
802                               FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT);
803   setEdgeProbability(BB, TakenIdx, TakenProb);
804   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
805   return true;
806 }
807 
808 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
809   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
810   if (!II)
811     return false;
812 
813   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
814                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
815   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
816   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
817   return true;
818 }
819 
820 void BranchProbabilityInfo::releaseMemory() {
821   Probs.clear();
822 }
823 
824 void BranchProbabilityInfo::print(raw_ostream &OS) const {
825   OS << "---- Branch Probabilities ----\n";
826   // We print the probabilities from the last function the analysis ran over,
827   // or the function it is currently running over.
828   assert(LastF && "Cannot print prior to running over a function");
829   for (const auto &BI : *LastF) {
830     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
831          ++SI) {
832       printEdgeProbability(OS << "  ", &BI, *SI);
833     }
834   }
835 }
836 
837 bool BranchProbabilityInfo::
838 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
839   // Hot probability is at least 4/5 = 80%
840   // FIXME: Compare against a static "hot" BranchProbability.
841   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
842 }
843 
844 const BasicBlock *
845 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
846   auto MaxProb = BranchProbability::getZero();
847   const BasicBlock *MaxSucc = nullptr;
848 
849   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
850     const BasicBlock *Succ = *I;
851     auto Prob = getEdgeProbability(BB, Succ);
852     if (Prob > MaxProb) {
853       MaxProb = Prob;
854       MaxSucc = Succ;
855     }
856   }
857 
858   // Hot probability is at least 4/5 = 80%
859   if (MaxProb > BranchProbability(4, 5))
860     return MaxSucc;
861 
862   return nullptr;
863 }
864 
865 /// Get the raw edge probability for the edge. If can't find it, return a
866 /// default probability 1/N where N is the number of successors. Here an edge is
867 /// specified using PredBlock and an
868 /// index to the successors.
869 BranchProbability
870 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
871                                           unsigned IndexInSuccessors) const {
872   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
873 
874   if (I != Probs.end())
875     return I->second;
876 
877   return {1, static_cast<uint32_t>(succ_size(Src))};
878 }
879 
880 BranchProbability
881 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
882                                           succ_const_iterator Dst) const {
883   return getEdgeProbability(Src, Dst.getSuccessorIndex());
884 }
885 
886 /// Get the raw edge probability calculated for the block pair. This returns the
887 /// sum of all raw edge probabilities from Src to Dst.
888 BranchProbability
889 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
890                                           const BasicBlock *Dst) const {
891   auto Prob = BranchProbability::getZero();
892   bool FoundProb = false;
893   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
894     if (*I == Dst) {
895       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
896       if (MapI != Probs.end()) {
897         FoundProb = true;
898         Prob += MapI->second;
899       }
900     }
901   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
902   return FoundProb ? Prob : BranchProbability(1, succ_num);
903 }
904 
905 /// Set the edge probability for a given edge specified by PredBlock and an
906 /// index to the successors.
907 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
908                                                unsigned IndexInSuccessors,
909                                                BranchProbability Prob) {
910   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
911   Handles.insert(BasicBlockCallbackVH(Src, this));
912   LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
913                     << IndexInSuccessors << " successor probability to " << Prob
914                     << "\n");
915 }
916 
917 raw_ostream &
918 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
919                                             const BasicBlock *Src,
920                                             const BasicBlock *Dst) const {
921   const BranchProbability Prob = getEdgeProbability(Src, Dst);
922   OS << "edge " << Src->getName() << " -> " << Dst->getName()
923      << " probability is " << Prob
924      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
925 
926   return OS;
927 }
928 
929 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
930   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
931     auto Key = I->first;
932     if (Key.first == BB)
933       Probs.erase(Key);
934   }
935 }
936 
937 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
938                                       const TargetLibraryInfo *TLI) {
939   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
940                     << " ----\n\n");
941   LastF = &F; // Store the last function we ran on for printing.
942   assert(PostDominatedByUnreachable.empty());
943   assert(PostDominatedByColdCall.empty());
944 
945   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
946   // FIXME: We could only calculate this if the CFG is known to be irreducible
947   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
948   int SccNum = 0;
949   SccInfo SccI;
950   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
951        ++It, ++SccNum) {
952     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
953     // catch them.
954     const std::vector<const BasicBlock *> &Scc = *It;
955     if (Scc.size() == 1)
956       continue;
957 
958     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
959     for (auto *BB : Scc) {
960       LLVM_DEBUG(dbgs() << " " << BB->getName());
961       SccI.SccNums[BB] = SccNum;
962     }
963     LLVM_DEBUG(dbgs() << "\n");
964   }
965 
966   // Walk the basic blocks in post-order so that we can build up state about
967   // the successors of a block iteratively.
968   for (auto BB : post_order(&F.getEntryBlock())) {
969     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
970                       << "\n");
971     updatePostDominatedByUnreachable(BB);
972     updatePostDominatedByColdCall(BB);
973     // If there is no at least two successors, no sense to set probability.
974     if (BB->getTerminator()->getNumSuccessors() < 2)
975       continue;
976     if (calcMetadataWeights(BB))
977       continue;
978     if (calcInvokeHeuristics(BB))
979       continue;
980     if (calcUnreachableHeuristics(BB))
981       continue;
982     if (calcColdCallHeuristics(BB))
983       continue;
984     if (calcLoopBranchHeuristics(BB, LI, SccI))
985       continue;
986     if (calcPointerHeuristics(BB))
987       continue;
988     if (calcZeroHeuristics(BB, TLI))
989       continue;
990     if (calcFloatingPointHeuristics(BB))
991       continue;
992   }
993 
994   PostDominatedByUnreachable.clear();
995   PostDominatedByColdCall.clear();
996 
997   if (PrintBranchProb &&
998       (PrintBranchProbFuncName.empty() ||
999        F.getName().equals(PrintBranchProbFuncName))) {
1000     print(dbgs());
1001   }
1002 }
1003 
1004 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1005     AnalysisUsage &AU) const {
1006   // We require DT so it's available when LI is available. The LI updating code
1007   // asserts that DT is also present so if we don't make sure that we have DT
1008   // here, that assert will trigger.
1009   AU.addRequired<DominatorTreeWrapperPass>();
1010   AU.addRequired<LoopInfoWrapperPass>();
1011   AU.addRequired<TargetLibraryInfoWrapperPass>();
1012   AU.setPreservesAll();
1013 }
1014 
1015 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1016   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1017   const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
1018   BPI.calculate(F, LI, &TLI);
1019   return false;
1020 }
1021 
1022 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1023 
1024 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1025                                              const Module *) const {
1026   BPI.print(OS);
1027 }
1028 
1029 AnalysisKey BranchProbabilityAnalysis::Key;
1030 BranchProbabilityInfo
1031 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1032   BranchProbabilityInfo BPI;
1033   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1034   return BPI;
1035 }
1036 
1037 PreservedAnalyses
1038 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1039   OS << "Printing analysis results of BPI for function "
1040      << "'" << F.getName() << "':"
1041      << "\n";
1042   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1043   return PreservedAnalyses::all();
1044 }
1045