xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Support/Casting.h"
27 #include <cassert>
28 
29 using namespace llvm;
SPIRVGlobalRegistry(unsigned PointerSize)30 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
31     : PointerSize(PointerSize), Bound(0) {}
32 
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)33 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
34                                                     Register VReg,
35                                                     MachineInstr &I,
36                                                     const SPIRVInstrInfo &TII) {
37   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
38   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
39   return SpirvType;
40 }
41 
42 SPIRVType *
assignFloatTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)43 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
44                                            MachineInstr &I,
45                                            const SPIRVInstrInfo &TII) {
46   SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
47   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
48   return SpirvType;
49 }
50 
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)51 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
52     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
53     const SPIRVInstrInfo &TII) {
54   SPIRVType *SpirvType =
55       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
56   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
57   return SpirvType;
58 }
59 
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)60 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
61     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
62     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
63   SPIRVType *SpirvType =
64       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
65   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
66   return SpirvType;
67 }
68 
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,MachineFunction & MF)69 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
70                                                 Register VReg,
71                                                 MachineFunction &MF) {
72   VRegToTypeMap[&MF][VReg] = SpirvType;
73 }
74 
createTypeVReg(MachineIRBuilder & MIRBuilder)75 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
76   auto &MRI = MIRBuilder.getMF().getRegInfo();
77   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
78   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
79   return Res;
80 }
81 
createTypeVReg(MachineRegisterInfo & MRI)82 static Register createTypeVReg(MachineRegisterInfo &MRI) {
83   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
84   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
85   return Res;
86 }
87 
getOpTypeBool(MachineIRBuilder & MIRBuilder)88 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
89   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
90       .addDef(createTypeVReg(MIRBuilder));
91 }
92 
adjustOpTypeIntWidth(unsigned Width) const93 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
94   if (Width > 64)
95     report_fatal_error("Unsupported integer width!");
96   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
97   if (ST.canUseExtension(
98           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
99     return Width;
100   if (Width <= 8)
101     Width = 8;
102   else if (Width <= 16)
103     Width = 16;
104   else if (Width <= 32)
105     Width = 32;
106   else
107     Width = 64;
108   return Width;
109 }
110 
getOpTypeInt(unsigned Width,MachineIRBuilder & MIRBuilder,bool IsSigned)111 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
112                                              MachineIRBuilder &MIRBuilder,
113                                              bool IsSigned) {
114   Width = adjustOpTypeIntWidth(Width);
115   const SPIRVSubtarget &ST =
116       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
117   if (ST.canUseExtension(
118           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
119     MIRBuilder.buildInstr(SPIRV::OpExtension)
120         .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
121     MIRBuilder.buildInstr(SPIRV::OpCapability)
122         .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
123   }
124   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
125                  .addDef(createTypeVReg(MIRBuilder))
126                  .addImm(Width)
127                  .addImm(IsSigned ? 1 : 0);
128   return MIB;
129 }
130 
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)131 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
132                                                MachineIRBuilder &MIRBuilder) {
133   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
134                  .addDef(createTypeVReg(MIRBuilder))
135                  .addImm(Width);
136   return MIB;
137 }
138 
getOpTypeVoid(MachineIRBuilder & MIRBuilder)139 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
140   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
141       .addDef(createTypeVReg(MIRBuilder));
142 }
143 
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)144 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
145                                                 SPIRVType *ElemType,
146                                                 MachineIRBuilder &MIRBuilder) {
147   auto EleOpc = ElemType->getOpcode();
148   (void)EleOpc;
149   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
150           EleOpc == SPIRV::OpTypeBool) &&
151          "Invalid vector element type");
152 
153   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
154                  .addDef(createTypeVReg(MIRBuilder))
155                  .addUse(getSPIRVTypeID(ElemType))
156                  .addImm(NumElems);
157   return MIB;
158 }
159 
160 std::tuple<Register, ConstantInt *, bool>
getOrCreateConstIntReg(uint64_t Val,SPIRVType * SpvType,MachineIRBuilder * MIRBuilder,MachineInstr * I,const SPIRVInstrInfo * TII)161 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
162                                             MachineIRBuilder *MIRBuilder,
163                                             MachineInstr *I,
164                                             const SPIRVInstrInfo *TII) {
165   const IntegerType *LLVMIntTy;
166   if (SpvType)
167     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
168   else
169     LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
170   bool NewInstr = false;
171   // Find a constant in DT or build a new one.
172   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
173   Register Res = DT.find(CI, CurMF);
174   if (!Res.isValid()) {
175     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
176     // TODO: handle cases where the type is not 32bit wide
177     // TODO: https://github.com/llvm/llvm-project/issues/88129
178     LLT LLTy = LLT::scalar(32);
179     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
180     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
181     if (MIRBuilder)
182       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
183     else
184       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
185     DT.add(CI, CurMF, Res);
186     NewInstr = true;
187   }
188   return std::make_tuple(Res, CI, NewInstr);
189 }
190 
191 std::tuple<Register, ConstantFP *, bool, unsigned>
getOrCreateConstFloatReg(APFloat Val,SPIRVType * SpvType,MachineIRBuilder * MIRBuilder,MachineInstr * I,const SPIRVInstrInfo * TII)192 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
193                                               MachineIRBuilder *MIRBuilder,
194                                               MachineInstr *I,
195                                               const SPIRVInstrInfo *TII) {
196   const Type *LLVMFloatTy;
197   LLVMContext &Ctx = CurMF->getFunction().getContext();
198   unsigned BitWidth = 32;
199   if (SpvType)
200     LLVMFloatTy = getTypeForSPIRVType(SpvType);
201   else {
202     LLVMFloatTy = Type::getFloatTy(Ctx);
203     if (MIRBuilder)
204       SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
205   }
206   bool NewInstr = false;
207   // Find a constant in DT or build a new one.
208   auto *const CI = ConstantFP::get(Ctx, Val);
209   Register Res = DT.find(CI, CurMF);
210   if (!Res.isValid()) {
211     if (SpvType)
212       BitWidth = getScalarOrVectorBitWidth(SpvType);
213     // TODO: handle cases where the type is not 32bit wide
214     // TODO: https://github.com/llvm/llvm-project/issues/88129
215     LLT LLTy = LLT::scalar(32);
216     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
217     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
218     if (MIRBuilder)
219       assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
220     else
221       assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
222     DT.add(CI, CurMF, Res);
223     NewInstr = true;
224   }
225   return std::make_tuple(Res, CI, NewInstr, BitWidth);
226 }
227 
getOrCreateConstFP(APFloat Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)228 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
229                                                  SPIRVType *SpvType,
230                                                  const SPIRVInstrInfo &TII,
231                                                  bool ZeroAsNull) {
232   assert(SpvType);
233   ConstantFP *CI;
234   Register Res;
235   bool New;
236   unsigned BitWidth;
237   std::tie(Res, CI, New, BitWidth) =
238       getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
239   // If we have found Res register which is defined by the passed G_CONSTANT
240   // machine instruction, a new constant instruction should be created.
241   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
242     return Res;
243   MachineInstrBuilder MIB;
244   MachineBasicBlock &BB = *I.getParent();
245   // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
246   if (Val.isPosZero() && ZeroAsNull) {
247     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
248               .addDef(Res)
249               .addUse(getSPIRVTypeID(SpvType));
250   } else {
251     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
252               .addDef(Res)
253               .addUse(getSPIRVTypeID(SpvType));
254     addNumImm(
255         APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
256         MIB);
257   }
258   const auto &ST = CurMF->getSubtarget();
259   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
260                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
261   return Res;
262 }
263 
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)264 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
265                                                   SPIRVType *SpvType,
266                                                   const SPIRVInstrInfo &TII,
267                                                   bool ZeroAsNull) {
268   assert(SpvType);
269   ConstantInt *CI;
270   Register Res;
271   bool New;
272   std::tie(Res, CI, New) =
273       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
274   // If we have found Res register which is defined by the passed G_CONSTANT
275   // machine instruction, a new constant instruction should be created.
276   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
277     return Res;
278   MachineInstrBuilder MIB;
279   MachineBasicBlock &BB = *I.getParent();
280   if (Val || !ZeroAsNull) {
281     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
282               .addDef(Res)
283               .addUse(getSPIRVTypeID(SpvType));
284     addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
285   } else {
286     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
287               .addDef(Res)
288               .addUse(getSPIRVTypeID(SpvType));
289   }
290   const auto &ST = CurMF->getSubtarget();
291   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
292                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
293   return Res;
294 }
295 
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)296 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
297                                                MachineIRBuilder &MIRBuilder,
298                                                SPIRVType *SpvType,
299                                                bool EmitIR) {
300   auto &MF = MIRBuilder.getMF();
301   const IntegerType *LLVMIntTy;
302   if (SpvType)
303     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
304   else
305     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
306   // Find a constant in DT or build a new one.
307   const auto ConstInt =
308       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
309   Register Res = DT.find(ConstInt, &MF);
310   if (!Res.isValid()) {
311     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
312     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
313     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
314     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
315     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
316                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
317     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
318     if (EmitIR) {
319       MIRBuilder.buildConstant(Res, *ConstInt);
320     } else {
321       if (!SpvType)
322         SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
323       MachineInstrBuilder MIB;
324       if (Val) {
325         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
326                   .addDef(Res)
327                   .addUse(getSPIRVTypeID(SpvType));
328         addNumImm(APInt(BitWidth, Val), MIB);
329       } else {
330         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
331                   .addDef(Res)
332                   .addUse(getSPIRVTypeID(SpvType));
333       }
334       const auto &Subtarget = CurMF->getSubtarget();
335       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
336                                        *Subtarget.getRegisterInfo(),
337                                        *Subtarget.getRegBankInfo());
338     }
339   }
340   return Res;
341 }
342 
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)343 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
344                                               MachineIRBuilder &MIRBuilder,
345                                               SPIRVType *SpvType) {
346   auto &MF = MIRBuilder.getMF();
347   auto &Ctx = MF.getFunction().getContext();
348   if (!SpvType) {
349     const Type *LLVMFPTy = Type::getFloatTy(Ctx);
350     SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
351   }
352   // Find a constant in DT or build a new one.
353   const auto ConstFP = ConstantFP::get(Ctx, Val);
354   Register Res = DT.find(ConstFP, &MF);
355   if (!Res.isValid()) {
356     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
357     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
358     assignSPIRVTypeToVReg(SpvType, Res, MF);
359     DT.add(ConstFP, &MF, Res);
360 
361     MachineInstrBuilder MIB;
362     MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
363               .addDef(Res)
364               .addUse(getSPIRVTypeID(SpvType));
365     addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
366   }
367 
368   return Res;
369 }
370 
getOrCreateBaseRegister(Constant * Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,unsigned BitWidth)371 Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val,
372                                                       MachineInstr &I,
373                                                       SPIRVType *SpvType,
374                                                       const SPIRVInstrInfo &TII,
375                                                       unsigned BitWidth) {
376   SPIRVType *Type = SpvType;
377   if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
378       SpvType->getOpcode() == SPIRV::OpTypeArray) {
379     auto EleTypeReg = SpvType->getOperand(1).getReg();
380     Type = getSPIRVTypeForVReg(EleTypeReg);
381   }
382   if (Type->getOpcode() == SPIRV::OpTypeFloat) {
383     SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
384     return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
385                               SpvBaseType, TII);
386   }
387   assert(Type->getOpcode() == SPIRV::OpTypeInt);
388   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
389   return getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(), I,
390                              SpvBaseType, TII);
391 }
392 
getOrCreateCompositeOrNull(Constant * Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,Constant * CA,unsigned BitWidth,unsigned ElemCnt,bool ZeroAsNull)393 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
394     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
395     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
396     unsigned ElemCnt, bool ZeroAsNull) {
397   // Find a constant vector or array in DT or build a new one.
398   Register Res = DT.find(CA, CurMF);
399   // If no values are attached, the composite is null constant.
400   bool IsNull = Val->isNullValue() && ZeroAsNull;
401   if (!Res.isValid()) {
402     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
403     // error on validation.
404     // TODO: can moved below once sorting of types/consts/defs is implemented.
405     Register SpvScalConst;
406     if (!IsNull)
407       SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth);
408 
409     // TODO: handle cases where the type is not 32bit wide
410     // TODO: https://github.com/llvm/llvm-project/issues/88129
411     LLT LLTy = LLT::scalar(32);
412     Register SpvVecConst =
413         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
414     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
415     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
416     DT.add(CA, CurMF, SpvVecConst);
417     MachineInstrBuilder MIB;
418     MachineBasicBlock &BB = *I.getParent();
419     if (!IsNull) {
420       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
421                 .addDef(SpvVecConst)
422                 .addUse(getSPIRVTypeID(SpvType));
423       for (unsigned i = 0; i < ElemCnt; ++i)
424         MIB.addUse(SpvScalConst);
425     } else {
426       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
427                 .addDef(SpvVecConst)
428                 .addUse(getSPIRVTypeID(SpvType));
429     }
430     const auto &Subtarget = CurMF->getSubtarget();
431     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
432                                      *Subtarget.getRegisterInfo(),
433                                      *Subtarget.getRegBankInfo());
434     return SpvVecConst;
435   }
436   return Res;
437 }
438 
getOrCreateConstVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)439 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
440                                                      MachineInstr &I,
441                                                      SPIRVType *SpvType,
442                                                      const SPIRVInstrInfo &TII,
443                                                      bool ZeroAsNull) {
444   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
445   assert(LLVMTy->isVectorTy());
446   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
447   Type *LLVMBaseTy = LLVMVecTy->getElementType();
448   assert(LLVMBaseTy->isIntegerTy());
449   auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
450   auto *ConstVec =
451       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
452   unsigned BW = getScalarOrVectorBitWidth(SpvType);
453   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
454                                     SpvType->getOperand(2).getImm(),
455                                     ZeroAsNull);
456 }
457 
getOrCreateConstVector(APFloat Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,bool ZeroAsNull)458 Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
459                                                      MachineInstr &I,
460                                                      SPIRVType *SpvType,
461                                                      const SPIRVInstrInfo &TII,
462                                                      bool ZeroAsNull) {
463   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
464   assert(LLVMTy->isVectorTy());
465   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
466   Type *LLVMBaseTy = LLVMVecTy->getElementType();
467   assert(LLVMBaseTy->isFloatingPointTy());
468   auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
469   auto *ConstVec =
470       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
471   unsigned BW = getScalarOrVectorBitWidth(SpvType);
472   return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
473                                     SpvType->getOperand(2).getImm(),
474                                     ZeroAsNull);
475 }
476 
getOrCreateConstIntArray(uint64_t Val,size_t Num,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)477 Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
478     uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
479     const SPIRVInstrInfo &TII) {
480   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
481   assert(LLVMTy->isArrayTy());
482   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
483   Type *LLVMBaseTy = LLVMArrTy->getElementType();
484   Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
485   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
486   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
487   // The following is reasonably unique key that is better that [Val]. The naive
488   // alternative would be something along the lines of:
489   //   SmallVector<Constant *> NumCI(Num, CI);
490   //   Constant *UniqueKey =
491   //     ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
492   // that would be a truly unique but dangerous key, because it could lead to
493   // the creation of constants of arbitrary length (that is, the parameter of
494   // memset) which were missing in the original module.
495   Constant *UniqueKey = ConstantStruct::getAnon(
496       {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
497        ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
498   return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
499                                     LLVMArrTy->getNumElements());
500 }
501 
getOrCreateIntCompositeOrNull(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,Constant * CA,unsigned BitWidth,unsigned ElemCnt)502 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
503     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
504     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
505   Register Res = DT.find(CA, CurMF);
506   if (!Res.isValid()) {
507     Register SpvScalConst;
508     if (Val || EmitIR) {
509       SPIRVType *SpvBaseType =
510           getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
511       SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
512     }
513     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
514     Register SpvVecConst =
515         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
516     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
517     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
518     DT.add(CA, CurMF, SpvVecConst);
519     if (EmitIR) {
520       MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
521     } else {
522       if (Val) {
523         auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
524                        .addDef(SpvVecConst)
525                        .addUse(getSPIRVTypeID(SpvType));
526         for (unsigned i = 0; i < ElemCnt; ++i)
527           MIB.addUse(SpvScalConst);
528       } else {
529         MIRBuilder.buildInstr(SPIRV::OpConstantNull)
530             .addDef(SpvVecConst)
531             .addUse(getSPIRVTypeID(SpvType));
532       }
533     }
534     return SpvVecConst;
535   }
536   return Res;
537 }
538 
539 Register
getOrCreateConsIntVector(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)540 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
541                                               MachineIRBuilder &MIRBuilder,
542                                               SPIRVType *SpvType, bool EmitIR) {
543   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
544   assert(LLVMTy->isVectorTy());
545   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
546   Type *LLVMBaseTy = LLVMVecTy->getElementType();
547   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
548   auto ConstVec =
549       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
550   unsigned BW = getScalarOrVectorBitWidth(SpvType);
551   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
552                                        ConstVec, BW,
553                                        SpvType->getOperand(2).getImm());
554 }
555 
556 Register
getOrCreateConstNullPtr(MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)557 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
558                                              SPIRVType *SpvType) {
559   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
560   const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(LLVMTy);
561   // Find a constant in DT or build a new one.
562   Constant *CP = ConstantPointerNull::get(PointerType::get(
563       LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace()));
564   Register Res = DT.find(CP, CurMF);
565   if (!Res.isValid()) {
566     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
567     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
568     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
569     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
570     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
571         .addDef(Res)
572         .addUse(getSPIRVTypeID(SpvType));
573     DT.add(CP, CurMF, Res);
574   }
575   return Res;
576 }
577 
buildConstantSampler(Register ResReg,unsigned AddrMode,unsigned Param,unsigned FilerMode,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)578 Register SPIRVGlobalRegistry::buildConstantSampler(
579     Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
580     MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
581   SPIRVType *SampTy;
582   if (SpvType)
583     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
584   else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
585                                                 MIRBuilder)) == nullptr)
586     report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
587 
588   auto Sampler =
589       ResReg.isValid()
590           ? ResReg
591           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
592   auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
593                  .addDef(Sampler)
594                  .addUse(getSPIRVTypeID(SampTy))
595                  .addImm(AddrMode)
596                  .addImm(Param)
597                  .addImm(FilerMode);
598   assert(Res->getOperand(0).isReg());
599   return Res->getOperand(0).getReg();
600 }
601 
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)602 Register SPIRVGlobalRegistry::buildGlobalVariable(
603     Register ResVReg, SPIRVType *BaseType, StringRef Name,
604     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
605     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
606     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
607     bool IsInstSelector) {
608   const GlobalVariable *GVar = nullptr;
609   if (GV)
610     GVar = cast<const GlobalVariable>(GV);
611   else {
612     // If GV is not passed explicitly, use the name to find or construct
613     // the global variable.
614     Module *M = MIRBuilder.getMF().getFunction().getParent();
615     GVar = M->getGlobalVariable(Name);
616     if (GVar == nullptr) {
617       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
618       // Module takes ownership of the global var.
619       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
620                                 GlobalValue::ExternalLinkage, nullptr,
621                                 Twine(Name));
622     }
623     GV = GVar;
624   }
625   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
626   if (Reg.isValid()) {
627     if (Reg != ResVReg)
628       MIRBuilder.buildCopy(ResVReg, Reg);
629     return ResVReg;
630   }
631 
632   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
633                  .addDef(ResVReg)
634                  .addUse(getSPIRVTypeID(BaseType))
635                  .addImm(static_cast<uint32_t>(Storage));
636 
637   if (Init != 0) {
638     MIB.addUse(Init->getOperand(0).getReg());
639   }
640 
641   // ISel may introduce a new register on this step, so we need to add it to
642   // DT and correct its type avoiding fails on the next stage.
643   if (IsInstSelector) {
644     const auto &Subtarget = CurMF->getSubtarget();
645     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
646                                      *Subtarget.getRegisterInfo(),
647                                      *Subtarget.getRegBankInfo());
648   }
649   Reg = MIB->getOperand(0).getReg();
650   DT.add(GVar, &MIRBuilder.getMF(), Reg);
651 
652   // Set to Reg the same type as ResVReg has.
653   auto MRI = MIRBuilder.getMRI();
654   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
655   if (Reg != ResVReg) {
656     LLT RegLLTy =
657         LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
658     MRI->setType(Reg, RegLLTy);
659     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
660   } else {
661     // Our knowledge about the type may be updated.
662     // If that's the case, we need to update a type
663     // associated with the register.
664     SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
665     if (!DefType || DefType != BaseType)
666       assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
667   }
668 
669   // If it's a global variable with name, output OpName for it.
670   if (GVar && GVar->hasName())
671     buildOpName(Reg, GVar->getName(), MIRBuilder);
672 
673   // Output decorations for the GV.
674   // TODO: maybe move to GenerateDecorations pass.
675   const SPIRVSubtarget &ST =
676       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
677   if (IsConst && ST.isOpenCLEnv())
678     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
679 
680   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
681     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
682     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
683   }
684 
685   if (HasLinkageTy)
686     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
687                     {static_cast<uint32_t>(LinkageType)}, Name);
688 
689   SPIRV::BuiltIn::BuiltIn BuiltInId;
690   if (getSpirvBuiltInIdByName(Name, BuiltInId))
691     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
692                     {static_cast<uint32_t>(BuiltInId)});
693 
694   // If it's a global variable with "spirv.Decorations" metadata node
695   // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
696   // arguments.
697   MDNode *GVarMD = nullptr;
698   if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
699     buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
700 
701   return Reg;
702 }
703 
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool EmitIR)704 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
705                                                SPIRVType *ElemType,
706                                                MachineIRBuilder &MIRBuilder,
707                                                bool EmitIR) {
708   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
709          "Invalid array element type");
710   Register NumElementsVReg =
711       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
712   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
713                  .addDef(createTypeVReg(MIRBuilder))
714                  .addUse(getSPIRVTypeID(ElemType))
715                  .addUse(NumElementsVReg);
716   return MIB;
717 }
718 
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)719 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
720                                                 MachineIRBuilder &MIRBuilder) {
721   assert(Ty->hasName());
722   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
723   Register ResVReg = createTypeVReg(MIRBuilder);
724   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
725   addStringImm(Name, MIB);
726   buildOpName(ResVReg, Name, MIRBuilder);
727   return MIB;
728 }
729 
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,bool EmitIR)730 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
731                                                 MachineIRBuilder &MIRBuilder,
732                                                 bool EmitIR) {
733   SmallVector<Register, 4> FieldTypes;
734   for (const auto &Elem : Ty->elements()) {
735     SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
736     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
737            "Invalid struct element type");
738     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
739   }
740   Register ResVReg = createTypeVReg(MIRBuilder);
741   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
742   for (const auto &Ty : FieldTypes)
743     MIB.addUse(Ty);
744   if (Ty->hasName())
745     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
746   if (Ty->isPacked())
747     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
748   return MIB;
749 }
750 
getOrCreateSpecialType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual)751 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
752     const Type *Ty, MachineIRBuilder &MIRBuilder,
753     SPIRV::AccessQualifier::AccessQualifier AccQual) {
754   assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
755   return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
756 }
757 
getOpTypePointer(SPIRV::StorageClass::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)758 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
759     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
760     MachineIRBuilder &MIRBuilder, Register Reg) {
761   if (!Reg.isValid())
762     Reg = createTypeVReg(MIRBuilder);
763   return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
764       .addDef(Reg)
765       .addImm(static_cast<uint32_t>(SC))
766       .addUse(getSPIRVTypeID(ElemType));
767 }
768 
getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,MachineIRBuilder & MIRBuilder)769 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
770     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
771   return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
772       .addUse(createTypeVReg(MIRBuilder))
773       .addImm(static_cast<uint32_t>(SC));
774 }
775 
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)776 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
777     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
778     MachineIRBuilder &MIRBuilder) {
779   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
780                  .addDef(createTypeVReg(MIRBuilder))
781                  .addUse(getSPIRVTypeID(RetType));
782   for (const SPIRVType *ArgType : ArgTypes)
783     MIB.addUse(getSPIRVTypeID(ArgType));
784   return MIB;
785 }
786 
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)787 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
788     const Type *Ty, SPIRVType *RetType,
789     const SmallVectorImpl<SPIRVType *> &ArgTypes,
790     MachineIRBuilder &MIRBuilder) {
791   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
792   if (Reg.isValid())
793     return getSPIRVTypeForVReg(Reg);
794   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
795   DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
796   return finishCreatingSPIRVType(Ty, SpirvType);
797 }
798 
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)799 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
800     const Type *Ty, MachineIRBuilder &MIRBuilder,
801     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
802   Ty = adjustIntTypeByWidth(Ty);
803   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
804   if (Reg.isValid())
805     return getSPIRVTypeForVReg(Reg);
806   if (ForwardPointerTypes.contains(Ty))
807     return ForwardPointerTypes[Ty];
808   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
809 }
810 
getSPIRVTypeID(const SPIRVType * SpirvType) const811 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
812   assert(SpirvType && "Attempting to get type id for nullptr type.");
813   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
814     return SpirvType->uses().begin()->getReg();
815   return SpirvType->defs().begin()->getReg();
816 }
817 
818 // We need to use a new LLVM integer type if there is a mismatch between
819 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker
820 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
821 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
822 // same "OpTypeInt 8" type for a series of LLVM integer types with number of
823 // bits less than 8. This would lead to duplicate type definitions
824 // eventually due to the method that DuplicateTracker utilizes to reason
825 // about uniqueness of type records.
adjustIntTypeByWidth(const Type * Ty) const826 const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
827   if (auto IType = dyn_cast<IntegerType>(Ty)) {
828     unsigned SrcBitWidth = IType->getBitWidth();
829     if (SrcBitWidth > 1) {
830       unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
831       // Maybe change source LLVM type to keep DuplicateTracker consistent.
832       if (SrcBitWidth != BitWidth)
833         Ty = IntegerType::get(Ty->getContext(), BitWidth);
834     }
835   }
836   return Ty;
837 }
838 
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)839 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
840     const Type *Ty, MachineIRBuilder &MIRBuilder,
841     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
842   if (isSpecialOpaqueType(Ty))
843     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
844   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
845   auto t = TypeToSPIRVTypeMap.find(Ty);
846   if (t != TypeToSPIRVTypeMap.end()) {
847     auto tt = t->second.find(&MIRBuilder.getMF());
848     if (tt != t->second.end())
849       return getSPIRVTypeForVReg(tt->second);
850   }
851 
852   if (auto IType = dyn_cast<IntegerType>(Ty)) {
853     const unsigned Width = IType->getBitWidth();
854     return Width == 1 ? getOpTypeBool(MIRBuilder)
855                       : getOpTypeInt(Width, MIRBuilder, false);
856   }
857   if (Ty->isFloatingPointTy())
858     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
859   if (Ty->isVoidTy())
860     return getOpTypeVoid(MIRBuilder);
861   if (Ty->isVectorTy()) {
862     SPIRVType *El =
863         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
864     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
865                            MIRBuilder);
866   }
867   if (Ty->isArrayTy()) {
868     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
869     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
870   }
871   if (auto SType = dyn_cast<StructType>(Ty)) {
872     if (SType->isOpaque())
873       return getOpTypeOpaque(SType, MIRBuilder);
874     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
875   }
876   if (auto FType = dyn_cast<FunctionType>(Ty)) {
877     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
878     SmallVector<SPIRVType *, 4> ParamTypes;
879     for (const auto &t : FType->params()) {
880       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
881     }
882     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
883   }
884   unsigned AddrSpace = 0xFFFF;
885   if (auto PType = dyn_cast<TypedPointerType>(Ty))
886     AddrSpace = PType->getAddressSpace();
887   else if (auto PType = dyn_cast<PointerType>(Ty))
888     AddrSpace = PType->getAddressSpace();
889   else
890     report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
891 
892   SPIRVType *SpvElementType = nullptr;
893   if (auto PType = dyn_cast<TypedPointerType>(Ty))
894     SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
895                                           AccQual, EmitIR);
896   else
897     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
898 
899   // Get access to information about available extensions
900   const SPIRVSubtarget *ST =
901       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
902   auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
903   // Null pointer means we have a loop in type definitions, make and
904   // return corresponding OpTypeForwardPointer.
905   if (SpvElementType == nullptr) {
906     if (!ForwardPointerTypes.contains(Ty))
907       ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
908     return ForwardPointerTypes[Ty];
909   }
910   // If we have forward pointer associated with this type, use its register
911   // operand to create OpTypePointer.
912   if (ForwardPointerTypes.contains(Ty)) {
913     Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);
914     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
915   }
916 
917   return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
918 }
919 
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)920 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
921     const Type *Ty, MachineIRBuilder &MIRBuilder,
922     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
923   if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))
924     return nullptr;
925   TypesInProcessing.insert(Ty);
926   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
927   TypesInProcessing.erase(Ty);
928   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
929   SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
930   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
931   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
932   // will be added later. For special types it is already added to DT.
933   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
934       !isSpecialOpaqueType(Ty)) {
935     if (!isPointerTy(Ty))
936       DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
937     else if (isTypedPointerTy(Ty))
938       DT.add(cast<TypedPointerType>(Ty)->getElementType(),
939              getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
940              getSPIRVTypeID(SpirvType));
941     else
942       DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
943              getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
944              getSPIRVTypeID(SpirvType));
945   }
946 
947   return SpirvType;
948 }
949 
950 SPIRVType *
getSPIRVTypeForVReg(Register VReg,const MachineFunction * MF) const951 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
952                                          const MachineFunction *MF) const {
953   auto t = VRegToTypeMap.find(MF ? MF : CurMF);
954   if (t != VRegToTypeMap.end()) {
955     auto tt = t->second.find(VReg);
956     if (tt != t->second.end())
957       return tt->second;
958   }
959   return nullptr;
960 }
961 
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)962 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
963     const Type *Ty, MachineIRBuilder &MIRBuilder,
964     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
965   Register Reg;
966   if (!isPointerTy(Ty)) {
967     Ty = adjustIntTypeByWidth(Ty);
968     Reg = DT.find(Ty, &MIRBuilder.getMF());
969   } else if (isTypedPointerTy(Ty)) {
970     Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
971                   getPointerAddressSpace(Ty), &MIRBuilder.getMF());
972   } else {
973     Reg =
974         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
975                 getPointerAddressSpace(Ty), &MIRBuilder.getMF());
976   }
977 
978   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
979     return getSPIRVTypeForVReg(Reg);
980   TypesInProcessing.clear();
981   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
982   // Create normal pointer types for the corresponding OpTypeForwardPointers.
983   for (auto &CU : ForwardPointerTypes) {
984     const Type *Ty2 = CU.first;
985     SPIRVType *STy2 = CU.second;
986     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
987       STy2 = getSPIRVTypeForVReg(Reg);
988     else
989       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
990     if (Ty == Ty2)
991       STy = STy2;
992   }
993   ForwardPointerTypes.clear();
994   return STy;
995 }
996 
isScalarOfType(Register VReg,unsigned TypeOpcode) const997 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
998                                          unsigned TypeOpcode) const {
999   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1000   assert(Type && "isScalarOfType VReg has no type assigned");
1001   return Type->getOpcode() == TypeOpcode;
1002 }
1003 
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const1004 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
1005                                                  unsigned TypeOpcode) const {
1006   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1007   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
1008   if (Type->getOpcode() == TypeOpcode)
1009     return true;
1010   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1011     Register ScalarTypeVReg = Type->getOperand(1).getReg();
1012     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
1013     return ScalarType->getOpcode() == TypeOpcode;
1014   }
1015   return false;
1016 }
1017 
1018 unsigned
getScalarOrVectorComponentCount(Register VReg) const1019 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
1020   return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
1021 }
1022 
1023 unsigned
getScalarOrVectorComponentCount(SPIRVType * Type) const1024 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
1025   if (!Type)
1026     return 0;
1027   return Type->getOpcode() == SPIRV::OpTypeVector
1028              ? static_cast<unsigned>(Type->getOperand(2).getImm())
1029              : 1;
1030 }
1031 
1032 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const1033 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1034   assert(Type && "Invalid Type pointer");
1035   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1036     auto EleTypeReg = Type->getOperand(1).getReg();
1037     Type = getSPIRVTypeForVReg(EleTypeReg);
1038   }
1039   if (Type->getOpcode() == SPIRV::OpTypeInt ||
1040       Type->getOpcode() == SPIRV::OpTypeFloat)
1041     return Type->getOperand(1).getImm();
1042   if (Type->getOpcode() == SPIRV::OpTypeBool)
1043     return 1;
1044   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1045 }
1046 
getNumScalarOrVectorTotalBitWidth(const SPIRVType * Type) const1047 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1048     const SPIRVType *Type) const {
1049   assert(Type && "Invalid Type pointer");
1050   unsigned NumElements = 1;
1051   if (Type->getOpcode() == SPIRV::OpTypeVector) {
1052     NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
1053     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1054   }
1055   return Type->getOpcode() == SPIRV::OpTypeInt ||
1056                  Type->getOpcode() == SPIRV::OpTypeFloat
1057              ? NumElements * Type->getOperand(1).getImm()
1058              : 0;
1059 }
1060 
retrieveScalarOrVectorIntType(const SPIRVType * Type) const1061 const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1062     const SPIRVType *Type) const {
1063   if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1064     Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1065   return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1066 }
1067 
isScalarOrVectorSigned(const SPIRVType * Type) const1068 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1069   const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1070   return IntType && IntType->getOperand(2).getImm() != 0;
1071 }
1072 
getPointeeType(SPIRVType * PtrType)1073 SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
1074   return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1075              ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
1076              : nullptr;
1077 }
1078 
getPointeeTypeOp(Register PtrReg)1079 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1080   SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
1081   return ElemType ? ElemType->getOpcode() : 0;
1082 }
1083 
isBitcastCompatible(const SPIRVType * Type1,const SPIRVType * Type2) const1084 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1085                                               const SPIRVType *Type2) const {
1086   if (!Type1 || !Type2)
1087     return false;
1088   auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1089   // Ignore difference between <1.5 and >=1.5 protocol versions:
1090   // it's valid if either Result Type or Operand is a pointer, and the other
1091   // is a pointer, an integer scalar, or an integer vector.
1092   if (Op1 == SPIRV::OpTypePointer &&
1093       (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
1094     return true;
1095   if (Op2 == SPIRV::OpTypePointer &&
1096       (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
1097     return true;
1098   unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
1099            Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
1100   return Bits1 > 0 && Bits1 == Bits2;
1101 }
1102 
1103 SPIRV::StorageClass::StorageClass
getPointerStorageClass(Register VReg) const1104 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1105   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1106   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1107          Type->getOperand(1).isImm() && "Pointer type is expected");
1108   return static_cast<SPIRV::StorageClass::StorageClass>(
1109       Type->getOperand(1).getImm());
1110 }
1111 
getOrCreateOpTypeImage(MachineIRBuilder & MIRBuilder,SPIRVType * SampledType,SPIRV::Dim::Dim Dim,uint32_t Depth,uint32_t Arrayed,uint32_t Multisampled,uint32_t Sampled,SPIRV::ImageFormat::ImageFormat ImageFormat,SPIRV::AccessQualifier::AccessQualifier AccessQual)1112 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1113     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1114     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1115     SPIRV::ImageFormat::ImageFormat ImageFormat,
1116     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1117   auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1118                                     Depth, Arrayed, Multisampled, Sampled,
1119                                     ImageFormat, AccessQual);
1120   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1121     return Res;
1122   Register ResVReg = createTypeVReg(MIRBuilder);
1123   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1124   return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
1125       .addDef(ResVReg)
1126       .addUse(getSPIRVTypeID(SampledType))
1127       .addImm(Dim)
1128       .addImm(Depth)        // Depth (whether or not it is a Depth image).
1129       .addImm(Arrayed)      // Arrayed.
1130       .addImm(Multisampled) // Multisampled (0 = only single-sample).
1131       .addImm(Sampled)      // Sampled (0 = usage known at runtime).
1132       .addImm(ImageFormat)
1133       .addImm(AccessQual);
1134 }
1135 
1136 SPIRVType *
getOrCreateOpTypeSampler(MachineIRBuilder & MIRBuilder)1137 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1138   auto TD = SPIRV::make_descr_sampler();
1139   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1140     return Res;
1141   Register ResVReg = createTypeVReg(MIRBuilder);
1142   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1143   return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
1144 }
1145 
getOrCreateOpTypePipe(MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual)1146 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1147     MachineIRBuilder &MIRBuilder,
1148     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1149   auto TD = SPIRV::make_descr_pipe(AccessQual);
1150   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1151     return Res;
1152   Register ResVReg = createTypeVReg(MIRBuilder);
1153   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1154   return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
1155       .addDef(ResVReg)
1156       .addImm(AccessQual);
1157 }
1158 
getOrCreateOpTypeDeviceEvent(MachineIRBuilder & MIRBuilder)1159 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1160     MachineIRBuilder &MIRBuilder) {
1161   auto TD = SPIRV::make_descr_event();
1162   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1163     return Res;
1164   Register ResVReg = createTypeVReg(MIRBuilder);
1165   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1166   return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
1167 }
1168 
getOrCreateOpTypeSampledImage(SPIRVType * ImageType,MachineIRBuilder & MIRBuilder)1169 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1170     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1171   auto TD = SPIRV::make_descr_sampled_image(
1172       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
1173           ImageType->getOperand(1).getReg())),
1174       ImageType);
1175   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1176     return Res;
1177   Register ResVReg = createTypeVReg(MIRBuilder);
1178   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1179   return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
1180       .addDef(ResVReg)
1181       .addUse(getSPIRVTypeID(ImageType));
1182 }
1183 
getOrCreateOpTypeCoopMatr(MachineIRBuilder & MIRBuilder,const TargetExtType * ExtensionType,const SPIRVType * ElemType,uint32_t Scope,uint32_t Rows,uint32_t Columns,uint32_t Use)1184 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1185     MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
1186     const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
1187     uint32_t Use) {
1188   Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
1189   if (ResVReg.isValid())
1190     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1191   ResVReg = createTypeVReg(MIRBuilder);
1192   SPIRVType *SpirvTy =
1193       MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
1194           .addDef(ResVReg)
1195           .addUse(getSPIRVTypeID(ElemType))
1196           .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
1197           .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
1198           .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
1199           .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
1200   DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
1201   return SpirvTy;
1202 }
1203 
getOrCreateOpTypeByOpcode(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode)1204 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1205     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1206   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
1207   if (ResVReg.isValid())
1208     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1209   ResVReg = createTypeVReg(MIRBuilder);
1210   SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
1211   DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
1212   return SpirvTy;
1213 }
1214 
1215 const MachineInstr *
checkSpecialInstr(const SPIRV::SpecialTypeDescriptor & TD,MachineIRBuilder & MIRBuilder)1216 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
1217                                        MachineIRBuilder &MIRBuilder) {
1218   Register Reg = DT.find(TD, &MIRBuilder.getMF());
1219   if (Reg.isValid())
1220     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
1221   return nullptr;
1222 }
1223 
1224 // Returns nullptr if unable to recognize SPIRV type name
getOrCreateSPIRVTypeByName(StringRef TypeStr,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC,SPIRV::AccessQualifier::AccessQualifier AQ)1225 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1226     StringRef TypeStr, MachineIRBuilder &MIRBuilder,
1227     SPIRV::StorageClass::StorageClass SC,
1228     SPIRV::AccessQualifier::AccessQualifier AQ) {
1229   unsigned VecElts = 0;
1230   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1231 
1232   // Parse strings representing either a SPIR-V or OpenCL builtin type.
1233   if (hasBuiltinTypePrefix(TypeStr))
1234     return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1235                                     TypeStr.str(), MIRBuilder.getContext()),
1236                                 MIRBuilder, AQ);
1237 
1238   // Parse type name in either "typeN" or "type vector[N]" format, where
1239   // N is the number of elements of the vector.
1240   Type *Ty;
1241 
1242   Ty = parseBasicTypeName(TypeStr, Ctx);
1243   if (!Ty)
1244     // Unable to recognize SPIRV type name
1245     return nullptr;
1246 
1247   auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
1248 
1249   // Handle "type*" or  "type* vector[N]".
1250   if (TypeStr.starts_with("*")) {
1251     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1252     TypeStr = TypeStr.substr(strlen("*"));
1253   }
1254 
1255   // Handle "typeN*" or  "type vector[N]*".
1256   bool IsPtrToVec = TypeStr.consume_back("*");
1257 
1258   if (TypeStr.consume_front(" vector[")) {
1259     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1260   }
1261   TypeStr.getAsInteger(10, VecElts);
1262   if (VecElts > 0)
1263     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1264 
1265   if (IsPtrToVec)
1266     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1267 
1268   return SpirvTy;
1269 }
1270 
1271 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)1272 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1273                                                  MachineIRBuilder &MIRBuilder) {
1274   return getOrCreateSPIRVType(
1275       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1276       MIRBuilder);
1277 }
1278 
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)1279 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1280                                                         SPIRVType *SpirvType) {
1281   assert(CurMF == SpirvType->getMF());
1282   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1283   SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
1284   return SpirvType;
1285 }
1286 
getOrCreateSPIRVType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII,unsigned SPIRVOPcode,Type * LLVMTy)1287 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1288                                                      MachineInstr &I,
1289                                                      const SPIRVInstrInfo &TII,
1290                                                      unsigned SPIRVOPcode,
1291                                                      Type *LLVMTy) {
1292   Register Reg = DT.find(LLVMTy, CurMF);
1293   if (Reg.isValid())
1294     return getSPIRVTypeForVReg(Reg);
1295   MachineBasicBlock &BB = *I.getParent();
1296   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
1297                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1298                  .addImm(BitWidth)
1299                  .addImm(0);
1300   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1301   return finishCreatingSPIRVType(LLVMTy, MIB);
1302 }
1303 
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1304 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1305     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1306   // Maybe adjust bit width to keep DuplicateTracker consistent. Without
1307   // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1308   // example, the same "OpTypeInt 8" type for a series of LLVM integer types
1309   // with number of bits less than 8, causing duplicate type definitions.
1310   BitWidth = adjustOpTypeIntWidth(BitWidth);
1311   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1312   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1313 }
1314 
getOrCreateSPIRVFloatType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1315 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1316     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1317   LLVMContext &Ctx = CurMF->getFunction().getContext();
1318   Type *LLVMTy;
1319   switch (BitWidth) {
1320   case 16:
1321     LLVMTy = Type::getHalfTy(Ctx);
1322     break;
1323   case 32:
1324     LLVMTy = Type::getFloatTy(Ctx);
1325     break;
1326   case 64:
1327     LLVMTy = Type::getDoubleTy(Ctx);
1328     break;
1329   default:
1330     llvm_unreachable("Bit width is of unexpected size.");
1331   }
1332   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1333 }
1334 
1335 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder)1336 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1337   return getOrCreateSPIRVType(
1338       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1339       MIRBuilder);
1340 }
1341 
1342 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)1343 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1344                                               const SPIRVInstrInfo &TII) {
1345   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1346   Register Reg = DT.find(LLVMTy, CurMF);
1347   if (Reg.isValid())
1348     return getSPIRVTypeForVReg(Reg);
1349   MachineBasicBlock &BB = *I.getParent();
1350   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1351                  .addDef(createTypeVReg(CurMF->getRegInfo()));
1352   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1353   return finishCreatingSPIRVType(LLVMTy, MIB);
1354 }
1355 
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder)1356 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1357     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1358   return getOrCreateSPIRVType(
1359       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1360                            NumElements),
1361       MIRBuilder);
1362 }
1363 
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1364 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1365     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1366     const SPIRVInstrInfo &TII) {
1367   Type *LLVMTy = FixedVectorType::get(
1368       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1369   Register Reg = DT.find(LLVMTy, CurMF);
1370   if (Reg.isValid())
1371     return getSPIRVTypeForVReg(Reg);
1372   MachineBasicBlock &BB = *I.getParent();
1373   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1374                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1375                  .addUse(getSPIRVTypeID(BaseType))
1376                  .addImm(NumElements);
1377   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1378   return finishCreatingSPIRVType(LLVMTy, MIB);
1379 }
1380 
getOrCreateSPIRVArrayType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1381 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1382     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1383     const SPIRVInstrInfo &TII) {
1384   Type *LLVMTy = ArrayType::get(
1385       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1386   Register Reg = DT.find(LLVMTy, CurMF);
1387   if (Reg.isValid())
1388     return getSPIRVTypeForVReg(Reg);
1389   MachineBasicBlock &BB = *I.getParent();
1390   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1391   Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1392   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1393                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1394                  .addUse(getSPIRVTypeID(BaseType))
1395                  .addUse(Len);
1396   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1397   return finishCreatingSPIRVType(LLVMTy, MIB);
1398 }
1399 
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1400 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1401     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1402     SPIRV::StorageClass::StorageClass SC) {
1403   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1404   unsigned AddressSpace = storageClassToAddressSpace(SC);
1405   Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),
1406                                        AddressSpace);
1407   // check if this type is already available
1408   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1409   if (Reg.isValid())
1410     return getSPIRVTypeForVReg(Reg);
1411   // create a new type
1412   auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1413                      MIRBuilder.getDebugLoc(),
1414                      MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1415                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1416                  .addImm(static_cast<uint32_t>(SC))
1417                  .addUse(getSPIRVTypeID(BaseType));
1418   DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1419   return finishCreatingSPIRVType(LLVMTy, MIB);
1420 }
1421 
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineInstr & I,const SPIRVInstrInfo &,SPIRV::StorageClass::StorageClass SC)1422 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1423     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
1424     SPIRV::StorageClass::StorageClass SC) {
1425   MachineIRBuilder MIRBuilder(I);
1426   return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1427 }
1428 
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)1429 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1430                                                SPIRVType *SpvType,
1431                                                const SPIRVInstrInfo &TII) {
1432   assert(SpvType);
1433   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1434   assert(LLVMTy);
1435   // Find a constant in DT or build a new one.
1436   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1437   Register Res = DT.find(UV, CurMF);
1438   if (Res.isValid())
1439     return Res;
1440   LLT LLTy = LLT::scalar(32);
1441   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1442   CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1443   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1444   DT.add(UV, CurMF, Res);
1445 
1446   MachineInstrBuilder MIB;
1447   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1448             .addDef(Res)
1449             .addUse(getSPIRVTypeID(SpvType));
1450   const auto &ST = CurMF->getSubtarget();
1451   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1452                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
1453   return Res;
1454 }
1455