xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision b3edf4467982447620505a28fc82e38a414c07dc)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 static llvm::cl::opt<bool>
27     EnableValueProfiling("enable-value-profiling",
28                          llvm::cl::desc("Enable value profiling"),
29                          llvm::cl::Hidden, llvm::cl::init(false));
30 
31 extern llvm::cl::opt<bool> SystemHeadersCoverage;
32 
33 using namespace clang;
34 using namespace CodeGen;
35 
36 void CodeGenPGO::setFuncName(StringRef Name,
37                              llvm::GlobalValue::LinkageTypes Linkage) {
38   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
39   FuncName = llvm::getPGOFuncName(
40       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
41       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
42 
43   // If we're generating a profile, create a variable for the name.
44   if (CGM.getCodeGenOpts().hasProfileClangInstr())
45     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
46 }
47 
48 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
49   setFuncName(Fn->getName(), Fn->getLinkage());
50   // Create PGOFuncName meta data.
51   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
52 }
53 
54 /// The version of the PGO hash algorithm.
55 enum PGOHashVersion : unsigned {
56   PGO_HASH_V1,
57   PGO_HASH_V2,
58   PGO_HASH_V3,
59 
60   // Keep this set to the latest hash version.
61   PGO_HASH_LATEST = PGO_HASH_V3
62 };
63 
64 namespace {
65 /// Stable hasher for PGO region counters.
66 ///
67 /// PGOHash produces a stable hash of a given function's control flow.
68 ///
69 /// Changing the output of this hash will invalidate all previously generated
70 /// profiles -- i.e., don't do it.
71 ///
72 /// \note  When this hash does eventually change (years?), we still need to
73 /// support old hashes.  We'll need to pull in the version number from the
74 /// profile data format and use the matching hash function.
75 class PGOHash {
76   uint64_t Working;
77   unsigned Count;
78   PGOHashVersion HashVersion;
79   llvm::MD5 MD5;
80 
81   static const int NumBitsPerType = 6;
82   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
83   static const unsigned TooBig = 1u << NumBitsPerType;
84 
85 public:
86   /// Hash values for AST nodes.
87   ///
88   /// Distinct values for AST nodes that have region counters attached.
89   ///
90   /// These values must be stable.  All new members must be added at the end,
91   /// and no members should be removed.  Changing the enumeration value for an
92   /// AST node will affect the hash of every function that contains that node.
93   enum HashType : unsigned char {
94     None = 0,
95     LabelStmt = 1,
96     WhileStmt,
97     DoStmt,
98     ForStmt,
99     CXXForRangeStmt,
100     ObjCForCollectionStmt,
101     SwitchStmt,
102     CaseStmt,
103     DefaultStmt,
104     IfStmt,
105     CXXTryStmt,
106     CXXCatchStmt,
107     ConditionalOperator,
108     BinaryOperatorLAnd,
109     BinaryOperatorLOr,
110     BinaryConditionalOperator,
111     // The preceding values are available with PGO_HASH_V1.
112 
113     EndOfScope,
114     IfThenBranch,
115     IfElseBranch,
116     GotoStmt,
117     IndirectGotoStmt,
118     BreakStmt,
119     ContinueStmt,
120     ReturnStmt,
121     ThrowExpr,
122     UnaryOperatorLNot,
123     BinaryOperatorLT,
124     BinaryOperatorGT,
125     BinaryOperatorLE,
126     BinaryOperatorGE,
127     BinaryOperatorEQ,
128     BinaryOperatorNE,
129     // The preceding values are available since PGO_HASH_V2.
130 
131     // Keep this last.  It's for the static assert that follows.
132     LastHashType
133   };
134   static_assert(LastHashType <= TooBig, "Too many types in HashType");
135 
136   PGOHash(PGOHashVersion HashVersion)
137       : Working(0), Count(0), HashVersion(HashVersion) {}
138   void combine(HashType Type);
139   uint64_t finalize();
140   PGOHashVersion getHashVersion() const { return HashVersion; }
141 };
142 const int PGOHash::NumBitsPerType;
143 const unsigned PGOHash::NumTypesPerWord;
144 const unsigned PGOHash::TooBig;
145 
146 /// Get the PGO hash version used in the given indexed profile.
147 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
148                                         CodeGenModule &CGM) {
149   if (PGOReader->getVersion() <= 4)
150     return PGO_HASH_V1;
151   if (PGOReader->getVersion() <= 5)
152     return PGO_HASH_V2;
153   return PGO_HASH_V3;
154 }
155 
156 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
157 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
158   using Base = RecursiveASTVisitor<MapRegionCounters>;
159 
160   /// The next counter value to assign.
161   unsigned NextCounter;
162   /// The function hash.
163   PGOHash Hash;
164   /// The map of statements to counters.
165   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
166   /// The next bitmap byte index to assign.
167   unsigned NextMCDCBitmapIdx;
168   /// The map of statements to MC/DC bitmap coverage objects.
169   llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap;
170   /// Maximum number of supported MC/DC conditions in a boolean expression.
171   unsigned MCDCMaxCond;
172   /// The profile version.
173   uint64_t ProfileVersion;
174   /// Diagnostics Engine used to report warnings.
175   DiagnosticsEngine &Diag;
176 
177   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
179                     llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap,
180                     unsigned MCDCMaxCond, DiagnosticsEngine &Diag)
181       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182         NextMCDCBitmapIdx(0), MCDCBitmapMap(MCDCBitmapMap),
183         MCDCMaxCond(MCDCMaxCond), ProfileVersion(ProfileVersion), Diag(Diag) {}
184 
185   // Blocks and lambdas are handled as separate functions, so we need not
186   // traverse them in the parent context.
187   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
188   bool TraverseLambdaExpr(LambdaExpr *LE) {
189     // Traverse the captures, but not the body.
190     for (auto C : zip(LE->captures(), LE->capture_inits()))
191       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192     return true;
193   }
194   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195 
196   bool VisitDecl(const Decl *D) {
197     switch (D->getKind()) {
198     default:
199       break;
200     case Decl::Function:
201     case Decl::CXXMethod:
202     case Decl::CXXConstructor:
203     case Decl::CXXDestructor:
204     case Decl::CXXConversion:
205     case Decl::ObjCMethod:
206     case Decl::Block:
207     case Decl::Captured:
208       CounterMap[D->getBody()] = NextCounter++;
209       break;
210     }
211     return true;
212   }
213 
214   /// If \p S gets a fresh counter, update the counter mappings. Return the
215   /// V1 hash of \p S.
216   PGOHash::HashType updateCounterMappings(Stmt *S) {
217     auto Type = getHashType(PGO_HASH_V1, S);
218     if (Type != PGOHash::None)
219       CounterMap[S] = NextCounter++;
220     return Type;
221   }
222 
223   /// The following stacks are used with dataTraverseStmtPre() and
224   /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225   /// boolean expression in a function.  The ultimate purpose is to keep track
226   /// of the number of leaf-level conditions in the boolean expression so that a
227   /// profile bitmap can be allocated based on that number.
228   ///
229   /// The stacks are also used to find error cases and notify the user.  A
230   /// standard logical operator nest for a boolean expression could be in a form
231   /// similar to this: "x = a && b && c && (d || f)"
232   unsigned NumCond = 0;
233   bool SplitNestedLogicalOp = false;
234   SmallVector<const Stmt *, 16> NonLogOpStack;
235   SmallVector<const BinaryOperator *, 16> LogOpStack;
236 
237   // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
238   bool dataTraverseStmtPre(Stmt *S) {
239     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240     if (MCDCMaxCond == 0)
241       return true;
242 
243     /// At the top of the logical operator nest, reset the number of conditions.
244     if (LogOpStack.empty())
245       NumCond = 0;
246 
247     if (const Expr *E = dyn_cast<Expr>(S)) {
248       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
249       if (BinOp && BinOp->isLogicalOp()) {
250         /// Check for "split-nested" logical operators. This happens when a new
251         /// boolean expression logical-op nest is encountered within an existing
252         /// boolean expression, separated by a non-logical operator.  For
253         /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
254         /// starts a new boolean expression that is separated from the other
255         /// conditions by the operator foo(). Split-nested cases are not
256         /// supported by MC/DC.
257         SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
258 
259         LogOpStack.push_back(BinOp);
260         return true;
261       }
262     }
263 
264     /// Keep track of non-logical operators. These are OK as long as we don't
265     /// encounter a new logical operator after seeing one.
266     if (!LogOpStack.empty())
267       NonLogOpStack.push_back(S);
268 
269     return true;
270   }
271 
272   // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
273   // an AST Stmt node.  MC/DC will use it to to signal when the top of a
274   // logical operation (boolean expression) nest is encountered.
275   bool dataTraverseStmtPost(Stmt *S) {
276     /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
277     if (MCDCMaxCond == 0)
278       return true;
279 
280     if (const Expr *E = dyn_cast<Expr>(S)) {
281       const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
282       if (BinOp && BinOp->isLogicalOp()) {
283         assert(LogOpStack.back() == BinOp);
284         LogOpStack.pop_back();
285 
286         /// At the top of logical operator nest:
287         if (LogOpStack.empty()) {
288           /// Was the "split-nested" logical operator case encountered?
289           if (SplitNestedLogicalOp) {
290             unsigned DiagID = Diag.getCustomDiagID(
291                 DiagnosticsEngine::Warning,
292                 "unsupported MC/DC boolean expression; "
293                 "contains an operation with a nested boolean expression. "
294                 "Expression will not be covered");
295             Diag.Report(S->getBeginLoc(), DiagID);
296             return false;
297           }
298 
299           /// Was the maximum number of conditions encountered?
300           if (NumCond > MCDCMaxCond) {
301             unsigned DiagID = Diag.getCustomDiagID(
302                 DiagnosticsEngine::Warning,
303                 "unsupported MC/DC boolean expression; "
304                 "number of conditions (%0) exceeds max (%1). "
305                 "Expression will not be covered");
306             Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
307             return false;
308           }
309 
310           // Otherwise, allocate the number of bytes required for the bitmap
311           // based on the number of conditions. Must be at least 1-byte long.
312           MCDCBitmapMap[BinOp] = NextMCDCBitmapIdx;
313           unsigned SizeInBits = std::max<unsigned>(1L << NumCond, CHAR_BIT);
314           NextMCDCBitmapIdx += SizeInBits / CHAR_BIT;
315         }
316         return true;
317       }
318     }
319 
320     if (!LogOpStack.empty())
321       NonLogOpStack.pop_back();
322 
323     return true;
324   }
325 
326   /// The RHS of all logical operators gets a fresh counter in order to count
327   /// how many times the RHS evaluates to true or false, depending on the
328   /// semantics of the operator. This is only valid for ">= v7" of the profile
329   /// version so that we facilitate backward compatibility. In addition, in
330   /// order to use MC/DC, count the number of total LHS and RHS conditions.
331   bool VisitBinaryOperator(BinaryOperator *S) {
332     if (S->isLogicalOp()) {
333       if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
334         NumCond++;
335 
336       if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
337         if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
338           CounterMap[S->getRHS()] = NextCounter++;
339 
340         NumCond++;
341       }
342     }
343     return Base::VisitBinaryOperator(S);
344   }
345 
346   /// Include \p S in the function hash.
347   bool VisitStmt(Stmt *S) {
348     auto Type = updateCounterMappings(S);
349     if (Hash.getHashVersion() != PGO_HASH_V1)
350       Type = getHashType(Hash.getHashVersion(), S);
351     if (Type != PGOHash::None)
352       Hash.combine(Type);
353     return true;
354   }
355 
356   bool TraverseIfStmt(IfStmt *If) {
357     // If we used the V1 hash, use the default traversal.
358     if (Hash.getHashVersion() == PGO_HASH_V1)
359       return Base::TraverseIfStmt(If);
360 
361     // Otherwise, keep track of which branch we're in while traversing.
362     VisitStmt(If);
363     for (Stmt *CS : If->children()) {
364       if (!CS)
365         continue;
366       if (CS == If->getThen())
367         Hash.combine(PGOHash::IfThenBranch);
368       else if (CS == If->getElse())
369         Hash.combine(PGOHash::IfElseBranch);
370       TraverseStmt(CS);
371     }
372     Hash.combine(PGOHash::EndOfScope);
373     return true;
374   }
375 
376 // If the statement type \p N is nestable, and its nesting impacts profile
377 // stability, define a custom traversal which tracks the end of the statement
378 // in the hash (provided we're not using the V1 hash).
379 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
380   bool Traverse##N(N *S) {                                                     \
381     Base::Traverse##N(S);                                                      \
382     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
383       Hash.combine(PGOHash::EndOfScope);                                       \
384     return true;                                                               \
385   }
386 
387   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
388   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
389   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
390   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
391   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
392   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
393   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
394 
395   /// Get version \p HashVersion of the PGO hash for \p S.
396   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
397     switch (S->getStmtClass()) {
398     default:
399       break;
400     case Stmt::LabelStmtClass:
401       return PGOHash::LabelStmt;
402     case Stmt::WhileStmtClass:
403       return PGOHash::WhileStmt;
404     case Stmt::DoStmtClass:
405       return PGOHash::DoStmt;
406     case Stmt::ForStmtClass:
407       return PGOHash::ForStmt;
408     case Stmt::CXXForRangeStmtClass:
409       return PGOHash::CXXForRangeStmt;
410     case Stmt::ObjCForCollectionStmtClass:
411       return PGOHash::ObjCForCollectionStmt;
412     case Stmt::SwitchStmtClass:
413       return PGOHash::SwitchStmt;
414     case Stmt::CaseStmtClass:
415       return PGOHash::CaseStmt;
416     case Stmt::DefaultStmtClass:
417       return PGOHash::DefaultStmt;
418     case Stmt::IfStmtClass:
419       return PGOHash::IfStmt;
420     case Stmt::CXXTryStmtClass:
421       return PGOHash::CXXTryStmt;
422     case Stmt::CXXCatchStmtClass:
423       return PGOHash::CXXCatchStmt;
424     case Stmt::ConditionalOperatorClass:
425       return PGOHash::ConditionalOperator;
426     case Stmt::BinaryConditionalOperatorClass:
427       return PGOHash::BinaryConditionalOperator;
428     case Stmt::BinaryOperatorClass: {
429       const BinaryOperator *BO = cast<BinaryOperator>(S);
430       if (BO->getOpcode() == BO_LAnd)
431         return PGOHash::BinaryOperatorLAnd;
432       if (BO->getOpcode() == BO_LOr)
433         return PGOHash::BinaryOperatorLOr;
434       if (HashVersion >= PGO_HASH_V2) {
435         switch (BO->getOpcode()) {
436         default:
437           break;
438         case BO_LT:
439           return PGOHash::BinaryOperatorLT;
440         case BO_GT:
441           return PGOHash::BinaryOperatorGT;
442         case BO_LE:
443           return PGOHash::BinaryOperatorLE;
444         case BO_GE:
445           return PGOHash::BinaryOperatorGE;
446         case BO_EQ:
447           return PGOHash::BinaryOperatorEQ;
448         case BO_NE:
449           return PGOHash::BinaryOperatorNE;
450         }
451       }
452       break;
453     }
454     }
455 
456     if (HashVersion >= PGO_HASH_V2) {
457       switch (S->getStmtClass()) {
458       default:
459         break;
460       case Stmt::GotoStmtClass:
461         return PGOHash::GotoStmt;
462       case Stmt::IndirectGotoStmtClass:
463         return PGOHash::IndirectGotoStmt;
464       case Stmt::BreakStmtClass:
465         return PGOHash::BreakStmt;
466       case Stmt::ContinueStmtClass:
467         return PGOHash::ContinueStmt;
468       case Stmt::ReturnStmtClass:
469         return PGOHash::ReturnStmt;
470       case Stmt::CXXThrowExprClass:
471         return PGOHash::ThrowExpr;
472       case Stmt::UnaryOperatorClass: {
473         const UnaryOperator *UO = cast<UnaryOperator>(S);
474         if (UO->getOpcode() == UO_LNot)
475           return PGOHash::UnaryOperatorLNot;
476         break;
477       }
478       }
479     }
480 
481     return PGOHash::None;
482   }
483 };
484 
485 /// A StmtVisitor that propagates the raw counts through the AST and
486 /// records the count at statements where the value may change.
487 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
488   /// PGO state.
489   CodeGenPGO &PGO;
490 
491   /// A flag that is set when the current count should be recorded on the
492   /// next statement, such as at the exit of a loop.
493   bool RecordNextStmtCount;
494 
495   /// The count at the current location in the traversal.
496   uint64_t CurrentCount;
497 
498   /// The map of statements to count values.
499   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
500 
501   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
502   struct BreakContinue {
503     uint64_t BreakCount = 0;
504     uint64_t ContinueCount = 0;
505     BreakContinue() = default;
506   };
507   SmallVector<BreakContinue, 8> BreakContinueStack;
508 
509   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
510                       CodeGenPGO &PGO)
511       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
512 
513   void RecordStmtCount(const Stmt *S) {
514     if (RecordNextStmtCount) {
515       CountMap[S] = CurrentCount;
516       RecordNextStmtCount = false;
517     }
518   }
519 
520   /// Set and return the current count.
521   uint64_t setCount(uint64_t Count) {
522     CurrentCount = Count;
523     return Count;
524   }
525 
526   void VisitStmt(const Stmt *S) {
527     RecordStmtCount(S);
528     for (const Stmt *Child : S->children())
529       if (Child)
530         this->Visit(Child);
531   }
532 
533   void VisitFunctionDecl(const FunctionDecl *D) {
534     // Counter tracks entry to the function body.
535     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
536     CountMap[D->getBody()] = BodyCount;
537     Visit(D->getBody());
538   }
539 
540   // Skip lambda expressions. We visit these as FunctionDecls when we're
541   // generating them and aren't interested in the body when generating a
542   // parent context.
543   void VisitLambdaExpr(const LambdaExpr *LE) {}
544 
545   void VisitCapturedDecl(const CapturedDecl *D) {
546     // Counter tracks entry to the capture body.
547     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
548     CountMap[D->getBody()] = BodyCount;
549     Visit(D->getBody());
550   }
551 
552   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
553     // Counter tracks entry to the method body.
554     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
555     CountMap[D->getBody()] = BodyCount;
556     Visit(D->getBody());
557   }
558 
559   void VisitBlockDecl(const BlockDecl *D) {
560     // Counter tracks entry to the block body.
561     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
562     CountMap[D->getBody()] = BodyCount;
563     Visit(D->getBody());
564   }
565 
566   void VisitReturnStmt(const ReturnStmt *S) {
567     RecordStmtCount(S);
568     if (S->getRetValue())
569       Visit(S->getRetValue());
570     CurrentCount = 0;
571     RecordNextStmtCount = true;
572   }
573 
574   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
575     RecordStmtCount(E);
576     if (E->getSubExpr())
577       Visit(E->getSubExpr());
578     CurrentCount = 0;
579     RecordNextStmtCount = true;
580   }
581 
582   void VisitGotoStmt(const GotoStmt *S) {
583     RecordStmtCount(S);
584     CurrentCount = 0;
585     RecordNextStmtCount = true;
586   }
587 
588   void VisitLabelStmt(const LabelStmt *S) {
589     RecordNextStmtCount = false;
590     // Counter tracks the block following the label.
591     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
592     CountMap[S] = BlockCount;
593     Visit(S->getSubStmt());
594   }
595 
596   void VisitBreakStmt(const BreakStmt *S) {
597     RecordStmtCount(S);
598     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
599     BreakContinueStack.back().BreakCount += CurrentCount;
600     CurrentCount = 0;
601     RecordNextStmtCount = true;
602   }
603 
604   void VisitContinueStmt(const ContinueStmt *S) {
605     RecordStmtCount(S);
606     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
607     BreakContinueStack.back().ContinueCount += CurrentCount;
608     CurrentCount = 0;
609     RecordNextStmtCount = true;
610   }
611 
612   void VisitWhileStmt(const WhileStmt *S) {
613     RecordStmtCount(S);
614     uint64_t ParentCount = CurrentCount;
615 
616     BreakContinueStack.push_back(BreakContinue());
617     // Visit the body region first so the break/continue adjustments can be
618     // included when visiting the condition.
619     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
620     CountMap[S->getBody()] = CurrentCount;
621     Visit(S->getBody());
622     uint64_t BackedgeCount = CurrentCount;
623 
624     // ...then go back and propagate counts through the condition. The count
625     // at the start of the condition is the sum of the incoming edges,
626     // the backedge from the end of the loop body, and the edges from
627     // continue statements.
628     BreakContinue BC = BreakContinueStack.pop_back_val();
629     uint64_t CondCount =
630         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
631     CountMap[S->getCond()] = CondCount;
632     Visit(S->getCond());
633     setCount(BC.BreakCount + CondCount - BodyCount);
634     RecordNextStmtCount = true;
635   }
636 
637   void VisitDoStmt(const DoStmt *S) {
638     RecordStmtCount(S);
639     uint64_t LoopCount = PGO.getRegionCount(S);
640 
641     BreakContinueStack.push_back(BreakContinue());
642     // The count doesn't include the fallthrough from the parent scope. Add it.
643     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
644     CountMap[S->getBody()] = BodyCount;
645     Visit(S->getBody());
646     uint64_t BackedgeCount = CurrentCount;
647 
648     BreakContinue BC = BreakContinueStack.pop_back_val();
649     // The count at the start of the condition is equal to the count at the
650     // end of the body, plus any continues.
651     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
652     CountMap[S->getCond()] = CondCount;
653     Visit(S->getCond());
654     setCount(BC.BreakCount + CondCount - LoopCount);
655     RecordNextStmtCount = true;
656   }
657 
658   void VisitForStmt(const ForStmt *S) {
659     RecordStmtCount(S);
660     if (S->getInit())
661       Visit(S->getInit());
662 
663     uint64_t ParentCount = CurrentCount;
664 
665     BreakContinueStack.push_back(BreakContinue());
666     // Visit the body region first. (This is basically the same as a while
667     // loop; see further comments in VisitWhileStmt.)
668     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
669     CountMap[S->getBody()] = BodyCount;
670     Visit(S->getBody());
671     uint64_t BackedgeCount = CurrentCount;
672     BreakContinue BC = BreakContinueStack.pop_back_val();
673 
674     // The increment is essentially part of the body but it needs to include
675     // the count for all the continue statements.
676     if (S->getInc()) {
677       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
678       CountMap[S->getInc()] = IncCount;
679       Visit(S->getInc());
680     }
681 
682     // ...then go back and propagate counts through the condition.
683     uint64_t CondCount =
684         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
685     if (S->getCond()) {
686       CountMap[S->getCond()] = CondCount;
687       Visit(S->getCond());
688     }
689     setCount(BC.BreakCount + CondCount - BodyCount);
690     RecordNextStmtCount = true;
691   }
692 
693   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
694     RecordStmtCount(S);
695     if (S->getInit())
696       Visit(S->getInit());
697     Visit(S->getLoopVarStmt());
698     Visit(S->getRangeStmt());
699     Visit(S->getBeginStmt());
700     Visit(S->getEndStmt());
701 
702     uint64_t ParentCount = CurrentCount;
703     BreakContinueStack.push_back(BreakContinue());
704     // Visit the body region first. (This is basically the same as a while
705     // loop; see further comments in VisitWhileStmt.)
706     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
707     CountMap[S->getBody()] = BodyCount;
708     Visit(S->getBody());
709     uint64_t BackedgeCount = CurrentCount;
710     BreakContinue BC = BreakContinueStack.pop_back_val();
711 
712     // The increment is essentially part of the body but it needs to include
713     // the count for all the continue statements.
714     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
715     CountMap[S->getInc()] = IncCount;
716     Visit(S->getInc());
717 
718     // ...then go back and propagate counts through the condition.
719     uint64_t CondCount =
720         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
721     CountMap[S->getCond()] = CondCount;
722     Visit(S->getCond());
723     setCount(BC.BreakCount + CondCount - BodyCount);
724     RecordNextStmtCount = true;
725   }
726 
727   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
728     RecordStmtCount(S);
729     Visit(S->getElement());
730     uint64_t ParentCount = CurrentCount;
731     BreakContinueStack.push_back(BreakContinue());
732     // Counter tracks the body of the loop.
733     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
734     CountMap[S->getBody()] = BodyCount;
735     Visit(S->getBody());
736     uint64_t BackedgeCount = CurrentCount;
737     BreakContinue BC = BreakContinueStack.pop_back_val();
738 
739     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
740              BodyCount);
741     RecordNextStmtCount = true;
742   }
743 
744   void VisitSwitchStmt(const SwitchStmt *S) {
745     RecordStmtCount(S);
746     if (S->getInit())
747       Visit(S->getInit());
748     Visit(S->getCond());
749     CurrentCount = 0;
750     BreakContinueStack.push_back(BreakContinue());
751     Visit(S->getBody());
752     // If the switch is inside a loop, add the continue counts.
753     BreakContinue BC = BreakContinueStack.pop_back_val();
754     if (!BreakContinueStack.empty())
755       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
756     // Counter tracks the exit block of the switch.
757     setCount(PGO.getRegionCount(S));
758     RecordNextStmtCount = true;
759   }
760 
761   void VisitSwitchCase(const SwitchCase *S) {
762     RecordNextStmtCount = false;
763     // Counter for this particular case. This counts only jumps from the
764     // switch header and does not include fallthrough from the case before
765     // this one.
766     uint64_t CaseCount = PGO.getRegionCount(S);
767     setCount(CurrentCount + CaseCount);
768     // We need the count without fallthrough in the mapping, so it's more useful
769     // for branch probabilities.
770     CountMap[S] = CaseCount;
771     RecordNextStmtCount = true;
772     Visit(S->getSubStmt());
773   }
774 
775   void VisitIfStmt(const IfStmt *S) {
776     RecordStmtCount(S);
777 
778     if (S->isConsteval()) {
779       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
780       if (Stm)
781         Visit(Stm);
782       return;
783     }
784 
785     uint64_t ParentCount = CurrentCount;
786     if (S->getInit())
787       Visit(S->getInit());
788     Visit(S->getCond());
789 
790     // Counter tracks the "then" part of an if statement. The count for
791     // the "else" part, if it exists, will be calculated from this counter.
792     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
793     CountMap[S->getThen()] = ThenCount;
794     Visit(S->getThen());
795     uint64_t OutCount = CurrentCount;
796 
797     uint64_t ElseCount = ParentCount - ThenCount;
798     if (S->getElse()) {
799       setCount(ElseCount);
800       CountMap[S->getElse()] = ElseCount;
801       Visit(S->getElse());
802       OutCount += CurrentCount;
803     } else
804       OutCount += ElseCount;
805     setCount(OutCount);
806     RecordNextStmtCount = true;
807   }
808 
809   void VisitCXXTryStmt(const CXXTryStmt *S) {
810     RecordStmtCount(S);
811     Visit(S->getTryBlock());
812     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
813       Visit(S->getHandler(I));
814     // Counter tracks the continuation block of the try statement.
815     setCount(PGO.getRegionCount(S));
816     RecordNextStmtCount = true;
817   }
818 
819   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
820     RecordNextStmtCount = false;
821     // Counter tracks the catch statement's handler block.
822     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
823     CountMap[S] = CatchCount;
824     Visit(S->getHandlerBlock());
825   }
826 
827   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
828     RecordStmtCount(E);
829     uint64_t ParentCount = CurrentCount;
830     Visit(E->getCond());
831 
832     // Counter tracks the "true" part of a conditional operator. The
833     // count in the "false" part will be calculated from this counter.
834     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
835     CountMap[E->getTrueExpr()] = TrueCount;
836     Visit(E->getTrueExpr());
837     uint64_t OutCount = CurrentCount;
838 
839     uint64_t FalseCount = setCount(ParentCount - TrueCount);
840     CountMap[E->getFalseExpr()] = FalseCount;
841     Visit(E->getFalseExpr());
842     OutCount += CurrentCount;
843 
844     setCount(OutCount);
845     RecordNextStmtCount = true;
846   }
847 
848   void VisitBinLAnd(const BinaryOperator *E) {
849     RecordStmtCount(E);
850     uint64_t ParentCount = CurrentCount;
851     Visit(E->getLHS());
852     // Counter tracks the right hand side of a logical and operator.
853     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
854     CountMap[E->getRHS()] = RHSCount;
855     Visit(E->getRHS());
856     setCount(ParentCount + RHSCount - CurrentCount);
857     RecordNextStmtCount = true;
858   }
859 
860   void VisitBinLOr(const BinaryOperator *E) {
861     RecordStmtCount(E);
862     uint64_t ParentCount = CurrentCount;
863     Visit(E->getLHS());
864     // Counter tracks the right hand side of a logical or operator.
865     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
866     CountMap[E->getRHS()] = RHSCount;
867     Visit(E->getRHS());
868     setCount(ParentCount + RHSCount - CurrentCount);
869     RecordNextStmtCount = true;
870   }
871 };
872 } // end anonymous namespace
873 
874 void PGOHash::combine(HashType Type) {
875   // Check that we never combine 0 and only have six bits.
876   assert(Type && "Hash is invalid: unexpected type 0");
877   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
878 
879   // Pass through MD5 if enough work has built up.
880   if (Count && Count % NumTypesPerWord == 0) {
881     using namespace llvm::support;
882     uint64_t Swapped =
883         endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
884     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
885     Working = 0;
886   }
887 
888   // Accumulate the current type.
889   ++Count;
890   Working = Working << NumBitsPerType | Type;
891 }
892 
893 uint64_t PGOHash::finalize() {
894   // Use Working as the hash directly if we never used MD5.
895   if (Count <= NumTypesPerWord)
896     // No need to byte swap here, since none of the math was endian-dependent.
897     // This number will be byte-swapped as required on endianness transitions,
898     // so we will see the same value on the other side.
899     return Working;
900 
901   // Check for remaining work in Working.
902   if (Working) {
903     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
904     // is buggy because it converts a uint64_t into an array of uint8_t.
905     if (HashVersion < PGO_HASH_V3) {
906       MD5.update({(uint8_t)Working});
907     } else {
908       using namespace llvm::support;
909       uint64_t Swapped =
910           endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
911       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
912     }
913   }
914 
915   // Finalize the MD5 and return the hash.
916   llvm::MD5::MD5Result Result;
917   MD5.final(Result);
918   return Result.low();
919 }
920 
921 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
922   const Decl *D = GD.getDecl();
923   if (!D->hasBody())
924     return;
925 
926   // Skip CUDA/HIP kernel launch stub functions.
927   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
928       D->hasAttr<CUDAGlobalAttr>())
929     return;
930 
931   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
932   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
933   if (!InstrumentRegions && !PGOReader)
934     return;
935   if (D->isImplicit())
936     return;
937   // Constructors and destructors may be represented by several functions in IR.
938   // If so, instrument only base variant, others are implemented by delegation
939   // to the base one, it would be counted twice otherwise.
940   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
941     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
942       if (GD.getCtorType() != Ctor_Base &&
943           CodeGenFunction::IsConstructorDelegationValid(CCD))
944         return;
945   }
946   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
947     return;
948 
949   CGM.ClearUnusedCoverageMapping(D);
950   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
951     return;
952   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
953     return;
954 
955   setFuncName(Fn);
956 
957   mapRegionCounters(D);
958   if (CGM.getCodeGenOpts().CoverageMapping)
959     emitCounterRegionMapping(D);
960   if (PGOReader) {
961     SourceManager &SM = CGM.getContext().getSourceManager();
962     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
963     computeRegionCounts(D);
964     applyFunctionAttributes(PGOReader, Fn);
965   }
966 }
967 
968 void CodeGenPGO::mapRegionCounters(const Decl *D) {
969   // Use the latest hash version when inserting instrumentation, but use the
970   // version in the indexed profile if we're reading PGO data.
971   PGOHashVersion HashVersion = PGO_HASH_LATEST;
972   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
973   if (auto *PGOReader = CGM.getPGOReader()) {
974     HashVersion = getPGOHashVersion(PGOReader, CGM);
975     ProfileVersion = PGOReader->getVersion();
976   }
977 
978   // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
979   // set it to zero. This value impacts the number of conditions accepted in a
980   // given boolean expression, which impacts the size of the bitmap used to
981   // track test vector execution for that boolean expression.  Because the
982   // bitmap scales exponentially (2^n) based on the number of conditions seen,
983   // the maximum value is hard-coded at 6 conditions, which is more than enough
984   // for most embedded applications. Setting a maximum value prevents the
985   // bitmap footprint from growing too large without the user's knowledge. In
986   // the future, this value could be adjusted with a command-line option.
987   unsigned MCDCMaxConditions = (CGM.getCodeGenOpts().MCDCCoverage) ? 6 : 0;
988 
989   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
990   RegionMCDCBitmapMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
991   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
992                            *RegionMCDCBitmapMap, MCDCMaxConditions,
993                            CGM.getDiags());
994   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
995     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
996   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
997     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
998   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
999     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1000   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1001     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1002   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1003   NumRegionCounters = Walker.NextCounter;
1004   MCDCBitmapBytes = Walker.NextMCDCBitmapIdx;
1005   FunctionHash = Walker.Hash.finalize();
1006 }
1007 
1008 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1009   if (!D->getBody())
1010     return true;
1011 
1012   // Skip host-only functions in the CUDA device compilation and device-only
1013   // functions in the host compilation. Just roughly filter them out based on
1014   // the function attributes. If there are effectively host-only or device-only
1015   // ones, their coverage mapping may still be generated.
1016   if (CGM.getLangOpts().CUDA &&
1017       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1018         !D->hasAttr<CUDAGlobalAttr>()) ||
1019        (!CGM.getLangOpts().CUDAIsDevice &&
1020         (D->hasAttr<CUDAGlobalAttr>() ||
1021          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1022     return true;
1023 
1024   // Don't map the functions in system headers.
1025   const auto &SM = CGM.getContext().getSourceManager();
1026   auto Loc = D->getBody()->getBeginLoc();
1027   return !SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1028 }
1029 
1030 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1031   if (skipRegionMappingForDecl(D))
1032     return;
1033 
1034   std::string CoverageMapping;
1035   llvm::raw_string_ostream OS(CoverageMapping);
1036   RegionCondIDMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1037   CoverageMappingGen MappingGen(
1038       *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1039       CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCBitmapMap.get(),
1040       RegionCondIDMap.get());
1041   MappingGen.emitCounterMapping(D, OS);
1042   OS.flush();
1043 
1044   if (CoverageMapping.empty())
1045     return;
1046 
1047   CGM.getCoverageMapping()->addFunctionMappingRecord(
1048       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1049 }
1050 
1051 void
1052 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1053                                     llvm::GlobalValue::LinkageTypes Linkage) {
1054   if (skipRegionMappingForDecl(D))
1055     return;
1056 
1057   std::string CoverageMapping;
1058   llvm::raw_string_ostream OS(CoverageMapping);
1059   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1060                                 CGM.getContext().getSourceManager(),
1061                                 CGM.getLangOpts());
1062   MappingGen.emitEmptyMapping(D, OS);
1063   OS.flush();
1064 
1065   if (CoverageMapping.empty())
1066     return;
1067 
1068   setFuncName(Name, Linkage);
1069   CGM.getCoverageMapping()->addFunctionMappingRecord(
1070       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1071 }
1072 
1073 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1074   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1075   ComputeRegionCounts Walker(*StmtCountMap, *this);
1076   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1077     Walker.VisitFunctionDecl(FD);
1078   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1079     Walker.VisitObjCMethodDecl(MD);
1080   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1081     Walker.VisitBlockDecl(BD);
1082   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1083     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1084 }
1085 
1086 void
1087 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1088                                     llvm::Function *Fn) {
1089   if (!haveRegionCounts())
1090     return;
1091 
1092   uint64_t FunctionCount = getRegionCount(nullptr);
1093   Fn->setEntryCount(FunctionCount);
1094 }
1095 
1096 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
1097                                       llvm::Value *StepV) {
1098   if (!RegionCounterMap || !Builder.GetInsertBlock())
1099     return;
1100 
1101   unsigned Counter = (*RegionCounterMap)[S];
1102 
1103   llvm::Value *Args[] = {FuncNameVar,
1104                          Builder.getInt64(FunctionHash),
1105                          Builder.getInt32(NumRegionCounters),
1106                          Builder.getInt32(Counter), StepV};
1107   if (!StepV)
1108     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1109                        ArrayRef(Args, 4));
1110   else
1111     Builder.CreateCall(
1112         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
1113         ArrayRef(Args));
1114 }
1115 
1116 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1117   return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1118           CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1119 }
1120 
1121 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1122   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1123     return;
1124 
1125   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1126 
1127   // Emit intrinsic representing MCDC bitmap parameters at function entry.
1128   // This is used by the instrumentation pass, but it isn't actually lowered to
1129   // anything.
1130   llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1131                           Builder.getInt64(FunctionHash),
1132                           Builder.getInt32(MCDCBitmapBytes)};
1133   Builder.CreateCall(
1134       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1135 }
1136 
1137 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1138                                                 const Expr *S,
1139                                                 Address MCDCCondBitmapAddr) {
1140   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1141     return;
1142 
1143   S = S->IgnoreParens();
1144 
1145   auto ExprMCDCBitmapMapIterator = RegionMCDCBitmapMap->find(S);
1146   if (ExprMCDCBitmapMapIterator == RegionMCDCBitmapMap->end())
1147     return;
1148 
1149   // Extract the ID of the global bitmap associated with this expression.
1150   unsigned MCDCTestVectorBitmapID = ExprMCDCBitmapMapIterator->second;
1151   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1152 
1153   // Emit intrinsic responsible for updating the global bitmap corresponding to
1154   // a boolean expression. The index being set is based on the value loaded
1155   // from a pointer to a dedicated temporary value on the stack that is itself
1156   // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1157   // index represents an executed test vector.
1158   llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1159                           Builder.getInt64(FunctionHash),
1160                           Builder.getInt32(MCDCBitmapBytes),
1161                           Builder.getInt32(MCDCTestVectorBitmapID),
1162                           MCDCCondBitmapAddr.getPointer()};
1163   Builder.CreateCall(
1164       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1165 }
1166 
1167 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1168                                          Address MCDCCondBitmapAddr) {
1169   if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1170     return;
1171 
1172   S = S->IgnoreParens();
1173 
1174   if (RegionMCDCBitmapMap->find(S) == RegionMCDCBitmapMap->end())
1175     return;
1176 
1177   // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1178   Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1179 }
1180 
1181 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1182                                           Address MCDCCondBitmapAddr,
1183                                           llvm::Value *Val) {
1184   if (!canEmitMCDCCoverage(Builder) || !RegionCondIDMap)
1185     return;
1186 
1187   // Even though, for simplicity, parentheses and unary logical-NOT operators
1188   // are considered part of their underlying condition for both MC/DC and
1189   // branch coverage, the condition IDs themselves are assigned and tracked
1190   // using the underlying condition itself.  This is done solely for
1191   // consistency since parentheses and logical-NOTs are ignored when checking
1192   // whether the condition is actually an instrumentable condition. This can
1193   // also make debugging a bit easier.
1194   S = CodeGenFunction::stripCond(S);
1195 
1196   auto ExprMCDCConditionIDMapIterator = RegionCondIDMap->find(S);
1197   if (ExprMCDCConditionIDMapIterator == RegionCondIDMap->end())
1198     return;
1199 
1200   // Extract the ID of the condition we are setting in the bitmap.
1201   unsigned CondID = ExprMCDCConditionIDMapIterator->second;
1202   assert(CondID > 0 && "Condition has no ID!");
1203 
1204   auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1205 
1206   // Emit intrinsic that updates a dedicated temporary value on the stack after
1207   // a condition is evaluated. After the set of conditions has been updated,
1208   // the resulting value is used to update the boolean expression's bitmap.
1209   llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1210                           Builder.getInt64(FunctionHash),
1211                           Builder.getInt32(CondID - 1),
1212                           MCDCCondBitmapAddr.getPointer(), Val};
1213   Builder.CreateCall(
1214       CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_condbitmap_update),
1215       Args);
1216 }
1217 
1218 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1219   if (CGM.getCodeGenOpts().hasProfileClangInstr())
1220     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1221                     uint32_t(EnableValueProfiling));
1222 }
1223 
1224 // This method either inserts a call to the profile run-time during
1225 // instrumentation or puts profile data into metadata for PGO use.
1226 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1227     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1228 
1229   if (!EnableValueProfiling)
1230     return;
1231 
1232   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1233     return;
1234 
1235   if (isa<llvm::Constant>(ValuePtr))
1236     return;
1237 
1238   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1239   if (InstrumentValueSites && RegionCounterMap) {
1240     auto BuilderInsertPoint = Builder.saveIP();
1241     Builder.SetInsertPoint(ValueSite);
1242     llvm::Value *Args[5] = {
1243         FuncNameVar,
1244         Builder.getInt64(FunctionHash),
1245         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1246         Builder.getInt32(ValueKind),
1247         Builder.getInt32(NumValueSites[ValueKind]++)
1248     };
1249     Builder.CreateCall(
1250         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1251     Builder.restoreIP(BuilderInsertPoint);
1252     return;
1253   }
1254 
1255   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1256   if (PGOReader && haveRegionCounts()) {
1257     // We record the top most called three functions at each call site.
1258     // Profile metadata contains "VP" string identifying this metadata
1259     // as value profiling data, then a uint32_t value for the value profiling
1260     // kind, a uint64_t value for the total number of times the call is
1261     // executed, followed by the function hash and execution count (uint64_t)
1262     // pairs for each function.
1263     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1264       return;
1265 
1266     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1267                             (llvm::InstrProfValueKind)ValueKind,
1268                             NumValueSites[ValueKind]);
1269 
1270     NumValueSites[ValueKind]++;
1271   }
1272 }
1273 
1274 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1275                                   bool IsInMainFile) {
1276   CGM.getPGOStats().addVisited(IsInMainFile);
1277   RegionCounts.clear();
1278   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1279       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1280   if (auto E = RecordExpected.takeError()) {
1281     auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1282     if (IPE == llvm::instrprof_error::unknown_function)
1283       CGM.getPGOStats().addMissing(IsInMainFile);
1284     else if (IPE == llvm::instrprof_error::hash_mismatch)
1285       CGM.getPGOStats().addMismatched(IsInMainFile);
1286     else if (IPE == llvm::instrprof_error::malformed)
1287       // TODO: Consider a more specific warning for this case.
1288       CGM.getPGOStats().addMismatched(IsInMainFile);
1289     return;
1290   }
1291   ProfRecord =
1292       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1293   RegionCounts = ProfRecord->Counts;
1294 }
1295 
1296 /// Calculate what to divide by to scale weights.
1297 ///
1298 /// Given the maximum weight, calculate a divisor that will scale all the
1299 /// weights to strictly less than UINT32_MAX.
1300 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1301   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1302 }
1303 
1304 /// Scale an individual branch weight (and add 1).
1305 ///
1306 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1307 ///
1308 /// According to Laplace's Rule of Succession, it is better to compute the
1309 /// weight based on the count plus 1, so universally add 1 to the value.
1310 ///
1311 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1312 /// greater than \c Weight.
1313 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1314   assert(Scale && "scale by 0?");
1315   uint64_t Scaled = Weight / Scale + 1;
1316   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1317   return Scaled;
1318 }
1319 
1320 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1321                                                     uint64_t FalseCount) const {
1322   // Check for empty weights.
1323   if (!TrueCount && !FalseCount)
1324     return nullptr;
1325 
1326   // Calculate how to scale down to 32-bits.
1327   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1328 
1329   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1330   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1331                                       scaleBranchWeight(FalseCount, Scale));
1332 }
1333 
1334 llvm::MDNode *
1335 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1336   // We need at least two elements to create meaningful weights.
1337   if (Weights.size() < 2)
1338     return nullptr;
1339 
1340   // Check for empty weights.
1341   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1342   if (MaxWeight == 0)
1343     return nullptr;
1344 
1345   // Calculate how to scale down to 32-bits.
1346   uint64_t Scale = calculateWeightScale(MaxWeight);
1347 
1348   SmallVector<uint32_t, 16> ScaledWeights;
1349   ScaledWeights.reserve(Weights.size());
1350   for (uint64_t W : Weights)
1351     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1352 
1353   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1354   return MDHelper.createBranchWeights(ScaledWeights);
1355 }
1356 
1357 llvm::MDNode *
1358 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1359                                              uint64_t LoopCount) const {
1360   if (!PGO.haveRegionCounts())
1361     return nullptr;
1362   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1363   if (!CondCount || *CondCount == 0)
1364     return nullptr;
1365   return createProfileWeights(LoopCount,
1366                               std::max(*CondCount, LoopCount) - LoopCount);
1367 }
1368