1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 PGO_HASH_V3, 56 57 // Keep this set to the latest hash version. 58 PGO_HASH_LATEST = PGO_HASH_V3 59 }; 60 61 namespace { 62 /// Stable hasher for PGO region counters. 63 /// 64 /// PGOHash produces a stable hash of a given function's control flow. 65 /// 66 /// Changing the output of this hash will invalidate all previously generated 67 /// profiles -- i.e., don't do it. 68 /// 69 /// \note When this hash does eventually change (years?), we still need to 70 /// support old hashes. We'll need to pull in the version number from the 71 /// profile data format and use the matching hash function. 72 class PGOHash { 73 uint64_t Working; 74 unsigned Count; 75 PGOHashVersion HashVersion; 76 llvm::MD5 MD5; 77 78 static const int NumBitsPerType = 6; 79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 80 static const unsigned TooBig = 1u << NumBitsPerType; 81 82 public: 83 /// Hash values for AST nodes. 84 /// 85 /// Distinct values for AST nodes that have region counters attached. 86 /// 87 /// These values must be stable. All new members must be added at the end, 88 /// and no members should be removed. Changing the enumeration value for an 89 /// AST node will affect the hash of every function that contains that node. 90 enum HashType : unsigned char { 91 None = 0, 92 LabelStmt = 1, 93 WhileStmt, 94 DoStmt, 95 ForStmt, 96 CXXForRangeStmt, 97 ObjCForCollectionStmt, 98 SwitchStmt, 99 CaseStmt, 100 DefaultStmt, 101 IfStmt, 102 CXXTryStmt, 103 CXXCatchStmt, 104 ConditionalOperator, 105 BinaryOperatorLAnd, 106 BinaryOperatorLOr, 107 BinaryConditionalOperator, 108 // The preceding values are available with PGO_HASH_V1. 109 110 EndOfScope, 111 IfThenBranch, 112 IfElseBranch, 113 GotoStmt, 114 IndirectGotoStmt, 115 BreakStmt, 116 ContinueStmt, 117 ReturnStmt, 118 ThrowExpr, 119 UnaryOperatorLNot, 120 BinaryOperatorLT, 121 BinaryOperatorGT, 122 BinaryOperatorLE, 123 BinaryOperatorGE, 124 BinaryOperatorEQ, 125 BinaryOperatorNE, 126 // The preceding values are available since PGO_HASH_V2. 127 128 // Keep this last. It's for the static assert that follows. 129 LastHashType 130 }; 131 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 132 133 PGOHash(PGOHashVersion HashVersion) 134 : Working(0), Count(0), HashVersion(HashVersion) {} 135 void combine(HashType Type); 136 uint64_t finalize(); 137 PGOHashVersion getHashVersion() const { return HashVersion; } 138 }; 139 const int PGOHash::NumBitsPerType; 140 const unsigned PGOHash::NumTypesPerWord; 141 const unsigned PGOHash::TooBig; 142 143 /// Get the PGO hash version used in the given indexed profile. 144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 145 CodeGenModule &CGM) { 146 if (PGOReader->getVersion() <= 4) 147 return PGO_HASH_V1; 148 if (PGOReader->getVersion() <= 5) 149 return PGO_HASH_V2; 150 return PGO_HASH_V3; 151 } 152 153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 155 using Base = RecursiveASTVisitor<MapRegionCounters>; 156 157 /// The next counter value to assign. 158 unsigned NextCounter; 159 /// The function hash. 160 PGOHash Hash; 161 /// The map of statements to counters. 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 163 /// The profile version. 164 uint64_t ProfileVersion; 165 166 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion, 167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 168 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap), 169 ProfileVersion(ProfileVersion) {} 170 171 // Blocks and lambdas are handled as separate functions, so we need not 172 // traverse them in the parent context. 173 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 174 bool TraverseLambdaExpr(LambdaExpr *LE) { 175 // Traverse the captures, but not the body. 176 for (auto C : zip(LE->captures(), LE->capture_inits())) 177 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 178 return true; 179 } 180 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 181 182 bool VisitDecl(const Decl *D) { 183 switch (D->getKind()) { 184 default: 185 break; 186 case Decl::Function: 187 case Decl::CXXMethod: 188 case Decl::CXXConstructor: 189 case Decl::CXXDestructor: 190 case Decl::CXXConversion: 191 case Decl::ObjCMethod: 192 case Decl::Block: 193 case Decl::Captured: 194 CounterMap[D->getBody()] = NextCounter++; 195 break; 196 } 197 return true; 198 } 199 200 /// If \p S gets a fresh counter, update the counter mappings. Return the 201 /// V1 hash of \p S. 202 PGOHash::HashType updateCounterMappings(Stmt *S) { 203 auto Type = getHashType(PGO_HASH_V1, S); 204 if (Type != PGOHash::None) 205 CounterMap[S] = NextCounter++; 206 return Type; 207 } 208 209 /// The RHS of all logical operators gets a fresh counter in order to count 210 /// how many times the RHS evaluates to true or false, depending on the 211 /// semantics of the operator. This is only valid for ">= v7" of the profile 212 /// version so that we facilitate backward compatibility. 213 bool VisitBinaryOperator(BinaryOperator *S) { 214 if (ProfileVersion >= llvm::IndexedInstrProf::Version7) 215 if (S->isLogicalOp() && 216 CodeGenFunction::isInstrumentedCondition(S->getRHS())) 217 CounterMap[S->getRHS()] = NextCounter++; 218 return Base::VisitBinaryOperator(S); 219 } 220 221 /// Include \p S in the function hash. 222 bool VisitStmt(Stmt *S) { 223 auto Type = updateCounterMappings(S); 224 if (Hash.getHashVersion() != PGO_HASH_V1) 225 Type = getHashType(Hash.getHashVersion(), S); 226 if (Type != PGOHash::None) 227 Hash.combine(Type); 228 return true; 229 } 230 231 bool TraverseIfStmt(IfStmt *If) { 232 // If we used the V1 hash, use the default traversal. 233 if (Hash.getHashVersion() == PGO_HASH_V1) 234 return Base::TraverseIfStmt(If); 235 236 // Otherwise, keep track of which branch we're in while traversing. 237 VisitStmt(If); 238 for (Stmt *CS : If->children()) { 239 if (!CS) 240 continue; 241 if (CS == If->getThen()) 242 Hash.combine(PGOHash::IfThenBranch); 243 else if (CS == If->getElse()) 244 Hash.combine(PGOHash::IfElseBranch); 245 TraverseStmt(CS); 246 } 247 Hash.combine(PGOHash::EndOfScope); 248 return true; 249 } 250 251 // If the statement type \p N is nestable, and its nesting impacts profile 252 // stability, define a custom traversal which tracks the end of the statement 253 // in the hash (provided we're not using the V1 hash). 254 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 255 bool Traverse##N(N *S) { \ 256 Base::Traverse##N(S); \ 257 if (Hash.getHashVersion() != PGO_HASH_V1) \ 258 Hash.combine(PGOHash::EndOfScope); \ 259 return true; \ 260 } 261 262 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 263 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 264 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 265 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 266 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 267 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 268 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 269 270 /// Get version \p HashVersion of the PGO hash for \p S. 271 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 272 switch (S->getStmtClass()) { 273 default: 274 break; 275 case Stmt::LabelStmtClass: 276 return PGOHash::LabelStmt; 277 case Stmt::WhileStmtClass: 278 return PGOHash::WhileStmt; 279 case Stmt::DoStmtClass: 280 return PGOHash::DoStmt; 281 case Stmt::ForStmtClass: 282 return PGOHash::ForStmt; 283 case Stmt::CXXForRangeStmtClass: 284 return PGOHash::CXXForRangeStmt; 285 case Stmt::ObjCForCollectionStmtClass: 286 return PGOHash::ObjCForCollectionStmt; 287 case Stmt::SwitchStmtClass: 288 return PGOHash::SwitchStmt; 289 case Stmt::CaseStmtClass: 290 return PGOHash::CaseStmt; 291 case Stmt::DefaultStmtClass: 292 return PGOHash::DefaultStmt; 293 case Stmt::IfStmtClass: 294 return PGOHash::IfStmt; 295 case Stmt::CXXTryStmtClass: 296 return PGOHash::CXXTryStmt; 297 case Stmt::CXXCatchStmtClass: 298 return PGOHash::CXXCatchStmt; 299 case Stmt::ConditionalOperatorClass: 300 return PGOHash::ConditionalOperator; 301 case Stmt::BinaryConditionalOperatorClass: 302 return PGOHash::BinaryConditionalOperator; 303 case Stmt::BinaryOperatorClass: { 304 const BinaryOperator *BO = cast<BinaryOperator>(S); 305 if (BO->getOpcode() == BO_LAnd) 306 return PGOHash::BinaryOperatorLAnd; 307 if (BO->getOpcode() == BO_LOr) 308 return PGOHash::BinaryOperatorLOr; 309 if (HashVersion >= PGO_HASH_V2) { 310 switch (BO->getOpcode()) { 311 default: 312 break; 313 case BO_LT: 314 return PGOHash::BinaryOperatorLT; 315 case BO_GT: 316 return PGOHash::BinaryOperatorGT; 317 case BO_LE: 318 return PGOHash::BinaryOperatorLE; 319 case BO_GE: 320 return PGOHash::BinaryOperatorGE; 321 case BO_EQ: 322 return PGOHash::BinaryOperatorEQ; 323 case BO_NE: 324 return PGOHash::BinaryOperatorNE; 325 } 326 } 327 break; 328 } 329 } 330 331 if (HashVersion >= PGO_HASH_V2) { 332 switch (S->getStmtClass()) { 333 default: 334 break; 335 case Stmt::GotoStmtClass: 336 return PGOHash::GotoStmt; 337 case Stmt::IndirectGotoStmtClass: 338 return PGOHash::IndirectGotoStmt; 339 case Stmt::BreakStmtClass: 340 return PGOHash::BreakStmt; 341 case Stmt::ContinueStmtClass: 342 return PGOHash::ContinueStmt; 343 case Stmt::ReturnStmtClass: 344 return PGOHash::ReturnStmt; 345 case Stmt::CXXThrowExprClass: 346 return PGOHash::ThrowExpr; 347 case Stmt::UnaryOperatorClass: { 348 const UnaryOperator *UO = cast<UnaryOperator>(S); 349 if (UO->getOpcode() == UO_LNot) 350 return PGOHash::UnaryOperatorLNot; 351 break; 352 } 353 } 354 } 355 356 return PGOHash::None; 357 } 358 }; 359 360 /// A StmtVisitor that propagates the raw counts through the AST and 361 /// records the count at statements where the value may change. 362 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 363 /// PGO state. 364 CodeGenPGO &PGO; 365 366 /// A flag that is set when the current count should be recorded on the 367 /// next statement, such as at the exit of a loop. 368 bool RecordNextStmtCount; 369 370 /// The count at the current location in the traversal. 371 uint64_t CurrentCount; 372 373 /// The map of statements to count values. 374 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 375 376 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 377 struct BreakContinue { 378 uint64_t BreakCount; 379 uint64_t ContinueCount; 380 BreakContinue() : BreakCount(0), ContinueCount(0) {} 381 }; 382 SmallVector<BreakContinue, 8> BreakContinueStack; 383 384 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 385 CodeGenPGO &PGO) 386 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 387 388 void RecordStmtCount(const Stmt *S) { 389 if (RecordNextStmtCount) { 390 CountMap[S] = CurrentCount; 391 RecordNextStmtCount = false; 392 } 393 } 394 395 /// Set and return the current count. 396 uint64_t setCount(uint64_t Count) { 397 CurrentCount = Count; 398 return Count; 399 } 400 401 void VisitStmt(const Stmt *S) { 402 RecordStmtCount(S); 403 for (const Stmt *Child : S->children()) 404 if (Child) 405 this->Visit(Child); 406 } 407 408 void VisitFunctionDecl(const FunctionDecl *D) { 409 // Counter tracks entry to the function body. 410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 411 CountMap[D->getBody()] = BodyCount; 412 Visit(D->getBody()); 413 } 414 415 // Skip lambda expressions. We visit these as FunctionDecls when we're 416 // generating them and aren't interested in the body when generating a 417 // parent context. 418 void VisitLambdaExpr(const LambdaExpr *LE) {} 419 420 void VisitCapturedDecl(const CapturedDecl *D) { 421 // Counter tracks entry to the capture body. 422 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 423 CountMap[D->getBody()] = BodyCount; 424 Visit(D->getBody()); 425 } 426 427 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 428 // Counter tracks entry to the method body. 429 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 430 CountMap[D->getBody()] = BodyCount; 431 Visit(D->getBody()); 432 } 433 434 void VisitBlockDecl(const BlockDecl *D) { 435 // Counter tracks entry to the block body. 436 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 437 CountMap[D->getBody()] = BodyCount; 438 Visit(D->getBody()); 439 } 440 441 void VisitReturnStmt(const ReturnStmt *S) { 442 RecordStmtCount(S); 443 if (S->getRetValue()) 444 Visit(S->getRetValue()); 445 CurrentCount = 0; 446 RecordNextStmtCount = true; 447 } 448 449 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 450 RecordStmtCount(E); 451 if (E->getSubExpr()) 452 Visit(E->getSubExpr()); 453 CurrentCount = 0; 454 RecordNextStmtCount = true; 455 } 456 457 void VisitGotoStmt(const GotoStmt *S) { 458 RecordStmtCount(S); 459 CurrentCount = 0; 460 RecordNextStmtCount = true; 461 } 462 463 void VisitLabelStmt(const LabelStmt *S) { 464 RecordNextStmtCount = false; 465 // Counter tracks the block following the label. 466 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 467 CountMap[S] = BlockCount; 468 Visit(S->getSubStmt()); 469 } 470 471 void VisitBreakStmt(const BreakStmt *S) { 472 RecordStmtCount(S); 473 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 474 BreakContinueStack.back().BreakCount += CurrentCount; 475 CurrentCount = 0; 476 RecordNextStmtCount = true; 477 } 478 479 void VisitContinueStmt(const ContinueStmt *S) { 480 RecordStmtCount(S); 481 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 482 BreakContinueStack.back().ContinueCount += CurrentCount; 483 CurrentCount = 0; 484 RecordNextStmtCount = true; 485 } 486 487 void VisitWhileStmt(const WhileStmt *S) { 488 RecordStmtCount(S); 489 uint64_t ParentCount = CurrentCount; 490 491 BreakContinueStack.push_back(BreakContinue()); 492 // Visit the body region first so the break/continue adjustments can be 493 // included when visiting the condition. 494 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 495 CountMap[S->getBody()] = CurrentCount; 496 Visit(S->getBody()); 497 uint64_t BackedgeCount = CurrentCount; 498 499 // ...then go back and propagate counts through the condition. The count 500 // at the start of the condition is the sum of the incoming edges, 501 // the backedge from the end of the loop body, and the edges from 502 // continue statements. 503 BreakContinue BC = BreakContinueStack.pop_back_val(); 504 uint64_t CondCount = 505 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 506 CountMap[S->getCond()] = CondCount; 507 Visit(S->getCond()); 508 setCount(BC.BreakCount + CondCount - BodyCount); 509 RecordNextStmtCount = true; 510 } 511 512 void VisitDoStmt(const DoStmt *S) { 513 RecordStmtCount(S); 514 uint64_t LoopCount = PGO.getRegionCount(S); 515 516 BreakContinueStack.push_back(BreakContinue()); 517 // The count doesn't include the fallthrough from the parent scope. Add it. 518 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 519 CountMap[S->getBody()] = BodyCount; 520 Visit(S->getBody()); 521 uint64_t BackedgeCount = CurrentCount; 522 523 BreakContinue BC = BreakContinueStack.pop_back_val(); 524 // The count at the start of the condition is equal to the count at the 525 // end of the body, plus any continues. 526 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 527 CountMap[S->getCond()] = CondCount; 528 Visit(S->getCond()); 529 setCount(BC.BreakCount + CondCount - LoopCount); 530 RecordNextStmtCount = true; 531 } 532 533 void VisitForStmt(const ForStmt *S) { 534 RecordStmtCount(S); 535 if (S->getInit()) 536 Visit(S->getInit()); 537 538 uint64_t ParentCount = CurrentCount; 539 540 BreakContinueStack.push_back(BreakContinue()); 541 // Visit the body region first. (This is basically the same as a while 542 // loop; see further comments in VisitWhileStmt.) 543 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 544 CountMap[S->getBody()] = BodyCount; 545 Visit(S->getBody()); 546 uint64_t BackedgeCount = CurrentCount; 547 BreakContinue BC = BreakContinueStack.pop_back_val(); 548 549 // The increment is essentially part of the body but it needs to include 550 // the count for all the continue statements. 551 if (S->getInc()) { 552 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 553 CountMap[S->getInc()] = IncCount; 554 Visit(S->getInc()); 555 } 556 557 // ...then go back and propagate counts through the condition. 558 uint64_t CondCount = 559 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 560 if (S->getCond()) { 561 CountMap[S->getCond()] = CondCount; 562 Visit(S->getCond()); 563 } 564 setCount(BC.BreakCount + CondCount - BodyCount); 565 RecordNextStmtCount = true; 566 } 567 568 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 569 RecordStmtCount(S); 570 if (S->getInit()) 571 Visit(S->getInit()); 572 Visit(S->getLoopVarStmt()); 573 Visit(S->getRangeStmt()); 574 Visit(S->getBeginStmt()); 575 Visit(S->getEndStmt()); 576 577 uint64_t ParentCount = CurrentCount; 578 BreakContinueStack.push_back(BreakContinue()); 579 // Visit the body region first. (This is basically the same as a while 580 // loop; see further comments in VisitWhileStmt.) 581 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 582 CountMap[S->getBody()] = BodyCount; 583 Visit(S->getBody()); 584 uint64_t BackedgeCount = CurrentCount; 585 BreakContinue BC = BreakContinueStack.pop_back_val(); 586 587 // The increment is essentially part of the body but it needs to include 588 // the count for all the continue statements. 589 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 590 CountMap[S->getInc()] = IncCount; 591 Visit(S->getInc()); 592 593 // ...then go back and propagate counts through the condition. 594 uint64_t CondCount = 595 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 596 CountMap[S->getCond()] = CondCount; 597 Visit(S->getCond()); 598 setCount(BC.BreakCount + CondCount - BodyCount); 599 RecordNextStmtCount = true; 600 } 601 602 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 603 RecordStmtCount(S); 604 Visit(S->getElement()); 605 uint64_t ParentCount = CurrentCount; 606 BreakContinueStack.push_back(BreakContinue()); 607 // Counter tracks the body of the loop. 608 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 609 CountMap[S->getBody()] = BodyCount; 610 Visit(S->getBody()); 611 uint64_t BackedgeCount = CurrentCount; 612 BreakContinue BC = BreakContinueStack.pop_back_val(); 613 614 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 615 BodyCount); 616 RecordNextStmtCount = true; 617 } 618 619 void VisitSwitchStmt(const SwitchStmt *S) { 620 RecordStmtCount(S); 621 if (S->getInit()) 622 Visit(S->getInit()); 623 Visit(S->getCond()); 624 CurrentCount = 0; 625 BreakContinueStack.push_back(BreakContinue()); 626 Visit(S->getBody()); 627 // If the switch is inside a loop, add the continue counts. 628 BreakContinue BC = BreakContinueStack.pop_back_val(); 629 if (!BreakContinueStack.empty()) 630 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 631 // Counter tracks the exit block of the switch. 632 setCount(PGO.getRegionCount(S)); 633 RecordNextStmtCount = true; 634 } 635 636 void VisitSwitchCase(const SwitchCase *S) { 637 RecordNextStmtCount = false; 638 // Counter for this particular case. This counts only jumps from the 639 // switch header and does not include fallthrough from the case before 640 // this one. 641 uint64_t CaseCount = PGO.getRegionCount(S); 642 setCount(CurrentCount + CaseCount); 643 // We need the count without fallthrough in the mapping, so it's more useful 644 // for branch probabilities. 645 CountMap[S] = CaseCount; 646 RecordNextStmtCount = true; 647 Visit(S->getSubStmt()); 648 } 649 650 void VisitIfStmt(const IfStmt *S) { 651 RecordStmtCount(S); 652 653 if (S->isConsteval()) { 654 const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse(); 655 if (Stm) 656 Visit(Stm); 657 return; 658 } 659 660 uint64_t ParentCount = CurrentCount; 661 if (S->getInit()) 662 Visit(S->getInit()); 663 Visit(S->getCond()); 664 665 // Counter tracks the "then" part of an if statement. The count for 666 // the "else" part, if it exists, will be calculated from this counter. 667 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 668 CountMap[S->getThen()] = ThenCount; 669 Visit(S->getThen()); 670 uint64_t OutCount = CurrentCount; 671 672 uint64_t ElseCount = ParentCount - ThenCount; 673 if (S->getElse()) { 674 setCount(ElseCount); 675 CountMap[S->getElse()] = ElseCount; 676 Visit(S->getElse()); 677 OutCount += CurrentCount; 678 } else 679 OutCount += ElseCount; 680 setCount(OutCount); 681 RecordNextStmtCount = true; 682 } 683 684 void VisitCXXTryStmt(const CXXTryStmt *S) { 685 RecordStmtCount(S); 686 Visit(S->getTryBlock()); 687 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 688 Visit(S->getHandler(I)); 689 // Counter tracks the continuation block of the try statement. 690 setCount(PGO.getRegionCount(S)); 691 RecordNextStmtCount = true; 692 } 693 694 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 695 RecordNextStmtCount = false; 696 // Counter tracks the catch statement's handler block. 697 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 698 CountMap[S] = CatchCount; 699 Visit(S->getHandlerBlock()); 700 } 701 702 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 703 RecordStmtCount(E); 704 uint64_t ParentCount = CurrentCount; 705 Visit(E->getCond()); 706 707 // Counter tracks the "true" part of a conditional operator. The 708 // count in the "false" part will be calculated from this counter. 709 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 710 CountMap[E->getTrueExpr()] = TrueCount; 711 Visit(E->getTrueExpr()); 712 uint64_t OutCount = CurrentCount; 713 714 uint64_t FalseCount = setCount(ParentCount - TrueCount); 715 CountMap[E->getFalseExpr()] = FalseCount; 716 Visit(E->getFalseExpr()); 717 OutCount += CurrentCount; 718 719 setCount(OutCount); 720 RecordNextStmtCount = true; 721 } 722 723 void VisitBinLAnd(const BinaryOperator *E) { 724 RecordStmtCount(E); 725 uint64_t ParentCount = CurrentCount; 726 Visit(E->getLHS()); 727 // Counter tracks the right hand side of a logical and operator. 728 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 729 CountMap[E->getRHS()] = RHSCount; 730 Visit(E->getRHS()); 731 setCount(ParentCount + RHSCount - CurrentCount); 732 RecordNextStmtCount = true; 733 } 734 735 void VisitBinLOr(const BinaryOperator *E) { 736 RecordStmtCount(E); 737 uint64_t ParentCount = CurrentCount; 738 Visit(E->getLHS()); 739 // Counter tracks the right hand side of a logical or operator. 740 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 741 CountMap[E->getRHS()] = RHSCount; 742 Visit(E->getRHS()); 743 setCount(ParentCount + RHSCount - CurrentCount); 744 RecordNextStmtCount = true; 745 } 746 }; 747 } // end anonymous namespace 748 749 void PGOHash::combine(HashType Type) { 750 // Check that we never combine 0 and only have six bits. 751 assert(Type && "Hash is invalid: unexpected type 0"); 752 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 753 754 // Pass through MD5 if enough work has built up. 755 if (Count && Count % NumTypesPerWord == 0) { 756 using namespace llvm::support; 757 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 758 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 759 Working = 0; 760 } 761 762 // Accumulate the current type. 763 ++Count; 764 Working = Working << NumBitsPerType | Type; 765 } 766 767 uint64_t PGOHash::finalize() { 768 // Use Working as the hash directly if we never used MD5. 769 if (Count <= NumTypesPerWord) 770 // No need to byte swap here, since none of the math was endian-dependent. 771 // This number will be byte-swapped as required on endianness transitions, 772 // so we will see the same value on the other side. 773 return Working; 774 775 // Check for remaining work in Working. 776 if (Working) { 777 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This 778 // is buggy because it converts a uint64_t into an array of uint8_t. 779 if (HashVersion < PGO_HASH_V3) { 780 MD5.update({(uint8_t)Working}); 781 } else { 782 using namespace llvm::support; 783 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 784 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 785 } 786 } 787 788 // Finalize the MD5 and return the hash. 789 llvm::MD5::MD5Result Result; 790 MD5.final(Result); 791 return Result.low(); 792 } 793 794 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 795 const Decl *D = GD.getDecl(); 796 if (!D->hasBody()) 797 return; 798 799 // Skip CUDA/HIP kernel launch stub functions. 800 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && 801 D->hasAttr<CUDAGlobalAttr>()) 802 return; 803 804 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 805 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 806 if (!InstrumentRegions && !PGOReader) 807 return; 808 if (D->isImplicit()) 809 return; 810 // Constructors and destructors may be represented by several functions in IR. 811 // If so, instrument only base variant, others are implemented by delegation 812 // to the base one, it would be counted twice otherwise. 813 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 814 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 815 if (GD.getCtorType() != Ctor_Base && 816 CodeGenFunction::IsConstructorDelegationValid(CCD)) 817 return; 818 } 819 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 820 return; 821 822 CGM.ClearUnusedCoverageMapping(D); 823 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile)) 824 return; 825 826 setFuncName(Fn); 827 828 mapRegionCounters(D); 829 if (CGM.getCodeGenOpts().CoverageMapping) 830 emitCounterRegionMapping(D); 831 if (PGOReader) { 832 SourceManager &SM = CGM.getContext().getSourceManager(); 833 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 834 computeRegionCounts(D); 835 applyFunctionAttributes(PGOReader, Fn); 836 } 837 } 838 839 void CodeGenPGO::mapRegionCounters(const Decl *D) { 840 // Use the latest hash version when inserting instrumentation, but use the 841 // version in the indexed profile if we're reading PGO data. 842 PGOHashVersion HashVersion = PGO_HASH_LATEST; 843 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version; 844 if (auto *PGOReader = CGM.getPGOReader()) { 845 HashVersion = getPGOHashVersion(PGOReader, CGM); 846 ProfileVersion = PGOReader->getVersion(); 847 } 848 849 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 850 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap); 851 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 852 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 853 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 854 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 855 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 856 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 857 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 858 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 859 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 860 NumRegionCounters = Walker.NextCounter; 861 FunctionHash = Walker.Hash.finalize(); 862 } 863 864 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 865 if (!D->getBody()) 866 return true; 867 868 // Skip host-only functions in the CUDA device compilation and device-only 869 // functions in the host compilation. Just roughly filter them out based on 870 // the function attributes. If there are effectively host-only or device-only 871 // ones, their coverage mapping may still be generated. 872 if (CGM.getLangOpts().CUDA && 873 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() && 874 !D->hasAttr<CUDAGlobalAttr>()) || 875 (!CGM.getLangOpts().CUDAIsDevice && 876 (D->hasAttr<CUDAGlobalAttr>() || 877 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>()))))) 878 return true; 879 880 // Don't map the functions in system headers. 881 const auto &SM = CGM.getContext().getSourceManager(); 882 auto Loc = D->getBody()->getBeginLoc(); 883 return SM.isInSystemHeader(Loc); 884 } 885 886 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 887 if (skipRegionMappingForDecl(D)) 888 return; 889 890 std::string CoverageMapping; 891 llvm::raw_string_ostream OS(CoverageMapping); 892 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 893 CGM.getContext().getSourceManager(), 894 CGM.getLangOpts(), RegionCounterMap.get()); 895 MappingGen.emitCounterMapping(D, OS); 896 OS.flush(); 897 898 if (CoverageMapping.empty()) 899 return; 900 901 CGM.getCoverageMapping()->addFunctionMappingRecord( 902 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 903 } 904 905 void 906 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 907 llvm::GlobalValue::LinkageTypes Linkage) { 908 if (skipRegionMappingForDecl(D)) 909 return; 910 911 std::string CoverageMapping; 912 llvm::raw_string_ostream OS(CoverageMapping); 913 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 914 CGM.getContext().getSourceManager(), 915 CGM.getLangOpts()); 916 MappingGen.emitEmptyMapping(D, OS); 917 OS.flush(); 918 919 if (CoverageMapping.empty()) 920 return; 921 922 setFuncName(Name, Linkage); 923 CGM.getCoverageMapping()->addFunctionMappingRecord( 924 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 925 } 926 927 void CodeGenPGO::computeRegionCounts(const Decl *D) { 928 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 929 ComputeRegionCounts Walker(*StmtCountMap, *this); 930 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 931 Walker.VisitFunctionDecl(FD); 932 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 933 Walker.VisitObjCMethodDecl(MD); 934 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 935 Walker.VisitBlockDecl(BD); 936 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 937 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 938 } 939 940 void 941 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 942 llvm::Function *Fn) { 943 if (!haveRegionCounts()) 944 return; 945 946 uint64_t FunctionCount = getRegionCount(nullptr); 947 Fn->setEntryCount(FunctionCount); 948 } 949 950 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 951 llvm::Value *StepV) { 952 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 953 return; 954 if (!Builder.GetInsertBlock()) 955 return; 956 957 unsigned Counter = (*RegionCounterMap)[S]; 958 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 959 960 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 961 Builder.getInt64(FunctionHash), 962 Builder.getInt32(NumRegionCounters), 963 Builder.getInt32(Counter), StepV}; 964 if (!StepV) 965 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 966 makeArrayRef(Args, 4)); 967 else 968 Builder.CreateCall( 969 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 970 makeArrayRef(Args)); 971 } 972 973 void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) { 974 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 975 M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling", 976 uint32_t(EnableValueProfiling)); 977 } 978 979 // This method either inserts a call to the profile run-time during 980 // instrumentation or puts profile data into metadata for PGO use. 981 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 982 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 983 984 if (!EnableValueProfiling) 985 return; 986 987 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 988 return; 989 990 if (isa<llvm::Constant>(ValuePtr)) 991 return; 992 993 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 994 if (InstrumentValueSites && RegionCounterMap) { 995 auto BuilderInsertPoint = Builder.saveIP(); 996 Builder.SetInsertPoint(ValueSite); 997 llvm::Value *Args[5] = { 998 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 999 Builder.getInt64(FunctionHash), 1000 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 1001 Builder.getInt32(ValueKind), 1002 Builder.getInt32(NumValueSites[ValueKind]++) 1003 }; 1004 Builder.CreateCall( 1005 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 1006 Builder.restoreIP(BuilderInsertPoint); 1007 return; 1008 } 1009 1010 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 1011 if (PGOReader && haveRegionCounts()) { 1012 // We record the top most called three functions at each call site. 1013 // Profile metadata contains "VP" string identifying this metadata 1014 // as value profiling data, then a uint32_t value for the value profiling 1015 // kind, a uint64_t value for the total number of times the call is 1016 // executed, followed by the function hash and execution count (uint64_t) 1017 // pairs for each function. 1018 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 1019 return; 1020 1021 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 1022 (llvm::InstrProfValueKind)ValueKind, 1023 NumValueSites[ValueKind]); 1024 1025 NumValueSites[ValueKind]++; 1026 } 1027 } 1028 1029 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 1030 bool IsInMainFile) { 1031 CGM.getPGOStats().addVisited(IsInMainFile); 1032 RegionCounts.clear(); 1033 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 1034 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 1035 if (auto E = RecordExpected.takeError()) { 1036 auto IPE = llvm::InstrProfError::take(std::move(E)); 1037 if (IPE == llvm::instrprof_error::unknown_function) 1038 CGM.getPGOStats().addMissing(IsInMainFile); 1039 else if (IPE == llvm::instrprof_error::hash_mismatch) 1040 CGM.getPGOStats().addMismatched(IsInMainFile); 1041 else if (IPE == llvm::instrprof_error::malformed) 1042 // TODO: Consider a more specific warning for this case. 1043 CGM.getPGOStats().addMismatched(IsInMainFile); 1044 return; 1045 } 1046 ProfRecord = 1047 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 1048 RegionCounts = ProfRecord->Counts; 1049 } 1050 1051 /// Calculate what to divide by to scale weights. 1052 /// 1053 /// Given the maximum weight, calculate a divisor that will scale all the 1054 /// weights to strictly less than UINT32_MAX. 1055 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1056 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1057 } 1058 1059 /// Scale an individual branch weight (and add 1). 1060 /// 1061 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1062 /// 1063 /// According to Laplace's Rule of Succession, it is better to compute the 1064 /// weight based on the count plus 1, so universally add 1 to the value. 1065 /// 1066 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1067 /// greater than \c Weight. 1068 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1069 assert(Scale && "scale by 0?"); 1070 uint64_t Scaled = Weight / Scale + 1; 1071 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1072 return Scaled; 1073 } 1074 1075 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1076 uint64_t FalseCount) const { 1077 // Check for empty weights. 1078 if (!TrueCount && !FalseCount) 1079 return nullptr; 1080 1081 // Calculate how to scale down to 32-bits. 1082 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1083 1084 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1085 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1086 scaleBranchWeight(FalseCount, Scale)); 1087 } 1088 1089 llvm::MDNode * 1090 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const { 1091 // We need at least two elements to create meaningful weights. 1092 if (Weights.size() < 2) 1093 return nullptr; 1094 1095 // Check for empty weights. 1096 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1097 if (MaxWeight == 0) 1098 return nullptr; 1099 1100 // Calculate how to scale down to 32-bits. 1101 uint64_t Scale = calculateWeightScale(MaxWeight); 1102 1103 SmallVector<uint32_t, 16> ScaledWeights; 1104 ScaledWeights.reserve(Weights.size()); 1105 for (uint64_t W : Weights) 1106 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1107 1108 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1109 return MDHelper.createBranchWeights(ScaledWeights); 1110 } 1111 1112 llvm::MDNode * 1113 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1114 uint64_t LoopCount) const { 1115 if (!PGO.haveRegionCounts()) 1116 return nullptr; 1117 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1118 if (!CondCount || *CondCount == 0) 1119 return nullptr; 1120 return createProfileWeights(LoopCount, 1121 std::max(*CondCount, LoopCount) - LoopCount); 1122 } 1123