1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of LLVM calls to machine code calls for 10 // GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRVCallLowering.h" 15 #include "MCTargetDesc/SPIRVBaseInfo.h" 16 #include "SPIRV.h" 17 #include "SPIRVBuiltins.h" 18 #include "SPIRVGlobalRegistry.h" 19 #include "SPIRVISelLowering.h" 20 #include "SPIRVMetadata.h" 21 #include "SPIRVRegisterInfo.h" 22 #include "SPIRVSubtarget.h" 23 #include "SPIRVUtils.h" 24 #include "llvm/CodeGen/FunctionLoweringInfo.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/IntrinsicsSPIRV.h" 27 #include "llvm/Support/ModRef.h" 28 29 using namespace llvm; 30 31 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 32 SPIRVGlobalRegistry *GR) 33 : CallLowering(&TLI), GR(GR) {} 34 35 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 36 const Value *Val, ArrayRef<Register> VRegs, 37 FunctionLoweringInfo &FLI, 38 Register SwiftErrorVReg) const { 39 // Maybe run postponed production of types for function pointers 40 if (IndirectCalls.size() > 0) { 41 produceIndirectPtrTypes(MIRBuilder); 42 IndirectCalls.clear(); 43 } 44 45 // Currently all return types should use a single register. 46 // TODO: handle the case of multiple registers. 47 if (VRegs.size() > 1) 48 return false; 49 if (Val) { 50 const auto &STI = MIRBuilder.getMF().getSubtarget(); 51 return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 52 .addUse(VRegs[0]) 53 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 54 *STI.getRegBankInfo()); 55 } 56 MIRBuilder.buildInstr(SPIRV::OpReturn); 57 return true; 58 } 59 60 // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 61 static uint32_t getFunctionControl(const Function &F) { 62 MemoryEffects MemEffects = F.getMemoryEffects(); 63 64 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 65 66 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) 67 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 68 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) 69 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 70 71 if (MemEffects.doesNotAccessMemory()) 72 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 73 else if (MemEffects.onlyReadsMemory()) 74 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 75 76 return FuncControl; 77 } 78 79 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { 80 if (MD->getNumOperands() > NumOp) { 81 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); 82 if (CMeta) 83 return dyn_cast<ConstantInt>(CMeta->getValue()); 84 } 85 return nullptr; 86 } 87 88 // If the function has pointer arguments, we are forced to re-create this 89 // function type from the very beginning, changing PointerType by 90 // TypedPointerType for each pointer argument. Otherwise, the same `Type*` 91 // potentially corresponds to different SPIR-V function type, effectively 92 // invalidating logic behind global registry and duplicates tracker. 93 static FunctionType * 94 fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F, 95 FunctionType *FTy, const SPIRVType *SRetTy, 96 const SmallVector<SPIRVType *, 4> &SArgTys) { 97 if (F.getParent()->getNamedMetadata("spv.cloned_funcs")) 98 return FTy; 99 100 bool hasArgPtrs = false; 101 for (auto &Arg : F.args()) { 102 // check if it's an instance of a non-typed PointerType 103 if (Arg.getType()->isPointerTy()) { 104 hasArgPtrs = true; 105 break; 106 } 107 } 108 if (!hasArgPtrs) { 109 Type *RetTy = FTy->getReturnType(); 110 // check if it's an instance of a non-typed PointerType 111 if (!RetTy->isPointerTy()) 112 return FTy; 113 } 114 115 // re-create function type, using TypedPointerType instead of PointerType to 116 // properly trace argument types 117 const Type *RetTy = GR->getTypeForSPIRVType(SRetTy); 118 SmallVector<Type *, 4> ArgTys; 119 for (auto SArgTy : SArgTys) 120 ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy))); 121 return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false); 122 } 123 124 // This code restores function args/retvalue types for composite cases 125 // because the final types should still be aggregate whereas they're i32 126 // during the translation to cope with aggregate flattening etc. 127 static FunctionType *getOriginalFunctionType(const Function &F) { 128 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); 129 if (NamedMD == nullptr) 130 return F.getFunctionType(); 131 132 Type *RetTy = F.getFunctionType()->getReturnType(); 133 SmallVector<Type *, 4> ArgTypes; 134 for (auto &Arg : F.args()) 135 ArgTypes.push_back(Arg.getType()); 136 137 auto ThisFuncMDIt = 138 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { 139 return isa<MDString>(N->getOperand(0)) && 140 cast<MDString>(N->getOperand(0))->getString() == F.getName(); 141 }); 142 // TODO: probably one function can have numerous type mutations, 143 // so we should support this. 144 if (ThisFuncMDIt != NamedMD->op_end()) { 145 auto *ThisFuncMD = *ThisFuncMDIt; 146 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); 147 assert(MD && "MDNode operand is expected"); 148 ConstantInt *Const = getConstInt(MD, 0); 149 if (Const) { 150 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); 151 assert(CMeta && "ConstantAsMetadata operand is expected"); 152 assert(Const->getSExtValue() >= -1); 153 // Currently -1 indicates return value, greater values mean 154 // argument numbers. 155 if (Const->getSExtValue() == -1) 156 RetTy = CMeta->getType(); 157 else 158 ArgTypes[Const->getSExtValue()] = CMeta->getType(); 159 } 160 } 161 162 return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); 163 } 164 165 static SPIRV::AccessQualifier::AccessQualifier 166 getArgAccessQual(const Function &F, unsigned ArgIdx) { 167 if (F.getCallingConv() != CallingConv::SPIR_KERNEL) 168 return SPIRV::AccessQualifier::ReadWrite; 169 170 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx); 171 if (!ArgAttribute) 172 return SPIRV::AccessQualifier::ReadWrite; 173 174 if (ArgAttribute->getString() == "read_only") 175 return SPIRV::AccessQualifier::ReadOnly; 176 if (ArgAttribute->getString() == "write_only") 177 return SPIRV::AccessQualifier::WriteOnly; 178 return SPIRV::AccessQualifier::ReadWrite; 179 } 180 181 static std::vector<SPIRV::Decoration::Decoration> 182 getKernelArgTypeQual(const Function &F, unsigned ArgIdx) { 183 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx); 184 if (ArgAttribute && ArgAttribute->getString() == "volatile") 185 return {SPIRV::Decoration::Volatile}; 186 return {}; 187 } 188 189 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, 190 SPIRVGlobalRegistry *GR, 191 MachineIRBuilder &MIRBuilder, 192 const SPIRVSubtarget &ST) { 193 // Read argument's access qualifier from metadata or default. 194 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = 195 getArgAccessQual(F, ArgIdx); 196 197 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); 198 199 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot 200 // be legally reassigned later). 201 if (!isPointerTy(OriginalArgType)) 202 return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); 203 204 Argument *Arg = F.getArg(ArgIdx); 205 Type *ArgType = Arg->getType(); 206 if (isTypedPointerTy(ArgType)) { 207 SPIRVType *ElementType = GR->getOrCreateSPIRVType( 208 cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder); 209 return GR->getOrCreateSPIRVPointerType( 210 ElementType, MIRBuilder, 211 addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); 212 } 213 214 // In case OriginalArgType is of untyped pointer type, there are three 215 // possibilities: 216 // 1) This is a pointer of an LLVM IR element type, passed byval/byref. 217 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type 218 // intrinsic assigning a TargetExtType. 219 // 3) This is a pointer, try to retrieve pointer element type from a 220 // spv_assign_ptr_type intrinsic or otherwise use default pointer element 221 // type. 222 if (hasPointeeTypeAttr(Arg)) { 223 SPIRVType *ElementType = 224 GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder); 225 return GR->getOrCreateSPIRVPointerType( 226 ElementType, MIRBuilder, 227 addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); 228 } 229 230 for (auto User : Arg->users()) { 231 auto *II = dyn_cast<IntrinsicInst>(User); 232 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type. 233 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) { 234 MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1)); 235 Type *BuiltinType = 236 cast<ConstantAsMetadata>(VMD->getMetadata())->getType(); 237 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType"); 238 return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual); 239 } 240 241 // Check if this is spv_assign_ptr_type assigning pointer element type. 242 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type) 243 continue; 244 245 MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1)); 246 Type *ElementTy = 247 toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType()); 248 SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder); 249 return GR->getOrCreateSPIRVPointerType( 250 ElementType, MIRBuilder, 251 addressSpaceToStorageClass( 252 cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST)); 253 } 254 255 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to 256 // LLVM types in a consistent manner 257 return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder, 258 ArgAccessQual); 259 } 260 261 static SPIRV::ExecutionModel::ExecutionModel 262 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) { 263 if (STI.isOpenCLEnv()) 264 return SPIRV::ExecutionModel::Kernel; 265 266 auto attribute = F.getFnAttribute("hlsl.shader"); 267 if (!attribute.isValid()) { 268 report_fatal_error( 269 "This entry point lacks mandatory hlsl.shader attribute."); 270 } 271 272 const auto value = attribute.getValueAsString(); 273 if (value == "compute") 274 return SPIRV::ExecutionModel::GLCompute; 275 276 report_fatal_error("This HLSL entry point is not supported by this backend."); 277 } 278 279 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 280 const Function &F, 281 ArrayRef<ArrayRef<Register>> VRegs, 282 FunctionLoweringInfo &FLI) const { 283 assert(GR && "Must initialize the SPIRV type registry before lowering args."); 284 GR->setCurrentFunc(MIRBuilder.getMF()); 285 286 // Get access to information about available extensions 287 const SPIRVSubtarget *ST = 288 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget()); 289 290 // Assign types and names to all args, and store their types for later. 291 SmallVector<SPIRVType *, 4> ArgTypeVRegs; 292 if (VRegs.size() > 0) { 293 unsigned i = 0; 294 for (const auto &Arg : F.args()) { 295 // Currently formal args should use single registers. 296 // TODO: handle the case of multiple registers. 297 if (VRegs[i].size() > 1) 298 return false; 299 auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST); 300 GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF()); 301 ArgTypeVRegs.push_back(SpirvTy); 302 303 if (Arg.hasName()) 304 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 305 if (isPointerTy(Arg.getType())) { 306 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 307 if (DerefBytes != 0) 308 buildOpDecorate(VRegs[i][0], MIRBuilder, 309 SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 310 } 311 if (Arg.hasAttribute(Attribute::Alignment)) { 312 auto Alignment = static_cast<unsigned>( 313 Arg.getAttribute(Attribute::Alignment).getValueAsInt()); 314 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 315 {Alignment}); 316 } 317 if (Arg.hasAttribute(Attribute::ReadOnly)) { 318 auto Attr = 319 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 320 buildOpDecorate(VRegs[i][0], MIRBuilder, 321 SPIRV::Decoration::FuncParamAttr, {Attr}); 322 } 323 if (Arg.hasAttribute(Attribute::ZExt)) { 324 auto Attr = 325 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 326 buildOpDecorate(VRegs[i][0], MIRBuilder, 327 SPIRV::Decoration::FuncParamAttr, {Attr}); 328 } 329 if (Arg.hasAttribute(Attribute::NoAlias)) { 330 auto Attr = 331 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); 332 buildOpDecorate(VRegs[i][0], MIRBuilder, 333 SPIRV::Decoration::FuncParamAttr, {Attr}); 334 } 335 if (Arg.hasAttribute(Attribute::ByVal)) { 336 auto Attr = 337 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal); 338 buildOpDecorate(VRegs[i][0], MIRBuilder, 339 SPIRV::Decoration::FuncParamAttr, {Attr}); 340 } 341 342 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 343 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs = 344 getKernelArgTypeQual(F, i); 345 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) 346 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); 347 } 348 349 MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); 350 if (Node && i < Node->getNumOperands() && 351 isa<MDNode>(Node->getOperand(i))) { 352 MDNode *MD = cast<MDNode>(Node->getOperand(i)); 353 for (const MDOperand &MDOp : MD->operands()) { 354 MDNode *MD2 = dyn_cast<MDNode>(MDOp); 355 assert(MD2 && "Metadata operand is expected"); 356 ConstantInt *Const = getConstInt(MD2, 0); 357 assert(Const && "MDOperand should be ConstantInt"); 358 auto Dec = 359 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue()); 360 std::vector<uint32_t> DecVec; 361 for (unsigned j = 1; j < MD2->getNumOperands(); j++) { 362 ConstantInt *Const = getConstInt(MD2, j); 363 assert(Const && "MDOperand should be ConstantInt"); 364 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); 365 } 366 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); 367 } 368 } 369 ++i; 370 } 371 } 372 373 auto MRI = MIRBuilder.getMRI(); 374 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 375 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); 376 if (F.isDeclaration()) 377 GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 378 FunctionType *FTy = getOriginalFunctionType(F); 379 Type *FRetTy = FTy->getReturnType(); 380 if (isUntypedPointerTy(FRetTy)) { 381 if (Type *FRetElemTy = GR->findDeducedElementType(&F)) { 382 TypedPointerType *DerivedTy = TypedPointerType::get( 383 toTypedPointer(FRetElemTy), getPointerAddressSpace(FRetTy)); 384 GR->addReturnType(&F, DerivedTy); 385 FRetTy = DerivedTy; 386 } 387 } 388 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder); 389 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs); 390 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 391 FTy, RetTy, ArgTypeVRegs, MIRBuilder); 392 uint32_t FuncControl = getFunctionControl(F); 393 394 // Add OpFunction instruction 395 MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction) 396 .addDef(FuncVReg) 397 .addUse(GR->getSPIRVTypeID(RetTy)) 398 .addImm(FuncControl) 399 .addUse(GR->getSPIRVTypeID(FuncTy)); 400 GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0)); 401 402 // Add OpFunctionParameter instructions 403 int i = 0; 404 for (const auto &Arg : F.args()) { 405 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 406 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); 407 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 408 .addDef(VRegs[i][0]) 409 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); 410 if (F.isDeclaration()) 411 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); 412 i++; 413 } 414 // Name the function. 415 if (F.hasName()) 416 buildOpName(FuncVReg, F.getName(), MIRBuilder); 417 418 // Handle entry points and function linkage. 419 if (isEntryPoint(F)) { 420 const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>(); 421 auto executionModel = getExecutionModel(STI, F); 422 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 423 .addImm(static_cast<uint32_t>(executionModel)) 424 .addUse(FuncVReg); 425 addStringImm(F.getName(), MIB); 426 } else if (F.getLinkage() != GlobalValue::InternalLinkage && 427 F.getLinkage() != GlobalValue::PrivateLinkage) { 428 SPIRV::LinkageType::LinkageType LnkTy = 429 F.isDeclaration() 430 ? SPIRV::LinkageType::Import 431 : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage && 432 ST->canUseExtension( 433 SPIRV::Extension::SPV_KHR_linkonce_odr) 434 ? SPIRV::LinkageType::LinkOnceODR 435 : SPIRV::LinkageType::Export); 436 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 437 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 438 } 439 440 // Handle function pointers decoration 441 bool hasFunctionPointers = 442 ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 443 if (hasFunctionPointers) { 444 if (F.hasFnAttribute("referenced-indirectly")) { 445 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) && 446 "Unexpected 'referenced-indirectly' attribute of the kernel " 447 "function"); 448 buildOpDecorate(FuncVReg, MIRBuilder, 449 SPIRV::Decoration::ReferencedIndirectlyINTEL, {}); 450 } 451 } 452 453 return true; 454 } 455 456 // Used to postpone producing of indirect function pointer types after all 457 // indirect calls info is collected 458 // TODO: 459 // - add a topological sort of IndirectCalls to ensure the best types knowledge 460 // - we may need to fix function formal parameter types if they are opaque 461 // pointers used as function pointers in these indirect calls 462 void SPIRVCallLowering::produceIndirectPtrTypes( 463 MachineIRBuilder &MIRBuilder) const { 464 // Create indirect call data types if any 465 MachineFunction &MF = MIRBuilder.getMF(); 466 for (auto const &IC : IndirectCalls) { 467 SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder); 468 SmallVector<SPIRVType *, 4> SpirvArgTypes; 469 for (size_t i = 0; i < IC.ArgTys.size(); ++i) { 470 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder); 471 SpirvArgTypes.push_back(SPIRVTy); 472 if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i])) 473 GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF); 474 } 475 // SPIR-V function type: 476 FunctionType *FTy = 477 FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false); 478 SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 479 FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder); 480 // SPIR-V pointer to function type: 481 SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType( 482 SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function); 483 // Correct the Callee type 484 GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF); 485 } 486 } 487 488 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 489 CallLoweringInfo &Info) const { 490 // Currently call returns should have single vregs. 491 // TODO: handle the case of multiple registers. 492 if (Info.OrigRet.Regs.size() > 1) 493 return false; 494 MachineFunction &MF = MIRBuilder.getMF(); 495 GR->setCurrentFunc(MF); 496 const Function *CF = nullptr; 497 std::string DemangledName; 498 const Type *OrigRetTy = Info.OrigRet.Ty; 499 500 // Emit a regular OpFunctionCall. If it's an externally declared function, 501 // be sure to emit its type and function declaration here. It will be hoisted 502 // globally later. 503 if (Info.Callee.isGlobal()) { 504 std::string FuncName = Info.Callee.getGlobal()->getName().str(); 505 DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); 506 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 507 // TODO: support constexpr casts and indirect calls. 508 if (CF == nullptr) 509 return false; 510 if (FunctionType *FTy = getOriginalFunctionType(*CF)) { 511 OrigRetTy = FTy->getReturnType(); 512 if (isUntypedPointerTy(OrigRetTy)) { 513 if (auto *DerivedRetTy = GR->findReturnType(CF)) 514 OrigRetTy = DerivedRetTy; 515 } 516 } 517 } 518 519 MachineRegisterInfo *MRI = MIRBuilder.getMRI(); 520 Register ResVReg = 521 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 522 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget()); 523 524 bool isFunctionDecl = CF && CF->isDeclaration(); 525 bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std); 526 bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450); 527 assert(canUseGLSL != canUseOpenCL && 528 "Scenario where both sets are enabled is not supported."); 529 530 if (isFunctionDecl && !DemangledName.empty() && 531 (canUseGLSL || canUseOpenCL)) { 532 SmallVector<Register, 8> ArgVRegs; 533 for (auto Arg : Info.OrigArgs) { 534 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 535 ArgVRegs.push_back(Arg.Regs[0]); 536 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); 537 if (!GR->getSPIRVTypeForVReg(Arg.Regs[0])) 538 GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF); 539 } 540 auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std 541 : SPIRV::InstructionSet::GLSL_std_450; 542 if (auto Res = 543 SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder, 544 ResVReg, OrigRetTy, ArgVRegs, GR)) 545 return *Res; 546 } 547 548 if (isFunctionDecl && !GR->find(CF, &MF).isValid()) { 549 // Emit the type info and forward function declaration to the first MBB 550 // to ensure VReg definition dependencies are valid across all MBBs. 551 MachineIRBuilder FirstBlockBuilder; 552 FirstBlockBuilder.setMF(MF); 553 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); 554 555 SmallVector<ArrayRef<Register>, 8> VRegArgs; 556 SmallVector<SmallVector<Register, 1>, 8> ToInsert; 557 for (const Argument &Arg : CF->args()) { 558 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 559 continue; // Don't handle zero sized types. 560 Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 561 MRI->setRegClass(Reg, &SPIRV::IDRegClass); 562 ToInsert.push_back({Reg}); 563 VRegArgs.push_back(ToInsert.back()); 564 } 565 // TODO: Reuse FunctionLoweringInfo 566 FunctionLoweringInfo FuncInfo; 567 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); 568 } 569 570 unsigned CallOp; 571 if (Info.CB->isIndirectCall()) { 572 if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) 573 report_fatal_error("An indirect call is encountered but SPIR-V without " 574 "extensions does not support it", 575 false); 576 // Set instruction operation according to SPV_INTEL_function_pointers 577 CallOp = SPIRV::OpFunctionPointerCallINTEL; 578 // Collect information about the indirect call to support possible 579 // specification of opaque ptr types of parent function's parameters 580 Register CalleeReg = Info.Callee.getReg(); 581 if (CalleeReg.isValid()) { 582 SPIRVCallLowering::SPIRVIndirectCall IndirectCall; 583 IndirectCall.Callee = CalleeReg; 584 IndirectCall.RetTy = OrigRetTy; 585 for (const auto &Arg : Info.OrigArgs) { 586 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 587 IndirectCall.ArgTys.push_back(Arg.Ty); 588 IndirectCall.ArgRegs.push_back(Arg.Regs[0]); 589 } 590 IndirectCalls.push_back(IndirectCall); 591 } 592 } else { 593 // Emit a regular OpFunctionCall 594 CallOp = SPIRV::OpFunctionCall; 595 } 596 597 // Make sure there's a valid return reg, even for functions returning void. 598 if (!ResVReg.isValid()) 599 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 600 SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder); 601 602 // Emit the call instruction and its args. 603 auto MIB = MIRBuilder.buildInstr(CallOp) 604 .addDef(ResVReg) 605 .addUse(GR->getSPIRVTypeID(RetType)) 606 .add(Info.Callee); 607 608 for (const auto &Arg : Info.OrigArgs) { 609 // Currently call args should have single vregs. 610 if (Arg.Regs.size() > 1) 611 return false; 612 MIB.addUse(Arg.Regs[0]); 613 } 614 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), 615 *ST->getRegBankInfo()); 616 } 617