1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVUtils.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <cassert>
31 #include <functional>
32
33 using namespace llvm;
34
allowEmitFakeUse(const Value * Arg)35 static bool allowEmitFakeUse(const Value *Arg) {
36 if (isSpvIntrinsic(Arg))
37 return false;
38 if (isa<AtomicCmpXchgInst, InsertValueInst, UndefValue>(Arg))
39 return false;
40 if (const auto *LI = dyn_cast<LoadInst>(Arg))
41 if (LI->getType()->isAggregateType())
42 return false;
43 return true;
44 }
45
typeToAddressSpace(const Type * Ty)46 static unsigned typeToAddressSpace(const Type *Ty) {
47 if (auto PType = dyn_cast<TypedPointerType>(Ty))
48 return PType->getAddressSpace();
49 if (auto PType = dyn_cast<PointerType>(Ty))
50 return PType->getAddressSpace();
51 if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
52 ExtTy && isTypedPointerWrapper(ExtTy))
53 return ExtTy->getIntParameter(0);
54 reportFatalInternalError("Unable to convert LLVM type to SPIRVType");
55 }
56
57 static bool
storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC)58 storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
59 switch (SC) {
60 case SPIRV::StorageClass::Uniform:
61 case SPIRV::StorageClass::PushConstant:
62 case SPIRV::StorageClass::StorageBuffer:
63 case SPIRV::StorageClass::PhysicalStorageBufferEXT:
64 return true;
65 case SPIRV::StorageClass::UniformConstant:
66 case SPIRV::StorageClass::Input:
67 case SPIRV::StorageClass::Output:
68 case SPIRV::StorageClass::Workgroup:
69 case SPIRV::StorageClass::CrossWorkgroup:
70 case SPIRV::StorageClass::Private:
71 case SPIRV::StorageClass::Function:
72 case SPIRV::StorageClass::Generic:
73 case SPIRV::StorageClass::AtomicCounter:
74 case SPIRV::StorageClass::Image:
75 case SPIRV::StorageClass::CallableDataNV:
76 case SPIRV::StorageClass::IncomingCallableDataNV:
77 case SPIRV::StorageClass::RayPayloadNV:
78 case SPIRV::StorageClass::HitAttributeNV:
79 case SPIRV::StorageClass::IncomingRayPayloadNV:
80 case SPIRV::StorageClass::ShaderRecordBufferNV:
81 case SPIRV::StorageClass::CodeSectionINTEL:
82 case SPIRV::StorageClass::DeviceOnlyINTEL:
83 case SPIRV::StorageClass::HostOnlyINTEL:
84 return false;
85 }
86 llvm_unreachable("Unknown SPIRV::StorageClass enum");
87 }
88
SPIRVGlobalRegistry(unsigned PointerSize)89 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
90 : PointerSize(PointerSize), Bound(0) {}
91
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)92 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
93 Register VReg,
94 MachineInstr &I,
95 const SPIRVInstrInfo &TII) {
96 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
97 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
98 return SpirvType;
99 }
100
101 SPIRVType *
assignFloatTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)102 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
103 MachineInstr &I,
104 const SPIRVInstrInfo &TII) {
105 SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
106 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
107 return SpirvType;
108 }
109
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)110 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
111 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
112 const SPIRVInstrInfo &TII) {
113 SPIRVType *SpirvType =
114 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
115 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
116 return SpirvType;
117 }
118
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)119 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
120 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
121 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
122 SPIRVType *SpirvType =
123 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
124 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
125 return SpirvType;
126 }
127
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,const MachineFunction & MF)128 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
129 Register VReg,
130 const MachineFunction &MF) {
131 VRegToTypeMap[&MF][VReg] = SpirvType;
132 }
133
createTypeVReg(MachineRegisterInfo & MRI)134 static Register createTypeVReg(MachineRegisterInfo &MRI) {
135 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
136 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
137 return Res;
138 }
139
createTypeVReg(MachineIRBuilder & MIRBuilder)140 inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
141 return createTypeVReg(MIRBuilder.getMF().getRegInfo());
142 }
143
getOpTypeBool(MachineIRBuilder & MIRBuilder)144 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
145 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
146 return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
147 .addDef(createTypeVReg(MIRBuilder));
148 });
149 }
150
adjustOpTypeIntWidth(unsigned Width) const151 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
152 if (Width > 64)
153 report_fatal_error("Unsupported integer width!");
154 const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
155 if (ST.canUseExtension(
156 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
157 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
158 return Width;
159 if (Width <= 8)
160 Width = 8;
161 else if (Width <= 16)
162 Width = 16;
163 else if (Width <= 32)
164 Width = 32;
165 else
166 Width = 64;
167 return Width;
168 }
169
getOpTypeInt(unsigned Width,MachineIRBuilder & MIRBuilder,bool IsSigned)170 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
171 MachineIRBuilder &MIRBuilder,
172 bool IsSigned) {
173 Width = adjustOpTypeIntWidth(Width);
174 const SPIRVSubtarget &ST =
175 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
176 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
177 if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
178 MIRBuilder.buildInstr(SPIRV::OpExtension)
179 .addImm(SPIRV::Extension::SPV_INTEL_int4);
180 MIRBuilder.buildInstr(SPIRV::OpCapability)
181 .addImm(SPIRV::Capability::Int4TypeINTEL);
182 } else if ((!isPowerOf2_32(Width) || Width < 8) &&
183 ST.canUseExtension(
184 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
185 MIRBuilder.buildInstr(SPIRV::OpExtension)
186 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
187 MIRBuilder.buildInstr(SPIRV::OpCapability)
188 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
189 }
190 return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
191 .addDef(createTypeVReg(MIRBuilder))
192 .addImm(Width)
193 .addImm(IsSigned ? 1 : 0);
194 });
195 }
196
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)197 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
198 MachineIRBuilder &MIRBuilder) {
199 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
200 return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
201 .addDef(createTypeVReg(MIRBuilder))
202 .addImm(Width);
203 });
204 }
205
getOpTypeVoid(MachineIRBuilder & MIRBuilder)206 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
207 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
208 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
209 .addDef(createTypeVReg(MIRBuilder));
210 });
211 }
212
invalidateMachineInstr(MachineInstr * MI)213 void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
214 // TODO:
215 // - review other data structure wrt. possible issues related to removal
216 // of a machine instruction during instruction selection.
217 const MachineFunction *MF = MI->getMF();
218 auto It = LastInsertedTypeMap.find(MF);
219 if (It == LastInsertedTypeMap.end())
220 return;
221 if (It->second == MI)
222 LastInsertedTypeMap.erase(MF);
223 // remove from the duplicate tracker to avoid incorrect reuse
224 erase(MI);
225 }
226
createOpType(MachineIRBuilder & MIRBuilder,std::function<MachineInstr * (MachineIRBuilder &)> Op)227 SPIRVType *SPIRVGlobalRegistry::createOpType(
228 MachineIRBuilder &MIRBuilder,
229 std::function<MachineInstr *(MachineIRBuilder &)> Op) {
230 auto oldInsertPoint = MIRBuilder.getInsertPt();
231 MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
232 MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin();
233
234 auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
235 if (LastInsertedType != LastInsertedTypeMap.end()) {
236 auto It = LastInsertedType->second->getIterator();
237 // It might happen that this instruction was removed from the first MBB,
238 // hence the Parent's check.
239 MachineBasicBlock::iterator InsertAt;
240 if (It->getParent() != NewMBB)
241 InsertAt = oldInsertPoint->getParent() == NewMBB
242 ? oldInsertPoint
243 : getInsertPtValidEnd(NewMBB);
244 else if (It->getNextNode())
245 InsertAt = It->getNextNode()->getIterator();
246 else
247 InsertAt = getInsertPtValidEnd(NewMBB);
248 MIRBuilder.setInsertPt(*NewMBB, InsertAt);
249 } else {
250 MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin());
251 auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
252 assert(Result.second);
253 LastInsertedType = Result.first;
254 }
255
256 MachineInstr *Type = Op(MIRBuilder);
257 // We expect all users of this function to insert definitions at the insertion
258 // point set above that is always the first MBB.
259 assert(Type->getParent() == NewMBB);
260 LastInsertedType->second = Type;
261
262 MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
263 return Type;
264 }
265
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)266 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
267 SPIRVType *ElemType,
268 MachineIRBuilder &MIRBuilder) {
269 auto EleOpc = ElemType->getOpcode();
270 (void)EleOpc;
271 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
272 EleOpc == SPIRV::OpTypeBool) &&
273 "Invalid vector element type");
274
275 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
276 return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
277 .addDef(createTypeVReg(MIRBuilder))
278 .addUse(getSPIRVTypeID(ElemType))
279 .addImm(NumElems);
280 });
281 }
282
getOrCreateConstFP(APFloat Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)283 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
284 SPIRVType *SpvType,
285 const SPIRVInstrInfo &TII,
286 bool ZeroAsNull) {
287 LLVMContext &Ctx = CurMF->getFunction().getContext();
288 auto *const CF = ConstantFP::get(Ctx, Val);
289 const MachineInstr *MI = findMI(CF, CurMF);
290 if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
291 MI->getOpcode() == SPIRV::OpConstantF))
292 return MI->getOperand(0).getReg();
293 return createConstFP(CF, I, SpvType, TII, ZeroAsNull);
294 }
295
createConstFP(const ConstantFP * CF,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)296 Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
297 MachineInstr &I, SPIRVType *SpvType,
298 const SPIRVInstrInfo &TII,
299 bool ZeroAsNull) {
300 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
301 LLT LLTy = LLT::scalar(BitWidth);
302 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
303 CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
304 assignFloatTypeToVReg(BitWidth, Res, I, TII);
305
306 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
307 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
308 SPIRVType *NewType =
309 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
310 MachineInstrBuilder MIB;
311 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
312 if (CF->getValue().isPosZero() && ZeroAsNull) {
313 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
314 .addDef(Res)
315 .addUse(getSPIRVTypeID(SpvType));
316 } else {
317 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
318 .addDef(Res)
319 .addUse(getSPIRVTypeID(SpvType));
320 addNumImm(APInt(BitWidth,
321 CF->getValueAPF().bitcastToAPInt().getZExtValue()),
322 MIB);
323 }
324 const auto &ST = CurMF->getSubtarget();
325 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
326 *ST.getRegisterInfo(),
327 *ST.getRegBankInfo());
328 return MIB;
329 });
330 add(CF, NewType);
331 return Res;
332 }
333
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)334 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
335 SPIRVType *SpvType,
336 const SPIRVInstrInfo &TII,
337 bool ZeroAsNull) {
338 const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
339 auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
340 const MachineInstr *MI = findMI(CI, CurMF);
341 if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
342 MI->getOpcode() == SPIRV::OpConstantI))
343 return MI->getOperand(0).getReg();
344 return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
345 }
346
createConstInt(const ConstantInt * CI,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)347 Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
348 MachineInstr &I,
349 SPIRVType *SpvType,
350 const SPIRVInstrInfo &TII,
351 bool ZeroAsNull) {
352 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
353 LLT LLTy = LLT::scalar(BitWidth);
354 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
355 CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
356 assignIntTypeToVReg(BitWidth, Res, I, TII);
357
358 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
359 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
360 SPIRVType *NewType =
361 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
362 MachineInstrBuilder MIB;
363 if (BitWidth == 1) {
364 MIB = MIRBuilder
365 .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
366 : SPIRV::OpConstantTrue)
367 .addDef(Res)
368 .addUse(getSPIRVTypeID(SpvType));
369 } else if (!CI->isZero() || !ZeroAsNull) {
370 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
371 .addDef(Res)
372 .addUse(getSPIRVTypeID(SpvType));
373 addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
374 } else {
375 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
376 .addDef(Res)
377 .addUse(getSPIRVTypeID(SpvType));
378 }
379 const auto &ST = CurMF->getSubtarget();
380 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
381 *ST.getRegisterInfo(),
382 *ST.getRegBankInfo());
383 return MIB;
384 });
385 add(CI, NewType);
386 return Res;
387 }
388
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,bool ZeroAsNull)389 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
390 MachineIRBuilder &MIRBuilder,
391 SPIRVType *SpvType, bool EmitIR,
392 bool ZeroAsNull) {
393 assert(SpvType);
394 auto &MF = MIRBuilder.getMF();
395 const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
396 auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
397 Register Res = find(CI, &MF);
398 if (Res.isValid())
399 return Res;
400
401 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
402 LLT LLTy = LLT::scalar(BitWidth);
403 MachineRegisterInfo &MRI = MF.getRegInfo();
404 Res = MRI.createGenericVirtualRegister(LLTy);
405 MRI.setRegClass(Res, &SPIRV::iIDRegClass);
406 assignTypeToVReg(Ty, Res, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
407 EmitIR);
408
409 SPIRVType *NewType =
410 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
411 if (EmitIR)
412 return MIRBuilder.buildConstant(Res, *CI);
413 Register SpvTypeReg = getSPIRVTypeID(SpvType);
414 MachineInstrBuilder MIB;
415 if (Val || !ZeroAsNull) {
416 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
417 .addDef(Res)
418 .addUse(SpvTypeReg);
419 addNumImm(APInt(BitWidth, Val), MIB);
420 } else {
421 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
422 .addDef(Res)
423 .addUse(SpvTypeReg);
424 }
425 const auto &Subtarget = CurMF->getSubtarget();
426 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
427 *Subtarget.getRegisterInfo(),
428 *Subtarget.getRegBankInfo());
429 return MIB;
430 });
431 add(CI, NewType);
432 return Res;
433 }
434
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)435 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
436 MachineIRBuilder &MIRBuilder,
437 SPIRVType *SpvType) {
438 auto &MF = MIRBuilder.getMF();
439 LLVMContext &Ctx = MF.getFunction().getContext();
440 if (!SpvType)
441 SpvType = getOrCreateSPIRVType(Type::getFloatTy(Ctx), MIRBuilder,
442 SPIRV::AccessQualifier::ReadWrite, true);
443 auto *const CF = ConstantFP::get(Ctx, Val);
444 Register Res = find(CF, &MF);
445 if (Res.isValid())
446 return Res;
447
448 LLT LLTy = LLT::scalar(getScalarOrVectorBitWidth(SpvType));
449 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
450 MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
451 assignSPIRVTypeToVReg(SpvType, Res, MF);
452
453 SPIRVType *NewType =
454 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
455 MachineInstrBuilder MIB;
456 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
457 .addDef(Res)
458 .addUse(getSPIRVTypeID(SpvType));
459 addNumImm(CF->getValueAPF().bitcastToAPInt(), MIB);
460 return MIB;
461 });
462 add(CF, NewType);
463 return Res;
464 }
465
getOrCreateBaseRegister(Constant * Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,unsigned BitWidth,bool ZeroAsNull)466 Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
467 Constant *Val, MachineInstr &I, SPIRVType *SpvType,
468 const SPIRVInstrInfo &TII, unsigned BitWidth, bool ZeroAsNull) {
469 SPIRVType *Type = SpvType;
470 if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
471 SpvType->getOpcode() == SPIRV::OpTypeArray) {
472 auto EleTypeReg = SpvType->getOperand(1).getReg();
473 Type = getSPIRVTypeForVReg(EleTypeReg);
474 }
475 if (Type->getOpcode() == SPIRV::OpTypeFloat) {
476 SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
477 return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
478 SpvBaseType, TII, ZeroAsNull);
479 }
480 assert(Type->getOpcode() == SPIRV::OpTypeInt);
481 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
482 return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
483 SpvBaseType, TII, ZeroAsNull);
484 }
485
getOrCreateCompositeOrNull(Constant * Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,Constant * CA,unsigned BitWidth,unsigned ElemCnt,bool ZeroAsNull)486 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
487 Constant *Val, MachineInstr &I, SPIRVType *SpvType,
488 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
489 unsigned ElemCnt, bool ZeroAsNull) {
490 if (Register R = find(CA, CurMF); R.isValid())
491 return R;
492
493 bool IsNull = Val->isNullValue() && ZeroAsNull;
494 Register ElemReg;
495 if (!IsNull)
496 ElemReg =
497 getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull);
498
499 LLT LLTy = LLT::scalar(64);
500 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
501 CurMF->getRegInfo().setRegClass(Res, getRegClass(SpvType));
502 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
503
504 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
505 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
506 const MachineInstr *NewMI =
507 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
508 MachineInstrBuilder MIB;
509 if (!IsNull) {
510 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
511 .addDef(Res)
512 .addUse(getSPIRVTypeID(SpvType));
513 for (unsigned i = 0; i < ElemCnt; ++i)
514 MIB.addUse(ElemReg);
515 } else {
516 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
517 .addDef(Res)
518 .addUse(getSPIRVTypeID(SpvType));
519 }
520 const auto &Subtarget = CurMF->getSubtarget();
521 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
522 *Subtarget.getRegisterInfo(),
523 *Subtarget.getRegBankInfo());
524 return MIB;
525 });
526 add(CA, NewMI);
527 return Res;
528 }
529
getOrCreateConstVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)530 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
531 MachineInstr &I,
532 SPIRVType *SpvType,
533 const SPIRVInstrInfo &TII,
534 bool ZeroAsNull) {
535 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
536 assert(LLVMTy->isVectorTy());
537 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
538 Type *LLVMBaseTy = LLVMVecTy->getElementType();
539 assert(LLVMBaseTy->isIntegerTy());
540 auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
541 auto *ConstVec =
542 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
543 unsigned BW = getScalarOrVectorBitWidth(SpvType);
544 return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
545 SpvType->getOperand(2).getImm(),
546 ZeroAsNull);
547 }
548
getOrCreateConstVector(APFloat Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)549 Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
550 MachineInstr &I,
551 SPIRVType *SpvType,
552 const SPIRVInstrInfo &TII,
553 bool ZeroAsNull) {
554 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
555 assert(LLVMTy->isVectorTy());
556 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
557 Type *LLVMBaseTy = LLVMVecTy->getElementType();
558 assert(LLVMBaseTy->isFloatingPointTy());
559 auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
560 auto *ConstVec =
561 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
562 unsigned BW = getScalarOrVectorBitWidth(SpvType);
563 return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
564 SpvType->getOperand(2).getImm(),
565 ZeroAsNull);
566 }
567
getOrCreateConstIntArray(uint64_t Val,size_t Num,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)568 Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
569 uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
570 const SPIRVInstrInfo &TII) {
571 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
572 assert(LLVMTy->isArrayTy());
573 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
574 Type *LLVMBaseTy = LLVMArrTy->getElementType();
575 Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
576 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
577 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
578 // The following is reasonably unique key that is better that [Val]. The naive
579 // alternative would be something along the lines of:
580 // SmallVector<Constant *> NumCI(Num, CI);
581 // Constant *UniqueKey =
582 // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
583 // that would be a truly unique but dangerous key, because it could lead to
584 // the creation of constants of arbitrary length (that is, the parameter of
585 // memset) which were missing in the original module.
586 Constant *UniqueKey = ConstantStruct::getAnon(
587 {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
588 ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
589 return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
590 LLVMArrTy->getNumElements());
591 }
592
getOrCreateIntCompositeOrNull(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,Constant * CA,unsigned BitWidth,unsigned ElemCnt)593 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
594 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
595 Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
596 if (Register R = find(CA, CurMF); R.isValid())
597 return R;
598
599 Register ElemReg;
600 if (Val || EmitIR) {
601 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
602 ElemReg = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
603 }
604 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
605 Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
606 CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
607 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
608
609 const MachineInstr *NewMI =
610 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
611 if (EmitIR)
612 return MIRBuilder.buildSplatBuildVector(Res, ElemReg);
613
614 if (Val) {
615 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
616 .addDef(Res)
617 .addUse(getSPIRVTypeID(SpvType));
618 for (unsigned i = 0; i < ElemCnt; ++i)
619 MIB.addUse(ElemReg);
620 return MIB;
621 }
622
623 return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
624 .addDef(Res)
625 .addUse(getSPIRVTypeID(SpvType));
626 });
627 add(CA, NewMI);
628 return Res;
629 }
630
631 Register
getOrCreateConsIntVector(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)632 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
633 MachineIRBuilder &MIRBuilder,
634 SPIRVType *SpvType, bool EmitIR) {
635 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
636 assert(LLVMTy->isVectorTy());
637 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
638 Type *LLVMBaseTy = LLVMVecTy->getElementType();
639 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
640 auto ConstVec =
641 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
642 unsigned BW = getScalarOrVectorBitWidth(SpvType);
643 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
644 ConstVec, BW,
645 SpvType->getOperand(2).getImm());
646 }
647
648 Register
getOrCreateConstNullPtr(MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)649 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
650 SPIRVType *SpvType) {
651 const Type *Ty = getTypeForSPIRVType(SpvType);
652 unsigned AddressSpace = typeToAddressSpace(Ty);
653 Type *ElemTy = ::getPointeeType(Ty);
654 assert(ElemTy);
655 const Constant *CP = ConstantTargetNone::get(
656 dyn_cast<TargetExtType>(getTypedPointerWrapper(ElemTy, AddressSpace)));
657 Register Res = find(CP, CurMF);
658 if (Res.isValid())
659 return Res;
660
661 LLT LLTy = LLT::pointer(AddressSpace, PointerSize);
662 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
663 CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
664 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
665
666 const MachineInstr *NewMI =
667 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
668 return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
669 .addDef(Res)
670 .addUse(getSPIRVTypeID(SpvType));
671 });
672 add(CP, NewMI);
673 return Res;
674 }
675
676 Register
buildConstantSampler(Register ResReg,unsigned AddrMode,unsigned Param,unsigned FilerMode,MachineIRBuilder & MIRBuilder)677 SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode,
678 unsigned Param, unsigned FilerMode,
679 MachineIRBuilder &MIRBuilder) {
680 auto Sampler =
681 ResReg.isValid()
682 ? ResReg
683 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
684 SPIRVType *TypeSampler = getOrCreateOpTypeSampler(MIRBuilder);
685 Register TypeSamplerReg = getSPIRVTypeID(TypeSampler);
686 // We cannot use createOpType() logic here, because of the
687 // GlobalISel/IRTranslator.cpp check for a tail call that expects that
688 // MIRBuilder.getInsertPt() has a previous instruction. If this constant is
689 // inserted as a result of "__translate_sampler_initializer()" this would
690 // break this IRTranslator assumption.
691 MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
692 .addDef(Sampler)
693 .addUse(TypeSamplerReg)
694 .addImm(AddrMode)
695 .addImm(Param)
696 .addImm(FilerMode);
697 return Sampler;
698 }
699
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)700 Register SPIRVGlobalRegistry::buildGlobalVariable(
701 Register ResVReg, SPIRVType *BaseType, StringRef Name,
702 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
703 const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
704 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
705 bool IsInstSelector) {
706 const GlobalVariable *GVar = nullptr;
707 if (GV) {
708 GVar = cast<const GlobalVariable>(GV);
709 } else {
710 // If GV is not passed explicitly, use the name to find or construct
711 // the global variable.
712 Module *M = MIRBuilder.getMF().getFunction().getParent();
713 GVar = M->getGlobalVariable(Name);
714 if (GVar == nullptr) {
715 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
716 // Module takes ownership of the global var.
717 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
718 GlobalValue::ExternalLinkage, nullptr,
719 Twine(Name));
720 }
721 GV = GVar;
722 }
723
724 const MachineFunction *MF = &MIRBuilder.getMF();
725 Register Reg = find(GVar, MF);
726 if (Reg.isValid()) {
727 if (Reg != ResVReg)
728 MIRBuilder.buildCopy(ResVReg, Reg);
729 return ResVReg;
730 }
731
732 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
733 .addDef(ResVReg)
734 .addUse(getSPIRVTypeID(BaseType))
735 .addImm(static_cast<uint32_t>(Storage));
736 if (Init != 0)
737 MIB.addUse(Init->getOperand(0).getReg());
738 // ISel may introduce a new register on this step, so we need to add it to
739 // DT and correct its type avoiding fails on the next stage.
740 if (IsInstSelector) {
741 const auto &Subtarget = CurMF->getSubtarget();
742 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
743 *Subtarget.getRegisterInfo(),
744 *Subtarget.getRegBankInfo());
745 }
746 add(GVar, MIB);
747
748 Reg = MIB->getOperand(0).getReg();
749 addGlobalObject(GVar, MF, Reg);
750
751 // Set to Reg the same type as ResVReg has.
752 auto MRI = MIRBuilder.getMRI();
753 if (Reg != ResVReg) {
754 LLT RegLLTy =
755 LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
756 MRI->setType(Reg, RegLLTy);
757 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
758 } else {
759 // Our knowledge about the type may be updated.
760 // If that's the case, we need to update a type
761 // associated with the register.
762 SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
763 if (!DefType || DefType != BaseType)
764 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
765 }
766
767 // If it's a global variable with name, output OpName for it.
768 if (GVar && GVar->hasName())
769 buildOpName(Reg, GVar->getName(), MIRBuilder);
770
771 // Output decorations for the GV.
772 // TODO: maybe move to GenerateDecorations pass.
773 const SPIRVSubtarget &ST =
774 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
775 if (IsConst && !ST.isShader())
776 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
777
778 if (GVar && GVar->getAlign().valueOrOne().value() != 1 && !ST.isShader()) {
779 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
780 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
781 }
782
783 if (HasLinkageTy)
784 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
785 {static_cast<uint32_t>(LinkageType)}, Name);
786
787 SPIRV::BuiltIn::BuiltIn BuiltInId;
788 if (getSpirvBuiltInIdByName(Name, BuiltInId))
789 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
790 {static_cast<uint32_t>(BuiltInId)});
791
792 // If it's a global variable with "spirv.Decorations" metadata node
793 // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
794 // arguments.
795 MDNode *GVarMD = nullptr;
796 if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
797 buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
798
799 return Reg;
800 }
801
802 // Returns a name based on the Type. Notes that this does not look at
803 // decorations, and will return the same string for two types that are the same
804 // except for decorations.
getOrCreateGlobalVariableWithBinding(const SPIRVType * VarType,uint32_t Set,uint32_t Binding,StringRef Name,MachineIRBuilder & MIRBuilder)805 Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
806 const SPIRVType *VarType, uint32_t Set, uint32_t Binding, StringRef Name,
807 MachineIRBuilder &MIRBuilder) {
808 Register VarReg =
809 MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
810
811 buildGlobalVariable(VarReg, VarType, Name, nullptr,
812 getPointerStorageClass(VarType), nullptr, false, false,
813 SPIRV::LinkageType::Import, MIRBuilder, false);
814
815 buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set});
816 buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding});
817 return VarReg;
818 }
819
820 // TODO: Double check the calls to getOpTypeArray to make sure that `ElemType`
821 // is explicitly laid out when required.
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool ExplicitLayoutRequired,bool EmitIR)822 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
823 SPIRVType *ElemType,
824 MachineIRBuilder &MIRBuilder,
825 bool ExplicitLayoutRequired,
826 bool EmitIR) {
827 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
828 "Invalid array element type");
829 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
830 SPIRVType *ArrayType = nullptr;
831 if (NumElems != 0) {
832 Register NumElementsVReg =
833 buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
834 ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
835 return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
836 .addDef(createTypeVReg(MIRBuilder))
837 .addUse(getSPIRVTypeID(ElemType))
838 .addUse(NumElementsVReg);
839 });
840 } else {
841 ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
842 return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray)
843 .addDef(createTypeVReg(MIRBuilder))
844 .addUse(getSPIRVTypeID(ElemType));
845 });
846 }
847
848 if (ExplicitLayoutRequired && !isResourceType(ElemType)) {
849 Type *ET = const_cast<Type *>(getTypeForSPIRVType(ElemType));
850 addArrayStrideDecorations(ArrayType->defs().begin()->getReg(), ET,
851 MIRBuilder);
852 }
853
854 return ArrayType;
855 }
856
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)857 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
858 MachineIRBuilder &MIRBuilder) {
859 assert(Ty->hasName());
860 const StringRef Name = Ty->hasName() ? Ty->getName() : "";
861 Register ResVReg = createTypeVReg(MIRBuilder);
862 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
863 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
864 addStringImm(Name, MIB);
865 buildOpName(ResVReg, Name, MIRBuilder);
866 return MIB;
867 });
868 }
869
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,StructOffsetDecorator Decorator,bool EmitIR)870 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
871 const StructType *Ty, MachineIRBuilder &MIRBuilder,
872 SPIRV::AccessQualifier::AccessQualifier AccQual,
873 StructOffsetDecorator Decorator, bool EmitIR) {
874 const SPIRVSubtarget &ST =
875 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
876 SmallVector<Register, 4> FieldTypes;
877 constexpr unsigned MaxWordCount = UINT16_MAX;
878 const size_t NumElements = Ty->getNumElements();
879
880 size_t MaxNumElements = MaxWordCount - 2;
881 size_t SPIRVStructNumElements = NumElements;
882 if (NumElements > MaxNumElements) {
883 // Do adjustments for continued instructions.
884 SPIRVStructNumElements = MaxNumElements;
885 MaxNumElements = MaxWordCount - 1;
886 }
887
888 for (const auto &Elem : Ty->elements()) {
889 SPIRVType *ElemTy = findSPIRVType(
890 toTypedPointer(Elem), MIRBuilder, AccQual,
891 /* ExplicitLayoutRequired= */ Decorator != nullptr, EmitIR);
892 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
893 "Invalid struct element type");
894 FieldTypes.push_back(getSPIRVTypeID(ElemTy));
895 }
896 Register ResVReg = createTypeVReg(MIRBuilder);
897 if (Ty->hasName())
898 buildOpName(ResVReg, Ty->getName(), MIRBuilder);
899 if (Ty->isPacked() && !ST.isShader())
900 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
901
902 SPIRVType *SPVType =
903 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
904 auto MIBStruct =
905 MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
906 for (size_t I = 0; I < SPIRVStructNumElements; ++I)
907 MIBStruct.addUse(FieldTypes[I]);
908 for (size_t I = SPIRVStructNumElements; I < NumElements;
909 I += MaxNumElements) {
910 auto MIBCont =
911 MIRBuilder.buildInstr(SPIRV::OpTypeStructContinuedINTEL);
912 for (size_t J = I; J < std::min(I + MaxNumElements, NumElements); ++J)
913 MIBCont.addUse(FieldTypes[I]);
914 }
915 return MIBStruct;
916 });
917
918 if (Decorator)
919 Decorator(SPVType->defs().begin()->getReg());
920
921 return SPVType;
922 }
923
getOrCreateSpecialType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual)924 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
925 const Type *Ty, MachineIRBuilder &MIRBuilder,
926 SPIRV::AccessQualifier::AccessQualifier AccQual) {
927 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
928 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
929 }
930
getOpTypePointer(SPIRV::StorageClass::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)931 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
932 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
933 MachineIRBuilder &MIRBuilder, Register Reg) {
934 if (!Reg.isValid())
935 Reg = createTypeVReg(MIRBuilder);
936
937 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
938 return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
939 .addDef(Reg)
940 .addImm(static_cast<uint32_t>(SC))
941 .addUse(getSPIRVTypeID(ElemType));
942 });
943 }
944
getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,MachineIRBuilder & MIRBuilder)945 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
946 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
947 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
948 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
949 .addUse(createTypeVReg(MIRBuilder))
950 .addImm(static_cast<uint32_t>(SC));
951 });
952 }
953
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)954 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
955 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
956 MachineIRBuilder &MIRBuilder) {
957 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
958 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
959 .addDef(createTypeVReg(MIRBuilder))
960 .addUse(getSPIRVTypeID(RetType));
961 for (const SPIRVType *ArgType : ArgTypes)
962 MIB.addUse(getSPIRVTypeID(ArgType));
963 return MIB;
964 });
965 }
966
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)967 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
968 const Type *Ty, SPIRVType *RetType,
969 const SmallVectorImpl<SPIRVType *> &ArgTypes,
970 MachineIRBuilder &MIRBuilder) {
971 if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
972 return MI;
973 const MachineInstr *NewMI = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
974 add(Ty, false, NewMI);
975 return finishCreatingSPIRVType(Ty, NewMI);
976 }
977
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool ExplicitLayoutRequired,bool EmitIR)978 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
979 const Type *Ty, MachineIRBuilder &MIRBuilder,
980 SPIRV::AccessQualifier::AccessQualifier AccQual,
981 bool ExplicitLayoutRequired, bool EmitIR) {
982 Ty = adjustIntTypeByWidth(Ty);
983 // TODO: findMI needs to know if a layout is required.
984 if (const MachineInstr *MI =
985 findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF()))
986 return MI;
987 if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end())
988 return It->second;
989 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, ExplicitLayoutRequired,
990 EmitIR);
991 }
992
getSPIRVTypeID(const SPIRVType * SpirvType) const993 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
994 assert(SpirvType && "Attempting to get type id for nullptr type.");
995 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
996 SpirvType->getOpcode() == SPIRV::OpTypeStructContinuedINTEL)
997 return SpirvType->uses().begin()->getReg();
998 return SpirvType->defs().begin()->getReg();
999 }
1000
1001 // We need to use a new LLVM integer type if there is a mismatch between
1002 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker
1003 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
1004 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
1005 // same "OpTypeInt 8" type for a series of LLVM integer types with number of
1006 // bits less than 8. This would lead to duplicate type definitions
1007 // eventually due to the method that DuplicateTracker utilizes to reason
1008 // about uniqueness of type records.
adjustIntTypeByWidth(const Type * Ty) const1009 const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
1010 if (auto IType = dyn_cast<IntegerType>(Ty)) {
1011 unsigned SrcBitWidth = IType->getBitWidth();
1012 if (SrcBitWidth > 1) {
1013 unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
1014 // Maybe change source LLVM type to keep DuplicateTracker consistent.
1015 if (SrcBitWidth != BitWidth)
1016 Ty = IntegerType::get(Ty->getContext(), BitWidth);
1017 }
1018 }
1019 return Ty;
1020 }
1021
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool ExplicitLayoutRequired,bool EmitIR)1022 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
1023 const Type *Ty, MachineIRBuilder &MIRBuilder,
1024 SPIRV::AccessQualifier::AccessQualifier AccQual,
1025 bool ExplicitLayoutRequired, bool EmitIR) {
1026 if (isSpecialOpaqueType(Ty))
1027 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
1028
1029 if (const MachineInstr *MI =
1030 findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF()))
1031 return MI;
1032
1033 if (auto IType = dyn_cast<IntegerType>(Ty)) {
1034 const unsigned Width = IType->getBitWidth();
1035 return Width == 1 ? getOpTypeBool(MIRBuilder)
1036 : getOpTypeInt(Width, MIRBuilder, false);
1037 }
1038 if (Ty->isFloatingPointTy())
1039 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1040 if (Ty->isVoidTy())
1041 return getOpTypeVoid(MIRBuilder);
1042 if (Ty->isVectorTy()) {
1043 SPIRVType *El =
1044 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder,
1045 AccQual, ExplicitLayoutRequired, EmitIR);
1046 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
1047 MIRBuilder);
1048 }
1049 if (Ty->isArrayTy()) {
1050 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder,
1051 AccQual, ExplicitLayoutRequired, EmitIR);
1052 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder,
1053 ExplicitLayoutRequired, EmitIR);
1054 }
1055 if (auto SType = dyn_cast<StructType>(Ty)) {
1056 if (SType->isOpaque())
1057 return getOpTypeOpaque(SType, MIRBuilder);
1058
1059 StructOffsetDecorator Decorator = nullptr;
1060 if (ExplicitLayoutRequired) {
1061 Decorator = [&MIRBuilder, SType, this](Register Reg) {
1062 addStructOffsetDecorations(Reg, const_cast<StructType *>(SType),
1063 MIRBuilder);
1064 };
1065 }
1066 return getOpTypeStruct(SType, MIRBuilder, AccQual, Decorator, EmitIR);
1067 }
1068 if (auto FType = dyn_cast<FunctionType>(Ty)) {
1069 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder,
1070 AccQual, ExplicitLayoutRequired, EmitIR);
1071 SmallVector<SPIRVType *, 4> ParamTypes;
1072 for (const auto &ParamTy : FType->params())
1073 ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual,
1074 ExplicitLayoutRequired, EmitIR));
1075 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
1076 }
1077
1078 unsigned AddrSpace = typeToAddressSpace(Ty);
1079 SPIRVType *SpvElementType = nullptr;
1080 if (Type *ElemTy = ::getPointeeType(Ty))
1081 SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
1082 else
1083 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
1084
1085 // Get access to information about available extensions
1086 const SPIRVSubtarget *ST =
1087 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
1088 auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
1089
1090 Type *ElemTy = ::getPointeeType(Ty);
1091 if (!ElemTy) {
1092 ElemTy = Type::getInt8Ty(MIRBuilder.getContext());
1093 }
1094
1095 // If we have forward pointer associated with this type, use its register
1096 // operand to create OpTypePointer.
1097 if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end()) {
1098 Register Reg = getSPIRVTypeID(It->second);
1099 // TODO: what does getOpTypePointer do?
1100 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
1101 }
1102
1103 return getOrCreateSPIRVPointerType(ElemTy, MIRBuilder, SC);
1104 }
1105
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool ExplicitLayoutRequired,bool EmitIR)1106 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
1107 const Type *Ty, MachineIRBuilder &MIRBuilder,
1108 SPIRV::AccessQualifier::AccessQualifier AccessQual,
1109 bool ExplicitLayoutRequired, bool EmitIR) {
1110 // TODO: Could this create a problem if one requires an explicit layout, and
1111 // the next time it does not?
1112 if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty))
1113 return nullptr;
1114 TypesInProcessing.insert(Ty);
1115 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
1116 ExplicitLayoutRequired, EmitIR);
1117 TypesInProcessing.erase(Ty);
1118 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
1119
1120 // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
1121 // Is that a problem?
1122 SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
1123
1124 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
1125 findMI(Ty, false, &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty))
1126 return SpirvType;
1127
1128 if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
1129 ExtTy && isTypedPointerWrapper(ExtTy))
1130 add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), SpirvType);
1131 else if (!isPointerTy(Ty))
1132 add(Ty, ExplicitLayoutRequired, SpirvType);
1133 else if (isTypedPointerTy(Ty))
1134 add(cast<TypedPointerType>(Ty)->getElementType(),
1135 getPointerAddressSpace(Ty), SpirvType);
1136 else
1137 add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
1138 getPointerAddressSpace(Ty), SpirvType);
1139 return SpirvType;
1140 }
1141
1142 SPIRVType *
getSPIRVTypeForVReg(Register VReg,const MachineFunction * MF) const1143 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
1144 const MachineFunction *MF) const {
1145 auto t = VRegToTypeMap.find(MF ? MF : CurMF);
1146 if (t != VRegToTypeMap.end()) {
1147 auto tt = t->second.find(VReg);
1148 if (tt != t->second.end())
1149 return tt->second;
1150 }
1151 return nullptr;
1152 }
1153
getResultType(Register VReg,MachineFunction * MF)1154 SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
1155 MachineFunction *MF) {
1156 if (!MF)
1157 MF = CurMF;
1158 MachineInstr *Instr = getVRegDef(MF->getRegInfo(), VReg);
1159 return getSPIRVTypeForVReg(Instr->getOperand(1).getReg(), MF);
1160 }
1161
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool ExplicitLayoutRequired,bool EmitIR)1162 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
1163 const Type *Ty, MachineIRBuilder &MIRBuilder,
1164 SPIRV::AccessQualifier::AccessQualifier AccessQual,
1165 bool ExplicitLayoutRequired, bool EmitIR) {
1166 const MachineFunction *MF = &MIRBuilder.getMF();
1167 Register Reg;
1168 if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
1169 ExtTy && isTypedPointerWrapper(ExtTy))
1170 Reg = find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), MF);
1171 else if (!isPointerTy(Ty))
1172 Reg = find(Ty = adjustIntTypeByWidth(Ty), ExplicitLayoutRequired, MF);
1173 else if (isTypedPointerTy(Ty))
1174 Reg = find(cast<TypedPointerType>(Ty)->getElementType(),
1175 getPointerAddressSpace(Ty), MF);
1176 else
1177 Reg = find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
1178 getPointerAddressSpace(Ty), MF);
1179 if (Reg.isValid() && !isSpecialOpaqueType(Ty))
1180 return getSPIRVTypeForVReg(Reg);
1181
1182 TypesInProcessing.clear();
1183 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual,
1184 ExplicitLayoutRequired, EmitIR);
1185 // Create normal pointer types for the corresponding OpTypeForwardPointers.
1186 for (auto &CU : ForwardPointerTypes) {
1187 // Pointer type themselves do not require an explicit layout. The types
1188 // they pointer to might, but that is taken care of when creating the type.
1189 bool PtrNeedsLayout = false;
1190 const Type *Ty2 = CU.first;
1191 SPIRVType *STy2 = CU.second;
1192 if ((Reg = find(Ty2, PtrNeedsLayout, MF)).isValid())
1193 STy2 = getSPIRVTypeForVReg(Reg);
1194 else
1195 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, PtrNeedsLayout,
1196 EmitIR);
1197 if (Ty == Ty2)
1198 STy = STy2;
1199 }
1200 ForwardPointerTypes.clear();
1201 return STy;
1202 }
1203
isScalarOfType(Register VReg,unsigned TypeOpcode) const1204 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
1205 unsigned TypeOpcode) const {
1206 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1207 assert(Type && "isScalarOfType VReg has no type assigned");
1208 return Type->getOpcode() == TypeOpcode;
1209 }
1210
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const1211 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
1212 unsigned TypeOpcode) const {
1213 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1214 assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
1215 if (Type->getOpcode() == TypeOpcode)
1216 return true;
1217 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1218 Register ScalarTypeVReg = Type->getOperand(1).getReg();
1219 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
1220 return ScalarType->getOpcode() == TypeOpcode;
1221 }
1222 return false;
1223 }
1224
isResourceType(SPIRVType * Type) const1225 bool SPIRVGlobalRegistry::isResourceType(SPIRVType *Type) const {
1226 switch (Type->getOpcode()) {
1227 case SPIRV::OpTypeImage:
1228 case SPIRV::OpTypeSampler:
1229 case SPIRV::OpTypeSampledImage:
1230 return true;
1231 case SPIRV::OpTypeStruct:
1232 return hasBlockDecoration(Type);
1233 default:
1234 return false;
1235 }
1236 return false;
1237 }
1238 unsigned
getScalarOrVectorComponentCount(Register VReg) const1239 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
1240 return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
1241 }
1242
1243 unsigned
getScalarOrVectorComponentCount(SPIRVType * Type) const1244 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
1245 if (!Type)
1246 return 0;
1247 return Type->getOpcode() == SPIRV::OpTypeVector
1248 ? static_cast<unsigned>(Type->getOperand(2).getImm())
1249 : 1;
1250 }
1251
1252 SPIRVType *
getScalarOrVectorComponentType(Register VReg) const1253 SPIRVGlobalRegistry::getScalarOrVectorComponentType(Register VReg) const {
1254 return getScalarOrVectorComponentType(getSPIRVTypeForVReg(VReg));
1255 }
1256
1257 SPIRVType *
getScalarOrVectorComponentType(SPIRVType * Type) const1258 SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVType *Type) const {
1259 if (!Type)
1260 return nullptr;
1261 Register ScalarReg = Type->getOpcode() == SPIRV::OpTypeVector
1262 ? Type->getOperand(1).getReg()
1263 : Type->getOperand(0).getReg();
1264 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarReg);
1265 assert(isScalarOrVectorOfType(Type->getOperand(0).getReg(),
1266 ScalarType->getOpcode()));
1267 return ScalarType;
1268 }
1269
1270 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const1271 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1272 assert(Type && "Invalid Type pointer");
1273 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1274 auto EleTypeReg = Type->getOperand(1).getReg();
1275 Type = getSPIRVTypeForVReg(EleTypeReg);
1276 }
1277 if (Type->getOpcode() == SPIRV::OpTypeInt ||
1278 Type->getOpcode() == SPIRV::OpTypeFloat)
1279 return Type->getOperand(1).getImm();
1280 if (Type->getOpcode() == SPIRV::OpTypeBool)
1281 return 1;
1282 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1283 }
1284
getNumScalarOrVectorTotalBitWidth(const SPIRVType * Type) const1285 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1286 const SPIRVType *Type) const {
1287 assert(Type && "Invalid Type pointer");
1288 unsigned NumElements = 1;
1289 if (Type->getOpcode() == SPIRV::OpTypeVector) {
1290 NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
1291 Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1292 }
1293 return Type->getOpcode() == SPIRV::OpTypeInt ||
1294 Type->getOpcode() == SPIRV::OpTypeFloat
1295 ? NumElements * Type->getOperand(1).getImm()
1296 : 0;
1297 }
1298
retrieveScalarOrVectorIntType(const SPIRVType * Type) const1299 const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1300 const SPIRVType *Type) const {
1301 if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1302 Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1303 return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1304 }
1305
isScalarOrVectorSigned(const SPIRVType * Type) const1306 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1307 const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1308 return IntType && IntType->getOperand(2).getImm() != 0;
1309 }
1310
getPointeeType(SPIRVType * PtrType)1311 SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
1312 return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1313 ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
1314 : nullptr;
1315 }
1316
getPointeeTypeOp(Register PtrReg)1317 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1318 SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
1319 return ElemType ? ElemType->getOpcode() : 0;
1320 }
1321
isBitcastCompatible(const SPIRVType * Type1,const SPIRVType * Type2) const1322 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1323 const SPIRVType *Type2) const {
1324 if (!Type1 || !Type2)
1325 return false;
1326 auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1327 // Ignore difference between <1.5 and >=1.5 protocol versions:
1328 // it's valid if either Result Type or Operand is a pointer, and the other
1329 // is a pointer, an integer scalar, or an integer vector.
1330 if (Op1 == SPIRV::OpTypePointer &&
1331 (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
1332 return true;
1333 if (Op2 == SPIRV::OpTypePointer &&
1334 (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
1335 return true;
1336 unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
1337 Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
1338 return Bits1 > 0 && Bits1 == Bits2;
1339 }
1340
1341 SPIRV::StorageClass::StorageClass
getPointerStorageClass(Register VReg) const1342 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1343 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1344 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1345 Type->getOperand(1).isImm() && "Pointer type is expected");
1346 return getPointerStorageClass(Type);
1347 }
1348
1349 SPIRV::StorageClass::StorageClass
getPointerStorageClass(const SPIRVType * Type) const1350 SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
1351 return static_cast<SPIRV::StorageClass::StorageClass>(
1352 Type->getOperand(1).getImm());
1353 }
1354
getOrCreateVulkanBufferType(MachineIRBuilder & MIRBuilder,Type * ElemType,SPIRV::StorageClass::StorageClass SC,bool IsWritable,bool EmitIr)1355 SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
1356 MachineIRBuilder &MIRBuilder, Type *ElemType,
1357 SPIRV::StorageClass::StorageClass SC, bool IsWritable, bool EmitIr) {
1358 auto Key = SPIRV::irhandle_vkbuffer(ElemType, SC, IsWritable);
1359 if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1360 return MI;
1361
1362 bool ExplicitLayoutRequired = storageClassRequiresExplictLayout(SC);
1363 // We need to get the SPIR-V type for the element here, so we can add the
1364 // decoration to it.
1365 auto *T = StructType::create(ElemType);
1366 auto *BlockType =
1367 getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None,
1368 ExplicitLayoutRequired, EmitIr);
1369
1370 buildOpDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
1371 SPIRV::Decoration::Block, {});
1372
1373 if (!IsWritable) {
1374 buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
1375 SPIRV::Decoration::NonWritable, 0, {});
1376 }
1377
1378 SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BlockType, MIRBuilder, SC);
1379 add(Key, R);
1380 return R;
1381 }
1382
getOrCreateLayoutType(MachineIRBuilder & MIRBuilder,const TargetExtType * T,bool EmitIr)1383 SPIRVType *SPIRVGlobalRegistry::getOrCreateLayoutType(
1384 MachineIRBuilder &MIRBuilder, const TargetExtType *T, bool EmitIr) {
1385 auto Key = SPIRV::handle(T);
1386 if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1387 return MI;
1388
1389 StructType *ST = cast<StructType>(T->getTypeParameter(0));
1390 ArrayRef<uint32_t> Offsets = T->int_params().slice(1);
1391 assert(ST->getNumElements() == Offsets.size());
1392
1393 StructOffsetDecorator Decorator = [&MIRBuilder, &Offsets](Register Reg) {
1394 for (uint32_t I = 0; I < Offsets.size(); ++I) {
1395 buildOpMemberDecorate(Reg, MIRBuilder, SPIRV::Decoration::Offset, I,
1396 {Offsets[I]});
1397 }
1398 };
1399
1400 // We need a new OpTypeStruct instruction because decorations will be
1401 // different from a struct with an explicit layout created from a different
1402 // entry point.
1403 SPIRVType *SPIRVStructType = getOpTypeStruct(
1404 ST, MIRBuilder, SPIRV::AccessQualifier::None, Decorator, EmitIr);
1405 add(Key, SPIRVStructType);
1406 return SPIRVStructType;
1407 }
1408
getImageType(const TargetExtType * ExtensionType,const SPIRV::AccessQualifier::AccessQualifier Qualifier,MachineIRBuilder & MIRBuilder)1409 SPIRVType *SPIRVGlobalRegistry::getImageType(
1410 const TargetExtType *ExtensionType,
1411 const SPIRV::AccessQualifier::AccessQualifier Qualifier,
1412 MachineIRBuilder &MIRBuilder) {
1413 assert(ExtensionType->getNumTypeParameters() == 1 &&
1414 "SPIR-V image builtin type must have sampled type parameter!");
1415 const SPIRVType *SampledType =
1416 getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
1417 SPIRV::AccessQualifier::ReadWrite, true);
1418 assert((ExtensionType->getNumIntParameters() == 7 ||
1419 ExtensionType->getNumIntParameters() == 6) &&
1420 "Invalid number of parameters for SPIR-V image builtin!");
1421
1422 SPIRV::AccessQualifier::AccessQualifier accessQualifier =
1423 SPIRV::AccessQualifier::None;
1424 if (ExtensionType->getNumIntParameters() == 7) {
1425 accessQualifier = Qualifier == SPIRV::AccessQualifier::WriteOnly
1426 ? SPIRV::AccessQualifier::WriteOnly
1427 : SPIRV::AccessQualifier::AccessQualifier(
1428 ExtensionType->getIntParameter(6));
1429 }
1430
1431 // Create or get an existing type from GlobalRegistry.
1432 SPIRVType *R = getOrCreateOpTypeImage(
1433 MIRBuilder, SampledType,
1434 SPIRV::Dim::Dim(ExtensionType->getIntParameter(0)),
1435 ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
1436 ExtensionType->getIntParameter(3), ExtensionType->getIntParameter(4),
1437 SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(5)),
1438 accessQualifier);
1439 SPIRVToLLVMType[R] = ExtensionType;
1440 return R;
1441 }
1442
getOrCreateOpTypeImage(MachineIRBuilder & MIRBuilder,SPIRVType * SampledType,SPIRV::Dim::Dim Dim,uint32_t Depth,uint32_t Arrayed,uint32_t Multisampled,uint32_t Sampled,SPIRV::ImageFormat::ImageFormat ImageFormat,SPIRV::AccessQualifier::AccessQualifier AccessQual)1443 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1444 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1445 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1446 SPIRV::ImageFormat::ImageFormat ImageFormat,
1447 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1448 auto Key = SPIRV::irhandle_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1449 Depth, Arrayed, Multisampled, Sampled,
1450 ImageFormat, AccessQual);
1451 if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1452 return MI;
1453 const MachineInstr *NewMI =
1454 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1455 auto MIB =
1456 MIRBuilder.buildInstr(SPIRV::OpTypeImage)
1457 .addDef(createTypeVReg(MIRBuilder))
1458 .addUse(getSPIRVTypeID(SampledType))
1459 .addImm(Dim)
1460 .addImm(Depth) // Depth (whether or not it is a Depth image).
1461 .addImm(Arrayed) // Arrayed.
1462 .addImm(Multisampled) // Multisampled (0 = only single-sample).
1463 .addImm(Sampled) // Sampled (0 = usage known at runtime).
1464 .addImm(ImageFormat);
1465 if (AccessQual != SPIRV::AccessQualifier::None)
1466 MIB.addImm(AccessQual);
1467 return MIB;
1468 });
1469 add(Key, NewMI);
1470 return NewMI;
1471 }
1472
1473 SPIRVType *
getOrCreateOpTypeSampler(MachineIRBuilder & MIRBuilder)1474 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1475 auto Key = SPIRV::irhandle_sampler();
1476 const MachineFunction *MF = &MIRBuilder.getMF();
1477 if (const MachineInstr *MI = findMI(Key, MF))
1478 return MI;
1479 const MachineInstr *NewMI =
1480 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1481 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler)
1482 .addDef(createTypeVReg(MIRBuilder));
1483 });
1484 add(Key, NewMI);
1485 return NewMI;
1486 }
1487
getOrCreateOpTypePipe(MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual)1488 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1489 MachineIRBuilder &MIRBuilder,
1490 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1491 auto Key = SPIRV::irhandle_pipe(AccessQual);
1492 if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1493 return MI;
1494 const MachineInstr *NewMI =
1495 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1496 return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
1497 .addDef(createTypeVReg(MIRBuilder))
1498 .addImm(AccessQual);
1499 });
1500 add(Key, NewMI);
1501 return NewMI;
1502 }
1503
getOrCreateOpTypeDeviceEvent(MachineIRBuilder & MIRBuilder)1504 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1505 MachineIRBuilder &MIRBuilder) {
1506 auto Key = SPIRV::irhandle_event();
1507 if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1508 return MI;
1509 const MachineInstr *NewMI =
1510 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1511 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent)
1512 .addDef(createTypeVReg(MIRBuilder));
1513 });
1514 add(Key, NewMI);
1515 return NewMI;
1516 }
1517
getOrCreateOpTypeSampledImage(SPIRVType * ImageType,MachineIRBuilder & MIRBuilder)1518 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1519 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1520 auto Key = SPIRV::irhandle_sampled_image(
1521 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
1522 ImageType->getOperand(1).getReg())),
1523 ImageType);
1524 if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1525 return MI;
1526 const MachineInstr *NewMI =
1527 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1528 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
1529 .addDef(createTypeVReg(MIRBuilder))
1530 .addUse(getSPIRVTypeID(ImageType));
1531 });
1532 add(Key, NewMI);
1533 return NewMI;
1534 }
1535
getOrCreateOpTypeCoopMatr(MachineIRBuilder & MIRBuilder,const TargetExtType * ExtensionType,const SPIRVType * ElemType,uint32_t Scope,uint32_t Rows,uint32_t Columns,uint32_t Use,bool EmitIR)1536 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1537 MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
1538 const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
1539 uint32_t Use, bool EmitIR) {
1540 if (const MachineInstr *MI =
1541 findMI(ExtensionType, false, &MIRBuilder.getMF()))
1542 return MI;
1543 const MachineInstr *NewMI =
1544 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1545 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
1546 const Type *ET = getTypeForSPIRVType(ElemType);
1547 if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
1548 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
1549 .canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1550 MIRBuilder.buildInstr(SPIRV::OpCapability)
1551 .addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
1552 }
1553 return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
1554 .addDef(createTypeVReg(MIRBuilder))
1555 .addUse(getSPIRVTypeID(ElemType))
1556 .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, EmitIR))
1557 .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, EmitIR))
1558 .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, EmitIR))
1559 .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, EmitIR));
1560 });
1561 add(ExtensionType, false, NewMI);
1562 return NewMI;
1563 }
1564
getOrCreateOpTypeByOpcode(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode)1565 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1566 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1567 if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
1568 return MI;
1569 const MachineInstr *NewMI =
1570 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1571 return MIRBuilder.buildInstr(Opcode).addDef(createTypeVReg(MIRBuilder));
1572 });
1573 add(Ty, false, NewMI);
1574 return NewMI;
1575 }
1576
getOrCreateUnknownType(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode,const ArrayRef<MCOperand> Operands)1577 SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
1578 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
1579 const ArrayRef<MCOperand> Operands) {
1580 if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
1581 return MI;
1582 Register ResVReg = createTypeVReg(MIRBuilder);
1583 const MachineInstr *NewMI =
1584 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1585 MachineInstrBuilder MIB = MIRBuilder.buildInstr(SPIRV::UNKNOWN_type)
1586 .addDef(ResVReg)
1587 .addImm(Opcode);
1588 for (MCOperand Operand : Operands) {
1589 if (Operand.isReg()) {
1590 MIB.addUse(Operand.getReg());
1591 } else if (Operand.isImm()) {
1592 MIB.addImm(Operand.getImm());
1593 }
1594 }
1595 return MIB;
1596 });
1597 add(Ty, false, NewMI);
1598 return NewMI;
1599 }
1600
1601 // Returns nullptr if unable to recognize SPIRV type name
getOrCreateSPIRVTypeByName(StringRef TypeStr,MachineIRBuilder & MIRBuilder,bool EmitIR,SPIRV::StorageClass::StorageClass SC,SPIRV::AccessQualifier::AccessQualifier AQ)1602 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1603 StringRef TypeStr, MachineIRBuilder &MIRBuilder, bool EmitIR,
1604 SPIRV::StorageClass::StorageClass SC,
1605 SPIRV::AccessQualifier::AccessQualifier AQ) {
1606 unsigned VecElts = 0;
1607 auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1608
1609 // Parse strings representing either a SPIR-V or OpenCL builtin type.
1610 if (hasBuiltinTypePrefix(TypeStr))
1611 return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1612 TypeStr.str(), MIRBuilder.getContext()),
1613 MIRBuilder, AQ, false, true);
1614
1615 // Parse type name in either "typeN" or "type vector[N]" format, where
1616 // N is the number of elements of the vector.
1617 Type *Ty;
1618
1619 Ty = parseBasicTypeName(TypeStr, Ctx);
1620 if (!Ty)
1621 // Unable to recognize SPIRV type name
1622 return nullptr;
1623
1624 const SPIRVType *SpirvTy =
1625 getOrCreateSPIRVType(Ty, MIRBuilder, AQ, false, true);
1626
1627 // Handle "type*" or "type* vector[N]".
1628 if (TypeStr.consume_front("*"))
1629 SpirvTy = getOrCreateSPIRVPointerType(Ty, MIRBuilder, SC);
1630
1631 // Handle "typeN*" or "type vector[N]*".
1632 bool IsPtrToVec = TypeStr.consume_back("*");
1633
1634 if (TypeStr.consume_front(" vector[")) {
1635 TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1636 }
1637 TypeStr.getAsInteger(10, VecElts);
1638 if (VecElts > 0)
1639 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder, EmitIR);
1640
1641 if (IsPtrToVec)
1642 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1643
1644 return SpirvTy;
1645 }
1646
1647 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)1648 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1649 MachineIRBuilder &MIRBuilder) {
1650 return getOrCreateSPIRVType(
1651 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1652 MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, true);
1653 }
1654
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)1655 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1656 SPIRVType *SpirvType) {
1657 assert(CurMF == SpirvType->getMF());
1658 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1659 SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
1660 return SpirvType;
1661 }
1662
getOrCreateSPIRVType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII,unsigned SPIRVOPcode,Type * Ty)1663 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1664 MachineInstr &I,
1665 const SPIRVInstrInfo &TII,
1666 unsigned SPIRVOPcode,
1667 Type *Ty) {
1668 if (const MachineInstr *MI = findMI(Ty, false, CurMF))
1669 return MI;
1670 MachineBasicBlock &DepMBB = I.getMF()->front();
1671 MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
1672 const MachineInstr *NewMI =
1673 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1674 return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1675 MIRBuilder.getDL(), TII.get(SPIRVOPcode))
1676 .addDef(createTypeVReg(CurMF->getRegInfo()))
1677 .addImm(BitWidth)
1678 .addImm(0);
1679 });
1680 add(Ty, false, NewMI);
1681 return finishCreatingSPIRVType(Ty, NewMI);
1682 }
1683
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1684 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1685 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1686 // Maybe adjust bit width to keep DuplicateTracker consistent. Without
1687 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1688 // example, the same "OpTypeInt 8" type for a series of LLVM integer types
1689 // with number of bits less than 8, causing duplicate type definitions.
1690 if (BitWidth > 1)
1691 BitWidth = adjustOpTypeIntWidth(BitWidth);
1692 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1693 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1694 }
1695
getOrCreateSPIRVFloatType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1696 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1697 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1698 LLVMContext &Ctx = CurMF->getFunction().getContext();
1699 Type *LLVMTy;
1700 switch (BitWidth) {
1701 case 16:
1702 LLVMTy = Type::getHalfTy(Ctx);
1703 break;
1704 case 32:
1705 LLVMTy = Type::getFloatTy(Ctx);
1706 break;
1707 case 64:
1708 LLVMTy = Type::getDoubleTy(Ctx);
1709 break;
1710 default:
1711 llvm_unreachable("Bit width is of unexpected size.");
1712 }
1713 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1714 }
1715
1716 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder,bool EmitIR)1717 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder,
1718 bool EmitIR) {
1719 return getOrCreateSPIRVType(
1720 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1721 MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR);
1722 }
1723
1724 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)1725 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1726 const SPIRVInstrInfo &TII) {
1727 Type *Ty = IntegerType::get(CurMF->getFunction().getContext(), 1);
1728 if (const MachineInstr *MI = findMI(Ty, false, CurMF))
1729 return MI;
1730 MachineBasicBlock &DepMBB = I.getMF()->front();
1731 MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
1732 const MachineInstr *NewMI =
1733 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1734 return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1735 MIRBuilder.getDL(), TII.get(SPIRV::OpTypeBool))
1736 .addDef(createTypeVReg(CurMF->getRegInfo()));
1737 });
1738 add(Ty, false, NewMI);
1739 return finishCreatingSPIRVType(Ty, NewMI);
1740 }
1741
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder,bool EmitIR)1742 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1743 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder,
1744 bool EmitIR) {
1745 return getOrCreateSPIRVType(
1746 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1747 NumElements),
1748 MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR);
1749 }
1750
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1751 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1752 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1753 const SPIRVInstrInfo &TII) {
1754 Type *Ty = FixedVectorType::get(
1755 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1756 if (const MachineInstr *MI = findMI(Ty, false, CurMF))
1757 return MI;
1758 MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
1759 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1760 const MachineInstr *NewMI =
1761 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1762 return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1763 MIRBuilder.getDL(), TII.get(SPIRV::OpTypeVector))
1764 .addDef(createTypeVReg(CurMF->getRegInfo()))
1765 .addUse(getSPIRVTypeID(BaseType))
1766 .addImm(NumElements);
1767 });
1768 add(Ty, false, NewMI);
1769 return finishCreatingSPIRVType(Ty, NewMI);
1770 }
1771
getOrCreateSPIRVPointerType(const Type * BaseType,MachineInstr & I,SPIRV::StorageClass::StorageClass SC)1772 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1773 const Type *BaseType, MachineInstr &I,
1774 SPIRV::StorageClass::StorageClass SC) {
1775 MachineIRBuilder MIRBuilder(I);
1776 return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1777 }
1778
getOrCreateSPIRVPointerType(const Type * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1779 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1780 const Type *BaseType, MachineIRBuilder &MIRBuilder,
1781 SPIRV::StorageClass::StorageClass SC) {
1782 // TODO: Need to check if EmitIr should always be true.
1783 SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
1784 BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
1785 storageClassRequiresExplictLayout(SC), true);
1786 assert(SpirvBaseType);
1787 return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
1788 }
1789
changePointerStorageClass(SPIRVType * PtrType,SPIRV::StorageClass::StorageClass SC,MachineInstr & I)1790 SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
1791 SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
1792 [[maybe_unused]] SPIRV::StorageClass::StorageClass OldSC =
1793 getPointerStorageClass(PtrType);
1794 assert(storageClassRequiresExplictLayout(OldSC) ==
1795 storageClassRequiresExplictLayout(SC));
1796
1797 SPIRVType *PointeeType = getPointeeType(PtrType);
1798 MachineIRBuilder MIRBuilder(I);
1799 return getOrCreateSPIRVPointerTypeInternal(PointeeType, MIRBuilder, SC);
1800 }
1801
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1802 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1803 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1804 SPIRV::StorageClass::StorageClass SC) {
1805 const Type *LLVMType = getTypeForSPIRVType(BaseType);
1806 assert(!storageClassRequiresExplictLayout(SC));
1807 SPIRVType *R = getOrCreateSPIRVPointerType(LLVMType, MIRBuilder, SC);
1808 assert(
1809 getPointeeType(R) == BaseType &&
1810 "The base type was not correctly laid out for the given storage class.");
1811 return R;
1812 }
1813
getOrCreateSPIRVPointerTypeInternal(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1814 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
1815 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1816 SPIRV::StorageClass::StorageClass SC) {
1817 const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1818 unsigned AddressSpace = storageClassToAddressSpace(SC);
1819 if (const MachineInstr *MI = findMI(PointerElementType, AddressSpace, CurMF))
1820 return MI;
1821 Type *Ty = TypedPointerType::get(const_cast<Type *>(PointerElementType),
1822 AddressSpace);
1823 const MachineInstr *NewMI =
1824 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1825 return BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1826 MIRBuilder.getDebugLoc(),
1827 MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1828 .addDef(createTypeVReg(CurMF->getRegInfo()))
1829 .addImm(static_cast<uint32_t>(SC))
1830 .addUse(getSPIRVTypeID(BaseType));
1831 });
1832 add(PointerElementType, AddressSpace, NewMI);
1833 return finishCreatingSPIRVType(Ty, NewMI);
1834 }
1835
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)1836 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1837 SPIRVType *SpvType,
1838 const SPIRVInstrInfo &TII) {
1839 UndefValue *UV =
1840 UndefValue::get(const_cast<Type *>(getTypeForSPIRVType(SpvType)));
1841 Register Res = find(UV, CurMF);
1842 if (Res.isValid())
1843 return Res;
1844
1845 LLT LLTy = LLT::scalar(64);
1846 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1847 CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
1848 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1849
1850 MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
1851 MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1852 const MachineInstr *NewMI =
1853 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1854 auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1855 MIRBuilder.getDL(), TII.get(SPIRV::OpUndef))
1856 .addDef(Res)
1857 .addUse(getSPIRVTypeID(SpvType));
1858 const auto &ST = CurMF->getSubtarget();
1859 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1860 *ST.getRegisterInfo(),
1861 *ST.getRegBankInfo());
1862 return MIB;
1863 });
1864 add(UV, NewMI);
1865 return Res;
1866 }
1867
1868 const TargetRegisterClass *
getRegClass(SPIRVType * SpvType) const1869 SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
1870 unsigned Opcode = SpvType->getOpcode();
1871 switch (Opcode) {
1872 case SPIRV::OpTypeFloat:
1873 return &SPIRV::fIDRegClass;
1874 case SPIRV::OpTypePointer:
1875 return &SPIRV::pIDRegClass;
1876 case SPIRV::OpTypeVector: {
1877 SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
1878 unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
1879 if (ElemOpcode == SPIRV::OpTypeFloat)
1880 return &SPIRV::vfIDRegClass;
1881 if (ElemOpcode == SPIRV::OpTypePointer)
1882 return &SPIRV::vpIDRegClass;
1883 return &SPIRV::vIDRegClass;
1884 }
1885 }
1886 return &SPIRV::iIDRegClass;
1887 }
1888
getAS(SPIRVType * SpvType)1889 inline unsigned getAS(SPIRVType *SpvType) {
1890 return storageClassToAddressSpace(
1891 static_cast<SPIRV::StorageClass::StorageClass>(
1892 SpvType->getOperand(1).getImm()));
1893 }
1894
getRegType(SPIRVType * SpvType) const1895 LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
1896 unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
1897 switch (Opcode) {
1898 case SPIRV::OpTypeInt:
1899 case SPIRV::OpTypeFloat:
1900 case SPIRV::OpTypeBool:
1901 return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
1902 case SPIRV::OpTypePointer:
1903 return LLT::pointer(getAS(SpvType), getPointerSize());
1904 case SPIRV::OpTypeVector: {
1905 SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
1906 LLT ET;
1907 switch (ElemType ? ElemType->getOpcode() : 0) {
1908 case SPIRV::OpTypePointer:
1909 ET = LLT::pointer(getAS(ElemType), getPointerSize());
1910 break;
1911 case SPIRV::OpTypeInt:
1912 case SPIRV::OpTypeFloat:
1913 case SPIRV::OpTypeBool:
1914 ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType));
1915 break;
1916 default:
1917 ET = LLT::scalar(64);
1918 }
1919 return LLT::fixed_vector(
1920 static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
1921 }
1922 }
1923 return LLT::scalar(64);
1924 }
1925
1926 // Aliasing list MD contains several scope MD nodes whithin it. Each scope MD
1927 // has a selfreference and an extra MD node for aliasing domain and also it
1928 // can contain an optional string operand. Domain MD contains a self-reference
1929 // with an optional string operand. Here we unfold the list, creating SPIR-V
1930 // aliasing instructions.
1931 // TODO: add support for an optional string operand.
getOrAddMemAliasingINTELInst(MachineIRBuilder & MIRBuilder,const MDNode * AliasingListMD)1932 MachineInstr *SPIRVGlobalRegistry::getOrAddMemAliasingINTELInst(
1933 MachineIRBuilder &MIRBuilder, const MDNode *AliasingListMD) {
1934 if (AliasingListMD->getNumOperands() == 0)
1935 return nullptr;
1936 if (auto L = AliasInstMDMap.find(AliasingListMD); L != AliasInstMDMap.end())
1937 return L->second;
1938
1939 SmallVector<MachineInstr *> ScopeList;
1940 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1941 for (const MDOperand &MDListOp : AliasingListMD->operands()) {
1942 if (MDNode *ScopeMD = dyn_cast<MDNode>(MDListOp)) {
1943 if (ScopeMD->getNumOperands() < 2)
1944 return nullptr;
1945 MDNode *DomainMD = dyn_cast<MDNode>(ScopeMD->getOperand(1));
1946 if (!DomainMD)
1947 return nullptr;
1948 auto *Domain = [&] {
1949 auto D = AliasInstMDMap.find(DomainMD);
1950 if (D != AliasInstMDMap.end())
1951 return D->second;
1952 const Register Ret = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1953 auto MIB =
1954 MIRBuilder.buildInstr(SPIRV::OpAliasDomainDeclINTEL).addDef(Ret);
1955 return MIB.getInstr();
1956 }();
1957 AliasInstMDMap.insert(std::make_pair(DomainMD, Domain));
1958 auto *Scope = [&] {
1959 auto S = AliasInstMDMap.find(ScopeMD);
1960 if (S != AliasInstMDMap.end())
1961 return S->second;
1962 const Register Ret = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1963 auto MIB = MIRBuilder.buildInstr(SPIRV::OpAliasScopeDeclINTEL)
1964 .addDef(Ret)
1965 .addUse(Domain->getOperand(0).getReg());
1966 return MIB.getInstr();
1967 }();
1968 AliasInstMDMap.insert(std::make_pair(ScopeMD, Scope));
1969 ScopeList.push_back(Scope);
1970 }
1971 }
1972
1973 const Register Ret = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1974 auto MIB =
1975 MIRBuilder.buildInstr(SPIRV::OpAliasScopeListDeclINTEL).addDef(Ret);
1976 for (auto *Scope : ScopeList)
1977 MIB.addUse(Scope->getOperand(0).getReg());
1978 auto List = MIB.getInstr();
1979 AliasInstMDMap.insert(std::make_pair(AliasingListMD, List));
1980 return List;
1981 }
1982
buildMemAliasingOpDecorate(Register Reg,MachineIRBuilder & MIRBuilder,uint32_t Dec,const MDNode * AliasingListMD)1983 void SPIRVGlobalRegistry::buildMemAliasingOpDecorate(
1984 Register Reg, MachineIRBuilder &MIRBuilder, uint32_t Dec,
1985 const MDNode *AliasingListMD) {
1986 MachineInstr *AliasList =
1987 getOrAddMemAliasingINTELInst(MIRBuilder, AliasingListMD);
1988 if (!AliasList)
1989 return;
1990 MIRBuilder.buildInstr(SPIRV::OpDecorate)
1991 .addUse(Reg)
1992 .addImm(Dec)
1993 .addUse(AliasList->getOperand(0).getReg());
1994 }
replaceAllUsesWith(Value * Old,Value * New,bool DeleteOld)1995 void SPIRVGlobalRegistry::replaceAllUsesWith(Value *Old, Value *New,
1996 bool DeleteOld) {
1997 Old->replaceAllUsesWith(New);
1998 updateIfExistDeducedElementType(Old, New, DeleteOld);
1999 updateIfExistAssignPtrTypeInstr(Old, New, DeleteOld);
2000 }
2001
buildAssignType(IRBuilder<> & B,Type * Ty,Value * Arg)2002 void SPIRVGlobalRegistry::buildAssignType(IRBuilder<> &B, Type *Ty,
2003 Value *Arg) {
2004 Value *OfType = getNormalizedPoisonValue(Ty);
2005 CallInst *AssignCI = nullptr;
2006 if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
2007 allowEmitFakeUse(Arg)) {
2008 LLVMContext &Ctx = Arg->getContext();
2009 SmallVector<Metadata *, 2> ArgMDs{
2010 MDNode::get(Ctx, ValueAsMetadata::getConstant(OfType)),
2011 MDString::get(Ctx, Arg->getName())};
2012 B.CreateIntrinsic(Intrinsic::spv_value_md,
2013 {MetadataAsValue::get(Ctx, MDTuple::get(Ctx, ArgMDs))});
2014 AssignCI = B.CreateIntrinsic(Intrinsic::fake_use, {Arg});
2015 } else {
2016 AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type, {Arg->getType()},
2017 OfType, Arg, {}, B);
2018 }
2019 addAssignPtrTypeInstr(Arg, AssignCI);
2020 }
2021
buildAssignPtr(IRBuilder<> & B,Type * ElemTy,Value * Arg)2022 void SPIRVGlobalRegistry::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
2023 Value *Arg) {
2024 Value *OfType = PoisonValue::get(ElemTy);
2025 CallInst *AssignPtrTyCI = findAssignPtrTypeInstr(Arg);
2026 Function *CurrF =
2027 B.GetInsertBlock() ? B.GetInsertBlock()->getParent() : nullptr;
2028 if (AssignPtrTyCI == nullptr ||
2029 AssignPtrTyCI->getParent()->getParent() != CurrF) {
2030 AssignPtrTyCI = buildIntrWithMD(
2031 Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
2032 {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
2033 addDeducedElementType(AssignPtrTyCI, ElemTy);
2034 addDeducedElementType(Arg, ElemTy);
2035 addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
2036 } else {
2037 updateAssignType(AssignPtrTyCI, Arg, OfType);
2038 }
2039 }
2040
updateAssignType(CallInst * AssignCI,Value * Arg,Value * OfType)2041 void SPIRVGlobalRegistry::updateAssignType(CallInst *AssignCI, Value *Arg,
2042 Value *OfType) {
2043 AssignCI->setArgOperand(1, buildMD(OfType));
2044 if (cast<IntrinsicInst>(AssignCI)->getIntrinsicID() !=
2045 Intrinsic::spv_assign_ptr_type)
2046 return;
2047
2048 // update association with the pointee type
2049 Type *ElemTy = OfType->getType();
2050 addDeducedElementType(AssignCI, ElemTy);
2051 addDeducedElementType(Arg, ElemTy);
2052 }
2053
addStructOffsetDecorations(Register Reg,StructType * Ty,MachineIRBuilder & MIRBuilder)2054 void SPIRVGlobalRegistry::addStructOffsetDecorations(
2055 Register Reg, StructType *Ty, MachineIRBuilder &MIRBuilder) {
2056 DataLayout DL;
2057 ArrayRef<TypeSize> Offsets = DL.getStructLayout(Ty)->getMemberOffsets();
2058 for (uint32_t I = 0; I < Ty->getNumElements(); ++I) {
2059 buildOpMemberDecorate(Reg, MIRBuilder, SPIRV::Decoration::Offset, I,
2060 {static_cast<uint32_t>(Offsets[I])});
2061 }
2062 }
2063
addArrayStrideDecorations(Register Reg,Type * ElementType,MachineIRBuilder & MIRBuilder)2064 void SPIRVGlobalRegistry::addArrayStrideDecorations(
2065 Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) {
2066 uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(ElementType) / 8;
2067 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::ArrayStride,
2068 {SizeInBytes});
2069 }
2070
hasBlockDecoration(SPIRVType * Type) const2071 bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
2072 Register Def = getSPIRVTypeID(Type);
2073 for (const MachineInstr &Use :
2074 Type->getMF()->getRegInfo().use_instructions(Def)) {
2075 if (Use.getOpcode() != SPIRV::OpDecorate)
2076 continue;
2077
2078 if (Use.getOperand(1).getImm() == SPIRV::Decoration::Block)
2079 return true;
2080 }
2081 return false;
2082 }
2083