1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility function to lower a printf call into a series of device 10 // library calls on the AMDGPU target. 11 // 12 // WARNING: This file knows about certain library functions. It recognizes them 13 // by name, and hardwires knowledge of their semantics. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h" 18 #include "llvm/ADT/SparseBitVector.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 21 using namespace llvm; 22 23 #define DEBUG_TYPE "amdgpu-emit-printf" 24 25 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) { 26 auto Int64Ty = Builder.getInt64Ty(); 27 auto Ty = Arg->getType(); 28 29 if (auto IntTy = dyn_cast<IntegerType>(Ty)) { 30 switch (IntTy->getBitWidth()) { 31 case 32: 32 return Builder.CreateZExt(Arg, Int64Ty); 33 case 64: 34 return Arg; 35 } 36 } 37 38 if (Ty->getTypeID() == Type::DoubleTyID) { 39 return Builder.CreateBitCast(Arg, Int64Ty); 40 } 41 42 if (isa<PointerType>(Ty)) { 43 return Builder.CreatePtrToInt(Arg, Int64Ty); 44 } 45 46 llvm_unreachable("unexpected type"); 47 } 48 49 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) { 50 auto Int64Ty = Builder.getInt64Ty(); 51 auto M = Builder.GetInsertBlock()->getModule(); 52 auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty); 53 if (!M->getModuleFlag("amdgpu_hostcall")) { 54 M->addModuleFlag(llvm::Module::Override, "amdgpu_hostcall", 1); 55 } 56 return Builder.CreateCall(Fn, Version); 57 } 58 59 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs, 60 Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3, 61 Value *Arg4, Value *Arg5, Value *Arg6, 62 bool IsLast) { 63 auto Int64Ty = Builder.getInt64Ty(); 64 auto Int32Ty = Builder.getInt32Ty(); 65 auto M = Builder.GetInsertBlock()->getModule(); 66 auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty, 67 Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty, 68 Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty); 69 auto IsLastValue = Builder.getInt32(IsLast); 70 auto NumArgsValue = Builder.getInt32(NumArgs); 71 return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3, 72 Arg4, Arg5, Arg6, IsLastValue}); 73 } 74 75 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg, 76 bool IsLast) { 77 auto Arg0 = fitArgInto64Bits(Builder, Arg); 78 auto Zero = Builder.getInt64(0); 79 return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero, 80 Zero, IsLast); 81 } 82 83 // The device library does not provide strlen, so we build our own loop 84 // here. While we are at it, we also include the terminating null in the length. 85 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) { 86 auto *Prev = Builder.GetInsertBlock(); 87 Module *M = Prev->getModule(); 88 89 auto CharZero = Builder.getInt8(0); 90 auto One = Builder.getInt64(1); 91 auto Zero = Builder.getInt64(0); 92 auto Int64Ty = Builder.getInt64Ty(); 93 94 // The length is either zero for a null pointer, or the computed value for an 95 // actual string. We need a join block for a phi that represents the final 96 // value. 97 // 98 // Strictly speaking, the zero does not matter since 99 // __ockl_printf_append_string_n ignores the length if the pointer is null. 100 BasicBlock *Join = nullptr; 101 if (Prev->getTerminator()) { 102 Join = Prev->splitBasicBlock(Builder.GetInsertPoint(), 103 "strlen.join"); 104 Prev->getTerminator()->eraseFromParent(); 105 } else { 106 Join = BasicBlock::Create(M->getContext(), "strlen.join", 107 Prev->getParent()); 108 } 109 BasicBlock *While = 110 BasicBlock::Create(M->getContext(), "strlen.while", 111 Prev->getParent(), Join); 112 BasicBlock *WhileDone = BasicBlock::Create( 113 M->getContext(), "strlen.while.done", 114 Prev->getParent(), Join); 115 116 // Emit an early return for when the pointer is null. 117 Builder.SetInsertPoint(Prev); 118 auto CmpNull = 119 Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType())); 120 BranchInst::Create(Join, While, CmpNull, Prev); 121 122 // Entry to the while loop. 123 Builder.SetInsertPoint(While); 124 125 auto PtrPhi = Builder.CreatePHI(Str->getType(), 2); 126 PtrPhi->addIncoming(Str, Prev); 127 auto PtrNext = Builder.CreateGEP(Builder.getInt8Ty(), PtrPhi, One); 128 PtrPhi->addIncoming(PtrNext, While); 129 130 // Condition for the while loop. 131 auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi); 132 auto Cmp = Builder.CreateICmpEQ(Data, CharZero); 133 Builder.CreateCondBr(Cmp, WhileDone, While); 134 135 // Add one to the computed length. 136 Builder.SetInsertPoint(WhileDone, WhileDone->begin()); 137 auto Begin = Builder.CreatePtrToInt(Str, Int64Ty); 138 auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty); 139 auto Len = Builder.CreateSub(End, Begin); 140 Len = Builder.CreateAdd(Len, One); 141 142 // Final join. 143 BranchInst::Create(Join, WhileDone); 144 Builder.SetInsertPoint(Join, Join->begin()); 145 auto LenPhi = Builder.CreatePHI(Len->getType(), 2); 146 LenPhi->addIncoming(Len, WhileDone); 147 LenPhi->addIncoming(Zero, Prev); 148 149 return LenPhi; 150 } 151 152 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str, 153 Value *Length, bool isLast) { 154 auto Int64Ty = Builder.getInt64Ty(); 155 auto CharPtrTy = Builder.getInt8PtrTy(); 156 auto Int32Ty = Builder.getInt32Ty(); 157 auto M = Builder.GetInsertBlock()->getModule(); 158 auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty, 159 Int64Ty, CharPtrTy, Int64Ty, Int32Ty); 160 auto IsLastInt32 = Builder.getInt32(isLast); 161 return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32}); 162 } 163 164 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg, 165 bool IsLast) { 166 Arg = Builder.CreateBitCast( 167 Arg, Builder.getInt8PtrTy(Arg->getType()->getPointerAddressSpace())); 168 auto Length = getStrlenWithNull(Builder, Arg); 169 return callAppendStringN(Builder, Desc, Arg, Length, IsLast); 170 } 171 172 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg, 173 bool SpecIsCString, bool IsLast) { 174 if (SpecIsCString && isa<PointerType>(Arg->getType())) { 175 return appendString(Builder, Desc, Arg, IsLast); 176 } 177 // If the format specifies a string but the argument is not, the frontend will 178 // have printed a warning. We just rely on undefined behaviour and send the 179 // argument anyway. 180 return appendArg(Builder, Desc, Arg, IsLast); 181 } 182 183 // Scan the format string to locate all specifiers, and mark the ones that 184 // specify a string, i.e, the "%s" specifier with optional '*' characters. 185 static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) { 186 StringRef Str; 187 if (!getConstantStringInfo(Fmt, Str) || Str.empty()) 188 return; 189 190 static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn"; 191 size_t SpecPos = 0; 192 // Skip the first argument, the format string. 193 unsigned ArgIdx = 1; 194 195 while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) { 196 if (Str[SpecPos + 1] == '%') { 197 SpecPos += 2; 198 continue; 199 } 200 auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos); 201 if (SpecEnd == StringRef::npos) 202 return; 203 auto Spec = Str.slice(SpecPos, SpecEnd + 1); 204 ArgIdx += Spec.count('*'); 205 if (Str[SpecEnd] == 's') { 206 BV.set(ArgIdx); 207 } 208 SpecPos = SpecEnd + 1; 209 ++ArgIdx; 210 } 211 } 212 213 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, 214 ArrayRef<Value *> Args) { 215 auto NumOps = Args.size(); 216 assert(NumOps >= 1); 217 218 auto Fmt = Args[0]; 219 SparseBitVector<8> SpecIsCString; 220 locateCStrings(SpecIsCString, Fmt); 221 222 auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0)); 223 Desc = appendString(Builder, Desc, Fmt, NumOps == 1); 224 225 // FIXME: This invokes hostcall once for each argument. We can pack up to 226 // seven scalar printf arguments in a single hostcall. See the signature of 227 // callAppendArgs(). 228 for (unsigned int i = 1; i != NumOps; ++i) { 229 bool IsLast = i == NumOps - 1; 230 bool IsCString = SpecIsCString.test(i); 231 Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast); 232 } 233 234 return Builder.CreateTrunc(Desc, Builder.getInt32Ty()); 235 } 236