1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 #include <optional> 25 26 static llvm::cl::opt<bool> 27 EnableValueProfiling("enable-value-profiling", 28 llvm::cl::desc("Enable value profiling"), 29 llvm::cl::Hidden, llvm::cl::init(false)); 30 31 using namespace clang; 32 using namespace CodeGen; 33 34 void CodeGenPGO::setFuncName(StringRef Name, 35 llvm::GlobalValue::LinkageTypes Linkage) { 36 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 37 FuncName = llvm::getPGOFuncName( 38 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 39 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 40 41 // If we're generating a profile, create a variable for the name. 42 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 43 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 44 } 45 46 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 47 setFuncName(Fn->getName(), Fn->getLinkage()); 48 // Create PGOFuncName meta data. 49 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 50 } 51 52 /// The version of the PGO hash algorithm. 53 enum PGOHashVersion : unsigned { 54 PGO_HASH_V1, 55 PGO_HASH_V2, 56 PGO_HASH_V3, 57 58 // Keep this set to the latest hash version. 59 PGO_HASH_LATEST = PGO_HASH_V3 60 }; 61 62 namespace { 63 /// Stable hasher for PGO region counters. 64 /// 65 /// PGOHash produces a stable hash of a given function's control flow. 66 /// 67 /// Changing the output of this hash will invalidate all previously generated 68 /// profiles -- i.e., don't do it. 69 /// 70 /// \note When this hash does eventually change (years?), we still need to 71 /// support old hashes. We'll need to pull in the version number from the 72 /// profile data format and use the matching hash function. 73 class PGOHash { 74 uint64_t Working; 75 unsigned Count; 76 PGOHashVersion HashVersion; 77 llvm::MD5 MD5; 78 79 static const int NumBitsPerType = 6; 80 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 81 static const unsigned TooBig = 1u << NumBitsPerType; 82 83 public: 84 /// Hash values for AST nodes. 85 /// 86 /// Distinct values for AST nodes that have region counters attached. 87 /// 88 /// These values must be stable. All new members must be added at the end, 89 /// and no members should be removed. Changing the enumeration value for an 90 /// AST node will affect the hash of every function that contains that node. 91 enum HashType : unsigned char { 92 None = 0, 93 LabelStmt = 1, 94 WhileStmt, 95 DoStmt, 96 ForStmt, 97 CXXForRangeStmt, 98 ObjCForCollectionStmt, 99 SwitchStmt, 100 CaseStmt, 101 DefaultStmt, 102 IfStmt, 103 CXXTryStmt, 104 CXXCatchStmt, 105 ConditionalOperator, 106 BinaryOperatorLAnd, 107 BinaryOperatorLOr, 108 BinaryConditionalOperator, 109 // The preceding values are available with PGO_HASH_V1. 110 111 EndOfScope, 112 IfThenBranch, 113 IfElseBranch, 114 GotoStmt, 115 IndirectGotoStmt, 116 BreakStmt, 117 ContinueStmt, 118 ReturnStmt, 119 ThrowExpr, 120 UnaryOperatorLNot, 121 BinaryOperatorLT, 122 BinaryOperatorGT, 123 BinaryOperatorLE, 124 BinaryOperatorGE, 125 BinaryOperatorEQ, 126 BinaryOperatorNE, 127 // The preceding values are available since PGO_HASH_V2. 128 129 // Keep this last. It's for the static assert that follows. 130 LastHashType 131 }; 132 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 133 134 PGOHash(PGOHashVersion HashVersion) 135 : Working(0), Count(0), HashVersion(HashVersion) {} 136 void combine(HashType Type); 137 uint64_t finalize(); 138 PGOHashVersion getHashVersion() const { return HashVersion; } 139 }; 140 const int PGOHash::NumBitsPerType; 141 const unsigned PGOHash::NumTypesPerWord; 142 const unsigned PGOHash::TooBig; 143 144 /// Get the PGO hash version used in the given indexed profile. 145 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 146 CodeGenModule &CGM) { 147 if (PGOReader->getVersion() <= 4) 148 return PGO_HASH_V1; 149 if (PGOReader->getVersion() <= 5) 150 return PGO_HASH_V2; 151 return PGO_HASH_V3; 152 } 153 154 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 155 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 156 using Base = RecursiveASTVisitor<MapRegionCounters>; 157 158 /// The next counter value to assign. 159 unsigned NextCounter; 160 /// The function hash. 161 PGOHash Hash; 162 /// The map of statements to counters. 163 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 164 /// The profile version. 165 uint64_t ProfileVersion; 166 167 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion, 168 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 169 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap), 170 ProfileVersion(ProfileVersion) {} 171 172 // Blocks and lambdas are handled as separate functions, so we need not 173 // traverse them in the parent context. 174 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 175 bool TraverseLambdaExpr(LambdaExpr *LE) { 176 // Traverse the captures, but not the body. 177 for (auto C : zip(LE->captures(), LE->capture_inits())) 178 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 179 return true; 180 } 181 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 182 183 bool VisitDecl(const Decl *D) { 184 switch (D->getKind()) { 185 default: 186 break; 187 case Decl::Function: 188 case Decl::CXXMethod: 189 case Decl::CXXConstructor: 190 case Decl::CXXDestructor: 191 case Decl::CXXConversion: 192 case Decl::ObjCMethod: 193 case Decl::Block: 194 case Decl::Captured: 195 CounterMap[D->getBody()] = NextCounter++; 196 break; 197 } 198 return true; 199 } 200 201 /// If \p S gets a fresh counter, update the counter mappings. Return the 202 /// V1 hash of \p S. 203 PGOHash::HashType updateCounterMappings(Stmt *S) { 204 auto Type = getHashType(PGO_HASH_V1, S); 205 if (Type != PGOHash::None) 206 CounterMap[S] = NextCounter++; 207 return Type; 208 } 209 210 /// The RHS of all logical operators gets a fresh counter in order to count 211 /// how many times the RHS evaluates to true or false, depending on the 212 /// semantics of the operator. This is only valid for ">= v7" of the profile 213 /// version so that we facilitate backward compatibility. 214 bool VisitBinaryOperator(BinaryOperator *S) { 215 if (ProfileVersion >= llvm::IndexedInstrProf::Version7) 216 if (S->isLogicalOp() && 217 CodeGenFunction::isInstrumentedCondition(S->getRHS())) 218 CounterMap[S->getRHS()] = NextCounter++; 219 return Base::VisitBinaryOperator(S); 220 } 221 222 /// Include \p S in the function hash. 223 bool VisitStmt(Stmt *S) { 224 auto Type = updateCounterMappings(S); 225 if (Hash.getHashVersion() != PGO_HASH_V1) 226 Type = getHashType(Hash.getHashVersion(), S); 227 if (Type != PGOHash::None) 228 Hash.combine(Type); 229 return true; 230 } 231 232 bool TraverseIfStmt(IfStmt *If) { 233 // If we used the V1 hash, use the default traversal. 234 if (Hash.getHashVersion() == PGO_HASH_V1) 235 return Base::TraverseIfStmt(If); 236 237 // Otherwise, keep track of which branch we're in while traversing. 238 VisitStmt(If); 239 for (Stmt *CS : If->children()) { 240 if (!CS) 241 continue; 242 if (CS == If->getThen()) 243 Hash.combine(PGOHash::IfThenBranch); 244 else if (CS == If->getElse()) 245 Hash.combine(PGOHash::IfElseBranch); 246 TraverseStmt(CS); 247 } 248 Hash.combine(PGOHash::EndOfScope); 249 return true; 250 } 251 252 // If the statement type \p N is nestable, and its nesting impacts profile 253 // stability, define a custom traversal which tracks the end of the statement 254 // in the hash (provided we're not using the V1 hash). 255 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 256 bool Traverse##N(N *S) { \ 257 Base::Traverse##N(S); \ 258 if (Hash.getHashVersion() != PGO_HASH_V1) \ 259 Hash.combine(PGOHash::EndOfScope); \ 260 return true; \ 261 } 262 263 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 264 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 265 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 266 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 267 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 268 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 269 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 270 271 /// Get version \p HashVersion of the PGO hash for \p S. 272 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 273 switch (S->getStmtClass()) { 274 default: 275 break; 276 case Stmt::LabelStmtClass: 277 return PGOHash::LabelStmt; 278 case Stmt::WhileStmtClass: 279 return PGOHash::WhileStmt; 280 case Stmt::DoStmtClass: 281 return PGOHash::DoStmt; 282 case Stmt::ForStmtClass: 283 return PGOHash::ForStmt; 284 case Stmt::CXXForRangeStmtClass: 285 return PGOHash::CXXForRangeStmt; 286 case Stmt::ObjCForCollectionStmtClass: 287 return PGOHash::ObjCForCollectionStmt; 288 case Stmt::SwitchStmtClass: 289 return PGOHash::SwitchStmt; 290 case Stmt::CaseStmtClass: 291 return PGOHash::CaseStmt; 292 case Stmt::DefaultStmtClass: 293 return PGOHash::DefaultStmt; 294 case Stmt::IfStmtClass: 295 return PGOHash::IfStmt; 296 case Stmt::CXXTryStmtClass: 297 return PGOHash::CXXTryStmt; 298 case Stmt::CXXCatchStmtClass: 299 return PGOHash::CXXCatchStmt; 300 case Stmt::ConditionalOperatorClass: 301 return PGOHash::ConditionalOperator; 302 case Stmt::BinaryConditionalOperatorClass: 303 return PGOHash::BinaryConditionalOperator; 304 case Stmt::BinaryOperatorClass: { 305 const BinaryOperator *BO = cast<BinaryOperator>(S); 306 if (BO->getOpcode() == BO_LAnd) 307 return PGOHash::BinaryOperatorLAnd; 308 if (BO->getOpcode() == BO_LOr) 309 return PGOHash::BinaryOperatorLOr; 310 if (HashVersion >= PGO_HASH_V2) { 311 switch (BO->getOpcode()) { 312 default: 313 break; 314 case BO_LT: 315 return PGOHash::BinaryOperatorLT; 316 case BO_GT: 317 return PGOHash::BinaryOperatorGT; 318 case BO_LE: 319 return PGOHash::BinaryOperatorLE; 320 case BO_GE: 321 return PGOHash::BinaryOperatorGE; 322 case BO_EQ: 323 return PGOHash::BinaryOperatorEQ; 324 case BO_NE: 325 return PGOHash::BinaryOperatorNE; 326 } 327 } 328 break; 329 } 330 } 331 332 if (HashVersion >= PGO_HASH_V2) { 333 switch (S->getStmtClass()) { 334 default: 335 break; 336 case Stmt::GotoStmtClass: 337 return PGOHash::GotoStmt; 338 case Stmt::IndirectGotoStmtClass: 339 return PGOHash::IndirectGotoStmt; 340 case Stmt::BreakStmtClass: 341 return PGOHash::BreakStmt; 342 case Stmt::ContinueStmtClass: 343 return PGOHash::ContinueStmt; 344 case Stmt::ReturnStmtClass: 345 return PGOHash::ReturnStmt; 346 case Stmt::CXXThrowExprClass: 347 return PGOHash::ThrowExpr; 348 case Stmt::UnaryOperatorClass: { 349 const UnaryOperator *UO = cast<UnaryOperator>(S); 350 if (UO->getOpcode() == UO_LNot) 351 return PGOHash::UnaryOperatorLNot; 352 break; 353 } 354 } 355 } 356 357 return PGOHash::None; 358 } 359 }; 360 361 /// A StmtVisitor that propagates the raw counts through the AST and 362 /// records the count at statements where the value may change. 363 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 364 /// PGO state. 365 CodeGenPGO &PGO; 366 367 /// A flag that is set when the current count should be recorded on the 368 /// next statement, such as at the exit of a loop. 369 bool RecordNextStmtCount; 370 371 /// The count at the current location in the traversal. 372 uint64_t CurrentCount; 373 374 /// The map of statements to count values. 375 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 376 377 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 378 struct BreakContinue { 379 uint64_t BreakCount; 380 uint64_t ContinueCount; 381 BreakContinue() : BreakCount(0), ContinueCount(0) {} 382 }; 383 SmallVector<BreakContinue, 8> BreakContinueStack; 384 385 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 386 CodeGenPGO &PGO) 387 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 388 389 void RecordStmtCount(const Stmt *S) { 390 if (RecordNextStmtCount) { 391 CountMap[S] = CurrentCount; 392 RecordNextStmtCount = false; 393 } 394 } 395 396 /// Set and return the current count. 397 uint64_t setCount(uint64_t Count) { 398 CurrentCount = Count; 399 return Count; 400 } 401 402 void VisitStmt(const Stmt *S) { 403 RecordStmtCount(S); 404 for (const Stmt *Child : S->children()) 405 if (Child) 406 this->Visit(Child); 407 } 408 409 void VisitFunctionDecl(const FunctionDecl *D) { 410 // Counter tracks entry to the function body. 411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 412 CountMap[D->getBody()] = BodyCount; 413 Visit(D->getBody()); 414 } 415 416 // Skip lambda expressions. We visit these as FunctionDecls when we're 417 // generating them and aren't interested in the body when generating a 418 // parent context. 419 void VisitLambdaExpr(const LambdaExpr *LE) {} 420 421 void VisitCapturedDecl(const CapturedDecl *D) { 422 // Counter tracks entry to the capture body. 423 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 424 CountMap[D->getBody()] = BodyCount; 425 Visit(D->getBody()); 426 } 427 428 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 429 // Counter tracks entry to the method body. 430 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 431 CountMap[D->getBody()] = BodyCount; 432 Visit(D->getBody()); 433 } 434 435 void VisitBlockDecl(const BlockDecl *D) { 436 // Counter tracks entry to the block body. 437 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 438 CountMap[D->getBody()] = BodyCount; 439 Visit(D->getBody()); 440 } 441 442 void VisitReturnStmt(const ReturnStmt *S) { 443 RecordStmtCount(S); 444 if (S->getRetValue()) 445 Visit(S->getRetValue()); 446 CurrentCount = 0; 447 RecordNextStmtCount = true; 448 } 449 450 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 451 RecordStmtCount(E); 452 if (E->getSubExpr()) 453 Visit(E->getSubExpr()); 454 CurrentCount = 0; 455 RecordNextStmtCount = true; 456 } 457 458 void VisitGotoStmt(const GotoStmt *S) { 459 RecordStmtCount(S); 460 CurrentCount = 0; 461 RecordNextStmtCount = true; 462 } 463 464 void VisitLabelStmt(const LabelStmt *S) { 465 RecordNextStmtCount = false; 466 // Counter tracks the block following the label. 467 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 468 CountMap[S] = BlockCount; 469 Visit(S->getSubStmt()); 470 } 471 472 void VisitBreakStmt(const BreakStmt *S) { 473 RecordStmtCount(S); 474 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 475 BreakContinueStack.back().BreakCount += CurrentCount; 476 CurrentCount = 0; 477 RecordNextStmtCount = true; 478 } 479 480 void VisitContinueStmt(const ContinueStmt *S) { 481 RecordStmtCount(S); 482 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 483 BreakContinueStack.back().ContinueCount += CurrentCount; 484 CurrentCount = 0; 485 RecordNextStmtCount = true; 486 } 487 488 void VisitWhileStmt(const WhileStmt *S) { 489 RecordStmtCount(S); 490 uint64_t ParentCount = CurrentCount; 491 492 BreakContinueStack.push_back(BreakContinue()); 493 // Visit the body region first so the break/continue adjustments can be 494 // included when visiting the condition. 495 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 496 CountMap[S->getBody()] = CurrentCount; 497 Visit(S->getBody()); 498 uint64_t BackedgeCount = CurrentCount; 499 500 // ...then go back and propagate counts through the condition. The count 501 // at the start of the condition is the sum of the incoming edges, 502 // the backedge from the end of the loop body, and the edges from 503 // continue statements. 504 BreakContinue BC = BreakContinueStack.pop_back_val(); 505 uint64_t CondCount = 506 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 507 CountMap[S->getCond()] = CondCount; 508 Visit(S->getCond()); 509 setCount(BC.BreakCount + CondCount - BodyCount); 510 RecordNextStmtCount = true; 511 } 512 513 void VisitDoStmt(const DoStmt *S) { 514 RecordStmtCount(S); 515 uint64_t LoopCount = PGO.getRegionCount(S); 516 517 BreakContinueStack.push_back(BreakContinue()); 518 // The count doesn't include the fallthrough from the parent scope. Add it. 519 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 520 CountMap[S->getBody()] = BodyCount; 521 Visit(S->getBody()); 522 uint64_t BackedgeCount = CurrentCount; 523 524 BreakContinue BC = BreakContinueStack.pop_back_val(); 525 // The count at the start of the condition is equal to the count at the 526 // end of the body, plus any continues. 527 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 528 CountMap[S->getCond()] = CondCount; 529 Visit(S->getCond()); 530 setCount(BC.BreakCount + CondCount - LoopCount); 531 RecordNextStmtCount = true; 532 } 533 534 void VisitForStmt(const ForStmt *S) { 535 RecordStmtCount(S); 536 if (S->getInit()) 537 Visit(S->getInit()); 538 539 uint64_t ParentCount = CurrentCount; 540 541 BreakContinueStack.push_back(BreakContinue()); 542 // Visit the body region first. (This is basically the same as a while 543 // loop; see further comments in VisitWhileStmt.) 544 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 545 CountMap[S->getBody()] = BodyCount; 546 Visit(S->getBody()); 547 uint64_t BackedgeCount = CurrentCount; 548 BreakContinue BC = BreakContinueStack.pop_back_val(); 549 550 // The increment is essentially part of the body but it needs to include 551 // the count for all the continue statements. 552 if (S->getInc()) { 553 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 554 CountMap[S->getInc()] = IncCount; 555 Visit(S->getInc()); 556 } 557 558 // ...then go back and propagate counts through the condition. 559 uint64_t CondCount = 560 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 561 if (S->getCond()) { 562 CountMap[S->getCond()] = CondCount; 563 Visit(S->getCond()); 564 } 565 setCount(BC.BreakCount + CondCount - BodyCount); 566 RecordNextStmtCount = true; 567 } 568 569 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 570 RecordStmtCount(S); 571 if (S->getInit()) 572 Visit(S->getInit()); 573 Visit(S->getLoopVarStmt()); 574 Visit(S->getRangeStmt()); 575 Visit(S->getBeginStmt()); 576 Visit(S->getEndStmt()); 577 578 uint64_t ParentCount = CurrentCount; 579 BreakContinueStack.push_back(BreakContinue()); 580 // Visit the body region first. (This is basically the same as a while 581 // loop; see further comments in VisitWhileStmt.) 582 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 583 CountMap[S->getBody()] = BodyCount; 584 Visit(S->getBody()); 585 uint64_t BackedgeCount = CurrentCount; 586 BreakContinue BC = BreakContinueStack.pop_back_val(); 587 588 // The increment is essentially part of the body but it needs to include 589 // the count for all the continue statements. 590 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 591 CountMap[S->getInc()] = IncCount; 592 Visit(S->getInc()); 593 594 // ...then go back and propagate counts through the condition. 595 uint64_t CondCount = 596 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 597 CountMap[S->getCond()] = CondCount; 598 Visit(S->getCond()); 599 setCount(BC.BreakCount + CondCount - BodyCount); 600 RecordNextStmtCount = true; 601 } 602 603 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 604 RecordStmtCount(S); 605 Visit(S->getElement()); 606 uint64_t ParentCount = CurrentCount; 607 BreakContinueStack.push_back(BreakContinue()); 608 // Counter tracks the body of the loop. 609 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 610 CountMap[S->getBody()] = BodyCount; 611 Visit(S->getBody()); 612 uint64_t BackedgeCount = CurrentCount; 613 BreakContinue BC = BreakContinueStack.pop_back_val(); 614 615 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 616 BodyCount); 617 RecordNextStmtCount = true; 618 } 619 620 void VisitSwitchStmt(const SwitchStmt *S) { 621 RecordStmtCount(S); 622 if (S->getInit()) 623 Visit(S->getInit()); 624 Visit(S->getCond()); 625 CurrentCount = 0; 626 BreakContinueStack.push_back(BreakContinue()); 627 Visit(S->getBody()); 628 // If the switch is inside a loop, add the continue counts. 629 BreakContinue BC = BreakContinueStack.pop_back_val(); 630 if (!BreakContinueStack.empty()) 631 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 632 // Counter tracks the exit block of the switch. 633 setCount(PGO.getRegionCount(S)); 634 RecordNextStmtCount = true; 635 } 636 637 void VisitSwitchCase(const SwitchCase *S) { 638 RecordNextStmtCount = false; 639 // Counter for this particular case. This counts only jumps from the 640 // switch header and does not include fallthrough from the case before 641 // this one. 642 uint64_t CaseCount = PGO.getRegionCount(S); 643 setCount(CurrentCount + CaseCount); 644 // We need the count without fallthrough in the mapping, so it's more useful 645 // for branch probabilities. 646 CountMap[S] = CaseCount; 647 RecordNextStmtCount = true; 648 Visit(S->getSubStmt()); 649 } 650 651 void VisitIfStmt(const IfStmt *S) { 652 RecordStmtCount(S); 653 654 if (S->isConsteval()) { 655 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse(); 656 if (Stm) 657 Visit(Stm); 658 return; 659 } 660 661 uint64_t ParentCount = CurrentCount; 662 if (S->getInit()) 663 Visit(S->getInit()); 664 Visit(S->getCond()); 665 666 // Counter tracks the "then" part of an if statement. The count for 667 // the "else" part, if it exists, will be calculated from this counter. 668 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 669 CountMap[S->getThen()] = ThenCount; 670 Visit(S->getThen()); 671 uint64_t OutCount = CurrentCount; 672 673 uint64_t ElseCount = ParentCount - ThenCount; 674 if (S->getElse()) { 675 setCount(ElseCount); 676 CountMap[S->getElse()] = ElseCount; 677 Visit(S->getElse()); 678 OutCount += CurrentCount; 679 } else 680 OutCount += ElseCount; 681 setCount(OutCount); 682 RecordNextStmtCount = true; 683 } 684 685 void VisitCXXTryStmt(const CXXTryStmt *S) { 686 RecordStmtCount(S); 687 Visit(S->getTryBlock()); 688 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 689 Visit(S->getHandler(I)); 690 // Counter tracks the continuation block of the try statement. 691 setCount(PGO.getRegionCount(S)); 692 RecordNextStmtCount = true; 693 } 694 695 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 696 RecordNextStmtCount = false; 697 // Counter tracks the catch statement's handler block. 698 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 699 CountMap[S] = CatchCount; 700 Visit(S->getHandlerBlock()); 701 } 702 703 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 704 RecordStmtCount(E); 705 uint64_t ParentCount = CurrentCount; 706 Visit(E->getCond()); 707 708 // Counter tracks the "true" part of a conditional operator. The 709 // count in the "false" part will be calculated from this counter. 710 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 711 CountMap[E->getTrueExpr()] = TrueCount; 712 Visit(E->getTrueExpr()); 713 uint64_t OutCount = CurrentCount; 714 715 uint64_t FalseCount = setCount(ParentCount - TrueCount); 716 CountMap[E->getFalseExpr()] = FalseCount; 717 Visit(E->getFalseExpr()); 718 OutCount += CurrentCount; 719 720 setCount(OutCount); 721 RecordNextStmtCount = true; 722 } 723 724 void VisitBinLAnd(const BinaryOperator *E) { 725 RecordStmtCount(E); 726 uint64_t ParentCount = CurrentCount; 727 Visit(E->getLHS()); 728 // Counter tracks the right hand side of a logical and operator. 729 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 730 CountMap[E->getRHS()] = RHSCount; 731 Visit(E->getRHS()); 732 setCount(ParentCount + RHSCount - CurrentCount); 733 RecordNextStmtCount = true; 734 } 735 736 void VisitBinLOr(const BinaryOperator *E) { 737 RecordStmtCount(E); 738 uint64_t ParentCount = CurrentCount; 739 Visit(E->getLHS()); 740 // Counter tracks the right hand side of a logical or operator. 741 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 742 CountMap[E->getRHS()] = RHSCount; 743 Visit(E->getRHS()); 744 setCount(ParentCount + RHSCount - CurrentCount); 745 RecordNextStmtCount = true; 746 } 747 }; 748 } // end anonymous namespace 749 750 void PGOHash::combine(HashType Type) { 751 // Check that we never combine 0 and only have six bits. 752 assert(Type && "Hash is invalid: unexpected type 0"); 753 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 754 755 // Pass through MD5 if enough work has built up. 756 if (Count && Count % NumTypesPerWord == 0) { 757 using namespace llvm::support; 758 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 759 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 760 Working = 0; 761 } 762 763 // Accumulate the current type. 764 ++Count; 765 Working = Working << NumBitsPerType | Type; 766 } 767 768 uint64_t PGOHash::finalize() { 769 // Use Working as the hash directly if we never used MD5. 770 if (Count <= NumTypesPerWord) 771 // No need to byte swap here, since none of the math was endian-dependent. 772 // This number will be byte-swapped as required on endianness transitions, 773 // so we will see the same value on the other side. 774 return Working; 775 776 // Check for remaining work in Working. 777 if (Working) { 778 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This 779 // is buggy because it converts a uint64_t into an array of uint8_t. 780 if (HashVersion < PGO_HASH_V3) { 781 MD5.update({(uint8_t)Working}); 782 } else { 783 using namespace llvm::support; 784 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 785 MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 786 } 787 } 788 789 // Finalize the MD5 and return the hash. 790 llvm::MD5::MD5Result Result; 791 MD5.final(Result); 792 return Result.low(); 793 } 794 795 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 796 const Decl *D = GD.getDecl(); 797 if (!D->hasBody()) 798 return; 799 800 // Skip CUDA/HIP kernel launch stub functions. 801 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && 802 D->hasAttr<CUDAGlobalAttr>()) 803 return; 804 805 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 806 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 807 if (!InstrumentRegions && !PGOReader) 808 return; 809 if (D->isImplicit()) 810 return; 811 // Constructors and destructors may be represented by several functions in IR. 812 // If so, instrument only base variant, others are implemented by delegation 813 // to the base one, it would be counted twice otherwise. 814 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 815 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 816 if (GD.getCtorType() != Ctor_Base && 817 CodeGenFunction::IsConstructorDelegationValid(CCD)) 818 return; 819 } 820 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 821 return; 822 823 CGM.ClearUnusedCoverageMapping(D); 824 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile)) 825 return; 826 if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile)) 827 return; 828 829 setFuncName(Fn); 830 831 mapRegionCounters(D); 832 if (CGM.getCodeGenOpts().CoverageMapping) 833 emitCounterRegionMapping(D); 834 if (PGOReader) { 835 SourceManager &SM = CGM.getContext().getSourceManager(); 836 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 837 computeRegionCounts(D); 838 applyFunctionAttributes(PGOReader, Fn); 839 } 840 } 841 842 void CodeGenPGO::mapRegionCounters(const Decl *D) { 843 // Use the latest hash version when inserting instrumentation, but use the 844 // version in the indexed profile if we're reading PGO data. 845 PGOHashVersion HashVersion = PGO_HASH_LATEST; 846 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version; 847 if (auto *PGOReader = CGM.getPGOReader()) { 848 HashVersion = getPGOHashVersion(PGOReader, CGM); 849 ProfileVersion = PGOReader->getVersion(); 850 } 851 852 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 853 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap); 854 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 855 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 856 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 857 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 858 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 859 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 860 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 861 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 862 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 863 NumRegionCounters = Walker.NextCounter; 864 FunctionHash = Walker.Hash.finalize(); 865 } 866 867 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 868 if (!D->getBody()) 869 return true; 870 871 // Skip host-only functions in the CUDA device compilation and device-only 872 // functions in the host compilation. Just roughly filter them out based on 873 // the function attributes. If there are effectively host-only or device-only 874 // ones, their coverage mapping may still be generated. 875 if (CGM.getLangOpts().CUDA && 876 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() && 877 !D->hasAttr<CUDAGlobalAttr>()) || 878 (!CGM.getLangOpts().CUDAIsDevice && 879 (D->hasAttr<CUDAGlobalAttr>() || 880 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>()))))) 881 return true; 882 883 // Don't map the functions in system headers. 884 const auto &SM = CGM.getContext().getSourceManager(); 885 auto Loc = D->getBody()->getBeginLoc(); 886 return SM.isInSystemHeader(Loc); 887 } 888 889 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 890 if (skipRegionMappingForDecl(D)) 891 return; 892 893 std::string CoverageMapping; 894 llvm::raw_string_ostream OS(CoverageMapping); 895 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 896 CGM.getContext().getSourceManager(), 897 CGM.getLangOpts(), RegionCounterMap.get()); 898 MappingGen.emitCounterMapping(D, OS); 899 OS.flush(); 900 901 if (CoverageMapping.empty()) 902 return; 903 904 CGM.getCoverageMapping()->addFunctionMappingRecord( 905 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 906 } 907 908 void 909 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 910 llvm::GlobalValue::LinkageTypes Linkage) { 911 if (skipRegionMappingForDecl(D)) 912 return; 913 914 std::string CoverageMapping; 915 llvm::raw_string_ostream OS(CoverageMapping); 916 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 917 CGM.getContext().getSourceManager(), 918 CGM.getLangOpts()); 919 MappingGen.emitEmptyMapping(D, OS); 920 OS.flush(); 921 922 if (CoverageMapping.empty()) 923 return; 924 925 setFuncName(Name, Linkage); 926 CGM.getCoverageMapping()->addFunctionMappingRecord( 927 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 928 } 929 930 void CodeGenPGO::computeRegionCounts(const Decl *D) { 931 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 932 ComputeRegionCounts Walker(*StmtCountMap, *this); 933 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 934 Walker.VisitFunctionDecl(FD); 935 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 936 Walker.VisitObjCMethodDecl(MD); 937 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 938 Walker.VisitBlockDecl(BD); 939 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 940 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 941 } 942 943 void 944 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 945 llvm::Function *Fn) { 946 if (!haveRegionCounts()) 947 return; 948 949 uint64_t FunctionCount = getRegionCount(nullptr); 950 Fn->setEntryCount(FunctionCount); 951 } 952 953 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 954 llvm::Value *StepV) { 955 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 956 return; 957 if (!Builder.GetInsertBlock()) 958 return; 959 960 unsigned Counter = (*RegionCounterMap)[S]; 961 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 962 963 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 964 Builder.getInt64(FunctionHash), 965 Builder.getInt32(NumRegionCounters), 966 Builder.getInt32(Counter), StepV}; 967 if (!StepV) 968 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 969 ArrayRef(Args, 4)); 970 else 971 Builder.CreateCall( 972 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 973 ArrayRef(Args)); 974 } 975 976 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) { 977 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 978 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling", 979 uint32_t(EnableValueProfiling)); 980 } 981 982 // This method either inserts a call to the profile run-time during 983 // instrumentation or puts profile data into metadata for PGO use. 984 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 985 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 986 987 if (!EnableValueProfiling) 988 return; 989 990 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 991 return; 992 993 if (isa<llvm::Constant>(ValuePtr)) 994 return; 995 996 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 997 if (InstrumentValueSites && RegionCounterMap) { 998 auto BuilderInsertPoint = Builder.saveIP(); 999 Builder.SetInsertPoint(ValueSite); 1000 llvm::Value *Args[5] = { 1001 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 1002 Builder.getInt64(FunctionHash), 1003 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 1004 Builder.getInt32(ValueKind), 1005 Builder.getInt32(NumValueSites[ValueKind]++) 1006 }; 1007 Builder.CreateCall( 1008 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 1009 Builder.restoreIP(BuilderInsertPoint); 1010 return; 1011 } 1012 1013 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 1014 if (PGOReader && haveRegionCounts()) { 1015 // We record the top most called three functions at each call site. 1016 // Profile metadata contains "VP" string identifying this metadata 1017 // as value profiling data, then a uint32_t value for the value profiling 1018 // kind, a uint64_t value for the total number of times the call is 1019 // executed, followed by the function hash and execution count (uint64_t) 1020 // pairs for each function. 1021 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 1022 return; 1023 1024 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 1025 (llvm::InstrProfValueKind)ValueKind, 1026 NumValueSites[ValueKind]); 1027 1028 NumValueSites[ValueKind]++; 1029 } 1030 } 1031 1032 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 1033 bool IsInMainFile) { 1034 CGM.getPGOStats().addVisited(IsInMainFile); 1035 RegionCounts.clear(); 1036 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 1037 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 1038 if (auto E = RecordExpected.takeError()) { 1039 auto IPE = llvm::InstrProfError::take(std::move(E)); 1040 if (IPE == llvm::instrprof_error::unknown_function) 1041 CGM.getPGOStats().addMissing(IsInMainFile); 1042 else if (IPE == llvm::instrprof_error::hash_mismatch) 1043 CGM.getPGOStats().addMismatched(IsInMainFile); 1044 else if (IPE == llvm::instrprof_error::malformed) 1045 // TODO: Consider a more specific warning for this case. 1046 CGM.getPGOStats().addMismatched(IsInMainFile); 1047 return; 1048 } 1049 ProfRecord = 1050 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 1051 RegionCounts = ProfRecord->Counts; 1052 } 1053 1054 /// Calculate what to divide by to scale weights. 1055 /// 1056 /// Given the maximum weight, calculate a divisor that will scale all the 1057 /// weights to strictly less than UINT32_MAX. 1058 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1059 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1060 } 1061 1062 /// Scale an individual branch weight (and add 1). 1063 /// 1064 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1065 /// 1066 /// According to Laplace's Rule of Succession, it is better to compute the 1067 /// weight based on the count plus 1, so universally add 1 to the value. 1068 /// 1069 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1070 /// greater than \c Weight. 1071 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1072 assert(Scale && "scale by 0?"); 1073 uint64_t Scaled = Weight / Scale + 1; 1074 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1075 return Scaled; 1076 } 1077 1078 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1079 uint64_t FalseCount) const { 1080 // Check for empty weights. 1081 if (!TrueCount && !FalseCount) 1082 return nullptr; 1083 1084 // Calculate how to scale down to 32-bits. 1085 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1086 1087 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1088 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1089 scaleBranchWeight(FalseCount, Scale)); 1090 } 1091 1092 llvm::MDNode * 1093 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const { 1094 // We need at least two elements to create meaningful weights. 1095 if (Weights.size() < 2) 1096 return nullptr; 1097 1098 // Check for empty weights. 1099 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1100 if (MaxWeight == 0) 1101 return nullptr; 1102 1103 // Calculate how to scale down to 32-bits. 1104 uint64_t Scale = calculateWeightScale(MaxWeight); 1105 1106 SmallVector<uint32_t, 16> ScaledWeights; 1107 ScaledWeights.reserve(Weights.size()); 1108 for (uint64_t W : Weights) 1109 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1110 1111 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1112 return MDHelper.createBranchWeights(ScaledWeights); 1113 } 1114 1115 llvm::MDNode * 1116 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1117 uint64_t LoopCount) const { 1118 if (!PGO.haveRegionCounts()) 1119 return nullptr; 1120 std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1121 if (!CondCount || *CondCount == 0) 1122 return nullptr; 1123 return createProfileWeights(LoopCount, 1124 std::max(*CondCount, LoopCount) - LoopCount); 1125 } 1126