xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CodeGenPGO.cpp (revision b4af4f93c682e445bf159f0d1ec90b636296c946)
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55 
56   // Keep this set to the latest hash version.
57   PGO_HASH_LATEST = PGO_HASH_V2
58 };
59 
60 namespace {
61 /// Stable hasher for PGO region counters.
62 ///
63 /// PGOHash produces a stable hash of a given function's control flow.
64 ///
65 /// Changing the output of this hash will invalidate all previously generated
66 /// profiles -- i.e., don't do it.
67 ///
68 /// \note  When this hash does eventually change (years?), we still need to
69 /// support old hashes.  We'll need to pull in the version number from the
70 /// profile data format and use the matching hash function.
71 class PGOHash {
72   uint64_t Working;
73   unsigned Count;
74   PGOHashVersion HashVersion;
75   llvm::MD5 MD5;
76 
77   static const int NumBitsPerType = 6;
78   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79   static const unsigned TooBig = 1u << NumBitsPerType;
80 
81 public:
82   /// Hash values for AST nodes.
83   ///
84   /// Distinct values for AST nodes that have region counters attached.
85   ///
86   /// These values must be stable.  All new members must be added at the end,
87   /// and no members should be removed.  Changing the enumeration value for an
88   /// AST node will affect the hash of every function that contains that node.
89   enum HashType : unsigned char {
90     None = 0,
91     LabelStmt = 1,
92     WhileStmt,
93     DoStmt,
94     ForStmt,
95     CXXForRangeStmt,
96     ObjCForCollectionStmt,
97     SwitchStmt,
98     CaseStmt,
99     DefaultStmt,
100     IfStmt,
101     CXXTryStmt,
102     CXXCatchStmt,
103     ConditionalOperator,
104     BinaryOperatorLAnd,
105     BinaryOperatorLOr,
106     BinaryConditionalOperator,
107     // The preceding values are available with PGO_HASH_V1.
108 
109     EndOfScope,
110     IfThenBranch,
111     IfElseBranch,
112     GotoStmt,
113     IndirectGotoStmt,
114     BreakStmt,
115     ContinueStmt,
116     ReturnStmt,
117     ThrowExpr,
118     UnaryOperatorLNot,
119     BinaryOperatorLT,
120     BinaryOperatorGT,
121     BinaryOperatorLE,
122     BinaryOperatorGE,
123     BinaryOperatorEQ,
124     BinaryOperatorNE,
125     // The preceding values are available with PGO_HASH_V2.
126 
127     // Keep this last.  It's for the static assert that follows.
128     LastHashType
129   };
130   static_assert(LastHashType <= TooBig, "Too many types in HashType");
131 
132   PGOHash(PGOHashVersion HashVersion)
133       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134   void combine(HashType Type);
135   uint64_t finalize();
136   PGOHashVersion getHashVersion() const { return HashVersion; }
137 };
138 const int PGOHash::NumBitsPerType;
139 const unsigned PGOHash::NumTypesPerWord;
140 const unsigned PGOHash::TooBig;
141 
142 /// Get the PGO hash version used in the given indexed profile.
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144                                         CodeGenModule &CGM) {
145   if (PGOReader->getVersion() <= 4)
146     return PGO_HASH_V1;
147   return PGO_HASH_V2;
148 }
149 
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152   using Base = RecursiveASTVisitor<MapRegionCounters>;
153 
154   /// The next counter value to assign.
155   unsigned NextCounter;
156   /// The function hash.
157   PGOHash Hash;
158   /// The map of statements to counters.
159   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160 
161   MapRegionCounters(PGOHashVersion HashVersion,
162                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164 
165   // Blocks and lambdas are handled as separate functions, so we need not
166   // traverse them in the parent context.
167   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168   bool TraverseLambdaExpr(LambdaExpr *LE) {
169     // Traverse the captures, but not the body.
170     for (auto C : zip(LE->captures(), LE->capture_inits()))
171       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
172     return true;
173   }
174   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
175 
176   bool VisitDecl(const Decl *D) {
177     switch (D->getKind()) {
178     default:
179       break;
180     case Decl::Function:
181     case Decl::CXXMethod:
182     case Decl::CXXConstructor:
183     case Decl::CXXDestructor:
184     case Decl::CXXConversion:
185     case Decl::ObjCMethod:
186     case Decl::Block:
187     case Decl::Captured:
188       CounterMap[D->getBody()] = NextCounter++;
189       break;
190     }
191     return true;
192   }
193 
194   /// If \p S gets a fresh counter, update the counter mappings. Return the
195   /// V1 hash of \p S.
196   PGOHash::HashType updateCounterMappings(Stmt *S) {
197     auto Type = getHashType(PGO_HASH_V1, S);
198     if (Type != PGOHash::None)
199       CounterMap[S] = NextCounter++;
200     return Type;
201   }
202 
203   /// Include \p S in the function hash.
204   bool VisitStmt(Stmt *S) {
205     auto Type = updateCounterMappings(S);
206     if (Hash.getHashVersion() != PGO_HASH_V1)
207       Type = getHashType(Hash.getHashVersion(), S);
208     if (Type != PGOHash::None)
209       Hash.combine(Type);
210     return true;
211   }
212 
213   bool TraverseIfStmt(IfStmt *If) {
214     // If we used the V1 hash, use the default traversal.
215     if (Hash.getHashVersion() == PGO_HASH_V1)
216       return Base::TraverseIfStmt(If);
217 
218     // Otherwise, keep track of which branch we're in while traversing.
219     VisitStmt(If);
220     for (Stmt *CS : If->children()) {
221       if (!CS)
222         continue;
223       if (CS == If->getThen())
224         Hash.combine(PGOHash::IfThenBranch);
225       else if (CS == If->getElse())
226         Hash.combine(PGOHash::IfElseBranch);
227       TraverseStmt(CS);
228     }
229     Hash.combine(PGOHash::EndOfScope);
230     return true;
231   }
232 
233 // If the statement type \p N is nestable, and its nesting impacts profile
234 // stability, define a custom traversal which tracks the end of the statement
235 // in the hash (provided we're not using the V1 hash).
236 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
237   bool Traverse##N(N *S) {                                                     \
238     Base::Traverse##N(S);                                                      \
239     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240       Hash.combine(PGOHash::EndOfScope);                                       \
241     return true;                                                               \
242   }
243 
244   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
245   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
246   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
247   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
248   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
249   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
250   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
251 
252   /// Get version \p HashVersion of the PGO hash for \p S.
253   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
254     switch (S->getStmtClass()) {
255     default:
256       break;
257     case Stmt::LabelStmtClass:
258       return PGOHash::LabelStmt;
259     case Stmt::WhileStmtClass:
260       return PGOHash::WhileStmt;
261     case Stmt::DoStmtClass:
262       return PGOHash::DoStmt;
263     case Stmt::ForStmtClass:
264       return PGOHash::ForStmt;
265     case Stmt::CXXForRangeStmtClass:
266       return PGOHash::CXXForRangeStmt;
267     case Stmt::ObjCForCollectionStmtClass:
268       return PGOHash::ObjCForCollectionStmt;
269     case Stmt::SwitchStmtClass:
270       return PGOHash::SwitchStmt;
271     case Stmt::CaseStmtClass:
272       return PGOHash::CaseStmt;
273     case Stmt::DefaultStmtClass:
274       return PGOHash::DefaultStmt;
275     case Stmt::IfStmtClass:
276       return PGOHash::IfStmt;
277     case Stmt::CXXTryStmtClass:
278       return PGOHash::CXXTryStmt;
279     case Stmt::CXXCatchStmtClass:
280       return PGOHash::CXXCatchStmt;
281     case Stmt::ConditionalOperatorClass:
282       return PGOHash::ConditionalOperator;
283     case Stmt::BinaryConditionalOperatorClass:
284       return PGOHash::BinaryConditionalOperator;
285     case Stmt::BinaryOperatorClass: {
286       const BinaryOperator *BO = cast<BinaryOperator>(S);
287       if (BO->getOpcode() == BO_LAnd)
288         return PGOHash::BinaryOperatorLAnd;
289       if (BO->getOpcode() == BO_LOr)
290         return PGOHash::BinaryOperatorLOr;
291       if (HashVersion == PGO_HASH_V2) {
292         switch (BO->getOpcode()) {
293         default:
294           break;
295         case BO_LT:
296           return PGOHash::BinaryOperatorLT;
297         case BO_GT:
298           return PGOHash::BinaryOperatorGT;
299         case BO_LE:
300           return PGOHash::BinaryOperatorLE;
301         case BO_GE:
302           return PGOHash::BinaryOperatorGE;
303         case BO_EQ:
304           return PGOHash::BinaryOperatorEQ;
305         case BO_NE:
306           return PGOHash::BinaryOperatorNE;
307         }
308       }
309       break;
310     }
311     }
312 
313     if (HashVersion == PGO_HASH_V2) {
314       switch (S->getStmtClass()) {
315       default:
316         break;
317       case Stmt::GotoStmtClass:
318         return PGOHash::GotoStmt;
319       case Stmt::IndirectGotoStmtClass:
320         return PGOHash::IndirectGotoStmt;
321       case Stmt::BreakStmtClass:
322         return PGOHash::BreakStmt;
323       case Stmt::ContinueStmtClass:
324         return PGOHash::ContinueStmt;
325       case Stmt::ReturnStmtClass:
326         return PGOHash::ReturnStmt;
327       case Stmt::CXXThrowExprClass:
328         return PGOHash::ThrowExpr;
329       case Stmt::UnaryOperatorClass: {
330         const UnaryOperator *UO = cast<UnaryOperator>(S);
331         if (UO->getOpcode() == UO_LNot)
332           return PGOHash::UnaryOperatorLNot;
333         break;
334       }
335       }
336     }
337 
338     return PGOHash::None;
339   }
340 };
341 
342 /// A StmtVisitor that propagates the raw counts through the AST and
343 /// records the count at statements where the value may change.
344 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
345   /// PGO state.
346   CodeGenPGO &PGO;
347 
348   /// A flag that is set when the current count should be recorded on the
349   /// next statement, such as at the exit of a loop.
350   bool RecordNextStmtCount;
351 
352   /// The count at the current location in the traversal.
353   uint64_t CurrentCount;
354 
355   /// The map of statements to count values.
356   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
357 
358   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
359   struct BreakContinue {
360     uint64_t BreakCount;
361     uint64_t ContinueCount;
362     BreakContinue() : BreakCount(0), ContinueCount(0) {}
363   };
364   SmallVector<BreakContinue, 8> BreakContinueStack;
365 
366   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
367                       CodeGenPGO &PGO)
368       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
369 
370   void RecordStmtCount(const Stmt *S) {
371     if (RecordNextStmtCount) {
372       CountMap[S] = CurrentCount;
373       RecordNextStmtCount = false;
374     }
375   }
376 
377   /// Set and return the current count.
378   uint64_t setCount(uint64_t Count) {
379     CurrentCount = Count;
380     return Count;
381   }
382 
383   void VisitStmt(const Stmt *S) {
384     RecordStmtCount(S);
385     for (const Stmt *Child : S->children())
386       if (Child)
387         this->Visit(Child);
388   }
389 
390   void VisitFunctionDecl(const FunctionDecl *D) {
391     // Counter tracks entry to the function body.
392     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
393     CountMap[D->getBody()] = BodyCount;
394     Visit(D->getBody());
395   }
396 
397   // Skip lambda expressions. We visit these as FunctionDecls when we're
398   // generating them and aren't interested in the body when generating a
399   // parent context.
400   void VisitLambdaExpr(const LambdaExpr *LE) {}
401 
402   void VisitCapturedDecl(const CapturedDecl *D) {
403     // Counter tracks entry to the capture body.
404     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
405     CountMap[D->getBody()] = BodyCount;
406     Visit(D->getBody());
407   }
408 
409   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410     // Counter tracks entry to the method body.
411     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412     CountMap[D->getBody()] = BodyCount;
413     Visit(D->getBody());
414   }
415 
416   void VisitBlockDecl(const BlockDecl *D) {
417     // Counter tracks entry to the block body.
418     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
419     CountMap[D->getBody()] = BodyCount;
420     Visit(D->getBody());
421   }
422 
423   void VisitReturnStmt(const ReturnStmt *S) {
424     RecordStmtCount(S);
425     if (S->getRetValue())
426       Visit(S->getRetValue());
427     CurrentCount = 0;
428     RecordNextStmtCount = true;
429   }
430 
431   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
432     RecordStmtCount(E);
433     if (E->getSubExpr())
434       Visit(E->getSubExpr());
435     CurrentCount = 0;
436     RecordNextStmtCount = true;
437   }
438 
439   void VisitGotoStmt(const GotoStmt *S) {
440     RecordStmtCount(S);
441     CurrentCount = 0;
442     RecordNextStmtCount = true;
443   }
444 
445   void VisitLabelStmt(const LabelStmt *S) {
446     RecordNextStmtCount = false;
447     // Counter tracks the block following the label.
448     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
449     CountMap[S] = BlockCount;
450     Visit(S->getSubStmt());
451   }
452 
453   void VisitBreakStmt(const BreakStmt *S) {
454     RecordStmtCount(S);
455     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
456     BreakContinueStack.back().BreakCount += CurrentCount;
457     CurrentCount = 0;
458     RecordNextStmtCount = true;
459   }
460 
461   void VisitContinueStmt(const ContinueStmt *S) {
462     RecordStmtCount(S);
463     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
464     BreakContinueStack.back().ContinueCount += CurrentCount;
465     CurrentCount = 0;
466     RecordNextStmtCount = true;
467   }
468 
469   void VisitWhileStmt(const WhileStmt *S) {
470     RecordStmtCount(S);
471     uint64_t ParentCount = CurrentCount;
472 
473     BreakContinueStack.push_back(BreakContinue());
474     // Visit the body region first so the break/continue adjustments can be
475     // included when visiting the condition.
476     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
477     CountMap[S->getBody()] = CurrentCount;
478     Visit(S->getBody());
479     uint64_t BackedgeCount = CurrentCount;
480 
481     // ...then go back and propagate counts through the condition. The count
482     // at the start of the condition is the sum of the incoming edges,
483     // the backedge from the end of the loop body, and the edges from
484     // continue statements.
485     BreakContinue BC = BreakContinueStack.pop_back_val();
486     uint64_t CondCount =
487         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
488     CountMap[S->getCond()] = CondCount;
489     Visit(S->getCond());
490     setCount(BC.BreakCount + CondCount - BodyCount);
491     RecordNextStmtCount = true;
492   }
493 
494   void VisitDoStmt(const DoStmt *S) {
495     RecordStmtCount(S);
496     uint64_t LoopCount = PGO.getRegionCount(S);
497 
498     BreakContinueStack.push_back(BreakContinue());
499     // The count doesn't include the fallthrough from the parent scope. Add it.
500     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
501     CountMap[S->getBody()] = BodyCount;
502     Visit(S->getBody());
503     uint64_t BackedgeCount = CurrentCount;
504 
505     BreakContinue BC = BreakContinueStack.pop_back_val();
506     // The count at the start of the condition is equal to the count at the
507     // end of the body, plus any continues.
508     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
509     CountMap[S->getCond()] = CondCount;
510     Visit(S->getCond());
511     setCount(BC.BreakCount + CondCount - LoopCount);
512     RecordNextStmtCount = true;
513   }
514 
515   void VisitForStmt(const ForStmt *S) {
516     RecordStmtCount(S);
517     if (S->getInit())
518       Visit(S->getInit());
519 
520     uint64_t ParentCount = CurrentCount;
521 
522     BreakContinueStack.push_back(BreakContinue());
523     // Visit the body region first. (This is basically the same as a while
524     // loop; see further comments in VisitWhileStmt.)
525     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
526     CountMap[S->getBody()] = BodyCount;
527     Visit(S->getBody());
528     uint64_t BackedgeCount = CurrentCount;
529     BreakContinue BC = BreakContinueStack.pop_back_val();
530 
531     // The increment is essentially part of the body but it needs to include
532     // the count for all the continue statements.
533     if (S->getInc()) {
534       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
535       CountMap[S->getInc()] = IncCount;
536       Visit(S->getInc());
537     }
538 
539     // ...then go back and propagate counts through the condition.
540     uint64_t CondCount =
541         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
542     if (S->getCond()) {
543       CountMap[S->getCond()] = CondCount;
544       Visit(S->getCond());
545     }
546     setCount(BC.BreakCount + CondCount - BodyCount);
547     RecordNextStmtCount = true;
548   }
549 
550   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
551     RecordStmtCount(S);
552     if (S->getInit())
553       Visit(S->getInit());
554     Visit(S->getLoopVarStmt());
555     Visit(S->getRangeStmt());
556     Visit(S->getBeginStmt());
557     Visit(S->getEndStmt());
558 
559     uint64_t ParentCount = CurrentCount;
560     BreakContinueStack.push_back(BreakContinue());
561     // Visit the body region first. (This is basically the same as a while
562     // loop; see further comments in VisitWhileStmt.)
563     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
564     CountMap[S->getBody()] = BodyCount;
565     Visit(S->getBody());
566     uint64_t BackedgeCount = CurrentCount;
567     BreakContinue BC = BreakContinueStack.pop_back_val();
568 
569     // The increment is essentially part of the body but it needs to include
570     // the count for all the continue statements.
571     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
572     CountMap[S->getInc()] = IncCount;
573     Visit(S->getInc());
574 
575     // ...then go back and propagate counts through the condition.
576     uint64_t CondCount =
577         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
578     CountMap[S->getCond()] = CondCount;
579     Visit(S->getCond());
580     setCount(BC.BreakCount + CondCount - BodyCount);
581     RecordNextStmtCount = true;
582   }
583 
584   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
585     RecordStmtCount(S);
586     Visit(S->getElement());
587     uint64_t ParentCount = CurrentCount;
588     BreakContinueStack.push_back(BreakContinue());
589     // Counter tracks the body of the loop.
590     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
591     CountMap[S->getBody()] = BodyCount;
592     Visit(S->getBody());
593     uint64_t BackedgeCount = CurrentCount;
594     BreakContinue BC = BreakContinueStack.pop_back_val();
595 
596     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
597              BodyCount);
598     RecordNextStmtCount = true;
599   }
600 
601   void VisitSwitchStmt(const SwitchStmt *S) {
602     RecordStmtCount(S);
603     if (S->getInit())
604       Visit(S->getInit());
605     Visit(S->getCond());
606     CurrentCount = 0;
607     BreakContinueStack.push_back(BreakContinue());
608     Visit(S->getBody());
609     // If the switch is inside a loop, add the continue counts.
610     BreakContinue BC = BreakContinueStack.pop_back_val();
611     if (!BreakContinueStack.empty())
612       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
613     // Counter tracks the exit block of the switch.
614     setCount(PGO.getRegionCount(S));
615     RecordNextStmtCount = true;
616   }
617 
618   void VisitSwitchCase(const SwitchCase *S) {
619     RecordNextStmtCount = false;
620     // Counter for this particular case. This counts only jumps from the
621     // switch header and does not include fallthrough from the case before
622     // this one.
623     uint64_t CaseCount = PGO.getRegionCount(S);
624     setCount(CurrentCount + CaseCount);
625     // We need the count without fallthrough in the mapping, so it's more useful
626     // for branch probabilities.
627     CountMap[S] = CaseCount;
628     RecordNextStmtCount = true;
629     Visit(S->getSubStmt());
630   }
631 
632   void VisitIfStmt(const IfStmt *S) {
633     RecordStmtCount(S);
634     uint64_t ParentCount = CurrentCount;
635     if (S->getInit())
636       Visit(S->getInit());
637     Visit(S->getCond());
638 
639     // Counter tracks the "then" part of an if statement. The count for
640     // the "else" part, if it exists, will be calculated from this counter.
641     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
642     CountMap[S->getThen()] = ThenCount;
643     Visit(S->getThen());
644     uint64_t OutCount = CurrentCount;
645 
646     uint64_t ElseCount = ParentCount - ThenCount;
647     if (S->getElse()) {
648       setCount(ElseCount);
649       CountMap[S->getElse()] = ElseCount;
650       Visit(S->getElse());
651       OutCount += CurrentCount;
652     } else
653       OutCount += ElseCount;
654     setCount(OutCount);
655     RecordNextStmtCount = true;
656   }
657 
658   void VisitCXXTryStmt(const CXXTryStmt *S) {
659     RecordStmtCount(S);
660     Visit(S->getTryBlock());
661     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
662       Visit(S->getHandler(I));
663     // Counter tracks the continuation block of the try statement.
664     setCount(PGO.getRegionCount(S));
665     RecordNextStmtCount = true;
666   }
667 
668   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
669     RecordNextStmtCount = false;
670     // Counter tracks the catch statement's handler block.
671     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
672     CountMap[S] = CatchCount;
673     Visit(S->getHandlerBlock());
674   }
675 
676   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
677     RecordStmtCount(E);
678     uint64_t ParentCount = CurrentCount;
679     Visit(E->getCond());
680 
681     // Counter tracks the "true" part of a conditional operator. The
682     // count in the "false" part will be calculated from this counter.
683     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
684     CountMap[E->getTrueExpr()] = TrueCount;
685     Visit(E->getTrueExpr());
686     uint64_t OutCount = CurrentCount;
687 
688     uint64_t FalseCount = setCount(ParentCount - TrueCount);
689     CountMap[E->getFalseExpr()] = FalseCount;
690     Visit(E->getFalseExpr());
691     OutCount += CurrentCount;
692 
693     setCount(OutCount);
694     RecordNextStmtCount = true;
695   }
696 
697   void VisitBinLAnd(const BinaryOperator *E) {
698     RecordStmtCount(E);
699     uint64_t ParentCount = CurrentCount;
700     Visit(E->getLHS());
701     // Counter tracks the right hand side of a logical and operator.
702     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
703     CountMap[E->getRHS()] = RHSCount;
704     Visit(E->getRHS());
705     setCount(ParentCount + RHSCount - CurrentCount);
706     RecordNextStmtCount = true;
707   }
708 
709   void VisitBinLOr(const BinaryOperator *E) {
710     RecordStmtCount(E);
711     uint64_t ParentCount = CurrentCount;
712     Visit(E->getLHS());
713     // Counter tracks the right hand side of a logical or operator.
714     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
715     CountMap[E->getRHS()] = RHSCount;
716     Visit(E->getRHS());
717     setCount(ParentCount + RHSCount - CurrentCount);
718     RecordNextStmtCount = true;
719   }
720 };
721 } // end anonymous namespace
722 
723 void PGOHash::combine(HashType Type) {
724   // Check that we never combine 0 and only have six bits.
725   assert(Type && "Hash is invalid: unexpected type 0");
726   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
727 
728   // Pass through MD5 if enough work has built up.
729   if (Count && Count % NumTypesPerWord == 0) {
730     using namespace llvm::support;
731     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
732     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
733     Working = 0;
734   }
735 
736   // Accumulate the current type.
737   ++Count;
738   Working = Working << NumBitsPerType | Type;
739 }
740 
741 uint64_t PGOHash::finalize() {
742   // Use Working as the hash directly if we never used MD5.
743   if (Count <= NumTypesPerWord)
744     // No need to byte swap here, since none of the math was endian-dependent.
745     // This number will be byte-swapped as required on endianness transitions,
746     // so we will see the same value on the other side.
747     return Working;
748 
749   // Check for remaining work in Working.
750   if (Working)
751     MD5.update(Working);
752 
753   // Finalize the MD5 and return the hash.
754   llvm::MD5::MD5Result Result;
755   MD5.final(Result);
756   using namespace llvm::support;
757   return Result.low();
758 }
759 
760 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
761   const Decl *D = GD.getDecl();
762   if (!D->hasBody())
763     return;
764 
765   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
766   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
767   if (!InstrumentRegions && !PGOReader)
768     return;
769   if (D->isImplicit())
770     return;
771   // Constructors and destructors may be represented by several functions in IR.
772   // If so, instrument only base variant, others are implemented by delegation
773   // to the base one, it would be counted twice otherwise.
774   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
775     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
776       if (GD.getCtorType() != Ctor_Base &&
777           CodeGenFunction::IsConstructorDelegationValid(CCD))
778         return;
779   }
780   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
781     return;
782 
783   CGM.ClearUnusedCoverageMapping(D);
784   setFuncName(Fn);
785 
786   mapRegionCounters(D);
787   if (CGM.getCodeGenOpts().CoverageMapping)
788     emitCounterRegionMapping(D);
789   if (PGOReader) {
790     SourceManager &SM = CGM.getContext().getSourceManager();
791     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
792     computeRegionCounts(D);
793     applyFunctionAttributes(PGOReader, Fn);
794   }
795 }
796 
797 void CodeGenPGO::mapRegionCounters(const Decl *D) {
798   // Use the latest hash version when inserting instrumentation, but use the
799   // version in the indexed profile if we're reading PGO data.
800   PGOHashVersion HashVersion = PGO_HASH_LATEST;
801   if (auto *PGOReader = CGM.getPGOReader())
802     HashVersion = getPGOHashVersion(PGOReader, CGM);
803 
804   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
805   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
806   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
807     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
808   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
809     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
810   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
811     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
812   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
813     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
814   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
815   NumRegionCounters = Walker.NextCounter;
816   FunctionHash = Walker.Hash.finalize();
817 }
818 
819 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
820   if (!D->getBody())
821     return true;
822 
823   // Don't map the functions in system headers.
824   const auto &SM = CGM.getContext().getSourceManager();
825   auto Loc = D->getBody()->getBeginLoc();
826   return SM.isInSystemHeader(Loc);
827 }
828 
829 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
830   if (skipRegionMappingForDecl(D))
831     return;
832 
833   std::string CoverageMapping;
834   llvm::raw_string_ostream OS(CoverageMapping);
835   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
836                                 CGM.getContext().getSourceManager(),
837                                 CGM.getLangOpts(), RegionCounterMap.get());
838   MappingGen.emitCounterMapping(D, OS);
839   OS.flush();
840 
841   if (CoverageMapping.empty())
842     return;
843 
844   CGM.getCoverageMapping()->addFunctionMappingRecord(
845       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
846 }
847 
848 void
849 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
850                                     llvm::GlobalValue::LinkageTypes Linkage) {
851   if (skipRegionMappingForDecl(D))
852     return;
853 
854   std::string CoverageMapping;
855   llvm::raw_string_ostream OS(CoverageMapping);
856   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
857                                 CGM.getContext().getSourceManager(),
858                                 CGM.getLangOpts());
859   MappingGen.emitEmptyMapping(D, OS);
860   OS.flush();
861 
862   if (CoverageMapping.empty())
863     return;
864 
865   setFuncName(Name, Linkage);
866   CGM.getCoverageMapping()->addFunctionMappingRecord(
867       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
868 }
869 
870 void CodeGenPGO::computeRegionCounts(const Decl *D) {
871   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
872   ComputeRegionCounts Walker(*StmtCountMap, *this);
873   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
874     Walker.VisitFunctionDecl(FD);
875   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
876     Walker.VisitObjCMethodDecl(MD);
877   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
878     Walker.VisitBlockDecl(BD);
879   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
880     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
881 }
882 
883 void
884 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
885                                     llvm::Function *Fn) {
886   if (!haveRegionCounts())
887     return;
888 
889   uint64_t FunctionCount = getRegionCount(nullptr);
890   Fn->setEntryCount(FunctionCount);
891 }
892 
893 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
894                                       llvm::Value *StepV) {
895   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
896     return;
897   if (!Builder.GetInsertBlock())
898     return;
899 
900   unsigned Counter = (*RegionCounterMap)[S];
901   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
902 
903   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
904                          Builder.getInt64(FunctionHash),
905                          Builder.getInt32(NumRegionCounters),
906                          Builder.getInt32(Counter), StepV};
907   if (!StepV)
908     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
909                        makeArrayRef(Args, 4));
910   else
911     Builder.CreateCall(
912         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
913         makeArrayRef(Args));
914 }
915 
916 // This method either inserts a call to the profile run-time during
917 // instrumentation or puts profile data into metadata for PGO use.
918 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
919     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
920 
921   if (!EnableValueProfiling)
922     return;
923 
924   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
925     return;
926 
927   if (isa<llvm::Constant>(ValuePtr))
928     return;
929 
930   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
931   if (InstrumentValueSites && RegionCounterMap) {
932     auto BuilderInsertPoint = Builder.saveIP();
933     Builder.SetInsertPoint(ValueSite);
934     llvm::Value *Args[5] = {
935         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
936         Builder.getInt64(FunctionHash),
937         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
938         Builder.getInt32(ValueKind),
939         Builder.getInt32(NumValueSites[ValueKind]++)
940     };
941     Builder.CreateCall(
942         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
943     Builder.restoreIP(BuilderInsertPoint);
944     return;
945   }
946 
947   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
948   if (PGOReader && haveRegionCounts()) {
949     // We record the top most called three functions at each call site.
950     // Profile metadata contains "VP" string identifying this metadata
951     // as value profiling data, then a uint32_t value for the value profiling
952     // kind, a uint64_t value for the total number of times the call is
953     // executed, followed by the function hash and execution count (uint64_t)
954     // pairs for each function.
955     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
956       return;
957 
958     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
959                             (llvm::InstrProfValueKind)ValueKind,
960                             NumValueSites[ValueKind]);
961 
962     NumValueSites[ValueKind]++;
963   }
964 }
965 
966 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
967                                   bool IsInMainFile) {
968   CGM.getPGOStats().addVisited(IsInMainFile);
969   RegionCounts.clear();
970   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
971       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
972   if (auto E = RecordExpected.takeError()) {
973     auto IPE = llvm::InstrProfError::take(std::move(E));
974     if (IPE == llvm::instrprof_error::unknown_function)
975       CGM.getPGOStats().addMissing(IsInMainFile);
976     else if (IPE == llvm::instrprof_error::hash_mismatch)
977       CGM.getPGOStats().addMismatched(IsInMainFile);
978     else if (IPE == llvm::instrprof_error::malformed)
979       // TODO: Consider a more specific warning for this case.
980       CGM.getPGOStats().addMismatched(IsInMainFile);
981     return;
982   }
983   ProfRecord =
984       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
985   RegionCounts = ProfRecord->Counts;
986 }
987 
988 /// Calculate what to divide by to scale weights.
989 ///
990 /// Given the maximum weight, calculate a divisor that will scale all the
991 /// weights to strictly less than UINT32_MAX.
992 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
993   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
994 }
995 
996 /// Scale an individual branch weight (and add 1).
997 ///
998 /// Scale a 64-bit weight down to 32-bits using \c Scale.
999 ///
1000 /// According to Laplace's Rule of Succession, it is better to compute the
1001 /// weight based on the count plus 1, so universally add 1 to the value.
1002 ///
1003 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1004 /// greater than \c Weight.
1005 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1006   assert(Scale && "scale by 0?");
1007   uint64_t Scaled = Weight / Scale + 1;
1008   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1009   return Scaled;
1010 }
1011 
1012 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1013                                                     uint64_t FalseCount) {
1014   // Check for empty weights.
1015   if (!TrueCount && !FalseCount)
1016     return nullptr;
1017 
1018   // Calculate how to scale down to 32-bits.
1019   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1020 
1021   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1022   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1023                                       scaleBranchWeight(FalseCount, Scale));
1024 }
1025 
1026 llvm::MDNode *
1027 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1028   // We need at least two elements to create meaningful weights.
1029   if (Weights.size() < 2)
1030     return nullptr;
1031 
1032   // Check for empty weights.
1033   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1034   if (MaxWeight == 0)
1035     return nullptr;
1036 
1037   // Calculate how to scale down to 32-bits.
1038   uint64_t Scale = calculateWeightScale(MaxWeight);
1039 
1040   SmallVector<uint32_t, 16> ScaledWeights;
1041   ScaledWeights.reserve(Weights.size());
1042   for (uint64_t W : Weights)
1043     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1044 
1045   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1046   return MDHelper.createBranchWeights(ScaledWeights);
1047 }
1048 
1049 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1050                                                            uint64_t LoopCount) {
1051   if (!PGO.haveRegionCounts())
1052     return nullptr;
1053   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1054   assert(CondCount.hasValue() && "missing expected loop condition count");
1055   if (*CondCount == 0)
1056     return nullptr;
1057   return createProfileWeights(LoopCount,
1058                               std::max(*CondCount, LoopCount) - LoopCount);
1059 }
1060