1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVSubtarget.h" 20 #include "SPIRVTargetMachine.h" 21 #include "SPIRVUtils.h" 22 23 using namespace llvm; 24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 25 : PointerSize(PointerSize) {} 26 27 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth, 28 Register VReg, 29 MachineInstr &I, 30 const SPIRVInstrInfo &TII) { 31 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 32 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 33 return SpirvType; 34 } 35 36 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( 37 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I, 38 const SPIRVInstrInfo &TII) { 39 SPIRVType *SpirvType = 40 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII); 41 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 42 return SpirvType; 43 } 44 45 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 46 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 47 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 48 49 SPIRVType *SpirvType = 50 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 51 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 52 return SpirvType; 53 } 54 55 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 56 Register VReg, 57 MachineFunction &MF) { 58 VRegToTypeMap[&MF][VReg] = SpirvType; 59 } 60 61 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 62 auto &MRI = MIRBuilder.getMF().getRegInfo(); 63 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 64 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 65 return Res; 66 } 67 68 static Register createTypeVReg(MachineRegisterInfo &MRI) { 69 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 70 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 71 return Res; 72 } 73 74 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 75 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 76 .addDef(createTypeVReg(MIRBuilder)); 77 } 78 79 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, 80 MachineIRBuilder &MIRBuilder, 81 bool IsSigned) { 82 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 83 .addDef(createTypeVReg(MIRBuilder)) 84 .addImm(Width) 85 .addImm(IsSigned ? 1 : 0); 86 return MIB; 87 } 88 89 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 90 MachineIRBuilder &MIRBuilder) { 91 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 92 .addDef(createTypeVReg(MIRBuilder)) 93 .addImm(Width); 94 return MIB; 95 } 96 97 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 98 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 99 .addDef(createTypeVReg(MIRBuilder)); 100 } 101 102 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 103 SPIRVType *ElemType, 104 MachineIRBuilder &MIRBuilder) { 105 auto EleOpc = ElemType->getOpcode(); 106 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 107 EleOpc == SPIRV::OpTypeBool) && 108 "Invalid vector element type"); 109 110 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 111 .addDef(createTypeVReg(MIRBuilder)) 112 .addUse(getSPIRVTypeID(ElemType)) 113 .addImm(NumElems); 114 return MIB; 115 } 116 117 std::tuple<Register, ConstantInt *, bool> 118 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, 119 MachineIRBuilder *MIRBuilder, 120 MachineInstr *I, 121 const SPIRVInstrInfo *TII) { 122 const IntegerType *LLVMIntTy; 123 if (SpvType) 124 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 125 else 126 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext()); 127 bool NewInstr = false; 128 // Find a constant in DT or build a new one. 129 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 130 Register Res = DT.find(CI, CurMF); 131 if (!Res.isValid()) { 132 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 133 LLT LLTy = LLT::scalar(32); 134 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 135 if (MIRBuilder) 136 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); 137 else 138 assignIntTypeToVReg(BitWidth, Res, *I, *TII); 139 DT.add(CI, CurMF, Res); 140 NewInstr = true; 141 } 142 return std::make_tuple(Res, CI, NewInstr); 143 } 144 145 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, 146 SPIRVType *SpvType, 147 const SPIRVInstrInfo &TII) { 148 assert(SpvType); 149 ConstantInt *CI; 150 Register Res; 151 bool New; 152 std::tie(Res, CI, New) = 153 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII); 154 // If we have found Res register which is defined by the passed G_CONSTANT 155 // machine instruction, a new constant instruction should be created. 156 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 157 return Res; 158 MachineInstrBuilder MIB; 159 MachineBasicBlock &BB = *I.getParent(); 160 if (Val) { 161 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) 162 .addDef(Res) 163 .addUse(getSPIRVTypeID(SpvType)); 164 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB); 165 } else { 166 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 167 .addDef(Res) 168 .addUse(getSPIRVTypeID(SpvType)); 169 } 170 const auto &ST = CurMF->getSubtarget(); 171 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 172 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 173 return Res; 174 } 175 176 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 177 MachineIRBuilder &MIRBuilder, 178 SPIRVType *SpvType, 179 bool EmitIR) { 180 auto &MF = MIRBuilder.getMF(); 181 const IntegerType *LLVMIntTy; 182 if (SpvType) 183 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 184 else 185 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 186 // Find a constant in DT or build a new one. 187 const auto ConstInt = 188 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 189 Register Res = DT.find(ConstInt, &MF); 190 if (!Res.isValid()) { 191 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 192 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); 193 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); 194 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, 195 SPIRV::AccessQualifier::ReadWrite, EmitIR); 196 DT.add(ConstInt, &MIRBuilder.getMF(), Res); 197 if (EmitIR) { 198 MIRBuilder.buildConstant(Res, *ConstInt); 199 } else { 200 MachineInstrBuilder MIB; 201 if (Val) { 202 assert(SpvType); 203 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 204 .addDef(Res) 205 .addUse(getSPIRVTypeID(SpvType)); 206 addNumImm(APInt(BitWidth, Val), MIB); 207 } else { 208 assert(SpvType); 209 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 210 .addDef(Res) 211 .addUse(getSPIRVTypeID(SpvType)); 212 } 213 const auto &Subtarget = CurMF->getSubtarget(); 214 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 215 *Subtarget.getRegisterInfo(), 216 *Subtarget.getRegBankInfo()); 217 } 218 } 219 return Res; 220 } 221 222 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 223 MachineIRBuilder &MIRBuilder, 224 SPIRVType *SpvType) { 225 auto &MF = MIRBuilder.getMF(); 226 const Type *LLVMFPTy; 227 if (SpvType) { 228 LLVMFPTy = getTypeForSPIRVType(SpvType); 229 assert(LLVMFPTy->isFloatingPointTy()); 230 } else { 231 LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); 232 } 233 // Find a constant in DT or build a new one. 234 const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); 235 Register Res = DT.find(ConstFP, &MF); 236 if (!Res.isValid()) { 237 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 238 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 239 assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); 240 DT.add(ConstFP, &MF, Res); 241 MIRBuilder.buildFConstant(Res, *ConstFP); 242 } 243 return Res; 244 } 245 246 Register 247 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, 248 SPIRVType *SpvType, 249 const SPIRVInstrInfo &TII) { 250 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 251 assert(LLVMTy->isVectorTy()); 252 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 253 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 254 // Find a constant vector in DT or build a new one. 255 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 256 auto ConstVec = 257 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 258 Register Res = DT.find(ConstVec, CurMF); 259 if (!Res.isValid()) { 260 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType); 261 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 262 // SpvScalConst should be created before SpvVecConst to avoid undefined ID 263 // error on validation. 264 // TODO: can moved below once sorting of types/consts/defs is implemented. 265 Register SpvScalConst; 266 if (Val) 267 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII); 268 // TODO: maybe use bitwidth of base type. 269 LLT LLTy = LLT::scalar(32); 270 Register SpvVecConst = 271 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 272 const unsigned ElemCnt = SpvType->getOperand(2).getImm(); 273 assignVectTypeToVReg(SpvBaseType, ElemCnt, SpvVecConst, I, TII); 274 DT.add(ConstVec, CurMF, SpvVecConst); 275 MachineInstrBuilder MIB; 276 MachineBasicBlock &BB = *I.getParent(); 277 if (Val) { 278 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite)) 279 .addDef(SpvVecConst) 280 .addUse(getSPIRVTypeID(SpvType)); 281 for (unsigned i = 0; i < ElemCnt; ++i) 282 MIB.addUse(SpvScalConst); 283 } else { 284 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 285 .addDef(SpvVecConst) 286 .addUse(getSPIRVTypeID(SpvType)); 287 } 288 const auto &Subtarget = CurMF->getSubtarget(); 289 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 290 *Subtarget.getRegisterInfo(), 291 *Subtarget.getRegBankInfo()); 292 return SpvVecConst; 293 } 294 return Res; 295 } 296 297 Register SPIRVGlobalRegistry::buildGlobalVariable( 298 Register ResVReg, SPIRVType *BaseType, StringRef Name, 299 const GlobalValue *GV, SPIRV::StorageClass Storage, 300 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 301 SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 302 bool IsInstSelector) { 303 const GlobalVariable *GVar = nullptr; 304 if (GV) 305 GVar = cast<const GlobalVariable>(GV); 306 else { 307 // If GV is not passed explicitly, use the name to find or construct 308 // the global variable. 309 Module *M = MIRBuilder.getMF().getFunction().getParent(); 310 GVar = M->getGlobalVariable(Name); 311 if (GVar == nullptr) { 312 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 313 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 314 GlobalValue::ExternalLinkage, nullptr, 315 Twine(Name)); 316 } 317 GV = GVar; 318 } 319 Register Reg = DT.find(GVar, &MIRBuilder.getMF()); 320 if (Reg.isValid()) { 321 if (Reg != ResVReg) 322 MIRBuilder.buildCopy(ResVReg, Reg); 323 return ResVReg; 324 } 325 326 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 327 .addDef(ResVReg) 328 .addUse(getSPIRVTypeID(BaseType)) 329 .addImm(static_cast<uint32_t>(Storage)); 330 331 if (Init != 0) { 332 MIB.addUse(Init->getOperand(0).getReg()); 333 } 334 335 // ISel may introduce a new register on this step, so we need to add it to 336 // DT and correct its type avoiding fails on the next stage. 337 if (IsInstSelector) { 338 const auto &Subtarget = CurMF->getSubtarget(); 339 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 340 *Subtarget.getRegisterInfo(), 341 *Subtarget.getRegBankInfo()); 342 } 343 Reg = MIB->getOperand(0).getReg(); 344 DT.add(GVar, &MIRBuilder.getMF(), Reg); 345 346 // Set to Reg the same type as ResVReg has. 347 auto MRI = MIRBuilder.getMRI(); 348 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 349 if (Reg != ResVReg) { 350 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); 351 MRI->setType(Reg, RegLLTy); 352 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 353 } 354 355 // If it's a global variable with name, output OpName for it. 356 if (GVar && GVar->hasName()) 357 buildOpName(Reg, GVar->getName(), MIRBuilder); 358 359 // Output decorations for the GV. 360 // TODO: maybe move to GenerateDecorations pass. 361 if (IsConst) 362 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 363 364 if (GVar && GVar->getAlign().valueOrOne().value() != 1) 365 buildOpDecorate( 366 Reg, MIRBuilder, SPIRV::Decoration::Alignment, 367 {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())}); 368 369 if (HasLinkageTy) 370 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 371 {static_cast<uint32_t>(LinkageType)}, Name); 372 return Reg; 373 } 374 375 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 376 SPIRVType *ElemType, 377 MachineIRBuilder &MIRBuilder, 378 bool EmitIR) { 379 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 380 "Invalid array element type"); 381 Register NumElementsVReg = 382 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 383 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 384 .addDef(createTypeVReg(MIRBuilder)) 385 .addUse(getSPIRVTypeID(ElemType)) 386 .addUse(NumElementsVReg); 387 return MIB; 388 } 389 390 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, 391 MachineIRBuilder &MIRBuilder) { 392 assert(Ty->hasName()); 393 const StringRef Name = Ty->hasName() ? Ty->getName() : ""; 394 Register ResVReg = createTypeVReg(MIRBuilder); 395 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg); 396 addStringImm(Name, MIB); 397 buildOpName(ResVReg, Name, MIRBuilder); 398 return MIB; 399 } 400 401 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, 402 MachineIRBuilder &MIRBuilder, 403 bool EmitIR) { 404 SmallVector<Register, 4> FieldTypes; 405 for (const auto &Elem : Ty->elements()) { 406 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder); 407 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && 408 "Invalid struct element type"); 409 FieldTypes.push_back(getSPIRVTypeID(ElemTy)); 410 } 411 Register ResVReg = createTypeVReg(MIRBuilder); 412 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); 413 for (const auto &Ty : FieldTypes) 414 MIB.addUse(Ty); 415 if (Ty->hasName()) 416 buildOpName(ResVReg, Ty->getName(), MIRBuilder); 417 if (Ty->isPacked()) 418 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); 419 return MIB; 420 } 421 422 static bool isOpenCLBuiltinType(const StructType *SType) { 423 return SType->isOpaque() && SType->hasName() && 424 SType->getName().startswith("opencl."); 425 } 426 427 static bool isSPIRVBuiltinType(const StructType *SType) { 428 return SType->isOpaque() && SType->hasName() && 429 SType->getName().startswith("spirv."); 430 } 431 432 static bool isSpecialType(const Type *Ty) { 433 if (auto PType = dyn_cast<PointerType>(Ty)) { 434 if (!PType->isOpaque()) 435 Ty = PType->getNonOpaquePointerElementType(); 436 } 437 if (auto SType = dyn_cast<StructType>(Ty)) 438 return isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType); 439 return false; 440 } 441 442 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, 443 SPIRVType *ElemType, 444 MachineIRBuilder &MIRBuilder, 445 Register Reg) { 446 if (!Reg.isValid()) 447 Reg = createTypeVReg(MIRBuilder); 448 return MIRBuilder.buildInstr(SPIRV::OpTypePointer) 449 .addDef(Reg) 450 .addImm(static_cast<uint32_t>(SC)) 451 .addUse(getSPIRVTypeID(ElemType)); 452 } 453 454 SPIRVType * 455 SPIRVGlobalRegistry::getOpTypeForwardPointer(SPIRV::StorageClass SC, 456 MachineIRBuilder &MIRBuilder) { 457 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer) 458 .addUse(createTypeVReg(MIRBuilder)) 459 .addImm(static_cast<uint32_t>(SC)); 460 } 461 462 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 463 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 464 MachineIRBuilder &MIRBuilder) { 465 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 466 .addDef(createTypeVReg(MIRBuilder)) 467 .addUse(getSPIRVTypeID(RetType)); 468 for (const SPIRVType *ArgType : ArgTypes) 469 MIB.addUse(getSPIRVTypeID(ArgType)); 470 return MIB; 471 } 472 473 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( 474 const Type *Ty, SPIRVType *RetType, 475 const SmallVectorImpl<SPIRVType *> &ArgTypes, 476 MachineIRBuilder &MIRBuilder) { 477 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 478 if (Reg.isValid()) 479 return getSPIRVTypeForVReg(Reg); 480 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); 481 return finishCreatingSPIRVType(Ty, SpirvType); 482 } 483 484 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(const Type *Ty, 485 MachineIRBuilder &MIRBuilder, 486 SPIRV::AccessQualifier AccQual, 487 bool EmitIR) { 488 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 489 if (Reg.isValid()) 490 return getSPIRVTypeForVReg(Reg); 491 if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end()) 492 return ForwardPointerTypes[Ty]; 493 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); 494 } 495 496 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { 497 assert(SpirvType && "Attempting to get type id for nullptr type."); 498 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer) 499 return SpirvType->uses().begin()->getReg(); 500 return SpirvType->defs().begin()->getReg(); 501 } 502 503 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, 504 MachineIRBuilder &MIRBuilder, 505 SPIRV::AccessQualifier AccQual, 506 bool EmitIR) { 507 assert(!isSpecialType(Ty)); 508 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); 509 auto t = TypeToSPIRVTypeMap.find(Ty); 510 if (t != TypeToSPIRVTypeMap.end()) { 511 auto tt = t->second.find(&MIRBuilder.getMF()); 512 if (tt != t->second.end()) 513 return getSPIRVTypeForVReg(tt->second); 514 } 515 516 if (auto IType = dyn_cast<IntegerType>(Ty)) { 517 const unsigned Width = IType->getBitWidth(); 518 return Width == 1 ? getOpTypeBool(MIRBuilder) 519 : getOpTypeInt(Width, MIRBuilder, false); 520 } 521 if (Ty->isFloatingPointTy()) 522 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 523 if (Ty->isVoidTy()) 524 return getOpTypeVoid(MIRBuilder); 525 if (Ty->isVectorTy()) { 526 SPIRVType *El = 527 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder); 528 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 529 MIRBuilder); 530 } 531 if (Ty->isArrayTy()) { 532 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder); 533 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 534 } 535 if (auto SType = dyn_cast<StructType>(Ty)) { 536 if (SType->isOpaque()) 537 return getOpTypeOpaque(SType, MIRBuilder); 538 return getOpTypeStruct(SType, MIRBuilder, EmitIR); 539 } 540 if (auto FType = dyn_cast<FunctionType>(Ty)) { 541 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder); 542 SmallVector<SPIRVType *, 4> ParamTypes; 543 for (const auto &t : FType->params()) { 544 ParamTypes.push_back(findSPIRVType(t, MIRBuilder)); 545 } 546 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 547 } 548 if (auto PType = dyn_cast<PointerType>(Ty)) { 549 SPIRVType *SpvElementType; 550 // At the moment, all opaque pointers correspond to i8 element type. 551 // TODO: change the implementation once opaque pointers are supported 552 // in the SPIR-V specification. 553 if (PType->isOpaque()) 554 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 555 else 556 SpvElementType = 557 findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder, 558 SPIRV::AccessQualifier::ReadWrite, EmitIR); 559 auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); 560 // Null pointer means we have a loop in type definitions, make and 561 // return corresponding OpTypeForwardPointer. 562 if (SpvElementType == nullptr) { 563 if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end()) 564 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder); 565 return ForwardPointerTypes[PType]; 566 } 567 Register Reg(0); 568 // If we have forward pointer associated with this type, use its register 569 // operand to create OpTypePointer. 570 if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end()) 571 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]); 572 573 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); 574 } 575 llvm_unreachable("Unable to convert LLVM type to SPIRVType"); 576 } 577 578 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( 579 const Type *Ty, MachineIRBuilder &MIRBuilder, 580 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 581 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy()) 582 return nullptr; 583 TypesInProcessing.insert(Ty); 584 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 585 TypesInProcessing.erase(Ty); 586 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 587 SPIRVToLLVMType[SpirvType] = Ty; 588 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 589 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type 590 // will be added later. For special types it is already added to DT. 591 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && 592 !isSpecialType(Ty)) 593 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 594 595 return SpirvType; 596 } 597 598 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { 599 auto t = VRegToTypeMap.find(CurMF); 600 if (t != VRegToTypeMap.end()) { 601 auto tt = t->second.find(VReg); 602 if (tt != t->second.end()) 603 return tt->second; 604 } 605 return nullptr; 606 } 607 608 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 609 const Type *Ty, MachineIRBuilder &MIRBuilder, 610 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 611 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 612 if (Reg.isValid()) 613 return getSPIRVTypeForVReg(Reg); 614 TypesInProcessing.clear(); 615 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 616 // Create normal pointer types for the corresponding OpTypeForwardPointers. 617 for (auto &CU : ForwardPointerTypes) { 618 const Type *Ty2 = CU.first; 619 SPIRVType *STy2 = CU.second; 620 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid()) 621 STy2 = getSPIRVTypeForVReg(Reg); 622 else 623 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); 624 if (Ty == Ty2) 625 STy = STy2; 626 } 627 ForwardPointerTypes.clear(); 628 return STy; 629 } 630 631 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 632 unsigned TypeOpcode) const { 633 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 634 assert(Type && "isScalarOfType VReg has no type assigned"); 635 return Type->getOpcode() == TypeOpcode; 636 } 637 638 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 639 unsigned TypeOpcode) const { 640 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 641 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 642 if (Type->getOpcode() == TypeOpcode) 643 return true; 644 if (Type->getOpcode() == SPIRV::OpTypeVector) { 645 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 646 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 647 return ScalarType->getOpcode() == TypeOpcode; 648 } 649 return false; 650 } 651 652 unsigned 653 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 654 assert(Type && "Invalid Type pointer"); 655 if (Type->getOpcode() == SPIRV::OpTypeVector) { 656 auto EleTypeReg = Type->getOperand(1).getReg(); 657 Type = getSPIRVTypeForVReg(EleTypeReg); 658 } 659 if (Type->getOpcode() == SPIRV::OpTypeInt || 660 Type->getOpcode() == SPIRV::OpTypeFloat) 661 return Type->getOperand(1).getImm(); 662 if (Type->getOpcode() == SPIRV::OpTypeBool) 663 return 1; 664 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 665 } 666 667 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 668 assert(Type && "Invalid Type pointer"); 669 if (Type->getOpcode() == SPIRV::OpTypeVector) { 670 auto EleTypeReg = Type->getOperand(1).getReg(); 671 Type = getSPIRVTypeForVReg(EleTypeReg); 672 } 673 if (Type->getOpcode() == SPIRV::OpTypeInt) 674 return Type->getOperand(2).getImm() != 0; 675 llvm_unreachable("Attempting to get sign of non-integer type."); 676 } 677 678 SPIRV::StorageClass 679 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 680 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 681 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 682 Type->getOperand(1).isImm() && "Pointer type is expected"); 683 return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm()); 684 } 685 686 SPIRVType * 687 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 688 MachineIRBuilder &MIRBuilder) { 689 return getOrCreateSPIRVType( 690 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 691 MIRBuilder); 692 } 693 694 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, 695 SPIRVType *SpirvType) { 696 assert(CurMF == SpirvType->getMF()); 697 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 698 SPIRVToLLVMType[SpirvType] = LLVMTy; 699 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); 700 return SpirvType; 701 } 702 703 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 704 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 705 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 706 Register Reg = DT.find(LLVMTy, CurMF); 707 if (Reg.isValid()) 708 return getSPIRVTypeForVReg(Reg); 709 MachineBasicBlock &BB = *I.getParent(); 710 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) 711 .addDef(createTypeVReg(CurMF->getRegInfo())) 712 .addImm(BitWidth) 713 .addImm(0); 714 return finishCreatingSPIRVType(LLVMTy, MIB); 715 } 716 717 SPIRVType * 718 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 719 return getOrCreateSPIRVType( 720 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 721 MIRBuilder); 722 } 723 724 SPIRVType * 725 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, 726 const SPIRVInstrInfo &TII) { 727 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1); 728 Register Reg = DT.find(LLVMTy, CurMF); 729 if (Reg.isValid()) 730 return getSPIRVTypeForVReg(Reg); 731 MachineBasicBlock &BB = *I.getParent(); 732 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool)) 733 .addDef(createTypeVReg(CurMF->getRegInfo())); 734 return finishCreatingSPIRVType(LLVMTy, MIB); 735 } 736 737 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 738 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 739 return getOrCreateSPIRVType( 740 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 741 NumElements), 742 MIRBuilder); 743 } 744 745 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 746 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 747 const SPIRVInstrInfo &TII) { 748 Type *LLVMTy = FixedVectorType::get( 749 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 750 Register Reg = DT.find(LLVMTy, CurMF); 751 if (Reg.isValid()) 752 return getSPIRVTypeForVReg(Reg); 753 MachineBasicBlock &BB = *I.getParent(); 754 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 755 .addDef(createTypeVReg(CurMF->getRegInfo())) 756 .addUse(getSPIRVTypeID(BaseType)) 757 .addImm(NumElements); 758 return finishCreatingSPIRVType(LLVMTy, MIB); 759 } 760 761 SPIRVType * 762 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, 763 MachineIRBuilder &MIRBuilder, 764 SPIRV::StorageClass SClass) { 765 return getOrCreateSPIRVType( 766 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 767 storageClassToAddressSpace(SClass)), 768 MIRBuilder); 769 } 770 771 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 772 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, 773 SPIRV::StorageClass SC) { 774 Type *LLVMTy = 775 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 776 storageClassToAddressSpace(SC)); 777 Register Reg = DT.find(LLVMTy, CurMF); 778 if (Reg.isValid()) 779 return getSPIRVTypeForVReg(Reg); 780 MachineBasicBlock &BB = *I.getParent(); 781 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) 782 .addDef(createTypeVReg(CurMF->getRegInfo())) 783 .addImm(static_cast<uint32_t>(SC)) 784 .addUse(getSPIRVTypeID(BaseType)); 785 return finishCreatingSPIRVType(LLVMTy, MIB); 786 } 787 788 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, 789 SPIRVType *SpvType, 790 const SPIRVInstrInfo &TII) { 791 assert(SpvType); 792 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 793 assert(LLVMTy); 794 // Find a constant in DT or build a new one. 795 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy)); 796 Register Res = DT.find(UV, CurMF); 797 if (Res.isValid()) 798 return Res; 799 LLT LLTy = LLT::scalar(32); 800 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 801 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 802 DT.add(UV, CurMF, Res); 803 804 MachineInstrBuilder MIB; 805 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) 806 .addDef(Res) 807 .addUse(getSPIRVTypeID(SpvType)); 808 const auto &ST = CurMF->getSubtarget(); 809 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 810 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 811 return Res; 812 } 813