xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 419822b372f543b22d7fb04eae0dffacf058feb6)
1  //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // Instrumentation-based profile-guided optimization
10  //
11  //===----------------------------------------------------------------------===//
12  
13  #include "CodeGenPGO.h"
14  #include "CodeGenFunction.h"
15  #include "CoverageMappingGen.h"
16  #include "clang/AST/RecursiveASTVisitor.h"
17  #include "clang/AST/StmtVisitor.h"
18  #include "llvm/IR/Intrinsics.h"
19  #include "llvm/IR/MDBuilder.h"
20  #include "llvm/Support/CommandLine.h"
21  #include "llvm/Support/Endian.h"
22  #include "llvm/Support/FileSystem.h"
23  #include "llvm/Support/MD5.h"
24  
25  static llvm::cl::opt<bool>
26      EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                           llvm::cl::desc("Enable value profiling"),
28                           llvm::cl::Hidden, llvm::cl::init(false));
29  
30  using namespace clang;
31  using namespace CodeGen;
32  
33  void CodeGenPGO::setFuncName(StringRef Name,
34                               llvm::GlobalValue::LinkageTypes Linkage) {
35    llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36    FuncName = llvm::getPGOFuncName(
37        Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38        PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39  
40    // If we're generating a profile, create a variable for the name.
41    if (CGM.getCodeGenOpts().hasProfileClangInstr())
42      FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43  }
44  
45  void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46    setFuncName(Fn->getName(), Fn->getLinkage());
47    // Create PGOFuncName meta data.
48    llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49  }
50  
51  /// The version of the PGO hash algorithm.
52  enum PGOHashVersion : unsigned {
53    PGO_HASH_V1,
54    PGO_HASH_V2,
55    PGO_HASH_V3,
56  
57    // Keep this set to the latest hash version.
58    PGO_HASH_LATEST = PGO_HASH_V3
59  };
60  
61  namespace {
62  /// Stable hasher for PGO region counters.
63  ///
64  /// PGOHash produces a stable hash of a given function's control flow.
65  ///
66  /// Changing the output of this hash will invalidate all previously generated
67  /// profiles -- i.e., don't do it.
68  ///
69  /// \note  When this hash does eventually change (years?), we still need to
70  /// support old hashes.  We'll need to pull in the version number from the
71  /// profile data format and use the matching hash function.
72  class PGOHash {
73    uint64_t Working;
74    unsigned Count;
75    PGOHashVersion HashVersion;
76    llvm::MD5 MD5;
77  
78    static const int NumBitsPerType = 6;
79    static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80    static const unsigned TooBig = 1u << NumBitsPerType;
81  
82  public:
83    /// Hash values for AST nodes.
84    ///
85    /// Distinct values for AST nodes that have region counters attached.
86    ///
87    /// These values must be stable.  All new members must be added at the end,
88    /// and no members should be removed.  Changing the enumeration value for an
89    /// AST node will affect the hash of every function that contains that node.
90    enum HashType : unsigned char {
91      None = 0,
92      LabelStmt = 1,
93      WhileStmt,
94      DoStmt,
95      ForStmt,
96      CXXForRangeStmt,
97      ObjCForCollectionStmt,
98      SwitchStmt,
99      CaseStmt,
100      DefaultStmt,
101      IfStmt,
102      CXXTryStmt,
103      CXXCatchStmt,
104      ConditionalOperator,
105      BinaryOperatorLAnd,
106      BinaryOperatorLOr,
107      BinaryConditionalOperator,
108      // The preceding values are available with PGO_HASH_V1.
109  
110      EndOfScope,
111      IfThenBranch,
112      IfElseBranch,
113      GotoStmt,
114      IndirectGotoStmt,
115      BreakStmt,
116      ContinueStmt,
117      ReturnStmt,
118      ThrowExpr,
119      UnaryOperatorLNot,
120      BinaryOperatorLT,
121      BinaryOperatorGT,
122      BinaryOperatorLE,
123      BinaryOperatorGE,
124      BinaryOperatorEQ,
125      BinaryOperatorNE,
126      // The preceding values are available since PGO_HASH_V2.
127  
128      // Keep this last.  It's for the static assert that follows.
129      LastHashType
130    };
131    static_assert(LastHashType <= TooBig, "Too many types in HashType");
132  
133    PGOHash(PGOHashVersion HashVersion)
134        : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135    void combine(HashType Type);
136    uint64_t finalize();
137    PGOHashVersion getHashVersion() const { return HashVersion; }
138  };
139  const int PGOHash::NumBitsPerType;
140  const unsigned PGOHash::NumTypesPerWord;
141  const unsigned PGOHash::TooBig;
142  
143  /// Get the PGO hash version used in the given indexed profile.
144  static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145                                          CodeGenModule &CGM) {
146    if (PGOReader->getVersion() <= 4)
147      return PGO_HASH_V1;
148    if (PGOReader->getVersion() <= 5)
149      return PGO_HASH_V2;
150    return PGO_HASH_V3;
151  }
152  
153  /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154  struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155    using Base = RecursiveASTVisitor<MapRegionCounters>;
156  
157    /// The next counter value to assign.
158    unsigned NextCounter;
159    /// The function hash.
160    PGOHash Hash;
161    /// The map of statements to counters.
162    llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163    /// The profile version.
164    uint64_t ProfileVersion;
165  
166    MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
167                      llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168        : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169          ProfileVersion(ProfileVersion) {}
170  
171    // Blocks and lambdas are handled as separate functions, so we need not
172    // traverse them in the parent context.
173    bool TraverseBlockExpr(BlockExpr *BE) { return true; }
174    bool TraverseLambdaExpr(LambdaExpr *LE) {
175      // Traverse the captures, but not the body.
176      for (auto C : zip(LE->captures(), LE->capture_inits()))
177        TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
178      return true;
179    }
180    bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
181  
182    bool VisitDecl(const Decl *D) {
183      switch (D->getKind()) {
184      default:
185        break;
186      case Decl::Function:
187      case Decl::CXXMethod:
188      case Decl::CXXConstructor:
189      case Decl::CXXDestructor:
190      case Decl::CXXConversion:
191      case Decl::ObjCMethod:
192      case Decl::Block:
193      case Decl::Captured:
194        CounterMap[D->getBody()] = NextCounter++;
195        break;
196      }
197      return true;
198    }
199  
200    /// If \p S gets a fresh counter, update the counter mappings. Return the
201    /// V1 hash of \p S.
202    PGOHash::HashType updateCounterMappings(Stmt *S) {
203      auto Type = getHashType(PGO_HASH_V1, S);
204      if (Type != PGOHash::None)
205        CounterMap[S] = NextCounter++;
206      return Type;
207    }
208  
209    /// The RHS of all logical operators gets a fresh counter in order to count
210    /// how many times the RHS evaluates to true or false, depending on the
211    /// semantics of the operator. This is only valid for ">= v7" of the profile
212    /// version so that we facilitate backward compatibility.
213    bool VisitBinaryOperator(BinaryOperator *S) {
214      if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215        if (S->isLogicalOp() &&
216            CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217          CounterMap[S->getRHS()] = NextCounter++;
218      return Base::VisitBinaryOperator(S);
219    }
220  
221    /// Include \p S in the function hash.
222    bool VisitStmt(Stmt *S) {
223      auto Type = updateCounterMappings(S);
224      if (Hash.getHashVersion() != PGO_HASH_V1)
225        Type = getHashType(Hash.getHashVersion(), S);
226      if (Type != PGOHash::None)
227        Hash.combine(Type);
228      return true;
229    }
230  
231    bool TraverseIfStmt(IfStmt *If) {
232      // If we used the V1 hash, use the default traversal.
233      if (Hash.getHashVersion() == PGO_HASH_V1)
234        return Base::TraverseIfStmt(If);
235  
236      // Otherwise, keep track of which branch we're in while traversing.
237      VisitStmt(If);
238      for (Stmt *CS : If->children()) {
239        if (!CS)
240          continue;
241        if (CS == If->getThen())
242          Hash.combine(PGOHash::IfThenBranch);
243        else if (CS == If->getElse())
244          Hash.combine(PGOHash::IfElseBranch);
245        TraverseStmt(CS);
246      }
247      Hash.combine(PGOHash::EndOfScope);
248      return true;
249    }
250  
251  // If the statement type \p N is nestable, and its nesting impacts profile
252  // stability, define a custom traversal which tracks the end of the statement
253  // in the hash (provided we're not using the V1 hash).
254  #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
255    bool Traverse##N(N *S) {                                                     \
256      Base::Traverse##N(S);                                                      \
257      if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
258        Hash.combine(PGOHash::EndOfScope);                                       \
259      return true;                                                               \
260    }
261  
262    DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
263    DEFINE_NESTABLE_TRAVERSAL(DoStmt)
264    DEFINE_NESTABLE_TRAVERSAL(ForStmt)
265    DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
266    DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
267    DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
268    DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
269  
270    /// Get version \p HashVersion of the PGO hash for \p S.
271    PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
272      switch (S->getStmtClass()) {
273      default:
274        break;
275      case Stmt::LabelStmtClass:
276        return PGOHash::LabelStmt;
277      case Stmt::WhileStmtClass:
278        return PGOHash::WhileStmt;
279      case Stmt::DoStmtClass:
280        return PGOHash::DoStmt;
281      case Stmt::ForStmtClass:
282        return PGOHash::ForStmt;
283      case Stmt::CXXForRangeStmtClass:
284        return PGOHash::CXXForRangeStmt;
285      case Stmt::ObjCForCollectionStmtClass:
286        return PGOHash::ObjCForCollectionStmt;
287      case Stmt::SwitchStmtClass:
288        return PGOHash::SwitchStmt;
289      case Stmt::CaseStmtClass:
290        return PGOHash::CaseStmt;
291      case Stmt::DefaultStmtClass:
292        return PGOHash::DefaultStmt;
293      case Stmt::IfStmtClass:
294        return PGOHash::IfStmt;
295      case Stmt::CXXTryStmtClass:
296        return PGOHash::CXXTryStmt;
297      case Stmt::CXXCatchStmtClass:
298        return PGOHash::CXXCatchStmt;
299      case Stmt::ConditionalOperatorClass:
300        return PGOHash::ConditionalOperator;
301      case Stmt::BinaryConditionalOperatorClass:
302        return PGOHash::BinaryConditionalOperator;
303      case Stmt::BinaryOperatorClass: {
304        const BinaryOperator *BO = cast<BinaryOperator>(S);
305        if (BO->getOpcode() == BO_LAnd)
306          return PGOHash::BinaryOperatorLAnd;
307        if (BO->getOpcode() == BO_LOr)
308          return PGOHash::BinaryOperatorLOr;
309        if (HashVersion >= PGO_HASH_V2) {
310          switch (BO->getOpcode()) {
311          default:
312            break;
313          case BO_LT:
314            return PGOHash::BinaryOperatorLT;
315          case BO_GT:
316            return PGOHash::BinaryOperatorGT;
317          case BO_LE:
318            return PGOHash::BinaryOperatorLE;
319          case BO_GE:
320            return PGOHash::BinaryOperatorGE;
321          case BO_EQ:
322            return PGOHash::BinaryOperatorEQ;
323          case BO_NE:
324            return PGOHash::BinaryOperatorNE;
325          }
326        }
327        break;
328      }
329      }
330  
331      if (HashVersion >= PGO_HASH_V2) {
332        switch (S->getStmtClass()) {
333        default:
334          break;
335        case Stmt::GotoStmtClass:
336          return PGOHash::GotoStmt;
337        case Stmt::IndirectGotoStmtClass:
338          return PGOHash::IndirectGotoStmt;
339        case Stmt::BreakStmtClass:
340          return PGOHash::BreakStmt;
341        case Stmt::ContinueStmtClass:
342          return PGOHash::ContinueStmt;
343        case Stmt::ReturnStmtClass:
344          return PGOHash::ReturnStmt;
345        case Stmt::CXXThrowExprClass:
346          return PGOHash::ThrowExpr;
347        case Stmt::UnaryOperatorClass: {
348          const UnaryOperator *UO = cast<UnaryOperator>(S);
349          if (UO->getOpcode() == UO_LNot)
350            return PGOHash::UnaryOperatorLNot;
351          break;
352        }
353        }
354      }
355  
356      return PGOHash::None;
357    }
358  };
359  
360  /// A StmtVisitor that propagates the raw counts through the AST and
361  /// records the count at statements where the value may change.
362  struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
363    /// PGO state.
364    CodeGenPGO &PGO;
365  
366    /// A flag that is set when the current count should be recorded on the
367    /// next statement, such as at the exit of a loop.
368    bool RecordNextStmtCount;
369  
370    /// The count at the current location in the traversal.
371    uint64_t CurrentCount;
372  
373    /// The map of statements to count values.
374    llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
375  
376    /// BreakContinueStack - Keep counts of breaks and continues inside loops.
377    struct BreakContinue {
378      uint64_t BreakCount;
379      uint64_t ContinueCount;
380      BreakContinue() : BreakCount(0), ContinueCount(0) {}
381    };
382    SmallVector<BreakContinue, 8> BreakContinueStack;
383  
384    ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
385                        CodeGenPGO &PGO)
386        : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
387  
388    void RecordStmtCount(const Stmt *S) {
389      if (RecordNextStmtCount) {
390        CountMap[S] = CurrentCount;
391        RecordNextStmtCount = false;
392      }
393    }
394  
395    /// Set and return the current count.
396    uint64_t setCount(uint64_t Count) {
397      CurrentCount = Count;
398      return Count;
399    }
400  
401    void VisitStmt(const Stmt *S) {
402      RecordStmtCount(S);
403      for (const Stmt *Child : S->children())
404        if (Child)
405          this->Visit(Child);
406    }
407  
408    void VisitFunctionDecl(const FunctionDecl *D) {
409      // Counter tracks entry to the function body.
410      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411      CountMap[D->getBody()] = BodyCount;
412      Visit(D->getBody());
413    }
414  
415    // Skip lambda expressions. We visit these as FunctionDecls when we're
416    // generating them and aren't interested in the body when generating a
417    // parent context.
418    void VisitLambdaExpr(const LambdaExpr *LE) {}
419  
420    void VisitCapturedDecl(const CapturedDecl *D) {
421      // Counter tracks entry to the capture body.
422      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
423      CountMap[D->getBody()] = BodyCount;
424      Visit(D->getBody());
425    }
426  
427    void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
428      // Counter tracks entry to the method body.
429      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
430      CountMap[D->getBody()] = BodyCount;
431      Visit(D->getBody());
432    }
433  
434    void VisitBlockDecl(const BlockDecl *D) {
435      // Counter tracks entry to the block body.
436      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
437      CountMap[D->getBody()] = BodyCount;
438      Visit(D->getBody());
439    }
440  
441    void VisitReturnStmt(const ReturnStmt *S) {
442      RecordStmtCount(S);
443      if (S->getRetValue())
444        Visit(S->getRetValue());
445      CurrentCount = 0;
446      RecordNextStmtCount = true;
447    }
448  
449    void VisitCXXThrowExpr(const CXXThrowExpr *E) {
450      RecordStmtCount(E);
451      if (E->getSubExpr())
452        Visit(E->getSubExpr());
453      CurrentCount = 0;
454      RecordNextStmtCount = true;
455    }
456  
457    void VisitGotoStmt(const GotoStmt *S) {
458      RecordStmtCount(S);
459      CurrentCount = 0;
460      RecordNextStmtCount = true;
461    }
462  
463    void VisitLabelStmt(const LabelStmt *S) {
464      RecordNextStmtCount = false;
465      // Counter tracks the block following the label.
466      uint64_t BlockCount = setCount(PGO.getRegionCount(S));
467      CountMap[S] = BlockCount;
468      Visit(S->getSubStmt());
469    }
470  
471    void VisitBreakStmt(const BreakStmt *S) {
472      RecordStmtCount(S);
473      assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
474      BreakContinueStack.back().BreakCount += CurrentCount;
475      CurrentCount = 0;
476      RecordNextStmtCount = true;
477    }
478  
479    void VisitContinueStmt(const ContinueStmt *S) {
480      RecordStmtCount(S);
481      assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
482      BreakContinueStack.back().ContinueCount += CurrentCount;
483      CurrentCount = 0;
484      RecordNextStmtCount = true;
485    }
486  
487    void VisitWhileStmt(const WhileStmt *S) {
488      RecordStmtCount(S);
489      uint64_t ParentCount = CurrentCount;
490  
491      BreakContinueStack.push_back(BreakContinue());
492      // Visit the body region first so the break/continue adjustments can be
493      // included when visiting the condition.
494      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
495      CountMap[S->getBody()] = CurrentCount;
496      Visit(S->getBody());
497      uint64_t BackedgeCount = CurrentCount;
498  
499      // ...then go back and propagate counts through the condition. The count
500      // at the start of the condition is the sum of the incoming edges,
501      // the backedge from the end of the loop body, and the edges from
502      // continue statements.
503      BreakContinue BC = BreakContinueStack.pop_back_val();
504      uint64_t CondCount =
505          setCount(ParentCount + BackedgeCount + BC.ContinueCount);
506      CountMap[S->getCond()] = CondCount;
507      Visit(S->getCond());
508      setCount(BC.BreakCount + CondCount - BodyCount);
509      RecordNextStmtCount = true;
510    }
511  
512    void VisitDoStmt(const DoStmt *S) {
513      RecordStmtCount(S);
514      uint64_t LoopCount = PGO.getRegionCount(S);
515  
516      BreakContinueStack.push_back(BreakContinue());
517      // The count doesn't include the fallthrough from the parent scope. Add it.
518      uint64_t BodyCount = setCount(LoopCount + CurrentCount);
519      CountMap[S->getBody()] = BodyCount;
520      Visit(S->getBody());
521      uint64_t BackedgeCount = CurrentCount;
522  
523      BreakContinue BC = BreakContinueStack.pop_back_val();
524      // The count at the start of the condition is equal to the count at the
525      // end of the body, plus any continues.
526      uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
527      CountMap[S->getCond()] = CondCount;
528      Visit(S->getCond());
529      setCount(BC.BreakCount + CondCount - LoopCount);
530      RecordNextStmtCount = true;
531    }
532  
533    void VisitForStmt(const ForStmt *S) {
534      RecordStmtCount(S);
535      if (S->getInit())
536        Visit(S->getInit());
537  
538      uint64_t ParentCount = CurrentCount;
539  
540      BreakContinueStack.push_back(BreakContinue());
541      // Visit the body region first. (This is basically the same as a while
542      // loop; see further comments in VisitWhileStmt.)
543      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
544      CountMap[S->getBody()] = BodyCount;
545      Visit(S->getBody());
546      uint64_t BackedgeCount = CurrentCount;
547      BreakContinue BC = BreakContinueStack.pop_back_val();
548  
549      // The increment is essentially part of the body but it needs to include
550      // the count for all the continue statements.
551      if (S->getInc()) {
552        uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
553        CountMap[S->getInc()] = IncCount;
554        Visit(S->getInc());
555      }
556  
557      // ...then go back and propagate counts through the condition.
558      uint64_t CondCount =
559          setCount(ParentCount + BackedgeCount + BC.ContinueCount);
560      if (S->getCond()) {
561        CountMap[S->getCond()] = CondCount;
562        Visit(S->getCond());
563      }
564      setCount(BC.BreakCount + CondCount - BodyCount);
565      RecordNextStmtCount = true;
566    }
567  
568    void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
569      RecordStmtCount(S);
570      if (S->getInit())
571        Visit(S->getInit());
572      Visit(S->getLoopVarStmt());
573      Visit(S->getRangeStmt());
574      Visit(S->getBeginStmt());
575      Visit(S->getEndStmt());
576  
577      uint64_t ParentCount = CurrentCount;
578      BreakContinueStack.push_back(BreakContinue());
579      // Visit the body region first. (This is basically the same as a while
580      // loop; see further comments in VisitWhileStmt.)
581      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
582      CountMap[S->getBody()] = BodyCount;
583      Visit(S->getBody());
584      uint64_t BackedgeCount = CurrentCount;
585      BreakContinue BC = BreakContinueStack.pop_back_val();
586  
587      // The increment is essentially part of the body but it needs to include
588      // the count for all the continue statements.
589      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
590      CountMap[S->getInc()] = IncCount;
591      Visit(S->getInc());
592  
593      // ...then go back and propagate counts through the condition.
594      uint64_t CondCount =
595          setCount(ParentCount + BackedgeCount + BC.ContinueCount);
596      CountMap[S->getCond()] = CondCount;
597      Visit(S->getCond());
598      setCount(BC.BreakCount + CondCount - BodyCount);
599      RecordNextStmtCount = true;
600    }
601  
602    void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
603      RecordStmtCount(S);
604      Visit(S->getElement());
605      uint64_t ParentCount = CurrentCount;
606      BreakContinueStack.push_back(BreakContinue());
607      // Counter tracks the body of the loop.
608      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
609      CountMap[S->getBody()] = BodyCount;
610      Visit(S->getBody());
611      uint64_t BackedgeCount = CurrentCount;
612      BreakContinue BC = BreakContinueStack.pop_back_val();
613  
614      setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
615               BodyCount);
616      RecordNextStmtCount = true;
617    }
618  
619    void VisitSwitchStmt(const SwitchStmt *S) {
620      RecordStmtCount(S);
621      if (S->getInit())
622        Visit(S->getInit());
623      Visit(S->getCond());
624      CurrentCount = 0;
625      BreakContinueStack.push_back(BreakContinue());
626      Visit(S->getBody());
627      // If the switch is inside a loop, add the continue counts.
628      BreakContinue BC = BreakContinueStack.pop_back_val();
629      if (!BreakContinueStack.empty())
630        BreakContinueStack.back().ContinueCount += BC.ContinueCount;
631      // Counter tracks the exit block of the switch.
632      setCount(PGO.getRegionCount(S));
633      RecordNextStmtCount = true;
634    }
635  
636    void VisitSwitchCase(const SwitchCase *S) {
637      RecordNextStmtCount = false;
638      // Counter for this particular case. This counts only jumps from the
639      // switch header and does not include fallthrough from the case before
640      // this one.
641      uint64_t CaseCount = PGO.getRegionCount(S);
642      setCount(CurrentCount + CaseCount);
643      // We need the count without fallthrough in the mapping, so it's more useful
644      // for branch probabilities.
645      CountMap[S] = CaseCount;
646      RecordNextStmtCount = true;
647      Visit(S->getSubStmt());
648    }
649  
650    void VisitIfStmt(const IfStmt *S) {
651      RecordStmtCount(S);
652      uint64_t ParentCount = CurrentCount;
653      if (S->getInit())
654        Visit(S->getInit());
655      Visit(S->getCond());
656  
657      // Counter tracks the "then" part of an if statement. The count for
658      // the "else" part, if it exists, will be calculated from this counter.
659      uint64_t ThenCount = setCount(PGO.getRegionCount(S));
660      CountMap[S->getThen()] = ThenCount;
661      Visit(S->getThen());
662      uint64_t OutCount = CurrentCount;
663  
664      uint64_t ElseCount = ParentCount - ThenCount;
665      if (S->getElse()) {
666        setCount(ElseCount);
667        CountMap[S->getElse()] = ElseCount;
668        Visit(S->getElse());
669        OutCount += CurrentCount;
670      } else
671        OutCount += ElseCount;
672      setCount(OutCount);
673      RecordNextStmtCount = true;
674    }
675  
676    void VisitCXXTryStmt(const CXXTryStmt *S) {
677      RecordStmtCount(S);
678      Visit(S->getTryBlock());
679      for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
680        Visit(S->getHandler(I));
681      // Counter tracks the continuation block of the try statement.
682      setCount(PGO.getRegionCount(S));
683      RecordNextStmtCount = true;
684    }
685  
686    void VisitCXXCatchStmt(const CXXCatchStmt *S) {
687      RecordNextStmtCount = false;
688      // Counter tracks the catch statement's handler block.
689      uint64_t CatchCount = setCount(PGO.getRegionCount(S));
690      CountMap[S] = CatchCount;
691      Visit(S->getHandlerBlock());
692    }
693  
694    void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
695      RecordStmtCount(E);
696      uint64_t ParentCount = CurrentCount;
697      Visit(E->getCond());
698  
699      // Counter tracks the "true" part of a conditional operator. The
700      // count in the "false" part will be calculated from this counter.
701      uint64_t TrueCount = setCount(PGO.getRegionCount(E));
702      CountMap[E->getTrueExpr()] = TrueCount;
703      Visit(E->getTrueExpr());
704      uint64_t OutCount = CurrentCount;
705  
706      uint64_t FalseCount = setCount(ParentCount - TrueCount);
707      CountMap[E->getFalseExpr()] = FalseCount;
708      Visit(E->getFalseExpr());
709      OutCount += CurrentCount;
710  
711      setCount(OutCount);
712      RecordNextStmtCount = true;
713    }
714  
715    void VisitBinLAnd(const BinaryOperator *E) {
716      RecordStmtCount(E);
717      uint64_t ParentCount = CurrentCount;
718      Visit(E->getLHS());
719      // Counter tracks the right hand side of a logical and operator.
720      uint64_t RHSCount = setCount(PGO.getRegionCount(E));
721      CountMap[E->getRHS()] = RHSCount;
722      Visit(E->getRHS());
723      setCount(ParentCount + RHSCount - CurrentCount);
724      RecordNextStmtCount = true;
725    }
726  
727    void VisitBinLOr(const BinaryOperator *E) {
728      RecordStmtCount(E);
729      uint64_t ParentCount = CurrentCount;
730      Visit(E->getLHS());
731      // Counter tracks the right hand side of a logical or operator.
732      uint64_t RHSCount = setCount(PGO.getRegionCount(E));
733      CountMap[E->getRHS()] = RHSCount;
734      Visit(E->getRHS());
735      setCount(ParentCount + RHSCount - CurrentCount);
736      RecordNextStmtCount = true;
737    }
738  };
739  } // end anonymous namespace
740  
741  void PGOHash::combine(HashType Type) {
742    // Check that we never combine 0 and only have six bits.
743    assert(Type && "Hash is invalid: unexpected type 0");
744    assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
745  
746    // Pass through MD5 if enough work has built up.
747    if (Count && Count % NumTypesPerWord == 0) {
748      using namespace llvm::support;
749      uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
750      MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
751      Working = 0;
752    }
753  
754    // Accumulate the current type.
755    ++Count;
756    Working = Working << NumBitsPerType | Type;
757  }
758  
759  uint64_t PGOHash::finalize() {
760    // Use Working as the hash directly if we never used MD5.
761    if (Count <= NumTypesPerWord)
762      // No need to byte swap here, since none of the math was endian-dependent.
763      // This number will be byte-swapped as required on endianness transitions,
764      // so we will see the same value on the other side.
765      return Working;
766  
767    // Check for remaining work in Working.
768    if (Working) {
769      // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
770      // is buggy because it converts a uint64_t into an array of uint8_t.
771      if (HashVersion < PGO_HASH_V3) {
772        MD5.update({(uint8_t)Working});
773      } else {
774        using namespace llvm::support;
775        uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
776        MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
777      }
778    }
779  
780    // Finalize the MD5 and return the hash.
781    llvm::MD5::MD5Result Result;
782    MD5.final(Result);
783    return Result.low();
784  }
785  
786  void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
787    const Decl *D = GD.getDecl();
788    if (!D->hasBody())
789      return;
790  
791    // Skip CUDA/HIP kernel launch stub functions.
792    if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
793        D->hasAttr<CUDAGlobalAttr>())
794      return;
795  
796    bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
797    llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
798    if (!InstrumentRegions && !PGOReader)
799      return;
800    if (D->isImplicit())
801      return;
802    // Constructors and destructors may be represented by several functions in IR.
803    // If so, instrument only base variant, others are implemented by delegation
804    // to the base one, it would be counted twice otherwise.
805    if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
806      if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
807        if (GD.getCtorType() != Ctor_Base &&
808            CodeGenFunction::IsConstructorDelegationValid(CCD))
809          return;
810    }
811    if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
812      return;
813  
814    CGM.ClearUnusedCoverageMapping(D);
815    if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
816      return;
817  
818    setFuncName(Fn);
819  
820    mapRegionCounters(D);
821    if (CGM.getCodeGenOpts().CoverageMapping)
822      emitCounterRegionMapping(D);
823    if (PGOReader) {
824      SourceManager &SM = CGM.getContext().getSourceManager();
825      loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
826      computeRegionCounts(D);
827      applyFunctionAttributes(PGOReader, Fn);
828    }
829  }
830  
831  void CodeGenPGO::mapRegionCounters(const Decl *D) {
832    // Use the latest hash version when inserting instrumentation, but use the
833    // version in the indexed profile if we're reading PGO data.
834    PGOHashVersion HashVersion = PGO_HASH_LATEST;
835    uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
836    if (auto *PGOReader = CGM.getPGOReader()) {
837      HashVersion = getPGOHashVersion(PGOReader, CGM);
838      ProfileVersion = PGOReader->getVersion();
839    }
840  
841    RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
842    MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
843    if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
844      Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
845    else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
846      Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
847    else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
848      Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
849    else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
850      Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
851    assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
852    NumRegionCounters = Walker.NextCounter;
853    FunctionHash = Walker.Hash.finalize();
854  }
855  
856  bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
857    if (!D->getBody())
858      return true;
859  
860    // Skip host-only functions in the CUDA device compilation and device-only
861    // functions in the host compilation. Just roughly filter them out based on
862    // the function attributes. If there are effectively host-only or device-only
863    // ones, their coverage mapping may still be generated.
864    if (CGM.getLangOpts().CUDA &&
865        ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
866          !D->hasAttr<CUDAGlobalAttr>()) ||
867         (!CGM.getLangOpts().CUDAIsDevice &&
868          (D->hasAttr<CUDAGlobalAttr>() ||
869           (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
870      return true;
871  
872    // Don't map the functions in system headers.
873    const auto &SM = CGM.getContext().getSourceManager();
874    auto Loc = D->getBody()->getBeginLoc();
875    return SM.isInSystemHeader(Loc);
876  }
877  
878  void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
879    if (skipRegionMappingForDecl(D))
880      return;
881  
882    std::string CoverageMapping;
883    llvm::raw_string_ostream OS(CoverageMapping);
884    CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
885                                  CGM.getContext().getSourceManager(),
886                                  CGM.getLangOpts(), RegionCounterMap.get());
887    MappingGen.emitCounterMapping(D, OS);
888    OS.flush();
889  
890    if (CoverageMapping.empty())
891      return;
892  
893    CGM.getCoverageMapping()->addFunctionMappingRecord(
894        FuncNameVar, FuncName, FunctionHash, CoverageMapping);
895  }
896  
897  void
898  CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
899                                      llvm::GlobalValue::LinkageTypes Linkage) {
900    if (skipRegionMappingForDecl(D))
901      return;
902  
903    std::string CoverageMapping;
904    llvm::raw_string_ostream OS(CoverageMapping);
905    CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
906                                  CGM.getContext().getSourceManager(),
907                                  CGM.getLangOpts());
908    MappingGen.emitEmptyMapping(D, OS);
909    OS.flush();
910  
911    if (CoverageMapping.empty())
912      return;
913  
914    setFuncName(Name, Linkage);
915    CGM.getCoverageMapping()->addFunctionMappingRecord(
916        FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
917  }
918  
919  void CodeGenPGO::computeRegionCounts(const Decl *D) {
920    StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
921    ComputeRegionCounts Walker(*StmtCountMap, *this);
922    if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
923      Walker.VisitFunctionDecl(FD);
924    else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
925      Walker.VisitObjCMethodDecl(MD);
926    else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
927      Walker.VisitBlockDecl(BD);
928    else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
929      Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
930  }
931  
932  void
933  CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
934                                      llvm::Function *Fn) {
935    if (!haveRegionCounts())
936      return;
937  
938    uint64_t FunctionCount = getRegionCount(nullptr);
939    Fn->setEntryCount(FunctionCount);
940  }
941  
942  void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
943                                        llvm::Value *StepV) {
944    if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
945      return;
946    if (!Builder.GetInsertBlock())
947      return;
948  
949    unsigned Counter = (*RegionCounterMap)[S];
950    auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
951  
952    llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
953                           Builder.getInt64(FunctionHash),
954                           Builder.getInt32(NumRegionCounters),
955                           Builder.getInt32(Counter), StepV};
956    if (!StepV)
957      Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
958                         makeArrayRef(Args, 4));
959    else
960      Builder.CreateCall(
961          CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
962          makeArrayRef(Args));
963  }
964  
965  void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
966    if (CGM.getCodeGenOpts().hasProfileClangInstr())
967      M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
968                      uint32_t(EnableValueProfiling));
969  }
970  
971  // This method either inserts a call to the profile run-time during
972  // instrumentation or puts profile data into metadata for PGO use.
973  void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
974      llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
975  
976    if (!EnableValueProfiling)
977      return;
978  
979    if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
980      return;
981  
982    if (isa<llvm::Constant>(ValuePtr))
983      return;
984  
985    bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
986    if (InstrumentValueSites && RegionCounterMap) {
987      auto BuilderInsertPoint = Builder.saveIP();
988      Builder.SetInsertPoint(ValueSite);
989      llvm::Value *Args[5] = {
990          llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
991          Builder.getInt64(FunctionHash),
992          Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
993          Builder.getInt32(ValueKind),
994          Builder.getInt32(NumValueSites[ValueKind]++)
995      };
996      Builder.CreateCall(
997          CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
998      Builder.restoreIP(BuilderInsertPoint);
999      return;
1000    }
1001  
1002    llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1003    if (PGOReader && haveRegionCounts()) {
1004      // We record the top most called three functions at each call site.
1005      // Profile metadata contains "VP" string identifying this metadata
1006      // as value profiling data, then a uint32_t value for the value profiling
1007      // kind, a uint64_t value for the total number of times the call is
1008      // executed, followed by the function hash and execution count (uint64_t)
1009      // pairs for each function.
1010      if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1011        return;
1012  
1013      llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1014                              (llvm::InstrProfValueKind)ValueKind,
1015                              NumValueSites[ValueKind]);
1016  
1017      NumValueSites[ValueKind]++;
1018    }
1019  }
1020  
1021  void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1022                                    bool IsInMainFile) {
1023    CGM.getPGOStats().addVisited(IsInMainFile);
1024    RegionCounts.clear();
1025    llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1026        PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1027    if (auto E = RecordExpected.takeError()) {
1028      auto IPE = llvm::InstrProfError::take(std::move(E));
1029      if (IPE == llvm::instrprof_error::unknown_function)
1030        CGM.getPGOStats().addMissing(IsInMainFile);
1031      else if (IPE == llvm::instrprof_error::hash_mismatch)
1032        CGM.getPGOStats().addMismatched(IsInMainFile);
1033      else if (IPE == llvm::instrprof_error::malformed)
1034        // TODO: Consider a more specific warning for this case.
1035        CGM.getPGOStats().addMismatched(IsInMainFile);
1036      return;
1037    }
1038    ProfRecord =
1039        std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1040    RegionCounts = ProfRecord->Counts;
1041  }
1042  
1043  /// Calculate what to divide by to scale weights.
1044  ///
1045  /// Given the maximum weight, calculate a divisor that will scale all the
1046  /// weights to strictly less than UINT32_MAX.
1047  static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1048    return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1049  }
1050  
1051  /// Scale an individual branch weight (and add 1).
1052  ///
1053  /// Scale a 64-bit weight down to 32-bits using \c Scale.
1054  ///
1055  /// According to Laplace's Rule of Succession, it is better to compute the
1056  /// weight based on the count plus 1, so universally add 1 to the value.
1057  ///
1058  /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1059  /// greater than \c Weight.
1060  static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1061    assert(Scale && "scale by 0?");
1062    uint64_t Scaled = Weight / Scale + 1;
1063    assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1064    return Scaled;
1065  }
1066  
1067  llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1068                                                      uint64_t FalseCount) const {
1069    // Check for empty weights.
1070    if (!TrueCount && !FalseCount)
1071      return nullptr;
1072  
1073    // Calculate how to scale down to 32-bits.
1074    uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1075  
1076    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1077    return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1078                                        scaleBranchWeight(FalseCount, Scale));
1079  }
1080  
1081  llvm::MDNode *
1082  CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1083    // We need at least two elements to create meaningful weights.
1084    if (Weights.size() < 2)
1085      return nullptr;
1086  
1087    // Check for empty weights.
1088    uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1089    if (MaxWeight == 0)
1090      return nullptr;
1091  
1092    // Calculate how to scale down to 32-bits.
1093    uint64_t Scale = calculateWeightScale(MaxWeight);
1094  
1095    SmallVector<uint32_t, 16> ScaledWeights;
1096    ScaledWeights.reserve(Weights.size());
1097    for (uint64_t W : Weights)
1098      ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1099  
1100    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1101    return MDHelper.createBranchWeights(ScaledWeights);
1102  }
1103  
1104  llvm::MDNode *
1105  CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1106                                               uint64_t LoopCount) const {
1107    if (!PGO.haveRegionCounts())
1108      return nullptr;
1109    Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1110    if (!CondCount || *CondCount == 0)
1111      return nullptr;
1112    return createProfileWeights(LoopCount,
1113                                std::max(*CondCount, LoopCount) - LoopCount);
1114  }
1115