xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 7ef62cebc2f965b0f640263e179276928885e33d)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25 
26 static llvm::cl::opt<bool>
27     EnableValueProfiling("enable-value-profiling",
28                          llvm::cl::desc("Enable value profiling"),
29                          llvm::cl::Hidden, llvm::cl::init(false));
30 
31 using namespace clang;
32 using namespace CodeGen;
33 
34 void CodeGenPGO::setFuncName(StringRef Name,
35                              llvm::GlobalValue::LinkageTypes Linkage) {
36   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
37   FuncName = llvm::getPGOFuncName(
38       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
39       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
40 
41   // If we're generating a profile, create a variable for the name.
42   if (CGM.getCodeGenOpts().hasProfileClangInstr())
43     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
44 }
45 
46 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
47   setFuncName(Fn->getName(), Fn->getLinkage());
48   // Create PGOFuncName meta data.
49   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
50 }
51 
52 /// The version of the PGO hash algorithm.
53 enum PGOHashVersion : unsigned {
54   PGO_HASH_V1,
55   PGO_HASH_V2,
56   PGO_HASH_V3,
57 
58   // Keep this set to the latest hash version.
59   PGO_HASH_LATEST = PGO_HASH_V3
60 };
61 
62 namespace {
63 /// Stable hasher for PGO region counters.
64 ///
65 /// PGOHash produces a stable hash of a given function's control flow.
66 ///
67 /// Changing the output of this hash will invalidate all previously generated
68 /// profiles -- i.e., don't do it.
69 ///
70 /// \note  When this hash does eventually change (years?), we still need to
71 /// support old hashes.  We'll need to pull in the version number from the
72 /// profile data format and use the matching hash function.
73 class PGOHash {
74   uint64_t Working;
75   unsigned Count;
76   PGOHashVersion HashVersion;
77   llvm::MD5 MD5;
78 
79   static const int NumBitsPerType = 6;
80   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
81   static const unsigned TooBig = 1u << NumBitsPerType;
82 
83 public:
84   /// Hash values for AST nodes.
85   ///
86   /// Distinct values for AST nodes that have region counters attached.
87   ///
88   /// These values must be stable.  All new members must be added at the end,
89   /// and no members should be removed.  Changing the enumeration value for an
90   /// AST node will affect the hash of every function that contains that node.
91   enum HashType : unsigned char {
92     None = 0,
93     LabelStmt = 1,
94     WhileStmt,
95     DoStmt,
96     ForStmt,
97     CXXForRangeStmt,
98     ObjCForCollectionStmt,
99     SwitchStmt,
100     CaseStmt,
101     DefaultStmt,
102     IfStmt,
103     CXXTryStmt,
104     CXXCatchStmt,
105     ConditionalOperator,
106     BinaryOperatorLAnd,
107     BinaryOperatorLOr,
108     BinaryConditionalOperator,
109     // The preceding values are available with PGO_HASH_V1.
110 
111     EndOfScope,
112     IfThenBranch,
113     IfElseBranch,
114     GotoStmt,
115     IndirectGotoStmt,
116     BreakStmt,
117     ContinueStmt,
118     ReturnStmt,
119     ThrowExpr,
120     UnaryOperatorLNot,
121     BinaryOperatorLT,
122     BinaryOperatorGT,
123     BinaryOperatorLE,
124     BinaryOperatorGE,
125     BinaryOperatorEQ,
126     BinaryOperatorNE,
127     // The preceding values are available since PGO_HASH_V2.
128 
129     // Keep this last.  It's for the static assert that follows.
130     LastHashType
131   };
132   static_assert(LastHashType <= TooBig, "Too many types in HashType");
133 
134   PGOHash(PGOHashVersion HashVersion)
135       : Working(0), Count(0), HashVersion(HashVersion) {}
136   void combine(HashType Type);
137   uint64_t finalize();
138   PGOHashVersion getHashVersion() const { return HashVersion; }
139 };
140 const int PGOHash::NumBitsPerType;
141 const unsigned PGOHash::NumTypesPerWord;
142 const unsigned PGOHash::TooBig;
143 
144 /// Get the PGO hash version used in the given indexed profile.
145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
146                                         CodeGenModule &CGM) {
147   if (PGOReader->getVersion() <= 4)
148     return PGO_HASH_V1;
149   if (PGOReader->getVersion() <= 5)
150     return PGO_HASH_V2;
151   return PGO_HASH_V3;
152 }
153 
154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
156   using Base = RecursiveASTVisitor<MapRegionCounters>;
157 
158   /// The next counter value to assign.
159   unsigned NextCounter;
160   /// The function hash.
161   PGOHash Hash;
162   /// The map of statements to counters.
163   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
164   /// The profile version.
165   uint64_t ProfileVersion;
166 
167   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
168                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
169       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
170         ProfileVersion(ProfileVersion) {}
171 
172   // Blocks and lambdas are handled as separate functions, so we need not
173   // traverse them in the parent context.
174   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
175   bool TraverseLambdaExpr(LambdaExpr *LE) {
176     // Traverse the captures, but not the body.
177     for (auto C : zip(LE->captures(), LE->capture_inits()))
178       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
179     return true;
180   }
181   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
182 
183   bool VisitDecl(const Decl *D) {
184     switch (D->getKind()) {
185     default:
186       break;
187     case Decl::Function:
188     case Decl::CXXMethod:
189     case Decl::CXXConstructor:
190     case Decl::CXXDestructor:
191     case Decl::CXXConversion:
192     case Decl::ObjCMethod:
193     case Decl::Block:
194     case Decl::Captured:
195       CounterMap[D->getBody()] = NextCounter++;
196       break;
197     }
198     return true;
199   }
200 
201   /// If \p S gets a fresh counter, update the counter mappings. Return the
202   /// V1 hash of \p S.
203   PGOHash::HashType updateCounterMappings(Stmt *S) {
204     auto Type = getHashType(PGO_HASH_V1, S);
205     if (Type != PGOHash::None)
206       CounterMap[S] = NextCounter++;
207     return Type;
208   }
209 
210   /// The RHS of all logical operators gets a fresh counter in order to count
211   /// how many times the RHS evaluates to true or false, depending on the
212   /// semantics of the operator. This is only valid for ">= v7" of the profile
213   /// version so that we facilitate backward compatibility.
214   bool VisitBinaryOperator(BinaryOperator *S) {
215     if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
216       if (S->isLogicalOp() &&
217           CodeGenFunction::isInstrumentedCondition(S->getRHS()))
218         CounterMap[S->getRHS()] = NextCounter++;
219     return Base::VisitBinaryOperator(S);
220   }
221 
222   /// Include \p S in the function hash.
223   bool VisitStmt(Stmt *S) {
224     auto Type = updateCounterMappings(S);
225     if (Hash.getHashVersion() != PGO_HASH_V1)
226       Type = getHashType(Hash.getHashVersion(), S);
227     if (Type != PGOHash::None)
228       Hash.combine(Type);
229     return true;
230   }
231 
232   bool TraverseIfStmt(IfStmt *If) {
233     // If we used the V1 hash, use the default traversal.
234     if (Hash.getHashVersion() == PGO_HASH_V1)
235       return Base::TraverseIfStmt(If);
236 
237     // Otherwise, keep track of which branch we're in while traversing.
238     VisitStmt(If);
239     for (Stmt *CS : If->children()) {
240       if (!CS)
241         continue;
242       if (CS == If->getThen())
243         Hash.combine(PGOHash::IfThenBranch);
244       else if (CS == If->getElse())
245         Hash.combine(PGOHash::IfElseBranch);
246       TraverseStmt(CS);
247     }
248     Hash.combine(PGOHash::EndOfScope);
249     return true;
250   }
251 
252 // If the statement type \p N is nestable, and its nesting impacts profile
253 // stability, define a custom traversal which tracks the end of the statement
254 // in the hash (provided we're not using the V1 hash).
255 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
256   bool Traverse##N(N *S) {                                                     \
257     Base::Traverse##N(S);                                                      \
258     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
259       Hash.combine(PGOHash::EndOfScope);                                       \
260     return true;                                                               \
261   }
262 
263   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
264   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
265   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
266   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
267   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
268   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
269   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
270 
271   /// Get version \p HashVersion of the PGO hash for \p S.
272   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
273     switch (S->getStmtClass()) {
274     default:
275       break;
276     case Stmt::LabelStmtClass:
277       return PGOHash::LabelStmt;
278     case Stmt::WhileStmtClass:
279       return PGOHash::WhileStmt;
280     case Stmt::DoStmtClass:
281       return PGOHash::DoStmt;
282     case Stmt::ForStmtClass:
283       return PGOHash::ForStmt;
284     case Stmt::CXXForRangeStmtClass:
285       return PGOHash::CXXForRangeStmt;
286     case Stmt::ObjCForCollectionStmtClass:
287       return PGOHash::ObjCForCollectionStmt;
288     case Stmt::SwitchStmtClass:
289       return PGOHash::SwitchStmt;
290     case Stmt::CaseStmtClass:
291       return PGOHash::CaseStmt;
292     case Stmt::DefaultStmtClass:
293       return PGOHash::DefaultStmt;
294     case Stmt::IfStmtClass:
295       return PGOHash::IfStmt;
296     case Stmt::CXXTryStmtClass:
297       return PGOHash::CXXTryStmt;
298     case Stmt::CXXCatchStmtClass:
299       return PGOHash::CXXCatchStmt;
300     case Stmt::ConditionalOperatorClass:
301       return PGOHash::ConditionalOperator;
302     case Stmt::BinaryConditionalOperatorClass:
303       return PGOHash::BinaryConditionalOperator;
304     case Stmt::BinaryOperatorClass: {
305       const BinaryOperator *BO = cast<BinaryOperator>(S);
306       if (BO->getOpcode() == BO_LAnd)
307         return PGOHash::BinaryOperatorLAnd;
308       if (BO->getOpcode() == BO_LOr)
309         return PGOHash::BinaryOperatorLOr;
310       if (HashVersion >= PGO_HASH_V2) {
311         switch (BO->getOpcode()) {
312         default:
313           break;
314         case BO_LT:
315           return PGOHash::BinaryOperatorLT;
316         case BO_GT:
317           return PGOHash::BinaryOperatorGT;
318         case BO_LE:
319           return PGOHash::BinaryOperatorLE;
320         case BO_GE:
321           return PGOHash::BinaryOperatorGE;
322         case BO_EQ:
323           return PGOHash::BinaryOperatorEQ;
324         case BO_NE:
325           return PGOHash::BinaryOperatorNE;
326         }
327       }
328       break;
329     }
330     }
331 
332     if (HashVersion >= PGO_HASH_V2) {
333       switch (S->getStmtClass()) {
334       default:
335         break;
336       case Stmt::GotoStmtClass:
337         return PGOHash::GotoStmt;
338       case Stmt::IndirectGotoStmtClass:
339         return PGOHash::IndirectGotoStmt;
340       case Stmt::BreakStmtClass:
341         return PGOHash::BreakStmt;
342       case Stmt::ContinueStmtClass:
343         return PGOHash::ContinueStmt;
344       case Stmt::ReturnStmtClass:
345         return PGOHash::ReturnStmt;
346       case Stmt::CXXThrowExprClass:
347         return PGOHash::ThrowExpr;
348       case Stmt::UnaryOperatorClass: {
349         const UnaryOperator *UO = cast<UnaryOperator>(S);
350         if (UO->getOpcode() == UO_LNot)
351           return PGOHash::UnaryOperatorLNot;
352         break;
353       }
354       }
355     }
356 
357     return PGOHash::None;
358   }
359 };
360 
361 /// A StmtVisitor that propagates the raw counts through the AST and
362 /// records the count at statements where the value may change.
363 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
364   /// PGO state.
365   CodeGenPGO &PGO;
366 
367   /// A flag that is set when the current count should be recorded on the
368   /// next statement, such as at the exit of a loop.
369   bool RecordNextStmtCount;
370 
371   /// The count at the current location in the traversal.
372   uint64_t CurrentCount;
373 
374   /// The map of statements to count values.
375   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
376 
377   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
378   struct BreakContinue {
379     uint64_t BreakCount;
380     uint64_t ContinueCount;
381     BreakContinue() : BreakCount(0), ContinueCount(0) {}
382   };
383   SmallVector<BreakContinue, 8> BreakContinueStack;
384 
385   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
386                       CodeGenPGO &PGO)
387       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
388 
389   void RecordStmtCount(const Stmt *S) {
390     if (RecordNextStmtCount) {
391       CountMap[S] = CurrentCount;
392       RecordNextStmtCount = false;
393     }
394   }
395 
396   /// Set and return the current count.
397   uint64_t setCount(uint64_t Count) {
398     CurrentCount = Count;
399     return Count;
400   }
401 
402   void VisitStmt(const Stmt *S) {
403     RecordStmtCount(S);
404     for (const Stmt *Child : S->children())
405       if (Child)
406         this->Visit(Child);
407   }
408 
409   void VisitFunctionDecl(const FunctionDecl *D) {
410     // Counter tracks entry to the function body.
411     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412     CountMap[D->getBody()] = BodyCount;
413     Visit(D->getBody());
414   }
415 
416   // Skip lambda expressions. We visit these as FunctionDecls when we're
417   // generating them and aren't interested in the body when generating a
418   // parent context.
419   void VisitLambdaExpr(const LambdaExpr *LE) {}
420 
421   void VisitCapturedDecl(const CapturedDecl *D) {
422     // Counter tracks entry to the capture body.
423     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
424     CountMap[D->getBody()] = BodyCount;
425     Visit(D->getBody());
426   }
427 
428   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
429     // Counter tracks entry to the method body.
430     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
431     CountMap[D->getBody()] = BodyCount;
432     Visit(D->getBody());
433   }
434 
435   void VisitBlockDecl(const BlockDecl *D) {
436     // Counter tracks entry to the block body.
437     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
438     CountMap[D->getBody()] = BodyCount;
439     Visit(D->getBody());
440   }
441 
442   void VisitReturnStmt(const ReturnStmt *S) {
443     RecordStmtCount(S);
444     if (S->getRetValue())
445       Visit(S->getRetValue());
446     CurrentCount = 0;
447     RecordNextStmtCount = true;
448   }
449 
450   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
451     RecordStmtCount(E);
452     if (E->getSubExpr())
453       Visit(E->getSubExpr());
454     CurrentCount = 0;
455     RecordNextStmtCount = true;
456   }
457 
458   void VisitGotoStmt(const GotoStmt *S) {
459     RecordStmtCount(S);
460     CurrentCount = 0;
461     RecordNextStmtCount = true;
462   }
463 
464   void VisitLabelStmt(const LabelStmt *S) {
465     RecordNextStmtCount = false;
466     // Counter tracks the block following the label.
467     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
468     CountMap[S] = BlockCount;
469     Visit(S->getSubStmt());
470   }
471 
472   void VisitBreakStmt(const BreakStmt *S) {
473     RecordStmtCount(S);
474     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
475     BreakContinueStack.back().BreakCount += CurrentCount;
476     CurrentCount = 0;
477     RecordNextStmtCount = true;
478   }
479 
480   void VisitContinueStmt(const ContinueStmt *S) {
481     RecordStmtCount(S);
482     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
483     BreakContinueStack.back().ContinueCount += CurrentCount;
484     CurrentCount = 0;
485     RecordNextStmtCount = true;
486   }
487 
488   void VisitWhileStmt(const WhileStmt *S) {
489     RecordStmtCount(S);
490     uint64_t ParentCount = CurrentCount;
491 
492     BreakContinueStack.push_back(BreakContinue());
493     // Visit the body region first so the break/continue adjustments can be
494     // included when visiting the condition.
495     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
496     CountMap[S->getBody()] = CurrentCount;
497     Visit(S->getBody());
498     uint64_t BackedgeCount = CurrentCount;
499 
500     // ...then go back and propagate counts through the condition. The count
501     // at the start of the condition is the sum of the incoming edges,
502     // the backedge from the end of the loop body, and the edges from
503     // continue statements.
504     BreakContinue BC = BreakContinueStack.pop_back_val();
505     uint64_t CondCount =
506         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
507     CountMap[S->getCond()] = CondCount;
508     Visit(S->getCond());
509     setCount(BC.BreakCount + CondCount - BodyCount);
510     RecordNextStmtCount = true;
511   }
512 
513   void VisitDoStmt(const DoStmt *S) {
514     RecordStmtCount(S);
515     uint64_t LoopCount = PGO.getRegionCount(S);
516 
517     BreakContinueStack.push_back(BreakContinue());
518     // The count doesn't include the fallthrough from the parent scope. Add it.
519     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
520     CountMap[S->getBody()] = BodyCount;
521     Visit(S->getBody());
522     uint64_t BackedgeCount = CurrentCount;
523 
524     BreakContinue BC = BreakContinueStack.pop_back_val();
525     // The count at the start of the condition is equal to the count at the
526     // end of the body, plus any continues.
527     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
528     CountMap[S->getCond()] = CondCount;
529     Visit(S->getCond());
530     setCount(BC.BreakCount + CondCount - LoopCount);
531     RecordNextStmtCount = true;
532   }
533 
534   void VisitForStmt(const ForStmt *S) {
535     RecordStmtCount(S);
536     if (S->getInit())
537       Visit(S->getInit());
538 
539     uint64_t ParentCount = CurrentCount;
540 
541     BreakContinueStack.push_back(BreakContinue());
542     // Visit the body region first. (This is basically the same as a while
543     // loop; see further comments in VisitWhileStmt.)
544     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
545     CountMap[S->getBody()] = BodyCount;
546     Visit(S->getBody());
547     uint64_t BackedgeCount = CurrentCount;
548     BreakContinue BC = BreakContinueStack.pop_back_val();
549 
550     // The increment is essentially part of the body but it needs to include
551     // the count for all the continue statements.
552     if (S->getInc()) {
553       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
554       CountMap[S->getInc()] = IncCount;
555       Visit(S->getInc());
556     }
557 
558     // ...then go back and propagate counts through the condition.
559     uint64_t CondCount =
560         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
561     if (S->getCond()) {
562       CountMap[S->getCond()] = CondCount;
563       Visit(S->getCond());
564     }
565     setCount(BC.BreakCount + CondCount - BodyCount);
566     RecordNextStmtCount = true;
567   }
568 
569   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
570     RecordStmtCount(S);
571     if (S->getInit())
572       Visit(S->getInit());
573     Visit(S->getLoopVarStmt());
574     Visit(S->getRangeStmt());
575     Visit(S->getBeginStmt());
576     Visit(S->getEndStmt());
577 
578     uint64_t ParentCount = CurrentCount;
579     BreakContinueStack.push_back(BreakContinue());
580     // Visit the body region first. (This is basically the same as a while
581     // loop; see further comments in VisitWhileStmt.)
582     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
583     CountMap[S->getBody()] = BodyCount;
584     Visit(S->getBody());
585     uint64_t BackedgeCount = CurrentCount;
586     BreakContinue BC = BreakContinueStack.pop_back_val();
587 
588     // The increment is essentially part of the body but it needs to include
589     // the count for all the continue statements.
590     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
591     CountMap[S->getInc()] = IncCount;
592     Visit(S->getInc());
593 
594     // ...then go back and propagate counts through the condition.
595     uint64_t CondCount =
596         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
597     CountMap[S->getCond()] = CondCount;
598     Visit(S->getCond());
599     setCount(BC.BreakCount + CondCount - BodyCount);
600     RecordNextStmtCount = true;
601   }
602 
603   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
604     RecordStmtCount(S);
605     Visit(S->getElement());
606     uint64_t ParentCount = CurrentCount;
607     BreakContinueStack.push_back(BreakContinue());
608     // Counter tracks the body of the loop.
609     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
610     CountMap[S->getBody()] = BodyCount;
611     Visit(S->getBody());
612     uint64_t BackedgeCount = CurrentCount;
613     BreakContinue BC = BreakContinueStack.pop_back_val();
614 
615     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
616              BodyCount);
617     RecordNextStmtCount = true;
618   }
619 
620   void VisitSwitchStmt(const SwitchStmt *S) {
621     RecordStmtCount(S);
622     if (S->getInit())
623       Visit(S->getInit());
624     Visit(S->getCond());
625     CurrentCount = 0;
626     BreakContinueStack.push_back(BreakContinue());
627     Visit(S->getBody());
628     // If the switch is inside a loop, add the continue counts.
629     BreakContinue BC = BreakContinueStack.pop_back_val();
630     if (!BreakContinueStack.empty())
631       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
632     // Counter tracks the exit block of the switch.
633     setCount(PGO.getRegionCount(S));
634     RecordNextStmtCount = true;
635   }
636 
637   void VisitSwitchCase(const SwitchCase *S) {
638     RecordNextStmtCount = false;
639     // Counter for this particular case. This counts only jumps from the
640     // switch header and does not include fallthrough from the case before
641     // this one.
642     uint64_t CaseCount = PGO.getRegionCount(S);
643     setCount(CurrentCount + CaseCount);
644     // We need the count without fallthrough in the mapping, so it's more useful
645     // for branch probabilities.
646     CountMap[S] = CaseCount;
647     RecordNextStmtCount = true;
648     Visit(S->getSubStmt());
649   }
650 
651   void VisitIfStmt(const IfStmt *S) {
652     RecordStmtCount(S);
653 
654     if (S->isConsteval()) {
655       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
656       if (Stm)
657         Visit(Stm);
658       return;
659     }
660 
661     uint64_t ParentCount = CurrentCount;
662     if (S->getInit())
663       Visit(S->getInit());
664     Visit(S->getCond());
665 
666     // Counter tracks the "then" part of an if statement. The count for
667     // the "else" part, if it exists, will be calculated from this counter.
668     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
669     CountMap[S->getThen()] = ThenCount;
670     Visit(S->getThen());
671     uint64_t OutCount = CurrentCount;
672 
673     uint64_t ElseCount = ParentCount - ThenCount;
674     if (S->getElse()) {
675       setCount(ElseCount);
676       CountMap[S->getElse()] = ElseCount;
677       Visit(S->getElse());
678       OutCount += CurrentCount;
679     } else
680       OutCount += ElseCount;
681     setCount(OutCount);
682     RecordNextStmtCount = true;
683   }
684 
685   void VisitCXXTryStmt(const CXXTryStmt *S) {
686     RecordStmtCount(S);
687     Visit(S->getTryBlock());
688     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
689       Visit(S->getHandler(I));
690     // Counter tracks the continuation block of the try statement.
691     setCount(PGO.getRegionCount(S));
692     RecordNextStmtCount = true;
693   }
694 
695   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
696     RecordNextStmtCount = false;
697     // Counter tracks the catch statement's handler block.
698     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
699     CountMap[S] = CatchCount;
700     Visit(S->getHandlerBlock());
701   }
702 
703   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
704     RecordStmtCount(E);
705     uint64_t ParentCount = CurrentCount;
706     Visit(E->getCond());
707 
708     // Counter tracks the "true" part of a conditional operator. The
709     // count in the "false" part will be calculated from this counter.
710     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
711     CountMap[E->getTrueExpr()] = TrueCount;
712     Visit(E->getTrueExpr());
713     uint64_t OutCount = CurrentCount;
714 
715     uint64_t FalseCount = setCount(ParentCount - TrueCount);
716     CountMap[E->getFalseExpr()] = FalseCount;
717     Visit(E->getFalseExpr());
718     OutCount += CurrentCount;
719 
720     setCount(OutCount);
721     RecordNextStmtCount = true;
722   }
723 
724   void VisitBinLAnd(const BinaryOperator *E) {
725     RecordStmtCount(E);
726     uint64_t ParentCount = CurrentCount;
727     Visit(E->getLHS());
728     // Counter tracks the right hand side of a logical and operator.
729     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
730     CountMap[E->getRHS()] = RHSCount;
731     Visit(E->getRHS());
732     setCount(ParentCount + RHSCount - CurrentCount);
733     RecordNextStmtCount = true;
734   }
735 
736   void VisitBinLOr(const BinaryOperator *E) {
737     RecordStmtCount(E);
738     uint64_t ParentCount = CurrentCount;
739     Visit(E->getLHS());
740     // Counter tracks the right hand side of a logical or operator.
741     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
742     CountMap[E->getRHS()] = RHSCount;
743     Visit(E->getRHS());
744     setCount(ParentCount + RHSCount - CurrentCount);
745     RecordNextStmtCount = true;
746   }
747 };
748 } // end anonymous namespace
749 
750 void PGOHash::combine(HashType Type) {
751   // Check that we never combine 0 and only have six bits.
752   assert(Type && "Hash is invalid: unexpected type 0");
753   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
754 
755   // Pass through MD5 if enough work has built up.
756   if (Count && Count % NumTypesPerWord == 0) {
757     using namespace llvm::support;
758     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
759     MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
760     Working = 0;
761   }
762 
763   // Accumulate the current type.
764   ++Count;
765   Working = Working << NumBitsPerType | Type;
766 }
767 
768 uint64_t PGOHash::finalize() {
769   // Use Working as the hash directly if we never used MD5.
770   if (Count <= NumTypesPerWord)
771     // No need to byte swap here, since none of the math was endian-dependent.
772     // This number will be byte-swapped as required on endianness transitions,
773     // so we will see the same value on the other side.
774     return Working;
775 
776   // Check for remaining work in Working.
777   if (Working) {
778     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
779     // is buggy because it converts a uint64_t into an array of uint8_t.
780     if (HashVersion < PGO_HASH_V3) {
781       MD5.update({(uint8_t)Working});
782     } else {
783       using namespace llvm::support;
784       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
785       MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
786     }
787   }
788 
789   // Finalize the MD5 and return the hash.
790   llvm::MD5::MD5Result Result;
791   MD5.final(Result);
792   return Result.low();
793 }
794 
795 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
796   const Decl *D = GD.getDecl();
797   if (!D->hasBody())
798     return;
799 
800   // Skip CUDA/HIP kernel launch stub functions.
801   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
802       D->hasAttr<CUDAGlobalAttr>())
803     return;
804 
805   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
806   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
807   if (!InstrumentRegions && !PGOReader)
808     return;
809   if (D->isImplicit())
810     return;
811   // Constructors and destructors may be represented by several functions in IR.
812   // If so, instrument only base variant, others are implemented by delegation
813   // to the base one, it would be counted twice otherwise.
814   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
815     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
816       if (GD.getCtorType() != Ctor_Base &&
817           CodeGenFunction::IsConstructorDelegationValid(CCD))
818         return;
819   }
820   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
821     return;
822 
823   CGM.ClearUnusedCoverageMapping(D);
824   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
825     return;
826   if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
827     return;
828 
829   setFuncName(Fn);
830 
831   mapRegionCounters(D);
832   if (CGM.getCodeGenOpts().CoverageMapping)
833     emitCounterRegionMapping(D);
834   if (PGOReader) {
835     SourceManager &SM = CGM.getContext().getSourceManager();
836     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
837     computeRegionCounts(D);
838     applyFunctionAttributes(PGOReader, Fn);
839   }
840 }
841 
842 void CodeGenPGO::mapRegionCounters(const Decl *D) {
843   // Use the latest hash version when inserting instrumentation, but use the
844   // version in the indexed profile if we're reading PGO data.
845   PGOHashVersion HashVersion = PGO_HASH_LATEST;
846   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
847   if (auto *PGOReader = CGM.getPGOReader()) {
848     HashVersion = getPGOHashVersion(PGOReader, CGM);
849     ProfileVersion = PGOReader->getVersion();
850   }
851 
852   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
853   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
854   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
855     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
856   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
857     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
858   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
859     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
860   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
861     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
862   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
863   NumRegionCounters = Walker.NextCounter;
864   FunctionHash = Walker.Hash.finalize();
865 }
866 
867 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
868   if (!D->getBody())
869     return true;
870 
871   // Skip host-only functions in the CUDA device compilation and device-only
872   // functions in the host compilation. Just roughly filter them out based on
873   // the function attributes. If there are effectively host-only or device-only
874   // ones, their coverage mapping may still be generated.
875   if (CGM.getLangOpts().CUDA &&
876       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
877         !D->hasAttr<CUDAGlobalAttr>()) ||
878        (!CGM.getLangOpts().CUDAIsDevice &&
879         (D->hasAttr<CUDAGlobalAttr>() ||
880          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
881     return true;
882 
883   // Don't map the functions in system headers.
884   const auto &SM = CGM.getContext().getSourceManager();
885   auto Loc = D->getBody()->getBeginLoc();
886   return SM.isInSystemHeader(Loc);
887 }
888 
889 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
890   if (skipRegionMappingForDecl(D))
891     return;
892 
893   std::string CoverageMapping;
894   llvm::raw_string_ostream OS(CoverageMapping);
895   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
896                                 CGM.getContext().getSourceManager(),
897                                 CGM.getLangOpts(), RegionCounterMap.get());
898   MappingGen.emitCounterMapping(D, OS);
899   OS.flush();
900 
901   if (CoverageMapping.empty())
902     return;
903 
904   CGM.getCoverageMapping()->addFunctionMappingRecord(
905       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
906 }
907 
908 void
909 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
910                                     llvm::GlobalValue::LinkageTypes Linkage) {
911   if (skipRegionMappingForDecl(D))
912     return;
913 
914   std::string CoverageMapping;
915   llvm::raw_string_ostream OS(CoverageMapping);
916   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
917                                 CGM.getContext().getSourceManager(),
918                                 CGM.getLangOpts());
919   MappingGen.emitEmptyMapping(D, OS);
920   OS.flush();
921 
922   if (CoverageMapping.empty())
923     return;
924 
925   setFuncName(Name, Linkage);
926   CGM.getCoverageMapping()->addFunctionMappingRecord(
927       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
928 }
929 
930 void CodeGenPGO::computeRegionCounts(const Decl *D) {
931   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
932   ComputeRegionCounts Walker(*StmtCountMap, *this);
933   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
934     Walker.VisitFunctionDecl(FD);
935   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
936     Walker.VisitObjCMethodDecl(MD);
937   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
938     Walker.VisitBlockDecl(BD);
939   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
940     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
941 }
942 
943 void
944 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
945                                     llvm::Function *Fn) {
946   if (!haveRegionCounts())
947     return;
948 
949   uint64_t FunctionCount = getRegionCount(nullptr);
950   Fn->setEntryCount(FunctionCount);
951 }
952 
953 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
954                                       llvm::Value *StepV) {
955   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
956     return;
957   if (!Builder.GetInsertBlock())
958     return;
959 
960   unsigned Counter = (*RegionCounterMap)[S];
961   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
962 
963   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
964                          Builder.getInt64(FunctionHash),
965                          Builder.getInt32(NumRegionCounters),
966                          Builder.getInt32(Counter), StepV};
967   if (!StepV)
968     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
969                        ArrayRef(Args, 4));
970   else
971     Builder.CreateCall(
972         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
973         ArrayRef(Args));
974 }
975 
976 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
977   if (CGM.getCodeGenOpts().hasProfileClangInstr())
978     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
979                     uint32_t(EnableValueProfiling));
980 }
981 
982 // This method either inserts a call to the profile run-time during
983 // instrumentation or puts profile data into metadata for PGO use.
984 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
985     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
986 
987   if (!EnableValueProfiling)
988     return;
989 
990   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
991     return;
992 
993   if (isa<llvm::Constant>(ValuePtr))
994     return;
995 
996   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
997   if (InstrumentValueSites && RegionCounterMap) {
998     auto BuilderInsertPoint = Builder.saveIP();
999     Builder.SetInsertPoint(ValueSite);
1000     llvm::Value *Args[5] = {
1001         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
1002         Builder.getInt64(FunctionHash),
1003         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1004         Builder.getInt32(ValueKind),
1005         Builder.getInt32(NumValueSites[ValueKind]++)
1006     };
1007     Builder.CreateCall(
1008         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1009     Builder.restoreIP(BuilderInsertPoint);
1010     return;
1011   }
1012 
1013   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1014   if (PGOReader && haveRegionCounts()) {
1015     // We record the top most called three functions at each call site.
1016     // Profile metadata contains "VP" string identifying this metadata
1017     // as value profiling data, then a uint32_t value for the value profiling
1018     // kind, a uint64_t value for the total number of times the call is
1019     // executed, followed by the function hash and execution count (uint64_t)
1020     // pairs for each function.
1021     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1022       return;
1023 
1024     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1025                             (llvm::InstrProfValueKind)ValueKind,
1026                             NumValueSites[ValueKind]);
1027 
1028     NumValueSites[ValueKind]++;
1029   }
1030 }
1031 
1032 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1033                                   bool IsInMainFile) {
1034   CGM.getPGOStats().addVisited(IsInMainFile);
1035   RegionCounts.clear();
1036   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1037       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1038   if (auto E = RecordExpected.takeError()) {
1039     auto IPE = llvm::InstrProfError::take(std::move(E));
1040     if (IPE == llvm::instrprof_error::unknown_function)
1041       CGM.getPGOStats().addMissing(IsInMainFile);
1042     else if (IPE == llvm::instrprof_error::hash_mismatch)
1043       CGM.getPGOStats().addMismatched(IsInMainFile);
1044     else if (IPE == llvm::instrprof_error::malformed)
1045       // TODO: Consider a more specific warning for this case.
1046       CGM.getPGOStats().addMismatched(IsInMainFile);
1047     return;
1048   }
1049   ProfRecord =
1050       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1051   RegionCounts = ProfRecord->Counts;
1052 }
1053 
1054 /// Calculate what to divide by to scale weights.
1055 ///
1056 /// Given the maximum weight, calculate a divisor that will scale all the
1057 /// weights to strictly less than UINT32_MAX.
1058 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1059   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1060 }
1061 
1062 /// Scale an individual branch weight (and add 1).
1063 ///
1064 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1065 ///
1066 /// According to Laplace's Rule of Succession, it is better to compute the
1067 /// weight based on the count plus 1, so universally add 1 to the value.
1068 ///
1069 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1070 /// greater than \c Weight.
1071 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1072   assert(Scale && "scale by 0?");
1073   uint64_t Scaled = Weight / Scale + 1;
1074   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1075   return Scaled;
1076 }
1077 
1078 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1079                                                     uint64_t FalseCount) const {
1080   // Check for empty weights.
1081   if (!TrueCount && !FalseCount)
1082     return nullptr;
1083 
1084   // Calculate how to scale down to 32-bits.
1085   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1086 
1087   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1088   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1089                                       scaleBranchWeight(FalseCount, Scale));
1090 }
1091 
1092 llvm::MDNode *
1093 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1094   // We need at least two elements to create meaningful weights.
1095   if (Weights.size() < 2)
1096     return nullptr;
1097 
1098   // Check for empty weights.
1099   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1100   if (MaxWeight == 0)
1101     return nullptr;
1102 
1103   // Calculate how to scale down to 32-bits.
1104   uint64_t Scale = calculateWeightScale(MaxWeight);
1105 
1106   SmallVector<uint32_t, 16> ScaledWeights;
1107   ScaledWeights.reserve(Weights.size());
1108   for (uint64_t W : Weights)
1109     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1110 
1111   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1112   return MDHelper.createBranchWeights(ScaledWeights);
1113 }
1114 
1115 llvm::MDNode *
1116 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1117                                              uint64_t LoopCount) const {
1118   if (!PGO.haveRegionCounts())
1119     return nullptr;
1120   std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1121   if (!CondCount || *CondCount == 0)
1122     return nullptr;
1123   return createProfileWeights(LoopCount,
1124                               std::max(*CondCount, LoopCount) - LoopCount);
1125 }
1126