xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility function to lower a printf call into a series of device
10 // library calls on the AMDGPU target.
11 //
12 // WARNING: This file knows about certain library functions. It recognizes them
13 // by name, and hardwires knowledge of their semantics.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"
18 #include "llvm/ADT/SparseBitVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/Support/DataExtractor.h"
23 #include "llvm/Support/MD5.h"
24 #include "llvm/Support/MathExtras.h"
25 
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "amdgpu-emit-printf"
29 
30 static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {
31   auto Int64Ty = Builder.getInt64Ty();
32   auto Ty = Arg->getType();
33 
34   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
35     switch (IntTy->getBitWidth()) {
36     case 32:
37       return Builder.CreateZExt(Arg, Int64Ty);
38     case 64:
39       return Arg;
40     }
41   }
42 
43   if (Ty->getTypeID() == Type::DoubleTyID) {
44     return Builder.CreateBitCast(Arg, Int64Ty);
45   }
46 
47   if (isa<PointerType>(Ty)) {
48     return Builder.CreatePtrToInt(Arg, Int64Ty);
49   }
50 
51   llvm_unreachable("unexpected type");
52 }
53 
54 static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {
55   auto Int64Ty = Builder.getInt64Ty();
56   auto M = Builder.GetInsertBlock()->getModule();
57   auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);
58   return Builder.CreateCall(Fn, Version);
59 }
60 
61 static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,
62                              Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,
63                              Value *Arg4, Value *Arg5, Value *Arg6,
64                              bool IsLast) {
65   auto Int64Ty = Builder.getInt64Ty();
66   auto Int32Ty = Builder.getInt32Ty();
67   auto M = Builder.GetInsertBlock()->getModule();
68   auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,
69                                    Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,
70                                    Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);
71   auto IsLastValue = Builder.getInt32(IsLast);
72   auto NumArgsValue = Builder.getInt32(NumArgs);
73   return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,
74                                  Arg4, Arg5, Arg6, IsLastValue});
75 }
76 
77 static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
78                         bool IsLast) {
79   auto Arg0 = fitArgInto64Bits(Builder, Arg);
80   auto Zero = Builder.getInt64(0);
81   return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,
82                         Zero, IsLast);
83 }
84 
85 // The device library does not provide strlen, so we build our own loop
86 // here. While we are at it, we also include the terminating null in the length.
87 static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {
88   auto *Prev = Builder.GetInsertBlock();
89   Module *M = Prev->getModule();
90 
91   auto CharZero = Builder.getInt8(0);
92   auto One = Builder.getInt64(1);
93   auto Zero = Builder.getInt64(0);
94   auto Int64Ty = Builder.getInt64Ty();
95 
96   // The length is either zero for a null pointer, or the computed value for an
97   // actual string. We need a join block for a phi that represents the final
98   // value.
99   //
100   //  Strictly speaking, the zero does not matter since
101   // __ockl_printf_append_string_n ignores the length if the pointer is null.
102   BasicBlock *Join = nullptr;
103   if (Prev->getTerminator()) {
104     Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),
105                                  "strlen.join");
106     Prev->getTerminator()->eraseFromParent();
107   } else {
108     Join = BasicBlock::Create(M->getContext(), "strlen.join",
109                               Prev->getParent());
110   }
111   BasicBlock *While =
112       BasicBlock::Create(M->getContext(), "strlen.while",
113                          Prev->getParent(), Join);
114   BasicBlock *WhileDone = BasicBlock::Create(
115       M->getContext(), "strlen.while.done",
116       Prev->getParent(), Join);
117 
118   // Emit an early return for when the pointer is null.
119   Builder.SetInsertPoint(Prev);
120   auto CmpNull =
121       Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));
122   BranchInst::Create(Join, While, CmpNull, Prev);
123 
124   // Entry to the while loop.
125   Builder.SetInsertPoint(While);
126 
127   auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);
128   PtrPhi->addIncoming(Str, Prev);
129   auto PtrNext = Builder.CreateGEP(Builder.getInt8Ty(), PtrPhi, One);
130   PtrPhi->addIncoming(PtrNext, While);
131 
132   // Condition for the while loop.
133   auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi);
134   auto Cmp = Builder.CreateICmpEQ(Data, CharZero);
135   Builder.CreateCondBr(Cmp, WhileDone, While);
136 
137   // Add one to the computed length.
138   Builder.SetInsertPoint(WhileDone, WhileDone->begin());
139   auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);
140   auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);
141   auto Len = Builder.CreateSub(End, Begin);
142   Len = Builder.CreateAdd(Len, One);
143 
144   // Final join.
145   BranchInst::Create(Join, WhileDone);
146   Builder.SetInsertPoint(Join, Join->begin());
147   auto LenPhi = Builder.CreatePHI(Len->getType(), 2);
148   LenPhi->addIncoming(Len, WhileDone);
149   LenPhi->addIncoming(Zero, Prev);
150 
151   return LenPhi;
152 }
153 
154 static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,
155                                 Value *Length, bool isLast) {
156   auto Int64Ty = Builder.getInt64Ty();
157   auto IsLastInt32 = Builder.getInt32(isLast);
158   auto M = Builder.GetInsertBlock()->getModule();
159   auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,
160                                    Desc->getType(), Str->getType(),
161                                    Length->getType(), IsLastInt32->getType());
162   return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});
163 }
164 
165 static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,
166                            bool IsLast) {
167   auto Length = getStrlenWithNull(Builder, Arg);
168   return callAppendStringN(Builder, Desc, Arg, Length, IsLast);
169 }
170 
171 static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,
172                          bool SpecIsCString, bool IsLast) {
173   if (SpecIsCString && isa<PointerType>(Arg->getType())) {
174     return appendString(Builder, Desc, Arg, IsLast);
175   }
176   // If the format specifies a string but the argument is not, the frontend will
177   // have printed a warning. We just rely on undefined behaviour and send the
178   // argument anyway.
179   return appendArg(Builder, Desc, Arg, IsLast);
180 }
181 
182 // Scan the format string to locate all specifiers, and mark the ones that
183 // specify a string, i.e, the "%s" specifier with optional '*' characters.
184 static void locateCStrings(SparseBitVector<8> &BV, StringRef Str) {
185   static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";
186   size_t SpecPos = 0;
187   // Skip the first argument, the format string.
188   unsigned ArgIdx = 1;
189 
190   while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {
191     if (Str[SpecPos + 1] == '%') {
192       SpecPos += 2;
193       continue;
194     }
195     auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);
196     if (SpecEnd == StringRef::npos)
197       return;
198     auto Spec = Str.slice(SpecPos, SpecEnd + 1);
199     ArgIdx += Spec.count('*');
200     if (Str[SpecEnd] == 's') {
201       BV.set(ArgIdx);
202     }
203     SpecPos = SpecEnd + 1;
204     ++ArgIdx;
205   }
206 }
207 
208 // helper struct to package the string related data
209 struct StringData {
210   StringRef Str;
211   Value *RealSize = nullptr;
212   Value *AlignedSize = nullptr;
213   bool IsConst = true;
214 
215   StringData(StringRef ST, Value *RS, Value *AS, bool IC)
216       : Str(ST), RealSize(RS), AlignedSize(AS), IsConst(IC) {}
217 };
218 
219 // Calculates frame size required for current printf expansion and allocates
220 // space on printf buffer. Printf frame includes following contents
221 // [ ControlDWord , format string/Hash , Arguments (each aligned to 8 byte) ]
222 static Value *callBufferedPrintfStart(
223     IRBuilder<> &Builder, ArrayRef<Value *> Args, Value *Fmt,
224     bool isConstFmtStr, SparseBitVector<8> &SpecIsCString,
225     SmallVectorImpl<StringData> &StringContents, Value *&ArgSize) {
226   Module *M = Builder.GetInsertBlock()->getModule();
227   Value *NonConstStrLen = nullptr;
228   Value *LenWithNull = nullptr;
229   Value *LenWithNullAligned = nullptr;
230   Value *TempAdd = nullptr;
231 
232   // First 4 bytes to be reserved for control dword
233   size_t BufSize = 4;
234   if (isConstFmtStr)
235     // First 8 bytes of MD5 hash
236     BufSize += 8;
237   else {
238     LenWithNull = getStrlenWithNull(Builder, Fmt);
239 
240     // Align the computed length to next 8 byte boundary
241     TempAdd = Builder.CreateAdd(LenWithNull,
242                                 ConstantInt::get(LenWithNull->getType(), 7U));
243     NonConstStrLen = Builder.CreateAnd(
244         TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
245 
246     StringContents.push_back(
247         StringData(StringRef(), LenWithNull, NonConstStrLen, false));
248   }
249 
250   for (size_t i = 1; i < Args.size(); i++) {
251     if (SpecIsCString.test(i)) {
252       StringRef ArgStr;
253       if (getConstantStringInfo(Args[i], ArgStr)) {
254         auto alignedLen = alignTo(ArgStr.size() + 1, 8);
255         StringContents.push_back(StringData(
256             ArgStr,
257             /*RealSize*/ nullptr, /*AlignedSize*/ nullptr, /*IsConst*/ true));
258         BufSize += alignedLen;
259       } else {
260         LenWithNull = getStrlenWithNull(Builder, Args[i]);
261 
262         // Align the computed length to next 8 byte boundary
263         TempAdd = Builder.CreateAdd(
264             LenWithNull, ConstantInt::get(LenWithNull->getType(), 7U));
265         LenWithNullAligned = Builder.CreateAnd(
266             TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));
267 
268         if (NonConstStrLen) {
269           auto Val = Builder.CreateAdd(LenWithNullAligned, NonConstStrLen,
270                                        "cumulativeAdd");
271           NonConstStrLen = Val;
272         } else
273           NonConstStrLen = LenWithNullAligned;
274 
275         StringContents.push_back(
276             StringData(StringRef(), LenWithNull, LenWithNullAligned, false));
277       }
278     } else {
279       int AllocSize = M->getDataLayout().getTypeAllocSize(Args[i]->getType());
280       // We end up expanding non string arguments to 8 bytes
281       // (args smaller than 8 bytes)
282       BufSize += std::max(AllocSize, 8);
283     }
284   }
285 
286   // calculate final size value to be passed to printf_alloc
287   Value *SizeToReserve = ConstantInt::get(Builder.getInt64Ty(), BufSize, false);
288   SmallVector<Value *, 1> Alloc_args;
289   if (NonConstStrLen)
290     SizeToReserve = Builder.CreateAdd(NonConstStrLen, SizeToReserve);
291 
292   ArgSize = Builder.CreateTrunc(SizeToReserve, Builder.getInt32Ty());
293   Alloc_args.push_back(ArgSize);
294 
295   // call the printf_alloc function
296   AttributeList Attr = AttributeList::get(
297       Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind);
298 
299   Type *Tys_alloc[1] = {Builder.getInt32Ty()};
300   Type *PtrTy =
301       Builder.getPtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());
302   FunctionType *FTy_alloc = FunctionType::get(PtrTy, Tys_alloc, false);
303   auto PrintfAllocFn =
304       M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr);
305 
306   return Builder.CreateCall(PrintfAllocFn, Alloc_args, "printf_alloc_fn");
307 }
308 
309 // Prepare constant string argument to push onto the buffer
310 static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder,
311                                      SmallVectorImpl<Value *> &WhatToStore) {
312   std::string Str(SD->Str.str() + '\0');
313 
314   DataExtractor Extractor(Str, /*IsLittleEndian=*/true, 8);
315   DataExtractor::Cursor Offset(0);
316   while (Offset && Offset.tell() < Str.size()) {
317     const uint64_t ReadSize = 4;
318     uint64_t ReadNow = std::min(ReadSize, Str.size() - Offset.tell());
319     uint64_t ReadBytes = 0;
320     switch (ReadNow) {
321     default:
322       llvm_unreachable("min(4, X) > 4?");
323     case 1:
324       ReadBytes = Extractor.getU8(Offset);
325       break;
326     case 2:
327       ReadBytes = Extractor.getU16(Offset);
328       break;
329     case 3:
330       ReadBytes = Extractor.getU24(Offset);
331       break;
332     case 4:
333       ReadBytes = Extractor.getU32(Offset);
334       break;
335     }
336     cantFail(Offset.takeError(), "failed to read bytes from constant array");
337 
338     APInt IntVal(8 * ReadSize, ReadBytes);
339 
340     // TODO: Should not bother aligning up.
341     if (ReadNow < ReadSize)
342       IntVal = IntVal.zext(8 * ReadSize);
343 
344     Type *IntTy = Type::getIntNTy(Builder.getContext(), IntVal.getBitWidth());
345     WhatToStore.push_back(ConstantInt::get(IntTy, IntVal));
346   }
347   // Additional padding for 8 byte alignment
348   int Rem = (Str.size() % 8);
349   if (Rem > 0 && Rem <= 4)
350     WhatToStore.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));
351 }
352 
353 static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) {
354   const DataLayout &DL = Builder.GetInsertBlock()->getDataLayout();
355   auto Ty = Arg->getType();
356 
357   if (auto IntTy = dyn_cast<IntegerType>(Ty)) {
358     if (IntTy->getBitWidth() < 64) {
359       return Builder.CreateZExt(Arg, Builder.getInt64Ty());
360     }
361   }
362 
363   if (Ty->isFloatingPointTy()) {
364     if (DL.getTypeAllocSize(Ty) < 8) {
365       return Builder.CreateFPExt(Arg, Builder.getDoubleTy());
366     }
367   }
368 
369   return Arg;
370 }
371 
372 static void
373 callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args,
374                           Value *PtrToStore, SparseBitVector<8> &SpecIsCString,
375                           SmallVectorImpl<StringData> &StringContents,
376                           bool IsConstFmtStr) {
377   Module *M = Builder.GetInsertBlock()->getModule();
378   const DataLayout &DL = M->getDataLayout();
379   auto StrIt = StringContents.begin();
380   size_t i = IsConstFmtStr ? 1 : 0;
381   for (; i < Args.size(); i++) {
382     SmallVector<Value *, 32> WhatToStore;
383     if ((i == 0) || SpecIsCString.test(i)) {
384       if (StrIt->IsConst) {
385         processConstantStringArg(StrIt, Builder, WhatToStore);
386         StrIt++;
387       } else {
388         // This copies the contents of the string, however the next offset
389         // is at aligned length, the extra space that might be created due
390         // to alignment padding is not populated with any specific value
391         // here. This would be safe as long as runtime is sync with
392         // the offsets.
393         Builder.CreateMemCpy(PtrToStore, /*DstAlign*/ Align(1), Args[i],
394                              /*SrcAlign*/ Args[i]->getPointerAlignment(DL),
395                              StrIt->RealSize);
396 
397         PtrToStore =
398             Builder.CreateInBoundsGEP(Builder.getInt8Ty(), PtrToStore,
399                                       {StrIt->AlignedSize}, "PrintBuffNextPtr");
400         LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:"
401                           << *PtrToStore << '\n');
402 
403         // done with current argument, move to next
404         StrIt++;
405         continue;
406       }
407     } else {
408       WhatToStore.push_back(processNonStringArg(Args[i], Builder));
409     }
410 
411     for (Value *toStore : WhatToStore) {
412       StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore);
413       LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff
414                         << '\n');
415       (void)StBuff;
416       PtrToStore = Builder.CreateConstInBoundsGEP1_32(
417           Builder.getInt8Ty(), PtrToStore,
418           M->getDataLayout().getTypeAllocSize(toStore->getType()),
419           "PrintBuffNextPtr");
420       LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" << *PtrToStore
421                         << '\n');
422     }
423   }
424 }
425 
426 Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, ArrayRef<Value *> Args,
427                                   bool IsBuffered) {
428   auto NumOps = Args.size();
429   assert(NumOps >= 1);
430 
431   auto Fmt = Args[0];
432   SparseBitVector<8> SpecIsCString;
433   StringRef FmtStr;
434 
435   if (getConstantStringInfo(Fmt, FmtStr))
436     locateCStrings(SpecIsCString, FmtStr);
437 
438   if (IsBuffered) {
439     SmallVector<StringData, 8> StringContents;
440     Module *M = Builder.GetInsertBlock()->getModule();
441     LLVMContext &Ctx = Builder.getContext();
442     auto Int8Ty = Builder.getInt8Ty();
443     auto Int32Ty = Builder.getInt32Ty();
444     bool IsConstFmtStr = !FmtStr.empty();
445 
446     Value *ArgSize = nullptr;
447     Value *Ptr =
448         callBufferedPrintfStart(Builder, Args, Fmt, IsConstFmtStr,
449                                 SpecIsCString, StringContents, ArgSize);
450 
451     // The buffered version still follows OpenCL printf standards for
452     // printf return value, i.e 0 on success, -1 on failure.
453     ConstantPointerNull *zeroIntPtr =
454         ConstantPointerNull::get(cast<PointerType>(Ptr->getType()));
455 
456     auto *Cmp = cast<ICmpInst>(Builder.CreateICmpNE(Ptr, zeroIntPtr, ""));
457 
458     BasicBlock *End = BasicBlock::Create(Ctx, "end.block",
459                                          Builder.GetInsertBlock()->getParent());
460     BasicBlock *ArgPush = BasicBlock::Create(
461         Ctx, "argpush.block", Builder.GetInsertBlock()->getParent());
462 
463     BranchInst::Create(ArgPush, End, Cmp, Builder.GetInsertBlock());
464     Builder.SetInsertPoint(ArgPush);
465 
466     // Create controlDWord and store as the first entry, format as follows
467     // Bit 0 (LSB) -> stream (1 if stderr, 0 if stdout, printf always outputs to
468     // stdout) Bit 1 -> constant format string (1 if constant) Bits 2-31 -> size
469     // of printf data frame
470     auto ConstantTwo = Builder.getInt32(2);
471     auto ControlDWord = Builder.CreateShl(ArgSize, ConstantTwo);
472     if (IsConstFmtStr)
473       ControlDWord = Builder.CreateOr(ControlDWord, ConstantTwo);
474 
475     Builder.CreateStore(ControlDWord, Ptr);
476 
477     Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 4);
478 
479     // Create MD5 hash for costant format string, push low 64 bits of the
480     // same onto buffer and metadata.
481     NamedMDNode *metaD = M->getOrInsertNamedMetadata("llvm.printf.fmts");
482     if (IsConstFmtStr) {
483       MD5 Hasher;
484       MD5::MD5Result Hash;
485       Hasher.update(FmtStr);
486       Hasher.final(Hash);
487 
488       // Try sticking to llvm.printf.fmts format, although we are not going to
489       // use the ID and argument size fields while printing,
490       std::string MetadataStr =
491           "0:0:" + llvm::utohexstr(Hash.low(), /*LowerCase=*/true) + "," +
492           FmtStr.str();
493       MDString *fmtStrArray = MDString::get(Ctx, MetadataStr);
494       MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
495       metaD->addOperand(myMD);
496 
497       Builder.CreateStore(Builder.getInt64(Hash.low()), Ptr);
498       Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 8);
499     } else {
500       // Include a dummy metadata instance in case of only non constant
501       // format string usage, This might be an absurd usecase but needs to
502       // be done for completeness
503       if (metaD->getNumOperands() == 0) {
504         MDString *fmtStrArray =
505             MDString::get(Ctx, "0:0:ffffffff,\"Non const format string\"");
506         MDNode *myMD = MDNode::get(Ctx, fmtStrArray);
507         metaD->addOperand(myMD);
508       }
509     }
510 
511     // Push The printf arguments onto buffer
512     callBufferedPrintfArgPush(Builder, Args, Ptr, SpecIsCString, StringContents,
513                               IsConstFmtStr);
514 
515     // End block, returns -1 on failure
516     BranchInst::Create(End, ArgPush);
517     Builder.SetInsertPoint(End);
518     return Builder.CreateSExt(Builder.CreateNot(Cmp), Int32Ty, "printf_result");
519   }
520 
521   auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));
522   Desc = appendString(Builder, Desc, Fmt, NumOps == 1);
523 
524   // FIXME: This invokes hostcall once for each argument. We can pack up to
525   // seven scalar printf arguments in a single hostcall. See the signature of
526   // callAppendArgs().
527   for (unsigned int i = 1; i != NumOps; ++i) {
528     bool IsLast = i == NumOps - 1;
529     bool IsCString = SpecIsCString.test(i);
530     Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);
531   }
532 
533   return Builder.CreateTrunc(Desc, Builder.getInt32Ty());
534 }
535