xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 8ddb146abcdf061be9f2c0db7e391697dafad85c)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55   PGO_HASH_V3,
56 
57   // Keep this set to the latest hash version.
58   PGO_HASH_LATEST = PGO_HASH_V3
59 };
60 
61 namespace {
62 /// Stable hasher for PGO region counters.
63 ///
64 /// PGOHash produces a stable hash of a given function's control flow.
65 ///
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
68 ///
69 /// \note  When this hash does eventually change (years?), we still need to
70 /// support old hashes.  We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
72 class PGOHash {
73   uint64_t Working;
74   unsigned Count;
75   PGOHashVersion HashVersion;
76   llvm::MD5 MD5;
77 
78   static const int NumBitsPerType = 6;
79   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80   static const unsigned TooBig = 1u << NumBitsPerType;
81 
82 public:
83   /// Hash values for AST nodes.
84   ///
85   /// Distinct values for AST nodes that have region counters attached.
86   ///
87   /// These values must be stable.  All new members must be added at the end,
88   /// and no members should be removed.  Changing the enumeration value for an
89   /// AST node will affect the hash of every function that contains that node.
90   enum HashType : unsigned char {
91     None = 0,
92     LabelStmt = 1,
93     WhileStmt,
94     DoStmt,
95     ForStmt,
96     CXXForRangeStmt,
97     ObjCForCollectionStmt,
98     SwitchStmt,
99     CaseStmt,
100     DefaultStmt,
101     IfStmt,
102     CXXTryStmt,
103     CXXCatchStmt,
104     ConditionalOperator,
105     BinaryOperatorLAnd,
106     BinaryOperatorLOr,
107     BinaryConditionalOperator,
108     // The preceding values are available with PGO_HASH_V1.
109 
110     EndOfScope,
111     IfThenBranch,
112     IfElseBranch,
113     GotoStmt,
114     IndirectGotoStmt,
115     BreakStmt,
116     ContinueStmt,
117     ReturnStmt,
118     ThrowExpr,
119     UnaryOperatorLNot,
120     BinaryOperatorLT,
121     BinaryOperatorGT,
122     BinaryOperatorLE,
123     BinaryOperatorGE,
124     BinaryOperatorEQ,
125     BinaryOperatorNE,
126     // The preceding values are available since PGO_HASH_V2.
127 
128     // Keep this last.  It's for the static assert that follows.
129     LastHashType
130   };
131   static_assert(LastHashType <= TooBig, "Too many types in HashType");
132 
133   PGOHash(PGOHashVersion HashVersion)
134       : Working(0), Count(0), HashVersion(HashVersion) {}
135   void combine(HashType Type);
136   uint64_t finalize();
137   PGOHashVersion getHashVersion() const { return HashVersion; }
138 };
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
142 
143 /// Get the PGO hash version used in the given indexed profile.
144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145                                         CodeGenModule &CGM) {
146   if (PGOReader->getVersion() <= 4)
147     return PGO_HASH_V1;
148   if (PGOReader->getVersion() <= 5)
149     return PGO_HASH_V2;
150   return PGO_HASH_V3;
151 }
152 
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155   using Base = RecursiveASTVisitor<MapRegionCounters>;
156 
157   /// The next counter value to assign.
158   unsigned NextCounter;
159   /// The function hash.
160   PGOHash Hash;
161   /// The map of statements to counters.
162   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163   /// The profile version.
164   uint64_t ProfileVersion;
165 
166   MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
167                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169         ProfileVersion(ProfileVersion) {}
170 
171   // Blocks and lambdas are handled as separate functions, so we need not
172   // traverse them in the parent context.
173   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
174   bool TraverseLambdaExpr(LambdaExpr *LE) {
175     // Traverse the captures, but not the body.
176     for (auto C : zip(LE->captures(), LE->capture_inits()))
177       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
178     return true;
179   }
180   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
181 
182   bool VisitDecl(const Decl *D) {
183     switch (D->getKind()) {
184     default:
185       break;
186     case Decl::Function:
187     case Decl::CXXMethod:
188     case Decl::CXXConstructor:
189     case Decl::CXXDestructor:
190     case Decl::CXXConversion:
191     case Decl::ObjCMethod:
192     case Decl::Block:
193     case Decl::Captured:
194       CounterMap[D->getBody()] = NextCounter++;
195       break;
196     }
197     return true;
198   }
199 
200   /// If \p S gets a fresh counter, update the counter mappings. Return the
201   /// V1 hash of \p S.
202   PGOHash::HashType updateCounterMappings(Stmt *S) {
203     auto Type = getHashType(PGO_HASH_V1, S);
204     if (Type != PGOHash::None)
205       CounterMap[S] = NextCounter++;
206     return Type;
207   }
208 
209   /// The RHS of all logical operators gets a fresh counter in order to count
210   /// how many times the RHS evaluates to true or false, depending on the
211   /// semantics of the operator. This is only valid for ">= v7" of the profile
212   /// version so that we facilitate backward compatibility.
213   bool VisitBinaryOperator(BinaryOperator *S) {
214     if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215       if (S->isLogicalOp() &&
216           CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217         CounterMap[S->getRHS()] = NextCounter++;
218     return Base::VisitBinaryOperator(S);
219   }
220 
221   /// Include \p S in the function hash.
222   bool VisitStmt(Stmt *S) {
223     auto Type = updateCounterMappings(S);
224     if (Hash.getHashVersion() != PGO_HASH_V1)
225       Type = getHashType(Hash.getHashVersion(), S);
226     if (Type != PGOHash::None)
227       Hash.combine(Type);
228     return true;
229   }
230 
231   bool TraverseIfStmt(IfStmt *If) {
232     // If we used the V1 hash, use the default traversal.
233     if (Hash.getHashVersion() == PGO_HASH_V1)
234       return Base::TraverseIfStmt(If);
235 
236     // Otherwise, keep track of which branch we're in while traversing.
237     VisitStmt(If);
238     for (Stmt *CS : If->children()) {
239       if (!CS)
240         continue;
241       if (CS == If->getThen())
242         Hash.combine(PGOHash::IfThenBranch);
243       else if (CS == If->getElse())
244         Hash.combine(PGOHash::IfElseBranch);
245       TraverseStmt(CS);
246     }
247     Hash.combine(PGOHash::EndOfScope);
248     return true;
249   }
250 
251 // If the statement type \p N is nestable, and its nesting impacts profile
252 // stability, define a custom traversal which tracks the end of the statement
253 // in the hash (provided we're not using the V1 hash).
254 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
255   bool Traverse##N(N *S) {                                                     \
256     Base::Traverse##N(S);                                                      \
257     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
258       Hash.combine(PGOHash::EndOfScope);                                       \
259     return true;                                                               \
260   }
261 
262   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
263   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
264   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
265   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
266   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
267   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
268   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
269 
270   /// Get version \p HashVersion of the PGO hash for \p S.
271   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
272     switch (S->getStmtClass()) {
273     default:
274       break;
275     case Stmt::LabelStmtClass:
276       return PGOHash::LabelStmt;
277     case Stmt::WhileStmtClass:
278       return PGOHash::WhileStmt;
279     case Stmt::DoStmtClass:
280       return PGOHash::DoStmt;
281     case Stmt::ForStmtClass:
282       return PGOHash::ForStmt;
283     case Stmt::CXXForRangeStmtClass:
284       return PGOHash::CXXForRangeStmt;
285     case Stmt::ObjCForCollectionStmtClass:
286       return PGOHash::ObjCForCollectionStmt;
287     case Stmt::SwitchStmtClass:
288       return PGOHash::SwitchStmt;
289     case Stmt::CaseStmtClass:
290       return PGOHash::CaseStmt;
291     case Stmt::DefaultStmtClass:
292       return PGOHash::DefaultStmt;
293     case Stmt::IfStmtClass:
294       return PGOHash::IfStmt;
295     case Stmt::CXXTryStmtClass:
296       return PGOHash::CXXTryStmt;
297     case Stmt::CXXCatchStmtClass:
298       return PGOHash::CXXCatchStmt;
299     case Stmt::ConditionalOperatorClass:
300       return PGOHash::ConditionalOperator;
301     case Stmt::BinaryConditionalOperatorClass:
302       return PGOHash::BinaryConditionalOperator;
303     case Stmt::BinaryOperatorClass: {
304       const BinaryOperator *BO = cast<BinaryOperator>(S);
305       if (BO->getOpcode() == BO_LAnd)
306         return PGOHash::BinaryOperatorLAnd;
307       if (BO->getOpcode() == BO_LOr)
308         return PGOHash::BinaryOperatorLOr;
309       if (HashVersion >= PGO_HASH_V2) {
310         switch (BO->getOpcode()) {
311         default:
312           break;
313         case BO_LT:
314           return PGOHash::BinaryOperatorLT;
315         case BO_GT:
316           return PGOHash::BinaryOperatorGT;
317         case BO_LE:
318           return PGOHash::BinaryOperatorLE;
319         case BO_GE:
320           return PGOHash::BinaryOperatorGE;
321         case BO_EQ:
322           return PGOHash::BinaryOperatorEQ;
323         case BO_NE:
324           return PGOHash::BinaryOperatorNE;
325         }
326       }
327       break;
328     }
329     }
330 
331     if (HashVersion >= PGO_HASH_V2) {
332       switch (S->getStmtClass()) {
333       default:
334         break;
335       case Stmt::GotoStmtClass:
336         return PGOHash::GotoStmt;
337       case Stmt::IndirectGotoStmtClass:
338         return PGOHash::IndirectGotoStmt;
339       case Stmt::BreakStmtClass:
340         return PGOHash::BreakStmt;
341       case Stmt::ContinueStmtClass:
342         return PGOHash::ContinueStmt;
343       case Stmt::ReturnStmtClass:
344         return PGOHash::ReturnStmt;
345       case Stmt::CXXThrowExprClass:
346         return PGOHash::ThrowExpr;
347       case Stmt::UnaryOperatorClass: {
348         const UnaryOperator *UO = cast<UnaryOperator>(S);
349         if (UO->getOpcode() == UO_LNot)
350           return PGOHash::UnaryOperatorLNot;
351         break;
352       }
353       }
354     }
355 
356     return PGOHash::None;
357   }
358 };
359 
360 /// A StmtVisitor that propagates the raw counts through the AST and
361 /// records the count at statements where the value may change.
362 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
363   /// PGO state.
364   CodeGenPGO &PGO;
365 
366   /// A flag that is set when the current count should be recorded on the
367   /// next statement, such as at the exit of a loop.
368   bool RecordNextStmtCount;
369 
370   /// The count at the current location in the traversal.
371   uint64_t CurrentCount;
372 
373   /// The map of statements to count values.
374   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
375 
376   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
377   struct BreakContinue {
378     uint64_t BreakCount;
379     uint64_t ContinueCount;
380     BreakContinue() : BreakCount(0), ContinueCount(0) {}
381   };
382   SmallVector<BreakContinue, 8> BreakContinueStack;
383 
384   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
385                       CodeGenPGO &PGO)
386       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
387 
388   void RecordStmtCount(const Stmt *S) {
389     if (RecordNextStmtCount) {
390       CountMap[S] = CurrentCount;
391       RecordNextStmtCount = false;
392     }
393   }
394 
395   /// Set and return the current count.
396   uint64_t setCount(uint64_t Count) {
397     CurrentCount = Count;
398     return Count;
399   }
400 
401   void VisitStmt(const Stmt *S) {
402     RecordStmtCount(S);
403     for (const Stmt *Child : S->children())
404       if (Child)
405         this->Visit(Child);
406   }
407 
408   void VisitFunctionDecl(const FunctionDecl *D) {
409     // Counter tracks entry to the function body.
410     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411     CountMap[D->getBody()] = BodyCount;
412     Visit(D->getBody());
413   }
414 
415   // Skip lambda expressions. We visit these as FunctionDecls when we're
416   // generating them and aren't interested in the body when generating a
417   // parent context.
418   void VisitLambdaExpr(const LambdaExpr *LE) {}
419 
420   void VisitCapturedDecl(const CapturedDecl *D) {
421     // Counter tracks entry to the capture body.
422     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
423     CountMap[D->getBody()] = BodyCount;
424     Visit(D->getBody());
425   }
426 
427   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
428     // Counter tracks entry to the method body.
429     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
430     CountMap[D->getBody()] = BodyCount;
431     Visit(D->getBody());
432   }
433 
434   void VisitBlockDecl(const BlockDecl *D) {
435     // Counter tracks entry to the block body.
436     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
437     CountMap[D->getBody()] = BodyCount;
438     Visit(D->getBody());
439   }
440 
441   void VisitReturnStmt(const ReturnStmt *S) {
442     RecordStmtCount(S);
443     if (S->getRetValue())
444       Visit(S->getRetValue());
445     CurrentCount = 0;
446     RecordNextStmtCount = true;
447   }
448 
449   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
450     RecordStmtCount(E);
451     if (E->getSubExpr())
452       Visit(E->getSubExpr());
453     CurrentCount = 0;
454     RecordNextStmtCount = true;
455   }
456 
457   void VisitGotoStmt(const GotoStmt *S) {
458     RecordStmtCount(S);
459     CurrentCount = 0;
460     RecordNextStmtCount = true;
461   }
462 
463   void VisitLabelStmt(const LabelStmt *S) {
464     RecordNextStmtCount = false;
465     // Counter tracks the block following the label.
466     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
467     CountMap[S] = BlockCount;
468     Visit(S->getSubStmt());
469   }
470 
471   void VisitBreakStmt(const BreakStmt *S) {
472     RecordStmtCount(S);
473     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
474     BreakContinueStack.back().BreakCount += CurrentCount;
475     CurrentCount = 0;
476     RecordNextStmtCount = true;
477   }
478 
479   void VisitContinueStmt(const ContinueStmt *S) {
480     RecordStmtCount(S);
481     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
482     BreakContinueStack.back().ContinueCount += CurrentCount;
483     CurrentCount = 0;
484     RecordNextStmtCount = true;
485   }
486 
487   void VisitWhileStmt(const WhileStmt *S) {
488     RecordStmtCount(S);
489     uint64_t ParentCount = CurrentCount;
490 
491     BreakContinueStack.push_back(BreakContinue());
492     // Visit the body region first so the break/continue adjustments can be
493     // included when visiting the condition.
494     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
495     CountMap[S->getBody()] = CurrentCount;
496     Visit(S->getBody());
497     uint64_t BackedgeCount = CurrentCount;
498 
499     // ...then go back and propagate counts through the condition. The count
500     // at the start of the condition is the sum of the incoming edges,
501     // the backedge from the end of the loop body, and the edges from
502     // continue statements.
503     BreakContinue BC = BreakContinueStack.pop_back_val();
504     uint64_t CondCount =
505         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
506     CountMap[S->getCond()] = CondCount;
507     Visit(S->getCond());
508     setCount(BC.BreakCount + CondCount - BodyCount);
509     RecordNextStmtCount = true;
510   }
511 
512   void VisitDoStmt(const DoStmt *S) {
513     RecordStmtCount(S);
514     uint64_t LoopCount = PGO.getRegionCount(S);
515 
516     BreakContinueStack.push_back(BreakContinue());
517     // The count doesn't include the fallthrough from the parent scope. Add it.
518     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
519     CountMap[S->getBody()] = BodyCount;
520     Visit(S->getBody());
521     uint64_t BackedgeCount = CurrentCount;
522 
523     BreakContinue BC = BreakContinueStack.pop_back_val();
524     // The count at the start of the condition is equal to the count at the
525     // end of the body, plus any continues.
526     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
527     CountMap[S->getCond()] = CondCount;
528     Visit(S->getCond());
529     setCount(BC.BreakCount + CondCount - LoopCount);
530     RecordNextStmtCount = true;
531   }
532 
533   void VisitForStmt(const ForStmt *S) {
534     RecordStmtCount(S);
535     if (S->getInit())
536       Visit(S->getInit());
537 
538     uint64_t ParentCount = CurrentCount;
539 
540     BreakContinueStack.push_back(BreakContinue());
541     // Visit the body region first. (This is basically the same as a while
542     // loop; see further comments in VisitWhileStmt.)
543     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
544     CountMap[S->getBody()] = BodyCount;
545     Visit(S->getBody());
546     uint64_t BackedgeCount = CurrentCount;
547     BreakContinue BC = BreakContinueStack.pop_back_val();
548 
549     // The increment is essentially part of the body but it needs to include
550     // the count for all the continue statements.
551     if (S->getInc()) {
552       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
553       CountMap[S->getInc()] = IncCount;
554       Visit(S->getInc());
555     }
556 
557     // ...then go back and propagate counts through the condition.
558     uint64_t CondCount =
559         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
560     if (S->getCond()) {
561       CountMap[S->getCond()] = CondCount;
562       Visit(S->getCond());
563     }
564     setCount(BC.BreakCount + CondCount - BodyCount);
565     RecordNextStmtCount = true;
566   }
567 
568   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
569     RecordStmtCount(S);
570     if (S->getInit())
571       Visit(S->getInit());
572     Visit(S->getLoopVarStmt());
573     Visit(S->getRangeStmt());
574     Visit(S->getBeginStmt());
575     Visit(S->getEndStmt());
576 
577     uint64_t ParentCount = CurrentCount;
578     BreakContinueStack.push_back(BreakContinue());
579     // Visit the body region first. (This is basically the same as a while
580     // loop; see further comments in VisitWhileStmt.)
581     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
582     CountMap[S->getBody()] = BodyCount;
583     Visit(S->getBody());
584     uint64_t BackedgeCount = CurrentCount;
585     BreakContinue BC = BreakContinueStack.pop_back_val();
586 
587     // The increment is essentially part of the body but it needs to include
588     // the count for all the continue statements.
589     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
590     CountMap[S->getInc()] = IncCount;
591     Visit(S->getInc());
592 
593     // ...then go back and propagate counts through the condition.
594     uint64_t CondCount =
595         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
596     CountMap[S->getCond()] = CondCount;
597     Visit(S->getCond());
598     setCount(BC.BreakCount + CondCount - BodyCount);
599     RecordNextStmtCount = true;
600   }
601 
602   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
603     RecordStmtCount(S);
604     Visit(S->getElement());
605     uint64_t ParentCount = CurrentCount;
606     BreakContinueStack.push_back(BreakContinue());
607     // Counter tracks the body of the loop.
608     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
609     CountMap[S->getBody()] = BodyCount;
610     Visit(S->getBody());
611     uint64_t BackedgeCount = CurrentCount;
612     BreakContinue BC = BreakContinueStack.pop_back_val();
613 
614     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
615              BodyCount);
616     RecordNextStmtCount = true;
617   }
618 
619   void VisitSwitchStmt(const SwitchStmt *S) {
620     RecordStmtCount(S);
621     if (S->getInit())
622       Visit(S->getInit());
623     Visit(S->getCond());
624     CurrentCount = 0;
625     BreakContinueStack.push_back(BreakContinue());
626     Visit(S->getBody());
627     // If the switch is inside a loop, add the continue counts.
628     BreakContinue BC = BreakContinueStack.pop_back_val();
629     if (!BreakContinueStack.empty())
630       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
631     // Counter tracks the exit block of the switch.
632     setCount(PGO.getRegionCount(S));
633     RecordNextStmtCount = true;
634   }
635 
636   void VisitSwitchCase(const SwitchCase *S) {
637     RecordNextStmtCount = false;
638     // Counter for this particular case. This counts only jumps from the
639     // switch header and does not include fallthrough from the case before
640     // this one.
641     uint64_t CaseCount = PGO.getRegionCount(S);
642     setCount(CurrentCount + CaseCount);
643     // We need the count without fallthrough in the mapping, so it's more useful
644     // for branch probabilities.
645     CountMap[S] = CaseCount;
646     RecordNextStmtCount = true;
647     Visit(S->getSubStmt());
648   }
649 
650   void VisitIfStmt(const IfStmt *S) {
651     RecordStmtCount(S);
652 
653     if (S->isConsteval()) {
654       const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
655       if (Stm)
656         Visit(Stm);
657       return;
658     }
659 
660     uint64_t ParentCount = CurrentCount;
661     if (S->getInit())
662       Visit(S->getInit());
663     Visit(S->getCond());
664 
665     // Counter tracks the "then" part of an if statement. The count for
666     // the "else" part, if it exists, will be calculated from this counter.
667     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
668     CountMap[S->getThen()] = ThenCount;
669     Visit(S->getThen());
670     uint64_t OutCount = CurrentCount;
671 
672     uint64_t ElseCount = ParentCount - ThenCount;
673     if (S->getElse()) {
674       setCount(ElseCount);
675       CountMap[S->getElse()] = ElseCount;
676       Visit(S->getElse());
677       OutCount += CurrentCount;
678     } else
679       OutCount += ElseCount;
680     setCount(OutCount);
681     RecordNextStmtCount = true;
682   }
683 
684   void VisitCXXTryStmt(const CXXTryStmt *S) {
685     RecordStmtCount(S);
686     Visit(S->getTryBlock());
687     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
688       Visit(S->getHandler(I));
689     // Counter tracks the continuation block of the try statement.
690     setCount(PGO.getRegionCount(S));
691     RecordNextStmtCount = true;
692   }
693 
694   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
695     RecordNextStmtCount = false;
696     // Counter tracks the catch statement's handler block.
697     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
698     CountMap[S] = CatchCount;
699     Visit(S->getHandlerBlock());
700   }
701 
702   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
703     RecordStmtCount(E);
704     uint64_t ParentCount = CurrentCount;
705     Visit(E->getCond());
706 
707     // Counter tracks the "true" part of a conditional operator. The
708     // count in the "false" part will be calculated from this counter.
709     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
710     CountMap[E->getTrueExpr()] = TrueCount;
711     Visit(E->getTrueExpr());
712     uint64_t OutCount = CurrentCount;
713 
714     uint64_t FalseCount = setCount(ParentCount - TrueCount);
715     CountMap[E->getFalseExpr()] = FalseCount;
716     Visit(E->getFalseExpr());
717     OutCount += CurrentCount;
718 
719     setCount(OutCount);
720     RecordNextStmtCount = true;
721   }
722 
723   void VisitBinLAnd(const BinaryOperator *E) {
724     RecordStmtCount(E);
725     uint64_t ParentCount = CurrentCount;
726     Visit(E->getLHS());
727     // Counter tracks the right hand side of a logical and operator.
728     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
729     CountMap[E->getRHS()] = RHSCount;
730     Visit(E->getRHS());
731     setCount(ParentCount + RHSCount - CurrentCount);
732     RecordNextStmtCount = true;
733   }
734 
735   void VisitBinLOr(const BinaryOperator *E) {
736     RecordStmtCount(E);
737     uint64_t ParentCount = CurrentCount;
738     Visit(E->getLHS());
739     // Counter tracks the right hand side of a logical or operator.
740     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
741     CountMap[E->getRHS()] = RHSCount;
742     Visit(E->getRHS());
743     setCount(ParentCount + RHSCount - CurrentCount);
744     RecordNextStmtCount = true;
745   }
746 };
747 } // end anonymous namespace
748 
749 void PGOHash::combine(HashType Type) {
750   // Check that we never combine 0 and only have six bits.
751   assert(Type && "Hash is invalid: unexpected type 0");
752   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
753 
754   // Pass through MD5 if enough work has built up.
755   if (Count && Count % NumTypesPerWord == 0) {
756     using namespace llvm::support;
757     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
758     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
759     Working = 0;
760   }
761 
762   // Accumulate the current type.
763   ++Count;
764   Working = Working << NumBitsPerType | Type;
765 }
766 
767 uint64_t PGOHash::finalize() {
768   // Use Working as the hash directly if we never used MD5.
769   if (Count <= NumTypesPerWord)
770     // No need to byte swap here, since none of the math was endian-dependent.
771     // This number will be byte-swapped as required on endianness transitions,
772     // so we will see the same value on the other side.
773     return Working;
774 
775   // Check for remaining work in Working.
776   if (Working) {
777     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
778     // is buggy because it converts a uint64_t into an array of uint8_t.
779     if (HashVersion < PGO_HASH_V3) {
780       MD5.update({(uint8_t)Working});
781     } else {
782       using namespace llvm::support;
783       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
784       MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
785     }
786   }
787 
788   // Finalize the MD5 and return the hash.
789   llvm::MD5::MD5Result Result;
790   MD5.final(Result);
791   return Result.low();
792 }
793 
794 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
795   const Decl *D = GD.getDecl();
796   if (!D->hasBody())
797     return;
798 
799   // Skip CUDA/HIP kernel launch stub functions.
800   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
801       D->hasAttr<CUDAGlobalAttr>())
802     return;
803 
804   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
805   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
806   if (!InstrumentRegions && !PGOReader)
807     return;
808   if (D->isImplicit())
809     return;
810   // Constructors and destructors may be represented by several functions in IR.
811   // If so, instrument only base variant, others are implemented by delegation
812   // to the base one, it would be counted twice otherwise.
813   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
814     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
815       if (GD.getCtorType() != Ctor_Base &&
816           CodeGenFunction::IsConstructorDelegationValid(CCD))
817         return;
818   }
819   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
820     return;
821 
822   CGM.ClearUnusedCoverageMapping(D);
823   if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
824     return;
825 
826   setFuncName(Fn);
827 
828   mapRegionCounters(D);
829   if (CGM.getCodeGenOpts().CoverageMapping)
830     emitCounterRegionMapping(D);
831   if (PGOReader) {
832     SourceManager &SM = CGM.getContext().getSourceManager();
833     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
834     computeRegionCounts(D);
835     applyFunctionAttributes(PGOReader, Fn);
836   }
837 }
838 
839 void CodeGenPGO::mapRegionCounters(const Decl *D) {
840   // Use the latest hash version when inserting instrumentation, but use the
841   // version in the indexed profile if we're reading PGO data.
842   PGOHashVersion HashVersion = PGO_HASH_LATEST;
843   uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
844   if (auto *PGOReader = CGM.getPGOReader()) {
845     HashVersion = getPGOHashVersion(PGOReader, CGM);
846     ProfileVersion = PGOReader->getVersion();
847   }
848 
849   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
850   MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
851   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
852     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
853   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
854     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
855   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
856     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
857   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
858     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
859   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
860   NumRegionCounters = Walker.NextCounter;
861   FunctionHash = Walker.Hash.finalize();
862 }
863 
864 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
865   if (!D->getBody())
866     return true;
867 
868   // Skip host-only functions in the CUDA device compilation and device-only
869   // functions in the host compilation. Just roughly filter them out based on
870   // the function attributes. If there are effectively host-only or device-only
871   // ones, their coverage mapping may still be generated.
872   if (CGM.getLangOpts().CUDA &&
873       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
874         !D->hasAttr<CUDAGlobalAttr>()) ||
875        (!CGM.getLangOpts().CUDAIsDevice &&
876         (D->hasAttr<CUDAGlobalAttr>() ||
877          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
878     return true;
879 
880   // Don't map the functions in system headers.
881   const auto &SM = CGM.getContext().getSourceManager();
882   auto Loc = D->getBody()->getBeginLoc();
883   return SM.isInSystemHeader(Loc);
884 }
885 
886 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
887   if (skipRegionMappingForDecl(D))
888     return;
889 
890   std::string CoverageMapping;
891   llvm::raw_string_ostream OS(CoverageMapping);
892   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
893                                 CGM.getContext().getSourceManager(),
894                                 CGM.getLangOpts(), RegionCounterMap.get());
895   MappingGen.emitCounterMapping(D, OS);
896   OS.flush();
897 
898   if (CoverageMapping.empty())
899     return;
900 
901   CGM.getCoverageMapping()->addFunctionMappingRecord(
902       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
903 }
904 
905 void
906 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
907                                     llvm::GlobalValue::LinkageTypes Linkage) {
908   if (skipRegionMappingForDecl(D))
909     return;
910 
911   std::string CoverageMapping;
912   llvm::raw_string_ostream OS(CoverageMapping);
913   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
914                                 CGM.getContext().getSourceManager(),
915                                 CGM.getLangOpts());
916   MappingGen.emitEmptyMapping(D, OS);
917   OS.flush();
918 
919   if (CoverageMapping.empty())
920     return;
921 
922   setFuncName(Name, Linkage);
923   CGM.getCoverageMapping()->addFunctionMappingRecord(
924       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
925 }
926 
927 void CodeGenPGO::computeRegionCounts(const Decl *D) {
928   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
929   ComputeRegionCounts Walker(*StmtCountMap, *this);
930   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
931     Walker.VisitFunctionDecl(FD);
932   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
933     Walker.VisitObjCMethodDecl(MD);
934   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
935     Walker.VisitBlockDecl(BD);
936   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
937     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
938 }
939 
940 void
941 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
942                                     llvm::Function *Fn) {
943   if (!haveRegionCounts())
944     return;
945 
946   uint64_t FunctionCount = getRegionCount(nullptr);
947   Fn->setEntryCount(FunctionCount);
948 }
949 
950 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
951                                       llvm::Value *StepV) {
952   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
953     return;
954   if (!Builder.GetInsertBlock())
955     return;
956 
957   unsigned Counter = (*RegionCounterMap)[S];
958   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
959 
960   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
961                          Builder.getInt64(FunctionHash),
962                          Builder.getInt32(NumRegionCounters),
963                          Builder.getInt32(Counter), StepV};
964   if (!StepV)
965     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
966                        makeArrayRef(Args, 4));
967   else
968     Builder.CreateCall(
969         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
970         makeArrayRef(Args));
971 }
972 
973 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
974   if (CGM.getCodeGenOpts().hasProfileClangInstr())
975     M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
976                     uint32_t(EnableValueProfiling));
977 }
978 
979 // This method either inserts a call to the profile run-time during
980 // instrumentation or puts profile data into metadata for PGO use.
981 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
982     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
983 
984   if (!EnableValueProfiling)
985     return;
986 
987   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
988     return;
989 
990   if (isa<llvm::Constant>(ValuePtr))
991     return;
992 
993   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
994   if (InstrumentValueSites && RegionCounterMap) {
995     auto BuilderInsertPoint = Builder.saveIP();
996     Builder.SetInsertPoint(ValueSite);
997     llvm::Value *Args[5] = {
998         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
999         Builder.getInt64(FunctionHash),
1000         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1001         Builder.getInt32(ValueKind),
1002         Builder.getInt32(NumValueSites[ValueKind]++)
1003     };
1004     Builder.CreateCall(
1005         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1006     Builder.restoreIP(BuilderInsertPoint);
1007     return;
1008   }
1009 
1010   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1011   if (PGOReader && haveRegionCounts()) {
1012     // We record the top most called three functions at each call site.
1013     // Profile metadata contains "VP" string identifying this metadata
1014     // as value profiling data, then a uint32_t value for the value profiling
1015     // kind, a uint64_t value for the total number of times the call is
1016     // executed, followed by the function hash and execution count (uint64_t)
1017     // pairs for each function.
1018     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1019       return;
1020 
1021     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1022                             (llvm::InstrProfValueKind)ValueKind,
1023                             NumValueSites[ValueKind]);
1024 
1025     NumValueSites[ValueKind]++;
1026   }
1027 }
1028 
1029 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1030                                   bool IsInMainFile) {
1031   CGM.getPGOStats().addVisited(IsInMainFile);
1032   RegionCounts.clear();
1033   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1034       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1035   if (auto E = RecordExpected.takeError()) {
1036     auto IPE = llvm::InstrProfError::take(std::move(E));
1037     if (IPE == llvm::instrprof_error::unknown_function)
1038       CGM.getPGOStats().addMissing(IsInMainFile);
1039     else if (IPE == llvm::instrprof_error::hash_mismatch)
1040       CGM.getPGOStats().addMismatched(IsInMainFile);
1041     else if (IPE == llvm::instrprof_error::malformed)
1042       // TODO: Consider a more specific warning for this case.
1043       CGM.getPGOStats().addMismatched(IsInMainFile);
1044     return;
1045   }
1046   ProfRecord =
1047       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1048   RegionCounts = ProfRecord->Counts;
1049 }
1050 
1051 /// Calculate what to divide by to scale weights.
1052 ///
1053 /// Given the maximum weight, calculate a divisor that will scale all the
1054 /// weights to strictly less than UINT32_MAX.
1055 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1056   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1057 }
1058 
1059 /// Scale an individual branch weight (and add 1).
1060 ///
1061 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1062 ///
1063 /// According to Laplace's Rule of Succession, it is better to compute the
1064 /// weight based on the count plus 1, so universally add 1 to the value.
1065 ///
1066 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1067 /// greater than \c Weight.
1068 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1069   assert(Scale && "scale by 0?");
1070   uint64_t Scaled = Weight / Scale + 1;
1071   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1072   return Scaled;
1073 }
1074 
1075 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1076                                                     uint64_t FalseCount) const {
1077   // Check for empty weights.
1078   if (!TrueCount && !FalseCount)
1079     return nullptr;
1080 
1081   // Calculate how to scale down to 32-bits.
1082   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1083 
1084   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1085   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1086                                       scaleBranchWeight(FalseCount, Scale));
1087 }
1088 
1089 llvm::MDNode *
1090 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1091   // We need at least two elements to create meaningful weights.
1092   if (Weights.size() < 2)
1093     return nullptr;
1094 
1095   // Check for empty weights.
1096   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1097   if (MaxWeight == 0)
1098     return nullptr;
1099 
1100   // Calculate how to scale down to 32-bits.
1101   uint64_t Scale = calculateWeightScale(MaxWeight);
1102 
1103   SmallVector<uint32_t, 16> ScaledWeights;
1104   ScaledWeights.reserve(Weights.size());
1105   for (uint64_t W : Weights)
1106     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1107 
1108   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1109   return MDHelper.createBranchWeights(ScaledWeights);
1110 }
1111 
1112 llvm::MDNode *
1113 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1114                                              uint64_t LoopCount) const {
1115   if (!PGO.haveRegionCounts())
1116     return nullptr;
1117   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1118   if (!CondCount || *CondCount == 0)
1119     return nullptr;
1120   return createProfileWeights(LoopCount,
1121                               std::max(*CondCount, LoopCount) - LoopCount);
1122 }
1123