1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVBuiltins.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVTargetMachine.h" 22 #include "SPIRVUtils.h" 23 24 using namespace llvm; 25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 26 : PointerSize(PointerSize) {} 27 28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth, 29 Register VReg, 30 MachineInstr &I, 31 const SPIRVInstrInfo &TII) { 32 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 33 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 34 return SpirvType; 35 } 36 37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( 38 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I, 39 const SPIRVInstrInfo &TII) { 40 SPIRVType *SpirvType = 41 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII); 42 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 43 return SpirvType; 44 } 45 46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 47 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 48 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 49 50 SPIRVType *SpirvType = 51 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 52 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 53 return SpirvType; 54 } 55 56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 57 Register VReg, 58 MachineFunction &MF) { 59 VRegToTypeMap[&MF][VReg] = SpirvType; 60 } 61 62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 63 auto &MRI = MIRBuilder.getMF().getRegInfo(); 64 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 65 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 66 return Res; 67 } 68 69 static Register createTypeVReg(MachineRegisterInfo &MRI) { 70 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 71 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 72 return Res; 73 } 74 75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 76 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 77 .addDef(createTypeVReg(MIRBuilder)); 78 } 79 80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, 81 MachineIRBuilder &MIRBuilder, 82 bool IsSigned) { 83 assert(Width <= 64 && "Unsupported integer width!"); 84 const SPIRVSubtarget &ST = 85 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); 86 if (ST.canUseExtension( 87 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) { 88 MIRBuilder.buildInstr(SPIRV::OpExtension) 89 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers); 90 MIRBuilder.buildInstr(SPIRV::OpCapability) 91 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL); 92 } else if (Width <= 8) 93 Width = 8; 94 else if (Width <= 16) 95 Width = 16; 96 else if (Width <= 32) 97 Width = 32; 98 else if (Width <= 64) 99 Width = 64; 100 101 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 102 .addDef(createTypeVReg(MIRBuilder)) 103 .addImm(Width) 104 .addImm(IsSigned ? 1 : 0); 105 return MIB; 106 } 107 108 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 109 MachineIRBuilder &MIRBuilder) { 110 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 111 .addDef(createTypeVReg(MIRBuilder)) 112 .addImm(Width); 113 return MIB; 114 } 115 116 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 117 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 118 .addDef(createTypeVReg(MIRBuilder)); 119 } 120 121 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 122 SPIRVType *ElemType, 123 MachineIRBuilder &MIRBuilder) { 124 auto EleOpc = ElemType->getOpcode(); 125 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 126 EleOpc == SPIRV::OpTypeBool) && 127 "Invalid vector element type"); 128 129 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 130 .addDef(createTypeVReg(MIRBuilder)) 131 .addUse(getSPIRVTypeID(ElemType)) 132 .addImm(NumElems); 133 return MIB; 134 } 135 136 std::tuple<Register, ConstantInt *, bool> 137 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, 138 MachineIRBuilder *MIRBuilder, 139 MachineInstr *I, 140 const SPIRVInstrInfo *TII) { 141 const IntegerType *LLVMIntTy; 142 if (SpvType) 143 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 144 else 145 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext()); 146 bool NewInstr = false; 147 // Find a constant in DT or build a new one. 148 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 149 Register Res = DT.find(CI, CurMF); 150 if (!Res.isValid()) { 151 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 152 LLT LLTy = LLT::scalar(32); 153 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 154 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 155 if (MIRBuilder) 156 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); 157 else 158 assignIntTypeToVReg(BitWidth, Res, *I, *TII); 159 DT.add(CI, CurMF, Res); 160 NewInstr = true; 161 } 162 return std::make_tuple(Res, CI, NewInstr); 163 } 164 165 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, 166 SPIRVType *SpvType, 167 const SPIRVInstrInfo &TII) { 168 assert(SpvType); 169 ConstantInt *CI; 170 Register Res; 171 bool New; 172 std::tie(Res, CI, New) = 173 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII); 174 // If we have found Res register which is defined by the passed G_CONSTANT 175 // machine instruction, a new constant instruction should be created. 176 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 177 return Res; 178 MachineInstrBuilder MIB; 179 MachineBasicBlock &BB = *I.getParent(); 180 if (Val) { 181 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) 182 .addDef(Res) 183 .addUse(getSPIRVTypeID(SpvType)); 184 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB); 185 } else { 186 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 187 .addDef(Res) 188 .addUse(getSPIRVTypeID(SpvType)); 189 } 190 const auto &ST = CurMF->getSubtarget(); 191 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 192 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 193 return Res; 194 } 195 196 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 197 MachineIRBuilder &MIRBuilder, 198 SPIRVType *SpvType, 199 bool EmitIR) { 200 auto &MF = MIRBuilder.getMF(); 201 const IntegerType *LLVMIntTy; 202 if (SpvType) 203 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 204 else 205 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 206 // Find a constant in DT or build a new one. 207 const auto ConstInt = 208 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 209 Register Res = DT.find(ConstInt, &MF); 210 if (!Res.isValid()) { 211 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 212 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); 213 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); 214 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 215 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, 216 SPIRV::AccessQualifier::ReadWrite, EmitIR); 217 DT.add(ConstInt, &MIRBuilder.getMF(), Res); 218 if (EmitIR) { 219 MIRBuilder.buildConstant(Res, *ConstInt); 220 } else { 221 MachineInstrBuilder MIB; 222 if (Val) { 223 assert(SpvType); 224 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 225 .addDef(Res) 226 .addUse(getSPIRVTypeID(SpvType)); 227 addNumImm(APInt(BitWidth, Val), MIB); 228 } else { 229 assert(SpvType); 230 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 231 .addDef(Res) 232 .addUse(getSPIRVTypeID(SpvType)); 233 } 234 const auto &Subtarget = CurMF->getSubtarget(); 235 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 236 *Subtarget.getRegisterInfo(), 237 *Subtarget.getRegBankInfo()); 238 } 239 } 240 return Res; 241 } 242 243 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 244 MachineIRBuilder &MIRBuilder, 245 SPIRVType *SpvType) { 246 auto &MF = MIRBuilder.getMF(); 247 auto &Ctx = MF.getFunction().getContext(); 248 if (!SpvType) { 249 const Type *LLVMFPTy = Type::getFloatTy(Ctx); 250 SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder); 251 } 252 // Find a constant in DT or build a new one. 253 const auto ConstFP = ConstantFP::get(Ctx, Val); 254 Register Res = DT.find(ConstFP, &MF); 255 if (!Res.isValid()) { 256 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32)); 257 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 258 assignSPIRVTypeToVReg(SpvType, Res, MF); 259 DT.add(ConstFP, &MF, Res); 260 261 MachineInstrBuilder MIB; 262 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF) 263 .addDef(Res) 264 .addUse(getSPIRVTypeID(SpvType)); 265 addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB); 266 } 267 268 return Res; 269 } 270 271 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 272 uint64_t Val, MachineInstr &I, SPIRVType *SpvType, 273 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, 274 unsigned ElemCnt) { 275 // Find a constant vector in DT or build a new one. 276 Register Res = DT.find(CA, CurMF); 277 if (!Res.isValid()) { 278 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 279 // SpvScalConst should be created before SpvVecConst to avoid undefined ID 280 // error on validation. 281 // TODO: can moved below once sorting of types/consts/defs is implemented. 282 Register SpvScalConst; 283 if (Val) 284 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII); 285 // TODO: maybe use bitwidth of base type. 286 LLT LLTy = LLT::scalar(32); 287 Register SpvVecConst = 288 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 289 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 290 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 291 DT.add(CA, CurMF, SpvVecConst); 292 MachineInstrBuilder MIB; 293 MachineBasicBlock &BB = *I.getParent(); 294 if (Val) { 295 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite)) 296 .addDef(SpvVecConst) 297 .addUse(getSPIRVTypeID(SpvType)); 298 for (unsigned i = 0; i < ElemCnt; ++i) 299 MIB.addUse(SpvScalConst); 300 } else { 301 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 302 .addDef(SpvVecConst) 303 .addUse(getSPIRVTypeID(SpvType)); 304 } 305 const auto &Subtarget = CurMF->getSubtarget(); 306 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 307 *Subtarget.getRegisterInfo(), 308 *Subtarget.getRegBankInfo()); 309 return SpvVecConst; 310 } 311 return Res; 312 } 313 314 Register 315 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, 316 SPIRVType *SpvType, 317 const SPIRVInstrInfo &TII) { 318 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 319 assert(LLVMTy->isVectorTy()); 320 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 321 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 322 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 323 auto ConstVec = 324 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 325 unsigned BW = getScalarOrVectorBitWidth(SpvType); 326 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW, 327 SpvType->getOperand(2).getImm()); 328 } 329 330 Register 331 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, 332 SPIRVType *SpvType, 333 const SPIRVInstrInfo &TII) { 334 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 335 assert(LLVMTy->isArrayTy()); 336 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 337 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 338 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 339 auto ConstArr = 340 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); 341 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 342 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 343 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW, 344 LLVMArrTy->getNumElements()); 345 } 346 347 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 348 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, 349 Constant *CA, unsigned BitWidth, unsigned ElemCnt) { 350 Register Res = DT.find(CA, CurMF); 351 if (!Res.isValid()) { 352 Register SpvScalConst; 353 if (Val || EmitIR) { 354 SPIRVType *SpvBaseType = 355 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); 356 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR); 357 } 358 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); 359 Register SpvVecConst = 360 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 361 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 362 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 363 DT.add(CA, CurMF, SpvVecConst); 364 if (EmitIR) { 365 MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst); 366 } else { 367 if (Val) { 368 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) 369 .addDef(SpvVecConst) 370 .addUse(getSPIRVTypeID(SpvType)); 371 for (unsigned i = 0; i < ElemCnt; ++i) 372 MIB.addUse(SpvScalConst); 373 } else { 374 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 375 .addDef(SpvVecConst) 376 .addUse(getSPIRVTypeID(SpvType)); 377 } 378 } 379 return SpvVecConst; 380 } 381 return Res; 382 } 383 384 Register 385 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, 386 MachineIRBuilder &MIRBuilder, 387 SPIRVType *SpvType, bool EmitIR) { 388 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 389 assert(LLVMTy->isVectorTy()); 390 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 391 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 392 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 393 auto ConstVec = 394 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 395 unsigned BW = getScalarOrVectorBitWidth(SpvType); 396 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 397 ConstVec, BW, 398 SpvType->getOperand(2).getImm()); 399 } 400 401 Register 402 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, 403 MachineIRBuilder &MIRBuilder, 404 SPIRVType *SpvType, bool EmitIR) { 405 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 406 assert(LLVMTy->isArrayTy()); 407 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 408 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 409 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 410 auto ConstArr = 411 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); 412 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 413 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 414 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 415 ConstArr, BW, 416 LLVMArrTy->getNumElements()); 417 } 418 419 Register 420 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, 421 SPIRVType *SpvType) { 422 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 423 const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy); 424 // Find a constant in DT or build a new one. 425 Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy)); 426 Register Res = DT.find(CP, CurMF); 427 if (!Res.isValid()) { 428 LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); 429 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 430 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 431 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 432 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 433 .addDef(Res) 434 .addUse(getSPIRVTypeID(SpvType)); 435 DT.add(CP, CurMF, Res); 436 } 437 return Res; 438 } 439 440 Register SPIRVGlobalRegistry::buildConstantSampler( 441 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode, 442 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { 443 SPIRVType *SampTy; 444 if (SpvType) 445 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); 446 else 447 SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder); 448 449 auto Sampler = 450 ResReg.isValid() 451 ? ResReg 452 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 453 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler) 454 .addDef(Sampler) 455 .addUse(getSPIRVTypeID(SampTy)) 456 .addImm(AddrMode) 457 .addImm(Param) 458 .addImm(FilerMode); 459 assert(Res->getOperand(0).isReg()); 460 return Res->getOperand(0).getReg(); 461 } 462 463 Register SPIRVGlobalRegistry::buildGlobalVariable( 464 Register ResVReg, SPIRVType *BaseType, StringRef Name, 465 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, 466 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 467 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 468 bool IsInstSelector) { 469 const GlobalVariable *GVar = nullptr; 470 if (GV) 471 GVar = cast<const GlobalVariable>(GV); 472 else { 473 // If GV is not passed explicitly, use the name to find or construct 474 // the global variable. 475 Module *M = MIRBuilder.getMF().getFunction().getParent(); 476 GVar = M->getGlobalVariable(Name); 477 if (GVar == nullptr) { 478 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 479 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 480 GlobalValue::ExternalLinkage, nullptr, 481 Twine(Name)); 482 } 483 GV = GVar; 484 } 485 Register Reg = DT.find(GVar, &MIRBuilder.getMF()); 486 if (Reg.isValid()) { 487 if (Reg != ResVReg) 488 MIRBuilder.buildCopy(ResVReg, Reg); 489 return ResVReg; 490 } 491 492 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 493 .addDef(ResVReg) 494 .addUse(getSPIRVTypeID(BaseType)) 495 .addImm(static_cast<uint32_t>(Storage)); 496 497 if (Init != 0) { 498 MIB.addUse(Init->getOperand(0).getReg()); 499 } 500 501 // ISel may introduce a new register on this step, so we need to add it to 502 // DT and correct its type avoiding fails on the next stage. 503 if (IsInstSelector) { 504 const auto &Subtarget = CurMF->getSubtarget(); 505 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 506 *Subtarget.getRegisterInfo(), 507 *Subtarget.getRegBankInfo()); 508 } 509 Reg = MIB->getOperand(0).getReg(); 510 DT.add(GVar, &MIRBuilder.getMF(), Reg); 511 512 // Set to Reg the same type as ResVReg has. 513 auto MRI = MIRBuilder.getMRI(); 514 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 515 if (Reg != ResVReg) { 516 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); 517 MRI->setType(Reg, RegLLTy); 518 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 519 } 520 521 // If it's a global variable with name, output OpName for it. 522 if (GVar && GVar->hasName()) 523 buildOpName(Reg, GVar->getName(), MIRBuilder); 524 525 // Output decorations for the GV. 526 // TODO: maybe move to GenerateDecorations pass. 527 if (IsConst) 528 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 529 530 if (GVar && GVar->getAlign().valueOrOne().value() != 1) { 531 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value(); 532 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); 533 } 534 535 if (HasLinkageTy) 536 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 537 {static_cast<uint32_t>(LinkageType)}, Name); 538 539 SPIRV::BuiltIn::BuiltIn BuiltInId; 540 if (getSpirvBuiltInIdByName(Name, BuiltInId)) 541 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn, 542 {static_cast<uint32_t>(BuiltInId)}); 543 544 return Reg; 545 } 546 547 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 548 SPIRVType *ElemType, 549 MachineIRBuilder &MIRBuilder, 550 bool EmitIR) { 551 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 552 "Invalid array element type"); 553 Register NumElementsVReg = 554 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 555 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 556 .addDef(createTypeVReg(MIRBuilder)) 557 .addUse(getSPIRVTypeID(ElemType)) 558 .addUse(NumElementsVReg); 559 return MIB; 560 } 561 562 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, 563 MachineIRBuilder &MIRBuilder) { 564 assert(Ty->hasName()); 565 const StringRef Name = Ty->hasName() ? Ty->getName() : ""; 566 Register ResVReg = createTypeVReg(MIRBuilder); 567 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg); 568 addStringImm(Name, MIB); 569 buildOpName(ResVReg, Name, MIRBuilder); 570 return MIB; 571 } 572 573 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, 574 MachineIRBuilder &MIRBuilder, 575 bool EmitIR) { 576 SmallVector<Register, 4> FieldTypes; 577 for (const auto &Elem : Ty->elements()) { 578 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder); 579 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && 580 "Invalid struct element type"); 581 FieldTypes.push_back(getSPIRVTypeID(ElemTy)); 582 } 583 Register ResVReg = createTypeVReg(MIRBuilder); 584 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); 585 for (const auto &Ty : FieldTypes) 586 MIB.addUse(Ty); 587 if (Ty->hasName()) 588 buildOpName(ResVReg, Ty->getName(), MIRBuilder); 589 if (Ty->isPacked()) 590 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); 591 return MIB; 592 } 593 594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( 595 const Type *Ty, MachineIRBuilder &MIRBuilder, 596 SPIRV::AccessQualifier::AccessQualifier AccQual) { 597 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type"); 598 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this); 599 } 600 601 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer( 602 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType, 603 MachineIRBuilder &MIRBuilder, Register Reg) { 604 if (!Reg.isValid()) 605 Reg = createTypeVReg(MIRBuilder); 606 return MIRBuilder.buildInstr(SPIRV::OpTypePointer) 607 .addDef(Reg) 608 .addImm(static_cast<uint32_t>(SC)) 609 .addUse(getSPIRVTypeID(ElemType)); 610 } 611 612 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer( 613 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) { 614 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer) 615 .addUse(createTypeVReg(MIRBuilder)) 616 .addImm(static_cast<uint32_t>(SC)); 617 } 618 619 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 620 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 621 MachineIRBuilder &MIRBuilder) { 622 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 623 .addDef(createTypeVReg(MIRBuilder)) 624 .addUse(getSPIRVTypeID(RetType)); 625 for (const SPIRVType *ArgType : ArgTypes) 626 MIB.addUse(getSPIRVTypeID(ArgType)); 627 return MIB; 628 } 629 630 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( 631 const Type *Ty, SPIRVType *RetType, 632 const SmallVectorImpl<SPIRVType *> &ArgTypes, 633 MachineIRBuilder &MIRBuilder) { 634 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 635 if (Reg.isValid()) 636 return getSPIRVTypeForVReg(Reg); 637 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); 638 DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType)); 639 return finishCreatingSPIRVType(Ty, SpirvType); 640 } 641 642 SPIRVType *SPIRVGlobalRegistry::findSPIRVType( 643 const Type *Ty, MachineIRBuilder &MIRBuilder, 644 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 645 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 646 if (Reg.isValid()) 647 return getSPIRVTypeForVReg(Reg); 648 if (ForwardPointerTypes.contains(Ty)) 649 return ForwardPointerTypes[Ty]; 650 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); 651 } 652 653 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { 654 assert(SpirvType && "Attempting to get type id for nullptr type."); 655 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer) 656 return SpirvType->uses().begin()->getReg(); 657 return SpirvType->defs().begin()->getReg(); 658 } 659 660 SPIRVType *SPIRVGlobalRegistry::createSPIRVType( 661 const Type *Ty, MachineIRBuilder &MIRBuilder, 662 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 663 if (isSpecialOpaqueType(Ty)) 664 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); 665 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); 666 auto t = TypeToSPIRVTypeMap.find(Ty); 667 if (t != TypeToSPIRVTypeMap.end()) { 668 auto tt = t->second.find(&MIRBuilder.getMF()); 669 if (tt != t->second.end()) 670 return getSPIRVTypeForVReg(tt->second); 671 } 672 673 if (auto IType = dyn_cast<IntegerType>(Ty)) { 674 const unsigned Width = IType->getBitWidth(); 675 return Width == 1 ? getOpTypeBool(MIRBuilder) 676 : getOpTypeInt(Width, MIRBuilder, false); 677 } 678 if (Ty->isFloatingPointTy()) 679 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 680 if (Ty->isVoidTy()) 681 return getOpTypeVoid(MIRBuilder); 682 if (Ty->isVectorTy()) { 683 SPIRVType *El = 684 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder); 685 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 686 MIRBuilder); 687 } 688 if (Ty->isArrayTy()) { 689 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder); 690 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 691 } 692 if (auto SType = dyn_cast<StructType>(Ty)) { 693 if (SType->isOpaque()) 694 return getOpTypeOpaque(SType, MIRBuilder); 695 return getOpTypeStruct(SType, MIRBuilder, EmitIR); 696 } 697 if (auto FType = dyn_cast<FunctionType>(Ty)) { 698 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder); 699 SmallVector<SPIRVType *, 4> ParamTypes; 700 for (const auto &t : FType->params()) { 701 ParamTypes.push_back(findSPIRVType(t, MIRBuilder)); 702 } 703 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 704 } 705 if (auto PType = dyn_cast<PointerType>(Ty)) { 706 SPIRVType *SpvElementType; 707 // At the moment, all opaque pointers correspond to i8 element type. 708 // TODO: change the implementation once opaque pointers are supported 709 // in the SPIR-V specification. 710 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 711 auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); 712 // Null pointer means we have a loop in type definitions, make and 713 // return corresponding OpTypeForwardPointer. 714 if (SpvElementType == nullptr) { 715 if (!ForwardPointerTypes.contains(Ty)) 716 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder); 717 return ForwardPointerTypes[PType]; 718 } 719 Register Reg(0); 720 // If we have forward pointer associated with this type, use its register 721 // operand to create OpTypePointer. 722 if (ForwardPointerTypes.contains(PType)) 723 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]); 724 725 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); 726 } 727 llvm_unreachable("Unable to convert LLVM type to SPIRVType"); 728 } 729 730 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( 731 const Type *Ty, MachineIRBuilder &MIRBuilder, 732 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 733 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy()) 734 return nullptr; 735 TypesInProcessing.insert(Ty); 736 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 737 TypesInProcessing.erase(Ty); 738 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 739 SPIRVToLLVMType[SpirvType] = Ty; 740 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 741 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type 742 // will be added later. For special types it is already added to DT. 743 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && 744 !isSpecialOpaqueType(Ty)) { 745 if (!Ty->isPointerTy()) 746 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 747 else 748 DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 749 Ty->getPointerAddressSpace(), &MIRBuilder.getMF(), 750 getSPIRVTypeID(SpirvType)); 751 } 752 753 return SpirvType; 754 } 755 756 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { 757 auto t = VRegToTypeMap.find(CurMF); 758 if (t != VRegToTypeMap.end()) { 759 auto tt = t->second.find(VReg); 760 if (tt != t->second.end()) 761 return tt->second; 762 } 763 return nullptr; 764 } 765 766 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 767 const Type *Ty, MachineIRBuilder &MIRBuilder, 768 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 769 Register Reg; 770 if (!Ty->isPointerTy()) 771 Reg = DT.find(Ty, &MIRBuilder.getMF()); 772 else 773 Reg = 774 DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 775 Ty->getPointerAddressSpace(), &MIRBuilder.getMF()); 776 777 if (Reg.isValid() && !isSpecialOpaqueType(Ty)) 778 return getSPIRVTypeForVReg(Reg); 779 TypesInProcessing.clear(); 780 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 781 // Create normal pointer types for the corresponding OpTypeForwardPointers. 782 for (auto &CU : ForwardPointerTypes) { 783 const Type *Ty2 = CU.first; 784 SPIRVType *STy2 = CU.second; 785 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid()) 786 STy2 = getSPIRVTypeForVReg(Reg); 787 else 788 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); 789 if (Ty == Ty2) 790 STy = STy2; 791 } 792 ForwardPointerTypes.clear(); 793 return STy; 794 } 795 796 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 797 unsigned TypeOpcode) const { 798 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 799 assert(Type && "isScalarOfType VReg has no type assigned"); 800 return Type->getOpcode() == TypeOpcode; 801 } 802 803 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 804 unsigned TypeOpcode) const { 805 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 806 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 807 if (Type->getOpcode() == TypeOpcode) 808 return true; 809 if (Type->getOpcode() == SPIRV::OpTypeVector) { 810 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 811 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 812 return ScalarType->getOpcode() == TypeOpcode; 813 } 814 return false; 815 } 816 817 unsigned 818 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 819 assert(Type && "Invalid Type pointer"); 820 if (Type->getOpcode() == SPIRV::OpTypeVector) { 821 auto EleTypeReg = Type->getOperand(1).getReg(); 822 Type = getSPIRVTypeForVReg(EleTypeReg); 823 } 824 if (Type->getOpcode() == SPIRV::OpTypeInt || 825 Type->getOpcode() == SPIRV::OpTypeFloat) 826 return Type->getOperand(1).getImm(); 827 if (Type->getOpcode() == SPIRV::OpTypeBool) 828 return 1; 829 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 830 } 831 832 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 833 assert(Type && "Invalid Type pointer"); 834 if (Type->getOpcode() == SPIRV::OpTypeVector) { 835 auto EleTypeReg = Type->getOperand(1).getReg(); 836 Type = getSPIRVTypeForVReg(EleTypeReg); 837 } 838 if (Type->getOpcode() == SPIRV::OpTypeInt) 839 return Type->getOperand(2).getImm() != 0; 840 llvm_unreachable("Attempting to get sign of non-integer type."); 841 } 842 843 SPIRV::StorageClass::StorageClass 844 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 845 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 846 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 847 Type->getOperand(1).isImm() && "Pointer type is expected"); 848 return static_cast<SPIRV::StorageClass::StorageClass>( 849 Type->getOperand(1).getImm()); 850 } 851 852 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( 853 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim, 854 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled, 855 SPIRV::ImageFormat::ImageFormat ImageFormat, 856 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 857 SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth, 858 Arrayed, Multisampled, Sampled, ImageFormat, 859 AccessQual); 860 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 861 return Res; 862 Register ResVReg = createTypeVReg(MIRBuilder); 863 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 864 return MIRBuilder.buildInstr(SPIRV::OpTypeImage) 865 .addDef(ResVReg) 866 .addUse(getSPIRVTypeID(SampledType)) 867 .addImm(Dim) 868 .addImm(Depth) // Depth (whether or not it is a Depth image). 869 .addImm(Arrayed) // Arrayed. 870 .addImm(Multisampled) // Multisampled (0 = only single-sample). 871 .addImm(Sampled) // Sampled (0 = usage known at runtime). 872 .addImm(ImageFormat) 873 .addImm(AccessQual); 874 } 875 876 SPIRVType * 877 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { 878 SPIRV::SamplerTypeDescriptor TD; 879 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 880 return Res; 881 Register ResVReg = createTypeVReg(MIRBuilder); 882 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 883 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg); 884 } 885 886 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( 887 MachineIRBuilder &MIRBuilder, 888 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 889 SPIRV::PipeTypeDescriptor TD(AccessQual); 890 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 891 return Res; 892 Register ResVReg = createTypeVReg(MIRBuilder); 893 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 894 return MIRBuilder.buildInstr(SPIRV::OpTypePipe) 895 .addDef(ResVReg) 896 .addImm(AccessQual); 897 } 898 899 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( 900 MachineIRBuilder &MIRBuilder) { 901 SPIRV::DeviceEventTypeDescriptor TD; 902 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 903 return Res; 904 Register ResVReg = createTypeVReg(MIRBuilder); 905 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 906 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg); 907 } 908 909 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( 910 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { 911 SPIRV::SampledImageTypeDescriptor TD( 912 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( 913 ImageType->getOperand(1).getReg())), 914 ImageType); 915 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 916 return Res; 917 Register ResVReg = createTypeVReg(MIRBuilder); 918 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 919 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage) 920 .addDef(ResVReg) 921 .addUse(getSPIRVTypeID(ImageType)); 922 } 923 924 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( 925 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { 926 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); 927 if (ResVReg.isValid()) 928 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 929 ResVReg = createTypeVReg(MIRBuilder); 930 SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg); 931 DT.add(Ty, &MIRBuilder.getMF(), ResVReg); 932 return SpirvTy; 933 } 934 935 const MachineInstr * 936 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, 937 MachineIRBuilder &MIRBuilder) { 938 Register Reg = DT.find(TD, &MIRBuilder.getMF()); 939 if (Reg.isValid()) 940 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg); 941 return nullptr; 942 } 943 944 // TODO: maybe use tablegen to implement this. 945 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( 946 StringRef TypeStr, MachineIRBuilder &MIRBuilder, 947 SPIRV::StorageClass::StorageClass SC, 948 SPIRV::AccessQualifier::AccessQualifier AQ) { 949 unsigned VecElts = 0; 950 auto &Ctx = MIRBuilder.getMF().getFunction().getContext(); 951 952 // Parse strings representing either a SPIR-V or OpenCL builtin type. 953 if (hasBuiltinTypePrefix(TypeStr)) 954 return getOrCreateSPIRVType( 955 SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder), 956 MIRBuilder, AQ); 957 958 // Parse type name in either "typeN" or "type vector[N]" format, where 959 // N is the number of elements of the vector. 960 Type *Ty; 961 962 TypeStr.consume_front("atomic_"); 963 964 if (TypeStr.starts_with("void")) { 965 Ty = Type::getVoidTy(Ctx); 966 TypeStr = TypeStr.substr(strlen("void")); 967 } else if (TypeStr.starts_with("bool")) { 968 Ty = Type::getIntNTy(Ctx, 1); 969 TypeStr = TypeStr.substr(strlen("bool")); 970 } else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) { 971 Ty = Type::getInt8Ty(Ctx); 972 TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char")) 973 : TypeStr.substr(strlen("uchar")); 974 } else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) { 975 Ty = Type::getInt16Ty(Ctx); 976 TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short")) 977 : TypeStr.substr(strlen("ushort")); 978 } else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) { 979 Ty = Type::getInt32Ty(Ctx); 980 TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int")) 981 : TypeStr.substr(strlen("uint")); 982 } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) { 983 Ty = Type::getInt64Ty(Ctx); 984 TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long")) 985 : TypeStr.substr(strlen("ulong")); 986 } else if (TypeStr.starts_with("half")) { 987 Ty = Type::getHalfTy(Ctx); 988 TypeStr = TypeStr.substr(strlen("half")); 989 } else if (TypeStr.starts_with("float")) { 990 Ty = Type::getFloatTy(Ctx); 991 TypeStr = TypeStr.substr(strlen("float")); 992 } else if (TypeStr.starts_with("double")) { 993 Ty = Type::getDoubleTy(Ctx); 994 TypeStr = TypeStr.substr(strlen("double")); 995 } else 996 llvm_unreachable("Unable to recognize SPIRV type name."); 997 998 auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ); 999 1000 // Handle "type*" or "type* vector[N]". 1001 if (TypeStr.starts_with("*")) { 1002 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1003 TypeStr = TypeStr.substr(strlen("*")); 1004 } 1005 1006 // Handle "typeN*" or "type vector[N]*". 1007 bool IsPtrToVec = TypeStr.consume_back("*"); 1008 1009 if (TypeStr.consume_front(" vector[")) { 1010 TypeStr = TypeStr.substr(0, TypeStr.find(']')); 1011 } 1012 TypeStr.getAsInteger(10, VecElts); 1013 if (VecElts > 0) 1014 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder); 1015 1016 if (IsPtrToVec) 1017 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1018 1019 return SpirvTy; 1020 } 1021 1022 SPIRVType * 1023 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 1024 MachineIRBuilder &MIRBuilder) { 1025 return getOrCreateSPIRVType( 1026 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 1027 MIRBuilder); 1028 } 1029 1030 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, 1031 SPIRVType *SpirvType) { 1032 assert(CurMF == SpirvType->getMF()); 1033 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 1034 SPIRVToLLVMType[SpirvType] = LLVMTy; 1035 return SpirvType; 1036 } 1037 1038 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 1039 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 1040 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 1041 Register Reg = DT.find(LLVMTy, CurMF); 1042 if (Reg.isValid()) 1043 return getSPIRVTypeForVReg(Reg); 1044 MachineBasicBlock &BB = *I.getParent(); 1045 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) 1046 .addDef(createTypeVReg(CurMF->getRegInfo())) 1047 .addImm(BitWidth) 1048 .addImm(0); 1049 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1050 return finishCreatingSPIRVType(LLVMTy, MIB); 1051 } 1052 1053 SPIRVType * 1054 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 1055 return getOrCreateSPIRVType( 1056 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 1057 MIRBuilder); 1058 } 1059 1060 SPIRVType * 1061 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, 1062 const SPIRVInstrInfo &TII) { 1063 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1); 1064 Register Reg = DT.find(LLVMTy, CurMF); 1065 if (Reg.isValid()) 1066 return getSPIRVTypeForVReg(Reg); 1067 MachineBasicBlock &BB = *I.getParent(); 1068 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool)) 1069 .addDef(createTypeVReg(CurMF->getRegInfo())); 1070 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1071 return finishCreatingSPIRVType(LLVMTy, MIB); 1072 } 1073 1074 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1075 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 1076 return getOrCreateSPIRVType( 1077 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1078 NumElements), 1079 MIRBuilder); 1080 } 1081 1082 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1083 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1084 const SPIRVInstrInfo &TII) { 1085 Type *LLVMTy = FixedVectorType::get( 1086 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1087 Register Reg = DT.find(LLVMTy, CurMF); 1088 if (Reg.isValid()) 1089 return getSPIRVTypeForVReg(Reg); 1090 MachineBasicBlock &BB = *I.getParent(); 1091 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 1092 .addDef(createTypeVReg(CurMF->getRegInfo())) 1093 .addUse(getSPIRVTypeID(BaseType)) 1094 .addImm(NumElements); 1095 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1096 return finishCreatingSPIRVType(LLVMTy, MIB); 1097 } 1098 1099 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType( 1100 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1101 const SPIRVInstrInfo &TII) { 1102 Type *LLVMTy = ArrayType::get( 1103 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1104 Register Reg = DT.find(LLVMTy, CurMF); 1105 if (Reg.isValid()) 1106 return getSPIRVTypeForVReg(Reg); 1107 MachineBasicBlock &BB = *I.getParent(); 1108 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII); 1109 Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII); 1110 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray)) 1111 .addDef(createTypeVReg(CurMF->getRegInfo())) 1112 .addUse(getSPIRVTypeID(BaseType)) 1113 .addUse(Len); 1114 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1115 return finishCreatingSPIRVType(LLVMTy, MIB); 1116 } 1117 1118 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1119 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, 1120 SPIRV::StorageClass::StorageClass SC) { 1121 const Type *PointerElementType = getTypeForSPIRVType(BaseType); 1122 unsigned AddressSpace = storageClassToAddressSpace(SC); 1123 Type *LLVMTy = 1124 PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace); 1125 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); 1126 if (Reg.isValid()) 1127 return getSPIRVTypeForVReg(Reg); 1128 auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(), 1129 MIRBuilder.getDebugLoc(), 1130 MIRBuilder.getTII().get(SPIRV::OpTypePointer)) 1131 .addDef(createTypeVReg(CurMF->getRegInfo())) 1132 .addImm(static_cast<uint32_t>(SC)) 1133 .addUse(getSPIRVTypeID(BaseType)); 1134 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); 1135 return finishCreatingSPIRVType(LLVMTy, MIB); 1136 } 1137 1138 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1139 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, 1140 SPIRV::StorageClass::StorageClass SC) { 1141 const Type *PointerElementType = getTypeForSPIRVType(BaseType); 1142 unsigned AddressSpace = storageClassToAddressSpace(SC); 1143 Type *LLVMTy = 1144 PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace); 1145 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); 1146 if (Reg.isValid()) 1147 return getSPIRVTypeForVReg(Reg); 1148 MachineBasicBlock &BB = *I.getParent(); 1149 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) 1150 .addDef(createTypeVReg(CurMF->getRegInfo())) 1151 .addImm(static_cast<uint32_t>(SC)) 1152 .addUse(getSPIRVTypeID(BaseType)); 1153 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); 1154 return finishCreatingSPIRVType(LLVMTy, MIB); 1155 } 1156 1157 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, 1158 SPIRVType *SpvType, 1159 const SPIRVInstrInfo &TII) { 1160 assert(SpvType); 1161 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 1162 assert(LLVMTy); 1163 // Find a constant in DT or build a new one. 1164 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy)); 1165 Register Res = DT.find(UV, CurMF); 1166 if (Res.isValid()) 1167 return Res; 1168 LLT LLTy = LLT::scalar(32); 1169 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 1170 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 1171 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 1172 DT.add(UV, CurMF, Res); 1173 1174 MachineInstrBuilder MIB; 1175 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) 1176 .addDef(Res) 1177 .addUse(getSPIRVTypeID(SpvType)); 1178 const auto &ST = CurMF->getSubtarget(); 1179 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 1180 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 1181 return Res; 1182 } 1183