xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp (revision b2d2a78ad80ec68d4a17f5aef97d21686cb1e29b)
1 //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 
10 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
12 #include "llvm/IR/Analysis.h"
13 #include "llvm/IR/DiagnosticInfo.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/IntrinsicInst.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/PassManager.h"
19 #include "llvm/Support/CommandLine.h"
20 #include <utility>
21 
22 using namespace llvm;
23 
24 #define DEBUG_TYPE "ctx-instr-lower"
25 
26 static cl::list<std::string> ContextRoots(
27     "profile-context-root", cl::Hidden,
28     cl::desc(
29         "A function name, assumed to be global, which will be treated as the "
30         "root of an interesting graph, which will be profiled independently "
31         "from other similar graphs."));
32 
33 bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
34   return !ContextRoots.empty();
35 }
36 
37 // the names of symbols we expect in compiler-rt. Using a namespace for
38 // readability.
39 namespace CompilerRtAPINames {
40 static auto StartCtx = "__llvm_ctx_profile_start_context";
41 static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
42 static auto GetCtx = "__llvm_ctx_profile_get_context";
43 static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
44 static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
45 } // namespace CompilerRtAPINames
46 
47 namespace {
48 // The lowering logic and state.
49 class CtxInstrumentationLowerer final {
50   Module &M;
51   ModuleAnalysisManager &MAM;
52   Type *ContextNodeTy = nullptr;
53   Type *ContextRootTy = nullptr;
54 
55   DenseMap<const Function *, Constant *> ContextRootMap;
56   Function *StartCtx = nullptr;
57   Function *GetCtx = nullptr;
58   Function *ReleaseCtx = nullptr;
59   GlobalVariable *ExpectedCalleeTLS = nullptr;
60   GlobalVariable *CallsiteInfoTLS = nullptr;
61 
62 public:
63   CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
64   // return true if lowering happened (i.e. a change was made)
65   bool lowerFunction(Function &F);
66 };
67 
68 // llvm.instrprof.increment[.step] captures the total number of counters as one
69 // of its parameters, and llvm.instrprof.callsite captures the total number of
70 // callsites. Those values are the same for instances of those intrinsics in
71 // this function. Find the first instance of each and return them.
72 std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
73   uint32_t NrCounters = 0;
74   uint32_t NrCallsites = 0;
75   for (const auto &BB : F) {
76     for (const auto &I : BB) {
77       if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
78         uint32_t V =
79             static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
80         assert((!NrCounters || V == NrCounters) &&
81                "expected all llvm.instrprof.increment[.step] intrinsics to "
82                "have the same total nr of counters parameter");
83         NrCounters = V;
84       } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
85         uint32_t V =
86             static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
87         assert((!NrCallsites || V == NrCallsites) &&
88                "expected all llvm.instrprof.callsite intrinsics to have the "
89                "same total nr of callsites parameter");
90         NrCallsites = V;
91       }
92 #if NDEBUG
93       if (NrCounters && NrCallsites)
94         return std::make_pair(NrCounters, NrCallsites);
95 #endif
96     }
97   }
98   return {NrCounters, NrCallsites};
99 }
100 } // namespace
101 
102 // set up tie-in with compiler-rt.
103 // NOTE!!!
104 // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
105 CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
106                                                      ModuleAnalysisManager &MAM)
107     : M(M), MAM(MAM) {
108   auto *PointerTy = PointerType::get(M.getContext(), 0);
109   auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
110   auto *I32Ty = Type::getInt32Ty(M.getContext());
111   auto *I64Ty = Type::getInt64Ty(M.getContext());
112 
113   // The ContextRoot type
114   ContextRootTy =
115       StructType::get(M.getContext(), {
116                                           PointerTy,          /*FirstNode*/
117                                           PointerTy,          /*FirstMemBlock*/
118                                           PointerTy,          /*CurrentMem*/
119                                           SanitizerMutexType, /*Taken*/
120                                       });
121   // The Context header.
122   ContextNodeTy = StructType::get(M.getContext(), {
123                                                       I64Ty,     /*Guid*/
124                                                       PointerTy, /*Next*/
125                                                       I32Ty,     /*NrCounters*/
126                                                       I32Ty,     /*NrCallsites*/
127                                                   });
128 
129   // Define a global for each entrypoint. We'll reuse the entrypoint's name as
130   // prefix. We assume the entrypoint names to be unique.
131   for (const auto &Fname : ContextRoots) {
132     if (const auto *F = M.getFunction(Fname)) {
133       if (F->isDeclaration())
134         continue;
135       auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
136       cast<GlobalVariable>(G)->setInitializer(
137           Constant::getNullValue(ContextRootTy));
138       ContextRootMap.insert(std::make_pair(F, G));
139       for (const auto &BB : *F)
140         for (const auto &I : BB)
141           if (const auto *CB = dyn_cast<CallBase>(&I))
142             if (CB->isMustTailCall()) {
143               M.getContext().emitError(
144                   "The function " + Fname +
145                   " was indicated as a context root, but it features musttail "
146                   "calls, which is not supported.");
147             }
148     }
149   }
150 
151   // Declare the functions we will call.
152   StartCtx = cast<Function>(
153       M.getOrInsertFunction(
154            CompilerRtAPINames::StartCtx,
155            FunctionType::get(ContextNodeTy->getPointerTo(),
156                              {ContextRootTy->getPointerTo(), /*ContextRoot*/
157                               I64Ty, /*Guid*/ I32Ty,
158                               /*NrCounters*/ I32Ty /*NrCallsites*/},
159                              false))
160           .getCallee());
161   GetCtx = cast<Function>(
162       M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
163                             FunctionType::get(ContextNodeTy->getPointerTo(),
164                                               {PointerTy, /*Callee*/
165                                                I64Ty,     /*Guid*/
166                                                I32Ty,     /*NrCounters*/
167                                                I32Ty},    /*NrCallsites*/
168                                               false))
169           .getCallee());
170   ReleaseCtx = cast<Function>(
171       M.getOrInsertFunction(
172            CompilerRtAPINames::ReleaseCtx,
173            FunctionType::get(Type::getVoidTy(M.getContext()),
174                              {
175                                  ContextRootTy->getPointerTo(), /*ContextRoot*/
176                              },
177                              false))
178           .getCallee());
179 
180   // Declare the TLSes we will need to use.
181   CallsiteInfoTLS =
182       new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
183                          nullptr, CompilerRtAPINames::CallsiteTLS);
184   CallsiteInfoTLS->setThreadLocal(true);
185   CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
186   ExpectedCalleeTLS =
187       new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
188                          nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
189   ExpectedCalleeTLS->setThreadLocal(true);
190   ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
191 }
192 
193 PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
194                                               ModuleAnalysisManager &MAM) {
195   CtxInstrumentationLowerer Lowerer(M, MAM);
196   bool Changed = false;
197   for (auto &F : M)
198     Changed |= Lowerer.lowerFunction(F);
199   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
200 }
201 
202 bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
203   if (F.isDeclaration())
204     return false;
205   auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
206   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
207 
208   Value *Guid = nullptr;
209   auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
210 
211   Value *Context = nullptr;
212   Value *RealContext = nullptr;
213 
214   StructType *ThisContextType = nullptr;
215   Value *TheRootContext = nullptr;
216   Value *ExpectedCalleeTLSAddr = nullptr;
217   Value *CallsiteInfoTLSAddr = nullptr;
218 
219   auto &Head = F.getEntryBlock();
220   for (auto &I : Head) {
221     // Find the increment intrinsic in the entry basic block.
222     if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
223       assert(Mark->getIndex()->isZero());
224 
225       IRBuilder<> Builder(Mark);
226       // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
227       Guid = Builder.getInt64(F.getGUID());
228       // The type of the context of this function is now knowable since we have
229       // NrCallsites and NrCounters. We delcare it here because it's more
230       // convenient - we have the Builder.
231       ThisContextType = StructType::get(
232           F.getContext(),
233           {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
234            ArrayType::get(Builder.getPtrTy(), NrCallsites)});
235       // Figure out which way we obtain the context object for this function -
236       // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
237       // former case, we also set TheRootContext since we need to release it
238       // at the end (plus it can be used to know if we have an entrypoint or a
239       // regular function)
240       auto Iter = ContextRootMap.find(&F);
241       if (Iter != ContextRootMap.end()) {
242         TheRootContext = Iter->second;
243         Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,
244                                                 Builder.getInt32(NrCounters),
245                                                 Builder.getInt32(NrCallsites)});
246         ORE.emit(
247             [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
248       } else {
249         Context =
250             Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
251                                         Builder.getInt32(NrCallsites)});
252         ORE.emit([&] {
253           return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
254         });
255       }
256       // The context could be scratch.
257       auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
258       if (NrCallsites > 0) {
259         // Figure out which index of the TLS 2-element buffers to use.
260         // Scratch context => we use index == 1. Real contexts => index == 0.
261         auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
262         // The GEPs corresponding to that index, in the respective TLS.
263         ExpectedCalleeTLSAddr = Builder.CreateGEP(
264             Builder.getInt8Ty()->getPointerTo(),
265             Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
266         CallsiteInfoTLSAddr = Builder.CreateGEP(
267             Builder.getInt32Ty(),
268             Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
269       }
270       // Because the context pointer may have LSB set (to indicate scratch),
271       // clear it for the value we use as base address for the counter vector.
272       // This way, if later we want to have "real" (not clobbered) buffers
273       // acting as scratch, the lowering (at least this part of it that deals
274       // with counters) stays the same.
275       RealContext = Builder.CreateIntToPtr(
276           Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
277           ThisContextType->getPointerTo());
278       I.eraseFromParent();
279       break;
280     }
281   }
282   if (!Context) {
283     ORE.emit([&] {
284       return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
285              << "Function doesn't have instrumentation, skipping";
286     });
287     return false;
288   }
289 
290   bool ContextWasReleased = false;
291   for (auto &BB : F) {
292     for (auto &I : llvm::make_early_inc_range(BB)) {
293       if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
294         IRBuilder<> Builder(Instr);
295         switch (Instr->getIntrinsicID()) {
296         case llvm::Intrinsic::instrprof_increment:
297         case llvm::Intrinsic::instrprof_increment_step: {
298           // Increments (or increment-steps) are just a typical load - increment
299           // - store in the RealContext.
300           auto *AsStep = cast<InstrProfIncrementInst>(Instr);
301           auto *GEP = Builder.CreateGEP(
302               ThisContextType, RealContext,
303               {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
304           Builder.CreateStore(
305               Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
306                                 AsStep->getStep()),
307               GEP);
308         } break;
309         case llvm::Intrinsic::instrprof_callsite:
310           // callsite lowering: write the called value in the expected callee
311           // TLS we treat the TLS as volatile because of signal handlers and to
312           // avoid these being moved away from the callsite they decorate.
313           auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
314           Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
315                               true);
316           // write the GEP of the slot in the sub-contexts portion of the
317           // context in TLS. Now, here, we use the actual Context value - as
318           // returned from compiler-rt - which may have the LSB set if the
319           // Context was scratch. Since the header of the context object and
320           // then the values are all 8-aligned (or, really, insofar as we care,
321           // they are even) - if the context is scratch (meaning, an odd value),
322           // so will the GEP. This is important because this is then visible to
323           // compiler-rt which will produce scratch contexts for callers that
324           // have a scratch context.
325           Builder.CreateStore(
326               Builder.CreateGEP(ThisContextType, Context,
327                                 {Builder.getInt32(0), Builder.getInt32(2),
328                                  CSIntrinsic->getIndex()}),
329               CallsiteInfoTLSAddr, true);
330           break;
331         }
332         I.eraseFromParent();
333       } else if (TheRootContext && isa<ReturnInst>(I)) {
334         // Remember to release the context if we are an entrypoint.
335         IRBuilder<> Builder(&I);
336         Builder.CreateCall(ReleaseCtx, {TheRootContext});
337         ContextWasReleased = true;
338       }
339     }
340   }
341   // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
342   // to disallow this, (so this then stays as an error), another is to detect
343   // that and then do a wrapper or disallow the tail call. This only affects
344   // instrumentation, when we want to detect the call graph.
345   if (TheRootContext && !ContextWasReleased)
346     F.getContext().emitError(
347         "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
348         "instructions above which to release the context: " +
349         F.getName());
350   return true;
351 }
352