1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/Endian.h" 21 #include "llvm/Support/FileSystem.h" 22 #include "llvm/Support/MD5.h" 23 24 static llvm::cl::opt<bool> 25 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 26 llvm::cl::desc("Enable value profiling"), 27 llvm::cl::Hidden, llvm::cl::init(false)); 28 29 using namespace clang; 30 using namespace CodeGen; 31 32 void CodeGenPGO::setFuncName(StringRef Name, 33 llvm::GlobalValue::LinkageTypes Linkage) { 34 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 35 FuncName = llvm::getPGOFuncName( 36 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 37 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 38 39 // If we're generating a profile, create a variable for the name. 40 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 41 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 42 } 43 44 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 45 setFuncName(Fn->getName(), Fn->getLinkage()); 46 // Create PGOFuncName meta data. 47 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 48 } 49 50 /// The version of the PGO hash algorithm. 51 enum PGOHashVersion : unsigned { 52 PGO_HASH_V1, 53 PGO_HASH_V2, 54 55 // Keep this set to the latest hash version. 56 PGO_HASH_LATEST = PGO_HASH_V2 57 }; 58 59 namespace { 60 /// Stable hasher for PGO region counters. 61 /// 62 /// PGOHash produces a stable hash of a given function's control flow. 63 /// 64 /// Changing the output of this hash will invalidate all previously generated 65 /// profiles -- i.e., don't do it. 66 /// 67 /// \note When this hash does eventually change (years?), we still need to 68 /// support old hashes. We'll need to pull in the version number from the 69 /// profile data format and use the matching hash function. 70 class PGOHash { 71 uint64_t Working; 72 unsigned Count; 73 PGOHashVersion HashVersion; 74 llvm::MD5 MD5; 75 76 static const int NumBitsPerType = 6; 77 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 78 static const unsigned TooBig = 1u << NumBitsPerType; 79 80 public: 81 /// Hash values for AST nodes. 82 /// 83 /// Distinct values for AST nodes that have region counters attached. 84 /// 85 /// These values must be stable. All new members must be added at the end, 86 /// and no members should be removed. Changing the enumeration value for an 87 /// AST node will affect the hash of every function that contains that node. 88 enum HashType : unsigned char { 89 None = 0, 90 LabelStmt = 1, 91 WhileStmt, 92 DoStmt, 93 ForStmt, 94 CXXForRangeStmt, 95 ObjCForCollectionStmt, 96 SwitchStmt, 97 CaseStmt, 98 DefaultStmt, 99 IfStmt, 100 CXXTryStmt, 101 CXXCatchStmt, 102 ConditionalOperator, 103 BinaryOperatorLAnd, 104 BinaryOperatorLOr, 105 BinaryConditionalOperator, 106 // The preceding values are available with PGO_HASH_V1. 107 108 EndOfScope, 109 IfThenBranch, 110 IfElseBranch, 111 GotoStmt, 112 IndirectGotoStmt, 113 BreakStmt, 114 ContinueStmt, 115 ReturnStmt, 116 ThrowExpr, 117 UnaryOperatorLNot, 118 BinaryOperatorLT, 119 BinaryOperatorGT, 120 BinaryOperatorLE, 121 BinaryOperatorGE, 122 BinaryOperatorEQ, 123 BinaryOperatorNE, 124 // The preceding values are available with PGO_HASH_V2. 125 126 // Keep this last. It's for the static assert that follows. 127 LastHashType 128 }; 129 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 130 131 PGOHash(PGOHashVersion HashVersion) 132 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 133 void combine(HashType Type); 134 uint64_t finalize(); 135 PGOHashVersion getHashVersion() const { return HashVersion; } 136 }; 137 const int PGOHash::NumBitsPerType; 138 const unsigned PGOHash::NumTypesPerWord; 139 const unsigned PGOHash::TooBig; 140 141 /// Get the PGO hash version used in the given indexed profile. 142 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 143 CodeGenModule &CGM) { 144 if (PGOReader->getVersion() <= 4) 145 return PGO_HASH_V1; 146 return PGO_HASH_V2; 147 } 148 149 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 150 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 151 using Base = RecursiveASTVisitor<MapRegionCounters>; 152 153 /// The next counter value to assign. 154 unsigned NextCounter; 155 /// The function hash. 156 PGOHash Hash; 157 /// The map of statements to counters. 158 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 159 160 MapRegionCounters(PGOHashVersion HashVersion, 161 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 162 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {} 163 164 // Blocks and lambdas are handled as separate functions, so we need not 165 // traverse them in the parent context. 166 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 167 bool TraverseLambdaExpr(LambdaExpr *LE) { 168 // Traverse the captures, but not the body. 169 for (const auto &C : zip(LE->captures(), LE->capture_inits())) 170 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 171 return true; 172 } 173 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 174 175 bool VisitDecl(const Decl *D) { 176 switch (D->getKind()) { 177 default: 178 break; 179 case Decl::Function: 180 case Decl::CXXMethod: 181 case Decl::CXXConstructor: 182 case Decl::CXXDestructor: 183 case Decl::CXXConversion: 184 case Decl::ObjCMethod: 185 case Decl::Block: 186 case Decl::Captured: 187 CounterMap[D->getBody()] = NextCounter++; 188 break; 189 } 190 return true; 191 } 192 193 /// If \p S gets a fresh counter, update the counter mappings. Return the 194 /// V1 hash of \p S. 195 PGOHash::HashType updateCounterMappings(Stmt *S) { 196 auto Type = getHashType(PGO_HASH_V1, S); 197 if (Type != PGOHash::None) 198 CounterMap[S] = NextCounter++; 199 return Type; 200 } 201 202 /// Include \p S in the function hash. 203 bool VisitStmt(Stmt *S) { 204 auto Type = updateCounterMappings(S); 205 if (Hash.getHashVersion() != PGO_HASH_V1) 206 Type = getHashType(Hash.getHashVersion(), S); 207 if (Type != PGOHash::None) 208 Hash.combine(Type); 209 return true; 210 } 211 212 bool TraverseIfStmt(IfStmt *If) { 213 // If we used the V1 hash, use the default traversal. 214 if (Hash.getHashVersion() == PGO_HASH_V1) 215 return Base::TraverseIfStmt(If); 216 217 // Otherwise, keep track of which branch we're in while traversing. 218 VisitStmt(If); 219 for (Stmt *CS : If->children()) { 220 if (!CS) 221 continue; 222 if (CS == If->getThen()) 223 Hash.combine(PGOHash::IfThenBranch); 224 else if (CS == If->getElse()) 225 Hash.combine(PGOHash::IfElseBranch); 226 TraverseStmt(CS); 227 } 228 Hash.combine(PGOHash::EndOfScope); 229 return true; 230 } 231 232 // If the statement type \p N is nestable, and its nesting impacts profile 233 // stability, define a custom traversal which tracks the end of the statement 234 // in the hash (provided we're not using the V1 hash). 235 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 236 bool Traverse##N(N *S) { \ 237 Base::Traverse##N(S); \ 238 if (Hash.getHashVersion() != PGO_HASH_V1) \ 239 Hash.combine(PGOHash::EndOfScope); \ 240 return true; \ 241 } 242 243 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 244 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 245 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 246 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 247 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 248 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 249 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 250 251 /// Get version \p HashVersion of the PGO hash for \p S. 252 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 253 switch (S->getStmtClass()) { 254 default: 255 break; 256 case Stmt::LabelStmtClass: 257 return PGOHash::LabelStmt; 258 case Stmt::WhileStmtClass: 259 return PGOHash::WhileStmt; 260 case Stmt::DoStmtClass: 261 return PGOHash::DoStmt; 262 case Stmt::ForStmtClass: 263 return PGOHash::ForStmt; 264 case Stmt::CXXForRangeStmtClass: 265 return PGOHash::CXXForRangeStmt; 266 case Stmt::ObjCForCollectionStmtClass: 267 return PGOHash::ObjCForCollectionStmt; 268 case Stmt::SwitchStmtClass: 269 return PGOHash::SwitchStmt; 270 case Stmt::CaseStmtClass: 271 return PGOHash::CaseStmt; 272 case Stmt::DefaultStmtClass: 273 return PGOHash::DefaultStmt; 274 case Stmt::IfStmtClass: 275 return PGOHash::IfStmt; 276 case Stmt::CXXTryStmtClass: 277 return PGOHash::CXXTryStmt; 278 case Stmt::CXXCatchStmtClass: 279 return PGOHash::CXXCatchStmt; 280 case Stmt::ConditionalOperatorClass: 281 return PGOHash::ConditionalOperator; 282 case Stmt::BinaryConditionalOperatorClass: 283 return PGOHash::BinaryConditionalOperator; 284 case Stmt::BinaryOperatorClass: { 285 const BinaryOperator *BO = cast<BinaryOperator>(S); 286 if (BO->getOpcode() == BO_LAnd) 287 return PGOHash::BinaryOperatorLAnd; 288 if (BO->getOpcode() == BO_LOr) 289 return PGOHash::BinaryOperatorLOr; 290 if (HashVersion == PGO_HASH_V2) { 291 switch (BO->getOpcode()) { 292 default: 293 break; 294 case BO_LT: 295 return PGOHash::BinaryOperatorLT; 296 case BO_GT: 297 return PGOHash::BinaryOperatorGT; 298 case BO_LE: 299 return PGOHash::BinaryOperatorLE; 300 case BO_GE: 301 return PGOHash::BinaryOperatorGE; 302 case BO_EQ: 303 return PGOHash::BinaryOperatorEQ; 304 case BO_NE: 305 return PGOHash::BinaryOperatorNE; 306 } 307 } 308 break; 309 } 310 } 311 312 if (HashVersion == PGO_HASH_V2) { 313 switch (S->getStmtClass()) { 314 default: 315 break; 316 case Stmt::GotoStmtClass: 317 return PGOHash::GotoStmt; 318 case Stmt::IndirectGotoStmtClass: 319 return PGOHash::IndirectGotoStmt; 320 case Stmt::BreakStmtClass: 321 return PGOHash::BreakStmt; 322 case Stmt::ContinueStmtClass: 323 return PGOHash::ContinueStmt; 324 case Stmt::ReturnStmtClass: 325 return PGOHash::ReturnStmt; 326 case Stmt::CXXThrowExprClass: 327 return PGOHash::ThrowExpr; 328 case Stmt::UnaryOperatorClass: { 329 const UnaryOperator *UO = cast<UnaryOperator>(S); 330 if (UO->getOpcode() == UO_LNot) 331 return PGOHash::UnaryOperatorLNot; 332 break; 333 } 334 } 335 } 336 337 return PGOHash::None; 338 } 339 }; 340 341 /// A StmtVisitor that propagates the raw counts through the AST and 342 /// records the count at statements where the value may change. 343 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 344 /// PGO state. 345 CodeGenPGO &PGO; 346 347 /// A flag that is set when the current count should be recorded on the 348 /// next statement, such as at the exit of a loop. 349 bool RecordNextStmtCount; 350 351 /// The count at the current location in the traversal. 352 uint64_t CurrentCount; 353 354 /// The map of statements to count values. 355 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 356 357 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 358 struct BreakContinue { 359 uint64_t BreakCount; 360 uint64_t ContinueCount; 361 BreakContinue() : BreakCount(0), ContinueCount(0) {} 362 }; 363 SmallVector<BreakContinue, 8> BreakContinueStack; 364 365 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 366 CodeGenPGO &PGO) 367 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 368 369 void RecordStmtCount(const Stmt *S) { 370 if (RecordNextStmtCount) { 371 CountMap[S] = CurrentCount; 372 RecordNextStmtCount = false; 373 } 374 } 375 376 /// Set and return the current count. 377 uint64_t setCount(uint64_t Count) { 378 CurrentCount = Count; 379 return Count; 380 } 381 382 void VisitStmt(const Stmt *S) { 383 RecordStmtCount(S); 384 for (const Stmt *Child : S->children()) 385 if (Child) 386 this->Visit(Child); 387 } 388 389 void VisitFunctionDecl(const FunctionDecl *D) { 390 // Counter tracks entry to the function body. 391 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 392 CountMap[D->getBody()] = BodyCount; 393 Visit(D->getBody()); 394 } 395 396 // Skip lambda expressions. We visit these as FunctionDecls when we're 397 // generating them and aren't interested in the body when generating a 398 // parent context. 399 void VisitLambdaExpr(const LambdaExpr *LE) {} 400 401 void VisitCapturedDecl(const CapturedDecl *D) { 402 // Counter tracks entry to the capture body. 403 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 404 CountMap[D->getBody()] = BodyCount; 405 Visit(D->getBody()); 406 } 407 408 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 409 // Counter tracks entry to the method body. 410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 411 CountMap[D->getBody()] = BodyCount; 412 Visit(D->getBody()); 413 } 414 415 void VisitBlockDecl(const BlockDecl *D) { 416 // Counter tracks entry to the block body. 417 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 418 CountMap[D->getBody()] = BodyCount; 419 Visit(D->getBody()); 420 } 421 422 void VisitReturnStmt(const ReturnStmt *S) { 423 RecordStmtCount(S); 424 if (S->getRetValue()) 425 Visit(S->getRetValue()); 426 CurrentCount = 0; 427 RecordNextStmtCount = true; 428 } 429 430 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 431 RecordStmtCount(E); 432 if (E->getSubExpr()) 433 Visit(E->getSubExpr()); 434 CurrentCount = 0; 435 RecordNextStmtCount = true; 436 } 437 438 void VisitGotoStmt(const GotoStmt *S) { 439 RecordStmtCount(S); 440 CurrentCount = 0; 441 RecordNextStmtCount = true; 442 } 443 444 void VisitLabelStmt(const LabelStmt *S) { 445 RecordNextStmtCount = false; 446 // Counter tracks the block following the label. 447 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 448 CountMap[S] = BlockCount; 449 Visit(S->getSubStmt()); 450 } 451 452 void VisitBreakStmt(const BreakStmt *S) { 453 RecordStmtCount(S); 454 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 455 BreakContinueStack.back().BreakCount += CurrentCount; 456 CurrentCount = 0; 457 RecordNextStmtCount = true; 458 } 459 460 void VisitContinueStmt(const ContinueStmt *S) { 461 RecordStmtCount(S); 462 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 463 BreakContinueStack.back().ContinueCount += CurrentCount; 464 CurrentCount = 0; 465 RecordNextStmtCount = true; 466 } 467 468 void VisitWhileStmt(const WhileStmt *S) { 469 RecordStmtCount(S); 470 uint64_t ParentCount = CurrentCount; 471 472 BreakContinueStack.push_back(BreakContinue()); 473 // Visit the body region first so the break/continue adjustments can be 474 // included when visiting the condition. 475 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 476 CountMap[S->getBody()] = CurrentCount; 477 Visit(S->getBody()); 478 uint64_t BackedgeCount = CurrentCount; 479 480 // ...then go back and propagate counts through the condition. The count 481 // at the start of the condition is the sum of the incoming edges, 482 // the backedge from the end of the loop body, and the edges from 483 // continue statements. 484 BreakContinue BC = BreakContinueStack.pop_back_val(); 485 uint64_t CondCount = 486 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 487 CountMap[S->getCond()] = CondCount; 488 Visit(S->getCond()); 489 setCount(BC.BreakCount + CondCount - BodyCount); 490 RecordNextStmtCount = true; 491 } 492 493 void VisitDoStmt(const DoStmt *S) { 494 RecordStmtCount(S); 495 uint64_t LoopCount = PGO.getRegionCount(S); 496 497 BreakContinueStack.push_back(BreakContinue()); 498 // The count doesn't include the fallthrough from the parent scope. Add it. 499 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 500 CountMap[S->getBody()] = BodyCount; 501 Visit(S->getBody()); 502 uint64_t BackedgeCount = CurrentCount; 503 504 BreakContinue BC = BreakContinueStack.pop_back_val(); 505 // The count at the start of the condition is equal to the count at the 506 // end of the body, plus any continues. 507 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 508 CountMap[S->getCond()] = CondCount; 509 Visit(S->getCond()); 510 setCount(BC.BreakCount + CondCount - LoopCount); 511 RecordNextStmtCount = true; 512 } 513 514 void VisitForStmt(const ForStmt *S) { 515 RecordStmtCount(S); 516 if (S->getInit()) 517 Visit(S->getInit()); 518 519 uint64_t ParentCount = CurrentCount; 520 521 BreakContinueStack.push_back(BreakContinue()); 522 // Visit the body region first. (This is basically the same as a while 523 // loop; see further comments in VisitWhileStmt.) 524 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 525 CountMap[S->getBody()] = BodyCount; 526 Visit(S->getBody()); 527 uint64_t BackedgeCount = CurrentCount; 528 BreakContinue BC = BreakContinueStack.pop_back_val(); 529 530 // The increment is essentially part of the body but it needs to include 531 // the count for all the continue statements. 532 if (S->getInc()) { 533 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 534 CountMap[S->getInc()] = IncCount; 535 Visit(S->getInc()); 536 } 537 538 // ...then go back and propagate counts through the condition. 539 uint64_t CondCount = 540 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 541 if (S->getCond()) { 542 CountMap[S->getCond()] = CondCount; 543 Visit(S->getCond()); 544 } 545 setCount(BC.BreakCount + CondCount - BodyCount); 546 RecordNextStmtCount = true; 547 } 548 549 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 550 RecordStmtCount(S); 551 if (S->getInit()) 552 Visit(S->getInit()); 553 Visit(S->getLoopVarStmt()); 554 Visit(S->getRangeStmt()); 555 Visit(S->getBeginStmt()); 556 Visit(S->getEndStmt()); 557 558 uint64_t ParentCount = CurrentCount; 559 BreakContinueStack.push_back(BreakContinue()); 560 // Visit the body region first. (This is basically the same as a while 561 // loop; see further comments in VisitWhileStmt.) 562 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 563 CountMap[S->getBody()] = BodyCount; 564 Visit(S->getBody()); 565 uint64_t BackedgeCount = CurrentCount; 566 BreakContinue BC = BreakContinueStack.pop_back_val(); 567 568 // The increment is essentially part of the body but it needs to include 569 // the count for all the continue statements. 570 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 571 CountMap[S->getInc()] = IncCount; 572 Visit(S->getInc()); 573 574 // ...then go back and propagate counts through the condition. 575 uint64_t CondCount = 576 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 577 CountMap[S->getCond()] = CondCount; 578 Visit(S->getCond()); 579 setCount(BC.BreakCount + CondCount - BodyCount); 580 RecordNextStmtCount = true; 581 } 582 583 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 584 RecordStmtCount(S); 585 Visit(S->getElement()); 586 uint64_t ParentCount = CurrentCount; 587 BreakContinueStack.push_back(BreakContinue()); 588 // Counter tracks the body of the loop. 589 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 590 CountMap[S->getBody()] = BodyCount; 591 Visit(S->getBody()); 592 uint64_t BackedgeCount = CurrentCount; 593 BreakContinue BC = BreakContinueStack.pop_back_val(); 594 595 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 596 BodyCount); 597 RecordNextStmtCount = true; 598 } 599 600 void VisitSwitchStmt(const SwitchStmt *S) { 601 RecordStmtCount(S); 602 if (S->getInit()) 603 Visit(S->getInit()); 604 Visit(S->getCond()); 605 CurrentCount = 0; 606 BreakContinueStack.push_back(BreakContinue()); 607 Visit(S->getBody()); 608 // If the switch is inside a loop, add the continue counts. 609 BreakContinue BC = BreakContinueStack.pop_back_val(); 610 if (!BreakContinueStack.empty()) 611 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 612 // Counter tracks the exit block of the switch. 613 setCount(PGO.getRegionCount(S)); 614 RecordNextStmtCount = true; 615 } 616 617 void VisitSwitchCase(const SwitchCase *S) { 618 RecordNextStmtCount = false; 619 // Counter for this particular case. This counts only jumps from the 620 // switch header and does not include fallthrough from the case before 621 // this one. 622 uint64_t CaseCount = PGO.getRegionCount(S); 623 setCount(CurrentCount + CaseCount); 624 // We need the count without fallthrough in the mapping, so it's more useful 625 // for branch probabilities. 626 CountMap[S] = CaseCount; 627 RecordNextStmtCount = true; 628 Visit(S->getSubStmt()); 629 } 630 631 void VisitIfStmt(const IfStmt *S) { 632 RecordStmtCount(S); 633 uint64_t ParentCount = CurrentCount; 634 if (S->getInit()) 635 Visit(S->getInit()); 636 Visit(S->getCond()); 637 638 // Counter tracks the "then" part of an if statement. The count for 639 // the "else" part, if it exists, will be calculated from this counter. 640 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 641 CountMap[S->getThen()] = ThenCount; 642 Visit(S->getThen()); 643 uint64_t OutCount = CurrentCount; 644 645 uint64_t ElseCount = ParentCount - ThenCount; 646 if (S->getElse()) { 647 setCount(ElseCount); 648 CountMap[S->getElse()] = ElseCount; 649 Visit(S->getElse()); 650 OutCount += CurrentCount; 651 } else 652 OutCount += ElseCount; 653 setCount(OutCount); 654 RecordNextStmtCount = true; 655 } 656 657 void VisitCXXTryStmt(const CXXTryStmt *S) { 658 RecordStmtCount(S); 659 Visit(S->getTryBlock()); 660 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 661 Visit(S->getHandler(I)); 662 // Counter tracks the continuation block of the try statement. 663 setCount(PGO.getRegionCount(S)); 664 RecordNextStmtCount = true; 665 } 666 667 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 668 RecordNextStmtCount = false; 669 // Counter tracks the catch statement's handler block. 670 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 671 CountMap[S] = CatchCount; 672 Visit(S->getHandlerBlock()); 673 } 674 675 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 676 RecordStmtCount(E); 677 uint64_t ParentCount = CurrentCount; 678 Visit(E->getCond()); 679 680 // Counter tracks the "true" part of a conditional operator. The 681 // count in the "false" part will be calculated from this counter. 682 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 683 CountMap[E->getTrueExpr()] = TrueCount; 684 Visit(E->getTrueExpr()); 685 uint64_t OutCount = CurrentCount; 686 687 uint64_t FalseCount = setCount(ParentCount - TrueCount); 688 CountMap[E->getFalseExpr()] = FalseCount; 689 Visit(E->getFalseExpr()); 690 OutCount += CurrentCount; 691 692 setCount(OutCount); 693 RecordNextStmtCount = true; 694 } 695 696 void VisitBinLAnd(const BinaryOperator *E) { 697 RecordStmtCount(E); 698 uint64_t ParentCount = CurrentCount; 699 Visit(E->getLHS()); 700 // Counter tracks the right hand side of a logical and operator. 701 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 702 CountMap[E->getRHS()] = RHSCount; 703 Visit(E->getRHS()); 704 setCount(ParentCount + RHSCount - CurrentCount); 705 RecordNextStmtCount = true; 706 } 707 708 void VisitBinLOr(const BinaryOperator *E) { 709 RecordStmtCount(E); 710 uint64_t ParentCount = CurrentCount; 711 Visit(E->getLHS()); 712 // Counter tracks the right hand side of a logical or operator. 713 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 714 CountMap[E->getRHS()] = RHSCount; 715 Visit(E->getRHS()); 716 setCount(ParentCount + RHSCount - CurrentCount); 717 RecordNextStmtCount = true; 718 } 719 }; 720 } // end anonymous namespace 721 722 void PGOHash::combine(HashType Type) { 723 // Check that we never combine 0 and only have six bits. 724 assert(Type && "Hash is invalid: unexpected type 0"); 725 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 726 727 // Pass through MD5 if enough work has built up. 728 if (Count && Count % NumTypesPerWord == 0) { 729 using namespace llvm::support; 730 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 731 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 732 Working = 0; 733 } 734 735 // Accumulate the current type. 736 ++Count; 737 Working = Working << NumBitsPerType | Type; 738 } 739 740 uint64_t PGOHash::finalize() { 741 // Use Working as the hash directly if we never used MD5. 742 if (Count <= NumTypesPerWord) 743 // No need to byte swap here, since none of the math was endian-dependent. 744 // This number will be byte-swapped as required on endianness transitions, 745 // so we will see the same value on the other side. 746 return Working; 747 748 // Check for remaining work in Working. 749 if (Working) 750 MD5.update(Working); 751 752 // Finalize the MD5 and return the hash. 753 llvm::MD5::MD5Result Result; 754 MD5.final(Result); 755 using namespace llvm::support; 756 return Result.low(); 757 } 758 759 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 760 const Decl *D = GD.getDecl(); 761 if (!D->hasBody()) 762 return; 763 764 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 765 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 766 if (!InstrumentRegions && !PGOReader) 767 return; 768 if (D->isImplicit()) 769 return; 770 // Constructors and destructors may be represented by several functions in IR. 771 // If so, instrument only base variant, others are implemented by delegation 772 // to the base one, it would be counted twice otherwise. 773 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 774 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 775 if (GD.getCtorType() != Ctor_Base && 776 CodeGenFunction::IsConstructorDelegationValid(CCD)) 777 return; 778 } 779 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 780 return; 781 782 CGM.ClearUnusedCoverageMapping(D); 783 setFuncName(Fn); 784 785 mapRegionCounters(D); 786 if (CGM.getCodeGenOpts().CoverageMapping) 787 emitCounterRegionMapping(D); 788 if (PGOReader) { 789 SourceManager &SM = CGM.getContext().getSourceManager(); 790 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 791 computeRegionCounts(D); 792 applyFunctionAttributes(PGOReader, Fn); 793 } 794 } 795 796 void CodeGenPGO::mapRegionCounters(const Decl *D) { 797 // Use the latest hash version when inserting instrumentation, but use the 798 // version in the indexed profile if we're reading PGO data. 799 PGOHashVersion HashVersion = PGO_HASH_LATEST; 800 if (auto *PGOReader = CGM.getPGOReader()) 801 HashVersion = getPGOHashVersion(PGOReader, CGM); 802 803 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 804 MapRegionCounters Walker(HashVersion, *RegionCounterMap); 805 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 806 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 807 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 808 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 809 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 810 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 811 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 812 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 813 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 814 NumRegionCounters = Walker.NextCounter; 815 FunctionHash = Walker.Hash.finalize(); 816 } 817 818 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 819 if (!D->getBody()) 820 return true; 821 822 // Don't map the functions in system headers. 823 const auto &SM = CGM.getContext().getSourceManager(); 824 auto Loc = D->getBody()->getBeginLoc(); 825 return SM.isInSystemHeader(Loc); 826 } 827 828 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 829 if (skipRegionMappingForDecl(D)) 830 return; 831 832 std::string CoverageMapping; 833 llvm::raw_string_ostream OS(CoverageMapping); 834 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 835 CGM.getContext().getSourceManager(), 836 CGM.getLangOpts(), RegionCounterMap.get()); 837 MappingGen.emitCounterMapping(D, OS); 838 OS.flush(); 839 840 if (CoverageMapping.empty()) 841 return; 842 843 CGM.getCoverageMapping()->addFunctionMappingRecord( 844 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 845 } 846 847 void 848 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 849 llvm::GlobalValue::LinkageTypes Linkage) { 850 if (skipRegionMappingForDecl(D)) 851 return; 852 853 std::string CoverageMapping; 854 llvm::raw_string_ostream OS(CoverageMapping); 855 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 856 CGM.getContext().getSourceManager(), 857 CGM.getLangOpts()); 858 MappingGen.emitEmptyMapping(D, OS); 859 OS.flush(); 860 861 if (CoverageMapping.empty()) 862 return; 863 864 setFuncName(Name, Linkage); 865 CGM.getCoverageMapping()->addFunctionMappingRecord( 866 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 867 } 868 869 void CodeGenPGO::computeRegionCounts(const Decl *D) { 870 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 871 ComputeRegionCounts Walker(*StmtCountMap, *this); 872 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 873 Walker.VisitFunctionDecl(FD); 874 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 875 Walker.VisitObjCMethodDecl(MD); 876 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 877 Walker.VisitBlockDecl(BD); 878 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 879 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 880 } 881 882 void 883 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 884 llvm::Function *Fn) { 885 if (!haveRegionCounts()) 886 return; 887 888 uint64_t FunctionCount = getRegionCount(nullptr); 889 Fn->setEntryCount(FunctionCount); 890 } 891 892 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 893 llvm::Value *StepV) { 894 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 895 return; 896 if (!Builder.GetInsertBlock()) 897 return; 898 899 unsigned Counter = (*RegionCounterMap)[S]; 900 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 901 902 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 903 Builder.getInt64(FunctionHash), 904 Builder.getInt32(NumRegionCounters), 905 Builder.getInt32(Counter), StepV}; 906 if (!StepV) 907 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 908 makeArrayRef(Args, 4)); 909 else 910 Builder.CreateCall( 911 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 912 makeArrayRef(Args)); 913 } 914 915 // This method either inserts a call to the profile run-time during 916 // instrumentation or puts profile data into metadata for PGO use. 917 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 918 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 919 920 if (!EnableValueProfiling) 921 return; 922 923 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 924 return; 925 926 if (isa<llvm::Constant>(ValuePtr)) 927 return; 928 929 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 930 if (InstrumentValueSites && RegionCounterMap) { 931 auto BuilderInsertPoint = Builder.saveIP(); 932 Builder.SetInsertPoint(ValueSite); 933 llvm::Value *Args[5] = { 934 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 935 Builder.getInt64(FunctionHash), 936 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 937 Builder.getInt32(ValueKind), 938 Builder.getInt32(NumValueSites[ValueKind]++) 939 }; 940 Builder.CreateCall( 941 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 942 Builder.restoreIP(BuilderInsertPoint); 943 return; 944 } 945 946 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 947 if (PGOReader && haveRegionCounts()) { 948 // We record the top most called three functions at each call site. 949 // Profile metadata contains "VP" string identifying this metadata 950 // as value profiling data, then a uint32_t value for the value profiling 951 // kind, a uint64_t value for the total number of times the call is 952 // executed, followed by the function hash and execution count (uint64_t) 953 // pairs for each function. 954 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 955 return; 956 957 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 958 (llvm::InstrProfValueKind)ValueKind, 959 NumValueSites[ValueKind]); 960 961 NumValueSites[ValueKind]++; 962 } 963 } 964 965 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 966 bool IsInMainFile) { 967 CGM.getPGOStats().addVisited(IsInMainFile); 968 RegionCounts.clear(); 969 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 970 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 971 if (auto E = RecordExpected.takeError()) { 972 auto IPE = llvm::InstrProfError::take(std::move(E)); 973 if (IPE == llvm::instrprof_error::unknown_function) 974 CGM.getPGOStats().addMissing(IsInMainFile); 975 else if (IPE == llvm::instrprof_error::hash_mismatch) 976 CGM.getPGOStats().addMismatched(IsInMainFile); 977 else if (IPE == llvm::instrprof_error::malformed) 978 // TODO: Consider a more specific warning for this case. 979 CGM.getPGOStats().addMismatched(IsInMainFile); 980 return; 981 } 982 ProfRecord = 983 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 984 RegionCounts = ProfRecord->Counts; 985 } 986 987 /// Calculate what to divide by to scale weights. 988 /// 989 /// Given the maximum weight, calculate a divisor that will scale all the 990 /// weights to strictly less than UINT32_MAX. 991 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 992 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 993 } 994 995 /// Scale an individual branch weight (and add 1). 996 /// 997 /// Scale a 64-bit weight down to 32-bits using \c Scale. 998 /// 999 /// According to Laplace's Rule of Succession, it is better to compute the 1000 /// weight based on the count plus 1, so universally add 1 to the value. 1001 /// 1002 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1003 /// greater than \c Weight. 1004 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1005 assert(Scale && "scale by 0?"); 1006 uint64_t Scaled = Weight / Scale + 1; 1007 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1008 return Scaled; 1009 } 1010 1011 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1012 uint64_t FalseCount) { 1013 // Check for empty weights. 1014 if (!TrueCount && !FalseCount) 1015 return nullptr; 1016 1017 // Calculate how to scale down to 32-bits. 1018 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1019 1020 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1021 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1022 scaleBranchWeight(FalseCount, Scale)); 1023 } 1024 1025 llvm::MDNode * 1026 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 1027 // We need at least two elements to create meaningful weights. 1028 if (Weights.size() < 2) 1029 return nullptr; 1030 1031 // Check for empty weights. 1032 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1033 if (MaxWeight == 0) 1034 return nullptr; 1035 1036 // Calculate how to scale down to 32-bits. 1037 uint64_t Scale = calculateWeightScale(MaxWeight); 1038 1039 SmallVector<uint32_t, 16> ScaledWeights; 1040 ScaledWeights.reserve(Weights.size()); 1041 for (uint64_t W : Weights) 1042 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1043 1044 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1045 return MDHelper.createBranchWeights(ScaledWeights); 1046 } 1047 1048 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1049 uint64_t LoopCount) { 1050 if (!PGO.haveRegionCounts()) 1051 return nullptr; 1052 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1053 assert(CondCount.hasValue() && "missing expected loop condition count"); 1054 if (*CondCount == 0) 1055 return nullptr; 1056 return createProfileWeights(LoopCount, 1057 std::max(*CondCount, LoopCount) - LoopCount); 1058 } 1059