xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Flattens the contextual profile and lowers it to MD_prof.
10 // This should happen after all IPO (which is assumed to have maintained the
11 // contextual profile) happened. Flattening consists of summing the values at
12 // the same index of the counters belonging to all the contexts of a function.
13 // The lowering consists of materializing the counter values to function
14 // entrypoint counts and branch probabilities.
15 //
16 // This pass also removes contextual instrumentation, which has been kept around
17 // to facilitate its functionality.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Analysis/CFG.h"
25 #include "llvm/Analysis/CtxProfAnalysis.h"
26 #include "llvm/Analysis/ProfileSummaryInfo.h"
27 #include "llvm/IR/Analysis.h"
28 #include "llvm/IR/CFG.h"
29 #include "llvm/IR/Dominators.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PassManager.h"
34 #include "llvm/IR/ProfileSummary.h"
35 #include "llvm/ProfileData/ProfileCommon.h"
36 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
37 #include "llvm/Transforms/Scalar/DCE.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39 
40 using namespace llvm;
41 
42 #define DEBUG_TYPE "ctx_prof_flatten"
43 
44 namespace {
45 
46 /// Assign branch weights and function entry count. Also update the PSI
47 /// builder.
assignProfileData(Function & F,ArrayRef<uint64_t> RawCounters)48 void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) {
49   assert(!RawCounters.empty());
50   ProfileAnnotator PA(F, RawCounters);
51 
52   F.setEntryCount(RawCounters[0]);
53   SmallVector<uint64_t, 2> ProfileHolder;
54 
55   for (auto &BB : F) {
56     for (auto &I : BB)
57       if (auto *SI = dyn_cast<SelectInst>(&I)) {
58         uint64_t TrueCount, FalseCount = 0;
59         if (!PA.getSelectInstrProfile(*SI, TrueCount, FalseCount))
60           continue;
61         setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
62                         std::max(TrueCount, FalseCount));
63       }
64     if (succ_size(&BB) < 2)
65       continue;
66     uint64_t MaxCount = 0;
67     if (!PA.getOutgoingBranchWeights(BB, ProfileHolder, MaxCount))
68       continue;
69     assert(MaxCount > 0);
70     setProfMetadata(F.getParent(), BB.getTerminator(), ProfileHolder, MaxCount);
71   }
72 }
73 
areAllBBsReachable(const Function & F,FunctionAnalysisManager & FAM)74 [[maybe_unused]] bool areAllBBsReachable(const Function &F,
75                                          FunctionAnalysisManager &FAM) {
76   auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));
77   return llvm::all_of(
78       F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); });
79 }
80 
clearColdFunctionProfile(Function & F)81 void clearColdFunctionProfile(Function &F) {
82   for (auto &BB : F)
83     BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);
84   F.setEntryCount(0U);
85 }
86 
removeInstrumentation(Function & F)87 void removeInstrumentation(Function &F) {
88   for (auto &BB : F)
89     for (auto &I : llvm::make_early_inc_range(BB))
90       if (isa<InstrProfCntrInstBase>(I))
91         I.eraseFromParent();
92 }
93 
annotateIndirectCall(Module & M,CallBase & CB,const DenseMap<uint32_t,FlatIndirectTargets> & FlatProf,const InstrProfCallsite & Ins)94 void annotateIndirectCall(
95     Module &M, CallBase &CB,
96     const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf,
97     const InstrProfCallsite &Ins) {
98   auto Idx = Ins.getIndex()->getZExtValue();
99   auto FIt = FlatProf.find(Idx);
100   if (FIt == FlatProf.end())
101     return;
102   const auto &Targets = FIt->second;
103   SmallVector<InstrProfValueData, 2> Data;
104   uint64_t Sum = 0;
105   for (auto &[Guid, Count] : Targets) {
106     Data.push_back({/*.Value=*/Guid, /*.Count=*/Count});
107     Sum += Count;
108   }
109 
110   llvm::sort(Data,
111              [](const InstrProfValueData &A, const InstrProfValueData &B) {
112                return A.Count > B.Count;
113              });
114   llvm::annotateValueSite(M, CB, Data, Sum,
115                           InstrProfValueKind::IPVK_IndirectCallTarget,
116                           Data.size());
117   LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB
118                     << CB.getMetadata(LLVMContext::MD_prof) << "\n");
119 }
120 
121 // We normally return a "Changed" bool, but the calling pass' run assumes
122 // something will change - some profile will be added - so this won't add much
123 // by returning false when applicable.
annotateIndirectCalls(Module & M,const CtxProfAnalysis::Result & CtxProf)124 void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) {
125   const auto FlatIndCalls = CtxProf.flattenVirtCalls();
126   for (auto &F : M) {
127     if (F.isDeclaration())
128       continue;
129     auto FlatProfIter = FlatIndCalls.find(AssignGUIDPass::getGUID(F));
130     if (FlatProfIter == FlatIndCalls.end())
131       continue;
132     const auto &FlatProf = FlatProfIter->second;
133     for (auto &BB : F) {
134       for (auto &I : BB) {
135         auto *CB = dyn_cast<CallBase>(&I);
136         if (!CB || !CB->isIndirectCall())
137           continue;
138         if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB))
139           annotateIndirectCall(M, *CB, FlatProf, *Ins);
140       }
141     }
142   }
143 }
144 
145 } // namespace
146 
run(Module & M,ModuleAnalysisManager & MAM)147 PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M,
148                                                 ModuleAnalysisManager &MAM) {
149   // Ensure in all cases the instrumentation is removed: if this module had no
150   // roots, the contextual profile would evaluate to false, but there would
151   // still be instrumentation.
152   // Note: in such cases we leave as-is any other profile info (if present -
153   // e.g. synthetic weights, etc) because it wouldn't interfere with the
154   // contextual - based one (which would be in other modules)
155   auto OnExit = llvm::make_scope_exit([&]() {
156     if (IsPreThinlink)
157       return;
158     for (auto &F : M)
159       removeInstrumentation(F);
160   });
161   auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
162   // post-thinlink, we only reprocess for the module(s) containing the
163   // contextual tree. For everything else, OnExit will just clean the
164   // instrumentation.
165   if (!IsPreThinlink && !CtxProf.isInSpecializedModule())
166     return PreservedAnalyses::none();
167 
168   if (IsPreThinlink)
169     annotateIndirectCalls(M, CtxProf);
170   const auto FlattenedProfile = CtxProf.flatten();
171 
172   for (auto &F : M) {
173     if (F.isDeclaration())
174       continue;
175 
176     assert(areAllBBsReachable(
177                F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)
178                       .getManager()) &&
179            "Function has unreacheable basic blocks. The expectation was that "
180            "DCE was run before.");
181 
182     auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F));
183     // If this function didn't appear in the contextual profile, it's cold.
184     if (It == FlattenedProfile.end())
185       clearColdFunctionProfile(F);
186     else
187       assignProfileData(F, It->second);
188   }
189   InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
190   // use here the flat profiles just so the importer doesn't complain about
191   // how different the PSIs are between the module with the roots and the
192   // various modules it imports.
193   for (auto &C : FlattenedProfile) {
194     PB.addEntryCount(C.second[0]);
195     for (auto V : llvm::drop_begin(C.second))
196       PB.addInternalCount(V);
197   }
198 
199   M.setProfileSummary(PB.getSummary()->getMD(M.getContext()),
200                       ProfileSummary::Kind::PSK_Instr);
201   PreservedAnalyses PA;
202   PA.abandon<ProfileSummaryAnalysis>();
203   MAM.invalidate(M, PA);
204   auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
205   PSI.refresh(PB.getSummary());
206   return PreservedAnalyses::none();
207 }
208