1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utility function to lower a printf call into a series of device 10 // library calls on the AMDGPU target. 11 // 12 // WARNING: This file knows about certain library functions. It recognizes them 13 // by name, and hardwires knowledge of their semantics. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h" 18 #include "llvm/ADT/SparseBitVector.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/IR/IRBuilder.h" 21 22 #include <iostream> 23 24 using namespace llvm; 25 26 #define DEBUG_TYPE "amdgpu-emit-printf" 27 28 static bool isCString(const Value *Arg) { 29 auto Ty = Arg->getType(); 30 auto PtrTy = dyn_cast<PointerType>(Ty); 31 if (!PtrTy) 32 return false; 33 34 auto IntTy = dyn_cast<IntegerType>(PtrTy->getElementType()); 35 if (!IntTy) 36 return false; 37 38 return IntTy->getBitWidth() == 8; 39 } 40 41 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) { 42 auto Int64Ty = Builder.getInt64Ty(); 43 auto Ty = Arg->getType(); 44 45 if (auto IntTy = dyn_cast<IntegerType>(Ty)) { 46 switch (IntTy->getBitWidth()) { 47 case 32: 48 return Builder.CreateZExt(Arg, Int64Ty); 49 case 64: 50 return Arg; 51 } 52 } 53 54 if (Ty->getTypeID() == Type::DoubleTyID) { 55 return Builder.CreateBitCast(Arg, Int64Ty); 56 } 57 58 if (isa<PointerType>(Ty)) { 59 return Builder.CreatePtrToInt(Arg, Int64Ty); 60 } 61 62 llvm_unreachable("unexpected type"); 63 } 64 65 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) { 66 auto Int64Ty = Builder.getInt64Ty(); 67 auto M = Builder.GetInsertBlock()->getModule(); 68 auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty); 69 return Builder.CreateCall(Fn, Version); 70 } 71 72 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs, 73 Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3, 74 Value *Arg4, Value *Arg5, Value *Arg6, 75 bool IsLast) { 76 auto Int64Ty = Builder.getInt64Ty(); 77 auto Int32Ty = Builder.getInt32Ty(); 78 auto M = Builder.GetInsertBlock()->getModule(); 79 auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty, 80 Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty, 81 Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty); 82 auto IsLastValue = Builder.getInt32(IsLast); 83 auto NumArgsValue = Builder.getInt32(NumArgs); 84 return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3, 85 Arg4, Arg5, Arg6, IsLastValue}); 86 } 87 88 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg, 89 bool IsLast) { 90 auto Arg0 = fitArgInto64Bits(Builder, Arg); 91 auto Zero = Builder.getInt64(0); 92 return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero, 93 Zero, IsLast); 94 } 95 96 // The device library does not provide strlen, so we build our own loop 97 // here. While we are at it, we also include the terminating null in the length. 98 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) { 99 auto *Prev = Builder.GetInsertBlock(); 100 Module *M = Prev->getModule(); 101 102 auto CharZero = Builder.getInt8(0); 103 auto One = Builder.getInt64(1); 104 auto Zero = Builder.getInt64(0); 105 auto Int64Ty = Builder.getInt64Ty(); 106 107 // The length is either zero for a null pointer, or the computed value for an 108 // actual string. We need a join block for a phi that represents the final 109 // value. 110 // 111 // Strictly speaking, the zero does not matter since 112 // __ockl_printf_append_string_n ignores the length if the pointer is null. 113 BasicBlock *Join = nullptr; 114 if (Prev->getTerminator()) { 115 Join = Prev->splitBasicBlock(Builder.GetInsertPoint(), 116 "strlen.join"); 117 Prev->getTerminator()->eraseFromParent(); 118 } else { 119 Join = BasicBlock::Create(M->getContext(), "strlen.join", 120 Prev->getParent()); 121 } 122 BasicBlock *While = 123 BasicBlock::Create(M->getContext(), "strlen.while", 124 Prev->getParent(), Join); 125 BasicBlock *WhileDone = BasicBlock::Create( 126 M->getContext(), "strlen.while.done", 127 Prev->getParent(), Join); 128 129 // Emit an early return for when the pointer is null. 130 Builder.SetInsertPoint(Prev); 131 auto CmpNull = 132 Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType())); 133 BranchInst::Create(Join, While, CmpNull, Prev); 134 135 // Entry to the while loop. 136 Builder.SetInsertPoint(While); 137 138 auto PtrPhi = Builder.CreatePHI(Str->getType(), 2); 139 PtrPhi->addIncoming(Str, Prev); 140 auto PtrNext = Builder.CreateGEP(PtrPhi, One); 141 PtrPhi->addIncoming(PtrNext, While); 142 143 // Condition for the while loop. 144 auto Data = Builder.CreateLoad(PtrPhi); 145 auto Cmp = Builder.CreateICmpEQ(Data, CharZero); 146 Builder.CreateCondBr(Cmp, WhileDone, While); 147 148 // Add one to the computed length. 149 Builder.SetInsertPoint(WhileDone, WhileDone->begin()); 150 auto Begin = Builder.CreatePtrToInt(Str, Int64Ty); 151 auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty); 152 auto Len = Builder.CreateSub(End, Begin); 153 Len = Builder.CreateAdd(Len, One); 154 155 // Final join. 156 BranchInst::Create(Join, WhileDone); 157 Builder.SetInsertPoint(Join, Join->begin()); 158 auto LenPhi = Builder.CreatePHI(Len->getType(), 2); 159 LenPhi->addIncoming(Len, WhileDone); 160 LenPhi->addIncoming(Zero, Prev); 161 162 return LenPhi; 163 } 164 165 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str, 166 Value *Length, bool isLast) { 167 auto Int64Ty = Builder.getInt64Ty(); 168 auto CharPtrTy = Builder.getInt8PtrTy(); 169 auto Int32Ty = Builder.getInt32Ty(); 170 auto M = Builder.GetInsertBlock()->getModule(); 171 auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty, 172 Int64Ty, CharPtrTy, Int64Ty, Int32Ty); 173 auto IsLastInt32 = Builder.getInt32(isLast); 174 return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32}); 175 } 176 177 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg, 178 bool IsLast) { 179 auto Length = getStrlenWithNull(Builder, Arg); 180 return callAppendStringN(Builder, Desc, Arg, Length, IsLast); 181 } 182 183 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg, 184 bool SpecIsCString, bool IsLast) { 185 if (SpecIsCString && isCString(Arg)) { 186 return appendString(Builder, Desc, Arg, IsLast); 187 } 188 // If the format specifies a string but the argument is not, the frontend will 189 // have printed a warning. We just rely on undefined behaviour and send the 190 // argument anyway. 191 return appendArg(Builder, Desc, Arg, IsLast); 192 } 193 194 // Scan the format string to locate all specifiers, and mark the ones that 195 // specify a string, i.e, the "%s" specifier with optional '*' characters. 196 static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) { 197 StringRef Str; 198 if (!getConstantStringInfo(Fmt, Str) || Str.empty()) 199 return; 200 201 static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn"; 202 size_t SpecPos = 0; 203 // Skip the first argument, the format string. 204 unsigned ArgIdx = 1; 205 206 while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) { 207 if (Str[SpecPos + 1] == '%') { 208 SpecPos += 2; 209 continue; 210 } 211 auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos); 212 if (SpecEnd == StringRef::npos) 213 return; 214 auto Spec = Str.slice(SpecPos, SpecEnd + 1); 215 ArgIdx += Spec.count('*'); 216 if (Str[SpecEnd] == 's') { 217 BV.set(ArgIdx); 218 } 219 SpecPos = SpecEnd + 1; 220 ++ArgIdx; 221 } 222 } 223 224 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, 225 ArrayRef<Value *> Args) { 226 auto NumOps = Args.size(); 227 assert(NumOps >= 1); 228 229 auto Fmt = Args[0]; 230 SparseBitVector<8> SpecIsCString; 231 locateCStrings(SpecIsCString, Fmt); 232 233 auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0)); 234 Desc = appendString(Builder, Desc, Fmt, NumOps == 1); 235 236 // FIXME: This invokes hostcall once for each argument. We can pack up to 237 // seven scalar printf arguments in a single hostcall. See the signature of 238 // callAppendArgs(). 239 for (unsigned int i = 1; i != NumOps; ++i) { 240 bool IsLast = i == NumOps - 1; 241 bool IsCString = SpecIsCString.test(i); 242 Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast); 243 } 244 245 return Builder.CreateTrunc(Desc, Builder.getInt32Ty()); 246 } 247