1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVBuiltins.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVTargetMachine.h" 22 #include "SPIRVUtils.h" 23 #include "llvm/ADT/APInt.h" 24 #include "llvm/IR/Constants.h" 25 #include "llvm/IR/Type.h" 26 #include "llvm/Support/Casting.h" 27 #include <cassert> 28 29 using namespace llvm; 30 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 31 : PointerSize(PointerSize), Bound(0) {} 32 33 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth, 34 Register VReg, 35 MachineInstr &I, 36 const SPIRVInstrInfo &TII) { 37 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 38 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 39 return SpirvType; 40 } 41 42 SPIRVType * 43 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg, 44 MachineInstr &I, 45 const SPIRVInstrInfo &TII) { 46 SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII); 47 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 48 return SpirvType; 49 } 50 51 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( 52 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I, 53 const SPIRVInstrInfo &TII) { 54 SPIRVType *SpirvType = 55 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII); 56 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 57 return SpirvType; 58 } 59 60 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 61 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 62 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 63 SPIRVType *SpirvType = 64 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 65 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 66 return SpirvType; 67 } 68 69 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 70 Register VReg, 71 MachineFunction &MF) { 72 VRegToTypeMap[&MF][VReg] = SpirvType; 73 } 74 75 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 76 auto &MRI = MIRBuilder.getMF().getRegInfo(); 77 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 78 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 79 return Res; 80 } 81 82 static Register createTypeVReg(MachineRegisterInfo &MRI) { 83 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 84 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 85 return Res; 86 } 87 88 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 89 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 90 .addDef(createTypeVReg(MIRBuilder)); 91 } 92 93 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const { 94 if (Width > 64) 95 report_fatal_error("Unsupported integer width!"); 96 const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget()); 97 if (ST.canUseExtension( 98 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) 99 return Width; 100 if (Width <= 8) 101 Width = 8; 102 else if (Width <= 16) 103 Width = 16; 104 else if (Width <= 32) 105 Width = 32; 106 else 107 Width = 64; 108 return Width; 109 } 110 111 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width, 112 MachineIRBuilder &MIRBuilder, 113 bool IsSigned) { 114 Width = adjustOpTypeIntWidth(Width); 115 const SPIRVSubtarget &ST = 116 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); 117 if (ST.canUseExtension( 118 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) { 119 MIRBuilder.buildInstr(SPIRV::OpExtension) 120 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers); 121 MIRBuilder.buildInstr(SPIRV::OpCapability) 122 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL); 123 } 124 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 125 .addDef(createTypeVReg(MIRBuilder)) 126 .addImm(Width) 127 .addImm(IsSigned ? 1 : 0); 128 return MIB; 129 } 130 131 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 132 MachineIRBuilder &MIRBuilder) { 133 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 134 .addDef(createTypeVReg(MIRBuilder)) 135 .addImm(Width); 136 return MIB; 137 } 138 139 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 140 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 141 .addDef(createTypeVReg(MIRBuilder)); 142 } 143 144 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 145 SPIRVType *ElemType, 146 MachineIRBuilder &MIRBuilder) { 147 auto EleOpc = ElemType->getOpcode(); 148 (void)EleOpc; 149 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 150 EleOpc == SPIRV::OpTypeBool) && 151 "Invalid vector element type"); 152 153 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 154 .addDef(createTypeVReg(MIRBuilder)) 155 .addUse(getSPIRVTypeID(ElemType)) 156 .addImm(NumElems); 157 return MIB; 158 } 159 160 std::tuple<Register, ConstantInt *, bool> 161 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, 162 MachineIRBuilder *MIRBuilder, 163 MachineInstr *I, 164 const SPIRVInstrInfo *TII) { 165 const IntegerType *LLVMIntTy; 166 if (SpvType) 167 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 168 else 169 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext()); 170 bool NewInstr = false; 171 // Find a constant in DT or build a new one. 172 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 173 Register Res = DT.find(CI, CurMF); 174 if (!Res.isValid()) { 175 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 176 // TODO: handle cases where the type is not 32bit wide 177 // TODO: https://github.com/llvm/llvm-project/issues/88129 178 LLT LLTy = LLT::scalar(32); 179 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 180 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 181 if (MIRBuilder) 182 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); 183 else 184 assignIntTypeToVReg(BitWidth, Res, *I, *TII); 185 DT.add(CI, CurMF, Res); 186 NewInstr = true; 187 } 188 return std::make_tuple(Res, CI, NewInstr); 189 } 190 191 std::tuple<Register, ConstantFP *, bool, unsigned> 192 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType, 193 MachineIRBuilder *MIRBuilder, 194 MachineInstr *I, 195 const SPIRVInstrInfo *TII) { 196 const Type *LLVMFloatTy; 197 LLVMContext &Ctx = CurMF->getFunction().getContext(); 198 unsigned BitWidth = 32; 199 if (SpvType) 200 LLVMFloatTy = getTypeForSPIRVType(SpvType); 201 else { 202 LLVMFloatTy = Type::getFloatTy(Ctx); 203 if (MIRBuilder) 204 SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder); 205 } 206 bool NewInstr = false; 207 // Find a constant in DT or build a new one. 208 auto *const CI = ConstantFP::get(Ctx, Val); 209 Register Res = DT.find(CI, CurMF); 210 if (!Res.isValid()) { 211 if (SpvType) 212 BitWidth = getScalarOrVectorBitWidth(SpvType); 213 // TODO: handle cases where the type is not 32bit wide 214 // TODO: https://github.com/llvm/llvm-project/issues/88129 215 LLT LLTy = LLT::scalar(32); 216 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 217 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 218 if (MIRBuilder) 219 assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder); 220 else 221 assignFloatTypeToVReg(BitWidth, Res, *I, *TII); 222 DT.add(CI, CurMF, Res); 223 NewInstr = true; 224 } 225 return std::make_tuple(Res, CI, NewInstr, BitWidth); 226 } 227 228 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I, 229 SPIRVType *SpvType, 230 const SPIRVInstrInfo &TII, 231 bool ZeroAsNull) { 232 assert(SpvType); 233 ConstantFP *CI; 234 Register Res; 235 bool New; 236 unsigned BitWidth; 237 std::tie(Res, CI, New, BitWidth) = 238 getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII); 239 // If we have found Res register which is defined by the passed G_CONSTANT 240 // machine instruction, a new constant instruction should be created. 241 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 242 return Res; 243 MachineInstrBuilder MIB; 244 MachineBasicBlock &BB = *I.getParent(); 245 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0) 246 if (Val.isPosZero() && ZeroAsNull) { 247 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 248 .addDef(Res) 249 .addUse(getSPIRVTypeID(SpvType)); 250 } else { 251 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF)) 252 .addDef(Res) 253 .addUse(getSPIRVTypeID(SpvType)); 254 addNumImm( 255 APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()), 256 MIB); 257 } 258 const auto &ST = CurMF->getSubtarget(); 259 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 260 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 261 return Res; 262 } 263 264 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, 265 SPIRVType *SpvType, 266 const SPIRVInstrInfo &TII, 267 bool ZeroAsNull) { 268 assert(SpvType); 269 ConstantInt *CI; 270 Register Res; 271 bool New; 272 std::tie(Res, CI, New) = 273 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII); 274 // If we have found Res register which is defined by the passed G_CONSTANT 275 // machine instruction, a new constant instruction should be created. 276 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 277 return Res; 278 MachineInstrBuilder MIB; 279 MachineBasicBlock &BB = *I.getParent(); 280 if (Val || !ZeroAsNull) { 281 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) 282 .addDef(Res) 283 .addUse(getSPIRVTypeID(SpvType)); 284 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB); 285 } else { 286 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 287 .addDef(Res) 288 .addUse(getSPIRVTypeID(SpvType)); 289 } 290 const auto &ST = CurMF->getSubtarget(); 291 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 292 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 293 return Res; 294 } 295 296 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 297 MachineIRBuilder &MIRBuilder, 298 SPIRVType *SpvType, 299 bool EmitIR) { 300 auto &MF = MIRBuilder.getMF(); 301 const IntegerType *LLVMIntTy; 302 if (SpvType) 303 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 304 else 305 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 306 // Find a constant in DT or build a new one. 307 const auto ConstInt = 308 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 309 Register Res = DT.find(ConstInt, &MF); 310 if (!Res.isValid()) { 311 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 312 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); 313 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); 314 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 315 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, 316 SPIRV::AccessQualifier::ReadWrite, EmitIR); 317 DT.add(ConstInt, &MIRBuilder.getMF(), Res); 318 if (EmitIR) { 319 MIRBuilder.buildConstant(Res, *ConstInt); 320 } else { 321 if (!SpvType) 322 SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); 323 MachineInstrBuilder MIB; 324 if (Val) { 325 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 326 .addDef(Res) 327 .addUse(getSPIRVTypeID(SpvType)); 328 addNumImm(APInt(BitWidth, Val), MIB); 329 } else { 330 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 331 .addDef(Res) 332 .addUse(getSPIRVTypeID(SpvType)); 333 } 334 const auto &Subtarget = CurMF->getSubtarget(); 335 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 336 *Subtarget.getRegisterInfo(), 337 *Subtarget.getRegBankInfo()); 338 } 339 } 340 return Res; 341 } 342 343 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 344 MachineIRBuilder &MIRBuilder, 345 SPIRVType *SpvType) { 346 auto &MF = MIRBuilder.getMF(); 347 auto &Ctx = MF.getFunction().getContext(); 348 if (!SpvType) { 349 const Type *LLVMFPTy = Type::getFloatTy(Ctx); 350 SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder); 351 } 352 // Find a constant in DT or build a new one. 353 const auto ConstFP = ConstantFP::get(Ctx, Val); 354 Register Res = DT.find(ConstFP, &MF); 355 if (!Res.isValid()) { 356 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32)); 357 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 358 assignSPIRVTypeToVReg(SpvType, Res, MF); 359 DT.add(ConstFP, &MF, Res); 360 361 MachineInstrBuilder MIB; 362 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF) 363 .addDef(Res) 364 .addUse(getSPIRVTypeID(SpvType)); 365 addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB); 366 } 367 368 return Res; 369 } 370 371 Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val, 372 MachineInstr &I, 373 SPIRVType *SpvType, 374 const SPIRVInstrInfo &TII, 375 unsigned BitWidth) { 376 SPIRVType *Type = SpvType; 377 if (SpvType->getOpcode() == SPIRV::OpTypeVector || 378 SpvType->getOpcode() == SPIRV::OpTypeArray) { 379 auto EleTypeReg = SpvType->getOperand(1).getReg(); 380 Type = getSPIRVTypeForVReg(EleTypeReg); 381 } 382 if (Type->getOpcode() == SPIRV::OpTypeFloat) { 383 SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII); 384 return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I, 385 SpvBaseType, TII); 386 } 387 assert(Type->getOpcode() == SPIRV::OpTypeInt); 388 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 389 return getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(), I, 390 SpvBaseType, TII); 391 } 392 393 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull( 394 Constant *Val, MachineInstr &I, SPIRVType *SpvType, 395 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, 396 unsigned ElemCnt, bool ZeroAsNull) { 397 // Find a constant vector or array in DT or build a new one. 398 Register Res = DT.find(CA, CurMF); 399 // If no values are attached, the composite is null constant. 400 bool IsNull = Val->isNullValue() && ZeroAsNull; 401 if (!Res.isValid()) { 402 // SpvScalConst should be created before SpvVecConst to avoid undefined ID 403 // error on validation. 404 // TODO: can moved below once sorting of types/consts/defs is implemented. 405 Register SpvScalConst; 406 if (!IsNull) 407 SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth); 408 409 // TODO: handle cases where the type is not 32bit wide 410 // TODO: https://github.com/llvm/llvm-project/issues/88129 411 LLT LLTy = LLT::scalar(32); 412 Register SpvVecConst = 413 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 414 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 415 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 416 DT.add(CA, CurMF, SpvVecConst); 417 MachineInstrBuilder MIB; 418 MachineBasicBlock &BB = *I.getParent(); 419 if (!IsNull) { 420 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite)) 421 .addDef(SpvVecConst) 422 .addUse(getSPIRVTypeID(SpvType)); 423 for (unsigned i = 0; i < ElemCnt; ++i) 424 MIB.addUse(SpvScalConst); 425 } else { 426 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 427 .addDef(SpvVecConst) 428 .addUse(getSPIRVTypeID(SpvType)); 429 } 430 const auto &Subtarget = CurMF->getSubtarget(); 431 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 432 *Subtarget.getRegisterInfo(), 433 *Subtarget.getRegBankInfo()); 434 return SpvVecConst; 435 } 436 return Res; 437 } 438 439 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val, 440 MachineInstr &I, 441 SPIRVType *SpvType, 442 const SPIRVInstrInfo &TII, 443 bool ZeroAsNull) { 444 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 445 assert(LLVMTy->isVectorTy()); 446 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 447 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 448 assert(LLVMBaseTy->isIntegerTy()); 449 auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val); 450 auto *ConstVec = 451 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal); 452 unsigned BW = getScalarOrVectorBitWidth(SpvType); 453 return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW, 454 SpvType->getOperand(2).getImm(), 455 ZeroAsNull); 456 } 457 458 Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val, 459 MachineInstr &I, 460 SPIRVType *SpvType, 461 const SPIRVInstrInfo &TII, 462 bool ZeroAsNull) { 463 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 464 assert(LLVMTy->isVectorTy()); 465 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 466 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 467 assert(LLVMBaseTy->isFloatingPointTy()); 468 auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val); 469 auto *ConstVec = 470 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal); 471 unsigned BW = getScalarOrVectorBitWidth(SpvType); 472 return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW, 473 SpvType->getOperand(2).getImm(), 474 ZeroAsNull); 475 } 476 477 Register SPIRVGlobalRegistry::getOrCreateConstIntArray( 478 uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType, 479 const SPIRVInstrInfo &TII) { 480 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 481 assert(LLVMTy->isArrayTy()); 482 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 483 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 484 Constant *CI = ConstantInt::get(LLVMBaseTy, Val); 485 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 486 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 487 // The following is reasonably unique key that is better that [Val]. The naive 488 // alternative would be something along the lines of: 489 // SmallVector<Constant *> NumCI(Num, CI); 490 // Constant *UniqueKey = 491 // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI); 492 // that would be a truly unique but dangerous key, because it could lead to 493 // the creation of constants of arbitrary length (that is, the parameter of 494 // memset) which were missing in the original module. 495 Constant *UniqueKey = ConstantStruct::getAnon( 496 {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)), 497 ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)}); 498 return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW, 499 LLVMArrTy->getNumElements()); 500 } 501 502 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 503 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, 504 Constant *CA, unsigned BitWidth, unsigned ElemCnt) { 505 Register Res = DT.find(CA, CurMF); 506 if (!Res.isValid()) { 507 Register SpvScalConst; 508 if (Val || EmitIR) { 509 SPIRVType *SpvBaseType = 510 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); 511 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR); 512 } 513 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); 514 Register SpvVecConst = 515 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 516 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 517 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 518 DT.add(CA, CurMF, SpvVecConst); 519 if (EmitIR) { 520 MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst); 521 } else { 522 if (Val) { 523 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) 524 .addDef(SpvVecConst) 525 .addUse(getSPIRVTypeID(SpvType)); 526 for (unsigned i = 0; i < ElemCnt; ++i) 527 MIB.addUse(SpvScalConst); 528 } else { 529 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 530 .addDef(SpvVecConst) 531 .addUse(getSPIRVTypeID(SpvType)); 532 } 533 } 534 return SpvVecConst; 535 } 536 return Res; 537 } 538 539 Register 540 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, 541 MachineIRBuilder &MIRBuilder, 542 SPIRVType *SpvType, bool EmitIR) { 543 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 544 assert(LLVMTy->isVectorTy()); 545 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 546 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 547 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 548 auto ConstVec = 549 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 550 unsigned BW = getScalarOrVectorBitWidth(SpvType); 551 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 552 ConstVec, BW, 553 SpvType->getOperand(2).getImm()); 554 } 555 556 Register 557 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, 558 SPIRVType *SpvType) { 559 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 560 const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(LLVMTy); 561 // Find a constant in DT or build a new one. 562 Constant *CP = ConstantPointerNull::get(PointerType::get( 563 LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace())); 564 Register Res = DT.find(CP, CurMF); 565 if (!Res.isValid()) { 566 LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); 567 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 568 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 569 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 570 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 571 .addDef(Res) 572 .addUse(getSPIRVTypeID(SpvType)); 573 DT.add(CP, CurMF, Res); 574 } 575 return Res; 576 } 577 578 Register SPIRVGlobalRegistry::buildConstantSampler( 579 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode, 580 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { 581 SPIRVType *SampTy; 582 if (SpvType) 583 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); 584 else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", 585 MIRBuilder)) == nullptr) 586 report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t"); 587 588 auto Sampler = 589 ResReg.isValid() 590 ? ResReg 591 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 592 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler) 593 .addDef(Sampler) 594 .addUse(getSPIRVTypeID(SampTy)) 595 .addImm(AddrMode) 596 .addImm(Param) 597 .addImm(FilerMode); 598 assert(Res->getOperand(0).isReg()); 599 return Res->getOperand(0).getReg(); 600 } 601 602 Register SPIRVGlobalRegistry::buildGlobalVariable( 603 Register ResVReg, SPIRVType *BaseType, StringRef Name, 604 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, 605 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 606 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 607 bool IsInstSelector) { 608 const GlobalVariable *GVar = nullptr; 609 if (GV) 610 GVar = cast<const GlobalVariable>(GV); 611 else { 612 // If GV is not passed explicitly, use the name to find or construct 613 // the global variable. 614 Module *M = MIRBuilder.getMF().getFunction().getParent(); 615 GVar = M->getGlobalVariable(Name); 616 if (GVar == nullptr) { 617 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 618 // Module takes ownership of the global var. 619 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 620 GlobalValue::ExternalLinkage, nullptr, 621 Twine(Name)); 622 } 623 GV = GVar; 624 } 625 Register Reg = DT.find(GVar, &MIRBuilder.getMF()); 626 if (Reg.isValid()) { 627 if (Reg != ResVReg) 628 MIRBuilder.buildCopy(ResVReg, Reg); 629 return ResVReg; 630 } 631 632 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 633 .addDef(ResVReg) 634 .addUse(getSPIRVTypeID(BaseType)) 635 .addImm(static_cast<uint32_t>(Storage)); 636 637 if (Init != 0) { 638 MIB.addUse(Init->getOperand(0).getReg()); 639 } 640 641 // ISel may introduce a new register on this step, so we need to add it to 642 // DT and correct its type avoiding fails on the next stage. 643 if (IsInstSelector) { 644 const auto &Subtarget = CurMF->getSubtarget(); 645 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 646 *Subtarget.getRegisterInfo(), 647 *Subtarget.getRegBankInfo()); 648 } 649 Reg = MIB->getOperand(0).getReg(); 650 DT.add(GVar, &MIRBuilder.getMF(), Reg); 651 652 // Set to Reg the same type as ResVReg has. 653 auto MRI = MIRBuilder.getMRI(); 654 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 655 if (Reg != ResVReg) { 656 LLT RegLLTy = 657 LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize()); 658 MRI->setType(Reg, RegLLTy); 659 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 660 } else { 661 // Our knowledge about the type may be updated. 662 // If that's the case, we need to update a type 663 // associated with the register. 664 SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg); 665 if (!DefType || DefType != BaseType) 666 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 667 } 668 669 // If it's a global variable with name, output OpName for it. 670 if (GVar && GVar->hasName()) 671 buildOpName(Reg, GVar->getName(), MIRBuilder); 672 673 // Output decorations for the GV. 674 // TODO: maybe move to GenerateDecorations pass. 675 const SPIRVSubtarget &ST = 676 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); 677 if (IsConst && ST.isOpenCLEnv()) 678 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 679 680 if (GVar && GVar->getAlign().valueOrOne().value() != 1) { 681 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value(); 682 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); 683 } 684 685 if (HasLinkageTy) 686 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 687 {static_cast<uint32_t>(LinkageType)}, Name); 688 689 SPIRV::BuiltIn::BuiltIn BuiltInId; 690 if (getSpirvBuiltInIdByName(Name, BuiltInId)) 691 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn, 692 {static_cast<uint32_t>(BuiltInId)}); 693 694 // If it's a global variable with "spirv.Decorations" metadata node 695 // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations" 696 // arguments. 697 MDNode *GVarMD = nullptr; 698 if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr) 699 buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD); 700 701 return Reg; 702 } 703 704 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 705 SPIRVType *ElemType, 706 MachineIRBuilder &MIRBuilder, 707 bool EmitIR) { 708 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 709 "Invalid array element type"); 710 Register NumElementsVReg = 711 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 712 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 713 .addDef(createTypeVReg(MIRBuilder)) 714 .addUse(getSPIRVTypeID(ElemType)) 715 .addUse(NumElementsVReg); 716 return MIB; 717 } 718 719 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, 720 MachineIRBuilder &MIRBuilder) { 721 assert(Ty->hasName()); 722 const StringRef Name = Ty->hasName() ? Ty->getName() : ""; 723 Register ResVReg = createTypeVReg(MIRBuilder); 724 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg); 725 addStringImm(Name, MIB); 726 buildOpName(ResVReg, Name, MIRBuilder); 727 return MIB; 728 } 729 730 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, 731 MachineIRBuilder &MIRBuilder, 732 bool EmitIR) { 733 SmallVector<Register, 4> FieldTypes; 734 for (const auto &Elem : Ty->elements()) { 735 SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder); 736 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && 737 "Invalid struct element type"); 738 FieldTypes.push_back(getSPIRVTypeID(ElemTy)); 739 } 740 Register ResVReg = createTypeVReg(MIRBuilder); 741 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); 742 for (const auto &Ty : FieldTypes) 743 MIB.addUse(Ty); 744 if (Ty->hasName()) 745 buildOpName(ResVReg, Ty->getName(), MIRBuilder); 746 if (Ty->isPacked()) 747 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); 748 return MIB; 749 } 750 751 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( 752 const Type *Ty, MachineIRBuilder &MIRBuilder, 753 SPIRV::AccessQualifier::AccessQualifier AccQual) { 754 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type"); 755 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this); 756 } 757 758 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer( 759 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType, 760 MachineIRBuilder &MIRBuilder, Register Reg) { 761 if (!Reg.isValid()) 762 Reg = createTypeVReg(MIRBuilder); 763 return MIRBuilder.buildInstr(SPIRV::OpTypePointer) 764 .addDef(Reg) 765 .addImm(static_cast<uint32_t>(SC)) 766 .addUse(getSPIRVTypeID(ElemType)); 767 } 768 769 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer( 770 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) { 771 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer) 772 .addUse(createTypeVReg(MIRBuilder)) 773 .addImm(static_cast<uint32_t>(SC)); 774 } 775 776 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 777 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 778 MachineIRBuilder &MIRBuilder) { 779 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 780 .addDef(createTypeVReg(MIRBuilder)) 781 .addUse(getSPIRVTypeID(RetType)); 782 for (const SPIRVType *ArgType : ArgTypes) 783 MIB.addUse(getSPIRVTypeID(ArgType)); 784 return MIB; 785 } 786 787 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( 788 const Type *Ty, SPIRVType *RetType, 789 const SmallVectorImpl<SPIRVType *> &ArgTypes, 790 MachineIRBuilder &MIRBuilder) { 791 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 792 if (Reg.isValid()) 793 return getSPIRVTypeForVReg(Reg); 794 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); 795 DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType)); 796 return finishCreatingSPIRVType(Ty, SpirvType); 797 } 798 799 SPIRVType *SPIRVGlobalRegistry::findSPIRVType( 800 const Type *Ty, MachineIRBuilder &MIRBuilder, 801 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 802 Ty = adjustIntTypeByWidth(Ty); 803 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 804 if (Reg.isValid()) 805 return getSPIRVTypeForVReg(Reg); 806 if (ForwardPointerTypes.contains(Ty)) 807 return ForwardPointerTypes[Ty]; 808 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); 809 } 810 811 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { 812 assert(SpirvType && "Attempting to get type id for nullptr type."); 813 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer) 814 return SpirvType->uses().begin()->getReg(); 815 return SpirvType->defs().begin()->getReg(); 816 } 817 818 // We need to use a new LLVM integer type if there is a mismatch between 819 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker 820 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without 821 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the 822 // same "OpTypeInt 8" type for a series of LLVM integer types with number of 823 // bits less than 8. This would lead to duplicate type definitions 824 // eventually due to the method that DuplicateTracker utilizes to reason 825 // about uniqueness of type records. 826 const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const { 827 if (auto IType = dyn_cast<IntegerType>(Ty)) { 828 unsigned SrcBitWidth = IType->getBitWidth(); 829 if (SrcBitWidth > 1) { 830 unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth); 831 // Maybe change source LLVM type to keep DuplicateTracker consistent. 832 if (SrcBitWidth != BitWidth) 833 Ty = IntegerType::get(Ty->getContext(), BitWidth); 834 } 835 } 836 return Ty; 837 } 838 839 SPIRVType *SPIRVGlobalRegistry::createSPIRVType( 840 const Type *Ty, MachineIRBuilder &MIRBuilder, 841 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 842 if (isSpecialOpaqueType(Ty)) 843 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); 844 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); 845 auto t = TypeToSPIRVTypeMap.find(Ty); 846 if (t != TypeToSPIRVTypeMap.end()) { 847 auto tt = t->second.find(&MIRBuilder.getMF()); 848 if (tt != t->second.end()) 849 return getSPIRVTypeForVReg(tt->second); 850 } 851 852 if (auto IType = dyn_cast<IntegerType>(Ty)) { 853 const unsigned Width = IType->getBitWidth(); 854 return Width == 1 ? getOpTypeBool(MIRBuilder) 855 : getOpTypeInt(Width, MIRBuilder, false); 856 } 857 if (Ty->isFloatingPointTy()) 858 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 859 if (Ty->isVoidTy()) 860 return getOpTypeVoid(MIRBuilder); 861 if (Ty->isVectorTy()) { 862 SPIRVType *El = 863 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder); 864 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 865 MIRBuilder); 866 } 867 if (Ty->isArrayTy()) { 868 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder); 869 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 870 } 871 if (auto SType = dyn_cast<StructType>(Ty)) { 872 if (SType->isOpaque()) 873 return getOpTypeOpaque(SType, MIRBuilder); 874 return getOpTypeStruct(SType, MIRBuilder, EmitIR); 875 } 876 if (auto FType = dyn_cast<FunctionType>(Ty)) { 877 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder); 878 SmallVector<SPIRVType *, 4> ParamTypes; 879 for (const auto &t : FType->params()) { 880 ParamTypes.push_back(findSPIRVType(t, MIRBuilder)); 881 } 882 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 883 } 884 unsigned AddrSpace = 0xFFFF; 885 if (auto PType = dyn_cast<TypedPointerType>(Ty)) 886 AddrSpace = PType->getAddressSpace(); 887 else if (auto PType = dyn_cast<PointerType>(Ty)) 888 AddrSpace = PType->getAddressSpace(); 889 else 890 report_fatal_error("Unable to convert LLVM type to SPIRVType", true); 891 892 SPIRVType *SpvElementType = nullptr; 893 if (auto PType = dyn_cast<TypedPointerType>(Ty)) 894 SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder, 895 AccQual, EmitIR); 896 else 897 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 898 899 // Get access to information about available extensions 900 const SPIRVSubtarget *ST = 901 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget()); 902 auto SC = addressSpaceToStorageClass(AddrSpace, *ST); 903 // Null pointer means we have a loop in type definitions, make and 904 // return corresponding OpTypeForwardPointer. 905 if (SpvElementType == nullptr) { 906 if (!ForwardPointerTypes.contains(Ty)) 907 ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder); 908 return ForwardPointerTypes[Ty]; 909 } 910 // If we have forward pointer associated with this type, use its register 911 // operand to create OpTypePointer. 912 if (ForwardPointerTypes.contains(Ty)) { 913 Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]); 914 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); 915 } 916 917 return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC); 918 } 919 920 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( 921 const Type *Ty, MachineIRBuilder &MIRBuilder, 922 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 923 if (TypesInProcessing.count(Ty) && !isPointerTy(Ty)) 924 return nullptr; 925 TypesInProcessing.insert(Ty); 926 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 927 TypesInProcessing.erase(Ty); 928 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 929 SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty); 930 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 931 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type 932 // will be added later. For special types it is already added to DT. 933 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && 934 !isSpecialOpaqueType(Ty)) { 935 if (!isPointerTy(Ty)) 936 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 937 else if (isTypedPointerTy(Ty)) 938 DT.add(cast<TypedPointerType>(Ty)->getElementType(), 939 getPointerAddressSpace(Ty), &MIRBuilder.getMF(), 940 getSPIRVTypeID(SpirvType)); 941 else 942 DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 943 getPointerAddressSpace(Ty), &MIRBuilder.getMF(), 944 getSPIRVTypeID(SpirvType)); 945 } 946 947 return SpirvType; 948 } 949 950 SPIRVType * 951 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg, 952 const MachineFunction *MF) const { 953 auto t = VRegToTypeMap.find(MF ? MF : CurMF); 954 if (t != VRegToTypeMap.end()) { 955 auto tt = t->second.find(VReg); 956 if (tt != t->second.end()) 957 return tt->second; 958 } 959 return nullptr; 960 } 961 962 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 963 const Type *Ty, MachineIRBuilder &MIRBuilder, 964 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 965 Register Reg; 966 if (!isPointerTy(Ty)) { 967 Ty = adjustIntTypeByWidth(Ty); 968 Reg = DT.find(Ty, &MIRBuilder.getMF()); 969 } else if (isTypedPointerTy(Ty)) { 970 Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(), 971 getPointerAddressSpace(Ty), &MIRBuilder.getMF()); 972 } else { 973 Reg = 974 DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 975 getPointerAddressSpace(Ty), &MIRBuilder.getMF()); 976 } 977 978 if (Reg.isValid() && !isSpecialOpaqueType(Ty)) 979 return getSPIRVTypeForVReg(Reg); 980 TypesInProcessing.clear(); 981 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 982 // Create normal pointer types for the corresponding OpTypeForwardPointers. 983 for (auto &CU : ForwardPointerTypes) { 984 const Type *Ty2 = CU.first; 985 SPIRVType *STy2 = CU.second; 986 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid()) 987 STy2 = getSPIRVTypeForVReg(Reg); 988 else 989 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); 990 if (Ty == Ty2) 991 STy = STy2; 992 } 993 ForwardPointerTypes.clear(); 994 return STy; 995 } 996 997 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 998 unsigned TypeOpcode) const { 999 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 1000 assert(Type && "isScalarOfType VReg has no type assigned"); 1001 return Type->getOpcode() == TypeOpcode; 1002 } 1003 1004 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 1005 unsigned TypeOpcode) const { 1006 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 1007 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 1008 if (Type->getOpcode() == TypeOpcode) 1009 return true; 1010 if (Type->getOpcode() == SPIRV::OpTypeVector) { 1011 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 1012 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 1013 return ScalarType->getOpcode() == TypeOpcode; 1014 } 1015 return false; 1016 } 1017 1018 unsigned 1019 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const { 1020 return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg)); 1021 } 1022 1023 unsigned 1024 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const { 1025 if (!Type) 1026 return 0; 1027 return Type->getOpcode() == SPIRV::OpTypeVector 1028 ? static_cast<unsigned>(Type->getOperand(2).getImm()) 1029 : 1; 1030 } 1031 1032 unsigned 1033 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 1034 assert(Type && "Invalid Type pointer"); 1035 if (Type->getOpcode() == SPIRV::OpTypeVector) { 1036 auto EleTypeReg = Type->getOperand(1).getReg(); 1037 Type = getSPIRVTypeForVReg(EleTypeReg); 1038 } 1039 if (Type->getOpcode() == SPIRV::OpTypeInt || 1040 Type->getOpcode() == SPIRV::OpTypeFloat) 1041 return Type->getOperand(1).getImm(); 1042 if (Type->getOpcode() == SPIRV::OpTypeBool) 1043 return 1; 1044 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 1045 } 1046 1047 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth( 1048 const SPIRVType *Type) const { 1049 assert(Type && "Invalid Type pointer"); 1050 unsigned NumElements = 1; 1051 if (Type->getOpcode() == SPIRV::OpTypeVector) { 1052 NumElements = static_cast<unsigned>(Type->getOperand(2).getImm()); 1053 Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg()); 1054 } 1055 return Type->getOpcode() == SPIRV::OpTypeInt || 1056 Type->getOpcode() == SPIRV::OpTypeFloat 1057 ? NumElements * Type->getOperand(1).getImm() 1058 : 0; 1059 } 1060 1061 const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType( 1062 const SPIRVType *Type) const { 1063 if (Type && Type->getOpcode() == SPIRV::OpTypeVector) 1064 Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg()); 1065 return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr; 1066 } 1067 1068 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 1069 const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type); 1070 return IntType && IntType->getOperand(2).getImm() != 0; 1071 } 1072 1073 SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) { 1074 return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer 1075 ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg()) 1076 : nullptr; 1077 } 1078 1079 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) { 1080 SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg)); 1081 return ElemType ? ElemType->getOpcode() : 0; 1082 } 1083 1084 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1, 1085 const SPIRVType *Type2) const { 1086 if (!Type1 || !Type2) 1087 return false; 1088 auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode(); 1089 // Ignore difference between <1.5 and >=1.5 protocol versions: 1090 // it's valid if either Result Type or Operand is a pointer, and the other 1091 // is a pointer, an integer scalar, or an integer vector. 1092 if (Op1 == SPIRV::OpTypePointer && 1093 (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2))) 1094 return true; 1095 if (Op2 == SPIRV::OpTypePointer && 1096 (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1))) 1097 return true; 1098 unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1), 1099 Bits2 = getNumScalarOrVectorTotalBitWidth(Type2); 1100 return Bits1 > 0 && Bits1 == Bits2; 1101 } 1102 1103 SPIRV::StorageClass::StorageClass 1104 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 1105 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 1106 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 1107 Type->getOperand(1).isImm() && "Pointer type is expected"); 1108 return static_cast<SPIRV::StorageClass::StorageClass>( 1109 Type->getOperand(1).getImm()); 1110 } 1111 1112 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( 1113 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim, 1114 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled, 1115 SPIRV::ImageFormat::ImageFormat ImageFormat, 1116 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 1117 auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim, 1118 Depth, Arrayed, Multisampled, Sampled, 1119 ImageFormat, AccessQual); 1120 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1121 return Res; 1122 Register ResVReg = createTypeVReg(MIRBuilder); 1123 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1124 return MIRBuilder.buildInstr(SPIRV::OpTypeImage) 1125 .addDef(ResVReg) 1126 .addUse(getSPIRVTypeID(SampledType)) 1127 .addImm(Dim) 1128 .addImm(Depth) // Depth (whether or not it is a Depth image). 1129 .addImm(Arrayed) // Arrayed. 1130 .addImm(Multisampled) // Multisampled (0 = only single-sample). 1131 .addImm(Sampled) // Sampled (0 = usage known at runtime). 1132 .addImm(ImageFormat) 1133 .addImm(AccessQual); 1134 } 1135 1136 SPIRVType * 1137 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { 1138 auto TD = SPIRV::make_descr_sampler(); 1139 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1140 return Res; 1141 Register ResVReg = createTypeVReg(MIRBuilder); 1142 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1143 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg); 1144 } 1145 1146 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( 1147 MachineIRBuilder &MIRBuilder, 1148 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 1149 auto TD = SPIRV::make_descr_pipe(AccessQual); 1150 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1151 return Res; 1152 Register ResVReg = createTypeVReg(MIRBuilder); 1153 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1154 return MIRBuilder.buildInstr(SPIRV::OpTypePipe) 1155 .addDef(ResVReg) 1156 .addImm(AccessQual); 1157 } 1158 1159 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( 1160 MachineIRBuilder &MIRBuilder) { 1161 auto TD = SPIRV::make_descr_event(); 1162 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1163 return Res; 1164 Register ResVReg = createTypeVReg(MIRBuilder); 1165 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1166 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg); 1167 } 1168 1169 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( 1170 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { 1171 auto TD = SPIRV::make_descr_sampled_image( 1172 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( 1173 ImageType->getOperand(1).getReg())), 1174 ImageType); 1175 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1176 return Res; 1177 Register ResVReg = createTypeVReg(MIRBuilder); 1178 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1179 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage) 1180 .addDef(ResVReg) 1181 .addUse(getSPIRVTypeID(ImageType)); 1182 } 1183 1184 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr( 1185 MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType, 1186 const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns, 1187 uint32_t Use) { 1188 Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF()); 1189 if (ResVReg.isValid()) 1190 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 1191 ResVReg = createTypeVReg(MIRBuilder); 1192 SPIRVType *SpirvTy = 1193 MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR) 1194 .addDef(ResVReg) 1195 .addUse(getSPIRVTypeID(ElemType)) 1196 .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true)) 1197 .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true)) 1198 .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true)) 1199 .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true)); 1200 DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg); 1201 return SpirvTy; 1202 } 1203 1204 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( 1205 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { 1206 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); 1207 if (ResVReg.isValid()) 1208 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 1209 ResVReg = createTypeVReg(MIRBuilder); 1210 SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg); 1211 DT.add(Ty, &MIRBuilder.getMF(), ResVReg); 1212 return SpirvTy; 1213 } 1214 1215 const MachineInstr * 1216 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, 1217 MachineIRBuilder &MIRBuilder) { 1218 Register Reg = DT.find(TD, &MIRBuilder.getMF()); 1219 if (Reg.isValid()) 1220 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg); 1221 return nullptr; 1222 } 1223 1224 // Returns nullptr if unable to recognize SPIRV type name 1225 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( 1226 StringRef TypeStr, MachineIRBuilder &MIRBuilder, 1227 SPIRV::StorageClass::StorageClass SC, 1228 SPIRV::AccessQualifier::AccessQualifier AQ) { 1229 unsigned VecElts = 0; 1230 auto &Ctx = MIRBuilder.getMF().getFunction().getContext(); 1231 1232 // Parse strings representing either a SPIR-V or OpenCL builtin type. 1233 if (hasBuiltinTypePrefix(TypeStr)) 1234 return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType( 1235 TypeStr.str(), MIRBuilder.getContext()), 1236 MIRBuilder, AQ); 1237 1238 // Parse type name in either "typeN" or "type vector[N]" format, where 1239 // N is the number of elements of the vector. 1240 Type *Ty; 1241 1242 Ty = parseBasicTypeName(TypeStr, Ctx); 1243 if (!Ty) 1244 // Unable to recognize SPIRV type name 1245 return nullptr; 1246 1247 auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ); 1248 1249 // Handle "type*" or "type* vector[N]". 1250 if (TypeStr.starts_with("*")) { 1251 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1252 TypeStr = TypeStr.substr(strlen("*")); 1253 } 1254 1255 // Handle "typeN*" or "type vector[N]*". 1256 bool IsPtrToVec = TypeStr.consume_back("*"); 1257 1258 if (TypeStr.consume_front(" vector[")) { 1259 TypeStr = TypeStr.substr(0, TypeStr.find(']')); 1260 } 1261 TypeStr.getAsInteger(10, VecElts); 1262 if (VecElts > 0) 1263 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder); 1264 1265 if (IsPtrToVec) 1266 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1267 1268 return SpirvTy; 1269 } 1270 1271 SPIRVType * 1272 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 1273 MachineIRBuilder &MIRBuilder) { 1274 return getOrCreateSPIRVType( 1275 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 1276 MIRBuilder); 1277 } 1278 1279 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, 1280 SPIRVType *SpirvType) { 1281 assert(CurMF == SpirvType->getMF()); 1282 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 1283 SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy); 1284 return SpirvType; 1285 } 1286 1287 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth, 1288 MachineInstr &I, 1289 const SPIRVInstrInfo &TII, 1290 unsigned SPIRVOPcode, 1291 Type *LLVMTy) { 1292 Register Reg = DT.find(LLVMTy, CurMF); 1293 if (Reg.isValid()) 1294 return getSPIRVTypeForVReg(Reg); 1295 MachineBasicBlock &BB = *I.getParent(); 1296 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode)) 1297 .addDef(createTypeVReg(CurMF->getRegInfo())) 1298 .addImm(BitWidth) 1299 .addImm(0); 1300 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1301 return finishCreatingSPIRVType(LLVMTy, MIB); 1302 } 1303 1304 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 1305 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 1306 // Maybe adjust bit width to keep DuplicateTracker consistent. Without 1307 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for 1308 // example, the same "OpTypeInt 8" type for a series of LLVM integer types 1309 // with number of bits less than 8, causing duplicate type definitions. 1310 BitWidth = adjustOpTypeIntWidth(BitWidth); 1311 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 1312 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy); 1313 } 1314 1315 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType( 1316 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 1317 LLVMContext &Ctx = CurMF->getFunction().getContext(); 1318 Type *LLVMTy; 1319 switch (BitWidth) { 1320 case 16: 1321 LLVMTy = Type::getHalfTy(Ctx); 1322 break; 1323 case 32: 1324 LLVMTy = Type::getFloatTy(Ctx); 1325 break; 1326 case 64: 1327 LLVMTy = Type::getDoubleTy(Ctx); 1328 break; 1329 default: 1330 llvm_unreachable("Bit width is of unexpected size."); 1331 } 1332 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy); 1333 } 1334 1335 SPIRVType * 1336 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 1337 return getOrCreateSPIRVType( 1338 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 1339 MIRBuilder); 1340 } 1341 1342 SPIRVType * 1343 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, 1344 const SPIRVInstrInfo &TII) { 1345 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1); 1346 Register Reg = DT.find(LLVMTy, CurMF); 1347 if (Reg.isValid()) 1348 return getSPIRVTypeForVReg(Reg); 1349 MachineBasicBlock &BB = *I.getParent(); 1350 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool)) 1351 .addDef(createTypeVReg(CurMF->getRegInfo())); 1352 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1353 return finishCreatingSPIRVType(LLVMTy, MIB); 1354 } 1355 1356 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1357 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 1358 return getOrCreateSPIRVType( 1359 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1360 NumElements), 1361 MIRBuilder); 1362 } 1363 1364 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1365 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1366 const SPIRVInstrInfo &TII) { 1367 Type *LLVMTy = FixedVectorType::get( 1368 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1369 Register Reg = DT.find(LLVMTy, CurMF); 1370 if (Reg.isValid()) 1371 return getSPIRVTypeForVReg(Reg); 1372 MachineBasicBlock &BB = *I.getParent(); 1373 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 1374 .addDef(createTypeVReg(CurMF->getRegInfo())) 1375 .addUse(getSPIRVTypeID(BaseType)) 1376 .addImm(NumElements); 1377 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1378 return finishCreatingSPIRVType(LLVMTy, MIB); 1379 } 1380 1381 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType( 1382 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1383 const SPIRVInstrInfo &TII) { 1384 Type *LLVMTy = ArrayType::get( 1385 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1386 Register Reg = DT.find(LLVMTy, CurMF); 1387 if (Reg.isValid()) 1388 return getSPIRVTypeForVReg(Reg); 1389 MachineBasicBlock &BB = *I.getParent(); 1390 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII); 1391 Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII); 1392 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray)) 1393 .addDef(createTypeVReg(CurMF->getRegInfo())) 1394 .addUse(getSPIRVTypeID(BaseType)) 1395 .addUse(Len); 1396 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1397 return finishCreatingSPIRVType(LLVMTy, MIB); 1398 } 1399 1400 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1401 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, 1402 SPIRV::StorageClass::StorageClass SC) { 1403 const Type *PointerElementType = getTypeForSPIRVType(BaseType); 1404 unsigned AddressSpace = storageClassToAddressSpace(SC); 1405 Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType), 1406 AddressSpace); 1407 // check if this type is already available 1408 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); 1409 if (Reg.isValid()) 1410 return getSPIRVTypeForVReg(Reg); 1411 // create a new type 1412 auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(), 1413 MIRBuilder.getDebugLoc(), 1414 MIRBuilder.getTII().get(SPIRV::OpTypePointer)) 1415 .addDef(createTypeVReg(CurMF->getRegInfo())) 1416 .addImm(static_cast<uint32_t>(SC)) 1417 .addUse(getSPIRVTypeID(BaseType)); 1418 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); 1419 return finishCreatingSPIRVType(LLVMTy, MIB); 1420 } 1421 1422 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1423 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &, 1424 SPIRV::StorageClass::StorageClass SC) { 1425 MachineIRBuilder MIRBuilder(I); 1426 return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC); 1427 } 1428 1429 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, 1430 SPIRVType *SpvType, 1431 const SPIRVInstrInfo &TII) { 1432 assert(SpvType); 1433 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 1434 assert(LLVMTy); 1435 // Find a constant in DT or build a new one. 1436 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy)); 1437 Register Res = DT.find(UV, CurMF); 1438 if (Res.isValid()) 1439 return Res; 1440 LLT LLTy = LLT::scalar(32); 1441 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 1442 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 1443 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 1444 DT.add(UV, CurMF, Res); 1445 1446 MachineInstrBuilder MIB; 1447 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) 1448 .addDef(Res) 1449 .addUse(getSPIRVTypeID(SpvType)); 1450 const auto &ST = CurMF->getSubtarget(); 1451 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 1452 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 1453 return Res; 1454 } 1455