1 //===- CtxProfAnalysis.h - maintain contextual profile info -*- C++ ---*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 #ifndef LLVM_ANALYSIS_CTXPROFANALYSIS_H 10 #define LLVM_ANALYSIS_CTXPROFANALYSIS_H 11 12 #include "llvm/ADT/SetVector.h" 13 #include "llvm/IR/GlobalValue.h" 14 #include "llvm/IR/InstrTypes.h" 15 #include "llvm/IR/IntrinsicInst.h" 16 #include "llvm/IR/PassManager.h" 17 #include "llvm/ProfileData/PGOCtxProfReader.h" 18 #include "llvm/Support/Compiler.h" 19 #include <optional> 20 21 namespace llvm { 22 23 class CtxProfAnalysis; 24 25 using FlatIndirectTargets = DenseMap<GlobalValue::GUID, uint64_t>; 26 using CtxProfFlatIndirectCallProfile = 27 DenseMap<GlobalValue::GUID, DenseMap<uint32_t, FlatIndirectTargets>>; 28 29 /// The instrumented contextual profile, produced by the CtxProfAnalysis. 30 class PGOContextualProfile { 31 friend class CtxProfAnalysis; 32 friend class CtxProfAnalysisPrinterPass; 33 struct FunctionInfo { 34 uint32_t NextCounterIndex = 0; 35 uint32_t NextCallsiteIndex = 0; 36 const std::string Name; 37 PGOCtxProfContext Index; FunctionInfoFunctionInfo38 FunctionInfo(StringRef Name) : Name(Name) {} 39 }; 40 PGOCtxProfile Profiles; 41 42 // True if this module is a post-thinlto module containing just functions 43 // participating in one or more contextual profiles. 44 bool IsInSpecializedModule = false; 45 46 // For the GUIDs in this module, associate metadata about each function which 47 // we'll need when we maintain the profiles during IPO transformations. 48 std::map<GlobalValue::GUID, FunctionInfo> FuncInfo; 49 50 /// Get the GUID of this Function if it's defined in this module. 51 LLVM_ABI GlobalValue::GUID getDefinedFunctionGUID(const Function &F) const; 52 53 // This is meant to be constructed from CtxProfAnalysis, which will also set 54 // its state piecemeal. 55 PGOContextualProfile() = default; 56 57 void initIndex(); 58 59 public: 60 PGOContextualProfile(const PGOContextualProfile &) = delete; 61 PGOContextualProfile(PGOContextualProfile &&) = default; 62 contexts()63 const CtxProfContextualProfiles &contexts() const { 64 return Profiles.Contexts; 65 } 66 profiles()67 const PGOCtxProfile &profiles() const { return Profiles; } 68 69 LLVM_ABI bool isInSpecializedModule() const; 70 isFunctionKnown(const Function & F)71 bool isFunctionKnown(const Function &F) const { 72 return getDefinedFunctionGUID(F) != 0; 73 } 74 getFunctionName(GlobalValue::GUID GUID)75 StringRef getFunctionName(GlobalValue::GUID GUID) const { 76 auto It = FuncInfo.find(GUID); 77 if (It == FuncInfo.end()) 78 return ""; 79 return It->second.Name; 80 } 81 getNumCounters(const Function & F)82 uint32_t getNumCounters(const Function &F) const { 83 assert(isFunctionKnown(F)); 84 return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex; 85 } 86 getNumCallsites(const Function & F)87 uint32_t getNumCallsites(const Function &F) const { 88 assert(isFunctionKnown(F)); 89 return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex; 90 } 91 allocateNextCounterIndex(const Function & F)92 uint32_t allocateNextCounterIndex(const Function &F) { 93 assert(isFunctionKnown(F)); 94 return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex++; 95 } 96 allocateNextCallsiteIndex(const Function & F)97 uint32_t allocateNextCallsiteIndex(const Function &F) { 98 assert(isFunctionKnown(F)); 99 return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++; 100 } 101 102 using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>; 103 using Visitor = function_ref<void(PGOCtxProfContext &)>; 104 105 LLVM_ABI void update(Visitor, const Function &F); 106 LLVM_ABI void visit(ConstVisitor, const Function *F = nullptr) const; 107 108 LLVM_ABI const CtxProfFlatProfile flatten() const; 109 LLVM_ABI const CtxProfFlatIndirectCallProfile flattenVirtCalls() const; 110 invalidate(Module &,const PreservedAnalyses & PA,ModuleAnalysisManager::Invalidator &)111 bool invalidate(Module &, const PreservedAnalyses &PA, 112 ModuleAnalysisManager::Invalidator &) { 113 // Check whether the analysis has been explicitly invalidated. Otherwise, 114 // it's stateless and remains preserved. 115 auto PAC = PA.getChecker<CtxProfAnalysis>(); 116 return !PAC.preservedWhenStateless(); 117 } 118 }; 119 120 class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> { 121 const std::optional<StringRef> Profile; 122 123 public: 124 LLVM_ABI static AnalysisKey Key; 125 LLVM_ABI explicit CtxProfAnalysis( 126 std::optional<StringRef> Profile = std::nullopt); 127 128 using Result = PGOContextualProfile; 129 130 LLVM_ABI PGOContextualProfile run(Module &M, ModuleAnalysisManager &MAM); 131 132 /// Get the instruction instrumenting a callsite, or nullptr if that cannot be 133 /// found. 134 LLVM_ABI static InstrProfCallsite *getCallsiteInstrumentation(CallBase &CB); 135 136 /// Get the instruction instrumenting a BB, or nullptr if not present. 137 LLVM_ABI static InstrProfIncrementInst *getBBInstrumentation(BasicBlock &BB); 138 139 /// Get the step instrumentation associated with a `select` 140 LLVM_ABI static InstrProfIncrementInstStep * 141 getSelectInstrumentation(SelectInst &SI); 142 143 // FIXME: refactor to an advisor model, and separate 144 LLVM_ABI static void collectIndirectCallPromotionList( 145 CallBase &IC, Result &Profile, 146 SetVector<std::pair<CallBase *, Function *>> &Candidates); 147 }; 148 149 class CtxProfAnalysisPrinterPass 150 : public PassInfoMixin<CtxProfAnalysisPrinterPass> { 151 public: 152 enum class PrintMode { Everything, YAML }; 153 LLVM_ABI explicit CtxProfAnalysisPrinterPass(raw_ostream &OS); 154 155 LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); isRequired()156 static bool isRequired() { return true; } 157 158 private: 159 raw_ostream &OS; 160 const PrintMode Mode; 161 }; 162 163 /// Utility that propagates counter values to each basic block and to each edge 164 /// when a basic block has more than one outgoing edge, using an adaptation of 165 /// PGOUseFunc::populateCounters. 166 // FIXME(mtrofin): look into factoring the code to share one implementation. 167 class ProfileAnnotatorImpl; 168 class ProfileAnnotator { 169 std::unique_ptr<ProfileAnnotatorImpl> PImpl; 170 171 public: 172 LLVM_ABI ProfileAnnotator(const Function &F, ArrayRef<uint64_t> RawCounters); 173 LLVM_ABI uint64_t getBBCount(const BasicBlock &BB) const; 174 175 // Finds the true and false counts for the given select instruction. Returns 176 // false if the select doesn't have instrumentation or if the count of the 177 // parent BB is 0. 178 LLVM_ABI bool getSelectInstrProfile(SelectInst &SI, uint64_t &TrueCount, 179 uint64_t &FalseCount) const; 180 // Clears Profile and populates it with the edge weights, in the same order as 181 // they need to appear in the MD_prof metadata. Also computes the max of those 182 // weights an returns it in MaxCount. Returs false if: 183 // - the BB has less than 2 successors 184 // - the counts are 0 185 LLVM_ABI bool getOutgoingBranchWeights(BasicBlock &BB, 186 SmallVectorImpl<uint64_t> &Profile, 187 uint64_t &MaxCount) const; 188 LLVM_ABI ~ProfileAnnotator(); 189 }; 190 191 /// Assign a GUID to functions as metadata. GUID calculation takes linkage into 192 /// account, which may change especially through and after thinlto. By 193 /// pre-computing and assigning as metadata, this mechanism is resilient to such 194 /// changes (as well as name changes e.g. suffix ".llvm." additions). 195 196 // FIXME(mtrofin): we can generalize this mechanism to calculate a GUID early in 197 // the pass pipeline, associate it with any Global Value, and then use it for 198 // PGO and ThinLTO. 199 // At that point, this should be moved elsewhere. 200 class AssignGUIDPass : public PassInfoMixin<AssignGUIDPass> { 201 public: 202 explicit AssignGUIDPass() = default; 203 204 /// Assign a GUID *if* one is not already assign, as a function metadata named 205 /// `GUIDMetadataName`. 206 LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 207 LLVM_ABI static const char *GUIDMetadataName; 208 // This should become GlobalValue::getGUID 209 LLVM_ABI static uint64_t getGUID(const Function &F); 210 }; 211 212 } // namespace llvm 213 #endif // LLVM_ANALYSIS_CTXPROFANALYSIS_H 214