xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (revision 258a0d760aa8b42899a000e30f610f900a402556)
1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 
24 using namespace llvm;
25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
26     : PointerSize(PointerSize) {}
27 
28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
29                                                     Register VReg,
30                                                     MachineInstr &I,
31                                                     const SPIRVInstrInfo &TII) {
32   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
33   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
34   return SpirvType;
35 }
36 
37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
38     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
39     const SPIRVInstrInfo &TII) {
40   SPIRVType *SpirvType =
41       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
42   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
43   return SpirvType;
44 }
45 
46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
47     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
48     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
49 
50   SPIRVType *SpirvType =
51       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
52   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
53   return SpirvType;
54 }
55 
56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
57                                                 Register VReg,
58                                                 MachineFunction &MF) {
59   VRegToTypeMap[&MF][VReg] = SpirvType;
60 }
61 
62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
63   auto &MRI = MIRBuilder.getMF().getRegInfo();
64   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
65   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
66   return Res;
67 }
68 
69 static Register createTypeVReg(MachineRegisterInfo &MRI) {
70   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
71   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
72   return Res;
73 }
74 
75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
76   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
77       .addDef(createTypeVReg(MIRBuilder));
78 }
79 
80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
81                                              MachineIRBuilder &MIRBuilder,
82                                              bool IsSigned) {
83   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
84                  .addDef(createTypeVReg(MIRBuilder))
85                  .addImm(Width)
86                  .addImm(IsSigned ? 1 : 0);
87   return MIB;
88 }
89 
90 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
91                                                MachineIRBuilder &MIRBuilder) {
92   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
93                  .addDef(createTypeVReg(MIRBuilder))
94                  .addImm(Width);
95   return MIB;
96 }
97 
98 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
99   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
100       .addDef(createTypeVReg(MIRBuilder));
101 }
102 
103 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
104                                                 SPIRVType *ElemType,
105                                                 MachineIRBuilder &MIRBuilder) {
106   auto EleOpc = ElemType->getOpcode();
107   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
108           EleOpc == SPIRV::OpTypeBool) &&
109          "Invalid vector element type");
110 
111   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
112                  .addDef(createTypeVReg(MIRBuilder))
113                  .addUse(getSPIRVTypeID(ElemType))
114                  .addImm(NumElems);
115   return MIB;
116 }
117 
118 std::tuple<Register, ConstantInt *, bool>
119 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
120                                             MachineIRBuilder *MIRBuilder,
121                                             MachineInstr *I,
122                                             const SPIRVInstrInfo *TII) {
123   const IntegerType *LLVMIntTy;
124   if (SpvType)
125     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
126   else
127     LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
128   bool NewInstr = false;
129   // Find a constant in DT or build a new one.
130   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
131   Register Res = DT.find(CI, CurMF);
132   if (!Res.isValid()) {
133     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
134     LLT LLTy = LLT::scalar(32);
135     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
136     if (MIRBuilder)
137       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
138     else
139       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
140     DT.add(CI, CurMF, Res);
141     NewInstr = true;
142   }
143   return std::make_tuple(Res, CI, NewInstr);
144 }
145 
146 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
147                                                   SPIRVType *SpvType,
148                                                   const SPIRVInstrInfo &TII) {
149   assert(SpvType);
150   ConstantInt *CI;
151   Register Res;
152   bool New;
153   std::tie(Res, CI, New) =
154       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
155   // If we have found Res register which is defined by the passed G_CONSTANT
156   // machine instruction, a new constant instruction should be created.
157   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
158     return Res;
159   MachineInstrBuilder MIB;
160   MachineBasicBlock &BB = *I.getParent();
161   if (Val) {
162     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
163               .addDef(Res)
164               .addUse(getSPIRVTypeID(SpvType));
165     addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
166   } else {
167     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
168               .addDef(Res)
169               .addUse(getSPIRVTypeID(SpvType));
170   }
171   const auto &ST = CurMF->getSubtarget();
172   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
173                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
174   return Res;
175 }
176 
177 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
178                                                MachineIRBuilder &MIRBuilder,
179                                                SPIRVType *SpvType,
180                                                bool EmitIR) {
181   auto &MF = MIRBuilder.getMF();
182   const IntegerType *LLVMIntTy;
183   if (SpvType)
184     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
185   else
186     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
187   // Find a constant in DT or build a new one.
188   const auto ConstInt =
189       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
190   Register Res = DT.find(ConstInt, &MF);
191   if (!Res.isValid()) {
192     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
193     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
194     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
195     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
196                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
197     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
198     if (EmitIR) {
199       MIRBuilder.buildConstant(Res, *ConstInt);
200     } else {
201       MachineInstrBuilder MIB;
202       if (Val) {
203         assert(SpvType);
204         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
205                   .addDef(Res)
206                   .addUse(getSPIRVTypeID(SpvType));
207         addNumImm(APInt(BitWidth, Val), MIB);
208       } else {
209         assert(SpvType);
210         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
211                   .addDef(Res)
212                   .addUse(getSPIRVTypeID(SpvType));
213       }
214       const auto &Subtarget = CurMF->getSubtarget();
215       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
216                                        *Subtarget.getRegisterInfo(),
217                                        *Subtarget.getRegBankInfo());
218     }
219   }
220   return Res;
221 }
222 
223 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
224                                               MachineIRBuilder &MIRBuilder,
225                                               SPIRVType *SpvType) {
226   auto &MF = MIRBuilder.getMF();
227   const Type *LLVMFPTy;
228   if (SpvType) {
229     LLVMFPTy = getTypeForSPIRVType(SpvType);
230     assert(LLVMFPTy->isFloatingPointTy());
231   } else {
232     LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
233   }
234   // Find a constant in DT or build a new one.
235   const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
236   Register Res = DT.find(ConstFP, &MF);
237   if (!Res.isValid()) {
238     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
239     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
240     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
241     DT.add(ConstFP, &MF, Res);
242     MIRBuilder.buildFConstant(Res, *ConstFP);
243   }
244   return Res;
245 }
246 
247 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
248     uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
249     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
250     unsigned ElemCnt) {
251   // Find a constant vector in DT or build a new one.
252   Register Res = DT.find(CA, CurMF);
253   if (!Res.isValid()) {
254     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
255     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
256     // error on validation.
257     // TODO: can moved below once sorting of types/consts/defs is implemented.
258     Register SpvScalConst;
259     if (Val)
260       SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
261     // TODO: maybe use bitwidth of base type.
262     LLT LLTy = LLT::scalar(32);
263     Register SpvVecConst =
264         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
265     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
266     DT.add(CA, CurMF, SpvVecConst);
267     MachineInstrBuilder MIB;
268     MachineBasicBlock &BB = *I.getParent();
269     if (Val) {
270       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
271                 .addDef(SpvVecConst)
272                 .addUse(getSPIRVTypeID(SpvType));
273       for (unsigned i = 0; i < ElemCnt; ++i)
274         MIB.addUse(SpvScalConst);
275     } else {
276       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
277                 .addDef(SpvVecConst)
278                 .addUse(getSPIRVTypeID(SpvType));
279     }
280     const auto &Subtarget = CurMF->getSubtarget();
281     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
282                                      *Subtarget.getRegisterInfo(),
283                                      *Subtarget.getRegBankInfo());
284     return SpvVecConst;
285   }
286   return Res;
287 }
288 
289 Register
290 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
291                                               SPIRVType *SpvType,
292                                               const SPIRVInstrInfo &TII) {
293   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
294   assert(LLVMTy->isVectorTy());
295   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
296   Type *LLVMBaseTy = LLVMVecTy->getElementType();
297   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
298   auto ConstVec =
299       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
300   unsigned BW = getScalarOrVectorBitWidth(SpvType);
301   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
302                                        SpvType->getOperand(2).getImm());
303 }
304 
305 Register
306 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
307                                              SPIRVType *SpvType,
308                                              const SPIRVInstrInfo &TII) {
309   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
310   assert(LLVMTy->isArrayTy());
311   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
312   Type *LLVMBaseTy = LLVMArrTy->getElementType();
313   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
314   auto ConstArr =
315       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
316   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
317   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
318   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
319                                        LLVMArrTy->getNumElements());
320 }
321 
322 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
323     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
324     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
325   Register Res = DT.find(CA, CurMF);
326   if (!Res.isValid()) {
327     Register SpvScalConst;
328     if (Val || EmitIR) {
329       SPIRVType *SpvBaseType =
330           getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
331       SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
332     }
333     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
334     Register SpvVecConst =
335         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
336     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
337     DT.add(CA, CurMF, SpvVecConst);
338     if (EmitIR) {
339       MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
340     } else {
341       if (Val) {
342         auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
343                        .addDef(SpvVecConst)
344                        .addUse(getSPIRVTypeID(SpvType));
345         for (unsigned i = 0; i < ElemCnt; ++i)
346           MIB.addUse(SpvScalConst);
347       } else {
348         MIRBuilder.buildInstr(SPIRV::OpConstantNull)
349             .addDef(SpvVecConst)
350             .addUse(getSPIRVTypeID(SpvType));
351       }
352     }
353     return SpvVecConst;
354   }
355   return Res;
356 }
357 
358 Register
359 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
360                                               MachineIRBuilder &MIRBuilder,
361                                               SPIRVType *SpvType, bool EmitIR) {
362   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
363   assert(LLVMTy->isVectorTy());
364   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
365   Type *LLVMBaseTy = LLVMVecTy->getElementType();
366   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
367   auto ConstVec =
368       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
369   unsigned BW = getScalarOrVectorBitWidth(SpvType);
370   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
371                                        ConstVec, BW,
372                                        SpvType->getOperand(2).getImm());
373 }
374 
375 Register
376 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
377                                              MachineIRBuilder &MIRBuilder,
378                                              SPIRVType *SpvType, bool EmitIR) {
379   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
380   assert(LLVMTy->isArrayTy());
381   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
382   Type *LLVMBaseTy = LLVMArrTy->getElementType();
383   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
384   auto ConstArr =
385       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
386   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
387   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
388   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
389                                        ConstArr, BW,
390                                        LLVMArrTy->getNumElements());
391 }
392 
393 Register
394 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
395                                              SPIRVType *SpvType) {
396   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
397   const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
398   // Find a constant in DT or build a new one.
399   Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
400   Register Res = DT.find(CP, CurMF);
401   if (!Res.isValid()) {
402     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
403     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
404     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
405     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
406         .addDef(Res)
407         .addUse(getSPIRVTypeID(SpvType));
408     DT.add(CP, CurMF, Res);
409   }
410   return Res;
411 }
412 
413 Register SPIRVGlobalRegistry::buildConstantSampler(
414     Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
415     MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
416   SPIRVType *SampTy;
417   if (SpvType)
418     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
419   else
420     SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
421 
422   auto Sampler =
423       ResReg.isValid()
424           ? ResReg
425           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
426   auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
427                  .addDef(Sampler)
428                  .addUse(getSPIRVTypeID(SampTy))
429                  .addImm(AddrMode)
430                  .addImm(Param)
431                  .addImm(FilerMode);
432   assert(Res->getOperand(0).isReg());
433   return Res->getOperand(0).getReg();
434 }
435 
436 Register SPIRVGlobalRegistry::buildGlobalVariable(
437     Register ResVReg, SPIRVType *BaseType, StringRef Name,
438     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
439     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
440     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
441     bool IsInstSelector) {
442   const GlobalVariable *GVar = nullptr;
443   if (GV)
444     GVar = cast<const GlobalVariable>(GV);
445   else {
446     // If GV is not passed explicitly, use the name to find or construct
447     // the global variable.
448     Module *M = MIRBuilder.getMF().getFunction().getParent();
449     GVar = M->getGlobalVariable(Name);
450     if (GVar == nullptr) {
451       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
452       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
453                                 GlobalValue::ExternalLinkage, nullptr,
454                                 Twine(Name));
455     }
456     GV = GVar;
457   }
458   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
459   if (Reg.isValid()) {
460     if (Reg != ResVReg)
461       MIRBuilder.buildCopy(ResVReg, Reg);
462     return ResVReg;
463   }
464 
465   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
466                  .addDef(ResVReg)
467                  .addUse(getSPIRVTypeID(BaseType))
468                  .addImm(static_cast<uint32_t>(Storage));
469 
470   if (Init != 0) {
471     MIB.addUse(Init->getOperand(0).getReg());
472   }
473 
474   // ISel may introduce a new register on this step, so we need to add it to
475   // DT and correct its type avoiding fails on the next stage.
476   if (IsInstSelector) {
477     const auto &Subtarget = CurMF->getSubtarget();
478     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
479                                      *Subtarget.getRegisterInfo(),
480                                      *Subtarget.getRegBankInfo());
481   }
482   Reg = MIB->getOperand(0).getReg();
483   DT.add(GVar, &MIRBuilder.getMF(), Reg);
484 
485   // Set to Reg the same type as ResVReg has.
486   auto MRI = MIRBuilder.getMRI();
487   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
488   if (Reg != ResVReg) {
489     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
490     MRI->setType(Reg, RegLLTy);
491     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
492   }
493 
494   // If it's a global variable with name, output OpName for it.
495   if (GVar && GVar->hasName())
496     buildOpName(Reg, GVar->getName(), MIRBuilder);
497 
498   // Output decorations for the GV.
499   // TODO: maybe move to GenerateDecorations pass.
500   if (IsConst)
501     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
502 
503   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
504     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
505     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
506   }
507 
508   if (HasLinkageTy)
509     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
510                     {static_cast<uint32_t>(LinkageType)}, Name);
511 
512   SPIRV::BuiltIn::BuiltIn BuiltInId;
513   if (getSpirvBuiltInIdByName(Name, BuiltInId))
514     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
515                     {static_cast<uint32_t>(BuiltInId)});
516 
517   return Reg;
518 }
519 
520 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
521                                                SPIRVType *ElemType,
522                                                MachineIRBuilder &MIRBuilder,
523                                                bool EmitIR) {
524   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
525          "Invalid array element type");
526   Register NumElementsVReg =
527       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
528   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
529                  .addDef(createTypeVReg(MIRBuilder))
530                  .addUse(getSPIRVTypeID(ElemType))
531                  .addUse(NumElementsVReg);
532   return MIB;
533 }
534 
535 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
536                                                 MachineIRBuilder &MIRBuilder) {
537   assert(Ty->hasName());
538   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
539   Register ResVReg = createTypeVReg(MIRBuilder);
540   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
541   addStringImm(Name, MIB);
542   buildOpName(ResVReg, Name, MIRBuilder);
543   return MIB;
544 }
545 
546 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
547                                                 MachineIRBuilder &MIRBuilder,
548                                                 bool EmitIR) {
549   SmallVector<Register, 4> FieldTypes;
550   for (const auto &Elem : Ty->elements()) {
551     SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
552     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
553            "Invalid struct element type");
554     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
555   }
556   Register ResVReg = createTypeVReg(MIRBuilder);
557   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
558   for (const auto &Ty : FieldTypes)
559     MIB.addUse(Ty);
560   if (Ty->hasName())
561     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
562   if (Ty->isPacked())
563     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
564   return MIB;
565 }
566 
567 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
568     const Type *Ty, MachineIRBuilder &MIRBuilder,
569     SPIRV::AccessQualifier::AccessQualifier AccQual) {
570   // Some OpenCL and SPIRV builtins like image2d_t are passed in as
571   // pointers, but should be treated as custom types like OpTypeImage.
572   if (auto PType = dyn_cast<PointerType>(Ty)) {
573     assert(!PType->isOpaque());
574     Ty = PType->getNonOpaquePointerElementType();
575   }
576   auto SType = cast<StructType>(Ty);
577   assert(isSpecialOpaqueType(SType) && "Not a special opaque builtin type");
578   return SPIRV::lowerBuiltinType(SType, AccQual, MIRBuilder, this);
579 }
580 
581 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
582     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
583     MachineIRBuilder &MIRBuilder, Register Reg) {
584   if (!Reg.isValid())
585     Reg = createTypeVReg(MIRBuilder);
586   return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
587       .addDef(Reg)
588       .addImm(static_cast<uint32_t>(SC))
589       .addUse(getSPIRVTypeID(ElemType));
590 }
591 
592 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
593     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
594   return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
595       .addUse(createTypeVReg(MIRBuilder))
596       .addImm(static_cast<uint32_t>(SC));
597 }
598 
599 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
600     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
601     MachineIRBuilder &MIRBuilder) {
602   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
603                  .addDef(createTypeVReg(MIRBuilder))
604                  .addUse(getSPIRVTypeID(RetType));
605   for (const SPIRVType *ArgType : ArgTypes)
606     MIB.addUse(getSPIRVTypeID(ArgType));
607   return MIB;
608 }
609 
610 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
611     const Type *Ty, SPIRVType *RetType,
612     const SmallVectorImpl<SPIRVType *> &ArgTypes,
613     MachineIRBuilder &MIRBuilder) {
614   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
615   if (Reg.isValid())
616     return getSPIRVTypeForVReg(Reg);
617   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
618   return finishCreatingSPIRVType(Ty, SpirvType);
619 }
620 
621 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
622     const Type *Ty, MachineIRBuilder &MIRBuilder,
623     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
624   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
625   if (Reg.isValid())
626     return getSPIRVTypeForVReg(Reg);
627   if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end())
628     return ForwardPointerTypes[Ty];
629   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
630 }
631 
632 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
633   assert(SpirvType && "Attempting to get type id for nullptr type.");
634   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
635     return SpirvType->uses().begin()->getReg();
636   return SpirvType->defs().begin()->getReg();
637 }
638 
639 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
640     const Type *Ty, MachineIRBuilder &MIRBuilder,
641     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
642   if (isSpecialOpaqueType(Ty))
643     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
644   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
645   auto t = TypeToSPIRVTypeMap.find(Ty);
646   if (t != TypeToSPIRVTypeMap.end()) {
647     auto tt = t->second.find(&MIRBuilder.getMF());
648     if (tt != t->second.end())
649       return getSPIRVTypeForVReg(tt->second);
650   }
651 
652   if (auto IType = dyn_cast<IntegerType>(Ty)) {
653     const unsigned Width = IType->getBitWidth();
654     return Width == 1 ? getOpTypeBool(MIRBuilder)
655                       : getOpTypeInt(Width, MIRBuilder, false);
656   }
657   if (Ty->isFloatingPointTy())
658     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
659   if (Ty->isVoidTy())
660     return getOpTypeVoid(MIRBuilder);
661   if (Ty->isVectorTy()) {
662     SPIRVType *El =
663         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
664     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
665                            MIRBuilder);
666   }
667   if (Ty->isArrayTy()) {
668     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
669     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
670   }
671   if (auto SType = dyn_cast<StructType>(Ty)) {
672     if (SType->isOpaque())
673       return getOpTypeOpaque(SType, MIRBuilder);
674     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
675   }
676   if (auto FType = dyn_cast<FunctionType>(Ty)) {
677     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
678     SmallVector<SPIRVType *, 4> ParamTypes;
679     for (const auto &t : FType->params()) {
680       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
681     }
682     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
683   }
684   if (auto PType = dyn_cast<PointerType>(Ty)) {
685     SPIRVType *SpvElementType;
686     // At the moment, all opaque pointers correspond to i8 element type.
687     // TODO: change the implementation once opaque pointers are supported
688     // in the SPIR-V specification.
689     if (PType->isOpaque())
690       SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
691     else
692       SpvElementType =
693           findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder,
694                         SPIRV::AccessQualifier::ReadWrite, EmitIR);
695     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
696     // Null pointer means we have a loop in type definitions, make and
697     // return corresponding OpTypeForwardPointer.
698     if (SpvElementType == nullptr) {
699       if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end())
700         ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
701       return ForwardPointerTypes[PType];
702     }
703     Register Reg(0);
704     // If we have forward pointer associated with this type, use its register
705     // operand to create OpTypePointer.
706     if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end())
707       Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
708 
709     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
710   }
711   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
712 }
713 
714 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
715     const Type *Ty, MachineIRBuilder &MIRBuilder,
716     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
717   if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
718     return nullptr;
719   TypesInProcessing.insert(Ty);
720   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
721   TypesInProcessing.erase(Ty);
722   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
723   SPIRVToLLVMType[SpirvType] = Ty;
724   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
725   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
726   // will be added later. For special types it is already added to DT.
727   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
728       !isSpecialOpaqueType(Ty))
729     DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
730 
731   return SpirvType;
732 }
733 
734 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
735   auto t = VRegToTypeMap.find(CurMF);
736   if (t != VRegToTypeMap.end()) {
737     auto tt = t->second.find(VReg);
738     if (tt != t->second.end())
739       return tt->second;
740   }
741   return nullptr;
742 }
743 
744 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
745     const Type *Ty, MachineIRBuilder &MIRBuilder,
746     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
747   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
748   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
749     return getSPIRVTypeForVReg(Reg);
750   TypesInProcessing.clear();
751   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
752   // Create normal pointer types for the corresponding OpTypeForwardPointers.
753   for (auto &CU : ForwardPointerTypes) {
754     const Type *Ty2 = CU.first;
755     SPIRVType *STy2 = CU.second;
756     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
757       STy2 = getSPIRVTypeForVReg(Reg);
758     else
759       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
760     if (Ty == Ty2)
761       STy = STy2;
762   }
763   ForwardPointerTypes.clear();
764   return STy;
765 }
766 
767 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
768                                          unsigned TypeOpcode) const {
769   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
770   assert(Type && "isScalarOfType VReg has no type assigned");
771   return Type->getOpcode() == TypeOpcode;
772 }
773 
774 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
775                                                  unsigned TypeOpcode) const {
776   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
777   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
778   if (Type->getOpcode() == TypeOpcode)
779     return true;
780   if (Type->getOpcode() == SPIRV::OpTypeVector) {
781     Register ScalarTypeVReg = Type->getOperand(1).getReg();
782     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
783     return ScalarType->getOpcode() == TypeOpcode;
784   }
785   return false;
786 }
787 
788 unsigned
789 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
790   assert(Type && "Invalid Type pointer");
791   if (Type->getOpcode() == SPIRV::OpTypeVector) {
792     auto EleTypeReg = Type->getOperand(1).getReg();
793     Type = getSPIRVTypeForVReg(EleTypeReg);
794   }
795   if (Type->getOpcode() == SPIRV::OpTypeInt ||
796       Type->getOpcode() == SPIRV::OpTypeFloat)
797     return Type->getOperand(1).getImm();
798   if (Type->getOpcode() == SPIRV::OpTypeBool)
799     return 1;
800   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
801 }
802 
803 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
804   assert(Type && "Invalid Type pointer");
805   if (Type->getOpcode() == SPIRV::OpTypeVector) {
806     auto EleTypeReg = Type->getOperand(1).getReg();
807     Type = getSPIRVTypeForVReg(EleTypeReg);
808   }
809   if (Type->getOpcode() == SPIRV::OpTypeInt)
810     return Type->getOperand(2).getImm() != 0;
811   llvm_unreachable("Attempting to get sign of non-integer type.");
812 }
813 
814 SPIRV::StorageClass::StorageClass
815 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
816   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
817   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
818          Type->getOperand(1).isImm() && "Pointer type is expected");
819   return static_cast<SPIRV::StorageClass::StorageClass>(
820       Type->getOperand(1).getImm());
821 }
822 
823 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
824     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
825     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
826     SPIRV::ImageFormat::ImageFormat ImageFormat,
827     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
828   SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
829                                 Arrayed, Multisampled, Sampled, ImageFormat,
830                                 AccessQual);
831   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
832     return Res;
833   Register ResVReg = createTypeVReg(MIRBuilder);
834   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
835   return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
836       .addDef(ResVReg)
837       .addUse(getSPIRVTypeID(SampledType))
838       .addImm(Dim)
839       .addImm(Depth)        // Depth (whether or not it is a Depth image).
840       .addImm(Arrayed)      // Arrayed.
841       .addImm(Multisampled) // Multisampled (0 = only single-sample).
842       .addImm(Sampled)      // Sampled (0 = usage known at runtime).
843       .addImm(ImageFormat)
844       .addImm(AccessQual);
845 }
846 
847 SPIRVType *
848 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
849   SPIRV::SamplerTypeDescriptor TD;
850   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
851     return Res;
852   Register ResVReg = createTypeVReg(MIRBuilder);
853   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
854   return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
855 }
856 
857 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
858     MachineIRBuilder &MIRBuilder,
859     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
860   SPIRV::PipeTypeDescriptor TD(AccessQual);
861   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
862     return Res;
863   Register ResVReg = createTypeVReg(MIRBuilder);
864   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
865   return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
866       .addDef(ResVReg)
867       .addImm(AccessQual);
868 }
869 
870 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
871     MachineIRBuilder &MIRBuilder) {
872   SPIRV::DeviceEventTypeDescriptor TD;
873   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
874     return Res;
875   Register ResVReg = createTypeVReg(MIRBuilder);
876   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
877   return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
878 }
879 
880 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
881     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
882   SPIRV::SampledImageTypeDescriptor TD(
883       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
884           ImageType->getOperand(1).getReg())),
885       ImageType);
886   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
887     return Res;
888   Register ResVReg = createTypeVReg(MIRBuilder);
889   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
890   return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
891       .addDef(ResVReg)
892       .addUse(getSPIRVTypeID(ImageType));
893 }
894 
895 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
896     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
897   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
898   if (ResVReg.isValid())
899     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
900   ResVReg = createTypeVReg(MIRBuilder);
901   DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
902   return MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
903 }
904 
905 const MachineInstr *
906 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
907                                        MachineIRBuilder &MIRBuilder) {
908   Register Reg = DT.find(TD, &MIRBuilder.getMF());
909   if (Reg.isValid())
910     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
911   return nullptr;
912 }
913 
914 // TODO: maybe use tablegen to implement this.
915 SPIRVType *
916 SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
917                                                 MachineIRBuilder &MIRBuilder) {
918   unsigned VecElts = 0;
919   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
920 
921   // Parse type name in either "typeN" or "type vector[N]" format, where
922   // N is the number of elements of the vector.
923   Type *Type;
924   if (TypeStr.startswith("void")) {
925     Type = Type::getVoidTy(Ctx);
926     TypeStr = TypeStr.substr(strlen("void"));
927   } else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) {
928     Type = Type::getInt32Ty(Ctx);
929     TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int"))
930                                         : TypeStr.substr(strlen("uint"));
931   } else if (TypeStr.startswith("float")) {
932     Type = Type::getFloatTy(Ctx);
933     TypeStr = TypeStr.substr(strlen("float"));
934   } else if (TypeStr.startswith("half")) {
935     Type = Type::getHalfTy(Ctx);
936     TypeStr = TypeStr.substr(strlen("half"));
937   } else if (TypeStr.startswith("opencl.sampler_t")) {
938     Type = StructType::create(Ctx, "opencl.sampler_t");
939   } else
940     llvm_unreachable("Unable to recognize SPIRV type name.");
941   if (TypeStr.startswith(" vector[")) {
942     TypeStr = TypeStr.substr(strlen(" vector["));
943     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
944   }
945   TypeStr.getAsInteger(10, VecElts);
946   auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder);
947   if (VecElts > 0)
948     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
949   return SpirvTy;
950 }
951 
952 SPIRVType *
953 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
954                                                  MachineIRBuilder &MIRBuilder) {
955   return getOrCreateSPIRVType(
956       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
957       MIRBuilder);
958 }
959 
960 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
961                                                         SPIRVType *SpirvType) {
962   assert(CurMF == SpirvType->getMF());
963   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
964   SPIRVToLLVMType[SpirvType] = LLVMTy;
965   DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
966   return SpirvType;
967 }
968 
969 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
970     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
971   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
972   Register Reg = DT.find(LLVMTy, CurMF);
973   if (Reg.isValid())
974     return getSPIRVTypeForVReg(Reg);
975   MachineBasicBlock &BB = *I.getParent();
976   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
977                  .addDef(createTypeVReg(CurMF->getRegInfo()))
978                  .addImm(BitWidth)
979                  .addImm(0);
980   return finishCreatingSPIRVType(LLVMTy, MIB);
981 }
982 
983 SPIRVType *
984 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
985   return getOrCreateSPIRVType(
986       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
987       MIRBuilder);
988 }
989 
990 SPIRVType *
991 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
992                                               const SPIRVInstrInfo &TII) {
993   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
994   Register Reg = DT.find(LLVMTy, CurMF);
995   if (Reg.isValid())
996     return getSPIRVTypeForVReg(Reg);
997   MachineBasicBlock &BB = *I.getParent();
998   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
999                  .addDef(createTypeVReg(CurMF->getRegInfo()));
1000   return finishCreatingSPIRVType(LLVMTy, MIB);
1001 }
1002 
1003 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1004     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1005   return getOrCreateSPIRVType(
1006       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1007                            NumElements),
1008       MIRBuilder);
1009 }
1010 
1011 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1012     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1013     const SPIRVInstrInfo &TII) {
1014   Type *LLVMTy = FixedVectorType::get(
1015       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1016   Register Reg = DT.find(LLVMTy, CurMF);
1017   if (Reg.isValid())
1018     return getSPIRVTypeForVReg(Reg);
1019   MachineBasicBlock &BB = *I.getParent();
1020   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1021                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1022                  .addUse(getSPIRVTypeID(BaseType))
1023                  .addImm(NumElements);
1024   return finishCreatingSPIRVType(LLVMTy, MIB);
1025 }
1026 
1027 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1028     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1029     const SPIRVInstrInfo &TII) {
1030   Type *LLVMTy = ArrayType::get(
1031       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1032   Register Reg = DT.find(LLVMTy, CurMF);
1033   if (Reg.isValid())
1034     return getSPIRVTypeForVReg(Reg);
1035   MachineBasicBlock &BB = *I.getParent();
1036   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1037   Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1038   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1039                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1040                  .addUse(getSPIRVTypeID(BaseType))
1041                  .addUse(Len);
1042   return finishCreatingSPIRVType(LLVMTy, MIB);
1043 }
1044 
1045 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1046     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1047     SPIRV::StorageClass::StorageClass SClass) {
1048   return getOrCreateSPIRVType(
1049       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1050                        storageClassToAddressSpace(SClass)),
1051       MIRBuilder);
1052 }
1053 
1054 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1055     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
1056     SPIRV::StorageClass::StorageClass SC) {
1057   Type *LLVMTy =
1058       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1059                        storageClassToAddressSpace(SC));
1060   Register Reg = DT.find(LLVMTy, CurMF);
1061   if (Reg.isValid())
1062     return getSPIRVTypeForVReg(Reg);
1063   MachineBasicBlock &BB = *I.getParent();
1064   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
1065                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1066                  .addImm(static_cast<uint32_t>(SC))
1067                  .addUse(getSPIRVTypeID(BaseType));
1068   return finishCreatingSPIRVType(LLVMTy, MIB);
1069 }
1070 
1071 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1072                                                SPIRVType *SpvType,
1073                                                const SPIRVInstrInfo &TII) {
1074   assert(SpvType);
1075   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1076   assert(LLVMTy);
1077   // Find a constant in DT or build a new one.
1078   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1079   Register Res = DT.find(UV, CurMF);
1080   if (Res.isValid())
1081     return Res;
1082   LLT LLTy = LLT::scalar(32);
1083   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1084   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1085   DT.add(UV, CurMF, Res);
1086 
1087   MachineInstrBuilder MIB;
1088   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1089             .addDef(Res)
1090             .addUse(getSPIRVTypeID(SpvType));
1091   const auto &ST = CurMF->getSubtarget();
1092   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1093                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
1094   return Res;
1095 }
1096