1 //=== AMDGPUPrintfRuntimeBinding.cpp - OpenCL printf implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // \file 9 // 10 // The pass bind printfs to a kernel arg pointer that will be bound to a buffer 11 // later by the runtime. 12 // 13 // This pass traverses the functions in the module and converts 14 // each call to printf to a sequence of operations that 15 // store the following into the printf buffer: 16 // - format string (passed as a module's metadata unique ID) 17 // - bitwise copies of printf arguments 18 // The backend passes will need to store metadata in the kernel 19 //===----------------------------------------------------------------------===// 20 21 #include "AMDGPU.h" 22 #include "llvm/ADT/Triple.h" 23 #include "llvm/Analysis/InstructionSimplify.h" 24 #include "llvm/Analysis/TargetLibraryInfo.h" 25 #include "llvm/Analysis/ValueTracking.h" 26 #include "llvm/IR/DiagnosticInfo.h" 27 #include "llvm/IR/Dominators.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/InitializePasses.h" 31 #include "llvm/Support/DataExtractor.h" 32 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "printfToRuntime" 37 #define DWORD_ALIGN 4 38 39 namespace { 40 class AMDGPUPrintfRuntimeBinding final : public ModulePass { 41 42 public: 43 static char ID; 44 45 explicit AMDGPUPrintfRuntimeBinding(); 46 47 private: 48 bool runOnModule(Module &M) override; 49 50 void getAnalysisUsage(AnalysisUsage &AU) const override { 51 AU.addRequired<TargetLibraryInfoWrapperPass>(); 52 AU.addRequired<DominatorTreeWrapperPass>(); 53 } 54 }; 55 56 class AMDGPUPrintfRuntimeBindingImpl { 57 public: 58 AMDGPUPrintfRuntimeBindingImpl( 59 function_ref<const DominatorTree &(Function &)> GetDT, 60 function_ref<const TargetLibraryInfo &(Function &)> GetTLI) 61 : GetDT(GetDT), GetTLI(GetTLI) {} 62 bool run(Module &M); 63 64 private: 65 void getConversionSpecifiers(SmallVectorImpl<char> &OpConvSpecifiers, 66 StringRef fmt, size_t num_ops) const; 67 68 bool lowerPrintfForGpu(Module &M); 69 70 Value *simplify(Instruction *I, const TargetLibraryInfo *TLI, 71 const DominatorTree *DT) { 72 return simplifyInstruction(I, {*TD, TLI, DT}); 73 } 74 75 const DataLayout *TD; 76 function_ref<const DominatorTree &(Function &)> GetDT; 77 function_ref<const TargetLibraryInfo &(Function &)> GetTLI; 78 SmallVector<CallInst *, 32> Printfs; 79 }; 80 } // namespace 81 82 char AMDGPUPrintfRuntimeBinding::ID = 0; 83 84 INITIALIZE_PASS_BEGIN(AMDGPUPrintfRuntimeBinding, 85 "amdgpu-printf-runtime-binding", "AMDGPU Printf lowering", 86 false, false) 87 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 88 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 89 INITIALIZE_PASS_END(AMDGPUPrintfRuntimeBinding, "amdgpu-printf-runtime-binding", 90 "AMDGPU Printf lowering", false, false) 91 92 char &llvm::AMDGPUPrintfRuntimeBindingID = AMDGPUPrintfRuntimeBinding::ID; 93 94 namespace llvm { 95 ModulePass *createAMDGPUPrintfRuntimeBinding() { 96 return new AMDGPUPrintfRuntimeBinding(); 97 } 98 } // namespace llvm 99 100 AMDGPUPrintfRuntimeBinding::AMDGPUPrintfRuntimeBinding() : ModulePass(ID) { 101 initializeAMDGPUPrintfRuntimeBindingPass(*PassRegistry::getPassRegistry()); 102 } 103 104 void AMDGPUPrintfRuntimeBindingImpl::getConversionSpecifiers( 105 SmallVectorImpl<char> &OpConvSpecifiers, StringRef Fmt, 106 size_t NumOps) const { 107 // not all format characters are collected. 108 // At this time the format characters of interest 109 // are %p and %s, which use to know if we 110 // are either storing a literal string or a 111 // pointer to the printf buffer. 112 static const char ConvSpecifiers[] = "cdieEfgGaosuxXp"; 113 size_t CurFmtSpecifierIdx = 0; 114 size_t PrevFmtSpecifierIdx = 0; 115 116 while ((CurFmtSpecifierIdx = Fmt.find_first_of( 117 ConvSpecifiers, CurFmtSpecifierIdx)) != StringRef::npos) { 118 bool ArgDump = false; 119 StringRef CurFmt = Fmt.substr(PrevFmtSpecifierIdx, 120 CurFmtSpecifierIdx - PrevFmtSpecifierIdx); 121 size_t pTag = CurFmt.find_last_of("%"); 122 if (pTag != StringRef::npos) { 123 ArgDump = true; 124 while (pTag && CurFmt[--pTag] == '%') { 125 ArgDump = !ArgDump; 126 } 127 } 128 129 if (ArgDump) 130 OpConvSpecifiers.push_back(Fmt[CurFmtSpecifierIdx]); 131 132 PrevFmtSpecifierIdx = ++CurFmtSpecifierIdx; 133 } 134 } 135 136 static bool shouldPrintAsStr(char Specifier, Type *OpType) { 137 return Specifier == 's' && isa<PointerType>(OpType); 138 } 139 140 constexpr StringLiteral NonLiteralStr("???"); 141 static_assert(NonLiteralStr.size() == 3); 142 143 static StringRef getAsConstantStr(Value *V) { 144 StringRef S; 145 if (!getConstantStringInfo(V, S)) 146 S = NonLiteralStr; 147 148 return S; 149 } 150 151 static void diagnoseInvalidFormatString(const CallBase *CI) { 152 DiagnosticInfoUnsupported UnsupportedFormatStr( 153 *CI->getParent()->getParent(), 154 "printf format string must be a trivially resolved constant string " 155 "global variable", 156 CI->getDebugLoc()); 157 CI->getContext().diagnose(UnsupportedFormatStr); 158 } 159 160 bool AMDGPUPrintfRuntimeBindingImpl::lowerPrintfForGpu(Module &M) { 161 LLVMContext &Ctx = M.getContext(); 162 IRBuilder<> Builder(Ctx); 163 Type *I32Ty = Type::getInt32Ty(Ctx); 164 165 // Instead of creating global variables, the printf format strings are 166 // extracted and passed as metadata. This avoids polluting llvm's symbol 167 // tables in this module. Metadata is going to be extracted by the backend 168 // passes and inserted into the OpenCL binary as appropriate. 169 NamedMDNode *metaD = M.getOrInsertNamedMetadata("llvm.printf.fmts"); 170 unsigned UniqID = metaD->getNumOperands(); 171 172 for (auto *CI : Printfs) { 173 unsigned NumOps = CI->arg_size(); 174 175 SmallString<16> OpConvSpecifiers; 176 Value *Op = CI->getArgOperand(0); 177 178 if (auto LI = dyn_cast<LoadInst>(Op)) { 179 Op = LI->getPointerOperand(); 180 for (auto *Use : Op->users()) { 181 if (auto SI = dyn_cast<StoreInst>(Use)) { 182 Op = SI->getValueOperand(); 183 break; 184 } 185 } 186 } 187 188 if (auto I = dyn_cast<Instruction>(Op)) { 189 Value *Op_simplified = 190 simplify(I, &GetTLI(*I->getFunction()), &GetDT(*I->getFunction())); 191 if (Op_simplified) 192 Op = Op_simplified; 193 } 194 195 StringRef FormatStr; 196 if (!getConstantStringInfo(Op, FormatStr)) { 197 Value *Stripped = Op->stripPointerCasts(); 198 if (!isa<UndefValue>(Stripped) && !isa<ConstantPointerNull>(Stripped)) 199 diagnoseInvalidFormatString(CI); 200 continue; 201 } 202 203 // We need this call to ascertain that we are printing a string or a 204 // pointer. It takes out the specifiers and fills up the first arg. 205 getConversionSpecifiers(OpConvSpecifiers, FormatStr, NumOps - 1); 206 207 // Add metadata for the string 208 std::string AStreamHolder; 209 raw_string_ostream Sizes(AStreamHolder); 210 int Sum = DWORD_ALIGN; 211 Sizes << CI->arg_size() - 1; 212 Sizes << ':'; 213 for (unsigned ArgCount = 1; 214 ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); 215 ArgCount++) { 216 Value *Arg = CI->getArgOperand(ArgCount); 217 Type *ArgType = Arg->getType(); 218 unsigned ArgSize = TD->getTypeAllocSize(ArgType); 219 // 220 // ArgSize by design should be a multiple of DWORD_ALIGN, 221 // expand the arguments that do not follow this rule. 222 // 223 if (ArgSize % DWORD_ALIGN != 0) { 224 Type *ResType = Type::getInt32Ty(Ctx); 225 if (auto *VecType = dyn_cast<VectorType>(ArgType)) 226 ResType = VectorType::get(ResType, VecType->getElementCount()); 227 Builder.SetInsertPoint(CI); 228 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 229 230 if (ArgType->isFloatingPointTy()) { 231 Arg = Builder.CreateBitCast( 232 Arg, 233 IntegerType::getIntNTy(Ctx, ArgType->getPrimitiveSizeInBits())); 234 } 235 236 if (OpConvSpecifiers[ArgCount - 1] == 'x' || 237 OpConvSpecifiers[ArgCount - 1] == 'X' || 238 OpConvSpecifiers[ArgCount - 1] == 'u' || 239 OpConvSpecifiers[ArgCount - 1] == 'o') 240 Arg = Builder.CreateZExt(Arg, ResType); 241 else 242 Arg = Builder.CreateSExt(Arg, ResType); 243 ArgType = Arg->getType(); 244 ArgSize = TD->getTypeAllocSize(ArgType); 245 CI->setOperand(ArgCount, Arg); 246 } 247 if (OpConvSpecifiers[ArgCount - 1] == 'f') { 248 ConstantFP *FpCons = dyn_cast<ConstantFP>(Arg); 249 if (FpCons) 250 ArgSize = 4; 251 else { 252 FPExtInst *FpExt = dyn_cast<FPExtInst>(Arg); 253 if (FpExt && FpExt->getType()->isDoubleTy() && 254 FpExt->getOperand(0)->getType()->isFloatTy()) 255 ArgSize = 4; 256 } 257 } 258 if (shouldPrintAsStr(OpConvSpecifiers[ArgCount - 1], ArgType)) 259 ArgSize = alignTo(getAsConstantStr(Arg).size() + 1, 4); 260 261 LLVM_DEBUG(dbgs() << "Printf ArgSize (in buffer) = " << ArgSize 262 << " for type: " << *ArgType << '\n'); 263 Sizes << ArgSize << ':'; 264 Sum += ArgSize; 265 } 266 LLVM_DEBUG(dbgs() << "Printf format string in source = " << FormatStr 267 << '\n'); 268 for (char C : FormatStr) { 269 // Rest of the C escape sequences (e.g. \') are handled correctly 270 // by the MDParser 271 switch (C) { 272 case '\a': 273 Sizes << "\\a"; 274 break; 275 case '\b': 276 Sizes << "\\b"; 277 break; 278 case '\f': 279 Sizes << "\\f"; 280 break; 281 case '\n': 282 Sizes << "\\n"; 283 break; 284 case '\r': 285 Sizes << "\\r"; 286 break; 287 case '\v': 288 Sizes << "\\v"; 289 break; 290 case ':': 291 // ':' cannot be scanned by Flex, as it is defined as a delimiter 292 // Replace it with it's octal representation \72 293 Sizes << "\\72"; 294 break; 295 default: 296 Sizes << C; 297 break; 298 } 299 } 300 301 // Insert the printf_alloc call 302 Builder.SetInsertPoint(CI); 303 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 304 305 AttributeList Attr = AttributeList::get(Ctx, AttributeList::FunctionIndex, 306 Attribute::NoUnwind); 307 308 Type *SizetTy = Type::getInt32Ty(Ctx); 309 310 Type *Tys_alloc[1] = {SizetTy}; 311 Type *I8Ty = Type::getInt8Ty(Ctx); 312 Type *I8Ptr = PointerType::get(I8Ty, 1); 313 FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false); 314 FunctionCallee PrintfAllocFn = 315 M.getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr); 316 317 LLVM_DEBUG(dbgs() << "Printf metadata = " << Sizes.str() << '\n'); 318 std::string fmtstr = itostr(++UniqID) + ":" + Sizes.str(); 319 MDString *fmtStrArray = MDString::get(Ctx, fmtstr); 320 321 MDNode *myMD = MDNode::get(Ctx, fmtStrArray); 322 metaD->addOperand(myMD); 323 Value *sumC = ConstantInt::get(SizetTy, Sum, false); 324 SmallVector<Value *, 1> alloc_args; 325 alloc_args.push_back(sumC); 326 CallInst *pcall = 327 CallInst::Create(PrintfAllocFn, alloc_args, "printf_alloc_fn", CI); 328 329 // 330 // Insert code to split basicblock with a 331 // piece of hammock code. 332 // basicblock splits after buffer overflow check 333 // 334 ConstantPointerNull *zeroIntPtr = 335 ConstantPointerNull::get(PointerType::get(I8Ty, 1)); 336 auto *cmp = cast<ICmpInst>(Builder.CreateICmpNE(pcall, zeroIntPtr, "")); 337 if (!CI->use_empty()) { 338 Value *result = 339 Builder.CreateSExt(Builder.CreateNot(cmp), I32Ty, "printf_res"); 340 CI->replaceAllUsesWith(result); 341 } 342 SplitBlock(CI->getParent(), cmp); 343 Instruction *Brnch = 344 SplitBlockAndInsertIfThen(cmp, cmp->getNextNode(), false); 345 346 Builder.SetInsertPoint(Brnch); 347 348 // store unique printf id in the buffer 349 // 350 GetElementPtrInst *BufferIdx = GetElementPtrInst::Create( 351 I8Ty, pcall, ConstantInt::get(Ctx, APInt(32, 0)), "PrintBuffID", Brnch); 352 353 Type *idPointer = PointerType::get(I32Ty, AMDGPUAS::GLOBAL_ADDRESS); 354 Value *id_gep_cast = 355 new BitCastInst(BufferIdx, idPointer, "PrintBuffIdCast", Brnch); 356 357 new StoreInst(ConstantInt::get(I32Ty, UniqID), id_gep_cast, Brnch); 358 359 // 1st 4 bytes hold the printf_id 360 // the following GEP is the buffer pointer 361 BufferIdx = GetElementPtrInst::Create(I8Ty, pcall, 362 ConstantInt::get(Ctx, APInt(32, 4)), 363 "PrintBuffGep", Brnch); 364 365 Type *Int32Ty = Type::getInt32Ty(Ctx); 366 for (unsigned ArgCount = 1; 367 ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); 368 ArgCount++) { 369 Value *Arg = CI->getArgOperand(ArgCount); 370 Type *ArgType = Arg->getType(); 371 SmallVector<Value *, 32> WhatToStore; 372 if (ArgType->isFPOrFPVectorTy() && !isa<VectorType>(ArgType)) { 373 if (OpConvSpecifiers[ArgCount - 1] == 'f') { 374 if (auto *FpCons = dyn_cast<ConstantFP>(Arg)) { 375 APFloat Val(FpCons->getValueAPF()); 376 bool Lost = false; 377 Val.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, 378 &Lost); 379 Arg = ConstantFP::get(Ctx, Val); 380 } else if (auto *FpExt = dyn_cast<FPExtInst>(Arg)) { 381 if (FpExt->getType()->isDoubleTy() && 382 FpExt->getOperand(0)->getType()->isFloatTy()) { 383 Arg = FpExt->getOperand(0); 384 } 385 } 386 } 387 WhatToStore.push_back(Arg); 388 } else if (isa<PointerType>(ArgType)) { 389 if (shouldPrintAsStr(OpConvSpecifiers[ArgCount - 1], ArgType)) { 390 StringRef S = getAsConstantStr(Arg); 391 if (!S.empty()) { 392 const uint64_t ReadSize = 4; 393 394 DataExtractor Extractor(S, /*IsLittleEndian=*/true, 8); 395 DataExtractor::Cursor Offset(0); 396 while (Offset && Offset.tell() < S.size()) { 397 uint64_t ReadNow = std::min(ReadSize, S.size() - Offset.tell()); 398 uint64_t ReadBytes = 0; 399 switch (ReadNow) { 400 default: llvm_unreachable("min(4, X) > 4?"); 401 case 1: 402 ReadBytes = Extractor.getU8(Offset); 403 break; 404 case 2: 405 ReadBytes = Extractor.getU16(Offset); 406 break; 407 case 3: 408 ReadBytes = Extractor.getU24(Offset); 409 break; 410 case 4: 411 ReadBytes = Extractor.getU32(Offset); 412 break; 413 } 414 415 cantFail(Offset.takeError(), 416 "failed to read bytes from constant array"); 417 418 APInt IntVal(8 * ReadSize, ReadBytes); 419 420 // TODO: Should not bothering aligning up. 421 if (ReadNow < ReadSize) 422 IntVal = IntVal.zext(8 * ReadSize); 423 424 Type *IntTy = Type::getIntNTy(Ctx, IntVal.getBitWidth()); 425 WhatToStore.push_back(ConstantInt::get(IntTy, IntVal)); 426 } 427 } else { 428 // Empty string, give a hint to RT it is no NULL 429 Value *ANumV = ConstantInt::get(Int32Ty, 0xFFFFFF00, false); 430 WhatToStore.push_back(ANumV); 431 } 432 } else { 433 WhatToStore.push_back(Arg); 434 } 435 } else { 436 WhatToStore.push_back(Arg); 437 } 438 for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) { 439 Value *TheBtCast = WhatToStore[I]; 440 unsigned ArgSize = TD->getTypeAllocSize(TheBtCast->getType()); 441 SmallVector<Value *, 1> BuffOffset; 442 BuffOffset.push_back(ConstantInt::get(I32Ty, ArgSize)); 443 444 Type *ArgPointer = PointerType::get(TheBtCast->getType(), 1); 445 Value *CastedGEP = 446 new BitCastInst(BufferIdx, ArgPointer, "PrintBuffPtrCast", Brnch); 447 StoreInst *StBuff = new StoreInst(TheBtCast, CastedGEP, Brnch); 448 LLVM_DEBUG(dbgs() << "inserting store to printf buffer:\n" 449 << *StBuff << '\n'); 450 (void)StBuff; 451 if (I + 1 == E && ArgCount + 1 == CI->arg_size()) 452 break; 453 BufferIdx = GetElementPtrInst::Create(I8Ty, BufferIdx, BuffOffset, 454 "PrintBuffNextPtr", Brnch); 455 LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:\n" 456 << *BufferIdx << '\n'); 457 } 458 } 459 } 460 461 // erase the printf calls 462 for (auto *CI : Printfs) 463 CI->eraseFromParent(); 464 465 Printfs.clear(); 466 return true; 467 } 468 469 bool AMDGPUPrintfRuntimeBindingImpl::run(Module &M) { 470 Triple TT(M.getTargetTriple()); 471 if (TT.getArch() == Triple::r600) 472 return false; 473 474 auto PrintfFunction = M.getFunction("printf"); 475 if (!PrintfFunction || !PrintfFunction->isDeclaration()) 476 return false; 477 478 for (auto &U : PrintfFunction->uses()) { 479 if (auto *CI = dyn_cast<CallInst>(U.getUser())) { 480 if (CI->isCallee(&U)) 481 Printfs.push_back(CI); 482 } 483 } 484 485 if (Printfs.empty()) 486 return false; 487 488 TD = &M.getDataLayout(); 489 490 return lowerPrintfForGpu(M); 491 } 492 493 bool AMDGPUPrintfRuntimeBinding::runOnModule(Module &M) { 494 auto GetDT = [this](Function &F) -> DominatorTree & { 495 return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); 496 }; 497 auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { 498 return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 499 }; 500 501 return AMDGPUPrintfRuntimeBindingImpl(GetDT, GetTLI).run(M); 502 } 503 504 PreservedAnalyses 505 AMDGPUPrintfRuntimeBindingPass::run(Module &M, ModuleAnalysisManager &AM) { 506 FunctionAnalysisManager &FAM = 507 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 508 auto GetDT = [&FAM](Function &F) -> DominatorTree & { 509 return FAM.getResult<DominatorTreeAnalysis>(F); 510 }; 511 auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { 512 return FAM.getResult<TargetLibraryAnalysis>(F); 513 }; 514 bool Changed = AMDGPUPrintfRuntimeBindingImpl(GetDT, GetTLI).run(M); 515 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 516 } 517