1 //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Flattens the contextual profile and lowers it to MD_prof.
10 // This should happen after all IPO (which is assumed to have maintained the
11 // contextual profile) happened. Flattening consists of summing the values at
12 // the same index of the counters belonging to all the contexts of a function.
13 // The lowering consists of materializing the counter values to function
14 // entrypoint counts and branch probabilities.
15 //
16 // This pass also removes contextual instrumentation, which has been kept around
17 // to facilitate its functionality.
18 //
19 //===----------------------------------------------------------------------===//
20
21 #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Analysis/CFG.h"
25 #include "llvm/Analysis/CtxProfAnalysis.h"
26 #include "llvm/Analysis/ProfileSummaryInfo.h"
27 #include "llvm/IR/Analysis.h"
28 #include "llvm/IR/CFG.h"
29 #include "llvm/IR/Dominators.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/IR/ProfileSummary.h"
35 #include "llvm/ProfileData/ProfileCommon.h"
36 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
37 #include "llvm/Transforms/Scalar/DCE.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39
40 using namespace llvm;
41
42 #define DEBUG_TYPE "ctx_prof_flatten"
43
44 namespace {
45
46 /// Assign branch weights and function entry count. Also update the PSI
47 /// builder.
assignProfileData(Function & F,ArrayRef<uint64_t> RawCounters)48 void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) {
49 assert(!RawCounters.empty());
50 ProfileAnnotator PA(F, RawCounters);
51
52 F.setEntryCount(RawCounters[0]);
53 SmallVector<uint64_t, 2> ProfileHolder;
54
55 for (auto &BB : F) {
56 for (auto &I : BB)
57 if (auto *SI = dyn_cast<SelectInst>(&I)) {
58 uint64_t TrueCount, FalseCount = 0;
59 if (!PA.getSelectInstrProfile(*SI, TrueCount, FalseCount))
60 continue;
61 setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
62 std::max(TrueCount, FalseCount));
63 }
64 if (succ_size(&BB) < 2)
65 continue;
66 uint64_t MaxCount = 0;
67 if (!PA.getOutgoingBranchWeights(BB, ProfileHolder, MaxCount))
68 continue;
69 assert(MaxCount > 0);
70 setProfMetadata(F.getParent(), BB.getTerminator(), ProfileHolder, MaxCount);
71 }
72 }
73
areAllBBsReachable(const Function & F,FunctionAnalysisManager & FAM)74 [[maybe_unused]] bool areAllBBsReachable(const Function &F,
75 FunctionAnalysisManager &FAM) {
76 auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));
77 return llvm::all_of(
78 F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); });
79 }
80
clearColdFunctionProfile(Function & F)81 void clearColdFunctionProfile(Function &F) {
82 for (auto &BB : F)
83 BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);
84 F.setEntryCount(0U);
85 }
86
removeInstrumentation(Function & F)87 void removeInstrumentation(Function &F) {
88 for (auto &BB : F)
89 for (auto &I : llvm::make_early_inc_range(BB))
90 if (isa<InstrProfCntrInstBase>(I))
91 I.eraseFromParent();
92 }
93
annotateIndirectCall(Module & M,CallBase & CB,const DenseMap<uint32_t,FlatIndirectTargets> & FlatProf,const InstrProfCallsite & Ins)94 void annotateIndirectCall(
95 Module &M, CallBase &CB,
96 const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf,
97 const InstrProfCallsite &Ins) {
98 auto Idx = Ins.getIndex()->getZExtValue();
99 auto FIt = FlatProf.find(Idx);
100 if (FIt == FlatProf.end())
101 return;
102 const auto &Targets = FIt->second;
103 SmallVector<InstrProfValueData, 2> Data;
104 uint64_t Sum = 0;
105 for (auto &[Guid, Count] : Targets) {
106 Data.push_back({/*.Value=*/Guid, /*.Count=*/Count});
107 Sum += Count;
108 }
109
110 llvm::sort(Data,
111 [](const InstrProfValueData &A, const InstrProfValueData &B) {
112 return A.Count > B.Count;
113 });
114 llvm::annotateValueSite(M, CB, Data, Sum,
115 InstrProfValueKind::IPVK_IndirectCallTarget,
116 Data.size());
117 LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB
118 << CB.getMetadata(LLVMContext::MD_prof) << "\n");
119 }
120
121 // We normally return a "Changed" bool, but the calling pass' run assumes
122 // something will change - some profile will be added - so this won't add much
123 // by returning false when applicable.
annotateIndirectCalls(Module & M,const CtxProfAnalysis::Result & CtxProf)124 void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) {
125 const auto FlatIndCalls = CtxProf.flattenVirtCalls();
126 for (auto &F : M) {
127 if (F.isDeclaration())
128 continue;
129 auto FlatProfIter = FlatIndCalls.find(AssignGUIDPass::getGUID(F));
130 if (FlatProfIter == FlatIndCalls.end())
131 continue;
132 const auto &FlatProf = FlatProfIter->second;
133 for (auto &BB : F) {
134 for (auto &I : BB) {
135 auto *CB = dyn_cast<CallBase>(&I);
136 if (!CB || !CB->isIndirectCall())
137 continue;
138 if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB))
139 annotateIndirectCall(M, *CB, FlatProf, *Ins);
140 }
141 }
142 }
143 }
144
145 } // namespace
146
run(Module & M,ModuleAnalysisManager & MAM)147 PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M,
148 ModuleAnalysisManager &MAM) {
149 // Ensure in all cases the instrumentation is removed: if this module had no
150 // roots, the contextual profile would evaluate to false, but there would
151 // still be instrumentation.
152 // Note: in such cases we leave as-is any other profile info (if present -
153 // e.g. synthetic weights, etc) because it wouldn't interfere with the
154 // contextual - based one (which would be in other modules)
155 auto OnExit = llvm::make_scope_exit([&]() {
156 if (IsPreThinlink)
157 return;
158 for (auto &F : M)
159 removeInstrumentation(F);
160 });
161 auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
162 // post-thinlink, we only reprocess for the module(s) containing the
163 // contextual tree. For everything else, OnExit will just clean the
164 // instrumentation.
165 if (!IsPreThinlink && !CtxProf.isInSpecializedModule())
166 return PreservedAnalyses::none();
167
168 if (IsPreThinlink)
169 annotateIndirectCalls(M, CtxProf);
170 const auto FlattenedProfile = CtxProf.flatten();
171
172 for (auto &F : M) {
173 if (F.isDeclaration())
174 continue;
175
176 assert(areAllBBsReachable(
177 F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)
178 .getManager()) &&
179 "Function has unreacheable basic blocks. The expectation was that "
180 "DCE was run before.");
181
182 auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F));
183 // If this function didn't appear in the contextual profile, it's cold.
184 if (It == FlattenedProfile.end())
185 clearColdFunctionProfile(F);
186 else
187 assignProfileData(F, It->second);
188 }
189 InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
190 // use here the flat profiles just so the importer doesn't complain about
191 // how different the PSIs are between the module with the roots and the
192 // various modules it imports.
193 for (auto &C : FlattenedProfile) {
194 PB.addEntryCount(C.second[0]);
195 for (auto V : llvm::drop_begin(C.second))
196 PB.addInternalCount(V);
197 }
198
199 M.setProfileSummary(PB.getSummary()->getMD(M.getContext()),
200 ProfileSummary::Kind::PSK_Instr);
201 PreservedAnalyses PA;
202 PA.abandon<ProfileSummaryAnalysis>();
203 MAM.invalidate(M, PA);
204 auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
205 PSI.refresh(PB.getSummary());
206 return PreservedAnalyses::none();
207 }
208