xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp (revision b59017c5cad90d0f09a59e68c00457b7faf93e7c)
1 //===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the interface between the inliner and a learned model.
10 // It delegates model evaluation to either the AOT compiled model (the
11 // 'release' mode) or a runtime-loaded model (the 'development' case).
12 //
13 //===----------------------------------------------------------------------===//
14 #include "llvm/Analysis/MLInlineAdvisor.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/Analysis/AssumptionCache.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
20 #include "llvm/Analysis/InlineCost.h"
21 #include "llvm/Analysis/InlineModelFeatureMaps.h"
22 #include "llvm/Analysis/InteractiveModelRunner.h"
23 #include "llvm/Analysis/LazyCallGraph.h"
24 #include "llvm/Analysis/LoopInfo.h"
25 #include "llvm/Analysis/MLModelRunner.h"
26 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
27 #include "llvm/Analysis/ProfileSummaryInfo.h"
28 #include "llvm/Analysis/ReleaseModeModelRunner.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/InstIterator.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/Support/CommandLine.h"
35 
36 using namespace llvm;
37 
38 static cl::opt<std::string> InteractiveChannelBaseName(
39     "inliner-interactive-channel-base", cl::Hidden,
40     cl::desc(
41         "Base file path for the interactive mode. The incoming filename should "
42         "have the name <inliner-interactive-channel-base>.in, while the "
43         "outgoing name should be <inliner-interactive-channel-base>.out"));
44 static const std::string InclDefaultMsg =
45     (Twine("In interactive mode, also send the default policy decision: ") +
46      DefaultDecisionName + ".")
47         .str();
48 static cl::opt<bool>
49     InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
50                               cl::desc(InclDefaultMsg));
51 
52 enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold };
53 
54 static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
55     "ml-inliner-skip-policy", cl::Hidden, cl::init(SkipMLPolicyCriteria::Never),
56     cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"),
57                clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold,
58                           "if-caller-not-cold", "if the caller is not cold")));
59 
60 static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
61                                           cl::Hidden, cl::init(""));
62 
63 #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
64 // codegen-ed file
65 #include "InlinerSizeModel.h" // NOLINT
66 using CompiledModelType = llvm::InlinerSizeModel;
67 #else
68 using CompiledModelType = NoopSavedModelImpl;
69 #endif
70 
71 std::unique_ptr<InlineAdvisor>
72 llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM,
73                             std::function<bool(CallBase &)> GetDefaultAdvice) {
74   if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() &&
75       InteractiveChannelBaseName.empty())
76     return nullptr;
77   std::unique_ptr<MLModelRunner> AOTRunner;
78   if (InteractiveChannelBaseName.empty())
79     AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
80         M.getContext(), FeatureMap, DecisionName,
81         EmbeddedModelRunnerOptions().setModelSelector(ModelSelector));
82   else {
83     auto Features = FeatureMap;
84     if (InteractiveIncludeDefault)
85       Features.push_back(DefaultDecisionSpec);
86     AOTRunner = std::make_unique<InteractiveModelRunner>(
87         M.getContext(), Features, InlineDecisionSpec,
88         InteractiveChannelBaseName + ".out",
89         InteractiveChannelBaseName + ".in");
90   }
91   return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner),
92                                            GetDefaultAdvice);
93 }
94 
95 #define DEBUG_TYPE "inline-ml"
96 
97 static cl::opt<float> SizeIncreaseThreshold(
98     "ml-advisor-size-increase-threshold", cl::Hidden,
99     cl::desc("Maximum factor by which expected native size may increase before "
100              "blocking any further inlining."),
101     cl::init(2.0));
102 
103 static cl::opt<bool> KeepFPICache(
104     "ml-advisor-keep-fpi-cache", cl::Hidden,
105     cl::desc(
106         "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"),
107     cl::init(false));
108 
109 // clang-format off
110 const std::vector<TensorSpec> llvm::FeatureMap{
111 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
112 // InlineCost features - these must come first
113   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
114 
115 // Non-cost features
116   INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
117 #undef POPULATE_NAMES
118 };
119 // clang-format on
120 
121 const char *const llvm::DecisionName = "inlining_decision";
122 const TensorSpec llvm::InlineDecisionSpec =
123     TensorSpec::createSpec<int64_t>(DecisionName, {1});
124 const char *const llvm::DefaultDecisionName = "inlining_default";
125 const TensorSpec llvm::DefaultDecisionSpec =
126     TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1});
127 const char *const llvm::RewardName = "delta_size";
128 
129 CallBase *getInlinableCS(Instruction &I) {
130   if (auto *CS = dyn_cast<CallBase>(&I))
131     if (Function *Callee = CS->getCalledFunction()) {
132       if (!Callee->isDeclaration()) {
133         return CS;
134       }
135     }
136   return nullptr;
137 }
138 
139 MLInlineAdvisor::MLInlineAdvisor(
140     Module &M, ModuleAnalysisManager &MAM,
141     std::unique_ptr<MLModelRunner> Runner,
142     std::function<bool(CallBase &)> GetDefaultAdvice)
143     : InlineAdvisor(
144           M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
145       ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
146       CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
147       InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
148       PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
149   assert(ModelRunner);
150   ModelRunner->switchContext("");
151   // Extract the 'call site height' feature - the position of a call site
152   // relative to the farthest statically reachable SCC node. We don't mutate
153   // this value while inlining happens. Empirically, this feature proved
154   // critical in behavioral cloning - i.e. training a model to mimic the manual
155   // heuristic's decisions - and, thus, equally important for training for
156   // improvement.
157   CallGraph CGraph(M);
158   for (auto I = scc_begin(&CGraph); !I.isAtEnd(); ++I) {
159     const std::vector<CallGraphNode *> &CGNodes = *I;
160     unsigned Level = 0;
161     for (auto *CGNode : CGNodes) {
162       Function *F = CGNode->getFunction();
163       if (!F || F->isDeclaration())
164         continue;
165       for (auto &I : instructions(F)) {
166         if (auto *CS = getInlinableCS(I)) {
167           auto *Called = CS->getCalledFunction();
168           auto Pos = FunctionLevels.find(&CG.get(*Called));
169           // In bottom up traversal, an inlinable callee is either in the
170           // same SCC, or to a function in a visited SCC. So not finding its
171           // level means we haven't visited it yet, meaning it's in this SCC.
172           if (Pos == FunctionLevels.end())
173             continue;
174           Level = std::max(Level, Pos->second + 1);
175         }
176       }
177     }
178     for (auto *CGNode : CGNodes) {
179       Function *F = CGNode->getFunction();
180       if (F && !F->isDeclaration())
181         FunctionLevels[&CG.get(*F)] = Level;
182     }
183   }
184   for (auto KVP : FunctionLevels) {
185     AllNodes.insert(KVP.first);
186     EdgeCount += getLocalCalls(KVP.first->getFunction());
187   }
188   NodeCount = AllNodes.size();
189 }
190 
191 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
192   return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0;
193 }
194 
195 void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) {
196   if (!CurSCC || ForceStop)
197     return;
198   FPICache.clear();
199   // Function passes executed between InlinerPass runs may have changed the
200   // module-wide features.
201   // The cgscc pass manager rules are such that:
202   // - if a pass leads to merging SCCs, then the pipeline is restarted on the
203   // merged SCC
204   // - if a pass leads to splitting the SCC, then we continue with one of the
205   // splits
206   // This means that the NodesInLastSCC is a superset (not strict) of the nodes
207   // that subsequent passes would have processed
208   // - in addition, if new Nodes were created by a pass (e.g. CoroSplit),
209   // they'd be adjacent to Nodes in the last SCC. So we just need to check the
210   // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't
211   // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we
212   // record them at the same level as the original node (this is a choice, may
213   // need revisiting).
214   // - nodes are only deleted at the end of a call graph walk where they are
215   // batch deleted, so we shouldn't see any dead nodes here.
216   while (!NodesInLastSCC.empty()) {
217     const auto *N = *NodesInLastSCC.begin();
218     assert(!N->isDead());
219     NodesInLastSCC.erase(N);
220     EdgeCount += getLocalCalls(N->getFunction());
221     const auto NLevel = FunctionLevels.at(N);
222     for (const auto &E : *(*N)) {
223       const auto *AdjNode = &E.getNode();
224       assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration());
225       auto I = AllNodes.insert(AdjNode);
226       // We've discovered a new function.
227       if (I.second) {
228         ++NodeCount;
229         NodesInLastSCC.insert(AdjNode);
230         FunctionLevels[AdjNode] = NLevel;
231       }
232     }
233   }
234 
235   EdgeCount -= EdgesOfLastSeenNodes;
236   EdgesOfLastSeenNodes = 0;
237 
238   // (Re)use NodesInLastSCC to remember the nodes in the SCC right now,
239   // in case the SCC is split before onPassExit and some nodes are split out
240   assert(NodesInLastSCC.empty());
241   for (const auto &N : *CurSCC)
242     NodesInLastSCC.insert(&N);
243 }
244 
245 void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) {
246   // No need to keep this around - function passes will invalidate it.
247   if (!KeepFPICache)
248     FPICache.clear();
249   if (!CurSCC || ForceStop)
250     return;
251   // Keep track of the nodes and edges we last saw. Then, in onPassEntry,
252   // we update the node count and edge count from the subset of these nodes that
253   // survived.
254   EdgesOfLastSeenNodes = 0;
255 
256   // Check on nodes that were in SCC onPassEntry
257   for (const LazyCallGraph::Node *N : NodesInLastSCC) {
258     assert(!N->isDead());
259     EdgesOfLastSeenNodes += getLocalCalls(N->getFunction());
260   }
261 
262   // Check on nodes that may have got added to SCC
263   for (const auto &N : *CurSCC) {
264     assert(!N.isDead());
265     auto I = NodesInLastSCC.insert(&N);
266     if (I.second)
267       EdgesOfLastSeenNodes += getLocalCalls(N.getFunction());
268   }
269   assert(NodeCount >= NodesInLastSCC.size());
270   assert(EdgeCount >= EdgesOfLastSeenNodes);
271 }
272 
273 int64_t MLInlineAdvisor::getLocalCalls(Function &F) {
274   return getCachedFPI(F).DirectCallsToDefinedFunctions;
275 }
276 
277 // Update the internal state of the advisor, and force invalidate feature
278 // analysis. Currently, we maintain minimal (and very simple) global state - the
279 // number of functions and the number of static calls. We also keep track of the
280 // total IR size in this module, to stop misbehaving policies at a certain bloat
281 // factor (SizeIncreaseThreshold)
282 void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
283                                            bool CalleeWasDeleted) {
284   assert(!ForceStop);
285   Function *Caller = Advice.getCaller();
286   Function *Callee = Advice.getCallee();
287   // The caller features aren't valid anymore.
288   {
289     PreservedAnalyses PA = PreservedAnalyses::all();
290     PA.abandon<FunctionPropertiesAnalysis>();
291     PA.abandon<DominatorTreeAnalysis>();
292     PA.abandon<LoopAnalysis>();
293     FAM.invalidate(*Caller, PA);
294   }
295   Advice.updateCachedCallerFPI(FAM);
296   int64_t IRSizeAfter =
297       getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
298   CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
299   if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
300     ForceStop = true;
301 
302   // We can delta-update module-wide features. We know the inlining only changed
303   // the caller, and maybe the callee (by deleting the latter).
304   // Nodes are simple to update.
305   // For edges, we 'forget' the edges that the caller and callee used to have
306   // before inlining, and add back what they currently have together.
307   int64_t NewCallerAndCalleeEdges =
308       getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
309 
310   // A dead function's node is not actually removed from the call graph until
311   // the end of the call graph walk, but the node no longer belongs to any valid
312   // SCC.
313   if (CalleeWasDeleted) {
314     --NodeCount;
315     NodesInLastSCC.erase(CG.lookup(*Callee));
316     DeadFunctions.insert(Callee);
317   } else {
318     NewCallerAndCalleeEdges +=
319         getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
320   }
321   EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
322   assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
323 }
324 
325 int64_t MLInlineAdvisor::getModuleIRSize() const {
326   int64_t Ret = 0;
327   for (auto &F : M)
328     if (!F.isDeclaration())
329       Ret += getIRSize(F);
330   return Ret;
331 }
332 
333 FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const {
334   auto InsertPair =
335       FPICache.insert(std::make_pair(&F, FunctionPropertiesInfo()));
336   if (!InsertPair.second)
337     return InsertPair.first->second;
338   InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F);
339   return InsertPair.first->second;
340 }
341 
342 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
343   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
344     return Skip;
345 
346   auto &Caller = *CB.getCaller();
347   auto &Callee = *CB.getCalledFunction();
348 
349   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
350     return FAM.getResult<AssumptionAnalysis>(F);
351   };
352   auto &TIR = FAM.getResult<TargetIRAnalysis>(Callee);
353   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller);
354 
355   if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
356     if (!PSI.isFunctionEntryCold(&Caller))
357       return std::make_unique<InlineAdvice>(this, CB, ORE,
358                                             GetDefaultAdvice(CB));
359   }
360   auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
361   // If this is a "never inline" case, there won't be any changes to internal
362   // state we need to track, so we can just return the base InlineAdvice, which
363   // will do nothing interesting.
364   // Same thing if this is a recursive case.
365   if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never ||
366       &Caller == &Callee)
367     return getMandatoryAdvice(CB, false);
368 
369   bool Mandatory =
370       MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always;
371 
372   // If we need to stop, we won't want to track anymore any state changes, so
373   // we just return the base InlineAdvice, which acts as a noop.
374   if (ForceStop) {
375     ORE.emit([&] {
376       return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB)
377              << "Won't attempt inlining because module size grew too much.";
378     });
379     return std::make_unique<InlineAdvice>(this, CB, ORE, Mandatory);
380   }
381 
382   int CostEstimate = 0;
383   if (!Mandatory) {
384     auto IsCallSiteInlinable =
385         llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache);
386     if (!IsCallSiteInlinable) {
387       // We can't inline this for correctness reasons, so return the base
388       // InlineAdvice, as we don't care about tracking any state changes (which
389       // won't happen).
390       return std::make_unique<InlineAdvice>(this, CB, ORE, false);
391     }
392     CostEstimate = *IsCallSiteInlinable;
393   }
394 
395   const auto CostFeatures =
396       llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache);
397   if (!CostFeatures) {
398     return std::make_unique<InlineAdvice>(this, CB, ORE, false);
399   }
400 
401   if (Mandatory)
402     return getMandatoryAdvice(CB, true);
403 
404   auto NrCtantParams = 0;
405   for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
406     NrCtantParams += (isa<Constant>(*I));
407   }
408 
409   auto &CallerBefore = getCachedFPI(Caller);
410   auto &CalleeBefore = getCachedFPI(Callee);
411 
412   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) =
413       CalleeBefore.BasicBlockCount;
414   *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) =
415       getInitialFunctionLevel(Caller);
416   *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount;
417   *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) =
418       NrCtantParams;
419   *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount;
420   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) =
421       CallerBefore.Uses;
422   *ModelRunner->getTensor<int64_t>(
423       FeatureIndex::caller_conditionally_executed_blocks) =
424       CallerBefore.BlocksReachedFromConditionalInstruction;
425   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) =
426       CallerBefore.BasicBlockCount;
427   *ModelRunner->getTensor<int64_t>(
428       FeatureIndex::callee_conditionally_executed_blocks) =
429       CalleeBefore.BlocksReachedFromConditionalInstruction;
430   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) =
431       CalleeBefore.Uses;
432   *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate;
433   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_callee_avail_external) =
434       Callee.hasAvailableExternallyLinkage();
435   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
436       Caller.hasAvailableExternallyLinkage();
437 
438   // Add the cost features
439   for (size_t I = 0;
440        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
441     *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature(
442         static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I);
443   }
444   // This one would have been set up to be right at the end.
445   if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
446     *ModelRunner->getTensor<int64_t>(InlineCostFeatureIndex::NumberOfFeatures) =
447         GetDefaultAdvice(CB);
448   return getAdviceFromModel(CB, ORE);
449 }
450 
451 std::unique_ptr<MLInlineAdvice>
452 MLInlineAdvisor::getAdviceFromModel(CallBase &CB,
453                                     OptimizationRemarkEmitter &ORE) {
454   return std::make_unique<MLInlineAdvice>(
455       this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>()));
456 }
457 
458 std::unique_ptr<InlineAdvice>
459 MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) {
460   if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller())
461            .isReachableFromEntry(CB.getParent()))
462     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
463   return nullptr;
464 }
465 
466 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB,
467                                                                   bool Advice) {
468   // Make sure we track inlinings in all cases - mandatory or not.
469   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
470     return Skip;
471   if (Advice && !ForceStop)
472     return getMandatoryAdviceImpl(CB);
473 
474   // If this is a "never inline" case, there won't be any changes to internal
475   // state we need to track, so we can just return the base InlineAdvice, which
476   // will do nothing interesting.
477   // Same if we are forced to stop - we don't track anymore.
478   return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), Advice);
479 }
480 
481 std::unique_ptr<MLInlineAdvice>
482 MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) {
483   return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true);
484 }
485 
486 void MLInlineAdvisor::print(raw_ostream &OS) const {
487   OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
488      << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n";
489   OS << "[MLInlineAdvisor] FPI:\n";
490   for (auto I : FPICache) {
491     OS << I.first->getName() << ":\n";
492     I.second.print(OS);
493     OS << "\n";
494   }
495   OS << "\n";
496   OS << "[MLInlineAdvisor] FuncLevels:\n";
497   for (auto I : FunctionLevels)
498     OS << (DeadFunctions.contains(&I.first->getFunction())
499                ? "<deleted>"
500                : I.first->getFunction().getName())
501        << " : " << I.second << "\n";
502 
503   OS << "\n";
504 }
505 
506 MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
507                                OptimizationRemarkEmitter &ORE,
508                                bool Recommendation)
509     : InlineAdvice(Advisor, CB, ORE, Recommendation),
510       CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)),
511       CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)),
512       CallerAndCalleeEdges(Advisor->isForcedToStop()
513                                ? 0
514                                : (Advisor->getLocalCalls(*Caller) +
515                                   Advisor->getLocalCalls(*Callee))),
516       PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) {
517   if (Recommendation)
518     FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB);
519 }
520 
521 void MLInlineAdvice::reportContextForRemark(
522     DiagnosticInfoOptimizationBase &OR) {
523   using namespace ore;
524   OR << NV("Callee", Callee->getName());
525   for (size_t I = 0; I < NumberOfFeatures; ++I)
526     OR << NV(FeatureMap[I].name(),
527              *getAdvisor()->getModelRunner().getTensor<int64_t>(I));
528   OR << NV("ShouldInline", isInliningRecommended());
529 }
530 
531 void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const {
532   FPU->finish(FAM);
533 }
534 
535 void MLInlineAdvice::recordInliningImpl() {
536   ORE.emit([&]() {
537     OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block);
538     reportContextForRemark(R);
539     return R;
540   });
541   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ false);
542 }
543 
544 void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() {
545   ORE.emit([&]() {
546     OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc,
547                          Block);
548     reportContextForRemark(R);
549     return R;
550   });
551   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ true);
552 }
553 
554 void MLInlineAdvice::recordUnsuccessfulInliningImpl(
555     const InlineResult &Result) {
556   getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI;
557   ORE.emit([&]() {
558     OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful",
559                                DLoc, Block);
560     reportContextForRemark(R);
561     return R;
562   });
563 }
564 void MLInlineAdvice::recordUnattemptedInliningImpl() {
565   assert(!FPU);
566   ORE.emit([&]() {
567     OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block);
568     reportContextForRemark(R);
569     return R;
570   });
571 }
572