xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/BranchProbabilityInfo.cpp (revision b4af4f93c682e445bf159f0d1ec90b636296c946)
1 //===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Loops should be simplified before this analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/BranchProbabilityInfo.h"
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/PostDominators.h"
20 #include "llvm/Analysis/TargetLibraryInfo.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/CFG.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/Metadata.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/BranchProbability.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include <cassert>
43 #include <cstdint>
44 #include <iterator>
45 #include <utility>
46 
47 using namespace llvm;
48 
49 #define DEBUG_TYPE "branch-prob"
50 
51 static cl::opt<bool> PrintBranchProb(
52     "print-bpi", cl::init(false), cl::Hidden,
53     cl::desc("Print the branch probability info."));
54 
55 cl::opt<std::string> PrintBranchProbFuncName(
56     "print-bpi-func-name", cl::Hidden,
57     cl::desc("The option to specify the name of the function "
58              "whose branch probability info is printed."));
59 
60 INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
61                       "Branch Probability Analysis", false, true)
62 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
63 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
64 INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
65                     "Branch Probability Analysis", false, true)
66 
67 BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
68     : FunctionPass(ID) {
69   initializeBranchProbabilityInfoWrapperPassPass(
70       *PassRegistry::getPassRegistry());
71 }
72 
73 char BranchProbabilityInfoWrapperPass::ID = 0;
74 
75 // Weights are for internal use only. They are used by heuristics to help to
76 // estimate edges' probability. Example:
77 //
78 // Using "Loop Branch Heuristics" we predict weights of edges for the
79 // block BB2.
80 //         ...
81 //          |
82 //          V
83 //         BB1<-+
84 //          |   |
85 //          |   | (Weight = 124)
86 //          V   |
87 //         BB2--+
88 //          |
89 //          | (Weight = 4)
90 //          V
91 //         BB3
92 //
93 // Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
94 // Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
95 static const uint32_t LBH_TAKEN_WEIGHT = 124;
96 static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
97 // Unlikely edges within a loop are half as likely as other edges
98 static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
99 
100 /// Unreachable-terminating branch taken probability.
101 ///
102 /// This is the probability for a branch being taken to a block that terminates
103 /// (eventually) in unreachable. These are predicted as unlikely as possible.
104 /// All reachable probability will equally share the remaining part.
105 static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
106 
107 /// Weight for a branch taken going into a cold block.
108 ///
109 /// This is the weight for a branch taken toward a block marked
110 /// cold.  A block is marked cold if it's postdominated by a
111 /// block containing a call to a cold function.  Cold functions
112 /// are those marked with attribute 'cold'.
113 static const uint32_t CC_TAKEN_WEIGHT = 4;
114 
115 /// Weight for a branch not-taken into a cold block.
116 ///
117 /// This is the weight for a branch not taken toward a block marked
118 /// cold.
119 static const uint32_t CC_NONTAKEN_WEIGHT = 64;
120 
121 static const uint32_t PH_TAKEN_WEIGHT = 20;
122 static const uint32_t PH_NONTAKEN_WEIGHT = 12;
123 
124 static const uint32_t ZH_TAKEN_WEIGHT = 20;
125 static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
126 
127 static const uint32_t FPH_TAKEN_WEIGHT = 20;
128 static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
129 
130 /// This is the probability for an ordered floating point comparison.
131 static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
132 /// This is the probability for an unordered floating point comparison, it means
133 /// one or two of the operands are NaN. Usually it is used to test for an
134 /// exceptional case, so the result is unlikely.
135 static const uint32_t FPH_UNO_WEIGHT = 1;
136 
137 /// Invoke-terminating normal branch taken weight
138 ///
139 /// This is the weight for branching to the normal destination of an invoke
140 /// instruction. We expect this to happen most of the time. Set the weight to an
141 /// absurdly high value so that nested loops subsume it.
142 static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
143 
144 /// Invoke-terminating normal branch not-taken weight.
145 ///
146 /// This is the weight for branching to the unwind destination of an invoke
147 /// instruction. This is essentially never taken.
148 static const uint32_t IH_NONTAKEN_WEIGHT = 1;
149 
150 static void UpdatePDTWorklist(const BasicBlock *BB, PostDominatorTree *PDT,
151                               SmallVectorImpl<const BasicBlock *> &WorkList,
152                               SmallPtrSetImpl<const BasicBlock *> &TargetSet) {
153   SmallVector<BasicBlock *, 8> Descendants;
154   SmallPtrSet<const BasicBlock *, 16> NewItems;
155 
156   PDT->getDescendants(const_cast<BasicBlock *>(BB), Descendants);
157   for (auto *BB : Descendants)
158     if (TargetSet.insert(BB).second)
159       for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
160         if (!TargetSet.count(*PI))
161           NewItems.insert(*PI);
162   WorkList.insert(WorkList.end(), NewItems.begin(), NewItems.end());
163 }
164 
165 /// Compute a set of basic blocks that are post-dominated by unreachables.
166 void BranchProbabilityInfo::computePostDominatedByUnreachable(
167     const Function &F, PostDominatorTree *PDT) {
168   SmallVector<const BasicBlock *, 8> WorkList;
169   for (auto &BB : F) {
170     const Instruction *TI = BB.getTerminator();
171     if (TI->getNumSuccessors() == 0) {
172       if (isa<UnreachableInst>(TI) ||
173           // If this block is terminated by a call to
174           // @llvm.experimental.deoptimize then treat it like an unreachable
175           // since the @llvm.experimental.deoptimize call is expected to
176           // practically never execute.
177           BB.getTerminatingDeoptimizeCall())
178         UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByUnreachable);
179     }
180   }
181 
182   while (!WorkList.empty()) {
183     const BasicBlock *BB = WorkList.pop_back_val();
184     if (PostDominatedByUnreachable.count(BB))
185       continue;
186     // If the terminator is an InvokeInst, check only the normal destination
187     // block as the unwind edge of InvokeInst is also very unlikely taken.
188     if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
189       if (PostDominatedByUnreachable.count(II->getNormalDest()))
190         UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
191     }
192     // If all the successors are unreachable, BB is unreachable as well.
193     else if (!successors(BB).empty() &&
194              llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
195                return PostDominatedByUnreachable.count(Succ);
196              }))
197       UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
198   }
199 }
200 
201 /// compute a set of basic blocks that are post-dominated by ColdCalls.
202 void BranchProbabilityInfo::computePostDominatedByColdCall(
203     const Function &F, PostDominatorTree *PDT) {
204   SmallVector<const BasicBlock *, 8> WorkList;
205   for (auto &BB : F)
206     for (auto &I : BB)
207       if (const CallInst *CI = dyn_cast<CallInst>(&I))
208         if (CI->hasFnAttr(Attribute::Cold))
209           UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByColdCall);
210 
211   while (!WorkList.empty()) {
212     const BasicBlock *BB = WorkList.pop_back_val();
213 
214     // If the terminator is an InvokeInst, check only the normal destination
215     // block as the unwind edge of InvokeInst is also very unlikely taken.
216     if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
217       if (PostDominatedByColdCall.count(II->getNormalDest()))
218         UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
219     }
220     // If all of successor are post dominated then BB is also done.
221     else if (!successors(BB).empty() &&
222              llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
223                return PostDominatedByColdCall.count(Succ);
224              }))
225       UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
226   }
227 }
228 
229 /// Calculate edge weights for successors lead to unreachable.
230 ///
231 /// Predict that a successor which leads necessarily to an
232 /// unreachable-terminated block as extremely unlikely.
233 bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
234   const Instruction *TI = BB->getTerminator();
235   (void) TI;
236   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
237   assert(!isa<InvokeInst>(TI) &&
238          "Invokes should have already been handled by calcInvokeHeuristics");
239 
240   SmallVector<unsigned, 4> UnreachableEdges;
241   SmallVector<unsigned, 4> ReachableEdges;
242 
243   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
244     if (PostDominatedByUnreachable.count(*I))
245       UnreachableEdges.push_back(I.getSuccessorIndex());
246     else
247       ReachableEdges.push_back(I.getSuccessorIndex());
248 
249   // Skip probabilities if all were reachable.
250   if (UnreachableEdges.empty())
251     return false;
252 
253   if (ReachableEdges.empty()) {
254     BranchProbability Prob(1, UnreachableEdges.size());
255     for (unsigned SuccIdx : UnreachableEdges)
256       setEdgeProbability(BB, SuccIdx, Prob);
257     return true;
258   }
259 
260   auto UnreachableProb = UR_TAKEN_PROB;
261   auto ReachableProb =
262       (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
263       ReachableEdges.size();
264 
265   for (unsigned SuccIdx : UnreachableEdges)
266     setEdgeProbability(BB, SuccIdx, UnreachableProb);
267   for (unsigned SuccIdx : ReachableEdges)
268     setEdgeProbability(BB, SuccIdx, ReachableProb);
269 
270   return true;
271 }
272 
273 // Propagate existing explicit probabilities from either profile data or
274 // 'expect' intrinsic processing. Examine metadata against unreachable
275 // heuristic. The probability of the edge coming to unreachable block is
276 // set to min of metadata and unreachable heuristic.
277 bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
278   const Instruction *TI = BB->getTerminator();
279   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
280   if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
281     return false;
282 
283   MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
284   if (!WeightsNode)
285     return false;
286 
287   // Check that the number of successors is manageable.
288   assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
289 
290   // Ensure there are weights for all of the successors. Note that the first
291   // operand to the metadata node is a name, not a weight.
292   if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
293     return false;
294 
295   // Build up the final weights that will be used in a temporary buffer.
296   // Compute the sum of all weights to later decide whether they need to
297   // be scaled to fit in 32 bits.
298   uint64_t WeightSum = 0;
299   SmallVector<uint32_t, 2> Weights;
300   SmallVector<unsigned, 2> UnreachableIdxs;
301   SmallVector<unsigned, 2> ReachableIdxs;
302   Weights.reserve(TI->getNumSuccessors());
303   for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
304     ConstantInt *Weight =
305         mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
306     if (!Weight)
307       return false;
308     assert(Weight->getValue().getActiveBits() <= 32 &&
309            "Too many bits for uint32_t");
310     Weights.push_back(Weight->getZExtValue());
311     WeightSum += Weights.back();
312     if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
313       UnreachableIdxs.push_back(i - 1);
314     else
315       ReachableIdxs.push_back(i - 1);
316   }
317   assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
318 
319   // If the sum of weights does not fit in 32 bits, scale every weight down
320   // accordingly.
321   uint64_t ScalingFactor =
322       (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
323 
324   if (ScalingFactor > 1) {
325     WeightSum = 0;
326     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
327       Weights[i] /= ScalingFactor;
328       WeightSum += Weights[i];
329     }
330   }
331   assert(WeightSum <= UINT32_MAX &&
332          "Expected weights to scale down to 32 bits");
333 
334   if (WeightSum == 0 || ReachableIdxs.size() == 0) {
335     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
336       Weights[i] = 1;
337     WeightSum = TI->getNumSuccessors();
338   }
339 
340   // Set the probability.
341   SmallVector<BranchProbability, 2> BP;
342   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
343     BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
344 
345   // Examine the metadata against unreachable heuristic.
346   // If the unreachable heuristic is more strong then we use it for this edge.
347   if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
348     auto ToDistribute = BranchProbability::getZero();
349     auto UnreachableProb = UR_TAKEN_PROB;
350     for (auto i : UnreachableIdxs)
351       if (UnreachableProb < BP[i]) {
352         ToDistribute += BP[i] - UnreachableProb;
353         BP[i] = UnreachableProb;
354       }
355 
356     // If we modified the probability of some edges then we must distribute
357     // the difference between reachable blocks.
358     if (ToDistribute > BranchProbability::getZero()) {
359       BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
360       for (auto i : ReachableIdxs)
361         BP[i] += PerEdge;
362     }
363   }
364 
365   for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
366     setEdgeProbability(BB, i, BP[i]);
367 
368   return true;
369 }
370 
371 /// Calculate edge weights for edges leading to cold blocks.
372 ///
373 /// A cold block is one post-dominated by  a block with a call to a
374 /// cold function.  Those edges are unlikely to be taken, so we give
375 /// them relatively low weight.
376 ///
377 /// Return true if we could compute the weights for cold edges.
378 /// Return false, otherwise.
379 bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
380   const Instruction *TI = BB->getTerminator();
381   (void) TI;
382   assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
383   assert(!isa<InvokeInst>(TI) &&
384          "Invokes should have already been handled by calcInvokeHeuristics");
385 
386   // Determine which successors are post-dominated by a cold block.
387   SmallVector<unsigned, 4> ColdEdges;
388   SmallVector<unsigned, 4> NormalEdges;
389   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
390     if (PostDominatedByColdCall.count(*I))
391       ColdEdges.push_back(I.getSuccessorIndex());
392     else
393       NormalEdges.push_back(I.getSuccessorIndex());
394 
395   // Skip probabilities if no cold edges.
396   if (ColdEdges.empty())
397     return false;
398 
399   if (NormalEdges.empty()) {
400     BranchProbability Prob(1, ColdEdges.size());
401     for (unsigned SuccIdx : ColdEdges)
402       setEdgeProbability(BB, SuccIdx, Prob);
403     return true;
404   }
405 
406   auto ColdProb = BranchProbability::getBranchProbability(
407       CC_TAKEN_WEIGHT,
408       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
409   auto NormalProb = BranchProbability::getBranchProbability(
410       CC_NONTAKEN_WEIGHT,
411       (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
412 
413   for (unsigned SuccIdx : ColdEdges)
414     setEdgeProbability(BB, SuccIdx, ColdProb);
415   for (unsigned SuccIdx : NormalEdges)
416     setEdgeProbability(BB, SuccIdx, NormalProb);
417 
418   return true;
419 }
420 
421 // Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
422 // between two pointer or pointer and NULL will fail.
423 bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
424   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
425   if (!BI || !BI->isConditional())
426     return false;
427 
428   Value *Cond = BI->getCondition();
429   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
430   if (!CI || !CI->isEquality())
431     return false;
432 
433   Value *LHS = CI->getOperand(0);
434 
435   if (!LHS->getType()->isPointerTy())
436     return false;
437 
438   assert(CI->getOperand(1)->getType()->isPointerTy());
439 
440   // p != 0   ->   isProb = true
441   // p == 0   ->   isProb = false
442   // p != q   ->   isProb = true
443   // p == q   ->   isProb = false;
444   unsigned TakenIdx = 0, NonTakenIdx = 1;
445   bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
446   if (!isProb)
447     std::swap(TakenIdx, NonTakenIdx);
448 
449   BranchProbability TakenProb(PH_TAKEN_WEIGHT,
450                               PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
451   setEdgeProbability(BB, TakenIdx, TakenProb);
452   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
453   return true;
454 }
455 
456 static int getSCCNum(const BasicBlock *BB,
457                      const BranchProbabilityInfo::SccInfo &SccI) {
458   auto SccIt = SccI.SccNums.find(BB);
459   if (SccIt == SccI.SccNums.end())
460     return -1;
461   return SccIt->second;
462 }
463 
464 // Consider any block that is an entry point to the SCC as a header.
465 static bool isSCCHeader(const BasicBlock *BB, int SccNum,
466                         BranchProbabilityInfo::SccInfo &SccI) {
467   assert(getSCCNum(BB, SccI) == SccNum);
468 
469   // Lazily compute the set of headers for a given SCC and cache the results
470   // in the SccHeaderMap.
471   if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
472     SccI.SccHeaders.resize(SccNum + 1);
473   auto &HeaderMap = SccI.SccHeaders[SccNum];
474   bool Inserted;
475   BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
476   std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
477   if (Inserted) {
478     bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
479                                  [&](const BasicBlock *Pred) {
480                                    return getSCCNum(Pred, SccI) != SccNum;
481                                  });
482     HeaderMapIt->second = IsHeader;
483     return IsHeader;
484   } else
485     return HeaderMapIt->second;
486 }
487 
488 // Compute the unlikely successors to the block BB in the loop L, specifically
489 // those that are unlikely because this is a loop, and add them to the
490 // UnlikelyBlocks set.
491 static void
492 computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
493                           SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
494   // Sometimes in a loop we have a branch whose condition is made false by
495   // taking it. This is typically something like
496   //  int n = 0;
497   //  while (...) {
498   //    if (++n >= MAX) {
499   //      n = 0;
500   //    }
501   //  }
502   // In this sort of situation taking the branch means that at the very least it
503   // won't be taken again in the next iteration of the loop, so we should
504   // consider it less likely than a typical branch.
505   //
506   // We detect this by looking back through the graph of PHI nodes that sets the
507   // value that the condition depends on, and seeing if we can reach a successor
508   // block which can be determined to make the condition false.
509   //
510   // FIXME: We currently consider unlikely blocks to be half as likely as other
511   // blocks, but if we consider the example above the likelyhood is actually
512   // 1/MAX. We could therefore be more precise in how unlikely we consider
513   // blocks to be, but it would require more careful examination of the form
514   // of the comparison expression.
515   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
516   if (!BI || !BI->isConditional())
517     return;
518 
519   // Check if the branch is based on an instruction compared with a constant
520   CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
521   if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
522       !isa<Constant>(CI->getOperand(1)))
523     return;
524 
525   // Either the instruction must be a PHI, or a chain of operations involving
526   // constants that ends in a PHI which we can then collapse into a single value
527   // if the PHI value is known.
528   Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
529   PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
530   Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
531   // Collect the instructions until we hit a PHI
532   SmallVector<BinaryOperator *, 1> InstChain;
533   while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
534          isa<Constant>(CmpLHS->getOperand(1))) {
535     // Stop if the chain extends outside of the loop
536     if (!L->contains(CmpLHS))
537       return;
538     InstChain.push_back(cast<BinaryOperator>(CmpLHS));
539     CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
540     if (CmpLHS)
541       CmpPHI = dyn_cast<PHINode>(CmpLHS);
542   }
543   if (!CmpPHI || !L->contains(CmpPHI))
544     return;
545 
546   // Trace the phi node to find all values that come from successors of BB
547   SmallPtrSet<PHINode*, 8> VisitedInsts;
548   SmallVector<PHINode*, 8> WorkList;
549   WorkList.push_back(CmpPHI);
550   VisitedInsts.insert(CmpPHI);
551   while (!WorkList.empty()) {
552     PHINode *P = WorkList.back();
553     WorkList.pop_back();
554     for (BasicBlock *B : P->blocks()) {
555       // Skip blocks that aren't part of the loop
556       if (!L->contains(B))
557         continue;
558       Value *V = P->getIncomingValueForBlock(B);
559       // If the source is a PHI add it to the work list if we haven't
560       // already visited it.
561       if (PHINode *PN = dyn_cast<PHINode>(V)) {
562         if (VisitedInsts.insert(PN).second)
563           WorkList.push_back(PN);
564         continue;
565       }
566       // If this incoming value is a constant and B is a successor of BB, then
567       // we can constant-evaluate the compare to see if it makes the branch be
568       // taken or not.
569       Constant *CmpLHSConst = dyn_cast<Constant>(V);
570       if (!CmpLHSConst ||
571           std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
572         continue;
573       // First collapse InstChain
574       for (Instruction *I : llvm::reverse(InstChain)) {
575         CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
576                                         cast<Constant>(I->getOperand(1)), true);
577         if (!CmpLHSConst)
578           break;
579       }
580       if (!CmpLHSConst)
581         continue;
582       // Now constant-evaluate the compare
583       Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
584                                                   CmpLHSConst, CmpConst, true);
585       // If the result means we don't branch to the block then that block is
586       // unlikely.
587       if (Result &&
588           ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
589            (Result->isOneValue() && B == BI->getSuccessor(1))))
590         UnlikelyBlocks.insert(B);
591     }
592   }
593 }
594 
595 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
596 // as taken, exiting edges as not-taken.
597 bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
598                                                      const LoopInfo &LI,
599                                                      SccInfo &SccI) {
600   int SccNum;
601   Loop *L = LI.getLoopFor(BB);
602   if (!L) {
603     SccNum = getSCCNum(BB, SccI);
604     if (SccNum < 0)
605       return false;
606   }
607 
608   SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
609   if (L)
610     computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
611 
612   SmallVector<unsigned, 8> BackEdges;
613   SmallVector<unsigned, 8> ExitingEdges;
614   SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
615   SmallVector<unsigned, 8> UnlikelyEdges;
616 
617   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
618     // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
619     // irreducible loops.
620     if (L) {
621       if (UnlikelyBlocks.count(*I) != 0)
622         UnlikelyEdges.push_back(I.getSuccessorIndex());
623       else if (!L->contains(*I))
624         ExitingEdges.push_back(I.getSuccessorIndex());
625       else if (L->getHeader() == *I)
626         BackEdges.push_back(I.getSuccessorIndex());
627       else
628         InEdges.push_back(I.getSuccessorIndex());
629     } else {
630       if (getSCCNum(*I, SccI) != SccNum)
631         ExitingEdges.push_back(I.getSuccessorIndex());
632       else if (isSCCHeader(*I, SccNum, SccI))
633         BackEdges.push_back(I.getSuccessorIndex());
634       else
635         InEdges.push_back(I.getSuccessorIndex());
636     }
637   }
638 
639   if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
640     return false;
641 
642   // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
643   // normalize them so that they sum up to one.
644   unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
645                    (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
646                    (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
647                    (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
648 
649   if (uint32_t numBackEdges = BackEdges.size()) {
650     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
651     auto Prob = TakenProb / numBackEdges;
652     for (unsigned SuccIdx : BackEdges)
653       setEdgeProbability(BB, SuccIdx, Prob);
654   }
655 
656   if (uint32_t numInEdges = InEdges.size()) {
657     BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
658     auto Prob = TakenProb / numInEdges;
659     for (unsigned SuccIdx : InEdges)
660       setEdgeProbability(BB, SuccIdx, Prob);
661   }
662 
663   if (uint32_t numExitingEdges = ExitingEdges.size()) {
664     BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
665                                                        Denom);
666     auto Prob = NotTakenProb / numExitingEdges;
667     for (unsigned SuccIdx : ExitingEdges)
668       setEdgeProbability(BB, SuccIdx, Prob);
669   }
670 
671   if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
672     BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
673                                                        Denom);
674     auto Prob = UnlikelyProb / numUnlikelyEdges;
675     for (unsigned SuccIdx : UnlikelyEdges)
676       setEdgeProbability(BB, SuccIdx, Prob);
677   }
678 
679   return true;
680 }
681 
682 bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
683                                                const TargetLibraryInfo *TLI) {
684   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
685   if (!BI || !BI->isConditional())
686     return false;
687 
688   Value *Cond = BI->getCondition();
689   ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
690   if (!CI)
691     return false;
692 
693   auto GetConstantInt = [](Value *V) {
694     if (auto *I = dyn_cast<BitCastInst>(V))
695       return dyn_cast<ConstantInt>(I->getOperand(0));
696     return dyn_cast<ConstantInt>(V);
697   };
698 
699   Value *RHS = CI->getOperand(1);
700   ConstantInt *CV = GetConstantInt(RHS);
701   if (!CV)
702     return false;
703 
704   // If the LHS is the result of AND'ing a value with a single bit bitmask,
705   // we don't have information about probabilities.
706   if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
707     if (LHS->getOpcode() == Instruction::And)
708       if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
709         if (AndRHS->getValue().isPowerOf2())
710           return false;
711 
712   // Check if the LHS is the return value of a library function
713   LibFunc Func = NumLibFuncs;
714   if (TLI)
715     if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
716       if (Function *CalledFn = Call->getCalledFunction())
717         TLI->getLibFunc(*CalledFn, Func);
718 
719   bool isProb;
720   if (Func == LibFunc_strcasecmp ||
721       Func == LibFunc_strcmp ||
722       Func == LibFunc_strncasecmp ||
723       Func == LibFunc_strncmp ||
724       Func == LibFunc_memcmp) {
725     // strcmp and similar functions return zero, negative, or positive, if the
726     // first string is equal, less, or greater than the second. We consider it
727     // likely that the strings are not equal, so a comparison with zero is
728     // probably false, but also a comparison with any other number is also
729     // probably false given that what exactly is returned for nonzero values is
730     // not specified. Any kind of comparison other than equality we know
731     // nothing about.
732     switch (CI->getPredicate()) {
733     case CmpInst::ICMP_EQ:
734       isProb = false;
735       break;
736     case CmpInst::ICMP_NE:
737       isProb = true;
738       break;
739     default:
740       return false;
741     }
742   } else if (CV->isZero()) {
743     switch (CI->getPredicate()) {
744     case CmpInst::ICMP_EQ:
745       // X == 0   ->  Unlikely
746       isProb = false;
747       break;
748     case CmpInst::ICMP_NE:
749       // X != 0   ->  Likely
750       isProb = true;
751       break;
752     case CmpInst::ICMP_SLT:
753       // X < 0   ->  Unlikely
754       isProb = false;
755       break;
756     case CmpInst::ICMP_SGT:
757       // X > 0   ->  Likely
758       isProb = true;
759       break;
760     default:
761       return false;
762     }
763   } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
764     // InstCombine canonicalizes X <= 0 into X < 1.
765     // X <= 0   ->  Unlikely
766     isProb = false;
767   } else if (CV->isMinusOne()) {
768     switch (CI->getPredicate()) {
769     case CmpInst::ICMP_EQ:
770       // X == -1  ->  Unlikely
771       isProb = false;
772       break;
773     case CmpInst::ICMP_NE:
774       // X != -1  ->  Likely
775       isProb = true;
776       break;
777     case CmpInst::ICMP_SGT:
778       // InstCombine canonicalizes X >= 0 into X > -1.
779       // X >= 0   ->  Likely
780       isProb = true;
781       break;
782     default:
783       return false;
784     }
785   } else {
786     return false;
787   }
788 
789   unsigned TakenIdx = 0, NonTakenIdx = 1;
790 
791   if (!isProb)
792     std::swap(TakenIdx, NonTakenIdx);
793 
794   BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
795                               ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
796   setEdgeProbability(BB, TakenIdx, TakenProb);
797   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
798   return true;
799 }
800 
801 bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
802   const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
803   if (!BI || !BI->isConditional())
804     return false;
805 
806   Value *Cond = BI->getCondition();
807   FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
808   if (!FCmp)
809     return false;
810 
811   uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
812   uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
813   bool isProb;
814   if (FCmp->isEquality()) {
815     // f1 == f2 -> Unlikely
816     // f1 != f2 -> Likely
817     isProb = !FCmp->isTrueWhenEqual();
818   } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
819     // !isnan -> Likely
820     isProb = true;
821     TakenWeight = FPH_ORD_WEIGHT;
822     NontakenWeight = FPH_UNO_WEIGHT;
823   } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
824     // isnan -> Unlikely
825     isProb = false;
826     TakenWeight = FPH_ORD_WEIGHT;
827     NontakenWeight = FPH_UNO_WEIGHT;
828   } else {
829     return false;
830   }
831 
832   unsigned TakenIdx = 0, NonTakenIdx = 1;
833 
834   if (!isProb)
835     std::swap(TakenIdx, NonTakenIdx);
836 
837   BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
838   setEdgeProbability(BB, TakenIdx, TakenProb);
839   setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
840   return true;
841 }
842 
843 bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
844   const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
845   if (!II)
846     return false;
847 
848   BranchProbability TakenProb(IH_TAKEN_WEIGHT,
849                               IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
850   setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
851   setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
852   return true;
853 }
854 
855 void BranchProbabilityInfo::releaseMemory() {
856   Probs.clear();
857 }
858 
859 void BranchProbabilityInfo::print(raw_ostream &OS) const {
860   OS << "---- Branch Probabilities ----\n";
861   // We print the probabilities from the last function the analysis ran over,
862   // or the function it is currently running over.
863   assert(LastF && "Cannot print prior to running over a function");
864   for (const auto &BI : *LastF) {
865     for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
866          ++SI) {
867       printEdgeProbability(OS << "  ", &BI, *SI);
868     }
869   }
870 }
871 
872 bool BranchProbabilityInfo::
873 isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
874   // Hot probability is at least 4/5 = 80%
875   // FIXME: Compare against a static "hot" BranchProbability.
876   return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
877 }
878 
879 const BasicBlock *
880 BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
881   auto MaxProb = BranchProbability::getZero();
882   const BasicBlock *MaxSucc = nullptr;
883 
884   for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
885     const BasicBlock *Succ = *I;
886     auto Prob = getEdgeProbability(BB, Succ);
887     if (Prob > MaxProb) {
888       MaxProb = Prob;
889       MaxSucc = Succ;
890     }
891   }
892 
893   // Hot probability is at least 4/5 = 80%
894   if (MaxProb > BranchProbability(4, 5))
895     return MaxSucc;
896 
897   return nullptr;
898 }
899 
900 /// Get the raw edge probability for the edge. If can't find it, return a
901 /// default probability 1/N where N is the number of successors. Here an edge is
902 /// specified using PredBlock and an
903 /// index to the successors.
904 BranchProbability
905 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
906                                           unsigned IndexInSuccessors) const {
907   auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
908 
909   if (I != Probs.end())
910     return I->second;
911 
912   return {1, static_cast<uint32_t>(succ_size(Src))};
913 }
914 
915 BranchProbability
916 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
917                                           succ_const_iterator Dst) const {
918   return getEdgeProbability(Src, Dst.getSuccessorIndex());
919 }
920 
921 /// Get the raw edge probability calculated for the block pair. This returns the
922 /// sum of all raw edge probabilities from Src to Dst.
923 BranchProbability
924 BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
925                                           const BasicBlock *Dst) const {
926   auto Prob = BranchProbability::getZero();
927   bool FoundProb = false;
928   for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
929     if (*I == Dst) {
930       auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
931       if (MapI != Probs.end()) {
932         FoundProb = true;
933         Prob += MapI->second;
934       }
935     }
936   uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
937   return FoundProb ? Prob : BranchProbability(1, succ_num);
938 }
939 
940 /// Set the edge probability for a given edge specified by PredBlock and an
941 /// index to the successors.
942 void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
943                                                unsigned IndexInSuccessors,
944                                                BranchProbability Prob) {
945   Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
946   Handles.insert(BasicBlockCallbackVH(Src, this));
947   LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
948                     << IndexInSuccessors << " successor probability to " << Prob
949                     << "\n");
950 }
951 
952 raw_ostream &
953 BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
954                                             const BasicBlock *Src,
955                                             const BasicBlock *Dst) const {
956   const BranchProbability Prob = getEdgeProbability(Src, Dst);
957   OS << "edge " << Src->getName() << " -> " << Dst->getName()
958      << " probability is " << Prob
959      << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
960 
961   return OS;
962 }
963 
964 void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
965   for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
966     auto Key = I->first;
967     if (Key.first == BB)
968       Probs.erase(Key);
969   }
970 }
971 
972 void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
973                                       const TargetLibraryInfo *TLI) {
974   LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
975                     << " ----\n\n");
976   LastF = &F; // Store the last function we ran on for printing.
977   assert(PostDominatedByUnreachable.empty());
978   assert(PostDominatedByColdCall.empty());
979 
980   // Record SCC numbers of blocks in the CFG to identify irreducible loops.
981   // FIXME: We could only calculate this if the CFG is known to be irreducible
982   // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
983   int SccNum = 0;
984   SccInfo SccI;
985   for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
986        ++It, ++SccNum) {
987     // Ignore single-block SCCs since they either aren't loops or LoopInfo will
988     // catch them.
989     const std::vector<const BasicBlock *> &Scc = *It;
990     if (Scc.size() == 1)
991       continue;
992 
993     LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
994     for (auto *BB : Scc) {
995       LLVM_DEBUG(dbgs() << " " << BB->getName());
996       SccI.SccNums[BB] = SccNum;
997     }
998     LLVM_DEBUG(dbgs() << "\n");
999   }
1000 
1001   std::unique_ptr<PostDominatorTree> PDT =
1002       std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
1003   computePostDominatedByUnreachable(F, PDT.get());
1004   computePostDominatedByColdCall(F, PDT.get());
1005 
1006   // Walk the basic blocks in post-order so that we can build up state about
1007   // the successors of a block iteratively.
1008   for (auto BB : post_order(&F.getEntryBlock())) {
1009     LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
1010                       << "\n");
1011     // If there is no at least two successors, no sense to set probability.
1012     if (BB->getTerminator()->getNumSuccessors() < 2)
1013       continue;
1014     if (calcMetadataWeights(BB))
1015       continue;
1016     if (calcInvokeHeuristics(BB))
1017       continue;
1018     if (calcUnreachableHeuristics(BB))
1019       continue;
1020     if (calcColdCallHeuristics(BB))
1021       continue;
1022     if (calcLoopBranchHeuristics(BB, LI, SccI))
1023       continue;
1024     if (calcPointerHeuristics(BB))
1025       continue;
1026     if (calcZeroHeuristics(BB, TLI))
1027       continue;
1028     if (calcFloatingPointHeuristics(BB))
1029       continue;
1030   }
1031 
1032   PostDominatedByUnreachable.clear();
1033   PostDominatedByColdCall.clear();
1034 
1035   if (PrintBranchProb &&
1036       (PrintBranchProbFuncName.empty() ||
1037        F.getName().equals(PrintBranchProbFuncName))) {
1038     print(dbgs());
1039   }
1040 }
1041 
1042 void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1043     AnalysisUsage &AU) const {
1044   // We require DT so it's available when LI is available. The LI updating code
1045   // asserts that DT is also present so if we don't make sure that we have DT
1046   // here, that assert will trigger.
1047   AU.addRequired<DominatorTreeWrapperPass>();
1048   AU.addRequired<LoopInfoWrapperPass>();
1049   AU.addRequired<TargetLibraryInfoWrapperPass>();
1050   AU.setPreservesAll();
1051 }
1052 
1053 bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1054   const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1055   const TargetLibraryInfo &TLI =
1056       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1057   BPI.calculate(F, LI, &TLI);
1058   return false;
1059 }
1060 
1061 void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1062 
1063 void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1064                                              const Module *) const {
1065   BPI.print(OS);
1066 }
1067 
1068 AnalysisKey BranchProbabilityAnalysis::Key;
1069 BranchProbabilityInfo
1070 BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1071   BranchProbabilityInfo BPI;
1072   BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1073   return BPI;
1074 }
1075 
1076 PreservedAnalyses
1077 BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1078   OS << "Printing analysis results of BPI for function "
1079      << "'" << F.getName() << "':"
1080      << "\n";
1081   AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1082   return PreservedAnalyses::all();
1083 }
1084