xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- SPIRVBuiltins.cpp - SPIR-V Built-in Functions ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering builtin function calls and types using their
10 // demangled names and TableGen records.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVBuiltins.h"
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/IntrinsicsSPIRV.h"
21 #include <string>
22 #include <tuple>
23 
24 #define DEBUG_TYPE "spirv-builtins"
25 
26 namespace llvm {
27 namespace SPIRV {
28 #define GET_BuiltinGroup_DECL
29 #include "SPIRVGenTables.inc"
30 
31 struct DemangledBuiltin {
32   StringRef Name;
33   InstructionSet::InstructionSet Set;
34   BuiltinGroup Group;
35   uint8_t MinNumArgs;
36   uint8_t MaxNumArgs;
37 };
38 
39 #define GET_DemangledBuiltins_DECL
40 #define GET_DemangledBuiltins_IMPL
41 
42 struct IncomingCall {
43   const std::string BuiltinName;
44   const DemangledBuiltin *Builtin;
45 
46   const Register ReturnRegister;
47   const SPIRVType *ReturnType;
48   const SmallVectorImpl<Register> &Arguments;
49 
IncomingCallllvm::SPIRV::IncomingCall50   IncomingCall(const std::string BuiltinName, const DemangledBuiltin *Builtin,
51                const Register ReturnRegister, const SPIRVType *ReturnType,
52                const SmallVectorImpl<Register> &Arguments)
53       : BuiltinName(BuiltinName), Builtin(Builtin),
54         ReturnRegister(ReturnRegister), ReturnType(ReturnType),
55         Arguments(Arguments) {}
56 
isSpirvOpllvm::SPIRV::IncomingCall57   bool isSpirvOp() const { return BuiltinName.rfind("__spirv_", 0) == 0; }
58 };
59 
60 struct NativeBuiltin {
61   StringRef Name;
62   InstructionSet::InstructionSet Set;
63   uint32_t Opcode;
64 };
65 
66 #define GET_NativeBuiltins_DECL
67 #define GET_NativeBuiltins_IMPL
68 
69 struct GroupBuiltin {
70   StringRef Name;
71   uint32_t Opcode;
72   uint32_t GroupOperation;
73   bool IsElect;
74   bool IsAllOrAny;
75   bool IsAllEqual;
76   bool IsBallot;
77   bool IsInverseBallot;
78   bool IsBallotBitExtract;
79   bool IsBallotFindBit;
80   bool IsLogical;
81   bool NoGroupOperation;
82   bool HasBoolArg;
83 };
84 
85 #define GET_GroupBuiltins_DECL
86 #define GET_GroupBuiltins_IMPL
87 
88 struct IntelSubgroupsBuiltin {
89   StringRef Name;
90   uint32_t Opcode;
91   bool IsBlock;
92   bool IsWrite;
93 };
94 
95 #define GET_IntelSubgroupsBuiltins_DECL
96 #define GET_IntelSubgroupsBuiltins_IMPL
97 
98 struct AtomicFloatingBuiltin {
99   StringRef Name;
100   uint32_t Opcode;
101 };
102 
103 #define GET_AtomicFloatingBuiltins_DECL
104 #define GET_AtomicFloatingBuiltins_IMPL
105 struct GroupUniformBuiltin {
106   StringRef Name;
107   uint32_t Opcode;
108   bool IsLogical;
109 };
110 
111 #define GET_GroupUniformBuiltins_DECL
112 #define GET_GroupUniformBuiltins_IMPL
113 
114 struct GetBuiltin {
115   StringRef Name;
116   InstructionSet::InstructionSet Set;
117   BuiltIn::BuiltIn Value;
118 };
119 
120 using namespace BuiltIn;
121 #define GET_GetBuiltins_DECL
122 #define GET_GetBuiltins_IMPL
123 
124 struct ImageQueryBuiltin {
125   StringRef Name;
126   InstructionSet::InstructionSet Set;
127   uint32_t Component;
128 };
129 
130 #define GET_ImageQueryBuiltins_DECL
131 #define GET_ImageQueryBuiltins_IMPL
132 
133 struct ConvertBuiltin {
134   StringRef Name;
135   InstructionSet::InstructionSet Set;
136   bool IsDestinationSigned;
137   bool IsSaturated;
138   bool IsRounded;
139   bool IsBfloat16;
140   FPRoundingMode::FPRoundingMode RoundingMode;
141 };
142 
143 struct VectorLoadStoreBuiltin {
144   StringRef Name;
145   InstructionSet::InstructionSet Set;
146   uint32_t Number;
147   uint32_t ElementCount;
148   bool IsRounded;
149   FPRoundingMode::FPRoundingMode RoundingMode;
150 };
151 
152 using namespace FPRoundingMode;
153 #define GET_ConvertBuiltins_DECL
154 #define GET_ConvertBuiltins_IMPL
155 
156 using namespace InstructionSet;
157 #define GET_VectorLoadStoreBuiltins_DECL
158 #define GET_VectorLoadStoreBuiltins_IMPL
159 
160 #define GET_CLMemoryScope_DECL
161 #define GET_CLSamplerAddressingMode_DECL
162 #define GET_CLMemoryFenceFlags_DECL
163 #define GET_ExtendedBuiltins_DECL
164 #include "SPIRVGenTables.inc"
165 } // namespace SPIRV
166 
167 //===----------------------------------------------------------------------===//
168 // Misc functions for looking up builtins and veryfying requirements using
169 // TableGen records
170 //===----------------------------------------------------------------------===//
171 
172 namespace SPIRV {
173 /// Parses the name part of the demangled builtin call.
lookupBuiltinNameHelper(StringRef DemangledCall)174 std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
175   const static std::string PassPrefix = "(anonymous namespace)::";
176   std::string BuiltinName;
177   // Itanium Demangler result may have "(anonymous namespace)::" prefix
178   if (DemangledCall.starts_with(PassPrefix.c_str()))
179     BuiltinName = DemangledCall.substr(PassPrefix.length());
180   else
181     BuiltinName = DemangledCall;
182   // Extract the builtin function name and types of arguments from the call
183   // skeleton.
184   BuiltinName = BuiltinName.substr(0, BuiltinName.find('('));
185 
186   // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
187   if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
188     BuiltinName = BuiltinName.substr(12);
189 
190   // Check if the extracted name contains type information between angle
191   // brackets. If so, the builtin is an instantiated template - needs to have
192   // the information after angle brackets and return type removed.
193   if (BuiltinName.find('<') && BuiltinName.back() == '>') {
194     BuiltinName = BuiltinName.substr(0, BuiltinName.find('<'));
195     BuiltinName = BuiltinName.substr(BuiltinName.find_last_of(' ') + 1);
196   }
197 
198   // Check if the extracted name begins with "__spirv_ImageSampleExplicitLod"
199   // contains return type information at the end "_R<type>", if so extract the
200   // plain builtin name without the type information.
201   if (StringRef(BuiltinName).contains("__spirv_ImageSampleExplicitLod") &&
202       StringRef(BuiltinName).contains("_R")) {
203     BuiltinName = BuiltinName.substr(0, BuiltinName.find("_R"));
204   }
205 
206   return BuiltinName;
207 }
208 } // namespace SPIRV
209 
210 /// Looks up the demangled builtin call in the SPIRVBuiltins.td records using
211 /// the provided \p DemangledCall and specified \p Set.
212 ///
213 /// The lookup follows the following algorithm, returning the first successful
214 /// match:
215 /// 1. Search with the plain demangled name (expecting a 1:1 match).
216 /// 2. Search with the prefix before or suffix after the demangled name
217 /// signyfying the type of the first argument.
218 ///
219 /// \returns Wrapper around the demangled call and found builtin definition.
220 static std::unique_ptr<const SPIRV::IncomingCall>
lookupBuiltin(StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set,Register ReturnRegister,const SPIRVType * ReturnType,const SmallVectorImpl<Register> & Arguments)221 lookupBuiltin(StringRef DemangledCall,
222               SPIRV::InstructionSet::InstructionSet Set,
223               Register ReturnRegister, const SPIRVType *ReturnType,
224               const SmallVectorImpl<Register> &Arguments) {
225   std::string BuiltinName = SPIRV::lookupBuiltinNameHelper(DemangledCall);
226 
227   SmallVector<StringRef, 10> BuiltinArgumentTypes;
228   StringRef BuiltinArgs =
229       DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
230   BuiltinArgs.split(BuiltinArgumentTypes, ',', -1, false);
231 
232   // Look up the builtin in the defined set. Start with the plain demangled
233   // name, expecting a 1:1 match in the defined builtin set.
234   const SPIRV::DemangledBuiltin *Builtin;
235   if ((Builtin = SPIRV::lookupBuiltin(BuiltinName, Set)))
236     return std::make_unique<SPIRV::IncomingCall>(
237         BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
238 
239   // If the initial look up was unsuccessful and the demangled call takes at
240   // least 1 argument, add a prefix or suffix signifying the type of the first
241   // argument and repeat the search.
242   if (BuiltinArgumentTypes.size() >= 1) {
243     char FirstArgumentType = BuiltinArgumentTypes[0][0];
244     // Prefix to be added to the builtin's name for lookup.
245     // For example, OpenCL "abs" taking an unsigned value has a prefix "u_".
246     std::string Prefix;
247 
248     switch (FirstArgumentType) {
249     // Unsigned:
250     case 'u':
251       if (Set == SPIRV::InstructionSet::OpenCL_std)
252         Prefix = "u_";
253       else if (Set == SPIRV::InstructionSet::GLSL_std_450)
254         Prefix = "u";
255       break;
256     // Signed:
257     case 'c':
258     case 's':
259     case 'i':
260     case 'l':
261       if (Set == SPIRV::InstructionSet::OpenCL_std)
262         Prefix = "s_";
263       else if (Set == SPIRV::InstructionSet::GLSL_std_450)
264         Prefix = "s";
265       break;
266     // Floating-point:
267     case 'f':
268     case 'd':
269     case 'h':
270       if (Set == SPIRV::InstructionSet::OpenCL_std ||
271           Set == SPIRV::InstructionSet::GLSL_std_450)
272         Prefix = "f";
273       break;
274     }
275 
276     // If argument-type name prefix was added, look up the builtin again.
277     if (!Prefix.empty() &&
278         (Builtin = SPIRV::lookupBuiltin(Prefix + BuiltinName, Set)))
279       return std::make_unique<SPIRV::IncomingCall>(
280           BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
281 
282     // If lookup with a prefix failed, find a suffix to be added to the
283     // builtin's name for lookup. For example, OpenCL "group_reduce_max" taking
284     // an unsigned value has a suffix "u".
285     std::string Suffix;
286 
287     switch (FirstArgumentType) {
288     // Unsigned:
289     case 'u':
290       Suffix = "u";
291       break;
292     // Signed:
293     case 'c':
294     case 's':
295     case 'i':
296     case 'l':
297       Suffix = "s";
298       break;
299     // Floating-point:
300     case 'f':
301     case 'd':
302     case 'h':
303       Suffix = "f";
304       break;
305     }
306 
307     // If argument-type name suffix was added, look up the builtin again.
308     if (!Suffix.empty() &&
309         (Builtin = SPIRV::lookupBuiltin(BuiltinName + Suffix, Set)))
310       return std::make_unique<SPIRV::IncomingCall>(
311           BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
312   }
313 
314   // No builtin with such name was found in the set.
315   return nullptr;
316 }
317 
getBlockStructInstr(Register ParamReg,MachineRegisterInfo * MRI)318 static MachineInstr *getBlockStructInstr(Register ParamReg,
319                                          MachineRegisterInfo *MRI) {
320   // We expect the following sequence of instructions:
321   //   %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
322   //   or       = G_GLOBAL_VALUE @block_literal_global
323   //   %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
324   //   %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
325   MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
326   assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
327          MI->getOperand(1).isReg());
328   Register BitcastReg = MI->getOperand(1).getReg();
329   MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
330   assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
331          BitcastMI->getOperand(2).isReg());
332   Register ValueReg = BitcastMI->getOperand(2).getReg();
333   MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
334   return ValueMI;
335 }
336 
337 // Return an integer constant corresponding to the given register and
338 // defined in spv_track_constant.
339 // TODO: maybe unify with prelegalizer pass.
getConstFromIntrinsic(Register Reg,MachineRegisterInfo * MRI)340 static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
341   MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
342   assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) &&
343          DefMI->getOperand(2).isReg());
344   MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg());
345   assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT &&
346          DefMI2->getOperand(1).isCImm());
347   return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue();
348 }
349 
350 // Return type of the instruction result from spv_assign_type intrinsic.
351 // TODO: maybe unify with prelegalizer pass.
getMachineInstrType(MachineInstr * MI)352 static const Type *getMachineInstrType(MachineInstr *MI) {
353   MachineInstr *NextMI = MI->getNextNode();
354   if (!NextMI)
355     return nullptr;
356   if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name))
357     if ((NextMI = NextMI->getNextNode()) == nullptr)
358       return nullptr;
359   Register ValueReg = MI->getOperand(0).getReg();
360   if ((!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) &&
361        !isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_ptr_type)) ||
362       NextMI->getOperand(1).getReg() != ValueReg)
363     return nullptr;
364   Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
365   assert(Ty && "Type is expected");
366   return Ty;
367 }
368 
getBlockStructType(Register ParamReg,MachineRegisterInfo * MRI)369 static const Type *getBlockStructType(Register ParamReg,
370                                       MachineRegisterInfo *MRI) {
371   // In principle, this information should be passed to us from Clang via
372   // an elementtype attribute. However, said attribute requires that
373   // the function call be an intrinsic, which is not. Instead, we rely on being
374   // able to trace this to the declaration of a variable: OpenCL C specification
375   // section 6.12.5 should guarantee that we can do this.
376   MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
377   if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
378     return MI->getOperand(1).getGlobal()->getType();
379   assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
380          "Blocks in OpenCL C must be traceable to allocation site");
381   return getMachineInstrType(MI);
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // Helper functions for building misc instructions
386 //===----------------------------------------------------------------------===//
387 
388 /// Helper function building either a resulting scalar or vector bool register
389 /// depending on the expected \p ResultType.
390 ///
391 /// \returns Tuple of the resulting register and its type.
392 static std::tuple<Register, SPIRVType *>
buildBoolRegister(MachineIRBuilder & MIRBuilder,const SPIRVType * ResultType,SPIRVGlobalRegistry * GR)393 buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
394                   SPIRVGlobalRegistry *GR) {
395   LLT Type;
396   SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
397 
398   if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
399     unsigned VectorElements = ResultType->getOperand(2).getImm();
400     BoolType =
401         GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder);
402     const FixedVectorType *LLVMVectorType =
403         cast<FixedVectorType>(GR->getTypeForSPIRVType(BoolType));
404     Type = LLT::vector(LLVMVectorType->getElementCount(), 1);
405   } else {
406     Type = LLT::scalar(1);
407   }
408 
409   Register ResultRegister =
410       MIRBuilder.getMRI()->createGenericVirtualRegister(Type);
411   MIRBuilder.getMRI()->setRegClass(ResultRegister, &SPIRV::IDRegClass);
412   GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF());
413   return std::make_tuple(ResultRegister, BoolType);
414 }
415 
416 /// Helper function for building either a vector or scalar select instruction
417 /// depending on the expected \p ResultType.
buildSelectInst(MachineIRBuilder & MIRBuilder,Register ReturnRegister,Register SourceRegister,const SPIRVType * ReturnType,SPIRVGlobalRegistry * GR)418 static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
419                             Register ReturnRegister, Register SourceRegister,
420                             const SPIRVType *ReturnType,
421                             SPIRVGlobalRegistry *GR) {
422   Register TrueConst, FalseConst;
423 
424   if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
425     unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
426     uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
427     TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
428     FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
429   } else {
430     TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
431     FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
432   }
433   return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst,
434                                 FalseConst);
435 }
436 
437 /// Helper function for building a load instruction loading into the
438 /// \p DestinationReg.
buildLoadInst(SPIRVType * BaseType,Register PtrRegister,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,LLT LowLevelType,Register DestinationReg=Register (0))439 static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
440                               MachineIRBuilder &MIRBuilder,
441                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
442                               Register DestinationReg = Register(0)) {
443   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
444   if (!DestinationReg.isValid()) {
445     DestinationReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
446     MRI->setType(DestinationReg, LLT::scalar(32));
447     GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
448   }
449   // TODO: consider using correct address space and alignment (p0 is canonical
450   // type for selection though).
451   MachinePointerInfo PtrInfo = MachinePointerInfo();
452   MIRBuilder.buildLoad(DestinationReg, PtrRegister, PtrInfo, Align());
453   return DestinationReg;
454 }
455 
456 /// Helper function for building a load instruction for loading a builtin global
457 /// variable of \p BuiltinValue value.
buildBuiltinVariableLoad(MachineIRBuilder & MIRBuilder,SPIRVType * VariableType,SPIRVGlobalRegistry * GR,SPIRV::BuiltIn::BuiltIn BuiltinValue,LLT LLType,Register Reg=Register (0),bool isConst=true,bool hasLinkageTy=true)458 static Register buildBuiltinVariableLoad(
459     MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
460     SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
461     Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
462   Register NewRegister =
463       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
464   MIRBuilder.getMRI()->setType(NewRegister,
465                                LLT::pointer(0, GR->getPointerSize()));
466   SPIRVType *PtrType = GR->getOrCreateSPIRVPointerType(
467       VariableType, MIRBuilder, SPIRV::StorageClass::Input);
468   GR->assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
469 
470   // Set up the global OpVariable with the necessary builtin decorations.
471   Register Variable = GR->buildGlobalVariable(
472       NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
473       SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
474       /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
475       false);
476 
477   // Load the value from the global variable.
478   Register LoadedRegister =
479       buildLoadInst(VariableType, Variable, MIRBuilder, GR, LLType, Reg);
480   MIRBuilder.getMRI()->setType(LoadedRegister, LLType);
481   return LoadedRegister;
482 }
483 
484 /// Helper external function for inserting ASSIGN_TYPE instuction between \p Reg
485 /// and its definition, set the new register as a destination of the definition,
486 /// assign SPIRVType to both registers. If SpirvTy is provided, use it as
487 /// SPIRVType in ASSIGN_TYPE, otherwise create it from \p Ty. Defined in
488 /// SPIRVPreLegalizer.cpp.
489 extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
490                                   SPIRVGlobalRegistry *GR,
491                                   MachineIRBuilder &MIB,
492                                   MachineRegisterInfo &MRI);
493 
494 // TODO: Move to TableGen.
495 static SPIRV::MemorySemantics::MemorySemantics
getSPIRVMemSemantics(std::memory_order MemOrder)496 getSPIRVMemSemantics(std::memory_order MemOrder) {
497   switch (MemOrder) {
498   case std::memory_order::memory_order_relaxed:
499     return SPIRV::MemorySemantics::None;
500   case std::memory_order::memory_order_acquire:
501     return SPIRV::MemorySemantics::Acquire;
502   case std::memory_order::memory_order_release:
503     return SPIRV::MemorySemantics::Release;
504   case std::memory_order::memory_order_acq_rel:
505     return SPIRV::MemorySemantics::AcquireRelease;
506   case std::memory_order::memory_order_seq_cst:
507     return SPIRV::MemorySemantics::SequentiallyConsistent;
508   default:
509     report_fatal_error("Unknown CL memory scope");
510   }
511 }
512 
getSPIRVScope(SPIRV::CLMemoryScope ClScope)513 static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
514   switch (ClScope) {
515   case SPIRV::CLMemoryScope::memory_scope_work_item:
516     return SPIRV::Scope::Invocation;
517   case SPIRV::CLMemoryScope::memory_scope_work_group:
518     return SPIRV::Scope::Workgroup;
519   case SPIRV::CLMemoryScope::memory_scope_device:
520     return SPIRV::Scope::Device;
521   case SPIRV::CLMemoryScope::memory_scope_all_svm_devices:
522     return SPIRV::Scope::CrossDevice;
523   case SPIRV::CLMemoryScope::memory_scope_sub_group:
524     return SPIRV::Scope::Subgroup;
525   }
526   report_fatal_error("Unknown CL memory scope");
527 }
528 
buildConstantIntReg(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,unsigned BitWidth=32)529 static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
530                                     SPIRVGlobalRegistry *GR,
531                                     unsigned BitWidth = 32) {
532   SPIRVType *IntType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
533   return GR->buildConstantInt(Val, MIRBuilder, IntType);
534 }
535 
buildScopeReg(Register CLScopeRegister,SPIRV::Scope::Scope Scope,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,MachineRegisterInfo * MRI)536 static Register buildScopeReg(Register CLScopeRegister,
537                               SPIRV::Scope::Scope Scope,
538                               MachineIRBuilder &MIRBuilder,
539                               SPIRVGlobalRegistry *GR,
540                               MachineRegisterInfo *MRI) {
541   if (CLScopeRegister.isValid()) {
542     auto CLScope =
543         static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
544     Scope = getSPIRVScope(CLScope);
545 
546     if (CLScope == static_cast<unsigned>(Scope)) {
547       MRI->setRegClass(CLScopeRegister, &SPIRV::IDRegClass);
548       return CLScopeRegister;
549     }
550   }
551   return buildConstantIntReg(Scope, MIRBuilder, GR);
552 }
553 
buildMemSemanticsReg(Register SemanticsRegister,Register PtrRegister,unsigned & Semantics,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)554 static Register buildMemSemanticsReg(Register SemanticsRegister,
555                                      Register PtrRegister, unsigned &Semantics,
556                                      MachineIRBuilder &MIRBuilder,
557                                      SPIRVGlobalRegistry *GR) {
558   if (SemanticsRegister.isValid()) {
559     MachineRegisterInfo *MRI = MIRBuilder.getMRI();
560     std::memory_order Order =
561         static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
562     Semantics =
563         getSPIRVMemSemantics(Order) |
564         getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
565 
566     if (Order == Semantics) {
567       MRI->setRegClass(SemanticsRegister, &SPIRV::IDRegClass);
568       return SemanticsRegister;
569     }
570   }
571   return buildConstantIntReg(Semantics, MIRBuilder, GR);
572 }
573 
buildOpFromWrapper(MachineIRBuilder & MIRBuilder,unsigned Opcode,const SPIRV::IncomingCall * Call,Register TypeReg,ArrayRef<uint32_t> ImmArgs={})574 static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
575                                const SPIRV::IncomingCall *Call,
576                                Register TypeReg,
577                                ArrayRef<uint32_t> ImmArgs = {}) {
578   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
579   auto MIB = MIRBuilder.buildInstr(Opcode);
580   if (TypeReg.isValid())
581     MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
582   unsigned Sz = Call->Arguments.size() - ImmArgs.size();
583   for (unsigned i = 0; i < Sz; ++i) {
584     Register ArgReg = Call->Arguments[i];
585     if (!MRI->getRegClassOrNull(ArgReg))
586       MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
587     MIB.addUse(ArgReg);
588   }
589   for (uint32_t ImmArg : ImmArgs)
590     MIB.addImm(ImmArg);
591   return true;
592 }
593 
594 /// Helper function for translating atomic init to OpStore.
buildAtomicInitInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)595 static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
596                                 MachineIRBuilder &MIRBuilder) {
597   if (Call->isSpirvOp())
598     return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
599 
600   assert(Call->Arguments.size() == 2 &&
601          "Need 2 arguments for atomic init translation");
602   MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
603   MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
604   MIRBuilder.buildInstr(SPIRV::OpStore)
605       .addUse(Call->Arguments[0])
606       .addUse(Call->Arguments[1]);
607   return true;
608 }
609 
610 /// Helper function for building an atomic load instruction.
buildAtomicLoadInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)611 static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
612                                 MachineIRBuilder &MIRBuilder,
613                                 SPIRVGlobalRegistry *GR) {
614   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
615   if (Call->isSpirvOp())
616     return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicLoad, Call, TypeReg);
617 
618   Register PtrRegister = Call->Arguments[0];
619   MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
620   // TODO: if true insert call to __translate_ocl_memory_sccope before
621   // OpAtomicLoad and the function implementation. We can use Translator's
622   // output for transcoding/atomic_explicit_arguments.cl as an example.
623   Register ScopeRegister;
624   if (Call->Arguments.size() > 1) {
625     ScopeRegister = Call->Arguments[1];
626     MIRBuilder.getMRI()->setRegClass(ScopeRegister, &SPIRV::IDRegClass);
627   } else
628     ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
629 
630   Register MemSemanticsReg;
631   if (Call->Arguments.size() > 2) {
632     // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad.
633     MemSemanticsReg = Call->Arguments[2];
634     MIRBuilder.getMRI()->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
635   } else {
636     int Semantics =
637         SPIRV::MemorySemantics::SequentiallyConsistent |
638         getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
639     MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
640   }
641 
642   MIRBuilder.buildInstr(SPIRV::OpAtomicLoad)
643       .addDef(Call->ReturnRegister)
644       .addUse(TypeReg)
645       .addUse(PtrRegister)
646       .addUse(ScopeRegister)
647       .addUse(MemSemanticsReg);
648   return true;
649 }
650 
651 /// Helper function for building an atomic store instruction.
buildAtomicStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)652 static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
653                                  MachineIRBuilder &MIRBuilder,
654                                  SPIRVGlobalRegistry *GR) {
655   if (Call->isSpirvOp())
656     return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
657 
658   Register ScopeRegister =
659       buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
660   Register PtrRegister = Call->Arguments[0];
661   MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
662   int Semantics =
663       SPIRV::MemorySemantics::SequentiallyConsistent |
664       getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
665   Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
666   MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
667   MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
668       .addUse(PtrRegister)
669       .addUse(ScopeRegister)
670       .addUse(MemSemanticsReg)
671       .addUse(Call->Arguments[1]);
672   return true;
673 }
674 
675 /// Helper function for building an atomic compare-exchange instruction.
buildAtomicCompareExchangeInst(const SPIRV::IncomingCall * Call,const SPIRV::DemangledBuiltin * Builtin,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)676 static bool buildAtomicCompareExchangeInst(
677     const SPIRV::IncomingCall *Call, const SPIRV::DemangledBuiltin *Builtin,
678     unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
679   if (Call->isSpirvOp())
680     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
681                               GR->getSPIRVTypeID(Call->ReturnType));
682 
683   bool IsCmpxchg = Call->Builtin->Name.contains("cmpxchg");
684   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
685 
686   Register ObjectPtr = Call->Arguments[0];   // Pointer (volatile A *object.)
687   Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected).
688   Register Desired = Call->Arguments[2];     // Value (C Desired).
689   MRI->setRegClass(ObjectPtr, &SPIRV::IDRegClass);
690   MRI->setRegClass(ExpectedArg, &SPIRV::IDRegClass);
691   MRI->setRegClass(Desired, &SPIRV::IDRegClass);
692   SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired);
693   LLT DesiredLLT = MRI->getType(Desired);
694 
695   assert(GR->getSPIRVTypeForVReg(ObjectPtr)->getOpcode() ==
696          SPIRV::OpTypePointer);
697   unsigned ExpectedType = GR->getSPIRVTypeForVReg(ExpectedArg)->getOpcode();
698   (void)ExpectedType;
699   assert(IsCmpxchg ? ExpectedType == SPIRV::OpTypeInt
700                    : ExpectedType == SPIRV::OpTypePointer);
701   assert(GR->isScalarOfType(Desired, SPIRV::OpTypeInt));
702 
703   SPIRVType *SpvObjectPtrTy = GR->getSPIRVTypeForVReg(ObjectPtr);
704   assert(SpvObjectPtrTy->getOperand(2).isReg() && "SPIRV type is expected");
705   auto StorageClass = static_cast<SPIRV::StorageClass::StorageClass>(
706       SpvObjectPtrTy->getOperand(1).getImm());
707   auto MemSemStorage = getMemSemanticsForStorageClass(StorageClass);
708 
709   Register MemSemEqualReg;
710   Register MemSemUnequalReg;
711   uint64_t MemSemEqual =
712       IsCmpxchg
713           ? SPIRV::MemorySemantics::None
714           : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
715   uint64_t MemSemUnequal =
716       IsCmpxchg
717           ? SPIRV::MemorySemantics::None
718           : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
719   if (Call->Arguments.size() >= 4) {
720     assert(Call->Arguments.size() >= 5 &&
721            "Need 5+ args for explicit atomic cmpxchg");
722     auto MemOrdEq =
723         static_cast<std::memory_order>(getIConstVal(Call->Arguments[3], MRI));
724     auto MemOrdNeq =
725         static_cast<std::memory_order>(getIConstVal(Call->Arguments[4], MRI));
726     MemSemEqual = getSPIRVMemSemantics(MemOrdEq) | MemSemStorage;
727     MemSemUnequal = getSPIRVMemSemantics(MemOrdNeq) | MemSemStorage;
728     if (MemOrdEq == MemSemEqual)
729       MemSemEqualReg = Call->Arguments[3];
730     if (MemOrdNeq == MemSemEqual)
731       MemSemUnequalReg = Call->Arguments[4];
732     MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
733     MRI->setRegClass(Call->Arguments[4], &SPIRV::IDRegClass);
734   }
735   if (!MemSemEqualReg.isValid())
736     MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR);
737   if (!MemSemUnequalReg.isValid())
738     MemSemUnequalReg = buildConstantIntReg(MemSemUnequal, MIRBuilder, GR);
739 
740   Register ScopeReg;
741   auto Scope = IsCmpxchg ? SPIRV::Scope::Workgroup : SPIRV::Scope::Device;
742   if (Call->Arguments.size() >= 6) {
743     assert(Call->Arguments.size() == 6 &&
744            "Extra args for explicit atomic cmpxchg");
745     auto ClScope = static_cast<SPIRV::CLMemoryScope>(
746         getIConstVal(Call->Arguments[5], MRI));
747     Scope = getSPIRVScope(ClScope);
748     if (ClScope == static_cast<unsigned>(Scope))
749       ScopeReg = Call->Arguments[5];
750     MRI->setRegClass(Call->Arguments[5], &SPIRV::IDRegClass);
751   }
752   if (!ScopeReg.isValid())
753     ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
754 
755   Register Expected = IsCmpxchg
756                           ? ExpectedArg
757                           : buildLoadInst(SpvDesiredTy, ExpectedArg, MIRBuilder,
758                                           GR, LLT::scalar(32));
759   MRI->setType(Expected, DesiredLLT);
760   Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
761                             : Call->ReturnRegister;
762   if (!MRI->getRegClassOrNull(Tmp))
763     MRI->setRegClass(Tmp, &SPIRV::IDRegClass);
764   GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF());
765 
766   SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
767   MIRBuilder.buildInstr(Opcode)
768       .addDef(Tmp)
769       .addUse(GR->getSPIRVTypeID(IntTy))
770       .addUse(ObjectPtr)
771       .addUse(ScopeReg)
772       .addUse(MemSemEqualReg)
773       .addUse(MemSemUnequalReg)
774       .addUse(Desired)
775       .addUse(Expected);
776   if (!IsCmpxchg) {
777     MIRBuilder.buildInstr(SPIRV::OpStore).addUse(ExpectedArg).addUse(Tmp);
778     MIRBuilder.buildICmp(CmpInst::ICMP_EQ, Call->ReturnRegister, Tmp, Expected);
779   }
780   return true;
781 }
782 
783 /// Helper function for building atomic instructions.
buildAtomicRMWInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)784 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
785                                MachineIRBuilder &MIRBuilder,
786                                SPIRVGlobalRegistry *GR) {
787   if (Call->isSpirvOp())
788     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
789                               GR->getSPIRVTypeID(Call->ReturnType));
790 
791   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
792   Register ScopeRegister =
793       Call->Arguments.size() >= 4 ? Call->Arguments[3] : Register();
794 
795   assert(Call->Arguments.size() <= 4 &&
796          "Too many args for explicit atomic RMW");
797   ScopeRegister = buildScopeReg(ScopeRegister, SPIRV::Scope::Workgroup,
798                                 MIRBuilder, GR, MRI);
799 
800   Register PtrRegister = Call->Arguments[0];
801   unsigned Semantics = SPIRV::MemorySemantics::None;
802   MRI->setRegClass(PtrRegister, &SPIRV::IDRegClass);
803   Register MemSemanticsReg =
804       Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
805   MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
806                                          Semantics, MIRBuilder, GR);
807   MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
808   Register ValueReg = Call->Arguments[1];
809   Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
810   // support cl_ext_float_atomics
811   if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
812     if (Opcode == SPIRV::OpAtomicIAdd) {
813       Opcode = SPIRV::OpAtomicFAddEXT;
814     } else if (Opcode == SPIRV::OpAtomicISub) {
815       // Translate OpAtomicISub applied to a floating type argument to
816       // OpAtomicFAddEXT with the negative value operand
817       Opcode = SPIRV::OpAtomicFAddEXT;
818       Register NegValueReg =
819           MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
820       MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass);
821       GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
822                                 MIRBuilder.getMF());
823       MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
824           .addDef(NegValueReg)
825           .addUse(ValueReg);
826       insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
827                         MIRBuilder.getMF().getRegInfo());
828       ValueReg = NegValueReg;
829     }
830   }
831   MIRBuilder.buildInstr(Opcode)
832       .addDef(Call->ReturnRegister)
833       .addUse(ValueTypeReg)
834       .addUse(PtrRegister)
835       .addUse(ScopeRegister)
836       .addUse(MemSemanticsReg)
837       .addUse(ValueReg);
838   return true;
839 }
840 
841 /// Helper function for building an atomic floating-type instruction.
buildAtomicFloatingRMWInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)842 static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
843                                        unsigned Opcode,
844                                        MachineIRBuilder &MIRBuilder,
845                                        SPIRVGlobalRegistry *GR) {
846   assert(Call->Arguments.size() == 4 &&
847          "Wrong number of atomic floating-type builtin");
848 
849   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
850 
851   Register PtrReg = Call->Arguments[0];
852   MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);
853 
854   Register ScopeReg = Call->Arguments[1];
855   MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
856 
857   Register MemSemanticsReg = Call->Arguments[2];
858   MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
859 
860   Register ValueReg = Call->Arguments[3];
861   MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
862 
863   MIRBuilder.buildInstr(Opcode)
864       .addDef(Call->ReturnRegister)
865       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
866       .addUse(PtrReg)
867       .addUse(ScopeReg)
868       .addUse(MemSemanticsReg)
869       .addUse(ValueReg);
870   return true;
871 }
872 
873 /// Helper function for building atomic flag instructions (e.g.
874 /// OpAtomicFlagTestAndSet).
buildAtomicFlagInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)875 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
876                                 unsigned Opcode, MachineIRBuilder &MIRBuilder,
877                                 SPIRVGlobalRegistry *GR) {
878   bool IsSet = Opcode == SPIRV::OpAtomicFlagTestAndSet;
879   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
880   if (Call->isSpirvOp())
881     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
882                               IsSet ? TypeReg : Register(0));
883 
884   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
885   Register PtrRegister = Call->Arguments[0];
886   unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent;
887   Register MemSemanticsReg =
888       Call->Arguments.size() >= 2 ? Call->Arguments[1] : Register();
889   MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
890                                          Semantics, MIRBuilder, GR);
891 
892   assert((Opcode != SPIRV::OpAtomicFlagClear ||
893           (Semantics != SPIRV::MemorySemantics::Acquire &&
894            Semantics != SPIRV::MemorySemantics::AcquireRelease)) &&
895          "Invalid memory order argument!");
896 
897   Register ScopeRegister =
898       Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
899   ScopeRegister =
900       buildScopeReg(ScopeRegister, SPIRV::Scope::Device, MIRBuilder, GR, MRI);
901 
902   auto MIB = MIRBuilder.buildInstr(Opcode);
903   if (IsSet)
904     MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
905 
906   MIB.addUse(PtrRegister).addUse(ScopeRegister).addUse(MemSemanticsReg);
907   return true;
908 }
909 
910 /// Helper function for building barriers, i.e., memory/control ordering
911 /// operations.
buildBarrierInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)912 static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
913                              MachineIRBuilder &MIRBuilder,
914                              SPIRVGlobalRegistry *GR) {
915   if (Call->isSpirvOp())
916     return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
917 
918   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
919   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
920   unsigned MemSemantics = SPIRV::MemorySemantics::None;
921 
922   if (MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE)
923     MemSemantics |= SPIRV::MemorySemantics::WorkgroupMemory;
924 
925   if (MemFlags & SPIRV::CLK_GLOBAL_MEM_FENCE)
926     MemSemantics |= SPIRV::MemorySemantics::CrossWorkgroupMemory;
927 
928   if (MemFlags & SPIRV::CLK_IMAGE_MEM_FENCE)
929     MemSemantics |= SPIRV::MemorySemantics::ImageMemory;
930 
931   if (Opcode == SPIRV::OpMemoryBarrier) {
932     std::memory_order MemOrder =
933         static_cast<std::memory_order>(getIConstVal(Call->Arguments[1], MRI));
934     MemSemantics = getSPIRVMemSemantics(MemOrder) | MemSemantics;
935   } else {
936     MemSemantics |= SPIRV::MemorySemantics::SequentiallyConsistent;
937   }
938 
939   Register MemSemanticsReg;
940   if (MemFlags == MemSemantics) {
941     MemSemanticsReg = Call->Arguments[0];
942     MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
943   } else
944     MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR);
945 
946   Register ScopeReg;
947   SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
948   SPIRV::Scope::Scope MemScope = Scope;
949   if (Call->Arguments.size() >= 2) {
950     assert(
951         ((Opcode != SPIRV::OpMemoryBarrier && Call->Arguments.size() == 2) ||
952          (Opcode == SPIRV::OpMemoryBarrier && Call->Arguments.size() == 3)) &&
953         "Extra args for explicitly scoped barrier");
954     Register ScopeArg = (Opcode == SPIRV::OpMemoryBarrier) ? Call->Arguments[2]
955                                                            : Call->Arguments[1];
956     SPIRV::CLMemoryScope CLScope =
957         static_cast<SPIRV::CLMemoryScope>(getIConstVal(ScopeArg, MRI));
958     MemScope = getSPIRVScope(CLScope);
959     if (!(MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE) ||
960         (Opcode == SPIRV::OpMemoryBarrier))
961       Scope = MemScope;
962 
963     if (CLScope == static_cast<unsigned>(Scope)) {
964       ScopeReg = Call->Arguments[1];
965       MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
966     }
967   }
968 
969   if (!ScopeReg.isValid())
970     ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
971 
972   auto MIB = MIRBuilder.buildInstr(Opcode).addUse(ScopeReg);
973   if (Opcode != SPIRV::OpMemoryBarrier)
974     MIB.addUse(buildConstantIntReg(MemScope, MIRBuilder, GR));
975   MIB.addUse(MemSemanticsReg);
976   return true;
977 }
978 
getNumComponentsForDim(SPIRV::Dim::Dim dim)979 static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
980   switch (dim) {
981   case SPIRV::Dim::DIM_1D:
982   case SPIRV::Dim::DIM_Buffer:
983     return 1;
984   case SPIRV::Dim::DIM_2D:
985   case SPIRV::Dim::DIM_Cube:
986   case SPIRV::Dim::DIM_Rect:
987     return 2;
988   case SPIRV::Dim::DIM_3D:
989     return 3;
990   default:
991     report_fatal_error("Cannot get num components for given Dim");
992   }
993 }
994 
995 /// Helper function for obtaining the number of size components.
getNumSizeComponents(SPIRVType * imgType)996 static unsigned getNumSizeComponents(SPIRVType *imgType) {
997   assert(imgType->getOpcode() == SPIRV::OpTypeImage);
998   auto dim = static_cast<SPIRV::Dim::Dim>(imgType->getOperand(2).getImm());
999   unsigned numComps = getNumComponentsForDim(dim);
1000   bool arrayed = imgType->getOperand(4).getImm() == 1;
1001   return arrayed ? numComps + 1 : numComps;
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // Implementation functions for each builtin group
1006 //===----------------------------------------------------------------------===//
1007 
generateExtInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1008 static bool generateExtInst(const SPIRV::IncomingCall *Call,
1009                             MachineIRBuilder &MIRBuilder,
1010                             SPIRVGlobalRegistry *GR) {
1011   // Lookup the extended instruction number in the TableGen records.
1012   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1013   uint32_t Number =
1014       SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
1015 
1016   // Build extended instruction.
1017   auto MIB =
1018       MIRBuilder.buildInstr(SPIRV::OpExtInst)
1019           .addDef(Call->ReturnRegister)
1020           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1021           .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
1022           .addImm(Number);
1023 
1024   for (auto Argument : Call->Arguments)
1025     MIB.addUse(Argument);
1026   return true;
1027 }
1028 
generateRelationalInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1029 static bool generateRelationalInst(const SPIRV::IncomingCall *Call,
1030                                    MachineIRBuilder &MIRBuilder,
1031                                    SPIRVGlobalRegistry *GR) {
1032   // Lookup the instruction opcode in the TableGen records.
1033   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1034   unsigned Opcode =
1035       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1036 
1037   Register CompareRegister;
1038   SPIRVType *RelationType;
1039   std::tie(CompareRegister, RelationType) =
1040       buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
1041 
1042   // Build relational instruction.
1043   auto MIB = MIRBuilder.buildInstr(Opcode)
1044                  .addDef(CompareRegister)
1045                  .addUse(GR->getSPIRVTypeID(RelationType));
1046 
1047   for (auto Argument : Call->Arguments)
1048     MIB.addUse(Argument);
1049 
1050   // Build select instruction.
1051   return buildSelectInst(MIRBuilder, Call->ReturnRegister, CompareRegister,
1052                          Call->ReturnType, GR);
1053 }
1054 
generateGroupInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1055 static bool generateGroupInst(const SPIRV::IncomingCall *Call,
1056                               MachineIRBuilder &MIRBuilder,
1057                               SPIRVGlobalRegistry *GR) {
1058   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1059   const SPIRV::GroupBuiltin *GroupBuiltin =
1060       SPIRV::lookupGroupBuiltin(Builtin->Name);
1061 
1062   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1063   if (Call->isSpirvOp()) {
1064     if (GroupBuiltin->NoGroupOperation)
1065       return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
1066                                 GR->getSPIRVTypeID(Call->ReturnType));
1067 
1068     // Group Operation is a literal
1069     Register GroupOpReg = Call->Arguments[1];
1070     const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
1071     if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT)
1072       report_fatal_error(
1073           "Group Operation parameter must be an integer constant");
1074     uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
1075     Register ScopeReg = Call->Arguments[0];
1076     if (!MRI->getRegClassOrNull(ScopeReg))
1077       MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
1078     auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1079                    .addDef(Call->ReturnRegister)
1080                    .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1081                    .addUse(ScopeReg)
1082                    .addImm(GrpOp);
1083     for (unsigned i = 2; i < Call->Arguments.size(); ++i) {
1084       Register ArgReg = Call->Arguments[i];
1085       if (!MRI->getRegClassOrNull(ArgReg))
1086         MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
1087       MIB.addUse(ArgReg);
1088     }
1089     return true;
1090   }
1091 
1092   Register Arg0;
1093   if (GroupBuiltin->HasBoolArg) {
1094     Register ConstRegister = Call->Arguments[0];
1095     auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI);
1096     (void)ArgInstruction;
1097     // TODO: support non-constant bool values.
1098     assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT &&
1099            "Only constant bool value args are supported");
1100     if (GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() !=
1101         SPIRV::OpTypeBool)
1102       Arg0 = GR->buildConstantInt(getIConstVal(ConstRegister, MRI), MIRBuilder,
1103                                   GR->getOrCreateSPIRVBoolType(MIRBuilder));
1104   }
1105 
1106   Register GroupResultRegister = Call->ReturnRegister;
1107   SPIRVType *GroupResultType = Call->ReturnType;
1108 
1109   // TODO: maybe we need to check whether the result type is already boolean
1110   // and in this case do not insert select instruction.
1111   const bool HasBoolReturnTy =
1112       GroupBuiltin->IsElect || GroupBuiltin->IsAllOrAny ||
1113       GroupBuiltin->IsAllEqual || GroupBuiltin->IsLogical ||
1114       GroupBuiltin->IsInverseBallot || GroupBuiltin->IsBallotBitExtract;
1115 
1116   if (HasBoolReturnTy)
1117     std::tie(GroupResultRegister, GroupResultType) =
1118         buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
1119 
1120   auto Scope = Builtin->Name.starts_with("sub_group") ? SPIRV::Scope::Subgroup
1121                                                       : SPIRV::Scope::Workgroup;
1122   Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
1123 
1124   // Build work/sub group instruction.
1125   auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1126                  .addDef(GroupResultRegister)
1127                  .addUse(GR->getSPIRVTypeID(GroupResultType))
1128                  .addUse(ScopeRegister);
1129 
1130   if (!GroupBuiltin->NoGroupOperation)
1131     MIB.addImm(GroupBuiltin->GroupOperation);
1132   if (Call->Arguments.size() > 0) {
1133     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
1134     MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1135     for (unsigned i = 1; i < Call->Arguments.size(); i++) {
1136       MIB.addUse(Call->Arguments[i]);
1137       MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
1138     }
1139   }
1140 
1141   // Build select instruction.
1142   if (HasBoolReturnTy)
1143     buildSelectInst(MIRBuilder, Call->ReturnRegister, GroupResultRegister,
1144                     Call->ReturnType, GR);
1145   return true;
1146 }
1147 
generateIntelSubgroupsInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1148 static bool generateIntelSubgroupsInst(const SPIRV::IncomingCall *Call,
1149                                        MachineIRBuilder &MIRBuilder,
1150                                        SPIRVGlobalRegistry *GR) {
1151   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1152   MachineFunction &MF = MIRBuilder.getMF();
1153   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1154   if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1155     std::string DiagMsg = std::string(Builtin->Name) +
1156                           ": the builtin requires the following SPIR-V "
1157                           "extension: SPV_INTEL_subgroups";
1158     report_fatal_error(DiagMsg.c_str(), false);
1159   }
1160   const SPIRV::IntelSubgroupsBuiltin *IntelSubgroups =
1161       SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name);
1162 
1163   uint32_t OpCode = IntelSubgroups->Opcode;
1164   if (Call->isSpirvOp()) {
1165     bool IsSet = OpCode != SPIRV::OpSubgroupBlockWriteINTEL &&
1166                  OpCode != SPIRV::OpSubgroupImageBlockWriteINTEL;
1167     return buildOpFromWrapper(MIRBuilder, OpCode, Call,
1168                               IsSet ? GR->getSPIRVTypeID(Call->ReturnType)
1169                                     : Register(0));
1170   }
1171 
1172   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1173   if (IntelSubgroups->IsBlock) {
1174     // Minimal number or arguments set in TableGen records is 1
1175     if (SPIRVType *Arg0Type = GR->getSPIRVTypeForVReg(Call->Arguments[0])) {
1176       if (Arg0Type->getOpcode() == SPIRV::OpTypeImage) {
1177         // TODO: add required validation from the specification:
1178         // "'Image' must be an object whose type is OpTypeImage with a 'Sampled'
1179         // operand of 0 or 2. If the 'Sampled' operand is 2, then some
1180         // dimensions require a capability."
1181         switch (OpCode) {
1182         case SPIRV::OpSubgroupBlockReadINTEL:
1183           OpCode = SPIRV::OpSubgroupImageBlockReadINTEL;
1184           break;
1185         case SPIRV::OpSubgroupBlockWriteINTEL:
1186           OpCode = SPIRV::OpSubgroupImageBlockWriteINTEL;
1187           break;
1188         }
1189       }
1190     }
1191   }
1192 
1193   // TODO: opaque pointers types should be eventually resolved in such a way
1194   // that validation of block read is enabled with respect to the following
1195   // specification requirement:
1196   // "'Result Type' may be a scalar or vector type, and its component type must
1197   // be equal to the type pointed to by 'Ptr'."
1198   // For example, function parameter type should not be default i8 pointer, but
1199   // depend on the result type of the instruction where it is used as a pointer
1200   // argument of OpSubgroupBlockReadINTEL
1201 
1202   // Build Intel subgroups instruction
1203   MachineInstrBuilder MIB =
1204       IntelSubgroups->IsWrite
1205           ? MIRBuilder.buildInstr(OpCode)
1206           : MIRBuilder.buildInstr(OpCode)
1207                 .addDef(Call->ReturnRegister)
1208                 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1209   for (size_t i = 0; i < Call->Arguments.size(); ++i) {
1210     MIB.addUse(Call->Arguments[i]);
1211     MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
1212   }
1213 
1214   return true;
1215 }
1216 
generateGroupUniformInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1217 static bool generateGroupUniformInst(const SPIRV::IncomingCall *Call,
1218                                      MachineIRBuilder &MIRBuilder,
1219                                      SPIRVGlobalRegistry *GR) {
1220   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1221   MachineFunction &MF = MIRBuilder.getMF();
1222   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1223   if (!ST->canUseExtension(
1224           SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1225     std::string DiagMsg = std::string(Builtin->Name) +
1226                           ": the builtin requires the following SPIR-V "
1227                           "extension: SPV_KHR_uniform_group_instructions";
1228     report_fatal_error(DiagMsg.c_str(), false);
1229   }
1230   const SPIRV::GroupUniformBuiltin *GroupUniform =
1231       SPIRV::lookupGroupUniformBuiltin(Builtin->Name);
1232   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1233 
1234   Register GroupResultReg = Call->ReturnRegister;
1235   MRI->setRegClass(GroupResultReg, &SPIRV::IDRegClass);
1236 
1237   // Scope
1238   Register ScopeReg = Call->Arguments[0];
1239   MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
1240 
1241   // Group Operation
1242   Register ConstGroupOpReg = Call->Arguments[1];
1243   const MachineInstr *Const = getDefInstrMaybeConstant(ConstGroupOpReg, MRI);
1244   if (!Const || Const->getOpcode() != TargetOpcode::G_CONSTANT)
1245     report_fatal_error(
1246         "expect a constant group operation for a uniform group instruction",
1247         false);
1248   const MachineOperand &ConstOperand = Const->getOperand(1);
1249   if (!ConstOperand.isCImm())
1250     report_fatal_error("uniform group instructions: group operation must be an "
1251                        "integer constant",
1252                        false);
1253 
1254   // Value
1255   Register ValueReg = Call->Arguments[2];
1256   MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
1257 
1258   auto MIB = MIRBuilder.buildInstr(GroupUniform->Opcode)
1259                  .addDef(GroupResultReg)
1260                  .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1261                  .addUse(ScopeReg);
1262   addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1263   MIB.addUse(ValueReg);
1264 
1265   return true;
1266 }
1267 
generateKernelClockInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1268 static bool generateKernelClockInst(const SPIRV::IncomingCall *Call,
1269                                     MachineIRBuilder &MIRBuilder,
1270                                     SPIRVGlobalRegistry *GR) {
1271   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1272   MachineFunction &MF = MIRBuilder.getMF();
1273   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1274   if (!ST->canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) {
1275     std::string DiagMsg = std::string(Builtin->Name) +
1276                           ": the builtin requires the following SPIR-V "
1277                           "extension: SPV_KHR_shader_clock";
1278     report_fatal_error(DiagMsg.c_str(), false);
1279   }
1280 
1281   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1282   Register ResultReg = Call->ReturnRegister;
1283   MRI->setRegClass(ResultReg, &SPIRV::IDRegClass);
1284 
1285   // Deduce the `Scope` operand from the builtin function name.
1286   SPIRV::Scope::Scope ScopeArg =
1287       StringSwitch<SPIRV::Scope::Scope>(Builtin->Name)
1288           .EndsWith("device", SPIRV::Scope::Scope::Device)
1289           .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
1290           .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
1291   Register ScopeReg = buildConstantIntReg(ScopeArg, MIRBuilder, GR);
1292 
1293   MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
1294       .addDef(ResultReg)
1295       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1296       .addUse(ScopeReg);
1297 
1298   return true;
1299 }
1300 
1301 // These queries ask for a single size_t result for a given dimension index, e.g
1302 // size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
1303 // these values are all vec3 types, so we need to extract the correct index or
1304 // return defaultVal (0 or 1 depending on the query). We also handle extending
1305 // or tuncating in case size_t does not match the expected result type's
1306 // bitwidth.
1307 //
1308 // For a constant index >= 3 we generate:
1309 //  %res = OpConstant %SizeT 0
1310 //
1311 // For other indices we generate:
1312 //  %g = OpVariable %ptr_V3_SizeT Input
1313 //  OpDecorate %g BuiltIn XXX
1314 //  OpDecorate %g LinkageAttributes "__spirv_BuiltInXXX"
1315 //  OpDecorate %g Constant
1316 //  %loadedVec = OpLoad %V3_SizeT %g
1317 //
1318 //  Then, if the index is constant < 3, we generate:
1319 //    %res = OpCompositeExtract %SizeT %loadedVec idx
1320 //  If the index is dynamic, we generate:
1321 //    %tmp = OpVectorExtractDynamic %SizeT %loadedVec %idx
1322 //    %cmp = OpULessThan %bool %idx %const_3
1323 //    %res = OpSelect %SizeT %cmp %tmp %const_0
1324 //
1325 //  If the bitwidth of %res does not match the expected return type, we add an
1326 //  extend or truncate.
genWorkgroupQuery(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,SPIRV::BuiltIn::BuiltIn BuiltinValue,uint64_t DefaultValue)1327 static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
1328                               MachineIRBuilder &MIRBuilder,
1329                               SPIRVGlobalRegistry *GR,
1330                               SPIRV::BuiltIn::BuiltIn BuiltinValue,
1331                               uint64_t DefaultValue) {
1332   Register IndexRegister = Call->Arguments[0];
1333   const unsigned ResultWidth = Call->ReturnType->getOperand(1).getImm();
1334   const unsigned PointerSize = GR->getPointerSize();
1335   const SPIRVType *PointerSizeType =
1336       GR->getOrCreateSPIRVIntegerType(PointerSize, MIRBuilder);
1337   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1338   auto IndexInstruction = getDefInstrMaybeConstant(IndexRegister, MRI);
1339 
1340   // Set up the final register to do truncation or extension on at the end.
1341   Register ToTruncate = Call->ReturnRegister;
1342 
1343   // If the index is constant, we can statically determine if it is in range.
1344   bool IsConstantIndex =
1345       IndexInstruction->getOpcode() == TargetOpcode::G_CONSTANT;
1346 
1347   // If it's out of range (max dimension is 3), we can just return the constant
1348   // default value (0 or 1 depending on which query function).
1349   if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) {
1350     Register DefaultReg = Call->ReturnRegister;
1351     if (PointerSize != ResultWidth) {
1352       DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1353       MRI->setRegClass(DefaultReg, &SPIRV::IDRegClass);
1354       GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg,
1355                                 MIRBuilder.getMF());
1356       ToTruncate = DefaultReg;
1357     }
1358     auto NewRegister =
1359         GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
1360     MIRBuilder.buildCopy(DefaultReg, NewRegister);
1361   } else { // If it could be in range, we need to load from the given builtin.
1362     auto Vec3Ty =
1363         GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
1364     Register LoadedVector =
1365         buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue,
1366                                  LLT::fixed_vector(3, PointerSize));
1367     // Set up the vreg to extract the result to (possibly a new temporary one).
1368     Register Extracted = Call->ReturnRegister;
1369     if (!IsConstantIndex || PointerSize != ResultWidth) {
1370       Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1371       MRI->setRegClass(Extracted, &SPIRV::IDRegClass);
1372       GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF());
1373     }
1374     // Use Intrinsic::spv_extractelt so dynamic vs static extraction is
1375     // handled later: extr = spv_extractelt LoadedVector, IndexRegister.
1376     MachineInstrBuilder ExtractInst = MIRBuilder.buildIntrinsic(
1377         Intrinsic::spv_extractelt, ArrayRef<Register>{Extracted}, true, false);
1378     ExtractInst.addUse(LoadedVector).addUse(IndexRegister);
1379 
1380     // If the index is dynamic, need check if it's < 3, and then use a select.
1381     if (!IsConstantIndex) {
1382       insertAssignInstr(Extracted, nullptr, PointerSizeType, GR, MIRBuilder,
1383                         *MRI);
1384 
1385       auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister);
1386       auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
1387 
1388       Register CompareRegister =
1389           MRI->createGenericVirtualRegister(LLT::scalar(1));
1390       MRI->setRegClass(CompareRegister, &SPIRV::IDRegClass);
1391       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
1392 
1393       // Use G_ICMP to check if idxVReg < 3.
1394       MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
1395                            GR->buildConstantInt(3, MIRBuilder, IndexType));
1396 
1397       // Get constant for the default value (0 or 1 depending on which
1398       // function).
1399       Register DefaultRegister =
1400           GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
1401 
1402       // Get a register for the selection result (possibly a new temporary one).
1403       Register SelectionResult = Call->ReturnRegister;
1404       if (PointerSize != ResultWidth) {
1405         SelectionResult =
1406             MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1407         MRI->setRegClass(SelectionResult, &SPIRV::IDRegClass);
1408         GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult,
1409                                   MIRBuilder.getMF());
1410       }
1411       // Create the final G_SELECT to return the extracted value or the default.
1412       MIRBuilder.buildSelect(SelectionResult, CompareRegister, Extracted,
1413                              DefaultRegister);
1414       ToTruncate = SelectionResult;
1415     } else {
1416       ToTruncate = Extracted;
1417     }
1418   }
1419   // Alter the result's bitwidth if it does not match the SizeT value extracted.
1420   if (PointerSize != ResultWidth)
1421     MIRBuilder.buildZExtOrTrunc(Call->ReturnRegister, ToTruncate);
1422   return true;
1423 }
1424 
generateBuiltinVar(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1425 static bool generateBuiltinVar(const SPIRV::IncomingCall *Call,
1426                                MachineIRBuilder &MIRBuilder,
1427                                SPIRVGlobalRegistry *GR) {
1428   // Lookup the builtin variable record.
1429   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1430   SPIRV::BuiltIn::BuiltIn Value =
1431       SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1432 
1433   if (Value == SPIRV::BuiltIn::GlobalInvocationId)
1434     return genWorkgroupQuery(Call, MIRBuilder, GR, Value, 0);
1435 
1436   // Build a load instruction for the builtin variable.
1437   unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType);
1438   LLT LLType;
1439   if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector)
1440     LLType =
1441         LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth);
1442   else
1443     LLType = LLT::scalar(BitWidth);
1444 
1445   return buildBuiltinVariableLoad(MIRBuilder, Call->ReturnType, GR, Value,
1446                                   LLType, Call->ReturnRegister);
1447 }
1448 
generateAtomicInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1449 static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
1450                                MachineIRBuilder &MIRBuilder,
1451                                SPIRVGlobalRegistry *GR) {
1452   // Lookup the instruction opcode in the TableGen records.
1453   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1454   unsigned Opcode =
1455       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1456 
1457   switch (Opcode) {
1458   case SPIRV::OpStore:
1459     return buildAtomicInitInst(Call, MIRBuilder);
1460   case SPIRV::OpAtomicLoad:
1461     return buildAtomicLoadInst(Call, MIRBuilder, GR);
1462   case SPIRV::OpAtomicStore:
1463     return buildAtomicStoreInst(Call, MIRBuilder, GR);
1464   case SPIRV::OpAtomicCompareExchange:
1465   case SPIRV::OpAtomicCompareExchangeWeak:
1466     return buildAtomicCompareExchangeInst(Call, Builtin, Opcode, MIRBuilder,
1467                                           GR);
1468   case SPIRV::OpAtomicIAdd:
1469   case SPIRV::OpAtomicISub:
1470   case SPIRV::OpAtomicOr:
1471   case SPIRV::OpAtomicXor:
1472   case SPIRV::OpAtomicAnd:
1473   case SPIRV::OpAtomicExchange:
1474     return buildAtomicRMWInst(Call, Opcode, MIRBuilder, GR);
1475   case SPIRV::OpMemoryBarrier:
1476     return buildBarrierInst(Call, SPIRV::OpMemoryBarrier, MIRBuilder, GR);
1477   case SPIRV::OpAtomicFlagTestAndSet:
1478   case SPIRV::OpAtomicFlagClear:
1479     return buildAtomicFlagInst(Call, Opcode, MIRBuilder, GR);
1480   default:
1481     if (Call->isSpirvOp())
1482       return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1483                                 GR->getSPIRVTypeID(Call->ReturnType));
1484     return false;
1485   }
1486 }
1487 
generateAtomicFloatingInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1488 static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
1489                                        MachineIRBuilder &MIRBuilder,
1490                                        SPIRVGlobalRegistry *GR) {
1491   // Lookup the instruction opcode in the TableGen records.
1492   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1493   unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
1494 
1495   switch (Opcode) {
1496   case SPIRV::OpAtomicFAddEXT:
1497   case SPIRV::OpAtomicFMinEXT:
1498   case SPIRV::OpAtomicFMaxEXT:
1499     return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
1500   default:
1501     return false;
1502   }
1503 }
1504 
generateBarrierInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1505 static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
1506                                 MachineIRBuilder &MIRBuilder,
1507                                 SPIRVGlobalRegistry *GR) {
1508   // Lookup the instruction opcode in the TableGen records.
1509   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1510   unsigned Opcode =
1511       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1512 
1513   return buildBarrierInst(Call, Opcode, MIRBuilder, GR);
1514 }
1515 
generateCastToPtrInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)1516 static bool generateCastToPtrInst(const SPIRV::IncomingCall *Call,
1517                                   MachineIRBuilder &MIRBuilder) {
1518   MIRBuilder.buildInstr(TargetOpcode::G_ADDRSPACE_CAST)
1519       .addDef(Call->ReturnRegister)
1520       .addUse(Call->Arguments[0]);
1521   return true;
1522 }
1523 
generateDotOrFMulInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1524 static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
1525                                   MachineIRBuilder &MIRBuilder,
1526                                   SPIRVGlobalRegistry *GR) {
1527   if (Call->isSpirvOp())
1528     return buildOpFromWrapper(MIRBuilder, SPIRV::OpDot, Call,
1529                               GR->getSPIRVTypeID(Call->ReturnType));
1530   unsigned Opcode = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode();
1531   bool IsVec = Opcode == SPIRV::OpTypeVector;
1532   // Use OpDot only in case of vector args and OpFMul in case of scalar args.
1533   MIRBuilder.buildInstr(IsVec ? SPIRV::OpDot : SPIRV::OpFMulS)
1534       .addDef(Call->ReturnRegister)
1535       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1536       .addUse(Call->Arguments[0])
1537       .addUse(Call->Arguments[1]);
1538   return true;
1539 }
1540 
generateWaveInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1541 static bool generateWaveInst(const SPIRV::IncomingCall *Call,
1542                              MachineIRBuilder &MIRBuilder,
1543                              SPIRVGlobalRegistry *GR) {
1544   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1545   SPIRV::BuiltIn::BuiltIn Value =
1546       SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1547 
1548   // For now, we only support a single Wave intrinsic with a single return type.
1549   assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
1550   LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
1551 
1552   return buildBuiltinVariableLoad(
1553       MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
1554       /* isConst= */ false, /* hasLinkageTy= */ false);
1555 }
1556 
generateGetQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1557 static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
1558                                  MachineIRBuilder &MIRBuilder,
1559                                  SPIRVGlobalRegistry *GR) {
1560   // Lookup the builtin record.
1561   SPIRV::BuiltIn::BuiltIn Value =
1562       SPIRV::lookupGetBuiltin(Call->Builtin->Name, Call->Builtin->Set)->Value;
1563   uint64_t IsDefault = (Value == SPIRV::BuiltIn::GlobalSize ||
1564                         Value == SPIRV::BuiltIn::WorkgroupSize ||
1565                         Value == SPIRV::BuiltIn::EnqueuedWorkgroupSize);
1566   return genWorkgroupQuery(Call, MIRBuilder, GR, Value, IsDefault ? 1 : 0);
1567 }
1568 
generateImageSizeQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1569 static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
1570                                        MachineIRBuilder &MIRBuilder,
1571                                        SPIRVGlobalRegistry *GR) {
1572   // Lookup the image size query component number in the TableGen records.
1573   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1574   uint32_t Component =
1575       SPIRV::lookupImageQueryBuiltin(Builtin->Name, Builtin->Set)->Component;
1576   // Query result may either be a vector or a scalar. If return type is not a
1577   // vector, expect only a single size component. Otherwise get the number of
1578   // expected components.
1579   SPIRVType *RetTy = Call->ReturnType;
1580   unsigned NumExpectedRetComponents = RetTy->getOpcode() == SPIRV::OpTypeVector
1581                                           ? RetTy->getOperand(2).getImm()
1582                                           : 1;
1583   // Get the actual number of query result/size components.
1584   SPIRVType *ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1585   unsigned NumActualRetComponents = getNumSizeComponents(ImgType);
1586   Register QueryResult = Call->ReturnRegister;
1587   SPIRVType *QueryResultType = Call->ReturnType;
1588   if (NumExpectedRetComponents != NumActualRetComponents) {
1589     QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister(
1590         LLT::fixed_vector(NumActualRetComponents, 32));
1591     MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::IDRegClass);
1592     SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
1593     QueryResultType = GR->getOrCreateSPIRVVectorType(
1594         IntTy, NumActualRetComponents, MIRBuilder);
1595     GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF());
1596   }
1597   bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
1598   unsigned Opcode =
1599       IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod;
1600   MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1601   auto MIB = MIRBuilder.buildInstr(Opcode)
1602                  .addDef(QueryResult)
1603                  .addUse(GR->getSPIRVTypeID(QueryResultType))
1604                  .addUse(Call->Arguments[0]);
1605   if (!IsDimBuf)
1606     MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Lod id.
1607   if (NumExpectedRetComponents == NumActualRetComponents)
1608     return true;
1609   if (NumExpectedRetComponents == 1) {
1610     // Only 1 component is expected, build OpCompositeExtract instruction.
1611     unsigned ExtractedComposite =
1612         Component == 3 ? NumActualRetComponents - 1 : Component;
1613     assert(ExtractedComposite < NumActualRetComponents &&
1614            "Invalid composite index!");
1615     Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
1616     SPIRVType *NewType = nullptr;
1617     if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) {
1618       Register NewTypeReg = QueryResultType->getOperand(1).getReg();
1619       if (TypeReg != NewTypeReg &&
1620           (NewType = GR->getSPIRVTypeForVReg(NewTypeReg)) != nullptr)
1621         TypeReg = NewTypeReg;
1622     }
1623     MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1624         .addDef(Call->ReturnRegister)
1625         .addUse(TypeReg)
1626         .addUse(QueryResult)
1627         .addImm(ExtractedComposite);
1628     if (NewType != nullptr)
1629       insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
1630                         MIRBuilder.getMF().getRegInfo());
1631   } else {
1632     // More than 1 component is expected, fill a new vector.
1633     auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle)
1634                    .addDef(Call->ReturnRegister)
1635                    .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1636                    .addUse(QueryResult)
1637                    .addUse(QueryResult);
1638     for (unsigned i = 0; i < NumExpectedRetComponents; ++i)
1639       MIB.addImm(i < NumActualRetComponents ? i : 0xffffffff);
1640   }
1641   return true;
1642 }
1643 
generateImageMiscQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1644 static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
1645                                        MachineIRBuilder &MIRBuilder,
1646                                        SPIRVGlobalRegistry *GR) {
1647   assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt &&
1648          "Image samples query result must be of int type!");
1649 
1650   // Lookup the instruction opcode in the TableGen records.
1651   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1652   unsigned Opcode =
1653       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1654 
1655   Register Image = Call->Arguments[0];
1656   MIRBuilder.getMRI()->setRegClass(Image, &SPIRV::IDRegClass);
1657   SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
1658       GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
1659   (void)ImageDimensionality;
1660 
1661   switch (Opcode) {
1662   case SPIRV::OpImageQuerySamples:
1663     assert(ImageDimensionality == SPIRV::Dim::DIM_2D &&
1664            "Image must be of 2D dimensionality");
1665     break;
1666   case SPIRV::OpImageQueryLevels:
1667     assert((ImageDimensionality == SPIRV::Dim::DIM_1D ||
1668             ImageDimensionality == SPIRV::Dim::DIM_2D ||
1669             ImageDimensionality == SPIRV::Dim::DIM_3D ||
1670             ImageDimensionality == SPIRV::Dim::DIM_Cube) &&
1671            "Image must be of 1D/2D/3D/Cube dimensionality");
1672     break;
1673   }
1674 
1675   MIRBuilder.buildInstr(Opcode)
1676       .addDef(Call->ReturnRegister)
1677       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1678       .addUse(Image);
1679   return true;
1680 }
1681 
1682 // TODO: Move to TableGen.
1683 static SPIRV::SamplerAddressingMode::SamplerAddressingMode
getSamplerAddressingModeFromBitmask(unsigned Bitmask)1684 getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
1685   switch (Bitmask & SPIRV::CLK_ADDRESS_MODE_MASK) {
1686   case SPIRV::CLK_ADDRESS_CLAMP:
1687     return SPIRV::SamplerAddressingMode::Clamp;
1688   case SPIRV::CLK_ADDRESS_CLAMP_TO_EDGE:
1689     return SPIRV::SamplerAddressingMode::ClampToEdge;
1690   case SPIRV::CLK_ADDRESS_REPEAT:
1691     return SPIRV::SamplerAddressingMode::Repeat;
1692   case SPIRV::CLK_ADDRESS_MIRRORED_REPEAT:
1693     return SPIRV::SamplerAddressingMode::RepeatMirrored;
1694   case SPIRV::CLK_ADDRESS_NONE:
1695     return SPIRV::SamplerAddressingMode::None;
1696   default:
1697     report_fatal_error("Unknown CL address mode");
1698   }
1699 }
1700 
getSamplerParamFromBitmask(unsigned Bitmask)1701 static unsigned getSamplerParamFromBitmask(unsigned Bitmask) {
1702   return (Bitmask & SPIRV::CLK_NORMALIZED_COORDS_TRUE) ? 1 : 0;
1703 }
1704 
1705 static SPIRV::SamplerFilterMode::SamplerFilterMode
getSamplerFilterModeFromBitmask(unsigned Bitmask)1706 getSamplerFilterModeFromBitmask(unsigned Bitmask) {
1707   if (Bitmask & SPIRV::CLK_FILTER_LINEAR)
1708     return SPIRV::SamplerFilterMode::Linear;
1709   if (Bitmask & SPIRV::CLK_FILTER_NEAREST)
1710     return SPIRV::SamplerFilterMode::Nearest;
1711   return SPIRV::SamplerFilterMode::Nearest;
1712 }
1713 
generateReadImageInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1714 static bool generateReadImageInst(const StringRef DemangledCall,
1715                                   const SPIRV::IncomingCall *Call,
1716                                   MachineIRBuilder &MIRBuilder,
1717                                   SPIRVGlobalRegistry *GR) {
1718   Register Image = Call->Arguments[0];
1719   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1720   MRI->setRegClass(Image, &SPIRV::IDRegClass);
1721   MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
1722   bool HasOclSampler = DemangledCall.contains_insensitive("ocl_sampler");
1723   bool HasMsaa = DemangledCall.contains_insensitive("msaa");
1724   if (HasOclSampler || HasMsaa)
1725     MRI->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
1726   if (HasOclSampler) {
1727     Register Sampler = Call->Arguments[1];
1728 
1729     if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) &&
1730         getDefInstrMaybeConstant(Sampler, MRI)->getOperand(1).isCImm()) {
1731       uint64_t SamplerMask = getIConstVal(Sampler, MRI);
1732       Sampler = GR->buildConstantSampler(
1733           Register(), getSamplerAddressingModeFromBitmask(SamplerMask),
1734           getSamplerParamFromBitmask(SamplerMask),
1735           getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder,
1736           GR->getSPIRVTypeForVReg(Sampler));
1737     }
1738     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
1739     SPIRVType *SampledImageType =
1740         GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
1741     Register SampledImage = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1742 
1743     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
1744         .addDef(SampledImage)
1745         .addUse(GR->getSPIRVTypeID(SampledImageType))
1746         .addUse(Image)
1747         .addUse(Sampler);
1748 
1749     Register Lod = GR->buildConstantFP(APFloat::getZero(APFloat::IEEEsingle()),
1750                                        MIRBuilder);
1751     SPIRVType *TempType = Call->ReturnType;
1752     bool NeedsExtraction = false;
1753     if (TempType->getOpcode() != SPIRV::OpTypeVector) {
1754       TempType =
1755           GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder);
1756       NeedsExtraction = true;
1757     }
1758     LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType));
1759     Register TempRegister = MRI->createGenericVirtualRegister(LLType);
1760     MRI->setRegClass(TempRegister, &SPIRV::IDRegClass);
1761     GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF());
1762 
1763     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
1764         .addDef(NeedsExtraction ? TempRegister : Call->ReturnRegister)
1765         .addUse(GR->getSPIRVTypeID(TempType))
1766         .addUse(SampledImage)
1767         .addUse(Call->Arguments[2]) // Coordinate.
1768         .addImm(SPIRV::ImageOperand::Lod)
1769         .addUse(Lod);
1770 
1771     if (NeedsExtraction)
1772       MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1773           .addDef(Call->ReturnRegister)
1774           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1775           .addUse(TempRegister)
1776           .addImm(0);
1777   } else if (HasMsaa) {
1778     MIRBuilder.buildInstr(SPIRV::OpImageRead)
1779         .addDef(Call->ReturnRegister)
1780         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1781         .addUse(Image)
1782         .addUse(Call->Arguments[1]) // Coordinate.
1783         .addImm(SPIRV::ImageOperand::Sample)
1784         .addUse(Call->Arguments[2]);
1785   } else {
1786     MIRBuilder.buildInstr(SPIRV::OpImageRead)
1787         .addDef(Call->ReturnRegister)
1788         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1789         .addUse(Image)
1790         .addUse(Call->Arguments[1]); // Coordinate.
1791   }
1792   return true;
1793 }
1794 
generateWriteImageInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1795 static bool generateWriteImageInst(const SPIRV::IncomingCall *Call,
1796                                    MachineIRBuilder &MIRBuilder,
1797                                    SPIRVGlobalRegistry *GR) {
1798   MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1799   MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
1800   MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
1801   MIRBuilder.buildInstr(SPIRV::OpImageWrite)
1802       .addUse(Call->Arguments[0])  // Image.
1803       .addUse(Call->Arguments[1])  // Coordinate.
1804       .addUse(Call->Arguments[2]); // Texel.
1805   return true;
1806 }
1807 
generateSampleImageInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1808 static bool generateSampleImageInst(const StringRef DemangledCall,
1809                                     const SPIRV::IncomingCall *Call,
1810                                     MachineIRBuilder &MIRBuilder,
1811                                     SPIRVGlobalRegistry *GR) {
1812   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1813   if (Call->Builtin->Name.contains_insensitive(
1814           "__translate_sampler_initializer")) {
1815     // Build sampler literal.
1816     uint64_t Bitmask = getIConstVal(Call->Arguments[0], MRI);
1817     Register Sampler = GR->buildConstantSampler(
1818         Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
1819         getSamplerParamFromBitmask(Bitmask),
1820         getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder, Call->ReturnType);
1821     return Sampler.isValid();
1822   } else if (Call->Builtin->Name.contains_insensitive("__spirv_SampledImage")) {
1823     // Create OpSampledImage.
1824     Register Image = Call->Arguments[0];
1825     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
1826     SPIRVType *SampledImageType =
1827         GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
1828     Register SampledImage =
1829         Call->ReturnRegister.isValid()
1830             ? Call->ReturnRegister
1831             : MRI->createVirtualRegister(&SPIRV::IDRegClass);
1832     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
1833         .addDef(SampledImage)
1834         .addUse(GR->getSPIRVTypeID(SampledImageType))
1835         .addUse(Image)
1836         .addUse(Call->Arguments[1]); // Sampler.
1837     return true;
1838   } else if (Call->Builtin->Name.contains_insensitive(
1839                  "__spirv_ImageSampleExplicitLod")) {
1840     // Sample an image using an explicit level of detail.
1841     std::string ReturnType = DemangledCall.str();
1842     if (DemangledCall.contains("_R")) {
1843       ReturnType = ReturnType.substr(ReturnType.find("_R") + 2);
1844       ReturnType = ReturnType.substr(0, ReturnType.find('('));
1845     }
1846     SPIRVType *Type =
1847         Call->ReturnType
1848             ? Call->ReturnType
1849             : GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
1850     if (!Type) {
1851       std::string DiagMsg =
1852           "Unable to recognize SPIRV type name: " + ReturnType;
1853       report_fatal_error(DiagMsg.c_str());
1854     }
1855     MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1856     MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
1857     MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
1858 
1859     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
1860         .addDef(Call->ReturnRegister)
1861         .addUse(GR->getSPIRVTypeID(Type))
1862         .addUse(Call->Arguments[0]) // Image.
1863         .addUse(Call->Arguments[1]) // Coordinate.
1864         .addImm(SPIRV::ImageOperand::Lod)
1865         .addUse(Call->Arguments[3]);
1866     return true;
1867   }
1868   return false;
1869 }
1870 
generateSelectInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)1871 static bool generateSelectInst(const SPIRV::IncomingCall *Call,
1872                                MachineIRBuilder &MIRBuilder) {
1873   MIRBuilder.buildSelect(Call->ReturnRegister, Call->Arguments[0],
1874                          Call->Arguments[1], Call->Arguments[2]);
1875   return true;
1876 }
1877 
generateConstructInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1878 static bool generateConstructInst(const SPIRV::IncomingCall *Call,
1879                                   MachineIRBuilder &MIRBuilder,
1880                                   SPIRVGlobalRegistry *GR) {
1881   return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
1882                             GR->getSPIRVTypeID(Call->ReturnType));
1883 }
1884 
generateCoopMatrInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1885 static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
1886                                  MachineIRBuilder &MIRBuilder,
1887                                  SPIRVGlobalRegistry *GR) {
1888   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1889   unsigned Opcode =
1890       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1891   bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
1892   unsigned ArgSz = Call->Arguments.size();
1893   unsigned LiteralIdx = 0;
1894   if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
1895     LiteralIdx = 3;
1896   else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
1897     LiteralIdx = 4;
1898   SmallVector<uint32_t, 1> ImmArgs;
1899   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1900   if (LiteralIdx > 0)
1901     ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
1902   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
1903   if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
1904     SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1905     if (!CoopMatrType)
1906       report_fatal_error("Can't find a register's type definition");
1907     MIRBuilder.buildInstr(Opcode)
1908         .addDef(Call->ReturnRegister)
1909         .addUse(TypeReg)
1910         .addUse(CoopMatrType->getOperand(0).getReg());
1911     return true;
1912   }
1913   return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1914                             IsSet ? TypeReg : Register(0), ImmArgs);
1915 }
1916 
generateSpecConstantInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1917 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
1918                                      MachineIRBuilder &MIRBuilder,
1919                                      SPIRVGlobalRegistry *GR) {
1920   // Lookup the instruction opcode in the TableGen records.
1921   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1922   unsigned Opcode =
1923       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1924   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1925 
1926   switch (Opcode) {
1927   case SPIRV::OpSpecConstant: {
1928     // Build the SpecID decoration.
1929     unsigned SpecId =
1930         static_cast<unsigned>(getIConstVal(Call->Arguments[0], MRI));
1931     buildOpDecorate(Call->ReturnRegister, MIRBuilder, SPIRV::Decoration::SpecId,
1932                     {SpecId});
1933     // Determine the constant MI.
1934     Register ConstRegister = Call->Arguments[1];
1935     const MachineInstr *Const = getDefInstrMaybeConstant(ConstRegister, MRI);
1936     assert(Const &&
1937            (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
1938             Const->getOpcode() == TargetOpcode::G_FCONSTANT) &&
1939            "Argument should be either an int or floating-point constant");
1940     // Determine the opcode and built the OpSpec MI.
1941     const MachineOperand &ConstOperand = Const->getOperand(1);
1942     if (Call->ReturnType->getOpcode() == SPIRV::OpTypeBool) {
1943       assert(ConstOperand.isCImm() && "Int constant operand is expected");
1944       Opcode = ConstOperand.getCImm()->getValue().getZExtValue()
1945                    ? SPIRV::OpSpecConstantTrue
1946                    : SPIRV::OpSpecConstantFalse;
1947     }
1948     auto MIB = MIRBuilder.buildInstr(Opcode)
1949                    .addDef(Call->ReturnRegister)
1950                    .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1951 
1952     if (Call->ReturnType->getOpcode() != SPIRV::OpTypeBool) {
1953       if (Const->getOpcode() == TargetOpcode::G_CONSTANT)
1954         addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1955       else
1956         addNumImm(ConstOperand.getFPImm()->getValueAPF().bitcastToAPInt(), MIB);
1957     }
1958     return true;
1959   }
1960   case SPIRV::OpSpecConstantComposite: {
1961     auto MIB = MIRBuilder.buildInstr(Opcode)
1962                    .addDef(Call->ReturnRegister)
1963                    .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1964     for (unsigned i = 0; i < Call->Arguments.size(); i++)
1965       MIB.addUse(Call->Arguments[i]);
1966     return true;
1967   }
1968   default:
1969     return false;
1970   }
1971 }
1972 
buildNDRange(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1973 static bool buildNDRange(const SPIRV::IncomingCall *Call,
1974                          MachineIRBuilder &MIRBuilder,
1975                          SPIRVGlobalRegistry *GR) {
1976   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1977   MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1978   SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1979   assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
1980          PtrType->getOperand(2).isReg());
1981   Register TypeReg = PtrType->getOperand(2).getReg();
1982   SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
1983   MachineFunction &MF = MIRBuilder.getMF();
1984   Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1985   GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF);
1986   // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
1987   // three other arguments, so pass zero constant on absence.
1988   unsigned NumArgs = Call->Arguments.size();
1989   assert(NumArgs >= 2);
1990   Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
1991   MRI->setRegClass(GlobalWorkSize, &SPIRV::IDRegClass);
1992   Register LocalWorkSize =
1993       NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
1994   if (LocalWorkSize.isValid())
1995     MRI->setRegClass(LocalWorkSize, &SPIRV::IDRegClass);
1996   Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
1997   if (GlobalWorkOffset.isValid())
1998     MRI->setRegClass(GlobalWorkOffset, &SPIRV::IDRegClass);
1999   if (NumArgs < 4) {
2000     Register Const;
2001     SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
2002     if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
2003       MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
2004       assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
2005              DefInstr->getOperand(3).isReg());
2006       Register GWSPtr = DefInstr->getOperand(3).getReg();
2007       if (!MRI->getRegClassOrNull(GWSPtr))
2008         MRI->setRegClass(GWSPtr, &SPIRV::IDRegClass);
2009       // TODO: Maybe simplify generation of the type of the fields.
2010       unsigned Size = Call->Builtin->Name == "ndrange_3D" ? 3 : 2;
2011       unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
2012       Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
2013       Type *FieldTy = ArrayType::get(BaseTy, Size);
2014       SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
2015       GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
2016       GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
2017       MIRBuilder.buildInstr(SPIRV::OpLoad)
2018           .addDef(GlobalWorkSize)
2019           .addUse(GR->getSPIRVTypeID(SpvFieldTy))
2020           .addUse(GWSPtr);
2021       const SPIRVSubtarget &ST =
2022           cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
2023       Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
2024                                            SpvFieldTy, *ST.getInstrInfo());
2025     } else {
2026       Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
2027     }
2028     if (!LocalWorkSize.isValid())
2029       LocalWorkSize = Const;
2030     if (!GlobalWorkOffset.isValid())
2031       GlobalWorkOffset = Const;
2032   }
2033   assert(LocalWorkSize.isValid() && GlobalWorkOffset.isValid());
2034   MIRBuilder.buildInstr(SPIRV::OpBuildNDRange)
2035       .addDef(TmpReg)
2036       .addUse(TypeReg)
2037       .addUse(GlobalWorkSize)
2038       .addUse(LocalWorkSize)
2039       .addUse(GlobalWorkOffset);
2040   return MIRBuilder.buildInstr(SPIRV::OpStore)
2041       .addUse(Call->Arguments[0])
2042       .addUse(TmpReg);
2043 }
2044 
2045 // TODO: maybe move to the global register.
2046 static SPIRVType *
getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2047 getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
2048                                    SPIRVGlobalRegistry *GR) {
2049   LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
2050   Type *OpaqueType = StructType::getTypeByName(Context, "spirv.DeviceEvent");
2051   if (!OpaqueType)
2052     OpaqueType = StructType::getTypeByName(Context, "opencl.clk_event_t");
2053   if (!OpaqueType)
2054     OpaqueType = StructType::create(Context, "spirv.DeviceEvent");
2055   unsigned SC0 = storageClassToAddressSpace(SPIRV::StorageClass::Function);
2056   unsigned SC1 = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
2057   Type *PtrType = PointerType::get(PointerType::get(OpaqueType, SC0), SC1);
2058   return GR->getOrCreateSPIRVType(PtrType, MIRBuilder);
2059 }
2060 
buildEnqueueKernel(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2061 static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
2062                                MachineIRBuilder &MIRBuilder,
2063                                SPIRVGlobalRegistry *GR) {
2064   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2065   const DataLayout &DL = MIRBuilder.getDataLayout();
2066   bool IsSpirvOp = Call->isSpirvOp();
2067   bool HasEvents = Call->Builtin->Name.contains("events") || IsSpirvOp;
2068   const SPIRVType *Int32Ty = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
2069 
2070   // Make vararg instructions before OpEnqueueKernel.
2071   // Local sizes arguments: Sizes of block invoke arguments. Clang generates
2072   // local size operands as an array, so we need to unpack them.
2073   SmallVector<Register, 16> LocalSizes;
2074   if (Call->Builtin->Name.contains("_varargs") || IsSpirvOp) {
2075     const unsigned LocalSizeArrayIdx = HasEvents ? 9 : 6;
2076     Register GepReg = Call->Arguments[LocalSizeArrayIdx];
2077     MachineInstr *GepMI = MRI->getUniqueVRegDef(GepReg);
2078     assert(isSpvIntrinsic(*GepMI, Intrinsic::spv_gep) &&
2079            GepMI->getOperand(3).isReg());
2080     Register ArrayReg = GepMI->getOperand(3).getReg();
2081     MachineInstr *ArrayMI = MRI->getUniqueVRegDef(ArrayReg);
2082     const Type *LocalSizeTy = getMachineInstrType(ArrayMI);
2083     assert(LocalSizeTy && "Local size type is expected");
2084     const uint64_t LocalSizeNum =
2085         cast<ArrayType>(LocalSizeTy)->getNumElements();
2086     unsigned SC = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
2087     const LLT LLType = LLT::pointer(SC, GR->getPointerSize());
2088     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
2089         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
2090     for (unsigned I = 0; I < LocalSizeNum; ++I) {
2091       Register Reg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
2092       MRI->setType(Reg, LLType);
2093       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
2094       auto GEPInst = MIRBuilder.buildIntrinsic(
2095           Intrinsic::spv_gep, ArrayRef<Register>{Reg}, true, false);
2096       GEPInst
2097           .addImm(GepMI->getOperand(2).getImm())          // In bound.
2098           .addUse(ArrayMI->getOperand(0).getReg())        // Alloca.
2099           .addUse(buildConstantIntReg(0, MIRBuilder, GR)) // Indices.
2100           .addUse(buildConstantIntReg(I, MIRBuilder, GR));
2101       LocalSizes.push_back(Reg);
2102     }
2103   }
2104 
2105   // SPIRV OpEnqueueKernel instruction has 10+ arguments.
2106   auto MIB = MIRBuilder.buildInstr(SPIRV::OpEnqueueKernel)
2107                  .addDef(Call->ReturnRegister)
2108                  .addUse(GR->getSPIRVTypeID(Int32Ty));
2109 
2110   // Copy all arguments before block invoke function pointer.
2111   const unsigned BlockFIdx = HasEvents ? 6 : 3;
2112   for (unsigned i = 0; i < BlockFIdx; i++)
2113     MIB.addUse(Call->Arguments[i]);
2114 
2115   // If there are no event arguments in the original call, add dummy ones.
2116   if (!HasEvents) {
2117     MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Dummy num events.
2118     Register NullPtr = GR->getOrCreateConstNullPtr(
2119         MIRBuilder, getOrCreateSPIRVDeviceEventPointer(MIRBuilder, GR));
2120     MIB.addUse(NullPtr); // Dummy wait events.
2121     MIB.addUse(NullPtr); // Dummy ret event.
2122   }
2123 
2124   MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
2125   assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
2126   // Invoke: Pointer to invoke function.
2127   MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
2128 
2129   Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
2130   // Param: Pointer to block literal.
2131   MIB.addUse(BlockLiteralReg);
2132 
2133   Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
2134   // TODO: these numbers should be obtained from block literal structure.
2135   // Param Size: Size of block literal structure.
2136   MIB.addUse(buildConstantIntReg(DL.getTypeStoreSize(PType), MIRBuilder, GR));
2137   // Param Aligment: Aligment of block literal structure.
2138   MIB.addUse(
2139       buildConstantIntReg(DL.getPrefTypeAlign(PType).value(), MIRBuilder, GR));
2140 
2141   for (unsigned i = 0; i < LocalSizes.size(); i++)
2142     MIB.addUse(LocalSizes[i]);
2143   return true;
2144 }
2145 
generateEnqueueInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2146 static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
2147                                 MachineIRBuilder &MIRBuilder,
2148                                 SPIRVGlobalRegistry *GR) {
2149   // Lookup the instruction opcode in the TableGen records.
2150   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2151   unsigned Opcode =
2152       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2153 
2154   switch (Opcode) {
2155   case SPIRV::OpRetainEvent:
2156   case SPIRV::OpReleaseEvent:
2157     MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2158     return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]);
2159   case SPIRV::OpCreateUserEvent:
2160   case SPIRV::OpGetDefaultQueue:
2161     return MIRBuilder.buildInstr(Opcode)
2162         .addDef(Call->ReturnRegister)
2163         .addUse(GR->getSPIRVTypeID(Call->ReturnType));
2164   case SPIRV::OpIsValidEvent:
2165     MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2166     return MIRBuilder.buildInstr(Opcode)
2167         .addDef(Call->ReturnRegister)
2168         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2169         .addUse(Call->Arguments[0]);
2170   case SPIRV::OpSetUserEventStatus:
2171     MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2172     MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
2173     return MIRBuilder.buildInstr(Opcode)
2174         .addUse(Call->Arguments[0])
2175         .addUse(Call->Arguments[1]);
2176   case SPIRV::OpCaptureEventProfilingInfo:
2177     MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2178     MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
2179     MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
2180     return MIRBuilder.buildInstr(Opcode)
2181         .addUse(Call->Arguments[0])
2182         .addUse(Call->Arguments[1])
2183         .addUse(Call->Arguments[2]);
2184   case SPIRV::OpBuildNDRange:
2185     return buildNDRange(Call, MIRBuilder, GR);
2186   case SPIRV::OpEnqueueKernel:
2187     return buildEnqueueKernel(Call, MIRBuilder, GR);
2188   default:
2189     return false;
2190   }
2191 }
2192 
generateAsyncCopy(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2193 static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
2194                               MachineIRBuilder &MIRBuilder,
2195                               SPIRVGlobalRegistry *GR) {
2196   // Lookup the instruction opcode in the TableGen records.
2197   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2198   unsigned Opcode =
2199       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2200 
2201   bool IsSet = Opcode == SPIRV::OpGroupAsyncCopy;
2202   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
2203   if (Call->isSpirvOp())
2204     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2205                               IsSet ? TypeReg : Register(0));
2206 
2207   auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);
2208 
2209   switch (Opcode) {
2210   case SPIRV::OpGroupAsyncCopy: {
2211     SPIRVType *NewType =
2212         Call->ReturnType->getOpcode() == SPIRV::OpTypeEvent
2213             ? nullptr
2214             : GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder);
2215     Register TypeReg = GR->getSPIRVTypeID(NewType ? NewType : Call->ReturnType);
2216     unsigned NumArgs = Call->Arguments.size();
2217     Register EventReg = Call->Arguments[NumArgs - 1];
2218     bool Res = MIRBuilder.buildInstr(Opcode)
2219                    .addDef(Call->ReturnRegister)
2220                    .addUse(TypeReg)
2221                    .addUse(Scope)
2222                    .addUse(Call->Arguments[0])
2223                    .addUse(Call->Arguments[1])
2224                    .addUse(Call->Arguments[2])
2225                    .addUse(Call->Arguments.size() > 4
2226                                ? Call->Arguments[3]
2227                                : buildConstantIntReg(1, MIRBuilder, GR))
2228                    .addUse(EventReg);
2229     if (NewType != nullptr)
2230       insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
2231                         MIRBuilder.getMF().getRegInfo());
2232     return Res;
2233   }
2234   case SPIRV::OpGroupWaitEvents:
2235     return MIRBuilder.buildInstr(Opcode)
2236         .addUse(Scope)
2237         .addUse(Call->Arguments[0])
2238         .addUse(Call->Arguments[1]);
2239   default:
2240     return false;
2241   }
2242 }
2243 
generateConvertInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2244 static bool generateConvertInst(const StringRef DemangledCall,
2245                                 const SPIRV::IncomingCall *Call,
2246                                 MachineIRBuilder &MIRBuilder,
2247                                 SPIRVGlobalRegistry *GR) {
2248   // Lookup the conversion builtin in the TableGen records.
2249   const SPIRV::ConvertBuiltin *Builtin =
2250       SPIRV::lookupConvertBuiltin(Call->Builtin->Name, Call->Builtin->Set);
2251 
2252   if (!Builtin && Call->isSpirvOp()) {
2253     const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2254     unsigned Opcode =
2255         SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2256     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2257                               GR->getSPIRVTypeID(Call->ReturnType));
2258   }
2259 
2260   if (Builtin->IsSaturated)
2261     buildOpDecorate(Call->ReturnRegister, MIRBuilder,
2262                     SPIRV::Decoration::SaturatedConversion, {});
2263   if (Builtin->IsRounded)
2264     buildOpDecorate(Call->ReturnRegister, MIRBuilder,
2265                     SPIRV::Decoration::FPRoundingMode,
2266                     {(unsigned)Builtin->RoundingMode});
2267 
2268   std::string NeedExtMsg;              // no errors if empty
2269   bool IsRightComponentsNumber = true; // check if input/output accepts vectors
2270   unsigned Opcode = SPIRV::OpNop;
2271   if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
2272     // Int -> ...
2273     if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
2274       // Int -> Int
2275       if (Builtin->IsSaturated)
2276         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpSatConvertUToS
2277                                               : SPIRV::OpSatConvertSToU;
2278       else
2279         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpUConvert
2280                                               : SPIRV::OpSConvert;
2281     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2282                                           SPIRV::OpTypeFloat)) {
2283       // Int -> Float
2284       if (Builtin->IsBfloat16) {
2285         const auto *ST = static_cast<const SPIRVSubtarget *>(
2286             &MIRBuilder.getMF().getSubtarget());
2287         if (!ST->canUseExtension(
2288                 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2289           NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2290         IsRightComponentsNumber =
2291             GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2292             GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2293         Opcode = SPIRV::OpConvertBF16ToFINTEL;
2294       } else {
2295         bool IsSourceSigned =
2296             DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
2297         Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
2298       }
2299     }
2300   } else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
2301                                         SPIRV::OpTypeFloat)) {
2302     // Float -> ...
2303     if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
2304       // Float -> Int
2305       if (Builtin->IsBfloat16) {
2306         const auto *ST = static_cast<const SPIRVSubtarget *>(
2307             &MIRBuilder.getMF().getSubtarget());
2308         if (!ST->canUseExtension(
2309                 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2310           NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2311         IsRightComponentsNumber =
2312             GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2313             GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2314         Opcode = SPIRV::OpConvertFToBF16INTEL;
2315       } else {
2316         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
2317                                               : SPIRV::OpConvertFToU;
2318       }
2319     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2320                                           SPIRV::OpTypeFloat)) {
2321       // Float -> Float
2322       Opcode = SPIRV::OpFConvert;
2323     }
2324   }
2325 
2326   if (!NeedExtMsg.empty()) {
2327     std::string DiagMsg = std::string(Builtin->Name) +
2328                           ": the builtin requires the following SPIR-V "
2329                           "extension: " +
2330                           NeedExtMsg;
2331     report_fatal_error(DiagMsg.c_str(), false);
2332   }
2333   if (!IsRightComponentsNumber) {
2334     std::string DiagMsg =
2335         std::string(Builtin->Name) +
2336         ": result and argument must have the same number of components";
2337     report_fatal_error(DiagMsg.c_str(), false);
2338   }
2339   assert(Opcode != SPIRV::OpNop &&
2340          "Conversion between the types not implemented!");
2341 
2342   MIRBuilder.buildInstr(Opcode)
2343       .addDef(Call->ReturnRegister)
2344       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2345       .addUse(Call->Arguments[0]);
2346   return true;
2347 }
2348 
generateVectorLoadStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2349 static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
2350                                         MachineIRBuilder &MIRBuilder,
2351                                         SPIRVGlobalRegistry *GR) {
2352   // Lookup the vector load/store builtin in the TableGen records.
2353   const SPIRV::VectorLoadStoreBuiltin *Builtin =
2354       SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
2355                                           Call->Builtin->Set);
2356   // Build extended instruction.
2357   auto MIB =
2358       MIRBuilder.buildInstr(SPIRV::OpExtInst)
2359           .addDef(Call->ReturnRegister)
2360           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2361           .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
2362           .addImm(Builtin->Number);
2363   for (auto Argument : Call->Arguments)
2364     MIB.addUse(Argument);
2365   if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
2366     MIB.addImm(Builtin->ElementCount);
2367 
2368   // Rounding mode should be passed as a last argument in the MI for builtins
2369   // like "vstorea_halfn_r".
2370   if (Builtin->IsRounded)
2371     MIB.addImm(static_cast<uint32_t>(Builtin->RoundingMode));
2372   return true;
2373 }
2374 
generateLoadStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2375 static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
2376                                   MachineIRBuilder &MIRBuilder,
2377                                   SPIRVGlobalRegistry *GR) {
2378   // Lookup the instruction opcode in the TableGen records.
2379   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2380   unsigned Opcode =
2381       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2382   bool IsLoad = Opcode == SPIRV::OpLoad;
2383   // Build the instruction.
2384   auto MIB = MIRBuilder.buildInstr(Opcode);
2385   if (IsLoad) {
2386     MIB.addDef(Call->ReturnRegister);
2387     MIB.addUse(GR->getSPIRVTypeID(Call->ReturnType));
2388   }
2389   // Add a pointer to the value to load/store.
2390   MIB.addUse(Call->Arguments[0]);
2391   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2392   MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2393   // Add a value to store.
2394   if (!IsLoad) {
2395     MIB.addUse(Call->Arguments[1]);
2396     MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
2397   }
2398   // Add optional memory attributes and an alignment.
2399   unsigned NumArgs = Call->Arguments.size();
2400   if ((IsLoad && NumArgs >= 2) || NumArgs >= 3) {
2401     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI));
2402     MRI->setRegClass(Call->Arguments[IsLoad ? 1 : 2], &SPIRV::IDRegClass);
2403   }
2404   if ((IsLoad && NumArgs >= 3) || NumArgs >= 4) {
2405     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI));
2406     MRI->setRegClass(Call->Arguments[IsLoad ? 2 : 3], &SPIRV::IDRegClass);
2407   }
2408   return true;
2409 }
2410 
2411 namespace SPIRV {
2412 // Try to find a builtin function attributes by a demangled function name and
2413 // return a tuple <builtin group, op code, ext instruction number>, or a special
2414 // tuple value <-1, 0, 0> if the builtin function is not found.
2415 // Not all builtin functions are supported, only those with a ready-to-use op
2416 // code or instruction number defined in TableGen.
2417 // TODO: consider a major rework of mapping demangled calls into a builtin
2418 // functions to unify search and decrease number of individual cases.
2419 std::tuple<int, unsigned, unsigned>
mapBuiltinToOpcode(const StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set)2420 mapBuiltinToOpcode(const StringRef DemangledCall,
2421                    SPIRV::InstructionSet::InstructionSet Set) {
2422   Register Reg;
2423   SmallVector<Register> Args;
2424   std::unique_ptr<const IncomingCall> Call =
2425       lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
2426   if (!Call)
2427     return std::make_tuple(-1, 0, 0);
2428 
2429   switch (Call->Builtin->Group) {
2430   case SPIRV::Relational:
2431   case SPIRV::Atomic:
2432   case SPIRV::Barrier:
2433   case SPIRV::CastToPtr:
2434   case SPIRV::ImageMiscQuery:
2435   case SPIRV::SpecConstant:
2436   case SPIRV::Enqueue:
2437   case SPIRV::AsyncCopy:
2438   case SPIRV::LoadStore:
2439   case SPIRV::CoopMatr:
2440     if (const auto *R =
2441             SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
2442       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2443     break;
2444   case SPIRV::Extended:
2445     if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
2446                                                      Call->Builtin->Set))
2447       return std::make_tuple(Call->Builtin->Group, 0, R->Number);
2448     break;
2449   case SPIRV::VectorLoadStore:
2450     if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
2451                                                             Call->Builtin->Set))
2452       return std::make_tuple(SPIRV::Extended, 0, R->Number);
2453     break;
2454   case SPIRV::Group:
2455     if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
2456       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2457     break;
2458   case SPIRV::AtomicFloating:
2459     if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
2460       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2461     break;
2462   case SPIRV::IntelSubgroups:
2463     if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
2464       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2465     break;
2466   case SPIRV::GroupUniform:
2467     if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
2468       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2469     break;
2470   case SPIRV::WriteImage:
2471     return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
2472   case SPIRV::Select:
2473     return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
2474   case SPIRV::Construct:
2475     return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
2476                            0);
2477   case SPIRV::KernelClock:
2478     return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
2479   default:
2480     return std::make_tuple(-1, 0, 0);
2481   }
2482   return std::make_tuple(-1, 0, 0);
2483 }
2484 
lowerBuiltin(const StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set,MachineIRBuilder & MIRBuilder,const Register OrigRet,const Type * OrigRetTy,const SmallVectorImpl<Register> & Args,SPIRVGlobalRegistry * GR)2485 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
2486                                  SPIRV::InstructionSet::InstructionSet Set,
2487                                  MachineIRBuilder &MIRBuilder,
2488                                  const Register OrigRet, const Type *OrigRetTy,
2489                                  const SmallVectorImpl<Register> &Args,
2490                                  SPIRVGlobalRegistry *GR) {
2491   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
2492 
2493   // SPIR-V type and return register.
2494   Register ReturnRegister = OrigRet;
2495   SPIRVType *ReturnType = nullptr;
2496   if (OrigRetTy && !OrigRetTy->isVoidTy()) {
2497     ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder);
2498     if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
2499       MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::IDRegClass);
2500   } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
2501     ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
2502     MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32));
2503     ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
2504   }
2505 
2506   // Lookup the builtin in the TableGen records.
2507   std::unique_ptr<const IncomingCall> Call =
2508       lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
2509 
2510   if (!Call) {
2511     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
2512     return std::nullopt;
2513   }
2514 
2515   // TODO: check if the provided args meet the builtin requirments.
2516   assert(Args.size() >= Call->Builtin->MinNumArgs &&
2517          "Too few arguments to generate the builtin");
2518   if (Call->Builtin->MaxNumArgs && Args.size() > Call->Builtin->MaxNumArgs)
2519     LLVM_DEBUG(dbgs() << "More arguments provided than required!\n");
2520 
2521   // Match the builtin with implementation based on the grouping.
2522   switch (Call->Builtin->Group) {
2523   case SPIRV::Extended:
2524     return generateExtInst(Call.get(), MIRBuilder, GR);
2525   case SPIRV::Relational:
2526     return generateRelationalInst(Call.get(), MIRBuilder, GR);
2527   case SPIRV::Group:
2528     return generateGroupInst(Call.get(), MIRBuilder, GR);
2529   case SPIRV::Variable:
2530     return generateBuiltinVar(Call.get(), MIRBuilder, GR);
2531   case SPIRV::Atomic:
2532     return generateAtomicInst(Call.get(), MIRBuilder, GR);
2533   case SPIRV::AtomicFloating:
2534     return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
2535   case SPIRV::Barrier:
2536     return generateBarrierInst(Call.get(), MIRBuilder, GR);
2537   case SPIRV::CastToPtr:
2538     return generateCastToPtrInst(Call.get(), MIRBuilder);
2539   case SPIRV::Dot:
2540     return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
2541   case SPIRV::Wave:
2542     return generateWaveInst(Call.get(), MIRBuilder, GR);
2543   case SPIRV::GetQuery:
2544     return generateGetQueryInst(Call.get(), MIRBuilder, GR);
2545   case SPIRV::ImageSizeQuery:
2546     return generateImageSizeQueryInst(Call.get(), MIRBuilder, GR);
2547   case SPIRV::ImageMiscQuery:
2548     return generateImageMiscQueryInst(Call.get(), MIRBuilder, GR);
2549   case SPIRV::ReadImage:
2550     return generateReadImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
2551   case SPIRV::WriteImage:
2552     return generateWriteImageInst(Call.get(), MIRBuilder, GR);
2553   case SPIRV::SampleImage:
2554     return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
2555   case SPIRV::Select:
2556     return generateSelectInst(Call.get(), MIRBuilder);
2557   case SPIRV::Construct:
2558     return generateConstructInst(Call.get(), MIRBuilder, GR);
2559   case SPIRV::SpecConstant:
2560     return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
2561   case SPIRV::Enqueue:
2562     return generateEnqueueInst(Call.get(), MIRBuilder, GR);
2563   case SPIRV::AsyncCopy:
2564     return generateAsyncCopy(Call.get(), MIRBuilder, GR);
2565   case SPIRV::Convert:
2566     return generateConvertInst(DemangledCall, Call.get(), MIRBuilder, GR);
2567   case SPIRV::VectorLoadStore:
2568     return generateVectorLoadStoreInst(Call.get(), MIRBuilder, GR);
2569   case SPIRV::LoadStore:
2570     return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
2571   case SPIRV::IntelSubgroups:
2572     return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
2573   case SPIRV::GroupUniform:
2574     return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
2575   case SPIRV::KernelClock:
2576     return generateKernelClockInst(Call.get(), MIRBuilder, GR);
2577   case SPIRV::CoopMatr:
2578     return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
2579   }
2580   return false;
2581 }
2582 
parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,unsigned ArgIdx,LLVMContext & Ctx)2583 Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2584                                        unsigned ArgIdx, LLVMContext &Ctx) {
2585   SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2586   StringRef BuiltinArgs =
2587       DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
2588   BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
2589   if (ArgIdx >= BuiltinArgsTypeStrs.size())
2590     return nullptr;
2591   StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2592 
2593   // Parse strings representing OpenCL builtin types.
2594   if (hasBuiltinTypePrefix(TypeStr)) {
2595     // OpenCL builtin types in demangled call strings have the following format:
2596     // e.g. ocl_image2d_ro
2597     [[maybe_unused]] bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
2598     assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
2599 
2600     // Check if this is pointer to a builtin type and not just pointer
2601     // representing a builtin type. In case it is a pointer to builtin type,
2602     // this will require additional handling in the method calling
2603     // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
2604     // base types.
2605     if (TypeStr.ends_with("*"))
2606       TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
2607 
2608     return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
2609                                                Ctx);
2610   }
2611 
2612   // Parse type name in either "typeN" or "type vector[N]" format, where
2613   // N is the number of elements of the vector.
2614   Type *BaseType;
2615   unsigned VecElts = 0;
2616 
2617   BaseType = parseBasicTypeName(TypeStr, Ctx);
2618   if (!BaseType)
2619     // Unable to recognize SPIRV type name.
2620     return nullptr;
2621 
2622   // Handle "typeN*" or "type vector[N]*".
2623   TypeStr.consume_back("*");
2624 
2625   if (TypeStr.consume_front(" vector["))
2626     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
2627 
2628   TypeStr.getAsInteger(10, VecElts);
2629   if (VecElts > 0)
2630     BaseType = VectorType::get(
2631         BaseType->isVoidTy() ? Type::getInt8Ty(Ctx) : BaseType, VecElts, false);
2632 
2633   return BaseType;
2634 }
2635 
2636 struct BuiltinType {
2637   StringRef Name;
2638   uint32_t Opcode;
2639 };
2640 
2641 #define GET_BuiltinTypes_DECL
2642 #define GET_BuiltinTypes_IMPL
2643 
2644 struct OpenCLType {
2645   StringRef Name;
2646   StringRef SpirvTypeLiteral;
2647 };
2648 
2649 #define GET_OpenCLTypes_DECL
2650 #define GET_OpenCLTypes_IMPL
2651 
2652 #include "SPIRVGenTables.inc"
2653 } // namespace SPIRV
2654 
2655 //===----------------------------------------------------------------------===//
2656 // Misc functions for parsing builtin types.
2657 //===----------------------------------------------------------------------===//
2658 
parseTypeString(const StringRef Name,LLVMContext & Context)2659 static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
2660   if (Name.starts_with("void"))
2661     return Type::getVoidTy(Context);
2662   else if (Name.starts_with("int") || Name.starts_with("uint"))
2663     return Type::getInt32Ty(Context);
2664   else if (Name.starts_with("float"))
2665     return Type::getFloatTy(Context);
2666   else if (Name.starts_with("half"))
2667     return Type::getHalfTy(Context);
2668   report_fatal_error("Unable to recognize type!");
2669 }
2670 
2671 //===----------------------------------------------------------------------===//
2672 // Implementation functions for builtin types.
2673 //===----------------------------------------------------------------------===//
2674 
getNonParameterizedType(const TargetExtType * ExtensionType,const SPIRV::BuiltinType * TypeRecord,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2675 static SPIRVType *getNonParameterizedType(const TargetExtType *ExtensionType,
2676                                           const SPIRV::BuiltinType *TypeRecord,
2677                                           MachineIRBuilder &MIRBuilder,
2678                                           SPIRVGlobalRegistry *GR) {
2679   unsigned Opcode = TypeRecord->Opcode;
2680   // Create or get an existing type from GlobalRegistry.
2681   return GR->getOrCreateOpTypeByOpcode(ExtensionType, MIRBuilder, Opcode);
2682 }
2683 
getSamplerType(MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2684 static SPIRVType *getSamplerType(MachineIRBuilder &MIRBuilder,
2685                                  SPIRVGlobalRegistry *GR) {
2686   // Create or get an existing type from GlobalRegistry.
2687   return GR->getOrCreateOpTypeSampler(MIRBuilder);
2688 }
2689 
getPipeType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2690 static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
2691                               MachineIRBuilder &MIRBuilder,
2692                               SPIRVGlobalRegistry *GR) {
2693   assert(ExtensionType->getNumIntParameters() == 1 &&
2694          "Invalid number of parameters for SPIR-V pipe builtin!");
2695   // Create or get an existing type from GlobalRegistry.
2696   return GR->getOrCreateOpTypePipe(MIRBuilder,
2697                                    SPIRV::AccessQualifier::AccessQualifier(
2698                                        ExtensionType->getIntParameter(0)));
2699 }
2700 
getCoopMatrType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2701 static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
2702                                   MachineIRBuilder &MIRBuilder,
2703                                   SPIRVGlobalRegistry *GR) {
2704   assert(ExtensionType->getNumIntParameters() == 4 &&
2705          "Invalid number of parameters for SPIR-V coop matrices builtin!");
2706   assert(ExtensionType->getNumTypeParameters() == 1 &&
2707          "SPIR-V coop matrices builtin type must have a type parameter!");
2708   const SPIRVType *ElemType =
2709       GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
2710   // Create or get an existing type from GlobalRegistry.
2711   return GR->getOrCreateOpTypeCoopMatr(
2712       MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
2713       ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
2714       ExtensionType->getIntParameter(3));
2715 }
2716 
2717 static SPIRVType *
getImageType(const TargetExtType * ExtensionType,const SPIRV::AccessQualifier::AccessQualifier Qualifier,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2718 getImageType(const TargetExtType *ExtensionType,
2719              const SPIRV::AccessQualifier::AccessQualifier Qualifier,
2720              MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
2721   assert(ExtensionType->getNumTypeParameters() == 1 &&
2722          "SPIR-V image builtin type must have sampled type parameter!");
2723   const SPIRVType *SampledType =
2724       GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
2725   assert(ExtensionType->getNumIntParameters() == 7 &&
2726          "Invalid number of parameters for SPIR-V image builtin!");
2727   // Create or get an existing type from GlobalRegistry.
2728   return GR->getOrCreateOpTypeImage(
2729       MIRBuilder, SampledType,
2730       SPIRV::Dim::Dim(ExtensionType->getIntParameter(0)),
2731       ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
2732       ExtensionType->getIntParameter(3), ExtensionType->getIntParameter(4),
2733       SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(5)),
2734       Qualifier == SPIRV::AccessQualifier::WriteOnly
2735           ? SPIRV::AccessQualifier::WriteOnly
2736           : SPIRV::AccessQualifier::AccessQualifier(
2737                 ExtensionType->getIntParameter(6)));
2738 }
2739 
getSampledImageType(const TargetExtType * OpaqueType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2740 static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
2741                                       MachineIRBuilder &MIRBuilder,
2742                                       SPIRVGlobalRegistry *GR) {
2743   SPIRVType *OpaqueImageType = getImageType(
2744       OpaqueType, SPIRV::AccessQualifier::ReadOnly, MIRBuilder, GR);
2745   // Create or get an existing type from GlobalRegistry.
2746   return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
2747 }
2748 
2749 namespace SPIRV {
parseBuiltinTypeNameToTargetExtType(std::string TypeName,LLVMContext & Context)2750 TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
2751                                                    LLVMContext &Context) {
2752   StringRef NameWithParameters = TypeName;
2753 
2754   // Pointers-to-opaque-structs representing OpenCL types are first translated
2755   // to equivalent SPIR-V types. OpenCL builtin type names should have the
2756   // following format: e.g. %opencl.event_t
2757   if (NameWithParameters.starts_with("opencl.")) {
2758     const SPIRV::OpenCLType *OCLTypeRecord =
2759         SPIRV::lookupOpenCLType(NameWithParameters);
2760     if (!OCLTypeRecord)
2761       report_fatal_error("Missing TableGen record for OpenCL type: " +
2762                          NameWithParameters);
2763     NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
2764     // Continue with the SPIR-V builtin type...
2765   }
2766 
2767   // Names of the opaque structs representing a SPIR-V builtins without
2768   // parameters should have the following format: e.g. %spirv.Event
2769   assert(NameWithParameters.starts_with("spirv.") &&
2770          "Unknown builtin opaque type!");
2771 
2772   // Parameterized SPIR-V builtins names follow this format:
2773   // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
2774   if (!NameWithParameters.contains('_'))
2775     return TargetExtType::get(Context, NameWithParameters);
2776 
2777   SmallVector<StringRef> Parameters;
2778   unsigned BaseNameLength = NameWithParameters.find('_') - 1;
2779   SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");
2780 
2781   SmallVector<Type *, 1> TypeParameters;
2782   bool HasTypeParameter = !isDigit(Parameters[0][0]);
2783   if (HasTypeParameter)
2784     TypeParameters.push_back(parseTypeString(Parameters[0], Context));
2785   SmallVector<unsigned> IntParameters;
2786   for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
2787     unsigned IntParameter = 0;
2788     bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
2789     (void)ValidLiteral;
2790     assert(ValidLiteral &&
2791            "Invalid format of SPIR-V builtin parameter literal!");
2792     IntParameters.push_back(IntParameter);
2793   }
2794   return TargetExtType::get(Context,
2795                             NameWithParameters.substr(0, BaseNameLength),
2796                             TypeParameters, IntParameters);
2797 }
2798 
lowerBuiltinType(const Type * OpaqueType,SPIRV::AccessQualifier::AccessQualifier AccessQual,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2799 SPIRVType *lowerBuiltinType(const Type *OpaqueType,
2800                             SPIRV::AccessQualifier::AccessQualifier AccessQual,
2801                             MachineIRBuilder &MIRBuilder,
2802                             SPIRVGlobalRegistry *GR) {
2803   // In LLVM IR, SPIR-V and OpenCL builtin types are represented as either
2804   // target(...) target extension types or pointers-to-opaque-structs. The
2805   // approach relying on structs is deprecated and works only in the non-opaque
2806   // pointer mode (-opaque-pointers=0).
2807   // In order to maintain compatibility with LLVM IR generated by older versions
2808   // of Clang and LLVM/SPIR-V Translator, the pointers-to-opaque-structs are
2809   // "translated" to target extension types. This translation is temporary and
2810   // will be removed in the future release of LLVM.
2811   const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
2812   if (!BuiltinType)
2813     BuiltinType = parseBuiltinTypeNameToTargetExtType(
2814         OpaqueType->getStructName().str(), MIRBuilder.getContext());
2815 
2816   unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
2817 
2818   const StringRef Name = BuiltinType->getName();
2819   LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
2820 
2821   // Lookup the demangled builtin type in the TableGen records.
2822   const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
2823   if (!TypeRecord)
2824     report_fatal_error("Missing TableGen record for builtin type: " + Name);
2825 
2826   // "Lower" the BuiltinType into TargetType. The following get<...>Type methods
2827   // use the implementation details from TableGen records or TargetExtType
2828   // parameters to either create a new OpType<...> machine instruction or get an
2829   // existing equivalent SPIRVType from GlobalRegistry.
2830   SPIRVType *TargetType;
2831   switch (TypeRecord->Opcode) {
2832   case SPIRV::OpTypeImage:
2833     TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
2834     break;
2835   case SPIRV::OpTypePipe:
2836     TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
2837     break;
2838   case SPIRV::OpTypeDeviceEvent:
2839     TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
2840     break;
2841   case SPIRV::OpTypeSampler:
2842     TargetType = getSamplerType(MIRBuilder, GR);
2843     break;
2844   case SPIRV::OpTypeSampledImage:
2845     TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
2846     break;
2847   case SPIRV::OpTypeCooperativeMatrixKHR:
2848     TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
2849     break;
2850   default:
2851     TargetType =
2852         getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
2853     break;
2854   }
2855 
2856   // Emit OpName instruction if a new OpType<...> instruction was added
2857   // (equivalent type was not found in GlobalRegistry).
2858   if (NumStartingVRegs < MIRBuilder.getMRI()->getNumVirtRegs())
2859     buildOpName(GR->getSPIRVTypeID(TargetType), Name, MIRBuilder);
2860 
2861   return TargetType;
2862 }
2863 } // namespace SPIRV
2864 } // namespace llvm
2865