1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 #include <optional>
25
26 namespace llvm {
27 extern cl::opt<bool> EnableSingleByteCoverage;
28 } // namespace llvm
29
30 static llvm::cl::opt<bool>
31 EnableValueProfiling("enable-value-profiling",
32 llvm::cl::desc("Enable value profiling"),
33 llvm::cl::Hidden, llvm::cl::init(false));
34
35 using namespace clang;
36 using namespace CodeGen;
37
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)38 void CodeGenPGO::setFuncName(StringRef Name,
39 llvm::GlobalValue::LinkageTypes Linkage) {
40 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
41 FuncName = llvm::getPGOFuncName(
42 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
43 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
44
45 // If we're generating a profile, create a variable for the name.
46 if (CGM.getCodeGenOpts().hasProfileClangInstr())
47 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
48 }
49
setFuncName(llvm::Function * Fn)50 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
51 setFuncName(Fn->getName(), Fn->getLinkage());
52 // Create PGOFuncName meta data.
53 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
54 }
55
56 /// The version of the PGO hash algorithm.
57 enum PGOHashVersion : unsigned {
58 PGO_HASH_V1,
59 PGO_HASH_V2,
60 PGO_HASH_V3,
61
62 // Keep this set to the latest hash version.
63 PGO_HASH_LATEST = PGO_HASH_V3
64 };
65
66 namespace {
67 /// Stable hasher for PGO region counters.
68 ///
69 /// PGOHash produces a stable hash of a given function's control flow.
70 ///
71 /// Changing the output of this hash will invalidate all previously generated
72 /// profiles -- i.e., don't do it.
73 ///
74 /// \note When this hash does eventually change (years?), we still need to
75 /// support old hashes. We'll need to pull in the version number from the
76 /// profile data format and use the matching hash function.
77 class PGOHash {
78 uint64_t Working;
79 unsigned Count;
80 PGOHashVersion HashVersion;
81 llvm::MD5 MD5;
82
83 static const int NumBitsPerType = 6;
84 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
85 static const unsigned TooBig = 1u << NumBitsPerType;
86
87 public:
88 /// Hash values for AST nodes.
89 ///
90 /// Distinct values for AST nodes that have region counters attached.
91 ///
92 /// These values must be stable. All new members must be added at the end,
93 /// and no members should be removed. Changing the enumeration value for an
94 /// AST node will affect the hash of every function that contains that node.
95 enum HashType : unsigned char {
96 None = 0,
97 LabelStmt = 1,
98 WhileStmt,
99 DoStmt,
100 ForStmt,
101 CXXForRangeStmt,
102 ObjCForCollectionStmt,
103 SwitchStmt,
104 CaseStmt,
105 DefaultStmt,
106 IfStmt,
107 CXXTryStmt,
108 CXXCatchStmt,
109 ConditionalOperator,
110 BinaryOperatorLAnd,
111 BinaryOperatorLOr,
112 BinaryConditionalOperator,
113 // The preceding values are available with PGO_HASH_V1.
114
115 EndOfScope,
116 IfThenBranch,
117 IfElseBranch,
118 GotoStmt,
119 IndirectGotoStmt,
120 BreakStmt,
121 ContinueStmt,
122 ReturnStmt,
123 ThrowExpr,
124 UnaryOperatorLNot,
125 BinaryOperatorLT,
126 BinaryOperatorGT,
127 BinaryOperatorLE,
128 BinaryOperatorGE,
129 BinaryOperatorEQ,
130 BinaryOperatorNE,
131 // The preceding values are available since PGO_HASH_V2.
132
133 // Keep this last. It's for the static assert that follows.
134 LastHashType
135 };
136 static_assert(LastHashType <= TooBig, "Too many types in HashType");
137
PGOHash(PGOHashVersion HashVersion)138 PGOHash(PGOHashVersion HashVersion)
139 : Working(0), Count(0), HashVersion(HashVersion) {}
140 void combine(HashType Type);
141 uint64_t finalize();
getHashVersion() const142 PGOHashVersion getHashVersion() const { return HashVersion; }
143 };
144 const int PGOHash::NumBitsPerType;
145 const unsigned PGOHash::NumTypesPerWord;
146 const unsigned PGOHash::TooBig;
147
148 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)149 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
150 CodeGenModule &CGM) {
151 if (PGOReader->getVersion() <= 4)
152 return PGO_HASH_V1;
153 if (PGOReader->getVersion() <= 5)
154 return PGO_HASH_V2;
155 return PGO_HASH_V3;
156 }
157
158 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
159 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
160 using Base = RecursiveASTVisitor<MapRegionCounters>;
161
162 /// The next counter value to assign.
163 unsigned NextCounter;
164 /// The function hash.
165 PGOHash Hash;
166 /// The map of statements to counters.
167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
168 /// The state of MC/DC Coverage in this function.
169 MCDC::State &MCDCState;
170 /// Maximum number of supported MC/DC conditions in a boolean expression.
171 unsigned MCDCMaxCond;
172 /// The profile version.
173 uint64_t ProfileVersion;
174 /// Diagnostics Engine used to report warnings.
175 DiagnosticsEngine &Diag;
176
MapRegionCounters__anon71781d390111::MapRegionCounters177 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178 llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
179 MCDC::State &MCDCState, unsigned MCDCMaxCond,
180 DiagnosticsEngine &Diag)
181 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182 MCDCState(MCDCState), MCDCMaxCond(MCDCMaxCond),
183 ProfileVersion(ProfileVersion), Diag(Diag) {}
184
185 // Blocks and lambdas are handled as separate functions, so we need not
186 // traverse them in the parent context.
TraverseBlockExpr__anon71781d390111::MapRegionCounters187 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon71781d390111::MapRegionCounters188 bool TraverseLambdaExpr(LambdaExpr *LE) {
189 // Traverse the captures, but not the body.
190 for (auto C : zip(LE->captures(), LE->capture_inits()))
191 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192 return true;
193 }
TraverseCapturedStmt__anon71781d390111::MapRegionCounters194 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195
VisitDecl__anon71781d390111::MapRegionCounters196 bool VisitDecl(const Decl *D) {
197 switch (D->getKind()) {
198 default:
199 break;
200 case Decl::Function:
201 case Decl::CXXMethod:
202 case Decl::CXXConstructor:
203 case Decl::CXXDestructor:
204 case Decl::CXXConversion:
205 case Decl::ObjCMethod:
206 case Decl::Block:
207 case Decl::Captured:
208 CounterMap[D->getBody()] = NextCounter++;
209 break;
210 }
211 return true;
212 }
213
214 /// If \p S gets a fresh counter, update the counter mappings. Return the
215 /// V1 hash of \p S.
updateCounterMappings__anon71781d390111::MapRegionCounters216 PGOHash::HashType updateCounterMappings(Stmt *S) {
217 auto Type = getHashType(PGO_HASH_V1, S);
218 if (Type != PGOHash::None)
219 CounterMap[S] = NextCounter++;
220 return Type;
221 }
222
223 /// The following stacks are used with dataTraverseStmtPre() and
224 /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225 /// boolean expression in a function. The ultimate purpose is to keep track
226 /// of the number of leaf-level conditions in the boolean expression so that a
227 /// profile bitmap can be allocated based on that number.
228 ///
229 /// The stacks are also used to find error cases and notify the user. A
230 /// standard logical operator nest for a boolean expression could be in a form
231 /// similar to this: "x = a && b && c && (d || f)"
232 unsigned NumCond = 0;
233 bool SplitNestedLogicalOp = false;
234 SmallVector<const Stmt *, 16> NonLogOpStack;
235 SmallVector<const BinaryOperator *, 16> LogOpStack;
236
237 // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
dataTraverseStmtPre__anon71781d390111::MapRegionCounters238 bool dataTraverseStmtPre(Stmt *S) {
239 /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240 if (MCDCMaxCond == 0)
241 return true;
242
243 /// At the top of the logical operator nest, reset the number of conditions,
244 /// also forget previously seen split nesting cases.
245 if (LogOpStack.empty()) {
246 NumCond = 0;
247 SplitNestedLogicalOp = false;
248 }
249
250 if (const Expr *E = dyn_cast<Expr>(S)) {
251 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
252 if (BinOp && BinOp->isLogicalOp()) {
253 /// Check for "split-nested" logical operators. This happens when a new
254 /// boolean expression logical-op nest is encountered within an existing
255 /// boolean expression, separated by a non-logical operator. For
256 /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257 /// starts a new boolean expression that is separated from the other
258 /// conditions by the operator foo(). Split-nested cases are not
259 /// supported by MC/DC.
260 SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261
262 LogOpStack.push_back(BinOp);
263 return true;
264 }
265 }
266
267 /// Keep track of non-logical operators. These are OK as long as we don't
268 /// encounter a new logical operator after seeing one.
269 if (!LogOpStack.empty())
270 NonLogOpStack.push_back(S);
271
272 return true;
273 }
274
275 // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276 // an AST Stmt node. MC/DC will use it to to signal when the top of a
277 // logical operation (boolean expression) nest is encountered.
dataTraverseStmtPost__anon71781d390111::MapRegionCounters278 bool dataTraverseStmtPost(Stmt *S) {
279 /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280 if (MCDCMaxCond == 0)
281 return true;
282
283 if (const Expr *E = dyn_cast<Expr>(S)) {
284 const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
285 if (BinOp && BinOp->isLogicalOp()) {
286 assert(LogOpStack.back() == BinOp);
287 LogOpStack.pop_back();
288
289 /// At the top of logical operator nest:
290 if (LogOpStack.empty()) {
291 /// Was the "split-nested" logical operator case encountered?
292 if (SplitNestedLogicalOp) {
293 unsigned DiagID = Diag.getCustomDiagID(
294 DiagnosticsEngine::Warning,
295 "unsupported MC/DC boolean expression; "
296 "contains an operation with a nested boolean expression. "
297 "Expression will not be covered");
298 Diag.Report(S->getBeginLoc(), DiagID);
299 return true;
300 }
301
302 /// Was the maximum number of conditions encountered?
303 if (NumCond > MCDCMaxCond) {
304 unsigned DiagID = Diag.getCustomDiagID(
305 DiagnosticsEngine::Warning,
306 "unsupported MC/DC boolean expression; "
307 "number of conditions (%0) exceeds max (%1). "
308 "Expression will not be covered");
309 Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310 return true;
311 }
312
313 // Otherwise, allocate the Decision.
314 MCDCState.DecisionByStmt[BinOp].BitmapIdx = 0;
315 }
316 return true;
317 }
318 }
319
320 if (!LogOpStack.empty())
321 NonLogOpStack.pop_back();
322
323 return true;
324 }
325
326 /// The RHS of all logical operators gets a fresh counter in order to count
327 /// how many times the RHS evaluates to true or false, depending on the
328 /// semantics of the operator. This is only valid for ">= v7" of the profile
329 /// version so that we facilitate backward compatibility. In addition, in
330 /// order to use MC/DC, count the number of total LHS and RHS conditions.
VisitBinaryOperator__anon71781d390111::MapRegionCounters331 bool VisitBinaryOperator(BinaryOperator *S) {
332 if (S->isLogicalOp()) {
333 if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
334 NumCond++;
335
336 if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
337 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
338 CounterMap[S->getRHS()] = NextCounter++;
339
340 NumCond++;
341 }
342 }
343 return Base::VisitBinaryOperator(S);
344 }
345
VisitConditionalOperator__anon71781d390111::MapRegionCounters346 bool VisitConditionalOperator(ConditionalOperator *S) {
347 if (llvm::EnableSingleByteCoverage && S->getTrueExpr())
348 CounterMap[S->getTrueExpr()] = NextCounter++;
349 if (llvm::EnableSingleByteCoverage && S->getFalseExpr())
350 CounterMap[S->getFalseExpr()] = NextCounter++;
351 return Base::VisitConditionalOperator(S);
352 }
353
354 /// Include \p S in the function hash.
VisitStmt__anon71781d390111::MapRegionCounters355 bool VisitStmt(Stmt *S) {
356 auto Type = updateCounterMappings(S);
357 if (Hash.getHashVersion() != PGO_HASH_V1)
358 Type = getHashType(Hash.getHashVersion(), S);
359 if (Type != PGOHash::None)
360 Hash.combine(Type);
361 return true;
362 }
363
TraverseIfStmt__anon71781d390111::MapRegionCounters364 bool TraverseIfStmt(IfStmt *If) {
365 // If we used the V1 hash, use the default traversal.
366 if (Hash.getHashVersion() == PGO_HASH_V1)
367 return Base::TraverseIfStmt(If);
368
369 // When single byte coverage mode is enabled, add a counter to then and
370 // else.
371 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
372 for (Stmt *CS : If->children()) {
373 if (!CS || NoSingleByteCoverage)
374 continue;
375 if (CS == If->getThen())
376 CounterMap[If->getThen()] = NextCounter++;
377 else if (CS == If->getElse())
378 CounterMap[If->getElse()] = NextCounter++;
379 }
380
381 // Otherwise, keep track of which branch we're in while traversing.
382 VisitStmt(If);
383
384 for (Stmt *CS : If->children()) {
385 if (!CS)
386 continue;
387 if (CS == If->getThen())
388 Hash.combine(PGOHash::IfThenBranch);
389 else if (CS == If->getElse())
390 Hash.combine(PGOHash::IfElseBranch);
391 TraverseStmt(CS);
392 }
393 Hash.combine(PGOHash::EndOfScope);
394 return true;
395 }
396
TraverseWhileStmt__anon71781d390111::MapRegionCounters397 bool TraverseWhileStmt(WhileStmt *While) {
398 // When single byte coverage mode is enabled, add a counter to condition and
399 // body.
400 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
401 for (Stmt *CS : While->children()) {
402 if (!CS || NoSingleByteCoverage)
403 continue;
404 if (CS == While->getCond())
405 CounterMap[While->getCond()] = NextCounter++;
406 else if (CS == While->getBody())
407 CounterMap[While->getBody()] = NextCounter++;
408 }
409
410 Base::TraverseWhileStmt(While);
411 if (Hash.getHashVersion() != PGO_HASH_V1)
412 Hash.combine(PGOHash::EndOfScope);
413 return true;
414 }
415
TraverseDoStmt__anon71781d390111::MapRegionCounters416 bool TraverseDoStmt(DoStmt *Do) {
417 // When single byte coverage mode is enabled, add a counter to condition and
418 // body.
419 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
420 for (Stmt *CS : Do->children()) {
421 if (!CS || NoSingleByteCoverage)
422 continue;
423 if (CS == Do->getCond())
424 CounterMap[Do->getCond()] = NextCounter++;
425 else if (CS == Do->getBody())
426 CounterMap[Do->getBody()] = NextCounter++;
427 }
428
429 Base::TraverseDoStmt(Do);
430 if (Hash.getHashVersion() != PGO_HASH_V1)
431 Hash.combine(PGOHash::EndOfScope);
432 return true;
433 }
434
TraverseForStmt__anon71781d390111::MapRegionCounters435 bool TraverseForStmt(ForStmt *For) {
436 // When single byte coverage mode is enabled, add a counter to condition,
437 // increment and body.
438 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
439 for (Stmt *CS : For->children()) {
440 if (!CS || NoSingleByteCoverage)
441 continue;
442 if (CS == For->getCond())
443 CounterMap[For->getCond()] = NextCounter++;
444 else if (CS == For->getInc())
445 CounterMap[For->getInc()] = NextCounter++;
446 else if (CS == For->getBody())
447 CounterMap[For->getBody()] = NextCounter++;
448 }
449
450 Base::TraverseForStmt(For);
451 if (Hash.getHashVersion() != PGO_HASH_V1)
452 Hash.combine(PGOHash::EndOfScope);
453 return true;
454 }
455
TraverseCXXForRangeStmt__anon71781d390111::MapRegionCounters456 bool TraverseCXXForRangeStmt(CXXForRangeStmt *ForRange) {
457 // When single byte coverage mode is enabled, add a counter to body.
458 bool NoSingleByteCoverage = !llvm::EnableSingleByteCoverage;
459 for (Stmt *CS : ForRange->children()) {
460 if (!CS || NoSingleByteCoverage)
461 continue;
462 if (CS == ForRange->getBody())
463 CounterMap[ForRange->getBody()] = NextCounter++;
464 }
465
466 Base::TraverseCXXForRangeStmt(ForRange);
467 if (Hash.getHashVersion() != PGO_HASH_V1)
468 Hash.combine(PGOHash::EndOfScope);
469 return true;
470 }
471
472 // If the statement type \p N is nestable, and its nesting impacts profile
473 // stability, define a custom traversal which tracks the end of the statement
474 // in the hash (provided we're not using the V1 hash).
475 #define DEFINE_NESTABLE_TRAVERSAL(N) \
476 bool Traverse##N(N *S) { \
477 Base::Traverse##N(S); \
478 if (Hash.getHashVersion() != PGO_HASH_V1) \
479 Hash.combine(PGOHash::EndOfScope); \
480 return true; \
481 }
482
483 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
DEFINE_NESTABLE_TRAVERSAL__anon71781d390111::MapRegionCounters484 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
485 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
486
487 /// Get version \p HashVersion of the PGO hash for \p S.
488 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
489 switch (S->getStmtClass()) {
490 default:
491 break;
492 case Stmt::LabelStmtClass:
493 return PGOHash::LabelStmt;
494 case Stmt::WhileStmtClass:
495 return PGOHash::WhileStmt;
496 case Stmt::DoStmtClass:
497 return PGOHash::DoStmt;
498 case Stmt::ForStmtClass:
499 return PGOHash::ForStmt;
500 case Stmt::CXXForRangeStmtClass:
501 return PGOHash::CXXForRangeStmt;
502 case Stmt::ObjCForCollectionStmtClass:
503 return PGOHash::ObjCForCollectionStmt;
504 case Stmt::SwitchStmtClass:
505 return PGOHash::SwitchStmt;
506 case Stmt::CaseStmtClass:
507 return PGOHash::CaseStmt;
508 case Stmt::DefaultStmtClass:
509 return PGOHash::DefaultStmt;
510 case Stmt::IfStmtClass:
511 return PGOHash::IfStmt;
512 case Stmt::CXXTryStmtClass:
513 return PGOHash::CXXTryStmt;
514 case Stmt::CXXCatchStmtClass:
515 return PGOHash::CXXCatchStmt;
516 case Stmt::ConditionalOperatorClass:
517 return PGOHash::ConditionalOperator;
518 case Stmt::BinaryConditionalOperatorClass:
519 return PGOHash::BinaryConditionalOperator;
520 case Stmt::BinaryOperatorClass: {
521 const BinaryOperator *BO = cast<BinaryOperator>(S);
522 if (BO->getOpcode() == BO_LAnd)
523 return PGOHash::BinaryOperatorLAnd;
524 if (BO->getOpcode() == BO_LOr)
525 return PGOHash::BinaryOperatorLOr;
526 if (HashVersion >= PGO_HASH_V2) {
527 switch (BO->getOpcode()) {
528 default:
529 break;
530 case BO_LT:
531 return PGOHash::BinaryOperatorLT;
532 case BO_GT:
533 return PGOHash::BinaryOperatorGT;
534 case BO_LE:
535 return PGOHash::BinaryOperatorLE;
536 case BO_GE:
537 return PGOHash::BinaryOperatorGE;
538 case BO_EQ:
539 return PGOHash::BinaryOperatorEQ;
540 case BO_NE:
541 return PGOHash::BinaryOperatorNE;
542 }
543 }
544 break;
545 }
546 }
547
548 if (HashVersion >= PGO_HASH_V2) {
549 switch (S->getStmtClass()) {
550 default:
551 break;
552 case Stmt::GotoStmtClass:
553 return PGOHash::GotoStmt;
554 case Stmt::IndirectGotoStmtClass:
555 return PGOHash::IndirectGotoStmt;
556 case Stmt::BreakStmtClass:
557 return PGOHash::BreakStmt;
558 case Stmt::ContinueStmtClass:
559 return PGOHash::ContinueStmt;
560 case Stmt::ReturnStmtClass:
561 return PGOHash::ReturnStmt;
562 case Stmt::CXXThrowExprClass:
563 return PGOHash::ThrowExpr;
564 case Stmt::UnaryOperatorClass: {
565 const UnaryOperator *UO = cast<UnaryOperator>(S);
566 if (UO->getOpcode() == UO_LNot)
567 return PGOHash::UnaryOperatorLNot;
568 break;
569 }
570 }
571 }
572
573 return PGOHash::None;
574 }
575 };
576
577 /// A StmtVisitor that propagates the raw counts through the AST and
578 /// records the count at statements where the value may change.
579 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
580 /// PGO state.
581 CodeGenPGO &PGO;
582
583 /// A flag that is set when the current count should be recorded on the
584 /// next statement, such as at the exit of a loop.
585 bool RecordNextStmtCount;
586
587 /// The count at the current location in the traversal.
588 uint64_t CurrentCount;
589
590 /// The map of statements to count values.
591 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
592
593 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
594 struct BreakContinue {
595 uint64_t BreakCount = 0;
596 uint64_t ContinueCount = 0;
597 BreakContinue() = default;
598 };
599 SmallVector<BreakContinue, 8> BreakContinueStack;
600
ComputeRegionCounts__anon71781d390111::ComputeRegionCounts601 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
602 CodeGenPGO &PGO)
603 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
604
RecordStmtCount__anon71781d390111::ComputeRegionCounts605 void RecordStmtCount(const Stmt *S) {
606 if (RecordNextStmtCount) {
607 CountMap[S] = CurrentCount;
608 RecordNextStmtCount = false;
609 }
610 }
611
612 /// Set and return the current count.
setCount__anon71781d390111::ComputeRegionCounts613 uint64_t setCount(uint64_t Count) {
614 CurrentCount = Count;
615 return Count;
616 }
617
VisitStmt__anon71781d390111::ComputeRegionCounts618 void VisitStmt(const Stmt *S) {
619 RecordStmtCount(S);
620 for (const Stmt *Child : S->children())
621 if (Child)
622 this->Visit(Child);
623 }
624
VisitFunctionDecl__anon71781d390111::ComputeRegionCounts625 void VisitFunctionDecl(const FunctionDecl *D) {
626 // Counter tracks entry to the function body.
627 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
628 CountMap[D->getBody()] = BodyCount;
629 Visit(D->getBody());
630 }
631
632 // Skip lambda expressions. We visit these as FunctionDecls when we're
633 // generating them and aren't interested in the body when generating a
634 // parent context.
VisitLambdaExpr__anon71781d390111::ComputeRegionCounts635 void VisitLambdaExpr(const LambdaExpr *LE) {}
636
VisitCapturedDecl__anon71781d390111::ComputeRegionCounts637 void VisitCapturedDecl(const CapturedDecl *D) {
638 // Counter tracks entry to the capture body.
639 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
640 CountMap[D->getBody()] = BodyCount;
641 Visit(D->getBody());
642 }
643
VisitObjCMethodDecl__anon71781d390111::ComputeRegionCounts644 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
645 // Counter tracks entry to the method body.
646 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
647 CountMap[D->getBody()] = BodyCount;
648 Visit(D->getBody());
649 }
650
VisitBlockDecl__anon71781d390111::ComputeRegionCounts651 void VisitBlockDecl(const BlockDecl *D) {
652 // Counter tracks entry to the block body.
653 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
654 CountMap[D->getBody()] = BodyCount;
655 Visit(D->getBody());
656 }
657
VisitReturnStmt__anon71781d390111::ComputeRegionCounts658 void VisitReturnStmt(const ReturnStmt *S) {
659 RecordStmtCount(S);
660 if (S->getRetValue())
661 Visit(S->getRetValue());
662 CurrentCount = 0;
663 RecordNextStmtCount = true;
664 }
665
VisitCXXThrowExpr__anon71781d390111::ComputeRegionCounts666 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
667 RecordStmtCount(E);
668 if (E->getSubExpr())
669 Visit(E->getSubExpr());
670 CurrentCount = 0;
671 RecordNextStmtCount = true;
672 }
673
VisitGotoStmt__anon71781d390111::ComputeRegionCounts674 void VisitGotoStmt(const GotoStmt *S) {
675 RecordStmtCount(S);
676 CurrentCount = 0;
677 RecordNextStmtCount = true;
678 }
679
VisitLabelStmt__anon71781d390111::ComputeRegionCounts680 void VisitLabelStmt(const LabelStmt *S) {
681 RecordNextStmtCount = false;
682 // Counter tracks the block following the label.
683 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
684 CountMap[S] = BlockCount;
685 Visit(S->getSubStmt());
686 }
687
VisitBreakStmt__anon71781d390111::ComputeRegionCounts688 void VisitBreakStmt(const BreakStmt *S) {
689 RecordStmtCount(S);
690 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
691 BreakContinueStack.back().BreakCount += CurrentCount;
692 CurrentCount = 0;
693 RecordNextStmtCount = true;
694 }
695
VisitContinueStmt__anon71781d390111::ComputeRegionCounts696 void VisitContinueStmt(const ContinueStmt *S) {
697 RecordStmtCount(S);
698 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
699 BreakContinueStack.back().ContinueCount += CurrentCount;
700 CurrentCount = 0;
701 RecordNextStmtCount = true;
702 }
703
VisitWhileStmt__anon71781d390111::ComputeRegionCounts704 void VisitWhileStmt(const WhileStmt *S) {
705 RecordStmtCount(S);
706 uint64_t ParentCount = CurrentCount;
707
708 BreakContinueStack.push_back(BreakContinue());
709 // Visit the body region first so the break/continue adjustments can be
710 // included when visiting the condition.
711 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
712 CountMap[S->getBody()] = CurrentCount;
713 Visit(S->getBody());
714 uint64_t BackedgeCount = CurrentCount;
715
716 // ...then go back and propagate counts through the condition. The count
717 // at the start of the condition is the sum of the incoming edges,
718 // the backedge from the end of the loop body, and the edges from
719 // continue statements.
720 BreakContinue BC = BreakContinueStack.pop_back_val();
721 uint64_t CondCount =
722 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
723 CountMap[S->getCond()] = CondCount;
724 Visit(S->getCond());
725 setCount(BC.BreakCount + CondCount - BodyCount);
726 RecordNextStmtCount = true;
727 }
728
VisitDoStmt__anon71781d390111::ComputeRegionCounts729 void VisitDoStmt(const DoStmt *S) {
730 RecordStmtCount(S);
731 uint64_t LoopCount = PGO.getRegionCount(S);
732
733 BreakContinueStack.push_back(BreakContinue());
734 // The count doesn't include the fallthrough from the parent scope. Add it.
735 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
736 CountMap[S->getBody()] = BodyCount;
737 Visit(S->getBody());
738 uint64_t BackedgeCount = CurrentCount;
739
740 BreakContinue BC = BreakContinueStack.pop_back_val();
741 // The count at the start of the condition is equal to the count at the
742 // end of the body, plus any continues.
743 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
744 CountMap[S->getCond()] = CondCount;
745 Visit(S->getCond());
746 setCount(BC.BreakCount + CondCount - LoopCount);
747 RecordNextStmtCount = true;
748 }
749
VisitForStmt__anon71781d390111::ComputeRegionCounts750 void VisitForStmt(const ForStmt *S) {
751 RecordStmtCount(S);
752 if (S->getInit())
753 Visit(S->getInit());
754
755 uint64_t ParentCount = CurrentCount;
756
757 BreakContinueStack.push_back(BreakContinue());
758 // Visit the body region first. (This is basically the same as a while
759 // loop; see further comments in VisitWhileStmt.)
760 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
761 CountMap[S->getBody()] = BodyCount;
762 Visit(S->getBody());
763 uint64_t BackedgeCount = CurrentCount;
764 BreakContinue BC = BreakContinueStack.pop_back_val();
765
766 // The increment is essentially part of the body but it needs to include
767 // the count for all the continue statements.
768 if (S->getInc()) {
769 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
770 CountMap[S->getInc()] = IncCount;
771 Visit(S->getInc());
772 }
773
774 // ...then go back and propagate counts through the condition.
775 uint64_t CondCount =
776 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
777 if (S->getCond()) {
778 CountMap[S->getCond()] = CondCount;
779 Visit(S->getCond());
780 }
781 setCount(BC.BreakCount + CondCount - BodyCount);
782 RecordNextStmtCount = true;
783 }
784
VisitCXXForRangeStmt__anon71781d390111::ComputeRegionCounts785 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
786 RecordStmtCount(S);
787 if (S->getInit())
788 Visit(S->getInit());
789 Visit(S->getLoopVarStmt());
790 Visit(S->getRangeStmt());
791 Visit(S->getBeginStmt());
792 Visit(S->getEndStmt());
793
794 uint64_t ParentCount = CurrentCount;
795 BreakContinueStack.push_back(BreakContinue());
796 // Visit the body region first. (This is basically the same as a while
797 // loop; see further comments in VisitWhileStmt.)
798 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
799 CountMap[S->getBody()] = BodyCount;
800 Visit(S->getBody());
801 uint64_t BackedgeCount = CurrentCount;
802 BreakContinue BC = BreakContinueStack.pop_back_val();
803
804 // The increment is essentially part of the body but it needs to include
805 // the count for all the continue statements.
806 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
807 CountMap[S->getInc()] = IncCount;
808 Visit(S->getInc());
809
810 // ...then go back and propagate counts through the condition.
811 uint64_t CondCount =
812 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
813 CountMap[S->getCond()] = CondCount;
814 Visit(S->getCond());
815 setCount(BC.BreakCount + CondCount - BodyCount);
816 RecordNextStmtCount = true;
817 }
818
VisitObjCForCollectionStmt__anon71781d390111::ComputeRegionCounts819 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
820 RecordStmtCount(S);
821 Visit(S->getElement());
822 uint64_t ParentCount = CurrentCount;
823 BreakContinueStack.push_back(BreakContinue());
824 // Counter tracks the body of the loop.
825 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
826 CountMap[S->getBody()] = BodyCount;
827 Visit(S->getBody());
828 uint64_t BackedgeCount = CurrentCount;
829 BreakContinue BC = BreakContinueStack.pop_back_val();
830
831 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
832 BodyCount);
833 RecordNextStmtCount = true;
834 }
835
VisitSwitchStmt__anon71781d390111::ComputeRegionCounts836 void VisitSwitchStmt(const SwitchStmt *S) {
837 RecordStmtCount(S);
838 if (S->getInit())
839 Visit(S->getInit());
840 Visit(S->getCond());
841 CurrentCount = 0;
842 BreakContinueStack.push_back(BreakContinue());
843 Visit(S->getBody());
844 // If the switch is inside a loop, add the continue counts.
845 BreakContinue BC = BreakContinueStack.pop_back_val();
846 if (!BreakContinueStack.empty())
847 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
848 // Counter tracks the exit block of the switch.
849 setCount(PGO.getRegionCount(S));
850 RecordNextStmtCount = true;
851 }
852
VisitSwitchCase__anon71781d390111::ComputeRegionCounts853 void VisitSwitchCase(const SwitchCase *S) {
854 RecordNextStmtCount = false;
855 // Counter for this particular case. This counts only jumps from the
856 // switch header and does not include fallthrough from the case before
857 // this one.
858 uint64_t CaseCount = PGO.getRegionCount(S);
859 setCount(CurrentCount + CaseCount);
860 // We need the count without fallthrough in the mapping, so it's more useful
861 // for branch probabilities.
862 CountMap[S] = CaseCount;
863 RecordNextStmtCount = true;
864 Visit(S->getSubStmt());
865 }
866
VisitIfStmt__anon71781d390111::ComputeRegionCounts867 void VisitIfStmt(const IfStmt *S) {
868 RecordStmtCount(S);
869
870 if (S->isConsteval()) {
871 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
872 if (Stm)
873 Visit(Stm);
874 return;
875 }
876
877 uint64_t ParentCount = CurrentCount;
878 if (S->getInit())
879 Visit(S->getInit());
880 Visit(S->getCond());
881
882 // Counter tracks the "then" part of an if statement. The count for
883 // the "else" part, if it exists, will be calculated from this counter.
884 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
885 CountMap[S->getThen()] = ThenCount;
886 Visit(S->getThen());
887 uint64_t OutCount = CurrentCount;
888
889 uint64_t ElseCount = ParentCount - ThenCount;
890 if (S->getElse()) {
891 setCount(ElseCount);
892 CountMap[S->getElse()] = ElseCount;
893 Visit(S->getElse());
894 OutCount += CurrentCount;
895 } else
896 OutCount += ElseCount;
897 setCount(OutCount);
898 RecordNextStmtCount = true;
899 }
900
VisitCXXTryStmt__anon71781d390111::ComputeRegionCounts901 void VisitCXXTryStmt(const CXXTryStmt *S) {
902 RecordStmtCount(S);
903 Visit(S->getTryBlock());
904 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
905 Visit(S->getHandler(I));
906 // Counter tracks the continuation block of the try statement.
907 setCount(PGO.getRegionCount(S));
908 RecordNextStmtCount = true;
909 }
910
VisitCXXCatchStmt__anon71781d390111::ComputeRegionCounts911 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
912 RecordNextStmtCount = false;
913 // Counter tracks the catch statement's handler block.
914 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
915 CountMap[S] = CatchCount;
916 Visit(S->getHandlerBlock());
917 }
918
VisitAbstractConditionalOperator__anon71781d390111::ComputeRegionCounts919 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
920 RecordStmtCount(E);
921 uint64_t ParentCount = CurrentCount;
922 Visit(E->getCond());
923
924 // Counter tracks the "true" part of a conditional operator. The
925 // count in the "false" part will be calculated from this counter.
926 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
927 CountMap[E->getTrueExpr()] = TrueCount;
928 Visit(E->getTrueExpr());
929 uint64_t OutCount = CurrentCount;
930
931 uint64_t FalseCount = setCount(ParentCount - TrueCount);
932 CountMap[E->getFalseExpr()] = FalseCount;
933 Visit(E->getFalseExpr());
934 OutCount += CurrentCount;
935
936 setCount(OutCount);
937 RecordNextStmtCount = true;
938 }
939
VisitBinLAnd__anon71781d390111::ComputeRegionCounts940 void VisitBinLAnd(const BinaryOperator *E) {
941 RecordStmtCount(E);
942 uint64_t ParentCount = CurrentCount;
943 Visit(E->getLHS());
944 // Counter tracks the right hand side of a logical and operator.
945 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
946 CountMap[E->getRHS()] = RHSCount;
947 Visit(E->getRHS());
948 setCount(ParentCount + RHSCount - CurrentCount);
949 RecordNextStmtCount = true;
950 }
951
VisitBinLOr__anon71781d390111::ComputeRegionCounts952 void VisitBinLOr(const BinaryOperator *E) {
953 RecordStmtCount(E);
954 uint64_t ParentCount = CurrentCount;
955 Visit(E->getLHS());
956 // Counter tracks the right hand side of a logical or operator.
957 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
958 CountMap[E->getRHS()] = RHSCount;
959 Visit(E->getRHS());
960 setCount(ParentCount + RHSCount - CurrentCount);
961 RecordNextStmtCount = true;
962 }
963 };
964 } // end anonymous namespace
965
combine(HashType Type)966 void PGOHash::combine(HashType Type) {
967 // Check that we never combine 0 and only have six bits.
968 assert(Type && "Hash is invalid: unexpected type 0");
969 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
970
971 // Pass through MD5 if enough work has built up.
972 if (Count && Count % NumTypesPerWord == 0) {
973 using namespace llvm::support;
974 uint64_t Swapped =
975 endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
976 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
977 Working = 0;
978 }
979
980 // Accumulate the current type.
981 ++Count;
982 Working = Working << NumBitsPerType | Type;
983 }
984
finalize()985 uint64_t PGOHash::finalize() {
986 // Use Working as the hash directly if we never used MD5.
987 if (Count <= NumTypesPerWord)
988 // No need to byte swap here, since none of the math was endian-dependent.
989 // This number will be byte-swapped as required on endianness transitions,
990 // so we will see the same value on the other side.
991 return Working;
992
993 // Check for remaining work in Working.
994 if (Working) {
995 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
996 // is buggy because it converts a uint64_t into an array of uint8_t.
997 if (HashVersion < PGO_HASH_V3) {
998 MD5.update({(uint8_t)Working});
999 } else {
1000 using namespace llvm::support;
1001 uint64_t Swapped =
1002 endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
1003 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
1004 }
1005 }
1006
1007 // Finalize the MD5 and return the hash.
1008 llvm::MD5::MD5Result Result;
1009 MD5.final(Result);
1010 return Result.low();
1011 }
1012
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)1013 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
1014 const Decl *D = GD.getDecl();
1015 if (!D->hasBody())
1016 return;
1017
1018 // Skip CUDA/HIP kernel launch stub functions.
1019 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
1020 D->hasAttr<CUDAGlobalAttr>())
1021 return;
1022
1023 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
1024 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1025 if (!InstrumentRegions && !PGOReader)
1026 return;
1027 if (D->isImplicit())
1028 return;
1029 // Constructors and destructors may be represented by several functions in IR.
1030 // If so, instrument only base variant, others are implemented by delegation
1031 // to the base one, it would be counted twice otherwise.
1032 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
1033 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
1034 if (GD.getCtorType() != Ctor_Base &&
1035 CodeGenFunction::IsConstructorDelegationValid(CCD))
1036 return;
1037 }
1038 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
1039 return;
1040
1041 CGM.ClearUnusedCoverageMapping(D);
1042 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
1043 return;
1044 if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
1045 return;
1046
1047 SourceManager &SM = CGM.getContext().getSourceManager();
1048 if (!llvm::coverage::SystemHeadersCoverage &&
1049 SM.isInSystemHeader(D->getLocation()))
1050 return;
1051
1052 setFuncName(Fn);
1053
1054 mapRegionCounters(D);
1055 if (CGM.getCodeGenOpts().CoverageMapping)
1056 emitCounterRegionMapping(D);
1057 if (PGOReader) {
1058 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
1059 computeRegionCounts(D);
1060 applyFunctionAttributes(PGOReader, Fn);
1061 }
1062 }
1063
mapRegionCounters(const Decl * D)1064 void CodeGenPGO::mapRegionCounters(const Decl *D) {
1065 // Use the latest hash version when inserting instrumentation, but use the
1066 // version in the indexed profile if we're reading PGO data.
1067 PGOHashVersion HashVersion = PGO_HASH_LATEST;
1068 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
1069 if (auto *PGOReader = CGM.getPGOReader()) {
1070 HashVersion = getPGOHashVersion(PGOReader, CGM);
1071 ProfileVersion = PGOReader->getVersion();
1072 }
1073
1074 // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
1075 // set it to zero. This value impacts the number of conditions accepted in a
1076 // given boolean expression, which impacts the size of the bitmap used to
1077 // track test vector execution for that boolean expression. Because the
1078 // bitmap scales exponentially (2^n) based on the number of conditions seen,
1079 // the maximum value is hard-coded at 6 conditions, which is more than enough
1080 // for most embedded applications. Setting a maximum value prevents the
1081 // bitmap footprint from growing too large without the user's knowledge. In
1082 // the future, this value could be adjusted with a command-line option.
1083 unsigned MCDCMaxConditions =
1084 (CGM.getCodeGenOpts().MCDCCoverage ? CGM.getCodeGenOpts().MCDCMaxConds
1085 : 0);
1086
1087 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1088 RegionMCDCState.reset(new MCDC::State);
1089 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
1090 *RegionMCDCState, MCDCMaxConditions, CGM.getDiags());
1091 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1092 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
1093 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1094 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1095 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1096 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1097 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1098 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1099 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1100 NumRegionCounters = Walker.NextCounter;
1101 FunctionHash = Walker.Hash.finalize();
1102 }
1103
skipRegionMappingForDecl(const Decl * D)1104 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1105 if (!D->getBody())
1106 return true;
1107
1108 // Skip host-only functions in the CUDA device compilation and device-only
1109 // functions in the host compilation. Just roughly filter them out based on
1110 // the function attributes. If there are effectively host-only or device-only
1111 // ones, their coverage mapping may still be generated.
1112 if (CGM.getLangOpts().CUDA &&
1113 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1114 !D->hasAttr<CUDAGlobalAttr>()) ||
1115 (!CGM.getLangOpts().CUDAIsDevice &&
1116 (D->hasAttr<CUDAGlobalAttr>() ||
1117 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1118 return true;
1119
1120 // Don't map the functions in system headers.
1121 const auto &SM = CGM.getContext().getSourceManager();
1122 auto Loc = D->getBody()->getBeginLoc();
1123 return !llvm::coverage::SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1124 }
1125
emitCounterRegionMapping(const Decl * D)1126 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1127 if (skipRegionMappingForDecl(D))
1128 return;
1129
1130 std::string CoverageMapping;
1131 llvm::raw_string_ostream OS(CoverageMapping);
1132 RegionMCDCState->BranchByStmt.clear();
1133 CoverageMappingGen MappingGen(
1134 *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1135 CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCState.get());
1136 MappingGen.emitCounterMapping(D, OS);
1137 OS.flush();
1138
1139 if (CoverageMapping.empty())
1140 return;
1141
1142 CGM.getCoverageMapping()->addFunctionMappingRecord(
1143 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1144 }
1145
1146 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)1147 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1148 llvm::GlobalValue::LinkageTypes Linkage) {
1149 if (skipRegionMappingForDecl(D))
1150 return;
1151
1152 std::string CoverageMapping;
1153 llvm::raw_string_ostream OS(CoverageMapping);
1154 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1155 CGM.getContext().getSourceManager(),
1156 CGM.getLangOpts());
1157 MappingGen.emitEmptyMapping(D, OS);
1158 OS.flush();
1159
1160 if (CoverageMapping.empty())
1161 return;
1162
1163 setFuncName(Name, Linkage);
1164 CGM.getCoverageMapping()->addFunctionMappingRecord(
1165 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1166 }
1167
computeRegionCounts(const Decl * D)1168 void CodeGenPGO::computeRegionCounts(const Decl *D) {
1169 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1170 ComputeRegionCounts Walker(*StmtCountMap, *this);
1171 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1172 Walker.VisitFunctionDecl(FD);
1173 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1174 Walker.VisitObjCMethodDecl(MD);
1175 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1176 Walker.VisitBlockDecl(BD);
1177 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1178 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1179 }
1180
1181 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)1182 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1183 llvm::Function *Fn) {
1184 if (!haveRegionCounts())
1185 return;
1186
1187 uint64_t FunctionCount = getRegionCount(nullptr);
1188 Fn->setEntryCount(FunctionCount);
1189 }
1190
emitCounterSetOrIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)1191 void CodeGenPGO::emitCounterSetOrIncrement(CGBuilderTy &Builder, const Stmt *S,
1192 llvm::Value *StepV) {
1193 if (!RegionCounterMap || !Builder.GetInsertBlock())
1194 return;
1195
1196 unsigned Counter = (*RegionCounterMap)[S];
1197
1198 llvm::Value *Args[] = {FuncNameVar,
1199 Builder.getInt64(FunctionHash),
1200 Builder.getInt32(NumRegionCounters),
1201 Builder.getInt32(Counter), StepV};
1202
1203 if (llvm::EnableSingleByteCoverage)
1204 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_cover),
1205 ArrayRef(Args, 4));
1206 else {
1207 if (!StepV)
1208 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1209 ArrayRef(Args, 4));
1210 else
1211 Builder.CreateCall(
1212 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), Args);
1213 }
1214 }
1215
canEmitMCDCCoverage(const CGBuilderTy & Builder)1216 bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1217 return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1218 CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1219 }
1220
emitMCDCParameters(CGBuilderTy & Builder)1221 void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1222 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1223 return;
1224
1225 auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1226
1227 // Emit intrinsic representing MCDC bitmap parameters at function entry.
1228 // This is used by the instrumentation pass, but it isn't actually lowered to
1229 // anything.
1230 llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1231 Builder.getInt64(FunctionHash),
1232 Builder.getInt32(RegionMCDCState->BitmapBits)};
1233 Builder.CreateCall(
1234 CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1235 }
1236
emitMCDCTestVectorBitmapUpdate(CGBuilderTy & Builder,const Expr * S,Address MCDCCondBitmapAddr,CodeGenFunction & CGF)1237 void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1238 const Expr *S,
1239 Address MCDCCondBitmapAddr,
1240 CodeGenFunction &CGF) {
1241 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1242 return;
1243
1244 S = S->IgnoreParens();
1245
1246 auto DecisionStateIter = RegionMCDCState->DecisionByStmt.find(S);
1247 if (DecisionStateIter == RegionMCDCState->DecisionByStmt.end())
1248 return;
1249
1250 // Don't create tvbitmap_update if the record is allocated but excluded.
1251 // Or `bitmap |= (1 << 0)` would be wrongly executed to the next bitmap.
1252 if (DecisionStateIter->second.Indices.size() == 0)
1253 return;
1254
1255 // Extract the offset of the global bitmap associated with this expression.
1256 unsigned MCDCTestVectorBitmapOffset = DecisionStateIter->second.BitmapIdx;
1257 auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1258
1259 // Emit intrinsic responsible for updating the global bitmap corresponding to
1260 // a boolean expression. The index being set is based on the value loaded
1261 // from a pointer to a dedicated temporary value on the stack that is itself
1262 // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1263 // index represents an executed test vector.
1264 llvm::Value *Args[4] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1265 Builder.getInt64(FunctionHash),
1266 Builder.getInt32(MCDCTestVectorBitmapOffset),
1267 MCDCCondBitmapAddr.emitRawPointer(CGF)};
1268 Builder.CreateCall(
1269 CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1270 }
1271
emitMCDCCondBitmapReset(CGBuilderTy & Builder,const Expr * S,Address MCDCCondBitmapAddr)1272 void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1273 Address MCDCCondBitmapAddr) {
1274 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1275 return;
1276
1277 S = S->IgnoreParens();
1278
1279 if (!RegionMCDCState->DecisionByStmt.contains(S))
1280 return;
1281
1282 // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1283 Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1284 }
1285
emitMCDCCondBitmapUpdate(CGBuilderTy & Builder,const Expr * S,Address MCDCCondBitmapAddr,llvm::Value * Val,CodeGenFunction & CGF)1286 void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1287 Address MCDCCondBitmapAddr,
1288 llvm::Value *Val,
1289 CodeGenFunction &CGF) {
1290 if (!canEmitMCDCCoverage(Builder) || !RegionMCDCState)
1291 return;
1292
1293 // Even though, for simplicity, parentheses and unary logical-NOT operators
1294 // are considered part of their underlying condition for both MC/DC and
1295 // branch coverage, the condition IDs themselves are assigned and tracked
1296 // using the underlying condition itself. This is done solely for
1297 // consistency since parentheses and logical-NOTs are ignored when checking
1298 // whether the condition is actually an instrumentable condition. This can
1299 // also make debugging a bit easier.
1300 S = CodeGenFunction::stripCond(S);
1301
1302 auto BranchStateIter = RegionMCDCState->BranchByStmt.find(S);
1303 if (BranchStateIter == RegionMCDCState->BranchByStmt.end())
1304 return;
1305
1306 // Extract the ID of the condition we are setting in the bitmap.
1307 const auto &Branch = BranchStateIter->second;
1308 assert(Branch.ID >= 0 && "Condition has no ID!");
1309 assert(Branch.DecisionStmt);
1310
1311 // Cancel the emission if the Decision is erased after the allocation.
1312 const auto DecisionIter =
1313 RegionMCDCState->DecisionByStmt.find(Branch.DecisionStmt);
1314 if (DecisionIter == RegionMCDCState->DecisionByStmt.end())
1315 return;
1316
1317 const auto &TVIdxs = DecisionIter->second.Indices[Branch.ID];
1318
1319 auto *CurTV = Builder.CreateLoad(MCDCCondBitmapAddr,
1320 "mcdc." + Twine(Branch.ID + 1) + ".cur");
1321 auto *NewTV = Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[true]));
1322 NewTV = Builder.CreateSelect(
1323 Val, NewTV, Builder.CreateAdd(CurTV, Builder.getInt32(TVIdxs[false])));
1324 Builder.CreateStore(NewTV, MCDCCondBitmapAddr);
1325 }
1326
setValueProfilingFlag(llvm::Module & M)1327 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1328 if (CGM.getCodeGenOpts().hasProfileClangInstr())
1329 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1330 uint32_t(EnableValueProfiling));
1331 }
1332
setProfileVersion(llvm::Module & M)1333 void CodeGenPGO::setProfileVersion(llvm::Module &M) {
1334 if (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1335 llvm::EnableSingleByteCoverage) {
1336 const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
1337 llvm::Type *IntTy64 = llvm::Type::getInt64Ty(M.getContext());
1338 uint64_t ProfileVersion =
1339 (INSTR_PROF_RAW_VERSION | VARIANT_MASK_BYTE_COVERAGE);
1340
1341 auto IRLevelVersionVariable = new llvm::GlobalVariable(
1342 M, IntTy64, true, llvm::GlobalValue::WeakAnyLinkage,
1343 llvm::Constant::getIntegerValue(IntTy64,
1344 llvm::APInt(64, ProfileVersion)),
1345 VarName);
1346
1347 IRLevelVersionVariable->setVisibility(llvm::GlobalValue::HiddenVisibility);
1348 llvm::Triple TT(M.getTargetTriple());
1349 if (TT.supportsCOMDAT()) {
1350 IRLevelVersionVariable->setLinkage(llvm::GlobalValue::ExternalLinkage);
1351 IRLevelVersionVariable->setComdat(M.getOrInsertComdat(VarName));
1352 }
1353 IRLevelVersionVariable->setDSOLocal(true);
1354 }
1355 }
1356
1357 // This method either inserts a call to the profile run-time during
1358 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)1359 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1360 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1361
1362 if (!EnableValueProfiling)
1363 return;
1364
1365 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1366 return;
1367
1368 if (isa<llvm::Constant>(ValuePtr))
1369 return;
1370
1371 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1372 if (InstrumentValueSites && RegionCounterMap) {
1373 auto BuilderInsertPoint = Builder.saveIP();
1374 Builder.SetInsertPoint(ValueSite);
1375 llvm::Value *Args[5] = {
1376 FuncNameVar,
1377 Builder.getInt64(FunctionHash),
1378 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1379 Builder.getInt32(ValueKind),
1380 Builder.getInt32(NumValueSites[ValueKind]++)
1381 };
1382 Builder.CreateCall(
1383 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1384 Builder.restoreIP(BuilderInsertPoint);
1385 return;
1386 }
1387
1388 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1389 if (PGOReader && haveRegionCounts()) {
1390 // We record the top most called three functions at each call site.
1391 // Profile metadata contains "VP" string identifying this metadata
1392 // as value profiling data, then a uint32_t value for the value profiling
1393 // kind, a uint64_t value for the total number of times the call is
1394 // executed, followed by the function hash and execution count (uint64_t)
1395 // pairs for each function.
1396 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1397 return;
1398
1399 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1400 (llvm::InstrProfValueKind)ValueKind,
1401 NumValueSites[ValueKind]);
1402
1403 NumValueSites[ValueKind]++;
1404 }
1405 }
1406
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)1407 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1408 bool IsInMainFile) {
1409 CGM.getPGOStats().addVisited(IsInMainFile);
1410 RegionCounts.clear();
1411 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1412 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1413 if (auto E = RecordExpected.takeError()) {
1414 auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1415 if (IPE == llvm::instrprof_error::unknown_function)
1416 CGM.getPGOStats().addMissing(IsInMainFile);
1417 else if (IPE == llvm::instrprof_error::hash_mismatch)
1418 CGM.getPGOStats().addMismatched(IsInMainFile);
1419 else if (IPE == llvm::instrprof_error::malformed)
1420 // TODO: Consider a more specific warning for this case.
1421 CGM.getPGOStats().addMismatched(IsInMainFile);
1422 return;
1423 }
1424 ProfRecord =
1425 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1426 RegionCounts = ProfRecord->Counts;
1427 }
1428
1429 /// Calculate what to divide by to scale weights.
1430 ///
1431 /// Given the maximum weight, calculate a divisor that will scale all the
1432 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1433 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1434 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1435 }
1436
1437 /// Scale an individual branch weight (and add 1).
1438 ///
1439 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1440 ///
1441 /// According to Laplace's Rule of Succession, it is better to compute the
1442 /// weight based on the count plus 1, so universally add 1 to the value.
1443 ///
1444 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1445 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1446 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1447 assert(Scale && "scale by 0?");
1448 uint64_t Scaled = Weight / Scale + 1;
1449 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1450 return Scaled;
1451 }
1452
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1453 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1454 uint64_t FalseCount) const {
1455 // Check for empty weights.
1456 if (!TrueCount && !FalseCount)
1457 return nullptr;
1458
1459 // Calculate how to scale down to 32-bits.
1460 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1461
1462 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1463 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1464 scaleBranchWeight(FalseCount, Scale));
1465 }
1466
1467 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1468 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1469 // We need at least two elements to create meaningful weights.
1470 if (Weights.size() < 2)
1471 return nullptr;
1472
1473 // Check for empty weights.
1474 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1475 if (MaxWeight == 0)
1476 return nullptr;
1477
1478 // Calculate how to scale down to 32-bits.
1479 uint64_t Scale = calculateWeightScale(MaxWeight);
1480
1481 SmallVector<uint32_t, 16> ScaledWeights;
1482 ScaledWeights.reserve(Weights.size());
1483 for (uint64_t W : Weights)
1484 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1485
1486 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1487 return MDHelper.createBranchWeights(ScaledWeights);
1488 }
1489
1490 llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1491 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1492 uint64_t LoopCount) const {
1493 if (!PGO.haveRegionCounts())
1494 return nullptr;
1495 std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1496 if (!CondCount || *CondCount == 0)
1497 return nullptr;
1498 return createProfileWeights(LoopCount,
1499 std::max(*CondCount, LoopCount) - LoopCount);
1500 }
1501