1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 PGO_HASH_V3, 56 57 // Keep this set to the latest hash version. 58 PGO_HASH_LATEST = PGO_HASH_V3 59 }; 60 61 namespace { 62 /// Stable hasher for PGO region counters. 63 /// 64 /// PGOHash produces a stable hash of a given function's control flow. 65 /// 66 /// Changing the output of this hash will invalidate all previously generated 67 /// profiles -- i.e., don't do it. 68 /// 69 /// \note When this hash does eventually change (years?), we still need to 70 /// support old hashes. We'll need to pull in the version number from the 71 /// profile data format and use the matching hash function. 72 class PGOHash { 73 uint64_t Working; 74 unsigned Count; 75 PGOHashVersion HashVersion; 76 llvm::MD5 MD5; 77 78 static const int NumBitsPerType = 6; 79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 80 static const unsigned TooBig = 1u << NumBitsPerType; 81 82 public: 83 /// Hash values for AST nodes. 84 /// 85 /// Distinct values for AST nodes that have region counters attached. 86 /// 87 /// These values must be stable. All new members must be added at the end, 88 /// and no members should be removed. Changing the enumeration value for an 89 /// AST node will affect the hash of every function that contains that node. 90 enum HashType : unsigned char { 91 None = 0, 92 LabelStmt = 1, 93 WhileStmt, 94 DoStmt, 95 ForStmt, 96 CXXForRangeStmt, 97 ObjCForCollectionStmt, 98 SwitchStmt, 99 CaseStmt, 100 DefaultStmt, 101 IfStmt, 102 CXXTryStmt, 103 CXXCatchStmt, 104 ConditionalOperator, 105 BinaryOperatorLAnd, 106 BinaryOperatorLOr, 107 BinaryConditionalOperator, 108 // The preceding values are available with PGO_HASH_V1. 109 110 EndOfScope, 111 IfThenBranch, 112 IfElseBranch, 113 GotoStmt, 114 IndirectGotoStmt, 115 BreakStmt, 116 ContinueStmt, 117 ReturnStmt, 118 ThrowExpr, 119 UnaryOperatorLNot, 120 BinaryOperatorLT, 121 BinaryOperatorGT, 122 BinaryOperatorLE, 123 BinaryOperatorGE, 124 BinaryOperatorEQ, 125 BinaryOperatorNE, 126 // The preceding values are available since PGO_HASH_V2. 127 128 // Keep this last. It's for the static assert that follows. 129 LastHashType 130 }; 131 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 132 133 PGOHash(PGOHashVersion HashVersion) 134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 135 void combine(HashType Type); 136 uint64_t finalize(); 137 PGOHashVersion getHashVersion() const { return HashVersion; } 138 }; 139 const int PGOHash::NumBitsPerType; 140 const unsigned PGOHash::NumTypesPerWord; 141 const unsigned PGOHash::TooBig; 142 143 /// Get the PGO hash version used in the given indexed profile. 144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 145 CodeGenModule &CGM) { 146 if (PGOReader->getVersion() <= 4) 147 return PGO_HASH_V1; 148 if (PGOReader->getVersion() <= 5) 149 return PGO_HASH_V2; 150 return PGO_HASH_V3; 151 } 152 153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 155 using Base = RecursiveASTVisitor<MapRegionCounters>; 156 157 /// The next counter value to assign. 158 unsigned NextCounter; 159 /// The function hash. 160 PGOHash Hash; 161 /// The map of statements to counters. 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 163 /// The profile version. 164 uint64_t ProfileVersion; 165 166 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion, 167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 168 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap), 169 ProfileVersion(ProfileVersion) {} 170 171 // Blocks and lambdas are handled as separate functions, so we need not 172 // traverse them in the parent context. 173 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 174 bool TraverseLambdaExpr(LambdaExpr *LE) { 175 // Traverse the captures, but not the body. 176 for (auto C : zip(LE->captures(), LE->capture_inits())) 177 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 178 return true; 179 } 180 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 181 182 bool VisitDecl(const Decl *D) { 183 switch (D->getKind()) { 184 default: 185 break; 186 case Decl::Function: 187 case Decl::CXXMethod: 188 case Decl::CXXConstructor: 189 case Decl::CXXDestructor: 190 case Decl::CXXConversion: 191 case Decl::ObjCMethod: 192 case Decl::Block: 193 case Decl::Captured: 194 CounterMap[D->getBody()] = NextCounter++; 195 break; 196 } 197 return true; 198 } 199 200 /// If \p S gets a fresh counter, update the counter mappings. Return the 201 /// V1 hash of \p S. 202 PGOHash::HashType updateCounterMappings(Stmt *S) { 203 auto Type = getHashType(PGO_HASH_V1, S); 204 if (Type != PGOHash::None) 205 CounterMap[S] = NextCounter++; 206 return Type; 207 } 208 209 /// The RHS of all logical operators gets a fresh counter in order to count 210 /// how many times the RHS evaluates to true or false, depending on the 211 /// semantics of the operator. This is only valid for ">= v7" of the profile 212 /// version so that we facilitate backward compatibility. 213 bool VisitBinaryOperator(BinaryOperator *S) { 214 if (ProfileVersion >= llvm::IndexedInstrProf::Version7) 215 if (S->isLogicalOp() && 216 CodeGenFunction::isInstrumentedCondition(S->getRHS())) 217 CounterMap[S->getRHS()] = NextCounter++; 218 return Base::VisitBinaryOperator(S); 219 } 220 221 /// Include \p S in the function hash. 222 bool VisitStmt(Stmt *S) { 223 auto Type = updateCounterMappings(S); 224 if (Hash.getHashVersion() != PGO_HASH_V1) 225 Type = getHashType(Hash.getHashVersion(), S); 226 if (Type != PGOHash::None) 227 Hash.combine(Type); 228 return true; 229 } 230 231 bool TraverseIfStmt(IfStmt *If) { 232 // If we used the V1 hash, use the default traversal. 233 if (Hash.getHashVersion() == PGO_HASH_V1) 234 return Base::TraverseIfStmt(If); 235 236 // Otherwise, keep track of which branch we're in while traversing. 237 VisitStmt(If); 238 for (Stmt *CS : If->children()) { 239 if (!CS) 240 continue; 241 if (CS == If->getThen()) 242 Hash.combine(PGOHash::IfThenBranch); 243 else if (CS == If->getElse()) 244 Hash.combine(PGOHash::IfElseBranch); 245 TraverseStmt(CS); 246 } 247 Hash.combine(PGOHash::EndOfScope); 248 return true; 249 } 250 251 // If the statement type \p N is nestable, and its nesting impacts profile 252 // stability, define a custom traversal which tracks the end of the statement 253 // in the hash (provided we're not using the V1 hash). 254 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 255 bool Traverse##N(N *S) { \ 256 Base::Traverse##N(S); \ 257 if (Hash.getHashVersion() != PGO_HASH_V1) \ 258 Hash.combine(PGOHash::EndOfScope); \ 259 return true; \ 260 } 261 262 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 263 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 264 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 265 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 266 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 267 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 268 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 269 270 /// Get version \p HashVersion of the PGO hash for \p S. 271 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 272 switch (S->getStmtClass()) { 273 default: 274 break; 275 case Stmt::LabelStmtClass: 276 return PGOHash::LabelStmt; 277 case Stmt::WhileStmtClass: 278 return PGOHash::WhileStmt; 279 case Stmt::DoStmtClass: 280 return PGOHash::DoStmt; 281 case Stmt::ForStmtClass: 282 return PGOHash::ForStmt; 283 case Stmt::CXXForRangeStmtClass: 284 return PGOHash::CXXForRangeStmt; 285 case Stmt::ObjCForCollectionStmtClass: 286 return PGOHash::ObjCForCollectionStmt; 287 case Stmt::SwitchStmtClass: 288 return PGOHash::SwitchStmt; 289 case Stmt::CaseStmtClass: 290 return PGOHash::CaseStmt; 291 case Stmt::DefaultStmtClass: 292 return PGOHash::DefaultStmt; 293 case Stmt::IfStmtClass: 294 return PGOHash::IfStmt; 295 case Stmt::CXXTryStmtClass: 296 return PGOHash::CXXTryStmt; 297 case Stmt::CXXCatchStmtClass: 298 return PGOHash::CXXCatchStmt; 299 case Stmt::ConditionalOperatorClass: 300 return PGOHash::ConditionalOperator; 301 case Stmt::BinaryConditionalOperatorClass: 302 return PGOHash::BinaryConditionalOperator; 303 case Stmt::BinaryOperatorClass: { 304 const BinaryOperator *BO = cast<BinaryOperator>(S); 305 if (BO->getOpcode() == BO_LAnd) 306 return PGOHash::BinaryOperatorLAnd; 307 if (BO->getOpcode() == BO_LOr) 308 return PGOHash::BinaryOperatorLOr; 309 if (HashVersion >= PGO_HASH_V2) { 310 switch (BO->getOpcode()) { 311 default: 312 break; 313 case BO_LT: 314 return PGOHash::BinaryOperatorLT; 315 case BO_GT: 316 return PGOHash::BinaryOperatorGT; 317 case BO_LE: 318 return PGOHash::BinaryOperatorLE; 319 case BO_GE: 320 return PGOHash::BinaryOperatorGE; 321 case BO_EQ: 322 return PGOHash::BinaryOperatorEQ; 323 case BO_NE: 324 return PGOHash::BinaryOperatorNE; 325 } 326 } 327 break; 328 } 329 } 330 331 if (HashVersion >= PGO_HASH_V2) { 332 switch (S->getStmtClass()) { 333 default: 334 break; 335 case Stmt::GotoStmtClass: 336 return PGOHash::GotoStmt; 337 case Stmt::IndirectGotoStmtClass: 338 return PGOHash::IndirectGotoStmt; 339 case Stmt::BreakStmtClass: 340 return PGOHash::BreakStmt; 341 case Stmt::ContinueStmtClass: 342 return PGOHash::ContinueStmt; 343 case Stmt::ReturnStmtClass: 344 return PGOHash::ReturnStmt; 345 case Stmt::CXXThrowExprClass: 346 return PGOHash::ThrowExpr; 347 case Stmt::UnaryOperatorClass: { 348 const UnaryOperator *UO = cast<UnaryOperator>(S); 349 if (UO->getOpcode() == UO_LNot) 350 return PGOHash::UnaryOperatorLNot; 351 break; 352 } 353 } 354 } 355 356 return PGOHash::None; 357 } 358 }; 359 360 /// A StmtVisitor that propagates the raw counts through the AST and 361 /// records the count at statements where the value may change. 362 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 363 /// PGO state. 364 CodeGenPGO &PGO; 365 366 /// A flag that is set when the current count should be recorded on the 367 /// next statement, such as at the exit of a loop. 368 bool RecordNextStmtCount; 369 370 /// The count at the current location in the traversal. 371 uint64_t CurrentCount; 372 373 /// The map of statements to count values. 374 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 375 376 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 377 struct BreakContinue { 378 uint64_t BreakCount; 379 uint64_t ContinueCount; 380 BreakContinue() : BreakCount(0), ContinueCount(0) {} 381 }; 382 SmallVector<BreakContinue, 8> BreakContinueStack; 383 384 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 385 CodeGenPGO &PGO) 386 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 387 388 void RecordStmtCount(const Stmt *S) { 389 if (RecordNextStmtCount) { 390 CountMap[S] = CurrentCount; 391 RecordNextStmtCount = false; 392 } 393 } 394 395 /// Set and return the current count. 396 uint64_t setCount(uint64_t Count) { 397 CurrentCount = Count; 398 return Count; 399 } 400 401 void VisitStmt(const Stmt *S) { 402 RecordStmtCount(S); 403 for (const Stmt *Child : S->children()) 404 if (Child) 405 this->Visit(Child); 406 } 407 408 void VisitFunctionDecl(const FunctionDecl *D) { 409 // Counter tracks entry to the function body. 410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 411 CountMap[D->getBody()] = BodyCount; 412 Visit(D->getBody()); 413 } 414 415 // Skip lambda expressions. We visit these as FunctionDecls when we're 416 // generating them and aren't interested in the body when generating a 417 // parent context. 418 void VisitLambdaExpr(const LambdaExpr *LE) {} 419 420 void VisitCapturedDecl(const CapturedDecl *D) { 421 // Counter tracks entry to the capture body. 422 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 423 CountMap[D->getBody()] = BodyCount; 424 Visit(D->getBody()); 425 } 426 427 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 428 // Counter tracks entry to the method body. 429 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 430 CountMap[D->getBody()] = BodyCount; 431 Visit(D->getBody()); 432 } 433 434 void VisitBlockDecl(const BlockDecl *D) { 435 // Counter tracks entry to the block body. 436 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 437 CountMap[D->getBody()] = BodyCount; 438 Visit(D->getBody()); 439 } 440 441 void VisitReturnStmt(const ReturnStmt *S) { 442 RecordStmtCount(S); 443 if (S->getRetValue()) 444 Visit(S->getRetValue()); 445 CurrentCount = 0; 446 RecordNextStmtCount = true; 447 } 448 449 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 450 RecordStmtCount(E); 451 if (E->getSubExpr()) 452 Visit(E->getSubExpr()); 453 CurrentCount = 0; 454 RecordNextStmtCount = true; 455 } 456 457 void VisitGotoStmt(const GotoStmt *S) { 458 RecordStmtCount(S); 459 CurrentCount = 0; 460 RecordNextStmtCount = true; 461 } 462 463 void VisitLabelStmt(const LabelStmt *S) { 464 RecordNextStmtCount = false; 465 // Counter tracks the block following the label. 466 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 467 CountMap[S] = BlockCount; 468 Visit(S->getSubStmt()); 469 } 470 471 void VisitBreakStmt(const BreakStmt *S) { 472 RecordStmtCount(S); 473 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 474 BreakContinueStack.back().BreakCount += CurrentCount; 475 CurrentCount = 0; 476 RecordNextStmtCount = true; 477 } 478 479 void VisitContinueStmt(const ContinueStmt *S) { 480 RecordStmtCount(S); 481 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 482 BreakContinueStack.back().ContinueCount += CurrentCount; 483 CurrentCount = 0; 484 RecordNextStmtCount = true; 485 } 486 487 void VisitWhileStmt(const WhileStmt *S) { 488 RecordStmtCount(S); 489 uint64_t ParentCount = CurrentCount; 490 491 BreakContinueStack.push_back(BreakContinue()); 492 // Visit the body region first so the break/continue adjustments can be 493 // included when visiting the condition. 494 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 495 CountMap[S->getBody()] = CurrentCount; 496 Visit(S->getBody()); 497 uint64_t BackedgeCount = CurrentCount; 498 499 // ...then go back and propagate counts through the condition. The count 500 // at the start of the condition is the sum of the incoming edges, 501 // the backedge from the end of the loop body, and the edges from 502 // continue statements. 503 BreakContinue BC = BreakContinueStack.pop_back_val(); 504 uint64_t CondCount = 505 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 506 CountMap[S->getCond()] = CondCount; 507 Visit(S->getCond()); 508 setCount(BC.BreakCount + CondCount - BodyCount); 509 RecordNextStmtCount = true; 510 } 511 512 void VisitDoStmt(const DoStmt *S) { 513 RecordStmtCount(S); 514 uint64_t LoopCount = PGO.getRegionCount(S); 515 516 BreakContinueStack.push_back(BreakContinue()); 517 // The count doesn't include the fallthrough from the parent scope. Add it. 518 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 519 CountMap[S->getBody()] = BodyCount; 520 Visit(S->getBody()); 521 uint64_t BackedgeCount = CurrentCount; 522 523 BreakContinue BC = BreakContinueStack.pop_back_val(); 524 // The count at the start of the condition is equal to the count at the 525 // end of the body, plus any continues. 526 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 527 CountMap[S->getCond()] = CondCount; 528 Visit(S->getCond()); 529 setCount(BC.BreakCount + CondCount - LoopCount); 530 RecordNextStmtCount = true; 531 } 532 533 void VisitForStmt(const ForStmt *S) { 534 RecordStmtCount(S); 535 if (S->getInit()) 536 Visit(S->getInit()); 537 538 uint64_t ParentCount = CurrentCount; 539 540 BreakContinueStack.push_back(BreakContinue()); 541 // Visit the body region first. (This is basically the same as a while 542 // loop; see further comments in VisitWhileStmt.) 543 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 544 CountMap[S->getBody()] = BodyCount; 545 Visit(S->getBody()); 546 uint64_t BackedgeCount = CurrentCount; 547 BreakContinue BC = BreakContinueStack.pop_back_val(); 548 549 // The increment is essentially part of the body but it needs to include 550 // the count for all the continue statements. 551 if (S->getInc()) { 552 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 553 CountMap[S->getInc()] = IncCount; 554 Visit(S->getInc()); 555 } 556 557 // ...then go back and propagate counts through the condition. 558 uint64_t CondCount = 559 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 560 if (S->getCond()) { 561 CountMap[S->getCond()] = CondCount; 562 Visit(S->getCond()); 563 } 564 setCount(BC.BreakCount + CondCount - BodyCount); 565 RecordNextStmtCount = true; 566 } 567 568 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 569 RecordStmtCount(S); 570 if (S->getInit()) 571 Visit(S->getInit()); 572 Visit(S->getLoopVarStmt()); 573 Visit(S->getRangeStmt()); 574 Visit(S->getBeginStmt()); 575 Visit(S->getEndStmt()); 576 577 uint64_t ParentCount = CurrentCount; 578 BreakContinueStack.push_back(BreakContinue()); 579 // Visit the body region first. (This is basically the same as a while 580 // loop; see further comments in VisitWhileStmt.) 581 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 582 CountMap[S->getBody()] = BodyCount; 583 Visit(S->getBody()); 584 uint64_t BackedgeCount = CurrentCount; 585 BreakContinue BC = BreakContinueStack.pop_back_val(); 586 587 // The increment is essentially part of the body but it needs to include 588 // the count for all the continue statements. 589 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 590 CountMap[S->getInc()] = IncCount; 591 Visit(S->getInc()); 592 593 // ...then go back and propagate counts through the condition. 594 uint64_t CondCount = 595 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 596 CountMap[S->getCond()] = CondCount; 597 Visit(S->getCond()); 598 setCount(BC.BreakCount + CondCount - BodyCount); 599 RecordNextStmtCount = true; 600 } 601 602 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 603 RecordStmtCount(S); 604 Visit(S->getElement()); 605 uint64_t ParentCount = CurrentCount; 606 BreakContinueStack.push_back(BreakContinue()); 607 // Counter tracks the body of the loop. 608 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 609 CountMap[S->getBody()] = BodyCount; 610 Visit(S->getBody()); 611 uint64_t BackedgeCount = CurrentCount; 612 BreakContinue BC = BreakContinueStack.pop_back_val(); 613 614 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 615 BodyCount); 616 RecordNextStmtCount = true; 617 } 618 619 void VisitSwitchStmt(const SwitchStmt *S) { 620 RecordStmtCount(S); 621 if (S->getInit()) 622 Visit(S->getInit()); 623 Visit(S->getCond()); 624 CurrentCount = 0; 625 BreakContinueStack.push_back(BreakContinue()); 626 Visit(S->getBody()); 627 // If the switch is inside a loop, add the continue counts. 628 BreakContinue BC = BreakContinueStack.pop_back_val(); 629 if (!BreakContinueStack.empty()) 630 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 631 // Counter tracks the exit block of the switch. 632 setCount(PGO.getRegionCount(S)); 633 RecordNextStmtCount = true; 634 } 635 636 void VisitSwitchCase(const SwitchCase *S) { 637 RecordNextStmtCount = false; 638 // Counter for this particular case. This counts only jumps from the 639 // switch header and does not include fallthrough from the case before 640 // this one. 641 uint64_t CaseCount = PGO.getRegionCount(S); 642 setCount(CurrentCount + CaseCount); 643 // We need the count without fallthrough in the mapping, so it's more useful 644 // for branch probabilities. 645 CountMap[S] = CaseCount; 646 RecordNextStmtCount = true; 647 Visit(S->getSubStmt()); 648 } 649 650 void VisitIfStmt(const IfStmt *S) { 651 RecordStmtCount(S); 652 uint64_t ParentCount = CurrentCount; 653 if (S->getInit()) 654 Visit(S->getInit()); 655 Visit(S->getCond()); 656 657 // Counter tracks the "then" part of an if statement. The count for 658 // the "else" part, if it exists, will be calculated from this counter. 659 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 660 CountMap[S->getThen()] = ThenCount; 661 Visit(S->getThen()); 662 uint64_t OutCount = CurrentCount; 663 664 uint64_t ElseCount = ParentCount - ThenCount; 665 if (S->getElse()) { 666 setCount(ElseCount); 667 CountMap[S->getElse()] = ElseCount; 668 Visit(S->getElse()); 669 OutCount += CurrentCount; 670 } else 671 OutCount += ElseCount; 672 setCount(OutCount); 673 RecordNextStmtCount = true; 674 } 675 676 void VisitCXXTryStmt(const CXXTryStmt *S) { 677 RecordStmtCount(S); 678 Visit(S->getTryBlock()); 679 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 680 Visit(S->getHandler(I)); 681 // Counter tracks the continuation block of the try statement. 682 setCount(PGO.getRegionCount(S)); 683 RecordNextStmtCount = true; 684 } 685 686 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 687 RecordNextStmtCount = false; 688 // Counter tracks the catch statement's handler block. 689 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 690 CountMap[S] = CatchCount; 691 Visit(S->getHandlerBlock()); 692 } 693 694 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 695 RecordStmtCount(E); 696 uint64_t ParentCount = CurrentCount; 697 Visit(E->getCond()); 698 699 // Counter tracks the "true" part of a conditional operator. The 700 // count in the "false" part will be calculated from this counter. 701 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 702 CountMap[E->getTrueExpr()] = TrueCount; 703 Visit(E->getTrueExpr()); 704 uint64_t OutCount = CurrentCount; 705 706 uint64_t FalseCount = setCount(ParentCount - TrueCount); 707 CountMap[E->getFalseExpr()] = FalseCount; 708 Visit(E->getFalseExpr()); 709 OutCount += CurrentCount; 710 711 setCount(OutCount); 712 RecordNextStmtCount = true; 713 } 714 715 void VisitBinLAnd(const BinaryOperator *E) { 716 RecordStmtCount(E); 717 uint64_t ParentCount = CurrentCount; 718 Visit(E->getLHS()); 719 // Counter tracks the right hand side of a logical and operator. 720 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 721 CountMap[E->getRHS()] = RHSCount; 722 Visit(E->getRHS()); 723 setCount(ParentCount + RHSCount - CurrentCount); 724 RecordNextStmtCount = true; 725 } 726 727 void VisitBinLOr(const BinaryOperator *E) { 728 RecordStmtCount(E); 729 uint64_t ParentCount = CurrentCount; 730 Visit(E->getLHS()); 731 // Counter tracks the right hand side of a logical or operator. 732 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 733 CountMap[E->getRHS()] = RHSCount; 734 Visit(E->getRHS()); 735 setCount(ParentCount + RHSCount - CurrentCount); 736 RecordNextStmtCount = true; 737 } 738 }; 739 } // end anonymous namespace 740 741 void PGOHash::combine(HashType Type) { 742 // Check that we never combine 0 and only have six bits. 743 assert(Type && "Hash is invalid: unexpected type 0"); 744 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 745 746 // Pass through MD5 if enough work has built up. 747 if (Count && Count % NumTypesPerWord == 0) { 748 using namespace llvm::support; 749 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 750 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 751 Working = 0; 752 } 753 754 // Accumulate the current type. 755 ++Count; 756 Working = Working << NumBitsPerType | Type; 757 } 758 759 uint64_t PGOHash::finalize() { 760 // Use Working as the hash directly if we never used MD5. 761 if (Count <= NumTypesPerWord) 762 // No need to byte swap here, since none of the math was endian-dependent. 763 // This number will be byte-swapped as required on endianness transitions, 764 // so we will see the same value on the other side. 765 return Working; 766 767 // Check for remaining work in Working. 768 if (Working) { 769 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This 770 // is buggy because it converts a uint64_t into an array of uint8_t. 771 if (HashVersion < PGO_HASH_V3) { 772 MD5.update({(uint8_t)Working}); 773 } else { 774 using namespace llvm::support; 775 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 776 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 777 } 778 } 779 780 // Finalize the MD5 and return the hash. 781 llvm::MD5::MD5Result Result; 782 MD5.final(Result); 783 return Result.low(); 784 } 785 786 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 787 const Decl *D = GD.getDecl(); 788 if (!D->hasBody()) 789 return; 790 791 // Skip CUDA/HIP kernel launch stub functions. 792 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && 793 D->hasAttr<CUDAGlobalAttr>()) 794 return; 795 796 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 797 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 798 if (!InstrumentRegions && !PGOReader) 799 return; 800 if (D->isImplicit()) 801 return; 802 // Constructors and destructors may be represented by several functions in IR. 803 // If so, instrument only base variant, others are implemented by delegation 804 // to the base one, it would be counted twice otherwise. 805 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 806 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 807 if (GD.getCtorType() != Ctor_Base && 808 CodeGenFunction::IsConstructorDelegationValid(CCD)) 809 return; 810 } 811 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 812 return; 813 814 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile)) 815 return; 816 817 CGM.ClearUnusedCoverageMapping(D); 818 setFuncName(Fn); 819 820 mapRegionCounters(D); 821 if (CGM.getCodeGenOpts().CoverageMapping) 822 emitCounterRegionMapping(D); 823 if (PGOReader) { 824 SourceManager &SM = CGM.getContext().getSourceManager(); 825 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 826 computeRegionCounts(D); 827 applyFunctionAttributes(PGOReader, Fn); 828 } 829 } 830 831 void CodeGenPGO::mapRegionCounters(const Decl *D) { 832 // Use the latest hash version when inserting instrumentation, but use the 833 // version in the indexed profile if we're reading PGO data. 834 PGOHashVersion HashVersion = PGO_HASH_LATEST; 835 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version; 836 if (auto *PGOReader = CGM.getPGOReader()) { 837 HashVersion = getPGOHashVersion(PGOReader, CGM); 838 ProfileVersion = PGOReader->getVersion(); 839 } 840 841 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 842 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap); 843 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 844 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 845 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 846 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 847 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 848 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 849 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 850 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 851 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 852 NumRegionCounters = Walker.NextCounter; 853 FunctionHash = Walker.Hash.finalize(); 854 } 855 856 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 857 if (!D->getBody()) 858 return true; 859 860 // Skip host-only functions in the CUDA device compilation and device-only 861 // functions in the host compilation. Just roughly filter them out based on 862 // the function attributes. If there are effectively host-only or device-only 863 // ones, their coverage mapping may still be generated. 864 if (CGM.getLangOpts().CUDA && 865 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() && 866 !D->hasAttr<CUDAGlobalAttr>()) || 867 (!CGM.getLangOpts().CUDAIsDevice && 868 (D->hasAttr<CUDAGlobalAttr>() || 869 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>()))))) 870 return true; 871 872 // Don't map the functions in system headers. 873 const auto &SM = CGM.getContext().getSourceManager(); 874 auto Loc = D->getBody()->getBeginLoc(); 875 return SM.isInSystemHeader(Loc); 876 } 877 878 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 879 if (skipRegionMappingForDecl(D)) 880 return; 881 882 std::string CoverageMapping; 883 llvm::raw_string_ostream OS(CoverageMapping); 884 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 885 CGM.getContext().getSourceManager(), 886 CGM.getLangOpts(), RegionCounterMap.get()); 887 MappingGen.emitCounterMapping(D, OS); 888 OS.flush(); 889 890 if (CoverageMapping.empty()) 891 return; 892 893 CGM.getCoverageMapping()->addFunctionMappingRecord( 894 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 895 } 896 897 void 898 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 899 llvm::GlobalValue::LinkageTypes Linkage) { 900 if (skipRegionMappingForDecl(D)) 901 return; 902 903 std::string CoverageMapping; 904 llvm::raw_string_ostream OS(CoverageMapping); 905 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 906 CGM.getContext().getSourceManager(), 907 CGM.getLangOpts()); 908 MappingGen.emitEmptyMapping(D, OS); 909 OS.flush(); 910 911 if (CoverageMapping.empty()) 912 return; 913 914 setFuncName(Name, Linkage); 915 CGM.getCoverageMapping()->addFunctionMappingRecord( 916 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 917 } 918 919 void CodeGenPGO::computeRegionCounts(const Decl *D) { 920 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 921 ComputeRegionCounts Walker(*StmtCountMap, *this); 922 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 923 Walker.VisitFunctionDecl(FD); 924 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 925 Walker.VisitObjCMethodDecl(MD); 926 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 927 Walker.VisitBlockDecl(BD); 928 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 929 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 930 } 931 932 void 933 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 934 llvm::Function *Fn) { 935 if (!haveRegionCounts()) 936 return; 937 938 uint64_t FunctionCount = getRegionCount(nullptr); 939 Fn->setEntryCount(FunctionCount); 940 } 941 942 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 943 llvm::Value *StepV) { 944 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 945 return; 946 if (!Builder.GetInsertBlock()) 947 return; 948 949 unsigned Counter = (*RegionCounterMap)[S]; 950 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 951 952 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 953 Builder.getInt64(FunctionHash), 954 Builder.getInt32(NumRegionCounters), 955 Builder.getInt32(Counter), StepV}; 956 if (!StepV) 957 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 958 makeArrayRef(Args, 4)); 959 else 960 Builder.CreateCall( 961 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 962 makeArrayRef(Args)); 963 } 964 965 // This method either inserts a call to the profile run-time during 966 // instrumentation or puts profile data into metadata for PGO use. 967 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 968 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 969 970 if (!EnableValueProfiling) 971 return; 972 973 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 974 return; 975 976 if (isa<llvm::Constant>(ValuePtr)) 977 return; 978 979 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 980 if (InstrumentValueSites && RegionCounterMap) { 981 auto BuilderInsertPoint = Builder.saveIP(); 982 Builder.SetInsertPoint(ValueSite); 983 llvm::Value *Args[5] = { 984 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 985 Builder.getInt64(FunctionHash), 986 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 987 Builder.getInt32(ValueKind), 988 Builder.getInt32(NumValueSites[ValueKind]++) 989 }; 990 Builder.CreateCall( 991 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 992 Builder.restoreIP(BuilderInsertPoint); 993 return; 994 } 995 996 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 997 if (PGOReader && haveRegionCounts()) { 998 // We record the top most called three functions at each call site. 999 // Profile metadata contains "VP" string identifying this metadata 1000 // as value profiling data, then a uint32_t value for the value profiling 1001 // kind, a uint64_t value for the total number of times the call is 1002 // executed, followed by the function hash and execution count (uint64_t) 1003 // pairs for each function. 1004 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 1005 return; 1006 1007 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 1008 (llvm::InstrProfValueKind)ValueKind, 1009 NumValueSites[ValueKind]); 1010 1011 NumValueSites[ValueKind]++; 1012 } 1013 } 1014 1015 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 1016 bool IsInMainFile) { 1017 CGM.getPGOStats().addVisited(IsInMainFile); 1018 RegionCounts.clear(); 1019 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 1020 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 1021 if (auto E = RecordExpected.takeError()) { 1022 auto IPE = llvm::InstrProfError::take(std::move(E)); 1023 if (IPE == llvm::instrprof_error::unknown_function) 1024 CGM.getPGOStats().addMissing(IsInMainFile); 1025 else if (IPE == llvm::instrprof_error::hash_mismatch) 1026 CGM.getPGOStats().addMismatched(IsInMainFile); 1027 else if (IPE == llvm::instrprof_error::malformed) 1028 // TODO: Consider a more specific warning for this case. 1029 CGM.getPGOStats().addMismatched(IsInMainFile); 1030 return; 1031 } 1032 ProfRecord = 1033 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 1034 RegionCounts = ProfRecord->Counts; 1035 } 1036 1037 /// Calculate what to divide by to scale weights. 1038 /// 1039 /// Given the maximum weight, calculate a divisor that will scale all the 1040 /// weights to strictly less than UINT32_MAX. 1041 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1042 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1043 } 1044 1045 /// Scale an individual branch weight (and add 1). 1046 /// 1047 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1048 /// 1049 /// According to Laplace's Rule of Succession, it is better to compute the 1050 /// weight based on the count plus 1, so universally add 1 to the value. 1051 /// 1052 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1053 /// greater than \c Weight. 1054 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1055 assert(Scale && "scale by 0?"); 1056 uint64_t Scaled = Weight / Scale + 1; 1057 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1058 return Scaled; 1059 } 1060 1061 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1062 uint64_t FalseCount) const { 1063 // Check for empty weights. 1064 if (!TrueCount && !FalseCount) 1065 return nullptr; 1066 1067 // Calculate how to scale down to 32-bits. 1068 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1069 1070 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1071 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1072 scaleBranchWeight(FalseCount, Scale)); 1073 } 1074 1075 llvm::MDNode * 1076 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const { 1077 // We need at least two elements to create meaningful weights. 1078 if (Weights.size() < 2) 1079 return nullptr; 1080 1081 // Check for empty weights. 1082 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1083 if (MaxWeight == 0) 1084 return nullptr; 1085 1086 // Calculate how to scale down to 32-bits. 1087 uint64_t Scale = calculateWeightScale(MaxWeight); 1088 1089 SmallVector<uint32_t, 16> ScaledWeights; 1090 ScaledWeights.reserve(Weights.size()); 1091 for (uint64_t W : Weights) 1092 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1093 1094 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1095 return MDHelper.createBranchWeights(ScaledWeights); 1096 } 1097 1098 llvm::MDNode * 1099 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1100 uint64_t LoopCount) const { 1101 if (!PGO.haveRegionCounts()) 1102 return nullptr; 1103 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1104 if (!CondCount || *CondCount == 0) 1105 return nullptr; 1106 return createProfileWeights(LoopCount, 1107 std::max(*CondCount, LoopCount) - LoopCount); 1108 } 1109