1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of LLVM calls to machine code calls for 10 // GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRVCallLowering.h" 15 #include "MCTargetDesc/SPIRVBaseInfo.h" 16 #include "SPIRV.h" 17 #include "SPIRVBuiltins.h" 18 #include "SPIRVGlobalRegistry.h" 19 #include "SPIRVISelLowering.h" 20 #include "SPIRVRegisterInfo.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVUtils.h" 23 #include "llvm/CodeGen/FunctionLoweringInfo.h" 24 #include "llvm/Support/ModRef.h" 25 26 using namespace llvm; 27 28 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 29 SPIRVGlobalRegistry *GR) 30 : CallLowering(&TLI), GR(GR) {} 31 32 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 33 const Value *Val, ArrayRef<Register> VRegs, 34 FunctionLoweringInfo &FLI, 35 Register SwiftErrorVReg) const { 36 // Currently all return types should use a single register. 37 // TODO: handle the case of multiple registers. 38 if (VRegs.size() > 1) 39 return false; 40 if (Val) { 41 const auto &STI = MIRBuilder.getMF().getSubtarget(); 42 return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 43 .addUse(VRegs[0]) 44 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 45 *STI.getRegBankInfo()); 46 } 47 MIRBuilder.buildInstr(SPIRV::OpReturn); 48 return true; 49 } 50 51 // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 52 static uint32_t getFunctionControl(const Function &F) { 53 MemoryEffects MemEffects = F.getMemoryEffects(); 54 55 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 56 57 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) 58 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 59 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) 60 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 61 62 if (MemEffects.doesNotAccessMemory()) 63 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 64 else if (MemEffects.onlyReadsMemory()) 65 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 66 67 return FuncControl; 68 } 69 70 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { 71 if (MD->getNumOperands() > NumOp) { 72 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); 73 if (CMeta) 74 return dyn_cast<ConstantInt>(CMeta->getValue()); 75 } 76 return nullptr; 77 } 78 79 // This code restores function args/retvalue types for composite cases 80 // because the final types should still be aggregate whereas they're i32 81 // during the translation to cope with aggregate flattening etc. 82 static FunctionType *getOriginalFunctionType(const Function &F) { 83 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); 84 if (NamedMD == nullptr) 85 return F.getFunctionType(); 86 87 Type *RetTy = F.getFunctionType()->getReturnType(); 88 SmallVector<Type *, 4> ArgTypes; 89 for (auto &Arg : F.args()) 90 ArgTypes.push_back(Arg.getType()); 91 92 auto ThisFuncMDIt = 93 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { 94 return isa<MDString>(N->getOperand(0)) && 95 cast<MDString>(N->getOperand(0))->getString() == F.getName(); 96 }); 97 // TODO: probably one function can have numerous type mutations, 98 // so we should support this. 99 if (ThisFuncMDIt != NamedMD->op_end()) { 100 auto *ThisFuncMD = *ThisFuncMDIt; 101 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); 102 assert(MD && "MDNode operand is expected"); 103 ConstantInt *Const = getConstInt(MD, 0); 104 if (Const) { 105 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); 106 assert(CMeta && "ConstantAsMetadata operand is expected"); 107 assert(Const->getSExtValue() >= -1); 108 // Currently -1 indicates return value, greater values mean 109 // argument numbers. 110 if (Const->getSExtValue() == -1) 111 RetTy = CMeta->getType(); 112 else 113 ArgTypes[Const->getSExtValue()] = CMeta->getType(); 114 } 115 } 116 117 return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); 118 } 119 120 static MDString *getKernelArgAttribute(const Function &KernelFunction, 121 unsigned ArgIdx, 122 const StringRef AttributeName) { 123 assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && 124 "Kernel attributes are attached/belong only to kernel functions"); 125 126 // Lookup the argument attribute in metadata attached to the kernel function. 127 MDNode *Node = KernelFunction.getMetadata(AttributeName); 128 if (Node && ArgIdx < Node->getNumOperands()) 129 return cast<MDString>(Node->getOperand(ArgIdx)); 130 131 // Sometimes metadata containing kernel attributes is not attached to the 132 // function, but can be found in the named module-level metadata instead. 133 // For example: 134 // !opencl.kernels = !{!0} 135 // !0 = !{void ()* @someKernelFunction, !1, ...} 136 // !1 = !{!"kernel_arg_addr_space", ...} 137 // In this case the actual index of searched argument attribute is ArgIdx + 1, 138 // since the first metadata node operand is occupied by attribute name 139 // ("kernel_arg_addr_space" in the example above). 140 unsigned MDArgIdx = ArgIdx + 1; 141 NamedMDNode *OpenCLKernelsMD = 142 KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); 143 if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) 144 return nullptr; 145 146 // KernelToMDNodeList contains kernel function declarations followed by 147 // corresponding MDNodes for each attribute. Search only MDNodes "belonging" 148 // to the currently lowered kernel function. 149 MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); 150 bool FoundLoweredKernelFunction = false; 151 for (const MDOperand &Operand : KernelToMDNodeList->operands()) { 152 ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand); 153 if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() == 154 KernelFunction.getName()) { 155 FoundLoweredKernelFunction = true; 156 continue; 157 } 158 if (MaybeValue && FoundLoweredKernelFunction) 159 return nullptr; 160 161 MDNode *MaybeNode = dyn_cast<MDNode>(Operand); 162 if (FoundLoweredKernelFunction && MaybeNode && 163 cast<MDString>(MaybeNode->getOperand(0))->getString() == 164 AttributeName && 165 MDArgIdx < MaybeNode->getNumOperands()) 166 return cast<MDString>(MaybeNode->getOperand(MDArgIdx)); 167 } 168 return nullptr; 169 } 170 171 static SPIRV::AccessQualifier::AccessQualifier 172 getArgAccessQual(const Function &F, unsigned ArgIdx) { 173 if (F.getCallingConv() != CallingConv::SPIR_KERNEL) 174 return SPIRV::AccessQualifier::ReadWrite; 175 176 MDString *ArgAttribute = 177 getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); 178 if (!ArgAttribute) 179 return SPIRV::AccessQualifier::ReadWrite; 180 181 if (ArgAttribute->getString().compare("read_only") == 0) 182 return SPIRV::AccessQualifier::ReadOnly; 183 if (ArgAttribute->getString().compare("write_only") == 0) 184 return SPIRV::AccessQualifier::WriteOnly; 185 return SPIRV::AccessQualifier::ReadWrite; 186 } 187 188 static std::vector<SPIRV::Decoration::Decoration> 189 getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { 190 MDString *ArgAttribute = 191 getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); 192 if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0) 193 return {SPIRV::Decoration::Volatile}; 194 return {}; 195 } 196 197 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, 198 SPIRVGlobalRegistry *GR, 199 MachineIRBuilder &MIRBuilder) { 200 // Read argument's access qualifier from metadata or default. 201 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = 202 getArgAccessQual(F, ArgIdx); 203 204 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); 205 206 // In case of non-kernel SPIR-V function or already TargetExtType, use the 207 // original IR type. 208 if (F.getCallingConv() != CallingConv::SPIR_KERNEL || 209 isSpecialOpaqueType(OriginalArgType)) 210 return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); 211 212 MDString *MDKernelArgType = 213 getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); 214 if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") && 215 !MDKernelArgType->getString().ends_with("_t"))) 216 return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); 217 218 if (MDKernelArgType->getString().ends_with("*")) 219 return GR->getOrCreateSPIRVTypeByName( 220 MDKernelArgType->getString(), MIRBuilder, 221 addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace())); 222 223 if (MDKernelArgType->getString().ends_with("_t")) 224 return GR->getOrCreateSPIRVTypeByName( 225 "opencl." + MDKernelArgType->getString().str(), MIRBuilder, 226 SPIRV::StorageClass::Function, ArgAccessQual); 227 228 llvm_unreachable("Unable to recognize argument type name."); 229 } 230 231 static bool isEntryPoint(const Function &F) { 232 // OpenCL handling: any function with the SPIR_KERNEL 233 // calling convention will be a potential entry point. 234 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) 235 return true; 236 237 // HLSL handling: special attribute are emitted from the 238 // front-end. 239 if (F.getFnAttribute("hlsl.shader").isValid()) 240 return true; 241 242 return false; 243 } 244 245 static SPIRV::ExecutionModel::ExecutionModel 246 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) { 247 if (STI.isOpenCLEnv()) 248 return SPIRV::ExecutionModel::Kernel; 249 250 auto attribute = F.getFnAttribute("hlsl.shader"); 251 if (!attribute.isValid()) { 252 report_fatal_error( 253 "This entry point lacks mandatory hlsl.shader attribute."); 254 } 255 256 const auto value = attribute.getValueAsString(); 257 if (value == "compute") 258 return SPIRV::ExecutionModel::GLCompute; 259 260 report_fatal_error("This HLSL entry point is not supported by this backend."); 261 } 262 263 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 264 const Function &F, 265 ArrayRef<ArrayRef<Register>> VRegs, 266 FunctionLoweringInfo &FLI) const { 267 assert(GR && "Must initialize the SPIRV type registry before lowering args."); 268 GR->setCurrentFunc(MIRBuilder.getMF()); 269 270 // Assign types and names to all args, and store their types for later. 271 FunctionType *FTy = getOriginalFunctionType(F); 272 SmallVector<SPIRVType *, 4> ArgTypeVRegs; 273 if (VRegs.size() > 0) { 274 unsigned i = 0; 275 for (const auto &Arg : F.args()) { 276 // Currently formal args should use single registers. 277 // TODO: handle the case of multiple registers. 278 if (VRegs[i].size() > 1) 279 return false; 280 auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder); 281 GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF()); 282 ArgTypeVRegs.push_back(SpirvTy); 283 284 if (Arg.hasName()) 285 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 286 if (Arg.getType()->isPointerTy()) { 287 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 288 if (DerefBytes != 0) 289 buildOpDecorate(VRegs[i][0], MIRBuilder, 290 SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 291 } 292 if (Arg.hasAttribute(Attribute::Alignment)) { 293 auto Alignment = static_cast<unsigned>( 294 Arg.getAttribute(Attribute::Alignment).getValueAsInt()); 295 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 296 {Alignment}); 297 } 298 if (Arg.hasAttribute(Attribute::ReadOnly)) { 299 auto Attr = 300 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 301 buildOpDecorate(VRegs[i][0], MIRBuilder, 302 SPIRV::Decoration::FuncParamAttr, {Attr}); 303 } 304 if (Arg.hasAttribute(Attribute::ZExt)) { 305 auto Attr = 306 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 307 buildOpDecorate(VRegs[i][0], MIRBuilder, 308 SPIRV::Decoration::FuncParamAttr, {Attr}); 309 } 310 if (Arg.hasAttribute(Attribute::NoAlias)) { 311 auto Attr = 312 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); 313 buildOpDecorate(VRegs[i][0], MIRBuilder, 314 SPIRV::Decoration::FuncParamAttr, {Attr}); 315 } 316 317 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 318 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs = 319 getKernelArgTypeQual(F, i); 320 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) 321 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); 322 } 323 324 MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); 325 if (Node && i < Node->getNumOperands() && 326 isa<MDNode>(Node->getOperand(i))) { 327 MDNode *MD = cast<MDNode>(Node->getOperand(i)); 328 for (const MDOperand &MDOp : MD->operands()) { 329 MDNode *MD2 = dyn_cast<MDNode>(MDOp); 330 assert(MD2 && "Metadata operand is expected"); 331 ConstantInt *Const = getConstInt(MD2, 0); 332 assert(Const && "MDOperand should be ConstantInt"); 333 auto Dec = 334 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue()); 335 std::vector<uint32_t> DecVec; 336 for (unsigned j = 1; j < MD2->getNumOperands(); j++) { 337 ConstantInt *Const = getConstInt(MD2, j); 338 assert(Const && "MDOperand should be ConstantInt"); 339 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); 340 } 341 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); 342 } 343 } 344 ++i; 345 } 346 } 347 348 // Generate a SPIR-V type for the function. 349 auto MRI = MIRBuilder.getMRI(); 350 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 351 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); 352 if (F.isDeclaration()) 353 GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 354 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); 355 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 356 FTy, RetTy, ArgTypeVRegs, MIRBuilder); 357 358 // Build the OpTypeFunction declaring it. 359 uint32_t FuncControl = getFunctionControl(F); 360 361 MIRBuilder.buildInstr(SPIRV::OpFunction) 362 .addDef(FuncVReg) 363 .addUse(GR->getSPIRVTypeID(RetTy)) 364 .addImm(FuncControl) 365 .addUse(GR->getSPIRVTypeID(FuncTy)); 366 367 // Add OpFunctionParameters. 368 int i = 0; 369 for (const auto &Arg : F.args()) { 370 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 371 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); 372 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 373 .addDef(VRegs[i][0]) 374 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); 375 if (F.isDeclaration()) 376 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); 377 i++; 378 } 379 // Name the function. 380 if (F.hasName()) 381 buildOpName(FuncVReg, F.getName(), MIRBuilder); 382 383 // Handle entry points and function linkage. 384 if (isEntryPoint(F)) { 385 const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>(); 386 auto executionModel = getExecutionModel(STI, F); 387 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 388 .addImm(static_cast<uint32_t>(executionModel)) 389 .addUse(FuncVReg); 390 addStringImm(F.getName(), MIB); 391 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || 392 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { 393 auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import 394 : SPIRV::LinkageType::Export; 395 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 396 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 397 } 398 399 return true; 400 } 401 402 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 403 CallLoweringInfo &Info) const { 404 // Currently call returns should have single vregs. 405 // TODO: handle the case of multiple registers. 406 if (Info.OrigRet.Regs.size() > 1) 407 return false; 408 MachineFunction &MF = MIRBuilder.getMF(); 409 GR->setCurrentFunc(MF); 410 FunctionType *FTy = nullptr; 411 const Function *CF = nullptr; 412 413 // Emit a regular OpFunctionCall. If it's an externally declared function, 414 // be sure to emit its type and function declaration here. It will be hoisted 415 // globally later. 416 if (Info.Callee.isGlobal()) { 417 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 418 // TODO: support constexpr casts and indirect calls. 419 if (CF == nullptr) 420 return false; 421 FTy = getOriginalFunctionType(*CF); 422 } 423 424 MachineRegisterInfo *MRI = MIRBuilder.getMRI(); 425 Register ResVReg = 426 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 427 std::string FuncName = Info.Callee.getGlobal()->getName().str(); 428 std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); 429 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget()); 430 // TODO: check that it's OCL builtin, then apply OpenCL_std. 431 if (!DemangledName.empty() && CF && CF->isDeclaration() && 432 ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 433 const Type *OrigRetTy = Info.OrigRet.Ty; 434 if (FTy) 435 OrigRetTy = FTy->getReturnType(); 436 SmallVector<Register, 8> ArgVRegs; 437 for (auto Arg : Info.OrigArgs) { 438 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 439 ArgVRegs.push_back(Arg.Regs[0]); 440 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); 441 if (!GR->getSPIRVTypeForVReg(Arg.Regs[0])) 442 GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); 443 } 444 if (auto Res = SPIRV::lowerBuiltin( 445 DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder, 446 ResVReg, OrigRetTy, ArgVRegs, GR)) 447 return *Res; 448 } 449 if (CF && CF->isDeclaration() && 450 !GR->find(CF, &MIRBuilder.getMF()).isValid()) { 451 // Emit the type info and forward function declaration to the first MBB 452 // to ensure VReg definition dependencies are valid across all MBBs. 453 MachineIRBuilder FirstBlockBuilder; 454 FirstBlockBuilder.setMF(MF); 455 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); 456 457 SmallVector<ArrayRef<Register>, 8> VRegArgs; 458 SmallVector<SmallVector<Register, 1>, 8> ToInsert; 459 for (const Argument &Arg : CF->args()) { 460 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 461 continue; // Don't handle zero sized types. 462 Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 463 MRI->setRegClass(Reg, &SPIRV::IDRegClass); 464 ToInsert.push_back({Reg}); 465 VRegArgs.push_back(ToInsert.back()); 466 } 467 // TODO: Reuse FunctionLoweringInfo 468 FunctionLoweringInfo FuncInfo; 469 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); 470 } 471 472 // Make sure there's a valid return reg, even for functions returning void. 473 if (!ResVReg.isValid()) 474 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 475 SPIRVType *RetType = 476 GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); 477 478 // Emit the OpFunctionCall and its args. 479 auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) 480 .addDef(ResVReg) 481 .addUse(GR->getSPIRVTypeID(RetType)) 482 .add(Info.Callee); 483 484 for (const auto &Arg : Info.OrigArgs) { 485 // Currently call args should have single vregs. 486 if (Arg.Regs.size() > 1) 487 return false; 488 MIB.addUse(Arg.Regs[0]); 489 } 490 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), 491 *ST->getRegBankInfo()); 492 } 493