xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision cfd6422a5217410fbd66f7a7a8a64d9d85e61229)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55   PGO_HASH_V3,
56 
57   // Keep this set to the latest hash version.
58   PGO_HASH_LATEST = PGO_HASH_V3
59 };
60 
61 namespace {
62 /// Stable hasher for PGO region counters.
63 ///
64 /// PGOHash produces a stable hash of a given function's control flow.
65 ///
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
68 ///
69 /// \note  When this hash does eventually change (years?), we still need to
70 /// support old hashes.  We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
72 class PGOHash {
73   uint64_t Working;
74   unsigned Count;
75   PGOHashVersion HashVersion;
76   llvm::MD5 MD5;
77 
78   static const int NumBitsPerType = 6;
79   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80   static const unsigned TooBig = 1u << NumBitsPerType;
81 
82 public:
83   /// Hash values for AST nodes.
84   ///
85   /// Distinct values for AST nodes that have region counters attached.
86   ///
87   /// These values must be stable.  All new members must be added at the end,
88   /// and no members should be removed.  Changing the enumeration value for an
89   /// AST node will affect the hash of every function that contains that node.
90   enum HashType : unsigned char {
91     None = 0,
92     LabelStmt = 1,
93     WhileStmt,
94     DoStmt,
95     ForStmt,
96     CXXForRangeStmt,
97     ObjCForCollectionStmt,
98     SwitchStmt,
99     CaseStmt,
100     DefaultStmt,
101     IfStmt,
102     CXXTryStmt,
103     CXXCatchStmt,
104     ConditionalOperator,
105     BinaryOperatorLAnd,
106     BinaryOperatorLOr,
107     BinaryConditionalOperator,
108     // The preceding values are available with PGO_HASH_V1.
109 
110     EndOfScope,
111     IfThenBranch,
112     IfElseBranch,
113     GotoStmt,
114     IndirectGotoStmt,
115     BreakStmt,
116     ContinueStmt,
117     ReturnStmt,
118     ThrowExpr,
119     UnaryOperatorLNot,
120     BinaryOperatorLT,
121     BinaryOperatorGT,
122     BinaryOperatorLE,
123     BinaryOperatorGE,
124     BinaryOperatorEQ,
125     BinaryOperatorNE,
126     // The preceding values are available since PGO_HASH_V2.
127 
128     // Keep this last.  It's for the static assert that follows.
129     LastHashType
130   };
131   static_assert(LastHashType <= TooBig, "Too many types in HashType");
132 
133   PGOHash(PGOHashVersion HashVersion)
134       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135   void combine(HashType Type);
136   uint64_t finalize();
137   PGOHashVersion getHashVersion() const { return HashVersion; }
138 };
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
142 
143 /// Get the PGO hash version used in the given indexed profile.
144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145                                         CodeGenModule &CGM) {
146   if (PGOReader->getVersion() <= 4)
147     return PGO_HASH_V1;
148   if (PGOReader->getVersion() <= 5)
149     return PGO_HASH_V2;
150   return PGO_HASH_V3;
151 }
152 
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155   using Base = RecursiveASTVisitor<MapRegionCounters>;
156 
157   /// The next counter value to assign.
158   unsigned NextCounter;
159   /// The function hash.
160   PGOHash Hash;
161   /// The map of statements to counters.
162   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163 
164   MapRegionCounters(PGOHashVersion HashVersion,
165                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
166       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
167 
168   // Blocks and lambdas are handled as separate functions, so we need not
169   // traverse them in the parent context.
170   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
171   bool TraverseLambdaExpr(LambdaExpr *LE) {
172     // Traverse the captures, but not the body.
173     for (auto C : zip(LE->captures(), LE->capture_inits()))
174       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
175     return true;
176   }
177   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
178 
179   bool VisitDecl(const Decl *D) {
180     switch (D->getKind()) {
181     default:
182       break;
183     case Decl::Function:
184     case Decl::CXXMethod:
185     case Decl::CXXConstructor:
186     case Decl::CXXDestructor:
187     case Decl::CXXConversion:
188     case Decl::ObjCMethod:
189     case Decl::Block:
190     case Decl::Captured:
191       CounterMap[D->getBody()] = NextCounter++;
192       break;
193     }
194     return true;
195   }
196 
197   /// If \p S gets a fresh counter, update the counter mappings. Return the
198   /// V1 hash of \p S.
199   PGOHash::HashType updateCounterMappings(Stmt *S) {
200     auto Type = getHashType(PGO_HASH_V1, S);
201     if (Type != PGOHash::None)
202       CounterMap[S] = NextCounter++;
203     return Type;
204   }
205 
206   /// Include \p S in the function hash.
207   bool VisitStmt(Stmt *S) {
208     auto Type = updateCounterMappings(S);
209     if (Hash.getHashVersion() != PGO_HASH_V1)
210       Type = getHashType(Hash.getHashVersion(), S);
211     if (Type != PGOHash::None)
212       Hash.combine(Type);
213     return true;
214   }
215 
216   bool TraverseIfStmt(IfStmt *If) {
217     // If we used the V1 hash, use the default traversal.
218     if (Hash.getHashVersion() == PGO_HASH_V1)
219       return Base::TraverseIfStmt(If);
220 
221     // Otherwise, keep track of which branch we're in while traversing.
222     VisitStmt(If);
223     for (Stmt *CS : If->children()) {
224       if (!CS)
225         continue;
226       if (CS == If->getThen())
227         Hash.combine(PGOHash::IfThenBranch);
228       else if (CS == If->getElse())
229         Hash.combine(PGOHash::IfElseBranch);
230       TraverseStmt(CS);
231     }
232     Hash.combine(PGOHash::EndOfScope);
233     return true;
234   }
235 
236 // If the statement type \p N is nestable, and its nesting impacts profile
237 // stability, define a custom traversal which tracks the end of the statement
238 // in the hash (provided we're not using the V1 hash).
239 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
240   bool Traverse##N(N *S) {                                                     \
241     Base::Traverse##N(S);                                                      \
242     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
243       Hash.combine(PGOHash::EndOfScope);                                       \
244     return true;                                                               \
245   }
246 
247   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
248   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
249   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
250   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
251   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
252   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
253   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
254 
255   /// Get version \p HashVersion of the PGO hash for \p S.
256   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
257     switch (S->getStmtClass()) {
258     default:
259       break;
260     case Stmt::LabelStmtClass:
261       return PGOHash::LabelStmt;
262     case Stmt::WhileStmtClass:
263       return PGOHash::WhileStmt;
264     case Stmt::DoStmtClass:
265       return PGOHash::DoStmt;
266     case Stmt::ForStmtClass:
267       return PGOHash::ForStmt;
268     case Stmt::CXXForRangeStmtClass:
269       return PGOHash::CXXForRangeStmt;
270     case Stmt::ObjCForCollectionStmtClass:
271       return PGOHash::ObjCForCollectionStmt;
272     case Stmt::SwitchStmtClass:
273       return PGOHash::SwitchStmt;
274     case Stmt::CaseStmtClass:
275       return PGOHash::CaseStmt;
276     case Stmt::DefaultStmtClass:
277       return PGOHash::DefaultStmt;
278     case Stmt::IfStmtClass:
279       return PGOHash::IfStmt;
280     case Stmt::CXXTryStmtClass:
281       return PGOHash::CXXTryStmt;
282     case Stmt::CXXCatchStmtClass:
283       return PGOHash::CXXCatchStmt;
284     case Stmt::ConditionalOperatorClass:
285       return PGOHash::ConditionalOperator;
286     case Stmt::BinaryConditionalOperatorClass:
287       return PGOHash::BinaryConditionalOperator;
288     case Stmt::BinaryOperatorClass: {
289       const BinaryOperator *BO = cast<BinaryOperator>(S);
290       if (BO->getOpcode() == BO_LAnd)
291         return PGOHash::BinaryOperatorLAnd;
292       if (BO->getOpcode() == BO_LOr)
293         return PGOHash::BinaryOperatorLOr;
294       if (HashVersion >= PGO_HASH_V2) {
295         switch (BO->getOpcode()) {
296         default:
297           break;
298         case BO_LT:
299           return PGOHash::BinaryOperatorLT;
300         case BO_GT:
301           return PGOHash::BinaryOperatorGT;
302         case BO_LE:
303           return PGOHash::BinaryOperatorLE;
304         case BO_GE:
305           return PGOHash::BinaryOperatorGE;
306         case BO_EQ:
307           return PGOHash::BinaryOperatorEQ;
308         case BO_NE:
309           return PGOHash::BinaryOperatorNE;
310         }
311       }
312       break;
313     }
314     }
315 
316     if (HashVersion >= PGO_HASH_V2) {
317       switch (S->getStmtClass()) {
318       default:
319         break;
320       case Stmt::GotoStmtClass:
321         return PGOHash::GotoStmt;
322       case Stmt::IndirectGotoStmtClass:
323         return PGOHash::IndirectGotoStmt;
324       case Stmt::BreakStmtClass:
325         return PGOHash::BreakStmt;
326       case Stmt::ContinueStmtClass:
327         return PGOHash::ContinueStmt;
328       case Stmt::ReturnStmtClass:
329         return PGOHash::ReturnStmt;
330       case Stmt::CXXThrowExprClass:
331         return PGOHash::ThrowExpr;
332       case Stmt::UnaryOperatorClass: {
333         const UnaryOperator *UO = cast<UnaryOperator>(S);
334         if (UO->getOpcode() == UO_LNot)
335           return PGOHash::UnaryOperatorLNot;
336         break;
337       }
338       }
339     }
340 
341     return PGOHash::None;
342   }
343 };
344 
345 /// A StmtVisitor that propagates the raw counts through the AST and
346 /// records the count at statements where the value may change.
347 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
348   /// PGO state.
349   CodeGenPGO &PGO;
350 
351   /// A flag that is set when the current count should be recorded on the
352   /// next statement, such as at the exit of a loop.
353   bool RecordNextStmtCount;
354 
355   /// The count at the current location in the traversal.
356   uint64_t CurrentCount;
357 
358   /// The map of statements to count values.
359   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
360 
361   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
362   struct BreakContinue {
363     uint64_t BreakCount;
364     uint64_t ContinueCount;
365     BreakContinue() : BreakCount(0), ContinueCount(0) {}
366   };
367   SmallVector<BreakContinue, 8> BreakContinueStack;
368 
369   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
370                       CodeGenPGO &PGO)
371       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
372 
373   void RecordStmtCount(const Stmt *S) {
374     if (RecordNextStmtCount) {
375       CountMap[S] = CurrentCount;
376       RecordNextStmtCount = false;
377     }
378   }
379 
380   /// Set and return the current count.
381   uint64_t setCount(uint64_t Count) {
382     CurrentCount = Count;
383     return Count;
384   }
385 
386   void VisitStmt(const Stmt *S) {
387     RecordStmtCount(S);
388     for (const Stmt *Child : S->children())
389       if (Child)
390         this->Visit(Child);
391   }
392 
393   void VisitFunctionDecl(const FunctionDecl *D) {
394     // Counter tracks entry to the function body.
395     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
396     CountMap[D->getBody()] = BodyCount;
397     Visit(D->getBody());
398   }
399 
400   // Skip lambda expressions. We visit these as FunctionDecls when we're
401   // generating them and aren't interested in the body when generating a
402   // parent context.
403   void VisitLambdaExpr(const LambdaExpr *LE) {}
404 
405   void VisitCapturedDecl(const CapturedDecl *D) {
406     // Counter tracks entry to the capture body.
407     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
408     CountMap[D->getBody()] = BodyCount;
409     Visit(D->getBody());
410   }
411 
412   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
413     // Counter tracks entry to the method body.
414     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
415     CountMap[D->getBody()] = BodyCount;
416     Visit(D->getBody());
417   }
418 
419   void VisitBlockDecl(const BlockDecl *D) {
420     // Counter tracks entry to the block body.
421     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
422     CountMap[D->getBody()] = BodyCount;
423     Visit(D->getBody());
424   }
425 
426   void VisitReturnStmt(const ReturnStmt *S) {
427     RecordStmtCount(S);
428     if (S->getRetValue())
429       Visit(S->getRetValue());
430     CurrentCount = 0;
431     RecordNextStmtCount = true;
432   }
433 
434   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
435     RecordStmtCount(E);
436     if (E->getSubExpr())
437       Visit(E->getSubExpr());
438     CurrentCount = 0;
439     RecordNextStmtCount = true;
440   }
441 
442   void VisitGotoStmt(const GotoStmt *S) {
443     RecordStmtCount(S);
444     CurrentCount = 0;
445     RecordNextStmtCount = true;
446   }
447 
448   void VisitLabelStmt(const LabelStmt *S) {
449     RecordNextStmtCount = false;
450     // Counter tracks the block following the label.
451     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
452     CountMap[S] = BlockCount;
453     Visit(S->getSubStmt());
454   }
455 
456   void VisitBreakStmt(const BreakStmt *S) {
457     RecordStmtCount(S);
458     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459     BreakContinueStack.back().BreakCount += CurrentCount;
460     CurrentCount = 0;
461     RecordNextStmtCount = true;
462   }
463 
464   void VisitContinueStmt(const ContinueStmt *S) {
465     RecordStmtCount(S);
466     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467     BreakContinueStack.back().ContinueCount += CurrentCount;
468     CurrentCount = 0;
469     RecordNextStmtCount = true;
470   }
471 
472   void VisitWhileStmt(const WhileStmt *S) {
473     RecordStmtCount(S);
474     uint64_t ParentCount = CurrentCount;
475 
476     BreakContinueStack.push_back(BreakContinue());
477     // Visit the body region first so the break/continue adjustments can be
478     // included when visiting the condition.
479     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
480     CountMap[S->getBody()] = CurrentCount;
481     Visit(S->getBody());
482     uint64_t BackedgeCount = CurrentCount;
483 
484     // ...then go back and propagate counts through the condition. The count
485     // at the start of the condition is the sum of the incoming edges,
486     // the backedge from the end of the loop body, and the edges from
487     // continue statements.
488     BreakContinue BC = BreakContinueStack.pop_back_val();
489     uint64_t CondCount =
490         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
491     CountMap[S->getCond()] = CondCount;
492     Visit(S->getCond());
493     setCount(BC.BreakCount + CondCount - BodyCount);
494     RecordNextStmtCount = true;
495   }
496 
497   void VisitDoStmt(const DoStmt *S) {
498     RecordStmtCount(S);
499     uint64_t LoopCount = PGO.getRegionCount(S);
500 
501     BreakContinueStack.push_back(BreakContinue());
502     // The count doesn't include the fallthrough from the parent scope. Add it.
503     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
504     CountMap[S->getBody()] = BodyCount;
505     Visit(S->getBody());
506     uint64_t BackedgeCount = CurrentCount;
507 
508     BreakContinue BC = BreakContinueStack.pop_back_val();
509     // The count at the start of the condition is equal to the count at the
510     // end of the body, plus any continues.
511     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
512     CountMap[S->getCond()] = CondCount;
513     Visit(S->getCond());
514     setCount(BC.BreakCount + CondCount - LoopCount);
515     RecordNextStmtCount = true;
516   }
517 
518   void VisitForStmt(const ForStmt *S) {
519     RecordStmtCount(S);
520     if (S->getInit())
521       Visit(S->getInit());
522 
523     uint64_t ParentCount = CurrentCount;
524 
525     BreakContinueStack.push_back(BreakContinue());
526     // Visit the body region first. (This is basically the same as a while
527     // loop; see further comments in VisitWhileStmt.)
528     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
529     CountMap[S->getBody()] = BodyCount;
530     Visit(S->getBody());
531     uint64_t BackedgeCount = CurrentCount;
532     BreakContinue BC = BreakContinueStack.pop_back_val();
533 
534     // The increment is essentially part of the body but it needs to include
535     // the count for all the continue statements.
536     if (S->getInc()) {
537       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
538       CountMap[S->getInc()] = IncCount;
539       Visit(S->getInc());
540     }
541 
542     // ...then go back and propagate counts through the condition.
543     uint64_t CondCount =
544         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
545     if (S->getCond()) {
546       CountMap[S->getCond()] = CondCount;
547       Visit(S->getCond());
548     }
549     setCount(BC.BreakCount + CondCount - BodyCount);
550     RecordNextStmtCount = true;
551   }
552 
553   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
554     RecordStmtCount(S);
555     if (S->getInit())
556       Visit(S->getInit());
557     Visit(S->getLoopVarStmt());
558     Visit(S->getRangeStmt());
559     Visit(S->getBeginStmt());
560     Visit(S->getEndStmt());
561 
562     uint64_t ParentCount = CurrentCount;
563     BreakContinueStack.push_back(BreakContinue());
564     // Visit the body region first. (This is basically the same as a while
565     // loop; see further comments in VisitWhileStmt.)
566     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
567     CountMap[S->getBody()] = BodyCount;
568     Visit(S->getBody());
569     uint64_t BackedgeCount = CurrentCount;
570     BreakContinue BC = BreakContinueStack.pop_back_val();
571 
572     // The increment is essentially part of the body but it needs to include
573     // the count for all the continue statements.
574     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
575     CountMap[S->getInc()] = IncCount;
576     Visit(S->getInc());
577 
578     // ...then go back and propagate counts through the condition.
579     uint64_t CondCount =
580         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
581     CountMap[S->getCond()] = CondCount;
582     Visit(S->getCond());
583     setCount(BC.BreakCount + CondCount - BodyCount);
584     RecordNextStmtCount = true;
585   }
586 
587   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
588     RecordStmtCount(S);
589     Visit(S->getElement());
590     uint64_t ParentCount = CurrentCount;
591     BreakContinueStack.push_back(BreakContinue());
592     // Counter tracks the body of the loop.
593     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
594     CountMap[S->getBody()] = BodyCount;
595     Visit(S->getBody());
596     uint64_t BackedgeCount = CurrentCount;
597     BreakContinue BC = BreakContinueStack.pop_back_val();
598 
599     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
600              BodyCount);
601     RecordNextStmtCount = true;
602   }
603 
604   void VisitSwitchStmt(const SwitchStmt *S) {
605     RecordStmtCount(S);
606     if (S->getInit())
607       Visit(S->getInit());
608     Visit(S->getCond());
609     CurrentCount = 0;
610     BreakContinueStack.push_back(BreakContinue());
611     Visit(S->getBody());
612     // If the switch is inside a loop, add the continue counts.
613     BreakContinue BC = BreakContinueStack.pop_back_val();
614     if (!BreakContinueStack.empty())
615       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
616     // Counter tracks the exit block of the switch.
617     setCount(PGO.getRegionCount(S));
618     RecordNextStmtCount = true;
619   }
620 
621   void VisitSwitchCase(const SwitchCase *S) {
622     RecordNextStmtCount = false;
623     // Counter for this particular case. This counts only jumps from the
624     // switch header and does not include fallthrough from the case before
625     // this one.
626     uint64_t CaseCount = PGO.getRegionCount(S);
627     setCount(CurrentCount + CaseCount);
628     // We need the count without fallthrough in the mapping, so it's more useful
629     // for branch probabilities.
630     CountMap[S] = CaseCount;
631     RecordNextStmtCount = true;
632     Visit(S->getSubStmt());
633   }
634 
635   void VisitIfStmt(const IfStmt *S) {
636     RecordStmtCount(S);
637     uint64_t ParentCount = CurrentCount;
638     if (S->getInit())
639       Visit(S->getInit());
640     Visit(S->getCond());
641 
642     // Counter tracks the "then" part of an if statement. The count for
643     // the "else" part, if it exists, will be calculated from this counter.
644     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
645     CountMap[S->getThen()] = ThenCount;
646     Visit(S->getThen());
647     uint64_t OutCount = CurrentCount;
648 
649     uint64_t ElseCount = ParentCount - ThenCount;
650     if (S->getElse()) {
651       setCount(ElseCount);
652       CountMap[S->getElse()] = ElseCount;
653       Visit(S->getElse());
654       OutCount += CurrentCount;
655     } else
656       OutCount += ElseCount;
657     setCount(OutCount);
658     RecordNextStmtCount = true;
659   }
660 
661   void VisitCXXTryStmt(const CXXTryStmt *S) {
662     RecordStmtCount(S);
663     Visit(S->getTryBlock());
664     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
665       Visit(S->getHandler(I));
666     // Counter tracks the continuation block of the try statement.
667     setCount(PGO.getRegionCount(S));
668     RecordNextStmtCount = true;
669   }
670 
671   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
672     RecordNextStmtCount = false;
673     // Counter tracks the catch statement's handler block.
674     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
675     CountMap[S] = CatchCount;
676     Visit(S->getHandlerBlock());
677   }
678 
679   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
680     RecordStmtCount(E);
681     uint64_t ParentCount = CurrentCount;
682     Visit(E->getCond());
683 
684     // Counter tracks the "true" part of a conditional operator. The
685     // count in the "false" part will be calculated from this counter.
686     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
687     CountMap[E->getTrueExpr()] = TrueCount;
688     Visit(E->getTrueExpr());
689     uint64_t OutCount = CurrentCount;
690 
691     uint64_t FalseCount = setCount(ParentCount - TrueCount);
692     CountMap[E->getFalseExpr()] = FalseCount;
693     Visit(E->getFalseExpr());
694     OutCount += CurrentCount;
695 
696     setCount(OutCount);
697     RecordNextStmtCount = true;
698   }
699 
700   void VisitBinLAnd(const BinaryOperator *E) {
701     RecordStmtCount(E);
702     uint64_t ParentCount = CurrentCount;
703     Visit(E->getLHS());
704     // Counter tracks the right hand side of a logical and operator.
705     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
706     CountMap[E->getRHS()] = RHSCount;
707     Visit(E->getRHS());
708     setCount(ParentCount + RHSCount - CurrentCount);
709     RecordNextStmtCount = true;
710   }
711 
712   void VisitBinLOr(const BinaryOperator *E) {
713     RecordStmtCount(E);
714     uint64_t ParentCount = CurrentCount;
715     Visit(E->getLHS());
716     // Counter tracks the right hand side of a logical or operator.
717     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
718     CountMap[E->getRHS()] = RHSCount;
719     Visit(E->getRHS());
720     setCount(ParentCount + RHSCount - CurrentCount);
721     RecordNextStmtCount = true;
722   }
723 };
724 } // end anonymous namespace
725 
726 void PGOHash::combine(HashType Type) {
727   // Check that we never combine 0 and only have six bits.
728   assert(Type && "Hash is invalid: unexpected type 0");
729   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
730 
731   // Pass through MD5 if enough work has built up.
732   if (Count && Count % NumTypesPerWord == 0) {
733     using namespace llvm::support;
734     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
735     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
736     Working = 0;
737   }
738 
739   // Accumulate the current type.
740   ++Count;
741   Working = Working << NumBitsPerType | Type;
742 }
743 
744 uint64_t PGOHash::finalize() {
745   // Use Working as the hash directly if we never used MD5.
746   if (Count <= NumTypesPerWord)
747     // No need to byte swap here, since none of the math was endian-dependent.
748     // This number will be byte-swapped as required on endianness transitions,
749     // so we will see the same value on the other side.
750     return Working;
751 
752   // Check for remaining work in Working.
753   if (Working) {
754     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
755     // is buggy because it converts a uint64_t into an array of uint8_t.
756     if (HashVersion < PGO_HASH_V3) {
757       MD5.update({(uint8_t)Working});
758     } else {
759       using namespace llvm::support;
760       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
761       MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
762     }
763   }
764 
765   // Finalize the MD5 and return the hash.
766   llvm::MD5::MD5Result Result;
767   MD5.final(Result);
768   return Result.low();
769 }
770 
771 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
772   const Decl *D = GD.getDecl();
773   if (!D->hasBody())
774     return;
775 
776   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
777   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
778   if (!InstrumentRegions && !PGOReader)
779     return;
780   if (D->isImplicit())
781     return;
782   // Constructors and destructors may be represented by several functions in IR.
783   // If so, instrument only base variant, others are implemented by delegation
784   // to the base one, it would be counted twice otherwise.
785   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
786     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
787       if (GD.getCtorType() != Ctor_Base &&
788           CodeGenFunction::IsConstructorDelegationValid(CCD))
789         return;
790   }
791   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
792     return;
793 
794   CGM.ClearUnusedCoverageMapping(D);
795   setFuncName(Fn);
796 
797   mapRegionCounters(D);
798   if (CGM.getCodeGenOpts().CoverageMapping)
799     emitCounterRegionMapping(D);
800   if (PGOReader) {
801     SourceManager &SM = CGM.getContext().getSourceManager();
802     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
803     computeRegionCounts(D);
804     applyFunctionAttributes(PGOReader, Fn);
805   }
806 }
807 
808 void CodeGenPGO::mapRegionCounters(const Decl *D) {
809   // Use the latest hash version when inserting instrumentation, but use the
810   // version in the indexed profile if we're reading PGO data.
811   PGOHashVersion HashVersion = PGO_HASH_LATEST;
812   if (auto *PGOReader = CGM.getPGOReader())
813     HashVersion = getPGOHashVersion(PGOReader, CGM);
814 
815   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
816   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
817   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
818     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
819   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
820     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
821   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
822     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
823   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
824     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
825   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
826   NumRegionCounters = Walker.NextCounter;
827   FunctionHash = Walker.Hash.finalize();
828 }
829 
830 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
831   if (!D->getBody())
832     return true;
833 
834   // Don't map the functions in system headers.
835   const auto &SM = CGM.getContext().getSourceManager();
836   auto Loc = D->getBody()->getBeginLoc();
837   return SM.isInSystemHeader(Loc);
838 }
839 
840 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
841   if (skipRegionMappingForDecl(D))
842     return;
843 
844   std::string CoverageMapping;
845   llvm::raw_string_ostream OS(CoverageMapping);
846   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
847                                 CGM.getContext().getSourceManager(),
848                                 CGM.getLangOpts(), RegionCounterMap.get());
849   MappingGen.emitCounterMapping(D, OS);
850   OS.flush();
851 
852   if (CoverageMapping.empty())
853     return;
854 
855   CGM.getCoverageMapping()->addFunctionMappingRecord(
856       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
857 }
858 
859 void
860 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
861                                     llvm::GlobalValue::LinkageTypes Linkage) {
862   if (skipRegionMappingForDecl(D))
863     return;
864 
865   std::string CoverageMapping;
866   llvm::raw_string_ostream OS(CoverageMapping);
867   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
868                                 CGM.getContext().getSourceManager(),
869                                 CGM.getLangOpts());
870   MappingGen.emitEmptyMapping(D, OS);
871   OS.flush();
872 
873   if (CoverageMapping.empty())
874     return;
875 
876   setFuncName(Name, Linkage);
877   CGM.getCoverageMapping()->addFunctionMappingRecord(
878       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
879 }
880 
881 void CodeGenPGO::computeRegionCounts(const Decl *D) {
882   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
883   ComputeRegionCounts Walker(*StmtCountMap, *this);
884   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
885     Walker.VisitFunctionDecl(FD);
886   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
887     Walker.VisitObjCMethodDecl(MD);
888   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
889     Walker.VisitBlockDecl(BD);
890   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
891     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
892 }
893 
894 void
895 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
896                                     llvm::Function *Fn) {
897   if (!haveRegionCounts())
898     return;
899 
900   uint64_t FunctionCount = getRegionCount(nullptr);
901   Fn->setEntryCount(FunctionCount);
902 }
903 
904 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
905                                       llvm::Value *StepV) {
906   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
907     return;
908   if (!Builder.GetInsertBlock())
909     return;
910 
911   unsigned Counter = (*RegionCounterMap)[S];
912   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
913 
914   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
915                          Builder.getInt64(FunctionHash),
916                          Builder.getInt32(NumRegionCounters),
917                          Builder.getInt32(Counter), StepV};
918   if (!StepV)
919     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
920                        makeArrayRef(Args, 4));
921   else
922     Builder.CreateCall(
923         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
924         makeArrayRef(Args));
925 }
926 
927 // This method either inserts a call to the profile run-time during
928 // instrumentation or puts profile data into metadata for PGO use.
929 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
930     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
931 
932   if (!EnableValueProfiling)
933     return;
934 
935   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
936     return;
937 
938   if (isa<llvm::Constant>(ValuePtr))
939     return;
940 
941   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
942   if (InstrumentValueSites && RegionCounterMap) {
943     auto BuilderInsertPoint = Builder.saveIP();
944     Builder.SetInsertPoint(ValueSite);
945     llvm::Value *Args[5] = {
946         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
947         Builder.getInt64(FunctionHash),
948         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
949         Builder.getInt32(ValueKind),
950         Builder.getInt32(NumValueSites[ValueKind]++)
951     };
952     Builder.CreateCall(
953         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
954     Builder.restoreIP(BuilderInsertPoint);
955     return;
956   }
957 
958   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
959   if (PGOReader && haveRegionCounts()) {
960     // We record the top most called three functions at each call site.
961     // Profile metadata contains "VP" string identifying this metadata
962     // as value profiling data, then a uint32_t value for the value profiling
963     // kind, a uint64_t value for the total number of times the call is
964     // executed, followed by the function hash and execution count (uint64_t)
965     // pairs for each function.
966     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
967       return;
968 
969     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
970                             (llvm::InstrProfValueKind)ValueKind,
971                             NumValueSites[ValueKind]);
972 
973     NumValueSites[ValueKind]++;
974   }
975 }
976 
977 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
978                                   bool IsInMainFile) {
979   CGM.getPGOStats().addVisited(IsInMainFile);
980   RegionCounts.clear();
981   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
982       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
983   if (auto E = RecordExpected.takeError()) {
984     auto IPE = llvm::InstrProfError::take(std::move(E));
985     if (IPE == llvm::instrprof_error::unknown_function)
986       CGM.getPGOStats().addMissing(IsInMainFile);
987     else if (IPE == llvm::instrprof_error::hash_mismatch)
988       CGM.getPGOStats().addMismatched(IsInMainFile);
989     else if (IPE == llvm::instrprof_error::malformed)
990       // TODO: Consider a more specific warning for this case.
991       CGM.getPGOStats().addMismatched(IsInMainFile);
992     return;
993   }
994   ProfRecord =
995       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
996   RegionCounts = ProfRecord->Counts;
997 }
998 
999 /// Calculate what to divide by to scale weights.
1000 ///
1001 /// Given the maximum weight, calculate a divisor that will scale all the
1002 /// weights to strictly less than UINT32_MAX.
1003 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1004   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1005 }
1006 
1007 /// Scale an individual branch weight (and add 1).
1008 ///
1009 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1010 ///
1011 /// According to Laplace's Rule of Succession, it is better to compute the
1012 /// weight based on the count plus 1, so universally add 1 to the value.
1013 ///
1014 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1015 /// greater than \c Weight.
1016 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1017   assert(Scale && "scale by 0?");
1018   uint64_t Scaled = Weight / Scale + 1;
1019   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1020   return Scaled;
1021 }
1022 
1023 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1024                                                     uint64_t FalseCount) {
1025   // Check for empty weights.
1026   if (!TrueCount && !FalseCount)
1027     return nullptr;
1028 
1029   // Calculate how to scale down to 32-bits.
1030   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1031 
1032   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1033   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1034                                       scaleBranchWeight(FalseCount, Scale));
1035 }
1036 
1037 llvm::MDNode *
1038 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1039   // We need at least two elements to create meaningful weights.
1040   if (Weights.size() < 2)
1041     return nullptr;
1042 
1043   // Check for empty weights.
1044   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1045   if (MaxWeight == 0)
1046     return nullptr;
1047 
1048   // Calculate how to scale down to 32-bits.
1049   uint64_t Scale = calculateWeightScale(MaxWeight);
1050 
1051   SmallVector<uint32_t, 16> ScaledWeights;
1052   ScaledWeights.reserve(Weights.size());
1053   for (uint64_t W : Weights)
1054     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1055 
1056   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1057   return MDHelper.createBranchWeights(ScaledWeights);
1058 }
1059 
1060 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1061                                                            uint64_t LoopCount) {
1062   if (!PGO.haveRegionCounts())
1063     return nullptr;
1064   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1065   if (!CondCount || *CondCount == 0)
1066     return nullptr;
1067   return createProfileWeights(LoopCount,
1068                               std::max(*CondCount, LoopCount) - LoopCount);
1069 }
1070