xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision 6966ac055c3b7a39266fb982493330df7a097997)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/Endian.h"
21 #include "llvm/Support/FileSystem.h"
22 #include "llvm/Support/MD5.h"
23 
24 static llvm::cl::opt<bool>
25     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
26                          llvm::cl::desc("Enable value profiling"),
27                          llvm::cl::Hidden, llvm::cl::init(false));
28 
29 using namespace clang;
30 using namespace CodeGen;
31 
32 void CodeGenPGO::setFuncName(StringRef Name,
33                              llvm::GlobalValue::LinkageTypes Linkage) {
34   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
35   FuncName = llvm::getPGOFuncName(
36       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
37       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
38 
39   // If we're generating a profile, create a variable for the name.
40   if (CGM.getCodeGenOpts().hasProfileClangInstr())
41     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
42 }
43 
44 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
45   setFuncName(Fn->getName(), Fn->getLinkage());
46   // Create PGOFuncName meta data.
47   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
48 }
49 
50 /// The version of the PGO hash algorithm.
51 enum PGOHashVersion : unsigned {
52   PGO_HASH_V1,
53   PGO_HASH_V2,
54 
55   // Keep this set to the latest hash version.
56   PGO_HASH_LATEST = PGO_HASH_V2
57 };
58 
59 namespace {
60 /// Stable hasher for PGO region counters.
61 ///
62 /// PGOHash produces a stable hash of a given function's control flow.
63 ///
64 /// Changing the output of this hash will invalidate all previously generated
65 /// profiles -- i.e., don't do it.
66 ///
67 /// \note  When this hash does eventually change (years?), we still need to
68 /// support old hashes.  We'll need to pull in the version number from the
69 /// profile data format and use the matching hash function.
70 class PGOHash {
71   uint64_t Working;
72   unsigned Count;
73   PGOHashVersion HashVersion;
74   llvm::MD5 MD5;
75 
76   static const int NumBitsPerType = 6;
77   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
78   static const unsigned TooBig = 1u << NumBitsPerType;
79 
80 public:
81   /// Hash values for AST nodes.
82   ///
83   /// Distinct values for AST nodes that have region counters attached.
84   ///
85   /// These values must be stable.  All new members must be added at the end,
86   /// and no members should be removed.  Changing the enumeration value for an
87   /// AST node will affect the hash of every function that contains that node.
88   enum HashType : unsigned char {
89     None = 0,
90     LabelStmt = 1,
91     WhileStmt,
92     DoStmt,
93     ForStmt,
94     CXXForRangeStmt,
95     ObjCForCollectionStmt,
96     SwitchStmt,
97     CaseStmt,
98     DefaultStmt,
99     IfStmt,
100     CXXTryStmt,
101     CXXCatchStmt,
102     ConditionalOperator,
103     BinaryOperatorLAnd,
104     BinaryOperatorLOr,
105     BinaryConditionalOperator,
106     // The preceding values are available with PGO_HASH_V1.
107 
108     EndOfScope,
109     IfThenBranch,
110     IfElseBranch,
111     GotoStmt,
112     IndirectGotoStmt,
113     BreakStmt,
114     ContinueStmt,
115     ReturnStmt,
116     ThrowExpr,
117     UnaryOperatorLNot,
118     BinaryOperatorLT,
119     BinaryOperatorGT,
120     BinaryOperatorLE,
121     BinaryOperatorGE,
122     BinaryOperatorEQ,
123     BinaryOperatorNE,
124     // The preceding values are available with PGO_HASH_V2.
125 
126     // Keep this last.  It's for the static assert that follows.
127     LastHashType
128   };
129   static_assert(LastHashType <= TooBig, "Too many types in HashType");
130 
131   PGOHash(PGOHashVersion HashVersion)
132       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
133   void combine(HashType Type);
134   uint64_t finalize();
135   PGOHashVersion getHashVersion() const { return HashVersion; }
136 };
137 const int PGOHash::NumBitsPerType;
138 const unsigned PGOHash::NumTypesPerWord;
139 const unsigned PGOHash::TooBig;
140 
141 /// Get the PGO hash version used in the given indexed profile.
142 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
143                                         CodeGenModule &CGM) {
144   if (PGOReader->getVersion() <= 4)
145     return PGO_HASH_V1;
146   return PGO_HASH_V2;
147 }
148 
149 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
150 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
151   using Base = RecursiveASTVisitor<MapRegionCounters>;
152 
153   /// The next counter value to assign.
154   unsigned NextCounter;
155   /// The function hash.
156   PGOHash Hash;
157   /// The map of statements to counters.
158   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
159 
160   MapRegionCounters(PGOHashVersion HashVersion,
161                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
162       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
163 
164   // Blocks and lambdas are handled as separate functions, so we need not
165   // traverse them in the parent context.
166   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
167   bool TraverseLambdaExpr(LambdaExpr *LE) {
168     // Traverse the captures, but not the body.
169     for (const auto &C : zip(LE->captures(), LE->capture_inits()))
170       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
171     return true;
172   }
173   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
174 
175   bool VisitDecl(const Decl *D) {
176     switch (D->getKind()) {
177     default:
178       break;
179     case Decl::Function:
180     case Decl::CXXMethod:
181     case Decl::CXXConstructor:
182     case Decl::CXXDestructor:
183     case Decl::CXXConversion:
184     case Decl::ObjCMethod:
185     case Decl::Block:
186     case Decl::Captured:
187       CounterMap[D->getBody()] = NextCounter++;
188       break;
189     }
190     return true;
191   }
192 
193   /// If \p S gets a fresh counter, update the counter mappings. Return the
194   /// V1 hash of \p S.
195   PGOHash::HashType updateCounterMappings(Stmt *S) {
196     auto Type = getHashType(PGO_HASH_V1, S);
197     if (Type != PGOHash::None)
198       CounterMap[S] = NextCounter++;
199     return Type;
200   }
201 
202   /// Include \p S in the function hash.
203   bool VisitStmt(Stmt *S) {
204     auto Type = updateCounterMappings(S);
205     if (Hash.getHashVersion() != PGO_HASH_V1)
206       Type = getHashType(Hash.getHashVersion(), S);
207     if (Type != PGOHash::None)
208       Hash.combine(Type);
209     return true;
210   }
211 
212   bool TraverseIfStmt(IfStmt *If) {
213     // If we used the V1 hash, use the default traversal.
214     if (Hash.getHashVersion() == PGO_HASH_V1)
215       return Base::TraverseIfStmt(If);
216 
217     // Otherwise, keep track of which branch we're in while traversing.
218     VisitStmt(If);
219     for (Stmt *CS : If->children()) {
220       if (!CS)
221         continue;
222       if (CS == If->getThen())
223         Hash.combine(PGOHash::IfThenBranch);
224       else if (CS == If->getElse())
225         Hash.combine(PGOHash::IfElseBranch);
226       TraverseStmt(CS);
227     }
228     Hash.combine(PGOHash::EndOfScope);
229     return true;
230   }
231 
232 // If the statement type \p N is nestable, and its nesting impacts profile
233 // stability, define a custom traversal which tracks the end of the statement
234 // in the hash (provided we're not using the V1 hash).
235 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
236   bool Traverse##N(N *S) {                                                     \
237     Base::Traverse##N(S);                                                      \
238     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
239       Hash.combine(PGOHash::EndOfScope);                                       \
240     return true;                                                               \
241   }
242 
243   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
244   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
245   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
246   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
247   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
248   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
249   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
250 
251   /// Get version \p HashVersion of the PGO hash for \p S.
252   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
253     switch (S->getStmtClass()) {
254     default:
255       break;
256     case Stmt::LabelStmtClass:
257       return PGOHash::LabelStmt;
258     case Stmt::WhileStmtClass:
259       return PGOHash::WhileStmt;
260     case Stmt::DoStmtClass:
261       return PGOHash::DoStmt;
262     case Stmt::ForStmtClass:
263       return PGOHash::ForStmt;
264     case Stmt::CXXForRangeStmtClass:
265       return PGOHash::CXXForRangeStmt;
266     case Stmt::ObjCForCollectionStmtClass:
267       return PGOHash::ObjCForCollectionStmt;
268     case Stmt::SwitchStmtClass:
269       return PGOHash::SwitchStmt;
270     case Stmt::CaseStmtClass:
271       return PGOHash::CaseStmt;
272     case Stmt::DefaultStmtClass:
273       return PGOHash::DefaultStmt;
274     case Stmt::IfStmtClass:
275       return PGOHash::IfStmt;
276     case Stmt::CXXTryStmtClass:
277       return PGOHash::CXXTryStmt;
278     case Stmt::CXXCatchStmtClass:
279       return PGOHash::CXXCatchStmt;
280     case Stmt::ConditionalOperatorClass:
281       return PGOHash::ConditionalOperator;
282     case Stmt::BinaryConditionalOperatorClass:
283       return PGOHash::BinaryConditionalOperator;
284     case Stmt::BinaryOperatorClass: {
285       const BinaryOperator *BO = cast<BinaryOperator>(S);
286       if (BO->getOpcode() == BO_LAnd)
287         return PGOHash::BinaryOperatorLAnd;
288       if (BO->getOpcode() == BO_LOr)
289         return PGOHash::BinaryOperatorLOr;
290       if (HashVersion == PGO_HASH_V2) {
291         switch (BO->getOpcode()) {
292         default:
293           break;
294         case BO_LT:
295           return PGOHash::BinaryOperatorLT;
296         case BO_GT:
297           return PGOHash::BinaryOperatorGT;
298         case BO_LE:
299           return PGOHash::BinaryOperatorLE;
300         case BO_GE:
301           return PGOHash::BinaryOperatorGE;
302         case BO_EQ:
303           return PGOHash::BinaryOperatorEQ;
304         case BO_NE:
305           return PGOHash::BinaryOperatorNE;
306         }
307       }
308       break;
309     }
310     }
311 
312     if (HashVersion == PGO_HASH_V2) {
313       switch (S->getStmtClass()) {
314       default:
315         break;
316       case Stmt::GotoStmtClass:
317         return PGOHash::GotoStmt;
318       case Stmt::IndirectGotoStmtClass:
319         return PGOHash::IndirectGotoStmt;
320       case Stmt::BreakStmtClass:
321         return PGOHash::BreakStmt;
322       case Stmt::ContinueStmtClass:
323         return PGOHash::ContinueStmt;
324       case Stmt::ReturnStmtClass:
325         return PGOHash::ReturnStmt;
326       case Stmt::CXXThrowExprClass:
327         return PGOHash::ThrowExpr;
328       case Stmt::UnaryOperatorClass: {
329         const UnaryOperator *UO = cast<UnaryOperator>(S);
330         if (UO->getOpcode() == UO_LNot)
331           return PGOHash::UnaryOperatorLNot;
332         break;
333       }
334       }
335     }
336 
337     return PGOHash::None;
338   }
339 };
340 
341 /// A StmtVisitor that propagates the raw counts through the AST and
342 /// records the count at statements where the value may change.
343 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
344   /// PGO state.
345   CodeGenPGO &PGO;
346 
347   /// A flag that is set when the current count should be recorded on the
348   /// next statement, such as at the exit of a loop.
349   bool RecordNextStmtCount;
350 
351   /// The count at the current location in the traversal.
352   uint64_t CurrentCount;
353 
354   /// The map of statements to count values.
355   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
356 
357   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
358   struct BreakContinue {
359     uint64_t BreakCount;
360     uint64_t ContinueCount;
361     BreakContinue() : BreakCount(0), ContinueCount(0) {}
362   };
363   SmallVector<BreakContinue, 8> BreakContinueStack;
364 
365   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
366                       CodeGenPGO &PGO)
367       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
368 
369   void RecordStmtCount(const Stmt *S) {
370     if (RecordNextStmtCount) {
371       CountMap[S] = CurrentCount;
372       RecordNextStmtCount = false;
373     }
374   }
375 
376   /// Set and return the current count.
377   uint64_t setCount(uint64_t Count) {
378     CurrentCount = Count;
379     return Count;
380   }
381 
382   void VisitStmt(const Stmt *S) {
383     RecordStmtCount(S);
384     for (const Stmt *Child : S->children())
385       if (Child)
386         this->Visit(Child);
387   }
388 
389   void VisitFunctionDecl(const FunctionDecl *D) {
390     // Counter tracks entry to the function body.
391     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
392     CountMap[D->getBody()] = BodyCount;
393     Visit(D->getBody());
394   }
395 
396   // Skip lambda expressions. We visit these as FunctionDecls when we're
397   // generating them and aren't interested in the body when generating a
398   // parent context.
399   void VisitLambdaExpr(const LambdaExpr *LE) {}
400 
401   void VisitCapturedDecl(const CapturedDecl *D) {
402     // Counter tracks entry to the capture body.
403     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
404     CountMap[D->getBody()] = BodyCount;
405     Visit(D->getBody());
406   }
407 
408   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
409     // Counter tracks entry to the method body.
410     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411     CountMap[D->getBody()] = BodyCount;
412     Visit(D->getBody());
413   }
414 
415   void VisitBlockDecl(const BlockDecl *D) {
416     // Counter tracks entry to the block body.
417     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
418     CountMap[D->getBody()] = BodyCount;
419     Visit(D->getBody());
420   }
421 
422   void VisitReturnStmt(const ReturnStmt *S) {
423     RecordStmtCount(S);
424     if (S->getRetValue())
425       Visit(S->getRetValue());
426     CurrentCount = 0;
427     RecordNextStmtCount = true;
428   }
429 
430   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
431     RecordStmtCount(E);
432     if (E->getSubExpr())
433       Visit(E->getSubExpr());
434     CurrentCount = 0;
435     RecordNextStmtCount = true;
436   }
437 
438   void VisitGotoStmt(const GotoStmt *S) {
439     RecordStmtCount(S);
440     CurrentCount = 0;
441     RecordNextStmtCount = true;
442   }
443 
444   void VisitLabelStmt(const LabelStmt *S) {
445     RecordNextStmtCount = false;
446     // Counter tracks the block following the label.
447     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
448     CountMap[S] = BlockCount;
449     Visit(S->getSubStmt());
450   }
451 
452   void VisitBreakStmt(const BreakStmt *S) {
453     RecordStmtCount(S);
454     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
455     BreakContinueStack.back().BreakCount += CurrentCount;
456     CurrentCount = 0;
457     RecordNextStmtCount = true;
458   }
459 
460   void VisitContinueStmt(const ContinueStmt *S) {
461     RecordStmtCount(S);
462     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
463     BreakContinueStack.back().ContinueCount += CurrentCount;
464     CurrentCount = 0;
465     RecordNextStmtCount = true;
466   }
467 
468   void VisitWhileStmt(const WhileStmt *S) {
469     RecordStmtCount(S);
470     uint64_t ParentCount = CurrentCount;
471 
472     BreakContinueStack.push_back(BreakContinue());
473     // Visit the body region first so the break/continue adjustments can be
474     // included when visiting the condition.
475     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
476     CountMap[S->getBody()] = CurrentCount;
477     Visit(S->getBody());
478     uint64_t BackedgeCount = CurrentCount;
479 
480     // ...then go back and propagate counts through the condition. The count
481     // at the start of the condition is the sum of the incoming edges,
482     // the backedge from the end of the loop body, and the edges from
483     // continue statements.
484     BreakContinue BC = BreakContinueStack.pop_back_val();
485     uint64_t CondCount =
486         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
487     CountMap[S->getCond()] = CondCount;
488     Visit(S->getCond());
489     setCount(BC.BreakCount + CondCount - BodyCount);
490     RecordNextStmtCount = true;
491   }
492 
493   void VisitDoStmt(const DoStmt *S) {
494     RecordStmtCount(S);
495     uint64_t LoopCount = PGO.getRegionCount(S);
496 
497     BreakContinueStack.push_back(BreakContinue());
498     // The count doesn't include the fallthrough from the parent scope. Add it.
499     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
500     CountMap[S->getBody()] = BodyCount;
501     Visit(S->getBody());
502     uint64_t BackedgeCount = CurrentCount;
503 
504     BreakContinue BC = BreakContinueStack.pop_back_val();
505     // The count at the start of the condition is equal to the count at the
506     // end of the body, plus any continues.
507     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
508     CountMap[S->getCond()] = CondCount;
509     Visit(S->getCond());
510     setCount(BC.BreakCount + CondCount - LoopCount);
511     RecordNextStmtCount = true;
512   }
513 
514   void VisitForStmt(const ForStmt *S) {
515     RecordStmtCount(S);
516     if (S->getInit())
517       Visit(S->getInit());
518 
519     uint64_t ParentCount = CurrentCount;
520 
521     BreakContinueStack.push_back(BreakContinue());
522     // Visit the body region first. (This is basically the same as a while
523     // loop; see further comments in VisitWhileStmt.)
524     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
525     CountMap[S->getBody()] = BodyCount;
526     Visit(S->getBody());
527     uint64_t BackedgeCount = CurrentCount;
528     BreakContinue BC = BreakContinueStack.pop_back_val();
529 
530     // The increment is essentially part of the body but it needs to include
531     // the count for all the continue statements.
532     if (S->getInc()) {
533       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
534       CountMap[S->getInc()] = IncCount;
535       Visit(S->getInc());
536     }
537 
538     // ...then go back and propagate counts through the condition.
539     uint64_t CondCount =
540         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
541     if (S->getCond()) {
542       CountMap[S->getCond()] = CondCount;
543       Visit(S->getCond());
544     }
545     setCount(BC.BreakCount + CondCount - BodyCount);
546     RecordNextStmtCount = true;
547   }
548 
549   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
550     RecordStmtCount(S);
551     if (S->getInit())
552       Visit(S->getInit());
553     Visit(S->getLoopVarStmt());
554     Visit(S->getRangeStmt());
555     Visit(S->getBeginStmt());
556     Visit(S->getEndStmt());
557 
558     uint64_t ParentCount = CurrentCount;
559     BreakContinueStack.push_back(BreakContinue());
560     // Visit the body region first. (This is basically the same as a while
561     // loop; see further comments in VisitWhileStmt.)
562     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
563     CountMap[S->getBody()] = BodyCount;
564     Visit(S->getBody());
565     uint64_t BackedgeCount = CurrentCount;
566     BreakContinue BC = BreakContinueStack.pop_back_val();
567 
568     // The increment is essentially part of the body but it needs to include
569     // the count for all the continue statements.
570     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
571     CountMap[S->getInc()] = IncCount;
572     Visit(S->getInc());
573 
574     // ...then go back and propagate counts through the condition.
575     uint64_t CondCount =
576         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
577     CountMap[S->getCond()] = CondCount;
578     Visit(S->getCond());
579     setCount(BC.BreakCount + CondCount - BodyCount);
580     RecordNextStmtCount = true;
581   }
582 
583   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
584     RecordStmtCount(S);
585     Visit(S->getElement());
586     uint64_t ParentCount = CurrentCount;
587     BreakContinueStack.push_back(BreakContinue());
588     // Counter tracks the body of the loop.
589     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
590     CountMap[S->getBody()] = BodyCount;
591     Visit(S->getBody());
592     uint64_t BackedgeCount = CurrentCount;
593     BreakContinue BC = BreakContinueStack.pop_back_val();
594 
595     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
596              BodyCount);
597     RecordNextStmtCount = true;
598   }
599 
600   void VisitSwitchStmt(const SwitchStmt *S) {
601     RecordStmtCount(S);
602     if (S->getInit())
603       Visit(S->getInit());
604     Visit(S->getCond());
605     CurrentCount = 0;
606     BreakContinueStack.push_back(BreakContinue());
607     Visit(S->getBody());
608     // If the switch is inside a loop, add the continue counts.
609     BreakContinue BC = BreakContinueStack.pop_back_val();
610     if (!BreakContinueStack.empty())
611       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
612     // Counter tracks the exit block of the switch.
613     setCount(PGO.getRegionCount(S));
614     RecordNextStmtCount = true;
615   }
616 
617   void VisitSwitchCase(const SwitchCase *S) {
618     RecordNextStmtCount = false;
619     // Counter for this particular case. This counts only jumps from the
620     // switch header and does not include fallthrough from the case before
621     // this one.
622     uint64_t CaseCount = PGO.getRegionCount(S);
623     setCount(CurrentCount + CaseCount);
624     // We need the count without fallthrough in the mapping, so it's more useful
625     // for branch probabilities.
626     CountMap[S] = CaseCount;
627     RecordNextStmtCount = true;
628     Visit(S->getSubStmt());
629   }
630 
631   void VisitIfStmt(const IfStmt *S) {
632     RecordStmtCount(S);
633     uint64_t ParentCount = CurrentCount;
634     if (S->getInit())
635       Visit(S->getInit());
636     Visit(S->getCond());
637 
638     // Counter tracks the "then" part of an if statement. The count for
639     // the "else" part, if it exists, will be calculated from this counter.
640     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
641     CountMap[S->getThen()] = ThenCount;
642     Visit(S->getThen());
643     uint64_t OutCount = CurrentCount;
644 
645     uint64_t ElseCount = ParentCount - ThenCount;
646     if (S->getElse()) {
647       setCount(ElseCount);
648       CountMap[S->getElse()] = ElseCount;
649       Visit(S->getElse());
650       OutCount += CurrentCount;
651     } else
652       OutCount += ElseCount;
653     setCount(OutCount);
654     RecordNextStmtCount = true;
655   }
656 
657   void VisitCXXTryStmt(const CXXTryStmt *S) {
658     RecordStmtCount(S);
659     Visit(S->getTryBlock());
660     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
661       Visit(S->getHandler(I));
662     // Counter tracks the continuation block of the try statement.
663     setCount(PGO.getRegionCount(S));
664     RecordNextStmtCount = true;
665   }
666 
667   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
668     RecordNextStmtCount = false;
669     // Counter tracks the catch statement's handler block.
670     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
671     CountMap[S] = CatchCount;
672     Visit(S->getHandlerBlock());
673   }
674 
675   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
676     RecordStmtCount(E);
677     uint64_t ParentCount = CurrentCount;
678     Visit(E->getCond());
679 
680     // Counter tracks the "true" part of a conditional operator. The
681     // count in the "false" part will be calculated from this counter.
682     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
683     CountMap[E->getTrueExpr()] = TrueCount;
684     Visit(E->getTrueExpr());
685     uint64_t OutCount = CurrentCount;
686 
687     uint64_t FalseCount = setCount(ParentCount - TrueCount);
688     CountMap[E->getFalseExpr()] = FalseCount;
689     Visit(E->getFalseExpr());
690     OutCount += CurrentCount;
691 
692     setCount(OutCount);
693     RecordNextStmtCount = true;
694   }
695 
696   void VisitBinLAnd(const BinaryOperator *E) {
697     RecordStmtCount(E);
698     uint64_t ParentCount = CurrentCount;
699     Visit(E->getLHS());
700     // Counter tracks the right hand side of a logical and operator.
701     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
702     CountMap[E->getRHS()] = RHSCount;
703     Visit(E->getRHS());
704     setCount(ParentCount + RHSCount - CurrentCount);
705     RecordNextStmtCount = true;
706   }
707 
708   void VisitBinLOr(const BinaryOperator *E) {
709     RecordStmtCount(E);
710     uint64_t ParentCount = CurrentCount;
711     Visit(E->getLHS());
712     // Counter tracks the right hand side of a logical or operator.
713     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
714     CountMap[E->getRHS()] = RHSCount;
715     Visit(E->getRHS());
716     setCount(ParentCount + RHSCount - CurrentCount);
717     RecordNextStmtCount = true;
718   }
719 };
720 } // end anonymous namespace
721 
722 void PGOHash::combine(HashType Type) {
723   // Check that we never combine 0 and only have six bits.
724   assert(Type && "Hash is invalid: unexpected type 0");
725   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
726 
727   // Pass through MD5 if enough work has built up.
728   if (Count && Count % NumTypesPerWord == 0) {
729     using namespace llvm::support;
730     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
731     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
732     Working = 0;
733   }
734 
735   // Accumulate the current type.
736   ++Count;
737   Working = Working << NumBitsPerType | Type;
738 }
739 
740 uint64_t PGOHash::finalize() {
741   // Use Working as the hash directly if we never used MD5.
742   if (Count <= NumTypesPerWord)
743     // No need to byte swap here, since none of the math was endian-dependent.
744     // This number will be byte-swapped as required on endianness transitions,
745     // so we will see the same value on the other side.
746     return Working;
747 
748   // Check for remaining work in Working.
749   if (Working)
750     MD5.update(Working);
751 
752   // Finalize the MD5 and return the hash.
753   llvm::MD5::MD5Result Result;
754   MD5.final(Result);
755   using namespace llvm::support;
756   return Result.low();
757 }
758 
759 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
760   const Decl *D = GD.getDecl();
761   if (!D->hasBody())
762     return;
763 
764   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
765   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
766   if (!InstrumentRegions && !PGOReader)
767     return;
768   if (D->isImplicit())
769     return;
770   // Constructors and destructors may be represented by several functions in IR.
771   // If so, instrument only base variant, others are implemented by delegation
772   // to the base one, it would be counted twice otherwise.
773   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
774     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
775       if (GD.getCtorType() != Ctor_Base &&
776           CodeGenFunction::IsConstructorDelegationValid(CCD))
777         return;
778   }
779   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
780     return;
781 
782   CGM.ClearUnusedCoverageMapping(D);
783   setFuncName(Fn);
784 
785   mapRegionCounters(D);
786   if (CGM.getCodeGenOpts().CoverageMapping)
787     emitCounterRegionMapping(D);
788   if (PGOReader) {
789     SourceManager &SM = CGM.getContext().getSourceManager();
790     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
791     computeRegionCounts(D);
792     applyFunctionAttributes(PGOReader, Fn);
793   }
794 }
795 
796 void CodeGenPGO::mapRegionCounters(const Decl *D) {
797   // Use the latest hash version when inserting instrumentation, but use the
798   // version in the indexed profile if we're reading PGO data.
799   PGOHashVersion HashVersion = PGO_HASH_LATEST;
800   if (auto *PGOReader = CGM.getPGOReader())
801     HashVersion = getPGOHashVersion(PGOReader, CGM);
802 
803   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
804   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
805   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
806     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
807   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
808     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
809   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
810     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
811   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
812     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
813   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
814   NumRegionCounters = Walker.NextCounter;
815   FunctionHash = Walker.Hash.finalize();
816 }
817 
818 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
819   if (!D->getBody())
820     return true;
821 
822   // Don't map the functions in system headers.
823   const auto &SM = CGM.getContext().getSourceManager();
824   auto Loc = D->getBody()->getBeginLoc();
825   return SM.isInSystemHeader(Loc);
826 }
827 
828 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
829   if (skipRegionMappingForDecl(D))
830     return;
831 
832   std::string CoverageMapping;
833   llvm::raw_string_ostream OS(CoverageMapping);
834   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
835                                 CGM.getContext().getSourceManager(),
836                                 CGM.getLangOpts(), RegionCounterMap.get());
837   MappingGen.emitCounterMapping(D, OS);
838   OS.flush();
839 
840   if (CoverageMapping.empty())
841     return;
842 
843   CGM.getCoverageMapping()->addFunctionMappingRecord(
844       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
845 }
846 
847 void
848 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
849                                     llvm::GlobalValue::LinkageTypes Linkage) {
850   if (skipRegionMappingForDecl(D))
851     return;
852 
853   std::string CoverageMapping;
854   llvm::raw_string_ostream OS(CoverageMapping);
855   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
856                                 CGM.getContext().getSourceManager(),
857                                 CGM.getLangOpts());
858   MappingGen.emitEmptyMapping(D, OS);
859   OS.flush();
860 
861   if (CoverageMapping.empty())
862     return;
863 
864   setFuncName(Name, Linkage);
865   CGM.getCoverageMapping()->addFunctionMappingRecord(
866       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
867 }
868 
869 void CodeGenPGO::computeRegionCounts(const Decl *D) {
870   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
871   ComputeRegionCounts Walker(*StmtCountMap, *this);
872   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
873     Walker.VisitFunctionDecl(FD);
874   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
875     Walker.VisitObjCMethodDecl(MD);
876   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
877     Walker.VisitBlockDecl(BD);
878   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
879     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
880 }
881 
882 void
883 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
884                                     llvm::Function *Fn) {
885   if (!haveRegionCounts())
886     return;
887 
888   uint64_t FunctionCount = getRegionCount(nullptr);
889   Fn->setEntryCount(FunctionCount);
890 }
891 
892 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
893                                       llvm::Value *StepV) {
894   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
895     return;
896   if (!Builder.GetInsertBlock())
897     return;
898 
899   unsigned Counter = (*RegionCounterMap)[S];
900   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
901 
902   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
903                          Builder.getInt64(FunctionHash),
904                          Builder.getInt32(NumRegionCounters),
905                          Builder.getInt32(Counter), StepV};
906   if (!StepV)
907     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
908                        makeArrayRef(Args, 4));
909   else
910     Builder.CreateCall(
911         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
912         makeArrayRef(Args));
913 }
914 
915 // This method either inserts a call to the profile run-time during
916 // instrumentation or puts profile data into metadata for PGO use.
917 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
918     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
919 
920   if (!EnableValueProfiling)
921     return;
922 
923   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
924     return;
925 
926   if (isa<llvm::Constant>(ValuePtr))
927     return;
928 
929   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
930   if (InstrumentValueSites && RegionCounterMap) {
931     auto BuilderInsertPoint = Builder.saveIP();
932     Builder.SetInsertPoint(ValueSite);
933     llvm::Value *Args[5] = {
934         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
935         Builder.getInt64(FunctionHash),
936         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
937         Builder.getInt32(ValueKind),
938         Builder.getInt32(NumValueSites[ValueKind]++)
939     };
940     Builder.CreateCall(
941         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
942     Builder.restoreIP(BuilderInsertPoint);
943     return;
944   }
945 
946   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
947   if (PGOReader && haveRegionCounts()) {
948     // We record the top most called three functions at each call site.
949     // Profile metadata contains "VP" string identifying this metadata
950     // as value profiling data, then a uint32_t value for the value profiling
951     // kind, a uint64_t value for the total number of times the call is
952     // executed, followed by the function hash and execution count (uint64_t)
953     // pairs for each function.
954     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
955       return;
956 
957     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
958                             (llvm::InstrProfValueKind)ValueKind,
959                             NumValueSites[ValueKind]);
960 
961     NumValueSites[ValueKind]++;
962   }
963 }
964 
965 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
966                                   bool IsInMainFile) {
967   CGM.getPGOStats().addVisited(IsInMainFile);
968   RegionCounts.clear();
969   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
970       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
971   if (auto E = RecordExpected.takeError()) {
972     auto IPE = llvm::InstrProfError::take(std::move(E));
973     if (IPE == llvm::instrprof_error::unknown_function)
974       CGM.getPGOStats().addMissing(IsInMainFile);
975     else if (IPE == llvm::instrprof_error::hash_mismatch)
976       CGM.getPGOStats().addMismatched(IsInMainFile);
977     else if (IPE == llvm::instrprof_error::malformed)
978       // TODO: Consider a more specific warning for this case.
979       CGM.getPGOStats().addMismatched(IsInMainFile);
980     return;
981   }
982   ProfRecord =
983       llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
984   RegionCounts = ProfRecord->Counts;
985 }
986 
987 /// Calculate what to divide by to scale weights.
988 ///
989 /// Given the maximum weight, calculate a divisor that will scale all the
990 /// weights to strictly less than UINT32_MAX.
991 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
992   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
993 }
994 
995 /// Scale an individual branch weight (and add 1).
996 ///
997 /// Scale a 64-bit weight down to 32-bits using \c Scale.
998 ///
999 /// According to Laplace's Rule of Succession, it is better to compute the
1000 /// weight based on the count plus 1, so universally add 1 to the value.
1001 ///
1002 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1003 /// greater than \c Weight.
1004 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1005   assert(Scale && "scale by 0?");
1006   uint64_t Scaled = Weight / Scale + 1;
1007   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1008   return Scaled;
1009 }
1010 
1011 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1012                                                     uint64_t FalseCount) {
1013   // Check for empty weights.
1014   if (!TrueCount && !FalseCount)
1015     return nullptr;
1016 
1017   // Calculate how to scale down to 32-bits.
1018   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1019 
1020   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1021   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1022                                       scaleBranchWeight(FalseCount, Scale));
1023 }
1024 
1025 llvm::MDNode *
1026 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1027   // We need at least two elements to create meaningful weights.
1028   if (Weights.size() < 2)
1029     return nullptr;
1030 
1031   // Check for empty weights.
1032   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1033   if (MaxWeight == 0)
1034     return nullptr;
1035 
1036   // Calculate how to scale down to 32-bits.
1037   uint64_t Scale = calculateWeightScale(MaxWeight);
1038 
1039   SmallVector<uint32_t, 16> ScaledWeights;
1040   ScaledWeights.reserve(Weights.size());
1041   for (uint64_t W : Weights)
1042     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1043 
1044   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1045   return MDHelper.createBranchWeights(ScaledWeights);
1046 }
1047 
1048 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1049                                                            uint64_t LoopCount) {
1050   if (!PGO.haveRegionCounts())
1051     return nullptr;
1052   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1053   assert(CondCount.hasValue() && "missing expected loop condition count");
1054   if (*CondCount == 0)
1055     return nullptr;
1056   return createProfileWeights(LoopCount,
1057                               std::max(*CondCount, LoopCount) - LoopCount);
1058 }
1059