xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVUtils.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <cassert>
31 #include <functional>
32 
33 using namespace llvm;
34 
allowEmitFakeUse(const Value * Arg)35 static bool allowEmitFakeUse(const Value *Arg) {
36   if (isSpvIntrinsic(Arg))
37     return false;
38   if (isa<AtomicCmpXchgInst, InsertValueInst, UndefValue>(Arg))
39     return false;
40   if (const auto *LI = dyn_cast<LoadInst>(Arg))
41     if (LI->getType()->isAggregateType())
42       return false;
43   return true;
44 }
45 
typeToAddressSpace(const Type * Ty)46 static unsigned typeToAddressSpace(const Type *Ty) {
47   if (auto PType = dyn_cast<TypedPointerType>(Ty))
48     return PType->getAddressSpace();
49   if (auto PType = dyn_cast<PointerType>(Ty))
50     return PType->getAddressSpace();
51   if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
52       ExtTy && isTypedPointerWrapper(ExtTy))
53     return ExtTy->getIntParameter(0);
54   reportFatalInternalError("Unable to convert LLVM type to SPIRVType");
55 }
56 
57 static bool
storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC)58 storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
59   switch (SC) {
60   case SPIRV::StorageClass::Uniform:
61   case SPIRV::StorageClass::PushConstant:
62   case SPIRV::StorageClass::StorageBuffer:
63   case SPIRV::StorageClass::PhysicalStorageBufferEXT:
64     return true;
65   case SPIRV::StorageClass::UniformConstant:
66   case SPIRV::StorageClass::Input:
67   case SPIRV::StorageClass::Output:
68   case SPIRV::StorageClass::Workgroup:
69   case SPIRV::StorageClass::CrossWorkgroup:
70   case SPIRV::StorageClass::Private:
71   case SPIRV::StorageClass::Function:
72   case SPIRV::StorageClass::Generic:
73   case SPIRV::StorageClass::AtomicCounter:
74   case SPIRV::StorageClass::Image:
75   case SPIRV::StorageClass::CallableDataNV:
76   case SPIRV::StorageClass::IncomingCallableDataNV:
77   case SPIRV::StorageClass::RayPayloadNV:
78   case SPIRV::StorageClass::HitAttributeNV:
79   case SPIRV::StorageClass::IncomingRayPayloadNV:
80   case SPIRV::StorageClass::ShaderRecordBufferNV:
81   case SPIRV::StorageClass::CodeSectionINTEL:
82   case SPIRV::StorageClass::DeviceOnlyINTEL:
83   case SPIRV::StorageClass::HostOnlyINTEL:
84     return false;
85   }
86   llvm_unreachable("Unknown SPIRV::StorageClass enum");
87 }
88 
SPIRVGlobalRegistry(unsigned PointerSize)89 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
90     : PointerSize(PointerSize), Bound(0) {}
91 
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)92 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
93                                                     Register VReg,
94                                                     MachineInstr &I,
95                                                     const SPIRVInstrInfo &TII) {
96   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
97   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
98   return SpirvType;
99 }
100 
101 SPIRVType *
assignFloatTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)102 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
103                                            MachineInstr &I,
104                                            const SPIRVInstrInfo &TII) {
105   SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
106   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
107   return SpirvType;
108 }
109 
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)110 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
111     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
112     const SPIRVInstrInfo &TII) {
113   SPIRVType *SpirvType =
114       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
115   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
116   return SpirvType;
117 }
118 
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)119 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
120     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
121     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
122   SPIRVType *SpirvType =
123       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
124   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
125   return SpirvType;
126 }
127 
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,const MachineFunction & MF)128 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
129                                                 Register VReg,
130                                                 const MachineFunction &MF) {
131   VRegToTypeMap[&MF][VReg] = SpirvType;
132 }
133 
createTypeVReg(MachineRegisterInfo & MRI)134 static Register createTypeVReg(MachineRegisterInfo &MRI) {
135   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
136   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
137   return Res;
138 }
139 
createTypeVReg(MachineIRBuilder & MIRBuilder)140 inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
141   return createTypeVReg(MIRBuilder.getMF().getRegInfo());
142 }
143 
getOpTypeBool(MachineIRBuilder & MIRBuilder)144 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
145   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
146     return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
147         .addDef(createTypeVReg(MIRBuilder));
148   });
149 }
150 
adjustOpTypeIntWidth(unsigned Width) const151 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
152   if (Width > 64)
153     report_fatal_error("Unsupported integer width!");
154   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
155   if (ST.canUseExtension(
156           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
157       ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
158     return Width;
159   if (Width <= 8)
160     Width = 8;
161   else if (Width <= 16)
162     Width = 16;
163   else if (Width <= 32)
164     Width = 32;
165   else
166     Width = 64;
167   return Width;
168 }
169 
getOpTypeInt(unsigned Width,MachineIRBuilder & MIRBuilder,bool IsSigned)170 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
171                                              MachineIRBuilder &MIRBuilder,
172                                              bool IsSigned) {
173   Width = adjustOpTypeIntWidth(Width);
174   const SPIRVSubtarget &ST =
175       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
176   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
177     if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
178       MIRBuilder.buildInstr(SPIRV::OpExtension)
179           .addImm(SPIRV::Extension::SPV_INTEL_int4);
180       MIRBuilder.buildInstr(SPIRV::OpCapability)
181           .addImm(SPIRV::Capability::Int4TypeINTEL);
182     } else if ((!isPowerOf2_32(Width) || Width < 8) &&
183                ST.canUseExtension(
184                    SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
185       MIRBuilder.buildInstr(SPIRV::OpExtension)
186           .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
187       MIRBuilder.buildInstr(SPIRV::OpCapability)
188           .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
189     }
190     return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
191         .addDef(createTypeVReg(MIRBuilder))
192         .addImm(Width)
193         .addImm(IsSigned ? 1 : 0);
194   });
195 }
196 
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)197 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
198                                                MachineIRBuilder &MIRBuilder) {
199   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
200     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
201         .addDef(createTypeVReg(MIRBuilder))
202         .addImm(Width);
203   });
204 }
205 
getOpTypeVoid(MachineIRBuilder & MIRBuilder)206 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
207   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
208     return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
209         .addDef(createTypeVReg(MIRBuilder));
210   });
211 }
212 
invalidateMachineInstr(MachineInstr * MI)213 void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
214   // TODO:
215   // - review other data structure wrt. possible issues related to removal
216   //   of a machine instruction during instruction selection.
217   const MachineFunction *MF = MI->getMF();
218   auto It = LastInsertedTypeMap.find(MF);
219   if (It == LastInsertedTypeMap.end())
220     return;
221   if (It->second == MI)
222     LastInsertedTypeMap.erase(MF);
223   // remove from the duplicate tracker to avoid incorrect reuse
224   erase(MI);
225 }
226 
createOpType(MachineIRBuilder & MIRBuilder,std::function<MachineInstr * (MachineIRBuilder &)> Op)227 SPIRVType *SPIRVGlobalRegistry::createOpType(
228     MachineIRBuilder &MIRBuilder,
229     std::function<MachineInstr *(MachineIRBuilder &)> Op) {
230   auto oldInsertPoint = MIRBuilder.getInsertPt();
231   MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
232   MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin();
233 
234   auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
235   if (LastInsertedType != LastInsertedTypeMap.end()) {
236     auto It = LastInsertedType->second->getIterator();
237     // It might happen that this instruction was removed from the first MBB,
238     // hence the Parent's check.
239     MachineBasicBlock::iterator InsertAt;
240     if (It->getParent() != NewMBB)
241       InsertAt = oldInsertPoint->getParent() == NewMBB
242                      ? oldInsertPoint
243                      : getInsertPtValidEnd(NewMBB);
244     else if (It->getNextNode())
245       InsertAt = It->getNextNode()->getIterator();
246     else
247       InsertAt = getInsertPtValidEnd(NewMBB);
248     MIRBuilder.setInsertPt(*NewMBB, InsertAt);
249   } else {
250     MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin());
251     auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
252     assert(Result.second);
253     LastInsertedType = Result.first;
254   }
255 
256   MachineInstr *Type = Op(MIRBuilder);
257   // We expect all users of this function to insert definitions at the insertion
258   // point set above that is always the first MBB.
259   assert(Type->getParent() == NewMBB);
260   LastInsertedType->second = Type;
261 
262   MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
263   return Type;
264 }
265 
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)266 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
267                                                 SPIRVType *ElemType,
268                                                 MachineIRBuilder &MIRBuilder) {
269   auto EleOpc = ElemType->getOpcode();
270   (void)EleOpc;
271   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
272           EleOpc == SPIRV::OpTypeBool) &&
273          "Invalid vector element type");
274 
275   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
276     return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
277         .addDef(createTypeVReg(MIRBuilder))
278         .addUse(getSPIRVTypeID(ElemType))
279         .addImm(NumElems);
280   });
281 }
282 
getOrCreateConstFP(APFloat Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)283 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
284                                                  SPIRVType *SpvType,
285                                                  const SPIRVInstrInfo &TII,
286                                                  bool ZeroAsNull) {
287   LLVMContext &Ctx = CurMF->getFunction().getContext();
288   auto *const CF = ConstantFP::get(Ctx, Val);
289   const MachineInstr *MI = findMI(CF, CurMF);
290   if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
291              MI->getOpcode() == SPIRV::OpConstantF))
292     return MI->getOperand(0).getReg();
293   return createConstFP(CF, I, SpvType, TII, ZeroAsNull);
294 }
295 
createConstFP(const ConstantFP * CF,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)296 Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
297                                             MachineInstr &I, SPIRVType *SpvType,
298                                             const SPIRVInstrInfo &TII,
299                                             bool ZeroAsNull) {
300   unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
301   LLT LLTy = LLT::scalar(BitWidth);
302   Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
303   CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
304   assignFloatTypeToVReg(BitWidth, Res, I, TII);
305 
306   MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
307   MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
308   SPIRVType *NewType =
309       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
310         MachineInstrBuilder MIB;
311         // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
312         if (CF->getValue().isPosZero() && ZeroAsNull) {
313           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
314                     .addDef(Res)
315                     .addUse(getSPIRVTypeID(SpvType));
316         } else {
317           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
318                     .addDef(Res)
319                     .addUse(getSPIRVTypeID(SpvType));
320           addNumImm(APInt(BitWidth,
321                           CF->getValueAPF().bitcastToAPInt().getZExtValue()),
322                     MIB);
323         }
324         const auto &ST = CurMF->getSubtarget();
325         constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
326                                          *ST.getRegisterInfo(),
327                                          *ST.getRegBankInfo());
328         return MIB;
329       });
330   add(CF, NewType);
331   return Res;
332 }
333 
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)334 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
335                                                   SPIRVType *SpvType,
336                                                   const SPIRVInstrInfo &TII,
337                                                   bool ZeroAsNull) {
338   const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
339   auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
340   const MachineInstr *MI = findMI(CI, CurMF);
341   if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
342              MI->getOpcode() == SPIRV::OpConstantI))
343     return MI->getOperand(0).getReg();
344   return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
345 }
346 
createConstInt(const ConstantInt * CI,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)347 Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
348                                              MachineInstr &I,
349                                              SPIRVType *SpvType,
350                                              const SPIRVInstrInfo &TII,
351                                              bool ZeroAsNull) {
352   unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
353   LLT LLTy = LLT::scalar(BitWidth);
354   Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
355   CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
356   assignIntTypeToVReg(BitWidth, Res, I, TII);
357 
358   MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
359   MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
360   SPIRVType *NewType =
361       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
362         MachineInstrBuilder MIB;
363         if (BitWidth == 1) {
364           MIB = MIRBuilder
365                     .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
366                                              : SPIRV::OpConstantTrue)
367                     .addDef(Res)
368                     .addUse(getSPIRVTypeID(SpvType));
369         } else if (!CI->isZero() || !ZeroAsNull) {
370           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
371                     .addDef(Res)
372                     .addUse(getSPIRVTypeID(SpvType));
373           addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
374         } else {
375           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
376                     .addDef(Res)
377                     .addUse(getSPIRVTypeID(SpvType));
378         }
379         const auto &ST = CurMF->getSubtarget();
380         constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
381                                          *ST.getRegisterInfo(),
382                                          *ST.getRegBankInfo());
383         return MIB;
384       });
385   add(CI, NewType);
386   return Res;
387 }
388 
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,bool ZeroAsNull)389 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
390                                                MachineIRBuilder &MIRBuilder,
391                                                SPIRVType *SpvType, bool EmitIR,
392                                                bool ZeroAsNull) {
393   assert(SpvType);
394   auto &MF = MIRBuilder.getMF();
395   const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
396   auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
397   Register Res = find(CI, &MF);
398   if (Res.isValid())
399     return Res;
400 
401   unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
402   LLT LLTy = LLT::scalar(BitWidth);
403   MachineRegisterInfo &MRI = MF.getRegInfo();
404   Res = MRI.createGenericVirtualRegister(LLTy);
405   MRI.setRegClass(Res, &SPIRV::iIDRegClass);
406   assignTypeToVReg(Ty, Res, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
407                    EmitIR);
408 
409   SPIRVType *NewType =
410       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
411         if (EmitIR)
412           return MIRBuilder.buildConstant(Res, *CI);
413         Register SpvTypeReg = getSPIRVTypeID(SpvType);
414         MachineInstrBuilder MIB;
415         if (Val || !ZeroAsNull) {
416           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
417                     .addDef(Res)
418                     .addUse(SpvTypeReg);
419           addNumImm(APInt(BitWidth, Val), MIB);
420         } else {
421           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
422                     .addDef(Res)
423                     .addUse(SpvTypeReg);
424         }
425         const auto &Subtarget = CurMF->getSubtarget();
426         constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
427                                          *Subtarget.getRegisterInfo(),
428                                          *Subtarget.getRegBankInfo());
429         return MIB;
430       });
431   add(CI, NewType);
432   return Res;
433 }
434 
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)435 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
436                                               MachineIRBuilder &MIRBuilder,
437                                               SPIRVType *SpvType) {
438   auto &MF = MIRBuilder.getMF();
439   LLVMContext &Ctx = MF.getFunction().getContext();
440   if (!SpvType)
441     SpvType = getOrCreateSPIRVType(Type::getFloatTy(Ctx), MIRBuilder,
442                                    SPIRV::AccessQualifier::ReadWrite, true);
443   auto *const CF = ConstantFP::get(Ctx, Val);
444   Register Res = find(CF, &MF);
445   if (Res.isValid())
446     return Res;
447 
448   LLT LLTy = LLT::scalar(getScalarOrVectorBitWidth(SpvType));
449   Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
450   MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
451   assignSPIRVTypeToVReg(SpvType, Res, MF);
452 
453   SPIRVType *NewType =
454       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
455         MachineInstrBuilder MIB;
456         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
457                   .addDef(Res)
458                   .addUse(getSPIRVTypeID(SpvType));
459         addNumImm(CF->getValueAPF().bitcastToAPInt(), MIB);
460         return MIB;
461       });
462   add(CF, NewType);
463   return Res;
464 }
465 
getOrCreateBaseRegister(Constant * Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,unsigned BitWidth,bool ZeroAsNull)466 Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
467     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
468     const SPIRVInstrInfo &TII, unsigned BitWidth, bool ZeroAsNull) {
469   SPIRVType *Type = SpvType;
470   if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
471       SpvType->getOpcode() == SPIRV::OpTypeArray) {
472     auto EleTypeReg = SpvType->getOperand(1).getReg();
473     Type = getSPIRVTypeForVReg(EleTypeReg);
474   }
475   if (Type->getOpcode() == SPIRV::OpTypeFloat) {
476     SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
477     return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
478                               SpvBaseType, TII, ZeroAsNull);
479   }
480   assert(Type->getOpcode() == SPIRV::OpTypeInt);
481   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
482   return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
483                              SpvBaseType, TII, ZeroAsNull);
484 }
485 
getOrCreateCompositeOrNull(Constant * Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,Constant * CA,unsigned BitWidth,unsigned ElemCnt,bool ZeroAsNull)486 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
487     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
488     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
489     unsigned ElemCnt, bool ZeroAsNull) {
490   if (Register R = find(CA, CurMF); R.isValid())
491     return R;
492 
493   bool IsNull = Val->isNullValue() && ZeroAsNull;
494   Register ElemReg;
495   if (!IsNull)
496     ElemReg =
497         getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull);
498 
499   LLT LLTy = LLT::scalar(64);
500   Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
501   CurMF->getRegInfo().setRegClass(Res, getRegClass(SpvType));
502   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
503 
504   MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
505   MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
506   const MachineInstr *NewMI =
507       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
508         MachineInstrBuilder MIB;
509         if (!IsNull) {
510           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
511                     .addDef(Res)
512                     .addUse(getSPIRVTypeID(SpvType));
513           for (unsigned i = 0; i < ElemCnt; ++i)
514             MIB.addUse(ElemReg);
515         } else {
516           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
517                     .addDef(Res)
518                     .addUse(getSPIRVTypeID(SpvType));
519         }
520         const auto &Subtarget = CurMF->getSubtarget();
521         constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
522                                          *Subtarget.getRegisterInfo(),
523                                          *Subtarget.getRegBankInfo());
524         return MIB;
525       });
526   add(CA, NewMI);
527   return Res;
528 }
529 
getOrCreateConstVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)530 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
531                                                      MachineInstr &I,
532                                                      SPIRVType *SpvType,
533                                                      const SPIRVInstrInfo &TII,
534                                                      bool ZeroAsNull) {
535   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
536   assert(LLVMTy->isVectorTy());
537   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
538   Type *LLVMBaseTy = LLVMVecTy->getElementType();
539   assert(LLVMBaseTy->isIntegerTy());
540   auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
541   auto *ConstVec =
542       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
543   unsigned BW = getScalarOrVectorBitWidth(SpvType);
544   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
545                                     SpvType->getOperand(2).getImm(),
546                                     ZeroAsNull);
547 }
548 
getOrCreateConstVector(APFloat Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)549 Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
550                                                      MachineInstr &I,
551                                                      SPIRVType *SpvType,
552                                                      const SPIRVInstrInfo &TII,
553                                                      bool ZeroAsNull) {
554   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
555   assert(LLVMTy->isVectorTy());
556   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
557   Type *LLVMBaseTy = LLVMVecTy->getElementType();
558   assert(LLVMBaseTy->isFloatingPointTy());
559   auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
560   auto *ConstVec =
561       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
562   unsigned BW = getScalarOrVectorBitWidth(SpvType);
563   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
564                                     SpvType->getOperand(2).getImm(),
565                                     ZeroAsNull);
566 }
567 
getOrCreateConstIntArray(uint64_t Val,size_t Num,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)568 Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
569     uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
570     const SPIRVInstrInfo &TII) {
571   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
572   assert(LLVMTy->isArrayTy());
573   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
574   Type *LLVMBaseTy = LLVMArrTy->getElementType();
575   Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
576   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
577   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
578   // The following is reasonably unique key that is better that [Val]. The naive
579   // alternative would be something along the lines of:
580   //   SmallVector<Constant *> NumCI(Num, CI);
581   //   Constant *UniqueKey =
582   //     ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
583   // that would be a truly unique but dangerous key, because it could lead to
584   // the creation of constants of arbitrary length (that is, the parameter of
585   // memset) which were missing in the original module.
586   Constant *UniqueKey = ConstantStruct::getAnon(
587       {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
588        ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
589   return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
590                                     LLVMArrTy->getNumElements());
591 }
592 
getOrCreateIntCompositeOrNull(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,Constant * CA,unsigned BitWidth,unsigned ElemCnt)593 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
594     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
595     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
596   if (Register R = find(CA, CurMF); R.isValid())
597     return R;
598 
599   Register ElemReg;
600   if (Val || EmitIR) {
601     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
602     ElemReg = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
603   }
604   LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
605   Register Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
606   CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
607   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
608 
609   const MachineInstr *NewMI =
610       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
611         if (EmitIR)
612           return MIRBuilder.buildSplatBuildVector(Res, ElemReg);
613 
614         if (Val) {
615           auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
616                          .addDef(Res)
617                          .addUse(getSPIRVTypeID(SpvType));
618           for (unsigned i = 0; i < ElemCnt; ++i)
619             MIB.addUse(ElemReg);
620           return MIB;
621         }
622 
623         return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
624             .addDef(Res)
625             .addUse(getSPIRVTypeID(SpvType));
626       });
627   add(CA, NewMI);
628   return Res;
629 }
630 
631 Register
getOrCreateConsIntVector(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)632 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
633                                               MachineIRBuilder &MIRBuilder,
634                                               SPIRVType *SpvType, bool EmitIR) {
635   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
636   assert(LLVMTy->isVectorTy());
637   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
638   Type *LLVMBaseTy = LLVMVecTy->getElementType();
639   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
640   auto ConstVec =
641       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
642   unsigned BW = getScalarOrVectorBitWidth(SpvType);
643   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
644                                        ConstVec, BW,
645                                        SpvType->getOperand(2).getImm());
646 }
647 
648 Register
getOrCreateConstNullPtr(MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)649 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
650                                              SPIRVType *SpvType) {
651   const Type *Ty = getTypeForSPIRVType(SpvType);
652   unsigned AddressSpace = typeToAddressSpace(Ty);
653   Type *ElemTy = ::getPointeeType(Ty);
654   assert(ElemTy);
655   const Constant *CP = ConstantTargetNone::get(
656       dyn_cast<TargetExtType>(getTypedPointerWrapper(ElemTy, AddressSpace)));
657   Register Res = find(CP, CurMF);
658   if (Res.isValid())
659     return Res;
660 
661   LLT LLTy = LLT::pointer(AddressSpace, PointerSize);
662   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
663   CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
664   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
665 
666   const MachineInstr *NewMI =
667       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
668         return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
669             .addDef(Res)
670             .addUse(getSPIRVTypeID(SpvType));
671       });
672   add(CP, NewMI);
673   return Res;
674 }
675 
676 Register
buildConstantSampler(Register ResReg,unsigned AddrMode,unsigned Param,unsigned FilerMode,MachineIRBuilder & MIRBuilder)677 SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode,
678                                           unsigned Param, unsigned FilerMode,
679                                           MachineIRBuilder &MIRBuilder) {
680   auto Sampler =
681       ResReg.isValid()
682           ? ResReg
683           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
684   SPIRVType *TypeSampler = getOrCreateOpTypeSampler(MIRBuilder);
685   Register TypeSamplerReg = getSPIRVTypeID(TypeSampler);
686   // We cannot use createOpType() logic here, because of the
687   // GlobalISel/IRTranslator.cpp check for a tail call that expects that
688   // MIRBuilder.getInsertPt() has a previous instruction. If this constant is
689   // inserted as a result of "__translate_sampler_initializer()" this would
690   // break this IRTranslator assumption.
691   MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
692       .addDef(Sampler)
693       .addUse(TypeSamplerReg)
694       .addImm(AddrMode)
695       .addImm(Param)
696       .addImm(FilerMode);
697   return Sampler;
698 }
699 
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)700 Register SPIRVGlobalRegistry::buildGlobalVariable(
701     Register ResVReg, SPIRVType *BaseType, StringRef Name,
702     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
703     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
704     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
705     bool IsInstSelector) {
706   const GlobalVariable *GVar = nullptr;
707   if (GV) {
708     GVar = cast<const GlobalVariable>(GV);
709   } else {
710     // If GV is not passed explicitly, use the name to find or construct
711     // the global variable.
712     Module *M = MIRBuilder.getMF().getFunction().getParent();
713     GVar = M->getGlobalVariable(Name);
714     if (GVar == nullptr) {
715       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
716       // Module takes ownership of the global var.
717       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
718                                 GlobalValue::ExternalLinkage, nullptr,
719                                 Twine(Name));
720     }
721     GV = GVar;
722   }
723 
724   const MachineFunction *MF = &MIRBuilder.getMF();
725   Register Reg = find(GVar, MF);
726   if (Reg.isValid()) {
727     if (Reg != ResVReg)
728       MIRBuilder.buildCopy(ResVReg, Reg);
729     return ResVReg;
730   }
731 
732   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
733                  .addDef(ResVReg)
734                  .addUse(getSPIRVTypeID(BaseType))
735                  .addImm(static_cast<uint32_t>(Storage));
736   if (Init != 0)
737     MIB.addUse(Init->getOperand(0).getReg());
738   // ISel may introduce a new register on this step, so we need to add it to
739   // DT and correct its type avoiding fails on the next stage.
740   if (IsInstSelector) {
741     const auto &Subtarget = CurMF->getSubtarget();
742     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
743                                      *Subtarget.getRegisterInfo(),
744                                      *Subtarget.getRegBankInfo());
745   }
746   add(GVar, MIB);
747 
748   Reg = MIB->getOperand(0).getReg();
749   addGlobalObject(GVar, MF, Reg);
750 
751   // Set to Reg the same type as ResVReg has.
752   auto MRI = MIRBuilder.getMRI();
753   if (Reg != ResVReg) {
754     LLT RegLLTy =
755         LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
756     MRI->setType(Reg, RegLLTy);
757     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
758   } else {
759     // Our knowledge about the type may be updated.
760     // If that's the case, we need to update a type
761     // associated with the register.
762     SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
763     if (!DefType || DefType != BaseType)
764       assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
765   }
766 
767   // If it's a global variable with name, output OpName for it.
768   if (GVar && GVar->hasName())
769     buildOpName(Reg, GVar->getName(), MIRBuilder);
770 
771   // Output decorations for the GV.
772   // TODO: maybe move to GenerateDecorations pass.
773   const SPIRVSubtarget &ST =
774       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
775   if (IsConst && !ST.isShader())
776     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
777 
778   if (GVar && GVar->getAlign().valueOrOne().value() != 1 && !ST.isShader()) {
779     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
780     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
781   }
782 
783   if (HasLinkageTy)
784     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
785                     {static_cast<uint32_t>(LinkageType)}, Name);
786 
787   SPIRV::BuiltIn::BuiltIn BuiltInId;
788   if (getSpirvBuiltInIdByName(Name, BuiltInId))
789     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
790                     {static_cast<uint32_t>(BuiltInId)});
791 
792   // If it's a global variable with "spirv.Decorations" metadata node
793   // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
794   // arguments.
795   MDNode *GVarMD = nullptr;
796   if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
797     buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
798 
799   return Reg;
800 }
801 
802 // Returns a name based on the Type. Notes that this does not look at
803 // decorations, and will return the same string for two types that are the same
804 // except for decorations.
getOrCreateGlobalVariableWithBinding(const SPIRVType * VarType,uint32_t Set,uint32_t Binding,StringRef Name,MachineIRBuilder & MIRBuilder)805 Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
806     const SPIRVType *VarType, uint32_t Set, uint32_t Binding, StringRef Name,
807     MachineIRBuilder &MIRBuilder) {
808   Register VarReg =
809       MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
810 
811   buildGlobalVariable(VarReg, VarType, Name, nullptr,
812                       getPointerStorageClass(VarType), nullptr, false, false,
813                       SPIRV::LinkageType::Import, MIRBuilder, false);
814 
815   buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set});
816   buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding});
817   return VarReg;
818 }
819 
820 // TODO: Double check the calls to getOpTypeArray to make sure that `ElemType`
821 // is explicitly laid out when required.
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool ExplicitLayoutRequired,bool EmitIR)822 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
823                                                SPIRVType *ElemType,
824                                                MachineIRBuilder &MIRBuilder,
825                                                bool ExplicitLayoutRequired,
826                                                bool EmitIR) {
827   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
828          "Invalid array element type");
829   SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
830   SPIRVType *ArrayType = nullptr;
831   if (NumElems != 0) {
832     Register NumElementsVReg =
833         buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
834     ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
835       return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
836           .addDef(createTypeVReg(MIRBuilder))
837           .addUse(getSPIRVTypeID(ElemType))
838           .addUse(NumElementsVReg);
839     });
840   } else {
841     ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
842       return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray)
843           .addDef(createTypeVReg(MIRBuilder))
844           .addUse(getSPIRVTypeID(ElemType));
845     });
846   }
847 
848   if (ExplicitLayoutRequired && !isResourceType(ElemType)) {
849     Type *ET = const_cast<Type *>(getTypeForSPIRVType(ElemType));
850     addArrayStrideDecorations(ArrayType->defs().begin()->getReg(), ET,
851                               MIRBuilder);
852   }
853 
854   return ArrayType;
855 }
856 
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)857 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
858                                                 MachineIRBuilder &MIRBuilder) {
859   assert(Ty->hasName());
860   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
861   Register ResVReg = createTypeVReg(MIRBuilder);
862   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
863     auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
864     addStringImm(Name, MIB);
865     buildOpName(ResVReg, Name, MIRBuilder);
866     return MIB;
867   });
868 }
869 
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,StructOffsetDecorator Decorator,bool EmitIR)870 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
871     const StructType *Ty, MachineIRBuilder &MIRBuilder,
872     SPIRV::AccessQualifier::AccessQualifier AccQual,
873     StructOffsetDecorator Decorator, bool EmitIR) {
874   const SPIRVSubtarget &ST =
875       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
876   SmallVector<Register, 4> FieldTypes;
877   constexpr unsigned MaxWordCount = UINT16_MAX;
878   const size_t NumElements = Ty->getNumElements();
879 
880   size_t MaxNumElements = MaxWordCount - 2;
881   size_t SPIRVStructNumElements = NumElements;
882   if (NumElements > MaxNumElements) {
883     // Do adjustments for continued instructions.
884     SPIRVStructNumElements = MaxNumElements;
885     MaxNumElements = MaxWordCount - 1;
886   }
887 
888   for (const auto &Elem : Ty->elements()) {
889     SPIRVType *ElemTy = findSPIRVType(
890         toTypedPointer(Elem), MIRBuilder, AccQual,
891         /* ExplicitLayoutRequired= */ Decorator != nullptr, EmitIR);
892     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
893            "Invalid struct element type");
894     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
895   }
896   Register ResVReg = createTypeVReg(MIRBuilder);
897   if (Ty->hasName())
898     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
899   if (Ty->isPacked() && !ST.isShader())
900     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
901 
902   SPIRVType *SPVType =
903       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
904         auto MIBStruct =
905             MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
906         for (size_t I = 0; I < SPIRVStructNumElements; ++I)
907           MIBStruct.addUse(FieldTypes[I]);
908         for (size_t I = SPIRVStructNumElements; I < NumElements;
909              I += MaxNumElements) {
910           auto MIBCont =
911               MIRBuilder.buildInstr(SPIRV::OpTypeStructContinuedINTEL);
912           for (size_t J = I; J < std::min(I + MaxNumElements, NumElements); ++J)
913             MIBCont.addUse(FieldTypes[I]);
914         }
915         return MIBStruct;
916       });
917 
918   if (Decorator)
919     Decorator(SPVType->defs().begin()->getReg());
920 
921   return SPVType;
922 }
923 
getOrCreateSpecialType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual)924 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
925     const Type *Ty, MachineIRBuilder &MIRBuilder,
926     SPIRV::AccessQualifier::AccessQualifier AccQual) {
927   assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
928   return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
929 }
930 
getOpTypePointer(SPIRV::StorageClass::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)931 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
932     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
933     MachineIRBuilder &MIRBuilder, Register Reg) {
934   if (!Reg.isValid())
935     Reg = createTypeVReg(MIRBuilder);
936 
937   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
938     return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
939         .addDef(Reg)
940         .addImm(static_cast<uint32_t>(SC))
941         .addUse(getSPIRVTypeID(ElemType));
942   });
943 }
944 
getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,MachineIRBuilder & MIRBuilder)945 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
946     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
947   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
948     return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
949         .addUse(createTypeVReg(MIRBuilder))
950         .addImm(static_cast<uint32_t>(SC));
951   });
952 }
953 
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)954 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
955     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
956     MachineIRBuilder &MIRBuilder) {
957   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
958     auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
959                    .addDef(createTypeVReg(MIRBuilder))
960                    .addUse(getSPIRVTypeID(RetType));
961     for (const SPIRVType *ArgType : ArgTypes)
962       MIB.addUse(getSPIRVTypeID(ArgType));
963     return MIB;
964   });
965 }
966 
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)967 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
968     const Type *Ty, SPIRVType *RetType,
969     const SmallVectorImpl<SPIRVType *> &ArgTypes,
970     MachineIRBuilder &MIRBuilder) {
971   if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
972     return MI;
973   const MachineInstr *NewMI = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
974   add(Ty, false, NewMI);
975   return finishCreatingSPIRVType(Ty, NewMI);
976 }
977 
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool ExplicitLayoutRequired,bool EmitIR)978 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
979     const Type *Ty, MachineIRBuilder &MIRBuilder,
980     SPIRV::AccessQualifier::AccessQualifier AccQual,
981     bool ExplicitLayoutRequired, bool EmitIR) {
982   Ty = adjustIntTypeByWidth(Ty);
983   // TODO: findMI needs to know if a layout is required.
984   if (const MachineInstr *MI =
985           findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF()))
986     return MI;
987   if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end())
988     return It->second;
989   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, ExplicitLayoutRequired,
990                                EmitIR);
991 }
992 
getSPIRVTypeID(const SPIRVType * SpirvType) const993 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
994   assert(SpirvType && "Attempting to get type id for nullptr type.");
995   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
996       SpirvType->getOpcode() == SPIRV::OpTypeStructContinuedINTEL)
997     return SpirvType->uses().begin()->getReg();
998   return SpirvType->defs().begin()->getReg();
999 }
1000 
1001 // We need to use a new LLVM integer type if there is a mismatch between
1002 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker
1003 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
1004 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
1005 // same "OpTypeInt 8" type for a series of LLVM integer types with number of
1006 // bits less than 8. This would lead to duplicate type definitions
1007 // eventually due to the method that DuplicateTracker utilizes to reason
1008 // about uniqueness of type records.
adjustIntTypeByWidth(const Type * Ty) const1009 const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
1010   if (auto IType = dyn_cast<IntegerType>(Ty)) {
1011     unsigned SrcBitWidth = IType->getBitWidth();
1012     if (SrcBitWidth > 1) {
1013       unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
1014       // Maybe change source LLVM type to keep DuplicateTracker consistent.
1015       if (SrcBitWidth != BitWidth)
1016         Ty = IntegerType::get(Ty->getContext(), BitWidth);
1017     }
1018   }
1019   return Ty;
1020 }
1021 
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool ExplicitLayoutRequired,bool EmitIR)1022 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
1023     const Type *Ty, MachineIRBuilder &MIRBuilder,
1024     SPIRV::AccessQualifier::AccessQualifier AccQual,
1025     bool ExplicitLayoutRequired, bool EmitIR) {
1026   if (isSpecialOpaqueType(Ty))
1027     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
1028 
1029   if (const MachineInstr *MI =
1030           findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF()))
1031     return MI;
1032 
1033   if (auto IType = dyn_cast<IntegerType>(Ty)) {
1034     const unsigned Width = IType->getBitWidth();
1035     return Width == 1 ? getOpTypeBool(MIRBuilder)
1036                       : getOpTypeInt(Width, MIRBuilder, false);
1037   }
1038   if (Ty->isFloatingPointTy())
1039     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
1040   if (Ty->isVoidTy())
1041     return getOpTypeVoid(MIRBuilder);
1042   if (Ty->isVectorTy()) {
1043     SPIRVType *El =
1044         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder,
1045                       AccQual, ExplicitLayoutRequired, EmitIR);
1046     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
1047                            MIRBuilder);
1048   }
1049   if (Ty->isArrayTy()) {
1050     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder,
1051                                   AccQual, ExplicitLayoutRequired, EmitIR);
1052     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder,
1053                           ExplicitLayoutRequired, EmitIR);
1054   }
1055   if (auto SType = dyn_cast<StructType>(Ty)) {
1056     if (SType->isOpaque())
1057       return getOpTypeOpaque(SType, MIRBuilder);
1058 
1059     StructOffsetDecorator Decorator = nullptr;
1060     if (ExplicitLayoutRequired) {
1061       Decorator = [&MIRBuilder, SType, this](Register Reg) {
1062         addStructOffsetDecorations(Reg, const_cast<StructType *>(SType),
1063                                    MIRBuilder);
1064       };
1065     }
1066     return getOpTypeStruct(SType, MIRBuilder, AccQual, Decorator, EmitIR);
1067   }
1068   if (auto FType = dyn_cast<FunctionType>(Ty)) {
1069     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder,
1070                                      AccQual, ExplicitLayoutRequired, EmitIR);
1071     SmallVector<SPIRVType *, 4> ParamTypes;
1072     for (const auto &ParamTy : FType->params())
1073       ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual,
1074                                          ExplicitLayoutRequired, EmitIR));
1075     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
1076   }
1077 
1078   unsigned AddrSpace = typeToAddressSpace(Ty);
1079   SPIRVType *SpvElementType = nullptr;
1080   if (Type *ElemTy = ::getPointeeType(Ty))
1081     SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
1082   else
1083     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
1084 
1085   // Get access to information about available extensions
1086   const SPIRVSubtarget *ST =
1087       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
1088   auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
1089 
1090   Type *ElemTy = ::getPointeeType(Ty);
1091   if (!ElemTy) {
1092     ElemTy = Type::getInt8Ty(MIRBuilder.getContext());
1093   }
1094 
1095   // If we have forward pointer associated with this type, use its register
1096   // operand to create OpTypePointer.
1097   if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end()) {
1098     Register Reg = getSPIRVTypeID(It->second);
1099     // TODO: what does getOpTypePointer do?
1100     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
1101   }
1102 
1103   return getOrCreateSPIRVPointerType(ElemTy, MIRBuilder, SC);
1104 }
1105 
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool ExplicitLayoutRequired,bool EmitIR)1106 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
1107     const Type *Ty, MachineIRBuilder &MIRBuilder,
1108     SPIRV::AccessQualifier::AccessQualifier AccessQual,
1109     bool ExplicitLayoutRequired, bool EmitIR) {
1110   // TODO: Could this create a problem if one requires an explicit layout, and
1111   // the next time it does not?
1112   if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty))
1113     return nullptr;
1114   TypesInProcessing.insert(Ty);
1115   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
1116                                          ExplicitLayoutRequired, EmitIR);
1117   TypesInProcessing.erase(Ty);
1118   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
1119 
1120   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
1121   // Is that a problem?
1122   SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
1123 
1124   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
1125       findMI(Ty, false, &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty))
1126     return SpirvType;
1127 
1128   if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
1129       ExtTy && isTypedPointerWrapper(ExtTy))
1130     add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), SpirvType);
1131   else if (!isPointerTy(Ty))
1132     add(Ty, ExplicitLayoutRequired, SpirvType);
1133   else if (isTypedPointerTy(Ty))
1134     add(cast<TypedPointerType>(Ty)->getElementType(),
1135         getPointerAddressSpace(Ty), SpirvType);
1136   else
1137     add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
1138         getPointerAddressSpace(Ty), SpirvType);
1139   return SpirvType;
1140 }
1141 
1142 SPIRVType *
getSPIRVTypeForVReg(Register VReg,const MachineFunction * MF) const1143 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
1144                                          const MachineFunction *MF) const {
1145   auto t = VRegToTypeMap.find(MF ? MF : CurMF);
1146   if (t != VRegToTypeMap.end()) {
1147     auto tt = t->second.find(VReg);
1148     if (tt != t->second.end())
1149       return tt->second;
1150   }
1151   return nullptr;
1152 }
1153 
getResultType(Register VReg,MachineFunction * MF)1154 SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
1155                                               MachineFunction *MF) {
1156   if (!MF)
1157     MF = CurMF;
1158   MachineInstr *Instr = getVRegDef(MF->getRegInfo(), VReg);
1159   return getSPIRVTypeForVReg(Instr->getOperand(1).getReg(), MF);
1160 }
1161 
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool ExplicitLayoutRequired,bool EmitIR)1162 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
1163     const Type *Ty, MachineIRBuilder &MIRBuilder,
1164     SPIRV::AccessQualifier::AccessQualifier AccessQual,
1165     bool ExplicitLayoutRequired, bool EmitIR) {
1166   const MachineFunction *MF = &MIRBuilder.getMF();
1167   Register Reg;
1168   if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
1169       ExtTy && isTypedPointerWrapper(ExtTy))
1170     Reg = find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), MF);
1171   else if (!isPointerTy(Ty))
1172     Reg = find(Ty = adjustIntTypeByWidth(Ty), ExplicitLayoutRequired, MF);
1173   else if (isTypedPointerTy(Ty))
1174     Reg = find(cast<TypedPointerType>(Ty)->getElementType(),
1175                getPointerAddressSpace(Ty), MF);
1176   else
1177     Reg = find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
1178                getPointerAddressSpace(Ty), MF);
1179   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
1180     return getSPIRVTypeForVReg(Reg);
1181 
1182   TypesInProcessing.clear();
1183   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual,
1184                                          ExplicitLayoutRequired, EmitIR);
1185   // Create normal pointer types for the corresponding OpTypeForwardPointers.
1186   for (auto &CU : ForwardPointerTypes) {
1187     // Pointer type themselves do not require an explicit layout. The types
1188     // they pointer to might, but that is taken care of when creating the type.
1189     bool PtrNeedsLayout = false;
1190     const Type *Ty2 = CU.first;
1191     SPIRVType *STy2 = CU.second;
1192     if ((Reg = find(Ty2, PtrNeedsLayout, MF)).isValid())
1193       STy2 = getSPIRVTypeForVReg(Reg);
1194     else
1195       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, PtrNeedsLayout,
1196                                    EmitIR);
1197     if (Ty == Ty2)
1198       STy = STy2;
1199   }
1200   ForwardPointerTypes.clear();
1201   return STy;
1202 }
1203 
isScalarOfType(Register VReg,unsigned TypeOpcode) const1204 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
1205                                          unsigned TypeOpcode) const {
1206   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1207   assert(Type && "isScalarOfType VReg has no type assigned");
1208   return Type->getOpcode() == TypeOpcode;
1209 }
1210 
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const1211 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
1212                                                  unsigned TypeOpcode) const {
1213   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1214   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
1215   if (Type->getOpcode() == TypeOpcode)
1216     return true;
1217   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1218     Register ScalarTypeVReg = Type->getOperand(1).getReg();
1219     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
1220     return ScalarType->getOpcode() == TypeOpcode;
1221   }
1222   return false;
1223 }
1224 
isResourceType(SPIRVType * Type) const1225 bool SPIRVGlobalRegistry::isResourceType(SPIRVType *Type) const {
1226   switch (Type->getOpcode()) {
1227   case SPIRV::OpTypeImage:
1228   case SPIRV::OpTypeSampler:
1229   case SPIRV::OpTypeSampledImage:
1230     return true;
1231   case SPIRV::OpTypeStruct:
1232     return hasBlockDecoration(Type);
1233   default:
1234     return false;
1235   }
1236   return false;
1237 }
1238 unsigned
getScalarOrVectorComponentCount(Register VReg) const1239 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
1240   return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
1241 }
1242 
1243 unsigned
getScalarOrVectorComponentCount(SPIRVType * Type) const1244 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
1245   if (!Type)
1246     return 0;
1247   return Type->getOpcode() == SPIRV::OpTypeVector
1248              ? static_cast<unsigned>(Type->getOperand(2).getImm())
1249              : 1;
1250 }
1251 
1252 SPIRVType *
getScalarOrVectorComponentType(Register VReg) const1253 SPIRVGlobalRegistry::getScalarOrVectorComponentType(Register VReg) const {
1254   return getScalarOrVectorComponentType(getSPIRVTypeForVReg(VReg));
1255 }
1256 
1257 SPIRVType *
getScalarOrVectorComponentType(SPIRVType * Type) const1258 SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVType *Type) const {
1259   if (!Type)
1260     return nullptr;
1261   Register ScalarReg = Type->getOpcode() == SPIRV::OpTypeVector
1262                            ? Type->getOperand(1).getReg()
1263                            : Type->getOperand(0).getReg();
1264   SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarReg);
1265   assert(isScalarOrVectorOfType(Type->getOperand(0).getReg(),
1266                                 ScalarType->getOpcode()));
1267   return ScalarType;
1268 }
1269 
1270 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const1271 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1272   assert(Type && "Invalid Type pointer");
1273   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1274     auto EleTypeReg = Type->getOperand(1).getReg();
1275     Type = getSPIRVTypeForVReg(EleTypeReg);
1276   }
1277   if (Type->getOpcode() == SPIRV::OpTypeInt ||
1278       Type->getOpcode() == SPIRV::OpTypeFloat)
1279     return Type->getOperand(1).getImm();
1280   if (Type->getOpcode() == SPIRV::OpTypeBool)
1281     return 1;
1282   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1283 }
1284 
getNumScalarOrVectorTotalBitWidth(const SPIRVType * Type) const1285 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1286     const SPIRVType *Type) const {
1287   assert(Type && "Invalid Type pointer");
1288   unsigned NumElements = 1;
1289   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1290     NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
1291     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1292   }
1293   return Type->getOpcode() == SPIRV::OpTypeInt ||
1294                  Type->getOpcode() == SPIRV::OpTypeFloat
1295              ? NumElements * Type->getOperand(1).getImm()
1296              : 0;
1297 }
1298 
retrieveScalarOrVectorIntType(const SPIRVType * Type) const1299 const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1300     const SPIRVType *Type) const {
1301   if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1302     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1303   return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1304 }
1305 
isScalarOrVectorSigned(const SPIRVType * Type) const1306 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1307   const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1308   return IntType && IntType->getOperand(2).getImm() != 0;
1309 }
1310 
getPointeeType(SPIRVType * PtrType)1311 SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
1312   return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1313              ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
1314              : nullptr;
1315 }
1316 
getPointeeTypeOp(Register PtrReg)1317 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1318   SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
1319   return ElemType ? ElemType->getOpcode() : 0;
1320 }
1321 
isBitcastCompatible(const SPIRVType * Type1,const SPIRVType * Type2) const1322 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1323                                               const SPIRVType *Type2) const {
1324   if (!Type1 || !Type2)
1325     return false;
1326   auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1327   // Ignore difference between <1.5 and >=1.5 protocol versions:
1328   // it's valid if either Result Type or Operand is a pointer, and the other
1329   // is a pointer, an integer scalar, or an integer vector.
1330   if (Op1 == SPIRV::OpTypePointer &&
1331       (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
1332     return true;
1333   if (Op2 == SPIRV::OpTypePointer &&
1334       (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
1335     return true;
1336   unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
1337            Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
1338   return Bits1 > 0 && Bits1 == Bits2;
1339 }
1340 
1341 SPIRV::StorageClass::StorageClass
getPointerStorageClass(Register VReg) const1342 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1343   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1344   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1345          Type->getOperand(1).isImm() && "Pointer type is expected");
1346   return getPointerStorageClass(Type);
1347 }
1348 
1349 SPIRV::StorageClass::StorageClass
getPointerStorageClass(const SPIRVType * Type) const1350 SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
1351   return static_cast<SPIRV::StorageClass::StorageClass>(
1352       Type->getOperand(1).getImm());
1353 }
1354 
getOrCreateVulkanBufferType(MachineIRBuilder & MIRBuilder,Type * ElemType,SPIRV::StorageClass::StorageClass SC,bool IsWritable,bool EmitIr)1355 SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
1356     MachineIRBuilder &MIRBuilder, Type *ElemType,
1357     SPIRV::StorageClass::StorageClass SC, bool IsWritable, bool EmitIr) {
1358   auto Key = SPIRV::irhandle_vkbuffer(ElemType, SC, IsWritable);
1359   if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1360     return MI;
1361 
1362   bool ExplicitLayoutRequired = storageClassRequiresExplictLayout(SC);
1363   // We need to get the SPIR-V type for the element here, so we can add the
1364   // decoration to it.
1365   auto *T = StructType::create(ElemType);
1366   auto *BlockType =
1367       getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None,
1368                            ExplicitLayoutRequired, EmitIr);
1369 
1370   buildOpDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
1371                   SPIRV::Decoration::Block, {});
1372 
1373   if (!IsWritable) {
1374     buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
1375                           SPIRV::Decoration::NonWritable, 0, {});
1376   }
1377 
1378   SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BlockType, MIRBuilder, SC);
1379   add(Key, R);
1380   return R;
1381 }
1382 
getOrCreateLayoutType(MachineIRBuilder & MIRBuilder,const TargetExtType * T,bool EmitIr)1383 SPIRVType *SPIRVGlobalRegistry::getOrCreateLayoutType(
1384     MachineIRBuilder &MIRBuilder, const TargetExtType *T, bool EmitIr) {
1385   auto Key = SPIRV::handle(T);
1386   if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1387     return MI;
1388 
1389   StructType *ST = cast<StructType>(T->getTypeParameter(0));
1390   ArrayRef<uint32_t> Offsets = T->int_params().slice(1);
1391   assert(ST->getNumElements() == Offsets.size());
1392 
1393   StructOffsetDecorator Decorator = [&MIRBuilder, &Offsets](Register Reg) {
1394     for (uint32_t I = 0; I < Offsets.size(); ++I) {
1395       buildOpMemberDecorate(Reg, MIRBuilder, SPIRV::Decoration::Offset, I,
1396                             {Offsets[I]});
1397     }
1398   };
1399 
1400   // We need a new OpTypeStruct instruction because decorations will be
1401   // different from a struct with an explicit layout created from a different
1402   // entry point.
1403   SPIRVType *SPIRVStructType = getOpTypeStruct(
1404       ST, MIRBuilder, SPIRV::AccessQualifier::None, Decorator, EmitIr);
1405   add(Key, SPIRVStructType);
1406   return SPIRVStructType;
1407 }
1408 
getImageType(const TargetExtType * ExtensionType,const SPIRV::AccessQualifier::AccessQualifier Qualifier,MachineIRBuilder & MIRBuilder)1409 SPIRVType *SPIRVGlobalRegistry::getImageType(
1410     const TargetExtType *ExtensionType,
1411     const SPIRV::AccessQualifier::AccessQualifier Qualifier,
1412     MachineIRBuilder &MIRBuilder) {
1413   assert(ExtensionType->getNumTypeParameters() == 1 &&
1414          "SPIR-V image builtin type must have sampled type parameter!");
1415   const SPIRVType *SampledType =
1416       getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
1417                            SPIRV::AccessQualifier::ReadWrite, true);
1418   assert((ExtensionType->getNumIntParameters() == 7 ||
1419           ExtensionType->getNumIntParameters() == 6) &&
1420          "Invalid number of parameters for SPIR-V image builtin!");
1421 
1422   SPIRV::AccessQualifier::AccessQualifier accessQualifier =
1423       SPIRV::AccessQualifier::None;
1424   if (ExtensionType->getNumIntParameters() == 7) {
1425     accessQualifier = Qualifier == SPIRV::AccessQualifier::WriteOnly
1426                           ? SPIRV::AccessQualifier::WriteOnly
1427                           : SPIRV::AccessQualifier::AccessQualifier(
1428                                 ExtensionType->getIntParameter(6));
1429   }
1430 
1431   // Create or get an existing type from GlobalRegistry.
1432   SPIRVType *R = getOrCreateOpTypeImage(
1433       MIRBuilder, SampledType,
1434       SPIRV::Dim::Dim(ExtensionType->getIntParameter(0)),
1435       ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
1436       ExtensionType->getIntParameter(3), ExtensionType->getIntParameter(4),
1437       SPIRV::ImageFormat::ImageFormat(ExtensionType->getIntParameter(5)),
1438       accessQualifier);
1439   SPIRVToLLVMType[R] = ExtensionType;
1440   return R;
1441 }
1442 
getOrCreateOpTypeImage(MachineIRBuilder & MIRBuilder,SPIRVType * SampledType,SPIRV::Dim::Dim Dim,uint32_t Depth,uint32_t Arrayed,uint32_t Multisampled,uint32_t Sampled,SPIRV::ImageFormat::ImageFormat ImageFormat,SPIRV::AccessQualifier::AccessQualifier AccessQual)1443 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1444     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1445     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1446     SPIRV::ImageFormat::ImageFormat ImageFormat,
1447     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1448   auto Key = SPIRV::irhandle_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1449                                    Depth, Arrayed, Multisampled, Sampled,
1450                                    ImageFormat, AccessQual);
1451   if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1452     return MI;
1453   const MachineInstr *NewMI =
1454       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1455         auto MIB =
1456             MIRBuilder.buildInstr(SPIRV::OpTypeImage)
1457                 .addDef(createTypeVReg(MIRBuilder))
1458                 .addUse(getSPIRVTypeID(SampledType))
1459                 .addImm(Dim)
1460                 .addImm(Depth)   // Depth (whether or not it is a Depth image).
1461                 .addImm(Arrayed) // Arrayed.
1462                 .addImm(Multisampled) // Multisampled (0 = only single-sample).
1463                 .addImm(Sampled)      // Sampled (0 = usage known at runtime).
1464                 .addImm(ImageFormat);
1465         if (AccessQual != SPIRV::AccessQualifier::None)
1466           MIB.addImm(AccessQual);
1467         return MIB;
1468       });
1469   add(Key, NewMI);
1470   return NewMI;
1471 }
1472 
1473 SPIRVType *
getOrCreateOpTypeSampler(MachineIRBuilder & MIRBuilder)1474 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1475   auto Key = SPIRV::irhandle_sampler();
1476   const MachineFunction *MF = &MIRBuilder.getMF();
1477   if (const MachineInstr *MI = findMI(Key, MF))
1478     return MI;
1479   const MachineInstr *NewMI =
1480       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1481         return MIRBuilder.buildInstr(SPIRV::OpTypeSampler)
1482             .addDef(createTypeVReg(MIRBuilder));
1483       });
1484   add(Key, NewMI);
1485   return NewMI;
1486 }
1487 
getOrCreateOpTypePipe(MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual)1488 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1489     MachineIRBuilder &MIRBuilder,
1490     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1491   auto Key = SPIRV::irhandle_pipe(AccessQual);
1492   if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1493     return MI;
1494   const MachineInstr *NewMI =
1495       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1496         return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
1497             .addDef(createTypeVReg(MIRBuilder))
1498             .addImm(AccessQual);
1499       });
1500   add(Key, NewMI);
1501   return NewMI;
1502 }
1503 
getOrCreateOpTypeDeviceEvent(MachineIRBuilder & MIRBuilder)1504 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1505     MachineIRBuilder &MIRBuilder) {
1506   auto Key = SPIRV::irhandle_event();
1507   if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1508     return MI;
1509   const MachineInstr *NewMI =
1510       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1511         return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent)
1512             .addDef(createTypeVReg(MIRBuilder));
1513       });
1514   add(Key, NewMI);
1515   return NewMI;
1516 }
1517 
getOrCreateOpTypeSampledImage(SPIRVType * ImageType,MachineIRBuilder & MIRBuilder)1518 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1519     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1520   auto Key = SPIRV::irhandle_sampled_image(
1521       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
1522           ImageType->getOperand(1).getReg())),
1523       ImageType);
1524   if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
1525     return MI;
1526   const MachineInstr *NewMI =
1527       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1528         return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
1529             .addDef(createTypeVReg(MIRBuilder))
1530             .addUse(getSPIRVTypeID(ImageType));
1531       });
1532   add(Key, NewMI);
1533   return NewMI;
1534 }
1535 
getOrCreateOpTypeCoopMatr(MachineIRBuilder & MIRBuilder,const TargetExtType * ExtensionType,const SPIRVType * ElemType,uint32_t Scope,uint32_t Rows,uint32_t Columns,uint32_t Use,bool EmitIR)1536 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1537     MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
1538     const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
1539     uint32_t Use, bool EmitIR) {
1540   if (const MachineInstr *MI =
1541           findMI(ExtensionType, false, &MIRBuilder.getMF()))
1542     return MI;
1543   const MachineInstr *NewMI =
1544       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1545         SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
1546         const Type *ET = getTypeForSPIRVType(ElemType);
1547         if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
1548             cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
1549                 .canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1550           MIRBuilder.buildInstr(SPIRV::OpCapability)
1551               .addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
1552         }
1553         return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
1554             .addDef(createTypeVReg(MIRBuilder))
1555             .addUse(getSPIRVTypeID(ElemType))
1556             .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, EmitIR))
1557             .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, EmitIR))
1558             .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, EmitIR))
1559             .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, EmitIR));
1560       });
1561   add(ExtensionType, false, NewMI);
1562   return NewMI;
1563 }
1564 
getOrCreateOpTypeByOpcode(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode)1565 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1566     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1567   if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
1568     return MI;
1569   const MachineInstr *NewMI =
1570       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1571         return MIRBuilder.buildInstr(Opcode).addDef(createTypeVReg(MIRBuilder));
1572       });
1573   add(Ty, false, NewMI);
1574   return NewMI;
1575 }
1576 
getOrCreateUnknownType(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode,const ArrayRef<MCOperand> Operands)1577 SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
1578     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
1579     const ArrayRef<MCOperand> Operands) {
1580   if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
1581     return MI;
1582   Register ResVReg = createTypeVReg(MIRBuilder);
1583   const MachineInstr *NewMI =
1584       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1585         MachineInstrBuilder MIB = MIRBuilder.buildInstr(SPIRV::UNKNOWN_type)
1586                                       .addDef(ResVReg)
1587                                       .addImm(Opcode);
1588         for (MCOperand Operand : Operands) {
1589           if (Operand.isReg()) {
1590             MIB.addUse(Operand.getReg());
1591           } else if (Operand.isImm()) {
1592             MIB.addImm(Operand.getImm());
1593           }
1594         }
1595         return MIB;
1596       });
1597   add(Ty, false, NewMI);
1598   return NewMI;
1599 }
1600 
1601 // Returns nullptr if unable to recognize SPIRV type name
getOrCreateSPIRVTypeByName(StringRef TypeStr,MachineIRBuilder & MIRBuilder,bool EmitIR,SPIRV::StorageClass::StorageClass SC,SPIRV::AccessQualifier::AccessQualifier AQ)1602 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1603     StringRef TypeStr, MachineIRBuilder &MIRBuilder, bool EmitIR,
1604     SPIRV::StorageClass::StorageClass SC,
1605     SPIRV::AccessQualifier::AccessQualifier AQ) {
1606   unsigned VecElts = 0;
1607   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1608 
1609   // Parse strings representing either a SPIR-V or OpenCL builtin type.
1610   if (hasBuiltinTypePrefix(TypeStr))
1611     return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1612                                     TypeStr.str(), MIRBuilder.getContext()),
1613                                 MIRBuilder, AQ, false, true);
1614 
1615   // Parse type name in either "typeN" or "type vector[N]" format, where
1616   // N is the number of elements of the vector.
1617   Type *Ty;
1618 
1619   Ty = parseBasicTypeName(TypeStr, Ctx);
1620   if (!Ty)
1621     // Unable to recognize SPIRV type name
1622     return nullptr;
1623 
1624   const SPIRVType *SpirvTy =
1625       getOrCreateSPIRVType(Ty, MIRBuilder, AQ, false, true);
1626 
1627   // Handle "type*" or  "type* vector[N]".
1628   if (TypeStr.consume_front("*"))
1629     SpirvTy = getOrCreateSPIRVPointerType(Ty, MIRBuilder, SC);
1630 
1631   // Handle "typeN*" or  "type vector[N]*".
1632   bool IsPtrToVec = TypeStr.consume_back("*");
1633 
1634   if (TypeStr.consume_front(" vector[")) {
1635     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1636   }
1637   TypeStr.getAsInteger(10, VecElts);
1638   if (VecElts > 0)
1639     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder, EmitIR);
1640 
1641   if (IsPtrToVec)
1642     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1643 
1644   return SpirvTy;
1645 }
1646 
1647 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)1648 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1649                                                  MachineIRBuilder &MIRBuilder) {
1650   return getOrCreateSPIRVType(
1651       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1652       MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, true);
1653 }
1654 
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)1655 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1656                                                         SPIRVType *SpirvType) {
1657   assert(CurMF == SpirvType->getMF());
1658   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1659   SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
1660   return SpirvType;
1661 }
1662 
getOrCreateSPIRVType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII,unsigned SPIRVOPcode,Type * Ty)1663 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1664                                                      MachineInstr &I,
1665                                                      const SPIRVInstrInfo &TII,
1666                                                      unsigned SPIRVOPcode,
1667                                                      Type *Ty) {
1668   if (const MachineInstr *MI = findMI(Ty, false, CurMF))
1669     return MI;
1670   MachineBasicBlock &DepMBB = I.getMF()->front();
1671   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
1672   const MachineInstr *NewMI =
1673       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1674         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1675                        MIRBuilder.getDL(), TII.get(SPIRVOPcode))
1676             .addDef(createTypeVReg(CurMF->getRegInfo()))
1677             .addImm(BitWidth)
1678             .addImm(0);
1679       });
1680   add(Ty, false, NewMI);
1681   return finishCreatingSPIRVType(Ty, NewMI);
1682 }
1683 
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1684 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1685     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1686   // Maybe adjust bit width to keep DuplicateTracker consistent. Without
1687   // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1688   // example, the same "OpTypeInt 8" type for a series of LLVM integer types
1689   // with number of bits less than 8, causing duplicate type definitions.
1690   if (BitWidth > 1)
1691     BitWidth = adjustOpTypeIntWidth(BitWidth);
1692   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1693   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1694 }
1695 
getOrCreateSPIRVFloatType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1696 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1697     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1698   LLVMContext &Ctx = CurMF->getFunction().getContext();
1699   Type *LLVMTy;
1700   switch (BitWidth) {
1701   case 16:
1702     LLVMTy = Type::getHalfTy(Ctx);
1703     break;
1704   case 32:
1705     LLVMTy = Type::getFloatTy(Ctx);
1706     break;
1707   case 64:
1708     LLVMTy = Type::getDoubleTy(Ctx);
1709     break;
1710   default:
1711     llvm_unreachable("Bit width is of unexpected size.");
1712   }
1713   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1714 }
1715 
1716 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder,bool EmitIR)1717 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder,
1718                                               bool EmitIR) {
1719   return getOrCreateSPIRVType(
1720       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1721       MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR);
1722 }
1723 
1724 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)1725 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1726                                               const SPIRVInstrInfo &TII) {
1727   Type *Ty = IntegerType::get(CurMF->getFunction().getContext(), 1);
1728   if (const MachineInstr *MI = findMI(Ty, false, CurMF))
1729     return MI;
1730   MachineBasicBlock &DepMBB = I.getMF()->front();
1731   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
1732   const MachineInstr *NewMI =
1733       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1734         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1735                        MIRBuilder.getDL(), TII.get(SPIRV::OpTypeBool))
1736             .addDef(createTypeVReg(CurMF->getRegInfo()));
1737       });
1738   add(Ty, false, NewMI);
1739   return finishCreatingSPIRVType(Ty, NewMI);
1740 }
1741 
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder,bool EmitIR)1742 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1743     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder,
1744     bool EmitIR) {
1745   return getOrCreateSPIRVType(
1746       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1747                            NumElements),
1748       MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR);
1749 }
1750 
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1751 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1752     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1753     const SPIRVInstrInfo &TII) {
1754   Type *Ty = FixedVectorType::get(
1755       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1756   if (const MachineInstr *MI = findMI(Ty, false, CurMF))
1757     return MI;
1758   MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
1759   MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1760   const MachineInstr *NewMI =
1761       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1762         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1763                        MIRBuilder.getDL(), TII.get(SPIRV::OpTypeVector))
1764             .addDef(createTypeVReg(CurMF->getRegInfo()))
1765             .addUse(getSPIRVTypeID(BaseType))
1766             .addImm(NumElements);
1767       });
1768   add(Ty, false, NewMI);
1769   return finishCreatingSPIRVType(Ty, NewMI);
1770 }
1771 
getOrCreateSPIRVPointerType(const Type * BaseType,MachineInstr & I,SPIRV::StorageClass::StorageClass SC)1772 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1773     const Type *BaseType, MachineInstr &I,
1774     SPIRV::StorageClass::StorageClass SC) {
1775   MachineIRBuilder MIRBuilder(I);
1776   return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1777 }
1778 
getOrCreateSPIRVPointerType(const Type * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1779 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1780     const Type *BaseType, MachineIRBuilder &MIRBuilder,
1781     SPIRV::StorageClass::StorageClass SC) {
1782   // TODO: Need to check if EmitIr should always be true.
1783   SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
1784       BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
1785       storageClassRequiresExplictLayout(SC), true);
1786   assert(SpirvBaseType);
1787   return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
1788 }
1789 
changePointerStorageClass(SPIRVType * PtrType,SPIRV::StorageClass::StorageClass SC,MachineInstr & I)1790 SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
1791     SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
1792   [[maybe_unused]] SPIRV::StorageClass::StorageClass OldSC =
1793       getPointerStorageClass(PtrType);
1794   assert(storageClassRequiresExplictLayout(OldSC) ==
1795          storageClassRequiresExplictLayout(SC));
1796 
1797   SPIRVType *PointeeType = getPointeeType(PtrType);
1798   MachineIRBuilder MIRBuilder(I);
1799   return getOrCreateSPIRVPointerTypeInternal(PointeeType, MIRBuilder, SC);
1800 }
1801 
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1802 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1803     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1804     SPIRV::StorageClass::StorageClass SC) {
1805   const Type *LLVMType = getTypeForSPIRVType(BaseType);
1806   assert(!storageClassRequiresExplictLayout(SC));
1807   SPIRVType *R = getOrCreateSPIRVPointerType(LLVMType, MIRBuilder, SC);
1808   assert(
1809       getPointeeType(R) == BaseType &&
1810       "The base type was not correctly laid out for the given storage class.");
1811   return R;
1812 }
1813 
getOrCreateSPIRVPointerTypeInternal(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1814 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
1815     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1816     SPIRV::StorageClass::StorageClass SC) {
1817   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1818   unsigned AddressSpace = storageClassToAddressSpace(SC);
1819   if (const MachineInstr *MI = findMI(PointerElementType, AddressSpace, CurMF))
1820     return MI;
1821   Type *Ty = TypedPointerType::get(const_cast<Type *>(PointerElementType),
1822                                    AddressSpace);
1823   const MachineInstr *NewMI =
1824       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1825         return BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1826                        MIRBuilder.getDebugLoc(),
1827                        MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1828             .addDef(createTypeVReg(CurMF->getRegInfo()))
1829             .addImm(static_cast<uint32_t>(SC))
1830             .addUse(getSPIRVTypeID(BaseType));
1831       });
1832   add(PointerElementType, AddressSpace, NewMI);
1833   return finishCreatingSPIRVType(Ty, NewMI);
1834 }
1835 
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)1836 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1837                                                SPIRVType *SpvType,
1838                                                const SPIRVInstrInfo &TII) {
1839   UndefValue *UV =
1840       UndefValue::get(const_cast<Type *>(getTypeForSPIRVType(SpvType)));
1841   Register Res = find(UV, CurMF);
1842   if (Res.isValid())
1843     return Res;
1844 
1845   LLT LLTy = LLT::scalar(64);
1846   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1847   CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
1848   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1849 
1850   MachineInstr *DepMI = const_cast<MachineInstr *>(SpvType);
1851   MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
1852   const MachineInstr *NewMI =
1853       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
1854         auto MIB = BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
1855                            MIRBuilder.getDL(), TII.get(SPIRV::OpUndef))
1856                        .addDef(Res)
1857                        .addUse(getSPIRVTypeID(SpvType));
1858         const auto &ST = CurMF->getSubtarget();
1859         constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1860                                          *ST.getRegisterInfo(),
1861                                          *ST.getRegBankInfo());
1862         return MIB;
1863       });
1864   add(UV, NewMI);
1865   return Res;
1866 }
1867 
1868 const TargetRegisterClass *
getRegClass(SPIRVType * SpvType) const1869 SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
1870   unsigned Opcode = SpvType->getOpcode();
1871   switch (Opcode) {
1872   case SPIRV::OpTypeFloat:
1873     return &SPIRV::fIDRegClass;
1874   case SPIRV::OpTypePointer:
1875     return &SPIRV::pIDRegClass;
1876   case SPIRV::OpTypeVector: {
1877     SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
1878     unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
1879     if (ElemOpcode == SPIRV::OpTypeFloat)
1880       return &SPIRV::vfIDRegClass;
1881     if (ElemOpcode == SPIRV::OpTypePointer)
1882       return &SPIRV::vpIDRegClass;
1883     return &SPIRV::vIDRegClass;
1884   }
1885   }
1886   return &SPIRV::iIDRegClass;
1887 }
1888 
getAS(SPIRVType * SpvType)1889 inline unsigned getAS(SPIRVType *SpvType) {
1890   return storageClassToAddressSpace(
1891       static_cast<SPIRV::StorageClass::StorageClass>(
1892           SpvType->getOperand(1).getImm()));
1893 }
1894 
getRegType(SPIRVType * SpvType) const1895 LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
1896   unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
1897   switch (Opcode) {
1898   case SPIRV::OpTypeInt:
1899   case SPIRV::OpTypeFloat:
1900   case SPIRV::OpTypeBool:
1901     return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
1902   case SPIRV::OpTypePointer:
1903     return LLT::pointer(getAS(SpvType), getPointerSize());
1904   case SPIRV::OpTypeVector: {
1905     SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
1906     LLT ET;
1907     switch (ElemType ? ElemType->getOpcode() : 0) {
1908     case SPIRV::OpTypePointer:
1909       ET = LLT::pointer(getAS(ElemType), getPointerSize());
1910       break;
1911     case SPIRV::OpTypeInt:
1912     case SPIRV::OpTypeFloat:
1913     case SPIRV::OpTypeBool:
1914       ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType));
1915       break;
1916     default:
1917       ET = LLT::scalar(64);
1918     }
1919     return LLT::fixed_vector(
1920         static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
1921   }
1922   }
1923   return LLT::scalar(64);
1924 }
1925 
1926 // Aliasing list MD contains several scope MD nodes whithin it. Each scope MD
1927 // has a selfreference and an extra MD node for aliasing domain and also it
1928 // can contain an optional string operand. Domain MD contains a self-reference
1929 // with an optional string operand. Here we unfold the list, creating SPIR-V
1930 // aliasing instructions.
1931 // TODO: add support for an optional string operand.
getOrAddMemAliasingINTELInst(MachineIRBuilder & MIRBuilder,const MDNode * AliasingListMD)1932 MachineInstr *SPIRVGlobalRegistry::getOrAddMemAliasingINTELInst(
1933     MachineIRBuilder &MIRBuilder, const MDNode *AliasingListMD) {
1934   if (AliasingListMD->getNumOperands() == 0)
1935     return nullptr;
1936   if (auto L = AliasInstMDMap.find(AliasingListMD); L != AliasInstMDMap.end())
1937     return L->second;
1938 
1939   SmallVector<MachineInstr *> ScopeList;
1940   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1941   for (const MDOperand &MDListOp : AliasingListMD->operands()) {
1942     if (MDNode *ScopeMD = dyn_cast<MDNode>(MDListOp)) {
1943       if (ScopeMD->getNumOperands() < 2)
1944         return nullptr;
1945       MDNode *DomainMD = dyn_cast<MDNode>(ScopeMD->getOperand(1));
1946       if (!DomainMD)
1947         return nullptr;
1948       auto *Domain = [&] {
1949         auto D = AliasInstMDMap.find(DomainMD);
1950         if (D != AliasInstMDMap.end())
1951           return D->second;
1952         const Register Ret = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1953         auto MIB =
1954             MIRBuilder.buildInstr(SPIRV::OpAliasDomainDeclINTEL).addDef(Ret);
1955         return MIB.getInstr();
1956       }();
1957       AliasInstMDMap.insert(std::make_pair(DomainMD, Domain));
1958       auto *Scope = [&] {
1959         auto S = AliasInstMDMap.find(ScopeMD);
1960         if (S != AliasInstMDMap.end())
1961           return S->second;
1962         const Register Ret = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1963         auto MIB = MIRBuilder.buildInstr(SPIRV::OpAliasScopeDeclINTEL)
1964                        .addDef(Ret)
1965                        .addUse(Domain->getOperand(0).getReg());
1966         return MIB.getInstr();
1967       }();
1968       AliasInstMDMap.insert(std::make_pair(ScopeMD, Scope));
1969       ScopeList.push_back(Scope);
1970     }
1971   }
1972 
1973   const Register Ret = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1974   auto MIB =
1975       MIRBuilder.buildInstr(SPIRV::OpAliasScopeListDeclINTEL).addDef(Ret);
1976   for (auto *Scope : ScopeList)
1977     MIB.addUse(Scope->getOperand(0).getReg());
1978   auto List = MIB.getInstr();
1979   AliasInstMDMap.insert(std::make_pair(AliasingListMD, List));
1980   return List;
1981 }
1982 
buildMemAliasingOpDecorate(Register Reg,MachineIRBuilder & MIRBuilder,uint32_t Dec,const MDNode * AliasingListMD)1983 void SPIRVGlobalRegistry::buildMemAliasingOpDecorate(
1984     Register Reg, MachineIRBuilder &MIRBuilder, uint32_t Dec,
1985     const MDNode *AliasingListMD) {
1986   MachineInstr *AliasList =
1987       getOrAddMemAliasingINTELInst(MIRBuilder, AliasingListMD);
1988   if (!AliasList)
1989     return;
1990   MIRBuilder.buildInstr(SPIRV::OpDecorate)
1991       .addUse(Reg)
1992       .addImm(Dec)
1993       .addUse(AliasList->getOperand(0).getReg());
1994 }
replaceAllUsesWith(Value * Old,Value * New,bool DeleteOld)1995 void SPIRVGlobalRegistry::replaceAllUsesWith(Value *Old, Value *New,
1996                                              bool DeleteOld) {
1997   Old->replaceAllUsesWith(New);
1998   updateIfExistDeducedElementType(Old, New, DeleteOld);
1999   updateIfExistAssignPtrTypeInstr(Old, New, DeleteOld);
2000 }
2001 
buildAssignType(IRBuilder<> & B,Type * Ty,Value * Arg)2002 void SPIRVGlobalRegistry::buildAssignType(IRBuilder<> &B, Type *Ty,
2003                                           Value *Arg) {
2004   Value *OfType = getNormalizedPoisonValue(Ty);
2005   CallInst *AssignCI = nullptr;
2006   if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
2007       allowEmitFakeUse(Arg)) {
2008     LLVMContext &Ctx = Arg->getContext();
2009     SmallVector<Metadata *, 2> ArgMDs{
2010         MDNode::get(Ctx, ValueAsMetadata::getConstant(OfType)),
2011         MDString::get(Ctx, Arg->getName())};
2012     B.CreateIntrinsic(Intrinsic::spv_value_md,
2013                       {MetadataAsValue::get(Ctx, MDTuple::get(Ctx, ArgMDs))});
2014     AssignCI = B.CreateIntrinsic(Intrinsic::fake_use, {Arg});
2015   } else {
2016     AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type, {Arg->getType()},
2017                                OfType, Arg, {}, B);
2018   }
2019   addAssignPtrTypeInstr(Arg, AssignCI);
2020 }
2021 
buildAssignPtr(IRBuilder<> & B,Type * ElemTy,Value * Arg)2022 void SPIRVGlobalRegistry::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
2023                                          Value *Arg) {
2024   Value *OfType = PoisonValue::get(ElemTy);
2025   CallInst *AssignPtrTyCI = findAssignPtrTypeInstr(Arg);
2026   Function *CurrF =
2027       B.GetInsertBlock() ? B.GetInsertBlock()->getParent() : nullptr;
2028   if (AssignPtrTyCI == nullptr ||
2029       AssignPtrTyCI->getParent()->getParent() != CurrF) {
2030     AssignPtrTyCI = buildIntrWithMD(
2031         Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
2032         {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
2033     addDeducedElementType(AssignPtrTyCI, ElemTy);
2034     addDeducedElementType(Arg, ElemTy);
2035     addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
2036   } else {
2037     updateAssignType(AssignPtrTyCI, Arg, OfType);
2038   }
2039 }
2040 
updateAssignType(CallInst * AssignCI,Value * Arg,Value * OfType)2041 void SPIRVGlobalRegistry::updateAssignType(CallInst *AssignCI, Value *Arg,
2042                                            Value *OfType) {
2043   AssignCI->setArgOperand(1, buildMD(OfType));
2044   if (cast<IntrinsicInst>(AssignCI)->getIntrinsicID() !=
2045       Intrinsic::spv_assign_ptr_type)
2046     return;
2047 
2048   // update association with the pointee type
2049   Type *ElemTy = OfType->getType();
2050   addDeducedElementType(AssignCI, ElemTy);
2051   addDeducedElementType(Arg, ElemTy);
2052 }
2053 
addStructOffsetDecorations(Register Reg,StructType * Ty,MachineIRBuilder & MIRBuilder)2054 void SPIRVGlobalRegistry::addStructOffsetDecorations(
2055     Register Reg, StructType *Ty, MachineIRBuilder &MIRBuilder) {
2056   DataLayout DL;
2057   ArrayRef<TypeSize> Offsets = DL.getStructLayout(Ty)->getMemberOffsets();
2058   for (uint32_t I = 0; I < Ty->getNumElements(); ++I) {
2059     buildOpMemberDecorate(Reg, MIRBuilder, SPIRV::Decoration::Offset, I,
2060                           {static_cast<uint32_t>(Offsets[I])});
2061   }
2062 }
2063 
addArrayStrideDecorations(Register Reg,Type * ElementType,MachineIRBuilder & MIRBuilder)2064 void SPIRVGlobalRegistry::addArrayStrideDecorations(
2065     Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) {
2066   uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(ElementType) / 8;
2067   buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::ArrayStride,
2068                   {SizeInBytes});
2069 }
2070 
hasBlockDecoration(SPIRVType * Type) const2071 bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
2072   Register Def = getSPIRVTypeID(Type);
2073   for (const MachineInstr &Use :
2074        Type->getMF()->getRegInfo().use_instructions(Def)) {
2075     if (Use.getOpcode() != SPIRV::OpDecorate)
2076       continue;
2077 
2078     if (Use.getOperand(1).getImm() == SPIRV::Decoration::Block)
2079       return true;
2080   }
2081   return false;
2082 }
2083