xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1  //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // Instrumentation-based profile-guided optimization
10  //
11  //===----------------------------------------------------------------------===//
12  
13  #include "CodeGenPGO.h"
14  #include "CodeGenFunction.h"
15  #include "CoverageMappingGen.h"
16  #include "clang/AST/RecursiveASTVisitor.h"
17  #include "clang/AST/StmtVisitor.h"
18  #include "llvm/IR/Intrinsics.h"
19  #include "llvm/IR/MDBuilder.h"
20  #include "llvm/Support/CommandLine.h"
21  #include "llvm/Support/Endian.h"
22  #include "llvm/Support/FileSystem.h"
23  #include "llvm/Support/MD5.h"
24  #include <optional>
25  
26  namespace llvm {
27  extern cl::opt<bool> EnableSingleByteCoverage;
28  } // namespace llvm
29  
30  static llvm::cl::opt<bool>
31      EnableValueProfiling("enable-value-profiling",
32                           llvm::cl::desc("Enable value profiling"),
33                           llvm::cl::Hidden, llvm::cl::init(false));
34  
35  using namespace clang;
36  using namespace CodeGen;
37  
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)38  void CodeGenPGO::setFuncName(StringRef Name,
39                               llvm::GlobalValue::LinkageTypes Linkage) {
40    llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
41    FuncName = llvm::getPGOFuncName(
42        Name, Linkage, CGM.getCodeGenOpts().MainFileName,
43        PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
44  
45    // If we're generating a profile, create a variable for the name.
46    if (CGM.getCodeGenOpts().hasProfileClangInstr())
47      FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
48  }
49  
setFuncName(llvm::Function * Fn)50  void CodeGenPGO::setFuncName(llvm::Function *Fn) {
51    setFuncName(Fn->getName(), Fn->getLinkage());
52    // Create PGOFuncName meta data.
53    llvm::createPGOFuncNameMetadata(*Fn, FuncName);
54  }
55  
56  /// The version of the PGO hash algorithm.
57  enum PGOHashVersion : unsigned {
58    PGO_HASH_V1,
59    PGO_HASH_V2,
60    PGO_HASH_V3,
61  
62    // Keep this set to the latest hash version.
63    PGO_HASH_LATEST = PGO_HASH_V3
64  };
65  
66  namespace {
67  /// Stable hasher for PGO region counters.
68  ///
69  /// PGOHash produces a stable hash of a given function's control flow.
70  ///
71  /// Changing the output of this hash will invalidate all previously generated
72  /// profiles -- i.e., don't do it.
73  ///
74  /// \note  When this hash does eventually change (years?), we still need to
75  /// support old hashes.  We'll need to pull in the version number from the
76  /// profile data format and use the matching hash function.
77  class PGOHash {
78    uint64_t Working;
79    unsigned Count;
80    PGOHashVersion HashVersion;
81    llvm::MD5 MD5;
82  
83    static const int NumBitsPerType = 6;
84    static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
85    static const unsigned TooBig = 1u << NumBitsPerType;
86  
87  public:
88    /// Hash values for AST nodes.
89    ///
90    /// Distinct values for AST nodes that have region counters attached.
91    ///
92    /// These values must be stable.  All new members must be added at the end,
93    /// and no members should be removed.  Changing the enumeration value for an
94    /// AST node will affect the hash of every function that contains that node.
95    enum HashType : unsigned char {
96      None = 0,
97      LabelStmt = 1,
98      WhileStmt,
99      DoStmt,
100      ForStmt,
101      CXXForRangeStmt,
102      ObjCForCollectionStmt,
103      SwitchStmt,
104      CaseStmt,
105      DefaultStmt,
106      IfStmt,
107      CXXTryStmt,
108      CXXCatchStmt,
109      ConditionalOperator,
110      BinaryOperatorLAnd,
111      BinaryOperatorLOr,
112      BinaryConditionalOperator,
113      // The preceding values are available with PGO_HASH_V1.
114  
115      EndOfScope,
116      IfThenBranch,
117      IfElseBranch,
118      GotoStmt,
119      IndirectGotoStmt,
120      BreakStmt,
121      ContinueStmt,
122      ReturnStmt,
123      ThrowExpr,
124      UnaryOperatorLNot,
125      BinaryOperatorLT,
126      BinaryOperatorGT,
127      BinaryOperatorLE,
128      BinaryOperatorGE,
129      BinaryOperatorEQ,
130      BinaryOperatorNE,
131      // The preceding values are available since PGO_HASH_V2.
132  
133      // Keep this last.  It's for the static assert that follows.
134      LastHashType
135    };
136    static_assert(LastHashType <= TooBig, "Too many types in HashType");
137  
PGOHash(PGOHashVersion HashVersion)138    PGOHash(PGOHashVersion HashVersion)
139        : Working(0), Count(0), HashVersion(HashVersion) {}
140    void combine(HashType Type);
141    uint64_t finalize();
getHashVersion() const142    PGOHashVersion getHashVersion() const { return HashVersion; }
143  };
144  const int PGOHash::NumBitsPerType;
145  const unsigned PGOHash::NumTypesPerWord;
146  const unsigned PGOHash::TooBig;
147  
148  /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)149  static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
150                                          CodeGenModule &CGM) {
151    if (PGOReader->getVersion() <= 4)
152      return PGO_HASH_V1;
153    if (PGOReader->getVersion() <= 5)
154      return PGO_HASH_V2;
155    return PGO_HASH_V3;
156  }
157  
158  /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
159  struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
160    using Base = RecursiveASTVisitor<MapRegionCounters>;
161  
162    /// The next counter value to assign.
163    unsigned NextCounter;
164    /// The function hash.
165    PGOHash Hash;
166    /// The map of statements to counters.
167    llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
168    /// The state of MC/DC Coverage in this function.
169    MCDC::State &MCDCState;
170    /// Maximum number of supported MC/DC conditions in a boolean expression.
171    unsigned MCDCMaxCond;
172    /// The profile version.
173    uint64_t ProfileVersion;
174    /// Diagnostics Engine used to report warnings.
175    DiagnosticsEngine &Diag;
176  
MapRegionCounters__anon71781d390111::MapRegionCounters177    MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178                      llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
179                      MCDC::State &MCDCState, unsigned MCDCMaxCond,
180                      DiagnosticsEngine &Diag)
181        : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182          MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
183          ProfileVersion(ProfileVersion), Diag(Diag) {}
184  
185    // Blocks and lambdas are handled as separate functions, so we need not
186    // traverse them in the parent context.
TraverseBlockExpr__anon71781d390111::MapRegionCounters187    bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon71781d390111::MapRegionCounters188    bool TraverseLambdaExpr(LambdaExpr *LE) {
189      // Traverse the captures, but not the body.
190      for (auto C : zip(LE->captures(), LE->capture_inits()))
191        TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192      return true;
193    }
TraverseCapturedStmt__anon71781d390111::MapRegionCounters194    bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195  
VisitDecl__anon71781d390111::MapRegionCounters196    bool VisitDecl(const Decl *D) {
197      switch (D->getKind()) {
198      default:
199        break;
200      case Decl::Function:
201      case Decl::CXXMethod:
202      case Decl::CXXConstructor:
203      case Decl::CXXDestructor:
204      case Decl::CXXConversion:
205      case Decl::ObjCMethod:
206      case Decl::Block:
207      case Decl::Captured:
208        CounterMap[D->getBody()] = NextCounter++;
209        break;
210      }
211      return true;
212    }
213  
214    /// If \p S gets a fresh counter, update the counter mappings. Return the
215    /// V1 hash of \p S.
updateCounterMappings__anon71781d390111::MapRegionCounters216    PGOHash::HashType updateCounterMappings(Stmt *S) {
217      auto Type = getHashType(PGO_HASH_V1, S);
218      if (Type != PGOHash::None)
219        CounterMap[S] = NextCounter++;
220      return Type;
221    }
222  
223    /// The following stacks are used with dataTraverseStmtPre() and
224    /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225    /// boolean expression in a function.  The ultimate purpose is to keep track
226    /// of the number of leaf-level conditions in the boolean expression so that a
227    /// profile bitmap can be allocated based on that number.
228    ///
229    /// The stacks are also used to find error cases and notify the user.  A
230    /// standard logical operator nest for a boolean expression could be in a form
231    /// similar to this: "x = a && b && c && (d || f)"
232    unsigned NumCond = 0;
233    bool SplitNestedLogicalOp = false;
234    SmallVector<const Stmt *, 16> NonLogOpStack;
235    SmallVector<const BinaryOperator *, 16> LogOpStack;
236  
237    // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
dataTraverseStmtPre__anon71781d390111::MapRegionCounters238    bool dataTraverseStmtPre(Stmt *S) {
239      /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240      if (MCDCMaxCond == 0)
241        return true;
242  
243      /// At the top of the logical operator nest, reset the number of conditions,
244      /// also forget previously seen split nesting cases.
245      if (LogOpStack.empty()) {
246        NumCond = 0;
247        SplitNestedLogicalOp = false;
248      }
249  
250      if (const Expr *E = dyn_cast<Expr>(S)) {
251        const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
252        if (BinOp && BinOp->isLogicalOp()) {
253          /// Check for "split-nested" logical operators. This happens when a new
254          /// boolean expression logical-op nest is encountered within an existing
255          /// boolean expression, separated by a non-logical operator.  For
256          /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257          /// starts a new boolean expression that is separated from the other
258          /// conditions by the operator foo(). Split-nested cases are not
259          /// supported by MC/DC.
260          SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261  
262          LogOpStack.push_back(BinOp);
263          return true;
264        }
265      }
266  
267      /// Keep track of non-logical operators. These are OK as long as we don't
268      /// encounter a new logical operator after seeing one.
269      if (!LogOpStack.empty())
270        NonLogOpStack.push_back(S);
271  
272      return true;
273    }
274  
275    // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276    // an AST Stmt node.  MC/DC will use it to to signal when the top of a
277    // logical operation (boolean expression) nest is encountered.
dataTraverseStmtPost__anon71781d390111::MapRegionCounters278    bool dataTraverseStmtPost(Stmt *S) {
279      /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280      if (MCDCMaxCond == 0)
281        return true;
282  
283      if (const Expr *E = dyn_cast<Expr>(S)) {
284        const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
285        if (BinOp && BinOp->isLogicalOp()) {
286          assert(LogOpStack.back() == BinOp);
287          LogOpStack.pop_back();
288  
289          /// At the top of logical operator nest:
290          if (LogOpStack.empty()) {
291            /// Was the "split-nested" logical operator case encountered?
292            if (SplitNestedLogicalOp) {
293              unsigned DiagID = Diag.getCustomDiagID(
294                  DiagnosticsEngine::Warning,
295                  "unsupported MC/DC boolean expression; "
296                  "contains an operation with a nested boolean expression. "
297                  "Expression will not be covered");
298              Diag.Report(S->getBeginLoc(), DiagID);
299              return true;
300            }
301  
302            /// Was the maximum number of conditions encountered?
303            if (NumCond > MCDCMaxCond) {
304              unsigned DiagID = Diag.getCustomDiagID(
305                  DiagnosticsEngine::Warning,
306                  "unsupported MC/DC boolean expression; "
307                  "number of conditions (%0) exceeds max (%1). "
308                  "Expression will not be covered");
309              Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310              return true;
311            }
312  
313            // Otherwise, allocate the Decision.
314            MCDCState.DecisionByStmt[BinOp].BitmapIdx = 0;
315          }
316          return true;
317        }
318      }
319  
320      if (!LogOpStack.empty())
321        NonLogOpStack.pop_back();
322  
323      return true;
324    }
325  
326    /// The RHS of all logical operators gets a fresh counter in order to count
327    /// how many times the RHS evaluates to true or false, depending on the
328    /// semantics of the operator. This is only valid for ">= v7" of the profile
329    /// version so that we facilitate backward compatibility. In addition, in
330    /// order to use MC/DC, count the number of total LHS and RHS conditions.
VisitBinaryOperator__anon71781d390111::MapRegionCounters331    bool VisitBinaryOperator(BinaryOperator *S) {
332      if (S->isLogicalOp()) {
333        if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
334          NumCond++;
335  
336        if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
337          if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
338            CounterMap[S->getRHS()] = NextCounter++;
339  
340          NumCond++;
341        }
342      }
343      return Base::VisitBinaryOperator(S);
344    }
345  
VisitConditionalOperator__anon71781d390111::MapRegionCounters346    bool VisitConditionalOperator(ConditionalOperator *S) {
347      if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
348        CounterMap[S->getTrueExpr()] = NextCounter++;
349      if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
350        CounterMap[S->getFalseExpr()] = NextCounter++;
351      return Base::VisitConditionalOperator(S);
352    }
353  
354    /// Include \p S in the function hash.
VisitStmt__anon71781d390111::MapRegionCounters355    bool VisitStmt(Stmt *S) {
356      auto Type = updateCounterMappings(S);
357      if (Hash.getHashVersion() != PGO_HASH_V1)
358        Type = getHashType(Hash.getHashVersion(), S);
359      if (Type != PGOHash::None)
360        Hash.combine(Type);
361      return true;
362    }
363  
TraverseIfStmt__anon71781d390111::MapRegionCounters364    bool TraverseIfStmt(IfStmt *If) {
365      // If we used the V1 hash, use the default traversal.
366      if (Hash.getHashVersion() == PGO_HASH_V1)
367        return Base::TraverseIfStmt(If);
368  
369      // When single byte coverage mode is enabled, add a counter to then and
370      // else.
371      bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
372      for (Stmt *CS : If->children()) {
373        if (!CS || NoSingleByteCoverage)
374          continue;
375        if (CS == If->getThen())
376          CounterMap[If->getThen()] = NextCounter++;
377        else if (CS == If->getElse())
378          CounterMap[If->getElse()] = NextCounter++;
379      }
380  
381      // Otherwise, keep track of which branch we're in while traversing.
382      VisitStmt(If);
383  
384      for (Stmt *CS : If->children()) {
385        if (!CS)
386          continue;
387        if (CS == If->getThen())
388          Hash.combine(PGOHash::IfThenBranch);
389        else if (CS == If->getElse())
390          Hash.combine(PGOHash::IfElseBranch);
391        TraverseStmt(CS);
392      }
393      Hash.combine(PGOHash::EndOfScope);
394      return true;
395    }
396  
TraverseWhileStmt__anon71781d390111::MapRegionCounters397    bool TraverseWhileStmt(WhileStmt *While) {
398      // When single byte coverage mode is enabled, add a counter to condition and
399      // body.
400      bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
401      for (Stmt *CS : While->children()) {
402        if (!CS || NoSingleByteCoverage)
403          continue;
404        if (CS == While->getCond())
405          CounterMap[While->getCond()] = NextCounter++;
406        else if (CS == While->getBody())
407          CounterMap[While->getBody()] = NextCounter++;
408      }
409  
410      Base::TraverseWhileStmt(While);
411      if (Hash.getHashVersion() != PGO_HASH_V1)
412        Hash.combine(PGOHash::EndOfScope);
413      return true;
414    }
415  
TraverseDoStmt__anon71781d390111::MapRegionCounters416    bool TraverseDoStmt(DoStmt *Do) {
417      // When single byte coverage mode is enabled, add a counter to condition and
418      // body.
419      bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
420      for (Stmt *CS : Do->children()) {
421        if (!CS || NoSingleByteCoverage)
422          continue;
423        if (CS == Do->getCond())
424          CounterMap[Do->getCond()] = NextCounter++;
425        else if (CS == Do->getBody())
426          CounterMap[Do->getBody()] = NextCounter++;
427      }
428  
429      Base::TraverseDoStmt(Do);
430      if (Hash.getHashVersion() != PGO_HASH_V1)
431        Hash.combine(PGOHash::EndOfScope);
432      return true;
433    }
434  
TraverseForStmt__anon71781d390111::MapRegionCounters435    bool TraverseForStmt(ForStmt *For) {
436      // When single byte coverage mode is enabled, add a counter to condition,
437      // increment and body.
438      bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
439      for (Stmt *CS : For->children()) {
440        if (!CS || NoSingleByteCoverage)
441          continue;
442        if (CS == For->getCond())
443          CounterMap[For->getCond()] = NextCounter++;
444        else if (CS == For->getInc())
445          CounterMap[For->getInc()] = NextCounter++;
446        else if (CS == For->getBody())
447          CounterMap[For->getBody()] = NextCounter++;
448      }
449  
450      Base::TraverseForStmt(For);
451      if (Hash.getHashVersion() != PGO_HASH_V1)
452        Hash.combine(PGOHash::EndOfScope);
453      return true;
454    }
455  
TraverseCXXForRangeStmt__anon71781d390111::MapRegionCounters456    bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
457      // When single byte coverage mode is enabled, add a counter to body.
458      bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
459      for (Stmt *CS : ForRange->children()) {
460        if (!CS || NoSingleByteCoverage)
461          continue;
462        if (CS == ForRange->getBody())
463          CounterMap[ForRange->getBody()] = NextCounter++;
464      }
465  
466      Base::TraverseCXXForRangeStmt(ForRange);
467      if (Hash.getHashVersion() != PGO_HASH_V1)
468        Hash.combine(PGOHash::EndOfScope);
469      return true;
470    }
471  
472  // If the statement type \p N is nestable, and its nesting impacts profile
473  // stability, define a custom traversal which tracks the end of the statement
474  // in the hash (provided we're not using the V1 hash).
475  #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
476    bool Traverse##N(N *S) {                                                     \
477      Base::Traverse##N(S);                                                      \
478      if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
479        Hash.combine(PGOHash::EndOfScope);                                       \
480      return true;                                                               \
481    }
482  
483    DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
DEFINE_NESTABLE_TRAVERSAL__anon71781d390111::MapRegionCounters484    DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
485    DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
486  
487    /// Get version \p HashVersion of the PGO hash for \p S.
488    PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
489      switch (S->getStmtClass()) {
490      default:
491        break;
492      case Stmt::LabelStmtClass:
493        return PGOHash::LabelStmt;
494      case Stmt::WhileStmtClass:
495        return PGOHash::WhileStmt;
496      case Stmt::DoStmtClass:
497        return PGOHash::DoStmt;
498      case Stmt::ForStmtClass:
499        return PGOHash::ForStmt;
500      case Stmt::CXXForRangeStmtClass:
501        return PGOHash::CXXForRangeStmt;
502      case Stmt::ObjCForCollectionStmtClass:
503        return PGOHash::ObjCForCollectionStmt;
504      case Stmt::SwitchStmtClass:
505        return PGOHash::SwitchStmt;
506      case Stmt::CaseStmtClass:
507        return PGOHash::CaseStmt;
508      case Stmt::DefaultStmtClass:
509        return PGOHash::DefaultStmt;
510      case Stmt::IfStmtClass:
511        return PGOHash::IfStmt;
512      case Stmt::CXXTryStmtClass:
513        return PGOHash::CXXTryStmt;
514      case Stmt::CXXCatchStmtClass:
515        return PGOHash::CXXCatchStmt;
516      case Stmt::ConditionalOperatorClass:
517        return PGOHash::ConditionalOperator;
518      case Stmt::BinaryConditionalOperatorClass:
519        return PGOHash::BinaryConditionalOperator;
520      case Stmt::BinaryOperatorClass: {
521        const BinaryOperator *BO = cast<BinaryOperator>(S);
522        if (BO->getOpcode() == BO_LAnd)
523          return PGOHash::BinaryOperatorLAnd;
524        if (BO->getOpcode() == BO_LOr)
525          return PGOHash::BinaryOperatorLOr;
526        if (HashVersion >= PGO_HASH_V2) {
527          switch (BO->getOpcode()) {
528          default:
529            break;
530          case BO_LT:
531            return PGOHash::BinaryOperatorLT;
532          case BO_GT:
533            return PGOHash::BinaryOperatorGT;
534          case BO_LE:
535            return PGOHash::BinaryOperatorLE;
536          case BO_GE:
537            return PGOHash::BinaryOperatorGE;
538          case BO_EQ:
539            return PGOHash::BinaryOperatorEQ;
540          case BO_NE:
541            return PGOHash::BinaryOperatorNE;
542          }
543        }
544        break;
545      }
546      }
547  
548      if (HashVersion >= PGO_HASH_V2) {
549        switch (S->getStmtClass()) {
550        default:
551          break;
552        case Stmt::GotoStmtClass:
553          return PGOHash::GotoStmt;
554        case Stmt::IndirectGotoStmtClass:
555          return PGOHash::IndirectGotoStmt;
556        case Stmt::BreakStmtClass:
557          return PGOHash::BreakStmt;
558        case Stmt::ContinueStmtClass:
559          return PGOHash::ContinueStmt;
560        case Stmt::ReturnStmtClass:
561          return PGOHash::ReturnStmt;
562        case Stmt::CXXThrowExprClass:
563          return PGOHash::ThrowExpr;
564        case Stmt::UnaryOperatorClass: {
565          const UnaryOperator *UO = cast<UnaryOperator>(S);
566          if (UO->getOpcode() == UO_LNot)
567            return PGOHash::UnaryOperatorLNot;
568          break;
569        }
570        }
571      }
572  
573      return PGOHash::None;
574    }
575  };
576  
577  /// A StmtVisitor that propagates the raw counts through the AST and
578  /// records the count at statements where the value may change.
579  struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
580    /// PGO state.
581    CodeGenPGO &PGO;
582  
583    /// A flag that is set when the current count should be recorded on the
584    /// next statement, such as at the exit of a loop.
585    bool RecordNextStmtCount;
586  
587    /// The count at the current location in the traversal.
588    uint64_t CurrentCount;
589  
590    /// The map of statements to count values.
591    llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
592  
593    /// BreakContinueStack - Keep counts of breaks and continues inside loops.
594    struct BreakContinue {
595      uint64_t BreakCount = 0;
596      uint64_t ContinueCount = 0;
597      BreakContinue() = default;
598    };
599    SmallVector<BreakContinue, 8> BreakContinueStack;
600  
ComputeRegionCounts__anon71781d390111::ComputeRegionCounts601    ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
602                        CodeGenPGO &PGO)
603        : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
604  
RecordStmtCount__anon71781d390111::ComputeRegionCounts605    void RecordStmtCount(const Stmt *S) {
606      if (RecordNextStmtCount) {
607        CountMap[S] = CurrentCount;
608        RecordNextStmtCount = false;
609      }
610    }
611  
612    /// Set and return the current count.
setCount__anon71781d390111::ComputeRegionCounts613    uint64_t setCount(uint64_t Count) {
614      CurrentCount = Count;
615      return Count;
616    }
617  
VisitStmt__anon71781d390111::ComputeRegionCounts618    void VisitStmt(const Stmt *S) {
619      RecordStmtCount(S);
620      for (const Stmt *Child : S->children())
621        if (Child)
622          this->Visit(Child);
623    }
624  
VisitFunctionDecl__anon71781d390111::ComputeRegionCounts625    void VisitFunctionDecl(const FunctionDecl *D) {
626      // Counter tracks entry to the function body.
627      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
628      CountMap[D->getBody()] = BodyCount;
629      Visit(D->getBody());
630    }
631  
632    // Skip lambda expressions. We visit these as FunctionDecls when we're
633    // generating them and aren't interested in the body when generating a
634    // parent context.
VisitLambdaExpr__anon71781d390111::ComputeRegionCounts635    void VisitLambdaExpr(const LambdaExpr *LE) {}
636  
VisitCapturedDecl__anon71781d390111::ComputeRegionCounts637    void VisitCapturedDecl(const CapturedDecl *D) {
638      // Counter tracks entry to the capture body.
639      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
640      CountMap[D->getBody()] = BodyCount;
641      Visit(D->getBody());
642    }
643  
VisitObjCMethodDecl__anon71781d390111::ComputeRegionCounts644    void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
645      // Counter tracks entry to the method body.
646      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
647      CountMap[D->getBody()] = BodyCount;
648      Visit(D->getBody());
649    }
650  
VisitBlockDecl__anon71781d390111::ComputeRegionCounts651    void VisitBlockDecl(const BlockDecl *D) {
652      // Counter tracks entry to the block body.
653      uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
654      CountMap[D->getBody()] = BodyCount;
655      Visit(D->getBody());
656    }
657  
VisitReturnStmt__anon71781d390111::ComputeRegionCounts658    void VisitReturnStmt(const ReturnStmt *S) {
659      RecordStmtCount(S);
660      if (S->getRetValue())
661        Visit(S->getRetValue());
662      CurrentCount = 0;
663      RecordNextStmtCount = true;
664    }
665  
VisitCXXThrowExpr__anon71781d390111::ComputeRegionCounts666    void VisitCXXThrowExpr(const CXXThrowExpr *E) {
667      RecordStmtCount(E);
668      if (E->getSubExpr())
669        Visit(E->getSubExpr());
670      CurrentCount = 0;
671      RecordNextStmtCount = true;
672    }
673  
VisitGotoStmt__anon71781d390111::ComputeRegionCounts674    void VisitGotoStmt(const GotoStmt *S) {
675      RecordStmtCount(S);
676      CurrentCount = 0;
677      RecordNextStmtCount = true;
678    }
679  
VisitLabelStmt__anon71781d390111::ComputeRegionCounts680    void VisitLabelStmt(const LabelStmt *S) {
681      RecordNextStmtCount = false;
682      // Counter tracks the block following the label.
683      uint64_t BlockCount = setCount(PGO.getRegionCount(S));
684      CountMap[S] = BlockCount;
685      Visit(S->getSubStmt());
686    }
687  
VisitBreakStmt__anon71781d390111::ComputeRegionCounts688    void VisitBreakStmt(const BreakStmt *S) {
689      RecordStmtCount(S);
690      assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
691      BreakContinueStack.back().BreakCount += CurrentCount;
692      CurrentCount = 0;
693      RecordNextStmtCount = true;
694    }
695  
VisitContinueStmt__anon71781d390111::ComputeRegionCounts696    void VisitContinueStmt(const ContinueStmt *S) {
697      RecordStmtCount(S);
698      assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
699      BreakContinueStack.back().ContinueCount += CurrentCount;
700      CurrentCount = 0;
701      RecordNextStmtCount = true;
702    }
703  
VisitWhileStmt__anon71781d390111::ComputeRegionCounts704    void VisitWhileStmt(const WhileStmt *S) {
705      RecordStmtCount(S);
706      uint64_t ParentCount = CurrentCount;
707  
708      BreakContinueStack.push_back(BreakContinue());
709      // Visit the body region first so the break/continue adjustments can be
710      // included when visiting the condition.
711      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
712      CountMap[S->getBody()] = CurrentCount;
713      Visit(S->getBody());
714      uint64_t BackedgeCount = CurrentCount;
715  
716      // ...then go back and propagate counts through the condition. The count
717      // at the start of the condition is the sum of the incoming edges,
718      // the backedge from the end of the loop body, and the edges from
719      // continue statements.
720      BreakContinue BC = BreakContinueStack.pop_back_val();
721      uint64_t CondCount =
722          setCount(ParentCount + BackedgeCount + BC.ContinueCount);
723      CountMap[S->getCond()] = CondCount;
724      Visit(S->getCond());
725      setCount(BC.BreakCount + CondCount - BodyCount);
726      RecordNextStmtCount = true;
727    }
728  
VisitDoStmt__anon71781d390111::ComputeRegionCounts729    void VisitDoStmt(const DoStmt *S) {
730      RecordStmtCount(S);
731      uint64_t LoopCount = PGO.getRegionCount(S);
732  
733      BreakContinueStack.push_back(BreakContinue());
734      // The count doesn't include the fallthrough from the parent scope. Add it.
735      uint64_t BodyCount = setCount(LoopCount + CurrentCount);
736      CountMap[S->getBody()] = BodyCount;
737      Visit(S->getBody());
738      uint64_t BackedgeCount = CurrentCount;
739  
740      BreakContinue BC = BreakContinueStack.pop_back_val();
741      // The count at the start of the condition is equal to the count at the
742      // end of the body, plus any continues.
743      uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
744      CountMap[S->getCond()] = CondCount;
745      Visit(S->getCond());
746      setCount(BC.BreakCount + CondCount - LoopCount);
747      RecordNextStmtCount = true;
748    }
749  
VisitForStmt__anon71781d390111::ComputeRegionCounts750    void VisitForStmt(const ForStmt *S) {
751      RecordStmtCount(S);
752      if (S->getInit())
753        Visit(S->getInit());
754  
755      uint64_t ParentCount = CurrentCount;
756  
757      BreakContinueStack.push_back(BreakContinue());
758      // Visit the body region first. (This is basically the same as a while
759      // loop; see further comments in VisitWhileStmt.)
760      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
761      CountMap[S->getBody()] = BodyCount;
762      Visit(S->getBody());
763      uint64_t BackedgeCount = CurrentCount;
764      BreakContinue BC = BreakContinueStack.pop_back_val();
765  
766      // The increment is essentially part of the body but it needs to include
767      // the count for all the continue statements.
768      if (S->getInc()) {
769        uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
770        CountMap[S->getInc()] = IncCount;
771        Visit(S->getInc());
772      }
773  
774      // ...then go back and propagate counts through the condition.
775      uint64_t CondCount =
776          setCount(ParentCount + BackedgeCount + BC.ContinueCount);
777      if (S->getCond()) {
778        CountMap[S->getCond()] = CondCount;
779        Visit(S->getCond());
780      }
781      setCount(BC.BreakCount + CondCount - BodyCount);
782      RecordNextStmtCount = true;
783    }
784  
VisitCXXForRangeStmt__anon71781d390111::ComputeRegionCounts785    void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
786      RecordStmtCount(S);
787      if (S->getInit())
788        Visit(S->getInit());
789      Visit(S->getLoopVarStmt());
790      Visit(S->getRangeStmt());
791      Visit(S->getBeginStmt());
792      Visit(S->getEndStmt());
793  
794      uint64_t ParentCount = CurrentCount;
795      BreakContinueStack.push_back(BreakContinue());
796      // Visit the body region first. (This is basically the same as a while
797      // loop; see further comments in VisitWhileStmt.)
798      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
799      CountMap[S->getBody()] = BodyCount;
800      Visit(S->getBody());
801      uint64_t BackedgeCount = CurrentCount;
802      BreakContinue BC = BreakContinueStack.pop_back_val();
803  
804      // The increment is essentially part of the body but it needs to include
805      // the count for all the continue statements.
806      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
807      CountMap[S->getInc()] = IncCount;
808      Visit(S->getInc());
809  
810      // ...then go back and propagate counts through the condition.
811      uint64_t CondCount =
812          setCount(ParentCount + BackedgeCount + BC.ContinueCount);
813      CountMap[S->getCond()] = CondCount;
814      Visit(S->getCond());
815      setCount(BC.BreakCount + CondCount - BodyCount);
816      RecordNextStmtCount = true;
817    }
818  
VisitObjCForCollectionStmt__anon71781d390111::ComputeRegionCounts819    void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
820      RecordStmtCount(S);
821      Visit(S->getElement());
822      uint64_t ParentCount = CurrentCount;
823      BreakContinueStack.push_back(BreakContinue());
824      // Counter tracks the body of the loop.
825      uint64_t BodyCount = setCount(PGO.getRegionCount(S));
826      CountMap[S->getBody()] = BodyCount;
827      Visit(S->getBody());
828      uint64_t BackedgeCount = CurrentCount;
829      BreakContinue BC = BreakContinueStack.pop_back_val();
830  
831      setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
832               BodyCount);
833      RecordNextStmtCount = true;
834    }
835  
VisitSwitchStmt__anon71781d390111::ComputeRegionCounts836    void VisitSwitchStmt(const SwitchStmt *S) {
837      RecordStmtCount(S);
838      if (S->getInit())
839        Visit(S->getInit());
840      Visit(S->getCond());
841      CurrentCount = 0;
842      BreakContinueStack.push_back(BreakContinue());
843      Visit(S->getBody());
844      // If the switch is inside a loop, add the continue counts.
845      BreakContinue BC = BreakContinueStack.pop_back_val();
846      if (!BreakContinueStack.empty())
847        BreakContinueStack.back().ContinueCount += BC.ContinueCount;
848      // Counter tracks the exit block of the switch.
849      setCount(PGO.getRegionCount(S));
850      RecordNextStmtCount = true;
851    }
852  
VisitSwitchCase__anon71781d390111::ComputeRegionCounts853    void VisitSwitchCase(const SwitchCase *S) {
854      RecordNextStmtCount = false;
855      // Counter for this particular case. This counts only jumps from the
856      // switch header and does not include fallthrough from the case before
857      // this one.
858      uint64_t CaseCount = PGO.getRegionCount(S);
859      setCount(CurrentCount + CaseCount);
860      // We need the count without fallthrough in the mapping, so it's more useful
861      // for branch probabilities.
862      CountMap[S] = CaseCount;
863      RecordNextStmtCount = true;
864      Visit(S->getSubStmt());
865    }
866  
VisitIfStmt__anon71781d390111::ComputeRegionCounts867    void VisitIfStmt(const IfStmt *S) {
868      RecordStmtCount(S);
869  
870      if (S->isConsteval()) {
871        const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
872        if (Stm)
873          Visit(Stm);
874        return;
875      }
876  
877      uint64_t ParentCount = CurrentCount;
878      if (S->getInit())
879        Visit(S->getInit());
880      Visit(S->getCond());
881  
882      // Counter tracks the "then" part of an if statement. The count for
883      // the "else" part, if it exists, will be calculated from this counter.
884      uint64_t ThenCount = setCount(PGO.getRegionCount(S));
885      CountMap[S->getThen()] = ThenCount;
886      Visit(S->getThen());
887      uint64_t OutCount = CurrentCount;
888  
889      uint64_t ElseCount = ParentCount - ThenCount;
890      if (S->getElse()) {
891        setCount(ElseCount);
892        CountMap[S->getElse()] = ElseCount;
893        Visit(S->getElse());
894        OutCount += CurrentCount;
895      } else
896        OutCount += ElseCount;
897      setCount(OutCount);
898      RecordNextStmtCount = true;
899    }
900  
VisitCXXTryStmt__anon71781d390111::ComputeRegionCounts901    void VisitCXXTryStmt(const CXXTryStmt *S) {
902      RecordStmtCount(S);
903      Visit(S->getTryBlock());
904      for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
905        Visit(S->getHandler(I));
906      // Counter tracks the continuation block of the try statement.
907      setCount(PGO.getRegionCount(S));
908      RecordNextStmtCount = true;
909    }
910  
VisitCXXCatchStmt__anon71781d390111::ComputeRegionCounts911    void VisitCXXCatchStmt(const CXXCatchStmt *S) {
912      RecordNextStmtCount = false;
913      // Counter tracks the catch statement's handler block.
914      uint64_t CatchCount = setCount(PGO.getRegionCount(S));
915      CountMap[S] = CatchCount;
916      Visit(S->getHandlerBlock());
917    }
918  
VisitAbstractConditionalOperator__anon71781d390111::ComputeRegionCounts919    void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
920      RecordStmtCount(E);
921      uint64_t ParentCount = CurrentCount;
922      Visit(E->getCond());
923  
924      // Counter tracks the "true" part of a conditional operator. The
925      // count in the "false" part will be calculated from this counter.
926      uint64_t TrueCount = setCount(PGO.getRegionCount(E));
927      CountMap[E->getTrueExpr()] = TrueCount;
928      Visit(E->getTrueExpr());
929      uint64_t OutCount = CurrentCount;
930  
931      uint64_t FalseCount = setCount(ParentCount - TrueCount);
932      CountMap[E->getFalseExpr()] = FalseCount;
933      Visit(E->getFalseExpr());
934      OutCount += CurrentCount;
935  
936      setCount(OutCount);
937      RecordNextStmtCount = true;
938    }
939  
VisitBinLAnd__anon71781d390111::ComputeRegionCounts940    void VisitBinLAnd(const BinaryOperator *E) {
941      RecordStmtCount(E);
942      uint64_t ParentCount = CurrentCount;
943      Visit(E->getLHS());
944      // Counter tracks the right hand side of a logical and operator.
945      uint64_t RHSCount = setCount(PGO.getRegionCount(E));
946      CountMap[E->getRHS()] = RHSCount;
947      Visit(E->getRHS());
948      setCount(ParentCount + RHSCount - CurrentCount);
949      RecordNextStmtCount = true;
950    }
951  
VisitBinLOr__anon71781d390111::ComputeRegionCounts952    void VisitBinLOr(const BinaryOperator *E) {
953      RecordStmtCount(E);
954      uint64_t ParentCount = CurrentCount;
955      Visit(E->getLHS());
956      // Counter tracks the right hand side of a logical or operator.
957      uint64_t RHSCount = setCount(PGO.getRegionCount(E));
958      CountMap[E->getRHS()] = RHSCount;
959      Visit(E->getRHS());
960      setCount(ParentCount + RHSCount - CurrentCount);
961      RecordNextStmtCount = true;
962    }
963  };
964  } // end anonymous namespace
965  
combine(HashType Type)966  void PGOHash::combine(HashType Type) {
967    // Check that we never combine 0 and only have six bits.
968    assert(Type && "Hash is invalid: unexpected type 0");
969    assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
970  
971    // Pass through MD5 if enough work has built up.
972    if (Count && Count % NumTypesPerWord == 0) {
973      using namespace llvm::support;
974      uint64_t Swapped =
975          endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
976      MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
977      Working = 0;
978    }
979  
980    // Accumulate the current type.
981    ++Count;
982    Working = Working << NumBitsPerType | Type;
983  }
984  
finalize()985  uint64_t PGOHash::finalize() {
986    // Use Working as the hash directly if we never used MD5.
987    if (Count <= NumTypesPerWord)
988      // No need to byte swap here, since none of the math was endian-dependent.
989      // This number will be byte-swapped as required on endianness transitions,
990      // so we will see the same value on the other side.
991      return Working;
992  
993    // Check for remaining work in Working.
994    if (Working) {
995      // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
996      // is buggy because it converts a uint64_t into an array of uint8_t.
997      if (HashVersion < PGO_HASH_V3) {
998        MD5.update({(uint8_t)Working});
999      } else {
1000        using namespace llvm::support;
1001        uint64_t Swapped =
1002            endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
1003        MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1004      }
1005    }
1006  
1007    // Finalize the MD5 and return the hash.
1008    llvm::MD5::MD5Result Result;
1009    MD5.final(Result);
1010    return Result.low();
1011  }
1012  
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)1013  void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1014    const Decl *D = GD.getDecl();
1015    if (!D->hasBody())
1016      return;
1017  
1018    // Skip CUDA/HIP kernel launch stub functions.
1019    if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1020        D->hasAttr<CUDAGlobalAttr>())
1021      return;
1022  
1023    bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1024    llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1025    if (!InstrumentRegions && !PGOReader)
1026      return;
1027    if (D->isImplicit())
1028      return;
1029    // Constructors and destructors may be represented by several functions in IR.
1030    // If so, instrument only base variant, others are implemented by delegation
1031    // to the base one, it would be counted twice otherwise.
1032    if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1033      if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
1034        if (GD.getCtorType() != Ctor_Base &&
1035            CodeGenFunction::IsConstructorDelegationValid(CCD))
1036          return;
1037    }
1038    if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
1039      return;
1040  
1041    CGM.ClearUnusedCoverageMapping(D);
1042    if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
1043      return;
1044    if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
1045      return;
1046  
1047    SourceManager &SM = CGM.getContext().getSourceManager();
1048    if (!llvm::coverage::SystemHeadersCoverage &&
1049        SM.isInSystemHeader(D->getLocation()))
1050      return;
1051  
1052    setFuncName(Fn);
1053  
1054    mapRegionCounters(D);
1055    if (CGM.getCodeGenOpts().CoverageMapping)
1056      emitCounterRegionMapping(D);
1057    if (PGOReader) {
1058      loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
1059      computeRegionCounts(D);
1060      applyFunctionAttributes(PGOReader, Fn);
1061    }
1062  }
1063  
mapRegionCounters(const Decl * D)1064  void CodeGenPGO::mapRegionCounters(const Decl *D) {
1065    // Use the latest hash version when inserting instrumentation, but use the
1066    // version in the indexed profile if we're reading PGO data.
1067    PGOHashVersion HashVersion = PGO_HASH_LATEST;
1068    uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1069    if (auto *PGOReader = CGM.getPGOReader()) {
1070      HashVersion = getPGOHashVersion(PGOReader, CGM);
1071      ProfileVersion = PGOReader->getVersion();
1072    }
1073  
1074    // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1075    // set it to zero. This value impacts the number of conditions accepted in a
1076    // given boolean expression, which impacts the size of the bitmap used to
1077    // track test vector execution for that boolean expression.  Because the
1078    // bitmap scales exponentially (2^n) based on the number of conditions seen,
1079    // the maximum value is hard-coded at 6 conditions, which is more than enough
1080    // for most embedded applications. Setting a maximum value prevents the
1081    // bitmap footprint from growing too large without the user's knowledge. In
1082    // the future, this value could be adjusted with a command-line option.
1083    unsigned MCDCMaxConditions =
1084        (CGM.getCodeGenOpts().MCDCCoverage ? CGM.getCodeGenOpts().MCDCMaxConds
1085                                           : 0);
1086  
1087    RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1088    RegionMCDCState.reset(new MCDC::State);
1089    MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1090                             *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1091    if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1092      Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
1093    else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1094      Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1095    else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1096      Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1097    else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1098      Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1099    assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1100    NumRegionCounters = Walker.NextCounter;
1101    FunctionHash = Walker.Hash.finalize();
1102  }
1103  
skipRegionMappingForDecl(const Decl * D)1104  bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1105    if (!D->getBody())
1106      return true;
1107  
1108    // Skip host-only functions in the CUDA device compilation and device-only
1109    // functions in the host compilation. Just roughly filter them out based on
1110    // the function attributes. If there are effectively host-only or device-only
1111    // ones, their coverage mapping may still be generated.
1112    if (CGM.getLangOpts().CUDA &&
1113        ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1114          !D->hasAttr<CUDAGlobalAttr>()) ||
1115         (!CGM.getLangOpts().CUDAIsDevice &&
1116          (D->hasAttr<CUDAGlobalAttr>() ||
1117           (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1118      return true;
1119  
1120    // Don't map the functions in system headers.
1121    const auto &SM = CGM.getContext().getSourceManager();
1122    auto Loc = D->getBody()->getBeginLoc();
1123    return !llvm::coverage::SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1124  }
1125  
emitCounterRegionMapping(const Decl * D)1126  void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1127    if (skipRegionMappingForDecl(D))
1128      return;
1129  
1130    std::string CoverageMapping;
1131    llvm::raw_string_ostream OS(CoverageMapping);
1132    RegionMCDCState->BranchByStmt.clear();
1133    CoverageMappingGen MappingGen(
1134        *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1135        CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1136    MappingGen.emitCounterMapping(D, OS);
1137    OS.flush();
1138  
1139    if (CoverageMapping.empty())
1140      return;
1141  
1142    CGM.getCoverageMapping()->addFunctionMappingRecord(
1143        FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1144  }
1145  
1146  void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)1147  CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1148                                      llvm::GlobalValue::LinkageTypes Linkage) {
1149    if (skipRegionMappingForDecl(D))
1150      return;
1151  
1152    std::string CoverageMapping;
1153    llvm::raw_string_ostream OS(CoverageMapping);
1154    CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1155                                  CGM.getContext().getSourceManager(),
1156                                  CGM.getLangOpts());
1157    MappingGen.emitEmptyMapping(D, OS);
1158    OS.flush();
1159  
1160    if (CoverageMapping.empty())
1161      return;
1162  
1163    setFuncName(Name, Linkage);
1164    CGM.getCoverageMapping()->addFunctionMappingRecord(
1165        FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1166  }
1167  
computeRegionCounts(const Decl * D)1168  void CodeGenPGO::computeRegionCounts(const Decl *D) {
1169    StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1170    ComputeRegionCounts Walker(*StmtCountMap, *this);
1171    if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1172      Walker.VisitFunctionDecl(FD);
1173    else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1174      Walker.VisitObjCMethodDecl(MD);
1175    else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1176      Walker.VisitBlockDecl(BD);
1177    else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1178      Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1179  }
1180  
1181  void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)1182  CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1183                                      llvm::Function *Fn) {
1184    if (!haveRegionCounts())
1185      return;
1186  
1187    uint64_t FunctionCount = getRegionCount(nullptr);
1188    Fn->setEntryCount(FunctionCount);
1189  }
1190  
emitCounterSetOrIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)1191  void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1192                                             llvm::Value *StepV) {
1193    if (!RegionCounterMap || !Builder.GetInsertBlock())
1194      return;
1195  
1196    unsigned Counter = (*RegionCounterMap)[S];
1197  
1198    llvm::Value *Args[] = {FuncNameVar,
1199                           Builder.getInt64(FunctionHash),
1200                           Builder.getInt32(NumRegionCounters),
1201                           Builder.getInt32(Counter), StepV};
1202  
1203    if (llvm::EnableSingleByteCoverage)
1204      Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_cover),
1205                         ArrayRef(Args, 4));
1206    else {
1207      if (!StepV)
1208        Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1209                           ArrayRef(Args, 4));
1210      else
1211        Builder.CreateCall(
1212            CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), Args);
1213    }
1214  }
1215  
canEmitMCDCCoverage(const CGBuilderTy & Builder)1216  bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1217    return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1218            CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1219  }
1220  
emitMCDCParameters(CGBuilderTy & Builder)1221  void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1222    if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1223      return;
1224  
1225    auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1226  
1227    // Emit intrinsic representing MCDC bitmap parameters at function entry.
1228    // This is used by the instrumentation pass, but it isn't actually lowered to
1229    // anything.
1230    llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1231                            Builder.getInt64(FunctionHash),
1232                            Builder.getInt32(RegionMCDCState->BitmapBits)};
1233    Builder.CreateCall(
1234        CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1235  }
1236  
emitMCDCTestVectorBitmapUpdate(CGBuilderTy & Builder,const Expr * S,Address MCDCCondBitmapAddr,CodeGenFunction & CGF)1237  void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1238                                                  const Expr *S,
1239                                                  Address MCDCCondBitmapAddr,
1240                                                  CodeGenFunction &CGF) {
1241    if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1242      return;
1243  
1244    S = S->IgnoreParens();
1245  
1246    auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(S);
1247    if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1248      return;
1249  
1250    // Don't create tvbitmap_update if the record is allocated but excluded.
1251    // Or `bitmap |= (1 << 0)` would be wrongly executed to the next bitmap.
1252    if (DecisionStateIter->second.Indices.size() == 0)
1253      return;
1254  
1255    // Extract the offset of the global bitmap associated with this expression.
1256    unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1257    auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1258  
1259    // Emit intrinsic responsible for updating the global bitmap corresponding to
1260    // a boolean expression. The index being set is based on the value loaded
1261    // from a pointer to a dedicated temporary value on the stack that is itself
1262    // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1263    // index represents an executed test vector.
1264    llvm::Value *Args[4] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1265                            Builder.getInt64(FunctionHash),
1266                            Builder.getInt32(MCDCTestVectorBitmapOffset),
1267                            MCDCCondBitmapAddr.emitRawPointer(CGF)};
1268    Builder.CreateCall(
1269        CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1270  }
1271  
emitMCDCCondBitmapReset(CGBuilderTy & Builder,const Expr * S,Address MCDCCondBitmapAddr)1272  void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1273                                           Address MCDCCondBitmapAddr) {
1274    if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1275      return;
1276  
1277    S = S->IgnoreParens();
1278  
1279    if (!RegionMCDCState->DecisionByStmt.contains(S))
1280      return;
1281  
1282    // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1283    Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1284  }
1285  
emitMCDCCondBitmapUpdate(CGBuilderTy & Builder,const Expr * S,Address MCDCCondBitmapAddr,llvm::Value * Val,CodeGenFunction & CGF)1286  void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1287                                            Address MCDCCondBitmapAddr,
1288                                            llvm::Value *Val,
1289                                            CodeGenFunction &CGF) {
1290    if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1291      return;
1292  
1293    // Even though, for simplicity, parentheses and unary logical-NOT operators
1294    // are considered part of their underlying condition for both MC/DC and
1295    // branch coverage, the condition IDs themselves are assigned and tracked
1296    // using the underlying condition itself.  This is done solely for
1297    // consistency since parentheses and logical-NOTs are ignored when checking
1298    // whether the condition is actually an instrumentable condition. This can
1299    // also make debugging a bit easier.
1300    S = CodeGenFunction::stripCond(S);
1301  
1302    auto BranchStateIter = RegionMCDCState->BranchByStmt.find(S);
1303    if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1304      return;
1305  
1306    // Extract the ID of the condition we are setting in the bitmap.
1307    const auto &Branch = BranchStateIter->second;
1308    assert(Branch.ID >= 0 && "Condition has no ID!");
1309    assert(Branch.DecisionStmt);
1310  
1311    // Cancel the emission if the Decision is erased after the allocation.
1312    const auto DecisionIter =
1313        RegionMCDCState->DecisionByStmt.find(Branch.DecisionStmt);
1314    if (DecisionIter == RegionMCDCState->DecisionByStmt.end())
1315      return;
1316  
1317    const auto &TVIdxs = DecisionIter->second.Indices[Branch.ID];
1318  
1319    auto *CurTV = Builder.CreateLoad(MCDCCondBitmapAddr,
1320                                     "mcdc." + Twine(Branch.ID + 1) + ".cur");
1321    auto *NewTV = Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[true]));
1322    NewTV = Builder.CreateSelect(
1323        Val, NewTV, Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[false])));
1324    Builder.CreateStore(NewTV, MCDCCondBitmapAddr);
1325  }
1326  
setValueProfilingFlag(llvm::Module & M)1327  void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1328    if (CGM.getCodeGenOpts().hasProfileClangInstr())
1329      M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1330                      uint32_t(EnableValueProfiling));
1331  }
1332  
setProfileVersion(llvm::Module & M)1333  void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1334    if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1335        llvm::EnableSingleByteCoverage) {
1336      const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1337      llvm::Type *IntTy64 = llvm::Type::getInt64Ty(M.getContext());
1338      uint64_t ProfileVersion =
1339          (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1340  
1341      auto IRLevelVersionVariable = new llvm::GlobalVariable(
1342          M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1343          llvm::Constant::getIntegerValue(IntTy64,
1344                                          llvm::APInt(64, ProfileVersion)),
1345          VarName);
1346  
1347      IRLevelVersionVariable->setVisibility(llvm::GlobalValue::HiddenVisibility);
1348      llvm::Triple TT(M.getTargetTriple());
1349      if (TT.supportsCOMDAT()) {
1350        IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1351        IRLevelVersionVariable->setComdat(M.getOrInsertComdat(VarName));
1352      }
1353      IRLevelVersionVariable->setDSOLocal(true);
1354    }
1355  }
1356  
1357  // This method either inserts a call to the profile run-time during
1358  // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)1359  void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1360      llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1361  
1362    if (!EnableValueProfiling)
1363      return;
1364  
1365    if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1366      return;
1367  
1368    if (isa<llvm::Constant>(ValuePtr))
1369      return;
1370  
1371    bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1372    if (InstrumentValueSites && RegionCounterMap) {
1373      auto BuilderInsertPoint = Builder.saveIP();
1374      Builder.SetInsertPoint(ValueSite);
1375      llvm::Value *Args[5] = {
1376          FuncNameVar,
1377          Builder.getInt64(FunctionHash),
1378          Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1379          Builder.getInt32(ValueKind),
1380          Builder.getInt32(NumValueSites[ValueKind]++)
1381      };
1382      Builder.CreateCall(
1383          CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1384      Builder.restoreIP(BuilderInsertPoint);
1385      return;
1386    }
1387  
1388    llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1389    if (PGOReader && haveRegionCounts()) {
1390      // We record the top most called three functions at each call site.
1391      // Profile metadata contains "VP" string identifying this metadata
1392      // as value profiling data, then a uint32_t value for the value profiling
1393      // kind, a uint64_t value for the total number of times the call is
1394      // executed, followed by the function hash and execution count (uint64_t)
1395      // pairs for each function.
1396      if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1397        return;
1398  
1399      llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1400                              (llvm::InstrProfValueKind)ValueKind,
1401                              NumValueSites[ValueKind]);
1402  
1403      NumValueSites[ValueKind]++;
1404    }
1405  }
1406  
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1407  void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1408                                    bool IsInMainFile) {
1409    CGM.getPGOStats().addVisited(IsInMainFile);
1410    RegionCounts.clear();
1411    llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1412        PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1413    if (auto E = RecordExpected.takeError()) {
1414      auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1415      if (IPE == llvm::instrprof_error::unknown_function)
1416        CGM.getPGOStats().addMissing(IsInMainFile);
1417      else if (IPE == llvm::instrprof_error::hash_mismatch)
1418        CGM.getPGOStats().addMismatched(IsInMainFile);
1419      else if (IPE == llvm::instrprof_error::malformed)
1420        // TODO: Consider a more specific warning for this case.
1421        CGM.getPGOStats().addMismatched(IsInMainFile);
1422      return;
1423    }
1424    ProfRecord =
1425        std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1426    RegionCounts = ProfRecord->Counts;
1427  }
1428  
1429  /// Calculate what to divide by to scale weights.
1430  ///
1431  /// Given the maximum weight, calculate a divisor that will scale all the
1432  /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1433  static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1434    return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1435  }
1436  
1437  /// Scale an individual branch weight (and add 1).
1438  ///
1439  /// Scale a 64-bit weight down to 32-bits using \c Scale.
1440  ///
1441  /// According to Laplace's Rule of Succession, it is better to compute the
1442  /// weight based on the count plus 1, so universally add 1 to the value.
1443  ///
1444  /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1445  /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1446  static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1447    assert(Scale && "scale by 0?");
1448    uint64_t Scaled = Weight / Scale + 1;
1449    assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1450    return Scaled;
1451  }
1452  
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1453  llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1454                                                      uint64_t FalseCount) const {
1455    // Check for empty weights.
1456    if (!TrueCount && !FalseCount)
1457      return nullptr;
1458  
1459    // Calculate how to scale down to 32-bits.
1460    uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1461  
1462    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1463    return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1464                                        scaleBranchWeight(FalseCount, Scale));
1465  }
1466  
1467  llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1468  CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1469    // We need at least two elements to create meaningful weights.
1470    if (Weights.size() < 2)
1471      return nullptr;
1472  
1473    // Check for empty weights.
1474    uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1475    if (MaxWeight == 0)
1476      return nullptr;
1477  
1478    // Calculate how to scale down to 32-bits.
1479    uint64_t Scale = calculateWeightScale(MaxWeight);
1480  
1481    SmallVector<uint32_t, 16> ScaledWeights;
1482    ScaledWeights.reserve(Weights.size());
1483    for (uint64_t W : Weights)
1484      ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1485  
1486    llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1487    return MDHelper.createBranchWeights(ScaledWeights);
1488  }
1489  
1490  llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1491  CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1492                                               uint64_t LoopCount) const {
1493    if (!PGO.haveRegionCounts())
1494      return nullptr;
1495    std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1496    if (!CondCount || *CondCount == 0)
1497      return nullptr;
1498    return createProfileWeights(LoopCount,
1499                                std::max(*CondCount, LoopCount) - LoopCount);
1500  }
1501