xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SPIRVBuiltins.cpp - SPIR-V Built-in Functions ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering builtin function calls and types using their
10 // demangled names and TableGen records.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVBuiltins.h"
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/IntrinsicsSPIRV.h"
21 #include <regex>
22 #include <string>
23 #include <tuple>
24 
25 #define DEBUG_TYPE "spirv-builtins"
26 
27 namespace llvm {
28 namespace SPIRV {
29 #define GET_BuiltinGroup_DECL
30 #include "SPIRVGenTables.inc"
31 
32 struct DemangledBuiltin {
33   StringRef Name;
34   InstructionSet::InstructionSet Set;
35   BuiltinGroup Group;
36   uint8_t MinNumArgs;
37   uint8_t MaxNumArgs;
38 };
39 
40 #define GET_DemangledBuiltins_DECL
41 #define GET_DemangledBuiltins_IMPL
42 
43 struct IncomingCall {
44   const std::string BuiltinName;
45   const DemangledBuiltin *Builtin;
46 
47   const Register ReturnRegister;
48   const SPIRVType *ReturnType;
49   const SmallVectorImpl<Register> &Arguments;
50 
IncomingCallllvm::SPIRV::IncomingCall51   IncomingCall(const std::string BuiltinName, const DemangledBuiltin *Builtin,
52                const Register ReturnRegister, const SPIRVType *ReturnType,
53                const SmallVectorImpl<Register> &Arguments)
54       : BuiltinName(BuiltinName), Builtin(Builtin),
55         ReturnRegister(ReturnRegister), ReturnType(ReturnType),
56         Arguments(Arguments) {}
57 
isSpirvOpllvm::SPIRV::IncomingCall58   bool isSpirvOp() const { return BuiltinName.rfind("__spirv_", 0) == 0; }
59 };
60 
61 struct NativeBuiltin {
62   StringRef Name;
63   InstructionSet::InstructionSet Set;
64   uint32_t Opcode;
65 };
66 
67 #define GET_NativeBuiltins_DECL
68 #define GET_NativeBuiltins_IMPL
69 
70 struct GroupBuiltin {
71   StringRef Name;
72   uint32_t Opcode;
73   uint32_t GroupOperation;
74   bool IsElect;
75   bool IsAllOrAny;
76   bool IsAllEqual;
77   bool IsBallot;
78   bool IsInverseBallot;
79   bool IsBallotBitExtract;
80   bool IsBallotFindBit;
81   bool IsLogical;
82   bool NoGroupOperation;
83   bool HasBoolArg;
84 };
85 
86 #define GET_GroupBuiltins_DECL
87 #define GET_GroupBuiltins_IMPL
88 
89 struct IntelSubgroupsBuiltin {
90   StringRef Name;
91   uint32_t Opcode;
92   bool IsBlock;
93   bool IsWrite;
94   bool IsMedia;
95 };
96 
97 #define GET_IntelSubgroupsBuiltins_DECL
98 #define GET_IntelSubgroupsBuiltins_IMPL
99 
100 struct AtomicFloatingBuiltin {
101   StringRef Name;
102   uint32_t Opcode;
103 };
104 
105 #define GET_AtomicFloatingBuiltins_DECL
106 #define GET_AtomicFloatingBuiltins_IMPL
107 struct GroupUniformBuiltin {
108   StringRef Name;
109   uint32_t Opcode;
110   bool IsLogical;
111 };
112 
113 #define GET_GroupUniformBuiltins_DECL
114 #define GET_GroupUniformBuiltins_IMPL
115 
116 struct GetBuiltin {
117   StringRef Name;
118   InstructionSet::InstructionSet Set;
119   BuiltIn::BuiltIn Value;
120 };
121 
122 using namespace BuiltIn;
123 #define GET_GetBuiltins_DECL
124 #define GET_GetBuiltins_IMPL
125 
126 struct ImageQueryBuiltin {
127   StringRef Name;
128   InstructionSet::InstructionSet Set;
129   uint32_t Component;
130 };
131 
132 #define GET_ImageQueryBuiltins_DECL
133 #define GET_ImageQueryBuiltins_IMPL
134 
135 struct IntegerDotProductBuiltin {
136   StringRef Name;
137   uint32_t Opcode;
138   bool IsSwapReq;
139 };
140 
141 #define GET_IntegerDotProductBuiltins_DECL
142 #define GET_IntegerDotProductBuiltins_IMPL
143 
144 struct ConvertBuiltin {
145   StringRef Name;
146   InstructionSet::InstructionSet Set;
147   bool IsDestinationSigned;
148   bool IsSaturated;
149   bool IsRounded;
150   bool IsBfloat16;
151   FPRoundingMode::FPRoundingMode RoundingMode;
152 };
153 
154 struct VectorLoadStoreBuiltin {
155   StringRef Name;
156   InstructionSet::InstructionSet Set;
157   uint32_t Number;
158   uint32_t ElementCount;
159   bool IsRounded;
160   FPRoundingMode::FPRoundingMode RoundingMode;
161 };
162 
163 using namespace FPRoundingMode;
164 #define GET_ConvertBuiltins_DECL
165 #define GET_ConvertBuiltins_IMPL
166 
167 using namespace InstructionSet;
168 #define GET_VectorLoadStoreBuiltins_DECL
169 #define GET_VectorLoadStoreBuiltins_IMPL
170 
171 #define GET_CLMemoryScope_DECL
172 #define GET_CLSamplerAddressingMode_DECL
173 #define GET_CLMemoryFenceFlags_DECL
174 #define GET_ExtendedBuiltins_DECL
175 #include "SPIRVGenTables.inc"
176 } // namespace SPIRV
177 
178 //===----------------------------------------------------------------------===//
179 // Misc functions for looking up builtins and veryfying requirements using
180 // TableGen records
181 //===----------------------------------------------------------------------===//
182 
183 namespace SPIRV {
184 /// Parses the name part of the demangled builtin call.
lookupBuiltinNameHelper(StringRef DemangledCall,FPDecorationId * DecorationId)185 std::string lookupBuiltinNameHelper(StringRef DemangledCall,
186                                     FPDecorationId *DecorationId) {
187   StringRef PassPrefix = "(anonymous namespace)::";
188   std::string BuiltinName;
189   // Itanium Demangler result may have "(anonymous namespace)::" prefix
190   if (DemangledCall.starts_with(PassPrefix))
191     BuiltinName = DemangledCall.substr(PassPrefix.size());
192   else
193     BuiltinName = DemangledCall;
194   // Extract the builtin function name and types of arguments from the call
195   // skeleton.
196   BuiltinName = BuiltinName.substr(0, BuiltinName.find('('));
197 
198   // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
199   if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
200     BuiltinName = BuiltinName.substr(12);
201 
202   // Check if the extracted name contains type information between angle
203   // brackets. If so, the builtin is an instantiated template - needs to have
204   // the information after angle brackets and return type removed.
205   std::size_t Pos1 = BuiltinName.rfind('<');
206   if (Pos1 != std::string::npos && BuiltinName.back() == '>') {
207     std::size_t Pos2 = BuiltinName.rfind(' ', Pos1);
208     if (Pos2 == std::string::npos)
209       Pos2 = 0;
210     else
211       ++Pos2;
212     BuiltinName = BuiltinName.substr(Pos2, Pos1 - Pos2);
213     BuiltinName = BuiltinName.substr(BuiltinName.find_last_of(' ') + 1);
214   }
215 
216   // Check if the extracted name begins with:
217   // - "__spirv_ImageSampleExplicitLod"
218   // - "__spirv_ImageRead"
219   // - "__spirv_ImageWrite"
220   // - "__spirv_ImageQuerySizeLod"
221   // - "__spirv_UDotKHR"
222   // - "__spirv_SDotKHR"
223   // - "__spirv_SUDotKHR"
224   // - "__spirv_SDotAccSatKHR"
225   // - "__spirv_UDotAccSatKHR"
226   // - "__spirv_SUDotAccSatKHR"
227   // - "__spirv_ReadClockKHR"
228   // - "__spirv_SubgroupBlockReadINTEL"
229   // - "__spirv_SubgroupImageBlockReadINTEL"
230   // - "__spirv_SubgroupImageMediaBlockReadINTEL"
231   // - "__spirv_SubgroupImageMediaBlockWriteINTEL"
232   // - "__spirv_Convert"
233   // - "__spirv_UConvert"
234   // - "__spirv_SConvert"
235   // - "__spirv_FConvert"
236   // - "__spirv_SatConvert"
237   // and maybe contains return type information at the end "_R<type>".
238   // If so, extract the plain builtin name without the type information.
239   static const std::regex SpvWithR(
240       "(__spirv_(ImageSampleExplicitLod|ImageRead|ImageWrite|ImageQuerySizeLod|"
241       "UDotKHR|"
242       "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
243       "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
244       "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
245       "Convert|"
246       "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
247   std::smatch Match;
248   if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
249     std::ssub_match SubMatch;
250     if (DecorationId && Match.size() > 3) {
251       SubMatch = Match[4];
252       *DecorationId = demangledPostfixToDecorationId(SubMatch.str());
253     }
254     SubMatch = Match[1];
255     BuiltinName = SubMatch.str();
256   }
257 
258   return BuiltinName;
259 }
260 } // namespace SPIRV
261 
262 /// Looks up the demangled builtin call in the SPIRVBuiltins.td records using
263 /// the provided \p DemangledCall and specified \p Set.
264 ///
265 /// The lookup follows the following algorithm, returning the first successful
266 /// match:
267 /// 1. Search with the plain demangled name (expecting a 1:1 match).
268 /// 2. Search with the prefix before or suffix after the demangled name
269 /// signyfying the type of the first argument.
270 ///
271 /// \returns Wrapper around the demangled call and found builtin definition.
272 static std::unique_ptr<const SPIRV::IncomingCall>
lookupBuiltin(StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set,Register ReturnRegister,const SPIRVType * ReturnType,const SmallVectorImpl<Register> & Arguments)273 lookupBuiltin(StringRef DemangledCall,
274               SPIRV::InstructionSet::InstructionSet Set,
275               Register ReturnRegister, const SPIRVType *ReturnType,
276               const SmallVectorImpl<Register> &Arguments) {
277   std::string BuiltinName = SPIRV::lookupBuiltinNameHelper(DemangledCall);
278 
279   SmallVector<StringRef, 10> BuiltinArgumentTypes;
280   StringRef BuiltinArgs =
281       DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
282   BuiltinArgs.split(BuiltinArgumentTypes, ',', -1, false);
283 
284   // Look up the builtin in the defined set. Start with the plain demangled
285   // name, expecting a 1:1 match in the defined builtin set.
286   const SPIRV::DemangledBuiltin *Builtin;
287   if ((Builtin = SPIRV::lookupBuiltin(BuiltinName, Set)))
288     return std::make_unique<SPIRV::IncomingCall>(
289         BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
290 
291   // If the initial look up was unsuccessful and the demangled call takes at
292   // least 1 argument, add a prefix or suffix signifying the type of the first
293   // argument and repeat the search.
294   if (BuiltinArgumentTypes.size() >= 1) {
295     char FirstArgumentType = BuiltinArgumentTypes[0][0];
296     // Prefix to be added to the builtin's name for lookup.
297     // For example, OpenCL "abs" taking an unsigned value has a prefix "u_".
298     std::string Prefix;
299 
300     switch (FirstArgumentType) {
301     // Unsigned:
302     case 'u':
303       if (Set == SPIRV::InstructionSet::OpenCL_std)
304         Prefix = "u_";
305       else if (Set == SPIRV::InstructionSet::GLSL_std_450)
306         Prefix = "u";
307       break;
308     // Signed:
309     case 'c':
310     case 's':
311     case 'i':
312     case 'l':
313       if (Set == SPIRV::InstructionSet::OpenCL_std)
314         Prefix = "s_";
315       else if (Set == SPIRV::InstructionSet::GLSL_std_450)
316         Prefix = "s";
317       break;
318     // Floating-point:
319     case 'f':
320     case 'd':
321     case 'h':
322       if (Set == SPIRV::InstructionSet::OpenCL_std ||
323           Set == SPIRV::InstructionSet::GLSL_std_450)
324         Prefix = "f";
325       break;
326     }
327 
328     // If argument-type name prefix was added, look up the builtin again.
329     if (!Prefix.empty() &&
330         (Builtin = SPIRV::lookupBuiltin(Prefix + BuiltinName, Set)))
331       return std::make_unique<SPIRV::IncomingCall>(
332           BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
333 
334     // If lookup with a prefix failed, find a suffix to be added to the
335     // builtin's name for lookup. For example, OpenCL "group_reduce_max" taking
336     // an unsigned value has a suffix "u".
337     std::string Suffix;
338 
339     switch (FirstArgumentType) {
340     // Unsigned:
341     case 'u':
342       Suffix = "u";
343       break;
344     // Signed:
345     case 'c':
346     case 's':
347     case 'i':
348     case 'l':
349       Suffix = "s";
350       break;
351     // Floating-point:
352     case 'f':
353     case 'd':
354     case 'h':
355       Suffix = "f";
356       break;
357     }
358 
359     // If argument-type name suffix was added, look up the builtin again.
360     if (!Suffix.empty() &&
361         (Builtin = SPIRV::lookupBuiltin(BuiltinName + Suffix, Set)))
362       return std::make_unique<SPIRV::IncomingCall>(
363           BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
364   }
365 
366   // No builtin with such name was found in the set.
367   return nullptr;
368 }
369 
getBlockStructInstr(Register ParamReg,MachineRegisterInfo * MRI)370 static MachineInstr *getBlockStructInstr(Register ParamReg,
371                                          MachineRegisterInfo *MRI) {
372   // We expect the following sequence of instructions:
373   //   %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
374   //   or       = G_GLOBAL_VALUE @block_literal_global
375   //   %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
376   //   %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
377   MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
378   assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
379          MI->getOperand(1).isReg());
380   Register BitcastReg = MI->getOperand(1).getReg();
381   MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
382   assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
383          BitcastMI->getOperand(2).isReg());
384   Register ValueReg = BitcastMI->getOperand(2).getReg();
385   MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
386   return ValueMI;
387 }
388 
389 // Return an integer constant corresponding to the given register and
390 // defined in spv_track_constant.
391 // TODO: maybe unify with prelegalizer pass.
getConstFromIntrinsic(Register Reg,MachineRegisterInfo * MRI)392 static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
393   MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
394   assert(DefMI->getOpcode() == TargetOpcode::G_CONSTANT &&
395          DefMI->getOperand(1).isCImm());
396   return DefMI->getOperand(1).getCImm()->getValue().getZExtValue();
397 }
398 
399 // Return type of the instruction result from spv_assign_type intrinsic.
400 // TODO: maybe unify with prelegalizer pass.
getMachineInstrType(MachineInstr * MI)401 static const Type *getMachineInstrType(MachineInstr *MI) {
402   MachineInstr *NextMI = MI->getNextNode();
403   if (!NextMI)
404     return nullptr;
405   if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name))
406     if ((NextMI = NextMI->getNextNode()) == nullptr)
407       return nullptr;
408   Register ValueReg = MI->getOperand(0).getReg();
409   if ((!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) &&
410        !isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_ptr_type)) ||
411       NextMI->getOperand(1).getReg() != ValueReg)
412     return nullptr;
413   Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
414   assert(Ty && "Type is expected");
415   return Ty;
416 }
417 
getBlockStructType(Register ParamReg,MachineRegisterInfo * MRI)418 static const Type *getBlockStructType(Register ParamReg,
419                                       MachineRegisterInfo *MRI) {
420   // In principle, this information should be passed to us from Clang via
421   // an elementtype attribute. However, said attribute requires that
422   // the function call be an intrinsic, which is not. Instead, we rely on being
423   // able to trace this to the declaration of a variable: OpenCL C specification
424   // section 6.12.5 should guarantee that we can do this.
425   MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
426   if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
427     return MI->getOperand(1).getGlobal()->getType();
428   assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
429          "Blocks in OpenCL C must be traceable to allocation site");
430   return getMachineInstrType(MI);
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // Helper functions for building misc instructions
435 //===----------------------------------------------------------------------===//
436 
437 /// Helper function building either a resulting scalar or vector bool register
438 /// depending on the expected \p ResultType.
439 ///
440 /// \returns Tuple of the resulting register and its type.
441 static std::tuple<Register, SPIRVType *>
buildBoolRegister(MachineIRBuilder & MIRBuilder,const SPIRVType * ResultType,SPIRVGlobalRegistry * GR)442 buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
443                   SPIRVGlobalRegistry *GR) {
444   LLT Type;
445   SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
446 
447   if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
448     unsigned VectorElements = ResultType->getOperand(2).getImm();
449     BoolType = GR->getOrCreateSPIRVVectorType(BoolType, VectorElements,
450                                               MIRBuilder, true);
451     const FixedVectorType *LLVMVectorType =
452         cast<FixedVectorType>(GR->getTypeForSPIRVType(BoolType));
453     Type = LLT::vector(LLVMVectorType->getElementCount(), 1);
454   } else {
455     Type = LLT::scalar(1);
456   }
457 
458   Register ResultRegister =
459       MIRBuilder.getMRI()->createGenericVirtualRegister(Type);
460   MIRBuilder.getMRI()->setRegClass(ResultRegister, GR->getRegClass(ResultType));
461   GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF());
462   return std::make_tuple(ResultRegister, BoolType);
463 }
464 
465 /// Helper function for building either a vector or scalar select instruction
466 /// depending on the expected \p ResultType.
buildSelectInst(MachineIRBuilder & MIRBuilder,Register ReturnRegister,Register SourceRegister,const SPIRVType * ReturnType,SPIRVGlobalRegistry * GR)467 static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
468                             Register ReturnRegister, Register SourceRegister,
469                             const SPIRVType *ReturnType,
470                             SPIRVGlobalRegistry *GR) {
471   Register TrueConst, FalseConst;
472 
473   if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
474     unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
475     uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
476     TrueConst =
477         GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
478     FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
479   } else {
480     TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType, true);
481     FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType, true);
482   }
483 
484   return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst,
485                                 FalseConst);
486 }
487 
488 /// Helper function for building a load instruction loading into the
489 /// \p DestinationReg.
buildLoadInst(SPIRVType * BaseType,Register PtrRegister,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,LLT LowLevelType,Register DestinationReg=Register (0))490 static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
491                               MachineIRBuilder &MIRBuilder,
492                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
493                               Register DestinationReg = Register(0)) {
494   if (!DestinationReg.isValid())
495     DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
496   // TODO: consider using correct address space and alignment (p0 is canonical
497   // type for selection though).
498   MachinePointerInfo PtrInfo = MachinePointerInfo();
499   MIRBuilder.buildLoad(DestinationReg, PtrRegister, PtrInfo, Align());
500   return DestinationReg;
501 }
502 
503 /// Helper function for building a load instruction for loading a builtin global
504 /// variable of \p BuiltinValue value.
buildBuiltinVariableLoad(MachineIRBuilder & MIRBuilder,SPIRVType * VariableType,SPIRVGlobalRegistry * GR,SPIRV::BuiltIn::BuiltIn BuiltinValue,LLT LLType,Register Reg=Register (0),bool isConst=true,bool hasLinkageTy=true)505 static Register buildBuiltinVariableLoad(
506     MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
507     SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
508     Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
509   Register NewRegister =
510       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::pIDRegClass);
511   MIRBuilder.getMRI()->setType(
512       NewRegister,
513       LLT::pointer(storageClassToAddressSpace(SPIRV::StorageClass::Function),
514                    GR->getPointerSize()));
515   SPIRVType *PtrType = GR->getOrCreateSPIRVPointerType(
516       VariableType, MIRBuilder, SPIRV::StorageClass::Input);
517   GR->assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
518 
519   // Set up the global OpVariable with the necessary builtin decorations.
520   Register Variable = GR->buildGlobalVariable(
521       NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
522       SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
523       /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
524       false);
525 
526   // Load the value from the global variable.
527   Register LoadedRegister =
528       buildLoadInst(VariableType, Variable, MIRBuilder, GR, LLType, Reg);
529   MIRBuilder.getMRI()->setType(LoadedRegister, LLType);
530   return LoadedRegister;
531 }
532 
533 /// Helper external function for inserting ASSIGN_TYPE instuction between \p Reg
534 /// and its definition, set the new register as a destination of the definition,
535 /// assign SPIRVType to both registers. If SpirvTy is provided, use it as
536 /// SPIRVType in ASSIGN_TYPE, otherwise create it from \p Ty. Defined in
537 /// SPIRVPreLegalizer.cpp.
538 extern void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
539                               SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
540                               MachineRegisterInfo &MRI);
541 
542 // TODO: Move to TableGen.
543 static SPIRV::MemorySemantics::MemorySemantics
getSPIRVMemSemantics(std::memory_order MemOrder)544 getSPIRVMemSemantics(std::memory_order MemOrder) {
545   switch (MemOrder) {
546   case std::memory_order_relaxed:
547     return SPIRV::MemorySemantics::None;
548   case std::memory_order_acquire:
549     return SPIRV::MemorySemantics::Acquire;
550   case std::memory_order_release:
551     return SPIRV::MemorySemantics::Release;
552   case std::memory_order_acq_rel:
553     return SPIRV::MemorySemantics::AcquireRelease;
554   case std::memory_order_seq_cst:
555     return SPIRV::MemorySemantics::SequentiallyConsistent;
556   default:
557     report_fatal_error("Unknown CL memory scope");
558   }
559 }
560 
getSPIRVScope(SPIRV::CLMemoryScope ClScope)561 static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
562   switch (ClScope) {
563   case SPIRV::CLMemoryScope::memory_scope_work_item:
564     return SPIRV::Scope::Invocation;
565   case SPIRV::CLMemoryScope::memory_scope_work_group:
566     return SPIRV::Scope::Workgroup;
567   case SPIRV::CLMemoryScope::memory_scope_device:
568     return SPIRV::Scope::Device;
569   case SPIRV::CLMemoryScope::memory_scope_all_svm_devices:
570     return SPIRV::Scope::CrossDevice;
571   case SPIRV::CLMemoryScope::memory_scope_sub_group:
572     return SPIRV::Scope::Subgroup;
573   }
574   report_fatal_error("Unknown CL memory scope");
575 }
576 
buildConstantIntReg32(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)577 static Register buildConstantIntReg32(uint64_t Val,
578                                       MachineIRBuilder &MIRBuilder,
579                                       SPIRVGlobalRegistry *GR) {
580   return GR->buildConstantInt(
581       Val, MIRBuilder, GR->getOrCreateSPIRVIntegerType(32, MIRBuilder), true);
582 }
583 
buildScopeReg(Register CLScopeRegister,SPIRV::Scope::Scope Scope,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,MachineRegisterInfo * MRI)584 static Register buildScopeReg(Register CLScopeRegister,
585                               SPIRV::Scope::Scope Scope,
586                               MachineIRBuilder &MIRBuilder,
587                               SPIRVGlobalRegistry *GR,
588                               MachineRegisterInfo *MRI) {
589   if (CLScopeRegister.isValid()) {
590     auto CLScope =
591         static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
592     Scope = getSPIRVScope(CLScope);
593 
594     if (CLScope == static_cast<unsigned>(Scope)) {
595       MRI->setRegClass(CLScopeRegister, &SPIRV::iIDRegClass);
596       return CLScopeRegister;
597     }
598   }
599   return buildConstantIntReg32(Scope, MIRBuilder, GR);
600 }
601 
setRegClassIfNull(Register Reg,MachineRegisterInfo * MRI,SPIRVGlobalRegistry * GR)602 static void setRegClassIfNull(Register Reg, MachineRegisterInfo *MRI,
603                               SPIRVGlobalRegistry *GR) {
604   if (MRI->getRegClassOrNull(Reg))
605     return;
606   SPIRVType *SpvType = GR->getSPIRVTypeForVReg(Reg);
607   MRI->setRegClass(Reg,
608                    SpvType ? GR->getRegClass(SpvType) : &SPIRV::iIDRegClass);
609 }
610 
buildMemSemanticsReg(Register SemanticsRegister,Register PtrRegister,unsigned & Semantics,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)611 static Register buildMemSemanticsReg(Register SemanticsRegister,
612                                      Register PtrRegister, unsigned &Semantics,
613                                      MachineIRBuilder &MIRBuilder,
614                                      SPIRVGlobalRegistry *GR) {
615   if (SemanticsRegister.isValid()) {
616     MachineRegisterInfo *MRI = MIRBuilder.getMRI();
617     std::memory_order Order =
618         static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
619     Semantics =
620         getSPIRVMemSemantics(Order) |
621         getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
622     if (static_cast<unsigned>(Order) == Semantics) {
623       MRI->setRegClass(SemanticsRegister, &SPIRV::iIDRegClass);
624       return SemanticsRegister;
625     }
626   }
627   return buildConstantIntReg32(Semantics, MIRBuilder, GR);
628 }
629 
buildOpFromWrapper(MachineIRBuilder & MIRBuilder,unsigned Opcode,const SPIRV::IncomingCall * Call,Register TypeReg,ArrayRef<uint32_t> ImmArgs={})630 static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
631                                const SPIRV::IncomingCall *Call,
632                                Register TypeReg,
633                                ArrayRef<uint32_t> ImmArgs = {}) {
634   auto MIB = MIRBuilder.buildInstr(Opcode);
635   if (TypeReg.isValid())
636     MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
637   unsigned Sz = Call->Arguments.size() - ImmArgs.size();
638   for (unsigned i = 0; i < Sz; ++i)
639     MIB.addUse(Call->Arguments[i]);
640   for (uint32_t ImmArg : ImmArgs)
641     MIB.addImm(ImmArg);
642   return true;
643 }
644 
645 /// Helper function for translating atomic init to OpStore.
buildAtomicInitInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)646 static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
647                                 MachineIRBuilder &MIRBuilder) {
648   if (Call->isSpirvOp())
649     return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
650 
651   assert(Call->Arguments.size() == 2 &&
652          "Need 2 arguments for atomic init translation");
653   MIRBuilder.buildInstr(SPIRV::OpStore)
654       .addUse(Call->Arguments[0])
655       .addUse(Call->Arguments[1]);
656   return true;
657 }
658 
659 /// Helper function for building an atomic load instruction.
buildAtomicLoadInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)660 static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
661                                 MachineIRBuilder &MIRBuilder,
662                                 SPIRVGlobalRegistry *GR) {
663   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
664   if (Call->isSpirvOp())
665     return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicLoad, Call, TypeReg);
666 
667   Register PtrRegister = Call->Arguments[0];
668   // TODO: if true insert call to __translate_ocl_memory_sccope before
669   // OpAtomicLoad and the function implementation. We can use Translator's
670   // output for transcoding/atomic_explicit_arguments.cl as an example.
671   Register ScopeRegister =
672       Call->Arguments.size() > 1
673           ? Call->Arguments[1]
674           : buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
675   Register MemSemanticsReg;
676   if (Call->Arguments.size() > 2) {
677     // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad.
678     MemSemanticsReg = Call->Arguments[2];
679   } else {
680     int Semantics =
681         SPIRV::MemorySemantics::SequentiallyConsistent |
682         getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
683     MemSemanticsReg = buildConstantIntReg32(Semantics, MIRBuilder, GR);
684   }
685 
686   MIRBuilder.buildInstr(SPIRV::OpAtomicLoad)
687       .addDef(Call->ReturnRegister)
688       .addUse(TypeReg)
689       .addUse(PtrRegister)
690       .addUse(ScopeRegister)
691       .addUse(MemSemanticsReg);
692   return true;
693 }
694 
695 /// Helper function for building an atomic store instruction.
buildAtomicStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)696 static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
697                                  MachineIRBuilder &MIRBuilder,
698                                  SPIRVGlobalRegistry *GR) {
699   if (Call->isSpirvOp())
700     return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
701 
702   Register ScopeRegister =
703       buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
704   Register PtrRegister = Call->Arguments[0];
705   int Semantics =
706       SPIRV::MemorySemantics::SequentiallyConsistent |
707       getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
708   Register MemSemanticsReg = buildConstantIntReg32(Semantics, MIRBuilder, GR);
709   MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
710       .addUse(PtrRegister)
711       .addUse(ScopeRegister)
712       .addUse(MemSemanticsReg)
713       .addUse(Call->Arguments[1]);
714   return true;
715 }
716 
717 /// Helper function for building an atomic compare-exchange instruction.
buildAtomicCompareExchangeInst(const SPIRV::IncomingCall * Call,const SPIRV::DemangledBuiltin * Builtin,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)718 static bool buildAtomicCompareExchangeInst(
719     const SPIRV::IncomingCall *Call, const SPIRV::DemangledBuiltin *Builtin,
720     unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
721   if (Call->isSpirvOp())
722     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
723                               GR->getSPIRVTypeID(Call->ReturnType));
724 
725   bool IsCmpxchg = Call->Builtin->Name.contains("cmpxchg");
726   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
727 
728   Register ObjectPtr = Call->Arguments[0];   // Pointer (volatile A *object.)
729   Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected).
730   Register Desired = Call->Arguments[2];     // Value (C Desired).
731   SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired);
732   LLT DesiredLLT = MRI->getType(Desired);
733 
734   assert(GR->getSPIRVTypeForVReg(ObjectPtr)->getOpcode() ==
735          SPIRV::OpTypePointer);
736   unsigned ExpectedType = GR->getSPIRVTypeForVReg(ExpectedArg)->getOpcode();
737   (void)ExpectedType;
738   assert(IsCmpxchg ? ExpectedType == SPIRV::OpTypeInt
739                    : ExpectedType == SPIRV::OpTypePointer);
740   assert(GR->isScalarOfType(Desired, SPIRV::OpTypeInt));
741 
742   SPIRVType *SpvObjectPtrTy = GR->getSPIRVTypeForVReg(ObjectPtr);
743   assert(SpvObjectPtrTy->getOperand(2).isReg() && "SPIRV type is expected");
744   auto StorageClass = static_cast<SPIRV::StorageClass::StorageClass>(
745       SpvObjectPtrTy->getOperand(1).getImm());
746   auto MemSemStorage = getMemSemanticsForStorageClass(StorageClass);
747 
748   Register MemSemEqualReg;
749   Register MemSemUnequalReg;
750   uint64_t MemSemEqual =
751       IsCmpxchg
752           ? SPIRV::MemorySemantics::None
753           : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
754   uint64_t MemSemUnequal =
755       IsCmpxchg
756           ? SPIRV::MemorySemantics::None
757           : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
758   if (Call->Arguments.size() >= 4) {
759     assert(Call->Arguments.size() >= 5 &&
760            "Need 5+ args for explicit atomic cmpxchg");
761     auto MemOrdEq =
762         static_cast<std::memory_order>(getIConstVal(Call->Arguments[3], MRI));
763     auto MemOrdNeq =
764         static_cast<std::memory_order>(getIConstVal(Call->Arguments[4], MRI));
765     MemSemEqual = getSPIRVMemSemantics(MemOrdEq) | MemSemStorage;
766     MemSemUnequal = getSPIRVMemSemantics(MemOrdNeq) | MemSemStorage;
767     if (static_cast<unsigned>(MemOrdEq) == MemSemEqual)
768       MemSemEqualReg = Call->Arguments[3];
769     if (static_cast<unsigned>(MemOrdNeq) == MemSemEqual)
770       MemSemUnequalReg = Call->Arguments[4];
771   }
772   if (!MemSemEqualReg.isValid())
773     MemSemEqualReg = buildConstantIntReg32(MemSemEqual, MIRBuilder, GR);
774   if (!MemSemUnequalReg.isValid())
775     MemSemUnequalReg = buildConstantIntReg32(MemSemUnequal, MIRBuilder, GR);
776 
777   Register ScopeReg;
778   auto Scope = IsCmpxchg ? SPIRV::Scope::Workgroup : SPIRV::Scope::Device;
779   if (Call->Arguments.size() >= 6) {
780     assert(Call->Arguments.size() == 6 &&
781            "Extra args for explicit atomic cmpxchg");
782     auto ClScope = static_cast<SPIRV::CLMemoryScope>(
783         getIConstVal(Call->Arguments[5], MRI));
784     Scope = getSPIRVScope(ClScope);
785     if (ClScope == static_cast<unsigned>(Scope))
786       ScopeReg = Call->Arguments[5];
787   }
788   if (!ScopeReg.isValid())
789     ScopeReg = buildConstantIntReg32(Scope, MIRBuilder, GR);
790 
791   Register Expected = IsCmpxchg
792                           ? ExpectedArg
793                           : buildLoadInst(SpvDesiredTy, ExpectedArg, MIRBuilder,
794                                           GR, LLT::scalar(64));
795   MRI->setType(Expected, DesiredLLT);
796   Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
797                             : Call->ReturnRegister;
798   if (!MRI->getRegClassOrNull(Tmp))
799     MRI->setRegClass(Tmp, GR->getRegClass(SpvDesiredTy));
800   GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF());
801 
802   SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
803   MIRBuilder.buildInstr(Opcode)
804       .addDef(Tmp)
805       .addUse(GR->getSPIRVTypeID(IntTy))
806       .addUse(ObjectPtr)
807       .addUse(ScopeReg)
808       .addUse(MemSemEqualReg)
809       .addUse(MemSemUnequalReg)
810       .addUse(Desired)
811       .addUse(Expected);
812   if (!IsCmpxchg) {
813     MIRBuilder.buildInstr(SPIRV::OpStore).addUse(ExpectedArg).addUse(Tmp);
814     MIRBuilder.buildICmp(CmpInst::ICMP_EQ, Call->ReturnRegister, Tmp, Expected);
815   }
816   return true;
817 }
818 
819 /// Helper function for building atomic instructions.
buildAtomicRMWInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)820 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
821                                MachineIRBuilder &MIRBuilder,
822                                SPIRVGlobalRegistry *GR) {
823   if (Call->isSpirvOp())
824     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
825                               GR->getSPIRVTypeID(Call->ReturnType));
826 
827   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
828   Register ScopeRegister =
829       Call->Arguments.size() >= 4 ? Call->Arguments[3] : Register();
830 
831   assert(Call->Arguments.size() <= 4 &&
832          "Too many args for explicit atomic RMW");
833   ScopeRegister = buildScopeReg(ScopeRegister, SPIRV::Scope::Workgroup,
834                                 MIRBuilder, GR, MRI);
835 
836   Register PtrRegister = Call->Arguments[0];
837   unsigned Semantics = SPIRV::MemorySemantics::None;
838   Register MemSemanticsReg =
839       Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
840   MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
841                                          Semantics, MIRBuilder, GR);
842   Register ValueReg = Call->Arguments[1];
843   Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
844   // support cl_ext_float_atomics
845   if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
846     if (Opcode == SPIRV::OpAtomicIAdd) {
847       Opcode = SPIRV::OpAtomicFAddEXT;
848     } else if (Opcode == SPIRV::OpAtomicISub) {
849       // Translate OpAtomicISub applied to a floating type argument to
850       // OpAtomicFAddEXT with the negative value operand
851       Opcode = SPIRV::OpAtomicFAddEXT;
852       Register NegValueReg =
853           MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
854       MRI->setRegClass(NegValueReg, GR->getRegClass(Call->ReturnType));
855       GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
856                                 MIRBuilder.getMF());
857       MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
858           .addDef(NegValueReg)
859           .addUse(ValueReg);
860       insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
861                         MIRBuilder.getMF().getRegInfo());
862       ValueReg = NegValueReg;
863     }
864   }
865   MIRBuilder.buildInstr(Opcode)
866       .addDef(Call->ReturnRegister)
867       .addUse(ValueTypeReg)
868       .addUse(PtrRegister)
869       .addUse(ScopeRegister)
870       .addUse(MemSemanticsReg)
871       .addUse(ValueReg);
872   return true;
873 }
874 
875 /// Helper function for building an atomic floating-type instruction.
buildAtomicFloatingRMWInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)876 static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
877                                        unsigned Opcode,
878                                        MachineIRBuilder &MIRBuilder,
879                                        SPIRVGlobalRegistry *GR) {
880   assert(Call->Arguments.size() == 4 &&
881          "Wrong number of atomic floating-type builtin");
882   Register PtrReg = Call->Arguments[0];
883   Register ScopeReg = Call->Arguments[1];
884   Register MemSemanticsReg = Call->Arguments[2];
885   Register ValueReg = Call->Arguments[3];
886   MIRBuilder.buildInstr(Opcode)
887       .addDef(Call->ReturnRegister)
888       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
889       .addUse(PtrReg)
890       .addUse(ScopeReg)
891       .addUse(MemSemanticsReg)
892       .addUse(ValueReg);
893   return true;
894 }
895 
896 /// Helper function for building atomic flag instructions (e.g.
897 /// OpAtomicFlagTestAndSet).
buildAtomicFlagInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)898 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
899                                 unsigned Opcode, MachineIRBuilder &MIRBuilder,
900                                 SPIRVGlobalRegistry *GR) {
901   bool IsSet = Opcode == SPIRV::OpAtomicFlagTestAndSet;
902   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
903   if (Call->isSpirvOp())
904     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
905                               IsSet ? TypeReg : Register(0));
906 
907   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
908   Register PtrRegister = Call->Arguments[0];
909   unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent;
910   Register MemSemanticsReg =
911       Call->Arguments.size() >= 2 ? Call->Arguments[1] : Register();
912   MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
913                                          Semantics, MIRBuilder, GR);
914 
915   assert((Opcode != SPIRV::OpAtomicFlagClear ||
916           (Semantics != SPIRV::MemorySemantics::Acquire &&
917            Semantics != SPIRV::MemorySemantics::AcquireRelease)) &&
918          "Invalid memory order argument!");
919 
920   Register ScopeRegister =
921       Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
922   ScopeRegister =
923       buildScopeReg(ScopeRegister, SPIRV::Scope::Device, MIRBuilder, GR, MRI);
924 
925   auto MIB = MIRBuilder.buildInstr(Opcode);
926   if (IsSet)
927     MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
928 
929   MIB.addUse(PtrRegister).addUse(ScopeRegister).addUse(MemSemanticsReg);
930   return true;
931 }
932 
933 /// Helper function for building barriers, i.e., memory/control ordering
934 /// operations.
buildBarrierInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)935 static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
936                              MachineIRBuilder &MIRBuilder,
937                              SPIRVGlobalRegistry *GR) {
938   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
939   const auto *ST =
940       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
941   if ((Opcode == SPIRV::OpControlBarrierArriveINTEL ||
942        Opcode == SPIRV::OpControlBarrierWaitINTEL) &&
943       !ST->canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
944     std::string DiagMsg = std::string(Builtin->Name) +
945                           ": the builtin requires the following SPIR-V "
946                           "extension: SPV_INTEL_split_barrier";
947     report_fatal_error(DiagMsg.c_str(), false);
948   }
949 
950   if (Call->isSpirvOp())
951     return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
952 
953   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
954   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
955   unsigned MemSemantics = SPIRV::MemorySemantics::None;
956 
957   if (MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE)
958     MemSemantics |= SPIRV::MemorySemantics::WorkgroupMemory;
959 
960   if (MemFlags & SPIRV::CLK_GLOBAL_MEM_FENCE)
961     MemSemantics |= SPIRV::MemorySemantics::CrossWorkgroupMemory;
962 
963   if (MemFlags & SPIRV::CLK_IMAGE_MEM_FENCE)
964     MemSemantics |= SPIRV::MemorySemantics::ImageMemory;
965 
966   if (Opcode == SPIRV::OpMemoryBarrier)
967     MemSemantics = getSPIRVMemSemantics(static_cast<std::memory_order>(
968                        getIConstVal(Call->Arguments[1], MRI))) |
969                    MemSemantics;
970   else if (Opcode == SPIRV::OpControlBarrierArriveINTEL)
971     MemSemantics |= SPIRV::MemorySemantics::Release;
972   else if (Opcode == SPIRV::OpControlBarrierWaitINTEL)
973     MemSemantics |= SPIRV::MemorySemantics::Acquire;
974   else
975     MemSemantics |= SPIRV::MemorySemantics::SequentiallyConsistent;
976 
977   Register MemSemanticsReg =
978       MemFlags == MemSemantics
979           ? Call->Arguments[0]
980           : buildConstantIntReg32(MemSemantics, MIRBuilder, GR);
981   Register ScopeReg;
982   SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
983   SPIRV::Scope::Scope MemScope = Scope;
984   if (Call->Arguments.size() >= 2) {
985     assert(
986         ((Opcode != SPIRV::OpMemoryBarrier && Call->Arguments.size() == 2) ||
987          (Opcode == SPIRV::OpMemoryBarrier && Call->Arguments.size() == 3)) &&
988         "Extra args for explicitly scoped barrier");
989     Register ScopeArg = (Opcode == SPIRV::OpMemoryBarrier) ? Call->Arguments[2]
990                                                            : Call->Arguments[1];
991     SPIRV::CLMemoryScope CLScope =
992         static_cast<SPIRV::CLMemoryScope>(getIConstVal(ScopeArg, MRI));
993     MemScope = getSPIRVScope(CLScope);
994     if (!(MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE) ||
995         (Opcode == SPIRV::OpMemoryBarrier))
996       Scope = MemScope;
997     if (CLScope == static_cast<unsigned>(Scope))
998       ScopeReg = Call->Arguments[1];
999   }
1000 
1001   if (!ScopeReg.isValid())
1002     ScopeReg = buildConstantIntReg32(Scope, MIRBuilder, GR);
1003 
1004   auto MIB = MIRBuilder.buildInstr(Opcode).addUse(ScopeReg);
1005   if (Opcode != SPIRV::OpMemoryBarrier)
1006     MIB.addUse(buildConstantIntReg32(MemScope, MIRBuilder, GR));
1007   MIB.addUse(MemSemanticsReg);
1008   return true;
1009 }
1010 
1011 /// Helper function for building extended bit operations.
buildExtendedBitOpsInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1012 static bool buildExtendedBitOpsInst(const SPIRV::IncomingCall *Call,
1013                                     unsigned Opcode,
1014                                     MachineIRBuilder &MIRBuilder,
1015                                     SPIRVGlobalRegistry *GR) {
1016   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1017   const auto *ST =
1018       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
1019   if ((Opcode == SPIRV::OpBitFieldInsert ||
1020        Opcode == SPIRV::OpBitFieldSExtract ||
1021        Opcode == SPIRV::OpBitFieldUExtract || Opcode == SPIRV::OpBitReverse) &&
1022       !ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1023     std::string DiagMsg = std::string(Builtin->Name) +
1024                           ": the builtin requires the following SPIR-V "
1025                           "extension: SPV_KHR_bit_instructions";
1026     report_fatal_error(DiagMsg.c_str(), false);
1027   }
1028 
1029   // Generate SPIRV instruction accordingly.
1030   if (Call->isSpirvOp())
1031     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1032                               GR->getSPIRVTypeID(Call->ReturnType));
1033 
1034   auto MIB = MIRBuilder.buildInstr(Opcode)
1035                  .addDef(Call->ReturnRegister)
1036                  .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1037   for (unsigned i = 0; i < Call->Arguments.size(); ++i)
1038     MIB.addUse(Call->Arguments[i]);
1039 
1040   return true;
1041 }
1042 
1043 /// Helper function for building Intel's bindless image instructions.
buildBindlessImageINTELInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1044 static bool buildBindlessImageINTELInst(const SPIRV::IncomingCall *Call,
1045                                         unsigned Opcode,
1046                                         MachineIRBuilder &MIRBuilder,
1047                                         SPIRVGlobalRegistry *GR) {
1048   // Generate SPIRV instruction accordingly.
1049   if (Call->isSpirvOp())
1050     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1051                               GR->getSPIRVTypeID(Call->ReturnType));
1052 
1053   MIRBuilder.buildInstr(Opcode)
1054       .addDef(Call->ReturnRegister)
1055       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1056       .addUse(Call->Arguments[0]);
1057 
1058   return true;
1059 }
1060 
1061 /// Helper function for building Intel's OpBitwiseFunctionINTEL instruction.
buildTernaryBitwiseFunctionINTELInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1062 static bool buildTernaryBitwiseFunctionINTELInst(
1063     const SPIRV::IncomingCall *Call, unsigned Opcode,
1064     MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
1065   // Generate SPIRV instruction accordingly.
1066   if (Call->isSpirvOp())
1067     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1068                               GR->getSPIRVTypeID(Call->ReturnType));
1069 
1070   auto MIB = MIRBuilder.buildInstr(Opcode)
1071                  .addDef(Call->ReturnRegister)
1072                  .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1073   for (unsigned i = 0; i < Call->Arguments.size(); ++i)
1074     MIB.addUse(Call->Arguments[i]);
1075 
1076   return true;
1077 }
1078 
1079 /// Helper function for building Intel's 2d block io instructions.
build2DBlockIOINTELInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1080 static bool build2DBlockIOINTELInst(const SPIRV::IncomingCall *Call,
1081                                     unsigned Opcode,
1082                                     MachineIRBuilder &MIRBuilder,
1083                                     SPIRVGlobalRegistry *GR) {
1084   // Generate SPIRV instruction accordingly.
1085   if (Call->isSpirvOp())
1086     return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
1087 
1088   auto MIB = MIRBuilder.buildInstr(Opcode)
1089                  .addDef(Call->ReturnRegister)
1090                  .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1091   for (unsigned i = 0; i < Call->Arguments.size(); ++i)
1092     MIB.addUse(Call->Arguments[i]);
1093 
1094   return true;
1095 }
1096 
getNumComponentsForDim(SPIRV::Dim::Dim dim)1097 static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
1098   switch (dim) {
1099   case SPIRV::Dim::DIM_1D:
1100   case SPIRV::Dim::DIM_Buffer:
1101     return 1;
1102   case SPIRV::Dim::DIM_2D:
1103   case SPIRV::Dim::DIM_Cube:
1104   case SPIRV::Dim::DIM_Rect:
1105     return 2;
1106   case SPIRV::Dim::DIM_3D:
1107     return 3;
1108   default:
1109     report_fatal_error("Cannot get num components for given Dim");
1110   }
1111 }
1112 
1113 /// Helper function for obtaining the number of size components.
getNumSizeComponents(SPIRVType * imgType)1114 static unsigned getNumSizeComponents(SPIRVType *imgType) {
1115   assert(imgType->getOpcode() == SPIRV::OpTypeImage);
1116   auto dim = static_cast<SPIRV::Dim::Dim>(imgType->getOperand(2).getImm());
1117   unsigned numComps = getNumComponentsForDim(dim);
1118   bool arrayed = imgType->getOperand(4).getImm() == 1;
1119   return arrayed ? numComps + 1 : numComps;
1120 }
1121 
1122 //===----------------------------------------------------------------------===//
1123 // Implementation functions for each builtin group
1124 //===----------------------------------------------------------------------===//
1125 
generateExtInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1126 static bool generateExtInst(const SPIRV::IncomingCall *Call,
1127                             MachineIRBuilder &MIRBuilder,
1128                             SPIRVGlobalRegistry *GR) {
1129   // Lookup the extended instruction number in the TableGen records.
1130   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1131   uint32_t Number =
1132       SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
1133 
1134   // Build extended instruction.
1135   auto MIB =
1136       MIRBuilder.buildInstr(SPIRV::OpExtInst)
1137           .addDef(Call->ReturnRegister)
1138           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1139           .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
1140           .addImm(Number);
1141 
1142   for (auto Argument : Call->Arguments)
1143     MIB.addUse(Argument);
1144   return true;
1145 }
1146 
generateRelationalInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1147 static bool generateRelationalInst(const SPIRV::IncomingCall *Call,
1148                                    MachineIRBuilder &MIRBuilder,
1149                                    SPIRVGlobalRegistry *GR) {
1150   // Lookup the instruction opcode in the TableGen records.
1151   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1152   unsigned Opcode =
1153       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1154 
1155   Register CompareRegister;
1156   SPIRVType *RelationType;
1157   std::tie(CompareRegister, RelationType) =
1158       buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
1159 
1160   // Build relational instruction.
1161   auto MIB = MIRBuilder.buildInstr(Opcode)
1162                  .addDef(CompareRegister)
1163                  .addUse(GR->getSPIRVTypeID(RelationType));
1164 
1165   for (auto Argument : Call->Arguments)
1166     MIB.addUse(Argument);
1167 
1168   // Build select instruction.
1169   return buildSelectInst(MIRBuilder, Call->ReturnRegister, CompareRegister,
1170                          Call->ReturnType, GR);
1171 }
1172 
generateGroupInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1173 static bool generateGroupInst(const SPIRV::IncomingCall *Call,
1174                               MachineIRBuilder &MIRBuilder,
1175                               SPIRVGlobalRegistry *GR) {
1176   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1177   const SPIRV::GroupBuiltin *GroupBuiltin =
1178       SPIRV::lookupGroupBuiltin(Builtin->Name);
1179 
1180   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1181   if (Call->isSpirvOp()) {
1182     if (GroupBuiltin->NoGroupOperation) {
1183       SmallVector<uint32_t, 1> ImmArgs;
1184       if (GroupBuiltin->Opcode ==
1185               SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL &&
1186           Call->Arguments.size() > 4)
1187         ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[4], MRI));
1188       return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
1189                                 GR->getSPIRVTypeID(Call->ReturnType), ImmArgs);
1190     }
1191 
1192     // Group Operation is a literal
1193     Register GroupOpReg = Call->Arguments[1];
1194     const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
1195     if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT)
1196       report_fatal_error(
1197           "Group Operation parameter must be an integer constant");
1198     uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
1199     Register ScopeReg = Call->Arguments[0];
1200     auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1201                    .addDef(Call->ReturnRegister)
1202                    .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1203                    .addUse(ScopeReg)
1204                    .addImm(GrpOp);
1205     for (unsigned i = 2; i < Call->Arguments.size(); ++i)
1206       MIB.addUse(Call->Arguments[i]);
1207     return true;
1208   }
1209 
1210   Register Arg0;
1211   if (GroupBuiltin->HasBoolArg) {
1212     SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
1213     Register BoolReg = Call->Arguments[0];
1214     SPIRVType *BoolRegType = GR->getSPIRVTypeForVReg(BoolReg);
1215     if (!BoolRegType)
1216       report_fatal_error("Can't find a register's type definition");
1217     MachineInstr *ArgInstruction = getDefInstrMaybeConstant(BoolReg, MRI);
1218     if (ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT) {
1219       if (BoolRegType->getOpcode() != SPIRV::OpTypeBool)
1220         Arg0 = GR->buildConstantInt(getIConstVal(BoolReg, MRI), MIRBuilder,
1221                                     BoolType, true);
1222     } else {
1223       if (BoolRegType->getOpcode() == SPIRV::OpTypeInt) {
1224         Arg0 = MRI->createGenericVirtualRegister(LLT::scalar(1));
1225         MRI->setRegClass(Arg0, &SPIRV::iIDRegClass);
1226         GR->assignSPIRVTypeToVReg(BoolType, Arg0, MIRBuilder.getMF());
1227         MIRBuilder.buildICmp(
1228             CmpInst::ICMP_NE, Arg0, BoolReg,
1229             GR->buildConstantInt(0, MIRBuilder, BoolRegType, true));
1230         insertAssignInstr(Arg0, nullptr, BoolType, GR, MIRBuilder,
1231                           MIRBuilder.getMF().getRegInfo());
1232       } else if (BoolRegType->getOpcode() != SPIRV::OpTypeBool) {
1233         report_fatal_error("Expect a boolean argument");
1234       }
1235       // if BoolReg is a boolean register, we don't need to do anything
1236     }
1237   }
1238 
1239   Register GroupResultRegister = Call->ReturnRegister;
1240   SPIRVType *GroupResultType = Call->ReturnType;
1241 
1242   // TODO: maybe we need to check whether the result type is already boolean
1243   // and in this case do not insert select instruction.
1244   const bool HasBoolReturnTy =
1245       GroupBuiltin->IsElect || GroupBuiltin->IsAllOrAny ||
1246       GroupBuiltin->IsAllEqual || GroupBuiltin->IsLogical ||
1247       GroupBuiltin->IsInverseBallot || GroupBuiltin->IsBallotBitExtract;
1248 
1249   if (HasBoolReturnTy)
1250     std::tie(GroupResultRegister, GroupResultType) =
1251         buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
1252 
1253   auto Scope = Builtin->Name.starts_with("sub_group") ? SPIRV::Scope::Subgroup
1254                                                       : SPIRV::Scope::Workgroup;
1255   Register ScopeRegister = buildConstantIntReg32(Scope, MIRBuilder, GR);
1256 
1257   Register VecReg;
1258   if (GroupBuiltin->Opcode == SPIRV::OpGroupBroadcast &&
1259       Call->Arguments.size() > 2) {
1260     // For OpGroupBroadcast "LocalId must be an integer datatype. It must be a
1261     // scalar, a vector with 2 components, or a vector with 3 components.",
1262     // meaning that we must create a vector from the function arguments if
1263     // it's a work_group_broadcast(val, local_id_x, local_id_y) or
1264     // work_group_broadcast(val, local_id_x, local_id_y, local_id_z) call.
1265     Register ElemReg = Call->Arguments[1];
1266     SPIRVType *ElemType = GR->getSPIRVTypeForVReg(ElemReg);
1267     if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeInt)
1268       report_fatal_error("Expect an integer <LocalId> argument");
1269     unsigned VecLen = Call->Arguments.size() - 1;
1270     VecReg = MRI->createGenericVirtualRegister(
1271         LLT::fixed_vector(VecLen, MRI->getType(ElemReg)));
1272     MRI->setRegClass(VecReg, &SPIRV::vIDRegClass);
1273     SPIRVType *VecType =
1274         GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder, true);
1275     GR->assignSPIRVTypeToVReg(VecType, VecReg, MIRBuilder.getMF());
1276     auto MIB =
1277         MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
1278     for (unsigned i = 1; i < Call->Arguments.size(); i++) {
1279       MIB.addUse(Call->Arguments[i]);
1280       setRegClassIfNull(Call->Arguments[i], MRI, GR);
1281     }
1282     insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
1283                       MIRBuilder.getMF().getRegInfo());
1284   }
1285 
1286   // Build work/sub group instruction.
1287   auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1288                  .addDef(GroupResultRegister)
1289                  .addUse(GR->getSPIRVTypeID(GroupResultType))
1290                  .addUse(ScopeRegister);
1291 
1292   if (!GroupBuiltin->NoGroupOperation)
1293     MIB.addImm(GroupBuiltin->GroupOperation);
1294   if (Call->Arguments.size() > 0) {
1295     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
1296     setRegClassIfNull(Call->Arguments[0], MRI, GR);
1297     if (VecReg.isValid())
1298       MIB.addUse(VecReg);
1299     else
1300       for (unsigned i = 1; i < Call->Arguments.size(); i++)
1301         MIB.addUse(Call->Arguments[i]);
1302   }
1303 
1304   // Build select instruction.
1305   if (HasBoolReturnTy)
1306     buildSelectInst(MIRBuilder, Call->ReturnRegister, GroupResultRegister,
1307                     Call->ReturnType, GR);
1308   return true;
1309 }
1310 
generateIntelSubgroupsInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1311 static bool generateIntelSubgroupsInst(const SPIRV::IncomingCall *Call,
1312                                        MachineIRBuilder &MIRBuilder,
1313                                        SPIRVGlobalRegistry *GR) {
1314   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1315   MachineFunction &MF = MIRBuilder.getMF();
1316   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1317   const SPIRV::IntelSubgroupsBuiltin *IntelSubgroups =
1318       SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name);
1319 
1320   if (IntelSubgroups->IsMedia &&
1321       !ST->canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1322     std::string DiagMsg = std::string(Builtin->Name) +
1323                           ": the builtin requires the following SPIR-V "
1324                           "extension: SPV_INTEL_media_block_io";
1325     report_fatal_error(DiagMsg.c_str(), false);
1326   } else if (!IntelSubgroups->IsMedia &&
1327              !ST->canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1328     std::string DiagMsg = std::string(Builtin->Name) +
1329                           ": the builtin requires the following SPIR-V "
1330                           "extension: SPV_INTEL_subgroups";
1331     report_fatal_error(DiagMsg.c_str(), false);
1332   }
1333 
1334   uint32_t OpCode = IntelSubgroups->Opcode;
1335   if (Call->isSpirvOp()) {
1336     bool IsSet = OpCode != SPIRV::OpSubgroupBlockWriteINTEL &&
1337                  OpCode != SPIRV::OpSubgroupImageBlockWriteINTEL &&
1338                  OpCode != SPIRV::OpSubgroupImageMediaBlockWriteINTEL;
1339     return buildOpFromWrapper(MIRBuilder, OpCode, Call,
1340                               IsSet ? GR->getSPIRVTypeID(Call->ReturnType)
1341                                     : Register(0));
1342   }
1343 
1344   if (IntelSubgroups->IsBlock) {
1345     // Minimal number or arguments set in TableGen records is 1
1346     if (SPIRVType *Arg0Type = GR->getSPIRVTypeForVReg(Call->Arguments[0])) {
1347       if (Arg0Type->getOpcode() == SPIRV::OpTypeImage) {
1348         // TODO: add required validation from the specification:
1349         // "'Image' must be an object whose type is OpTypeImage with a 'Sampled'
1350         // operand of 0 or 2. If the 'Sampled' operand is 2, then some
1351         // dimensions require a capability."
1352         switch (OpCode) {
1353         case SPIRV::OpSubgroupBlockReadINTEL:
1354           OpCode = SPIRV::OpSubgroupImageBlockReadINTEL;
1355           break;
1356         case SPIRV::OpSubgroupBlockWriteINTEL:
1357           OpCode = SPIRV::OpSubgroupImageBlockWriteINTEL;
1358           break;
1359         }
1360       }
1361     }
1362   }
1363 
1364   // TODO: opaque pointers types should be eventually resolved in such a way
1365   // that validation of block read is enabled with respect to the following
1366   // specification requirement:
1367   // "'Result Type' may be a scalar or vector type, and its component type must
1368   // be equal to the type pointed to by 'Ptr'."
1369   // For example, function parameter type should not be default i8 pointer, but
1370   // depend on the result type of the instruction where it is used as a pointer
1371   // argument of OpSubgroupBlockReadINTEL
1372 
1373   // Build Intel subgroups instruction
1374   MachineInstrBuilder MIB =
1375       IntelSubgroups->IsWrite
1376           ? MIRBuilder.buildInstr(OpCode)
1377           : MIRBuilder.buildInstr(OpCode)
1378                 .addDef(Call->ReturnRegister)
1379                 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1380   for (size_t i = 0; i < Call->Arguments.size(); ++i)
1381     MIB.addUse(Call->Arguments[i]);
1382   return true;
1383 }
1384 
generateGroupUniformInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1385 static bool generateGroupUniformInst(const SPIRV::IncomingCall *Call,
1386                                      MachineIRBuilder &MIRBuilder,
1387                                      SPIRVGlobalRegistry *GR) {
1388   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1389   MachineFunction &MF = MIRBuilder.getMF();
1390   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1391   if (!ST->canUseExtension(
1392           SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1393     std::string DiagMsg = std::string(Builtin->Name) +
1394                           ": the builtin requires the following SPIR-V "
1395                           "extension: SPV_KHR_uniform_group_instructions";
1396     report_fatal_error(DiagMsg.c_str(), false);
1397   }
1398   const SPIRV::GroupUniformBuiltin *GroupUniform =
1399       SPIRV::lookupGroupUniformBuiltin(Builtin->Name);
1400   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1401 
1402   Register GroupResultReg = Call->ReturnRegister;
1403   Register ScopeReg = Call->Arguments[0];
1404   Register ValueReg = Call->Arguments[2];
1405 
1406   // Group Operation
1407   Register ConstGroupOpReg = Call->Arguments[1];
1408   const MachineInstr *Const = getDefInstrMaybeConstant(ConstGroupOpReg, MRI);
1409   if (!Const || Const->getOpcode() != TargetOpcode::G_CONSTANT)
1410     report_fatal_error(
1411         "expect a constant group operation for a uniform group instruction",
1412         false);
1413   const MachineOperand &ConstOperand = Const->getOperand(1);
1414   if (!ConstOperand.isCImm())
1415     report_fatal_error("uniform group instructions: group operation must be an "
1416                        "integer constant",
1417                        false);
1418 
1419   auto MIB = MIRBuilder.buildInstr(GroupUniform->Opcode)
1420                  .addDef(GroupResultReg)
1421                  .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1422                  .addUse(ScopeReg);
1423   addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1424   MIB.addUse(ValueReg);
1425 
1426   return true;
1427 }
1428 
generateKernelClockInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1429 static bool generateKernelClockInst(const SPIRV::IncomingCall *Call,
1430                                     MachineIRBuilder &MIRBuilder,
1431                                     SPIRVGlobalRegistry *GR) {
1432   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1433   MachineFunction &MF = MIRBuilder.getMF();
1434   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1435   if (!ST->canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) {
1436     std::string DiagMsg = std::string(Builtin->Name) +
1437                           ": the builtin requires the following SPIR-V "
1438                           "extension: SPV_KHR_shader_clock";
1439     report_fatal_error(DiagMsg.c_str(), false);
1440   }
1441 
1442   Register ResultReg = Call->ReturnRegister;
1443 
1444   // Deduce the `Scope` operand from the builtin function name.
1445   SPIRV::Scope::Scope ScopeArg =
1446       StringSwitch<SPIRV::Scope::Scope>(Builtin->Name)
1447           .EndsWith("device", SPIRV::Scope::Scope::Device)
1448           .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
1449           .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
1450   Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR);
1451 
1452   MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
1453       .addDef(ResultReg)
1454       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1455       .addUse(ScopeReg);
1456 
1457   return true;
1458 }
1459 
1460 // These queries ask for a single size_t result for a given dimension index,
1461 // e.g. size_t get_global_id(uint dimindex). In SPIR-V, the builtins
1462 // corresponding to these values are all vec3 types, so we need to extract the
1463 // correct index or return DefaultValue (0 or 1 depending on the query). We also
1464 // handle extending or truncating in case size_t does not match the expected
1465 // result type's bitwidth.
1466 //
1467 // For a constant index >= 3 we generate:
1468 //  %res = OpConstant %SizeT DefaultValue
1469 //
1470 // For other indices we generate:
1471 //  %g = OpVariable %ptr_V3_SizeT Input
1472 //  OpDecorate %g BuiltIn XXX
1473 //  OpDecorate %g LinkageAttributes "__spirv_BuiltInXXX"
1474 //  OpDecorate %g Constant
1475 //  %loadedVec = OpLoad %V3_SizeT %g
1476 //
1477 //  Then, if the index is constant < 3, we generate:
1478 //    %res = OpCompositeExtract %SizeT %loadedVec idx
1479 //  If the index is dynamic, we generate:
1480 //    %tmp = OpVectorExtractDynamic %SizeT %loadedVec %idx
1481 //    %cmp = OpULessThan %bool %idx %const_3
1482 //    %res = OpSelect %SizeT %cmp %tmp %const_<DefaultValue>
1483 //
1484 //  If the bitwidth of %res does not match the expected return type, we add an
1485 //  extend or truncate.
genWorkgroupQuery(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,SPIRV::BuiltIn::BuiltIn BuiltinValue,uint64_t DefaultValue)1486 static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
1487                               MachineIRBuilder &MIRBuilder,
1488                               SPIRVGlobalRegistry *GR,
1489                               SPIRV::BuiltIn::BuiltIn BuiltinValue,
1490                               uint64_t DefaultValue) {
1491   Register IndexRegister = Call->Arguments[0];
1492   const unsigned ResultWidth = Call->ReturnType->getOperand(1).getImm();
1493   const unsigned PointerSize = GR->getPointerSize();
1494   const SPIRVType *PointerSizeType =
1495       GR->getOrCreateSPIRVIntegerType(PointerSize, MIRBuilder);
1496   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1497   auto IndexInstruction = getDefInstrMaybeConstant(IndexRegister, MRI);
1498 
1499   // Set up the final register to do truncation or extension on at the end.
1500   Register ToTruncate = Call->ReturnRegister;
1501 
1502   // If the index is constant, we can statically determine if it is in range.
1503   bool IsConstantIndex =
1504       IndexInstruction->getOpcode() == TargetOpcode::G_CONSTANT;
1505 
1506   // If it's out of range (max dimension is 3), we can just return the constant
1507   // default value (0 or 1 depending on which query function).
1508   if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) {
1509     Register DefaultReg = Call->ReturnRegister;
1510     if (PointerSize != ResultWidth) {
1511       DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1512       MRI->setRegClass(DefaultReg, &SPIRV::iIDRegClass);
1513       GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg,
1514                                 MIRBuilder.getMF());
1515       ToTruncate = DefaultReg;
1516     }
1517     auto NewRegister =
1518         GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
1519     MIRBuilder.buildCopy(DefaultReg, NewRegister);
1520   } else { // If it could be in range, we need to load from the given builtin.
1521     auto Vec3Ty =
1522         GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder, true);
1523     Register LoadedVector =
1524         buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue,
1525                                  LLT::fixed_vector(3, PointerSize));
1526     // Set up the vreg to extract the result to (possibly a new temporary one).
1527     Register Extracted = Call->ReturnRegister;
1528     if (!IsConstantIndex || PointerSize != ResultWidth) {
1529       Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1530       MRI->setRegClass(Extracted, &SPIRV::iIDRegClass);
1531       GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF());
1532     }
1533     // Use Intrinsic::spv_extractelt so dynamic vs static extraction is
1534     // handled later: extr = spv_extractelt LoadedVector, IndexRegister.
1535     MachineInstrBuilder ExtractInst = MIRBuilder.buildIntrinsic(
1536         Intrinsic::spv_extractelt, ArrayRef<Register>{Extracted}, true, false);
1537     ExtractInst.addUse(LoadedVector).addUse(IndexRegister);
1538 
1539     // If the index is dynamic, need check if it's < 3, and then use a select.
1540     if (!IsConstantIndex) {
1541       insertAssignInstr(Extracted, nullptr, PointerSizeType, GR, MIRBuilder,
1542                         *MRI);
1543 
1544       auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister);
1545       auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
1546 
1547       Register CompareRegister =
1548           MRI->createGenericVirtualRegister(LLT::scalar(1));
1549       MRI->setRegClass(CompareRegister, &SPIRV::iIDRegClass);
1550       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
1551 
1552       // Use G_ICMP to check if idxVReg < 3.
1553       MIRBuilder.buildICmp(
1554           CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
1555           GR->buildConstantInt(3, MIRBuilder, IndexType, true));
1556 
1557       // Get constant for the default value (0 or 1 depending on which
1558       // function).
1559       Register DefaultRegister =
1560           GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
1561 
1562       // Get a register for the selection result (possibly a new temporary one).
1563       Register SelectionResult = Call->ReturnRegister;
1564       if (PointerSize != ResultWidth) {
1565         SelectionResult =
1566             MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1567         MRI->setRegClass(SelectionResult, &SPIRV::iIDRegClass);
1568         GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult,
1569                                   MIRBuilder.getMF());
1570       }
1571       // Create the final G_SELECT to return the extracted value or the default.
1572       MIRBuilder.buildSelect(SelectionResult, CompareRegister, Extracted,
1573                              DefaultRegister);
1574       ToTruncate = SelectionResult;
1575     } else {
1576       ToTruncate = Extracted;
1577     }
1578   }
1579   // Alter the result's bitwidth if it does not match the SizeT value extracted.
1580   if (PointerSize != ResultWidth)
1581     MIRBuilder.buildZExtOrTrunc(Call->ReturnRegister, ToTruncate);
1582   return true;
1583 }
1584 
generateBuiltinVar(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1585 static bool generateBuiltinVar(const SPIRV::IncomingCall *Call,
1586                                MachineIRBuilder &MIRBuilder,
1587                                SPIRVGlobalRegistry *GR) {
1588   // Lookup the builtin variable record.
1589   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1590   SPIRV::BuiltIn::BuiltIn Value =
1591       SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1592 
1593   if (Value == SPIRV::BuiltIn::GlobalInvocationId)
1594     return genWorkgroupQuery(Call, MIRBuilder, GR, Value, 0);
1595 
1596   // Build a load instruction for the builtin variable.
1597   unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType);
1598   LLT LLType;
1599   if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector)
1600     LLType =
1601         LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth);
1602   else
1603     LLType = LLT::scalar(BitWidth);
1604 
1605   return buildBuiltinVariableLoad(MIRBuilder, Call->ReturnType, GR, Value,
1606                                   LLType, Call->ReturnRegister);
1607 }
1608 
generateAtomicInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1609 static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
1610                                MachineIRBuilder &MIRBuilder,
1611                                SPIRVGlobalRegistry *GR) {
1612   // Lookup the instruction opcode in the TableGen records.
1613   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1614   unsigned Opcode =
1615       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1616 
1617   switch (Opcode) {
1618   case SPIRV::OpStore:
1619     return buildAtomicInitInst(Call, MIRBuilder);
1620   case SPIRV::OpAtomicLoad:
1621     return buildAtomicLoadInst(Call, MIRBuilder, GR);
1622   case SPIRV::OpAtomicStore:
1623     return buildAtomicStoreInst(Call, MIRBuilder, GR);
1624   case SPIRV::OpAtomicCompareExchange:
1625   case SPIRV::OpAtomicCompareExchangeWeak:
1626     return buildAtomicCompareExchangeInst(Call, Builtin, Opcode, MIRBuilder,
1627                                           GR);
1628   case SPIRV::OpAtomicIAdd:
1629   case SPIRV::OpAtomicISub:
1630   case SPIRV::OpAtomicOr:
1631   case SPIRV::OpAtomicXor:
1632   case SPIRV::OpAtomicAnd:
1633   case SPIRV::OpAtomicExchange:
1634     return buildAtomicRMWInst(Call, Opcode, MIRBuilder, GR);
1635   case SPIRV::OpMemoryBarrier:
1636     return buildBarrierInst(Call, SPIRV::OpMemoryBarrier, MIRBuilder, GR);
1637   case SPIRV::OpAtomicFlagTestAndSet:
1638   case SPIRV::OpAtomicFlagClear:
1639     return buildAtomicFlagInst(Call, Opcode, MIRBuilder, GR);
1640   default:
1641     if (Call->isSpirvOp())
1642       return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1643                                 GR->getSPIRVTypeID(Call->ReturnType));
1644     return false;
1645   }
1646 }
1647 
generateAtomicFloatingInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1648 static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
1649                                        MachineIRBuilder &MIRBuilder,
1650                                        SPIRVGlobalRegistry *GR) {
1651   // Lookup the instruction opcode in the TableGen records.
1652   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1653   unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
1654 
1655   switch (Opcode) {
1656   case SPIRV::OpAtomicFAddEXT:
1657   case SPIRV::OpAtomicFMinEXT:
1658   case SPIRV::OpAtomicFMaxEXT:
1659     return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
1660   default:
1661     return false;
1662   }
1663 }
1664 
generateBarrierInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1665 static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
1666                                 MachineIRBuilder &MIRBuilder,
1667                                 SPIRVGlobalRegistry *GR) {
1668   // Lookup the instruction opcode in the TableGen records.
1669   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1670   unsigned Opcode =
1671       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1672 
1673   return buildBarrierInst(Call, Opcode, MIRBuilder, GR);
1674 }
1675 
generateCastToPtrInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1676 static bool generateCastToPtrInst(const SPIRV::IncomingCall *Call,
1677                                   MachineIRBuilder &MIRBuilder,
1678                                   SPIRVGlobalRegistry *GR) {
1679   // Lookup the instruction opcode in the TableGen records.
1680   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1681   unsigned Opcode =
1682       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1683 
1684   if (Opcode == SPIRV::OpGenericCastToPtrExplicit) {
1685     SPIRV::StorageClass::StorageClass ResSC =
1686         GR->getPointerStorageClass(Call->ReturnRegister);
1687     if (!isGenericCastablePtr(ResSC))
1688       return false;
1689 
1690     MIRBuilder.buildInstr(Opcode)
1691         .addDef(Call->ReturnRegister)
1692         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1693         .addUse(Call->Arguments[0])
1694         .addImm(ResSC);
1695   } else {
1696     MIRBuilder.buildInstr(TargetOpcode::G_ADDRSPACE_CAST)
1697         .addDef(Call->ReturnRegister)
1698         .addUse(Call->Arguments[0]);
1699   }
1700   return true;
1701 }
1702 
generateDotOrFMulInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1703 static bool generateDotOrFMulInst(const StringRef DemangledCall,
1704                                   const SPIRV::IncomingCall *Call,
1705                                   MachineIRBuilder &MIRBuilder,
1706                                   SPIRVGlobalRegistry *GR) {
1707   if (Call->isSpirvOp())
1708     return buildOpFromWrapper(MIRBuilder, SPIRV::OpDot, Call,
1709                               GR->getSPIRVTypeID(Call->ReturnType));
1710 
1711   bool IsVec = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() ==
1712                SPIRV::OpTypeVector;
1713   // Use OpDot only in case of vector args and OpFMul in case of scalar args.
1714   uint32_t OC = IsVec ? SPIRV::OpDot : SPIRV::OpFMulS;
1715   bool IsSwapReq = false;
1716 
1717   const auto *ST =
1718       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
1719   if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt) &&
1720       (ST->canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
1721        ST->isAtLeastSPIRVVer(VersionTuple(1, 6)))) {
1722     const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1723     const SPIRV::IntegerDotProductBuiltin *IntDot =
1724         SPIRV::lookupIntegerDotProductBuiltin(Builtin->Name);
1725     if (IntDot) {
1726       OC = IntDot->Opcode;
1727       IsSwapReq = IntDot->IsSwapReq;
1728     } else if (IsVec) {
1729       // Handling "dot" and "dot_acc_sat" builtins which use vectors of
1730       // integers.
1731       LLVMContext &Ctx = MIRBuilder.getContext();
1732       SmallVector<StringRef, 10> TypeStrs;
1733       SPIRV::parseBuiltinTypeStr(TypeStrs, DemangledCall, Ctx);
1734       bool IsFirstSigned = TypeStrs[0].trim()[0] != 'u';
1735       bool IsSecondSigned = TypeStrs[1].trim()[0] != 'u';
1736 
1737       if (Call->BuiltinName == "dot") {
1738         if (IsFirstSigned && IsSecondSigned)
1739           OC = SPIRV::OpSDot;
1740         else if (!IsFirstSigned && !IsSecondSigned)
1741           OC = SPIRV::OpUDot;
1742         else {
1743           OC = SPIRV::OpSUDot;
1744           if (!IsFirstSigned)
1745             IsSwapReq = true;
1746         }
1747       } else if (Call->BuiltinName == "dot_acc_sat") {
1748         if (IsFirstSigned && IsSecondSigned)
1749           OC = SPIRV::OpSDotAccSat;
1750         else if (!IsFirstSigned && !IsSecondSigned)
1751           OC = SPIRV::OpUDotAccSat;
1752         else {
1753           OC = SPIRV::OpSUDotAccSat;
1754           if (!IsFirstSigned)
1755             IsSwapReq = true;
1756         }
1757       }
1758     }
1759   }
1760 
1761   MachineInstrBuilder MIB = MIRBuilder.buildInstr(OC)
1762                                 .addDef(Call->ReturnRegister)
1763                                 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1764 
1765   if (IsSwapReq) {
1766     MIB.addUse(Call->Arguments[1]);
1767     MIB.addUse(Call->Arguments[0]);
1768     // needed for dot_acc_sat* builtins
1769     for (size_t i = 2; i < Call->Arguments.size(); ++i)
1770       MIB.addUse(Call->Arguments[i]);
1771   } else {
1772     for (size_t i = 0; i < Call->Arguments.size(); ++i)
1773       MIB.addUse(Call->Arguments[i]);
1774   }
1775 
1776   // Add Packed Vector Format for Integer dot product builtins if arguments are
1777   // scalar
1778   if (!IsVec && OC != SPIRV::OpFMulS)
1779     MIB.addImm(0);
1780 
1781   return true;
1782 }
1783 
generateWaveInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1784 static bool generateWaveInst(const SPIRV::IncomingCall *Call,
1785                              MachineIRBuilder &MIRBuilder,
1786                              SPIRVGlobalRegistry *GR) {
1787   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1788   SPIRV::BuiltIn::BuiltIn Value =
1789       SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1790 
1791   // For now, we only support a single Wave intrinsic with a single return type.
1792   assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
1793   LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
1794 
1795   return buildBuiltinVariableLoad(
1796       MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
1797       /* isConst= */ false, /* hasLinkageTy= */ false);
1798 }
1799 
1800 // We expect a builtin
1801 //     Name(ptr sret([RetType]) %result, Type %operand1, Type %operand1)
1802 // where %result is a pointer to where the result of the builtin execution
1803 // is to be stored, and generate the following instructions:
1804 //     Res = Opcode RetType Operand1 Operand1
1805 //     OpStore RetVariable Res
generateICarryBorrowInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1806 static bool generateICarryBorrowInst(const SPIRV::IncomingCall *Call,
1807                                      MachineIRBuilder &MIRBuilder,
1808                                      SPIRVGlobalRegistry *GR) {
1809   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1810   unsigned Opcode =
1811       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1812 
1813   Register SRetReg = Call->Arguments[0];
1814   SPIRVType *PtrRetType = GR->getSPIRVTypeForVReg(SRetReg);
1815   SPIRVType *RetType = GR->getPointeeType(PtrRetType);
1816   if (!RetType)
1817     report_fatal_error("The first parameter must be a pointer");
1818   if (RetType->getOpcode() != SPIRV::OpTypeStruct)
1819     report_fatal_error("Expected struct type result for the arithmetic with "
1820                        "overflow builtins");
1821 
1822   SPIRVType *OpType1 = GR->getSPIRVTypeForVReg(Call->Arguments[1]);
1823   SPIRVType *OpType2 = GR->getSPIRVTypeForVReg(Call->Arguments[2]);
1824   if (!OpType1 || !OpType2 || OpType1 != OpType2)
1825     report_fatal_error("Operands must have the same type");
1826   if (OpType1->getOpcode() == SPIRV::OpTypeVector)
1827     switch (Opcode) {
1828     case SPIRV::OpIAddCarryS:
1829       Opcode = SPIRV::OpIAddCarryV;
1830       break;
1831     case SPIRV::OpISubBorrowS:
1832       Opcode = SPIRV::OpISubBorrowV;
1833       break;
1834     }
1835 
1836   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1837   Register ResReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
1838   if (const TargetRegisterClass *DstRC =
1839           MRI->getRegClassOrNull(Call->Arguments[1])) {
1840     MRI->setRegClass(ResReg, DstRC);
1841     MRI->setType(ResReg, MRI->getType(Call->Arguments[1]));
1842   } else {
1843     MRI->setType(ResReg, LLT::scalar(64));
1844   }
1845   GR->assignSPIRVTypeToVReg(RetType, ResReg, MIRBuilder.getMF());
1846   MIRBuilder.buildInstr(Opcode)
1847       .addDef(ResReg)
1848       .addUse(GR->getSPIRVTypeID(RetType))
1849       .addUse(Call->Arguments[1])
1850       .addUse(Call->Arguments[2]);
1851   MIRBuilder.buildInstr(SPIRV::OpStore).addUse(SRetReg).addUse(ResReg);
1852   return true;
1853 }
1854 
generateGetQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1855 static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
1856                                  MachineIRBuilder &MIRBuilder,
1857                                  SPIRVGlobalRegistry *GR) {
1858   // Lookup the builtin record.
1859   SPIRV::BuiltIn::BuiltIn Value =
1860       SPIRV::lookupGetBuiltin(Call->Builtin->Name, Call->Builtin->Set)->Value;
1861   const bool IsDefaultOne = (Value == SPIRV::BuiltIn::GlobalSize ||
1862                              Value == SPIRV::BuiltIn::NumWorkgroups ||
1863                              Value == SPIRV::BuiltIn::WorkgroupSize ||
1864                              Value == SPIRV::BuiltIn::EnqueuedWorkgroupSize);
1865   return genWorkgroupQuery(Call, MIRBuilder, GR, Value, IsDefaultOne ? 1 : 0);
1866 }
1867 
generateImageSizeQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1868 static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
1869                                        MachineIRBuilder &MIRBuilder,
1870                                        SPIRVGlobalRegistry *GR) {
1871   // Lookup the image size query component number in the TableGen records.
1872   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1873   uint32_t Component =
1874       SPIRV::lookupImageQueryBuiltin(Builtin->Name, Builtin->Set)->Component;
1875   // Query result may either be a vector or a scalar. If return type is not a
1876   // vector, expect only a single size component. Otherwise get the number of
1877   // expected components.
1878   unsigned NumExpectedRetComponents =
1879       Call->ReturnType->getOpcode() == SPIRV::OpTypeVector
1880           ? Call->ReturnType->getOperand(2).getImm()
1881           : 1;
1882   // Get the actual number of query result/size components.
1883   SPIRVType *ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1884   unsigned NumActualRetComponents = getNumSizeComponents(ImgType);
1885   Register QueryResult = Call->ReturnRegister;
1886   SPIRVType *QueryResultType = Call->ReturnType;
1887   if (NumExpectedRetComponents != NumActualRetComponents) {
1888     unsigned Bitwidth = Call->ReturnType->getOpcode() == SPIRV::OpTypeInt
1889                             ? Call->ReturnType->getOperand(1).getImm()
1890                             : 32;
1891     QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister(
1892         LLT::fixed_vector(NumActualRetComponents, Bitwidth));
1893     MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::vIDRegClass);
1894     SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(Bitwidth, MIRBuilder);
1895     QueryResultType = GR->getOrCreateSPIRVVectorType(
1896         IntTy, NumActualRetComponents, MIRBuilder, true);
1897     GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF());
1898   }
1899   bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
1900   unsigned Opcode =
1901       IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod;
1902   auto MIB = MIRBuilder.buildInstr(Opcode)
1903                  .addDef(QueryResult)
1904                  .addUse(GR->getSPIRVTypeID(QueryResultType))
1905                  .addUse(Call->Arguments[0]);
1906   if (!IsDimBuf)
1907     MIB.addUse(buildConstantIntReg32(0, MIRBuilder, GR)); // Lod id.
1908   if (NumExpectedRetComponents == NumActualRetComponents)
1909     return true;
1910   if (NumExpectedRetComponents == 1) {
1911     // Only 1 component is expected, build OpCompositeExtract instruction.
1912     unsigned ExtractedComposite =
1913         Component == 3 ? NumActualRetComponents - 1 : Component;
1914     assert(ExtractedComposite < NumActualRetComponents &&
1915            "Invalid composite index!");
1916     Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
1917     SPIRVType *NewType = nullptr;
1918     if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) {
1919       Register NewTypeReg = QueryResultType->getOperand(1).getReg();
1920       if (TypeReg != NewTypeReg &&
1921           (NewType = GR->getSPIRVTypeForVReg(NewTypeReg)) != nullptr)
1922         TypeReg = NewTypeReg;
1923     }
1924     MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1925         .addDef(Call->ReturnRegister)
1926         .addUse(TypeReg)
1927         .addUse(QueryResult)
1928         .addImm(ExtractedComposite);
1929     if (NewType != nullptr)
1930       insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
1931                         MIRBuilder.getMF().getRegInfo());
1932   } else {
1933     // More than 1 component is expected, fill a new vector.
1934     auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle)
1935                    .addDef(Call->ReturnRegister)
1936                    .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1937                    .addUse(QueryResult)
1938                    .addUse(QueryResult);
1939     for (unsigned i = 0; i < NumExpectedRetComponents; ++i)
1940       MIB.addImm(i < NumActualRetComponents ? i : 0xffffffff);
1941   }
1942   return true;
1943 }
1944 
generateImageMiscQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1945 static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
1946                                        MachineIRBuilder &MIRBuilder,
1947                                        SPIRVGlobalRegistry *GR) {
1948   assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt &&
1949          "Image samples query result must be of int type!");
1950 
1951   // Lookup the instruction opcode in the TableGen records.
1952   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1953   unsigned Opcode =
1954       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1955 
1956   Register Image = Call->Arguments[0];
1957   SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
1958       GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
1959   (void)ImageDimensionality;
1960 
1961   switch (Opcode) {
1962   case SPIRV::OpImageQuerySamples:
1963     assert(ImageDimensionality == SPIRV::Dim::DIM_2D &&
1964            "Image must be of 2D dimensionality");
1965     break;
1966   case SPIRV::OpImageQueryLevels:
1967     assert((ImageDimensionality == SPIRV::Dim::DIM_1D ||
1968             ImageDimensionality == SPIRV::Dim::DIM_2D ||
1969             ImageDimensionality == SPIRV::Dim::DIM_3D ||
1970             ImageDimensionality == SPIRV::Dim::DIM_Cube) &&
1971            "Image must be of 1D/2D/3D/Cube dimensionality");
1972     break;
1973   }
1974 
1975   MIRBuilder.buildInstr(Opcode)
1976       .addDef(Call->ReturnRegister)
1977       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1978       .addUse(Image);
1979   return true;
1980 }
1981 
1982 // TODO: Move to TableGen.
1983 static SPIRV::SamplerAddressingMode::SamplerAddressingMode
getSamplerAddressingModeFromBitmask(unsigned Bitmask)1984 getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
1985   switch (Bitmask & SPIRV::CLK_ADDRESS_MODE_MASK) {
1986   case SPIRV::CLK_ADDRESS_CLAMP:
1987     return SPIRV::SamplerAddressingMode::Clamp;
1988   case SPIRV::CLK_ADDRESS_CLAMP_TO_EDGE:
1989     return SPIRV::SamplerAddressingMode::ClampToEdge;
1990   case SPIRV::CLK_ADDRESS_REPEAT:
1991     return SPIRV::SamplerAddressingMode::Repeat;
1992   case SPIRV::CLK_ADDRESS_MIRRORED_REPEAT:
1993     return SPIRV::SamplerAddressingMode::RepeatMirrored;
1994   case SPIRV::CLK_ADDRESS_NONE:
1995     return SPIRV::SamplerAddressingMode::None;
1996   default:
1997     report_fatal_error("Unknown CL address mode");
1998   }
1999 }
2000 
getSamplerParamFromBitmask(unsigned Bitmask)2001 static unsigned getSamplerParamFromBitmask(unsigned Bitmask) {
2002   return (Bitmask & SPIRV::CLK_NORMALIZED_COORDS_TRUE) ? 1 : 0;
2003 }
2004 
2005 static SPIRV::SamplerFilterMode::SamplerFilterMode
getSamplerFilterModeFromBitmask(unsigned Bitmask)2006 getSamplerFilterModeFromBitmask(unsigned Bitmask) {
2007   if (Bitmask & SPIRV::CLK_FILTER_LINEAR)
2008     return SPIRV::SamplerFilterMode::Linear;
2009   if (Bitmask & SPIRV::CLK_FILTER_NEAREST)
2010     return SPIRV::SamplerFilterMode::Nearest;
2011   return SPIRV::SamplerFilterMode::Nearest;
2012 }
2013 
generateReadImageInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2014 static bool generateReadImageInst(const StringRef DemangledCall,
2015                                   const SPIRV::IncomingCall *Call,
2016                                   MachineIRBuilder &MIRBuilder,
2017                                   SPIRVGlobalRegistry *GR) {
2018   if (Call->isSpirvOp())
2019     return buildOpFromWrapper(MIRBuilder, SPIRV::OpImageRead, Call,
2020                               GR->getSPIRVTypeID(Call->ReturnType));
2021   Register Image = Call->Arguments[0];
2022   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2023   bool HasOclSampler = DemangledCall.contains_insensitive("ocl_sampler");
2024   bool HasMsaa = DemangledCall.contains_insensitive("msaa");
2025   if (HasOclSampler) {
2026     Register Sampler = Call->Arguments[1];
2027 
2028     if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) &&
2029         getDefInstrMaybeConstant(Sampler, MRI)->getOperand(1).isCImm()) {
2030       uint64_t SamplerMask = getIConstVal(Sampler, MRI);
2031       Sampler = GR->buildConstantSampler(
2032           Register(), getSamplerAddressingModeFromBitmask(SamplerMask),
2033           getSamplerParamFromBitmask(SamplerMask),
2034           getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder);
2035     }
2036     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
2037     SPIRVType *SampledImageType =
2038         GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
2039     Register SampledImage = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
2040 
2041     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
2042         .addDef(SampledImage)
2043         .addUse(GR->getSPIRVTypeID(SampledImageType))
2044         .addUse(Image)
2045         .addUse(Sampler);
2046 
2047     Register Lod = GR->buildConstantFP(APFloat::getZero(APFloat::IEEEsingle()),
2048                                        MIRBuilder);
2049 
2050     if (Call->ReturnType->getOpcode() != SPIRV::OpTypeVector) {
2051       SPIRVType *TempType =
2052           GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder, true);
2053       Register TempRegister =
2054           MRI->createGenericVirtualRegister(GR->getRegType(TempType));
2055       MRI->setRegClass(TempRegister, GR->getRegClass(TempType));
2056       GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF());
2057       MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
2058           .addDef(TempRegister)
2059           .addUse(GR->getSPIRVTypeID(TempType))
2060           .addUse(SampledImage)
2061           .addUse(Call->Arguments[2]) // Coordinate.
2062           .addImm(SPIRV::ImageOperand::Lod)
2063           .addUse(Lod);
2064       MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
2065           .addDef(Call->ReturnRegister)
2066           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2067           .addUse(TempRegister)
2068           .addImm(0);
2069     } else {
2070       MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
2071           .addDef(Call->ReturnRegister)
2072           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2073           .addUse(SampledImage)
2074           .addUse(Call->Arguments[2]) // Coordinate.
2075           .addImm(SPIRV::ImageOperand::Lod)
2076           .addUse(Lod);
2077     }
2078   } else if (HasMsaa) {
2079     MIRBuilder.buildInstr(SPIRV::OpImageRead)
2080         .addDef(Call->ReturnRegister)
2081         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2082         .addUse(Image)
2083         .addUse(Call->Arguments[1]) // Coordinate.
2084         .addImm(SPIRV::ImageOperand::Sample)
2085         .addUse(Call->Arguments[2]);
2086   } else {
2087     MIRBuilder.buildInstr(SPIRV::OpImageRead)
2088         .addDef(Call->ReturnRegister)
2089         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2090         .addUse(Image)
2091         .addUse(Call->Arguments[1]); // Coordinate.
2092   }
2093   return true;
2094 }
2095 
generateWriteImageInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2096 static bool generateWriteImageInst(const SPIRV::IncomingCall *Call,
2097                                    MachineIRBuilder &MIRBuilder,
2098                                    SPIRVGlobalRegistry *GR) {
2099   if (Call->isSpirvOp())
2100     return buildOpFromWrapper(MIRBuilder, SPIRV::OpImageWrite, Call,
2101                               Register(0));
2102   MIRBuilder.buildInstr(SPIRV::OpImageWrite)
2103       .addUse(Call->Arguments[0])  // Image.
2104       .addUse(Call->Arguments[1])  // Coordinate.
2105       .addUse(Call->Arguments[2]); // Texel.
2106   return true;
2107 }
2108 
generateSampleImageInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2109 static bool generateSampleImageInst(const StringRef DemangledCall,
2110                                     const SPIRV::IncomingCall *Call,
2111                                     MachineIRBuilder &MIRBuilder,
2112                                     SPIRVGlobalRegistry *GR) {
2113   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2114   if (Call->Builtin->Name.contains_insensitive(
2115           "__translate_sampler_initializer")) {
2116     // Build sampler literal.
2117     uint64_t Bitmask = getIConstVal(Call->Arguments[0], MRI);
2118     Register Sampler = GR->buildConstantSampler(
2119         Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
2120         getSamplerParamFromBitmask(Bitmask),
2121         getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder);
2122     return Sampler.isValid();
2123   } else if (Call->Builtin->Name.contains_insensitive("__spirv_SampledImage")) {
2124     // Create OpSampledImage.
2125     Register Image = Call->Arguments[0];
2126     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
2127     SPIRVType *SampledImageType =
2128         GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
2129     Register SampledImage =
2130         Call->ReturnRegister.isValid()
2131             ? Call->ReturnRegister
2132             : MRI->createVirtualRegister(&SPIRV::iIDRegClass);
2133     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
2134         .addDef(SampledImage)
2135         .addUse(GR->getSPIRVTypeID(SampledImageType))
2136         .addUse(Image)
2137         .addUse(Call->Arguments[1]); // Sampler.
2138     return true;
2139   } else if (Call->Builtin->Name.contains_insensitive(
2140                  "__spirv_ImageSampleExplicitLod")) {
2141     // Sample an image using an explicit level of detail.
2142     std::string ReturnType = DemangledCall.str();
2143     if (DemangledCall.contains("_R")) {
2144       ReturnType = ReturnType.substr(ReturnType.find("_R") + 2);
2145       ReturnType = ReturnType.substr(0, ReturnType.find('('));
2146     }
2147     SPIRVType *Type =
2148         Call->ReturnType
2149             ? Call->ReturnType
2150             : GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder, true);
2151     if (!Type) {
2152       std::string DiagMsg =
2153           "Unable to recognize SPIRV type name: " + ReturnType;
2154       report_fatal_error(DiagMsg.c_str());
2155     }
2156     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
2157         .addDef(Call->ReturnRegister)
2158         .addUse(GR->getSPIRVTypeID(Type))
2159         .addUse(Call->Arguments[0]) // Image.
2160         .addUse(Call->Arguments[1]) // Coordinate.
2161         .addImm(SPIRV::ImageOperand::Lod)
2162         .addUse(Call->Arguments[3]);
2163     return true;
2164   }
2165   return false;
2166 }
2167 
generateSelectInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)2168 static bool generateSelectInst(const SPIRV::IncomingCall *Call,
2169                                MachineIRBuilder &MIRBuilder) {
2170   MIRBuilder.buildSelect(Call->ReturnRegister, Call->Arguments[0],
2171                          Call->Arguments[1], Call->Arguments[2]);
2172   return true;
2173 }
2174 
generateConstructInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2175 static bool generateConstructInst(const SPIRV::IncomingCall *Call,
2176                                   MachineIRBuilder &MIRBuilder,
2177                                   SPIRVGlobalRegistry *GR) {
2178   createContinuedInstructions(MIRBuilder, SPIRV::OpCompositeConstruct, 3,
2179                               SPIRV::OpCompositeConstructContinuedINTEL,
2180                               Call->Arguments, Call->ReturnRegister,
2181                               GR->getSPIRVTypeID(Call->ReturnType));
2182   return true;
2183 }
2184 
generateCoopMatrInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2185 static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
2186                                  MachineIRBuilder &MIRBuilder,
2187                                  SPIRVGlobalRegistry *GR) {
2188   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2189   unsigned Opcode =
2190       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2191   bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR &&
2192                Opcode != SPIRV::OpCooperativeMatrixStoreCheckedINTEL &&
2193                Opcode != SPIRV::OpCooperativeMatrixPrefetchINTEL;
2194   unsigned ArgSz = Call->Arguments.size();
2195   unsigned LiteralIdx = 0;
2196   switch (Opcode) {
2197   // Memory operand is optional and is literal.
2198   case SPIRV::OpCooperativeMatrixLoadKHR:
2199     LiteralIdx = ArgSz > 3 ? 3 : 0;
2200     break;
2201   case SPIRV::OpCooperativeMatrixStoreKHR:
2202     LiteralIdx = ArgSz > 4 ? 4 : 0;
2203     break;
2204   case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2205     LiteralIdx = ArgSz > 7 ? 7 : 0;
2206     break;
2207   case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2208     LiteralIdx = ArgSz > 8 ? 8 : 0;
2209     break;
2210   // Cooperative Matrix Operands operand is optional and is literal.
2211   case SPIRV::OpCooperativeMatrixMulAddKHR:
2212     LiteralIdx = ArgSz > 3 ? 3 : 0;
2213     break;
2214   };
2215 
2216   SmallVector<uint32_t, 1> ImmArgs;
2217   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2218   if (Opcode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2219     const uint32_t CacheLevel = getConstFromIntrinsic(Call->Arguments[3], MRI);
2220     auto MIB = MIRBuilder.buildInstr(SPIRV::OpCooperativeMatrixPrefetchINTEL)
2221                    .addUse(Call->Arguments[0])  // pointer
2222                    .addUse(Call->Arguments[1])  // rows
2223                    .addUse(Call->Arguments[2])  // columns
2224                    .addImm(CacheLevel)          // cache level
2225                    .addUse(Call->Arguments[4]); // memory layout
2226     if (ArgSz > 5)
2227       MIB.addUse(Call->Arguments[5]); // stride
2228     if (ArgSz > 6) {
2229       const uint32_t MemOp = getConstFromIntrinsic(Call->Arguments[6], MRI);
2230       MIB.addImm(MemOp); // memory operand
2231     }
2232     return true;
2233   }
2234   if (LiteralIdx > 0)
2235     ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
2236   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
2237   if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
2238     SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
2239     if (!CoopMatrType)
2240       report_fatal_error("Can't find a register's type definition");
2241     MIRBuilder.buildInstr(Opcode)
2242         .addDef(Call->ReturnRegister)
2243         .addUse(TypeReg)
2244         .addUse(CoopMatrType->getOperand(0).getReg());
2245     return true;
2246   }
2247   return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2248                             IsSet ? TypeReg : Register(0), ImmArgs);
2249 }
2250 
generateSpecConstantInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2251 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
2252                                      MachineIRBuilder &MIRBuilder,
2253                                      SPIRVGlobalRegistry *GR) {
2254   // Lookup the instruction opcode in the TableGen records.
2255   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2256   unsigned Opcode =
2257       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2258   const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2259 
2260   switch (Opcode) {
2261   case SPIRV::OpSpecConstant: {
2262     // Build the SpecID decoration.
2263     unsigned SpecId =
2264         static_cast<unsigned>(getIConstVal(Call->Arguments[0], MRI));
2265     buildOpDecorate(Call->ReturnRegister, MIRBuilder, SPIRV::Decoration::SpecId,
2266                     {SpecId});
2267     // Determine the constant MI.
2268     Register ConstRegister = Call->Arguments[1];
2269     const MachineInstr *Const = getDefInstrMaybeConstant(ConstRegister, MRI);
2270     assert(Const &&
2271            (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
2272             Const->getOpcode() == TargetOpcode::G_FCONSTANT) &&
2273            "Argument should be either an int or floating-point constant");
2274     // Determine the opcode and built the OpSpec MI.
2275     const MachineOperand &ConstOperand = Const->getOperand(1);
2276     if (Call->ReturnType->getOpcode() == SPIRV::OpTypeBool) {
2277       assert(ConstOperand.isCImm() && "Int constant operand is expected");
2278       Opcode = ConstOperand.getCImm()->getValue().getZExtValue()
2279                    ? SPIRV::OpSpecConstantTrue
2280                    : SPIRV::OpSpecConstantFalse;
2281     }
2282     auto MIB = MIRBuilder.buildInstr(Opcode)
2283                    .addDef(Call->ReturnRegister)
2284                    .addUse(GR->getSPIRVTypeID(Call->ReturnType));
2285 
2286     if (Call->ReturnType->getOpcode() != SPIRV::OpTypeBool) {
2287       if (Const->getOpcode() == TargetOpcode::G_CONSTANT)
2288         addNumImm(ConstOperand.getCImm()->getValue(), MIB);
2289       else
2290         addNumImm(ConstOperand.getFPImm()->getValueAPF().bitcastToAPInt(), MIB);
2291     }
2292     return true;
2293   }
2294   case SPIRV::OpSpecConstantComposite: {
2295     createContinuedInstructions(MIRBuilder, Opcode, 3,
2296                                 SPIRV::OpSpecConstantCompositeContinuedINTEL,
2297                                 Call->Arguments, Call->ReturnRegister,
2298                                 GR->getSPIRVTypeID(Call->ReturnType));
2299     return true;
2300   }
2301   default:
2302     return false;
2303   }
2304 }
2305 
generateExtendedBitOpsInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2306 static bool generateExtendedBitOpsInst(const SPIRV::IncomingCall *Call,
2307                                        MachineIRBuilder &MIRBuilder,
2308                                        SPIRVGlobalRegistry *GR) {
2309   // Lookup the instruction opcode in the TableGen records.
2310   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2311   unsigned Opcode =
2312       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2313 
2314   return buildExtendedBitOpsInst(Call, Opcode, MIRBuilder, GR);
2315 }
2316 
generateBindlessImageINTELInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2317 static bool generateBindlessImageINTELInst(const SPIRV::IncomingCall *Call,
2318                                            MachineIRBuilder &MIRBuilder,
2319                                            SPIRVGlobalRegistry *GR) {
2320   // Lookup the instruction opcode in the TableGen records.
2321   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2322   unsigned Opcode =
2323       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2324 
2325   return buildBindlessImageINTELInst(Call, Opcode, MIRBuilder, GR);
2326 }
2327 
2328 static bool
generateTernaryBitwiseFunctionINTELInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2329 generateTernaryBitwiseFunctionINTELInst(const SPIRV::IncomingCall *Call,
2330                                         MachineIRBuilder &MIRBuilder,
2331                                         SPIRVGlobalRegistry *GR) {
2332   // Lookup the instruction opcode in the TableGen records.
2333   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2334   unsigned Opcode =
2335       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2336 
2337   return buildTernaryBitwiseFunctionINTELInst(Call, Opcode, MIRBuilder, GR);
2338 }
2339 
generate2DBlockIOINTELInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2340 static bool generate2DBlockIOINTELInst(const SPIRV::IncomingCall *Call,
2341                                        MachineIRBuilder &MIRBuilder,
2342                                        SPIRVGlobalRegistry *GR) {
2343   // Lookup the instruction opcode in the TableGen records.
2344   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2345   unsigned Opcode =
2346       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2347 
2348   return build2DBlockIOINTELInst(Call, Opcode, MIRBuilder, GR);
2349 }
2350 
buildNDRange(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2351 static bool buildNDRange(const SPIRV::IncomingCall *Call,
2352                          MachineIRBuilder &MIRBuilder,
2353                          SPIRVGlobalRegistry *GR) {
2354   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2355   SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
2356   assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
2357          PtrType->getOperand(2).isReg());
2358   Register TypeReg = PtrType->getOperand(2).getReg();
2359   SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
2360   MachineFunction &MF = MIRBuilder.getMF();
2361   Register TmpReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
2362   GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF);
2363   // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
2364   // three other arguments, so pass zero constant on absence.
2365   unsigned NumArgs = Call->Arguments.size();
2366   assert(NumArgs >= 2);
2367   Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
2368   Register LocalWorkSize =
2369       NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
2370   Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
2371   if (NumArgs < 4) {
2372     Register Const;
2373     SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
2374     if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
2375       MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
2376       assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
2377              DefInstr->getOperand(3).isReg());
2378       Register GWSPtr = DefInstr->getOperand(3).getReg();
2379       // TODO: Maybe simplify generation of the type of the fields.
2380       unsigned Size = Call->Builtin->Name == "ndrange_3D" ? 3 : 2;
2381       unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
2382       Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
2383       Type *FieldTy = ArrayType::get(BaseTy, Size);
2384       SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(
2385           FieldTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
2386       GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
2387       GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
2388       MIRBuilder.buildInstr(SPIRV::OpLoad)
2389           .addDef(GlobalWorkSize)
2390           .addUse(GR->getSPIRVTypeID(SpvFieldTy))
2391           .addUse(GWSPtr);
2392       const SPIRVSubtarget &ST =
2393           cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
2394       Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
2395                                            SpvFieldTy, *ST.getInstrInfo());
2396     } else {
2397       Const = GR->buildConstantInt(0, MIRBuilder, SpvTy, true);
2398     }
2399     if (!LocalWorkSize.isValid())
2400       LocalWorkSize = Const;
2401     if (!GlobalWorkOffset.isValid())
2402       GlobalWorkOffset = Const;
2403   }
2404   assert(LocalWorkSize.isValid() && GlobalWorkOffset.isValid());
2405   MIRBuilder.buildInstr(SPIRV::OpBuildNDRange)
2406       .addDef(TmpReg)
2407       .addUse(TypeReg)
2408       .addUse(GlobalWorkSize)
2409       .addUse(LocalWorkSize)
2410       .addUse(GlobalWorkOffset);
2411   return MIRBuilder.buildInstr(SPIRV::OpStore)
2412       .addUse(Call->Arguments[0])
2413       .addUse(TmpReg);
2414 }
2415 
2416 // TODO: maybe move to the global register.
2417 static SPIRVType *
getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2418 getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
2419                                    SPIRVGlobalRegistry *GR) {
2420   LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
2421   unsigned SC1 = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
2422   Type *PtrType = PointerType::get(Context, SC1);
2423   return GR->getOrCreateSPIRVType(PtrType, MIRBuilder,
2424                                   SPIRV::AccessQualifier::ReadWrite, true);
2425 }
2426 
buildEnqueueKernel(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2427 static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
2428                                MachineIRBuilder &MIRBuilder,
2429                                SPIRVGlobalRegistry *GR) {
2430   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2431   const DataLayout &DL = MIRBuilder.getDataLayout();
2432   bool IsSpirvOp = Call->isSpirvOp();
2433   bool HasEvents = Call->Builtin->Name.contains("events") || IsSpirvOp;
2434   const SPIRVType *Int32Ty = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
2435 
2436   // Make vararg instructions before OpEnqueueKernel.
2437   // Local sizes arguments: Sizes of block invoke arguments. Clang generates
2438   // local size operands as an array, so we need to unpack them.
2439   SmallVector<Register, 16> LocalSizes;
2440   if (Call->Builtin->Name.contains("_varargs") || IsSpirvOp) {
2441     const unsigned LocalSizeArrayIdx = HasEvents ? 9 : 6;
2442     Register GepReg = Call->Arguments[LocalSizeArrayIdx];
2443     MachineInstr *GepMI = MRI->getUniqueVRegDef(GepReg);
2444     assert(isSpvIntrinsic(*GepMI, Intrinsic::spv_gep) &&
2445            GepMI->getOperand(3).isReg());
2446     Register ArrayReg = GepMI->getOperand(3).getReg();
2447     MachineInstr *ArrayMI = MRI->getUniqueVRegDef(ArrayReg);
2448     const Type *LocalSizeTy = getMachineInstrType(ArrayMI);
2449     assert(LocalSizeTy && "Local size type is expected");
2450     const uint64_t LocalSizeNum =
2451         cast<ArrayType>(LocalSizeTy)->getNumElements();
2452     unsigned SC = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
2453     const LLT LLType = LLT::pointer(SC, GR->getPointerSize());
2454     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
2455         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
2456     for (unsigned I = 0; I < LocalSizeNum; ++I) {
2457       Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
2458       MRI->setType(Reg, LLType);
2459       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
2460       auto GEPInst = MIRBuilder.buildIntrinsic(
2461           Intrinsic::spv_gep, ArrayRef<Register>{Reg}, true, false);
2462       GEPInst
2463           .addImm(GepMI->getOperand(2).getImm())            // In bound.
2464           .addUse(ArrayMI->getOperand(0).getReg())          // Alloca.
2465           .addUse(buildConstantIntReg32(0, MIRBuilder, GR)) // Indices.
2466           .addUse(buildConstantIntReg32(I, MIRBuilder, GR));
2467       LocalSizes.push_back(Reg);
2468     }
2469   }
2470 
2471   // SPIRV OpEnqueueKernel instruction has 10+ arguments.
2472   auto MIB = MIRBuilder.buildInstr(SPIRV::OpEnqueueKernel)
2473                  .addDef(Call->ReturnRegister)
2474                  .addUse(GR->getSPIRVTypeID(Int32Ty));
2475 
2476   // Copy all arguments before block invoke function pointer.
2477   const unsigned BlockFIdx = HasEvents ? 6 : 3;
2478   for (unsigned i = 0; i < BlockFIdx; i++)
2479     MIB.addUse(Call->Arguments[i]);
2480 
2481   // If there are no event arguments in the original call, add dummy ones.
2482   if (!HasEvents) {
2483     MIB.addUse(buildConstantIntReg32(0, MIRBuilder, GR)); // Dummy num events.
2484     Register NullPtr = GR->getOrCreateConstNullPtr(
2485         MIRBuilder, getOrCreateSPIRVDeviceEventPointer(MIRBuilder, GR));
2486     MIB.addUse(NullPtr); // Dummy wait events.
2487     MIB.addUse(NullPtr); // Dummy ret event.
2488   }
2489 
2490   MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
2491   assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
2492   // Invoke: Pointer to invoke function.
2493   MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
2494 
2495   Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
2496   // Param: Pointer to block literal.
2497   MIB.addUse(BlockLiteralReg);
2498 
2499   Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
2500   // TODO: these numbers should be obtained from block literal structure.
2501   // Param Size: Size of block literal structure.
2502   MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
2503   // Param Aligment: Aligment of block literal structure.
2504   MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
2505                                    MIRBuilder, GR));
2506 
2507   for (unsigned i = 0; i < LocalSizes.size(); i++)
2508     MIB.addUse(LocalSizes[i]);
2509   return true;
2510 }
2511 
generateEnqueueInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2512 static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
2513                                 MachineIRBuilder &MIRBuilder,
2514                                 SPIRVGlobalRegistry *GR) {
2515   // Lookup the instruction opcode in the TableGen records.
2516   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2517   unsigned Opcode =
2518       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2519 
2520   switch (Opcode) {
2521   case SPIRV::OpRetainEvent:
2522   case SPIRV::OpReleaseEvent:
2523     return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]);
2524   case SPIRV::OpCreateUserEvent:
2525   case SPIRV::OpGetDefaultQueue:
2526     return MIRBuilder.buildInstr(Opcode)
2527         .addDef(Call->ReturnRegister)
2528         .addUse(GR->getSPIRVTypeID(Call->ReturnType));
2529   case SPIRV::OpIsValidEvent:
2530     return MIRBuilder.buildInstr(Opcode)
2531         .addDef(Call->ReturnRegister)
2532         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2533         .addUse(Call->Arguments[0]);
2534   case SPIRV::OpSetUserEventStatus:
2535     return MIRBuilder.buildInstr(Opcode)
2536         .addUse(Call->Arguments[0])
2537         .addUse(Call->Arguments[1]);
2538   case SPIRV::OpCaptureEventProfilingInfo:
2539     return MIRBuilder.buildInstr(Opcode)
2540         .addUse(Call->Arguments[0])
2541         .addUse(Call->Arguments[1])
2542         .addUse(Call->Arguments[2]);
2543   case SPIRV::OpBuildNDRange:
2544     return buildNDRange(Call, MIRBuilder, GR);
2545   case SPIRV::OpEnqueueKernel:
2546     return buildEnqueueKernel(Call, MIRBuilder, GR);
2547   default:
2548     return false;
2549   }
2550 }
2551 
generateAsyncCopy(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2552 static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
2553                               MachineIRBuilder &MIRBuilder,
2554                               SPIRVGlobalRegistry *GR) {
2555   // Lookup the instruction opcode in the TableGen records.
2556   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2557   unsigned Opcode =
2558       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2559 
2560   bool IsSet = Opcode == SPIRV::OpGroupAsyncCopy;
2561   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
2562   if (Call->isSpirvOp())
2563     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2564                               IsSet ? TypeReg : Register(0));
2565 
2566   auto Scope = buildConstantIntReg32(SPIRV::Scope::Workgroup, MIRBuilder, GR);
2567 
2568   switch (Opcode) {
2569   case SPIRV::OpGroupAsyncCopy: {
2570     SPIRVType *NewType =
2571         Call->ReturnType->getOpcode() == SPIRV::OpTypeEvent
2572             ? nullptr
2573             : GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder, true);
2574     Register TypeReg = GR->getSPIRVTypeID(NewType ? NewType : Call->ReturnType);
2575     unsigned NumArgs = Call->Arguments.size();
2576     Register EventReg = Call->Arguments[NumArgs - 1];
2577     bool Res = MIRBuilder.buildInstr(Opcode)
2578                    .addDef(Call->ReturnRegister)
2579                    .addUse(TypeReg)
2580                    .addUse(Scope)
2581                    .addUse(Call->Arguments[0])
2582                    .addUse(Call->Arguments[1])
2583                    .addUse(Call->Arguments[2])
2584                    .addUse(Call->Arguments.size() > 4
2585                                ? Call->Arguments[3]
2586                                : buildConstantIntReg32(1, MIRBuilder, GR))
2587                    .addUse(EventReg);
2588     if (NewType != nullptr)
2589       insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
2590                         MIRBuilder.getMF().getRegInfo());
2591     return Res;
2592   }
2593   case SPIRV::OpGroupWaitEvents:
2594     return MIRBuilder.buildInstr(Opcode)
2595         .addUse(Scope)
2596         .addUse(Call->Arguments[0])
2597         .addUse(Call->Arguments[1]);
2598   default:
2599     return false;
2600   }
2601 }
2602 
generateConvertInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2603 static bool generateConvertInst(const StringRef DemangledCall,
2604                                 const SPIRV::IncomingCall *Call,
2605                                 MachineIRBuilder &MIRBuilder,
2606                                 SPIRVGlobalRegistry *GR) {
2607   // Lookup the conversion builtin in the TableGen records.
2608   const SPIRV::ConvertBuiltin *Builtin =
2609       SPIRV::lookupConvertBuiltin(Call->Builtin->Name, Call->Builtin->Set);
2610 
2611   if (!Builtin && Call->isSpirvOp()) {
2612     const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2613     unsigned Opcode =
2614         SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2615     return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2616                               GR->getSPIRVTypeID(Call->ReturnType));
2617   }
2618 
2619   if (Builtin->IsSaturated)
2620     buildOpDecorate(Call->ReturnRegister, MIRBuilder,
2621                     SPIRV::Decoration::SaturatedConversion, {});
2622   if (Builtin->IsRounded)
2623     buildOpDecorate(Call->ReturnRegister, MIRBuilder,
2624                     SPIRV::Decoration::FPRoundingMode,
2625                     {(unsigned)Builtin->RoundingMode});
2626 
2627   std::string NeedExtMsg;              // no errors if empty
2628   bool IsRightComponentsNumber = true; // check if input/output accepts vectors
2629   unsigned Opcode = SPIRV::OpNop;
2630   if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
2631     // Int -> ...
2632     if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
2633       // Int -> Int
2634       if (Builtin->IsSaturated)
2635         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpSatConvertUToS
2636                                               : SPIRV::OpSatConvertSToU;
2637       else
2638         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpUConvert
2639                                               : SPIRV::OpSConvert;
2640     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2641                                           SPIRV::OpTypeFloat)) {
2642       // Int -> Float
2643       if (Builtin->IsBfloat16) {
2644         const auto *ST = static_cast<const SPIRVSubtarget *>(
2645             &MIRBuilder.getMF().getSubtarget());
2646         if (!ST->canUseExtension(
2647                 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2648           NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2649         IsRightComponentsNumber =
2650             GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2651             GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2652         Opcode = SPIRV::OpConvertBF16ToFINTEL;
2653       } else {
2654         bool IsSourceSigned =
2655             DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
2656         Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
2657       }
2658     }
2659   } else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
2660                                         SPIRV::OpTypeFloat)) {
2661     // Float -> ...
2662     if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
2663       // Float -> Int
2664       if (Builtin->IsBfloat16) {
2665         const auto *ST = static_cast<const SPIRVSubtarget *>(
2666             &MIRBuilder.getMF().getSubtarget());
2667         if (!ST->canUseExtension(
2668                 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2669           NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2670         IsRightComponentsNumber =
2671             GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2672             GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2673         Opcode = SPIRV::OpConvertFToBF16INTEL;
2674       } else {
2675         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
2676                                               : SPIRV::OpConvertFToU;
2677       }
2678     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2679                                           SPIRV::OpTypeFloat)) {
2680       // Float -> Float
2681       Opcode = SPIRV::OpFConvert;
2682     }
2683   }
2684 
2685   if (!NeedExtMsg.empty()) {
2686     std::string DiagMsg = std::string(Builtin->Name) +
2687                           ": the builtin requires the following SPIR-V "
2688                           "extension: " +
2689                           NeedExtMsg;
2690     report_fatal_error(DiagMsg.c_str(), false);
2691   }
2692   if (!IsRightComponentsNumber) {
2693     std::string DiagMsg =
2694         std::string(Builtin->Name) +
2695         ": result and argument must have the same number of components";
2696     report_fatal_error(DiagMsg.c_str(), false);
2697   }
2698   assert(Opcode != SPIRV::OpNop &&
2699          "Conversion between the types not implemented!");
2700 
2701   MIRBuilder.buildInstr(Opcode)
2702       .addDef(Call->ReturnRegister)
2703       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2704       .addUse(Call->Arguments[0]);
2705   return true;
2706 }
2707 
generateVectorLoadStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2708 static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
2709                                         MachineIRBuilder &MIRBuilder,
2710                                         SPIRVGlobalRegistry *GR) {
2711   // Lookup the vector load/store builtin in the TableGen records.
2712   const SPIRV::VectorLoadStoreBuiltin *Builtin =
2713       SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
2714                                           Call->Builtin->Set);
2715   // Build extended instruction.
2716   auto MIB =
2717       MIRBuilder.buildInstr(SPIRV::OpExtInst)
2718           .addDef(Call->ReturnRegister)
2719           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2720           .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
2721           .addImm(Builtin->Number);
2722   for (auto Argument : Call->Arguments)
2723     MIB.addUse(Argument);
2724   if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
2725     MIB.addImm(Builtin->ElementCount);
2726 
2727   // Rounding mode should be passed as a last argument in the MI for builtins
2728   // like "vstorea_halfn_r".
2729   if (Builtin->IsRounded)
2730     MIB.addImm(static_cast<uint32_t>(Builtin->RoundingMode));
2731   return true;
2732 }
2733 
generateLoadStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2734 static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
2735                                   MachineIRBuilder &MIRBuilder,
2736                                   SPIRVGlobalRegistry *GR) {
2737   // Lookup the instruction opcode in the TableGen records.
2738   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2739   unsigned Opcode =
2740       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2741   bool IsLoad = Opcode == SPIRV::OpLoad;
2742   // Build the instruction.
2743   auto MIB = MIRBuilder.buildInstr(Opcode);
2744   if (IsLoad) {
2745     MIB.addDef(Call->ReturnRegister);
2746     MIB.addUse(GR->getSPIRVTypeID(Call->ReturnType));
2747   }
2748   // Add a pointer to the value to load/store.
2749   MIB.addUse(Call->Arguments[0]);
2750   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2751   // Add a value to store.
2752   if (!IsLoad)
2753     MIB.addUse(Call->Arguments[1]);
2754   // Add optional memory attributes and an alignment.
2755   unsigned NumArgs = Call->Arguments.size();
2756   if ((IsLoad && NumArgs >= 2) || NumArgs >= 3)
2757     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI));
2758   if ((IsLoad && NumArgs >= 3) || NumArgs >= 4)
2759     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI));
2760   return true;
2761 }
2762 
2763 namespace SPIRV {
2764 // Try to find a builtin function attributes by a demangled function name and
2765 // return a tuple <builtin group, op code, ext instruction number>, or a special
2766 // tuple value <-1, 0, 0> if the builtin function is not found.
2767 // Not all builtin functions are supported, only those with a ready-to-use op
2768 // code or instruction number defined in TableGen.
2769 // TODO: consider a major rework of mapping demangled calls into a builtin
2770 // functions to unify search and decrease number of individual cases.
2771 std::tuple<int, unsigned, unsigned>
mapBuiltinToOpcode(const StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set)2772 mapBuiltinToOpcode(const StringRef DemangledCall,
2773                    SPIRV::InstructionSet::InstructionSet Set) {
2774   Register Reg;
2775   SmallVector<Register> Args;
2776   std::unique_ptr<const IncomingCall> Call =
2777       lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
2778   if (!Call)
2779     return std::make_tuple(-1, 0, 0);
2780 
2781   switch (Call->Builtin->Group) {
2782   case SPIRV::Relational:
2783   case SPIRV::Atomic:
2784   case SPIRV::Barrier:
2785   case SPIRV::CastToPtr:
2786   case SPIRV::ImageMiscQuery:
2787   case SPIRV::SpecConstant:
2788   case SPIRV::Enqueue:
2789   case SPIRV::AsyncCopy:
2790   case SPIRV::LoadStore:
2791   case SPIRV::CoopMatr:
2792     if (const auto *R =
2793             SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
2794       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2795     break;
2796   case SPIRV::Extended:
2797     if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
2798                                                      Call->Builtin->Set))
2799       return std::make_tuple(Call->Builtin->Group, 0, R->Number);
2800     break;
2801   case SPIRV::VectorLoadStore:
2802     if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
2803                                                             Call->Builtin->Set))
2804       return std::make_tuple(SPIRV::Extended, 0, R->Number);
2805     break;
2806   case SPIRV::Group:
2807     if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
2808       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2809     break;
2810   case SPIRV::AtomicFloating:
2811     if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
2812       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2813     break;
2814   case SPIRV::IntelSubgroups:
2815     if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
2816       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2817     break;
2818   case SPIRV::GroupUniform:
2819     if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
2820       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2821     break;
2822   case SPIRV::IntegerDot:
2823     if (const auto *R =
2824             SPIRV::lookupIntegerDotProductBuiltin(Call->Builtin->Name))
2825       return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2826     break;
2827   case SPIRV::WriteImage:
2828     return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
2829   case SPIRV::Select:
2830     return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
2831   case SPIRV::Construct:
2832     return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
2833                            0);
2834   case SPIRV::KernelClock:
2835     return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
2836   default:
2837     return std::make_tuple(-1, 0, 0);
2838   }
2839   return std::make_tuple(-1, 0, 0);
2840 }
2841 
lowerBuiltin(const StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set,MachineIRBuilder & MIRBuilder,const Register OrigRet,const Type * OrigRetTy,const SmallVectorImpl<Register> & Args,SPIRVGlobalRegistry * GR)2842 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
2843                                  SPIRV::InstructionSet::InstructionSet Set,
2844                                  MachineIRBuilder &MIRBuilder,
2845                                  const Register OrigRet, const Type *OrigRetTy,
2846                                  const SmallVectorImpl<Register> &Args,
2847                                  SPIRVGlobalRegistry *GR) {
2848   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
2849 
2850   // Lookup the builtin in the TableGen records.
2851   SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
2852   assert(SpvType && "Inconsistent return register: expected valid type info");
2853   std::unique_ptr<const IncomingCall> Call =
2854       lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
2855 
2856   if (!Call) {
2857     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
2858     return std::nullopt;
2859   }
2860 
2861   // TODO: check if the provided args meet the builtin requirments.
2862   assert(Args.size() >= Call->Builtin->MinNumArgs &&
2863          "Too few arguments to generate the builtin");
2864   if (Call->Builtin->MaxNumArgs && Args.size() > Call->Builtin->MaxNumArgs)
2865     LLVM_DEBUG(dbgs() << "More arguments provided than required!\n");
2866 
2867   // Match the builtin with implementation based on the grouping.
2868   switch (Call->Builtin->Group) {
2869   case SPIRV::Extended:
2870     return generateExtInst(Call.get(), MIRBuilder, GR);
2871   case SPIRV::Relational:
2872     return generateRelationalInst(Call.get(), MIRBuilder, GR);
2873   case SPIRV::Group:
2874     return generateGroupInst(Call.get(), MIRBuilder, GR);
2875   case SPIRV::Variable:
2876     return generateBuiltinVar(Call.get(), MIRBuilder, GR);
2877   case SPIRV::Atomic:
2878     return generateAtomicInst(Call.get(), MIRBuilder, GR);
2879   case SPIRV::AtomicFloating:
2880     return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
2881   case SPIRV::Barrier:
2882     return generateBarrierInst(Call.get(), MIRBuilder, GR);
2883   case SPIRV::CastToPtr:
2884     return generateCastToPtrInst(Call.get(), MIRBuilder, GR);
2885   case SPIRV::Dot:
2886   case SPIRV::IntegerDot:
2887     return generateDotOrFMulInst(DemangledCall, Call.get(), MIRBuilder, GR);
2888   case SPIRV::Wave:
2889     return generateWaveInst(Call.get(), MIRBuilder, GR);
2890   case SPIRV::ICarryBorrow:
2891     return generateICarryBorrowInst(Call.get(), MIRBuilder, GR);
2892   case SPIRV::GetQuery:
2893     return generateGetQueryInst(Call.get(), MIRBuilder, GR);
2894   case SPIRV::ImageSizeQuery:
2895     return generateImageSizeQueryInst(Call.get(), MIRBuilder, GR);
2896   case SPIRV::ImageMiscQuery:
2897     return generateImageMiscQueryInst(Call.get(), MIRBuilder, GR);
2898   case SPIRV::ReadImage:
2899     return generateReadImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
2900   case SPIRV::WriteImage:
2901     return generateWriteImageInst(Call.get(), MIRBuilder, GR);
2902   case SPIRV::SampleImage:
2903     return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
2904   case SPIRV::Select:
2905     return generateSelectInst(Call.get(), MIRBuilder);
2906   case SPIRV::Construct:
2907     return generateConstructInst(Call.get(), MIRBuilder, GR);
2908   case SPIRV::SpecConstant:
2909     return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
2910   case SPIRV::Enqueue:
2911     return generateEnqueueInst(Call.get(), MIRBuilder, GR);
2912   case SPIRV::AsyncCopy:
2913     return generateAsyncCopy(Call.get(), MIRBuilder, GR);
2914   case SPIRV::Convert:
2915     return generateConvertInst(DemangledCall, Call.get(), MIRBuilder, GR);
2916   case SPIRV::VectorLoadStore:
2917     return generateVectorLoadStoreInst(Call.get(), MIRBuilder, GR);
2918   case SPIRV::LoadStore:
2919     return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
2920   case SPIRV::IntelSubgroups:
2921     return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
2922   case SPIRV::GroupUniform:
2923     return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
2924   case SPIRV::KernelClock:
2925     return generateKernelClockInst(Call.get(), MIRBuilder, GR);
2926   case SPIRV::CoopMatr:
2927     return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
2928   case SPIRV::ExtendedBitOps:
2929     return generateExtendedBitOpsInst(Call.get(), MIRBuilder, GR);
2930   case SPIRV::BindlessINTEL:
2931     return generateBindlessImageINTELInst(Call.get(), MIRBuilder, GR);
2932   case SPIRV::TernaryBitwiseINTEL:
2933     return generateTernaryBitwiseFunctionINTELInst(Call.get(), MIRBuilder, GR);
2934   case SPIRV::Block2DLoadStore:
2935     return generate2DBlockIOINTELInst(Call.get(), MIRBuilder, GR);
2936   }
2937   return false;
2938 }
2939 
parseBuiltinCallArgumentType(StringRef TypeStr,LLVMContext & Ctx)2940 Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx) {
2941   // Parse strings representing OpenCL builtin types.
2942   if (hasBuiltinTypePrefix(TypeStr)) {
2943     // OpenCL builtin types in demangled call strings have the following format:
2944     // e.g. ocl_image2d_ro
2945     [[maybe_unused]] bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
2946     assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
2947 
2948     // Check if this is pointer to a builtin type and not just pointer
2949     // representing a builtin type. In case it is a pointer to builtin type,
2950     // this will require additional handling in the method calling
2951     // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
2952     // base types.
2953     if (TypeStr.ends_with("*"))
2954       TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
2955 
2956     return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
2957                                                Ctx);
2958   }
2959 
2960   // Parse type name in either "typeN" or "type vector[N]" format, where
2961   // N is the number of elements of the vector.
2962   Type *BaseType;
2963   unsigned VecElts = 0;
2964 
2965   BaseType = parseBasicTypeName(TypeStr, Ctx);
2966   if (!BaseType)
2967     // Unable to recognize SPIRV type name.
2968     return nullptr;
2969 
2970   // Handle "typeN*" or "type vector[N]*".
2971   TypeStr.consume_back("*");
2972 
2973   if (TypeStr.consume_front(" vector["))
2974     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
2975 
2976   TypeStr.getAsInteger(10, VecElts);
2977   if (VecElts > 0)
2978     BaseType = VectorType::get(
2979         BaseType->isVoidTy() ? Type::getInt8Ty(Ctx) : BaseType, VecElts, false);
2980 
2981   return BaseType;
2982 }
2983 
parseBuiltinTypeStr(SmallVector<StringRef,10> & BuiltinArgsTypeStrs,const StringRef DemangledCall,LLVMContext & Ctx)2984 bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
2985                          const StringRef DemangledCall, LLVMContext &Ctx) {
2986   auto Pos1 = DemangledCall.find('(');
2987   if (Pos1 == StringRef::npos)
2988     return false;
2989   auto Pos2 = DemangledCall.find(')');
2990   if (Pos2 == StringRef::npos || Pos1 > Pos2)
2991     return false;
2992   DemangledCall.slice(Pos1 + 1, Pos2)
2993       .split(BuiltinArgsTypeStrs, ',', -1, false);
2994   return true;
2995 }
2996 
parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,unsigned ArgIdx,LLVMContext & Ctx)2997 Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2998                                        unsigned ArgIdx, LLVMContext &Ctx) {
2999   SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
3000   parseBuiltinTypeStr(BuiltinArgsTypeStrs, DemangledCall, Ctx);
3001   if (ArgIdx >= BuiltinArgsTypeStrs.size())
3002     return nullptr;
3003   StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
3004   return parseBuiltinCallArgumentType(TypeStr, Ctx);
3005 }
3006 
3007 struct BuiltinType {
3008   StringRef Name;
3009   uint32_t Opcode;
3010 };
3011 
3012 #define GET_BuiltinTypes_DECL
3013 #define GET_BuiltinTypes_IMPL
3014 
3015 struct OpenCLType {
3016   StringRef Name;
3017   StringRef SpirvTypeLiteral;
3018 };
3019 
3020 #define GET_OpenCLTypes_DECL
3021 #define GET_OpenCLTypes_IMPL
3022 
3023 #include "SPIRVGenTables.inc"
3024 } // namespace SPIRV
3025 
3026 //===----------------------------------------------------------------------===//
3027 // Misc functions for parsing builtin types.
3028 //===----------------------------------------------------------------------===//
3029 
parseTypeString(const StringRef Name,LLVMContext & Context)3030 static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
3031   if (Name.starts_with("void"))
3032     return Type::getVoidTy(Context);
3033   else if (Name.starts_with("int") || Name.starts_with("uint"))
3034     return Type::getInt32Ty(Context);
3035   else if (Name.starts_with("float"))
3036     return Type::getFloatTy(Context);
3037   else if (Name.starts_with("half"))
3038     return Type::getHalfTy(Context);
3039   report_fatal_error("Unable to recognize type!");
3040 }
3041 
3042 //===----------------------------------------------------------------------===//
3043 // Implementation functions for builtin types.
3044 //===----------------------------------------------------------------------===//
3045 
getNonParameterizedType(const TargetExtType * ExtensionType,const SPIRV::BuiltinType * TypeRecord,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3046 static SPIRVType *getNonParameterizedType(const TargetExtType *ExtensionType,
3047                                           const SPIRV::BuiltinType *TypeRecord,
3048                                           MachineIRBuilder &MIRBuilder,
3049                                           SPIRVGlobalRegistry *GR) {
3050   unsigned Opcode = TypeRecord->Opcode;
3051   // Create or get an existing type from GlobalRegistry.
3052   return GR->getOrCreateOpTypeByOpcode(ExtensionType, MIRBuilder, Opcode);
3053 }
3054 
getSamplerType(MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3055 static SPIRVType *getSamplerType(MachineIRBuilder &MIRBuilder,
3056                                  SPIRVGlobalRegistry *GR) {
3057   // Create or get an existing type from GlobalRegistry.
3058   return GR->getOrCreateOpTypeSampler(MIRBuilder);
3059 }
3060 
getPipeType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3061 static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
3062                               MachineIRBuilder &MIRBuilder,
3063                               SPIRVGlobalRegistry *GR) {
3064   assert(ExtensionType->getNumIntParameters() == 1 &&
3065          "Invalid number of parameters for SPIR-V pipe builtin!");
3066   // Create or get an existing type from GlobalRegistry.
3067   return GR->getOrCreateOpTypePipe(MIRBuilder,
3068                                    SPIRV::AccessQualifier::AccessQualifier(
3069                                        ExtensionType->getIntParameter(0)));
3070 }
3071 
getCoopMatrType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3072 static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
3073                                   MachineIRBuilder &MIRBuilder,
3074                                   SPIRVGlobalRegistry *GR) {
3075   assert(ExtensionType->getNumIntParameters() == 4 &&
3076          "Invalid number of parameters for SPIR-V coop matrices builtin!");
3077   assert(ExtensionType->getNumTypeParameters() == 1 &&
3078          "SPIR-V coop matrices builtin type must have a type parameter!");
3079   const SPIRVType *ElemType =
3080       GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
3081                                SPIRV::AccessQualifier::ReadWrite, true);
3082   // Create or get an existing type from GlobalRegistry.
3083   return GR->getOrCreateOpTypeCoopMatr(
3084       MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
3085       ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
3086       ExtensionType->getIntParameter(3), true);
3087 }
3088 
getSampledImageType(const TargetExtType * OpaqueType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3089 static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
3090                                       MachineIRBuilder &MIRBuilder,
3091                                       SPIRVGlobalRegistry *GR) {
3092   SPIRVType *OpaqueImageType = GR->getImageType(
3093       OpaqueType, SPIRV::AccessQualifier::ReadOnly, MIRBuilder);
3094   // Create or get an existing type from GlobalRegistry.
3095   return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
3096 }
3097 
getInlineSpirvType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3098 static SPIRVType *getInlineSpirvType(const TargetExtType *ExtensionType,
3099                                      MachineIRBuilder &MIRBuilder,
3100                                      SPIRVGlobalRegistry *GR) {
3101   assert(ExtensionType->getNumIntParameters() == 3 &&
3102          "Inline SPIR-V type builtin takes an opcode, size, and alignment "
3103          "parameter");
3104   auto Opcode = ExtensionType->getIntParameter(0);
3105 
3106   SmallVector<MCOperand> Operands;
3107   for (Type *Param : ExtensionType->type_params()) {
3108     if (const TargetExtType *ParamEType = dyn_cast<TargetExtType>(Param)) {
3109       if (ParamEType->getName() == "spirv.IntegralConstant") {
3110         assert(ParamEType->getNumTypeParameters() == 1 &&
3111                "Inline SPIR-V integral constant builtin must have a type "
3112                "parameter");
3113         assert(ParamEType->getNumIntParameters() == 1 &&
3114                "Inline SPIR-V integral constant builtin must have a "
3115                "value parameter");
3116 
3117         auto OperandValue = ParamEType->getIntParameter(0);
3118         auto *OperandType = ParamEType->getTypeParameter(0);
3119 
3120         const SPIRVType *OperandSPIRVType = GR->getOrCreateSPIRVType(
3121             OperandType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
3122 
3123         Operands.push_back(MCOperand::createReg(GR->buildConstantInt(
3124             OperandValue, MIRBuilder, OperandSPIRVType, true)));
3125         continue;
3126       } else if (ParamEType->getName() == "spirv.Literal") {
3127         assert(ParamEType->getNumTypeParameters() == 0 &&
3128                "Inline SPIR-V literal builtin does not take type "
3129                "parameters");
3130         assert(ParamEType->getNumIntParameters() == 1 &&
3131                "Inline SPIR-V literal builtin must have an integer "
3132                "parameter");
3133 
3134         auto OperandValue = ParamEType->getIntParameter(0);
3135 
3136         Operands.push_back(MCOperand::createImm(OperandValue));
3137         continue;
3138       }
3139     }
3140     const SPIRVType *TypeOperand = GR->getOrCreateSPIRVType(
3141         Param, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
3142     Operands.push_back(MCOperand::createReg(GR->getSPIRVTypeID(TypeOperand)));
3143   }
3144 
3145   return GR->getOrCreateUnknownType(ExtensionType, MIRBuilder, Opcode,
3146                                     Operands);
3147 }
3148 
getVulkanBufferType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3149 static SPIRVType *getVulkanBufferType(const TargetExtType *ExtensionType,
3150                                       MachineIRBuilder &MIRBuilder,
3151                                       SPIRVGlobalRegistry *GR) {
3152   assert(ExtensionType->getNumTypeParameters() == 1 &&
3153          "Vulkan buffers have exactly one type for the type of the buffer.");
3154   assert(ExtensionType->getNumIntParameters() == 2 &&
3155          "Vulkan buffer have 2 integer parameters: storage class and is "
3156          "writable.");
3157 
3158   auto *T = ExtensionType->getTypeParameter(0);
3159   auto SC = static_cast<SPIRV::StorageClass::StorageClass>(
3160       ExtensionType->getIntParameter(0));
3161   bool IsWritable = ExtensionType->getIntParameter(1);
3162   return GR->getOrCreateVulkanBufferType(MIRBuilder, T, SC, IsWritable);
3163 }
3164 
getLayoutType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3165 static SPIRVType *getLayoutType(const TargetExtType *ExtensionType,
3166                                 MachineIRBuilder &MIRBuilder,
3167                                 SPIRVGlobalRegistry *GR) {
3168   return GR->getOrCreateLayoutType(MIRBuilder, ExtensionType);
3169 }
3170 
3171 namespace SPIRV {
parseBuiltinTypeNameToTargetExtType(std::string TypeName,LLVMContext & Context)3172 TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
3173                                                    LLVMContext &Context) {
3174   StringRef NameWithParameters = TypeName;
3175 
3176   // Pointers-to-opaque-structs representing OpenCL types are first translated
3177   // to equivalent SPIR-V types. OpenCL builtin type names should have the
3178   // following format: e.g. %opencl.event_t
3179   if (NameWithParameters.starts_with("opencl.")) {
3180     const SPIRV::OpenCLType *OCLTypeRecord =
3181         SPIRV::lookupOpenCLType(NameWithParameters);
3182     if (!OCLTypeRecord)
3183       report_fatal_error("Missing TableGen record for OpenCL type: " +
3184                          NameWithParameters);
3185     NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
3186     // Continue with the SPIR-V builtin type...
3187   }
3188 
3189   // Names of the opaque structs representing a SPIR-V builtins without
3190   // parameters should have the following format: e.g. %spirv.Event
3191   assert(NameWithParameters.starts_with("spirv.") &&
3192          "Unknown builtin opaque type!");
3193 
3194   // Parameterized SPIR-V builtins names follow this format:
3195   // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
3196   if (!NameWithParameters.contains('_'))
3197     return TargetExtType::get(Context, NameWithParameters);
3198 
3199   SmallVector<StringRef> Parameters;
3200   unsigned BaseNameLength = NameWithParameters.find('_') - 1;
3201   SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");
3202 
3203   SmallVector<Type *, 1> TypeParameters;
3204   bool HasTypeParameter = !isDigit(Parameters[0][0]);
3205   if (HasTypeParameter)
3206     TypeParameters.push_back(parseTypeString(Parameters[0], Context));
3207   SmallVector<unsigned> IntParameters;
3208   for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
3209     unsigned IntParameter = 0;
3210     bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
3211     (void)ValidLiteral;
3212     assert(ValidLiteral &&
3213            "Invalid format of SPIR-V builtin parameter literal!");
3214     IntParameters.push_back(IntParameter);
3215   }
3216   return TargetExtType::get(Context,
3217                             NameWithParameters.substr(0, BaseNameLength),
3218                             TypeParameters, IntParameters);
3219 }
3220 
lowerBuiltinType(const Type * OpaqueType,SPIRV::AccessQualifier::AccessQualifier AccessQual,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)3221 SPIRVType *lowerBuiltinType(const Type *OpaqueType,
3222                             SPIRV::AccessQualifier::AccessQualifier AccessQual,
3223                             MachineIRBuilder &MIRBuilder,
3224                             SPIRVGlobalRegistry *GR) {
3225   // In LLVM IR, SPIR-V and OpenCL builtin types are represented as either
3226   // target(...) target extension types or pointers-to-opaque-structs. The
3227   // approach relying on structs is deprecated and works only in the non-opaque
3228   // pointer mode (-opaque-pointers=0).
3229   // In order to maintain compatibility with LLVM IR generated by older versions
3230   // of Clang and LLVM/SPIR-V Translator, the pointers-to-opaque-structs are
3231   // "translated" to target extension types. This translation is temporary and
3232   // will be removed in the future release of LLVM.
3233   const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
3234   if (!BuiltinType)
3235     BuiltinType = parseBuiltinTypeNameToTargetExtType(
3236         OpaqueType->getStructName().str(), MIRBuilder.getContext());
3237 
3238   unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
3239 
3240   const StringRef Name = BuiltinType->getName();
3241   LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
3242 
3243   SPIRVType *TargetType;
3244   if (Name == "spirv.Type") {
3245     TargetType = getInlineSpirvType(BuiltinType, MIRBuilder, GR);
3246   } else if (Name == "spirv.VulkanBuffer") {
3247     TargetType = getVulkanBufferType(BuiltinType, MIRBuilder, GR);
3248   } else if (Name == "spirv.Layout") {
3249     TargetType = getLayoutType(BuiltinType, MIRBuilder, GR);
3250   } else {
3251     // Lookup the demangled builtin type in the TableGen records.
3252     const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
3253     if (!TypeRecord)
3254       report_fatal_error("Missing TableGen record for builtin type: " + Name);
3255 
3256     // "Lower" the BuiltinType into TargetType. The following get<...>Type
3257     // methods use the implementation details from TableGen records or
3258     // TargetExtType parameters to either create a new OpType<...> machine
3259     // instruction or get an existing equivalent SPIRVType from
3260     // GlobalRegistry.
3261 
3262     switch (TypeRecord->Opcode) {
3263     case SPIRV::OpTypeImage:
3264       TargetType = GR->getImageType(BuiltinType, AccessQual, MIRBuilder);
3265       break;
3266     case SPIRV::OpTypePipe:
3267       TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
3268       break;
3269     case SPIRV::OpTypeDeviceEvent:
3270       TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
3271       break;
3272     case SPIRV::OpTypeSampler:
3273       TargetType = getSamplerType(MIRBuilder, GR);
3274       break;
3275     case SPIRV::OpTypeSampledImage:
3276       TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
3277       break;
3278     case SPIRV::OpTypeCooperativeMatrixKHR:
3279       TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
3280       break;
3281     default:
3282       TargetType =
3283           getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
3284       break;
3285     }
3286   }
3287 
3288   // Emit OpName instruction if a new OpType<...> instruction was added
3289   // (equivalent type was not found in GlobalRegistry).
3290   if (NumStartingVRegs < MIRBuilder.getMRI()->getNumVirtRegs())
3291     buildOpName(GR->getSPIRVTypeID(TargetType), Name, MIRBuilder);
3292 
3293   return TargetType;
3294 }
3295 } // namespace SPIRV
3296 } // namespace llvm
3297