1 //===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h" 11 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 12 #include "llvm/IR/Analysis.h" 13 #include "llvm/IR/DiagnosticInfo.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/Instructions.h" 16 #include "llvm/IR/IntrinsicInst.h" 17 #include "llvm/IR/Module.h" 18 #include "llvm/IR/PassManager.h" 19 #include "llvm/Support/CommandLine.h" 20 #include <utility> 21 22 using namespace llvm; 23 24 #define DEBUG_TYPE "ctx-instr-lower" 25 26 static cl::list<std::string> ContextRoots( 27 "profile-context-root", cl::Hidden, 28 cl::desc( 29 "A function name, assumed to be global, which will be treated as the " 30 "root of an interesting graph, which will be profiled independently " 31 "from other similar graphs.")); 32 33 bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() { 34 return !ContextRoots.empty(); 35 } 36 37 // the names of symbols we expect in compiler-rt. Using a namespace for 38 // readability. 39 namespace CompilerRtAPINames { 40 static auto StartCtx = "__llvm_ctx_profile_start_context"; 41 static auto ReleaseCtx = "__llvm_ctx_profile_release_context"; 42 static auto GetCtx = "__llvm_ctx_profile_get_context"; 43 static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee"; 44 static auto CallsiteTLS = "__llvm_ctx_profile_callsite"; 45 } // namespace CompilerRtAPINames 46 47 namespace { 48 // The lowering logic and state. 49 class CtxInstrumentationLowerer final { 50 Module &M; 51 ModuleAnalysisManager &MAM; 52 Type *ContextNodeTy = nullptr; 53 Type *ContextRootTy = nullptr; 54 55 DenseMap<const Function *, Constant *> ContextRootMap; 56 Function *StartCtx = nullptr; 57 Function *GetCtx = nullptr; 58 Function *ReleaseCtx = nullptr; 59 GlobalVariable *ExpectedCalleeTLS = nullptr; 60 GlobalVariable *CallsiteInfoTLS = nullptr; 61 62 public: 63 CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM); 64 // return true if lowering happened (i.e. a change was made) 65 bool lowerFunction(Function &F); 66 }; 67 68 // llvm.instrprof.increment[.step] captures the total number of counters as one 69 // of its parameters, and llvm.instrprof.callsite captures the total number of 70 // callsites. Those values are the same for instances of those intrinsics in 71 // this function. Find the first instance of each and return them. 72 std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) { 73 uint32_t NrCounters = 0; 74 uint32_t NrCallsites = 0; 75 for (const auto &BB : F) { 76 for (const auto &I : BB) { 77 if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) { 78 uint32_t V = 79 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue()); 80 assert((!NrCounters || V == NrCounters) && 81 "expected all llvm.instrprof.increment[.step] intrinsics to " 82 "have the same total nr of counters parameter"); 83 NrCounters = V; 84 } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) { 85 uint32_t V = 86 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue()); 87 assert((!NrCallsites || V == NrCallsites) && 88 "expected all llvm.instrprof.callsite intrinsics to have the " 89 "same total nr of callsites parameter"); 90 NrCallsites = V; 91 } 92 #if NDEBUG 93 if (NrCounters && NrCallsites) 94 return std::make_pair(NrCounters, NrCallsites); 95 #endif 96 } 97 } 98 return {NrCounters, NrCallsites}; 99 } 100 } // namespace 101 102 // set up tie-in with compiler-rt. 103 // NOTE!!! 104 // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h 105 CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M, 106 ModuleAnalysisManager &MAM) 107 : M(M), MAM(MAM) { 108 auto *PointerTy = PointerType::get(M.getContext(), 0); 109 auto *SanitizerMutexType = Type::getInt8Ty(M.getContext()); 110 auto *I32Ty = Type::getInt32Ty(M.getContext()); 111 auto *I64Ty = Type::getInt64Ty(M.getContext()); 112 113 // The ContextRoot type 114 ContextRootTy = 115 StructType::get(M.getContext(), { 116 PointerTy, /*FirstNode*/ 117 PointerTy, /*FirstMemBlock*/ 118 PointerTy, /*CurrentMem*/ 119 SanitizerMutexType, /*Taken*/ 120 }); 121 // The Context header. 122 ContextNodeTy = StructType::get(M.getContext(), { 123 I64Ty, /*Guid*/ 124 PointerTy, /*Next*/ 125 I32Ty, /*NrCounters*/ 126 I32Ty, /*NrCallsites*/ 127 }); 128 129 // Define a global for each entrypoint. We'll reuse the entrypoint's name as 130 // prefix. We assume the entrypoint names to be unique. 131 for (const auto &Fname : ContextRoots) { 132 if (const auto *F = M.getFunction(Fname)) { 133 if (F->isDeclaration()) 134 continue; 135 auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy); 136 cast<GlobalVariable>(G)->setInitializer( 137 Constant::getNullValue(ContextRootTy)); 138 ContextRootMap.insert(std::make_pair(F, G)); 139 for (const auto &BB : *F) 140 for (const auto &I : BB) 141 if (const auto *CB = dyn_cast<CallBase>(&I)) 142 if (CB->isMustTailCall()) { 143 M.getContext().emitError( 144 "The function " + Fname + 145 " was indicated as a context root, but it features musttail " 146 "calls, which is not supported."); 147 } 148 } 149 } 150 151 // Declare the functions we will call. 152 StartCtx = cast<Function>( 153 M.getOrInsertFunction( 154 CompilerRtAPINames::StartCtx, 155 FunctionType::get(ContextNodeTy->getPointerTo(), 156 {ContextRootTy->getPointerTo(), /*ContextRoot*/ 157 I64Ty, /*Guid*/ I32Ty, 158 /*NrCounters*/ I32Ty /*NrCallsites*/}, 159 false)) 160 .getCallee()); 161 GetCtx = cast<Function>( 162 M.getOrInsertFunction(CompilerRtAPINames::GetCtx, 163 FunctionType::get(ContextNodeTy->getPointerTo(), 164 {PointerTy, /*Callee*/ 165 I64Ty, /*Guid*/ 166 I32Ty, /*NrCounters*/ 167 I32Ty}, /*NrCallsites*/ 168 false)) 169 .getCallee()); 170 ReleaseCtx = cast<Function>( 171 M.getOrInsertFunction( 172 CompilerRtAPINames::ReleaseCtx, 173 FunctionType::get(Type::getVoidTy(M.getContext()), 174 { 175 ContextRootTy->getPointerTo(), /*ContextRoot*/ 176 }, 177 false)) 178 .getCallee()); 179 180 // Declare the TLSes we will need to use. 181 CallsiteInfoTLS = 182 new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, 183 nullptr, CompilerRtAPINames::CallsiteTLS); 184 CallsiteInfoTLS->setThreadLocal(true); 185 CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); 186 ExpectedCalleeTLS = 187 new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, 188 nullptr, CompilerRtAPINames::ExpectedCalleeTLS); 189 ExpectedCalleeTLS->setThreadLocal(true); 190 ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); 191 } 192 193 PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M, 194 ModuleAnalysisManager &MAM) { 195 CtxInstrumentationLowerer Lowerer(M, MAM); 196 bool Changed = false; 197 for (auto &F : M) 198 Changed |= Lowerer.lowerFunction(F); 199 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 200 } 201 202 bool CtxInstrumentationLowerer::lowerFunction(Function &F) { 203 if (F.isDeclaration()) 204 return false; 205 auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 206 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); 207 208 Value *Guid = nullptr; 209 auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F); 210 211 Value *Context = nullptr; 212 Value *RealContext = nullptr; 213 214 StructType *ThisContextType = nullptr; 215 Value *TheRootContext = nullptr; 216 Value *ExpectedCalleeTLSAddr = nullptr; 217 Value *CallsiteInfoTLSAddr = nullptr; 218 219 auto &Head = F.getEntryBlock(); 220 for (auto &I : Head) { 221 // Find the increment intrinsic in the entry basic block. 222 if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) { 223 assert(Mark->getIndex()->isZero()); 224 225 IRBuilder<> Builder(Mark); 226 // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName 227 Guid = Builder.getInt64(F.getGUID()); 228 // The type of the context of this function is now knowable since we have 229 // NrCallsites and NrCounters. We delcare it here because it's more 230 // convenient - we have the Builder. 231 ThisContextType = StructType::get( 232 F.getContext(), 233 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters), 234 ArrayType::get(Builder.getPtrTy(), NrCallsites)}); 235 // Figure out which way we obtain the context object for this function - 236 // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the 237 // former case, we also set TheRootContext since we need to release it 238 // at the end (plus it can be used to know if we have an entrypoint or a 239 // regular function) 240 auto Iter = ContextRootMap.find(&F); 241 if (Iter != ContextRootMap.end()) { 242 TheRootContext = Iter->second; 243 Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid, 244 Builder.getInt32(NrCounters), 245 Builder.getInt32(NrCallsites)}); 246 ORE.emit( 247 [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); }); 248 } else { 249 Context = 250 Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters), 251 Builder.getInt32(NrCallsites)}); 252 ORE.emit([&] { 253 return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F); 254 }); 255 } 256 // The context could be scratch. 257 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty()); 258 if (NrCallsites > 0) { 259 // Figure out which index of the TLS 2-element buffers to use. 260 // Scratch context => we use index == 1. Real contexts => index == 0. 261 auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1)); 262 // The GEPs corresponding to that index, in the respective TLS. 263 ExpectedCalleeTLSAddr = Builder.CreateGEP( 264 Builder.getInt8Ty()->getPointerTo(), 265 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index}); 266 CallsiteInfoTLSAddr = Builder.CreateGEP( 267 Builder.getInt32Ty(), 268 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index}); 269 } 270 // Because the context pointer may have LSB set (to indicate scratch), 271 // clear it for the value we use as base address for the counter vector. 272 // This way, if later we want to have "real" (not clobbered) buffers 273 // acting as scratch, the lowering (at least this part of it that deals 274 // with counters) stays the same. 275 RealContext = Builder.CreateIntToPtr( 276 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)), 277 ThisContextType->getPointerTo()); 278 I.eraseFromParent(); 279 break; 280 } 281 } 282 if (!Context) { 283 ORE.emit([&] { 284 return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F) 285 << "Function doesn't have instrumentation, skipping"; 286 }); 287 return false; 288 } 289 290 bool ContextWasReleased = false; 291 for (auto &BB : F) { 292 for (auto &I : llvm::make_early_inc_range(BB)) { 293 if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) { 294 IRBuilder<> Builder(Instr); 295 switch (Instr->getIntrinsicID()) { 296 case llvm::Intrinsic::instrprof_increment: 297 case llvm::Intrinsic::instrprof_increment_step: { 298 // Increments (or increment-steps) are just a typical load - increment 299 // - store in the RealContext. 300 auto *AsStep = cast<InstrProfIncrementInst>(Instr); 301 auto *GEP = Builder.CreateGEP( 302 ThisContextType, RealContext, 303 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()}); 304 Builder.CreateStore( 305 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP), 306 AsStep->getStep()), 307 GEP); 308 } break; 309 case llvm::Intrinsic::instrprof_callsite: 310 // callsite lowering: write the called value in the expected callee 311 // TLS we treat the TLS as volatile because of signal handlers and to 312 // avoid these being moved away from the callsite they decorate. 313 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr); 314 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr, 315 true); 316 // write the GEP of the slot in the sub-contexts portion of the 317 // context in TLS. Now, here, we use the actual Context value - as 318 // returned from compiler-rt - which may have the LSB set if the 319 // Context was scratch. Since the header of the context object and 320 // then the values are all 8-aligned (or, really, insofar as we care, 321 // they are even) - if the context is scratch (meaning, an odd value), 322 // so will the GEP. This is important because this is then visible to 323 // compiler-rt which will produce scratch contexts for callers that 324 // have a scratch context. 325 Builder.CreateStore( 326 Builder.CreateGEP(ThisContextType, Context, 327 {Builder.getInt32(0), Builder.getInt32(2), 328 CSIntrinsic->getIndex()}), 329 CallsiteInfoTLSAddr, true); 330 break; 331 } 332 I.eraseFromParent(); 333 } else if (TheRootContext && isa<ReturnInst>(I)) { 334 // Remember to release the context if we are an entrypoint. 335 IRBuilder<> Builder(&I); 336 Builder.CreateCall(ReleaseCtx, {TheRootContext}); 337 ContextWasReleased = true; 338 } 339 } 340 } 341 // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be 342 // to disallow this, (so this then stays as an error), another is to detect 343 // that and then do a wrapper or disallow the tail call. This only affects 344 // instrumentation, when we want to detect the call graph. 345 if (TheRootContext && !ContextWasReleased) 346 F.getContext().emitError( 347 "[ctx_prof] An entrypoint was instrumented but it has no `ret` " 348 "instructions above which to release the context: " + 349 F.getName()); 350 return true; 351 } 352