1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 56 // Keep this set to the latest hash version. 57 PGO_HASH_LATEST = PGO_HASH_V2 58 }; 59 60 namespace { 61 /// Stable hasher for PGO region counters. 62 /// 63 /// PGOHash produces a stable hash of a given function's control flow. 64 /// 65 /// Changing the output of this hash will invalidate all previously generated 66 /// profiles -- i.e., don't do it. 67 /// 68 /// \note When this hash does eventually change (years?), we still need to 69 /// support old hashes. We'll need to pull in the version number from the 70 /// profile data format and use the matching hash function. 71 class PGOHash { 72 uint64_t Working; 73 unsigned Count; 74 PGOHashVersion HashVersion; 75 llvm::MD5 MD5; 76 77 static const int NumBitsPerType = 6; 78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 79 static const unsigned TooBig = 1u << NumBitsPerType; 80 81 public: 82 /// Hash values for AST nodes. 83 /// 84 /// Distinct values for AST nodes that have region counters attached. 85 /// 86 /// These values must be stable. All new members must be added at the end, 87 /// and no members should be removed. Changing the enumeration value for an 88 /// AST node will affect the hash of every function that contains that node. 89 enum HashType : unsigned char { 90 None = 0, 91 LabelStmt = 1, 92 WhileStmt, 93 DoStmt, 94 ForStmt, 95 CXXForRangeStmt, 96 ObjCForCollectionStmt, 97 SwitchStmt, 98 CaseStmt, 99 DefaultStmt, 100 IfStmt, 101 CXXTryStmt, 102 CXXCatchStmt, 103 ConditionalOperator, 104 BinaryOperatorLAnd, 105 BinaryOperatorLOr, 106 BinaryConditionalOperator, 107 // The preceding values are available with PGO_HASH_V1. 108 109 EndOfScope, 110 IfThenBranch, 111 IfElseBranch, 112 GotoStmt, 113 IndirectGotoStmt, 114 BreakStmt, 115 ContinueStmt, 116 ReturnStmt, 117 ThrowExpr, 118 UnaryOperatorLNot, 119 BinaryOperatorLT, 120 BinaryOperatorGT, 121 BinaryOperatorLE, 122 BinaryOperatorGE, 123 BinaryOperatorEQ, 124 BinaryOperatorNE, 125 // The preceding values are available with PGO_HASH_V2. 126 127 // Keep this last. It's for the static assert that follows. 128 LastHashType 129 }; 130 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 131 132 PGOHash(PGOHashVersion HashVersion) 133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 134 void combine(HashType Type); 135 uint64_t finalize(); 136 PGOHashVersion getHashVersion() const { return HashVersion; } 137 }; 138 const int PGOHash::NumBitsPerType; 139 const unsigned PGOHash::NumTypesPerWord; 140 const unsigned PGOHash::TooBig; 141 142 /// Get the PGO hash version used in the given indexed profile. 143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 144 CodeGenModule &CGM) { 145 if (PGOReader->getVersion() <= 4) 146 return PGO_HASH_V1; 147 return PGO_HASH_V2; 148 } 149 150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 152 using Base = RecursiveASTVisitor<MapRegionCounters>; 153 154 /// The next counter value to assign. 155 unsigned NextCounter; 156 /// The function hash. 157 PGOHash Hash; 158 /// The map of statements to counters. 159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 160 161 MapRegionCounters(PGOHashVersion HashVersion, 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {} 164 165 // Blocks and lambdas are handled as separate functions, so we need not 166 // traverse them in the parent context. 167 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 168 bool TraverseLambdaExpr(LambdaExpr *LE) { 169 // Traverse the captures, but not the body. 170 for (auto C : zip(LE->captures(), LE->capture_inits())) 171 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 172 return true; 173 } 174 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 175 176 bool VisitDecl(const Decl *D) { 177 switch (D->getKind()) { 178 default: 179 break; 180 case Decl::Function: 181 case Decl::CXXMethod: 182 case Decl::CXXConstructor: 183 case Decl::CXXDestructor: 184 case Decl::CXXConversion: 185 case Decl::ObjCMethod: 186 case Decl::Block: 187 case Decl::Captured: 188 CounterMap[D->getBody()] = NextCounter++; 189 break; 190 } 191 return true; 192 } 193 194 /// If \p S gets a fresh counter, update the counter mappings. Return the 195 /// V1 hash of \p S. 196 PGOHash::HashType updateCounterMappings(Stmt *S) { 197 auto Type = getHashType(PGO_HASH_V1, S); 198 if (Type != PGOHash::None) 199 CounterMap[S] = NextCounter++; 200 return Type; 201 } 202 203 /// Include \p S in the function hash. 204 bool VisitStmt(Stmt *S) { 205 auto Type = updateCounterMappings(S); 206 if (Hash.getHashVersion() != PGO_HASH_V1) 207 Type = getHashType(Hash.getHashVersion(), S); 208 if (Type != PGOHash::None) 209 Hash.combine(Type); 210 return true; 211 } 212 213 bool TraverseIfStmt(IfStmt *If) { 214 // If we used the V1 hash, use the default traversal. 215 if (Hash.getHashVersion() == PGO_HASH_V1) 216 return Base::TraverseIfStmt(If); 217 218 // Otherwise, keep track of which branch we're in while traversing. 219 VisitStmt(If); 220 for (Stmt *CS : If->children()) { 221 if (!CS) 222 continue; 223 if (CS == If->getThen()) 224 Hash.combine(PGOHash::IfThenBranch); 225 else if (CS == If->getElse()) 226 Hash.combine(PGOHash::IfElseBranch); 227 TraverseStmt(CS); 228 } 229 Hash.combine(PGOHash::EndOfScope); 230 return true; 231 } 232 233 // If the statement type \p N is nestable, and its nesting impacts profile 234 // stability, define a custom traversal which tracks the end of the statement 235 // in the hash (provided we're not using the V1 hash). 236 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 237 bool Traverse##N(N *S) { \ 238 Base::Traverse##N(S); \ 239 if (Hash.getHashVersion() != PGO_HASH_V1) \ 240 Hash.combine(PGOHash::EndOfScope); \ 241 return true; \ 242 } 243 244 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 245 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 246 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 247 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 248 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 249 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 250 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 251 252 /// Get version \p HashVersion of the PGO hash for \p S. 253 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 254 switch (S->getStmtClass()) { 255 default: 256 break; 257 case Stmt::LabelStmtClass: 258 return PGOHash::LabelStmt; 259 case Stmt::WhileStmtClass: 260 return PGOHash::WhileStmt; 261 case Stmt::DoStmtClass: 262 return PGOHash::DoStmt; 263 case Stmt::ForStmtClass: 264 return PGOHash::ForStmt; 265 case Stmt::CXXForRangeStmtClass: 266 return PGOHash::CXXForRangeStmt; 267 case Stmt::ObjCForCollectionStmtClass: 268 return PGOHash::ObjCForCollectionStmt; 269 case Stmt::SwitchStmtClass: 270 return PGOHash::SwitchStmt; 271 case Stmt::CaseStmtClass: 272 return PGOHash::CaseStmt; 273 case Stmt::DefaultStmtClass: 274 return PGOHash::DefaultStmt; 275 case Stmt::IfStmtClass: 276 return PGOHash::IfStmt; 277 case Stmt::CXXTryStmtClass: 278 return PGOHash::CXXTryStmt; 279 case Stmt::CXXCatchStmtClass: 280 return PGOHash::CXXCatchStmt; 281 case Stmt::ConditionalOperatorClass: 282 return PGOHash::ConditionalOperator; 283 case Stmt::BinaryConditionalOperatorClass: 284 return PGOHash::BinaryConditionalOperator; 285 case Stmt::BinaryOperatorClass: { 286 const BinaryOperator *BO = cast<BinaryOperator>(S); 287 if (BO->getOpcode() == BO_LAnd) 288 return PGOHash::BinaryOperatorLAnd; 289 if (BO->getOpcode() == BO_LOr) 290 return PGOHash::BinaryOperatorLOr; 291 if (HashVersion == PGO_HASH_V2) { 292 switch (BO->getOpcode()) { 293 default: 294 break; 295 case BO_LT: 296 return PGOHash::BinaryOperatorLT; 297 case BO_GT: 298 return PGOHash::BinaryOperatorGT; 299 case BO_LE: 300 return PGOHash::BinaryOperatorLE; 301 case BO_GE: 302 return PGOHash::BinaryOperatorGE; 303 case BO_EQ: 304 return PGOHash::BinaryOperatorEQ; 305 case BO_NE: 306 return PGOHash::BinaryOperatorNE; 307 } 308 } 309 break; 310 } 311 } 312 313 if (HashVersion == PGO_HASH_V2) { 314 switch (S->getStmtClass()) { 315 default: 316 break; 317 case Stmt::GotoStmtClass: 318 return PGOHash::GotoStmt; 319 case Stmt::IndirectGotoStmtClass: 320 return PGOHash::IndirectGotoStmt; 321 case Stmt::BreakStmtClass: 322 return PGOHash::BreakStmt; 323 case Stmt::ContinueStmtClass: 324 return PGOHash::ContinueStmt; 325 case Stmt::ReturnStmtClass: 326 return PGOHash::ReturnStmt; 327 case Stmt::CXXThrowExprClass: 328 return PGOHash::ThrowExpr; 329 case Stmt::UnaryOperatorClass: { 330 const UnaryOperator *UO = cast<UnaryOperator>(S); 331 if (UO->getOpcode() == UO_LNot) 332 return PGOHash::UnaryOperatorLNot; 333 break; 334 } 335 } 336 } 337 338 return PGOHash::None; 339 } 340 }; 341 342 /// A StmtVisitor that propagates the raw counts through the AST and 343 /// records the count at statements where the value may change. 344 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 345 /// PGO state. 346 CodeGenPGO &PGO; 347 348 /// A flag that is set when the current count should be recorded on the 349 /// next statement, such as at the exit of a loop. 350 bool RecordNextStmtCount; 351 352 /// The count at the current location in the traversal. 353 uint64_t CurrentCount; 354 355 /// The map of statements to count values. 356 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 357 358 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 359 struct BreakContinue { 360 uint64_t BreakCount; 361 uint64_t ContinueCount; 362 BreakContinue() : BreakCount(0), ContinueCount(0) {} 363 }; 364 SmallVector<BreakContinue, 8> BreakContinueStack; 365 366 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 367 CodeGenPGO &PGO) 368 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 369 370 void RecordStmtCount(const Stmt *S) { 371 if (RecordNextStmtCount) { 372 CountMap[S] = CurrentCount; 373 RecordNextStmtCount = false; 374 } 375 } 376 377 /// Set and return the current count. 378 uint64_t setCount(uint64_t Count) { 379 CurrentCount = Count; 380 return Count; 381 } 382 383 void VisitStmt(const Stmt *S) { 384 RecordStmtCount(S); 385 for (const Stmt *Child : S->children()) 386 if (Child) 387 this->Visit(Child); 388 } 389 390 void VisitFunctionDecl(const FunctionDecl *D) { 391 // Counter tracks entry to the function body. 392 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 393 CountMap[D->getBody()] = BodyCount; 394 Visit(D->getBody()); 395 } 396 397 // Skip lambda expressions. We visit these as FunctionDecls when we're 398 // generating them and aren't interested in the body when generating a 399 // parent context. 400 void VisitLambdaExpr(const LambdaExpr *LE) {} 401 402 void VisitCapturedDecl(const CapturedDecl *D) { 403 // Counter tracks entry to the capture body. 404 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 405 CountMap[D->getBody()] = BodyCount; 406 Visit(D->getBody()); 407 } 408 409 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 410 // Counter tracks entry to the method body. 411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 412 CountMap[D->getBody()] = BodyCount; 413 Visit(D->getBody()); 414 } 415 416 void VisitBlockDecl(const BlockDecl *D) { 417 // Counter tracks entry to the block body. 418 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 419 CountMap[D->getBody()] = BodyCount; 420 Visit(D->getBody()); 421 } 422 423 void VisitReturnStmt(const ReturnStmt *S) { 424 RecordStmtCount(S); 425 if (S->getRetValue()) 426 Visit(S->getRetValue()); 427 CurrentCount = 0; 428 RecordNextStmtCount = true; 429 } 430 431 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 432 RecordStmtCount(E); 433 if (E->getSubExpr()) 434 Visit(E->getSubExpr()); 435 CurrentCount = 0; 436 RecordNextStmtCount = true; 437 } 438 439 void VisitGotoStmt(const GotoStmt *S) { 440 RecordStmtCount(S); 441 CurrentCount = 0; 442 RecordNextStmtCount = true; 443 } 444 445 void VisitLabelStmt(const LabelStmt *S) { 446 RecordNextStmtCount = false; 447 // Counter tracks the block following the label. 448 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 449 CountMap[S] = BlockCount; 450 Visit(S->getSubStmt()); 451 } 452 453 void VisitBreakStmt(const BreakStmt *S) { 454 RecordStmtCount(S); 455 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 456 BreakContinueStack.back().BreakCount += CurrentCount; 457 CurrentCount = 0; 458 RecordNextStmtCount = true; 459 } 460 461 void VisitContinueStmt(const ContinueStmt *S) { 462 RecordStmtCount(S); 463 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 464 BreakContinueStack.back().ContinueCount += CurrentCount; 465 CurrentCount = 0; 466 RecordNextStmtCount = true; 467 } 468 469 void VisitWhileStmt(const WhileStmt *S) { 470 RecordStmtCount(S); 471 uint64_t ParentCount = CurrentCount; 472 473 BreakContinueStack.push_back(BreakContinue()); 474 // Visit the body region first so the break/continue adjustments can be 475 // included when visiting the condition. 476 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 477 CountMap[S->getBody()] = CurrentCount; 478 Visit(S->getBody()); 479 uint64_t BackedgeCount = CurrentCount; 480 481 // ...then go back and propagate counts through the condition. The count 482 // at the start of the condition is the sum of the incoming edges, 483 // the backedge from the end of the loop body, and the edges from 484 // continue statements. 485 BreakContinue BC = BreakContinueStack.pop_back_val(); 486 uint64_t CondCount = 487 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 488 CountMap[S->getCond()] = CondCount; 489 Visit(S->getCond()); 490 setCount(BC.BreakCount + CondCount - BodyCount); 491 RecordNextStmtCount = true; 492 } 493 494 void VisitDoStmt(const DoStmt *S) { 495 RecordStmtCount(S); 496 uint64_t LoopCount = PGO.getRegionCount(S); 497 498 BreakContinueStack.push_back(BreakContinue()); 499 // The count doesn't include the fallthrough from the parent scope. Add it. 500 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 501 CountMap[S->getBody()] = BodyCount; 502 Visit(S->getBody()); 503 uint64_t BackedgeCount = CurrentCount; 504 505 BreakContinue BC = BreakContinueStack.pop_back_val(); 506 // The count at the start of the condition is equal to the count at the 507 // end of the body, plus any continues. 508 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 509 CountMap[S->getCond()] = CondCount; 510 Visit(S->getCond()); 511 setCount(BC.BreakCount + CondCount - LoopCount); 512 RecordNextStmtCount = true; 513 } 514 515 void VisitForStmt(const ForStmt *S) { 516 RecordStmtCount(S); 517 if (S->getInit()) 518 Visit(S->getInit()); 519 520 uint64_t ParentCount = CurrentCount; 521 522 BreakContinueStack.push_back(BreakContinue()); 523 // Visit the body region first. (This is basically the same as a while 524 // loop; see further comments in VisitWhileStmt.) 525 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 526 CountMap[S->getBody()] = BodyCount; 527 Visit(S->getBody()); 528 uint64_t BackedgeCount = CurrentCount; 529 BreakContinue BC = BreakContinueStack.pop_back_val(); 530 531 // The increment is essentially part of the body but it needs to include 532 // the count for all the continue statements. 533 if (S->getInc()) { 534 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 535 CountMap[S->getInc()] = IncCount; 536 Visit(S->getInc()); 537 } 538 539 // ...then go back and propagate counts through the condition. 540 uint64_t CondCount = 541 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 542 if (S->getCond()) { 543 CountMap[S->getCond()] = CondCount; 544 Visit(S->getCond()); 545 } 546 setCount(BC.BreakCount + CondCount - BodyCount); 547 RecordNextStmtCount = true; 548 } 549 550 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 551 RecordStmtCount(S); 552 if (S->getInit()) 553 Visit(S->getInit()); 554 Visit(S->getLoopVarStmt()); 555 Visit(S->getRangeStmt()); 556 Visit(S->getBeginStmt()); 557 Visit(S->getEndStmt()); 558 559 uint64_t ParentCount = CurrentCount; 560 BreakContinueStack.push_back(BreakContinue()); 561 // Visit the body region first. (This is basically the same as a while 562 // loop; see further comments in VisitWhileStmt.) 563 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 564 CountMap[S->getBody()] = BodyCount; 565 Visit(S->getBody()); 566 uint64_t BackedgeCount = CurrentCount; 567 BreakContinue BC = BreakContinueStack.pop_back_val(); 568 569 // The increment is essentially part of the body but it needs to include 570 // the count for all the continue statements. 571 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 572 CountMap[S->getInc()] = IncCount; 573 Visit(S->getInc()); 574 575 // ...then go back and propagate counts through the condition. 576 uint64_t CondCount = 577 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 578 CountMap[S->getCond()] = CondCount; 579 Visit(S->getCond()); 580 setCount(BC.BreakCount + CondCount - BodyCount); 581 RecordNextStmtCount = true; 582 } 583 584 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 585 RecordStmtCount(S); 586 Visit(S->getElement()); 587 uint64_t ParentCount = CurrentCount; 588 BreakContinueStack.push_back(BreakContinue()); 589 // Counter tracks the body of the loop. 590 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 591 CountMap[S->getBody()] = BodyCount; 592 Visit(S->getBody()); 593 uint64_t BackedgeCount = CurrentCount; 594 BreakContinue BC = BreakContinueStack.pop_back_val(); 595 596 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 597 BodyCount); 598 RecordNextStmtCount = true; 599 } 600 601 void VisitSwitchStmt(const SwitchStmt *S) { 602 RecordStmtCount(S); 603 if (S->getInit()) 604 Visit(S->getInit()); 605 Visit(S->getCond()); 606 CurrentCount = 0; 607 BreakContinueStack.push_back(BreakContinue()); 608 Visit(S->getBody()); 609 // If the switch is inside a loop, add the continue counts. 610 BreakContinue BC = BreakContinueStack.pop_back_val(); 611 if (!BreakContinueStack.empty()) 612 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 613 // Counter tracks the exit block of the switch. 614 setCount(PGO.getRegionCount(S)); 615 RecordNextStmtCount = true; 616 } 617 618 void VisitSwitchCase(const SwitchCase *S) { 619 RecordNextStmtCount = false; 620 // Counter for this particular case. This counts only jumps from the 621 // switch header and does not include fallthrough from the case before 622 // this one. 623 uint64_t CaseCount = PGO.getRegionCount(S); 624 setCount(CurrentCount + CaseCount); 625 // We need the count without fallthrough in the mapping, so it's more useful 626 // for branch probabilities. 627 CountMap[S] = CaseCount; 628 RecordNextStmtCount = true; 629 Visit(S->getSubStmt()); 630 } 631 632 void VisitIfStmt(const IfStmt *S) { 633 RecordStmtCount(S); 634 uint64_t ParentCount = CurrentCount; 635 if (S->getInit()) 636 Visit(S->getInit()); 637 Visit(S->getCond()); 638 639 // Counter tracks the "then" part of an if statement. The count for 640 // the "else" part, if it exists, will be calculated from this counter. 641 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 642 CountMap[S->getThen()] = ThenCount; 643 Visit(S->getThen()); 644 uint64_t OutCount = CurrentCount; 645 646 uint64_t ElseCount = ParentCount - ThenCount; 647 if (S->getElse()) { 648 setCount(ElseCount); 649 CountMap[S->getElse()] = ElseCount; 650 Visit(S->getElse()); 651 OutCount += CurrentCount; 652 } else 653 OutCount += ElseCount; 654 setCount(OutCount); 655 RecordNextStmtCount = true; 656 } 657 658 void VisitCXXTryStmt(const CXXTryStmt *S) { 659 RecordStmtCount(S); 660 Visit(S->getTryBlock()); 661 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 662 Visit(S->getHandler(I)); 663 // Counter tracks the continuation block of the try statement. 664 setCount(PGO.getRegionCount(S)); 665 RecordNextStmtCount = true; 666 } 667 668 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 669 RecordNextStmtCount = false; 670 // Counter tracks the catch statement's handler block. 671 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 672 CountMap[S] = CatchCount; 673 Visit(S->getHandlerBlock()); 674 } 675 676 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 677 RecordStmtCount(E); 678 uint64_t ParentCount = CurrentCount; 679 Visit(E->getCond()); 680 681 // Counter tracks the "true" part of a conditional operator. The 682 // count in the "false" part will be calculated from this counter. 683 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 684 CountMap[E->getTrueExpr()] = TrueCount; 685 Visit(E->getTrueExpr()); 686 uint64_t OutCount = CurrentCount; 687 688 uint64_t FalseCount = setCount(ParentCount - TrueCount); 689 CountMap[E->getFalseExpr()] = FalseCount; 690 Visit(E->getFalseExpr()); 691 OutCount += CurrentCount; 692 693 setCount(OutCount); 694 RecordNextStmtCount = true; 695 } 696 697 void VisitBinLAnd(const BinaryOperator *E) { 698 RecordStmtCount(E); 699 uint64_t ParentCount = CurrentCount; 700 Visit(E->getLHS()); 701 // Counter tracks the right hand side of a logical and operator. 702 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 703 CountMap[E->getRHS()] = RHSCount; 704 Visit(E->getRHS()); 705 setCount(ParentCount + RHSCount - CurrentCount); 706 RecordNextStmtCount = true; 707 } 708 709 void VisitBinLOr(const BinaryOperator *E) { 710 RecordStmtCount(E); 711 uint64_t ParentCount = CurrentCount; 712 Visit(E->getLHS()); 713 // Counter tracks the right hand side of a logical or operator. 714 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 715 CountMap[E->getRHS()] = RHSCount; 716 Visit(E->getRHS()); 717 setCount(ParentCount + RHSCount - CurrentCount); 718 RecordNextStmtCount = true; 719 } 720 }; 721 } // end anonymous namespace 722 723 void PGOHash::combine(HashType Type) { 724 // Check that we never combine 0 and only have six bits. 725 assert(Type && "Hash is invalid: unexpected type 0"); 726 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 727 728 // Pass through MD5 if enough work has built up. 729 if (Count && Count % NumTypesPerWord == 0) { 730 using namespace llvm::support; 731 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 732 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 733 Working = 0; 734 } 735 736 // Accumulate the current type. 737 ++Count; 738 Working = Working << NumBitsPerType | Type; 739 } 740 741 uint64_t PGOHash::finalize() { 742 // Use Working as the hash directly if we never used MD5. 743 if (Count <= NumTypesPerWord) 744 // No need to byte swap here, since none of the math was endian-dependent. 745 // This number will be byte-swapped as required on endianness transitions, 746 // so we will see the same value on the other side. 747 return Working; 748 749 // Check for remaining work in Working. 750 if (Working) 751 MD5.update(Working); 752 753 // Finalize the MD5 and return the hash. 754 llvm::MD5::MD5Result Result; 755 MD5.final(Result); 756 using namespace llvm::support; 757 return Result.low(); 758 } 759 760 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 761 const Decl *D = GD.getDecl(); 762 if (!D->hasBody()) 763 return; 764 765 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 766 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 767 if (!InstrumentRegions && !PGOReader) 768 return; 769 if (D->isImplicit()) 770 return; 771 // Constructors and destructors may be represented by several functions in IR. 772 // If so, instrument only base variant, others are implemented by delegation 773 // to the base one, it would be counted twice otherwise. 774 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 775 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 776 if (GD.getCtorType() != Ctor_Base && 777 CodeGenFunction::IsConstructorDelegationValid(CCD)) 778 return; 779 } 780 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 781 return; 782 783 CGM.ClearUnusedCoverageMapping(D); 784 setFuncName(Fn); 785 786 mapRegionCounters(D); 787 if (CGM.getCodeGenOpts().CoverageMapping) 788 emitCounterRegionMapping(D); 789 if (PGOReader) { 790 SourceManager &SM = CGM.getContext().getSourceManager(); 791 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 792 computeRegionCounts(D); 793 applyFunctionAttributes(PGOReader, Fn); 794 } 795 } 796 797 void CodeGenPGO::mapRegionCounters(const Decl *D) { 798 // Use the latest hash version when inserting instrumentation, but use the 799 // version in the indexed profile if we're reading PGO data. 800 PGOHashVersion HashVersion = PGO_HASH_LATEST; 801 if (auto *PGOReader = CGM.getPGOReader()) 802 HashVersion = getPGOHashVersion(PGOReader, CGM); 803 804 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 805 MapRegionCounters Walker(HashVersion, *RegionCounterMap); 806 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 807 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 808 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 809 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 810 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 811 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 812 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 813 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 814 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 815 NumRegionCounters = Walker.NextCounter; 816 FunctionHash = Walker.Hash.finalize(); 817 } 818 819 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 820 if (!D->getBody()) 821 return true; 822 823 // Don't map the functions in system headers. 824 const auto &SM = CGM.getContext().getSourceManager(); 825 auto Loc = D->getBody()->getBeginLoc(); 826 return SM.isInSystemHeader(Loc); 827 } 828 829 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 830 if (skipRegionMappingForDecl(D)) 831 return; 832 833 std::string CoverageMapping; 834 llvm::raw_string_ostream OS(CoverageMapping); 835 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 836 CGM.getContext().getSourceManager(), 837 CGM.getLangOpts(), RegionCounterMap.get()); 838 MappingGen.emitCounterMapping(D, OS); 839 OS.flush(); 840 841 if (CoverageMapping.empty()) 842 return; 843 844 CGM.getCoverageMapping()->addFunctionMappingRecord( 845 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 846 } 847 848 void 849 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 850 llvm::GlobalValue::LinkageTypes Linkage) { 851 if (skipRegionMappingForDecl(D)) 852 return; 853 854 std::string CoverageMapping; 855 llvm::raw_string_ostream OS(CoverageMapping); 856 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 857 CGM.getContext().getSourceManager(), 858 CGM.getLangOpts()); 859 MappingGen.emitEmptyMapping(D, OS); 860 OS.flush(); 861 862 if (CoverageMapping.empty()) 863 return; 864 865 setFuncName(Name, Linkage); 866 CGM.getCoverageMapping()->addFunctionMappingRecord( 867 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 868 } 869 870 void CodeGenPGO::computeRegionCounts(const Decl *D) { 871 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 872 ComputeRegionCounts Walker(*StmtCountMap, *this); 873 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 874 Walker.VisitFunctionDecl(FD); 875 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 876 Walker.VisitObjCMethodDecl(MD); 877 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 878 Walker.VisitBlockDecl(BD); 879 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 880 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 881 } 882 883 void 884 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 885 llvm::Function *Fn) { 886 if (!haveRegionCounts()) 887 return; 888 889 uint64_t FunctionCount = getRegionCount(nullptr); 890 Fn->setEntryCount(FunctionCount); 891 } 892 893 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 894 llvm::Value *StepV) { 895 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 896 return; 897 if (!Builder.GetInsertBlock()) 898 return; 899 900 unsigned Counter = (*RegionCounterMap)[S]; 901 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 902 903 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 904 Builder.getInt64(FunctionHash), 905 Builder.getInt32(NumRegionCounters), 906 Builder.getInt32(Counter), StepV}; 907 if (!StepV) 908 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 909 makeArrayRef(Args, 4)); 910 else 911 Builder.CreateCall( 912 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 913 makeArrayRef(Args)); 914 } 915 916 // This method either inserts a call to the profile run-time during 917 // instrumentation or puts profile data into metadata for PGO use. 918 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 919 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 920 921 if (!EnableValueProfiling) 922 return; 923 924 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 925 return; 926 927 if (isa<llvm::Constant>(ValuePtr)) 928 return; 929 930 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 931 if (InstrumentValueSites && RegionCounterMap) { 932 auto BuilderInsertPoint = Builder.saveIP(); 933 Builder.SetInsertPoint(ValueSite); 934 llvm::Value *Args[5] = { 935 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 936 Builder.getInt64(FunctionHash), 937 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 938 Builder.getInt32(ValueKind), 939 Builder.getInt32(NumValueSites[ValueKind]++) 940 }; 941 Builder.CreateCall( 942 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 943 Builder.restoreIP(BuilderInsertPoint); 944 return; 945 } 946 947 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 948 if (PGOReader && haveRegionCounts()) { 949 // We record the top most called three functions at each call site. 950 // Profile metadata contains "VP" string identifying this metadata 951 // as value profiling data, then a uint32_t value for the value profiling 952 // kind, a uint64_t value for the total number of times the call is 953 // executed, followed by the function hash and execution count (uint64_t) 954 // pairs for each function. 955 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 956 return; 957 958 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 959 (llvm::InstrProfValueKind)ValueKind, 960 NumValueSites[ValueKind]); 961 962 NumValueSites[ValueKind]++; 963 } 964 } 965 966 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 967 bool IsInMainFile) { 968 CGM.getPGOStats().addVisited(IsInMainFile); 969 RegionCounts.clear(); 970 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 971 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 972 if (auto E = RecordExpected.takeError()) { 973 auto IPE = llvm::InstrProfError::take(std::move(E)); 974 if (IPE == llvm::instrprof_error::unknown_function) 975 CGM.getPGOStats().addMissing(IsInMainFile); 976 else if (IPE == llvm::instrprof_error::hash_mismatch) 977 CGM.getPGOStats().addMismatched(IsInMainFile); 978 else if (IPE == llvm::instrprof_error::malformed) 979 // TODO: Consider a more specific warning for this case. 980 CGM.getPGOStats().addMismatched(IsInMainFile); 981 return; 982 } 983 ProfRecord = 984 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 985 RegionCounts = ProfRecord->Counts; 986 } 987 988 /// Calculate what to divide by to scale weights. 989 /// 990 /// Given the maximum weight, calculate a divisor that will scale all the 991 /// weights to strictly less than UINT32_MAX. 992 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 993 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 994 } 995 996 /// Scale an individual branch weight (and add 1). 997 /// 998 /// Scale a 64-bit weight down to 32-bits using \c Scale. 999 /// 1000 /// According to Laplace's Rule of Succession, it is better to compute the 1001 /// weight based on the count plus 1, so universally add 1 to the value. 1002 /// 1003 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1004 /// greater than \c Weight. 1005 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1006 assert(Scale && "scale by 0?"); 1007 uint64_t Scaled = Weight / Scale + 1; 1008 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1009 return Scaled; 1010 } 1011 1012 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1013 uint64_t FalseCount) { 1014 // Check for empty weights. 1015 if (!TrueCount && !FalseCount) 1016 return nullptr; 1017 1018 // Calculate how to scale down to 32-bits. 1019 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1020 1021 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1022 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1023 scaleBranchWeight(FalseCount, Scale)); 1024 } 1025 1026 llvm::MDNode * 1027 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 1028 // We need at least two elements to create meaningful weights. 1029 if (Weights.size() < 2) 1030 return nullptr; 1031 1032 // Check for empty weights. 1033 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1034 if (MaxWeight == 0) 1035 return nullptr; 1036 1037 // Calculate how to scale down to 32-bits. 1038 uint64_t Scale = calculateWeightScale(MaxWeight); 1039 1040 SmallVector<uint32_t, 16> ScaledWeights; 1041 ScaledWeights.reserve(Weights.size()); 1042 for (uint64_t W : Weights) 1043 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1044 1045 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1046 return MDHelper.createBranchWeights(ScaledWeights); 1047 } 1048 1049 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1050 uint64_t LoopCount) { 1051 if (!PGO.haveRegionCounts()) 1052 return nullptr; 1053 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1054 assert(CondCount.hasValue() && "missing expected loop condition count"); 1055 if (*CondCount == 0) 1056 return nullptr; 1057 return createProfileWeights(LoopCount, 1058 std::max(*CondCount, LoopCount) - LoopCount); 1059 } 1060