1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of LLVM calls to machine code calls for 10 // GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRVCallLowering.h" 15 #include "MCTargetDesc/SPIRVBaseInfo.h" 16 #include "SPIRV.h" 17 #include "SPIRVBuiltins.h" 18 #include "SPIRVGlobalRegistry.h" 19 #include "SPIRVISelLowering.h" 20 #include "SPIRVRegisterInfo.h" 21 #include "SPIRVSubtarget.h" 22 #include "SPIRVUtils.h" 23 #include "llvm/CodeGen/FunctionLoweringInfo.h" 24 #include "llvm/Support/ModRef.h" 25 26 using namespace llvm; 27 28 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 29 SPIRVGlobalRegistry *GR) 30 : CallLowering(&TLI), GR(GR) {} 31 32 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 33 const Value *Val, ArrayRef<Register> VRegs, 34 FunctionLoweringInfo &FLI, 35 Register SwiftErrorVReg) const { 36 // Currently all return types should use a single register. 37 // TODO: handle the case of multiple registers. 38 if (VRegs.size() > 1) 39 return false; 40 if (Val) { 41 const auto &STI = MIRBuilder.getMF().getSubtarget(); 42 return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 43 .addUse(VRegs[0]) 44 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 45 *STI.getRegBankInfo()); 46 } 47 MIRBuilder.buildInstr(SPIRV::OpReturn); 48 return true; 49 } 50 51 // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 52 static uint32_t getFunctionControl(const Function &F) { 53 MemoryEffects MemEffects = F.getMemoryEffects(); 54 55 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 56 57 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) 58 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 59 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) 60 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 61 62 if (MemEffects.doesNotAccessMemory()) 63 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 64 else if (MemEffects.onlyReadsMemory()) 65 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 66 67 return FuncControl; 68 } 69 70 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { 71 if (MD->getNumOperands() > NumOp) { 72 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); 73 if (CMeta) 74 return dyn_cast<ConstantInt>(CMeta->getValue()); 75 } 76 return nullptr; 77 } 78 79 // This code restores function args/retvalue types for composite cases 80 // because the final types should still be aggregate whereas they're i32 81 // during the translation to cope with aggregate flattening etc. 82 static FunctionType *getOriginalFunctionType(const Function &F) { 83 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); 84 if (NamedMD == nullptr) 85 return F.getFunctionType(); 86 87 Type *RetTy = F.getFunctionType()->getReturnType(); 88 SmallVector<Type *, 4> ArgTypes; 89 for (auto &Arg : F.args()) 90 ArgTypes.push_back(Arg.getType()); 91 92 auto ThisFuncMDIt = 93 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { 94 return isa<MDString>(N->getOperand(0)) && 95 cast<MDString>(N->getOperand(0))->getString() == F.getName(); 96 }); 97 // TODO: probably one function can have numerous type mutations, 98 // so we should support this. 99 if (ThisFuncMDIt != NamedMD->op_end()) { 100 auto *ThisFuncMD = *ThisFuncMDIt; 101 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); 102 assert(MD && "MDNode operand is expected"); 103 ConstantInt *Const = getConstInt(MD, 0); 104 if (Const) { 105 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); 106 assert(CMeta && "ConstantAsMetadata operand is expected"); 107 assert(Const->getSExtValue() >= -1); 108 // Currently -1 indicates return value, greater values mean 109 // argument numbers. 110 if (Const->getSExtValue() == -1) 111 RetTy = CMeta->getType(); 112 else 113 ArgTypes[Const->getSExtValue()] = CMeta->getType(); 114 } 115 } 116 117 return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); 118 } 119 120 static MDString *getKernelArgAttribute(const Function &KernelFunction, 121 unsigned ArgIdx, 122 const StringRef AttributeName) { 123 assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && 124 "Kernel attributes are attached/belong only to kernel functions"); 125 126 // Lookup the argument attribute in metadata attached to the kernel function. 127 MDNode *Node = KernelFunction.getMetadata(AttributeName); 128 if (Node && ArgIdx < Node->getNumOperands()) 129 return cast<MDString>(Node->getOperand(ArgIdx)); 130 131 // Sometimes metadata containing kernel attributes is not attached to the 132 // function, but can be found in the named module-level metadata instead. 133 // For example: 134 // !opencl.kernels = !{!0} 135 // !0 = !{void ()* @someKernelFunction, !1, ...} 136 // !1 = !{!"kernel_arg_addr_space", ...} 137 // In this case the actual index of searched argument attribute is ArgIdx + 1, 138 // since the first metadata node operand is occupied by attribute name 139 // ("kernel_arg_addr_space" in the example above). 140 unsigned MDArgIdx = ArgIdx + 1; 141 NamedMDNode *OpenCLKernelsMD = 142 KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); 143 if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) 144 return nullptr; 145 146 // KernelToMDNodeList contains kernel function declarations followed by 147 // corresponding MDNodes for each attribute. Search only MDNodes "belonging" 148 // to the currently lowered kernel function. 149 MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); 150 bool FoundLoweredKernelFunction = false; 151 for (const MDOperand &Operand : KernelToMDNodeList->operands()) { 152 ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand); 153 if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() == 154 KernelFunction.getName()) { 155 FoundLoweredKernelFunction = true; 156 continue; 157 } 158 if (MaybeValue && FoundLoweredKernelFunction) 159 return nullptr; 160 161 MDNode *MaybeNode = dyn_cast<MDNode>(Operand); 162 if (FoundLoweredKernelFunction && MaybeNode && 163 cast<MDString>(MaybeNode->getOperand(0))->getString() == 164 AttributeName && 165 MDArgIdx < MaybeNode->getNumOperands()) 166 return cast<MDString>(MaybeNode->getOperand(MDArgIdx)); 167 } 168 return nullptr; 169 } 170 171 static SPIRV::AccessQualifier::AccessQualifier 172 getArgAccessQual(const Function &F, unsigned ArgIdx) { 173 if (F.getCallingConv() != CallingConv::SPIR_KERNEL) 174 return SPIRV::AccessQualifier::ReadWrite; 175 176 MDString *ArgAttribute = 177 getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); 178 if (!ArgAttribute) 179 return SPIRV::AccessQualifier::ReadWrite; 180 181 if (ArgAttribute->getString().compare("read_only") == 0) 182 return SPIRV::AccessQualifier::ReadOnly; 183 if (ArgAttribute->getString().compare("write_only") == 0) 184 return SPIRV::AccessQualifier::WriteOnly; 185 return SPIRV::AccessQualifier::ReadWrite; 186 } 187 188 static std::vector<SPIRV::Decoration::Decoration> 189 getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { 190 MDString *ArgAttribute = 191 getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); 192 if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0) 193 return {SPIRV::Decoration::Volatile}; 194 return {}; 195 } 196 197 static Type *getArgType(const Function &F, unsigned ArgIdx) { 198 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); 199 if (F.getCallingConv() != CallingConv::SPIR_KERNEL || 200 isSpecialOpaqueType(OriginalArgType)) 201 return OriginalArgType; 202 203 MDString *MDKernelArgType = 204 getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); 205 if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t")) 206 return OriginalArgType; 207 208 std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str(); 209 Type *ExistingOpaqueType = 210 StructType::getTypeByName(F.getContext(), KernelArgTypeStr); 211 return ExistingOpaqueType 212 ? ExistingOpaqueType 213 : StructType::create(F.getContext(), KernelArgTypeStr); 214 } 215 216 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 217 const Function &F, 218 ArrayRef<ArrayRef<Register>> VRegs, 219 FunctionLoweringInfo &FLI) const { 220 assert(GR && "Must initialize the SPIRV type registry before lowering args."); 221 GR->setCurrentFunc(MIRBuilder.getMF()); 222 223 // Assign types and names to all args, and store their types for later. 224 FunctionType *FTy = getOriginalFunctionType(F); 225 SmallVector<SPIRVType *, 4> ArgTypeVRegs; 226 if (VRegs.size() > 0) { 227 unsigned i = 0; 228 for (const auto &Arg : F.args()) { 229 // Currently formal args should use single registers. 230 // TODO: handle the case of multiple registers. 231 if (VRegs[i].size() > 1) 232 return false; 233 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = 234 getArgAccessQual(F, i); 235 auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0], 236 MIRBuilder, ArgAccessQual); 237 ArgTypeVRegs.push_back(SpirvTy); 238 239 if (Arg.hasName()) 240 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 241 if (Arg.getType()->isPointerTy()) { 242 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 243 if (DerefBytes != 0) 244 buildOpDecorate(VRegs[i][0], MIRBuilder, 245 SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 246 } 247 if (Arg.hasAttribute(Attribute::Alignment)) { 248 auto Alignment = static_cast<unsigned>( 249 Arg.getAttribute(Attribute::Alignment).getValueAsInt()); 250 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 251 {Alignment}); 252 } 253 if (Arg.hasAttribute(Attribute::ReadOnly)) { 254 auto Attr = 255 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 256 buildOpDecorate(VRegs[i][0], MIRBuilder, 257 SPIRV::Decoration::FuncParamAttr, {Attr}); 258 } 259 if (Arg.hasAttribute(Attribute::ZExt)) { 260 auto Attr = 261 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 262 buildOpDecorate(VRegs[i][0], MIRBuilder, 263 SPIRV::Decoration::FuncParamAttr, {Attr}); 264 } 265 if (Arg.hasAttribute(Attribute::NoAlias)) { 266 auto Attr = 267 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); 268 buildOpDecorate(VRegs[i][0], MIRBuilder, 269 SPIRV::Decoration::FuncParamAttr, {Attr}); 270 } 271 272 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 273 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs = 274 getKernelArgTypeQual(F, i); 275 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) 276 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); 277 } 278 279 MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); 280 if (Node && i < Node->getNumOperands() && 281 isa<MDNode>(Node->getOperand(i))) { 282 MDNode *MD = cast<MDNode>(Node->getOperand(i)); 283 for (const MDOperand &MDOp : MD->operands()) { 284 MDNode *MD2 = dyn_cast<MDNode>(MDOp); 285 assert(MD2 && "Metadata operand is expected"); 286 ConstantInt *Const = getConstInt(MD2, 0); 287 assert(Const && "MDOperand should be ConstantInt"); 288 auto Dec = 289 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue()); 290 std::vector<uint32_t> DecVec; 291 for (unsigned j = 1; j < MD2->getNumOperands(); j++) { 292 ConstantInt *Const = getConstInt(MD2, j); 293 assert(Const && "MDOperand should be ConstantInt"); 294 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); 295 } 296 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); 297 } 298 } 299 ++i; 300 } 301 } 302 303 // Generate a SPIR-V type for the function. 304 auto MRI = MIRBuilder.getMRI(); 305 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 306 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); 307 if (F.isDeclaration()) 308 GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 309 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); 310 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 311 FTy, RetTy, ArgTypeVRegs, MIRBuilder); 312 313 // Build the OpTypeFunction declaring it. 314 uint32_t FuncControl = getFunctionControl(F); 315 316 MIRBuilder.buildInstr(SPIRV::OpFunction) 317 .addDef(FuncVReg) 318 .addUse(GR->getSPIRVTypeID(RetTy)) 319 .addImm(FuncControl) 320 .addUse(GR->getSPIRVTypeID(FuncTy)); 321 322 // Add OpFunctionParameters. 323 int i = 0; 324 for (const auto &Arg : F.args()) { 325 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 326 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); 327 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 328 .addDef(VRegs[i][0]) 329 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); 330 if (F.isDeclaration()) 331 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); 332 i++; 333 } 334 // Name the function. 335 if (F.hasName()) 336 buildOpName(FuncVReg, F.getName(), MIRBuilder); 337 338 // Handle entry points and function linkage. 339 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 340 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 341 .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel)) 342 .addUse(FuncVReg); 343 addStringImm(F.getName(), MIB); 344 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || 345 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { 346 auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import 347 : SPIRV::LinkageType::Export; 348 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 349 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 350 } 351 352 return true; 353 } 354 355 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 356 CallLoweringInfo &Info) const { 357 // Currently call returns should have single vregs. 358 // TODO: handle the case of multiple registers. 359 if (Info.OrigRet.Regs.size() > 1) 360 return false; 361 MachineFunction &MF = MIRBuilder.getMF(); 362 GR->setCurrentFunc(MF); 363 FunctionType *FTy = nullptr; 364 const Function *CF = nullptr; 365 366 // Emit a regular OpFunctionCall. If it's an externally declared function, 367 // be sure to emit its type and function declaration here. It will be hoisted 368 // globally later. 369 if (Info.Callee.isGlobal()) { 370 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 371 // TODO: support constexpr casts and indirect calls. 372 if (CF == nullptr) 373 return false; 374 FTy = getOriginalFunctionType(*CF); 375 } 376 377 MachineRegisterInfo *MRI = MIRBuilder.getMRI(); 378 Register ResVReg = 379 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 380 std::string FuncName = Info.Callee.getGlobal()->getName().str(); 381 std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); 382 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget()); 383 // TODO: check that it's OCL builtin, then apply OpenCL_std. 384 if (!DemangledName.empty() && CF && CF->isDeclaration() && 385 ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 386 const Type *OrigRetTy = Info.OrigRet.Ty; 387 if (FTy) 388 OrigRetTy = FTy->getReturnType(); 389 SmallVector<Register, 8> ArgVRegs; 390 for (auto Arg : Info.OrigArgs) { 391 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 392 ArgVRegs.push_back(Arg.Regs[0]); 393 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); 394 GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); 395 } 396 if (auto Res = SPIRV::lowerBuiltin( 397 DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder, 398 ResVReg, OrigRetTy, ArgVRegs, GR)) 399 return *Res; 400 } 401 if (CF && CF->isDeclaration() && 402 !GR->find(CF, &MIRBuilder.getMF()).isValid()) { 403 // Emit the type info and forward function declaration to the first MBB 404 // to ensure VReg definition dependencies are valid across all MBBs. 405 MachineIRBuilder FirstBlockBuilder; 406 FirstBlockBuilder.setMF(MF); 407 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); 408 409 SmallVector<ArrayRef<Register>, 8> VRegArgs; 410 SmallVector<SmallVector<Register, 1>, 8> ToInsert; 411 for (const Argument &Arg : CF->args()) { 412 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 413 continue; // Don't handle zero sized types. 414 Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 415 MRI->setRegClass(Reg, &SPIRV::IDRegClass); 416 ToInsert.push_back({Reg}); 417 VRegArgs.push_back(ToInsert.back()); 418 } 419 // TODO: Reuse FunctionLoweringInfo 420 FunctionLoweringInfo FuncInfo; 421 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); 422 } 423 424 // Make sure there's a valid return reg, even for functions returning void. 425 if (!ResVReg.isValid()) 426 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 427 SPIRVType *RetType = 428 GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); 429 430 // Emit the OpFunctionCall and its args. 431 auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) 432 .addDef(ResVReg) 433 .addUse(GR->getSPIRVTypeID(RetType)) 434 .add(Info.Callee); 435 436 for (const auto &Arg : Info.OrigArgs) { 437 // Currently call args should have single vregs. 438 if (Arg.Regs.size() > 1) 439 return false; 440 MIB.addUse(Arg.Regs[0]); 441 } 442 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), 443 *ST->getRegBankInfo()); 444 } 445