xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVUtils.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains miscellaneous utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVUtils.h"
14 #include "MCTargetDesc/SPIRVBaseInfo.h"
15 #include "SPIRV.h"
16 #include "SPIRVInstrInfo.h"
17 #include "SPIRVSubtarget.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/Demangle/Demangle.h"
24 #include "llvm/IR/IntrinsicsSPIRV.h"
25 
26 namespace llvm {
27 
28 // The following functions are used to add these string literals as a series of
29 // 32-bit integer operands with the correct format, and unpack them if necessary
30 // when making string comparisons in compiler passes.
31 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
convertCharsToWord(const StringRef & Str,unsigned i)32 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
33   uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
34   for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
35     unsigned StrIndex = i + WordIndex;
36     uint8_t CharToAdd = 0;       // Initilize char as padding/null.
37     if (StrIndex < Str.size()) { // If it's within the string, get a real char.
38       CharToAdd = Str[StrIndex];
39     }
40     Word |= (CharToAdd << (WordIndex * 8));
41   }
42   return Word;
43 }
44 
45 // Get length including padding and null terminator.
getPaddedLen(const StringRef & Str)46 static size_t getPaddedLen(const StringRef &Str) {
47   const size_t Len = Str.size() + 1;
48   return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4));
49 }
50 
addStringImm(const StringRef & Str,MCInst & Inst)51 void addStringImm(const StringRef &Str, MCInst &Inst) {
52   const size_t PaddedLen = getPaddedLen(Str);
53   for (unsigned i = 0; i < PaddedLen; i += 4) {
54     // Add an operand for the 32-bits of chars or padding.
55     Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i)));
56   }
57 }
58 
addStringImm(const StringRef & Str,MachineInstrBuilder & MIB)59 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
60   const size_t PaddedLen = getPaddedLen(Str);
61   for (unsigned i = 0; i < PaddedLen; i += 4) {
62     // Add an operand for the 32-bits of chars or padding.
63     MIB.addImm(convertCharsToWord(Str, i));
64   }
65 }
66 
addStringImm(const StringRef & Str,IRBuilder<> & B,std::vector<Value * > & Args)67 void addStringImm(const StringRef &Str, IRBuilder<> &B,
68                   std::vector<Value *> &Args) {
69   const size_t PaddedLen = getPaddedLen(Str);
70   for (unsigned i = 0; i < PaddedLen; i += 4) {
71     // Add a vector element for the 32-bits of chars or padding.
72     Args.push_back(B.getInt32(convertCharsToWord(Str, i)));
73   }
74 }
75 
getStringImm(const MachineInstr & MI,unsigned StartIndex)76 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
77   return getSPIRVStringOperand(MI, StartIndex);
78 }
79 
addNumImm(const APInt & Imm,MachineInstrBuilder & MIB)80 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
81   const auto Bitwidth = Imm.getBitWidth();
82   if (Bitwidth == 1)
83     return; // Already handled
84   else if (Bitwidth <= 32) {
85     MIB.addImm(Imm.getZExtValue());
86     // Asm Printer needs this info to print floating-type correctly
87     if (Bitwidth == 16)
88       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
89     return;
90   } else if (Bitwidth <= 64) {
91     uint64_t FullImm = Imm.getZExtValue();
92     uint32_t LowBits = FullImm & 0xffffffff;
93     uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
94     MIB.addImm(LowBits).addImm(HighBits);
95     return;
96   }
97   report_fatal_error("Unsupported constant bitwidth");
98 }
99 
buildOpName(Register Target,const StringRef & Name,MachineIRBuilder & MIRBuilder)100 void buildOpName(Register Target, const StringRef &Name,
101                  MachineIRBuilder &MIRBuilder) {
102   if (!Name.empty()) {
103     auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target);
104     addStringImm(Name, MIB);
105   }
106 }
107 
finishBuildOpDecorate(MachineInstrBuilder & MIB,const std::vector<uint32_t> & DecArgs,StringRef StrImm)108 static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
109                                   const std::vector<uint32_t> &DecArgs,
110                                   StringRef StrImm) {
111   if (!StrImm.empty())
112     addStringImm(StrImm, MIB);
113   for (const auto &DecArg : DecArgs)
114     MIB.addImm(DecArg);
115 }
116 
buildOpDecorate(Register Reg,MachineIRBuilder & MIRBuilder,SPIRV::Decoration::Decoration Dec,const std::vector<uint32_t> & DecArgs,StringRef StrImm)117 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
118                      SPIRV::Decoration::Decoration Dec,
119                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
120   auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
121                  .addUse(Reg)
122                  .addImm(static_cast<uint32_t>(Dec));
123   finishBuildOpDecorate(MIB, DecArgs, StrImm);
124 }
125 
buildOpDecorate(Register Reg,MachineInstr & I,const SPIRVInstrInfo & TII,SPIRV::Decoration::Decoration Dec,const std::vector<uint32_t> & DecArgs,StringRef StrImm)126 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
127                      SPIRV::Decoration::Decoration Dec,
128                      const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
129   MachineBasicBlock &MBB = *I.getParent();
130   auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate))
131                  .addUse(Reg)
132                  .addImm(static_cast<uint32_t>(Dec));
133   finishBuildOpDecorate(MIB, DecArgs, StrImm);
134 }
135 
buildOpSpirvDecorations(Register Reg,MachineIRBuilder & MIRBuilder,const MDNode * GVarMD)136 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
137                              const MDNode *GVarMD) {
138   for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) {
139     auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I));
140     if (!OpMD)
141       report_fatal_error("Invalid decoration");
142     if (OpMD->getNumOperands() == 0)
143       report_fatal_error("Expect operand(s) of the decoration");
144     ConstantInt *DecorationId =
145         mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0));
146     if (!DecorationId)
147       report_fatal_error("Expect SPIR-V <Decoration> operand to be the first "
148                          "element of the decoration");
149     auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
150                    .addUse(Reg)
151                    .addImm(static_cast<uint32_t>(DecorationId->getZExtValue()));
152     for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) {
153       if (ConstantInt *OpV =
154               mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI)))
155         MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue()));
156       else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI)))
157         addStringImm(OpV->getString(), MIB);
158       else
159         report_fatal_error("Unexpected operand of the decoration");
160     }
161   }
162 }
163 
164 // TODO: maybe the following two functions should be handled in the subtarget
165 // to allow for different OpenCL vs Vulkan handling.
storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC)166 unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
167   switch (SC) {
168   case SPIRV::StorageClass::Function:
169     return 0;
170   case SPIRV::StorageClass::CrossWorkgroup:
171     return 1;
172   case SPIRV::StorageClass::UniformConstant:
173     return 2;
174   case SPIRV::StorageClass::Workgroup:
175     return 3;
176   case SPIRV::StorageClass::Generic:
177     return 4;
178   case SPIRV::StorageClass::DeviceOnlyINTEL:
179     return 5;
180   case SPIRV::StorageClass::HostOnlyINTEL:
181     return 6;
182   case SPIRV::StorageClass::Input:
183     return 7;
184   default:
185     report_fatal_error("Unable to get address space id");
186   }
187 }
188 
189 SPIRV::StorageClass::StorageClass
addressSpaceToStorageClass(unsigned AddrSpace,const SPIRVSubtarget & STI)190 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
191   switch (AddrSpace) {
192   case 0:
193     return SPIRV::StorageClass::Function;
194   case 1:
195     return SPIRV::StorageClass::CrossWorkgroup;
196   case 2:
197     return SPIRV::StorageClass::UniformConstant;
198   case 3:
199     return SPIRV::StorageClass::Workgroup;
200   case 4:
201     return SPIRV::StorageClass::Generic;
202   case 5:
203     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
204                ? SPIRV::StorageClass::DeviceOnlyINTEL
205                : SPIRV::StorageClass::CrossWorkgroup;
206   case 6:
207     return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
208                ? SPIRV::StorageClass::HostOnlyINTEL
209                : SPIRV::StorageClass::CrossWorkgroup;
210   case 7:
211     return SPIRV::StorageClass::Input;
212   default:
213     report_fatal_error("Unknown address space");
214   }
215 }
216 
217 SPIRV::MemorySemantics::MemorySemantics
getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC)218 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) {
219   switch (SC) {
220   case SPIRV::StorageClass::StorageBuffer:
221   case SPIRV::StorageClass::Uniform:
222     return SPIRV::MemorySemantics::UniformMemory;
223   case SPIRV::StorageClass::Workgroup:
224     return SPIRV::MemorySemantics::WorkgroupMemory;
225   case SPIRV::StorageClass::CrossWorkgroup:
226     return SPIRV::MemorySemantics::CrossWorkgroupMemory;
227   case SPIRV::StorageClass::AtomicCounter:
228     return SPIRV::MemorySemantics::AtomicCounterMemory;
229   case SPIRV::StorageClass::Image:
230     return SPIRV::MemorySemantics::ImageMemory;
231   default:
232     return SPIRV::MemorySemantics::None;
233   }
234 }
235 
getMemSemantics(AtomicOrdering Ord)236 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
237   switch (Ord) {
238   case AtomicOrdering::Acquire:
239     return SPIRV::MemorySemantics::Acquire;
240   case AtomicOrdering::Release:
241     return SPIRV::MemorySemantics::Release;
242   case AtomicOrdering::AcquireRelease:
243     return SPIRV::MemorySemantics::AcquireRelease;
244   case AtomicOrdering::SequentiallyConsistent:
245     return SPIRV::MemorySemantics::SequentiallyConsistent;
246   case AtomicOrdering::Unordered:
247   case AtomicOrdering::Monotonic:
248   case AtomicOrdering::NotAtomic:
249     return SPIRV::MemorySemantics::None;
250   }
251   llvm_unreachable(nullptr);
252 }
253 
getDefInstrMaybeConstant(Register & ConstReg,const MachineRegisterInfo * MRI)254 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
255                                        const MachineRegisterInfo *MRI) {
256   MachineInstr *MI = MRI->getVRegDef(ConstReg);
257   MachineInstr *ConstInstr =
258       MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
259           ? MRI->getVRegDef(MI->getOperand(1).getReg())
260           : MI;
261   if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
262     if (GI->is(Intrinsic::spv_track_constant)) {
263       ConstReg = ConstInstr->getOperand(2).getReg();
264       return MRI->getVRegDef(ConstReg);
265     }
266   } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
267     ConstReg = ConstInstr->getOperand(1).getReg();
268     return MRI->getVRegDef(ConstReg);
269   }
270   return MRI->getVRegDef(ConstReg);
271 }
272 
getIConstVal(Register ConstReg,const MachineRegisterInfo * MRI)273 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
274   const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
275   assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
276   return MI->getOperand(1).getCImm()->getValue().getZExtValue();
277 }
278 
isSpvIntrinsic(const MachineInstr & MI,Intrinsic::ID IntrinsicID)279 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
280   if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
281     return GI->is(IntrinsicID);
282   return false;
283 }
284 
getMDOperandAsType(const MDNode * N,unsigned I)285 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
286   Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
287   return toTypedPointer(ElementTy);
288 }
289 
290 // The set of names is borrowed from the SPIR-V translator.
291 // TODO: may be implemented in SPIRVBuiltins.td.
isPipeOrAddressSpaceCastBI(const StringRef MangledName)292 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) {
293   return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" ||
294          MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" ||
295          MangledName == "write_pipe_4" || MangledName == "read_pipe_4" ||
296          MangledName == "reserve_write_pipe" ||
297          MangledName == "reserve_read_pipe" ||
298          MangledName == "commit_write_pipe" ||
299          MangledName == "commit_read_pipe" ||
300          MangledName == "work_group_reserve_write_pipe" ||
301          MangledName == "work_group_reserve_read_pipe" ||
302          MangledName == "work_group_commit_write_pipe" ||
303          MangledName == "work_group_commit_read_pipe" ||
304          MangledName == "get_pipe_num_packets_ro" ||
305          MangledName == "get_pipe_max_packets_ro" ||
306          MangledName == "get_pipe_num_packets_wo" ||
307          MangledName == "get_pipe_max_packets_wo" ||
308          MangledName == "sub_group_reserve_write_pipe" ||
309          MangledName == "sub_group_reserve_read_pipe" ||
310          MangledName == "sub_group_commit_write_pipe" ||
311          MangledName == "sub_group_commit_read_pipe" ||
312          MangledName == "to_global" || MangledName == "to_local" ||
313          MangledName == "to_private";
314 }
315 
isEnqueueKernelBI(const StringRef MangledName)316 static bool isEnqueueKernelBI(const StringRef MangledName) {
317   return MangledName == "__enqueue_kernel_basic" ||
318          MangledName == "__enqueue_kernel_basic_events" ||
319          MangledName == "__enqueue_kernel_varargs" ||
320          MangledName == "__enqueue_kernel_events_varargs";
321 }
322 
isKernelQueryBI(const StringRef MangledName)323 static bool isKernelQueryBI(const StringRef MangledName) {
324   return MangledName == "__get_kernel_work_group_size_impl" ||
325          MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" ||
326          MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" ||
327          MangledName == "__get_kernel_preferred_work_group_size_multiple_impl";
328 }
329 
isNonMangledOCLBuiltin(StringRef Name)330 static bool isNonMangledOCLBuiltin(StringRef Name) {
331   if (!Name.starts_with("__"))
332     return false;
333 
334   return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) ||
335          isPipeOrAddressSpaceCastBI(Name.drop_front(2)) ||
336          Name == "__translate_sampler_initializer";
337 }
338 
getOclOrSpirvBuiltinDemangledName(StringRef Name)339 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
340   bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
341   bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
342   bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
343   bool IsMangled = Name.starts_with("_Z");
344 
345   // Otherwise use simple demangling to return the function name.
346   if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
347     return Name.str();
348 
349   // Try to use the itanium demangler.
350   if (char *DemangledName = itaniumDemangle(Name.data())) {
351     std::string Result = DemangledName;
352     free(DemangledName);
353     return Result;
354   }
355 
356   // Autocheck C++, maybe need to do explicit check of the source language.
357   // OpenCL C++ built-ins are declared in cl namespace.
358   // TODO: consider using 'St' abbriviation for cl namespace mangling.
359   // Similar to ::std:: in C++.
360   size_t Start, Len = 0;
361   size_t DemangledNameLenStart = 2;
362   if (Name.starts_with("_ZN")) {
363     // Skip CV and ref qualifiers.
364     size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3);
365     // All built-ins are in the ::cl:: namespace.
366     if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv")
367       return std::string();
368     DemangledNameLenStart = NameSpaceStart + 11;
369   }
370   Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
371   Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart)
372       .getAsInteger(10, Len);
373   return Name.substr(Start, Len).str();
374 }
375 
hasBuiltinTypePrefix(StringRef Name)376 bool hasBuiltinTypePrefix(StringRef Name) {
377   if (Name.starts_with("opencl.") || Name.starts_with("ocl_") ||
378       Name.starts_with("spirv."))
379     return true;
380   return false;
381 }
382 
isSpecialOpaqueType(const Type * Ty)383 bool isSpecialOpaqueType(const Type *Ty) {
384   if (const TargetExtType *EType = dyn_cast<TargetExtType>(Ty))
385     return hasBuiltinTypePrefix(EType->getName());
386 
387   return false;
388 }
389 
isEntryPoint(const Function & F)390 bool isEntryPoint(const Function &F) {
391   // OpenCL handling: any function with the SPIR_KERNEL
392   // calling convention will be a potential entry point.
393   if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
394     return true;
395 
396   // HLSL handling: special attribute are emitted from the
397   // front-end.
398   if (F.getFnAttribute("hlsl.shader").isValid())
399     return true;
400 
401   return false;
402 }
403 
parseBasicTypeName(StringRef & TypeName,LLVMContext & Ctx)404 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
405   TypeName.consume_front("atomic_");
406   if (TypeName.consume_front("void"))
407     return Type::getVoidTy(Ctx);
408   else if (TypeName.consume_front("bool"))
409     return Type::getIntNTy(Ctx, 1);
410   else if (TypeName.consume_front("char") ||
411            TypeName.consume_front("unsigned char") ||
412            TypeName.consume_front("uchar"))
413     return Type::getInt8Ty(Ctx);
414   else if (TypeName.consume_front("short") ||
415            TypeName.consume_front("unsigned short") ||
416            TypeName.consume_front("ushort"))
417     return Type::getInt16Ty(Ctx);
418   else if (TypeName.consume_front("int") ||
419            TypeName.consume_front("unsigned int") ||
420            TypeName.consume_front("uint"))
421     return Type::getInt32Ty(Ctx);
422   else if (TypeName.consume_front("long") ||
423            TypeName.consume_front("unsigned long") ||
424            TypeName.consume_front("ulong"))
425     return Type::getInt64Ty(Ctx);
426   else if (TypeName.consume_front("half"))
427     return Type::getHalfTy(Ctx);
428   else if (TypeName.consume_front("float"))
429     return Type::getFloatTy(Ctx);
430   else if (TypeName.consume_front("double"))
431     return Type::getDoubleTy(Ctx);
432 
433   // Unable to recognize SPIRV type name
434   return nullptr;
435 }
436 
437 } // namespace llvm
438