xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/MLInlineAdvisor.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H
11 
12 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
13 #include "llvm/Analysis/InlineAdvisor.h"
14 #include "llvm/Analysis/LazyCallGraph.h"
15 #include "llvm/Analysis/MLModelRunner.h"
16 #include "llvm/Analysis/ProfileSummaryInfo.h"
17 #include "llvm/IR/PassManager.h"
18 
19 #include <deque>
20 #include <map>
21 #include <memory>
22 #include <optional>
23 
24 namespace llvm {
25 class DiagnosticInfoOptimizationBase;
26 class Module;
27 class MLInlineAdvice;
28 
29 class MLInlineAdvisor : public InlineAdvisor {
30 public:
31   MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
32                   std::unique_ptr<MLModelRunner> ModelRunner,
33                   std::function<bool(CallBase &)> GetDefaultAdvice);
34 
35   virtual ~MLInlineAdvisor() = default;
36 
37   void onPassEntry(LazyCallGraph::SCC *SCC) override;
38   void onPassExit(LazyCallGraph::SCC *SCC) override;
39 
getIRSize(Function & F)40   int64_t getIRSize(Function &F) const {
41     return getCachedFPI(F).TotalInstructionCount;
42   }
43   void onSuccessfulInlining(const MLInlineAdvice &Advice,
44                             bool CalleeWasDeleted);
45 
isForcedToStop()46   bool isForcedToStop() const { return ForceStop; }
47   int64_t getLocalCalls(Function &F);
getModelRunner()48   const MLModelRunner &getModelRunner() const { return *ModelRunner; }
49   FunctionPropertiesInfo &getCachedFPI(Function &) const;
50 
51 protected:
52   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
53 
54   std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
55                                                    bool Advice) override;
56 
57   virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
58 
59   virtual std::unique_ptr<MLInlineAdvice>
60   getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
61 
62   // Get the initial 'level' of the function, or 0 if the function has been
63   // introduced afterwards.
64   // TODO: should we keep this updated?
65   unsigned getInitialFunctionLevel(const Function &F) const;
66 
67   std::unique_ptr<MLModelRunner> ModelRunner;
68   std::function<bool(CallBase &)> GetDefaultAdvice;
69 
70 private:
71   int64_t getModuleIRSize() const;
72   std::unique_ptr<InlineAdvice>
73   getSkipAdviceIfUnreachableCallsite(CallBase &CB);
74   void print(raw_ostream &OS) const override;
75 
76   // Using std::map to benefit from its iterator / reference non-invalidating
77   // semantics, which make it easy to use `getCachedFPI` results from multiple
78   // calls without needing to copy to avoid invalidation effects.
79   mutable std::map<const Function *, FunctionPropertiesInfo> FPICache;
80 
81   LazyCallGraph &CG;
82 
83   int64_t NodeCount = 0;
84   int64_t EdgeCount = 0;
85   int64_t EdgesOfLastSeenNodes = 0;
86 
87   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
88   const int32_t InitialIRSize = 0;
89   int32_t CurrentIRSize = 0;
90   llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC;
91   DenseSet<const LazyCallGraph::Node *> AllNodes;
92   DenseSet<Function *> DeadFunctions;
93   bool ForceStop = false;
94   ProfileSummaryInfo &PSI;
95 };
96 
97 /// InlineAdvice that tracks changes post inlining. For that reason, it only
98 /// overrides the "successful inlining" extension points.
99 class MLInlineAdvice : public InlineAdvice {
100 public:
101   MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
102                  OptimizationRemarkEmitter &ORE, bool Recommendation);
103   virtual ~MLInlineAdvice() = default;
104 
105   void recordInliningImpl() override;
106   void recordInliningWithCalleeDeletedImpl() override;
107   void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
108   void recordUnattemptedInliningImpl() override;
109 
getCaller()110   Function *getCaller() const { return Caller; }
getCallee()111   Function *getCallee() const { return Callee; }
112 
113   const int64_t CallerIRSize;
114   const int64_t CalleeIRSize;
115   const int64_t CallerAndCalleeEdges;
116   void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const;
117 
118 private:
119   void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
getAdvisor()120   MLInlineAdvisor *getAdvisor() const {
121     return static_cast<MLInlineAdvisor *>(Advisor);
122   };
123   // Make a copy of the FPI of the caller right before inlining. If inlining
124   // fails, we can just update the cache with that value.
125   const FunctionPropertiesInfo PreInlineCallerFPI;
126   std::optional<FunctionPropertiesUpdater> FPU;
127 };
128 
129 } // namespace llvm
130 
131 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
132