1 //===- SPIRVBuiltins.cpp - SPIR-V Built-in Functions ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering builtin function calls and types using their
10 // demangled names and TableGen records.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "SPIRVBuiltins.h"
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/IntrinsicsSPIRV.h"
21 #include <string>
22 #include <tuple>
23
24 #define DEBUG_TYPE "spirv-builtins"
25
26 namespace llvm {
27 namespace SPIRV {
28 #define GET_BuiltinGroup_DECL
29 #include "SPIRVGenTables.inc"
30
31 struct DemangledBuiltin {
32 StringRef Name;
33 InstructionSet::InstructionSet Set;
34 BuiltinGroup Group;
35 uint8_t MinNumArgs;
36 uint8_t MaxNumArgs;
37 };
38
39 #define GET_DemangledBuiltins_DECL
40 #define GET_DemangledBuiltins_IMPL
41
42 struct IncomingCall {
43 const std::string BuiltinName;
44 const DemangledBuiltin *Builtin;
45
46 const Register ReturnRegister;
47 const SPIRVType *ReturnType;
48 const SmallVectorImpl<Register> &Arguments;
49
IncomingCallllvm::SPIRV::IncomingCall50 IncomingCall(const std::string BuiltinName, const DemangledBuiltin *Builtin,
51 const Register ReturnRegister, const SPIRVType *ReturnType,
52 const SmallVectorImpl<Register> &Arguments)
53 : BuiltinName(BuiltinName), Builtin(Builtin),
54 ReturnRegister(ReturnRegister), ReturnType(ReturnType),
55 Arguments(Arguments) {}
56
isSpirvOpllvm::SPIRV::IncomingCall57 bool isSpirvOp() const { return BuiltinName.rfind("__spirv_", 0) == 0; }
58 };
59
60 struct NativeBuiltin {
61 StringRef Name;
62 InstructionSet::InstructionSet Set;
63 uint32_t Opcode;
64 };
65
66 #define GET_NativeBuiltins_DECL
67 #define GET_NativeBuiltins_IMPL
68
69 struct GroupBuiltin {
70 StringRef Name;
71 uint32_t Opcode;
72 uint32_t GroupOperation;
73 bool IsElect;
74 bool IsAllOrAny;
75 bool IsAllEqual;
76 bool IsBallot;
77 bool IsInverseBallot;
78 bool IsBallotBitExtract;
79 bool IsBallotFindBit;
80 bool IsLogical;
81 bool NoGroupOperation;
82 bool HasBoolArg;
83 };
84
85 #define GET_GroupBuiltins_DECL
86 #define GET_GroupBuiltins_IMPL
87
88 struct IntelSubgroupsBuiltin {
89 StringRef Name;
90 uint32_t Opcode;
91 bool IsBlock;
92 bool IsWrite;
93 };
94
95 #define GET_IntelSubgroupsBuiltins_DECL
96 #define GET_IntelSubgroupsBuiltins_IMPL
97
98 struct AtomicFloatingBuiltin {
99 StringRef Name;
100 uint32_t Opcode;
101 };
102
103 #define GET_AtomicFloatingBuiltins_DECL
104 #define GET_AtomicFloatingBuiltins_IMPL
105 struct GroupUniformBuiltin {
106 StringRef Name;
107 uint32_t Opcode;
108 bool IsLogical;
109 };
110
111 #define GET_GroupUniformBuiltins_DECL
112 #define GET_GroupUniformBuiltins_IMPL
113
114 struct GetBuiltin {
115 StringRef Name;
116 InstructionSet::InstructionSet Set;
117 BuiltIn::BuiltIn Value;
118 };
119
120 using namespace BuiltIn;
121 #define GET_GetBuiltins_DECL
122 #define GET_GetBuiltins_IMPL
123
124 struct ImageQueryBuiltin {
125 StringRef Name;
126 InstructionSet::InstructionSet Set;
127 uint32_t Component;
128 };
129
130 #define GET_ImageQueryBuiltins_DECL
131 #define GET_ImageQueryBuiltins_IMPL
132
133 struct ConvertBuiltin {
134 StringRef Name;
135 InstructionSet::InstructionSet Set;
136 bool IsDestinationSigned;
137 bool IsSaturated;
138 bool IsRounded;
139 bool IsBfloat16;
140 FPRoundingMode::FPRoundingMode RoundingMode;
141 };
142
143 struct VectorLoadStoreBuiltin {
144 StringRef Name;
145 InstructionSet::InstructionSet Set;
146 uint32_t Number;
147 uint32_t ElementCount;
148 bool IsRounded;
149 FPRoundingMode::FPRoundingMode RoundingMode;
150 };
151
152 using namespace FPRoundingMode;
153 #define GET_ConvertBuiltins_DECL
154 #define GET_ConvertBuiltins_IMPL
155
156 using namespace InstructionSet;
157 #define GET_VectorLoadStoreBuiltins_DECL
158 #define GET_VectorLoadStoreBuiltins_IMPL
159
160 #define GET_CLMemoryScope_DECL
161 #define GET_CLSamplerAddressingMode_DECL
162 #define GET_CLMemoryFenceFlags_DECL
163 #define GET_ExtendedBuiltins_DECL
164 #include "SPIRVGenTables.inc"
165 } // namespace SPIRV
166
167 //===----------------------------------------------------------------------===//
168 // Misc functions for looking up builtins and veryfying requirements using
169 // TableGen records
170 //===----------------------------------------------------------------------===//
171
172 namespace SPIRV {
173 /// Parses the name part of the demangled builtin call.
lookupBuiltinNameHelper(StringRef DemangledCall)174 std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
175 const static std::string PassPrefix = "(anonymous namespace)::";
176 std::string BuiltinName;
177 // Itanium Demangler result may have "(anonymous namespace)::" prefix
178 if (DemangledCall.starts_with(PassPrefix.c_str()))
179 BuiltinName = DemangledCall.substr(PassPrefix.length());
180 else
181 BuiltinName = DemangledCall;
182 // Extract the builtin function name and types of arguments from the call
183 // skeleton.
184 BuiltinName = BuiltinName.substr(0, BuiltinName.find('('));
185
186 // Account for possible "__spirv_ocl_" prefix in SPIR-V friendly LLVM IR
187 if (BuiltinName.rfind("__spirv_ocl_", 0) == 0)
188 BuiltinName = BuiltinName.substr(12);
189
190 // Check if the extracted name contains type information between angle
191 // brackets. If so, the builtin is an instantiated template - needs to have
192 // the information after angle brackets and return type removed.
193 if (BuiltinName.find('<') && BuiltinName.back() == '>') {
194 BuiltinName = BuiltinName.substr(0, BuiltinName.find('<'));
195 BuiltinName = BuiltinName.substr(BuiltinName.find_last_of(' ') + 1);
196 }
197
198 // Check if the extracted name begins with "__spirv_ImageSampleExplicitLod"
199 // contains return type information at the end "_R<type>", if so extract the
200 // plain builtin name without the type information.
201 if (StringRef(BuiltinName).contains("__spirv_ImageSampleExplicitLod") &&
202 StringRef(BuiltinName).contains("_R")) {
203 BuiltinName = BuiltinName.substr(0, BuiltinName.find("_R"));
204 }
205
206 return BuiltinName;
207 }
208 } // namespace SPIRV
209
210 /// Looks up the demangled builtin call in the SPIRVBuiltins.td records using
211 /// the provided \p DemangledCall and specified \p Set.
212 ///
213 /// The lookup follows the following algorithm, returning the first successful
214 /// match:
215 /// 1. Search with the plain demangled name (expecting a 1:1 match).
216 /// 2. Search with the prefix before or suffix after the demangled name
217 /// signyfying the type of the first argument.
218 ///
219 /// \returns Wrapper around the demangled call and found builtin definition.
220 static std::unique_ptr<const SPIRV::IncomingCall>
lookupBuiltin(StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set,Register ReturnRegister,const SPIRVType * ReturnType,const SmallVectorImpl<Register> & Arguments)221 lookupBuiltin(StringRef DemangledCall,
222 SPIRV::InstructionSet::InstructionSet Set,
223 Register ReturnRegister, const SPIRVType *ReturnType,
224 const SmallVectorImpl<Register> &Arguments) {
225 std::string BuiltinName = SPIRV::lookupBuiltinNameHelper(DemangledCall);
226
227 SmallVector<StringRef, 10> BuiltinArgumentTypes;
228 StringRef BuiltinArgs =
229 DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
230 BuiltinArgs.split(BuiltinArgumentTypes, ',', -1, false);
231
232 // Look up the builtin in the defined set. Start with the plain demangled
233 // name, expecting a 1:1 match in the defined builtin set.
234 const SPIRV::DemangledBuiltin *Builtin;
235 if ((Builtin = SPIRV::lookupBuiltin(BuiltinName, Set)))
236 return std::make_unique<SPIRV::IncomingCall>(
237 BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
238
239 // If the initial look up was unsuccessful and the demangled call takes at
240 // least 1 argument, add a prefix or suffix signifying the type of the first
241 // argument and repeat the search.
242 if (BuiltinArgumentTypes.size() >= 1) {
243 char FirstArgumentType = BuiltinArgumentTypes[0][0];
244 // Prefix to be added to the builtin's name for lookup.
245 // For example, OpenCL "abs" taking an unsigned value has a prefix "u_".
246 std::string Prefix;
247
248 switch (FirstArgumentType) {
249 // Unsigned:
250 case 'u':
251 if (Set == SPIRV::InstructionSet::OpenCL_std)
252 Prefix = "u_";
253 else if (Set == SPIRV::InstructionSet::GLSL_std_450)
254 Prefix = "u";
255 break;
256 // Signed:
257 case 'c':
258 case 's':
259 case 'i':
260 case 'l':
261 if (Set == SPIRV::InstructionSet::OpenCL_std)
262 Prefix = "s_";
263 else if (Set == SPIRV::InstructionSet::GLSL_std_450)
264 Prefix = "s";
265 break;
266 // Floating-point:
267 case 'f':
268 case 'd':
269 case 'h':
270 if (Set == SPIRV::InstructionSet::OpenCL_std ||
271 Set == SPIRV::InstructionSet::GLSL_std_450)
272 Prefix = "f";
273 break;
274 }
275
276 // If argument-type name prefix was added, look up the builtin again.
277 if (!Prefix.empty() &&
278 (Builtin = SPIRV::lookupBuiltin(Prefix + BuiltinName, Set)))
279 return std::make_unique<SPIRV::IncomingCall>(
280 BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
281
282 // If lookup with a prefix failed, find a suffix to be added to the
283 // builtin's name for lookup. For example, OpenCL "group_reduce_max" taking
284 // an unsigned value has a suffix "u".
285 std::string Suffix;
286
287 switch (FirstArgumentType) {
288 // Unsigned:
289 case 'u':
290 Suffix = "u";
291 break;
292 // Signed:
293 case 'c':
294 case 's':
295 case 'i':
296 case 'l':
297 Suffix = "s";
298 break;
299 // Floating-point:
300 case 'f':
301 case 'd':
302 case 'h':
303 Suffix = "f";
304 break;
305 }
306
307 // If argument-type name suffix was added, look up the builtin again.
308 if (!Suffix.empty() &&
309 (Builtin = SPIRV::lookupBuiltin(BuiltinName + Suffix, Set)))
310 return std::make_unique<SPIRV::IncomingCall>(
311 BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments);
312 }
313
314 // No builtin with such name was found in the set.
315 return nullptr;
316 }
317
getBlockStructInstr(Register ParamReg,MachineRegisterInfo * MRI)318 static MachineInstr *getBlockStructInstr(Register ParamReg,
319 MachineRegisterInfo *MRI) {
320 // We expect the following sequence of instructions:
321 // %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
322 // or = G_GLOBAL_VALUE @block_literal_global
323 // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
324 // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
325 MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
326 assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
327 MI->getOperand(1).isReg());
328 Register BitcastReg = MI->getOperand(1).getReg();
329 MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
330 assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
331 BitcastMI->getOperand(2).isReg());
332 Register ValueReg = BitcastMI->getOperand(2).getReg();
333 MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
334 return ValueMI;
335 }
336
337 // Return an integer constant corresponding to the given register and
338 // defined in spv_track_constant.
339 // TODO: maybe unify with prelegalizer pass.
getConstFromIntrinsic(Register Reg,MachineRegisterInfo * MRI)340 static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
341 MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
342 assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) &&
343 DefMI->getOperand(2).isReg());
344 MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg());
345 assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT &&
346 DefMI2->getOperand(1).isCImm());
347 return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue();
348 }
349
350 // Return type of the instruction result from spv_assign_type intrinsic.
351 // TODO: maybe unify with prelegalizer pass.
getMachineInstrType(MachineInstr * MI)352 static const Type *getMachineInstrType(MachineInstr *MI) {
353 MachineInstr *NextMI = MI->getNextNode();
354 if (!NextMI)
355 return nullptr;
356 if (isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_name))
357 if ((NextMI = NextMI->getNextNode()) == nullptr)
358 return nullptr;
359 Register ValueReg = MI->getOperand(0).getReg();
360 if ((!isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_type) &&
361 !isSpvIntrinsic(*NextMI, Intrinsic::spv_assign_ptr_type)) ||
362 NextMI->getOperand(1).getReg() != ValueReg)
363 return nullptr;
364 Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
365 assert(Ty && "Type is expected");
366 return Ty;
367 }
368
getBlockStructType(Register ParamReg,MachineRegisterInfo * MRI)369 static const Type *getBlockStructType(Register ParamReg,
370 MachineRegisterInfo *MRI) {
371 // In principle, this information should be passed to us from Clang via
372 // an elementtype attribute. However, said attribute requires that
373 // the function call be an intrinsic, which is not. Instead, we rely on being
374 // able to trace this to the declaration of a variable: OpenCL C specification
375 // section 6.12.5 should guarantee that we can do this.
376 MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
377 if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
378 return MI->getOperand(1).getGlobal()->getType();
379 assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
380 "Blocks in OpenCL C must be traceable to allocation site");
381 return getMachineInstrType(MI);
382 }
383
384 //===----------------------------------------------------------------------===//
385 // Helper functions for building misc instructions
386 //===----------------------------------------------------------------------===//
387
388 /// Helper function building either a resulting scalar or vector bool register
389 /// depending on the expected \p ResultType.
390 ///
391 /// \returns Tuple of the resulting register and its type.
392 static std::tuple<Register, SPIRVType *>
buildBoolRegister(MachineIRBuilder & MIRBuilder,const SPIRVType * ResultType,SPIRVGlobalRegistry * GR)393 buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
394 SPIRVGlobalRegistry *GR) {
395 LLT Type;
396 SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
397
398 if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
399 unsigned VectorElements = ResultType->getOperand(2).getImm();
400 BoolType =
401 GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder);
402 const FixedVectorType *LLVMVectorType =
403 cast<FixedVectorType>(GR->getTypeForSPIRVType(BoolType));
404 Type = LLT::vector(LLVMVectorType->getElementCount(), 1);
405 } else {
406 Type = LLT::scalar(1);
407 }
408
409 Register ResultRegister =
410 MIRBuilder.getMRI()->createGenericVirtualRegister(Type);
411 MIRBuilder.getMRI()->setRegClass(ResultRegister, &SPIRV::IDRegClass);
412 GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF());
413 return std::make_tuple(ResultRegister, BoolType);
414 }
415
416 /// Helper function for building either a vector or scalar select instruction
417 /// depending on the expected \p ResultType.
buildSelectInst(MachineIRBuilder & MIRBuilder,Register ReturnRegister,Register SourceRegister,const SPIRVType * ReturnType,SPIRVGlobalRegistry * GR)418 static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
419 Register ReturnRegister, Register SourceRegister,
420 const SPIRVType *ReturnType,
421 SPIRVGlobalRegistry *GR) {
422 Register TrueConst, FalseConst;
423
424 if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
425 unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
426 uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
427 TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
428 FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
429 } else {
430 TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
431 FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
432 }
433 return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst,
434 FalseConst);
435 }
436
437 /// Helper function for building a load instruction loading into the
438 /// \p DestinationReg.
buildLoadInst(SPIRVType * BaseType,Register PtrRegister,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,LLT LowLevelType,Register DestinationReg=Register (0))439 static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
440 MachineIRBuilder &MIRBuilder,
441 SPIRVGlobalRegistry *GR, LLT LowLevelType,
442 Register DestinationReg = Register(0)) {
443 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
444 if (!DestinationReg.isValid()) {
445 DestinationReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
446 MRI->setType(DestinationReg, LLT::scalar(32));
447 GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
448 }
449 // TODO: consider using correct address space and alignment (p0 is canonical
450 // type for selection though).
451 MachinePointerInfo PtrInfo = MachinePointerInfo();
452 MIRBuilder.buildLoad(DestinationReg, PtrRegister, PtrInfo, Align());
453 return DestinationReg;
454 }
455
456 /// Helper function for building a load instruction for loading a builtin global
457 /// variable of \p BuiltinValue value.
buildBuiltinVariableLoad(MachineIRBuilder & MIRBuilder,SPIRVType * VariableType,SPIRVGlobalRegistry * GR,SPIRV::BuiltIn::BuiltIn BuiltinValue,LLT LLType,Register Reg=Register (0),bool isConst=true,bool hasLinkageTy=true)458 static Register buildBuiltinVariableLoad(
459 MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
460 SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
461 Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
462 Register NewRegister =
463 MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
464 MIRBuilder.getMRI()->setType(NewRegister,
465 LLT::pointer(0, GR->getPointerSize()));
466 SPIRVType *PtrType = GR->getOrCreateSPIRVPointerType(
467 VariableType, MIRBuilder, SPIRV::StorageClass::Input);
468 GR->assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
469
470 // Set up the global OpVariable with the necessary builtin decorations.
471 Register Variable = GR->buildGlobalVariable(
472 NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
473 SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
474 /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
475 false);
476
477 // Load the value from the global variable.
478 Register LoadedRegister =
479 buildLoadInst(VariableType, Variable, MIRBuilder, GR, LLType, Reg);
480 MIRBuilder.getMRI()->setType(LoadedRegister, LLType);
481 return LoadedRegister;
482 }
483
484 /// Helper external function for inserting ASSIGN_TYPE instuction between \p Reg
485 /// and its definition, set the new register as a destination of the definition,
486 /// assign SPIRVType to both registers. If SpirvTy is provided, use it as
487 /// SPIRVType in ASSIGN_TYPE, otherwise create it from \p Ty. Defined in
488 /// SPIRVPreLegalizer.cpp.
489 extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
490 SPIRVGlobalRegistry *GR,
491 MachineIRBuilder &MIB,
492 MachineRegisterInfo &MRI);
493
494 // TODO: Move to TableGen.
495 static SPIRV::MemorySemantics::MemorySemantics
getSPIRVMemSemantics(std::memory_order MemOrder)496 getSPIRVMemSemantics(std::memory_order MemOrder) {
497 switch (MemOrder) {
498 case std::memory_order::memory_order_relaxed:
499 return SPIRV::MemorySemantics::None;
500 case std::memory_order::memory_order_acquire:
501 return SPIRV::MemorySemantics::Acquire;
502 case std::memory_order::memory_order_release:
503 return SPIRV::MemorySemantics::Release;
504 case std::memory_order::memory_order_acq_rel:
505 return SPIRV::MemorySemantics::AcquireRelease;
506 case std::memory_order::memory_order_seq_cst:
507 return SPIRV::MemorySemantics::SequentiallyConsistent;
508 default:
509 report_fatal_error("Unknown CL memory scope");
510 }
511 }
512
getSPIRVScope(SPIRV::CLMemoryScope ClScope)513 static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
514 switch (ClScope) {
515 case SPIRV::CLMemoryScope::memory_scope_work_item:
516 return SPIRV::Scope::Invocation;
517 case SPIRV::CLMemoryScope::memory_scope_work_group:
518 return SPIRV::Scope::Workgroup;
519 case SPIRV::CLMemoryScope::memory_scope_device:
520 return SPIRV::Scope::Device;
521 case SPIRV::CLMemoryScope::memory_scope_all_svm_devices:
522 return SPIRV::Scope::CrossDevice;
523 case SPIRV::CLMemoryScope::memory_scope_sub_group:
524 return SPIRV::Scope::Subgroup;
525 }
526 report_fatal_error("Unknown CL memory scope");
527 }
528
buildConstantIntReg(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,unsigned BitWidth=32)529 static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
530 SPIRVGlobalRegistry *GR,
531 unsigned BitWidth = 32) {
532 SPIRVType *IntType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
533 return GR->buildConstantInt(Val, MIRBuilder, IntType);
534 }
535
buildScopeReg(Register CLScopeRegister,SPIRV::Scope::Scope Scope,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,MachineRegisterInfo * MRI)536 static Register buildScopeReg(Register CLScopeRegister,
537 SPIRV::Scope::Scope Scope,
538 MachineIRBuilder &MIRBuilder,
539 SPIRVGlobalRegistry *GR,
540 MachineRegisterInfo *MRI) {
541 if (CLScopeRegister.isValid()) {
542 auto CLScope =
543 static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
544 Scope = getSPIRVScope(CLScope);
545
546 if (CLScope == static_cast<unsigned>(Scope)) {
547 MRI->setRegClass(CLScopeRegister, &SPIRV::IDRegClass);
548 return CLScopeRegister;
549 }
550 }
551 return buildConstantIntReg(Scope, MIRBuilder, GR);
552 }
553
buildMemSemanticsReg(Register SemanticsRegister,Register PtrRegister,unsigned & Semantics,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)554 static Register buildMemSemanticsReg(Register SemanticsRegister,
555 Register PtrRegister, unsigned &Semantics,
556 MachineIRBuilder &MIRBuilder,
557 SPIRVGlobalRegistry *GR) {
558 if (SemanticsRegister.isValid()) {
559 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
560 std::memory_order Order =
561 static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
562 Semantics =
563 getSPIRVMemSemantics(Order) |
564 getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
565
566 if (Order == Semantics) {
567 MRI->setRegClass(SemanticsRegister, &SPIRV::IDRegClass);
568 return SemanticsRegister;
569 }
570 }
571 return buildConstantIntReg(Semantics, MIRBuilder, GR);
572 }
573
buildOpFromWrapper(MachineIRBuilder & MIRBuilder,unsigned Opcode,const SPIRV::IncomingCall * Call,Register TypeReg,ArrayRef<uint32_t> ImmArgs={})574 static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
575 const SPIRV::IncomingCall *Call,
576 Register TypeReg,
577 ArrayRef<uint32_t> ImmArgs = {}) {
578 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
579 auto MIB = MIRBuilder.buildInstr(Opcode);
580 if (TypeReg.isValid())
581 MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
582 unsigned Sz = Call->Arguments.size() - ImmArgs.size();
583 for (unsigned i = 0; i < Sz; ++i) {
584 Register ArgReg = Call->Arguments[i];
585 if (!MRI->getRegClassOrNull(ArgReg))
586 MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
587 MIB.addUse(ArgReg);
588 }
589 for (uint32_t ImmArg : ImmArgs)
590 MIB.addImm(ImmArg);
591 return true;
592 }
593
594 /// Helper function for translating atomic init to OpStore.
buildAtomicInitInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)595 static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
596 MachineIRBuilder &MIRBuilder) {
597 if (Call->isSpirvOp())
598 return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
599
600 assert(Call->Arguments.size() == 2 &&
601 "Need 2 arguments for atomic init translation");
602 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
603 MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
604 MIRBuilder.buildInstr(SPIRV::OpStore)
605 .addUse(Call->Arguments[0])
606 .addUse(Call->Arguments[1]);
607 return true;
608 }
609
610 /// Helper function for building an atomic load instruction.
buildAtomicLoadInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)611 static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
612 MachineIRBuilder &MIRBuilder,
613 SPIRVGlobalRegistry *GR) {
614 Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
615 if (Call->isSpirvOp())
616 return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicLoad, Call, TypeReg);
617
618 Register PtrRegister = Call->Arguments[0];
619 MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
620 // TODO: if true insert call to __translate_ocl_memory_sccope before
621 // OpAtomicLoad and the function implementation. We can use Translator's
622 // output for transcoding/atomic_explicit_arguments.cl as an example.
623 Register ScopeRegister;
624 if (Call->Arguments.size() > 1) {
625 ScopeRegister = Call->Arguments[1];
626 MIRBuilder.getMRI()->setRegClass(ScopeRegister, &SPIRV::IDRegClass);
627 } else
628 ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
629
630 Register MemSemanticsReg;
631 if (Call->Arguments.size() > 2) {
632 // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad.
633 MemSemanticsReg = Call->Arguments[2];
634 MIRBuilder.getMRI()->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
635 } else {
636 int Semantics =
637 SPIRV::MemorySemantics::SequentiallyConsistent |
638 getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
639 MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
640 }
641
642 MIRBuilder.buildInstr(SPIRV::OpAtomicLoad)
643 .addDef(Call->ReturnRegister)
644 .addUse(TypeReg)
645 .addUse(PtrRegister)
646 .addUse(ScopeRegister)
647 .addUse(MemSemanticsReg);
648 return true;
649 }
650
651 /// Helper function for building an atomic store instruction.
buildAtomicStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)652 static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
653 MachineIRBuilder &MIRBuilder,
654 SPIRVGlobalRegistry *GR) {
655 if (Call->isSpirvOp())
656 return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
657
658 Register ScopeRegister =
659 buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
660 Register PtrRegister = Call->Arguments[0];
661 MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
662 int Semantics =
663 SPIRV::MemorySemantics::SequentiallyConsistent |
664 getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
665 Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
666 MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
667 MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
668 .addUse(PtrRegister)
669 .addUse(ScopeRegister)
670 .addUse(MemSemanticsReg)
671 .addUse(Call->Arguments[1]);
672 return true;
673 }
674
675 /// Helper function for building an atomic compare-exchange instruction.
buildAtomicCompareExchangeInst(const SPIRV::IncomingCall * Call,const SPIRV::DemangledBuiltin * Builtin,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)676 static bool buildAtomicCompareExchangeInst(
677 const SPIRV::IncomingCall *Call, const SPIRV::DemangledBuiltin *Builtin,
678 unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
679 if (Call->isSpirvOp())
680 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
681 GR->getSPIRVTypeID(Call->ReturnType));
682
683 bool IsCmpxchg = Call->Builtin->Name.contains("cmpxchg");
684 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
685
686 Register ObjectPtr = Call->Arguments[0]; // Pointer (volatile A *object.)
687 Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected).
688 Register Desired = Call->Arguments[2]; // Value (C Desired).
689 MRI->setRegClass(ObjectPtr, &SPIRV::IDRegClass);
690 MRI->setRegClass(ExpectedArg, &SPIRV::IDRegClass);
691 MRI->setRegClass(Desired, &SPIRV::IDRegClass);
692 SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired);
693 LLT DesiredLLT = MRI->getType(Desired);
694
695 assert(GR->getSPIRVTypeForVReg(ObjectPtr)->getOpcode() ==
696 SPIRV::OpTypePointer);
697 unsigned ExpectedType = GR->getSPIRVTypeForVReg(ExpectedArg)->getOpcode();
698 (void)ExpectedType;
699 assert(IsCmpxchg ? ExpectedType == SPIRV::OpTypeInt
700 : ExpectedType == SPIRV::OpTypePointer);
701 assert(GR->isScalarOfType(Desired, SPIRV::OpTypeInt));
702
703 SPIRVType *SpvObjectPtrTy = GR->getSPIRVTypeForVReg(ObjectPtr);
704 assert(SpvObjectPtrTy->getOperand(2).isReg() && "SPIRV type is expected");
705 auto StorageClass = static_cast<SPIRV::StorageClass::StorageClass>(
706 SpvObjectPtrTy->getOperand(1).getImm());
707 auto MemSemStorage = getMemSemanticsForStorageClass(StorageClass);
708
709 Register MemSemEqualReg;
710 Register MemSemUnequalReg;
711 uint64_t MemSemEqual =
712 IsCmpxchg
713 ? SPIRV::MemorySemantics::None
714 : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
715 uint64_t MemSemUnequal =
716 IsCmpxchg
717 ? SPIRV::MemorySemantics::None
718 : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage;
719 if (Call->Arguments.size() >= 4) {
720 assert(Call->Arguments.size() >= 5 &&
721 "Need 5+ args for explicit atomic cmpxchg");
722 auto MemOrdEq =
723 static_cast<std::memory_order>(getIConstVal(Call->Arguments[3], MRI));
724 auto MemOrdNeq =
725 static_cast<std::memory_order>(getIConstVal(Call->Arguments[4], MRI));
726 MemSemEqual = getSPIRVMemSemantics(MemOrdEq) | MemSemStorage;
727 MemSemUnequal = getSPIRVMemSemantics(MemOrdNeq) | MemSemStorage;
728 if (MemOrdEq == MemSemEqual)
729 MemSemEqualReg = Call->Arguments[3];
730 if (MemOrdNeq == MemSemEqual)
731 MemSemUnequalReg = Call->Arguments[4];
732 MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
733 MRI->setRegClass(Call->Arguments[4], &SPIRV::IDRegClass);
734 }
735 if (!MemSemEqualReg.isValid())
736 MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR);
737 if (!MemSemUnequalReg.isValid())
738 MemSemUnequalReg = buildConstantIntReg(MemSemUnequal, MIRBuilder, GR);
739
740 Register ScopeReg;
741 auto Scope = IsCmpxchg ? SPIRV::Scope::Workgroup : SPIRV::Scope::Device;
742 if (Call->Arguments.size() >= 6) {
743 assert(Call->Arguments.size() == 6 &&
744 "Extra args for explicit atomic cmpxchg");
745 auto ClScope = static_cast<SPIRV::CLMemoryScope>(
746 getIConstVal(Call->Arguments[5], MRI));
747 Scope = getSPIRVScope(ClScope);
748 if (ClScope == static_cast<unsigned>(Scope))
749 ScopeReg = Call->Arguments[5];
750 MRI->setRegClass(Call->Arguments[5], &SPIRV::IDRegClass);
751 }
752 if (!ScopeReg.isValid())
753 ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
754
755 Register Expected = IsCmpxchg
756 ? ExpectedArg
757 : buildLoadInst(SpvDesiredTy, ExpectedArg, MIRBuilder,
758 GR, LLT::scalar(32));
759 MRI->setType(Expected, DesiredLLT);
760 Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
761 : Call->ReturnRegister;
762 if (!MRI->getRegClassOrNull(Tmp))
763 MRI->setRegClass(Tmp, &SPIRV::IDRegClass);
764 GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF());
765
766 SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
767 MIRBuilder.buildInstr(Opcode)
768 .addDef(Tmp)
769 .addUse(GR->getSPIRVTypeID(IntTy))
770 .addUse(ObjectPtr)
771 .addUse(ScopeReg)
772 .addUse(MemSemEqualReg)
773 .addUse(MemSemUnequalReg)
774 .addUse(Desired)
775 .addUse(Expected);
776 if (!IsCmpxchg) {
777 MIRBuilder.buildInstr(SPIRV::OpStore).addUse(ExpectedArg).addUse(Tmp);
778 MIRBuilder.buildICmp(CmpInst::ICMP_EQ, Call->ReturnRegister, Tmp, Expected);
779 }
780 return true;
781 }
782
783 /// Helper function for building atomic instructions.
buildAtomicRMWInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)784 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
785 MachineIRBuilder &MIRBuilder,
786 SPIRVGlobalRegistry *GR) {
787 if (Call->isSpirvOp())
788 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
789 GR->getSPIRVTypeID(Call->ReturnType));
790
791 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
792 Register ScopeRegister =
793 Call->Arguments.size() >= 4 ? Call->Arguments[3] : Register();
794
795 assert(Call->Arguments.size() <= 4 &&
796 "Too many args for explicit atomic RMW");
797 ScopeRegister = buildScopeReg(ScopeRegister, SPIRV::Scope::Workgroup,
798 MIRBuilder, GR, MRI);
799
800 Register PtrRegister = Call->Arguments[0];
801 unsigned Semantics = SPIRV::MemorySemantics::None;
802 MRI->setRegClass(PtrRegister, &SPIRV::IDRegClass);
803 Register MemSemanticsReg =
804 Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
805 MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
806 Semantics, MIRBuilder, GR);
807 MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
808 Register ValueReg = Call->Arguments[1];
809 Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
810 // support cl_ext_float_atomics
811 if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
812 if (Opcode == SPIRV::OpAtomicIAdd) {
813 Opcode = SPIRV::OpAtomicFAddEXT;
814 } else if (Opcode == SPIRV::OpAtomicISub) {
815 // Translate OpAtomicISub applied to a floating type argument to
816 // OpAtomicFAddEXT with the negative value operand
817 Opcode = SPIRV::OpAtomicFAddEXT;
818 Register NegValueReg =
819 MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
820 MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass);
821 GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
822 MIRBuilder.getMF());
823 MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
824 .addDef(NegValueReg)
825 .addUse(ValueReg);
826 insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
827 MIRBuilder.getMF().getRegInfo());
828 ValueReg = NegValueReg;
829 }
830 }
831 MIRBuilder.buildInstr(Opcode)
832 .addDef(Call->ReturnRegister)
833 .addUse(ValueTypeReg)
834 .addUse(PtrRegister)
835 .addUse(ScopeRegister)
836 .addUse(MemSemanticsReg)
837 .addUse(ValueReg);
838 return true;
839 }
840
841 /// Helper function for building an atomic floating-type instruction.
buildAtomicFloatingRMWInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)842 static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
843 unsigned Opcode,
844 MachineIRBuilder &MIRBuilder,
845 SPIRVGlobalRegistry *GR) {
846 assert(Call->Arguments.size() == 4 &&
847 "Wrong number of atomic floating-type builtin");
848
849 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
850
851 Register PtrReg = Call->Arguments[0];
852 MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);
853
854 Register ScopeReg = Call->Arguments[1];
855 MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
856
857 Register MemSemanticsReg = Call->Arguments[2];
858 MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
859
860 Register ValueReg = Call->Arguments[3];
861 MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
862
863 MIRBuilder.buildInstr(Opcode)
864 .addDef(Call->ReturnRegister)
865 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
866 .addUse(PtrReg)
867 .addUse(ScopeReg)
868 .addUse(MemSemanticsReg)
869 .addUse(ValueReg);
870 return true;
871 }
872
873 /// Helper function for building atomic flag instructions (e.g.
874 /// OpAtomicFlagTestAndSet).
buildAtomicFlagInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)875 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
876 unsigned Opcode, MachineIRBuilder &MIRBuilder,
877 SPIRVGlobalRegistry *GR) {
878 bool IsSet = Opcode == SPIRV::OpAtomicFlagTestAndSet;
879 Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
880 if (Call->isSpirvOp())
881 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
882 IsSet ? TypeReg : Register(0));
883
884 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
885 Register PtrRegister = Call->Arguments[0];
886 unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent;
887 Register MemSemanticsReg =
888 Call->Arguments.size() >= 2 ? Call->Arguments[1] : Register();
889 MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
890 Semantics, MIRBuilder, GR);
891
892 assert((Opcode != SPIRV::OpAtomicFlagClear ||
893 (Semantics != SPIRV::MemorySemantics::Acquire &&
894 Semantics != SPIRV::MemorySemantics::AcquireRelease)) &&
895 "Invalid memory order argument!");
896
897 Register ScopeRegister =
898 Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
899 ScopeRegister =
900 buildScopeReg(ScopeRegister, SPIRV::Scope::Device, MIRBuilder, GR, MRI);
901
902 auto MIB = MIRBuilder.buildInstr(Opcode);
903 if (IsSet)
904 MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
905
906 MIB.addUse(PtrRegister).addUse(ScopeRegister).addUse(MemSemanticsReg);
907 return true;
908 }
909
910 /// Helper function for building barriers, i.e., memory/control ordering
911 /// operations.
buildBarrierInst(const SPIRV::IncomingCall * Call,unsigned Opcode,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)912 static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
913 MachineIRBuilder &MIRBuilder,
914 SPIRVGlobalRegistry *GR) {
915 if (Call->isSpirvOp())
916 return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
917
918 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
919 unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
920 unsigned MemSemantics = SPIRV::MemorySemantics::None;
921
922 if (MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE)
923 MemSemantics |= SPIRV::MemorySemantics::WorkgroupMemory;
924
925 if (MemFlags & SPIRV::CLK_GLOBAL_MEM_FENCE)
926 MemSemantics |= SPIRV::MemorySemantics::CrossWorkgroupMemory;
927
928 if (MemFlags & SPIRV::CLK_IMAGE_MEM_FENCE)
929 MemSemantics |= SPIRV::MemorySemantics::ImageMemory;
930
931 if (Opcode == SPIRV::OpMemoryBarrier) {
932 std::memory_order MemOrder =
933 static_cast<std::memory_order>(getIConstVal(Call->Arguments[1], MRI));
934 MemSemantics = getSPIRVMemSemantics(MemOrder) | MemSemantics;
935 } else {
936 MemSemantics |= SPIRV::MemorySemantics::SequentiallyConsistent;
937 }
938
939 Register MemSemanticsReg;
940 if (MemFlags == MemSemantics) {
941 MemSemanticsReg = Call->Arguments[0];
942 MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
943 } else
944 MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR);
945
946 Register ScopeReg;
947 SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
948 SPIRV::Scope::Scope MemScope = Scope;
949 if (Call->Arguments.size() >= 2) {
950 assert(
951 ((Opcode != SPIRV::OpMemoryBarrier && Call->Arguments.size() == 2) ||
952 (Opcode == SPIRV::OpMemoryBarrier && Call->Arguments.size() == 3)) &&
953 "Extra args for explicitly scoped barrier");
954 Register ScopeArg = (Opcode == SPIRV::OpMemoryBarrier) ? Call->Arguments[2]
955 : Call->Arguments[1];
956 SPIRV::CLMemoryScope CLScope =
957 static_cast<SPIRV::CLMemoryScope>(getIConstVal(ScopeArg, MRI));
958 MemScope = getSPIRVScope(CLScope);
959 if (!(MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE) ||
960 (Opcode == SPIRV::OpMemoryBarrier))
961 Scope = MemScope;
962
963 if (CLScope == static_cast<unsigned>(Scope)) {
964 ScopeReg = Call->Arguments[1];
965 MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
966 }
967 }
968
969 if (!ScopeReg.isValid())
970 ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
971
972 auto MIB = MIRBuilder.buildInstr(Opcode).addUse(ScopeReg);
973 if (Opcode != SPIRV::OpMemoryBarrier)
974 MIB.addUse(buildConstantIntReg(MemScope, MIRBuilder, GR));
975 MIB.addUse(MemSemanticsReg);
976 return true;
977 }
978
getNumComponentsForDim(SPIRV::Dim::Dim dim)979 static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
980 switch (dim) {
981 case SPIRV::Dim::DIM_1D:
982 case SPIRV::Dim::DIM_Buffer:
983 return 1;
984 case SPIRV::Dim::DIM_2D:
985 case SPIRV::Dim::DIM_Cube:
986 case SPIRV::Dim::DIM_Rect:
987 return 2;
988 case SPIRV::Dim::DIM_3D:
989 return 3;
990 default:
991 report_fatal_error("Cannot get num components for given Dim");
992 }
993 }
994
995 /// Helper function for obtaining the number of size components.
getNumSizeComponents(SPIRVType * imgType)996 static unsigned getNumSizeComponents(SPIRVType *imgType) {
997 assert(imgType->getOpcode() == SPIRV::OpTypeImage);
998 auto dim = static_cast<SPIRV::Dim::Dim>(imgType->getOperand(2).getImm());
999 unsigned numComps = getNumComponentsForDim(dim);
1000 bool arrayed = imgType->getOperand(4).getImm() == 1;
1001 return arrayed ? numComps + 1 : numComps;
1002 }
1003
1004 //===----------------------------------------------------------------------===//
1005 // Implementation functions for each builtin group
1006 //===----------------------------------------------------------------------===//
1007
generateExtInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1008 static bool generateExtInst(const SPIRV::IncomingCall *Call,
1009 MachineIRBuilder &MIRBuilder,
1010 SPIRVGlobalRegistry *GR) {
1011 // Lookup the extended instruction number in the TableGen records.
1012 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1013 uint32_t Number =
1014 SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
1015
1016 // Build extended instruction.
1017 auto MIB =
1018 MIRBuilder.buildInstr(SPIRV::OpExtInst)
1019 .addDef(Call->ReturnRegister)
1020 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1021 .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
1022 .addImm(Number);
1023
1024 for (auto Argument : Call->Arguments)
1025 MIB.addUse(Argument);
1026 return true;
1027 }
1028
generateRelationalInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1029 static bool generateRelationalInst(const SPIRV::IncomingCall *Call,
1030 MachineIRBuilder &MIRBuilder,
1031 SPIRVGlobalRegistry *GR) {
1032 // Lookup the instruction opcode in the TableGen records.
1033 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1034 unsigned Opcode =
1035 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1036
1037 Register CompareRegister;
1038 SPIRVType *RelationType;
1039 std::tie(CompareRegister, RelationType) =
1040 buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
1041
1042 // Build relational instruction.
1043 auto MIB = MIRBuilder.buildInstr(Opcode)
1044 .addDef(CompareRegister)
1045 .addUse(GR->getSPIRVTypeID(RelationType));
1046
1047 for (auto Argument : Call->Arguments)
1048 MIB.addUse(Argument);
1049
1050 // Build select instruction.
1051 return buildSelectInst(MIRBuilder, Call->ReturnRegister, CompareRegister,
1052 Call->ReturnType, GR);
1053 }
1054
generateGroupInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1055 static bool generateGroupInst(const SPIRV::IncomingCall *Call,
1056 MachineIRBuilder &MIRBuilder,
1057 SPIRVGlobalRegistry *GR) {
1058 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1059 const SPIRV::GroupBuiltin *GroupBuiltin =
1060 SPIRV::lookupGroupBuiltin(Builtin->Name);
1061
1062 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1063 if (Call->isSpirvOp()) {
1064 if (GroupBuiltin->NoGroupOperation)
1065 return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
1066 GR->getSPIRVTypeID(Call->ReturnType));
1067
1068 // Group Operation is a literal
1069 Register GroupOpReg = Call->Arguments[1];
1070 const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
1071 if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT)
1072 report_fatal_error(
1073 "Group Operation parameter must be an integer constant");
1074 uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
1075 Register ScopeReg = Call->Arguments[0];
1076 if (!MRI->getRegClassOrNull(ScopeReg))
1077 MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
1078 auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1079 .addDef(Call->ReturnRegister)
1080 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1081 .addUse(ScopeReg)
1082 .addImm(GrpOp);
1083 for (unsigned i = 2; i < Call->Arguments.size(); ++i) {
1084 Register ArgReg = Call->Arguments[i];
1085 if (!MRI->getRegClassOrNull(ArgReg))
1086 MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
1087 MIB.addUse(ArgReg);
1088 }
1089 return true;
1090 }
1091
1092 Register Arg0;
1093 if (GroupBuiltin->HasBoolArg) {
1094 Register ConstRegister = Call->Arguments[0];
1095 auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI);
1096 (void)ArgInstruction;
1097 // TODO: support non-constant bool values.
1098 assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT &&
1099 "Only constant bool value args are supported");
1100 if (GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() !=
1101 SPIRV::OpTypeBool)
1102 Arg0 = GR->buildConstantInt(getIConstVal(ConstRegister, MRI), MIRBuilder,
1103 GR->getOrCreateSPIRVBoolType(MIRBuilder));
1104 }
1105
1106 Register GroupResultRegister = Call->ReturnRegister;
1107 SPIRVType *GroupResultType = Call->ReturnType;
1108
1109 // TODO: maybe we need to check whether the result type is already boolean
1110 // and in this case do not insert select instruction.
1111 const bool HasBoolReturnTy =
1112 GroupBuiltin->IsElect || GroupBuiltin->IsAllOrAny ||
1113 GroupBuiltin->IsAllEqual || GroupBuiltin->IsLogical ||
1114 GroupBuiltin->IsInverseBallot || GroupBuiltin->IsBallotBitExtract;
1115
1116 if (HasBoolReturnTy)
1117 std::tie(GroupResultRegister, GroupResultType) =
1118 buildBoolRegister(MIRBuilder, Call->ReturnType, GR);
1119
1120 auto Scope = Builtin->Name.starts_with("sub_group") ? SPIRV::Scope::Subgroup
1121 : SPIRV::Scope::Workgroup;
1122 Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
1123
1124 // Build work/sub group instruction.
1125 auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1126 .addDef(GroupResultRegister)
1127 .addUse(GR->getSPIRVTypeID(GroupResultType))
1128 .addUse(ScopeRegister);
1129
1130 if (!GroupBuiltin->NoGroupOperation)
1131 MIB.addImm(GroupBuiltin->GroupOperation);
1132 if (Call->Arguments.size() > 0) {
1133 MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
1134 MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1135 for (unsigned i = 1; i < Call->Arguments.size(); i++) {
1136 MIB.addUse(Call->Arguments[i]);
1137 MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
1138 }
1139 }
1140
1141 // Build select instruction.
1142 if (HasBoolReturnTy)
1143 buildSelectInst(MIRBuilder, Call->ReturnRegister, GroupResultRegister,
1144 Call->ReturnType, GR);
1145 return true;
1146 }
1147
generateIntelSubgroupsInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1148 static bool generateIntelSubgroupsInst(const SPIRV::IncomingCall *Call,
1149 MachineIRBuilder &MIRBuilder,
1150 SPIRVGlobalRegistry *GR) {
1151 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1152 MachineFunction &MF = MIRBuilder.getMF();
1153 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1154 if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1155 std::string DiagMsg = std::string(Builtin->Name) +
1156 ": the builtin requires the following SPIR-V "
1157 "extension: SPV_INTEL_subgroups";
1158 report_fatal_error(DiagMsg.c_str(), false);
1159 }
1160 const SPIRV::IntelSubgroupsBuiltin *IntelSubgroups =
1161 SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name);
1162
1163 uint32_t OpCode = IntelSubgroups->Opcode;
1164 if (Call->isSpirvOp()) {
1165 bool IsSet = OpCode != SPIRV::OpSubgroupBlockWriteINTEL &&
1166 OpCode != SPIRV::OpSubgroupImageBlockWriteINTEL;
1167 return buildOpFromWrapper(MIRBuilder, OpCode, Call,
1168 IsSet ? GR->getSPIRVTypeID(Call->ReturnType)
1169 : Register(0));
1170 }
1171
1172 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1173 if (IntelSubgroups->IsBlock) {
1174 // Minimal number or arguments set in TableGen records is 1
1175 if (SPIRVType *Arg0Type = GR->getSPIRVTypeForVReg(Call->Arguments[0])) {
1176 if (Arg0Type->getOpcode() == SPIRV::OpTypeImage) {
1177 // TODO: add required validation from the specification:
1178 // "'Image' must be an object whose type is OpTypeImage with a 'Sampled'
1179 // operand of 0 or 2. If the 'Sampled' operand is 2, then some
1180 // dimensions require a capability."
1181 switch (OpCode) {
1182 case SPIRV::OpSubgroupBlockReadINTEL:
1183 OpCode = SPIRV::OpSubgroupImageBlockReadINTEL;
1184 break;
1185 case SPIRV::OpSubgroupBlockWriteINTEL:
1186 OpCode = SPIRV::OpSubgroupImageBlockWriteINTEL;
1187 break;
1188 }
1189 }
1190 }
1191 }
1192
1193 // TODO: opaque pointers types should be eventually resolved in such a way
1194 // that validation of block read is enabled with respect to the following
1195 // specification requirement:
1196 // "'Result Type' may be a scalar or vector type, and its component type must
1197 // be equal to the type pointed to by 'Ptr'."
1198 // For example, function parameter type should not be default i8 pointer, but
1199 // depend on the result type of the instruction where it is used as a pointer
1200 // argument of OpSubgroupBlockReadINTEL
1201
1202 // Build Intel subgroups instruction
1203 MachineInstrBuilder MIB =
1204 IntelSubgroups->IsWrite
1205 ? MIRBuilder.buildInstr(OpCode)
1206 : MIRBuilder.buildInstr(OpCode)
1207 .addDef(Call->ReturnRegister)
1208 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1209 for (size_t i = 0; i < Call->Arguments.size(); ++i) {
1210 MIB.addUse(Call->Arguments[i]);
1211 MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
1212 }
1213
1214 return true;
1215 }
1216
generateGroupUniformInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1217 static bool generateGroupUniformInst(const SPIRV::IncomingCall *Call,
1218 MachineIRBuilder &MIRBuilder,
1219 SPIRVGlobalRegistry *GR) {
1220 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1221 MachineFunction &MF = MIRBuilder.getMF();
1222 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1223 if (!ST->canUseExtension(
1224 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1225 std::string DiagMsg = std::string(Builtin->Name) +
1226 ": the builtin requires the following SPIR-V "
1227 "extension: SPV_KHR_uniform_group_instructions";
1228 report_fatal_error(DiagMsg.c_str(), false);
1229 }
1230 const SPIRV::GroupUniformBuiltin *GroupUniform =
1231 SPIRV::lookupGroupUniformBuiltin(Builtin->Name);
1232 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1233
1234 Register GroupResultReg = Call->ReturnRegister;
1235 MRI->setRegClass(GroupResultReg, &SPIRV::IDRegClass);
1236
1237 // Scope
1238 Register ScopeReg = Call->Arguments[0];
1239 MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
1240
1241 // Group Operation
1242 Register ConstGroupOpReg = Call->Arguments[1];
1243 const MachineInstr *Const = getDefInstrMaybeConstant(ConstGroupOpReg, MRI);
1244 if (!Const || Const->getOpcode() != TargetOpcode::G_CONSTANT)
1245 report_fatal_error(
1246 "expect a constant group operation for a uniform group instruction",
1247 false);
1248 const MachineOperand &ConstOperand = Const->getOperand(1);
1249 if (!ConstOperand.isCImm())
1250 report_fatal_error("uniform group instructions: group operation must be an "
1251 "integer constant",
1252 false);
1253
1254 // Value
1255 Register ValueReg = Call->Arguments[2];
1256 MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
1257
1258 auto MIB = MIRBuilder.buildInstr(GroupUniform->Opcode)
1259 .addDef(GroupResultReg)
1260 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1261 .addUse(ScopeReg);
1262 addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1263 MIB.addUse(ValueReg);
1264
1265 return true;
1266 }
1267
generateKernelClockInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1268 static bool generateKernelClockInst(const SPIRV::IncomingCall *Call,
1269 MachineIRBuilder &MIRBuilder,
1270 SPIRVGlobalRegistry *GR) {
1271 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1272 MachineFunction &MF = MIRBuilder.getMF();
1273 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
1274 if (!ST->canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock)) {
1275 std::string DiagMsg = std::string(Builtin->Name) +
1276 ": the builtin requires the following SPIR-V "
1277 "extension: SPV_KHR_shader_clock";
1278 report_fatal_error(DiagMsg.c_str(), false);
1279 }
1280
1281 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1282 Register ResultReg = Call->ReturnRegister;
1283 MRI->setRegClass(ResultReg, &SPIRV::IDRegClass);
1284
1285 // Deduce the `Scope` operand from the builtin function name.
1286 SPIRV::Scope::Scope ScopeArg =
1287 StringSwitch<SPIRV::Scope::Scope>(Builtin->Name)
1288 .EndsWith("device", SPIRV::Scope::Scope::Device)
1289 .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
1290 .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
1291 Register ScopeReg = buildConstantIntReg(ScopeArg, MIRBuilder, GR);
1292
1293 MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
1294 .addDef(ResultReg)
1295 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1296 .addUse(ScopeReg);
1297
1298 return true;
1299 }
1300
1301 // These queries ask for a single size_t result for a given dimension index, e.g
1302 // size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
1303 // these values are all vec3 types, so we need to extract the correct index or
1304 // return defaultVal (0 or 1 depending on the query). We also handle extending
1305 // or tuncating in case size_t does not match the expected result type's
1306 // bitwidth.
1307 //
1308 // For a constant index >= 3 we generate:
1309 // %res = OpConstant %SizeT 0
1310 //
1311 // For other indices we generate:
1312 // %g = OpVariable %ptr_V3_SizeT Input
1313 // OpDecorate %g BuiltIn XXX
1314 // OpDecorate %g LinkageAttributes "__spirv_BuiltInXXX"
1315 // OpDecorate %g Constant
1316 // %loadedVec = OpLoad %V3_SizeT %g
1317 //
1318 // Then, if the index is constant < 3, we generate:
1319 // %res = OpCompositeExtract %SizeT %loadedVec idx
1320 // If the index is dynamic, we generate:
1321 // %tmp = OpVectorExtractDynamic %SizeT %loadedVec %idx
1322 // %cmp = OpULessThan %bool %idx %const_3
1323 // %res = OpSelect %SizeT %cmp %tmp %const_0
1324 //
1325 // If the bitwidth of %res does not match the expected return type, we add an
1326 // extend or truncate.
genWorkgroupQuery(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR,SPIRV::BuiltIn::BuiltIn BuiltinValue,uint64_t DefaultValue)1327 static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
1328 MachineIRBuilder &MIRBuilder,
1329 SPIRVGlobalRegistry *GR,
1330 SPIRV::BuiltIn::BuiltIn BuiltinValue,
1331 uint64_t DefaultValue) {
1332 Register IndexRegister = Call->Arguments[0];
1333 const unsigned ResultWidth = Call->ReturnType->getOperand(1).getImm();
1334 const unsigned PointerSize = GR->getPointerSize();
1335 const SPIRVType *PointerSizeType =
1336 GR->getOrCreateSPIRVIntegerType(PointerSize, MIRBuilder);
1337 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1338 auto IndexInstruction = getDefInstrMaybeConstant(IndexRegister, MRI);
1339
1340 // Set up the final register to do truncation or extension on at the end.
1341 Register ToTruncate = Call->ReturnRegister;
1342
1343 // If the index is constant, we can statically determine if it is in range.
1344 bool IsConstantIndex =
1345 IndexInstruction->getOpcode() == TargetOpcode::G_CONSTANT;
1346
1347 // If it's out of range (max dimension is 3), we can just return the constant
1348 // default value (0 or 1 depending on which query function).
1349 if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) {
1350 Register DefaultReg = Call->ReturnRegister;
1351 if (PointerSize != ResultWidth) {
1352 DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1353 MRI->setRegClass(DefaultReg, &SPIRV::IDRegClass);
1354 GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg,
1355 MIRBuilder.getMF());
1356 ToTruncate = DefaultReg;
1357 }
1358 auto NewRegister =
1359 GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
1360 MIRBuilder.buildCopy(DefaultReg, NewRegister);
1361 } else { // If it could be in range, we need to load from the given builtin.
1362 auto Vec3Ty =
1363 GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
1364 Register LoadedVector =
1365 buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue,
1366 LLT::fixed_vector(3, PointerSize));
1367 // Set up the vreg to extract the result to (possibly a new temporary one).
1368 Register Extracted = Call->ReturnRegister;
1369 if (!IsConstantIndex || PointerSize != ResultWidth) {
1370 Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1371 MRI->setRegClass(Extracted, &SPIRV::IDRegClass);
1372 GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF());
1373 }
1374 // Use Intrinsic::spv_extractelt so dynamic vs static extraction is
1375 // handled later: extr = spv_extractelt LoadedVector, IndexRegister.
1376 MachineInstrBuilder ExtractInst = MIRBuilder.buildIntrinsic(
1377 Intrinsic::spv_extractelt, ArrayRef<Register>{Extracted}, true, false);
1378 ExtractInst.addUse(LoadedVector).addUse(IndexRegister);
1379
1380 // If the index is dynamic, need check if it's < 3, and then use a select.
1381 if (!IsConstantIndex) {
1382 insertAssignInstr(Extracted, nullptr, PointerSizeType, GR, MIRBuilder,
1383 *MRI);
1384
1385 auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister);
1386 auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
1387
1388 Register CompareRegister =
1389 MRI->createGenericVirtualRegister(LLT::scalar(1));
1390 MRI->setRegClass(CompareRegister, &SPIRV::IDRegClass);
1391 GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
1392
1393 // Use G_ICMP to check if idxVReg < 3.
1394 MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
1395 GR->buildConstantInt(3, MIRBuilder, IndexType));
1396
1397 // Get constant for the default value (0 or 1 depending on which
1398 // function).
1399 Register DefaultRegister =
1400 GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
1401
1402 // Get a register for the selection result (possibly a new temporary one).
1403 Register SelectionResult = Call->ReturnRegister;
1404 if (PointerSize != ResultWidth) {
1405 SelectionResult =
1406 MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
1407 MRI->setRegClass(SelectionResult, &SPIRV::IDRegClass);
1408 GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult,
1409 MIRBuilder.getMF());
1410 }
1411 // Create the final G_SELECT to return the extracted value or the default.
1412 MIRBuilder.buildSelect(SelectionResult, CompareRegister, Extracted,
1413 DefaultRegister);
1414 ToTruncate = SelectionResult;
1415 } else {
1416 ToTruncate = Extracted;
1417 }
1418 }
1419 // Alter the result's bitwidth if it does not match the SizeT value extracted.
1420 if (PointerSize != ResultWidth)
1421 MIRBuilder.buildZExtOrTrunc(Call->ReturnRegister, ToTruncate);
1422 return true;
1423 }
1424
generateBuiltinVar(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1425 static bool generateBuiltinVar(const SPIRV::IncomingCall *Call,
1426 MachineIRBuilder &MIRBuilder,
1427 SPIRVGlobalRegistry *GR) {
1428 // Lookup the builtin variable record.
1429 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1430 SPIRV::BuiltIn::BuiltIn Value =
1431 SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1432
1433 if (Value == SPIRV::BuiltIn::GlobalInvocationId)
1434 return genWorkgroupQuery(Call, MIRBuilder, GR, Value, 0);
1435
1436 // Build a load instruction for the builtin variable.
1437 unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType);
1438 LLT LLType;
1439 if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector)
1440 LLType =
1441 LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth);
1442 else
1443 LLType = LLT::scalar(BitWidth);
1444
1445 return buildBuiltinVariableLoad(MIRBuilder, Call->ReturnType, GR, Value,
1446 LLType, Call->ReturnRegister);
1447 }
1448
generateAtomicInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1449 static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
1450 MachineIRBuilder &MIRBuilder,
1451 SPIRVGlobalRegistry *GR) {
1452 // Lookup the instruction opcode in the TableGen records.
1453 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1454 unsigned Opcode =
1455 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1456
1457 switch (Opcode) {
1458 case SPIRV::OpStore:
1459 return buildAtomicInitInst(Call, MIRBuilder);
1460 case SPIRV::OpAtomicLoad:
1461 return buildAtomicLoadInst(Call, MIRBuilder, GR);
1462 case SPIRV::OpAtomicStore:
1463 return buildAtomicStoreInst(Call, MIRBuilder, GR);
1464 case SPIRV::OpAtomicCompareExchange:
1465 case SPIRV::OpAtomicCompareExchangeWeak:
1466 return buildAtomicCompareExchangeInst(Call, Builtin, Opcode, MIRBuilder,
1467 GR);
1468 case SPIRV::OpAtomicIAdd:
1469 case SPIRV::OpAtomicISub:
1470 case SPIRV::OpAtomicOr:
1471 case SPIRV::OpAtomicXor:
1472 case SPIRV::OpAtomicAnd:
1473 case SPIRV::OpAtomicExchange:
1474 return buildAtomicRMWInst(Call, Opcode, MIRBuilder, GR);
1475 case SPIRV::OpMemoryBarrier:
1476 return buildBarrierInst(Call, SPIRV::OpMemoryBarrier, MIRBuilder, GR);
1477 case SPIRV::OpAtomicFlagTestAndSet:
1478 case SPIRV::OpAtomicFlagClear:
1479 return buildAtomicFlagInst(Call, Opcode, MIRBuilder, GR);
1480 default:
1481 if (Call->isSpirvOp())
1482 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1483 GR->getSPIRVTypeID(Call->ReturnType));
1484 return false;
1485 }
1486 }
1487
generateAtomicFloatingInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1488 static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
1489 MachineIRBuilder &MIRBuilder,
1490 SPIRVGlobalRegistry *GR) {
1491 // Lookup the instruction opcode in the TableGen records.
1492 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1493 unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
1494
1495 switch (Opcode) {
1496 case SPIRV::OpAtomicFAddEXT:
1497 case SPIRV::OpAtomicFMinEXT:
1498 case SPIRV::OpAtomicFMaxEXT:
1499 return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
1500 default:
1501 return false;
1502 }
1503 }
1504
generateBarrierInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1505 static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
1506 MachineIRBuilder &MIRBuilder,
1507 SPIRVGlobalRegistry *GR) {
1508 // Lookup the instruction opcode in the TableGen records.
1509 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1510 unsigned Opcode =
1511 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1512
1513 return buildBarrierInst(Call, Opcode, MIRBuilder, GR);
1514 }
1515
generateCastToPtrInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)1516 static bool generateCastToPtrInst(const SPIRV::IncomingCall *Call,
1517 MachineIRBuilder &MIRBuilder) {
1518 MIRBuilder.buildInstr(TargetOpcode::G_ADDRSPACE_CAST)
1519 .addDef(Call->ReturnRegister)
1520 .addUse(Call->Arguments[0]);
1521 return true;
1522 }
1523
generateDotOrFMulInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1524 static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
1525 MachineIRBuilder &MIRBuilder,
1526 SPIRVGlobalRegistry *GR) {
1527 if (Call->isSpirvOp())
1528 return buildOpFromWrapper(MIRBuilder, SPIRV::OpDot, Call,
1529 GR->getSPIRVTypeID(Call->ReturnType));
1530 unsigned Opcode = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode();
1531 bool IsVec = Opcode == SPIRV::OpTypeVector;
1532 // Use OpDot only in case of vector args and OpFMul in case of scalar args.
1533 MIRBuilder.buildInstr(IsVec ? SPIRV::OpDot : SPIRV::OpFMulS)
1534 .addDef(Call->ReturnRegister)
1535 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1536 .addUse(Call->Arguments[0])
1537 .addUse(Call->Arguments[1]);
1538 return true;
1539 }
1540
generateWaveInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1541 static bool generateWaveInst(const SPIRV::IncomingCall *Call,
1542 MachineIRBuilder &MIRBuilder,
1543 SPIRVGlobalRegistry *GR) {
1544 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1545 SPIRV::BuiltIn::BuiltIn Value =
1546 SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1547
1548 // For now, we only support a single Wave intrinsic with a single return type.
1549 assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
1550 LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
1551
1552 return buildBuiltinVariableLoad(
1553 MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
1554 /* isConst= */ false, /* hasLinkageTy= */ false);
1555 }
1556
generateGetQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1557 static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
1558 MachineIRBuilder &MIRBuilder,
1559 SPIRVGlobalRegistry *GR) {
1560 // Lookup the builtin record.
1561 SPIRV::BuiltIn::BuiltIn Value =
1562 SPIRV::lookupGetBuiltin(Call->Builtin->Name, Call->Builtin->Set)->Value;
1563 uint64_t IsDefault = (Value == SPIRV::BuiltIn::GlobalSize ||
1564 Value == SPIRV::BuiltIn::WorkgroupSize ||
1565 Value == SPIRV::BuiltIn::EnqueuedWorkgroupSize);
1566 return genWorkgroupQuery(Call, MIRBuilder, GR, Value, IsDefault ? 1 : 0);
1567 }
1568
generateImageSizeQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1569 static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
1570 MachineIRBuilder &MIRBuilder,
1571 SPIRVGlobalRegistry *GR) {
1572 // Lookup the image size query component number in the TableGen records.
1573 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1574 uint32_t Component =
1575 SPIRV::lookupImageQueryBuiltin(Builtin->Name, Builtin->Set)->Component;
1576 // Query result may either be a vector or a scalar. If return type is not a
1577 // vector, expect only a single size component. Otherwise get the number of
1578 // expected components.
1579 SPIRVType *RetTy = Call->ReturnType;
1580 unsigned NumExpectedRetComponents = RetTy->getOpcode() == SPIRV::OpTypeVector
1581 ? RetTy->getOperand(2).getImm()
1582 : 1;
1583 // Get the actual number of query result/size components.
1584 SPIRVType *ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1585 unsigned NumActualRetComponents = getNumSizeComponents(ImgType);
1586 Register QueryResult = Call->ReturnRegister;
1587 SPIRVType *QueryResultType = Call->ReturnType;
1588 if (NumExpectedRetComponents != NumActualRetComponents) {
1589 QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister(
1590 LLT::fixed_vector(NumActualRetComponents, 32));
1591 MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::IDRegClass);
1592 SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
1593 QueryResultType = GR->getOrCreateSPIRVVectorType(
1594 IntTy, NumActualRetComponents, MIRBuilder);
1595 GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF());
1596 }
1597 bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
1598 unsigned Opcode =
1599 IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod;
1600 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1601 auto MIB = MIRBuilder.buildInstr(Opcode)
1602 .addDef(QueryResult)
1603 .addUse(GR->getSPIRVTypeID(QueryResultType))
1604 .addUse(Call->Arguments[0]);
1605 if (!IsDimBuf)
1606 MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Lod id.
1607 if (NumExpectedRetComponents == NumActualRetComponents)
1608 return true;
1609 if (NumExpectedRetComponents == 1) {
1610 // Only 1 component is expected, build OpCompositeExtract instruction.
1611 unsigned ExtractedComposite =
1612 Component == 3 ? NumActualRetComponents - 1 : Component;
1613 assert(ExtractedComposite < NumActualRetComponents &&
1614 "Invalid composite index!");
1615 Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
1616 SPIRVType *NewType = nullptr;
1617 if (QueryResultType->getOpcode() == SPIRV::OpTypeVector) {
1618 Register NewTypeReg = QueryResultType->getOperand(1).getReg();
1619 if (TypeReg != NewTypeReg &&
1620 (NewType = GR->getSPIRVTypeForVReg(NewTypeReg)) != nullptr)
1621 TypeReg = NewTypeReg;
1622 }
1623 MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1624 .addDef(Call->ReturnRegister)
1625 .addUse(TypeReg)
1626 .addUse(QueryResult)
1627 .addImm(ExtractedComposite);
1628 if (NewType != nullptr)
1629 insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
1630 MIRBuilder.getMF().getRegInfo());
1631 } else {
1632 // More than 1 component is expected, fill a new vector.
1633 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle)
1634 .addDef(Call->ReturnRegister)
1635 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1636 .addUse(QueryResult)
1637 .addUse(QueryResult);
1638 for (unsigned i = 0; i < NumExpectedRetComponents; ++i)
1639 MIB.addImm(i < NumActualRetComponents ? i : 0xffffffff);
1640 }
1641 return true;
1642 }
1643
generateImageMiscQueryInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1644 static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
1645 MachineIRBuilder &MIRBuilder,
1646 SPIRVGlobalRegistry *GR) {
1647 assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt &&
1648 "Image samples query result must be of int type!");
1649
1650 // Lookup the instruction opcode in the TableGen records.
1651 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1652 unsigned Opcode =
1653 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1654
1655 Register Image = Call->Arguments[0];
1656 MIRBuilder.getMRI()->setRegClass(Image, &SPIRV::IDRegClass);
1657 SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
1658 GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
1659 (void)ImageDimensionality;
1660
1661 switch (Opcode) {
1662 case SPIRV::OpImageQuerySamples:
1663 assert(ImageDimensionality == SPIRV::Dim::DIM_2D &&
1664 "Image must be of 2D dimensionality");
1665 break;
1666 case SPIRV::OpImageQueryLevels:
1667 assert((ImageDimensionality == SPIRV::Dim::DIM_1D ||
1668 ImageDimensionality == SPIRV::Dim::DIM_2D ||
1669 ImageDimensionality == SPIRV::Dim::DIM_3D ||
1670 ImageDimensionality == SPIRV::Dim::DIM_Cube) &&
1671 "Image must be of 1D/2D/3D/Cube dimensionality");
1672 break;
1673 }
1674
1675 MIRBuilder.buildInstr(Opcode)
1676 .addDef(Call->ReturnRegister)
1677 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1678 .addUse(Image);
1679 return true;
1680 }
1681
1682 // TODO: Move to TableGen.
1683 static SPIRV::SamplerAddressingMode::SamplerAddressingMode
getSamplerAddressingModeFromBitmask(unsigned Bitmask)1684 getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
1685 switch (Bitmask & SPIRV::CLK_ADDRESS_MODE_MASK) {
1686 case SPIRV::CLK_ADDRESS_CLAMP:
1687 return SPIRV::SamplerAddressingMode::Clamp;
1688 case SPIRV::CLK_ADDRESS_CLAMP_TO_EDGE:
1689 return SPIRV::SamplerAddressingMode::ClampToEdge;
1690 case SPIRV::CLK_ADDRESS_REPEAT:
1691 return SPIRV::SamplerAddressingMode::Repeat;
1692 case SPIRV::CLK_ADDRESS_MIRRORED_REPEAT:
1693 return SPIRV::SamplerAddressingMode::RepeatMirrored;
1694 case SPIRV::CLK_ADDRESS_NONE:
1695 return SPIRV::SamplerAddressingMode::None;
1696 default:
1697 report_fatal_error("Unknown CL address mode");
1698 }
1699 }
1700
getSamplerParamFromBitmask(unsigned Bitmask)1701 static unsigned getSamplerParamFromBitmask(unsigned Bitmask) {
1702 return (Bitmask & SPIRV::CLK_NORMALIZED_COORDS_TRUE) ? 1 : 0;
1703 }
1704
1705 static SPIRV::SamplerFilterMode::SamplerFilterMode
getSamplerFilterModeFromBitmask(unsigned Bitmask)1706 getSamplerFilterModeFromBitmask(unsigned Bitmask) {
1707 if (Bitmask & SPIRV::CLK_FILTER_LINEAR)
1708 return SPIRV::SamplerFilterMode::Linear;
1709 if (Bitmask & SPIRV::CLK_FILTER_NEAREST)
1710 return SPIRV::SamplerFilterMode::Nearest;
1711 return SPIRV::SamplerFilterMode::Nearest;
1712 }
1713
generateReadImageInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1714 static bool generateReadImageInst(const StringRef DemangledCall,
1715 const SPIRV::IncomingCall *Call,
1716 MachineIRBuilder &MIRBuilder,
1717 SPIRVGlobalRegistry *GR) {
1718 Register Image = Call->Arguments[0];
1719 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1720 MRI->setRegClass(Image, &SPIRV::IDRegClass);
1721 MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
1722 bool HasOclSampler = DemangledCall.contains_insensitive("ocl_sampler");
1723 bool HasMsaa = DemangledCall.contains_insensitive("msaa");
1724 if (HasOclSampler || HasMsaa)
1725 MRI->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
1726 if (HasOclSampler) {
1727 Register Sampler = Call->Arguments[1];
1728
1729 if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) &&
1730 getDefInstrMaybeConstant(Sampler, MRI)->getOperand(1).isCImm()) {
1731 uint64_t SamplerMask = getIConstVal(Sampler, MRI);
1732 Sampler = GR->buildConstantSampler(
1733 Register(), getSamplerAddressingModeFromBitmask(SamplerMask),
1734 getSamplerParamFromBitmask(SamplerMask),
1735 getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder,
1736 GR->getSPIRVTypeForVReg(Sampler));
1737 }
1738 SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
1739 SPIRVType *SampledImageType =
1740 GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
1741 Register SampledImage = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1742
1743 MIRBuilder.buildInstr(SPIRV::OpSampledImage)
1744 .addDef(SampledImage)
1745 .addUse(GR->getSPIRVTypeID(SampledImageType))
1746 .addUse(Image)
1747 .addUse(Sampler);
1748
1749 Register Lod = GR->buildConstantFP(APFloat::getZero(APFloat::IEEEsingle()),
1750 MIRBuilder);
1751 SPIRVType *TempType = Call->ReturnType;
1752 bool NeedsExtraction = false;
1753 if (TempType->getOpcode() != SPIRV::OpTypeVector) {
1754 TempType =
1755 GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder);
1756 NeedsExtraction = true;
1757 }
1758 LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType));
1759 Register TempRegister = MRI->createGenericVirtualRegister(LLType);
1760 MRI->setRegClass(TempRegister, &SPIRV::IDRegClass);
1761 GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF());
1762
1763 MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
1764 .addDef(NeedsExtraction ? TempRegister : Call->ReturnRegister)
1765 .addUse(GR->getSPIRVTypeID(TempType))
1766 .addUse(SampledImage)
1767 .addUse(Call->Arguments[2]) // Coordinate.
1768 .addImm(SPIRV::ImageOperand::Lod)
1769 .addUse(Lod);
1770
1771 if (NeedsExtraction)
1772 MIRBuilder.buildInstr(SPIRV::OpCompositeExtract)
1773 .addDef(Call->ReturnRegister)
1774 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1775 .addUse(TempRegister)
1776 .addImm(0);
1777 } else if (HasMsaa) {
1778 MIRBuilder.buildInstr(SPIRV::OpImageRead)
1779 .addDef(Call->ReturnRegister)
1780 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1781 .addUse(Image)
1782 .addUse(Call->Arguments[1]) // Coordinate.
1783 .addImm(SPIRV::ImageOperand::Sample)
1784 .addUse(Call->Arguments[2]);
1785 } else {
1786 MIRBuilder.buildInstr(SPIRV::OpImageRead)
1787 .addDef(Call->ReturnRegister)
1788 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
1789 .addUse(Image)
1790 .addUse(Call->Arguments[1]); // Coordinate.
1791 }
1792 return true;
1793 }
1794
generateWriteImageInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1795 static bool generateWriteImageInst(const SPIRV::IncomingCall *Call,
1796 MachineIRBuilder &MIRBuilder,
1797 SPIRVGlobalRegistry *GR) {
1798 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1799 MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
1800 MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
1801 MIRBuilder.buildInstr(SPIRV::OpImageWrite)
1802 .addUse(Call->Arguments[0]) // Image.
1803 .addUse(Call->Arguments[1]) // Coordinate.
1804 .addUse(Call->Arguments[2]); // Texel.
1805 return true;
1806 }
1807
generateSampleImageInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1808 static bool generateSampleImageInst(const StringRef DemangledCall,
1809 const SPIRV::IncomingCall *Call,
1810 MachineIRBuilder &MIRBuilder,
1811 SPIRVGlobalRegistry *GR) {
1812 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1813 if (Call->Builtin->Name.contains_insensitive(
1814 "__translate_sampler_initializer")) {
1815 // Build sampler literal.
1816 uint64_t Bitmask = getIConstVal(Call->Arguments[0], MRI);
1817 Register Sampler = GR->buildConstantSampler(
1818 Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
1819 getSamplerParamFromBitmask(Bitmask),
1820 getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder, Call->ReturnType);
1821 return Sampler.isValid();
1822 } else if (Call->Builtin->Name.contains_insensitive("__spirv_SampledImage")) {
1823 // Create OpSampledImage.
1824 Register Image = Call->Arguments[0];
1825 SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
1826 SPIRVType *SampledImageType =
1827 GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder);
1828 Register SampledImage =
1829 Call->ReturnRegister.isValid()
1830 ? Call->ReturnRegister
1831 : MRI->createVirtualRegister(&SPIRV::IDRegClass);
1832 MIRBuilder.buildInstr(SPIRV::OpSampledImage)
1833 .addDef(SampledImage)
1834 .addUse(GR->getSPIRVTypeID(SampledImageType))
1835 .addUse(Image)
1836 .addUse(Call->Arguments[1]); // Sampler.
1837 return true;
1838 } else if (Call->Builtin->Name.contains_insensitive(
1839 "__spirv_ImageSampleExplicitLod")) {
1840 // Sample an image using an explicit level of detail.
1841 std::string ReturnType = DemangledCall.str();
1842 if (DemangledCall.contains("_R")) {
1843 ReturnType = ReturnType.substr(ReturnType.find("_R") + 2);
1844 ReturnType = ReturnType.substr(0, ReturnType.find('('));
1845 }
1846 SPIRVType *Type =
1847 Call->ReturnType
1848 ? Call->ReturnType
1849 : GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
1850 if (!Type) {
1851 std::string DiagMsg =
1852 "Unable to recognize SPIRV type name: " + ReturnType;
1853 report_fatal_error(DiagMsg.c_str());
1854 }
1855 MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1856 MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
1857 MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
1858
1859 MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
1860 .addDef(Call->ReturnRegister)
1861 .addUse(GR->getSPIRVTypeID(Type))
1862 .addUse(Call->Arguments[0]) // Image.
1863 .addUse(Call->Arguments[1]) // Coordinate.
1864 .addImm(SPIRV::ImageOperand::Lod)
1865 .addUse(Call->Arguments[3]);
1866 return true;
1867 }
1868 return false;
1869 }
1870
generateSelectInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder)1871 static bool generateSelectInst(const SPIRV::IncomingCall *Call,
1872 MachineIRBuilder &MIRBuilder) {
1873 MIRBuilder.buildSelect(Call->ReturnRegister, Call->Arguments[0],
1874 Call->Arguments[1], Call->Arguments[2]);
1875 return true;
1876 }
1877
generateConstructInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1878 static bool generateConstructInst(const SPIRV::IncomingCall *Call,
1879 MachineIRBuilder &MIRBuilder,
1880 SPIRVGlobalRegistry *GR) {
1881 return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
1882 GR->getSPIRVTypeID(Call->ReturnType));
1883 }
1884
generateCoopMatrInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1885 static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
1886 MachineIRBuilder &MIRBuilder,
1887 SPIRVGlobalRegistry *GR) {
1888 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1889 unsigned Opcode =
1890 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1891 bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
1892 unsigned ArgSz = Call->Arguments.size();
1893 unsigned LiteralIdx = 0;
1894 if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
1895 LiteralIdx = 3;
1896 else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
1897 LiteralIdx = 4;
1898 SmallVector<uint32_t, 1> ImmArgs;
1899 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1900 if (LiteralIdx > 0)
1901 ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
1902 Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
1903 if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
1904 SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1905 if (!CoopMatrType)
1906 report_fatal_error("Can't find a register's type definition");
1907 MIRBuilder.buildInstr(Opcode)
1908 .addDef(Call->ReturnRegister)
1909 .addUse(TypeReg)
1910 .addUse(CoopMatrType->getOperand(0).getReg());
1911 return true;
1912 }
1913 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
1914 IsSet ? TypeReg : Register(0), ImmArgs);
1915 }
1916
generateSpecConstantInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1917 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
1918 MachineIRBuilder &MIRBuilder,
1919 SPIRVGlobalRegistry *GR) {
1920 // Lookup the instruction opcode in the TableGen records.
1921 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1922 unsigned Opcode =
1923 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1924 const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1925
1926 switch (Opcode) {
1927 case SPIRV::OpSpecConstant: {
1928 // Build the SpecID decoration.
1929 unsigned SpecId =
1930 static_cast<unsigned>(getIConstVal(Call->Arguments[0], MRI));
1931 buildOpDecorate(Call->ReturnRegister, MIRBuilder, SPIRV::Decoration::SpecId,
1932 {SpecId});
1933 // Determine the constant MI.
1934 Register ConstRegister = Call->Arguments[1];
1935 const MachineInstr *Const = getDefInstrMaybeConstant(ConstRegister, MRI);
1936 assert(Const &&
1937 (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
1938 Const->getOpcode() == TargetOpcode::G_FCONSTANT) &&
1939 "Argument should be either an int or floating-point constant");
1940 // Determine the opcode and built the OpSpec MI.
1941 const MachineOperand &ConstOperand = Const->getOperand(1);
1942 if (Call->ReturnType->getOpcode() == SPIRV::OpTypeBool) {
1943 assert(ConstOperand.isCImm() && "Int constant operand is expected");
1944 Opcode = ConstOperand.getCImm()->getValue().getZExtValue()
1945 ? SPIRV::OpSpecConstantTrue
1946 : SPIRV::OpSpecConstantFalse;
1947 }
1948 auto MIB = MIRBuilder.buildInstr(Opcode)
1949 .addDef(Call->ReturnRegister)
1950 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1951
1952 if (Call->ReturnType->getOpcode() != SPIRV::OpTypeBool) {
1953 if (Const->getOpcode() == TargetOpcode::G_CONSTANT)
1954 addNumImm(ConstOperand.getCImm()->getValue(), MIB);
1955 else
1956 addNumImm(ConstOperand.getFPImm()->getValueAPF().bitcastToAPInt(), MIB);
1957 }
1958 return true;
1959 }
1960 case SPIRV::OpSpecConstantComposite: {
1961 auto MIB = MIRBuilder.buildInstr(Opcode)
1962 .addDef(Call->ReturnRegister)
1963 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
1964 for (unsigned i = 0; i < Call->Arguments.size(); i++)
1965 MIB.addUse(Call->Arguments[i]);
1966 return true;
1967 }
1968 default:
1969 return false;
1970 }
1971 }
1972
buildNDRange(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)1973 static bool buildNDRange(const SPIRV::IncomingCall *Call,
1974 MachineIRBuilder &MIRBuilder,
1975 SPIRVGlobalRegistry *GR) {
1976 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1977 MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
1978 SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
1979 assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
1980 PtrType->getOperand(2).isReg());
1981 Register TypeReg = PtrType->getOperand(2).getReg();
1982 SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
1983 MachineFunction &MF = MIRBuilder.getMF();
1984 Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1985 GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF);
1986 // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
1987 // three other arguments, so pass zero constant on absence.
1988 unsigned NumArgs = Call->Arguments.size();
1989 assert(NumArgs >= 2);
1990 Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
1991 MRI->setRegClass(GlobalWorkSize, &SPIRV::IDRegClass);
1992 Register LocalWorkSize =
1993 NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
1994 if (LocalWorkSize.isValid())
1995 MRI->setRegClass(LocalWorkSize, &SPIRV::IDRegClass);
1996 Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
1997 if (GlobalWorkOffset.isValid())
1998 MRI->setRegClass(GlobalWorkOffset, &SPIRV::IDRegClass);
1999 if (NumArgs < 4) {
2000 Register Const;
2001 SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
2002 if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
2003 MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
2004 assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
2005 DefInstr->getOperand(3).isReg());
2006 Register GWSPtr = DefInstr->getOperand(3).getReg();
2007 if (!MRI->getRegClassOrNull(GWSPtr))
2008 MRI->setRegClass(GWSPtr, &SPIRV::IDRegClass);
2009 // TODO: Maybe simplify generation of the type of the fields.
2010 unsigned Size = Call->Builtin->Name == "ndrange_3D" ? 3 : 2;
2011 unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
2012 Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
2013 Type *FieldTy = ArrayType::get(BaseTy, Size);
2014 SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
2015 GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
2016 GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
2017 MIRBuilder.buildInstr(SPIRV::OpLoad)
2018 .addDef(GlobalWorkSize)
2019 .addUse(GR->getSPIRVTypeID(SpvFieldTy))
2020 .addUse(GWSPtr);
2021 const SPIRVSubtarget &ST =
2022 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
2023 Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
2024 SpvFieldTy, *ST.getInstrInfo());
2025 } else {
2026 Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
2027 }
2028 if (!LocalWorkSize.isValid())
2029 LocalWorkSize = Const;
2030 if (!GlobalWorkOffset.isValid())
2031 GlobalWorkOffset = Const;
2032 }
2033 assert(LocalWorkSize.isValid() && GlobalWorkOffset.isValid());
2034 MIRBuilder.buildInstr(SPIRV::OpBuildNDRange)
2035 .addDef(TmpReg)
2036 .addUse(TypeReg)
2037 .addUse(GlobalWorkSize)
2038 .addUse(LocalWorkSize)
2039 .addUse(GlobalWorkOffset);
2040 return MIRBuilder.buildInstr(SPIRV::OpStore)
2041 .addUse(Call->Arguments[0])
2042 .addUse(TmpReg);
2043 }
2044
2045 // TODO: maybe move to the global register.
2046 static SPIRVType *
getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2047 getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
2048 SPIRVGlobalRegistry *GR) {
2049 LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
2050 Type *OpaqueType = StructType::getTypeByName(Context, "spirv.DeviceEvent");
2051 if (!OpaqueType)
2052 OpaqueType = StructType::getTypeByName(Context, "opencl.clk_event_t");
2053 if (!OpaqueType)
2054 OpaqueType = StructType::create(Context, "spirv.DeviceEvent");
2055 unsigned SC0 = storageClassToAddressSpace(SPIRV::StorageClass::Function);
2056 unsigned SC1 = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
2057 Type *PtrType = PointerType::get(PointerType::get(OpaqueType, SC0), SC1);
2058 return GR->getOrCreateSPIRVType(PtrType, MIRBuilder);
2059 }
2060
buildEnqueueKernel(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2061 static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
2062 MachineIRBuilder &MIRBuilder,
2063 SPIRVGlobalRegistry *GR) {
2064 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2065 const DataLayout &DL = MIRBuilder.getDataLayout();
2066 bool IsSpirvOp = Call->isSpirvOp();
2067 bool HasEvents = Call->Builtin->Name.contains("events") || IsSpirvOp;
2068 const SPIRVType *Int32Ty = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
2069
2070 // Make vararg instructions before OpEnqueueKernel.
2071 // Local sizes arguments: Sizes of block invoke arguments. Clang generates
2072 // local size operands as an array, so we need to unpack them.
2073 SmallVector<Register, 16> LocalSizes;
2074 if (Call->Builtin->Name.contains("_varargs") || IsSpirvOp) {
2075 const unsigned LocalSizeArrayIdx = HasEvents ? 9 : 6;
2076 Register GepReg = Call->Arguments[LocalSizeArrayIdx];
2077 MachineInstr *GepMI = MRI->getUniqueVRegDef(GepReg);
2078 assert(isSpvIntrinsic(*GepMI, Intrinsic::spv_gep) &&
2079 GepMI->getOperand(3).isReg());
2080 Register ArrayReg = GepMI->getOperand(3).getReg();
2081 MachineInstr *ArrayMI = MRI->getUniqueVRegDef(ArrayReg);
2082 const Type *LocalSizeTy = getMachineInstrType(ArrayMI);
2083 assert(LocalSizeTy && "Local size type is expected");
2084 const uint64_t LocalSizeNum =
2085 cast<ArrayType>(LocalSizeTy)->getNumElements();
2086 unsigned SC = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
2087 const LLT LLType = LLT::pointer(SC, GR->getPointerSize());
2088 const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
2089 Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
2090 for (unsigned I = 0; I < LocalSizeNum; ++I) {
2091 Register Reg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
2092 MRI->setType(Reg, LLType);
2093 GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
2094 auto GEPInst = MIRBuilder.buildIntrinsic(
2095 Intrinsic::spv_gep, ArrayRef<Register>{Reg}, true, false);
2096 GEPInst
2097 .addImm(GepMI->getOperand(2).getImm()) // In bound.
2098 .addUse(ArrayMI->getOperand(0).getReg()) // Alloca.
2099 .addUse(buildConstantIntReg(0, MIRBuilder, GR)) // Indices.
2100 .addUse(buildConstantIntReg(I, MIRBuilder, GR));
2101 LocalSizes.push_back(Reg);
2102 }
2103 }
2104
2105 // SPIRV OpEnqueueKernel instruction has 10+ arguments.
2106 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEnqueueKernel)
2107 .addDef(Call->ReturnRegister)
2108 .addUse(GR->getSPIRVTypeID(Int32Ty));
2109
2110 // Copy all arguments before block invoke function pointer.
2111 const unsigned BlockFIdx = HasEvents ? 6 : 3;
2112 for (unsigned i = 0; i < BlockFIdx; i++)
2113 MIB.addUse(Call->Arguments[i]);
2114
2115 // If there are no event arguments in the original call, add dummy ones.
2116 if (!HasEvents) {
2117 MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Dummy num events.
2118 Register NullPtr = GR->getOrCreateConstNullPtr(
2119 MIRBuilder, getOrCreateSPIRVDeviceEventPointer(MIRBuilder, GR));
2120 MIB.addUse(NullPtr); // Dummy wait events.
2121 MIB.addUse(NullPtr); // Dummy ret event.
2122 }
2123
2124 MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
2125 assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
2126 // Invoke: Pointer to invoke function.
2127 MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
2128
2129 Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
2130 // Param: Pointer to block literal.
2131 MIB.addUse(BlockLiteralReg);
2132
2133 Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
2134 // TODO: these numbers should be obtained from block literal structure.
2135 // Param Size: Size of block literal structure.
2136 MIB.addUse(buildConstantIntReg(DL.getTypeStoreSize(PType), MIRBuilder, GR));
2137 // Param Aligment: Aligment of block literal structure.
2138 MIB.addUse(
2139 buildConstantIntReg(DL.getPrefTypeAlign(PType).value(), MIRBuilder, GR));
2140
2141 for (unsigned i = 0; i < LocalSizes.size(); i++)
2142 MIB.addUse(LocalSizes[i]);
2143 return true;
2144 }
2145
generateEnqueueInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2146 static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
2147 MachineIRBuilder &MIRBuilder,
2148 SPIRVGlobalRegistry *GR) {
2149 // Lookup the instruction opcode in the TableGen records.
2150 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2151 unsigned Opcode =
2152 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2153
2154 switch (Opcode) {
2155 case SPIRV::OpRetainEvent:
2156 case SPIRV::OpReleaseEvent:
2157 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2158 return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]);
2159 case SPIRV::OpCreateUserEvent:
2160 case SPIRV::OpGetDefaultQueue:
2161 return MIRBuilder.buildInstr(Opcode)
2162 .addDef(Call->ReturnRegister)
2163 .addUse(GR->getSPIRVTypeID(Call->ReturnType));
2164 case SPIRV::OpIsValidEvent:
2165 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2166 return MIRBuilder.buildInstr(Opcode)
2167 .addDef(Call->ReturnRegister)
2168 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2169 .addUse(Call->Arguments[0]);
2170 case SPIRV::OpSetUserEventStatus:
2171 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2172 MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
2173 return MIRBuilder.buildInstr(Opcode)
2174 .addUse(Call->Arguments[0])
2175 .addUse(Call->Arguments[1]);
2176 case SPIRV::OpCaptureEventProfilingInfo:
2177 MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2178 MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
2179 MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
2180 return MIRBuilder.buildInstr(Opcode)
2181 .addUse(Call->Arguments[0])
2182 .addUse(Call->Arguments[1])
2183 .addUse(Call->Arguments[2]);
2184 case SPIRV::OpBuildNDRange:
2185 return buildNDRange(Call, MIRBuilder, GR);
2186 case SPIRV::OpEnqueueKernel:
2187 return buildEnqueueKernel(Call, MIRBuilder, GR);
2188 default:
2189 return false;
2190 }
2191 }
2192
generateAsyncCopy(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2193 static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
2194 MachineIRBuilder &MIRBuilder,
2195 SPIRVGlobalRegistry *GR) {
2196 // Lookup the instruction opcode in the TableGen records.
2197 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2198 unsigned Opcode =
2199 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2200
2201 bool IsSet = Opcode == SPIRV::OpGroupAsyncCopy;
2202 Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
2203 if (Call->isSpirvOp())
2204 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2205 IsSet ? TypeReg : Register(0));
2206
2207 auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);
2208
2209 switch (Opcode) {
2210 case SPIRV::OpGroupAsyncCopy: {
2211 SPIRVType *NewType =
2212 Call->ReturnType->getOpcode() == SPIRV::OpTypeEvent
2213 ? nullptr
2214 : GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder);
2215 Register TypeReg = GR->getSPIRVTypeID(NewType ? NewType : Call->ReturnType);
2216 unsigned NumArgs = Call->Arguments.size();
2217 Register EventReg = Call->Arguments[NumArgs - 1];
2218 bool Res = MIRBuilder.buildInstr(Opcode)
2219 .addDef(Call->ReturnRegister)
2220 .addUse(TypeReg)
2221 .addUse(Scope)
2222 .addUse(Call->Arguments[0])
2223 .addUse(Call->Arguments[1])
2224 .addUse(Call->Arguments[2])
2225 .addUse(Call->Arguments.size() > 4
2226 ? Call->Arguments[3]
2227 : buildConstantIntReg(1, MIRBuilder, GR))
2228 .addUse(EventReg);
2229 if (NewType != nullptr)
2230 insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
2231 MIRBuilder.getMF().getRegInfo());
2232 return Res;
2233 }
2234 case SPIRV::OpGroupWaitEvents:
2235 return MIRBuilder.buildInstr(Opcode)
2236 .addUse(Scope)
2237 .addUse(Call->Arguments[0])
2238 .addUse(Call->Arguments[1]);
2239 default:
2240 return false;
2241 }
2242 }
2243
generateConvertInst(const StringRef DemangledCall,const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2244 static bool generateConvertInst(const StringRef DemangledCall,
2245 const SPIRV::IncomingCall *Call,
2246 MachineIRBuilder &MIRBuilder,
2247 SPIRVGlobalRegistry *GR) {
2248 // Lookup the conversion builtin in the TableGen records.
2249 const SPIRV::ConvertBuiltin *Builtin =
2250 SPIRV::lookupConvertBuiltin(Call->Builtin->Name, Call->Builtin->Set);
2251
2252 if (!Builtin && Call->isSpirvOp()) {
2253 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2254 unsigned Opcode =
2255 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2256 return buildOpFromWrapper(MIRBuilder, Opcode, Call,
2257 GR->getSPIRVTypeID(Call->ReturnType));
2258 }
2259
2260 if (Builtin->IsSaturated)
2261 buildOpDecorate(Call->ReturnRegister, MIRBuilder,
2262 SPIRV::Decoration::SaturatedConversion, {});
2263 if (Builtin->IsRounded)
2264 buildOpDecorate(Call->ReturnRegister, MIRBuilder,
2265 SPIRV::Decoration::FPRoundingMode,
2266 {(unsigned)Builtin->RoundingMode});
2267
2268 std::string NeedExtMsg; // no errors if empty
2269 bool IsRightComponentsNumber = true; // check if input/output accepts vectors
2270 unsigned Opcode = SPIRV::OpNop;
2271 if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
2272 // Int -> ...
2273 if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
2274 // Int -> Int
2275 if (Builtin->IsSaturated)
2276 Opcode = Builtin->IsDestinationSigned ? SPIRV::OpSatConvertUToS
2277 : SPIRV::OpSatConvertSToU;
2278 else
2279 Opcode = Builtin->IsDestinationSigned ? SPIRV::OpUConvert
2280 : SPIRV::OpSConvert;
2281 } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2282 SPIRV::OpTypeFloat)) {
2283 // Int -> Float
2284 if (Builtin->IsBfloat16) {
2285 const auto *ST = static_cast<const SPIRVSubtarget *>(
2286 &MIRBuilder.getMF().getSubtarget());
2287 if (!ST->canUseExtension(
2288 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2289 NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2290 IsRightComponentsNumber =
2291 GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2292 GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2293 Opcode = SPIRV::OpConvertBF16ToFINTEL;
2294 } else {
2295 bool IsSourceSigned =
2296 DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
2297 Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
2298 }
2299 }
2300 } else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
2301 SPIRV::OpTypeFloat)) {
2302 // Float -> ...
2303 if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
2304 // Float -> Int
2305 if (Builtin->IsBfloat16) {
2306 const auto *ST = static_cast<const SPIRVSubtarget *>(
2307 &MIRBuilder.getMF().getSubtarget());
2308 if (!ST->canUseExtension(
2309 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2310 NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2311 IsRightComponentsNumber =
2312 GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2313 GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2314 Opcode = SPIRV::OpConvertFToBF16INTEL;
2315 } else {
2316 Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
2317 : SPIRV::OpConvertFToU;
2318 }
2319 } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2320 SPIRV::OpTypeFloat)) {
2321 // Float -> Float
2322 Opcode = SPIRV::OpFConvert;
2323 }
2324 }
2325
2326 if (!NeedExtMsg.empty()) {
2327 std::string DiagMsg = std::string(Builtin->Name) +
2328 ": the builtin requires the following SPIR-V "
2329 "extension: " +
2330 NeedExtMsg;
2331 report_fatal_error(DiagMsg.c_str(), false);
2332 }
2333 if (!IsRightComponentsNumber) {
2334 std::string DiagMsg =
2335 std::string(Builtin->Name) +
2336 ": result and argument must have the same number of components";
2337 report_fatal_error(DiagMsg.c_str(), false);
2338 }
2339 assert(Opcode != SPIRV::OpNop &&
2340 "Conversion between the types not implemented!");
2341
2342 MIRBuilder.buildInstr(Opcode)
2343 .addDef(Call->ReturnRegister)
2344 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2345 .addUse(Call->Arguments[0]);
2346 return true;
2347 }
2348
generateVectorLoadStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2349 static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
2350 MachineIRBuilder &MIRBuilder,
2351 SPIRVGlobalRegistry *GR) {
2352 // Lookup the vector load/store builtin in the TableGen records.
2353 const SPIRV::VectorLoadStoreBuiltin *Builtin =
2354 SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
2355 Call->Builtin->Set);
2356 // Build extended instruction.
2357 auto MIB =
2358 MIRBuilder.buildInstr(SPIRV::OpExtInst)
2359 .addDef(Call->ReturnRegister)
2360 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
2361 .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
2362 .addImm(Builtin->Number);
2363 for (auto Argument : Call->Arguments)
2364 MIB.addUse(Argument);
2365 if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
2366 MIB.addImm(Builtin->ElementCount);
2367
2368 // Rounding mode should be passed as a last argument in the MI for builtins
2369 // like "vstorea_halfn_r".
2370 if (Builtin->IsRounded)
2371 MIB.addImm(static_cast<uint32_t>(Builtin->RoundingMode));
2372 return true;
2373 }
2374
generateLoadStoreInst(const SPIRV::IncomingCall * Call,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2375 static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
2376 MachineIRBuilder &MIRBuilder,
2377 SPIRVGlobalRegistry *GR) {
2378 // Lookup the instruction opcode in the TableGen records.
2379 const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
2380 unsigned Opcode =
2381 SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
2382 bool IsLoad = Opcode == SPIRV::OpLoad;
2383 // Build the instruction.
2384 auto MIB = MIRBuilder.buildInstr(Opcode);
2385 if (IsLoad) {
2386 MIB.addDef(Call->ReturnRegister);
2387 MIB.addUse(GR->getSPIRVTypeID(Call->ReturnType));
2388 }
2389 // Add a pointer to the value to load/store.
2390 MIB.addUse(Call->Arguments[0]);
2391 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
2392 MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
2393 // Add a value to store.
2394 if (!IsLoad) {
2395 MIB.addUse(Call->Arguments[1]);
2396 MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
2397 }
2398 // Add optional memory attributes and an alignment.
2399 unsigned NumArgs = Call->Arguments.size();
2400 if ((IsLoad && NumArgs >= 2) || NumArgs >= 3) {
2401 MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI));
2402 MRI->setRegClass(Call->Arguments[IsLoad ? 1 : 2], &SPIRV::IDRegClass);
2403 }
2404 if ((IsLoad && NumArgs >= 3) || NumArgs >= 4) {
2405 MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI));
2406 MRI->setRegClass(Call->Arguments[IsLoad ? 2 : 3], &SPIRV::IDRegClass);
2407 }
2408 return true;
2409 }
2410
2411 namespace SPIRV {
2412 // Try to find a builtin function attributes by a demangled function name and
2413 // return a tuple <builtin group, op code, ext instruction number>, or a special
2414 // tuple value <-1, 0, 0> if the builtin function is not found.
2415 // Not all builtin functions are supported, only those with a ready-to-use op
2416 // code or instruction number defined in TableGen.
2417 // TODO: consider a major rework of mapping demangled calls into a builtin
2418 // functions to unify search and decrease number of individual cases.
2419 std::tuple<int, unsigned, unsigned>
mapBuiltinToOpcode(const StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set)2420 mapBuiltinToOpcode(const StringRef DemangledCall,
2421 SPIRV::InstructionSet::InstructionSet Set) {
2422 Register Reg;
2423 SmallVector<Register> Args;
2424 std::unique_ptr<const IncomingCall> Call =
2425 lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
2426 if (!Call)
2427 return std::make_tuple(-1, 0, 0);
2428
2429 switch (Call->Builtin->Group) {
2430 case SPIRV::Relational:
2431 case SPIRV::Atomic:
2432 case SPIRV::Barrier:
2433 case SPIRV::CastToPtr:
2434 case SPIRV::ImageMiscQuery:
2435 case SPIRV::SpecConstant:
2436 case SPIRV::Enqueue:
2437 case SPIRV::AsyncCopy:
2438 case SPIRV::LoadStore:
2439 case SPIRV::CoopMatr:
2440 if (const auto *R =
2441 SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
2442 return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2443 break;
2444 case SPIRV::Extended:
2445 if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
2446 Call->Builtin->Set))
2447 return std::make_tuple(Call->Builtin->Group, 0, R->Number);
2448 break;
2449 case SPIRV::VectorLoadStore:
2450 if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
2451 Call->Builtin->Set))
2452 return std::make_tuple(SPIRV::Extended, 0, R->Number);
2453 break;
2454 case SPIRV::Group:
2455 if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
2456 return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2457 break;
2458 case SPIRV::AtomicFloating:
2459 if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
2460 return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2461 break;
2462 case SPIRV::IntelSubgroups:
2463 if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
2464 return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2465 break;
2466 case SPIRV::GroupUniform:
2467 if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
2468 return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
2469 break;
2470 case SPIRV::WriteImage:
2471 return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
2472 case SPIRV::Select:
2473 return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
2474 case SPIRV::Construct:
2475 return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
2476 0);
2477 case SPIRV::KernelClock:
2478 return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
2479 default:
2480 return std::make_tuple(-1, 0, 0);
2481 }
2482 return std::make_tuple(-1, 0, 0);
2483 }
2484
lowerBuiltin(const StringRef DemangledCall,SPIRV::InstructionSet::InstructionSet Set,MachineIRBuilder & MIRBuilder,const Register OrigRet,const Type * OrigRetTy,const SmallVectorImpl<Register> & Args,SPIRVGlobalRegistry * GR)2485 std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
2486 SPIRV::InstructionSet::InstructionSet Set,
2487 MachineIRBuilder &MIRBuilder,
2488 const Register OrigRet, const Type *OrigRetTy,
2489 const SmallVectorImpl<Register> &Args,
2490 SPIRVGlobalRegistry *GR) {
2491 LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
2492
2493 // SPIR-V type and return register.
2494 Register ReturnRegister = OrigRet;
2495 SPIRVType *ReturnType = nullptr;
2496 if (OrigRetTy && !OrigRetTy->isVoidTy()) {
2497 ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder);
2498 if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
2499 MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::IDRegClass);
2500 } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
2501 ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
2502 MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32));
2503 ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
2504 }
2505
2506 // Lookup the builtin in the TableGen records.
2507 std::unique_ptr<const IncomingCall> Call =
2508 lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
2509
2510 if (!Call) {
2511 LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
2512 return std::nullopt;
2513 }
2514
2515 // TODO: check if the provided args meet the builtin requirments.
2516 assert(Args.size() >= Call->Builtin->MinNumArgs &&
2517 "Too few arguments to generate the builtin");
2518 if (Call->Builtin->MaxNumArgs && Args.size() > Call->Builtin->MaxNumArgs)
2519 LLVM_DEBUG(dbgs() << "More arguments provided than required!\n");
2520
2521 // Match the builtin with implementation based on the grouping.
2522 switch (Call->Builtin->Group) {
2523 case SPIRV::Extended:
2524 return generateExtInst(Call.get(), MIRBuilder, GR);
2525 case SPIRV::Relational:
2526 return generateRelationalInst(Call.get(), MIRBuilder, GR);
2527 case SPIRV::Group:
2528 return generateGroupInst(Call.get(), MIRBuilder, GR);
2529 case SPIRV::Variable:
2530 return generateBuiltinVar(Call.get(), MIRBuilder, GR);
2531 case SPIRV::Atomic:
2532 return generateAtomicInst(Call.get(), MIRBuilder, GR);
2533 case SPIRV::AtomicFloating:
2534 return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
2535 case SPIRV::Barrier:
2536 return generateBarrierInst(Call.get(), MIRBuilder, GR);
2537 case SPIRV::CastToPtr:
2538 return generateCastToPtrInst(Call.get(), MIRBuilder);
2539 case SPIRV::Dot:
2540 return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
2541 case SPIRV::Wave:
2542 return generateWaveInst(Call.get(), MIRBuilder, GR);
2543 case SPIRV::GetQuery:
2544 return generateGetQueryInst(Call.get(), MIRBuilder, GR);
2545 case SPIRV::ImageSizeQuery:
2546 return generateImageSizeQueryInst(Call.get(), MIRBuilder, GR);
2547 case SPIRV::ImageMiscQuery:
2548 return generateImageMiscQueryInst(Call.get(), MIRBuilder, GR);
2549 case SPIRV::ReadImage:
2550 return generateReadImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
2551 case SPIRV::WriteImage:
2552 return generateWriteImageInst(Call.get(), MIRBuilder, GR);
2553 case SPIRV::SampleImage:
2554 return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
2555 case SPIRV::Select:
2556 return generateSelectInst(Call.get(), MIRBuilder);
2557 case SPIRV::Construct:
2558 return generateConstructInst(Call.get(), MIRBuilder, GR);
2559 case SPIRV::SpecConstant:
2560 return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
2561 case SPIRV::Enqueue:
2562 return generateEnqueueInst(Call.get(), MIRBuilder, GR);
2563 case SPIRV::AsyncCopy:
2564 return generateAsyncCopy(Call.get(), MIRBuilder, GR);
2565 case SPIRV::Convert:
2566 return generateConvertInst(DemangledCall, Call.get(), MIRBuilder, GR);
2567 case SPIRV::VectorLoadStore:
2568 return generateVectorLoadStoreInst(Call.get(), MIRBuilder, GR);
2569 case SPIRV::LoadStore:
2570 return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
2571 case SPIRV::IntelSubgroups:
2572 return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
2573 case SPIRV::GroupUniform:
2574 return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
2575 case SPIRV::KernelClock:
2576 return generateKernelClockInst(Call.get(), MIRBuilder, GR);
2577 case SPIRV::CoopMatr:
2578 return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
2579 }
2580 return false;
2581 }
2582
parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,unsigned ArgIdx,LLVMContext & Ctx)2583 Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2584 unsigned ArgIdx, LLVMContext &Ctx) {
2585 SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2586 StringRef BuiltinArgs =
2587 DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
2588 BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
2589 if (ArgIdx >= BuiltinArgsTypeStrs.size())
2590 return nullptr;
2591 StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2592
2593 // Parse strings representing OpenCL builtin types.
2594 if (hasBuiltinTypePrefix(TypeStr)) {
2595 // OpenCL builtin types in demangled call strings have the following format:
2596 // e.g. ocl_image2d_ro
2597 [[maybe_unused]] bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
2598 assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
2599
2600 // Check if this is pointer to a builtin type and not just pointer
2601 // representing a builtin type. In case it is a pointer to builtin type,
2602 // this will require additional handling in the method calling
2603 // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
2604 // base types.
2605 if (TypeStr.ends_with("*"))
2606 TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" *"));
2607
2608 return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
2609 Ctx);
2610 }
2611
2612 // Parse type name in either "typeN" or "type vector[N]" format, where
2613 // N is the number of elements of the vector.
2614 Type *BaseType;
2615 unsigned VecElts = 0;
2616
2617 BaseType = parseBasicTypeName(TypeStr, Ctx);
2618 if (!BaseType)
2619 // Unable to recognize SPIRV type name.
2620 return nullptr;
2621
2622 // Handle "typeN*" or "type vector[N]*".
2623 TypeStr.consume_back("*");
2624
2625 if (TypeStr.consume_front(" vector["))
2626 TypeStr = TypeStr.substr(0, TypeStr.find(']'));
2627
2628 TypeStr.getAsInteger(10, VecElts);
2629 if (VecElts > 0)
2630 BaseType = VectorType::get(
2631 BaseType->isVoidTy() ? Type::getInt8Ty(Ctx) : BaseType, VecElts, false);
2632
2633 return BaseType;
2634 }
2635
2636 struct BuiltinType {
2637 StringRef Name;
2638 uint32_t Opcode;
2639 };
2640
2641 #define GET_BuiltinTypes_DECL
2642 #define GET_BuiltinTypes_IMPL
2643
2644 struct OpenCLType {
2645 StringRef Name;
2646 StringRef SpirvTypeLiteral;
2647 };
2648
2649 #define GET_OpenCLTypes_DECL
2650 #define GET_OpenCLTypes_IMPL
2651
2652 #include "SPIRVGenTables.inc"
2653 } // namespace SPIRV
2654
2655 //===----------------------------------------------------------------------===//
2656 // Misc functions for parsing builtin types.
2657 //===----------------------------------------------------------------------===//
2658
parseTypeString(const StringRef Name,LLVMContext & Context)2659 static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
2660 if (Name.starts_with("void"))
2661 return Type::getVoidTy(Context);
2662 else if (Name.starts_with("int") || Name.starts_with("uint"))
2663 return Type::getInt32Ty(Context);
2664 else if (Name.starts_with("float"))
2665 return Type::getFloatTy(Context);
2666 else if (Name.starts_with("half"))
2667 return Type::getHalfTy(Context);
2668 report_fatal_error("Unable to recognize type!");
2669 }
2670
2671 //===----------------------------------------------------------------------===//
2672 // Implementation functions for builtin types.
2673 //===----------------------------------------------------------------------===//
2674
getNonParameterizedType(const TargetExtType * ExtensionType,const SPIRV::BuiltinType * TypeRecord,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2675 static SPIRVType *getNonParameterizedType(const TargetExtType *ExtensionType,
2676 const SPIRV::BuiltinType *TypeRecord,
2677 MachineIRBuilder &MIRBuilder,
2678 SPIRVGlobalRegistry *GR) {
2679 unsigned Opcode = TypeRecord->Opcode;
2680 // Create or get an existing type from GlobalRegistry.
2681 return GR->getOrCreateOpTypeByOpcode(ExtensionType, MIRBuilder, Opcode);
2682 }
2683
getSamplerType(MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2684 static SPIRVType *getSamplerType(MachineIRBuilder &MIRBuilder,
2685 SPIRVGlobalRegistry *GR) {
2686 // Create or get an existing type from GlobalRegistry.
2687 return GR->getOrCreateOpTypeSampler(MIRBuilder);
2688 }
2689
getPipeType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2690 static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
2691 MachineIRBuilder &MIRBuilder,
2692 SPIRVGlobalRegistry *GR) {
2693 assert(ExtensionType->getNumIntParameters() == 1 &&
2694 "Invalid number of parameters for SPIR-V pipe builtin!");
2695 // Create or get an existing type from GlobalRegistry.
2696 return GR->getOrCreateOpTypePipe(MIRBuilder,
2697 SPIRV::AccessQualifier::AccessQualifier(
2698 ExtensionType->getIntParameter(0)));
2699 }
2700
getCoopMatrType(const TargetExtType * ExtensionType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2701 static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
2702 MachineIRBuilder &MIRBuilder,
2703 SPIRVGlobalRegistry *GR) {
2704 assert(ExtensionType->getNumIntParameters() == 4 &&
2705 "Invalid number of parameters for SPIR-V coop matrices builtin!");
2706 assert(ExtensionType->getNumTypeParameters() == 1 &&
2707 "SPIR-V coop matrices builtin type must have a type parameter!");
2708 const SPIRVType *ElemType =
2709 GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
2710 // Create or get an existing type from GlobalRegistry.
2711 return GR->getOrCreateOpTypeCoopMatr(
2712 MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
2713 ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
2714 ExtensionType->getIntParameter(3));
2715 }
2716
2717 static SPIRVType *
getImageType(const TargetExtType * ExtensionType,const SPIRV::AccessQualifier::AccessQualifier Qualifier,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2718 getImageType(const TargetExtType *ExtensionType,
2719 const SPIRV::AccessQualifier::AccessQualifier Qualifier,
2720 MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
2721 assert(ExtensionType->getNumTypeParameters() == 1 &&
2722 "SPIR-V image builtin type must have sampled type parameter!");
2723 const SPIRVType *SampledType =
2724 GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
2725 assert(ExtensionType->getNumIntParameters() == 7 &&
2726 "Invalid number of parameters for SPIR-V image builtin!");
2727 // Create or get an existing type from GlobalRegistry.
2728 return GR->getOrCreateOpTypeImage(
2729 MIRBuilder, SampledType,
2730 SPIRV::Dim::Dim(ExtensionType->getIntParameter(0)),
2731 ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
2732 ExtensionType->getIntParameter(3), ExtensionType->getIntParameter(4),
2733 SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(5)),
2734 Qualifier == SPIRV::AccessQualifier::WriteOnly
2735 ? SPIRV::AccessQualifier::WriteOnly
2736 : SPIRV::AccessQualifier::AccessQualifier(
2737 ExtensionType->getIntParameter(6)));
2738 }
2739
getSampledImageType(const TargetExtType * OpaqueType,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2740 static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
2741 MachineIRBuilder &MIRBuilder,
2742 SPIRVGlobalRegistry *GR) {
2743 SPIRVType *OpaqueImageType = getImageType(
2744 OpaqueType, SPIRV::AccessQualifier::ReadOnly, MIRBuilder, GR);
2745 // Create or get an existing type from GlobalRegistry.
2746 return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
2747 }
2748
2749 namespace SPIRV {
parseBuiltinTypeNameToTargetExtType(std::string TypeName,LLVMContext & Context)2750 TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
2751 LLVMContext &Context) {
2752 StringRef NameWithParameters = TypeName;
2753
2754 // Pointers-to-opaque-structs representing OpenCL types are first translated
2755 // to equivalent SPIR-V types. OpenCL builtin type names should have the
2756 // following format: e.g. %opencl.event_t
2757 if (NameWithParameters.starts_with("opencl.")) {
2758 const SPIRV::OpenCLType *OCLTypeRecord =
2759 SPIRV::lookupOpenCLType(NameWithParameters);
2760 if (!OCLTypeRecord)
2761 report_fatal_error("Missing TableGen record for OpenCL type: " +
2762 NameWithParameters);
2763 NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
2764 // Continue with the SPIR-V builtin type...
2765 }
2766
2767 // Names of the opaque structs representing a SPIR-V builtins without
2768 // parameters should have the following format: e.g. %spirv.Event
2769 assert(NameWithParameters.starts_with("spirv.") &&
2770 "Unknown builtin opaque type!");
2771
2772 // Parameterized SPIR-V builtins names follow this format:
2773 // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
2774 if (!NameWithParameters.contains('_'))
2775 return TargetExtType::get(Context, NameWithParameters);
2776
2777 SmallVector<StringRef> Parameters;
2778 unsigned BaseNameLength = NameWithParameters.find('_') - 1;
2779 SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");
2780
2781 SmallVector<Type *, 1> TypeParameters;
2782 bool HasTypeParameter = !isDigit(Parameters[0][0]);
2783 if (HasTypeParameter)
2784 TypeParameters.push_back(parseTypeString(Parameters[0], Context));
2785 SmallVector<unsigned> IntParameters;
2786 for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
2787 unsigned IntParameter = 0;
2788 bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
2789 (void)ValidLiteral;
2790 assert(ValidLiteral &&
2791 "Invalid format of SPIR-V builtin parameter literal!");
2792 IntParameters.push_back(IntParameter);
2793 }
2794 return TargetExtType::get(Context,
2795 NameWithParameters.substr(0, BaseNameLength),
2796 TypeParameters, IntParameters);
2797 }
2798
lowerBuiltinType(const Type * OpaqueType,SPIRV::AccessQualifier::AccessQualifier AccessQual,MachineIRBuilder & MIRBuilder,SPIRVGlobalRegistry * GR)2799 SPIRVType *lowerBuiltinType(const Type *OpaqueType,
2800 SPIRV::AccessQualifier::AccessQualifier AccessQual,
2801 MachineIRBuilder &MIRBuilder,
2802 SPIRVGlobalRegistry *GR) {
2803 // In LLVM IR, SPIR-V and OpenCL builtin types are represented as either
2804 // target(...) target extension types or pointers-to-opaque-structs. The
2805 // approach relying on structs is deprecated and works only in the non-opaque
2806 // pointer mode (-opaque-pointers=0).
2807 // In order to maintain compatibility with LLVM IR generated by older versions
2808 // of Clang and LLVM/SPIR-V Translator, the pointers-to-opaque-structs are
2809 // "translated" to target extension types. This translation is temporary and
2810 // will be removed in the future release of LLVM.
2811 const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
2812 if (!BuiltinType)
2813 BuiltinType = parseBuiltinTypeNameToTargetExtType(
2814 OpaqueType->getStructName().str(), MIRBuilder.getContext());
2815
2816 unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
2817
2818 const StringRef Name = BuiltinType->getName();
2819 LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
2820
2821 // Lookup the demangled builtin type in the TableGen records.
2822 const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
2823 if (!TypeRecord)
2824 report_fatal_error("Missing TableGen record for builtin type: " + Name);
2825
2826 // "Lower" the BuiltinType into TargetType. The following get<...>Type methods
2827 // use the implementation details from TableGen records or TargetExtType
2828 // parameters to either create a new OpType<...> machine instruction or get an
2829 // existing equivalent SPIRVType from GlobalRegistry.
2830 SPIRVType *TargetType;
2831 switch (TypeRecord->Opcode) {
2832 case SPIRV::OpTypeImage:
2833 TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
2834 break;
2835 case SPIRV::OpTypePipe:
2836 TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
2837 break;
2838 case SPIRV::OpTypeDeviceEvent:
2839 TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
2840 break;
2841 case SPIRV::OpTypeSampler:
2842 TargetType = getSamplerType(MIRBuilder, GR);
2843 break;
2844 case SPIRV::OpTypeSampledImage:
2845 TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
2846 break;
2847 case SPIRV::OpTypeCooperativeMatrixKHR:
2848 TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
2849 break;
2850 default:
2851 TargetType =
2852 getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
2853 break;
2854 }
2855
2856 // Emit OpName instruction if a new OpType<...> instruction was added
2857 // (equivalent type was not found in GlobalRegistry).
2858 if (NumStartingVRegs < MIRBuilder.getMRI()->getNumVirtRegs())
2859 buildOpName(GR->getSPIRVTypeID(TargetType), Name, MIRBuilder);
2860
2861 return TargetType;
2862 }
2863 } // namespace SPIRV
2864 } // namespace llvm
2865