1 //===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the interface between the inliner and a learned model. 10 // It delegates model evaluation to either the AOT compiled model (the 11 // 'release' mode) or a runtime-loaded model (the 'development' case). 12 // 13 //===----------------------------------------------------------------------===// 14 #include "llvm/Analysis/MLInlineAdvisor.h" 15 #include "llvm/ADT/SCCIterator.h" 16 #include "llvm/Analysis/AssumptionCache.h" 17 #include "llvm/Analysis/BlockFrequencyInfo.h" 18 #include "llvm/Analysis/CallGraph.h" 19 #include "llvm/Analysis/FunctionPropertiesAnalysis.h" 20 #include "llvm/Analysis/InlineCost.h" 21 #include "llvm/Analysis/InlineModelFeatureMaps.h" 22 #include "llvm/Analysis/InteractiveModelRunner.h" 23 #include "llvm/Analysis/LazyCallGraph.h" 24 #include "llvm/Analysis/LoopInfo.h" 25 #include "llvm/Analysis/MLModelRunner.h" 26 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 27 #include "llvm/Analysis/ProfileSummaryInfo.h" 28 #include "llvm/Analysis/ReleaseModeModelRunner.h" 29 #include "llvm/Analysis/TargetTransformInfo.h" 30 #include "llvm/IR/Dominators.h" 31 #include "llvm/IR/InstIterator.h" 32 #include "llvm/IR/Module.h" 33 #include "llvm/IR/PassManager.h" 34 #include "llvm/Support/CommandLine.h" 35 36 using namespace llvm; 37 38 static cl::opt<std::string> InteractiveChannelBaseName( 39 "inliner-interactive-channel-base", cl::Hidden, 40 cl::desc( 41 "Base file path for the interactive mode. The incoming filename should " 42 "have the name <inliner-interactive-channel-base>.in, while the " 43 "outgoing name should be <inliner-interactive-channel-base>.out")); 44 static const std::string InclDefaultMsg = 45 (Twine("In interactive mode, also send the default policy decision: ") + 46 DefaultDecisionName + ".") 47 .str(); 48 static cl::opt<bool> 49 InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden, 50 cl::desc(InclDefaultMsg)); 51 52 enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold }; 53 54 static cl::opt<SkipMLPolicyCriteria> SkipPolicy( 55 "ml-inliner-skip-policy", cl::Hidden, cl::init(SkipMLPolicyCriteria::Never), 56 cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"), 57 clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold, 58 "if-caller-not-cold", "if the caller is not cold"))); 59 60 static cl::opt<std::string> ModelSelector("ml-inliner-model-selector", 61 cl::Hidden, cl::init("")); 62 63 #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL) 64 // codegen-ed file 65 #include "InlinerSizeModel.h" // NOLINT 66 using CompiledModelType = llvm::InlinerSizeModel; 67 #else 68 using CompiledModelType = NoopSavedModelImpl; 69 #endif 70 71 std::unique_ptr<InlineAdvisor> 72 llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM, 73 std::function<bool(CallBase &)> GetDefaultAdvice) { 74 if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() && 75 InteractiveChannelBaseName.empty()) 76 return nullptr; 77 std::unique_ptr<MLModelRunner> AOTRunner; 78 if (InteractiveChannelBaseName.empty()) 79 AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>( 80 M.getContext(), FeatureMap, DecisionName, 81 EmbeddedModelRunnerOptions().setModelSelector(ModelSelector)); 82 else { 83 auto Features = FeatureMap; 84 if (InteractiveIncludeDefault) 85 Features.push_back(DefaultDecisionSpec); 86 AOTRunner = std::make_unique<InteractiveModelRunner>( 87 M.getContext(), Features, InlineDecisionSpec, 88 InteractiveChannelBaseName + ".out", 89 InteractiveChannelBaseName + ".in"); 90 } 91 return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner), 92 GetDefaultAdvice); 93 } 94 95 #define DEBUG_TYPE "inline-ml" 96 97 static cl::opt<float> SizeIncreaseThreshold( 98 "ml-advisor-size-increase-threshold", cl::Hidden, 99 cl::desc("Maximum factor by which expected native size may increase before " 100 "blocking any further inlining."), 101 cl::init(2.0)); 102 103 static cl::opt<bool> KeepFPICache( 104 "ml-advisor-keep-fpi-cache", cl::Hidden, 105 cl::desc( 106 "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"), 107 cl::init(false)); 108 109 // clang-format off 110 std::vector<TensorSpec> llvm::FeatureMap{ 111 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE), 112 // InlineCost features - these must come first 113 INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES) 114 115 // Non-cost features 116 INLINE_FEATURE_ITERATOR(POPULATE_NAMES) 117 #undef POPULATE_NAMES 118 }; 119 // clang-format on 120 121 const char *const llvm::DecisionName = "inlining_decision"; 122 const TensorSpec llvm::InlineDecisionSpec = 123 TensorSpec::createSpec<int64_t>(DecisionName, {1}); 124 const char *const llvm::DefaultDecisionName = "inlining_default"; 125 const TensorSpec llvm::DefaultDecisionSpec = 126 TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1}); 127 const char *const llvm::RewardName = "delta_size"; 128 129 CallBase *getInlinableCS(Instruction &I) { 130 if (auto *CS = dyn_cast<CallBase>(&I)) 131 if (Function *Callee = CS->getCalledFunction()) { 132 if (!Callee->isDeclaration()) { 133 return CS; 134 } 135 } 136 return nullptr; 137 } 138 139 MLInlineAdvisor::MLInlineAdvisor( 140 Module &M, ModuleAnalysisManager &MAM, 141 std::unique_ptr<MLModelRunner> Runner, 142 std::function<bool(CallBase &)> GetDefaultAdvice) 143 : InlineAdvisor( 144 M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()), 145 ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice), 146 CG(MAM.getResult<LazyCallGraphAnalysis>(M)), 147 UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(M) != nullptr), 148 InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize), 149 PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) { 150 assert(ModelRunner); 151 ModelRunner->switchContext(""); 152 // Extract the 'call site height' feature - the position of a call site 153 // relative to the farthest statically reachable SCC node. We don't mutate 154 // this value while inlining happens. Empirically, this feature proved 155 // critical in behavioral cloning - i.e. training a model to mimic the manual 156 // heuristic's decisions - and, thus, equally important for training for 157 // improvement. 158 CallGraph CGraph(M); 159 for (auto I = scc_begin(&CGraph); !I.isAtEnd(); ++I) { 160 const std::vector<CallGraphNode *> &CGNodes = *I; 161 unsigned Level = 0; 162 for (auto *CGNode : CGNodes) { 163 Function *F = CGNode->getFunction(); 164 if (!F || F->isDeclaration()) 165 continue; 166 for (auto &I : instructions(F)) { 167 if (auto *CS = getInlinableCS(I)) { 168 auto *Called = CS->getCalledFunction(); 169 auto Pos = FunctionLevels.find(&CG.get(*Called)); 170 // In bottom up traversal, an inlinable callee is either in the 171 // same SCC, or to a function in a visited SCC. So not finding its 172 // level means we haven't visited it yet, meaning it's in this SCC. 173 if (Pos == FunctionLevels.end()) 174 continue; 175 Level = std::max(Level, Pos->second + 1); 176 } 177 } 178 } 179 for (auto *CGNode : CGNodes) { 180 Function *F = CGNode->getFunction(); 181 if (F && !F->isDeclaration()) 182 FunctionLevels[&CG.get(*F)] = Level; 183 } 184 } 185 for (auto KVP : FunctionLevels) { 186 AllNodes.insert(KVP.first); 187 EdgeCount += getLocalCalls(KVP.first->getFunction()); 188 } 189 NodeCount = AllNodes.size(); 190 191 if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) { 192 if (!IR2VecVocabResult->isValid()) { 193 M.getContext().emitError("IR2VecVocabAnalysis is not valid"); 194 return; 195 } 196 // Add the IR2Vec features to the feature map 197 auto IR2VecDim = IR2VecVocabResult->getDimension(); 198 FeatureMap.push_back( 199 TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim})); 200 FeatureMap.push_back( 201 TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim})); 202 } 203 } 204 205 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const { 206 return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0; 207 } 208 209 void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) { 210 if (!CurSCC || ForceStop) 211 return; 212 FPICache.clear(); 213 // Function passes executed between InlinerPass runs may have changed the 214 // module-wide features. 215 // The cgscc pass manager rules are such that: 216 // - if a pass leads to merging SCCs, then the pipeline is restarted on the 217 // merged SCC 218 // - if a pass leads to splitting the SCC, then we continue with one of the 219 // splits 220 // This means that the NodesInLastSCC is a superset (not strict) of the nodes 221 // that subsequent passes would have processed 222 // - in addition, if new Nodes were created by a pass (e.g. CoroSplit), 223 // they'd be adjacent to Nodes in the last SCC. So we just need to check the 224 // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't 225 // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we 226 // record them at the same level as the original node (this is a choice, may 227 // need revisiting). 228 // - nodes are only deleted at the end of a call graph walk where they are 229 // batch deleted, so we shouldn't see any dead nodes here. 230 while (!NodesInLastSCC.empty()) { 231 const auto *N = *NodesInLastSCC.begin(); 232 assert(!N->isDead()); 233 NodesInLastSCC.erase(N); 234 EdgeCount += getLocalCalls(N->getFunction()); 235 const auto NLevel = FunctionLevels.at(N); 236 for (const auto &E : *(*N)) { 237 const auto *AdjNode = &E.getNode(); 238 assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration()); 239 auto I = AllNodes.insert(AdjNode); 240 // We've discovered a new function. 241 if (I.second) { 242 ++NodeCount; 243 NodesInLastSCC.insert(AdjNode); 244 FunctionLevels[AdjNode] = NLevel; 245 } 246 } 247 } 248 249 EdgeCount -= EdgesOfLastSeenNodes; 250 EdgesOfLastSeenNodes = 0; 251 252 // (Re)use NodesInLastSCC to remember the nodes in the SCC right now, 253 // in case the SCC is split before onPassExit and some nodes are split out 254 assert(NodesInLastSCC.empty()); 255 for (const auto &N : *CurSCC) 256 NodesInLastSCC.insert(&N); 257 } 258 259 void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) { 260 // No need to keep this around - function passes will invalidate it. 261 if (!KeepFPICache) 262 FPICache.clear(); 263 if (!CurSCC || ForceStop) 264 return; 265 // Keep track of the nodes and edges we last saw. Then, in onPassEntry, 266 // we update the node count and edge count from the subset of these nodes that 267 // survived. 268 EdgesOfLastSeenNodes = 0; 269 270 // Check on nodes that were in SCC onPassEntry 271 for (const LazyCallGraph::Node *N : NodesInLastSCC) { 272 assert(!N->isDead()); 273 EdgesOfLastSeenNodes += getLocalCalls(N->getFunction()); 274 } 275 276 // Check on nodes that may have got added to SCC 277 for (const auto &N : *CurSCC) { 278 assert(!N.isDead()); 279 auto I = NodesInLastSCC.insert(&N); 280 if (I.second) 281 EdgesOfLastSeenNodes += getLocalCalls(N.getFunction()); 282 } 283 assert(NodeCount >= NodesInLastSCC.size()); 284 assert(EdgeCount >= EdgesOfLastSeenNodes); 285 } 286 287 int64_t MLInlineAdvisor::getLocalCalls(Function &F) { 288 return getCachedFPI(F).DirectCallsToDefinedFunctions; 289 } 290 291 // Update the internal state of the advisor, and force invalidate feature 292 // analysis. Currently, we maintain minimal (and very simple) global state - the 293 // number of functions and the number of static calls. We also keep track of the 294 // total IR size in this module, to stop misbehaving policies at a certain bloat 295 // factor (SizeIncreaseThreshold) 296 void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, 297 bool CalleeWasDeleted) { 298 assert(!ForceStop); 299 Function *Caller = Advice.getCaller(); 300 Function *Callee = Advice.getCallee(); 301 // The caller features aren't valid anymore. 302 { 303 PreservedAnalyses PA = PreservedAnalyses::all(); 304 PA.abandon<FunctionPropertiesAnalysis>(); 305 PA.abandon<LoopAnalysis>(); 306 FAM.invalidate(*Caller, PA); 307 } 308 Advice.updateCachedCallerFPI(FAM); 309 int64_t IRSizeAfter = 310 getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); 311 CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); 312 if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize) 313 ForceStop = true; 314 315 // We can delta-update module-wide features. We know the inlining only changed 316 // the caller, and maybe the callee (by deleting the latter). 317 // Nodes are simple to update. 318 // For edges, we 'forget' the edges that the caller and callee used to have 319 // before inlining, and add back what they currently have together. 320 int64_t NewCallerAndCalleeEdges = 321 getCachedFPI(*Caller).DirectCallsToDefinedFunctions; 322 323 // A dead function's node is not actually removed from the call graph until 324 // the end of the call graph walk, but the node no longer belongs to any valid 325 // SCC. 326 if (CalleeWasDeleted) { 327 --NodeCount; 328 NodesInLastSCC.erase(CG.lookup(*Callee)); 329 DeadFunctions.insert(Callee); 330 } else { 331 NewCallerAndCalleeEdges += 332 getCachedFPI(*Callee).DirectCallsToDefinedFunctions; 333 } 334 EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); 335 assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0); 336 } 337 338 int64_t MLInlineAdvisor::getModuleIRSize() const { 339 int64_t Ret = 0; 340 for (auto &F : M) 341 if (!F.isDeclaration()) 342 Ret += getIRSize(F); 343 return Ret; 344 } 345 346 FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const { 347 auto InsertPair = FPICache.try_emplace(&F); 348 if (!InsertPair.second) 349 return InsertPair.first->second; 350 InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F); 351 return InsertPair.first->second; 352 } 353 354 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { 355 if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB)) 356 return Skip; 357 358 auto &Caller = *CB.getCaller(); 359 auto &Callee = *CB.getCalledFunction(); 360 361 auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { 362 return FAM.getResult<AssumptionAnalysis>(F); 363 }; 364 auto &TIR = FAM.getResult<TargetIRAnalysis>(Callee); 365 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller); 366 367 if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) { 368 if (!PSI.isFunctionEntryCold(&Caller)) 369 return std::make_unique<InlineAdvice>(this, CB, ORE, 370 GetDefaultAdvice(CB)); 371 } 372 auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE); 373 // If this is a "never inline" case, there won't be any changes to internal 374 // state we need to track, so we can just return the base InlineAdvice, which 375 // will do nothing interesting. 376 // Same thing if this is a recursive case. 377 if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never || 378 &Caller == &Callee) 379 return getMandatoryAdvice(CB, false); 380 381 bool Mandatory = 382 MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always; 383 384 // If we need to stop, we won't want to track anymore any state changes, so 385 // we just return the base InlineAdvice, which acts as a noop. 386 if (ForceStop) { 387 ORE.emit([&] { 388 return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB) 389 << "Won't attempt inlining because module size grew too much."; 390 }); 391 return std::make_unique<InlineAdvice>(this, CB, ORE, Mandatory); 392 } 393 394 int CostEstimate = 0; 395 if (!Mandatory) { 396 auto IsCallSiteInlinable = 397 llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache); 398 if (!IsCallSiteInlinable) { 399 // We can't inline this for correctness reasons, so return the base 400 // InlineAdvice, as we don't care about tracking any state changes (which 401 // won't happen). 402 return std::make_unique<InlineAdvice>(this, CB, ORE, false); 403 } 404 CostEstimate = *IsCallSiteInlinable; 405 } 406 407 const auto CostFeatures = 408 llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache); 409 if (!CostFeatures) { 410 return std::make_unique<InlineAdvice>(this, CB, ORE, false); 411 } 412 413 if (Mandatory) 414 return getMandatoryAdvice(CB, true); 415 416 auto NumCtantParams = 0; 417 for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) { 418 NumCtantParams += (isa<Constant>(*I)); 419 } 420 421 auto &CallerBefore = getCachedFPI(Caller); 422 auto &CalleeBefore = getCachedFPI(Callee); 423 424 *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) = 425 CalleeBefore.BasicBlockCount; 426 *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) = 427 getInitialFunctionLevel(Caller); 428 *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount; 429 *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) = 430 NumCtantParams; 431 *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount; 432 *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) = 433 CallerBefore.Uses; 434 *ModelRunner->getTensor<int64_t>( 435 FeatureIndex::caller_conditionally_executed_blocks) = 436 CallerBefore.BlocksReachedFromConditionalInstruction; 437 *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) = 438 CallerBefore.BasicBlockCount; 439 *ModelRunner->getTensor<int64_t>( 440 FeatureIndex::callee_conditionally_executed_blocks) = 441 CalleeBefore.BlocksReachedFromConditionalInstruction; 442 *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) = 443 CalleeBefore.Uses; 444 *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate; 445 *ModelRunner->getTensor<int64_t>(FeatureIndex::is_callee_avail_external) = 446 Callee.hasAvailableExternallyLinkage(); 447 *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) = 448 Caller.hasAvailableExternallyLinkage(); 449 450 if (UseIR2Vec) { 451 // Python side expects float embeddings. The IR2Vec embeddings are doubles 452 // as of now due to the restriction of fromJSON method used by the 453 // readVocabulary method in ir2vec::Embeddings. 454 auto setEmbedding = [&](const ir2vec::Embedding &Embedding, 455 FeatureIndex Index) { 456 llvm::transform(Embedding, ModelRunner->getTensor<float>(Index), 457 [](double Val) { return static_cast<float>(Val); }); 458 }; 459 460 setEmbedding(CalleeBefore.getFunctionEmbedding(), 461 FeatureIndex::callee_embedding); 462 setEmbedding(CallerBefore.getFunctionEmbedding(), 463 FeatureIndex::caller_embedding); 464 } 465 466 // Add the cost features 467 for (size_t I = 0; 468 I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) { 469 *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature( 470 static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I); 471 } 472 // This one would have been set up to be right at the end. 473 if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault) 474 *ModelRunner->getTensor<int64_t>(FeatureMap.size()) = GetDefaultAdvice(CB); 475 return getAdviceFromModel(CB, ORE); 476 } 477 478 std::unique_ptr<MLInlineAdvice> 479 MLInlineAdvisor::getAdviceFromModel(CallBase &CB, 480 OptimizationRemarkEmitter &ORE) { 481 return std::make_unique<MLInlineAdvice>( 482 this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>())); 483 } 484 485 std::unique_ptr<InlineAdvice> 486 MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) { 487 if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller()) 488 .isReachableFromEntry(CB.getParent())) 489 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 490 return nullptr; 491 } 492 493 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB, 494 bool Advice) { 495 // Make sure we track inlinings in all cases - mandatory or not. 496 if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB)) 497 return Skip; 498 if (Advice && !ForceStop) 499 return getMandatoryAdviceImpl(CB); 500 501 // If this is a "never inline" case, there won't be any changes to internal 502 // state we need to track, so we can just return the base InlineAdvice, which 503 // will do nothing interesting. 504 // Same if we are forced to stop - we don't track anymore. 505 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), Advice); 506 } 507 508 std::unique_ptr<MLInlineAdvice> 509 MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) { 510 return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true); 511 } 512 513 void MLInlineAdvisor::print(raw_ostream &OS) const { 514 OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount 515 << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n"; 516 OS << "[MLInlineAdvisor] FPI:\n"; 517 for (auto I : FPICache) { 518 OS << I.first->getName() << ":\n"; 519 I.second.print(OS); 520 OS << "\n"; 521 } 522 OS << "\n"; 523 OS << "[MLInlineAdvisor] FuncLevels:\n"; 524 for (auto I : FunctionLevels) 525 OS << (DeadFunctions.contains(&I.first->getFunction()) 526 ? "<deleted>" 527 : I.first->getFunction().getName()) 528 << " : " << I.second << "\n"; 529 530 OS << "\n"; 531 } 532 533 MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 534 OptimizationRemarkEmitter &ORE, 535 bool Recommendation) 536 : InlineAdvice(Advisor, CB, ORE, Recommendation), 537 CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)), 538 CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)), 539 CallerAndCalleeEdges(Advisor->isForcedToStop() 540 ? 0 541 : (Advisor->getLocalCalls(*Caller) + 542 Advisor->getLocalCalls(*Callee))), 543 PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) { 544 if (Recommendation) 545 FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB); 546 } 547 548 void MLInlineAdvice::reportContextForRemark( 549 DiagnosticInfoOptimizationBase &OR) { 550 using namespace ore; 551 OR << NV("Callee", Callee->getName()); 552 for (size_t I = 0; I < FeatureMap.size(); ++I) 553 OR << NV(FeatureMap[I].name(), 554 *getAdvisor()->getModelRunner().getTensor<int64_t>(I)); 555 OR << NV("ShouldInline", isInliningRecommended()); 556 } 557 558 void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const { 559 FPU->finish(FAM); 560 } 561 562 void MLInlineAdvice::recordInliningImpl() { 563 ORE.emit([&]() { 564 OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block); 565 reportContextForRemark(R); 566 return R; 567 }); 568 getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ false); 569 } 570 571 void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() { 572 ORE.emit([&]() { 573 OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc, 574 Block); 575 reportContextForRemark(R); 576 return R; 577 }); 578 getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ true); 579 } 580 581 void MLInlineAdvice::recordUnsuccessfulInliningImpl( 582 const InlineResult &Result) { 583 getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI; 584 ORE.emit([&]() { 585 OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful", 586 DLoc, Block); 587 reportContextForRemark(R); 588 return R; 589 }); 590 } 591 void MLInlineAdvice::recordUnattemptedInliningImpl() { 592 assert(!FPU); 593 ORE.emit([&]() { 594 OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block); 595 reportContextForRemark(R); 596 return R; 597 }); 598 } 599