xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- MLInlineAdvisor.cpp - machine learned InlineAdvisor ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the interface between the inliner and a learned model.
10 // It delegates model evaluation to either the AOT compiled model (the
11 // 'release' mode) or a runtime-loaded model (the 'development' case).
12 //
13 //===----------------------------------------------------------------------===//
14 #include "llvm/Analysis/MLInlineAdvisor.h"
15 #include "llvm/ADT/SCCIterator.h"
16 #include "llvm/Analysis/AssumptionCache.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
20 #include "llvm/Analysis/InlineCost.h"
21 #include "llvm/Analysis/InlineModelFeatureMaps.h"
22 #include "llvm/Analysis/InteractiveModelRunner.h"
23 #include "llvm/Analysis/LazyCallGraph.h"
24 #include "llvm/Analysis/LoopInfo.h"
25 #include "llvm/Analysis/MLModelRunner.h"
26 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
27 #include "llvm/Analysis/ProfileSummaryInfo.h"
28 #include "llvm/Analysis/ReleaseModeModelRunner.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/InstIterator.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/Support/CommandLine.h"
35 
36 using namespace llvm;
37 
38 static cl::opt<std::string> InteractiveChannelBaseName(
39     "inliner-interactive-channel-base", cl::Hidden,
40     cl::desc(
41         "Base file path for the interactive mode. The incoming filename should "
42         "have the name <inliner-interactive-channel-base>.in, while the "
43         "outgoing name should be <inliner-interactive-channel-base>.out"));
44 static const std::string InclDefaultMsg =
45     (Twine("In interactive mode, also send the default policy decision: ") +
46      DefaultDecisionName + ".")
47         .str();
48 static cl::opt<bool>
49     InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
50                               cl::desc(InclDefaultMsg));
51 
52 enum class SkipMLPolicyCriteria { Never, IfCallerIsNotCold };
53 
54 static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
55     "ml-inliner-skip-policy", cl::Hidden, cl::init(SkipMLPolicyCriteria::Never),
56     cl::values(clEnumValN(SkipMLPolicyCriteria::Never, "never", "never"),
57                clEnumValN(SkipMLPolicyCriteria::IfCallerIsNotCold,
58                           "if-caller-not-cold", "if the caller is not cold")));
59 
60 static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
61                                           cl::Hidden, cl::init(""));
62 
63 #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
64 // codegen-ed file
65 #include "InlinerSizeModel.h" // NOLINT
66 using CompiledModelType = llvm::InlinerSizeModel;
67 #else
68 using CompiledModelType = NoopSavedModelImpl;
69 #endif
70 
71 std::unique_ptr<InlineAdvisor>
72 llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM,
73                             std::function<bool(CallBase &)> GetDefaultAdvice) {
74   if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() &&
75       InteractiveChannelBaseName.empty())
76     return nullptr;
77   std::unique_ptr<MLModelRunner> AOTRunner;
78   if (InteractiveChannelBaseName.empty())
79     AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
80         M.getContext(), FeatureMap, DecisionName,
81         EmbeddedModelRunnerOptions().setModelSelector(ModelSelector));
82   else {
83     auto Features = FeatureMap;
84     if (InteractiveIncludeDefault)
85       Features.push_back(DefaultDecisionSpec);
86     AOTRunner = std::make_unique<InteractiveModelRunner>(
87         M.getContext(), Features, InlineDecisionSpec,
88         InteractiveChannelBaseName + ".out",
89         InteractiveChannelBaseName + ".in");
90   }
91   return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner),
92                                            GetDefaultAdvice);
93 }
94 
95 #define DEBUG_TYPE "inline-ml"
96 
97 static cl::opt<float> SizeIncreaseThreshold(
98     "ml-advisor-size-increase-threshold", cl::Hidden,
99     cl::desc("Maximum factor by which expected native size may increase before "
100              "blocking any further inlining."),
101     cl::init(2.0));
102 
103 static cl::opt<bool> KeepFPICache(
104     "ml-advisor-keep-fpi-cache", cl::Hidden,
105     cl::desc(
106         "For test - keep the ML Inline advisor's FunctionPropertiesInfo cache"),
107     cl::init(false));
108 
109 // clang-format off
110 std::vector<TensorSpec> llvm::FeatureMap{
111 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
112 // InlineCost features - these must come first
113   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
114 
115 // Non-cost features
116   INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
117 #undef POPULATE_NAMES
118 };
119 // clang-format on
120 
121 const char *const llvm::DecisionName = "inlining_decision";
122 const TensorSpec llvm::InlineDecisionSpec =
123     TensorSpec::createSpec<int64_t>(DecisionName, {1});
124 const char *const llvm::DefaultDecisionName = "inlining_default";
125 const TensorSpec llvm::DefaultDecisionSpec =
126     TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1});
127 const char *const llvm::RewardName = "delta_size";
128 
129 CallBase *getInlinableCS(Instruction &I) {
130   if (auto *CS = dyn_cast<CallBase>(&I))
131     if (Function *Callee = CS->getCalledFunction()) {
132       if (!Callee->isDeclaration()) {
133         return CS;
134       }
135     }
136   return nullptr;
137 }
138 
139 MLInlineAdvisor::MLInlineAdvisor(
140     Module &M, ModuleAnalysisManager &MAM,
141     std::unique_ptr<MLModelRunner> Runner,
142     std::function<bool(CallBase &)> GetDefaultAdvice)
143     : InlineAdvisor(
144           M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
145       ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
146       CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
147       UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(M) != nullptr),
148       InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
149       PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
150   assert(ModelRunner);
151   ModelRunner->switchContext("");
152   // Extract the 'call site height' feature - the position of a call site
153   // relative to the farthest statically reachable SCC node. We don't mutate
154   // this value while inlining happens. Empirically, this feature proved
155   // critical in behavioral cloning - i.e. training a model to mimic the manual
156   // heuristic's decisions - and, thus, equally important for training for
157   // improvement.
158   CallGraph CGraph(M);
159   for (auto I = scc_begin(&CGraph); !I.isAtEnd(); ++I) {
160     const std::vector<CallGraphNode *> &CGNodes = *I;
161     unsigned Level = 0;
162     for (auto *CGNode : CGNodes) {
163       Function *F = CGNode->getFunction();
164       if (!F || F->isDeclaration())
165         continue;
166       for (auto &I : instructions(F)) {
167         if (auto *CS = getInlinableCS(I)) {
168           auto *Called = CS->getCalledFunction();
169           auto Pos = FunctionLevels.find(&CG.get(*Called));
170           // In bottom up traversal, an inlinable callee is either in the
171           // same SCC, or to a function in a visited SCC. So not finding its
172           // level means we haven't visited it yet, meaning it's in this SCC.
173           if (Pos == FunctionLevels.end())
174             continue;
175           Level = std::max(Level, Pos->second + 1);
176         }
177       }
178     }
179     for (auto *CGNode : CGNodes) {
180       Function *F = CGNode->getFunction();
181       if (F && !F->isDeclaration())
182         FunctionLevels[&CG.get(*F)] = Level;
183     }
184   }
185   for (auto KVP : FunctionLevels) {
186     AllNodes.insert(KVP.first);
187     EdgeCount += getLocalCalls(KVP.first->getFunction());
188   }
189   NodeCount = AllNodes.size();
190 
191   if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
192     if (!IR2VecVocabResult->isValid()) {
193       M.getContext().emitError("IR2VecVocabAnalysis is not valid");
194       return;
195     }
196     // Add the IR2Vec features to the feature map
197     auto IR2VecDim = IR2VecVocabResult->getDimension();
198     FeatureMap.push_back(
199         TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
200     FeatureMap.push_back(
201         TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
202   }
203 }
204 
205 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
206   return CG.lookup(F) ? FunctionLevels.at(CG.lookup(F)) : 0;
207 }
208 
209 void MLInlineAdvisor::onPassEntry(LazyCallGraph::SCC *CurSCC) {
210   if (!CurSCC || ForceStop)
211     return;
212   FPICache.clear();
213   // Function passes executed between InlinerPass runs may have changed the
214   // module-wide features.
215   // The cgscc pass manager rules are such that:
216   // - if a pass leads to merging SCCs, then the pipeline is restarted on the
217   // merged SCC
218   // - if a pass leads to splitting the SCC, then we continue with one of the
219   // splits
220   // This means that the NodesInLastSCC is a superset (not strict) of the nodes
221   // that subsequent passes would have processed
222   // - in addition, if new Nodes were created by a pass (e.g. CoroSplit),
223   // they'd be adjacent to Nodes in the last SCC. So we just need to check the
224   // boundary of Nodes in NodesInLastSCC for Nodes we haven't seen. We don't
225   // care about the nature of the Edge (call or ref). `FunctionLevels`-wise, we
226   // record them at the same level as the original node (this is a choice, may
227   // need revisiting).
228   // - nodes are only deleted at the end of a call graph walk where they are
229   // batch deleted, so we shouldn't see any dead nodes here.
230   while (!NodesInLastSCC.empty()) {
231     const auto *N = *NodesInLastSCC.begin();
232     assert(!N->isDead());
233     NodesInLastSCC.erase(N);
234     EdgeCount += getLocalCalls(N->getFunction());
235     const auto NLevel = FunctionLevels.at(N);
236     for (const auto &E : *(*N)) {
237       const auto *AdjNode = &E.getNode();
238       assert(!AdjNode->isDead() && !AdjNode->getFunction().isDeclaration());
239       auto I = AllNodes.insert(AdjNode);
240       // We've discovered a new function.
241       if (I.second) {
242         ++NodeCount;
243         NodesInLastSCC.insert(AdjNode);
244         FunctionLevels[AdjNode] = NLevel;
245       }
246     }
247   }
248 
249   EdgeCount -= EdgesOfLastSeenNodes;
250   EdgesOfLastSeenNodes = 0;
251 
252   // (Re)use NodesInLastSCC to remember the nodes in the SCC right now,
253   // in case the SCC is split before onPassExit and some nodes are split out
254   assert(NodesInLastSCC.empty());
255   for (const auto &N : *CurSCC)
256     NodesInLastSCC.insert(&N);
257 }
258 
259 void MLInlineAdvisor::onPassExit(LazyCallGraph::SCC *CurSCC) {
260   // No need to keep this around - function passes will invalidate it.
261   if (!KeepFPICache)
262     FPICache.clear();
263   if (!CurSCC || ForceStop)
264     return;
265   // Keep track of the nodes and edges we last saw. Then, in onPassEntry,
266   // we update the node count and edge count from the subset of these nodes that
267   // survived.
268   EdgesOfLastSeenNodes = 0;
269 
270   // Check on nodes that were in SCC onPassEntry
271   for (const LazyCallGraph::Node *N : NodesInLastSCC) {
272     assert(!N->isDead());
273     EdgesOfLastSeenNodes += getLocalCalls(N->getFunction());
274   }
275 
276   // Check on nodes that may have got added to SCC
277   for (const auto &N : *CurSCC) {
278     assert(!N.isDead());
279     auto I = NodesInLastSCC.insert(&N);
280     if (I.second)
281       EdgesOfLastSeenNodes += getLocalCalls(N.getFunction());
282   }
283   assert(NodeCount >= NodesInLastSCC.size());
284   assert(EdgeCount >= EdgesOfLastSeenNodes);
285 }
286 
287 int64_t MLInlineAdvisor::getLocalCalls(Function &F) {
288   return getCachedFPI(F).DirectCallsToDefinedFunctions;
289 }
290 
291 // Update the internal state of the advisor, and force invalidate feature
292 // analysis. Currently, we maintain minimal (and very simple) global state - the
293 // number of functions and the number of static calls. We also keep track of the
294 // total IR size in this module, to stop misbehaving policies at a certain bloat
295 // factor (SizeIncreaseThreshold)
296 void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
297                                            bool CalleeWasDeleted) {
298   assert(!ForceStop);
299   Function *Caller = Advice.getCaller();
300   Function *Callee = Advice.getCallee();
301   // The caller features aren't valid anymore.
302   {
303     PreservedAnalyses PA = PreservedAnalyses::all();
304     PA.abandon<FunctionPropertiesAnalysis>();
305     PA.abandon<LoopAnalysis>();
306     FAM.invalidate(*Caller, PA);
307   }
308   Advice.updateCachedCallerFPI(FAM);
309   int64_t IRSizeAfter =
310       getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
311   CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
312   if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
313     ForceStop = true;
314 
315   // We can delta-update module-wide features. We know the inlining only changed
316   // the caller, and maybe the callee (by deleting the latter).
317   // Nodes are simple to update.
318   // For edges, we 'forget' the edges that the caller and callee used to have
319   // before inlining, and add back what they currently have together.
320   int64_t NewCallerAndCalleeEdges =
321       getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
322 
323   // A dead function's node is not actually removed from the call graph until
324   // the end of the call graph walk, but the node no longer belongs to any valid
325   // SCC.
326   if (CalleeWasDeleted) {
327     --NodeCount;
328     NodesInLastSCC.erase(CG.lookup(*Callee));
329     DeadFunctions.insert(Callee);
330   } else {
331     NewCallerAndCalleeEdges +=
332         getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
333   }
334   EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
335   assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
336 }
337 
338 int64_t MLInlineAdvisor::getModuleIRSize() const {
339   int64_t Ret = 0;
340   for (auto &F : M)
341     if (!F.isDeclaration())
342       Ret += getIRSize(F);
343   return Ret;
344 }
345 
346 FunctionPropertiesInfo &MLInlineAdvisor::getCachedFPI(Function &F) const {
347   auto InsertPair = FPICache.try_emplace(&F);
348   if (!InsertPair.second)
349     return InsertPair.first->second;
350   InsertPair.first->second = FAM.getResult<FunctionPropertiesAnalysis>(F);
351   return InsertPair.first->second;
352 }
353 
354 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
355   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
356     return Skip;
357 
358   auto &Caller = *CB.getCaller();
359   auto &Callee = *CB.getCalledFunction();
360 
361   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
362     return FAM.getResult<AssumptionAnalysis>(F);
363   };
364   auto &TIR = FAM.getResult<TargetIRAnalysis>(Callee);
365   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller);
366 
367   if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
368     if (!PSI.isFunctionEntryCold(&Caller))
369       return std::make_unique<InlineAdvice>(this, CB, ORE,
370                                             GetDefaultAdvice(CB));
371   }
372   auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
373   // If this is a "never inline" case, there won't be any changes to internal
374   // state we need to track, so we can just return the base InlineAdvice, which
375   // will do nothing interesting.
376   // Same thing if this is a recursive case.
377   if (MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never ||
378       &Caller == &Callee)
379     return getMandatoryAdvice(CB, false);
380 
381   bool Mandatory =
382       MandatoryKind == InlineAdvisor::MandatoryInliningKind::Always;
383 
384   // If we need to stop, we won't want to track anymore any state changes, so
385   // we just return the base InlineAdvice, which acts as a noop.
386   if (ForceStop) {
387     ORE.emit([&] {
388       return OptimizationRemarkMissed(DEBUG_TYPE, "ForceStop", &CB)
389              << "Won't attempt inlining because module size grew too much.";
390     });
391     return std::make_unique<InlineAdvice>(this, CB, ORE, Mandatory);
392   }
393 
394   int CostEstimate = 0;
395   if (!Mandatory) {
396     auto IsCallSiteInlinable =
397         llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache);
398     if (!IsCallSiteInlinable) {
399       // We can't inline this for correctness reasons, so return the base
400       // InlineAdvice, as we don't care about tracking any state changes (which
401       // won't happen).
402       return std::make_unique<InlineAdvice>(this, CB, ORE, false);
403     }
404     CostEstimate = *IsCallSiteInlinable;
405   }
406 
407   const auto CostFeatures =
408       llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache);
409   if (!CostFeatures) {
410     return std::make_unique<InlineAdvice>(this, CB, ORE, false);
411   }
412 
413   if (Mandatory)
414     return getMandatoryAdvice(CB, true);
415 
416   auto NumCtantParams = 0;
417   for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
418     NumCtantParams += (isa<Constant>(*I));
419   }
420 
421   auto &CallerBefore = getCachedFPI(Caller);
422   auto &CalleeBefore = getCachedFPI(Callee);
423 
424   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) =
425       CalleeBefore.BasicBlockCount;
426   *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) =
427       getInitialFunctionLevel(Caller);
428   *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount;
429   *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) =
430       NumCtantParams;
431   *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount;
432   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) =
433       CallerBefore.Uses;
434   *ModelRunner->getTensor<int64_t>(
435       FeatureIndex::caller_conditionally_executed_blocks) =
436       CallerBefore.BlocksReachedFromConditionalInstruction;
437   *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) =
438       CallerBefore.BasicBlockCount;
439   *ModelRunner->getTensor<int64_t>(
440       FeatureIndex::callee_conditionally_executed_blocks) =
441       CalleeBefore.BlocksReachedFromConditionalInstruction;
442   *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) =
443       CalleeBefore.Uses;
444   *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate;
445   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_callee_avail_external) =
446       Callee.hasAvailableExternallyLinkage();
447   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
448       Caller.hasAvailableExternallyLinkage();
449 
450   if (UseIR2Vec) {
451     // Python side expects float embeddings. The IR2Vec embeddings are doubles
452     // as of now due to the restriction of fromJSON method used by the
453     // readVocabulary method in ir2vec::Embeddings.
454     auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
455                             FeatureIndex Index) {
456       llvm::transform(Embedding, ModelRunner->getTensor<float>(Index),
457                       [](double Val) { return static_cast<float>(Val); });
458     };
459 
460     setEmbedding(CalleeBefore.getFunctionEmbedding(),
461                  FeatureIndex::callee_embedding);
462     setEmbedding(CallerBefore.getFunctionEmbedding(),
463                  FeatureIndex::caller_embedding);
464   }
465 
466   // Add the cost features
467   for (size_t I = 0;
468        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
469     *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature(
470         static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I);
471   }
472   // This one would have been set up to be right at the end.
473   if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
474     *ModelRunner->getTensor<int64_t>(FeatureMap.size()) = GetDefaultAdvice(CB);
475   return getAdviceFromModel(CB, ORE);
476 }
477 
478 std::unique_ptr<MLInlineAdvice>
479 MLInlineAdvisor::getAdviceFromModel(CallBase &CB,
480                                     OptimizationRemarkEmitter &ORE) {
481   return std::make_unique<MLInlineAdvice>(
482       this, CB, ORE, static_cast<bool>(ModelRunner->evaluate<int64_t>()));
483 }
484 
485 std::unique_ptr<InlineAdvice>
486 MLInlineAdvisor::getSkipAdviceIfUnreachableCallsite(CallBase &CB) {
487   if (!FAM.getResult<DominatorTreeAnalysis>(*CB.getCaller())
488            .isReachableFromEntry(CB.getParent()))
489     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
490   return nullptr;
491 }
492 
493 std::unique_ptr<InlineAdvice> MLInlineAdvisor::getMandatoryAdvice(CallBase &CB,
494                                                                   bool Advice) {
495   // Make sure we track inlinings in all cases - mandatory or not.
496   if (auto Skip = getSkipAdviceIfUnreachableCallsite(CB))
497     return Skip;
498   if (Advice && !ForceStop)
499     return getMandatoryAdviceImpl(CB);
500 
501   // If this is a "never inline" case, there won't be any changes to internal
502   // state we need to track, so we can just return the base InlineAdvice, which
503   // will do nothing interesting.
504   // Same if we are forced to stop - we don't track anymore.
505   return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), Advice);
506 }
507 
508 std::unique_ptr<MLInlineAdvice>
509 MLInlineAdvisor::getMandatoryAdviceImpl(CallBase &CB) {
510   return std::make_unique<MLInlineAdvice>(this, CB, getCallerORE(CB), true);
511 }
512 
513 void MLInlineAdvisor::print(raw_ostream &OS) const {
514   OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
515      << " EdgesOfLastSeenNodes: " << EdgesOfLastSeenNodes << "\n";
516   OS << "[MLInlineAdvisor] FPI:\n";
517   for (auto I : FPICache) {
518     OS << I.first->getName() << ":\n";
519     I.second.print(OS);
520     OS << "\n";
521   }
522   OS << "\n";
523   OS << "[MLInlineAdvisor] FuncLevels:\n";
524   for (auto I : FunctionLevels)
525     OS << (DeadFunctions.contains(&I.first->getFunction())
526                ? "<deleted>"
527                : I.first->getFunction().getName())
528        << " : " << I.second << "\n";
529 
530   OS << "\n";
531 }
532 
533 MLInlineAdvice::MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
534                                OptimizationRemarkEmitter &ORE,
535                                bool Recommendation)
536     : InlineAdvice(Advisor, CB, ORE, Recommendation),
537       CallerIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Caller)),
538       CalleeIRSize(Advisor->isForcedToStop() ? 0 : Advisor->getIRSize(*Callee)),
539       CallerAndCalleeEdges(Advisor->isForcedToStop()
540                                ? 0
541                                : (Advisor->getLocalCalls(*Caller) +
542                                   Advisor->getLocalCalls(*Callee))),
543       PreInlineCallerFPI(Advisor->getCachedFPI(*Caller)) {
544   if (Recommendation)
545     FPU.emplace(Advisor->getCachedFPI(*getCaller()), CB);
546 }
547 
548 void MLInlineAdvice::reportContextForRemark(
549     DiagnosticInfoOptimizationBase &OR) {
550   using namespace ore;
551   OR << NV("Callee", Callee->getName());
552   for (size_t I = 0; I < FeatureMap.size(); ++I)
553     OR << NV(FeatureMap[I].name(),
554              *getAdvisor()->getModelRunner().getTensor<int64_t>(I));
555   OR << NV("ShouldInline", isInliningRecommended());
556 }
557 
558 void MLInlineAdvice::updateCachedCallerFPI(FunctionAnalysisManager &FAM) const {
559   FPU->finish(FAM);
560 }
561 
562 void MLInlineAdvice::recordInliningImpl() {
563   ORE.emit([&]() {
564     OptimizationRemark R(DEBUG_TYPE, "InliningSuccess", DLoc, Block);
565     reportContextForRemark(R);
566     return R;
567   });
568   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ false);
569 }
570 
571 void MLInlineAdvice::recordInliningWithCalleeDeletedImpl() {
572   ORE.emit([&]() {
573     OptimizationRemark R(DEBUG_TYPE, "InliningSuccessWithCalleeDeleted", DLoc,
574                          Block);
575     reportContextForRemark(R);
576     return R;
577   });
578   getAdvisor()->onSuccessfulInlining(*this, /*CalleeWasDeleted*/ true);
579 }
580 
581 void MLInlineAdvice::recordUnsuccessfulInliningImpl(
582     const InlineResult &Result) {
583   getAdvisor()->getCachedFPI(*Caller) = PreInlineCallerFPI;
584   ORE.emit([&]() {
585     OptimizationRemarkMissed R(DEBUG_TYPE, "InliningAttemptedAndUnsuccessful",
586                                DLoc, Block);
587     reportContextForRemark(R);
588     return R;
589   });
590 }
591 void MLInlineAdvice::recordUnattemptedInliningImpl() {
592   assert(!FPU);
593   ORE.emit([&]() {
594     OptimizationRemarkMissed R(DEBUG_TYPE, "IniningNotAttempted", DLoc, Block);
595     reportContextForRemark(R);
596     return R;
597   });
598 }
599