1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Instrumentation-based profile-guided optimization 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodeGenPGO.h" 14 #include "CodeGenFunction.h" 15 #include "CoverageMappingGen.h" 16 #include "clang/AST/RecursiveASTVisitor.h" 17 #include "clang/AST/StmtVisitor.h" 18 #include "llvm/IR/Intrinsics.h" 19 #include "llvm/IR/MDBuilder.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> 26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), 28 llvm::cl::Hidden, llvm::cl::init(false)); 29 30 using namespace clang; 31 using namespace CodeGen; 32 33 void CodeGenPGO::setFuncName(StringRef Name, 34 llvm::GlobalValue::LinkageTypes Linkage) { 35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 36 FuncName = llvm::getPGOFuncName( 37 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 39 40 // If we're generating a profile, create a variable for the name. 41 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 43 } 44 45 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 46 setFuncName(Fn->getName(), Fn->getLinkage()); 47 // Create PGOFuncName meta data. 48 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 49 } 50 51 /// The version of the PGO hash algorithm. 52 enum PGOHashVersion : unsigned { 53 PGO_HASH_V1, 54 PGO_HASH_V2, 55 PGO_HASH_V3, 56 57 // Keep this set to the latest hash version. 58 PGO_HASH_LATEST = PGO_HASH_V3 59 }; 60 61 namespace { 62 /// Stable hasher for PGO region counters. 63 /// 64 /// PGOHash produces a stable hash of a given function's control flow. 65 /// 66 /// Changing the output of this hash will invalidate all previously generated 67 /// profiles -- i.e., don't do it. 68 /// 69 /// \note When this hash does eventually change (years?), we still need to 70 /// support old hashes. We'll need to pull in the version number from the 71 /// profile data format and use the matching hash function. 72 class PGOHash { 73 uint64_t Working; 74 unsigned Count; 75 PGOHashVersion HashVersion; 76 llvm::MD5 MD5; 77 78 static const int NumBitsPerType = 6; 79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 80 static const unsigned TooBig = 1u << NumBitsPerType; 81 82 public: 83 /// Hash values for AST nodes. 84 /// 85 /// Distinct values for AST nodes that have region counters attached. 86 /// 87 /// These values must be stable. All new members must be added at the end, 88 /// and no members should be removed. Changing the enumeration value for an 89 /// AST node will affect the hash of every function that contains that node. 90 enum HashType : unsigned char { 91 None = 0, 92 LabelStmt = 1, 93 WhileStmt, 94 DoStmt, 95 ForStmt, 96 CXXForRangeStmt, 97 ObjCForCollectionStmt, 98 SwitchStmt, 99 CaseStmt, 100 DefaultStmt, 101 IfStmt, 102 CXXTryStmt, 103 CXXCatchStmt, 104 ConditionalOperator, 105 BinaryOperatorLAnd, 106 BinaryOperatorLOr, 107 BinaryConditionalOperator, 108 // The preceding values are available with PGO_HASH_V1. 109 110 EndOfScope, 111 IfThenBranch, 112 IfElseBranch, 113 GotoStmt, 114 IndirectGotoStmt, 115 BreakStmt, 116 ContinueStmt, 117 ReturnStmt, 118 ThrowExpr, 119 UnaryOperatorLNot, 120 BinaryOperatorLT, 121 BinaryOperatorGT, 122 BinaryOperatorLE, 123 BinaryOperatorGE, 124 BinaryOperatorEQ, 125 BinaryOperatorNE, 126 // The preceding values are available since PGO_HASH_V2. 127 128 // Keep this last. It's for the static assert that follows. 129 LastHashType 130 }; 131 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 132 133 PGOHash(PGOHashVersion HashVersion) 134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {} 135 void combine(HashType Type); 136 uint64_t finalize(); 137 PGOHashVersion getHashVersion() const { return HashVersion; } 138 }; 139 const int PGOHash::NumBitsPerType; 140 const unsigned PGOHash::NumTypesPerWord; 141 const unsigned PGOHash::TooBig; 142 143 /// Get the PGO hash version used in the given indexed profile. 144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader, 145 CodeGenModule &CGM) { 146 if (PGOReader->getVersion() <= 4) 147 return PGO_HASH_V1; 148 if (PGOReader->getVersion() <= 5) 149 return PGO_HASH_V2; 150 return PGO_HASH_V3; 151 } 152 153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 155 using Base = RecursiveASTVisitor<MapRegionCounters>; 156 157 /// The next counter value to assign. 158 unsigned NextCounter; 159 /// The function hash. 160 PGOHash Hash; 161 /// The map of statements to counters. 162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 163 164 MapRegionCounters(PGOHashVersion HashVersion, 165 llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 166 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {} 167 168 // Blocks and lambdas are handled as separate functions, so we need not 169 // traverse them in the parent context. 170 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 171 bool TraverseLambdaExpr(LambdaExpr *LE) { 172 // Traverse the captures, but not the body. 173 for (auto C : zip(LE->captures(), LE->capture_inits())) 174 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C)); 175 return true; 176 } 177 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 178 179 bool VisitDecl(const Decl *D) { 180 switch (D->getKind()) { 181 default: 182 break; 183 case Decl::Function: 184 case Decl::CXXMethod: 185 case Decl::CXXConstructor: 186 case Decl::CXXDestructor: 187 case Decl::CXXConversion: 188 case Decl::ObjCMethod: 189 case Decl::Block: 190 case Decl::Captured: 191 CounterMap[D->getBody()] = NextCounter++; 192 break; 193 } 194 return true; 195 } 196 197 /// If \p S gets a fresh counter, update the counter mappings. Return the 198 /// V1 hash of \p S. 199 PGOHash::HashType updateCounterMappings(Stmt *S) { 200 auto Type = getHashType(PGO_HASH_V1, S); 201 if (Type != PGOHash::None) 202 CounterMap[S] = NextCounter++; 203 return Type; 204 } 205 206 /// Include \p S in the function hash. 207 bool VisitStmt(Stmt *S) { 208 auto Type = updateCounterMappings(S); 209 if (Hash.getHashVersion() != PGO_HASH_V1) 210 Type = getHashType(Hash.getHashVersion(), S); 211 if (Type != PGOHash::None) 212 Hash.combine(Type); 213 return true; 214 } 215 216 bool TraverseIfStmt(IfStmt *If) { 217 // If we used the V1 hash, use the default traversal. 218 if (Hash.getHashVersion() == PGO_HASH_V1) 219 return Base::TraverseIfStmt(If); 220 221 // Otherwise, keep track of which branch we're in while traversing. 222 VisitStmt(If); 223 for (Stmt *CS : If->children()) { 224 if (!CS) 225 continue; 226 if (CS == If->getThen()) 227 Hash.combine(PGOHash::IfThenBranch); 228 else if (CS == If->getElse()) 229 Hash.combine(PGOHash::IfElseBranch); 230 TraverseStmt(CS); 231 } 232 Hash.combine(PGOHash::EndOfScope); 233 return true; 234 } 235 236 // If the statement type \p N is nestable, and its nesting impacts profile 237 // stability, define a custom traversal which tracks the end of the statement 238 // in the hash (provided we're not using the V1 hash). 239 #define DEFINE_NESTABLE_TRAVERSAL(N) \ 240 bool Traverse##N(N *S) { \ 241 Base::Traverse##N(S); \ 242 if (Hash.getHashVersion() != PGO_HASH_V1) \ 243 Hash.combine(PGOHash::EndOfScope); \ 244 return true; \ 245 } 246 247 DEFINE_NESTABLE_TRAVERSAL(WhileStmt) 248 DEFINE_NESTABLE_TRAVERSAL(DoStmt) 249 DEFINE_NESTABLE_TRAVERSAL(ForStmt) 250 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt) 251 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt) 252 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt) 253 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt) 254 255 /// Get version \p HashVersion of the PGO hash for \p S. 256 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) { 257 switch (S->getStmtClass()) { 258 default: 259 break; 260 case Stmt::LabelStmtClass: 261 return PGOHash::LabelStmt; 262 case Stmt::WhileStmtClass: 263 return PGOHash::WhileStmt; 264 case Stmt::DoStmtClass: 265 return PGOHash::DoStmt; 266 case Stmt::ForStmtClass: 267 return PGOHash::ForStmt; 268 case Stmt::CXXForRangeStmtClass: 269 return PGOHash::CXXForRangeStmt; 270 case Stmt::ObjCForCollectionStmtClass: 271 return PGOHash::ObjCForCollectionStmt; 272 case Stmt::SwitchStmtClass: 273 return PGOHash::SwitchStmt; 274 case Stmt::CaseStmtClass: 275 return PGOHash::CaseStmt; 276 case Stmt::DefaultStmtClass: 277 return PGOHash::DefaultStmt; 278 case Stmt::IfStmtClass: 279 return PGOHash::IfStmt; 280 case Stmt::CXXTryStmtClass: 281 return PGOHash::CXXTryStmt; 282 case Stmt::CXXCatchStmtClass: 283 return PGOHash::CXXCatchStmt; 284 case Stmt::ConditionalOperatorClass: 285 return PGOHash::ConditionalOperator; 286 case Stmt::BinaryConditionalOperatorClass: 287 return PGOHash::BinaryConditionalOperator; 288 case Stmt::BinaryOperatorClass: { 289 const BinaryOperator *BO = cast<BinaryOperator>(S); 290 if (BO->getOpcode() == BO_LAnd) 291 return PGOHash::BinaryOperatorLAnd; 292 if (BO->getOpcode() == BO_LOr) 293 return PGOHash::BinaryOperatorLOr; 294 if (HashVersion >= PGO_HASH_V2) { 295 switch (BO->getOpcode()) { 296 default: 297 break; 298 case BO_LT: 299 return PGOHash::BinaryOperatorLT; 300 case BO_GT: 301 return PGOHash::BinaryOperatorGT; 302 case BO_LE: 303 return PGOHash::BinaryOperatorLE; 304 case BO_GE: 305 return PGOHash::BinaryOperatorGE; 306 case BO_EQ: 307 return PGOHash::BinaryOperatorEQ; 308 case BO_NE: 309 return PGOHash::BinaryOperatorNE; 310 } 311 } 312 break; 313 } 314 } 315 316 if (HashVersion >= PGO_HASH_V2) { 317 switch (S->getStmtClass()) { 318 default: 319 break; 320 case Stmt::GotoStmtClass: 321 return PGOHash::GotoStmt; 322 case Stmt::IndirectGotoStmtClass: 323 return PGOHash::IndirectGotoStmt; 324 case Stmt::BreakStmtClass: 325 return PGOHash::BreakStmt; 326 case Stmt::ContinueStmtClass: 327 return PGOHash::ContinueStmt; 328 case Stmt::ReturnStmtClass: 329 return PGOHash::ReturnStmt; 330 case Stmt::CXXThrowExprClass: 331 return PGOHash::ThrowExpr; 332 case Stmt::UnaryOperatorClass: { 333 const UnaryOperator *UO = cast<UnaryOperator>(S); 334 if (UO->getOpcode() == UO_LNot) 335 return PGOHash::UnaryOperatorLNot; 336 break; 337 } 338 } 339 } 340 341 return PGOHash::None; 342 } 343 }; 344 345 /// A StmtVisitor that propagates the raw counts through the AST and 346 /// records the count at statements where the value may change. 347 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 348 /// PGO state. 349 CodeGenPGO &PGO; 350 351 /// A flag that is set when the current count should be recorded on the 352 /// next statement, such as at the exit of a loop. 353 bool RecordNextStmtCount; 354 355 /// The count at the current location in the traversal. 356 uint64_t CurrentCount; 357 358 /// The map of statements to count values. 359 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 360 361 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 362 struct BreakContinue { 363 uint64_t BreakCount; 364 uint64_t ContinueCount; 365 BreakContinue() : BreakCount(0), ContinueCount(0) {} 366 }; 367 SmallVector<BreakContinue, 8> BreakContinueStack; 368 369 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 370 CodeGenPGO &PGO) 371 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 372 373 void RecordStmtCount(const Stmt *S) { 374 if (RecordNextStmtCount) { 375 CountMap[S] = CurrentCount; 376 RecordNextStmtCount = false; 377 } 378 } 379 380 /// Set and return the current count. 381 uint64_t setCount(uint64_t Count) { 382 CurrentCount = Count; 383 return Count; 384 } 385 386 void VisitStmt(const Stmt *S) { 387 RecordStmtCount(S); 388 for (const Stmt *Child : S->children()) 389 if (Child) 390 this->Visit(Child); 391 } 392 393 void VisitFunctionDecl(const FunctionDecl *D) { 394 // Counter tracks entry to the function body. 395 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 396 CountMap[D->getBody()] = BodyCount; 397 Visit(D->getBody()); 398 } 399 400 // Skip lambda expressions. We visit these as FunctionDecls when we're 401 // generating them and aren't interested in the body when generating a 402 // parent context. 403 void VisitLambdaExpr(const LambdaExpr *LE) {} 404 405 void VisitCapturedDecl(const CapturedDecl *D) { 406 // Counter tracks entry to the capture body. 407 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 408 CountMap[D->getBody()] = BodyCount; 409 Visit(D->getBody()); 410 } 411 412 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 413 // Counter tracks entry to the method body. 414 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 415 CountMap[D->getBody()] = BodyCount; 416 Visit(D->getBody()); 417 } 418 419 void VisitBlockDecl(const BlockDecl *D) { 420 // Counter tracks entry to the block body. 421 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 422 CountMap[D->getBody()] = BodyCount; 423 Visit(D->getBody()); 424 } 425 426 void VisitReturnStmt(const ReturnStmt *S) { 427 RecordStmtCount(S); 428 if (S->getRetValue()) 429 Visit(S->getRetValue()); 430 CurrentCount = 0; 431 RecordNextStmtCount = true; 432 } 433 434 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 435 RecordStmtCount(E); 436 if (E->getSubExpr()) 437 Visit(E->getSubExpr()); 438 CurrentCount = 0; 439 RecordNextStmtCount = true; 440 } 441 442 void VisitGotoStmt(const GotoStmt *S) { 443 RecordStmtCount(S); 444 CurrentCount = 0; 445 RecordNextStmtCount = true; 446 } 447 448 void VisitLabelStmt(const LabelStmt *S) { 449 RecordNextStmtCount = false; 450 // Counter tracks the block following the label. 451 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 452 CountMap[S] = BlockCount; 453 Visit(S->getSubStmt()); 454 } 455 456 void VisitBreakStmt(const BreakStmt *S) { 457 RecordStmtCount(S); 458 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 459 BreakContinueStack.back().BreakCount += CurrentCount; 460 CurrentCount = 0; 461 RecordNextStmtCount = true; 462 } 463 464 void VisitContinueStmt(const ContinueStmt *S) { 465 RecordStmtCount(S); 466 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 467 BreakContinueStack.back().ContinueCount += CurrentCount; 468 CurrentCount = 0; 469 RecordNextStmtCount = true; 470 } 471 472 void VisitWhileStmt(const WhileStmt *S) { 473 RecordStmtCount(S); 474 uint64_t ParentCount = CurrentCount; 475 476 BreakContinueStack.push_back(BreakContinue()); 477 // Visit the body region first so the break/continue adjustments can be 478 // included when visiting the condition. 479 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 480 CountMap[S->getBody()] = CurrentCount; 481 Visit(S->getBody()); 482 uint64_t BackedgeCount = CurrentCount; 483 484 // ...then go back and propagate counts through the condition. The count 485 // at the start of the condition is the sum of the incoming edges, 486 // the backedge from the end of the loop body, and the edges from 487 // continue statements. 488 BreakContinue BC = BreakContinueStack.pop_back_val(); 489 uint64_t CondCount = 490 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 491 CountMap[S->getCond()] = CondCount; 492 Visit(S->getCond()); 493 setCount(BC.BreakCount + CondCount - BodyCount); 494 RecordNextStmtCount = true; 495 } 496 497 void VisitDoStmt(const DoStmt *S) { 498 RecordStmtCount(S); 499 uint64_t LoopCount = PGO.getRegionCount(S); 500 501 BreakContinueStack.push_back(BreakContinue()); 502 // The count doesn't include the fallthrough from the parent scope. Add it. 503 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 504 CountMap[S->getBody()] = BodyCount; 505 Visit(S->getBody()); 506 uint64_t BackedgeCount = CurrentCount; 507 508 BreakContinue BC = BreakContinueStack.pop_back_val(); 509 // The count at the start of the condition is equal to the count at the 510 // end of the body, plus any continues. 511 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 512 CountMap[S->getCond()] = CondCount; 513 Visit(S->getCond()); 514 setCount(BC.BreakCount + CondCount - LoopCount); 515 RecordNextStmtCount = true; 516 } 517 518 void VisitForStmt(const ForStmt *S) { 519 RecordStmtCount(S); 520 if (S->getInit()) 521 Visit(S->getInit()); 522 523 uint64_t ParentCount = CurrentCount; 524 525 BreakContinueStack.push_back(BreakContinue()); 526 // Visit the body region first. (This is basically the same as a while 527 // loop; see further comments in VisitWhileStmt.) 528 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 529 CountMap[S->getBody()] = BodyCount; 530 Visit(S->getBody()); 531 uint64_t BackedgeCount = CurrentCount; 532 BreakContinue BC = BreakContinueStack.pop_back_val(); 533 534 // The increment is essentially part of the body but it needs to include 535 // the count for all the continue statements. 536 if (S->getInc()) { 537 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 538 CountMap[S->getInc()] = IncCount; 539 Visit(S->getInc()); 540 } 541 542 // ...then go back and propagate counts through the condition. 543 uint64_t CondCount = 544 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 545 if (S->getCond()) { 546 CountMap[S->getCond()] = CondCount; 547 Visit(S->getCond()); 548 } 549 setCount(BC.BreakCount + CondCount - BodyCount); 550 RecordNextStmtCount = true; 551 } 552 553 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 554 RecordStmtCount(S); 555 if (S->getInit()) 556 Visit(S->getInit()); 557 Visit(S->getLoopVarStmt()); 558 Visit(S->getRangeStmt()); 559 Visit(S->getBeginStmt()); 560 Visit(S->getEndStmt()); 561 562 uint64_t ParentCount = CurrentCount; 563 BreakContinueStack.push_back(BreakContinue()); 564 // Visit the body region first. (This is basically the same as a while 565 // loop; see further comments in VisitWhileStmt.) 566 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 567 CountMap[S->getBody()] = BodyCount; 568 Visit(S->getBody()); 569 uint64_t BackedgeCount = CurrentCount; 570 BreakContinue BC = BreakContinueStack.pop_back_val(); 571 572 // The increment is essentially part of the body but it needs to include 573 // the count for all the continue statements. 574 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 575 CountMap[S->getInc()] = IncCount; 576 Visit(S->getInc()); 577 578 // ...then go back and propagate counts through the condition. 579 uint64_t CondCount = 580 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 581 CountMap[S->getCond()] = CondCount; 582 Visit(S->getCond()); 583 setCount(BC.BreakCount + CondCount - BodyCount); 584 RecordNextStmtCount = true; 585 } 586 587 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 588 RecordStmtCount(S); 589 Visit(S->getElement()); 590 uint64_t ParentCount = CurrentCount; 591 BreakContinueStack.push_back(BreakContinue()); 592 // Counter tracks the body of the loop. 593 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 594 CountMap[S->getBody()] = BodyCount; 595 Visit(S->getBody()); 596 uint64_t BackedgeCount = CurrentCount; 597 BreakContinue BC = BreakContinueStack.pop_back_val(); 598 599 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 600 BodyCount); 601 RecordNextStmtCount = true; 602 } 603 604 void VisitSwitchStmt(const SwitchStmt *S) { 605 RecordStmtCount(S); 606 if (S->getInit()) 607 Visit(S->getInit()); 608 Visit(S->getCond()); 609 CurrentCount = 0; 610 BreakContinueStack.push_back(BreakContinue()); 611 Visit(S->getBody()); 612 // If the switch is inside a loop, add the continue counts. 613 BreakContinue BC = BreakContinueStack.pop_back_val(); 614 if (!BreakContinueStack.empty()) 615 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 616 // Counter tracks the exit block of the switch. 617 setCount(PGO.getRegionCount(S)); 618 RecordNextStmtCount = true; 619 } 620 621 void VisitSwitchCase(const SwitchCase *S) { 622 RecordNextStmtCount = false; 623 // Counter for this particular case. This counts only jumps from the 624 // switch header and does not include fallthrough from the case before 625 // this one. 626 uint64_t CaseCount = PGO.getRegionCount(S); 627 setCount(CurrentCount + CaseCount); 628 // We need the count without fallthrough in the mapping, so it's more useful 629 // for branch probabilities. 630 CountMap[S] = CaseCount; 631 RecordNextStmtCount = true; 632 Visit(S->getSubStmt()); 633 } 634 635 void VisitIfStmt(const IfStmt *S) { 636 RecordStmtCount(S); 637 uint64_t ParentCount = CurrentCount; 638 if (S->getInit()) 639 Visit(S->getInit()); 640 Visit(S->getCond()); 641 642 // Counter tracks the "then" part of an if statement. The count for 643 // the "else" part, if it exists, will be calculated from this counter. 644 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 645 CountMap[S->getThen()] = ThenCount; 646 Visit(S->getThen()); 647 uint64_t OutCount = CurrentCount; 648 649 uint64_t ElseCount = ParentCount - ThenCount; 650 if (S->getElse()) { 651 setCount(ElseCount); 652 CountMap[S->getElse()] = ElseCount; 653 Visit(S->getElse()); 654 OutCount += CurrentCount; 655 } else 656 OutCount += ElseCount; 657 setCount(OutCount); 658 RecordNextStmtCount = true; 659 } 660 661 void VisitCXXTryStmt(const CXXTryStmt *S) { 662 RecordStmtCount(S); 663 Visit(S->getTryBlock()); 664 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 665 Visit(S->getHandler(I)); 666 // Counter tracks the continuation block of the try statement. 667 setCount(PGO.getRegionCount(S)); 668 RecordNextStmtCount = true; 669 } 670 671 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 672 RecordNextStmtCount = false; 673 // Counter tracks the catch statement's handler block. 674 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 675 CountMap[S] = CatchCount; 676 Visit(S->getHandlerBlock()); 677 } 678 679 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 680 RecordStmtCount(E); 681 uint64_t ParentCount = CurrentCount; 682 Visit(E->getCond()); 683 684 // Counter tracks the "true" part of a conditional operator. The 685 // count in the "false" part will be calculated from this counter. 686 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 687 CountMap[E->getTrueExpr()] = TrueCount; 688 Visit(E->getTrueExpr()); 689 uint64_t OutCount = CurrentCount; 690 691 uint64_t FalseCount = setCount(ParentCount - TrueCount); 692 CountMap[E->getFalseExpr()] = FalseCount; 693 Visit(E->getFalseExpr()); 694 OutCount += CurrentCount; 695 696 setCount(OutCount); 697 RecordNextStmtCount = true; 698 } 699 700 void VisitBinLAnd(const BinaryOperator *E) { 701 RecordStmtCount(E); 702 uint64_t ParentCount = CurrentCount; 703 Visit(E->getLHS()); 704 // Counter tracks the right hand side of a logical and operator. 705 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 706 CountMap[E->getRHS()] = RHSCount; 707 Visit(E->getRHS()); 708 setCount(ParentCount + RHSCount - CurrentCount); 709 RecordNextStmtCount = true; 710 } 711 712 void VisitBinLOr(const BinaryOperator *E) { 713 RecordStmtCount(E); 714 uint64_t ParentCount = CurrentCount; 715 Visit(E->getLHS()); 716 // Counter tracks the right hand side of a logical or operator. 717 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 718 CountMap[E->getRHS()] = RHSCount; 719 Visit(E->getRHS()); 720 setCount(ParentCount + RHSCount - CurrentCount); 721 RecordNextStmtCount = true; 722 } 723 }; 724 } // end anonymous namespace 725 726 void PGOHash::combine(HashType Type) { 727 // Check that we never combine 0 and only have six bits. 728 assert(Type && "Hash is invalid: unexpected type 0"); 729 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 730 731 // Pass through MD5 if enough work has built up. 732 if (Count && Count % NumTypesPerWord == 0) { 733 using namespace llvm::support; 734 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 735 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 736 Working = 0; 737 } 738 739 // Accumulate the current type. 740 ++Count; 741 Working = Working << NumBitsPerType | Type; 742 } 743 744 uint64_t PGOHash::finalize() { 745 // Use Working as the hash directly if we never used MD5. 746 if (Count <= NumTypesPerWord) 747 // No need to byte swap here, since none of the math was endian-dependent. 748 // This number will be byte-swapped as required on endianness transitions, 749 // so we will see the same value on the other side. 750 return Working; 751 752 // Check for remaining work in Working. 753 if (Working) { 754 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This 755 // is buggy because it converts a uint64_t into an array of uint8_t. 756 if (HashVersion < PGO_HASH_V3) { 757 MD5.update({(uint8_t)Working}); 758 } else { 759 using namespace llvm::support; 760 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 761 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 762 } 763 } 764 765 // Finalize the MD5 and return the hash. 766 llvm::MD5::MD5Result Result; 767 MD5.final(Result); 768 return Result.low(); 769 } 770 771 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 772 const Decl *D = GD.getDecl(); 773 if (!D->hasBody()) 774 return; 775 776 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 777 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 778 if (!InstrumentRegions && !PGOReader) 779 return; 780 if (D->isImplicit()) 781 return; 782 // Constructors and destructors may be represented by several functions in IR. 783 // If so, instrument only base variant, others are implemented by delegation 784 // to the base one, it would be counted twice otherwise. 785 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) { 786 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D)) 787 if (GD.getCtorType() != Ctor_Base && 788 CodeGenFunction::IsConstructorDelegationValid(CCD)) 789 return; 790 } 791 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base) 792 return; 793 794 CGM.ClearUnusedCoverageMapping(D); 795 setFuncName(Fn); 796 797 mapRegionCounters(D); 798 if (CGM.getCodeGenOpts().CoverageMapping) 799 emitCounterRegionMapping(D); 800 if (PGOReader) { 801 SourceManager &SM = CGM.getContext().getSourceManager(); 802 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 803 computeRegionCounts(D); 804 applyFunctionAttributes(PGOReader, Fn); 805 } 806 } 807 808 void CodeGenPGO::mapRegionCounters(const Decl *D) { 809 // Use the latest hash version when inserting instrumentation, but use the 810 // version in the indexed profile if we're reading PGO data. 811 PGOHashVersion HashVersion = PGO_HASH_LATEST; 812 if (auto *PGOReader = CGM.getPGOReader()) 813 HashVersion = getPGOHashVersion(PGOReader, CGM); 814 815 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 816 MapRegionCounters Walker(HashVersion, *RegionCounterMap); 817 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 818 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 819 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 820 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 821 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 822 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 823 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 824 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 825 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 826 NumRegionCounters = Walker.NextCounter; 827 FunctionHash = Walker.Hash.finalize(); 828 } 829 830 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 831 if (!D->getBody()) 832 return true; 833 834 // Don't map the functions in system headers. 835 const auto &SM = CGM.getContext().getSourceManager(); 836 auto Loc = D->getBody()->getBeginLoc(); 837 return SM.isInSystemHeader(Loc); 838 } 839 840 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 841 if (skipRegionMappingForDecl(D)) 842 return; 843 844 std::string CoverageMapping; 845 llvm::raw_string_ostream OS(CoverageMapping); 846 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 847 CGM.getContext().getSourceManager(), 848 CGM.getLangOpts(), RegionCounterMap.get()); 849 MappingGen.emitCounterMapping(D, OS); 850 OS.flush(); 851 852 if (CoverageMapping.empty()) 853 return; 854 855 CGM.getCoverageMapping()->addFunctionMappingRecord( 856 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 857 } 858 859 void 860 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 861 llvm::GlobalValue::LinkageTypes Linkage) { 862 if (skipRegionMappingForDecl(D)) 863 return; 864 865 std::string CoverageMapping; 866 llvm::raw_string_ostream OS(CoverageMapping); 867 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 868 CGM.getContext().getSourceManager(), 869 CGM.getLangOpts()); 870 MappingGen.emitEmptyMapping(D, OS); 871 OS.flush(); 872 873 if (CoverageMapping.empty()) 874 return; 875 876 setFuncName(Name, Linkage); 877 CGM.getCoverageMapping()->addFunctionMappingRecord( 878 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 879 } 880 881 void CodeGenPGO::computeRegionCounts(const Decl *D) { 882 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 883 ComputeRegionCounts Walker(*StmtCountMap, *this); 884 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 885 Walker.VisitFunctionDecl(FD); 886 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 887 Walker.VisitObjCMethodDecl(MD); 888 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 889 Walker.VisitBlockDecl(BD); 890 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 891 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 892 } 893 894 void 895 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 896 llvm::Function *Fn) { 897 if (!haveRegionCounts()) 898 return; 899 900 uint64_t FunctionCount = getRegionCount(nullptr); 901 Fn->setEntryCount(FunctionCount); 902 } 903 904 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S, 905 llvm::Value *StepV) { 906 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 907 return; 908 if (!Builder.GetInsertBlock()) 909 return; 910 911 unsigned Counter = (*RegionCounterMap)[S]; 912 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 913 914 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 915 Builder.getInt64(FunctionHash), 916 Builder.getInt32(NumRegionCounters), 917 Builder.getInt32(Counter), StepV}; 918 if (!StepV) 919 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 920 makeArrayRef(Args, 4)); 921 else 922 Builder.CreateCall( 923 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step), 924 makeArrayRef(Args)); 925 } 926 927 // This method either inserts a call to the profile run-time during 928 // instrumentation or puts profile data into metadata for PGO use. 929 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 930 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 931 932 if (!EnableValueProfiling) 933 return; 934 935 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 936 return; 937 938 if (isa<llvm::Constant>(ValuePtr)) 939 return; 940 941 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 942 if (InstrumentValueSites && RegionCounterMap) { 943 auto BuilderInsertPoint = Builder.saveIP(); 944 Builder.SetInsertPoint(ValueSite); 945 llvm::Value *Args[5] = { 946 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 947 Builder.getInt64(FunctionHash), 948 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 949 Builder.getInt32(ValueKind), 950 Builder.getInt32(NumValueSites[ValueKind]++) 951 }; 952 Builder.CreateCall( 953 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 954 Builder.restoreIP(BuilderInsertPoint); 955 return; 956 } 957 958 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 959 if (PGOReader && haveRegionCounts()) { 960 // We record the top most called three functions at each call site. 961 // Profile metadata contains "VP" string identifying this metadata 962 // as value profiling data, then a uint32_t value for the value profiling 963 // kind, a uint64_t value for the total number of times the call is 964 // executed, followed by the function hash and execution count (uint64_t) 965 // pairs for each function. 966 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 967 return; 968 969 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 970 (llvm::InstrProfValueKind)ValueKind, 971 NumValueSites[ValueKind]); 972 973 NumValueSites[ValueKind]++; 974 } 975 } 976 977 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 978 bool IsInMainFile) { 979 CGM.getPGOStats().addVisited(IsInMainFile); 980 RegionCounts.clear(); 981 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 982 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 983 if (auto E = RecordExpected.takeError()) { 984 auto IPE = llvm::InstrProfError::take(std::move(E)); 985 if (IPE == llvm::instrprof_error::unknown_function) 986 CGM.getPGOStats().addMissing(IsInMainFile); 987 else if (IPE == llvm::instrprof_error::hash_mismatch) 988 CGM.getPGOStats().addMismatched(IsInMainFile); 989 else if (IPE == llvm::instrprof_error::malformed) 990 // TODO: Consider a more specific warning for this case. 991 CGM.getPGOStats().addMismatched(IsInMainFile); 992 return; 993 } 994 ProfRecord = 995 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 996 RegionCounts = ProfRecord->Counts; 997 } 998 999 /// Calculate what to divide by to scale weights. 1000 /// 1001 /// Given the maximum weight, calculate a divisor that will scale all the 1002 /// weights to strictly less than UINT32_MAX. 1003 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 1004 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 1005 } 1006 1007 /// Scale an individual branch weight (and add 1). 1008 /// 1009 /// Scale a 64-bit weight down to 32-bits using \c Scale. 1010 /// 1011 /// According to Laplace's Rule of Succession, it is better to compute the 1012 /// weight based on the count plus 1, so universally add 1 to the value. 1013 /// 1014 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 1015 /// greater than \c Weight. 1016 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 1017 assert(Scale && "scale by 0?"); 1018 uint64_t Scaled = Weight / Scale + 1; 1019 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 1020 return Scaled; 1021 } 1022 1023 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 1024 uint64_t FalseCount) { 1025 // Check for empty weights. 1026 if (!TrueCount && !FalseCount) 1027 return nullptr; 1028 1029 // Calculate how to scale down to 32-bits. 1030 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 1031 1032 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1033 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 1034 scaleBranchWeight(FalseCount, Scale)); 1035 } 1036 1037 llvm::MDNode * 1038 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 1039 // We need at least two elements to create meaningful weights. 1040 if (Weights.size() < 2) 1041 return nullptr; 1042 1043 // Check for empty weights. 1044 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 1045 if (MaxWeight == 0) 1046 return nullptr; 1047 1048 // Calculate how to scale down to 32-bits. 1049 uint64_t Scale = calculateWeightScale(MaxWeight); 1050 1051 SmallVector<uint32_t, 16> ScaledWeights; 1052 ScaledWeights.reserve(Weights.size()); 1053 for (uint64_t W : Weights) 1054 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 1055 1056 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 1057 return MDHelper.createBranchWeights(ScaledWeights); 1058 } 1059 1060 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 1061 uint64_t LoopCount) { 1062 if (!PGO.haveRegionCounts()) 1063 return nullptr; 1064 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 1065 if (!CondCount || *CondCount == 0) 1066 return nullptr; 1067 return createProfileWeights(LoopCount, 1068 std::max(*CondCount, LoopCount) - LoopCount); 1069 } 1070