xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 
10 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11 #include "llvm/ADT/STLExtras.h"
12 #include "llvm/Analysis/CFG.h"
13 #include "llvm/Analysis/CtxProfAnalysis.h"
14 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
15 #include "llvm/IR/Analysis.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/DiagnosticInfo.h"
18 #include "llvm/IR/GlobalValue.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/PassManager.h"
25 #include "llvm/ProfileData/CtxInstrContextNode.h"
26 #include "llvm/ProfileData/InstrProf.h"
27 #include "llvm/Support/CommandLine.h"
28 #include <utility>
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "ctx-instr-lower"
33 
34 static cl::list<std::string> ContextRoots(
35     "profile-context-root", cl::Hidden,
36     cl::desc(
37         "A function name, assumed to be global, which will be treated as the "
38         "root of an interesting graph, which will be profiled independently "
39         "from other similar graphs."));
40 
isCtxIRPGOInstrEnabled()41 bool PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() {
42   return !ContextRoots.empty();
43 }
44 
45 // the names of symbols we expect in compiler-rt. Using a namespace for
46 // readability.
47 namespace CompilerRtAPINames {
48 static auto StartCtx = "__llvm_ctx_profile_start_context";
49 static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
50 static auto GetCtx = "__llvm_ctx_profile_get_context";
51 static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
52 static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
53 } // namespace CompilerRtAPINames
54 
55 namespace {
56 // The lowering logic and state.
57 class CtxInstrumentationLowerer final {
58   Module &M;
59   ModuleAnalysisManager &MAM;
60   Type *ContextNodeTy = nullptr;
61   StructType *FunctionDataTy = nullptr;
62 
63   DenseSet<const Function *> ContextRootSet;
64   Function *StartCtx = nullptr;
65   Function *GetCtx = nullptr;
66   Function *ReleaseCtx = nullptr;
67   GlobalVariable *ExpectedCalleeTLS = nullptr;
68   GlobalVariable *CallsiteInfoTLS = nullptr;
69   Constant *CannotBeRootInitializer = nullptr;
70 
71 public:
72   CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
73   // return true if lowering happened (i.e. a change was made)
74   bool lowerFunction(Function &F);
75 };
76 
77 // llvm.instrprof.increment[.step] captures the total number of counters as one
78 // of its parameters, and llvm.instrprof.callsite captures the total number of
79 // callsites. Those values are the same for instances of those intrinsics in
80 // this function. Find the first instance of each and return them.
getNumCountersAndCallsites(const Function & F)81 std::pair<uint32_t, uint32_t> getNumCountersAndCallsites(const Function &F) {
82   uint32_t NumCounters = 0;
83   uint32_t NumCallsites = 0;
84   for (const auto &BB : F) {
85     for (const auto &I : BB) {
86       if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
87         uint32_t V =
88             static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
89         assert((!NumCounters || V == NumCounters) &&
90                "expected all llvm.instrprof.increment[.step] intrinsics to "
91                "have the same total nr of counters parameter");
92         NumCounters = V;
93       } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
94         uint32_t V =
95             static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
96         assert((!NumCallsites || V == NumCallsites) &&
97                "expected all llvm.instrprof.callsite intrinsics to have the "
98                "same total nr of callsites parameter");
99         NumCallsites = V;
100       }
101 #if NDEBUG
102       if (NumCounters && NumCallsites)
103         return std::make_pair(NumCounters, NumCallsites);
104 #endif
105     }
106   }
107   return {NumCounters, NumCallsites};
108 }
109 
emitUnsupportedRootError(const Function & F,StringRef Reason)110 void emitUnsupportedRootError(const Function &F, StringRef Reason) {
111   F.getContext().emitError("[ctxprof] The function " + F.getName() +
112                            " was indicated as context root but " + Reason +
113                            ", which is not supported.");
114 }
115 } // namespace
116 
117 // set up tie-in with compiler-rt.
118 // NOTE!!!
119 // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
CtxInstrumentationLowerer(Module & M,ModuleAnalysisManager & MAM)120 CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
121                                                      ModuleAnalysisManager &MAM)
122     : M(M), MAM(MAM) {
123   auto *PointerTy = PointerType::get(M.getContext(), 0);
124   auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
125   auto *I32Ty = Type::getInt32Ty(M.getContext());
126   auto *I64Ty = Type::getInt64Ty(M.getContext());
127 
128 #define _PTRDECL(_, __) PointerTy,
129 #define _VOLATILE_PTRDECL(_, __) PointerTy,
130 #define _CONTEXT_ROOT PointerTy,
131 #define _MUTEXDECL(_) SanitizerMutexType,
132 
133   FunctionDataTy = StructType::get(
134       M.getContext(), {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
135                                              _VOLATILE_PTRDECL, _MUTEXDECL)});
136 #undef _PTRDECL
137 #undef _CONTEXT_ROOT
138 #undef _VOLATILE_PTRDECL
139 #undef _MUTEXDECL
140 
141 #define _PTRDECL(_, __) Constant::getNullValue(PointerTy),
142 #define _VOLATILE_PTRDECL(_, __) _PTRDECL(_, __)
143 #define _MUTEXDECL(_) Constant::getNullValue(SanitizerMutexType),
144 #define _CONTEXT_ROOT                                                          \
145   Constant::getIntegerValue(                                                   \
146       PointerTy,                                                               \
147       APInt(M.getDataLayout().getPointerTypeSizeInBits(PointerTy), 1U)),
148   CannotBeRootInitializer = ConstantStruct::get(
149       FunctionDataTy, {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
150                                              _VOLATILE_PTRDECL, _MUTEXDECL)});
151 #undef _PTRDECL
152 #undef _CONTEXT_ROOT
153 #undef _VOLATILE_PTRDECL
154 #undef _MUTEXDECL
155 
156   // The Context header.
157   ContextNodeTy = StructType::get(M.getContext(), {
158                                                       I64Ty,     /*Guid*/
159                                                       PointerTy, /*Next*/
160                                                       I32Ty,     /*NumCounters*/
161                                                       I32Ty, /*NumCallsites*/
162                                                   });
163 
164   // Define a global for each entrypoint. We'll reuse the entrypoint's name
165   // as prefix. We assume the entrypoint names to be unique.
166   for (const auto &Fname : ContextRoots) {
167     if (const auto *F = M.getFunction(Fname)) {
168       if (F->isDeclaration())
169         continue;
170       ContextRootSet.insert(F);
171       for (const auto &BB : *F)
172         for (const auto &I : BB)
173           if (const auto *CB = dyn_cast<CallBase>(&I))
174             if (CB->isMustTailCall())
175               emitUnsupportedRootError(*F, "it features musttail calls");
176     }
177   }
178 
179   // Declare the functions we will call.
180   StartCtx = cast<Function>(
181       M.getOrInsertFunction(
182            CompilerRtAPINames::StartCtx,
183            FunctionType::get(PointerTy,
184                              {PointerTy, /*FunctionData*/
185                               I64Ty, /*Guid*/ I32Ty,
186                               /*NumCounters*/ I32Ty /*NumCallsites*/},
187                              false))
188           .getCallee());
189   GetCtx = cast<Function>(
190       M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
191                             FunctionType::get(PointerTy,
192                                               {PointerTy, /*FunctionData*/
193                                                PointerTy, /*Callee*/
194                                                I64Ty,     /*Guid*/
195                                                I32Ty,     /*NumCounters*/
196                                                I32Ty},    /*NumCallsites*/
197                                               false))
198           .getCallee());
199   ReleaseCtx = cast<Function>(
200       M.getOrInsertFunction(CompilerRtAPINames::ReleaseCtx,
201                             FunctionType::get(Type::getVoidTy(M.getContext()),
202                                               {
203                                                   PointerTy, /*FunctionData*/
204                                               },
205                                               false))
206           .getCallee());
207 
208   // Declare the TLSes we will need to use.
209   CallsiteInfoTLS =
210       new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
211                          nullptr, CompilerRtAPINames::CallsiteTLS);
212   CallsiteInfoTLS->setThreadLocal(true);
213   CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
214   ExpectedCalleeTLS =
215       new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
216                          nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
217   ExpectedCalleeTLS->setThreadLocal(true);
218   ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
219 }
220 
run(Module & M,ModuleAnalysisManager & MAM)221 PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
222                                               ModuleAnalysisManager &MAM) {
223   CtxInstrumentationLowerer Lowerer(M, MAM);
224   bool Changed = false;
225   for (auto &F : M)
226     Changed |= Lowerer.lowerFunction(F);
227   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
228 }
229 
lowerFunction(Function & F)230 bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
231   if (F.isDeclaration())
232     return false;
233 
234   // Probably pointless to try to do anything here, unlikely to be
235   // performance-affecting.
236   if (!llvm::canReturn(F)) {
237     for (auto &BB : F)
238       for (auto &I : make_early_inc_range(BB))
239         if (isa<InstrProfCntrInstBase>(&I))
240           I.eraseFromParent();
241     if (ContextRootSet.contains(&F))
242       emitUnsupportedRootError(F, "it does not return");
243     return true;
244   }
245 
246   auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
247   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
248 
249   Value *Guid = nullptr;
250   auto [NumCounters, NumCallsites] = getNumCountersAndCallsites(F);
251 
252   Value *Context = nullptr;
253   Value *RealContext = nullptr;
254 
255   StructType *ThisContextType = nullptr;
256   Value *TheRootFuctionData = nullptr;
257   Value *ExpectedCalleeTLSAddr = nullptr;
258   Value *CallsiteInfoTLSAddr = nullptr;
259   const bool HasMusttail = [&F]() {
260     for (auto &BB : F)
261       for (auto &I : BB)
262         if (auto *CB = dyn_cast<CallBase>(&I))
263           if (CB->isMustTailCall())
264             return true;
265     return false;
266   }();
267 
268   if (HasMusttail && ContextRootSet.contains(&F)) {
269     F.getContext().emitError(
270         "[ctx_prof] A function with musttail calls was explicitly requested as "
271         "root. That is not supported because we cannot instrument a return "
272         "instruction to release the context: " +
273         F.getName());
274     return false;
275   }
276   auto &Head = F.getEntryBlock();
277   for (auto &I : Head) {
278     // Find the increment intrinsic in the entry basic block.
279     if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
280       assert(Mark->getIndex()->isZero());
281 
282       IRBuilder<> Builder(Mark);
283       Guid = Builder.getInt64(
284           AssignGUIDPass::getGUID(cast<Function>(*Mark->getNameValue())));
285       // The type of the context of this function is now knowable since we have
286       // NumCallsites and NumCounters. We delcare it here because it's more
287       // convenient - we have the Builder.
288       ThisContextType = StructType::get(
289           F.getContext(),
290           {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NumCounters),
291            ArrayType::get(Builder.getPtrTy(), NumCallsites)});
292       // Figure out which way we obtain the context object for this function -
293       // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
294       // former case, we also set TheRootFuctionData since we need to release it
295       // at the end (plus it can be used to know if we have an entrypoint or a
296       // regular function)
297       // Don't set a name, they end up taking a lot of space and we don't need
298       // them.
299 
300       // Zero-initialize the FunctionData, except for functions that have
301       // musttail calls. There, we set the CtxRoot field to 1, which will be
302       // treated as a "can't be set as root".
303       TheRootFuctionData = new GlobalVariable(
304           M, FunctionDataTy, false, GlobalVariable::InternalLinkage,
305           HasMusttail ? CannotBeRootInitializer
306                       : Constant::getNullValue(FunctionDataTy));
307 
308       if (ContextRootSet.contains(&F)) {
309         Context = Builder.CreateCall(
310             StartCtx, {TheRootFuctionData, Guid, Builder.getInt32(NumCounters),
311                        Builder.getInt32(NumCallsites)});
312         ORE.emit(
313             [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
314       } else {
315         Context = Builder.CreateCall(GetCtx, {TheRootFuctionData, &F, Guid,
316                                               Builder.getInt32(NumCounters),
317                                               Builder.getInt32(NumCallsites)});
318         ORE.emit([&] {
319           return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
320         });
321       }
322       // The context could be scratch.
323       auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
324       if (NumCallsites > 0) {
325         // Figure out which index of the TLS 2-element buffers to use.
326         // Scratch context => we use index == 1. Real contexts => index == 0.
327         auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
328         // The GEPs corresponding to that index, in the respective TLS.
329         ExpectedCalleeTLSAddr = Builder.CreateGEP(
330             PointerType::getUnqual(F.getContext()),
331             Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
332         CallsiteInfoTLSAddr = Builder.CreateGEP(
333             Builder.getInt32Ty(),
334             Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
335       }
336       // Because the context pointer may have LSB set (to indicate scratch),
337       // clear it for the value we use as base address for the counter vector.
338       // This way, if later we want to have "real" (not clobbered) buffers
339       // acting as scratch, the lowering (at least this part of it that deals
340       // with counters) stays the same.
341       RealContext = Builder.CreateIntToPtr(
342           Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
343           PointerType::getUnqual(F.getContext()));
344       I.eraseFromParent();
345       break;
346     }
347   }
348   if (!Context) {
349     ORE.emit([&] {
350       return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
351              << "Function doesn't have instrumentation, skipping";
352     });
353     return false;
354   }
355 
356   bool ContextWasReleased = false;
357   for (auto &BB : F) {
358     for (auto &I : llvm::make_early_inc_range(BB)) {
359       if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
360         IRBuilder<> Builder(Instr);
361         switch (Instr->getIntrinsicID()) {
362         case llvm::Intrinsic::instrprof_increment:
363         case llvm::Intrinsic::instrprof_increment_step: {
364           // Increments (or increment-steps) are just a typical load - increment
365           // - store in the RealContext.
366           auto *AsStep = cast<InstrProfIncrementInst>(Instr);
367           auto *GEP = Builder.CreateGEP(
368               ThisContextType, RealContext,
369               {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
370           Builder.CreateStore(
371               Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
372                                 AsStep->getStep()),
373               GEP);
374         } break;
375         case llvm::Intrinsic::instrprof_callsite:
376           // callsite lowering: write the called value in the expected callee
377           // TLS we treat the TLS as volatile because of signal handlers and to
378           // avoid these being moved away from the callsite they decorate.
379           auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
380           Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
381                               true);
382           // write the GEP of the slot in the sub-contexts portion of the
383           // context in TLS. Now, here, we use the actual Context value - as
384           // returned from compiler-rt - which may have the LSB set if the
385           // Context was scratch. Since the header of the context object and
386           // then the values are all 8-aligned (or, really, insofar as we care,
387           // they are even) - if the context is scratch (meaning, an odd value),
388           // so will the GEP. This is important because this is then visible to
389           // compiler-rt which will produce scratch contexts for callers that
390           // have a scratch context.
391           Builder.CreateStore(
392               Builder.CreateGEP(ThisContextType, Context,
393                                 {Builder.getInt32(0), Builder.getInt32(2),
394                                  CSIntrinsic->getIndex()}),
395               CallsiteInfoTLSAddr, true);
396           break;
397         }
398         I.eraseFromParent();
399       } else if (!HasMusttail && isa<ReturnInst>(I)) {
400         // Remember to release the context if we are an entrypoint.
401         IRBuilder<> Builder(&I);
402         Builder.CreateCall(ReleaseCtx, {TheRootFuctionData});
403         ContextWasReleased = true;
404       }
405     }
406   }
407   if (!HasMusttail && !ContextWasReleased)
408     F.getContext().emitError(
409         "[ctx_prof] A function that doesn't have musttail calls was "
410         "instrumented but it has no `ret` "
411         "instructions above which to release the context: " +
412         F.getName());
413   return true;
414 }
415 
run(Module & M,ModuleAnalysisManager & MAM)416 PreservedAnalyses NoinlineNonPrevailing::run(Module &M,
417                                              ModuleAnalysisManager &MAM) {
418   bool Changed = false;
419   for (auto &F : M) {
420     if (F.isDeclaration())
421       continue;
422     if (F.hasFnAttribute(Attribute::NoInline))
423       continue;
424     if (!F.isWeakForLinker())
425       continue;
426 
427     if (F.hasFnAttribute(Attribute::AlwaysInline))
428       F.removeFnAttr(Attribute::AlwaysInline);
429 
430     F.addFnAttr(Attribute::NoInline);
431     Changed = true;
432   }
433   if (Changed)
434     return PreservedAnalyses::none();
435   return PreservedAnalyses::all();
436 }
437