xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 static llvm::cl::opt<bool>
27     EnableValueProfiling("enable-value-profiling",
28                          llvm::cl::desc("Enable value profiling"),
29                          llvm::cl::Hidden, llvm::cl::init(false));
30 
31 using namespace clang;
32 using namespace CodeGen;
33 
34 void CodeGenPGO::setFuncName(StringRef Name,
35                              llvm::GlobalValue::LinkageTypes Linkage) {
36   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37   FuncName = llvm::getPGOFuncName(
38       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40 
41   // If we're generating a profile, create a variable for the name.
42   if (CGM.getCodeGenOpts().hasProfileClangInstr())
43     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44 }
45 
46 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47   setFuncName(Fn->getName(), Fn->getLinkage());
48   // Create PGOFuncName meta data.
49   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50 }
51 
52 /// The version of the PGO hash algorithm.
53 enum PGOHashVersion : unsigned {
54   PGO_HASH_V1,
55   PGO_HASH_V2,
56   PGO_HASH_V3,
57 
58   // Keep this set to the latest hash version.
59   PGO_HASH_LATEST = PGO_HASH_V3
60 };
61 
62 namespace {
63 /// Stable hasher for PGO region counters.
64 ///
65 /// PGOHash produces a stable hash of a given function's control flow.
66 ///
67 /// Changing the output of this hash will invalidate all previously generated
68 /// profiles -- i.e., don't do it.
69 ///
70 /// \note  When this hash does eventually change (years?), we still need to
71 /// support old hashes.  We'll need to pull in the version number from the
72 /// profile data format and use the matching hash function.
73 class PGOHash {
74   uint64_t Working;
75   unsigned Count;
76   PGOHashVersion HashVersion;
77   llvm::MD5 MD5;
78 
79   static const int NumBitsPerType = 6;
80   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81   static const unsigned TooBig = 1u << NumBitsPerType;
82 
83 public:
84   /// Hash values for AST nodes.
85   ///
86   /// Distinct values for AST nodes that have region counters attached.
87   ///
88   /// These values must be stable.  All new members must be added at the end,
89   /// and no members should be removed.  Changing the enumeration value for an
90   /// AST node will affect the hash of every function that contains that node.
91   enum HashType : unsigned char {
92     None = 0,
93     LabelStmt = 1,
94     WhileStmt,
95     DoStmt,
96     ForStmt,
97     CXXForRangeStmt,
98     ObjCForCollectionStmt,
99     SwitchStmt,
100     CaseStmt,
101     DefaultStmt,
102     IfStmt,
103     CXXTryStmt,
104     CXXCatchStmt,
105     ConditionalOperator,
106     BinaryOperatorLAnd,
107     BinaryOperatorLOr,
108     BinaryConditionalOperator,
109     // The preceding values are available with PGO_HASH_V1.
110 
111     EndOfScope,
112     IfThenBranch,
113     IfElseBranch,
114     GotoStmt,
115     IndirectGotoStmt,
116     BreakStmt,
117     ContinueStmt,
118     ReturnStmt,
119     ThrowExpr,
120     UnaryOperatorLNot,
121     BinaryOperatorLT,
122     BinaryOperatorGT,
123     BinaryOperatorLE,
124     BinaryOperatorGE,
125     BinaryOperatorEQ,
126     BinaryOperatorNE,
127     // The preceding values are available since PGO_HASH_V2.
128 
129     // Keep this last.  It's for the static assert that follows.
130     LastHashType
131   };
132   static_assert(LastHashType <= TooBig, "Too many types in HashType");
133 
134   PGOHash(PGOHashVersion HashVersion)
135       : Working(0), Count(0), HashVersion(HashVersion) {}
136   void combine(HashType Type);
137   uint64_t finalize();
138   PGOHashVersion getHashVersion() const { return HashVersion; }
139 };
140 const int PGOHash::NumBitsPerType;
141 const unsigned PGOHash::NumTypesPerWord;
142 const unsigned PGOHash::TooBig;
143 
144 /// Get the PGO hash version used in the given indexed profile.
145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146                                         CodeGenModule &CGM) {
147   if (PGOReader->getVersion() <= 4)
148     return PGO_HASH_V1;
149   if (PGOReader->getVersion() <= 5)
150     return PGO_HASH_V2;
151   return PGO_HASH_V3;
152 }
153 
154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156   using Base = RecursiveASTVisitor<MapRegionCounters>;
157 
158   /// The next counter value to assign.
159   unsigned NextCounter;
160   /// The function hash.
161   PGOHash Hash;
162   /// The map of statements to counters.
163   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164   /// The next bitmap byte index to assign.
165   unsigned NextMCDCBitmapIdx;
166   /// The map of statements to MC/DC bitmap coverage objects.
167   llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap;
168   /// Maximum number of supported MC/DC conditions in a boolean expression.
169   unsigned MCDCMaxCond;
170   /// The profile version.
171   uint64_t ProfileVersion;
172   /// Diagnostics Engine used to report warnings.
173   DiagnosticsEngine &Diag;
174 
175   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
176                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
177                     llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap,
178                     unsigned MCDCMaxCond, DiagnosticsEngine &Diag)
179       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
180         NextMCDCBitmapIdx(0), MCDCBitmapMap(MCDCBitmapMap),
181         MCDCMaxCond(MCDCMaxCond), ProfileVersion(ProfileVersion), Diag(Diag) {}
182 
183   // Blocks and lambdas are handled as separate functions, so we need not
184   // traverse them in the parent context.
185   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
186   bool TraverseLambdaExpr(LambdaExpr *LE) {
187     // Traverse the captures, but not the body.
188     for (auto C : zip(LE->captures(), LE->capture_inits()))
189       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
190     return true;
191   }
192   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
193 
194   bool VisitDecl(const Decl *D) {
195     switch (D->getKind()) {
196     default:
197       break;
198     case Decl::Function:
199     case Decl::CXXMethod:
200     case Decl::CXXConstructor:
201     case Decl::CXXDestructor:
202     case Decl::CXXConversion:
203     case Decl::ObjCMethod:
204     case Decl::Block:
205     case Decl::Captured:
206       CounterMap[D->getBody()] = NextCounter++;
207       break;
208     }
209     return true;
210   }
211 
212   /// If \p S gets a fresh counter, update the counter mappings. Return the
213   /// V1 hash of \p S.
214   PGOHash::HashType updateCounterMappings(Stmt *S) {
215     auto Type = getHashType(PGO_HASH_V1, S);
216     if (Type != PGOHash::None)
217       CounterMap[S] = NextCounter++;
218     return Type;
219   }
220 
221   /// The following stacks are used with dataTraverseStmtPre() and
222   /// dataTraverseStmtPost() to track the depth of nested logical operators in a
223   /// boolean expression in a function.  The ultimate purpose is to keep track
224   /// of the number of leaf-level conditions in the boolean expression so that a
225   /// profile bitmap can be allocated based on that number.
226   ///
227   /// The stacks are also used to find error cases and notify the user.  A
228   /// standard logical operator nest for a boolean expression could be in a form
229   /// similar to this: "x = a && b && c && (d || f)"
230   unsigned NumCond = 0;
231   bool SplitNestedLogicalOp = false;
232   SmallVector<const Stmt *, 16> NonLogOpStack;
233   SmallVector<const BinaryOperator *, 16> LogOpStack;
234 
235   // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
236   bool dataTraverseStmtPre(Stmt *S) {
237     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
238     if (MCDCMaxCond == 0)
239       return true;
240 
241     /// At the top of the logical operator nest, reset the number of conditions.
242     if (LogOpStack.empty())
243       NumCond = 0;
244 
245     if (const Expr *E = dyn_cast<Expr>(S)) {
246       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
247       if (BinOp && BinOp->isLogicalOp()) {
248         /// Check for "split-nested" logical operators. This happens when a new
249         /// boolean expression logical-op nest is encountered within an existing
250         /// boolean expression, separated by a non-logical operator.  For
251         /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
252         /// starts a new boolean expression that is separated from the other
253         /// conditions by the operator foo(). Split-nested cases are not
254         /// supported by MC/DC.
255         SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
256 
257         LogOpStack.push_back(BinOp);
258         return true;
259       }
260     }
261 
262     /// Keep track of non-logical operators. These are OK as long as we don't
263     /// encounter a new logical operator after seeing one.
264     if (!LogOpStack.empty())
265       NonLogOpStack.push_back(S);
266 
267     return true;
268   }
269 
270   // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
271   // an AST Stmt node.  MC/DC will use it to to signal when the top of a
272   // logical operation (boolean expression) nest is encountered.
273   bool dataTraverseStmtPost(Stmt *S) {
274     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
275     if (MCDCMaxCond == 0)
276       return true;
277 
278     if (const Expr *E = dyn_cast<Expr>(S)) {
279       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
280       if (BinOp && BinOp->isLogicalOp()) {
281         assert(LogOpStack.back() == BinOp);
282         LogOpStack.pop_back();
283 
284         /// At the top of logical operator nest:
285         if (LogOpStack.empty()) {
286           /// Was the "split-nested" logical operator case encountered?
287           if (SplitNestedLogicalOp) {
288             unsigned DiagID = Diag.getCustomDiagID(
289                 DiagnosticsEngine::Warning,
290                 "unsupported MC/DC boolean expression; "
291                 "contains an operation with a nested boolean expression. "
292                 "Expression will not be covered");
293             Diag.Report(S->getBeginLoc(), DiagID);
294             return false;
295           }
296 
297           /// Was the maximum number of conditions encountered?
298           if (NumCond > MCDCMaxCond) {
299             unsigned DiagID = Diag.getCustomDiagID(
300                 DiagnosticsEngine::Warning,
301                 "unsupported MC/DC boolean expression; "
302                 "number of conditions (%0) exceeds max (%1). "
303                 "Expression will not be covered");
304             Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
305             return false;
306           }
307 
308           // Otherwise, allocate the number of bytes required for the bitmap
309           // based on the number of conditions. Must be at least 1-byte long.
310           MCDCBitmapMap[BinOp] = NextMCDCBitmapIdx;
311           unsigned SizeInBits = std::max<unsigned>(1L << NumCond, CHAR_BIT);
312           NextMCDCBitmapIdx += SizeInBits / CHAR_BIT;
313         }
314         return true;
315       }
316     }
317 
318     if (!LogOpStack.empty())
319       NonLogOpStack.pop_back();
320 
321     return true;
322   }
323 
324   /// The RHS of all logical operators gets a fresh counter in order to count
325   /// how many times the RHS evaluates to true or false, depending on the
326   /// semantics of the operator. This is only valid for ">= v7" of the profile
327   /// version so that we facilitate backward compatibility. In addition, in
328   /// order to use MC/DC, count the number of total LHS and RHS conditions.
329   bool VisitBinaryOperator(BinaryOperator *S) {
330     if (S->isLogicalOp()) {
331       if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
332         NumCond++;
333 
334       if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
335         if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
336           CounterMap[S->getRHS()] = NextCounter++;
337 
338         NumCond++;
339       }
340     }
341     return Base::VisitBinaryOperator(S);
342   }
343 
344   /// Include \p S in the function hash.
345   bool VisitStmt(Stmt *S) {
346     auto Type = updateCounterMappings(S);
347     if (Hash.getHashVersion() != PGO_HASH_V1)
348       Type = getHashType(Hash.getHashVersion(), S);
349     if (Type != PGOHash::None)
350       Hash.combine(Type);
351     return true;
352   }
353 
354   bool TraverseIfStmt(IfStmt *If) {
355     // If we used the V1 hash, use the default traversal.
356     if (Hash.getHashVersion() == PGO_HASH_V1)
357       return Base::TraverseIfStmt(If);
358 
359     // Otherwise, keep track of which branch we're in while traversing.
360     VisitStmt(If);
361     for (Stmt *CS : If->children()) {
362       if (!CS)
363         continue;
364       if (CS == If->getThen())
365         Hash.combine(PGOHash::IfThenBranch);
366       else if (CS == If->getElse())
367         Hash.combine(PGOHash::IfElseBranch);
368       TraverseStmt(CS);
369     }
370     Hash.combine(PGOHash::EndOfScope);
371     return true;
372   }
373 
374 // If the statement type \p N is nestable, and its nesting impacts profile
375 // stability, define a custom traversal which tracks the end of the statement
376 // in the hash (provided we're not using the V1 hash).
377 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
378   bool Traverse##N(N *S) {                                                     \
379     Base::Traverse##N(S);                                                      \
380     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
381       Hash.combine(PGOHash::EndOfScope);                                       \
382     return true;                                                               \
383   }
384 
385   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
386   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
387   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
388   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
389   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
390   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
391   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
392 
393   /// Get version \p HashVersion of the PGO hash for \p S.
394   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
395     switch (S->getStmtClass()) {
396     default:
397       break;
398     case Stmt::LabelStmtClass:
399       return PGOHash::LabelStmt;
400     case Stmt::WhileStmtClass:
401       return PGOHash::WhileStmt;
402     case Stmt::DoStmtClass:
403       return PGOHash::DoStmt;
404     case Stmt::ForStmtClass:
405       return PGOHash::ForStmt;
406     case Stmt::CXXForRangeStmtClass:
407       return PGOHash::CXXForRangeStmt;
408     case Stmt::ObjCForCollectionStmtClass:
409       return PGOHash::ObjCForCollectionStmt;
410     case Stmt::SwitchStmtClass:
411       return PGOHash::SwitchStmt;
412     case Stmt::CaseStmtClass:
413       return PGOHash::CaseStmt;
414     case Stmt::DefaultStmtClass:
415       return PGOHash::DefaultStmt;
416     case Stmt::IfStmtClass:
417       return PGOHash::IfStmt;
418     case Stmt::CXXTryStmtClass:
419       return PGOHash::CXXTryStmt;
420     case Stmt::CXXCatchStmtClass:
421       return PGOHash::CXXCatchStmt;
422     case Stmt::ConditionalOperatorClass:
423       return PGOHash::ConditionalOperator;
424     case Stmt::BinaryConditionalOperatorClass:
425       return PGOHash::BinaryConditionalOperator;
426     case Stmt::BinaryOperatorClass: {
427       const BinaryOperator *BO = cast<BinaryOperator>(S);
428       if (BO->getOpcode() == BO_LAnd)
429         return PGOHash::BinaryOperatorLAnd;
430       if (BO->getOpcode() == BO_LOr)
431         return PGOHash::BinaryOperatorLOr;
432       if (HashVersion >= PGO_HASH_V2) {
433         switch (BO->getOpcode()) {
434         default:
435           break;
436         case BO_LT:
437           return PGOHash::BinaryOperatorLT;
438         case BO_GT:
439           return PGOHash::BinaryOperatorGT;
440         case BO_LE:
441           return PGOHash::BinaryOperatorLE;
442         case BO_GE:
443           return PGOHash::BinaryOperatorGE;
444         case BO_EQ:
445           return PGOHash::BinaryOperatorEQ;
446         case BO_NE:
447           return PGOHash::BinaryOperatorNE;
448         }
449       }
450       break;
451     }
452     }
453 
454     if (HashVersion >= PGO_HASH_V2) {
455       switch (S->getStmtClass()) {
456       default:
457         break;
458       case Stmt::GotoStmtClass:
459         return PGOHash::GotoStmt;
460       case Stmt::IndirectGotoStmtClass:
461         return PGOHash::IndirectGotoStmt;
462       case Stmt::BreakStmtClass:
463         return PGOHash::BreakStmt;
464       case Stmt::ContinueStmtClass:
465         return PGOHash::ContinueStmt;
466       case Stmt::ReturnStmtClass:
467         return PGOHash::ReturnStmt;
468       case Stmt::CXXThrowExprClass:
469         return PGOHash::ThrowExpr;
470       case Stmt::UnaryOperatorClass: {
471         const UnaryOperator *UO = cast<UnaryOperator>(S);
472         if (UO->getOpcode() == UO_LNot)
473           return PGOHash::UnaryOperatorLNot;
474         break;
475       }
476       }
477     }
478 
479     return PGOHash::None;
480   }
481 };
482 
483 /// A StmtVisitor that propagates the raw counts through the AST and
484 /// records the count at statements where the value may change.
485 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
486   /// PGO state.
487   CodeGenPGO &PGO;
488 
489   /// A flag that is set when the current count should be recorded on the
490   /// next statement, such as at the exit of a loop.
491   bool RecordNextStmtCount;
492 
493   /// The count at the current location in the traversal.
494   uint64_t CurrentCount;
495 
496   /// The map of statements to count values.
497   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
498 
499   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
500   struct BreakContinue {
501     uint64_t BreakCount = 0;
502     uint64_t ContinueCount = 0;
503     BreakContinue() = default;
504   };
505   SmallVector<BreakContinue, 8> BreakContinueStack;
506 
507   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
508                       CodeGenPGO &PGO)
509       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
510 
511   void RecordStmtCount(const Stmt *S) {
512     if (RecordNextStmtCount) {
513       CountMap[S] = CurrentCount;
514       RecordNextStmtCount = false;
515     }
516   }
517 
518   /// Set and return the current count.
519   uint64_t setCount(uint64_t Count) {
520     CurrentCount = Count;
521     return Count;
522   }
523 
524   void VisitStmt(const Stmt *S) {
525     RecordStmtCount(S);
526     for (const Stmt *Child : S->children())
527       if (Child)
528         this->Visit(Child);
529   }
530 
531   void VisitFunctionDecl(const FunctionDecl *D) {
532     // Counter tracks entry to the function body.
533     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
534     CountMap[D->getBody()] = BodyCount;
535     Visit(D->getBody());
536   }
537 
538   // Skip lambda expressions. We visit these as FunctionDecls when we're
539   // generating them and aren't interested in the body when generating a
540   // parent context.
541   void VisitLambdaExpr(const LambdaExpr *LE) {}
542 
543   void VisitCapturedDecl(const CapturedDecl *D) {
544     // Counter tracks entry to the capture body.
545     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
546     CountMap[D->getBody()] = BodyCount;
547     Visit(D->getBody());
548   }
549 
550   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
551     // Counter tracks entry to the method body.
552     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
553     CountMap[D->getBody()] = BodyCount;
554     Visit(D->getBody());
555   }
556 
557   void VisitBlockDecl(const BlockDecl *D) {
558     // Counter tracks entry to the block body.
559     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
560     CountMap[D->getBody()] = BodyCount;
561     Visit(D->getBody());
562   }
563 
564   void VisitReturnStmt(const ReturnStmt *S) {
565     RecordStmtCount(S);
566     if (S->getRetValue())
567       Visit(S->getRetValue());
568     CurrentCount = 0;
569     RecordNextStmtCount = true;
570   }
571 
572   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
573     RecordStmtCount(E);
574     if (E->getSubExpr())
575       Visit(E->getSubExpr());
576     CurrentCount = 0;
577     RecordNextStmtCount = true;
578   }
579 
580   void VisitGotoStmt(const GotoStmt *S) {
581     RecordStmtCount(S);
582     CurrentCount = 0;
583     RecordNextStmtCount = true;
584   }
585 
586   void VisitLabelStmt(const LabelStmt *S) {
587     RecordNextStmtCount = false;
588     // Counter tracks the block following the label.
589     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
590     CountMap[S] = BlockCount;
591     Visit(S->getSubStmt());
592   }
593 
594   void VisitBreakStmt(const BreakStmt *S) {
595     RecordStmtCount(S);
596     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
597     BreakContinueStack.back().BreakCount += CurrentCount;
598     CurrentCount = 0;
599     RecordNextStmtCount = true;
600   }
601 
602   void VisitContinueStmt(const ContinueStmt *S) {
603     RecordStmtCount(S);
604     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
605     BreakContinueStack.back().ContinueCount += CurrentCount;
606     CurrentCount = 0;
607     RecordNextStmtCount = true;
608   }
609 
610   void VisitWhileStmt(const WhileStmt *S) {
611     RecordStmtCount(S);
612     uint64_t ParentCount = CurrentCount;
613 
614     BreakContinueStack.push_back(BreakContinue());
615     // Visit the body region first so the break/continue adjustments can be
616     // included when visiting the condition.
617     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
618     CountMap[S->getBody()] = CurrentCount;
619     Visit(S->getBody());
620     uint64_t BackedgeCount = CurrentCount;
621 
622     // ...then go back and propagate counts through the condition. The count
623     // at the start of the condition is the sum of the incoming edges,
624     // the backedge from the end of the loop body, and the edges from
625     // continue statements.
626     BreakContinue BC = BreakContinueStack.pop_back_val();
627     uint64_t CondCount =
628         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
629     CountMap[S->getCond()] = CondCount;
630     Visit(S->getCond());
631     setCount(BC.BreakCount + CondCount - BodyCount);
632     RecordNextStmtCount = true;
633   }
634 
635   void VisitDoStmt(const DoStmt *S) {
636     RecordStmtCount(S);
637     uint64_t LoopCount = PGO.getRegionCount(S);
638 
639     BreakContinueStack.push_back(BreakContinue());
640     // The count doesn't include the fallthrough from the parent scope. Add it.
641     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
642     CountMap[S->getBody()] = BodyCount;
643     Visit(S->getBody());
644     uint64_t BackedgeCount = CurrentCount;
645 
646     BreakContinue BC = BreakContinueStack.pop_back_val();
647     // The count at the start of the condition is equal to the count at the
648     // end of the body, plus any continues.
649     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
650     CountMap[S->getCond()] = CondCount;
651     Visit(S->getCond());
652     setCount(BC.BreakCount + CondCount - LoopCount);
653     RecordNextStmtCount = true;
654   }
655 
656   void VisitForStmt(const ForStmt *S) {
657     RecordStmtCount(S);
658     if (S->getInit())
659       Visit(S->getInit());
660 
661     uint64_t ParentCount = CurrentCount;
662 
663     BreakContinueStack.push_back(BreakContinue());
664     // Visit the body region first. (This is basically the same as a while
665     // loop; see further comments in VisitWhileStmt.)
666     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
667     CountMap[S->getBody()] = BodyCount;
668     Visit(S->getBody());
669     uint64_t BackedgeCount = CurrentCount;
670     BreakContinue BC = BreakContinueStack.pop_back_val();
671 
672     // The increment is essentially part of the body but it needs to include
673     // the count for all the continue statements.
674     if (S->getInc()) {
675       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
676       CountMap[S->getInc()] = IncCount;
677       Visit(S->getInc());
678     }
679 
680     // ...then go back and propagate counts through the condition.
681     uint64_t CondCount =
682         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
683     if (S->getCond()) {
684       CountMap[S->getCond()] = CondCount;
685       Visit(S->getCond());
686     }
687     setCount(BC.BreakCount + CondCount - BodyCount);
688     RecordNextStmtCount = true;
689   }
690 
691   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
692     RecordStmtCount(S);
693     if (S->getInit())
694       Visit(S->getInit());
695     Visit(S->getLoopVarStmt());
696     Visit(S->getRangeStmt());
697     Visit(S->getBeginStmt());
698     Visit(S->getEndStmt());
699 
700     uint64_t ParentCount = CurrentCount;
701     BreakContinueStack.push_back(BreakContinue());
702     // Visit the body region first. (This is basically the same as a while
703     // loop; see further comments in VisitWhileStmt.)
704     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
705     CountMap[S->getBody()] = BodyCount;
706     Visit(S->getBody());
707     uint64_t BackedgeCount = CurrentCount;
708     BreakContinue BC = BreakContinueStack.pop_back_val();
709 
710     // The increment is essentially part of the body but it needs to include
711     // the count for all the continue statements.
712     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
713     CountMap[S->getInc()] = IncCount;
714     Visit(S->getInc());
715 
716     // ...then go back and propagate counts through the condition.
717     uint64_t CondCount =
718         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
719     CountMap[S->getCond()] = CondCount;
720     Visit(S->getCond());
721     setCount(BC.BreakCount + CondCount - BodyCount);
722     RecordNextStmtCount = true;
723   }
724 
725   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
726     RecordStmtCount(S);
727     Visit(S->getElement());
728     uint64_t ParentCount = CurrentCount;
729     BreakContinueStack.push_back(BreakContinue());
730     // Counter tracks the body of the loop.
731     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
732     CountMap[S->getBody()] = BodyCount;
733     Visit(S->getBody());
734     uint64_t BackedgeCount = CurrentCount;
735     BreakContinue BC = BreakContinueStack.pop_back_val();
736 
737     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
738              BodyCount);
739     RecordNextStmtCount = true;
740   }
741 
742   void VisitSwitchStmt(const SwitchStmt *S) {
743     RecordStmtCount(S);
744     if (S->getInit())
745       Visit(S->getInit());
746     Visit(S->getCond());
747     CurrentCount = 0;
748     BreakContinueStack.push_back(BreakContinue());
749     Visit(S->getBody());
750     // If the switch is inside a loop, add the continue counts.
751     BreakContinue BC = BreakContinueStack.pop_back_val();
752     if (!BreakContinueStack.empty())
753       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
754     // Counter tracks the exit block of the switch.
755     setCount(PGO.getRegionCount(S));
756     RecordNextStmtCount = true;
757   }
758 
759   void VisitSwitchCase(const SwitchCase *S) {
760     RecordNextStmtCount = false;
761     // Counter for this particular case. This counts only jumps from the
762     // switch header and does not include fallthrough from the case before
763     // this one.
764     uint64_t CaseCount = PGO.getRegionCount(S);
765     setCount(CurrentCount + CaseCount);
766     // We need the count without fallthrough in the mapping, so it's more useful
767     // for branch probabilities.
768     CountMap[S] = CaseCount;
769     RecordNextStmtCount = true;
770     Visit(S->getSubStmt());
771   }
772 
773   void VisitIfStmt(const IfStmt *S) {
774     RecordStmtCount(S);
775 
776     if (S->isConsteval()) {
777       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
778       if (Stm)
779         Visit(Stm);
780       return;
781     }
782 
783     uint64_t ParentCount = CurrentCount;
784     if (S->getInit())
785       Visit(S->getInit());
786     Visit(S->getCond());
787 
788     // Counter tracks the "then" part of an if statement. The count for
789     // the "else" part, if it exists, will be calculated from this counter.
790     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
791     CountMap[S->getThen()] = ThenCount;
792     Visit(S->getThen());
793     uint64_t OutCount = CurrentCount;
794 
795     uint64_t ElseCount = ParentCount - ThenCount;
796     if (S->getElse()) {
797       setCount(ElseCount);
798       CountMap[S->getElse()] = ElseCount;
799       Visit(S->getElse());
800       OutCount += CurrentCount;
801     } else
802       OutCount += ElseCount;
803     setCount(OutCount);
804     RecordNextStmtCount = true;
805   }
806 
807   void VisitCXXTryStmt(const CXXTryStmt *S) {
808     RecordStmtCount(S);
809     Visit(S->getTryBlock());
810     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
811       Visit(S->getHandler(I));
812     // Counter tracks the continuation block of the try statement.
813     setCount(PGO.getRegionCount(S));
814     RecordNextStmtCount = true;
815   }
816 
817   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
818     RecordNextStmtCount = false;
819     // Counter tracks the catch statement's handler block.
820     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
821     CountMap[S] = CatchCount;
822     Visit(S->getHandlerBlock());
823   }
824 
825   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
826     RecordStmtCount(E);
827     uint64_t ParentCount = CurrentCount;
828     Visit(E->getCond());
829 
830     // Counter tracks the "true" part of a conditional operator. The
831     // count in the "false" part will be calculated from this counter.
832     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
833     CountMap[E->getTrueExpr()] = TrueCount;
834     Visit(E->getTrueExpr());
835     uint64_t OutCount = CurrentCount;
836 
837     uint64_t FalseCount = setCount(ParentCount - TrueCount);
838     CountMap[E->getFalseExpr()] = FalseCount;
839     Visit(E->getFalseExpr());
840     OutCount += CurrentCount;
841 
842     setCount(OutCount);
843     RecordNextStmtCount = true;
844   }
845 
846   void VisitBinLAnd(const BinaryOperator *E) {
847     RecordStmtCount(E);
848     uint64_t ParentCount = CurrentCount;
849     Visit(E->getLHS());
850     // Counter tracks the right hand side of a logical and operator.
851     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
852     CountMap[E->getRHS()] = RHSCount;
853     Visit(E->getRHS());
854     setCount(ParentCount + RHSCount - CurrentCount);
855     RecordNextStmtCount = true;
856   }
857 
858   void VisitBinLOr(const BinaryOperator *E) {
859     RecordStmtCount(E);
860     uint64_t ParentCount = CurrentCount;
861     Visit(E->getLHS());
862     // Counter tracks the right hand side of a logical or operator.
863     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
864     CountMap[E->getRHS()] = RHSCount;
865     Visit(E->getRHS());
866     setCount(ParentCount + RHSCount - CurrentCount);
867     RecordNextStmtCount = true;
868   }
869 };
870 } // end anonymous namespace
871 
872 void PGOHash::combine(HashType Type) {
873   // Check that we never combine 0 and only have six bits.
874   assert(Type && "Hash is invalid: unexpected type 0");
875   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
876 
877   // Pass through MD5 if enough work has built up.
878   if (Count && Count % NumTypesPerWord == 0) {
879     using namespace llvm::support;
880     uint64_t Swapped =
881         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
882     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
883     Working = 0;
884   }
885 
886   // Accumulate the current type.
887   ++Count;
888   Working = Working << NumBitsPerType | Type;
889 }
890 
891 uint64_t PGOHash::finalize() {
892   // Use Working as the hash directly if we never used MD5.
893   if (Count <= NumTypesPerWord)
894     // No need to byte swap here, since none of the math was endian-dependent.
895     // This number will be byte-swapped as required on endianness transitions,
896     // so we will see the same value on the other side.
897     return Working;
898 
899   // Check for remaining work in Working.
900   if (Working) {
901     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
902     // is buggy because it converts a uint64_t into an array of uint8_t.
903     if (HashVersion < PGO_HASH_V3) {
904       MD5.update({(uint8_t)Working});
905     } else {
906       using namespace llvm::support;
907       uint64_t Swapped =
908           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
909       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
910     }
911   }
912 
913   // Finalize the MD5 and return the hash.
914   llvm::MD5::MD5Result Result;
915   MD5.final(Result);
916   return Result.low();
917 }
918 
919 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
920   const Decl *D = GD.getDecl();
921   if (!D->hasBody())
922     return;
923 
924   // Skip CUDA/HIP kernel launch stub functions.
925   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
926       D->hasAttr<CUDAGlobalAttr>())
927     return;
928 
929   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
930   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
931   if (!InstrumentRegions && !PGOReader)
932     return;
933   if (D->isImplicit())
934     return;
935   // Constructors and destructors may be represented by several functions in IR.
936   // If so, instrument only base variant, others are implemented by delegation
937   // to the base one, it would be counted twice otherwise.
938   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
939     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
940       if (GD.getCtorType() != Ctor_Base &&
941           CodeGenFunction::IsConstructorDelegationValid(CCD))
942         return;
943   }
944   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
945     return;
946 
947   CGM.ClearUnusedCoverageMapping(D);
948   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
949     return;
950   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
951     return;
952 
953   setFuncName(Fn);
954 
955   mapRegionCounters(D);
956   if (CGM.getCodeGenOpts().CoverageMapping)
957     emitCounterRegionMapping(D);
958   if (PGOReader) {
959     SourceManager &SM = CGM.getContext().getSourceManager();
960     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
961     computeRegionCounts(D);
962     applyFunctionAttributes(PGOReader, Fn);
963   }
964 }
965 
966 void CodeGenPGO::mapRegionCounters(const Decl *D) {
967   // Use the latest hash version when inserting instrumentation, but use the
968   // version in the indexed profile if we're reading PGO data.
969   PGOHashVersion HashVersion = PGO_HASH_LATEST;
970   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
971   if (auto *PGOReader = CGM.getPGOReader()) {
972     HashVersion = getPGOHashVersion(PGOReader, CGM);
973     ProfileVersion = PGOReader->getVersion();
974   }
975 
976   // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
977   // set it to zero. This value impacts the number of conditions accepted in a
978   // given boolean expression, which impacts the size of the bitmap used to
979   // track test vector execution for that boolean expression.  Because the
980   // bitmap scales exponentially (2^n) based on the number of conditions seen,
981   // the maximum value is hard-coded at 6 conditions, which is more than enough
982   // for most embedded applications. Setting a maximum value prevents the
983   // bitmap footprint from growing too large without the user's knowledge. In
984   // the future, this value could be adjusted with a command-line option.
985   unsigned MCDCMaxConditions = (CGM.getCodeGenOpts().MCDCCoverage) ? 6 : 0;
986 
987   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
988   RegionMCDCBitmapMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
989   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
990                            *RegionMCDCBitmapMap, MCDCMaxConditions,
991                            CGM.getDiags());
992   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
993     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
994   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
995     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
996   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
997     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
998   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
999     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1000   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1001   NumRegionCounters = Walker.NextCounter;
1002   MCDCBitmapBytes = Walker.NextMCDCBitmapIdx;
1003   FunctionHash = Walker.Hash.finalize();
1004 }
1005 
1006 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1007   if (!D->getBody())
1008     return true;
1009 
1010   // Skip host-only functions in the CUDA device compilation and device-only
1011   // functions in the host compilation. Just roughly filter them out based on
1012   // the function attributes. If there are effectively host-only or device-only
1013   // ones, their coverage mapping may still be generated.
1014   if (CGM.getLangOpts().CUDA &&
1015       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1016         !D->hasAttr<CUDAGlobalAttr>()) ||
1017        (!CGM.getLangOpts().CUDAIsDevice &&
1018         (D->hasAttr<CUDAGlobalAttr>() ||
1019          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1020     return true;
1021 
1022   // Don't map the functions in system headers.
1023   const auto &SM = CGM.getContext().getSourceManager();
1024   auto Loc = D->getBody()->getBeginLoc();
1025   return SM.isInSystemHeader(Loc);
1026 }
1027 
1028 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1029   if (skipRegionMappingForDecl(D))
1030     return;
1031 
1032   std::string CoverageMapping;
1033   llvm::raw_string_ostream OS(CoverageMapping);
1034   RegionCondIDMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1035   CoverageMappingGen MappingGen(
1036       *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1037       CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCBitmapMap.get(),
1038       RegionCondIDMap.get());
1039   MappingGen.emitCounterMapping(D, OS);
1040   OS.flush();
1041 
1042   if (CoverageMapping.empty())
1043     return;
1044 
1045   CGM.getCoverageMapping()->addFunctionMappingRecord(
1046       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1047 }
1048 
1049 void
1050 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1051                                     llvm::GlobalValue::LinkageTypes Linkage) {
1052   if (skipRegionMappingForDecl(D))
1053     return;
1054 
1055   std::string CoverageMapping;
1056   llvm::raw_string_ostream OS(CoverageMapping);
1057   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1058                                 CGM.getContext().getSourceManager(),
1059                                 CGM.getLangOpts());
1060   MappingGen.emitEmptyMapping(D, OS);
1061   OS.flush();
1062 
1063   if (CoverageMapping.empty())
1064     return;
1065 
1066   setFuncName(Name, Linkage);
1067   CGM.getCoverageMapping()->addFunctionMappingRecord(
1068       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1069 }
1070 
1071 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1072   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1073   ComputeRegionCounts Walker(*StmtCountMap, *this);
1074   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1075     Walker.VisitFunctionDecl(FD);
1076   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1077     Walker.VisitObjCMethodDecl(MD);
1078   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1079     Walker.VisitBlockDecl(BD);
1080   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1081     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1082 }
1083 
1084 void
1085 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1086                                     llvm::Function *Fn) {
1087   if (!haveRegionCounts())
1088     return;
1089 
1090   uint64_t FunctionCount = getRegionCount(nullptr);
1091   Fn->setEntryCount(FunctionCount);
1092 }
1093 
1094 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
1095                                       llvm::Value *StepV) {
1096   if (!RegionCounterMap || !Builder.GetInsertBlock())
1097     return;
1098 
1099   unsigned Counter = (*RegionCounterMap)[S];
1100 
1101   llvm::Value *Args[] = {FuncNameVar,
1102                          Builder.getInt64(FunctionHash),
1103                          Builder.getInt32(NumRegionCounters),
1104                          Builder.getInt32(Counter), StepV};
1105   if (!StepV)
1106     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1107                        ArrayRef(Args, 4));
1108   else
1109     Builder.CreateCall(
1110         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
1111         ArrayRef(Args));
1112 }
1113 
1114 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1115   return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1116           CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1117 }
1118 
1119 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1120   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1121     return;
1122 
1123   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1124 
1125   // Emit intrinsic representing MCDC bitmap parameters at function entry.
1126   // This is used by the instrumentation pass, but it isn't actually lowered to
1127   // anything.
1128   llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1129                           Builder.getInt64(FunctionHash),
1130                           Builder.getInt32(MCDCBitmapBytes)};
1131   Builder.CreateCall(
1132       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1133 }
1134 
1135 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1136                                                 const Expr *S,
1137                                                 Address MCDCCondBitmapAddr) {
1138   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1139     return;
1140 
1141   S = S->IgnoreParens();
1142 
1143   auto ExprMCDCBitmapMapIterator = RegionMCDCBitmapMap->find(S);
1144   if (ExprMCDCBitmapMapIterator == RegionMCDCBitmapMap->end())
1145     return;
1146 
1147   // Extract the ID of the global bitmap associated with this expression.
1148   unsigned MCDCTestVectorBitmapID = ExprMCDCBitmapMapIterator->second;
1149   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1150 
1151   // Emit intrinsic responsible for updating the global bitmap corresponding to
1152   // a boolean expression. The index being set is based on the value loaded
1153   // from a pointer to a dedicated temporary value on the stack that is itself
1154   // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1155   // index represents an executed test vector.
1156   llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1157                           Builder.getInt64(FunctionHash),
1158                           Builder.getInt32(MCDCBitmapBytes),
1159                           Builder.getInt32(MCDCTestVectorBitmapID),
1160                           MCDCCondBitmapAddr.getPointer()};
1161   Builder.CreateCall(
1162       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1163 }
1164 
1165 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1166                                          Address MCDCCondBitmapAddr) {
1167   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1168     return;
1169 
1170   S = S->IgnoreParens();
1171 
1172   if (RegionMCDCBitmapMap->find(S) == RegionMCDCBitmapMap->end())
1173     return;
1174 
1175   // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1176   Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1177 }
1178 
1179 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1180                                           Address MCDCCondBitmapAddr,
1181                                           llvm::Value *Val) {
1182   if (!canEmitMCDCCoverage(Builder) || !RegionCondIDMap)
1183     return;
1184 
1185   // Even though, for simplicity, parentheses and unary logical-NOT operators
1186   // are considered part of their underlying condition for both MC/DC and
1187   // branch coverage, the condition IDs themselves are assigned and tracked
1188   // using the underlying condition itself.  This is done solely for
1189   // consistency since parentheses and logical-NOTs are ignored when checking
1190   // whether the condition is actually an instrumentable condition. This can
1191   // also make debugging a bit easier.
1192   S = CodeGenFunction::stripCond(S);
1193 
1194   auto ExprMCDCConditionIDMapIterator = RegionCondIDMap->find(S);
1195   if (ExprMCDCConditionIDMapIterator == RegionCondIDMap->end())
1196     return;
1197 
1198   // Extract the ID of the condition we are setting in the bitmap.
1199   unsigned CondID = ExprMCDCConditionIDMapIterator->second;
1200   assert(CondID > 0 && "Condition has no ID!");
1201 
1202   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1203 
1204   // Emit intrinsic that updates a dedicated temporary value on the stack after
1205   // a condition is evaluated. After the set of conditions has been updated,
1206   // the resulting value is used to update the boolean expression's bitmap.
1207   llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1208                           Builder.getInt64(FunctionHash),
1209                           Builder.getInt32(CondID - 1),
1210                           MCDCCondBitmapAddr.getPointer(), Val};
1211   Builder.CreateCall(
1212       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_condbitmap_update),
1213       Args);
1214 }
1215 
1216 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1217   if (CGM.getCodeGenOpts().hasProfileClangInstr())
1218     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1219                     uint32_t(EnableValueProfiling));
1220 }
1221 
1222 // This method either inserts a call to the profile run-time during
1223 // instrumentation or puts profile data into metadata for PGO use.
1224 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1225     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1226 
1227   if (!EnableValueProfiling)
1228     return;
1229 
1230   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1231     return;
1232 
1233   if (isa<llvm::Constant>(ValuePtr))
1234     return;
1235 
1236   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1237   if (InstrumentValueSites && RegionCounterMap) {
1238     auto BuilderInsertPoint = Builder.saveIP();
1239     Builder.SetInsertPoint(ValueSite);
1240     llvm::Value *Args[5] = {
1241         FuncNameVar,
1242         Builder.getInt64(FunctionHash),
1243         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1244         Builder.getInt32(ValueKind),
1245         Builder.getInt32(NumValueSites[ValueKind]++)
1246     };
1247     Builder.CreateCall(
1248         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1249     Builder.restoreIP(BuilderInsertPoint);
1250     return;
1251   }
1252 
1253   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1254   if (PGOReader && haveRegionCounts()) {
1255     // We record the top most called three functions at each call site.
1256     // Profile metadata contains "VP" string identifying this metadata
1257     // as value profiling data, then a uint32_t value for the value profiling
1258     // kind, a uint64_t value for the total number of times the call is
1259     // executed, followed by the function hash and execution count (uint64_t)
1260     // pairs for each function.
1261     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1262       return;
1263 
1264     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1265                             (llvm::InstrProfValueKind)ValueKind,
1266                             NumValueSites[ValueKind]);
1267 
1268     NumValueSites[ValueKind]++;
1269   }
1270 }
1271 
1272 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1273                                   bool IsInMainFile) {
1274   CGM.getPGOStats().addVisited(IsInMainFile);
1275   RegionCounts.clear();
1276   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1277       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1278   if (auto E = RecordExpected.takeError()) {
1279     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1280     if (IPE == llvm::instrprof_error::unknown_function)
1281       CGM.getPGOStats().addMissing(IsInMainFile);
1282     else if (IPE == llvm::instrprof_error::hash_mismatch)
1283       CGM.getPGOStats().addMismatched(IsInMainFile);
1284     else if (IPE == llvm::instrprof_error::malformed)
1285       // TODO: Consider a more specific warning for this case.
1286       CGM.getPGOStats().addMismatched(IsInMainFile);
1287     return;
1288   }
1289   ProfRecord =
1290       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1291   RegionCounts = ProfRecord->Counts;
1292 }
1293 
1294 /// Calculate what to divide by to scale weights.
1295 ///
1296 /// Given the maximum weight, calculate a divisor that will scale all the
1297 /// weights to strictly less than UINT32_MAX.
1298 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1299   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1300 }
1301 
1302 /// Scale an individual branch weight (and add 1).
1303 ///
1304 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1305 ///
1306 /// According to Laplace's Rule of Succession, it is better to compute the
1307 /// weight based on the count plus 1, so universally add 1 to the value.
1308 ///
1309 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1310 /// greater than \c Weight.
1311 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1312   assert(Scale && "scale by 0?");
1313   uint64_t Scaled = Weight / Scale + 1;
1314   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1315   return Scaled;
1316 }
1317 
1318 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1319                                                     uint64_t FalseCount) const {
1320   // Check for empty weights.
1321   if (!TrueCount && !FalseCount)
1322     return nullptr;
1323 
1324   // Calculate how to scale down to 32-bits.
1325   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1326 
1327   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1328   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1329                                       scaleBranchWeight(FalseCount, Scale));
1330 }
1331 
1332 llvm::MDNode *
1333 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1334   // We need at least two elements to create meaningful weights.
1335   if (Weights.size() < 2)
1336     return nullptr;
1337 
1338   // Check for empty weights.
1339   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1340   if (MaxWeight == 0)
1341     return nullptr;
1342 
1343   // Calculate how to scale down to 32-bits.
1344   uint64_t Scale = calculateWeightScale(MaxWeight);
1345 
1346   SmallVector<uint32_t, 16> ScaledWeights;
1347   ScaledWeights.reserve(Weights.size());
1348   for (uint64_t W : Weights)
1349     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1350 
1351   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1352   return MDHelper.createBranchWeights(ScaledWeights);
1353 }
1354 
1355 llvm::MDNode *
1356 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1357                                              uint64_t LoopCount) const {
1358   if (!PGO.haveRegionCounts())
1359     return nullptr;
1360   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1361   if (!CondCount || *CondCount == 0)
1362     return nullptr;
1363   return createProfileWeights(LoopCount,
1364                               std::max(*CondCount, LoopCount) - LoopCount);
1365 }
1366