xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CGDebugInfo.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 namespace llvm {
27 extern cl::opt<bool> EnableSingleByteCoverage;
28 } // namespace llvm
29 
30 static llvm::cl::opt<bool>
31     EnableValueProfiling("enable-value-profiling",
32                          llvm::cl::desc("Enable value profiling"),
33                          llvm::cl::Hidden, llvm::cl::init(false));
34 
35 using namespace clang;
36 using namespace CodeGen;
37 
38 void CodeGenPGO::setFuncName(StringRef Name,
39                              llvm::GlobalValue::LinkageTypes Linkage) {
40   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
41   FuncName = llvm::getPGOFuncName(
42       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
43       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
44 
45   // If we're generating a profile, create a variable for the name.
46   if (CGM.getCodeGenOpts().hasProfileClangInstr())
47     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
48 }
49 
50 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
51   setFuncName(Fn->getName(), Fn->getLinkage());
52   // Create PGOFuncName meta data.
53   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
54 }
55 
56 /// The version of the PGO hash algorithm.
57 enum PGOHashVersion : unsigned {
58   PGO_HASH_V1,
59   PGO_HASH_V2,
60   PGO_HASH_V3,
61 
62   // Keep this set to the latest hash version.
63   PGO_HASH_LATEST = PGO_HASH_V3
64 };
65 
66 namespace {
67 /// Stable hasher for PGO region counters.
68 ///
69 /// PGOHash produces a stable hash of a given function's control flow.
70 ///
71 /// Changing the output of this hash will invalidate all previously generated
72 /// profiles -- i.e., don't do it.
73 ///
74 /// \note  When this hash does eventually change (years?), we still need to
75 /// support old hashes.  We'll need to pull in the version number from the
76 /// profile data format and use the matching hash function.
77 class PGOHash {
78   uint64_t Working;
79   unsigned Count;
80   PGOHashVersion HashVersion;
81   llvm::MD5 MD5;
82 
83   static const int NumBitsPerType = 6;
84   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
85   static const unsigned TooBig = 1u << NumBitsPerType;
86 
87 public:
88   /// Hash values for AST nodes.
89   ///
90   /// Distinct values for AST nodes that have region counters attached.
91   ///
92   /// These values must be stable.  All new members must be added at the end,
93   /// and no members should be removed.  Changing the enumeration value for an
94   /// AST node will affect the hash of every function that contains that node.
95   enum HashType : unsigned char {
96     None = 0,
97     LabelStmt = 1,
98     WhileStmt,
99     DoStmt,
100     ForStmt,
101     CXXForRangeStmt,
102     ObjCForCollectionStmt,
103     SwitchStmt,
104     CaseStmt,
105     DefaultStmt,
106     IfStmt,
107     CXXTryStmt,
108     CXXCatchStmt,
109     ConditionalOperator,
110     BinaryOperatorLAnd,
111     BinaryOperatorLOr,
112     BinaryConditionalOperator,
113     // The preceding values are available with PGO_HASH_V1.
114 
115     EndOfScope,
116     IfThenBranch,
117     IfElseBranch,
118     GotoStmt,
119     IndirectGotoStmt,
120     BreakStmt,
121     ContinueStmt,
122     ReturnStmt,
123     ThrowExpr,
124     UnaryOperatorLNot,
125     BinaryOperatorLT,
126     BinaryOperatorGT,
127     BinaryOperatorLE,
128     BinaryOperatorGE,
129     BinaryOperatorEQ,
130     BinaryOperatorNE,
131     // The preceding values are available since PGO_HASH_V2.
132 
133     // Keep this last.  It's for the static assert that follows.
134     LastHashType
135   };
136   static_assert(LastHashType <= TooBig, "Too many types in HashType");
137 
138   PGOHash(PGOHashVersion HashVersion)
139       : Working(0), Count(0), HashVersion(HashVersion) {}
140   void combine(HashType Type);
141   uint64_t finalize();
142   PGOHashVersion getHashVersion() const { return HashVersion; }
143 };
144 const int PGOHash::NumBitsPerType;
145 const unsigned PGOHash::NumTypesPerWord;
146 const unsigned PGOHash::TooBig;
147 
148 /// Get the PGO hash version used in the given indexed profile.
149 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
150                                         CodeGenModule &CGM) {
151   if (PGOReader->getVersion() <= 4)
152     return PGO_HASH_V1;
153   if (PGOReader->getVersion() <= 5)
154     return PGO_HASH_V2;
155   return PGO_HASH_V3;
156 }
157 
158 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
159 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
160   using Base = RecursiveASTVisitor<MapRegionCounters>;
161 
162   /// The next counter value to assign.
163   unsigned NextCounter;
164   /// The function hash.
165   PGOHash Hash;
166   /// The map of statements to counters.
167   llvm::DenseMap<const Stmt *, CounterPair> &CounterMap;
168   /// The state of MC/DC Coverage in this function.
169   MCDC::State &MCDCState;
170   /// Maximum number of supported MC/DC conditions in a boolean expression.
171   unsigned MCDCMaxCond;
172   /// The profile version.
173   uint64_t ProfileVersion;
174   /// Diagnostics Engine used to report warnings.
175   DiagnosticsEngine &Diag;
176 
177   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178                     llvm::DenseMap<const Stmt *, CounterPair> &CounterMap,
179                     MCDC::State &MCDCState, unsigned MCDCMaxCond,
180                     DiagnosticsEngine &Diag)
181       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182         MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
183         ProfileVersion(ProfileVersion), Diag(Diag) {}
184 
185   // Blocks and lambdas are handled as separate functions, so we need not
186   // traverse them in the parent context.
187   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
188   bool TraverseLambdaExpr(LambdaExpr *LE) {
189     // Traverse the captures, but not the body.
190     for (auto C : zip(LE->captures(), LE->capture_inits()))
191       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192     return true;
193   }
194   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195 
196   bool VisitDecl(const Decl *D) {
197     switch (D->getKind()) {
198     default:
199       break;
200     case Decl::Function:
201     case Decl::CXXMethod:
202     case Decl::CXXConstructor:
203     case Decl::CXXDestructor:
204     case Decl::CXXConversion:
205     case Decl::ObjCMethod:
206     case Decl::Block:
207     case Decl::Captured:
208       CounterMap[D->getBody()] = NextCounter++;
209       break;
210     }
211     return true;
212   }
213 
214   /// If \p S gets a fresh counter, update the counter mappings. Return the
215   /// V1 hash of \p S.
216   PGOHash::HashType updateCounterMappings(Stmt *S) {
217     auto Type = getHashType(PGO_HASH_V1, S);
218     if (Type != PGOHash::None)
219       CounterMap[S] = NextCounter++;
220     return Type;
221   }
222 
223   /// The following stacks are used with dataTraverseStmtPre() and
224   /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225   /// boolean expression in a function.  The ultimate purpose is to keep track
226   /// of the number of leaf-level conditions in the boolean expression so that a
227   /// profile bitmap can be allocated based on that number.
228   ///
229   /// The stacks are also used to find error cases and notify the user.  A
230   /// standard logical operator nest for a boolean expression could be in a form
231   /// similar to this: "x = a && b && c && (d || f)"
232   unsigned NumCond = 0;
233   bool SplitNestedLogicalOp = false;
234   SmallVector<const Stmt *, 16> NonLogOpStack;
235   SmallVector<const BinaryOperator *, 16> LogOpStack;
236 
237   // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
238   bool dataTraverseStmtPre(Stmt *S) {
239     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240     if (MCDCMaxCond == 0)
241       return true;
242 
243     /// At the top of the logical operator nest, reset the number of conditions,
244     /// also forget previously seen split nesting cases.
245     if (LogOpStack.empty()) {
246       NumCond = 0;
247       SplitNestedLogicalOp = false;
248     }
249 
250     if (const Expr *E = dyn_cast<Expr>(S)) {
251       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
252       if (BinOp && BinOp->isLogicalOp()) {
253         /// Check for "split-nested" logical operators. This happens when a new
254         /// boolean expression logical-op nest is encountered within an existing
255         /// boolean expression, separated by a non-logical operator.  For
256         /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257         /// starts a new boolean expression that is separated from the other
258         /// conditions by the operator foo(). Split-nested cases are not
259         /// supported by MC/DC.
260         SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261 
262         LogOpStack.push_back(BinOp);
263         return true;
264       }
265     }
266 
267     /// Keep track of non-logical operators. These are OK as long as we don't
268     /// encounter a new logical operator after seeing one.
269     if (!LogOpStack.empty())
270       NonLogOpStack.push_back(S);
271 
272     return true;
273   }
274 
275   // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276   // an AST Stmt node.  MC/DC will use it to to signal when the top of a
277   // logical operation (boolean expression) nest is encountered.
278   bool dataTraverseStmtPost(Stmt *S) {
279     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280     if (MCDCMaxCond == 0)
281       return true;
282 
283     if (const Expr *E = dyn_cast<Expr>(S)) {
284       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
285       if (BinOp && BinOp->isLogicalOp()) {
286         assert(LogOpStack.back() == BinOp);
287         LogOpStack.pop_back();
288 
289         /// At the top of logical operator nest:
290         if (LogOpStack.empty()) {
291           /// Was the "split-nested" logical operator case encountered?
292           if (SplitNestedLogicalOp) {
293             unsigned DiagID = Diag.getCustomDiagID(
294                 DiagnosticsEngine::Warning,
295                 "unsupported MC/DC boolean expression; "
296                 "contains an operation with a nested boolean expression. "
297                 "Expression will not be covered");
298             Diag.Report(S->getBeginLoc(), DiagID);
299             return true;
300           }
301 
302           /// Was the maximum number of conditions encountered?
303           if (NumCond > MCDCMaxCond) {
304             unsigned DiagID = Diag.getCustomDiagID(
305                 DiagnosticsEngine::Warning,
306                 "unsupported MC/DC boolean expression; "
307                 "number of conditions (%0) exceeds max (%1). "
308                 "Expression will not be covered");
309             Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310             return true;
311           }
312 
313           // Otherwise, allocate the Decision.
314           MCDCState.DecisionByStmt[BinOp].BitmapIdx = 0;
315         }
316         return true;
317       }
318     }
319 
320     if (!LogOpStack.empty())
321       NonLogOpStack.pop_back();
322 
323     return true;
324   }
325 
326   /// The RHS of all logical operators gets a fresh counter in order to count
327   /// how many times the RHS evaluates to true or false, depending on the
328   /// semantics of the operator. This is only valid for ">= v7" of the profile
329   /// version so that we facilitate backward compatibility. In addition, in
330   /// order to use MC/DC, count the number of total LHS and RHS conditions.
331   bool VisitBinaryOperator(BinaryOperator *S) {
332     if (S->isLogicalOp()) {
333       if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
334         NumCond++;
335 
336       if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
337         if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
338           CounterMap[S->getRHS()] = NextCounter++;
339 
340         NumCond++;
341       }
342     }
343     return Base::VisitBinaryOperator(S);
344   }
345 
346   bool VisitConditionalOperator(ConditionalOperator *S) {
347     if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
348       CounterMap[S->getTrueExpr()] = NextCounter++;
349     if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
350       CounterMap[S->getFalseExpr()] = NextCounter++;
351     return Base::VisitConditionalOperator(S);
352   }
353 
354   /// Include \p S in the function hash.
355   bool VisitStmt(Stmt *S) {
356     auto Type = updateCounterMappings(S);
357     if (Hash.getHashVersion() != PGO_HASH_V1)
358       Type = getHashType(Hash.getHashVersion(), S);
359     if (Type != PGOHash::None)
360       Hash.combine(Type);
361     return true;
362   }
363 
364   bool TraverseIfStmt(IfStmt *If) {
365     // If we used the V1 hash, use the default traversal.
366     if (Hash.getHashVersion() == PGO_HASH_V1)
367       return Base::TraverseIfStmt(If);
368 
369     // When single byte coverage mode is enabled, add a counter to then and
370     // else.
371     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
372     for (Stmt *CS : If->children()) {
373       if (!CS || NoSingleByteCoverage)
374         continue;
375       if (CS == If->getThen())
376         CounterMap[If->getThen()] = NextCounter++;
377       else if (CS == If->getElse())
378         CounterMap[If->getElse()] = NextCounter++;
379     }
380 
381     // Otherwise, keep track of which branch we're in while traversing.
382     VisitStmt(If);
383 
384     for (Stmt *CS : If->children()) {
385       if (!CS)
386         continue;
387       if (CS == If->getThen())
388         Hash.combine(PGOHash::IfThenBranch);
389       else if (CS == If->getElse())
390         Hash.combine(PGOHash::IfElseBranch);
391       TraverseStmt(CS);
392     }
393     Hash.combine(PGOHash::EndOfScope);
394     return true;
395   }
396 
397   bool TraverseWhileStmt(WhileStmt *While) {
398     // When single byte coverage mode is enabled, add a counter to condition and
399     // body.
400     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
401     for (Stmt *CS : While->children()) {
402       if (!CS || NoSingleByteCoverage)
403         continue;
404       if (CS == While->getCond())
405         CounterMap[While->getCond()] = NextCounter++;
406       else if (CS == While->getBody())
407         CounterMap[While->getBody()] = NextCounter++;
408     }
409 
410     Base::TraverseWhileStmt(While);
411     if (Hash.getHashVersion() != PGO_HASH_V1)
412       Hash.combine(PGOHash::EndOfScope);
413     return true;
414   }
415 
416   bool TraverseDoStmt(DoStmt *Do) {
417     // When single byte coverage mode is enabled, add a counter to condition and
418     // body.
419     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
420     for (Stmt *CS : Do->children()) {
421       if (!CS || NoSingleByteCoverage)
422         continue;
423       if (CS == Do->getCond())
424         CounterMap[Do->getCond()] = NextCounter++;
425       else if (CS == Do->getBody())
426         CounterMap[Do->getBody()] = NextCounter++;
427     }
428 
429     Base::TraverseDoStmt(Do);
430     if (Hash.getHashVersion() != PGO_HASH_V1)
431       Hash.combine(PGOHash::EndOfScope);
432     return true;
433   }
434 
435   bool TraverseForStmt(ForStmt *For) {
436     // When single byte coverage mode is enabled, add a counter to condition,
437     // increment and body.
438     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
439     for (Stmt *CS : For->children()) {
440       if (!CS || NoSingleByteCoverage)
441         continue;
442       if (CS == For->getCond())
443         CounterMap[For->getCond()] = NextCounter++;
444       else if (CS == For->getInc())
445         CounterMap[For->getInc()] = NextCounter++;
446       else if (CS == For->getBody())
447         CounterMap[For->getBody()] = NextCounter++;
448     }
449 
450     Base::TraverseForStmt(For);
451     if (Hash.getHashVersion() != PGO_HASH_V1)
452       Hash.combine(PGOHash::EndOfScope);
453     return true;
454   }
455 
456   bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
457     // When single byte coverage mode is enabled, add a counter to body.
458     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
459     for (Stmt *CS : ForRange->children()) {
460       if (!CS || NoSingleByteCoverage)
461         continue;
462       if (CS == ForRange->getBody())
463         CounterMap[ForRange->getBody()] = NextCounter++;
464     }
465 
466     Base::TraverseCXXForRangeStmt(ForRange);
467     if (Hash.getHashVersion() != PGO_HASH_V1)
468       Hash.combine(PGOHash::EndOfScope);
469     return true;
470   }
471 
472 // If the statement type \p N is nestable, and its nesting impacts profile
473 // stability, define a custom traversal which tracks the end of the statement
474 // in the hash (provided we're not using the V1 hash).
475 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
476   bool Traverse##N(N *S) {                                                     \
477     Base::Traverse##N(S);                                                      \
478     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
479       Hash.combine(PGOHash::EndOfScope);                                       \
480     return true;                                                               \
481   }
482 
483   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
484   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
485   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
486 
487   /// Get version \p HashVersion of the PGO hash for \p S.
488   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
489     switch (S->getStmtClass()) {
490     default:
491       break;
492     case Stmt::LabelStmtClass:
493       return PGOHash::LabelStmt;
494     case Stmt::WhileStmtClass:
495       return PGOHash::WhileStmt;
496     case Stmt::DoStmtClass:
497       return PGOHash::DoStmt;
498     case Stmt::ForStmtClass:
499       return PGOHash::ForStmt;
500     case Stmt::CXXForRangeStmtClass:
501       return PGOHash::CXXForRangeStmt;
502     case Stmt::ObjCForCollectionStmtClass:
503       return PGOHash::ObjCForCollectionStmt;
504     case Stmt::SwitchStmtClass:
505       return PGOHash::SwitchStmt;
506     case Stmt::CaseStmtClass:
507       return PGOHash::CaseStmt;
508     case Stmt::DefaultStmtClass:
509       return PGOHash::DefaultStmt;
510     case Stmt::IfStmtClass:
511       return PGOHash::IfStmt;
512     case Stmt::CXXTryStmtClass:
513       return PGOHash::CXXTryStmt;
514     case Stmt::CXXCatchStmtClass:
515       return PGOHash::CXXCatchStmt;
516     case Stmt::ConditionalOperatorClass:
517       return PGOHash::ConditionalOperator;
518     case Stmt::BinaryConditionalOperatorClass:
519       return PGOHash::BinaryConditionalOperator;
520     case Stmt::BinaryOperatorClass: {
521       const BinaryOperator *BO = cast<BinaryOperator>(S);
522       if (BO->getOpcode() == BO_LAnd)
523         return PGOHash::BinaryOperatorLAnd;
524       if (BO->getOpcode() == BO_LOr)
525         return PGOHash::BinaryOperatorLOr;
526       if (HashVersion >= PGO_HASH_V2) {
527         switch (BO->getOpcode()) {
528         default:
529           break;
530         case BO_LT:
531           return PGOHash::BinaryOperatorLT;
532         case BO_GT:
533           return PGOHash::BinaryOperatorGT;
534         case BO_LE:
535           return PGOHash::BinaryOperatorLE;
536         case BO_GE:
537           return PGOHash::BinaryOperatorGE;
538         case BO_EQ:
539           return PGOHash::BinaryOperatorEQ;
540         case BO_NE:
541           return PGOHash::BinaryOperatorNE;
542         }
543       }
544       break;
545     }
546     }
547 
548     if (HashVersion >= PGO_HASH_V2) {
549       switch (S->getStmtClass()) {
550       default:
551         break;
552       case Stmt::GotoStmtClass:
553         return PGOHash::GotoStmt;
554       case Stmt::IndirectGotoStmtClass:
555         return PGOHash::IndirectGotoStmt;
556       case Stmt::BreakStmtClass:
557         return PGOHash::BreakStmt;
558       case Stmt::ContinueStmtClass:
559         return PGOHash::ContinueStmt;
560       case Stmt::ReturnStmtClass:
561         return PGOHash::ReturnStmt;
562       case Stmt::CXXThrowExprClass:
563         return PGOHash::ThrowExpr;
564       case Stmt::UnaryOperatorClass: {
565         const UnaryOperator *UO = cast<UnaryOperator>(S);
566         if (UO->getOpcode() == UO_LNot)
567           return PGOHash::UnaryOperatorLNot;
568         break;
569       }
570       }
571     }
572 
573     return PGOHash::None;
574   }
575 };
576 
577 /// A StmtVisitor that propagates the raw counts through the AST and
578 /// records the count at statements where the value may change.
579 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
580   /// PGO state.
581   CodeGenPGO &PGO;
582 
583   /// A flag that is set when the current count should be recorded on the
584   /// next statement, such as at the exit of a loop.
585   bool RecordNextStmtCount;
586 
587   /// The count at the current location in the traversal.
588   uint64_t CurrentCount;
589 
590   /// The map of statements to count values.
591   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
592 
593   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
594   struct BreakContinue {
595     uint64_t BreakCount = 0;
596     uint64_t ContinueCount = 0;
597     BreakContinue() = default;
598   };
599   SmallVector<BreakContinue, 8> BreakContinueStack;
600 
601   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
602                       CodeGenPGO &PGO)
603       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
604 
605   void RecordStmtCount(const Stmt *S) {
606     if (RecordNextStmtCount) {
607       CountMap[S] = CurrentCount;
608       RecordNextStmtCount = false;
609     }
610   }
611 
612   /// Set and return the current count.
613   uint64_t setCount(uint64_t Count) {
614     CurrentCount = Count;
615     return Count;
616   }
617 
618   void VisitStmt(const Stmt *S) {
619     RecordStmtCount(S);
620     for (const Stmt *Child : S->children())
621       if (Child)
622         this->Visit(Child);
623   }
624 
625   void VisitFunctionDecl(const FunctionDecl *D) {
626     // Counter tracks entry to the function body.
627     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
628     CountMap[D->getBody()] = BodyCount;
629     Visit(D->getBody());
630   }
631 
632   // Skip lambda expressions. We visit these as FunctionDecls when we're
633   // generating them and aren't interested in the body when generating a
634   // parent context.
635   void VisitLambdaExpr(const LambdaExpr *LE) {}
636 
637   void VisitCapturedDecl(const CapturedDecl *D) {
638     // Counter tracks entry to the capture body.
639     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
640     CountMap[D->getBody()] = BodyCount;
641     Visit(D->getBody());
642   }
643 
644   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
645     // Counter tracks entry to the method body.
646     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
647     CountMap[D->getBody()] = BodyCount;
648     Visit(D->getBody());
649   }
650 
651   void VisitBlockDecl(const BlockDecl *D) {
652     // Counter tracks entry to the block body.
653     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
654     CountMap[D->getBody()] = BodyCount;
655     Visit(D->getBody());
656   }
657 
658   void VisitReturnStmt(const ReturnStmt *S) {
659     RecordStmtCount(S);
660     if (S->getRetValue())
661       Visit(S->getRetValue());
662     CurrentCount = 0;
663     RecordNextStmtCount = true;
664   }
665 
666   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
667     RecordStmtCount(E);
668     if (E->getSubExpr())
669       Visit(E->getSubExpr());
670     CurrentCount = 0;
671     RecordNextStmtCount = true;
672   }
673 
674   void VisitGotoStmt(const GotoStmt *S) {
675     RecordStmtCount(S);
676     CurrentCount = 0;
677     RecordNextStmtCount = true;
678   }
679 
680   void VisitLabelStmt(const LabelStmt *S) {
681     RecordNextStmtCount = false;
682     // Counter tracks the block following the label.
683     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
684     CountMap[S] = BlockCount;
685     Visit(S->getSubStmt());
686   }
687 
688   void VisitBreakStmt(const BreakStmt *S) {
689     RecordStmtCount(S);
690     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
691     BreakContinueStack.back().BreakCount += CurrentCount;
692     CurrentCount = 0;
693     RecordNextStmtCount = true;
694   }
695 
696   void VisitContinueStmt(const ContinueStmt *S) {
697     RecordStmtCount(S);
698     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
699     BreakContinueStack.back().ContinueCount += CurrentCount;
700     CurrentCount = 0;
701     RecordNextStmtCount = true;
702   }
703 
704   void VisitWhileStmt(const WhileStmt *S) {
705     RecordStmtCount(S);
706     uint64_t ParentCount = CurrentCount;
707 
708     BreakContinueStack.push_back(BreakContinue());
709     // Visit the body region first so the break/continue adjustments can be
710     // included when visiting the condition.
711     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
712     CountMap[S->getBody()] = CurrentCount;
713     Visit(S->getBody());
714     uint64_t BackedgeCount = CurrentCount;
715 
716     // ...then go back and propagate counts through the condition. The count
717     // at the start of the condition is the sum of the incoming edges,
718     // the backedge from the end of the loop body, and the edges from
719     // continue statements.
720     BreakContinue BC = BreakContinueStack.pop_back_val();
721     uint64_t CondCount =
722         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
723     CountMap[S->getCond()] = CondCount;
724     Visit(S->getCond());
725     setCount(BC.BreakCount + CondCount - BodyCount);
726     RecordNextStmtCount = true;
727   }
728 
729   void VisitDoStmt(const DoStmt *S) {
730     RecordStmtCount(S);
731     uint64_t LoopCount = PGO.getRegionCount(S);
732 
733     BreakContinueStack.push_back(BreakContinue());
734     // The count doesn't include the fallthrough from the parent scope. Add it.
735     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
736     CountMap[S->getBody()] = BodyCount;
737     Visit(S->getBody());
738     uint64_t BackedgeCount = CurrentCount;
739 
740     BreakContinue BC = BreakContinueStack.pop_back_val();
741     // The count at the start of the condition is equal to the count at the
742     // end of the body, plus any continues.
743     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
744     CountMap[S->getCond()] = CondCount;
745     Visit(S->getCond());
746     setCount(BC.BreakCount + CondCount - LoopCount);
747     RecordNextStmtCount = true;
748   }
749 
750   void VisitForStmt(const ForStmt *S) {
751     RecordStmtCount(S);
752     if (S->getInit())
753       Visit(S->getInit());
754 
755     uint64_t ParentCount = CurrentCount;
756 
757     BreakContinueStack.push_back(BreakContinue());
758     // Visit the body region first. (This is basically the same as a while
759     // loop; see further comments in VisitWhileStmt.)
760     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
761     CountMap[S->getBody()] = BodyCount;
762     Visit(S->getBody());
763     uint64_t BackedgeCount = CurrentCount;
764     BreakContinue BC = BreakContinueStack.pop_back_val();
765 
766     // The increment is essentially part of the body but it needs to include
767     // the count for all the continue statements.
768     if (S->getInc()) {
769       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
770       CountMap[S->getInc()] = IncCount;
771       Visit(S->getInc());
772     }
773 
774     // ...then go back and propagate counts through the condition.
775     uint64_t CondCount =
776         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
777     if (S->getCond()) {
778       CountMap[S->getCond()] = CondCount;
779       Visit(S->getCond());
780     }
781     setCount(BC.BreakCount + CondCount - BodyCount);
782     RecordNextStmtCount = true;
783   }
784 
785   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
786     RecordStmtCount(S);
787     if (S->getInit())
788       Visit(S->getInit());
789     Visit(S->getLoopVarStmt());
790     Visit(S->getRangeStmt());
791     Visit(S->getBeginStmt());
792     Visit(S->getEndStmt());
793 
794     uint64_t ParentCount = CurrentCount;
795     BreakContinueStack.push_back(BreakContinue());
796     // Visit the body region first. (This is basically the same as a while
797     // loop; see further comments in VisitWhileStmt.)
798     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
799     CountMap[S->getBody()] = BodyCount;
800     Visit(S->getBody());
801     uint64_t BackedgeCount = CurrentCount;
802     BreakContinue BC = BreakContinueStack.pop_back_val();
803 
804     // The increment is essentially part of the body but it needs to include
805     // the count for all the continue statements.
806     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
807     CountMap[S->getInc()] = IncCount;
808     Visit(S->getInc());
809 
810     // ...then go back and propagate counts through the condition.
811     uint64_t CondCount =
812         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
813     CountMap[S->getCond()] = CondCount;
814     Visit(S->getCond());
815     setCount(BC.BreakCount + CondCount - BodyCount);
816     RecordNextStmtCount = true;
817   }
818 
819   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
820     RecordStmtCount(S);
821     Visit(S->getElement());
822     uint64_t ParentCount = CurrentCount;
823     BreakContinueStack.push_back(BreakContinue());
824     // Counter tracks the body of the loop.
825     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
826     CountMap[S->getBody()] = BodyCount;
827     Visit(S->getBody());
828     uint64_t BackedgeCount = CurrentCount;
829     BreakContinue BC = BreakContinueStack.pop_back_val();
830 
831     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
832              BodyCount);
833     RecordNextStmtCount = true;
834   }
835 
836   void VisitSwitchStmt(const SwitchStmt *S) {
837     RecordStmtCount(S);
838     if (S->getInit())
839       Visit(S->getInit());
840     Visit(S->getCond());
841     CurrentCount = 0;
842     BreakContinueStack.push_back(BreakContinue());
843     Visit(S->getBody());
844     // If the switch is inside a loop, add the continue counts.
845     BreakContinue BC = BreakContinueStack.pop_back_val();
846     if (!BreakContinueStack.empty())
847       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
848     // Counter tracks the exit block of the switch.
849     setCount(PGO.getRegionCount(S));
850     RecordNextStmtCount = true;
851   }
852 
853   void VisitSwitchCase(const SwitchCase *S) {
854     RecordNextStmtCount = false;
855     // Counter for this particular case. This counts only jumps from the
856     // switch header and does not include fallthrough from the case before
857     // this one.
858     uint64_t CaseCount = PGO.getRegionCount(S);
859     setCount(CurrentCount + CaseCount);
860     // We need the count without fallthrough in the mapping, so it's more useful
861     // for branch probabilities.
862     CountMap[S] = CaseCount;
863     RecordNextStmtCount = true;
864     Visit(S->getSubStmt());
865   }
866 
867   void VisitIfStmt(const IfStmt *S) {
868     RecordStmtCount(S);
869 
870     if (S->isConsteval()) {
871       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
872       if (Stm)
873         Visit(Stm);
874       return;
875     }
876 
877     uint64_t ParentCount = CurrentCount;
878     if (S->getInit())
879       Visit(S->getInit());
880     Visit(S->getCond());
881 
882     // Counter tracks the "then" part of an if statement. The count for
883     // the "else" part, if it exists, will be calculated from this counter.
884     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
885     CountMap[S->getThen()] = ThenCount;
886     Visit(S->getThen());
887     uint64_t OutCount = CurrentCount;
888 
889     uint64_t ElseCount = ParentCount - ThenCount;
890     if (S->getElse()) {
891       setCount(ElseCount);
892       CountMap[S->getElse()] = ElseCount;
893       Visit(S->getElse());
894       OutCount += CurrentCount;
895     } else
896       OutCount += ElseCount;
897     setCount(OutCount);
898     RecordNextStmtCount = true;
899   }
900 
901   void VisitCXXTryStmt(const CXXTryStmt *S) {
902     RecordStmtCount(S);
903     Visit(S->getTryBlock());
904     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
905       Visit(S->getHandler(I));
906     // Counter tracks the continuation block of the try statement.
907     setCount(PGO.getRegionCount(S));
908     RecordNextStmtCount = true;
909   }
910 
911   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
912     RecordNextStmtCount = false;
913     // Counter tracks the catch statement's handler block.
914     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
915     CountMap[S] = CatchCount;
916     Visit(S->getHandlerBlock());
917   }
918 
919   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
920     RecordStmtCount(E);
921     uint64_t ParentCount = CurrentCount;
922     Visit(E->getCond());
923 
924     // Counter tracks the "true" part of a conditional operator. The
925     // count in the "false" part will be calculated from this counter.
926     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
927     CountMap[E->getTrueExpr()] = TrueCount;
928     Visit(E->getTrueExpr());
929     uint64_t OutCount = CurrentCount;
930 
931     uint64_t FalseCount = setCount(ParentCount - TrueCount);
932     CountMap[E->getFalseExpr()] = FalseCount;
933     Visit(E->getFalseExpr());
934     OutCount += CurrentCount;
935 
936     setCount(OutCount);
937     RecordNextStmtCount = true;
938   }
939 
940   void VisitBinLAnd(const BinaryOperator *E) {
941     RecordStmtCount(E);
942     uint64_t ParentCount = CurrentCount;
943     Visit(E->getLHS());
944     // Counter tracks the right hand side of a logical and operator.
945     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
946     CountMap[E->getRHS()] = RHSCount;
947     Visit(E->getRHS());
948     setCount(ParentCount + RHSCount - CurrentCount);
949     RecordNextStmtCount = true;
950   }
951 
952   void VisitBinLOr(const BinaryOperator *E) {
953     RecordStmtCount(E);
954     uint64_t ParentCount = CurrentCount;
955     Visit(E->getLHS());
956     // Counter tracks the right hand side of a logical or operator.
957     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
958     CountMap[E->getRHS()] = RHSCount;
959     Visit(E->getRHS());
960     setCount(ParentCount + RHSCount - CurrentCount);
961     RecordNextStmtCount = true;
962   }
963 };
964 } // end anonymous namespace
965 
966 void PGOHash::combine(HashType Type) {
967   // Check that we never combine 0 and only have six bits.
968   assert(Type && "Hash is invalid: unexpected type 0");
969   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
970 
971   // Pass through MD5 if enough work has built up.
972   if (Count && Count % NumTypesPerWord == 0) {
973     using namespace llvm::support;
974     uint64_t Swapped =
975         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
976     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
977     Working = 0;
978   }
979 
980   // Accumulate the current type.
981   ++Count;
982   Working = Working << NumBitsPerType | Type;
983 }
984 
985 uint64_t PGOHash::finalize() {
986   // Use Working as the hash directly if we never used MD5.
987   if (Count <= NumTypesPerWord)
988     // No need to byte swap here, since none of the math was endian-dependent.
989     // This number will be byte-swapped as required on endianness transitions,
990     // so we will see the same value on the other side.
991     return Working;
992 
993   // Check for remaining work in Working.
994   if (Working) {
995     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
996     // is buggy because it converts a uint64_t into an array of uint8_t.
997     if (HashVersion < PGO_HASH_V3) {
998       MD5.update({(uint8_t)Working});
999     } else {
1000       using namespace llvm::support;
1001       uint64_t Swapped =
1002           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
1003       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1004     }
1005   }
1006 
1007   // Finalize the MD5 and return the hash.
1008   llvm::MD5::MD5Result Result;
1009   MD5.final(Result);
1010   return Result.low();
1011 }
1012 
1013 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1014   const Decl *D = GD.getDecl();
1015   if (!D->hasBody())
1016     return;
1017 
1018   // Skip CUDA/HIP kernel launch stub functions.
1019   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1020       D->hasAttr<CUDAGlobalAttr>())
1021     return;
1022 
1023   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1024   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1025   if (!InstrumentRegions && !PGOReader)
1026     return;
1027   if (D->isImplicit())
1028     return;
1029   // Constructors and destructors may be represented by several functions in IR.
1030   // If so, instrument only base variant, others are implemented by delegation
1031   // to the base one, it would be counted twice otherwise.
1032   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1033     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
1034       if (GD.getCtorType() != Ctor_Base &&
1035           CodeGenFunction::IsConstructorDelegationValid(CCD))
1036         return;
1037   }
1038   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
1039     return;
1040 
1041   CGM.ClearUnusedCoverageMapping(D);
1042   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
1043     return;
1044   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
1045     return;
1046 
1047   SourceManager &SM = CGM.getContext().getSourceManager();
1048   if (!llvm::coverage::SystemHeadersCoverage &&
1049       SM.isInSystemHeader(D->getLocation()))
1050     return;
1051 
1052   setFuncName(Fn);
1053 
1054   mapRegionCounters(D);
1055   if (CGM.getCodeGenOpts().CoverageMapping)
1056     emitCounterRegionMapping(D);
1057   if (PGOReader) {
1058     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
1059     computeRegionCounts(D);
1060     applyFunctionAttributes(PGOReader, Fn);
1061   }
1062 }
1063 
1064 void CodeGenPGO::mapRegionCounters(const Decl *D) {
1065   // Use the latest hash version when inserting instrumentation, but use the
1066   // version in the indexed profile if we're reading PGO data.
1067   PGOHashVersion HashVersion = PGO_HASH_LATEST;
1068   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1069   if (auto *PGOReader = CGM.getPGOReader()) {
1070     HashVersion = getPGOHashVersion(PGOReader, CGM);
1071     ProfileVersion = PGOReader->getVersion();
1072   }
1073 
1074   // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1075   // set it to zero. This value impacts the number of conditions accepted in a
1076   // given boolean expression, which impacts the size of the bitmap used to
1077   // track test vector execution for that boolean expression.  Because the
1078   // bitmap scales exponentially (2^n) based on the number of conditions seen,
1079   // the maximum value is hard-coded at 6 conditions, which is more than enough
1080   // for most embedded applications. Setting a maximum value prevents the
1081   // bitmap footprint from growing too large without the user's knowledge. In
1082   // the future, this value could be adjusted with a command-line option.
1083   unsigned MCDCMaxConditions =
1084       (CGM.getCodeGenOpts().MCDCCoverage ? CGM.getCodeGenOpts().MCDCMaxConds
1085                                          : 0);
1086 
1087   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, CounterPair>);
1088   RegionMCDCState.reset(new MCDC::State);
1089   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1090                            *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1091   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1092     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
1093   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1094     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1095   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1096     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1097   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1098     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1099   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1100   NumRegionCounters = Walker.NextCounter;
1101   FunctionHash = Walker.Hash.finalize();
1102 }
1103 
1104 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1105   if (!D->getBody())
1106     return true;
1107 
1108   // Skip host-only functions in the CUDA device compilation and device-only
1109   // functions in the host compilation. Just roughly filter them out based on
1110   // the function attributes. If there are effectively host-only or device-only
1111   // ones, their coverage mapping may still be generated.
1112   if (CGM.getLangOpts().CUDA &&
1113       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1114         !D->hasAttr<CUDAGlobalAttr>()) ||
1115        (!CGM.getLangOpts().CUDAIsDevice &&
1116         (D->hasAttr<CUDAGlobalAttr>() ||
1117          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1118     return true;
1119 
1120   // Don't map the functions in system headers.
1121   const auto &SM = CGM.getContext().getSourceManager();
1122   auto Loc = D->getBody()->getBeginLoc();
1123   return !llvm::coverage::SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1124 }
1125 
1126 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1127   if (skipRegionMappingForDecl(D))
1128     return;
1129 
1130   std::string CoverageMapping;
1131   llvm::raw_string_ostream OS(CoverageMapping);
1132   RegionMCDCState->BranchByStmt.clear();
1133   CoverageMappingGen MappingGen(
1134       *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1135       CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1136   MappingGen.emitCounterMapping(D, OS);
1137 
1138   if (CoverageMapping.empty())
1139     return;
1140 
1141   CGM.getCoverageMapping()->addFunctionMappingRecord(
1142       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1143 }
1144 
1145 void
1146 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1147                                     llvm::GlobalValue::LinkageTypes Linkage) {
1148   if (skipRegionMappingForDecl(D))
1149     return;
1150 
1151   std::string CoverageMapping;
1152   llvm::raw_string_ostream OS(CoverageMapping);
1153   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1154                                 CGM.getContext().getSourceManager(),
1155                                 CGM.getLangOpts());
1156   MappingGen.emitEmptyMapping(D, OS);
1157 
1158   if (CoverageMapping.empty())
1159     return;
1160 
1161   setFuncName(Name, Linkage);
1162   CGM.getCoverageMapping()->addFunctionMappingRecord(
1163       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1164 }
1165 
1166 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1167   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1168   ComputeRegionCounts Walker(*StmtCountMap, *this);
1169   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1170     Walker.VisitFunctionDecl(FD);
1171   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1172     Walker.VisitObjCMethodDecl(MD);
1173   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1174     Walker.VisitBlockDecl(BD);
1175   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1176     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1177 }
1178 
1179 void
1180 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1181                                     llvm::Function *Fn) {
1182   if (!haveRegionCounts())
1183     return;
1184 
1185   uint64_t FunctionCount = getRegionCount(nullptr);
1186   Fn->setEntryCount(FunctionCount);
1187 }
1188 
1189 std::pair<bool, bool> CodeGenPGO::getIsCounterPair(const Stmt *S) const {
1190   if (!RegionCounterMap)
1191     return {false, false};
1192 
1193   auto I = RegionCounterMap->find(S);
1194   if (I == RegionCounterMap->end())
1195     return {false, false};
1196 
1197   return {I->second.Executed.hasValue(), I->second.Skipped.hasValue()};
1198 }
1199 
1200 void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1201                                            llvm::Value *StepV) {
1202   if (!RegionCounterMap || !Builder.GetInsertBlock())
1203     return;
1204 
1205   unsigned Counter = (*RegionCounterMap)[S].Executed;
1206 
1207   // Make sure that pointer to global is passed in with zero addrspace
1208   // This is relevant during GPU profiling
1209   auto *NormalizedFuncNameVarPtr =
1210       llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
1211           FuncNameVar, llvm::PointerType::get(CGM.getLLVMContext(), 0));
1212 
1213   llvm::Value *Args[] = {
1214       NormalizedFuncNameVarPtr, Builder.getInt64(FunctionHash),
1215       Builder.getInt32(NumRegionCounters), Builder.getInt32(Counter), StepV};
1216 
1217   if (llvm::EnableSingleByteCoverage)
1218     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_cover),
1219                        ArrayRef(Args, 4));
1220   else if (!StepV)
1221     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1222                        ArrayRef(Args, 4));
1223   else
1224     Builder.CreateCall(
1225         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), Args);
1226 }
1227 
1228 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1229   return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1230           CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1231 }
1232 
1233 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1234   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1235     return;
1236 
1237   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1238 
1239   // Emit intrinsic representing MCDC bitmap parameters at function entry.
1240   // This is used by the instrumentation pass, but it isn't actually lowered to
1241   // anything.
1242   llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1243                           Builder.getInt64(FunctionHash),
1244                           Builder.getInt32(RegionMCDCState->BitmapBits)};
1245   Builder.CreateCall(
1246       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1247 }
1248 
1249 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1250                                                 const Expr *S,
1251                                                 Address MCDCCondBitmapAddr,
1252                                                 CodeGenFunction &CGF) {
1253   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1254     return;
1255 
1256   S = S->IgnoreParens();
1257 
1258   auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(S);
1259   if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1260     return;
1261 
1262   // Don't create tvbitmap_update if the record is allocated but excluded.
1263   // Or `bitmap |= (1 << 0)` would be wrongly executed to the next bitmap.
1264   if (DecisionStateIter->second.Indices.size() == 0)
1265     return;
1266 
1267   // Extract the offset of the global bitmap associated with this expression.
1268   unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1269   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1270 
1271   // Emit intrinsic responsible for updating the global bitmap corresponding to
1272   // a boolean expression. The index being set is based on the value loaded
1273   // from a pointer to a dedicated temporary value on the stack that is itself
1274   // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1275   // index represents an executed test vector.
1276   llvm::Value *Args[4] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1277                           Builder.getInt64(FunctionHash),
1278                           Builder.getInt32(MCDCTestVectorBitmapOffset),
1279                           MCDCCondBitmapAddr.emitRawPointer(CGF)};
1280   Builder.CreateCall(
1281       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1282 }
1283 
1284 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1285                                          Address MCDCCondBitmapAddr) {
1286   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1287     return;
1288 
1289   S = S->IgnoreParens();
1290 
1291   if (!RegionMCDCState->DecisionByStmt.contains(S))
1292     return;
1293 
1294   // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1295   Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1296 }
1297 
1298 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1299                                           Address MCDCCondBitmapAddr,
1300                                           llvm::Value *Val,
1301                                           CodeGenFunction &CGF) {
1302   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1303     return;
1304 
1305   // Even though, for simplicity, parentheses and unary logical-NOT operators
1306   // are considered part of their underlying condition for both MC/DC and
1307   // branch coverage, the condition IDs themselves are assigned and tracked
1308   // using the underlying condition itself.  This is done solely for
1309   // consistency since parentheses and logical-NOTs are ignored when checking
1310   // whether the condition is actually an instrumentable condition. This can
1311   // also make debugging a bit easier.
1312   S = CodeGenFunction::stripCond(S);
1313 
1314   auto BranchStateIter = RegionMCDCState->BranchByStmt.find(S);
1315   if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1316     return;
1317 
1318   // Extract the ID of the condition we are setting in the bitmap.
1319   const auto &Branch = BranchStateIter->second;
1320   assert(Branch.ID >= 0 && "Condition has no ID!");
1321   assert(Branch.DecisionStmt);
1322 
1323   // Cancel the emission if the Decision is erased after the allocation.
1324   const auto DecisionIter =
1325       RegionMCDCState->DecisionByStmt.find(Branch.DecisionStmt);
1326   if (DecisionIter == RegionMCDCState->DecisionByStmt.end())
1327     return;
1328 
1329   const auto &TVIdxs = DecisionIter->second.Indices[Branch.ID];
1330 
1331   auto *CurTV = Builder.CreateLoad(MCDCCondBitmapAddr,
1332                                    "mcdc." + Twine(Branch.ID + 1) + ".cur");
1333   auto *NewTV = Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[true]));
1334   NewTV = Builder.CreateSelect(
1335       Val, NewTV, Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[false])));
1336   Builder.CreateStore(NewTV, MCDCCondBitmapAddr);
1337 }
1338 
1339 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1340   if (CGM.getCodeGenOpts().hasProfileClangInstr())
1341     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1342                     uint32_t(EnableValueProfiling));
1343 }
1344 
1345 void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1346   if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1347       llvm::EnableSingleByteCoverage) {
1348     const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1349     llvm::Type *IntTy64 = llvm::Type::getInt64Ty(M.getContext());
1350     uint64_t ProfileVersion =
1351         (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1352 
1353     auto IRLevelVersionVariable = new llvm::GlobalVariable(
1354         M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1355         llvm::Constant::getIntegerValue(IntTy64,
1356                                         llvm::APInt(64, ProfileVersion)),
1357         VarName);
1358 
1359     IRLevelVersionVariable->setVisibility(llvm::GlobalValue::HiddenVisibility);
1360     llvm::Triple TT(M.getTargetTriple());
1361     if (TT.isGPU())
1362       IRLevelVersionVariable->setVisibility(
1363           llvm::GlobalValue::ProtectedVisibility);
1364     if (TT.supportsCOMDAT()) {
1365       IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1366       IRLevelVersionVariable->setComdat(M.getOrInsertComdat(VarName));
1367     }
1368     IRLevelVersionVariable->setDSOLocal(true);
1369   }
1370 }
1371 
1372 // This method either inserts a call to the profile run-time during
1373 // instrumentation or puts profile data into metadata for PGO use.
1374 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1375     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1376 
1377   if (!EnableValueProfiling)
1378     return;
1379 
1380   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1381     return;
1382 
1383   if (isa<llvm::Constant>(ValuePtr))
1384     return;
1385 
1386   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1387   if (InstrumentValueSites && RegionCounterMap) {
1388     auto BuilderInsertPoint = Builder.saveIP();
1389     Builder.SetInsertPoint(ValueSite);
1390     llvm::Value *Args[5] = {
1391         FuncNameVar,
1392         Builder.getInt64(FunctionHash),
1393         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1394         Builder.getInt32(ValueKind),
1395         Builder.getInt32(NumValueSites[ValueKind]++)
1396     };
1397     Builder.CreateCall(
1398         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1399     Builder.restoreIP(BuilderInsertPoint);
1400     return;
1401   }
1402 
1403   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1404   if (PGOReader && haveRegionCounts()) {
1405     // We record the top most called three functions at each call site.
1406     // Profile metadata contains "VP" string identifying this metadata
1407     // as value profiling data, then a uint32_t value for the value profiling
1408     // kind, a uint64_t value for the total number of times the call is
1409     // executed, followed by the function hash and execution count (uint64_t)
1410     // pairs for each function.
1411     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1412       return;
1413 
1414     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1415                             (llvm::InstrProfValueKind)ValueKind,
1416                             NumValueSites[ValueKind]);
1417 
1418     NumValueSites[ValueKind]++;
1419   }
1420 }
1421 
1422 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1423                                   bool IsInMainFile) {
1424   CGM.getPGOStats().addVisited(IsInMainFile);
1425   RegionCounts.clear();
1426   auto RecordExpected = PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1427   if (auto E = RecordExpected.takeError()) {
1428     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1429     if (IPE == llvm::instrprof_error::unknown_function)
1430       CGM.getPGOStats().addMissing(IsInMainFile);
1431     else if (IPE == llvm::instrprof_error::hash_mismatch)
1432       CGM.getPGOStats().addMismatched(IsInMainFile);
1433     else if (IPE == llvm::instrprof_error::malformed)
1434       // TODO: Consider a more specific warning for this case.
1435       CGM.getPGOStats().addMismatched(IsInMainFile);
1436     return;
1437   }
1438   ProfRecord =
1439       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1440   RegionCounts = ProfRecord->Counts;
1441 }
1442 
1443 /// Calculate what to divide by to scale weights.
1444 ///
1445 /// Given the maximum weight, calculate a divisor that will scale all the
1446 /// weights to strictly less than UINT32_MAX.
1447 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1448   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1449 }
1450 
1451 /// Scale an individual branch weight (and add 1).
1452 ///
1453 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1454 ///
1455 /// According to Laplace's Rule of Succession, it is better to compute the
1456 /// weight based on the count plus 1, so universally add 1 to the value.
1457 ///
1458 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1459 /// greater than \c Weight.
1460 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1461   assert(Scale && "scale by 0?");
1462   uint64_t Scaled = Weight / Scale + 1;
1463   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1464   return Scaled;
1465 }
1466 
1467 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1468                                                     uint64_t FalseCount) const {
1469   // Check for empty weights.
1470   if (!TrueCount && !FalseCount)
1471     return nullptr;
1472 
1473   // Calculate how to scale down to 32-bits.
1474   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1475 
1476   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1477   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1478                                       scaleBranchWeight(FalseCount, Scale));
1479 }
1480 
1481 llvm::MDNode *
1482 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1483   // We need at least two elements to create meaningful weights.
1484   if (Weights.size() < 2)
1485     return nullptr;
1486 
1487   // Check for empty weights.
1488   uint64_t MaxWeight = *llvm::max_element(Weights);
1489   if (MaxWeight == 0)
1490     return nullptr;
1491 
1492   // Calculate how to scale down to 32-bits.
1493   uint64_t Scale = calculateWeightScale(MaxWeight);
1494 
1495   SmallVector<uint32_t, 16> ScaledWeights;
1496   ScaledWeights.reserve(Weights.size());
1497   for (uint64_t W : Weights)
1498     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1499 
1500   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1501   return MDHelper.createBranchWeights(ScaledWeights);
1502 }
1503 
1504 llvm::MDNode *
1505 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1506                                              uint64_t LoopCount) const {
1507   if (!PGO->haveRegionCounts())
1508     return nullptr;
1509   std::optional<uint64_t> CondCount = PGO->getStmtCount(Cond);
1510   if (!CondCount || *CondCount == 0)
1511     return nullptr;
1512   return createProfileWeights(LoopCount,
1513                               std::max(*CondCount, LoopCount) - LoopCount);
1514 }
1515 
1516 void CodeGenFunction::incrementProfileCounter(const Stmt *S,
1517                                               llvm::Value *StepV) {
1518   if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1519       !CurFn->hasFnAttribute(llvm::Attribute::NoProfile) &&
1520       !CurFn->hasFnAttribute(llvm::Attribute::SkipProfile)) {
1521     auto AL = ApplyDebugLocation::CreateArtificial(*this);
1522     PGO->emitCounterSetOrIncrement(Builder, S, StepV);
1523   }
1524   PGO->setCurrentStmt(S);
1525 }
1526 
1527 std::pair<bool, bool> CodeGenFunction::getIsCounterPair(const Stmt *S) const {
1528   return PGO->getIsCounterPair(S);
1529 }
1530 void CodeGenFunction::markStmtAsUsed(bool Skipped, const Stmt *S) {
1531   PGO->markStmtAsUsed(Skipped, S);
1532 }
1533 void CodeGenFunction::markStmtMaybeUsed(const Stmt *S) {
1534   PGO->markStmtMaybeUsed(S);
1535 }
1536 
1537 void CodeGenFunction::maybeCreateMCDCCondBitmap() {
1538   if (isMCDCCoverageEnabled()) {
1539     PGO->emitMCDCParameters(Builder);
1540     MCDCCondBitmapAddr = CreateIRTemp(getContext().UnsignedIntTy, "mcdc.addr");
1541   }
1542 }
1543 void CodeGenFunction::maybeResetMCDCCondBitmap(const Expr *E) {
1544   if (isMCDCCoverageEnabled() && isBinaryLogicalOp(E)) {
1545     PGO->emitMCDCCondBitmapReset(Builder, E, MCDCCondBitmapAddr);
1546     PGO->setCurrentStmt(E);
1547   }
1548 }
1549 void CodeGenFunction::maybeUpdateMCDCTestVectorBitmap(const Expr *E) {
1550   if (isMCDCCoverageEnabled() && isBinaryLogicalOp(E)) {
1551     PGO->emitMCDCTestVectorBitmapUpdate(Builder, E, MCDCCondBitmapAddr, *this);
1552     PGO->setCurrentStmt(E);
1553   }
1554 }
1555 
1556 void CodeGenFunction::maybeUpdateMCDCCondBitmap(const Expr *E,
1557                                                 llvm::Value *Val) {
1558   if (isMCDCCoverageEnabled()) {
1559     PGO->emitMCDCCondBitmapUpdate(Builder, E, MCDCCondBitmapAddr, Val, *this);
1560     PGO->setCurrentStmt(E);
1561   }
1562 }
1563 
1564 uint64_t CodeGenFunction::getProfileCount(const Stmt *S) {
1565   return PGO->getStmtCount(S).value_or(0);
1566 }
1567 
1568 /// Set the profiler's current count.
1569 void CodeGenFunction::setCurrentProfileCount(uint64_t Count) {
1570   PGO->setCurrentRegionCount(Count);
1571 }
1572 
1573 /// Get the profiler's current count. This is generally the count for the most
1574 /// recently incremented counter.
1575 uint64_t CodeGenFunction::getCurrentProfileCount() {
1576   return PGO->getCurrentRegionCount();
1577 }
1578