1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SampleContextTracker used by CSSPGO. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/IPO/SampleContextTracker.h" 14 #include "llvm/ADT/StringMap.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/IR/DebugInfoMetadata.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/ProfileData/SampleProf.h" 19 #include <map> 20 #include <queue> 21 #include <vector> 22 23 using namespace llvm; 24 using namespace sampleprof; 25 26 #define DEBUG_TYPE "sample-context-tracker" 27 28 namespace llvm { 29 30 ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, 31 StringRef CalleeName) { 32 if (CalleeName.empty()) 33 return getHottestChildContext(CallSite); 34 35 uint32_t Hash = nodeHash(CalleeName, CallSite); 36 auto It = AllChildContext.find(Hash); 37 if (It != AllChildContext.end()) 38 return &It->second; 39 return nullptr; 40 } 41 42 ContextTrieNode * 43 ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) { 44 // CSFDO-TODO: This could be slow, change AllChildContext so we can 45 // do point look up for child node by call site alone. 46 // Retrieve the child node with max count for indirect call 47 ContextTrieNode *ChildNodeRet = nullptr; 48 uint64_t MaxCalleeSamples = 0; 49 for (auto &It : AllChildContext) { 50 ContextTrieNode &ChildNode = It.second; 51 if (ChildNode.CallSiteLoc != CallSite) 52 continue; 53 FunctionSamples *Samples = ChildNode.getFunctionSamples(); 54 if (!Samples) 55 continue; 56 if (Samples->getTotalSamples() > MaxCalleeSamples) { 57 ChildNodeRet = &ChildNode; 58 MaxCalleeSamples = Samples->getTotalSamples(); 59 } 60 } 61 62 return ChildNodeRet; 63 } 64 65 ContextTrieNode &ContextTrieNode::moveToChildContext( 66 const LineLocation &CallSite, ContextTrieNode &&NodeToMove, 67 StringRef ContextStrToRemove, bool DeleteNode) { 68 uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite); 69 assert(!AllChildContext.count(Hash) && "Node to remove must exist"); 70 LineLocation OldCallSite = NodeToMove.CallSiteLoc; 71 ContextTrieNode &OldParentContext = *NodeToMove.getParentContext(); 72 AllChildContext[Hash] = NodeToMove; 73 ContextTrieNode &NewNode = AllChildContext[Hash]; 74 NewNode.CallSiteLoc = CallSite; 75 76 // Walk through nodes in the moved the subtree, and update 77 // FunctionSamples' context as for the context promotion. 78 // We also need to set new parant link for all children. 79 std::queue<ContextTrieNode *> NodeToUpdate; 80 NewNode.setParentContext(this); 81 NodeToUpdate.push(&NewNode); 82 83 while (!NodeToUpdate.empty()) { 84 ContextTrieNode *Node = NodeToUpdate.front(); 85 NodeToUpdate.pop(); 86 FunctionSamples *FSamples = Node->getFunctionSamples(); 87 88 if (FSamples) { 89 FSamples->getContext().promoteOnPath(ContextStrToRemove); 90 FSamples->getContext().setState(SyntheticContext); 91 LLVM_DEBUG(dbgs() << " Context promoted to: " << FSamples->getContext() 92 << "\n"); 93 } 94 95 for (auto &It : Node->getAllChildContext()) { 96 ContextTrieNode *ChildNode = &It.second; 97 ChildNode->setParentContext(Node); 98 NodeToUpdate.push(ChildNode); 99 } 100 } 101 102 // Original context no longer needed, destroy if requested. 103 if (DeleteNode) 104 OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName()); 105 106 return NewNode; 107 } 108 109 void ContextTrieNode::removeChildContext(const LineLocation &CallSite, 110 StringRef CalleeName) { 111 uint32_t Hash = nodeHash(CalleeName, CallSite); 112 // Note this essentially calls dtor and destroys that child context 113 AllChildContext.erase(Hash); 114 } 115 116 std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { 117 return AllChildContext; 118 } 119 120 const StringRef ContextTrieNode::getFuncName() const { return FuncName; } 121 122 FunctionSamples *ContextTrieNode::getFunctionSamples() const { 123 return FuncSamples; 124 } 125 126 void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { 127 FuncSamples = FSamples; 128 } 129 130 LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; } 131 132 ContextTrieNode *ContextTrieNode::getParentContext() const { 133 return ParentContext; 134 } 135 136 void ContextTrieNode::setParentContext(ContextTrieNode *Parent) { 137 ParentContext = Parent; 138 } 139 140 void ContextTrieNode::dump() { 141 dbgs() << "Node: " << FuncName << "\n" 142 << " Callsite: " << CallSiteLoc << "\n" 143 << " Children:\n"; 144 145 for (auto &It : AllChildContext) { 146 dbgs() << " Node: " << It.second.getFuncName() << "\n"; 147 } 148 } 149 150 uint32_t ContextTrieNode::nodeHash(StringRef ChildName, 151 const LineLocation &Callsite) { 152 // We still use child's name for child hash, this is 153 // because for children of root node, we don't have 154 // different line/discriminator, and we'll rely on name 155 // to differentiate children. 156 uint32_t NameHash = std::hash<std::string>{}(ChildName.str()); 157 uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator; 158 return NameHash + (LocId << 5) + LocId; 159 } 160 161 ContextTrieNode *ContextTrieNode::getOrCreateChildContext( 162 const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { 163 uint32_t Hash = nodeHash(CalleeName, CallSite); 164 auto It = AllChildContext.find(Hash); 165 if (It != AllChildContext.end()) { 166 assert(It->second.getFuncName() == CalleeName && 167 "Hash collision for child context node"); 168 return &It->second; 169 } 170 171 if (!AllowCreate) 172 return nullptr; 173 174 AllChildContext[Hash] = ContextTrieNode(this, CalleeName, nullptr, CallSite); 175 return &AllChildContext[Hash]; 176 } 177 178 // Profiler tracker than manages profiles and its associated context 179 SampleContextTracker::SampleContextTracker( 180 StringMap<FunctionSamples> &Profiles) { 181 for (auto &FuncSample : Profiles) { 182 FunctionSamples *FSamples = &FuncSample.second; 183 SampleContext Context(FuncSample.first(), RawContext); 184 LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n"); 185 if (!Context.isBaseContext()) 186 FuncToCtxtProfileSet[Context.getNameWithoutContext()].insert(FSamples); 187 ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); 188 assert(!NewNode->getFunctionSamples() && 189 "New node can't have sample profile"); 190 NewNode->setFunctionSamples(FSamples); 191 } 192 } 193 194 FunctionSamples * 195 SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, 196 StringRef CalleeName) { 197 LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n"); 198 DILocation *DIL = Inst.getDebugLoc(); 199 if (!DIL) 200 return nullptr; 201 202 // For indirect call, CalleeName will be empty, in which case the context 203 // profile for callee with largest total samples will be returned. 204 ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName); 205 if (CalleeContext) { 206 FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); 207 LLVM_DEBUG(if (FSamples) { 208 dbgs() << " Callee context found: " << FSamples->getContext() << "\n"; 209 }); 210 return FSamples; 211 } 212 213 return nullptr; 214 } 215 216 std::vector<const FunctionSamples *> 217 SampleContextTracker::getIndirectCalleeContextSamplesFor( 218 const DILocation *DIL) { 219 std::vector<const FunctionSamples *> R; 220 if (!DIL) 221 return R; 222 223 ContextTrieNode *CallerNode = getContextFor(DIL); 224 LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL); 225 for (auto &It : CallerNode->getAllChildContext()) { 226 ContextTrieNode &ChildNode = It.second; 227 if (ChildNode.getCallSiteLoc() != CallSite) 228 continue; 229 if (FunctionSamples *CalleeSamples = ChildNode.getFunctionSamples()) 230 R.push_back(CalleeSamples); 231 } 232 233 return R; 234 } 235 236 FunctionSamples * 237 SampleContextTracker::getContextSamplesFor(const DILocation *DIL) { 238 assert(DIL && "Expect non-null location"); 239 240 ContextTrieNode *ContextNode = getContextFor(DIL); 241 if (!ContextNode) 242 return nullptr; 243 244 // We may have inlined callees during pre-LTO compilation, in which case 245 // we need to rely on the inline stack from !dbg to mark context profile 246 // as inlined, instead of `MarkContextSamplesInlined` during inlining. 247 // Sample profile loader walks through all instructions to get profile, 248 // which calls this function. So once that is done, all previously inlined 249 // context profile should be marked properly. 250 FunctionSamples *Samples = ContextNode->getFunctionSamples(); 251 if (Samples && ContextNode->getParentContext() != &RootContext) 252 Samples->getContext().setState(InlinedContext); 253 254 return Samples; 255 } 256 257 FunctionSamples * 258 SampleContextTracker::getContextSamplesFor(const SampleContext &Context) { 259 ContextTrieNode *Node = getContextFor(Context); 260 if (!Node) 261 return nullptr; 262 263 return Node->getFunctionSamples(); 264 } 265 266 SampleContextTracker::ContextSamplesTy & 267 SampleContextTracker::getAllContextSamplesFor(const Function &Func) { 268 StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); 269 return FuncToCtxtProfileSet[CanonName]; 270 } 271 272 SampleContextTracker::ContextSamplesTy & 273 SampleContextTracker::getAllContextSamplesFor(StringRef Name) { 274 return FuncToCtxtProfileSet[Name]; 275 } 276 277 FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, 278 bool MergeContext) { 279 StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); 280 return getBaseSamplesFor(CanonName, MergeContext); 281 } 282 283 FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, 284 bool MergeContext) { 285 LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); 286 // Base profile is top-level node (child of root node), so try to retrieve 287 // existing top-level node for given function first. If it exists, it could be 288 // that we've merged base profile before, or there's actually context-less 289 // profile from the input (e.g. due to unreliable stack walking). 290 ContextTrieNode *Node = getTopLevelContextNode(Name); 291 if (MergeContext) { 292 LLVM_DEBUG(dbgs() << " Merging context profile into base profile: " << Name 293 << "\n"); 294 295 // We have profile for function under different contexts, 296 // create synthetic base profile and merge context profiles 297 // into base profile. 298 for (auto *CSamples : FuncToCtxtProfileSet[Name]) { 299 SampleContext &Context = CSamples->getContext(); 300 ContextTrieNode *FromNode = getContextFor(Context); 301 if (FromNode == Node) 302 continue; 303 304 // Skip inlined context profile and also don't re-merge any context 305 if (Context.hasState(InlinedContext) || Context.hasState(MergedContext)) 306 continue; 307 308 ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode); 309 assert((!Node || Node == &ToNode) && "Expect only one base profile"); 310 Node = &ToNode; 311 } 312 } 313 314 // Still no profile even after merge/promotion (if allowed) 315 if (!Node) 316 return nullptr; 317 318 return Node->getFunctionSamples(); 319 } 320 321 void SampleContextTracker::markContextSamplesInlined( 322 const FunctionSamples *InlinedSamples) { 323 assert(InlinedSamples && "Expect non-null inlined samples"); 324 LLVM_DEBUG(dbgs() << "Marking context profile as inlined: " 325 << InlinedSamples->getContext() << "\n"); 326 InlinedSamples->getContext().setState(InlinedContext); 327 } 328 329 void SampleContextTracker::promoteMergeContextSamplesTree( 330 const Instruction &Inst, StringRef CalleeName) { 331 LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n" 332 << Inst << "\n"); 333 // Get the caller context for the call instruction, we don't use callee 334 // name from call because there can be context from indirect calls too. 335 DILocation *DIL = Inst.getDebugLoc(); 336 ContextTrieNode *CallerNode = getContextFor(DIL); 337 if (!CallerNode) 338 return; 339 340 // Get the context that needs to be promoted 341 LineLocation CallSite = FunctionSamples::getCallSiteIdentifier(DIL); 342 // For indirect call, CalleeName will be empty, in which case we need to 343 // promote all non-inlined child context profiles. 344 if (CalleeName.empty()) { 345 for (auto &It : CallerNode->getAllChildContext()) { 346 ContextTrieNode *NodeToPromo = &It.second; 347 if (CallSite != NodeToPromo->getCallSiteLoc()) 348 continue; 349 FunctionSamples *FromSamples = NodeToPromo->getFunctionSamples(); 350 if (FromSamples && FromSamples->getContext().hasState(InlinedContext)) 351 continue; 352 promoteMergeContextSamplesTree(*NodeToPromo); 353 } 354 return; 355 } 356 357 // Get the context for the given callee that needs to be promoted 358 ContextTrieNode *NodeToPromo = 359 CallerNode->getChildContext(CallSite, CalleeName); 360 if (!NodeToPromo) 361 return; 362 363 promoteMergeContextSamplesTree(*NodeToPromo); 364 } 365 366 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( 367 ContextTrieNode &NodeToPromo) { 368 // Promote the input node to be directly under root. This can happen 369 // when we decided to not inline a function under context represented 370 // by the input node. The promote and merge is then needed to reflect 371 // the context profile in the base (context-less) profile. 372 FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples(); 373 assert(FromSamples && "Shouldn't promote a context without profile"); 374 LLVM_DEBUG(dbgs() << " Found context tree root to promote: " 375 << FromSamples->getContext() << "\n"); 376 377 assert(!FromSamples->getContext().hasState(InlinedContext) && 378 "Shouldn't promote inlined context profile"); 379 StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext(); 380 return promoteMergeContextSamplesTree(NodeToPromo, RootContext, 381 ContextStrToRemove); 382 } 383 384 void SampleContextTracker::dump() { 385 dbgs() << "Context Profile Tree:\n"; 386 std::queue<ContextTrieNode *> NodeQueue; 387 NodeQueue.push(&RootContext); 388 389 while (!NodeQueue.empty()) { 390 ContextTrieNode *Node = NodeQueue.front(); 391 NodeQueue.pop(); 392 Node->dump(); 393 394 for (auto &It : Node->getAllChildContext()) { 395 ContextTrieNode *ChildNode = &It.second; 396 NodeQueue.push(ChildNode); 397 } 398 } 399 } 400 401 ContextTrieNode * 402 SampleContextTracker::getContextFor(const SampleContext &Context) { 403 return getOrCreateContextPath(Context, false); 404 } 405 406 ContextTrieNode * 407 SampleContextTracker::getCalleeContextFor(const DILocation *DIL, 408 StringRef CalleeName) { 409 assert(DIL && "Expect non-null location"); 410 411 ContextTrieNode *CallContext = getContextFor(DIL); 412 if (!CallContext) 413 return nullptr; 414 415 // When CalleeName is empty, the child context profile with max 416 // total samples will be returned. 417 return CallContext->getChildContext( 418 FunctionSamples::getCallSiteIdentifier(DIL), CalleeName); 419 } 420 421 ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { 422 assert(DIL && "Expect non-null location"); 423 SmallVector<std::pair<LineLocation, StringRef>, 10> S; 424 425 // Use C++ linkage name if possible. 426 const DILocation *PrevDIL = DIL; 427 for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { 428 StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName(); 429 if (Name.empty()) 430 Name = PrevDIL->getScope()->getSubprogram()->getName(); 431 S.push_back( 432 std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), 433 PrevDIL->getScope()->getSubprogram()->getLinkageName())); 434 PrevDIL = DIL; 435 } 436 437 // Push root node, note that root node like main may only 438 // a name, but not linkage name. 439 StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName(); 440 if (RootName.empty()) 441 RootName = PrevDIL->getScope()->getSubprogram()->getName(); 442 S.push_back(std::make_pair(LineLocation(0, 0), RootName)); 443 444 ContextTrieNode *ContextNode = &RootContext; 445 int I = S.size(); 446 while (--I >= 0 && ContextNode) { 447 LineLocation &CallSite = S[I].first; 448 StringRef &CalleeName = S[I].second; 449 ContextNode = ContextNode->getChildContext(CallSite, CalleeName); 450 } 451 452 if (I < 0) 453 return ContextNode; 454 455 return nullptr; 456 } 457 458 ContextTrieNode * 459 SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, 460 bool AllowCreate) { 461 ContextTrieNode *ContextNode = &RootContext; 462 StringRef ContextRemain = Context; 463 StringRef ChildContext; 464 StringRef CalleeName; 465 LineLocation CallSiteLoc(0, 0); 466 467 while (ContextNode && !ContextRemain.empty()) { 468 auto ContextSplit = SampleContext::splitContextString(ContextRemain); 469 ChildContext = ContextSplit.first; 470 ContextRemain = ContextSplit.second; 471 LineLocation NextCallSiteLoc(0, 0); 472 SampleContext::decodeContextString(ChildContext, CalleeName, 473 NextCallSiteLoc); 474 475 // Create child node at parent line/disc location 476 if (AllowCreate) { 477 ContextNode = 478 ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName); 479 } else { 480 ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName); 481 } 482 CallSiteLoc = NextCallSiteLoc; 483 } 484 485 assert((!AllowCreate || ContextNode) && 486 "Node must exist if creation is allowed"); 487 return ContextNode; 488 } 489 490 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) { 491 return RootContext.getChildContext(LineLocation(0, 0), FName); 492 } 493 494 ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { 495 assert(!getTopLevelContextNode(FName) && "Node to add must not exist"); 496 return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName); 497 } 498 499 void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, 500 ContextTrieNode &ToNode, 501 StringRef ContextStrToRemove) { 502 FunctionSamples *FromSamples = FromNode.getFunctionSamples(); 503 FunctionSamples *ToSamples = ToNode.getFunctionSamples(); 504 if (FromSamples && ToSamples) { 505 // Merge/duplicate FromSamples into ToSamples 506 ToSamples->merge(*FromSamples); 507 ToSamples->getContext().setState(SyntheticContext); 508 FromSamples->getContext().setState(MergedContext); 509 } else if (FromSamples) { 510 // Transfer FromSamples from FromNode to ToNode 511 ToNode.setFunctionSamples(FromSamples); 512 FromSamples->getContext().setState(SyntheticContext); 513 FromSamples->getContext().promoteOnPath(ContextStrToRemove); 514 FromNode.setFunctionSamples(nullptr); 515 } 516 } 517 518 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( 519 ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent, 520 StringRef ContextStrToRemove) { 521 assert(!ContextStrToRemove.empty() && "Context to remove can't be empty"); 522 523 // Ignore call site location if destination is top level under root 524 LineLocation NewCallSiteLoc = LineLocation(0, 0); 525 LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc(); 526 ContextTrieNode &FromNodeParent = *FromNode.getParentContext(); 527 ContextTrieNode *ToNode = nullptr; 528 bool MoveToRoot = (&ToNodeParent == &RootContext); 529 if (!MoveToRoot) { 530 NewCallSiteLoc = OldCallSiteLoc; 531 } 532 533 // Locate destination node, create/move if not existing 534 ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName()); 535 if (!ToNode) { 536 // Do not delete node to move from its parent here because 537 // caller is iterating over children of that parent node. 538 ToNode = &ToNodeParent.moveToChildContext( 539 NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false); 540 } else { 541 // Destination node exists, merge samples for the context tree 542 mergeContextNode(FromNode, *ToNode, ContextStrToRemove); 543 LLVM_DEBUG(dbgs() << " Context promoted and merged to: " 544 << ToNode->getFunctionSamples()->getContext() << "\n"); 545 546 // Recursively promote and merge children 547 for (auto &It : FromNode.getAllChildContext()) { 548 ContextTrieNode &FromChildNode = It.second; 549 promoteMergeContextSamplesTree(FromChildNode, *ToNode, 550 ContextStrToRemove); 551 } 552 553 // Remove children once they're all merged 554 FromNode.getAllChildContext().clear(); 555 } 556 557 // For root of subtree, remove itself from old parent too 558 if (MoveToRoot) 559 FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName()); 560 561 return *ToNode; 562 } 563 564 // Replace call graph edges with dynamic call edges from the profile. 565 void SampleContextTracker::addCallGraphEdges(CallGraph &CG, 566 StringMap<Function *> &SymbolMap) { 567 // Add profile call edges to the call graph. 568 std::queue<ContextTrieNode *> NodeQueue; 569 NodeQueue.push(&RootContext); 570 while (!NodeQueue.empty()) { 571 ContextTrieNode *Node = NodeQueue.front(); 572 NodeQueue.pop(); 573 Function *F = SymbolMap.lookup(Node->getFuncName()); 574 for (auto &I : Node->getAllChildContext()) { 575 ContextTrieNode *ChildNode = &I.second; 576 NodeQueue.push(ChildNode); 577 if (F && !F->isDeclaration()) { 578 Function *Callee = SymbolMap.lookup(ChildNode->getFuncName()); 579 if (Callee && !Callee->isDeclaration()) 580 CG[F]->addCalledFunction(nullptr, CG[Callee]); 581 } 582 } 583 } 584 } 585 } // namespace llvm 586