xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 2f9966ff63d65bd474478888c9088eeae3f9c669)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 static llvm::cl::opt<bool>
27     EnableValueProfiling("enable-value-profiling",
28                          llvm::cl::desc("Enable value profiling"),
29                          llvm::cl::Hidden, llvm::cl::init(false));
30 
31 extern llvm::cl::opt<bool> SystemHeadersCoverage;
32 
33 using namespace clang;
34 using namespace CodeGen;
35 
36 void CodeGenPGO::setFuncName(StringRef Name,
37                              llvm::GlobalValue::LinkageTypes Linkage) {
38   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
39   FuncName = llvm::getPGOFuncName(
40       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
41       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
42 
43   // If we're generating a profile, create a variable for the name.
44   if (CGM.getCodeGenOpts().hasProfileClangInstr())
45     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
46 }
47 
48 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
49   setFuncName(Fn->getName(), Fn->getLinkage());
50   // Create PGOFuncName meta data.
51   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
52 }
53 
54 /// The version of the PGO hash algorithm.
55 enum PGOHashVersion : unsigned {
56   PGO_HASH_V1,
57   PGO_HASH_V2,
58   PGO_HASH_V3,
59 
60   // Keep this set to the latest hash version.
61   PGO_HASH_LATEST = PGO_HASH_V3
62 };
63 
64 namespace {
65 /// Stable hasher for PGO region counters.
66 ///
67 /// PGOHash produces a stable hash of a given function's control flow.
68 ///
69 /// Changing the output of this hash will invalidate all previously generated
70 /// profiles -- i.e., don't do it.
71 ///
72 /// \note  When this hash does eventually change (years?), we still need to
73 /// support old hashes.  We'll need to pull in the version number from the
74 /// profile data format and use the matching hash function.
75 class PGOHash {
76   uint64_t Working;
77   unsigned Count;
78   PGOHashVersion HashVersion;
79   llvm::MD5 MD5;
80 
81   static const int NumBitsPerType = 6;
82   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
83   static const unsigned TooBig = 1u << NumBitsPerType;
84 
85 public:
86   /// Hash values for AST nodes.
87   ///
88   /// Distinct values for AST nodes that have region counters attached.
89   ///
90   /// These values must be stable.  All new members must be added at the end,
91   /// and no members should be removed.  Changing the enumeration value for an
92   /// AST node will affect the hash of every function that contains that node.
93   enum HashType : unsigned char {
94     None = 0,
95     LabelStmt = 1,
96     WhileStmt,
97     DoStmt,
98     ForStmt,
99     CXXForRangeStmt,
100     ObjCForCollectionStmt,
101     SwitchStmt,
102     CaseStmt,
103     DefaultStmt,
104     IfStmt,
105     CXXTryStmt,
106     CXXCatchStmt,
107     ConditionalOperator,
108     BinaryOperatorLAnd,
109     BinaryOperatorLOr,
110     BinaryConditionalOperator,
111     // The preceding values are available with PGO_HASH_V1.
112 
113     EndOfScope,
114     IfThenBranch,
115     IfElseBranch,
116     GotoStmt,
117     IndirectGotoStmt,
118     BreakStmt,
119     ContinueStmt,
120     ReturnStmt,
121     ThrowExpr,
122     UnaryOperatorLNot,
123     BinaryOperatorLT,
124     BinaryOperatorGT,
125     BinaryOperatorLE,
126     BinaryOperatorGE,
127     BinaryOperatorEQ,
128     BinaryOperatorNE,
129     // The preceding values are available since PGO_HASH_V2.
130 
131     // Keep this last.  It's for the static assert that follows.
132     LastHashType
133   };
134   static_assert(LastHashType <= TooBig, "Too many types in HashType");
135 
136   PGOHash(PGOHashVersion HashVersion)
137       : Working(0), Count(0), HashVersion(HashVersion) {}
138   void combine(HashType Type);
139   uint64_t finalize();
140   PGOHashVersion getHashVersion() const { return HashVersion; }
141 };
142 const int PGOHash::NumBitsPerType;
143 const unsigned PGOHash::NumTypesPerWord;
144 const unsigned PGOHash::TooBig;
145 
146 /// Get the PGO hash version used in the given indexed profile.
147 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
148                                         CodeGenModule &CGM) {
149   if (PGOReader->getVersion() <= 4)
150     return PGO_HASH_V1;
151   if (PGOReader->getVersion() <= 5)
152     return PGO_HASH_V2;
153   return PGO_HASH_V3;
154 }
155 
156 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
157 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
158   using Base = RecursiveASTVisitor<MapRegionCounters>;
159 
160   /// The next counter value to assign.
161   unsigned NextCounter;
162   /// The function hash.
163   PGOHash Hash;
164   /// The map of statements to counters.
165   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
166   /// The next bitmap byte index to assign.
167   unsigned NextMCDCBitmapIdx;
168   /// The map of statements to MC/DC bitmap coverage objects.
169   llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap;
170   /// Maximum number of supported MC/DC conditions in a boolean expression.
171   unsigned MCDCMaxCond;
172   /// The profile version.
173   uint64_t ProfileVersion;
174   /// Diagnostics Engine used to report warnings.
175   DiagnosticsEngine &Diag;
176 
177   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
179                     llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap,
180                     unsigned MCDCMaxCond, DiagnosticsEngine &Diag)
181       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182         NextMCDCBitmapIdx(0), MCDCBitmapMap(MCDCBitmapMap),
183         MCDCMaxCond(MCDCMaxCond), ProfileVersion(ProfileVersion), Diag(Diag) {}
184 
185   // Blocks and lambdas are handled as separate functions, so we need not
186   // traverse them in the parent context.
187   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
188   bool TraverseLambdaExpr(LambdaExpr *LE) {
189     // Traverse the captures, but not the body.
190     for (auto C : zip(LE->captures(), LE->capture_inits()))
191       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192     return true;
193   }
194   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195 
196   bool VisitDecl(const Decl *D) {
197     switch (D->getKind()) {
198     default:
199       break;
200     case Decl::Function:
201     case Decl::CXXMethod:
202     case Decl::CXXConstructor:
203     case Decl::CXXDestructor:
204     case Decl::CXXConversion:
205     case Decl::ObjCMethod:
206     case Decl::Block:
207     case Decl::Captured:
208       CounterMap[D->getBody()] = NextCounter++;
209       break;
210     }
211     return true;
212   }
213 
214   /// If \p S gets a fresh counter, update the counter mappings. Return the
215   /// V1 hash of \p S.
216   PGOHash::HashType updateCounterMappings(Stmt *S) {
217     auto Type = getHashType(PGO_HASH_V1, S);
218     if (Type != PGOHash::None)
219       CounterMap[S] = NextCounter++;
220     return Type;
221   }
222 
223   /// The following stacks are used with dataTraverseStmtPre() and
224   /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225   /// boolean expression in a function.  The ultimate purpose is to keep track
226   /// of the number of leaf-level conditions in the boolean expression so that a
227   /// profile bitmap can be allocated based on that number.
228   ///
229   /// The stacks are also used to find error cases and notify the user.  A
230   /// standard logical operator nest for a boolean expression could be in a form
231   /// similar to this: "x = a && b && c && (d || f)"
232   unsigned NumCond = 0;
233   bool SplitNestedLogicalOp = false;
234   SmallVector<const Stmt *, 16> NonLogOpStack;
235   SmallVector<const BinaryOperator *, 16> LogOpStack;
236 
237   // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
238   bool dataTraverseStmtPre(Stmt *S) {
239     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240     if (MCDCMaxCond == 0)
241       return true;
242 
243     /// At the top of the logical operator nest, reset the number of conditions,
244     /// also forget previously seen split nesting cases.
245     if (LogOpStack.empty()) {
246       NumCond = 0;
247       SplitNestedLogicalOp = false;
248     }
249 
250     if (const Expr *E = dyn_cast<Expr>(S)) {
251       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
252       if (BinOp && BinOp->isLogicalOp()) {
253         /// Check for "split-nested" logical operators. This happens when a new
254         /// boolean expression logical-op nest is encountered within an existing
255         /// boolean expression, separated by a non-logical operator.  For
256         /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257         /// starts a new boolean expression that is separated from the other
258         /// conditions by the operator foo(). Split-nested cases are not
259         /// supported by MC/DC.
260         SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261 
262         LogOpStack.push_back(BinOp);
263         return true;
264       }
265     }
266 
267     /// Keep track of non-logical operators. These are OK as long as we don't
268     /// encounter a new logical operator after seeing one.
269     if (!LogOpStack.empty())
270       NonLogOpStack.push_back(S);
271 
272     return true;
273   }
274 
275   // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276   // an AST Stmt node.  MC/DC will use it to to signal when the top of a
277   // logical operation (boolean expression) nest is encountered.
278   bool dataTraverseStmtPost(Stmt *S) {
279     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280     if (MCDCMaxCond == 0)
281       return true;
282 
283     if (const Expr *E = dyn_cast<Expr>(S)) {
284       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
285       if (BinOp && BinOp->isLogicalOp()) {
286         assert(LogOpStack.back() == BinOp);
287         LogOpStack.pop_back();
288 
289         /// At the top of logical operator nest:
290         if (LogOpStack.empty()) {
291           /// Was the "split-nested" logical operator case encountered?
292           if (SplitNestedLogicalOp) {
293             unsigned DiagID = Diag.getCustomDiagID(
294                 DiagnosticsEngine::Warning,
295                 "unsupported MC/DC boolean expression; "
296                 "contains an operation with a nested boolean expression. "
297                 "Expression will not be covered");
298             Diag.Report(S->getBeginLoc(), DiagID);
299             return true;
300           }
301 
302           /// Was the maximum number of conditions encountered?
303           if (NumCond > MCDCMaxCond) {
304             unsigned DiagID = Diag.getCustomDiagID(
305                 DiagnosticsEngine::Warning,
306                 "unsupported MC/DC boolean expression; "
307                 "number of conditions (%0) exceeds max (%1). "
308                 "Expression will not be covered");
309             Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310             return true;
311           }
312 
313           // Otherwise, allocate the number of bytes required for the bitmap
314           // based on the number of conditions. Must be at least 1-byte long.
315           MCDCBitmapMap[BinOp] = NextMCDCBitmapIdx;
316           unsigned SizeInBits = std::max<unsigned>(1L << NumCond, CHAR_BIT);
317           NextMCDCBitmapIdx += SizeInBits / CHAR_BIT;
318         }
319         return true;
320       }
321     }
322 
323     if (!LogOpStack.empty())
324       NonLogOpStack.pop_back();
325 
326     return true;
327   }
328 
329   /// The RHS of all logical operators gets a fresh counter in order to count
330   /// how many times the RHS evaluates to true or false, depending on the
331   /// semantics of the operator. This is only valid for ">= v7" of the profile
332   /// version so that we facilitate backward compatibility. In addition, in
333   /// order to use MC/DC, count the number of total LHS and RHS conditions.
334   bool VisitBinaryOperator(BinaryOperator *S) {
335     if (S->isLogicalOp()) {
336       if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
337         NumCond++;
338 
339       if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
340         if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
341           CounterMap[S->getRHS()] = NextCounter++;
342 
343         NumCond++;
344       }
345     }
346     return Base::VisitBinaryOperator(S);
347   }
348 
349   /// Include \p S in the function hash.
350   bool VisitStmt(Stmt *S) {
351     auto Type = updateCounterMappings(S);
352     if (Hash.getHashVersion() != PGO_HASH_V1)
353       Type = getHashType(Hash.getHashVersion(), S);
354     if (Type != PGOHash::None)
355       Hash.combine(Type);
356     return true;
357   }
358 
359   bool TraverseIfStmt(IfStmt *If) {
360     // If we used the V1 hash, use the default traversal.
361     if (Hash.getHashVersion() == PGO_HASH_V1)
362       return Base::TraverseIfStmt(If);
363 
364     // Otherwise, keep track of which branch we're in while traversing.
365     VisitStmt(If);
366     for (Stmt *CS : If->children()) {
367       if (!CS)
368         continue;
369       if (CS == If->getThen())
370         Hash.combine(PGOHash::IfThenBranch);
371       else if (CS == If->getElse())
372         Hash.combine(PGOHash::IfElseBranch);
373       TraverseStmt(CS);
374     }
375     Hash.combine(PGOHash::EndOfScope);
376     return true;
377   }
378 
379 // If the statement type \p N is nestable, and its nesting impacts profile
380 // stability, define a custom traversal which tracks the end of the statement
381 // in the hash (provided we're not using the V1 hash).
382 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
383   bool Traverse##N(N *S) {                                                     \
384     Base::Traverse##N(S);                                                      \
385     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
386       Hash.combine(PGOHash::EndOfScope);                                       \
387     return true;                                                               \
388   }
389 
390   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
391   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
392   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
393   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
394   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
395   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
396   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
397 
398   /// Get version \p HashVersion of the PGO hash for \p S.
399   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
400     switch (S->getStmtClass()) {
401     default:
402       break;
403     case Stmt::LabelStmtClass:
404       return PGOHash::LabelStmt;
405     case Stmt::WhileStmtClass:
406       return PGOHash::WhileStmt;
407     case Stmt::DoStmtClass:
408       return PGOHash::DoStmt;
409     case Stmt::ForStmtClass:
410       return PGOHash::ForStmt;
411     case Stmt::CXXForRangeStmtClass:
412       return PGOHash::CXXForRangeStmt;
413     case Stmt::ObjCForCollectionStmtClass:
414       return PGOHash::ObjCForCollectionStmt;
415     case Stmt::SwitchStmtClass:
416       return PGOHash::SwitchStmt;
417     case Stmt::CaseStmtClass:
418       return PGOHash::CaseStmt;
419     case Stmt::DefaultStmtClass:
420       return PGOHash::DefaultStmt;
421     case Stmt::IfStmtClass:
422       return PGOHash::IfStmt;
423     case Stmt::CXXTryStmtClass:
424       return PGOHash::CXXTryStmt;
425     case Stmt::CXXCatchStmtClass:
426       return PGOHash::CXXCatchStmt;
427     case Stmt::ConditionalOperatorClass:
428       return PGOHash::ConditionalOperator;
429     case Stmt::BinaryConditionalOperatorClass:
430       return PGOHash::BinaryConditionalOperator;
431     case Stmt::BinaryOperatorClass: {
432       const BinaryOperator *BO = cast<BinaryOperator>(S);
433       if (BO->getOpcode() == BO_LAnd)
434         return PGOHash::BinaryOperatorLAnd;
435       if (BO->getOpcode() == BO_LOr)
436         return PGOHash::BinaryOperatorLOr;
437       if (HashVersion >= PGO_HASH_V2) {
438         switch (BO->getOpcode()) {
439         default:
440           break;
441         case BO_LT:
442           return PGOHash::BinaryOperatorLT;
443         case BO_GT:
444           return PGOHash::BinaryOperatorGT;
445         case BO_LE:
446           return PGOHash::BinaryOperatorLE;
447         case BO_GE:
448           return PGOHash::BinaryOperatorGE;
449         case BO_EQ:
450           return PGOHash::BinaryOperatorEQ;
451         case BO_NE:
452           return PGOHash::BinaryOperatorNE;
453         }
454       }
455       break;
456     }
457     }
458 
459     if (HashVersion >= PGO_HASH_V2) {
460       switch (S->getStmtClass()) {
461       default:
462         break;
463       case Stmt::GotoStmtClass:
464         return PGOHash::GotoStmt;
465       case Stmt::IndirectGotoStmtClass:
466         return PGOHash::IndirectGotoStmt;
467       case Stmt::BreakStmtClass:
468         return PGOHash::BreakStmt;
469       case Stmt::ContinueStmtClass:
470         return PGOHash::ContinueStmt;
471       case Stmt::ReturnStmtClass:
472         return PGOHash::ReturnStmt;
473       case Stmt::CXXThrowExprClass:
474         return PGOHash::ThrowExpr;
475       case Stmt::UnaryOperatorClass: {
476         const UnaryOperator *UO = cast<UnaryOperator>(S);
477         if (UO->getOpcode() == UO_LNot)
478           return PGOHash::UnaryOperatorLNot;
479         break;
480       }
481       }
482     }
483 
484     return PGOHash::None;
485   }
486 };
487 
488 /// A StmtVisitor that propagates the raw counts through the AST and
489 /// records the count at statements where the value may change.
490 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
491   /// PGO state.
492   CodeGenPGO &PGO;
493 
494   /// A flag that is set when the current count should be recorded on the
495   /// next statement, such as at the exit of a loop.
496   bool RecordNextStmtCount;
497 
498   /// The count at the current location in the traversal.
499   uint64_t CurrentCount;
500 
501   /// The map of statements to count values.
502   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
503 
504   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
505   struct BreakContinue {
506     uint64_t BreakCount = 0;
507     uint64_t ContinueCount = 0;
508     BreakContinue() = default;
509   };
510   SmallVector<BreakContinue, 8> BreakContinueStack;
511 
512   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
513                       CodeGenPGO &PGO)
514       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
515 
516   void RecordStmtCount(const Stmt *S) {
517     if (RecordNextStmtCount) {
518       CountMap[S] = CurrentCount;
519       RecordNextStmtCount = false;
520     }
521   }
522 
523   /// Set and return the current count.
524   uint64_t setCount(uint64_t Count) {
525     CurrentCount = Count;
526     return Count;
527   }
528 
529   void VisitStmt(const Stmt *S) {
530     RecordStmtCount(S);
531     for (const Stmt *Child : S->children())
532       if (Child)
533         this->Visit(Child);
534   }
535 
536   void VisitFunctionDecl(const FunctionDecl *D) {
537     // Counter tracks entry to the function body.
538     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
539     CountMap[D->getBody()] = BodyCount;
540     Visit(D->getBody());
541   }
542 
543   // Skip lambda expressions. We visit these as FunctionDecls when we're
544   // generating them and aren't interested in the body when generating a
545   // parent context.
546   void VisitLambdaExpr(const LambdaExpr *LE) {}
547 
548   void VisitCapturedDecl(const CapturedDecl *D) {
549     // Counter tracks entry to the capture body.
550     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
551     CountMap[D->getBody()] = BodyCount;
552     Visit(D->getBody());
553   }
554 
555   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
556     // Counter tracks entry to the method body.
557     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
558     CountMap[D->getBody()] = BodyCount;
559     Visit(D->getBody());
560   }
561 
562   void VisitBlockDecl(const BlockDecl *D) {
563     // Counter tracks entry to the block body.
564     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
565     CountMap[D->getBody()] = BodyCount;
566     Visit(D->getBody());
567   }
568 
569   void VisitReturnStmt(const ReturnStmt *S) {
570     RecordStmtCount(S);
571     if (S->getRetValue())
572       Visit(S->getRetValue());
573     CurrentCount = 0;
574     RecordNextStmtCount = true;
575   }
576 
577   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
578     RecordStmtCount(E);
579     if (E->getSubExpr())
580       Visit(E->getSubExpr());
581     CurrentCount = 0;
582     RecordNextStmtCount = true;
583   }
584 
585   void VisitGotoStmt(const GotoStmt *S) {
586     RecordStmtCount(S);
587     CurrentCount = 0;
588     RecordNextStmtCount = true;
589   }
590 
591   void VisitLabelStmt(const LabelStmt *S) {
592     RecordNextStmtCount = false;
593     // Counter tracks the block following the label.
594     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
595     CountMap[S] = BlockCount;
596     Visit(S->getSubStmt());
597   }
598 
599   void VisitBreakStmt(const BreakStmt *S) {
600     RecordStmtCount(S);
601     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
602     BreakContinueStack.back().BreakCount += CurrentCount;
603     CurrentCount = 0;
604     RecordNextStmtCount = true;
605   }
606 
607   void VisitContinueStmt(const ContinueStmt *S) {
608     RecordStmtCount(S);
609     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
610     BreakContinueStack.back().ContinueCount += CurrentCount;
611     CurrentCount = 0;
612     RecordNextStmtCount = true;
613   }
614 
615   void VisitWhileStmt(const WhileStmt *S) {
616     RecordStmtCount(S);
617     uint64_t ParentCount = CurrentCount;
618 
619     BreakContinueStack.push_back(BreakContinue());
620     // Visit the body region first so the break/continue adjustments can be
621     // included when visiting the condition.
622     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
623     CountMap[S->getBody()] = CurrentCount;
624     Visit(S->getBody());
625     uint64_t BackedgeCount = CurrentCount;
626 
627     // ...then go back and propagate counts through the condition. The count
628     // at the start of the condition is the sum of the incoming edges,
629     // the backedge from the end of the loop body, and the edges from
630     // continue statements.
631     BreakContinue BC = BreakContinueStack.pop_back_val();
632     uint64_t CondCount =
633         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
634     CountMap[S->getCond()] = CondCount;
635     Visit(S->getCond());
636     setCount(BC.BreakCount + CondCount - BodyCount);
637     RecordNextStmtCount = true;
638   }
639 
640   void VisitDoStmt(const DoStmt *S) {
641     RecordStmtCount(S);
642     uint64_t LoopCount = PGO.getRegionCount(S);
643 
644     BreakContinueStack.push_back(BreakContinue());
645     // The count doesn't include the fallthrough from the parent scope. Add it.
646     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
647     CountMap[S->getBody()] = BodyCount;
648     Visit(S->getBody());
649     uint64_t BackedgeCount = CurrentCount;
650 
651     BreakContinue BC = BreakContinueStack.pop_back_val();
652     // The count at the start of the condition is equal to the count at the
653     // end of the body, plus any continues.
654     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
655     CountMap[S->getCond()] = CondCount;
656     Visit(S->getCond());
657     setCount(BC.BreakCount + CondCount - LoopCount);
658     RecordNextStmtCount = true;
659   }
660 
661   void VisitForStmt(const ForStmt *S) {
662     RecordStmtCount(S);
663     if (S->getInit())
664       Visit(S->getInit());
665 
666     uint64_t ParentCount = CurrentCount;
667 
668     BreakContinueStack.push_back(BreakContinue());
669     // Visit the body region first. (This is basically the same as a while
670     // loop; see further comments in VisitWhileStmt.)
671     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
672     CountMap[S->getBody()] = BodyCount;
673     Visit(S->getBody());
674     uint64_t BackedgeCount = CurrentCount;
675     BreakContinue BC = BreakContinueStack.pop_back_val();
676 
677     // The increment is essentially part of the body but it needs to include
678     // the count for all the continue statements.
679     if (S->getInc()) {
680       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
681       CountMap[S->getInc()] = IncCount;
682       Visit(S->getInc());
683     }
684 
685     // ...then go back and propagate counts through the condition.
686     uint64_t CondCount =
687         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
688     if (S->getCond()) {
689       CountMap[S->getCond()] = CondCount;
690       Visit(S->getCond());
691     }
692     setCount(BC.BreakCount + CondCount - BodyCount);
693     RecordNextStmtCount = true;
694   }
695 
696   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
697     RecordStmtCount(S);
698     if (S->getInit())
699       Visit(S->getInit());
700     Visit(S->getLoopVarStmt());
701     Visit(S->getRangeStmt());
702     Visit(S->getBeginStmt());
703     Visit(S->getEndStmt());
704 
705     uint64_t ParentCount = CurrentCount;
706     BreakContinueStack.push_back(BreakContinue());
707     // Visit the body region first. (This is basically the same as a while
708     // loop; see further comments in VisitWhileStmt.)
709     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
710     CountMap[S->getBody()] = BodyCount;
711     Visit(S->getBody());
712     uint64_t BackedgeCount = CurrentCount;
713     BreakContinue BC = BreakContinueStack.pop_back_val();
714 
715     // The increment is essentially part of the body but it needs to include
716     // the count for all the continue statements.
717     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
718     CountMap[S->getInc()] = IncCount;
719     Visit(S->getInc());
720 
721     // ...then go back and propagate counts through the condition.
722     uint64_t CondCount =
723         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
724     CountMap[S->getCond()] = CondCount;
725     Visit(S->getCond());
726     setCount(BC.BreakCount + CondCount - BodyCount);
727     RecordNextStmtCount = true;
728   }
729 
730   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
731     RecordStmtCount(S);
732     Visit(S->getElement());
733     uint64_t ParentCount = CurrentCount;
734     BreakContinueStack.push_back(BreakContinue());
735     // Counter tracks the body of the loop.
736     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
737     CountMap[S->getBody()] = BodyCount;
738     Visit(S->getBody());
739     uint64_t BackedgeCount = CurrentCount;
740     BreakContinue BC = BreakContinueStack.pop_back_val();
741 
742     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
743              BodyCount);
744     RecordNextStmtCount = true;
745   }
746 
747   void VisitSwitchStmt(const SwitchStmt *S) {
748     RecordStmtCount(S);
749     if (S->getInit())
750       Visit(S->getInit());
751     Visit(S->getCond());
752     CurrentCount = 0;
753     BreakContinueStack.push_back(BreakContinue());
754     Visit(S->getBody());
755     // If the switch is inside a loop, add the continue counts.
756     BreakContinue BC = BreakContinueStack.pop_back_val();
757     if (!BreakContinueStack.empty())
758       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
759     // Counter tracks the exit block of the switch.
760     setCount(PGO.getRegionCount(S));
761     RecordNextStmtCount = true;
762   }
763 
764   void VisitSwitchCase(const SwitchCase *S) {
765     RecordNextStmtCount = false;
766     // Counter for this particular case. This counts only jumps from the
767     // switch header and does not include fallthrough from the case before
768     // this one.
769     uint64_t CaseCount = PGO.getRegionCount(S);
770     setCount(CurrentCount + CaseCount);
771     // We need the count without fallthrough in the mapping, so it's more useful
772     // for branch probabilities.
773     CountMap[S] = CaseCount;
774     RecordNextStmtCount = true;
775     Visit(S->getSubStmt());
776   }
777 
778   void VisitIfStmt(const IfStmt *S) {
779     RecordStmtCount(S);
780 
781     if (S->isConsteval()) {
782       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
783       if (Stm)
784         Visit(Stm);
785       return;
786     }
787 
788     uint64_t ParentCount = CurrentCount;
789     if (S->getInit())
790       Visit(S->getInit());
791     Visit(S->getCond());
792 
793     // Counter tracks the "then" part of an if statement. The count for
794     // the "else" part, if it exists, will be calculated from this counter.
795     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
796     CountMap[S->getThen()] = ThenCount;
797     Visit(S->getThen());
798     uint64_t OutCount = CurrentCount;
799 
800     uint64_t ElseCount = ParentCount - ThenCount;
801     if (S->getElse()) {
802       setCount(ElseCount);
803       CountMap[S->getElse()] = ElseCount;
804       Visit(S->getElse());
805       OutCount += CurrentCount;
806     } else
807       OutCount += ElseCount;
808     setCount(OutCount);
809     RecordNextStmtCount = true;
810   }
811 
812   void VisitCXXTryStmt(const CXXTryStmt *S) {
813     RecordStmtCount(S);
814     Visit(S->getTryBlock());
815     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
816       Visit(S->getHandler(I));
817     // Counter tracks the continuation block of the try statement.
818     setCount(PGO.getRegionCount(S));
819     RecordNextStmtCount = true;
820   }
821 
822   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
823     RecordNextStmtCount = false;
824     // Counter tracks the catch statement's handler block.
825     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
826     CountMap[S] = CatchCount;
827     Visit(S->getHandlerBlock());
828   }
829 
830   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
831     RecordStmtCount(E);
832     uint64_t ParentCount = CurrentCount;
833     Visit(E->getCond());
834 
835     // Counter tracks the "true" part of a conditional operator. The
836     // count in the "false" part will be calculated from this counter.
837     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
838     CountMap[E->getTrueExpr()] = TrueCount;
839     Visit(E->getTrueExpr());
840     uint64_t OutCount = CurrentCount;
841 
842     uint64_t FalseCount = setCount(ParentCount - TrueCount);
843     CountMap[E->getFalseExpr()] = FalseCount;
844     Visit(E->getFalseExpr());
845     OutCount += CurrentCount;
846 
847     setCount(OutCount);
848     RecordNextStmtCount = true;
849   }
850 
851   void VisitBinLAnd(const BinaryOperator *E) {
852     RecordStmtCount(E);
853     uint64_t ParentCount = CurrentCount;
854     Visit(E->getLHS());
855     // Counter tracks the right hand side of a logical and operator.
856     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
857     CountMap[E->getRHS()] = RHSCount;
858     Visit(E->getRHS());
859     setCount(ParentCount + RHSCount - CurrentCount);
860     RecordNextStmtCount = true;
861   }
862 
863   void VisitBinLOr(const BinaryOperator *E) {
864     RecordStmtCount(E);
865     uint64_t ParentCount = CurrentCount;
866     Visit(E->getLHS());
867     // Counter tracks the right hand side of a logical or operator.
868     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
869     CountMap[E->getRHS()] = RHSCount;
870     Visit(E->getRHS());
871     setCount(ParentCount + RHSCount - CurrentCount);
872     RecordNextStmtCount = true;
873   }
874 };
875 } // end anonymous namespace
876 
877 void PGOHash::combine(HashType Type) {
878   // Check that we never combine 0 and only have six bits.
879   assert(Type && "Hash is invalid: unexpected type 0");
880   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
881 
882   // Pass through MD5 if enough work has built up.
883   if (Count && Count % NumTypesPerWord == 0) {
884     using namespace llvm::support;
885     uint64_t Swapped =
886         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
887     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
888     Working = 0;
889   }
890 
891   // Accumulate the current type.
892   ++Count;
893   Working = Working << NumBitsPerType | Type;
894 }
895 
896 uint64_t PGOHash::finalize() {
897   // Use Working as the hash directly if we never used MD5.
898   if (Count <= NumTypesPerWord)
899     // No need to byte swap here, since none of the math was endian-dependent.
900     // This number will be byte-swapped as required on endianness transitions,
901     // so we will see the same value on the other side.
902     return Working;
903 
904   // Check for remaining work in Working.
905   if (Working) {
906     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
907     // is buggy because it converts a uint64_t into an array of uint8_t.
908     if (HashVersion < PGO_HASH_V3) {
909       MD5.update({(uint8_t)Working});
910     } else {
911       using namespace llvm::support;
912       uint64_t Swapped =
913           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
914       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
915     }
916   }
917 
918   // Finalize the MD5 and return the hash.
919   llvm::MD5::MD5Result Result;
920   MD5.final(Result);
921   return Result.low();
922 }
923 
924 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
925   const Decl *D = GD.getDecl();
926   if (!D->hasBody())
927     return;
928 
929   // Skip CUDA/HIP kernel launch stub functions.
930   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
931       D->hasAttr<CUDAGlobalAttr>())
932     return;
933 
934   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
935   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
936   if (!InstrumentRegions && !PGOReader)
937     return;
938   if (D->isImplicit())
939     return;
940   // Constructors and destructors may be represented by several functions in IR.
941   // If so, instrument only base variant, others are implemented by delegation
942   // to the base one, it would be counted twice otherwise.
943   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
944     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
945       if (GD.getCtorType() != Ctor_Base &&
946           CodeGenFunction::IsConstructorDelegationValid(CCD))
947         return;
948   }
949   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
950     return;
951 
952   CGM.ClearUnusedCoverageMapping(D);
953   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
954     return;
955   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
956     return;
957 
958   setFuncName(Fn);
959 
960   mapRegionCounters(D);
961   if (CGM.getCodeGenOpts().CoverageMapping)
962     emitCounterRegionMapping(D);
963   if (PGOReader) {
964     SourceManager &SM = CGM.getContext().getSourceManager();
965     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
966     computeRegionCounts(D);
967     applyFunctionAttributes(PGOReader, Fn);
968   }
969 }
970 
971 void CodeGenPGO::mapRegionCounters(const Decl *D) {
972   // Use the latest hash version when inserting instrumentation, but use the
973   // version in the indexed profile if we're reading PGO data.
974   PGOHashVersion HashVersion = PGO_HASH_LATEST;
975   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
976   if (auto *PGOReader = CGM.getPGOReader()) {
977     HashVersion = getPGOHashVersion(PGOReader, CGM);
978     ProfileVersion = PGOReader->getVersion();
979   }
980 
981   // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
982   // set it to zero. This value impacts the number of conditions accepted in a
983   // given boolean expression, which impacts the size of the bitmap used to
984   // track test vector execution for that boolean expression.  Because the
985   // bitmap scales exponentially (2^n) based on the number of conditions seen,
986   // the maximum value is hard-coded at 6 conditions, which is more than enough
987   // for most embedded applications. Setting a maximum value prevents the
988   // bitmap footprint from growing too large without the user's knowledge. In
989   // the future, this value could be adjusted with a command-line option.
990   unsigned MCDCMaxConditions = (CGM.getCodeGenOpts().MCDCCoverage) ? 6 : 0;
991 
992   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
993   RegionMCDCBitmapMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
994   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
995                            *RegionMCDCBitmapMap, MCDCMaxConditions,
996                            CGM.getDiags());
997   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
998     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
999   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1000     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1001   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1002     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1003   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1004     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1005   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1006   NumRegionCounters = Walker.NextCounter;
1007   MCDCBitmapBytes = Walker.NextMCDCBitmapIdx;
1008   FunctionHash = Walker.Hash.finalize();
1009 }
1010 
1011 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1012   if (!D->getBody())
1013     return true;
1014 
1015   // Skip host-only functions in the CUDA device compilation and device-only
1016   // functions in the host compilation. Just roughly filter them out based on
1017   // the function attributes. If there are effectively host-only or device-only
1018   // ones, their coverage mapping may still be generated.
1019   if (CGM.getLangOpts().CUDA &&
1020       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1021         !D->hasAttr<CUDAGlobalAttr>()) ||
1022        (!CGM.getLangOpts().CUDAIsDevice &&
1023         (D->hasAttr<CUDAGlobalAttr>() ||
1024          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1025     return true;
1026 
1027   // Don't map the functions in system headers.
1028   const auto &SM = CGM.getContext().getSourceManager();
1029   auto Loc = D->getBody()->getBeginLoc();
1030   return !SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1031 }
1032 
1033 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1034   if (skipRegionMappingForDecl(D))
1035     return;
1036 
1037   std::string CoverageMapping;
1038   llvm::raw_string_ostream OS(CoverageMapping);
1039   RegionCondIDMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1040   CoverageMappingGen MappingGen(
1041       *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1042       CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCBitmapMap.get(),
1043       RegionCondIDMap.get());
1044   MappingGen.emitCounterMapping(D, OS);
1045   OS.flush();
1046 
1047   if (CoverageMapping.empty())
1048     return;
1049 
1050   CGM.getCoverageMapping()->addFunctionMappingRecord(
1051       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1052 }
1053 
1054 void
1055 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1056                                     llvm::GlobalValue::LinkageTypes Linkage) {
1057   if (skipRegionMappingForDecl(D))
1058     return;
1059 
1060   std::string CoverageMapping;
1061   llvm::raw_string_ostream OS(CoverageMapping);
1062   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1063                                 CGM.getContext().getSourceManager(),
1064                                 CGM.getLangOpts());
1065   MappingGen.emitEmptyMapping(D, OS);
1066   OS.flush();
1067 
1068   if (CoverageMapping.empty())
1069     return;
1070 
1071   setFuncName(Name, Linkage);
1072   CGM.getCoverageMapping()->addFunctionMappingRecord(
1073       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1074 }
1075 
1076 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1077   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1078   ComputeRegionCounts Walker(*StmtCountMap, *this);
1079   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1080     Walker.VisitFunctionDecl(FD);
1081   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1082     Walker.VisitObjCMethodDecl(MD);
1083   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1084     Walker.VisitBlockDecl(BD);
1085   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1086     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1087 }
1088 
1089 void
1090 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1091                                     llvm::Function *Fn) {
1092   if (!haveRegionCounts())
1093     return;
1094 
1095   uint64_t FunctionCount = getRegionCount(nullptr);
1096   Fn->setEntryCount(FunctionCount);
1097 }
1098 
1099 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
1100                                       llvm::Value *StepV) {
1101   if (!RegionCounterMap || !Builder.GetInsertBlock())
1102     return;
1103 
1104   unsigned Counter = (*RegionCounterMap)[S];
1105 
1106   llvm::Value *Args[] = {FuncNameVar,
1107                          Builder.getInt64(FunctionHash),
1108                          Builder.getInt32(NumRegionCounters),
1109                          Builder.getInt32(Counter), StepV};
1110   if (!StepV)
1111     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1112                        ArrayRef(Args, 4));
1113   else
1114     Builder.CreateCall(
1115         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
1116         ArrayRef(Args));
1117 }
1118 
1119 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1120   return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1121           CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1122 }
1123 
1124 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1125   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1126     return;
1127 
1128   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1129 
1130   // Emit intrinsic representing MCDC bitmap parameters at function entry.
1131   // This is used by the instrumentation pass, but it isn't actually lowered to
1132   // anything.
1133   llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1134                           Builder.getInt64(FunctionHash),
1135                           Builder.getInt32(MCDCBitmapBytes)};
1136   Builder.CreateCall(
1137       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1138 }
1139 
1140 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1141                                                 const Expr *S,
1142                                                 Address MCDCCondBitmapAddr) {
1143   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1144     return;
1145 
1146   S = S->IgnoreParens();
1147 
1148   auto ExprMCDCBitmapMapIterator = RegionMCDCBitmapMap->find(S);
1149   if (ExprMCDCBitmapMapIterator == RegionMCDCBitmapMap->end())
1150     return;
1151 
1152   // Extract the ID of the global bitmap associated with this expression.
1153   unsigned MCDCTestVectorBitmapID = ExprMCDCBitmapMapIterator->second;
1154   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1155 
1156   // Emit intrinsic responsible for updating the global bitmap corresponding to
1157   // a boolean expression. The index being set is based on the value loaded
1158   // from a pointer to a dedicated temporary value on the stack that is itself
1159   // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1160   // index represents an executed test vector.
1161   llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1162                           Builder.getInt64(FunctionHash),
1163                           Builder.getInt32(MCDCBitmapBytes),
1164                           Builder.getInt32(MCDCTestVectorBitmapID),
1165                           MCDCCondBitmapAddr.getPointer()};
1166   Builder.CreateCall(
1167       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1168 }
1169 
1170 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1171                                          Address MCDCCondBitmapAddr) {
1172   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1173     return;
1174 
1175   S = S->IgnoreParens();
1176 
1177   if (RegionMCDCBitmapMap->find(S) == RegionMCDCBitmapMap->end())
1178     return;
1179 
1180   // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1181   Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1182 }
1183 
1184 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1185                                           Address MCDCCondBitmapAddr,
1186                                           llvm::Value *Val) {
1187   if (!canEmitMCDCCoverage(Builder) || !RegionCondIDMap)
1188     return;
1189 
1190   // Even though, for simplicity, parentheses and unary logical-NOT operators
1191   // are considered part of their underlying condition for both MC/DC and
1192   // branch coverage, the condition IDs themselves are assigned and tracked
1193   // using the underlying condition itself.  This is done solely for
1194   // consistency since parentheses and logical-NOTs are ignored when checking
1195   // whether the condition is actually an instrumentable condition. This can
1196   // also make debugging a bit easier.
1197   S = CodeGenFunction::stripCond(S);
1198 
1199   auto ExprMCDCConditionIDMapIterator = RegionCondIDMap->find(S);
1200   if (ExprMCDCConditionIDMapIterator == RegionCondIDMap->end())
1201     return;
1202 
1203   // Extract the ID of the condition we are setting in the bitmap.
1204   unsigned CondID = ExprMCDCConditionIDMapIterator->second;
1205   assert(CondID > 0 && "Condition has no ID!");
1206 
1207   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1208 
1209   // Emit intrinsic that updates a dedicated temporary value on the stack after
1210   // a condition is evaluated. After the set of conditions has been updated,
1211   // the resulting value is used to update the boolean expression's bitmap.
1212   llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1213                           Builder.getInt64(FunctionHash),
1214                           Builder.getInt32(CondID - 1),
1215                           MCDCCondBitmapAddr.getPointer(), Val};
1216   Builder.CreateCall(
1217       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_condbitmap_update),
1218       Args);
1219 }
1220 
1221 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1222   if (CGM.getCodeGenOpts().hasProfileClangInstr())
1223     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1224                     uint32_t(EnableValueProfiling));
1225 }
1226 
1227 // This method either inserts a call to the profile run-time during
1228 // instrumentation or puts profile data into metadata for PGO use.
1229 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1230     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1231 
1232   if (!EnableValueProfiling)
1233     return;
1234 
1235   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1236     return;
1237 
1238   if (isa<llvm::Constant>(ValuePtr))
1239     return;
1240 
1241   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1242   if (InstrumentValueSites && RegionCounterMap) {
1243     auto BuilderInsertPoint = Builder.saveIP();
1244     Builder.SetInsertPoint(ValueSite);
1245     llvm::Value *Args[5] = {
1246         FuncNameVar,
1247         Builder.getInt64(FunctionHash),
1248         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1249         Builder.getInt32(ValueKind),
1250         Builder.getInt32(NumValueSites[ValueKind]++)
1251     };
1252     Builder.CreateCall(
1253         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1254     Builder.restoreIP(BuilderInsertPoint);
1255     return;
1256   }
1257 
1258   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1259   if (PGOReader && haveRegionCounts()) {
1260     // We record the top most called three functions at each call site.
1261     // Profile metadata contains "VP" string identifying this metadata
1262     // as value profiling data, then a uint32_t value for the value profiling
1263     // kind, a uint64_t value for the total number of times the call is
1264     // executed, followed by the function hash and execution count (uint64_t)
1265     // pairs for each function.
1266     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1267       return;
1268 
1269     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1270                             (llvm::InstrProfValueKind)ValueKind,
1271                             NumValueSites[ValueKind]);
1272 
1273     NumValueSites[ValueKind]++;
1274   }
1275 }
1276 
1277 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1278                                   bool IsInMainFile) {
1279   CGM.getPGOStats().addVisited(IsInMainFile);
1280   RegionCounts.clear();
1281   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1282       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1283   if (auto E = RecordExpected.takeError()) {
1284     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1285     if (IPE == llvm::instrprof_error::unknown_function)
1286       CGM.getPGOStats().addMissing(IsInMainFile);
1287     else if (IPE == llvm::instrprof_error::hash_mismatch)
1288       CGM.getPGOStats().addMismatched(IsInMainFile);
1289     else if (IPE == llvm::instrprof_error::malformed)
1290       // TODO: Consider a more specific warning for this case.
1291       CGM.getPGOStats().addMismatched(IsInMainFile);
1292     return;
1293   }
1294   ProfRecord =
1295       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1296   RegionCounts = ProfRecord->Counts;
1297 }
1298 
1299 /// Calculate what to divide by to scale weights.
1300 ///
1301 /// Given the maximum weight, calculate a divisor that will scale all the
1302 /// weights to strictly less than UINT32_MAX.
1303 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1304   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1305 }
1306 
1307 /// Scale an individual branch weight (and add 1).
1308 ///
1309 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1310 ///
1311 /// According to Laplace's Rule of Succession, it is better to compute the
1312 /// weight based on the count plus 1, so universally add 1 to the value.
1313 ///
1314 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1315 /// greater than \c Weight.
1316 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1317   assert(Scale && "scale by 0?");
1318   uint64_t Scaled = Weight / Scale + 1;
1319   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1320   return Scaled;
1321 }
1322 
1323 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1324                                                     uint64_t FalseCount) const {
1325   // Check for empty weights.
1326   if (!TrueCount && !FalseCount)
1327     return nullptr;
1328 
1329   // Calculate how to scale down to 32-bits.
1330   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1331 
1332   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1333   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1334                                       scaleBranchWeight(FalseCount, Scale));
1335 }
1336 
1337 llvm::MDNode *
1338 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1339   // We need at least two elements to create meaningful weights.
1340   if (Weights.size() < 2)
1341     return nullptr;
1342 
1343   // Check for empty weights.
1344   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1345   if (MaxWeight == 0)
1346     return nullptr;
1347 
1348   // Calculate how to scale down to 32-bits.
1349   uint64_t Scale = calculateWeightScale(MaxWeight);
1350 
1351   SmallVector<uint32_t, 16> ScaledWeights;
1352   ScaledWeights.reserve(Weights.size());
1353   for (uint64_t W : Weights)
1354     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1355 
1356   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1357   return MDHelper.createBranchWeights(ScaledWeights);
1358 }
1359 
1360 llvm::MDNode *
1361 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1362                                              uint64_t LoopCount) const {
1363   if (!PGO.haveRegionCounts())
1364     return nullptr;
1365   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1366   if (!CondCount || *CondCount == 0)
1367     return nullptr;
1368   return createProfileWeights(LoopCount,
1369                               std::max(*CondCount, LoopCount) - LoopCount);
1370 }
1371