xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp (revision 90ec6a30353aa7caaf995ea50e2e23aa5a099600)
1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility function to lower a printf call into a series of device
10 // library calls on the AMDGPU target.
11 //
12 // WARNING: This file knows about certain library functions. It recognizes them
13 // by name, and hardwires knowledge of their semantics.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
18 #include "llvm/ADT/SparseBitVector.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/IRBuilder.h"
21 
22 #include <iostream>
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "amdgpu-emit-printf"
27 
28 static bool isCString(const Value *Arg) {
29   auto Ty = Arg->getType();
30   auto PtrTy = dyn_cast<PointerType>(Ty);
31   if (!PtrTy)
32     return false;
33 
34   auto IntTy = dyn_cast<IntegerType>(PtrTy->getElementType());
35   if (!IntTy)
36     return false;
37 
38   return IntTy->getBitWidth() == 8;
39 }
40 
41 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
42   auto Int64Ty = Builder.getInt64Ty();
43   auto Ty = Arg->getType();
44 
45   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
46     switch (IntTy->getBitWidth()) {
47     case 32:
48       return Builder.CreateZExt(Arg, Int64Ty);
49     case 64:
50       return Arg;
51     }
52   }
53 
54   if (Ty->getTypeID() == Type::DoubleTyID) {
55     return Builder.CreateBitCast(Arg, Int64Ty);
56   }
57 
58   if (isa<PointerType>(Ty)) {
59     return Builder.CreatePtrToInt(Arg, Int64Ty);
60   }
61 
62   llvm_unreachable("unexpected type");
63 }
64 
65 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
66   auto Int64Ty = Builder.getInt64Ty();
67   auto M = Builder.GetInsertBlock()->getModule();
68   auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
69   return Builder.CreateCall(Fn, Version);
70 }
71 
72 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
73                              Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
74                              Value *Arg4, Value *Arg5, Value *Arg6,
75                              bool IsLast) {
76   auto Int64Ty = Builder.getInt64Ty();
77   auto Int32Ty = Builder.getInt32Ty();
78   auto M = Builder.GetInsertBlock()->getModule();
79   auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
80                                    Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
81                                    Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
82   auto IsLastValue = Builder.getInt32(IsLast);
83   auto NumArgsValue = Builder.getInt32(NumArgs);
84   return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
85                                  Arg4, Arg5, Arg6, IsLastValue});
86 }
87 
88 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
89                         bool IsLast) {
90   auto Arg0 = fitArgInto64Bits(Builder, Arg);
91   auto Zero = Builder.getInt64(0);
92   return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
93                         Zero, IsLast);
94 }
95 
96 // The device library does not provide strlen, so we build our own loop
97 // here. While we are at it, we also include the terminating null in the length.
98 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
99   auto *Prev = Builder.GetInsertBlock();
100   Module *M = Prev->getModule();
101 
102   auto CharZero = Builder.getInt8(0);
103   auto One = Builder.getInt64(1);
104   auto Zero = Builder.getInt64(0);
105   auto Int64Ty = Builder.getInt64Ty();
106 
107   // The length is either zero for a null pointer, or the computed value for an
108   // actual string. We need a join block for a phi that represents the final
109   // value.
110   //
111   //  Strictly speaking, the zero does not matter since
112   // __ockl_printf_append_string_n ignores the length if the pointer is null.
113   BasicBlock *Join = nullptr;
114   if (Prev->getTerminator()) {
115     Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
116                                  "strlen.join");
117     Prev->getTerminator()->eraseFromParent();
118   } else {
119     Join = BasicBlock::Create(M->getContext(), "strlen.join",
120                               Prev->getParent());
121   }
122   BasicBlock *While =
123       BasicBlock::Create(M->getContext(), "strlen.while",
124                          Prev->getParent(), Join);
125   BasicBlock *WhileDone = BasicBlock::Create(
126       M->getContext(), "strlen.while.done",
127       Prev->getParent(), Join);
128 
129   // Emit an early return for when the pointer is null.
130   Builder.SetInsertPoint(Prev);
131   auto CmpNull =
132       Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
133   BranchInst::Create(Join, While, CmpNull, Prev);
134 
135   // Entry to the while loop.
136   Builder.SetInsertPoint(While);
137 
138   auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
139   PtrPhi->addIncoming(Str, Prev);
140   auto PtrNext = Builder.CreateGEP(PtrPhi, One);
141   PtrPhi->addIncoming(PtrNext, While);
142 
143   // Condition for the while loop.
144   auto Data = Builder.CreateLoad(PtrPhi);
145   auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
146   Builder.CreateCondBr(Cmp, WhileDone, While);
147 
148   // Add one to the computed length.
149   Builder.SetInsertPoint(WhileDone, WhileDone->begin());
150   auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
151   auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
152   auto Len = Builder.CreateSub(End, Begin);
153   Len = Builder.CreateAdd(Len, One);
154 
155   // Final join.
156   BranchInst::Create(Join, WhileDone);
157   Builder.SetInsertPoint(Join, Join->begin());
158   auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
159   LenPhi->addIncoming(Len, WhileDone);
160   LenPhi->addIncoming(Zero, Prev);
161 
162   return LenPhi;
163 }
164 
165 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
166                                 Value *Length, bool isLast) {
167   auto Int64Ty = Builder.getInt64Ty();
168   auto CharPtrTy = Builder.getInt8PtrTy();
169   auto Int32Ty = Builder.getInt32Ty();
170   auto M = Builder.GetInsertBlock()->getModule();
171   auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
172                                    Int64Ty, CharPtrTy, Int64Ty, Int32Ty);
173   auto IsLastInt32 = Builder.getInt32(isLast);
174   return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
175 }
176 
177 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
178                            bool IsLast) {
179   auto Length = getStrlenWithNull(Builder, Arg);
180   return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
181 }
182 
183 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
184                          bool SpecIsCString, bool IsLast) {
185   if (SpecIsCString && isCString(Arg)) {
186     return appendString(Builder, Desc, Arg, IsLast);
187   }
188   // If the format specifies a string but the argument is not, the frontend will
189   // have printed a warning. We just rely on undefined behaviour and send the
190   // argument anyway.
191   return appendArg(Builder, Desc, Arg, IsLast);
192 }
193 
194 // Scan the format string to locate all specifiers, and mark the ones that
195 // specify a string, i.e, the "%s" specifier with optional '*' characters.
196 static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) {
197   StringRef Str;
198   if (!getConstantStringInfo(Fmt, Str) || Str.empty())
199     return;
200 
201   static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
202   size_t SpecPos = 0;
203   // Skip the first argument, the format string.
204   unsigned ArgIdx = 1;
205 
206   while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
207     if (Str[SpecPos + 1] == '%') {
208       SpecPos += 2;
209       continue;
210     }
211     auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
212     if (SpecEnd == StringRef::npos)
213       return;
214     auto Spec = Str.slice(SpecPos, SpecEnd + 1);
215     ArgIdx += Spec.count('*');
216     if (Str[SpecEnd] == 's') {
217       BV.set(ArgIdx);
218     }
219     SpecPos = SpecEnd + 1;
220     ++ArgIdx;
221   }
222 }
223 
224 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder,
225                                   ArrayRef<Value *> Args) {
226   auto NumOps = Args.size();
227   assert(NumOps >= 1);
228 
229   auto Fmt = Args[0];
230   SparseBitVector<8> SpecIsCString;
231   locateCStrings(SpecIsCString, Fmt);
232 
233   auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
234   Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
235 
236   // FIXME: This invokes hostcall once for each argument. We can pack up to
237   // seven scalar printf arguments in a single hostcall. See the signature of
238   // callAppendArgs().
239   for (unsigned int i = 1; i != NumOps; ++i) {
240     bool IsLast = i == NumOps - 1;
241     bool IsCString = SpecIsCString.test(i);
242     Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
243   }
244 
245   return Builder.CreateTrunc(Desc, Builder.getInt32Ty());
246 }
247