xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 namespace llvm {
27 extern cl::opt<bool> EnableSingleByteCoverage;
28 } // namespace llvm
29 
30 static llvm::cl::opt<bool>
31     EnableValueProfiling("enable-value-profiling",
32                          llvm::cl::desc("Enable value profiling"),
33                          llvm::cl::Hidden, llvm::cl::init(false));
34 
35 using namespace clang;
36 using namespace CodeGen;
37 
38 void CodeGenPGO::setFuncName(StringRef Name,
39                              llvm::GlobalValue::LinkageTypes Linkage) {
40   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
41   FuncName = llvm::getPGOFuncName(
42       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
43       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
44 
45   // If we're generating a profile, create a variable for the name.
46   if (CGM.getCodeGenOpts().hasProfileClangInstr())
47     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
48 }
49 
50 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
51   setFuncName(Fn->getName(), Fn->getLinkage());
52   // Create PGOFuncName meta data.
53   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
54 }
55 
56 /// The version of the PGO hash algorithm.
57 enum PGOHashVersion : unsigned {
58   PGO_HASH_V1,
59   PGO_HASH_V2,
60   PGO_HASH_V3,
61 
62   // Keep this set to the latest hash version.
63   PGO_HASH_LATEST = PGO_HASH_V3
64 };
65 
66 namespace {
67 /// Stable hasher for PGO region counters.
68 ///
69 /// PGOHash produces a stable hash of a given function's control flow.
70 ///
71 /// Changing the output of this hash will invalidate all previously generated
72 /// profiles -- i.e., don't do it.
73 ///
74 /// \note  When this hash does eventually change (years?), we still need to
75 /// support old hashes.  We'll need to pull in the version number from the
76 /// profile data format and use the matching hash function.
77 class PGOHash {
78   uint64_t Working;
79   unsigned Count;
80   PGOHashVersion HashVersion;
81   llvm::MD5 MD5;
82 
83   static const int NumBitsPerType = 6;
84   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
85   static const unsigned TooBig = 1u << NumBitsPerType;
86 
87 public:
88   /// Hash values for AST nodes.
89   ///
90   /// Distinct values for AST nodes that have region counters attached.
91   ///
92   /// These values must be stable.  All new members must be added at the end,
93   /// and no members should be removed.  Changing the enumeration value for an
94   /// AST node will affect the hash of every function that contains that node.
95   enum HashType : unsigned char {
96     None = 0,
97     LabelStmt = 1,
98     WhileStmt,
99     DoStmt,
100     ForStmt,
101     CXXForRangeStmt,
102     ObjCForCollectionStmt,
103     SwitchStmt,
104     CaseStmt,
105     DefaultStmt,
106     IfStmt,
107     CXXTryStmt,
108     CXXCatchStmt,
109     ConditionalOperator,
110     BinaryOperatorLAnd,
111     BinaryOperatorLOr,
112     BinaryConditionalOperator,
113     // The preceding values are available with PGO_HASH_V1.
114 
115     EndOfScope,
116     IfThenBranch,
117     IfElseBranch,
118     GotoStmt,
119     IndirectGotoStmt,
120     BreakStmt,
121     ContinueStmt,
122     ReturnStmt,
123     ThrowExpr,
124     UnaryOperatorLNot,
125     BinaryOperatorLT,
126     BinaryOperatorGT,
127     BinaryOperatorLE,
128     BinaryOperatorGE,
129     BinaryOperatorEQ,
130     BinaryOperatorNE,
131     // The preceding values are available since PGO_HASH_V2.
132 
133     // Keep this last.  It's for the static assert that follows.
134     LastHashType
135   };
136   static_assert(LastHashType <= TooBig, "Too many types in HashType");
137 
138   PGOHash(PGOHashVersion HashVersion)
139       : Working(0), Count(0), HashVersion(HashVersion) {}
140   void combine(HashType Type);
141   uint64_t finalize();
142   PGOHashVersion getHashVersion() const { return HashVersion; }
143 };
144 const int PGOHash::NumBitsPerType;
145 const unsigned PGOHash::NumTypesPerWord;
146 const unsigned PGOHash::TooBig;
147 
148 /// Get the PGO hash version used in the given indexed profile.
149 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
150                                         CodeGenModule &CGM) {
151   if (PGOReader->getVersion() <= 4)
152     return PGO_HASH_V1;
153   if (PGOReader->getVersion() <= 5)
154     return PGO_HASH_V2;
155   return PGO_HASH_V3;
156 }
157 
158 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
159 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
160   using Base = RecursiveASTVisitor<MapRegionCounters>;
161 
162   /// The next counter value to assign.
163   unsigned NextCounter;
164   /// The function hash.
165   PGOHash Hash;
166   /// The map of statements to counters.
167   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
168   /// The state of MC/DC Coverage in this function.
169   MCDC::State &MCDCState;
170   /// Maximum number of supported MC/DC conditions in a boolean expression.
171   unsigned MCDCMaxCond;
172   /// The profile version.
173   uint64_t ProfileVersion;
174   /// Diagnostics Engine used to report warnings.
175   DiagnosticsEngine &Diag;
176 
177   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
179                     MCDC::State &MCDCState, unsigned MCDCMaxCond,
180                     DiagnosticsEngine &Diag)
181       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182         MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
183         ProfileVersion(ProfileVersion), Diag(Diag) {}
184 
185   // Blocks and lambdas are handled as separate functions, so we need not
186   // traverse them in the parent context.
187   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
188   bool TraverseLambdaExpr(LambdaExpr *LE) {
189     // Traverse the captures, but not the body.
190     for (auto C : zip(LE->captures(), LE->capture_inits()))
191       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192     return true;
193   }
194   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195 
196   bool VisitDecl(const Decl *D) {
197     switch (D->getKind()) {
198     default:
199       break;
200     case Decl::Function:
201     case Decl::CXXMethod:
202     case Decl::CXXConstructor:
203     case Decl::CXXDestructor:
204     case Decl::CXXConversion:
205     case Decl::ObjCMethod:
206     case Decl::Block:
207     case Decl::Captured:
208       CounterMap[D->getBody()] = NextCounter++;
209       break;
210     }
211     return true;
212   }
213 
214   /// If \p S gets a fresh counter, update the counter mappings. Return the
215   /// V1 hash of \p S.
216   PGOHash::HashType updateCounterMappings(Stmt *S) {
217     auto Type = getHashType(PGO_HASH_V1, S);
218     if (Type != PGOHash::None)
219       CounterMap[S] = NextCounter++;
220     return Type;
221   }
222 
223   /// The following stacks are used with dataTraverseStmtPre() and
224   /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225   /// boolean expression in a function.  The ultimate purpose is to keep track
226   /// of the number of leaf-level conditions in the boolean expression so that a
227   /// profile bitmap can be allocated based on that number.
228   ///
229   /// The stacks are also used to find error cases and notify the user.  A
230   /// standard logical operator nest for a boolean expression could be in a form
231   /// similar to this: "x = a && b && c && (d || f)"
232   unsigned NumCond = 0;
233   bool SplitNestedLogicalOp = false;
234   SmallVector<const Stmt *, 16> NonLogOpStack;
235   SmallVector<const BinaryOperator *, 16> LogOpStack;
236 
237   // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
238   bool dataTraverseStmtPre(Stmt *S) {
239     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240     if (MCDCMaxCond == 0)
241       return true;
242 
243     /// At the top of the logical operator nest, reset the number of conditions,
244     /// also forget previously seen split nesting cases.
245     if (LogOpStack.empty()) {
246       NumCond = 0;
247       SplitNestedLogicalOp = false;
248     }
249 
250     if (const Expr *E = dyn_cast<Expr>(S)) {
251       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
252       if (BinOp && BinOp->isLogicalOp()) {
253         /// Check for "split-nested" logical operators. This happens when a new
254         /// boolean expression logical-op nest is encountered within an existing
255         /// boolean expression, separated by a non-logical operator.  For
256         /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257         /// starts a new boolean expression that is separated from the other
258         /// conditions by the operator foo(). Split-nested cases are not
259         /// supported by MC/DC.
260         SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261 
262         LogOpStack.push_back(BinOp);
263         return true;
264       }
265     }
266 
267     /// Keep track of non-logical operators. These are OK as long as we don't
268     /// encounter a new logical operator after seeing one.
269     if (!LogOpStack.empty())
270       NonLogOpStack.push_back(S);
271 
272     return true;
273   }
274 
275   // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276   // an AST Stmt node.  MC/DC will use it to to signal when the top of a
277   // logical operation (boolean expression) nest is encountered.
278   bool dataTraverseStmtPost(Stmt *S) {
279     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280     if (MCDCMaxCond == 0)
281       return true;
282 
283     if (const Expr *E = dyn_cast<Expr>(S)) {
284       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
285       if (BinOp && BinOp->isLogicalOp()) {
286         assert(LogOpStack.back() == BinOp);
287         LogOpStack.pop_back();
288 
289         /// At the top of logical operator nest:
290         if (LogOpStack.empty()) {
291           /// Was the "split-nested" logical operator case encountered?
292           if (SplitNestedLogicalOp) {
293             unsigned DiagID = Diag.getCustomDiagID(
294                 DiagnosticsEngine::Warning,
295                 "unsupported MC/DC boolean expression; "
296                 "contains an operation with a nested boolean expression. "
297                 "Expression will not be covered");
298             Diag.Report(S->getBeginLoc(), DiagID);
299             return true;
300           }
301 
302           /// Was the maximum number of conditions encountered?
303           if (NumCond > MCDCMaxCond) {
304             unsigned DiagID = Diag.getCustomDiagID(
305                 DiagnosticsEngine::Warning,
306                 "unsupported MC/DC boolean expression; "
307                 "number of conditions (%0) exceeds max (%1). "
308                 "Expression will not be covered");
309             Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310             return true;
311           }
312 
313           // Otherwise, allocate the Decision.
314           MCDCState.DecisionByStmt[BinOp].BitmapIdx = 0;
315         }
316         return true;
317       }
318     }
319 
320     if (!LogOpStack.empty())
321       NonLogOpStack.pop_back();
322 
323     return true;
324   }
325 
326   /// The RHS of all logical operators gets a fresh counter in order to count
327   /// how many times the RHS evaluates to true or false, depending on the
328   /// semantics of the operator. This is only valid for ">= v7" of the profile
329   /// version so that we facilitate backward compatibility. In addition, in
330   /// order to use MC/DC, count the number of total LHS and RHS conditions.
331   bool VisitBinaryOperator(BinaryOperator *S) {
332     if (S->isLogicalOp()) {
333       if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
334         NumCond++;
335 
336       if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
337         if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
338           CounterMap[S->getRHS()] = NextCounter++;
339 
340         NumCond++;
341       }
342     }
343     return Base::VisitBinaryOperator(S);
344   }
345 
346   bool VisitConditionalOperator(ConditionalOperator *S) {
347     if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
348       CounterMap[S->getTrueExpr()] = NextCounter++;
349     if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
350       CounterMap[S->getFalseExpr()] = NextCounter++;
351     return Base::VisitConditionalOperator(S);
352   }
353 
354   /// Include \p S in the function hash.
355   bool VisitStmt(Stmt *S) {
356     auto Type = updateCounterMappings(S);
357     if (Hash.getHashVersion() != PGO_HASH_V1)
358       Type = getHashType(Hash.getHashVersion(), S);
359     if (Type != PGOHash::None)
360       Hash.combine(Type);
361     return true;
362   }
363 
364   bool TraverseIfStmt(IfStmt *If) {
365     // If we used the V1 hash, use the default traversal.
366     if (Hash.getHashVersion() == PGO_HASH_V1)
367       return Base::TraverseIfStmt(If);
368 
369     // When single byte coverage mode is enabled, add a counter to then and
370     // else.
371     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
372     for (Stmt *CS : If->children()) {
373       if (!CS || NoSingleByteCoverage)
374         continue;
375       if (CS == If->getThen())
376         CounterMap[If->getThen()] = NextCounter++;
377       else if (CS == If->getElse())
378         CounterMap[If->getElse()] = NextCounter++;
379     }
380 
381     // Otherwise, keep track of which branch we're in while traversing.
382     VisitStmt(If);
383 
384     for (Stmt *CS : If->children()) {
385       if (!CS)
386         continue;
387       if (CS == If->getThen())
388         Hash.combine(PGOHash::IfThenBranch);
389       else if (CS == If->getElse())
390         Hash.combine(PGOHash::IfElseBranch);
391       TraverseStmt(CS);
392     }
393     Hash.combine(PGOHash::EndOfScope);
394     return true;
395   }
396 
397   bool TraverseWhileStmt(WhileStmt *While) {
398     // When single byte coverage mode is enabled, add a counter to condition and
399     // body.
400     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
401     for (Stmt *CS : While->children()) {
402       if (!CS || NoSingleByteCoverage)
403         continue;
404       if (CS == While->getCond())
405         CounterMap[While->getCond()] = NextCounter++;
406       else if (CS == While->getBody())
407         CounterMap[While->getBody()] = NextCounter++;
408     }
409 
410     Base::TraverseWhileStmt(While);
411     if (Hash.getHashVersion() != PGO_HASH_V1)
412       Hash.combine(PGOHash::EndOfScope);
413     return true;
414   }
415 
416   bool TraverseDoStmt(DoStmt *Do) {
417     // When single byte coverage mode is enabled, add a counter to condition and
418     // body.
419     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
420     for (Stmt *CS : Do->children()) {
421       if (!CS || NoSingleByteCoverage)
422         continue;
423       if (CS == Do->getCond())
424         CounterMap[Do->getCond()] = NextCounter++;
425       else if (CS == Do->getBody())
426         CounterMap[Do->getBody()] = NextCounter++;
427     }
428 
429     Base::TraverseDoStmt(Do);
430     if (Hash.getHashVersion() != PGO_HASH_V1)
431       Hash.combine(PGOHash::EndOfScope);
432     return true;
433   }
434 
435   bool TraverseForStmt(ForStmt *For) {
436     // When single byte coverage mode is enabled, add a counter to condition,
437     // increment and body.
438     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
439     for (Stmt *CS : For->children()) {
440       if (!CS || NoSingleByteCoverage)
441         continue;
442       if (CS == For->getCond())
443         CounterMap[For->getCond()] = NextCounter++;
444       else if (CS == For->getInc())
445         CounterMap[For->getInc()] = NextCounter++;
446       else if (CS == For->getBody())
447         CounterMap[For->getBody()] = NextCounter++;
448     }
449 
450     Base::TraverseForStmt(For);
451     if (Hash.getHashVersion() != PGO_HASH_V1)
452       Hash.combine(PGOHash::EndOfScope);
453     return true;
454   }
455 
456   bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
457     // When single byte coverage mode is enabled, add a counter to body.
458     bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
459     for (Stmt *CS : ForRange->children()) {
460       if (!CS || NoSingleByteCoverage)
461         continue;
462       if (CS == ForRange->getBody())
463         CounterMap[ForRange->getBody()] = NextCounter++;
464     }
465 
466     Base::TraverseCXXForRangeStmt(ForRange);
467     if (Hash.getHashVersion() != PGO_HASH_V1)
468       Hash.combine(PGOHash::EndOfScope);
469     return true;
470   }
471 
472 // If the statement type \p N is nestable, and its nesting impacts profile
473 // stability, define a custom traversal which tracks the end of the statement
474 // in the hash (provided we're not using the V1 hash).
475 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
476   bool Traverse##N(N *S) {                                                     \
477     Base::Traverse##N(S);                                                      \
478     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
479       Hash.combine(PGOHash::EndOfScope);                                       \
480     return true;                                                               \
481   }
482 
483   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
484   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
485   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
486 
487   /// Get version \p HashVersion of the PGO hash for \p S.
488   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
489     switch (S->getStmtClass()) {
490     default:
491       break;
492     case Stmt::LabelStmtClass:
493       return PGOHash::LabelStmt;
494     case Stmt::WhileStmtClass:
495       return PGOHash::WhileStmt;
496     case Stmt::DoStmtClass:
497       return PGOHash::DoStmt;
498     case Stmt::ForStmtClass:
499       return PGOHash::ForStmt;
500     case Stmt::CXXForRangeStmtClass:
501       return PGOHash::CXXForRangeStmt;
502     case Stmt::ObjCForCollectionStmtClass:
503       return PGOHash::ObjCForCollectionStmt;
504     case Stmt::SwitchStmtClass:
505       return PGOHash::SwitchStmt;
506     case Stmt::CaseStmtClass:
507       return PGOHash::CaseStmt;
508     case Stmt::DefaultStmtClass:
509       return PGOHash::DefaultStmt;
510     case Stmt::IfStmtClass:
511       return PGOHash::IfStmt;
512     case Stmt::CXXTryStmtClass:
513       return PGOHash::CXXTryStmt;
514     case Stmt::CXXCatchStmtClass:
515       return PGOHash::CXXCatchStmt;
516     case Stmt::ConditionalOperatorClass:
517       return PGOHash::ConditionalOperator;
518     case Stmt::BinaryConditionalOperatorClass:
519       return PGOHash::BinaryConditionalOperator;
520     case Stmt::BinaryOperatorClass: {
521       const BinaryOperator *BO = cast<BinaryOperator>(S);
522       if (BO->getOpcode() == BO_LAnd)
523         return PGOHash::BinaryOperatorLAnd;
524       if (BO->getOpcode() == BO_LOr)
525         return PGOHash::BinaryOperatorLOr;
526       if (HashVersion >= PGO_HASH_V2) {
527         switch (BO->getOpcode()) {
528         default:
529           break;
530         case BO_LT:
531           return PGOHash::BinaryOperatorLT;
532         case BO_GT:
533           return PGOHash::BinaryOperatorGT;
534         case BO_LE:
535           return PGOHash::BinaryOperatorLE;
536         case BO_GE:
537           return PGOHash::BinaryOperatorGE;
538         case BO_EQ:
539           return PGOHash::BinaryOperatorEQ;
540         case BO_NE:
541           return PGOHash::BinaryOperatorNE;
542         }
543       }
544       break;
545     }
546     }
547 
548     if (HashVersion >= PGO_HASH_V2) {
549       switch (S->getStmtClass()) {
550       default:
551         break;
552       case Stmt::GotoStmtClass:
553         return PGOHash::GotoStmt;
554       case Stmt::IndirectGotoStmtClass:
555         return PGOHash::IndirectGotoStmt;
556       case Stmt::BreakStmtClass:
557         return PGOHash::BreakStmt;
558       case Stmt::ContinueStmtClass:
559         return PGOHash::ContinueStmt;
560       case Stmt::ReturnStmtClass:
561         return PGOHash::ReturnStmt;
562       case Stmt::CXXThrowExprClass:
563         return PGOHash::ThrowExpr;
564       case Stmt::UnaryOperatorClass: {
565         const UnaryOperator *UO = cast<UnaryOperator>(S);
566         if (UO->getOpcode() == UO_LNot)
567           return PGOHash::UnaryOperatorLNot;
568         break;
569       }
570       }
571     }
572 
573     return PGOHash::None;
574   }
575 };
576 
577 /// A StmtVisitor that propagates the raw counts through the AST and
578 /// records the count at statements where the value may change.
579 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
580   /// PGO state.
581   CodeGenPGO &PGO;
582 
583   /// A flag that is set when the current count should be recorded on the
584   /// next statement, such as at the exit of a loop.
585   bool RecordNextStmtCount;
586 
587   /// The count at the current location in the traversal.
588   uint64_t CurrentCount;
589 
590   /// The map of statements to count values.
591   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
592 
593   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
594   struct BreakContinue {
595     uint64_t BreakCount = 0;
596     uint64_t ContinueCount = 0;
597     BreakContinue() = default;
598   };
599   SmallVector<BreakContinue, 8> BreakContinueStack;
600 
601   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
602                       CodeGenPGO &PGO)
603       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
604 
605   void RecordStmtCount(const Stmt *S) {
606     if (RecordNextStmtCount) {
607       CountMap[S] = CurrentCount;
608       RecordNextStmtCount = false;
609     }
610   }
611 
612   /// Set and return the current count.
613   uint64_t setCount(uint64_t Count) {
614     CurrentCount = Count;
615     return Count;
616   }
617 
618   void VisitStmt(const Stmt *S) {
619     RecordStmtCount(S);
620     for (const Stmt *Child : S->children())
621       if (Child)
622         this->Visit(Child);
623   }
624 
625   void VisitFunctionDecl(const FunctionDecl *D) {
626     // Counter tracks entry to the function body.
627     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
628     CountMap[D->getBody()] = BodyCount;
629     Visit(D->getBody());
630   }
631 
632   // Skip lambda expressions. We visit these as FunctionDecls when we're
633   // generating them and aren't interested in the body when generating a
634   // parent context.
635   void VisitLambdaExpr(const LambdaExpr *LE) {}
636 
637   void VisitCapturedDecl(const CapturedDecl *D) {
638     // Counter tracks entry to the capture body.
639     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
640     CountMap[D->getBody()] = BodyCount;
641     Visit(D->getBody());
642   }
643 
644   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
645     // Counter tracks entry to the method body.
646     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
647     CountMap[D->getBody()] = BodyCount;
648     Visit(D->getBody());
649   }
650 
651   void VisitBlockDecl(const BlockDecl *D) {
652     // Counter tracks entry to the block body.
653     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
654     CountMap[D->getBody()] = BodyCount;
655     Visit(D->getBody());
656   }
657 
658   void VisitReturnStmt(const ReturnStmt *S) {
659     RecordStmtCount(S);
660     if (S->getRetValue())
661       Visit(S->getRetValue());
662     CurrentCount = 0;
663     RecordNextStmtCount = true;
664   }
665 
666   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
667     RecordStmtCount(E);
668     if (E->getSubExpr())
669       Visit(E->getSubExpr());
670     CurrentCount = 0;
671     RecordNextStmtCount = true;
672   }
673 
674   void VisitGotoStmt(const GotoStmt *S) {
675     RecordStmtCount(S);
676     CurrentCount = 0;
677     RecordNextStmtCount = true;
678   }
679 
680   void VisitLabelStmt(const LabelStmt *S) {
681     RecordNextStmtCount = false;
682     // Counter tracks the block following the label.
683     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
684     CountMap[S] = BlockCount;
685     Visit(S->getSubStmt());
686   }
687 
688   void VisitBreakStmt(const BreakStmt *S) {
689     RecordStmtCount(S);
690     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
691     BreakContinueStack.back().BreakCount += CurrentCount;
692     CurrentCount = 0;
693     RecordNextStmtCount = true;
694   }
695 
696   void VisitContinueStmt(const ContinueStmt *S) {
697     RecordStmtCount(S);
698     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
699     BreakContinueStack.back().ContinueCount += CurrentCount;
700     CurrentCount = 0;
701     RecordNextStmtCount = true;
702   }
703 
704   void VisitWhileStmt(const WhileStmt *S) {
705     RecordStmtCount(S);
706     uint64_t ParentCount = CurrentCount;
707 
708     BreakContinueStack.push_back(BreakContinue());
709     // Visit the body region first so the break/continue adjustments can be
710     // included when visiting the condition.
711     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
712     CountMap[S->getBody()] = CurrentCount;
713     Visit(S->getBody());
714     uint64_t BackedgeCount = CurrentCount;
715 
716     // ...then go back and propagate counts through the condition. The count
717     // at the start of the condition is the sum of the incoming edges,
718     // the backedge from the end of the loop body, and the edges from
719     // continue statements.
720     BreakContinue BC = BreakContinueStack.pop_back_val();
721     uint64_t CondCount =
722         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
723     CountMap[S->getCond()] = CondCount;
724     Visit(S->getCond());
725     setCount(BC.BreakCount + CondCount - BodyCount);
726     RecordNextStmtCount = true;
727   }
728 
729   void VisitDoStmt(const DoStmt *S) {
730     RecordStmtCount(S);
731     uint64_t LoopCount = PGO.getRegionCount(S);
732 
733     BreakContinueStack.push_back(BreakContinue());
734     // The count doesn't include the fallthrough from the parent scope. Add it.
735     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
736     CountMap[S->getBody()] = BodyCount;
737     Visit(S->getBody());
738     uint64_t BackedgeCount = CurrentCount;
739 
740     BreakContinue BC = BreakContinueStack.pop_back_val();
741     // The count at the start of the condition is equal to the count at the
742     // end of the body, plus any continues.
743     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
744     CountMap[S->getCond()] = CondCount;
745     Visit(S->getCond());
746     setCount(BC.BreakCount + CondCount - LoopCount);
747     RecordNextStmtCount = true;
748   }
749 
750   void VisitForStmt(const ForStmt *S) {
751     RecordStmtCount(S);
752     if (S->getInit())
753       Visit(S->getInit());
754 
755     uint64_t ParentCount = CurrentCount;
756 
757     BreakContinueStack.push_back(BreakContinue());
758     // Visit the body region first. (This is basically the same as a while
759     // loop; see further comments in VisitWhileStmt.)
760     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
761     CountMap[S->getBody()] = BodyCount;
762     Visit(S->getBody());
763     uint64_t BackedgeCount = CurrentCount;
764     BreakContinue BC = BreakContinueStack.pop_back_val();
765 
766     // The increment is essentially part of the body but it needs to include
767     // the count for all the continue statements.
768     if (S->getInc()) {
769       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
770       CountMap[S->getInc()] = IncCount;
771       Visit(S->getInc());
772     }
773 
774     // ...then go back and propagate counts through the condition.
775     uint64_t CondCount =
776         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
777     if (S->getCond()) {
778       CountMap[S->getCond()] = CondCount;
779       Visit(S->getCond());
780     }
781     setCount(BC.BreakCount + CondCount - BodyCount);
782     RecordNextStmtCount = true;
783   }
784 
785   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
786     RecordStmtCount(S);
787     if (S->getInit())
788       Visit(S->getInit());
789     Visit(S->getLoopVarStmt());
790     Visit(S->getRangeStmt());
791     Visit(S->getBeginStmt());
792     Visit(S->getEndStmt());
793 
794     uint64_t ParentCount = CurrentCount;
795     BreakContinueStack.push_back(BreakContinue());
796     // Visit the body region first. (This is basically the same as a while
797     // loop; see further comments in VisitWhileStmt.)
798     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
799     CountMap[S->getBody()] = BodyCount;
800     Visit(S->getBody());
801     uint64_t BackedgeCount = CurrentCount;
802     BreakContinue BC = BreakContinueStack.pop_back_val();
803 
804     // The increment is essentially part of the body but it needs to include
805     // the count for all the continue statements.
806     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
807     CountMap[S->getInc()] = IncCount;
808     Visit(S->getInc());
809 
810     // ...then go back and propagate counts through the condition.
811     uint64_t CondCount =
812         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
813     CountMap[S->getCond()] = CondCount;
814     Visit(S->getCond());
815     setCount(BC.BreakCount + CondCount - BodyCount);
816     RecordNextStmtCount = true;
817   }
818 
819   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
820     RecordStmtCount(S);
821     Visit(S->getElement());
822     uint64_t ParentCount = CurrentCount;
823     BreakContinueStack.push_back(BreakContinue());
824     // Counter tracks the body of the loop.
825     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
826     CountMap[S->getBody()] = BodyCount;
827     Visit(S->getBody());
828     uint64_t BackedgeCount = CurrentCount;
829     BreakContinue BC = BreakContinueStack.pop_back_val();
830 
831     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
832              BodyCount);
833     RecordNextStmtCount = true;
834   }
835 
836   void VisitSwitchStmt(const SwitchStmt *S) {
837     RecordStmtCount(S);
838     if (S->getInit())
839       Visit(S->getInit());
840     Visit(S->getCond());
841     CurrentCount = 0;
842     BreakContinueStack.push_back(BreakContinue());
843     Visit(S->getBody());
844     // If the switch is inside a loop, add the continue counts.
845     BreakContinue BC = BreakContinueStack.pop_back_val();
846     if (!BreakContinueStack.empty())
847       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
848     // Counter tracks the exit block of the switch.
849     setCount(PGO.getRegionCount(S));
850     RecordNextStmtCount = true;
851   }
852 
853   void VisitSwitchCase(const SwitchCase *S) {
854     RecordNextStmtCount = false;
855     // Counter for this particular case. This counts only jumps from the
856     // switch header and does not include fallthrough from the case before
857     // this one.
858     uint64_t CaseCount = PGO.getRegionCount(S);
859     setCount(CurrentCount + CaseCount);
860     // We need the count without fallthrough in the mapping, so it's more useful
861     // for branch probabilities.
862     CountMap[S] = CaseCount;
863     RecordNextStmtCount = true;
864     Visit(S->getSubStmt());
865   }
866 
867   void VisitIfStmt(const IfStmt *S) {
868     RecordStmtCount(S);
869 
870     if (S->isConsteval()) {
871       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
872       if (Stm)
873         Visit(Stm);
874       return;
875     }
876 
877     uint64_t ParentCount = CurrentCount;
878     if (S->getInit())
879       Visit(S->getInit());
880     Visit(S->getCond());
881 
882     // Counter tracks the "then" part of an if statement. The count for
883     // the "else" part, if it exists, will be calculated from this counter.
884     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
885     CountMap[S->getThen()] = ThenCount;
886     Visit(S->getThen());
887     uint64_t OutCount = CurrentCount;
888 
889     uint64_t ElseCount = ParentCount - ThenCount;
890     if (S->getElse()) {
891       setCount(ElseCount);
892       CountMap[S->getElse()] = ElseCount;
893       Visit(S->getElse());
894       OutCount += CurrentCount;
895     } else
896       OutCount += ElseCount;
897     setCount(OutCount);
898     RecordNextStmtCount = true;
899   }
900 
901   void VisitCXXTryStmt(const CXXTryStmt *S) {
902     RecordStmtCount(S);
903     Visit(S->getTryBlock());
904     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
905       Visit(S->getHandler(I));
906     // Counter tracks the continuation block of the try statement.
907     setCount(PGO.getRegionCount(S));
908     RecordNextStmtCount = true;
909   }
910 
911   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
912     RecordNextStmtCount = false;
913     // Counter tracks the catch statement's handler block.
914     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
915     CountMap[S] = CatchCount;
916     Visit(S->getHandlerBlock());
917   }
918 
919   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
920     RecordStmtCount(E);
921     uint64_t ParentCount = CurrentCount;
922     Visit(E->getCond());
923 
924     // Counter tracks the "true" part of a conditional operator. The
925     // count in the "false" part will be calculated from this counter.
926     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
927     CountMap[E->getTrueExpr()] = TrueCount;
928     Visit(E->getTrueExpr());
929     uint64_t OutCount = CurrentCount;
930 
931     uint64_t FalseCount = setCount(ParentCount - TrueCount);
932     CountMap[E->getFalseExpr()] = FalseCount;
933     Visit(E->getFalseExpr());
934     OutCount += CurrentCount;
935 
936     setCount(OutCount);
937     RecordNextStmtCount = true;
938   }
939 
940   void VisitBinLAnd(const BinaryOperator *E) {
941     RecordStmtCount(E);
942     uint64_t ParentCount = CurrentCount;
943     Visit(E->getLHS());
944     // Counter tracks the right hand side of a logical and operator.
945     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
946     CountMap[E->getRHS()] = RHSCount;
947     Visit(E->getRHS());
948     setCount(ParentCount + RHSCount - CurrentCount);
949     RecordNextStmtCount = true;
950   }
951 
952   void VisitBinLOr(const BinaryOperator *E) {
953     RecordStmtCount(E);
954     uint64_t ParentCount = CurrentCount;
955     Visit(E->getLHS());
956     // Counter tracks the right hand side of a logical or operator.
957     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
958     CountMap[E->getRHS()] = RHSCount;
959     Visit(E->getRHS());
960     setCount(ParentCount + RHSCount - CurrentCount);
961     RecordNextStmtCount = true;
962   }
963 };
964 } // end anonymous namespace
965 
966 void PGOHash::combine(HashType Type) {
967   // Check that we never combine 0 and only have six bits.
968   assert(Type && "Hash is invalid: unexpected type 0");
969   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
970 
971   // Pass through MD5 if enough work has built up.
972   if (Count && Count % NumTypesPerWord == 0) {
973     using namespace llvm::support;
974     uint64_t Swapped =
975         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
976     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
977     Working = 0;
978   }
979 
980   // Accumulate the current type.
981   ++Count;
982   Working = Working << NumBitsPerType | Type;
983 }
984 
985 uint64_t PGOHash::finalize() {
986   // Use Working as the hash directly if we never used MD5.
987   if (Count <= NumTypesPerWord)
988     // No need to byte swap here, since none of the math was endian-dependent.
989     // This number will be byte-swapped as required on endianness transitions,
990     // so we will see the same value on the other side.
991     return Working;
992 
993   // Check for remaining work in Working.
994   if (Working) {
995     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
996     // is buggy because it converts a uint64_t into an array of uint8_t.
997     if (HashVersion < PGO_HASH_V3) {
998       MD5.update({(uint8_t)Working});
999     } else {
1000       using namespace llvm::support;
1001       uint64_t Swapped =
1002           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
1003       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1004     }
1005   }
1006 
1007   // Finalize the MD5 and return the hash.
1008   llvm::MD5::MD5Result Result;
1009   MD5.final(Result);
1010   return Result.low();
1011 }
1012 
1013 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1014   const Decl *D = GD.getDecl();
1015   if (!D->hasBody())
1016     return;
1017 
1018   // Skip CUDA/HIP kernel launch stub functions.
1019   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1020       D->hasAttr<CUDAGlobalAttr>())
1021     return;
1022 
1023   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1024   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1025   if (!InstrumentRegions && !PGOReader)
1026     return;
1027   if (D->isImplicit())
1028     return;
1029   // Constructors and destructors may be represented by several functions in IR.
1030   // If so, instrument only base variant, others are implemented by delegation
1031   // to the base one, it would be counted twice otherwise.
1032   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1033     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
1034       if (GD.getCtorType() != Ctor_Base &&
1035           CodeGenFunction::IsConstructorDelegationValid(CCD))
1036         return;
1037   }
1038   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
1039     return;
1040 
1041   CGM.ClearUnusedCoverageMapping(D);
1042   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
1043     return;
1044   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
1045     return;
1046 
1047   SourceManager &SM = CGM.getContext().getSourceManager();
1048   if (!llvm::coverage::SystemHeadersCoverage &&
1049       SM.isInSystemHeader(D->getLocation()))
1050     return;
1051 
1052   setFuncName(Fn);
1053 
1054   mapRegionCounters(D);
1055   if (CGM.getCodeGenOpts().CoverageMapping)
1056     emitCounterRegionMapping(D);
1057   if (PGOReader) {
1058     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
1059     computeRegionCounts(D);
1060     applyFunctionAttributes(PGOReader, Fn);
1061   }
1062 }
1063 
1064 void CodeGenPGO::mapRegionCounters(const Decl *D) {
1065   // Use the latest hash version when inserting instrumentation, but use the
1066   // version in the indexed profile if we're reading PGO data.
1067   PGOHashVersion HashVersion = PGO_HASH_LATEST;
1068   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1069   if (auto *PGOReader = CGM.getPGOReader()) {
1070     HashVersion = getPGOHashVersion(PGOReader, CGM);
1071     ProfileVersion = PGOReader->getVersion();
1072   }
1073 
1074   // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1075   // set it to zero. This value impacts the number of conditions accepted in a
1076   // given boolean expression, which impacts the size of the bitmap used to
1077   // track test vector execution for that boolean expression.  Because the
1078   // bitmap scales exponentially (2^n) based on the number of conditions seen,
1079   // the maximum value is hard-coded at 6 conditions, which is more than enough
1080   // for most embedded applications. Setting a maximum value prevents the
1081   // bitmap footprint from growing too large without the user's knowledge. In
1082   // the future, this value could be adjusted with a command-line option.
1083   unsigned MCDCMaxConditions =
1084       (CGM.getCodeGenOpts().MCDCCoverage ? CGM.getCodeGenOpts().MCDCMaxConds
1085                                          : 0);
1086 
1087   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1088   RegionMCDCState.reset(new MCDC::State);
1089   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1090                            *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1091   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1092     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
1093   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1094     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1095   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1096     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1097   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1098     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1099   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1100   NumRegionCounters = Walker.NextCounter;
1101   FunctionHash = Walker.Hash.finalize();
1102 }
1103 
1104 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1105   if (!D->getBody())
1106     return true;
1107 
1108   // Skip host-only functions in the CUDA device compilation and device-only
1109   // functions in the host compilation. Just roughly filter them out based on
1110   // the function attributes. If there are effectively host-only or device-only
1111   // ones, their coverage mapping may still be generated.
1112   if (CGM.getLangOpts().CUDA &&
1113       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1114         !D->hasAttr<CUDAGlobalAttr>()) ||
1115        (!CGM.getLangOpts().CUDAIsDevice &&
1116         (D->hasAttr<CUDAGlobalAttr>() ||
1117          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1118     return true;
1119 
1120   // Don't map the functions in system headers.
1121   const auto &SM = CGM.getContext().getSourceManager();
1122   auto Loc = D->getBody()->getBeginLoc();
1123   return !llvm::coverage::SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1124 }
1125 
1126 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1127   if (skipRegionMappingForDecl(D))
1128     return;
1129 
1130   std::string CoverageMapping;
1131   llvm::raw_string_ostream OS(CoverageMapping);
1132   RegionMCDCState->BranchByStmt.clear();
1133   CoverageMappingGen MappingGen(
1134       *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1135       CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1136   MappingGen.emitCounterMapping(D, OS);
1137   OS.flush();
1138 
1139   if (CoverageMapping.empty())
1140     return;
1141 
1142   CGM.getCoverageMapping()->addFunctionMappingRecord(
1143       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1144 }
1145 
1146 void
1147 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1148                                     llvm::GlobalValue::LinkageTypes Linkage) {
1149   if (skipRegionMappingForDecl(D))
1150     return;
1151 
1152   std::string CoverageMapping;
1153   llvm::raw_string_ostream OS(CoverageMapping);
1154   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1155                                 CGM.getContext().getSourceManager(),
1156                                 CGM.getLangOpts());
1157   MappingGen.emitEmptyMapping(D, OS);
1158   OS.flush();
1159 
1160   if (CoverageMapping.empty())
1161     return;
1162 
1163   setFuncName(Name, Linkage);
1164   CGM.getCoverageMapping()->addFunctionMappingRecord(
1165       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1166 }
1167 
1168 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1169   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1170   ComputeRegionCounts Walker(*StmtCountMap, *this);
1171   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1172     Walker.VisitFunctionDecl(FD);
1173   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1174     Walker.VisitObjCMethodDecl(MD);
1175   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1176     Walker.VisitBlockDecl(BD);
1177   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1178     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1179 }
1180 
1181 void
1182 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1183                                     llvm::Function *Fn) {
1184   if (!haveRegionCounts())
1185     return;
1186 
1187   uint64_t FunctionCount = getRegionCount(nullptr);
1188   Fn->setEntryCount(FunctionCount);
1189 }
1190 
1191 void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1192                                            llvm::Value *StepV) {
1193   if (!RegionCounterMap || !Builder.GetInsertBlock())
1194     return;
1195 
1196   unsigned Counter = (*RegionCounterMap)[S];
1197 
1198   llvm::Value *Args[] = {FuncNameVar,
1199                          Builder.getInt64(FunctionHash),
1200                          Builder.getInt32(NumRegionCounters),
1201                          Builder.getInt32(Counter), StepV};
1202 
1203   if (llvm::EnableSingleByteCoverage)
1204     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_cover),
1205                        ArrayRef(Args, 4));
1206   else {
1207     if (!StepV)
1208       Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1209                          ArrayRef(Args, 4));
1210     else
1211       Builder.CreateCall(
1212           CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), Args);
1213   }
1214 }
1215 
1216 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1217   return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1218           CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1219 }
1220 
1221 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1222   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1223     return;
1224 
1225   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1226 
1227   // Emit intrinsic representing MCDC bitmap parameters at function entry.
1228   // This is used by the instrumentation pass, but it isn't actually lowered to
1229   // anything.
1230   llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1231                           Builder.getInt64(FunctionHash),
1232                           Builder.getInt32(RegionMCDCState->BitmapBits)};
1233   Builder.CreateCall(
1234       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1235 }
1236 
1237 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1238                                                 const Expr *S,
1239                                                 Address MCDCCondBitmapAddr,
1240                                                 CodeGenFunction &CGF) {
1241   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1242     return;
1243 
1244   S = S->IgnoreParens();
1245 
1246   auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(S);
1247   if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1248     return;
1249 
1250   // Don't create tvbitmap_update if the record is allocated but excluded.
1251   // Or `bitmap |= (1 << 0)` would be wrongly executed to the next bitmap.
1252   if (DecisionStateIter->second.Indices.size() == 0)
1253     return;
1254 
1255   // Extract the offset of the global bitmap associated with this expression.
1256   unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1257   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1258 
1259   // Emit intrinsic responsible for updating the global bitmap corresponding to
1260   // a boolean expression. The index being set is based on the value loaded
1261   // from a pointer to a dedicated temporary value on the stack that is itself
1262   // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1263   // index represents an executed test vector.
1264   llvm::Value *Args[4] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1265                           Builder.getInt64(FunctionHash),
1266                           Builder.getInt32(MCDCTestVectorBitmapOffset),
1267                           MCDCCondBitmapAddr.emitRawPointer(CGF)};
1268   Builder.CreateCall(
1269       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1270 }
1271 
1272 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1273                                          Address MCDCCondBitmapAddr) {
1274   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1275     return;
1276 
1277   S = S->IgnoreParens();
1278 
1279   if (!RegionMCDCState->DecisionByStmt.contains(S))
1280     return;
1281 
1282   // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1283   Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1284 }
1285 
1286 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1287                                           Address MCDCCondBitmapAddr,
1288                                           llvm::Value *Val,
1289                                           CodeGenFunction &CGF) {
1290   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1291     return;
1292 
1293   // Even though, for simplicity, parentheses and unary logical-NOT operators
1294   // are considered part of their underlying condition for both MC/DC and
1295   // branch coverage, the condition IDs themselves are assigned and tracked
1296   // using the underlying condition itself.  This is done solely for
1297   // consistency since parentheses and logical-NOTs are ignored when checking
1298   // whether the condition is actually an instrumentable condition. This can
1299   // also make debugging a bit easier.
1300   S = CodeGenFunction::stripCond(S);
1301 
1302   auto BranchStateIter = RegionMCDCState->BranchByStmt.find(S);
1303   if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1304     return;
1305 
1306   // Extract the ID of the condition we are setting in the bitmap.
1307   const auto &Branch = BranchStateIter->second;
1308   assert(Branch.ID >= 0 && "Condition has no ID!");
1309   assert(Branch.DecisionStmt);
1310 
1311   // Cancel the emission if the Decision is erased after the allocation.
1312   const auto DecisionIter =
1313       RegionMCDCState->DecisionByStmt.find(Branch.DecisionStmt);
1314   if (DecisionIter == RegionMCDCState->DecisionByStmt.end())
1315     return;
1316 
1317   const auto &TVIdxs = DecisionIter->second.Indices[Branch.ID];
1318 
1319   auto *CurTV = Builder.CreateLoad(MCDCCondBitmapAddr,
1320                                    "mcdc." + Twine(Branch.ID + 1) + ".cur");
1321   auto *NewTV = Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[true]));
1322   NewTV = Builder.CreateSelect(
1323       Val, NewTV, Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[false])));
1324   Builder.CreateStore(NewTV, MCDCCondBitmapAddr);
1325 }
1326 
1327 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1328   if (CGM.getCodeGenOpts().hasProfileClangInstr())
1329     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1330                     uint32_t(EnableValueProfiling));
1331 }
1332 
1333 void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1334   if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1335       llvm::EnableSingleByteCoverage) {
1336     const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1337     llvm::Type *IntTy64 = llvm::Type::getInt64Ty(M.getContext());
1338     uint64_t ProfileVersion =
1339         (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1340 
1341     auto IRLevelVersionVariable = new llvm::GlobalVariable(
1342         M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1343         llvm::Constant::getIntegerValue(IntTy64,
1344                                         llvm::APInt(64, ProfileVersion)),
1345         VarName);
1346 
1347     IRLevelVersionVariable->setVisibility(llvm::GlobalValue::HiddenVisibility);
1348     llvm::Triple TT(M.getTargetTriple());
1349     if (TT.supportsCOMDAT()) {
1350       IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1351       IRLevelVersionVariable->setComdat(M.getOrInsertComdat(VarName));
1352     }
1353     IRLevelVersionVariable->setDSOLocal(true);
1354   }
1355 }
1356 
1357 // This method either inserts a call to the profile run-time during
1358 // instrumentation or puts profile data into metadata for PGO use.
1359 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1360     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1361 
1362   if (!EnableValueProfiling)
1363     return;
1364 
1365   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1366     return;
1367 
1368   if (isa<llvm::Constant>(ValuePtr))
1369     return;
1370 
1371   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1372   if (InstrumentValueSites && RegionCounterMap) {
1373     auto BuilderInsertPoint = Builder.saveIP();
1374     Builder.SetInsertPoint(ValueSite);
1375     llvm::Value *Args[5] = {
1376         FuncNameVar,
1377         Builder.getInt64(FunctionHash),
1378         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1379         Builder.getInt32(ValueKind),
1380         Builder.getInt32(NumValueSites[ValueKind]++)
1381     };
1382     Builder.CreateCall(
1383         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1384     Builder.restoreIP(BuilderInsertPoint);
1385     return;
1386   }
1387 
1388   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1389   if (PGOReader && haveRegionCounts()) {
1390     // We record the top most called three functions at each call site.
1391     // Profile metadata contains "VP" string identifying this metadata
1392     // as value profiling data, then a uint32_t value for the value profiling
1393     // kind, a uint64_t value for the total number of times the call is
1394     // executed, followed by the function hash and execution count (uint64_t)
1395     // pairs for each function.
1396     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1397       return;
1398 
1399     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1400                             (llvm::InstrProfValueKind)ValueKind,
1401                             NumValueSites[ValueKind]);
1402 
1403     NumValueSites[ValueKind]++;
1404   }
1405 }
1406 
1407 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1408                                   bool IsInMainFile) {
1409   CGM.getPGOStats().addVisited(IsInMainFile);
1410   RegionCounts.clear();
1411   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1412       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1413   if (auto E = RecordExpected.takeError()) {
1414     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1415     if (IPE == llvm::instrprof_error::unknown_function)
1416       CGM.getPGOStats().addMissing(IsInMainFile);
1417     else if (IPE == llvm::instrprof_error::hash_mismatch)
1418       CGM.getPGOStats().addMismatched(IsInMainFile);
1419     else if (IPE == llvm::instrprof_error::malformed)
1420       // TODO: Consider a more specific warning for this case.
1421       CGM.getPGOStats().addMismatched(IsInMainFile);
1422     return;
1423   }
1424   ProfRecord =
1425       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1426   RegionCounts = ProfRecord->Counts;
1427 }
1428 
1429 /// Calculate what to divide by to scale weights.
1430 ///
1431 /// Given the maximum weight, calculate a divisor that will scale all the
1432 /// weights to strictly less than UINT32_MAX.
1433 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1434   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1435 }
1436 
1437 /// Scale an individual branch weight (and add 1).
1438 ///
1439 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1440 ///
1441 /// According to Laplace's Rule of Succession, it is better to compute the
1442 /// weight based on the count plus 1, so universally add 1 to the value.
1443 ///
1444 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1445 /// greater than \c Weight.
1446 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1447   assert(Scale && "scale by 0?");
1448   uint64_t Scaled = Weight / Scale + 1;
1449   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1450   return Scaled;
1451 }
1452 
1453 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1454                                                     uint64_t FalseCount) const {
1455   // Check for empty weights.
1456   if (!TrueCount && !FalseCount)
1457     return nullptr;
1458 
1459   // Calculate how to scale down to 32-bits.
1460   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1461 
1462   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1463   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1464                                       scaleBranchWeight(FalseCount, Scale));
1465 }
1466 
1467 llvm::MDNode *
1468 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1469   // We need at least two elements to create meaningful weights.
1470   if (Weights.size() < 2)
1471     return nullptr;
1472 
1473   // Check for empty weights.
1474   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1475   if (MaxWeight == 0)
1476     return nullptr;
1477 
1478   // Calculate how to scale down to 32-bits.
1479   uint64_t Scale = calculateWeightScale(MaxWeight);
1480 
1481   SmallVector<uint32_t, 16> ScaledWeights;
1482   ScaledWeights.reserve(Weights.size());
1483   for (uint64_t W : Weights)
1484     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1485 
1486   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1487   return MDHelper.createBranchWeights(ScaledWeights);
1488 }
1489 
1490 llvm::MDNode *
1491 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1492                                              uint64_t LoopCount) const {
1493   if (!PGO.haveRegionCounts())
1494     return nullptr;
1495   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1496   if (!CondCount || *CondCount == 0)
1497     return nullptr;
1498   return createProfileWeights(LoopCount,
1499                               std::max(*CondCount, LoopCount) - LoopCount);
1500 }
1501