1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVBuiltins.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVTargetMachine.h" 22 #include "SPIRVUtils.h" 23 24 using namespace llvm; 25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 26 : PointerSize(PointerSize) {} 27 28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth, 29 Register VReg, 30 MachineInstr &I, 31 const SPIRVInstrInfo &TII) { 32 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 33 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 34 return SpirvType; 35 } 36 37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( 38 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I, 39 const SPIRVInstrInfo &TII) { 40 SPIRVType *SpirvType = 41 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII); 42 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 43 return SpirvType; 44 } 45 46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 47 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 48 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 49 50 SPIRVType *SpirvType = 51 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 52 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 53 return SpirvType; 54 } 55 56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 57 Register VReg, 58 MachineFunction &MF) { 59 VRegToTypeMap[&MF][VReg] = SpirvType; 60 } 61 62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 63 auto &MRI = MIRBuilder.getMF().getRegInfo(); 64 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 65 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 66 return Res; 67 } 68 69 static Register createTypeVReg(MachineRegisterInfo &MRI) { 70 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 71 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 72 return Res; 73 } 74 75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 76 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 77 .addDef(createTypeVReg(MIRBuilder)); 78 } 79 80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, 81 MachineIRBuilder &MIRBuilder, 82 bool IsSigned) { 83 assert(Width <= 64 && "Unsupported integer width!"); 84 if (Width <= 8) 85 Width = 8; 86 else if (Width <= 16) 87 Width = 16; 88 else if (Width <= 32) 89 Width = 32; 90 else if (Width <= 64) 91 Width = 64; 92 93 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 94 .addDef(createTypeVReg(MIRBuilder)) 95 .addImm(Width) 96 .addImm(IsSigned ? 1 : 0); 97 return MIB; 98 } 99 100 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 101 MachineIRBuilder &MIRBuilder) { 102 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 103 .addDef(createTypeVReg(MIRBuilder)) 104 .addImm(Width); 105 return MIB; 106 } 107 108 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 109 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 110 .addDef(createTypeVReg(MIRBuilder)); 111 } 112 113 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 114 SPIRVType *ElemType, 115 MachineIRBuilder &MIRBuilder) { 116 auto EleOpc = ElemType->getOpcode(); 117 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 118 EleOpc == SPIRV::OpTypeBool) && 119 "Invalid vector element type"); 120 121 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 122 .addDef(createTypeVReg(MIRBuilder)) 123 .addUse(getSPIRVTypeID(ElemType)) 124 .addImm(NumElems); 125 return MIB; 126 } 127 128 std::tuple<Register, ConstantInt *, bool> 129 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, 130 MachineIRBuilder *MIRBuilder, 131 MachineInstr *I, 132 const SPIRVInstrInfo *TII) { 133 const IntegerType *LLVMIntTy; 134 if (SpvType) 135 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 136 else 137 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext()); 138 bool NewInstr = false; 139 // Find a constant in DT or build a new one. 140 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 141 Register Res = DT.find(CI, CurMF); 142 if (!Res.isValid()) { 143 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 144 LLT LLTy = LLT::scalar(32); 145 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 146 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 147 if (MIRBuilder) 148 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); 149 else 150 assignIntTypeToVReg(BitWidth, Res, *I, *TII); 151 DT.add(CI, CurMF, Res); 152 NewInstr = true; 153 } 154 return std::make_tuple(Res, CI, NewInstr); 155 } 156 157 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, 158 SPIRVType *SpvType, 159 const SPIRVInstrInfo &TII) { 160 assert(SpvType); 161 ConstantInt *CI; 162 Register Res; 163 bool New; 164 std::tie(Res, CI, New) = 165 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII); 166 // If we have found Res register which is defined by the passed G_CONSTANT 167 // machine instruction, a new constant instruction should be created. 168 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 169 return Res; 170 MachineInstrBuilder MIB; 171 MachineBasicBlock &BB = *I.getParent(); 172 if (Val) { 173 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) 174 .addDef(Res) 175 .addUse(getSPIRVTypeID(SpvType)); 176 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB); 177 } else { 178 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 179 .addDef(Res) 180 .addUse(getSPIRVTypeID(SpvType)); 181 } 182 const auto &ST = CurMF->getSubtarget(); 183 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 184 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 185 return Res; 186 } 187 188 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 189 MachineIRBuilder &MIRBuilder, 190 SPIRVType *SpvType, 191 bool EmitIR) { 192 auto &MF = MIRBuilder.getMF(); 193 const IntegerType *LLVMIntTy; 194 if (SpvType) 195 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 196 else 197 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 198 // Find a constant in DT or build a new one. 199 const auto ConstInt = 200 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 201 Register Res = DT.find(ConstInt, &MF); 202 if (!Res.isValid()) { 203 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 204 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); 205 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); 206 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 207 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, 208 SPIRV::AccessQualifier::ReadWrite, EmitIR); 209 DT.add(ConstInt, &MIRBuilder.getMF(), Res); 210 if (EmitIR) { 211 MIRBuilder.buildConstant(Res, *ConstInt); 212 } else { 213 MachineInstrBuilder MIB; 214 if (Val) { 215 assert(SpvType); 216 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 217 .addDef(Res) 218 .addUse(getSPIRVTypeID(SpvType)); 219 addNumImm(APInt(BitWidth, Val), MIB); 220 } else { 221 assert(SpvType); 222 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 223 .addDef(Res) 224 .addUse(getSPIRVTypeID(SpvType)); 225 } 226 const auto &Subtarget = CurMF->getSubtarget(); 227 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 228 *Subtarget.getRegisterInfo(), 229 *Subtarget.getRegBankInfo()); 230 } 231 } 232 return Res; 233 } 234 235 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 236 MachineIRBuilder &MIRBuilder, 237 SPIRVType *SpvType) { 238 auto &MF = MIRBuilder.getMF(); 239 const Type *LLVMFPTy; 240 if (SpvType) { 241 LLVMFPTy = getTypeForSPIRVType(SpvType); 242 assert(LLVMFPTy->isFloatingPointTy()); 243 } else { 244 LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); 245 } 246 // Find a constant in DT or build a new one. 247 const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); 248 Register Res = DT.find(ConstFP, &MF); 249 if (!Res.isValid()) { 250 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 251 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 252 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 253 assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); 254 DT.add(ConstFP, &MF, Res); 255 MIRBuilder.buildFConstant(Res, *ConstFP); 256 } 257 return Res; 258 } 259 260 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 261 uint64_t Val, MachineInstr &I, SPIRVType *SpvType, 262 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, 263 unsigned ElemCnt) { 264 // Find a constant vector in DT or build a new one. 265 Register Res = DT.find(CA, CurMF); 266 if (!Res.isValid()) { 267 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 268 // SpvScalConst should be created before SpvVecConst to avoid undefined ID 269 // error on validation. 270 // TODO: can moved below once sorting of types/consts/defs is implemented. 271 Register SpvScalConst; 272 if (Val) 273 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII); 274 // TODO: maybe use bitwidth of base type. 275 LLT LLTy = LLT::scalar(32); 276 Register SpvVecConst = 277 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 278 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 279 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 280 DT.add(CA, CurMF, SpvVecConst); 281 MachineInstrBuilder MIB; 282 MachineBasicBlock &BB = *I.getParent(); 283 if (Val) { 284 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite)) 285 .addDef(SpvVecConst) 286 .addUse(getSPIRVTypeID(SpvType)); 287 for (unsigned i = 0; i < ElemCnt; ++i) 288 MIB.addUse(SpvScalConst); 289 } else { 290 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 291 .addDef(SpvVecConst) 292 .addUse(getSPIRVTypeID(SpvType)); 293 } 294 const auto &Subtarget = CurMF->getSubtarget(); 295 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 296 *Subtarget.getRegisterInfo(), 297 *Subtarget.getRegBankInfo()); 298 return SpvVecConst; 299 } 300 return Res; 301 } 302 303 Register 304 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, 305 SPIRVType *SpvType, 306 const SPIRVInstrInfo &TII) { 307 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 308 assert(LLVMTy->isVectorTy()); 309 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 310 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 311 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 312 auto ConstVec = 313 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 314 unsigned BW = getScalarOrVectorBitWidth(SpvType); 315 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW, 316 SpvType->getOperand(2).getImm()); 317 } 318 319 Register 320 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, 321 SPIRVType *SpvType, 322 const SPIRVInstrInfo &TII) { 323 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 324 assert(LLVMTy->isArrayTy()); 325 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 326 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 327 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 328 auto ConstArr = 329 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); 330 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 331 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 332 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW, 333 LLVMArrTy->getNumElements()); 334 } 335 336 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 337 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, 338 Constant *CA, unsigned BitWidth, unsigned ElemCnt) { 339 Register Res = DT.find(CA, CurMF); 340 if (!Res.isValid()) { 341 Register SpvScalConst; 342 if (Val || EmitIR) { 343 SPIRVType *SpvBaseType = 344 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); 345 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR); 346 } 347 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); 348 Register SpvVecConst = 349 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 350 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 351 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 352 DT.add(CA, CurMF, SpvVecConst); 353 if (EmitIR) { 354 MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst); 355 } else { 356 if (Val) { 357 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) 358 .addDef(SpvVecConst) 359 .addUse(getSPIRVTypeID(SpvType)); 360 for (unsigned i = 0; i < ElemCnt; ++i) 361 MIB.addUse(SpvScalConst); 362 } else { 363 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 364 .addDef(SpvVecConst) 365 .addUse(getSPIRVTypeID(SpvType)); 366 } 367 } 368 return SpvVecConst; 369 } 370 return Res; 371 } 372 373 Register 374 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, 375 MachineIRBuilder &MIRBuilder, 376 SPIRVType *SpvType, bool EmitIR) { 377 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 378 assert(LLVMTy->isVectorTy()); 379 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 380 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 381 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 382 auto ConstVec = 383 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 384 unsigned BW = getScalarOrVectorBitWidth(SpvType); 385 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 386 ConstVec, BW, 387 SpvType->getOperand(2).getImm()); 388 } 389 390 Register 391 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, 392 MachineIRBuilder &MIRBuilder, 393 SPIRVType *SpvType, bool EmitIR) { 394 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 395 assert(LLVMTy->isArrayTy()); 396 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 397 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 398 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 399 auto ConstArr = 400 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); 401 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 402 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 403 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 404 ConstArr, BW, 405 LLVMArrTy->getNumElements()); 406 } 407 408 Register 409 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, 410 SPIRVType *SpvType) { 411 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 412 const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy); 413 // Find a constant in DT or build a new one. 414 Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy)); 415 Register Res = DT.find(CP, CurMF); 416 if (!Res.isValid()) { 417 LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); 418 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 419 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 420 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 421 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 422 .addDef(Res) 423 .addUse(getSPIRVTypeID(SpvType)); 424 DT.add(CP, CurMF, Res); 425 } 426 return Res; 427 } 428 429 Register SPIRVGlobalRegistry::buildConstantSampler( 430 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode, 431 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { 432 SPIRVType *SampTy; 433 if (SpvType) 434 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); 435 else 436 SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder); 437 438 auto Sampler = 439 ResReg.isValid() 440 ? ResReg 441 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 442 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler) 443 .addDef(Sampler) 444 .addUse(getSPIRVTypeID(SampTy)) 445 .addImm(AddrMode) 446 .addImm(Param) 447 .addImm(FilerMode); 448 assert(Res->getOperand(0).isReg()); 449 return Res->getOperand(0).getReg(); 450 } 451 452 Register SPIRVGlobalRegistry::buildGlobalVariable( 453 Register ResVReg, SPIRVType *BaseType, StringRef Name, 454 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, 455 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 456 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 457 bool IsInstSelector) { 458 const GlobalVariable *GVar = nullptr; 459 if (GV) 460 GVar = cast<const GlobalVariable>(GV); 461 else { 462 // If GV is not passed explicitly, use the name to find or construct 463 // the global variable. 464 Module *M = MIRBuilder.getMF().getFunction().getParent(); 465 GVar = M->getGlobalVariable(Name); 466 if (GVar == nullptr) { 467 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 468 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 469 GlobalValue::ExternalLinkage, nullptr, 470 Twine(Name)); 471 } 472 GV = GVar; 473 } 474 Register Reg = DT.find(GVar, &MIRBuilder.getMF()); 475 if (Reg.isValid()) { 476 if (Reg != ResVReg) 477 MIRBuilder.buildCopy(ResVReg, Reg); 478 return ResVReg; 479 } 480 481 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 482 .addDef(ResVReg) 483 .addUse(getSPIRVTypeID(BaseType)) 484 .addImm(static_cast<uint32_t>(Storage)); 485 486 if (Init != 0) { 487 MIB.addUse(Init->getOperand(0).getReg()); 488 } 489 490 // ISel may introduce a new register on this step, so we need to add it to 491 // DT and correct its type avoiding fails on the next stage. 492 if (IsInstSelector) { 493 const auto &Subtarget = CurMF->getSubtarget(); 494 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 495 *Subtarget.getRegisterInfo(), 496 *Subtarget.getRegBankInfo()); 497 } 498 Reg = MIB->getOperand(0).getReg(); 499 DT.add(GVar, &MIRBuilder.getMF(), Reg); 500 501 // Set to Reg the same type as ResVReg has. 502 auto MRI = MIRBuilder.getMRI(); 503 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 504 if (Reg != ResVReg) { 505 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); 506 MRI->setType(Reg, RegLLTy); 507 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 508 } 509 510 // If it's a global variable with name, output OpName for it. 511 if (GVar && GVar->hasName()) 512 buildOpName(Reg, GVar->getName(), MIRBuilder); 513 514 // Output decorations for the GV. 515 // TODO: maybe move to GenerateDecorations pass. 516 if (IsConst) 517 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 518 519 if (GVar && GVar->getAlign().valueOrOne().value() != 1) { 520 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value(); 521 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); 522 } 523 524 if (HasLinkageTy) 525 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 526 {static_cast<uint32_t>(LinkageType)}, Name); 527 528 SPIRV::BuiltIn::BuiltIn BuiltInId; 529 if (getSpirvBuiltInIdByName(Name, BuiltInId)) 530 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn, 531 {static_cast<uint32_t>(BuiltInId)}); 532 533 return Reg; 534 } 535 536 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 537 SPIRVType *ElemType, 538 MachineIRBuilder &MIRBuilder, 539 bool EmitIR) { 540 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 541 "Invalid array element type"); 542 Register NumElementsVReg = 543 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 544 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 545 .addDef(createTypeVReg(MIRBuilder)) 546 .addUse(getSPIRVTypeID(ElemType)) 547 .addUse(NumElementsVReg); 548 return MIB; 549 } 550 551 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, 552 MachineIRBuilder &MIRBuilder) { 553 assert(Ty->hasName()); 554 const StringRef Name = Ty->hasName() ? Ty->getName() : ""; 555 Register ResVReg = createTypeVReg(MIRBuilder); 556 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg); 557 addStringImm(Name, MIB); 558 buildOpName(ResVReg, Name, MIRBuilder); 559 return MIB; 560 } 561 562 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, 563 MachineIRBuilder &MIRBuilder, 564 bool EmitIR) { 565 SmallVector<Register, 4> FieldTypes; 566 for (const auto &Elem : Ty->elements()) { 567 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder); 568 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && 569 "Invalid struct element type"); 570 FieldTypes.push_back(getSPIRVTypeID(ElemTy)); 571 } 572 Register ResVReg = createTypeVReg(MIRBuilder); 573 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); 574 for (const auto &Ty : FieldTypes) 575 MIB.addUse(Ty); 576 if (Ty->hasName()) 577 buildOpName(ResVReg, Ty->getName(), MIRBuilder); 578 if (Ty->isPacked()) 579 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); 580 return MIB; 581 } 582 583 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( 584 const Type *Ty, MachineIRBuilder &MIRBuilder, 585 SPIRV::AccessQualifier::AccessQualifier AccQual) { 586 // Some OpenCL and SPIRV builtins like image2d_t are passed in as 587 // pointers, but should be treated as custom types like OpTypeImage. 588 if (auto PType = dyn_cast<PointerType>(Ty)) { 589 assert(!PType->isOpaque()); 590 Ty = PType->getNonOpaquePointerElementType(); 591 } 592 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type"); 593 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this); 594 } 595 596 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer( 597 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType, 598 MachineIRBuilder &MIRBuilder, Register Reg) { 599 if (!Reg.isValid()) 600 Reg = createTypeVReg(MIRBuilder); 601 return MIRBuilder.buildInstr(SPIRV::OpTypePointer) 602 .addDef(Reg) 603 .addImm(static_cast<uint32_t>(SC)) 604 .addUse(getSPIRVTypeID(ElemType)); 605 } 606 607 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer( 608 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) { 609 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer) 610 .addUse(createTypeVReg(MIRBuilder)) 611 .addImm(static_cast<uint32_t>(SC)); 612 } 613 614 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 615 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 616 MachineIRBuilder &MIRBuilder) { 617 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 618 .addDef(createTypeVReg(MIRBuilder)) 619 .addUse(getSPIRVTypeID(RetType)); 620 for (const SPIRVType *ArgType : ArgTypes) 621 MIB.addUse(getSPIRVTypeID(ArgType)); 622 return MIB; 623 } 624 625 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( 626 const Type *Ty, SPIRVType *RetType, 627 const SmallVectorImpl<SPIRVType *> &ArgTypes, 628 MachineIRBuilder &MIRBuilder) { 629 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 630 if (Reg.isValid()) 631 return getSPIRVTypeForVReg(Reg); 632 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); 633 return finishCreatingSPIRVType(Ty, SpirvType); 634 } 635 636 SPIRVType *SPIRVGlobalRegistry::findSPIRVType( 637 const Type *Ty, MachineIRBuilder &MIRBuilder, 638 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 639 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 640 if (Reg.isValid()) 641 return getSPIRVTypeForVReg(Reg); 642 if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end()) 643 return ForwardPointerTypes[Ty]; 644 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); 645 } 646 647 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { 648 assert(SpirvType && "Attempting to get type id for nullptr type."); 649 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer) 650 return SpirvType->uses().begin()->getReg(); 651 return SpirvType->defs().begin()->getReg(); 652 } 653 654 SPIRVType *SPIRVGlobalRegistry::createSPIRVType( 655 const Type *Ty, MachineIRBuilder &MIRBuilder, 656 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 657 if (isSpecialOpaqueType(Ty)) 658 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); 659 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); 660 auto t = TypeToSPIRVTypeMap.find(Ty); 661 if (t != TypeToSPIRVTypeMap.end()) { 662 auto tt = t->second.find(&MIRBuilder.getMF()); 663 if (tt != t->second.end()) 664 return getSPIRVTypeForVReg(tt->second); 665 } 666 667 if (auto IType = dyn_cast<IntegerType>(Ty)) { 668 const unsigned Width = IType->getBitWidth(); 669 return Width == 1 ? getOpTypeBool(MIRBuilder) 670 : getOpTypeInt(Width, MIRBuilder, false); 671 } 672 if (Ty->isFloatingPointTy()) 673 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 674 if (Ty->isVoidTy()) 675 return getOpTypeVoid(MIRBuilder); 676 if (Ty->isVectorTy()) { 677 SPIRVType *El = 678 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder); 679 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 680 MIRBuilder); 681 } 682 if (Ty->isArrayTy()) { 683 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder); 684 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 685 } 686 if (auto SType = dyn_cast<StructType>(Ty)) { 687 if (SType->isOpaque()) 688 return getOpTypeOpaque(SType, MIRBuilder); 689 return getOpTypeStruct(SType, MIRBuilder, EmitIR); 690 } 691 if (auto FType = dyn_cast<FunctionType>(Ty)) { 692 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder); 693 SmallVector<SPIRVType *, 4> ParamTypes; 694 for (const auto &t : FType->params()) { 695 ParamTypes.push_back(findSPIRVType(t, MIRBuilder)); 696 } 697 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 698 } 699 if (auto PType = dyn_cast<PointerType>(Ty)) { 700 SPIRVType *SpvElementType; 701 // At the moment, all opaque pointers correspond to i8 element type. 702 // TODO: change the implementation once opaque pointers are supported 703 // in the SPIR-V specification. 704 if (PType->isOpaque()) 705 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 706 else 707 SpvElementType = 708 findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder, 709 SPIRV::AccessQualifier::ReadWrite, EmitIR); 710 auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); 711 // Null pointer means we have a loop in type definitions, make and 712 // return corresponding OpTypeForwardPointer. 713 if (SpvElementType == nullptr) { 714 if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end()) 715 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder); 716 return ForwardPointerTypes[PType]; 717 } 718 Register Reg(0); 719 // If we have forward pointer associated with this type, use its register 720 // operand to create OpTypePointer. 721 if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end()) 722 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]); 723 724 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); 725 } 726 llvm_unreachable("Unable to convert LLVM type to SPIRVType"); 727 } 728 729 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( 730 const Type *Ty, MachineIRBuilder &MIRBuilder, 731 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 732 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy()) 733 return nullptr; 734 TypesInProcessing.insert(Ty); 735 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 736 TypesInProcessing.erase(Ty); 737 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 738 SPIRVToLLVMType[SpirvType] = Ty; 739 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 740 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type 741 // will be added later. For special types it is already added to DT. 742 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && 743 !isSpecialOpaqueType(Ty)) 744 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 745 746 return SpirvType; 747 } 748 749 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { 750 auto t = VRegToTypeMap.find(CurMF); 751 if (t != VRegToTypeMap.end()) { 752 auto tt = t->second.find(VReg); 753 if (tt != t->second.end()) 754 return tt->second; 755 } 756 return nullptr; 757 } 758 759 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 760 const Type *Ty, MachineIRBuilder &MIRBuilder, 761 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 762 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 763 if (Reg.isValid() && !isSpecialOpaqueType(Ty)) 764 return getSPIRVTypeForVReg(Reg); 765 TypesInProcessing.clear(); 766 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 767 // Create normal pointer types for the corresponding OpTypeForwardPointers. 768 for (auto &CU : ForwardPointerTypes) { 769 const Type *Ty2 = CU.first; 770 SPIRVType *STy2 = CU.second; 771 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid()) 772 STy2 = getSPIRVTypeForVReg(Reg); 773 else 774 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); 775 if (Ty == Ty2) 776 STy = STy2; 777 } 778 ForwardPointerTypes.clear(); 779 return STy; 780 } 781 782 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 783 unsigned TypeOpcode) const { 784 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 785 assert(Type && "isScalarOfType VReg has no type assigned"); 786 return Type->getOpcode() == TypeOpcode; 787 } 788 789 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 790 unsigned TypeOpcode) const { 791 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 792 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 793 if (Type->getOpcode() == TypeOpcode) 794 return true; 795 if (Type->getOpcode() == SPIRV::OpTypeVector) { 796 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 797 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 798 return ScalarType->getOpcode() == TypeOpcode; 799 } 800 return false; 801 } 802 803 unsigned 804 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 805 assert(Type && "Invalid Type pointer"); 806 if (Type->getOpcode() == SPIRV::OpTypeVector) { 807 auto EleTypeReg = Type->getOperand(1).getReg(); 808 Type = getSPIRVTypeForVReg(EleTypeReg); 809 } 810 if (Type->getOpcode() == SPIRV::OpTypeInt || 811 Type->getOpcode() == SPIRV::OpTypeFloat) 812 return Type->getOperand(1).getImm(); 813 if (Type->getOpcode() == SPIRV::OpTypeBool) 814 return 1; 815 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 816 } 817 818 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 819 assert(Type && "Invalid Type pointer"); 820 if (Type->getOpcode() == SPIRV::OpTypeVector) { 821 auto EleTypeReg = Type->getOperand(1).getReg(); 822 Type = getSPIRVTypeForVReg(EleTypeReg); 823 } 824 if (Type->getOpcode() == SPIRV::OpTypeInt) 825 return Type->getOperand(2).getImm() != 0; 826 llvm_unreachable("Attempting to get sign of non-integer type."); 827 } 828 829 SPIRV::StorageClass::StorageClass 830 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 831 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 832 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 833 Type->getOperand(1).isImm() && "Pointer type is expected"); 834 return static_cast<SPIRV::StorageClass::StorageClass>( 835 Type->getOperand(1).getImm()); 836 } 837 838 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( 839 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim, 840 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled, 841 SPIRV::ImageFormat::ImageFormat ImageFormat, 842 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 843 SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth, 844 Arrayed, Multisampled, Sampled, ImageFormat, 845 AccessQual); 846 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 847 return Res; 848 Register ResVReg = createTypeVReg(MIRBuilder); 849 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 850 return MIRBuilder.buildInstr(SPIRV::OpTypeImage) 851 .addDef(ResVReg) 852 .addUse(getSPIRVTypeID(SampledType)) 853 .addImm(Dim) 854 .addImm(Depth) // Depth (whether or not it is a Depth image). 855 .addImm(Arrayed) // Arrayed. 856 .addImm(Multisampled) // Multisampled (0 = only single-sample). 857 .addImm(Sampled) // Sampled (0 = usage known at runtime). 858 .addImm(ImageFormat) 859 .addImm(AccessQual); 860 } 861 862 SPIRVType * 863 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { 864 SPIRV::SamplerTypeDescriptor TD; 865 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 866 return Res; 867 Register ResVReg = createTypeVReg(MIRBuilder); 868 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 869 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg); 870 } 871 872 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( 873 MachineIRBuilder &MIRBuilder, 874 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 875 SPIRV::PipeTypeDescriptor TD(AccessQual); 876 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 877 return Res; 878 Register ResVReg = createTypeVReg(MIRBuilder); 879 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 880 return MIRBuilder.buildInstr(SPIRV::OpTypePipe) 881 .addDef(ResVReg) 882 .addImm(AccessQual); 883 } 884 885 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( 886 MachineIRBuilder &MIRBuilder) { 887 SPIRV::DeviceEventTypeDescriptor TD; 888 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 889 return Res; 890 Register ResVReg = createTypeVReg(MIRBuilder); 891 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 892 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg); 893 } 894 895 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( 896 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { 897 SPIRV::SampledImageTypeDescriptor TD( 898 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( 899 ImageType->getOperand(1).getReg())), 900 ImageType); 901 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 902 return Res; 903 Register ResVReg = createTypeVReg(MIRBuilder); 904 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 905 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage) 906 .addDef(ResVReg) 907 .addUse(getSPIRVTypeID(ImageType)); 908 } 909 910 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( 911 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { 912 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); 913 if (ResVReg.isValid()) 914 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 915 ResVReg = createTypeVReg(MIRBuilder); 916 DT.add(Ty, &MIRBuilder.getMF(), ResVReg); 917 return MIRBuilder.buildInstr(Opcode).addDef(ResVReg); 918 } 919 920 const MachineInstr * 921 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, 922 MachineIRBuilder &MIRBuilder) { 923 Register Reg = DT.find(TD, &MIRBuilder.getMF()); 924 if (Reg.isValid()) 925 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg); 926 return nullptr; 927 } 928 929 // TODO: maybe use tablegen to implement this. 930 SPIRVType * 931 SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr, 932 MachineIRBuilder &MIRBuilder) { 933 unsigned VecElts = 0; 934 auto &Ctx = MIRBuilder.getMF().getFunction().getContext(); 935 936 // Parse type name in either "typeN" or "type vector[N]" format, where 937 // N is the number of elements of the vector. 938 Type *Type; 939 if (TypeStr.startswith("void")) { 940 Type = Type::getVoidTy(Ctx); 941 TypeStr = TypeStr.substr(strlen("void")); 942 } else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) { 943 Type = Type::getInt32Ty(Ctx); 944 TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int")) 945 : TypeStr.substr(strlen("uint")); 946 } else if (TypeStr.startswith("float")) { 947 Type = Type::getFloatTy(Ctx); 948 TypeStr = TypeStr.substr(strlen("float")); 949 } else if (TypeStr.startswith("half")) { 950 Type = Type::getHalfTy(Ctx); 951 TypeStr = TypeStr.substr(strlen("half")); 952 } else if (TypeStr.startswith("opencl.sampler_t")) { 953 Type = StructType::create(Ctx, "opencl.sampler_t"); 954 } else 955 llvm_unreachable("Unable to recognize SPIRV type name."); 956 if (TypeStr.startswith(" vector[")) { 957 TypeStr = TypeStr.substr(strlen(" vector[")); 958 TypeStr = TypeStr.substr(0, TypeStr.find(']')); 959 } 960 TypeStr.getAsInteger(10, VecElts); 961 auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder); 962 if (VecElts > 0) 963 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder); 964 return SpirvTy; 965 } 966 967 SPIRVType * 968 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 969 MachineIRBuilder &MIRBuilder) { 970 return getOrCreateSPIRVType( 971 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 972 MIRBuilder); 973 } 974 975 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, 976 SPIRVType *SpirvType) { 977 assert(CurMF == SpirvType->getMF()); 978 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 979 SPIRVToLLVMType[SpirvType] = LLVMTy; 980 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); 981 return SpirvType; 982 } 983 984 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 985 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 986 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 987 Register Reg = DT.find(LLVMTy, CurMF); 988 if (Reg.isValid()) 989 return getSPIRVTypeForVReg(Reg); 990 MachineBasicBlock &BB = *I.getParent(); 991 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) 992 .addDef(createTypeVReg(CurMF->getRegInfo())) 993 .addImm(BitWidth) 994 .addImm(0); 995 return finishCreatingSPIRVType(LLVMTy, MIB); 996 } 997 998 SPIRVType * 999 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 1000 return getOrCreateSPIRVType( 1001 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 1002 MIRBuilder); 1003 } 1004 1005 SPIRVType * 1006 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, 1007 const SPIRVInstrInfo &TII) { 1008 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1); 1009 Register Reg = DT.find(LLVMTy, CurMF); 1010 if (Reg.isValid()) 1011 return getSPIRVTypeForVReg(Reg); 1012 MachineBasicBlock &BB = *I.getParent(); 1013 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool)) 1014 .addDef(createTypeVReg(CurMF->getRegInfo())); 1015 return finishCreatingSPIRVType(LLVMTy, MIB); 1016 } 1017 1018 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1019 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 1020 return getOrCreateSPIRVType( 1021 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1022 NumElements), 1023 MIRBuilder); 1024 } 1025 1026 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1027 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1028 const SPIRVInstrInfo &TII) { 1029 Type *LLVMTy = FixedVectorType::get( 1030 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1031 Register Reg = DT.find(LLVMTy, CurMF); 1032 if (Reg.isValid()) 1033 return getSPIRVTypeForVReg(Reg); 1034 MachineBasicBlock &BB = *I.getParent(); 1035 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 1036 .addDef(createTypeVReg(CurMF->getRegInfo())) 1037 .addUse(getSPIRVTypeID(BaseType)) 1038 .addImm(NumElements); 1039 return finishCreatingSPIRVType(LLVMTy, MIB); 1040 } 1041 1042 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType( 1043 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1044 const SPIRVInstrInfo &TII) { 1045 Type *LLVMTy = ArrayType::get( 1046 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1047 Register Reg = DT.find(LLVMTy, CurMF); 1048 if (Reg.isValid()) 1049 return getSPIRVTypeForVReg(Reg); 1050 MachineBasicBlock &BB = *I.getParent(); 1051 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII); 1052 Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII); 1053 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray)) 1054 .addDef(createTypeVReg(CurMF->getRegInfo())) 1055 .addUse(getSPIRVTypeID(BaseType)) 1056 .addUse(Len); 1057 return finishCreatingSPIRVType(LLVMTy, MIB); 1058 } 1059 1060 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1061 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, 1062 SPIRV::StorageClass::StorageClass SClass) { 1063 return getOrCreateSPIRVType( 1064 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1065 storageClassToAddressSpace(SClass)), 1066 MIRBuilder); 1067 } 1068 1069 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1070 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, 1071 SPIRV::StorageClass::StorageClass SC) { 1072 Type *LLVMTy = 1073 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1074 storageClassToAddressSpace(SC)); 1075 Register Reg = DT.find(LLVMTy, CurMF); 1076 if (Reg.isValid()) 1077 return getSPIRVTypeForVReg(Reg); 1078 MachineBasicBlock &BB = *I.getParent(); 1079 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) 1080 .addDef(createTypeVReg(CurMF->getRegInfo())) 1081 .addImm(static_cast<uint32_t>(SC)) 1082 .addUse(getSPIRVTypeID(BaseType)); 1083 return finishCreatingSPIRVType(LLVMTy, MIB); 1084 } 1085 1086 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, 1087 SPIRVType *SpvType, 1088 const SPIRVInstrInfo &TII) { 1089 assert(SpvType); 1090 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 1091 assert(LLVMTy); 1092 // Find a constant in DT or build a new one. 1093 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy)); 1094 Register Res = DT.find(UV, CurMF); 1095 if (Res.isValid()) 1096 return Res; 1097 LLT LLTy = LLT::scalar(32); 1098 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 1099 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 1100 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 1101 DT.add(UV, CurMF, Res); 1102 1103 MachineInstrBuilder MIB; 1104 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) 1105 .addDef(Res) 1106 .addUse(getSPIRVTypeID(SpvType)); 1107 const auto &ST = CurMF->getSubtarget(); 1108 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 1109 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 1110 return Res; 1111 } 1112