1 //=== AMDGPUPrintfRuntimeBinding.cpp - OpenCL printf implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // \file 9 // 10 // The pass bind printfs to a kernel arg pointer that will be bound to a buffer 11 // later by the runtime. 12 // 13 // This pass traverses the functions in the module and converts 14 // each call to printf to a sequence of operations that 15 // store the following into the printf buffer: 16 // - format string (passed as a module's metadata unique ID) 17 // - bitwise copies of printf arguments 18 // The backend passes will need to store metadata in the kernel 19 //===----------------------------------------------------------------------===// 20 21 #include "AMDGPU.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/Analysis/ValueTracking.h" 24 #include "llvm/IR/DiagnosticInfo.h" 25 #include "llvm/IR/Dominators.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/Module.h" 29 #include "llvm/InitializePasses.h" 30 #include "llvm/Support/DataExtractor.h" 31 #include "llvm/TargetParser/Triple.h" 32 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "printfToRuntime" 37 enum { DWORD_ALIGN = 4 }; 38 39 namespace { 40 class AMDGPUPrintfRuntimeBinding final : public ModulePass { 41 42 public: 43 static char ID; 44 45 explicit AMDGPUPrintfRuntimeBinding() : ModulePass(ID) {} 46 47 private: 48 bool runOnModule(Module &M) override; 49 }; 50 51 class AMDGPUPrintfRuntimeBindingImpl { 52 public: 53 AMDGPUPrintfRuntimeBindingImpl() = default; 54 bool run(Module &M); 55 56 private: 57 void getConversionSpecifiers(SmallVectorImpl<char> &OpConvSpecifiers, 58 StringRef fmt, size_t num_ops) const; 59 60 bool lowerPrintfForGpu(Module &M); 61 62 const DataLayout *TD; 63 SmallVector<CallInst *, 32> Printfs; 64 }; 65 } // namespace 66 67 char AMDGPUPrintfRuntimeBinding::ID = 0; 68 69 INITIALIZE_PASS_BEGIN(AMDGPUPrintfRuntimeBinding, 70 "amdgpu-printf-runtime-binding", "AMDGPU Printf lowering", 71 false, false) 72 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 73 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 74 INITIALIZE_PASS_END(AMDGPUPrintfRuntimeBinding, "amdgpu-printf-runtime-binding", 75 "AMDGPU Printf lowering", false, false) 76 77 char &llvm::AMDGPUPrintfRuntimeBindingID = AMDGPUPrintfRuntimeBinding::ID; 78 79 ModulePass *llvm::createAMDGPUPrintfRuntimeBinding() { 80 return new AMDGPUPrintfRuntimeBinding(); 81 } 82 83 void AMDGPUPrintfRuntimeBindingImpl::getConversionSpecifiers( 84 SmallVectorImpl<char> &OpConvSpecifiers, StringRef Fmt, 85 size_t NumOps) const { 86 // not all format characters are collected. 87 // At this time the format characters of interest 88 // are %p and %s, which use to know if we 89 // are either storing a literal string or a 90 // pointer to the printf buffer. 91 static const char ConvSpecifiers[] = "cdieEfFgGaAosuxXp"; 92 size_t CurFmtSpecifierIdx = 0; 93 size_t PrevFmtSpecifierIdx = 0; 94 95 while ((CurFmtSpecifierIdx = Fmt.find_first_of( 96 ConvSpecifiers, CurFmtSpecifierIdx)) != StringRef::npos) { 97 bool ArgDump = false; 98 StringRef CurFmt = Fmt.substr(PrevFmtSpecifierIdx, 99 CurFmtSpecifierIdx - PrevFmtSpecifierIdx); 100 size_t pTag = CurFmt.find_last_of('%'); 101 if (pTag != StringRef::npos) { 102 ArgDump = true; 103 while (pTag && CurFmt[--pTag] == '%') { 104 ArgDump = !ArgDump; 105 } 106 } 107 108 if (ArgDump) 109 OpConvSpecifiers.push_back(Fmt[CurFmtSpecifierIdx]); 110 111 PrevFmtSpecifierIdx = ++CurFmtSpecifierIdx; 112 } 113 } 114 115 static bool shouldPrintAsStr(char Specifier, Type *OpType) { 116 return Specifier == 's' && isa<PointerType>(OpType); 117 } 118 119 constexpr StringLiteral NonLiteralStr("???"); 120 static_assert(NonLiteralStr.size() == 3); 121 122 static StringRef getAsConstantStr(Value *V) { 123 StringRef S; 124 if (!getConstantStringInfo(V, S)) 125 S = NonLiteralStr; 126 127 return S; 128 } 129 130 static void diagnoseInvalidFormatString(const CallBase *CI) { 131 CI->getContext().diagnose(DiagnosticInfoUnsupported( 132 *CI->getParent()->getParent(), 133 "printf format string must be a trivially resolved constant string " 134 "global variable", 135 CI->getDebugLoc())); 136 } 137 138 bool AMDGPUPrintfRuntimeBindingImpl::lowerPrintfForGpu(Module &M) { 139 LLVMContext &Ctx = M.getContext(); 140 IRBuilder<> Builder(Ctx); 141 Type *I32Ty = Type::getInt32Ty(Ctx); 142 143 // Instead of creating global variables, the printf format strings are 144 // extracted and passed as metadata. This avoids polluting llvm's symbol 145 // tables in this module. Metadata is going to be extracted by the backend 146 // passes and inserted into the OpenCL binary as appropriate. 147 NamedMDNode *metaD = M.getOrInsertNamedMetadata("llvm.printf.fmts"); 148 unsigned UniqID = metaD->getNumOperands(); 149 150 for (auto *CI : Printfs) { 151 unsigned NumOps = CI->arg_size(); 152 153 SmallString<16> OpConvSpecifiers; 154 Value *Op = CI->getArgOperand(0); 155 156 StringRef FormatStr; 157 if (!getConstantStringInfo(Op, FormatStr)) { 158 Value *Stripped = Op->stripPointerCasts(); 159 if (!isa<UndefValue>(Stripped) && !isa<ConstantPointerNull>(Stripped)) 160 diagnoseInvalidFormatString(CI); 161 continue; 162 } 163 164 // We need this call to ascertain that we are printing a string or a 165 // pointer. It takes out the specifiers and fills up the first arg. 166 getConversionSpecifiers(OpConvSpecifiers, FormatStr, NumOps - 1); 167 168 // Add metadata for the string 169 std::string AStreamHolder; 170 raw_string_ostream Sizes(AStreamHolder); 171 int Sum = DWORD_ALIGN; 172 Sizes << CI->arg_size() - 1; 173 Sizes << ':'; 174 for (unsigned ArgCount = 1; 175 ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); 176 ArgCount++) { 177 Value *Arg = CI->getArgOperand(ArgCount); 178 Type *ArgType = Arg->getType(); 179 unsigned ArgSize = TD->getTypeAllocSize(ArgType); 180 // 181 // ArgSize by design should be a multiple of DWORD_ALIGN, 182 // expand the arguments that do not follow this rule. 183 // 184 if (ArgSize % DWORD_ALIGN != 0) { 185 Type *ResType = Type::getInt32Ty(Ctx); 186 if (auto *VecType = dyn_cast<VectorType>(ArgType)) 187 ResType = VectorType::get(ResType, VecType->getElementCount()); 188 Builder.SetInsertPoint(CI); 189 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 190 191 if (ArgType->isFloatingPointTy()) { 192 Arg = Builder.CreateBitCast( 193 Arg, 194 IntegerType::getIntNTy(Ctx, ArgType->getPrimitiveSizeInBits())); 195 } 196 197 if (OpConvSpecifiers[ArgCount - 1] == 'x' || 198 OpConvSpecifiers[ArgCount - 1] == 'X' || 199 OpConvSpecifiers[ArgCount - 1] == 'u' || 200 OpConvSpecifiers[ArgCount - 1] == 'o') 201 Arg = Builder.CreateZExt(Arg, ResType); 202 else 203 Arg = Builder.CreateSExt(Arg, ResType); 204 ArgType = Arg->getType(); 205 ArgSize = TD->getTypeAllocSize(ArgType); 206 CI->setOperand(ArgCount, Arg); 207 } 208 if (OpConvSpecifiers[ArgCount - 1] == 'f') { 209 ConstantFP *FpCons = dyn_cast<ConstantFP>(Arg); 210 if (FpCons) 211 ArgSize = 4; 212 else { 213 FPExtInst *FpExt = dyn_cast<FPExtInst>(Arg); 214 if (FpExt && FpExt->getType()->isDoubleTy() && 215 FpExt->getOperand(0)->getType()->isFloatTy()) 216 ArgSize = 4; 217 } 218 } 219 if (shouldPrintAsStr(OpConvSpecifiers[ArgCount - 1], ArgType)) 220 ArgSize = alignTo(getAsConstantStr(Arg).size() + 1, 4); 221 222 LLVM_DEBUG(dbgs() << "Printf ArgSize (in buffer) = " << ArgSize 223 << " for type: " << *ArgType << '\n'); 224 Sizes << ArgSize << ':'; 225 Sum += ArgSize; 226 } 227 LLVM_DEBUG(dbgs() << "Printf format string in source = " << FormatStr 228 << '\n'); 229 for (char C : FormatStr) { 230 // Rest of the C escape sequences (e.g. \') are handled correctly 231 // by the MDParser 232 switch (C) { 233 case '\a': 234 Sizes << "\\a"; 235 break; 236 case '\b': 237 Sizes << "\\b"; 238 break; 239 case '\f': 240 Sizes << "\\f"; 241 break; 242 case '\n': 243 Sizes << "\\n"; 244 break; 245 case '\r': 246 Sizes << "\\r"; 247 break; 248 case '\v': 249 Sizes << "\\v"; 250 break; 251 case ':': 252 // ':' cannot be scanned by Flex, as it is defined as a delimiter 253 // Replace it with it's octal representation \72 254 Sizes << "\\72"; 255 break; 256 default: 257 Sizes << C; 258 break; 259 } 260 } 261 262 // Insert the printf_alloc call 263 Builder.SetInsertPoint(CI); 264 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 265 266 AttributeList Attr = AttributeList::get(Ctx, AttributeList::FunctionIndex, 267 Attribute::NoUnwind); 268 269 Type *SizetTy = Type::getInt32Ty(Ctx); 270 271 Type *Tys_alloc[1] = {SizetTy}; 272 Type *I8Ty = Type::getInt8Ty(Ctx); 273 Type *I8Ptr = PointerType::get(Ctx, 1); 274 FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false); 275 FunctionCallee PrintfAllocFn = 276 M.getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr); 277 278 LLVM_DEBUG(dbgs() << "Printf metadata = " << Sizes.str() << '\n'); 279 std::string fmtstr = itostr(++UniqID) + ":" + Sizes.str(); 280 MDString *fmtStrArray = MDString::get(Ctx, fmtstr); 281 282 MDNode *myMD = MDNode::get(Ctx, fmtStrArray); 283 metaD->addOperand(myMD); 284 Value *sumC = ConstantInt::get(SizetTy, Sum, false); 285 SmallVector<Value *, 1> alloc_args; 286 alloc_args.push_back(sumC); 287 CallInst *pcall = CallInst::Create(PrintfAllocFn, alloc_args, 288 "printf_alloc_fn", CI->getIterator()); 289 290 // 291 // Insert code to split basicblock with a 292 // piece of hammock code. 293 // basicblock splits after buffer overflow check 294 // 295 ConstantPointerNull *zeroIntPtr = 296 ConstantPointerNull::get(PointerType::get(Ctx, 1)); 297 auto *cmp = cast<ICmpInst>(Builder.CreateICmpNE(pcall, zeroIntPtr, "")); 298 if (!CI->use_empty()) { 299 Value *result = 300 Builder.CreateSExt(Builder.CreateNot(cmp), I32Ty, "printf_res"); 301 CI->replaceAllUsesWith(result); 302 } 303 SplitBlock(CI->getParent(), cmp); 304 Instruction *Brnch = 305 SplitBlockAndInsertIfThen(cmp, cmp->getNextNode(), false); 306 BasicBlock::iterator BrnchPoint = Brnch->getIterator(); 307 308 Builder.SetInsertPoint(Brnch); 309 310 // store unique printf id in the buffer 311 // 312 GetElementPtrInst *BufferIdx = GetElementPtrInst::Create( 313 I8Ty, pcall, ConstantInt::get(Ctx, APInt(32, 0)), "PrintBuffID", 314 BrnchPoint); 315 316 Type *idPointer = PointerType::get(Ctx, AMDGPUAS::GLOBAL_ADDRESS); 317 Value *id_gep_cast = 318 new BitCastInst(BufferIdx, idPointer, "PrintBuffIdCast", BrnchPoint); 319 320 new StoreInst(ConstantInt::get(I32Ty, UniqID), id_gep_cast, BrnchPoint); 321 322 // 1st 4 bytes hold the printf_id 323 // the following GEP is the buffer pointer 324 BufferIdx = GetElementPtrInst::Create(I8Ty, pcall, 325 ConstantInt::get(Ctx, APInt(32, 4)), 326 "PrintBuffGep", BrnchPoint); 327 328 Type *Int32Ty = Type::getInt32Ty(Ctx); 329 for (unsigned ArgCount = 1; 330 ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); 331 ArgCount++) { 332 Value *Arg = CI->getArgOperand(ArgCount); 333 Type *ArgType = Arg->getType(); 334 SmallVector<Value *, 32> WhatToStore; 335 if (ArgType->isFPOrFPVectorTy() && !isa<VectorType>(ArgType)) { 336 if (OpConvSpecifiers[ArgCount - 1] == 'f') { 337 if (auto *FpCons = dyn_cast<ConstantFP>(Arg)) { 338 APFloat Val(FpCons->getValueAPF()); 339 bool Lost = false; 340 Val.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, 341 &Lost); 342 Arg = ConstantFP::get(Ctx, Val); 343 } else if (auto *FpExt = dyn_cast<FPExtInst>(Arg)) { 344 if (FpExt->getType()->isDoubleTy() && 345 FpExt->getOperand(0)->getType()->isFloatTy()) { 346 Arg = FpExt->getOperand(0); 347 } 348 } 349 } 350 WhatToStore.push_back(Arg); 351 } else if (isa<PointerType>(ArgType)) { 352 if (shouldPrintAsStr(OpConvSpecifiers[ArgCount - 1], ArgType)) { 353 StringRef S = getAsConstantStr(Arg); 354 if (!S.empty()) { 355 const uint64_t ReadSize = 4; 356 357 DataExtractor Extractor(S, /*IsLittleEndian=*/true, 8); 358 DataExtractor::Cursor Offset(0); 359 while (Offset && Offset.tell() < S.size()) { 360 uint64_t ReadNow = std::min(ReadSize, S.size() - Offset.tell()); 361 uint64_t ReadBytes = 0; 362 switch (ReadNow) { 363 default: llvm_unreachable("min(4, X) > 4?"); 364 case 1: 365 ReadBytes = Extractor.getU8(Offset); 366 break; 367 case 2: 368 ReadBytes = Extractor.getU16(Offset); 369 break; 370 case 3: 371 ReadBytes = Extractor.getU24(Offset); 372 break; 373 case 4: 374 ReadBytes = Extractor.getU32(Offset); 375 break; 376 } 377 378 cantFail(Offset.takeError(), 379 "failed to read bytes from constant array"); 380 381 APInt IntVal(8 * ReadSize, ReadBytes); 382 383 // TODO: Should not bothering aligning up. 384 if (ReadNow < ReadSize) 385 IntVal = IntVal.zext(8 * ReadSize); 386 387 Type *IntTy = Type::getIntNTy(Ctx, IntVal.getBitWidth()); 388 WhatToStore.push_back(ConstantInt::get(IntTy, IntVal)); 389 } 390 } else { 391 // Empty string, give a hint to RT it is no NULL 392 Value *ANumV = ConstantInt::get(Int32Ty, 0xFFFFFF00, false); 393 WhatToStore.push_back(ANumV); 394 } 395 } else { 396 WhatToStore.push_back(Arg); 397 } 398 } else { 399 WhatToStore.push_back(Arg); 400 } 401 for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) { 402 Value *TheBtCast = WhatToStore[I]; 403 unsigned ArgSize = TD->getTypeAllocSize(TheBtCast->getType()); 404 StoreInst *StBuff = new StoreInst(TheBtCast, BufferIdx, BrnchPoint); 405 LLVM_DEBUG(dbgs() << "inserting store to printf buffer:\n" 406 << *StBuff << '\n'); 407 (void)StBuff; 408 if (I + 1 == E && ArgCount + 1 == CI->arg_size()) 409 break; 410 BufferIdx = GetElementPtrInst::Create( 411 I8Ty, BufferIdx, {ConstantInt::get(I32Ty, ArgSize)}, 412 "PrintBuffNextPtr", BrnchPoint); 413 LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:\n" 414 << *BufferIdx << '\n'); 415 } 416 } 417 } 418 419 // erase the printf calls 420 for (auto *CI : Printfs) 421 CI->eraseFromParent(); 422 423 Printfs.clear(); 424 return true; 425 } 426 427 bool AMDGPUPrintfRuntimeBindingImpl::run(Module &M) { 428 Triple TT(M.getTargetTriple()); 429 if (TT.getArch() == Triple::r600) 430 return false; 431 432 auto *PrintfFunction = M.getFunction("printf"); 433 if (!PrintfFunction || !PrintfFunction->isDeclaration() || 434 M.getModuleFlag("openmp")) 435 return false; 436 437 for (auto &U : PrintfFunction->uses()) { 438 if (auto *CI = dyn_cast<CallInst>(U.getUser())) { 439 if (CI->isCallee(&U) && !CI->isNoBuiltin()) 440 Printfs.push_back(CI); 441 } 442 } 443 444 if (Printfs.empty()) 445 return false; 446 447 TD = &M.getDataLayout(); 448 449 return lowerPrintfForGpu(M); 450 } 451 452 bool AMDGPUPrintfRuntimeBinding::runOnModule(Module &M) { 453 return AMDGPUPrintfRuntimeBindingImpl().run(M); 454 } 455 456 PreservedAnalyses 457 AMDGPUPrintfRuntimeBindingPass::run(Module &M, ModuleAnalysisManager &AM) { 458 bool Changed = AMDGPUPrintfRuntimeBindingImpl().run(M); 459 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 460 } 461