1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H 10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H 11 12 #include "llvm/Analysis/FunctionPropertiesAnalysis.h" 13 #include "llvm/Analysis/InlineAdvisor.h" 14 #include "llvm/Analysis/LazyCallGraph.h" 15 #include "llvm/Analysis/MLModelRunner.h" 16 #include "llvm/Analysis/ProfileSummaryInfo.h" 17 #include "llvm/IR/PassManager.h" 18 19 #include <deque> 20 #include <map> 21 #include <memory> 22 #include <optional> 23 24 namespace llvm { 25 class DiagnosticInfoOptimizationBase; 26 class Module; 27 class MLInlineAdvice; 28 29 class MLInlineAdvisor : public InlineAdvisor { 30 public: 31 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, 32 std::unique_ptr<MLModelRunner> ModelRunner, 33 std::function<bool(CallBase &)> GetDefaultAdvice); 34 35 virtual ~MLInlineAdvisor() = default; 36 37 void onPassEntry(LazyCallGraph::SCC *SCC) override; 38 void onPassExit(LazyCallGraph::SCC *SCC) override; 39 getIRSize(Function & F)40 int64_t getIRSize(Function &F) const { 41 return getCachedFPI(F).TotalInstructionCount; 42 } 43 void onSuccessfulInlining(const MLInlineAdvice &Advice, 44 bool CalleeWasDeleted); 45 isForcedToStop()46 bool isForcedToStop() const { return ForceStop; } 47 int64_t getLocalCalls(Function &F); getModelRunner()48 const MLModelRunner &getModelRunner() const { return *ModelRunner; } 49 FunctionPropertiesInfo &getCachedFPI(Function &) const; 50 51 protected: 52 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; 53 54 std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB, 55 bool Advice) override; 56 57 virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB); 58 59 virtual std::unique_ptr<MLInlineAdvice> 60 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE); 61 62 // Get the initial 'level' of the function, or 0 if the function has been 63 // introduced afterwards. 64 // TODO: should we keep this updated? 65 unsigned getInitialFunctionLevel(const Function &F) const; 66 67 std::unique_ptr<MLModelRunner> ModelRunner; 68 std::function<bool(CallBase &)> GetDefaultAdvice; 69 70 private: 71 int64_t getModuleIRSize() const; 72 std::unique_ptr<InlineAdvice> 73 getSkipAdviceIfUnreachableCallsite(CallBase &CB); 74 void print(raw_ostream &OS) const override; 75 76 // Using std::map to benefit from its iterator / reference non-invalidating 77 // semantics, which make it easy to use `getCachedFPI` results from multiple 78 // calls without needing to copy to avoid invalidation effects. 79 mutable std::map<const Function *, FunctionPropertiesInfo> FPICache; 80 81 LazyCallGraph &CG; 82 83 int64_t NodeCount = 0; 84 int64_t EdgeCount = 0; 85 int64_t EdgesOfLastSeenNodes = 0; 86 87 std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels; 88 const int32_t InitialIRSize = 0; 89 int32_t CurrentIRSize = 0; 90 llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC; 91 DenseSet<const LazyCallGraph::Node *> AllNodes; 92 DenseSet<Function *> DeadFunctions; 93 bool ForceStop = false; 94 ProfileSummaryInfo &PSI; 95 }; 96 97 /// InlineAdvice that tracks changes post inlining. For that reason, it only 98 /// overrides the "successful inlining" extension points. 99 class MLInlineAdvice : public InlineAdvice { 100 public: 101 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 102 OptimizationRemarkEmitter &ORE, bool Recommendation); 103 virtual ~MLInlineAdvice() = default; 104 105 void recordInliningImpl() override; 106 void recordInliningWithCalleeDeletedImpl() override; 107 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; 108 void recordUnattemptedInliningImpl() override; 109 getCaller()110 Function *getCaller() const { return Caller; } getCallee()111 Function *getCallee() const { return Callee; } 112 113 const int64_t CallerIRSize; 114 const int64_t CalleeIRSize; 115 const int64_t CallerAndCalleeEdges; 116 void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const; 117 118 private: 119 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR); getAdvisor()120 MLInlineAdvisor *getAdvisor() const { 121 return static_cast<MLInlineAdvisor *>(Advisor); 122 }; 123 // Make a copy of the FPI of the caller right before inlining. If inlining 124 // fails, we can just update the cache with that value. 125 const FunctionPropertiesInfo PreInlineCallerFPI; 126 std::optional<FunctionPropertiesUpdater> FPU; 127 }; 128 129 } // namespace llvm 130 131 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H 132