1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVBuiltins.h"
18 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVISelLowering.h"
20 #include "SPIRVMetadata.h"
21 #include "SPIRVRegisterInfo.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/CodeGen/FunctionLoweringInfo.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include "llvm/Support/ModRef.h"
28
29 using namespace llvm;
30
SPIRVCallLowering(const SPIRVTargetLowering & TLI,SPIRVGlobalRegistry * GR)31 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32 SPIRVGlobalRegistry *GR)
33 : CallLowering(&TLI), GR(GR) {}
34
lowerReturn(MachineIRBuilder & MIRBuilder,const Value * Val,ArrayRef<Register> VRegs,FunctionLoweringInfo & FLI,Register SwiftErrorVReg) const35 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36 const Value *Val, ArrayRef<Register> VRegs,
37 FunctionLoweringInfo &FLI,
38 Register SwiftErrorVReg) const {
39 // Maybe run postponed production of types for function pointers
40 if (IndirectCalls.size() > 0) {
41 produceIndirectPtrTypes(MIRBuilder);
42 IndirectCalls.clear();
43 }
44
45 // Currently all return types should use a single register.
46 // TODO: handle the case of multiple registers.
47 if (VRegs.size() > 1)
48 return false;
49 if (Val) {
50 const auto &STI = MIRBuilder.getMF().getSubtarget();
51 return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
52 .addUse(VRegs[0])
53 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
54 *STI.getRegBankInfo());
55 }
56 MIRBuilder.buildInstr(SPIRV::OpReturn);
57 return true;
58 }
59
60 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
getFunctionControl(const Function & F)61 static uint32_t getFunctionControl(const Function &F) {
62 MemoryEffects MemEffects = F.getMemoryEffects();
63
64 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
65
66 if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
67 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
68 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
69 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
70
71 if (MemEffects.doesNotAccessMemory())
72 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
73 else if (MemEffects.onlyReadsMemory())
74 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
75
76 return FuncControl;
77 }
78
getConstInt(MDNode * MD,unsigned NumOp)79 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
80 if (MD->getNumOperands() > NumOp) {
81 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
82 if (CMeta)
83 return dyn_cast<ConstantInt>(CMeta->getValue());
84 }
85 return nullptr;
86 }
87
88 // If the function has pointer arguments, we are forced to re-create this
89 // function type from the very beginning, changing PointerType by
90 // TypedPointerType for each pointer argument. Otherwise, the same `Type*`
91 // potentially corresponds to different SPIR-V function type, effectively
92 // invalidating logic behind global registry and duplicates tracker.
93 static FunctionType *
fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry * GR,const Function & F,FunctionType * FTy,const SPIRVType * SRetTy,const SmallVector<SPIRVType *,4> & SArgTys)94 fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
95 FunctionType *FTy, const SPIRVType *SRetTy,
96 const SmallVector<SPIRVType *, 4> &SArgTys) {
97 if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
98 return FTy;
99
100 bool hasArgPtrs = false;
101 for (auto &Arg : F.args()) {
102 // check if it's an instance of a non-typed PointerType
103 if (Arg.getType()->isPointerTy()) {
104 hasArgPtrs = true;
105 break;
106 }
107 }
108 if (!hasArgPtrs) {
109 Type *RetTy = FTy->getReturnType();
110 // check if it's an instance of a non-typed PointerType
111 if (!RetTy->isPointerTy())
112 return FTy;
113 }
114
115 // re-create function type, using TypedPointerType instead of PointerType to
116 // properly trace argument types
117 const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
118 SmallVector<Type *, 4> ArgTys;
119 for (auto SArgTy : SArgTys)
120 ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
121 return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
122 }
123
124 // This code restores function args/retvalue types for composite cases
125 // because the final types should still be aggregate whereas they're i32
126 // during the translation to cope with aggregate flattening etc.
getOriginalFunctionType(const Function & F)127 static FunctionType *getOriginalFunctionType(const Function &F) {
128 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
129 if (NamedMD == nullptr)
130 return F.getFunctionType();
131
132 Type *RetTy = F.getFunctionType()->getReturnType();
133 SmallVector<Type *, 4> ArgTypes;
134 for (auto &Arg : F.args())
135 ArgTypes.push_back(Arg.getType());
136
137 auto ThisFuncMDIt =
138 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
139 return isa<MDString>(N->getOperand(0)) &&
140 cast<MDString>(N->getOperand(0))->getString() == F.getName();
141 });
142 // TODO: probably one function can have numerous type mutations,
143 // so we should support this.
144 if (ThisFuncMDIt != NamedMD->op_end()) {
145 auto *ThisFuncMD = *ThisFuncMDIt;
146 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
147 assert(MD && "MDNode operand is expected");
148 ConstantInt *Const = getConstInt(MD, 0);
149 if (Const) {
150 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
151 assert(CMeta && "ConstantAsMetadata operand is expected");
152 assert(Const->getSExtValue() >= -1);
153 // Currently -1 indicates return value, greater values mean
154 // argument numbers.
155 if (Const->getSExtValue() == -1)
156 RetTy = CMeta->getType();
157 else
158 ArgTypes[Const->getSExtValue()] = CMeta->getType();
159 }
160 }
161
162 return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
163 }
164
165 static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function & F,unsigned ArgIdx)166 getArgAccessQual(const Function &F, unsigned ArgIdx) {
167 if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
168 return SPIRV::AccessQualifier::ReadWrite;
169
170 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
171 if (!ArgAttribute)
172 return SPIRV::AccessQualifier::ReadWrite;
173
174 if (ArgAttribute->getString() == "read_only")
175 return SPIRV::AccessQualifier::ReadOnly;
176 if (ArgAttribute->getString() == "write_only")
177 return SPIRV::AccessQualifier::WriteOnly;
178 return SPIRV::AccessQualifier::ReadWrite;
179 }
180
181 static std::vector<SPIRV::Decoration::Decoration>
getKernelArgTypeQual(const Function & F,unsigned ArgIdx)182 getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
183 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
184 if (ArgAttribute && ArgAttribute->getString() == "volatile")
185 return {SPIRV::Decoration::Volatile};
186 return {};
187 }
188
getArgSPIRVType(const Function & F,unsigned ArgIdx,SPIRVGlobalRegistry * GR,MachineIRBuilder & MIRBuilder,const SPIRVSubtarget & ST)189 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
190 SPIRVGlobalRegistry *GR,
191 MachineIRBuilder &MIRBuilder,
192 const SPIRVSubtarget &ST) {
193 // Read argument's access qualifier from metadata or default.
194 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
195 getArgAccessQual(F, ArgIdx);
196
197 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
198
199 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
200 // be legally reassigned later).
201 if (!isPointerTy(OriginalArgType))
202 return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
203
204 Argument *Arg = F.getArg(ArgIdx);
205 Type *ArgType = Arg->getType();
206 if (isTypedPointerTy(ArgType)) {
207 SPIRVType *ElementType = GR->getOrCreateSPIRVType(
208 cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
209 return GR->getOrCreateSPIRVPointerType(
210 ElementType, MIRBuilder,
211 addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
212 }
213
214 // In case OriginalArgType is of untyped pointer type, there are three
215 // possibilities:
216 // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
217 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
218 // intrinsic assigning a TargetExtType.
219 // 3) This is a pointer, try to retrieve pointer element type from a
220 // spv_assign_ptr_type intrinsic or otherwise use default pointer element
221 // type.
222 if (hasPointeeTypeAttr(Arg)) {
223 SPIRVType *ElementType =
224 GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
225 return GR->getOrCreateSPIRVPointerType(
226 ElementType, MIRBuilder,
227 addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
228 }
229
230 for (auto User : Arg->users()) {
231 auto *II = dyn_cast<IntrinsicInst>(User);
232 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
233 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
234 MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
235 Type *BuiltinType =
236 cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
237 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
238 return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
239 }
240
241 // Check if this is spv_assign_ptr_type assigning pointer element type.
242 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
243 continue;
244
245 MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
246 Type *ElementTy =
247 toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
248 SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
249 return GR->getOrCreateSPIRVPointerType(
250 ElementType, MIRBuilder,
251 addressSpaceToStorageClass(
252 cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
253 }
254
255 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
256 // LLVM types in a consistent manner
257 return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder,
258 ArgAccessQual);
259 }
260
261 static SPIRV::ExecutionModel::ExecutionModel
getExecutionModel(const SPIRVSubtarget & STI,const Function & F)262 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
263 if (STI.isOpenCLEnv())
264 return SPIRV::ExecutionModel::Kernel;
265
266 auto attribute = F.getFnAttribute("hlsl.shader");
267 if (!attribute.isValid()) {
268 report_fatal_error(
269 "This entry point lacks mandatory hlsl.shader attribute.");
270 }
271
272 const auto value = attribute.getValueAsString();
273 if (value == "compute")
274 return SPIRV::ExecutionModel::GLCompute;
275
276 report_fatal_error("This HLSL entry point is not supported by this backend.");
277 }
278
lowerFormalArguments(MachineIRBuilder & MIRBuilder,const Function & F,ArrayRef<ArrayRef<Register>> VRegs,FunctionLoweringInfo & FLI) const279 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
280 const Function &F,
281 ArrayRef<ArrayRef<Register>> VRegs,
282 FunctionLoweringInfo &FLI) const {
283 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
284 GR->setCurrentFunc(MIRBuilder.getMF());
285
286 // Get access to information about available extensions
287 const SPIRVSubtarget *ST =
288 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
289
290 // Assign types and names to all args, and store their types for later.
291 SmallVector<SPIRVType *, 4> ArgTypeVRegs;
292 if (VRegs.size() > 0) {
293 unsigned i = 0;
294 for (const auto &Arg : F.args()) {
295 // Currently formal args should use single registers.
296 // TODO: handle the case of multiple registers.
297 if (VRegs[i].size() > 1)
298 return false;
299 auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
300 GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
301 ArgTypeVRegs.push_back(SpirvTy);
302
303 if (Arg.hasName())
304 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
305 if (isPointerTy(Arg.getType())) {
306 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
307 if (DerefBytes != 0)
308 buildOpDecorate(VRegs[i][0], MIRBuilder,
309 SPIRV::Decoration::MaxByteOffset, {DerefBytes});
310 }
311 if (Arg.hasAttribute(Attribute::Alignment)) {
312 auto Alignment = static_cast<unsigned>(
313 Arg.getAttribute(Attribute::Alignment).getValueAsInt());
314 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
315 {Alignment});
316 }
317 if (Arg.hasAttribute(Attribute::ReadOnly)) {
318 auto Attr =
319 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
320 buildOpDecorate(VRegs[i][0], MIRBuilder,
321 SPIRV::Decoration::FuncParamAttr, {Attr});
322 }
323 if (Arg.hasAttribute(Attribute::ZExt)) {
324 auto Attr =
325 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
326 buildOpDecorate(VRegs[i][0], MIRBuilder,
327 SPIRV::Decoration::FuncParamAttr, {Attr});
328 }
329 if (Arg.hasAttribute(Attribute::NoAlias)) {
330 auto Attr =
331 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
332 buildOpDecorate(VRegs[i][0], MIRBuilder,
333 SPIRV::Decoration::FuncParamAttr, {Attr});
334 }
335 if (Arg.hasAttribute(Attribute::ByVal)) {
336 auto Attr =
337 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
338 buildOpDecorate(VRegs[i][0], MIRBuilder,
339 SPIRV::Decoration::FuncParamAttr, {Attr});
340 }
341
342 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
343 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
344 getKernelArgTypeQual(F, i);
345 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
346 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
347 }
348
349 MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
350 if (Node && i < Node->getNumOperands() &&
351 isa<MDNode>(Node->getOperand(i))) {
352 MDNode *MD = cast<MDNode>(Node->getOperand(i));
353 for (const MDOperand &MDOp : MD->operands()) {
354 MDNode *MD2 = dyn_cast<MDNode>(MDOp);
355 assert(MD2 && "Metadata operand is expected");
356 ConstantInt *Const = getConstInt(MD2, 0);
357 assert(Const && "MDOperand should be ConstantInt");
358 auto Dec =
359 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
360 std::vector<uint32_t> DecVec;
361 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
362 ConstantInt *Const = getConstInt(MD2, j);
363 assert(Const && "MDOperand should be ConstantInt");
364 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
365 }
366 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
367 }
368 }
369 ++i;
370 }
371 }
372
373 auto MRI = MIRBuilder.getMRI();
374 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
375 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
376 if (F.isDeclaration())
377 GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
378 FunctionType *FTy = getOriginalFunctionType(F);
379 Type *FRetTy = FTy->getReturnType();
380 if (isUntypedPointerTy(FRetTy)) {
381 if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
382 TypedPointerType *DerivedTy = TypedPointerType::get(
383 toTypedPointer(FRetElemTy), getPointerAddressSpace(FRetTy));
384 GR->addReturnType(&F, DerivedTy);
385 FRetTy = DerivedTy;
386 }
387 }
388 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
389 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
390 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
391 FTy, RetTy, ArgTypeVRegs, MIRBuilder);
392 uint32_t FuncControl = getFunctionControl(F);
393
394 // Add OpFunction instruction
395 MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
396 .addDef(FuncVReg)
397 .addUse(GR->getSPIRVTypeID(RetTy))
398 .addImm(FuncControl)
399 .addUse(GR->getSPIRVTypeID(FuncTy));
400 GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
401
402 // Add OpFunctionParameter instructions
403 int i = 0;
404 for (const auto &Arg : F.args()) {
405 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
406 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
407 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
408 .addDef(VRegs[i][0])
409 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
410 if (F.isDeclaration())
411 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
412 i++;
413 }
414 // Name the function.
415 if (F.hasName())
416 buildOpName(FuncVReg, F.getName(), MIRBuilder);
417
418 // Handle entry points and function linkage.
419 if (isEntryPoint(F)) {
420 const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
421 auto executionModel = getExecutionModel(STI, F);
422 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
423 .addImm(static_cast<uint32_t>(executionModel))
424 .addUse(FuncVReg);
425 addStringImm(F.getName(), MIB);
426 } else if (F.getLinkage() != GlobalValue::InternalLinkage &&
427 F.getLinkage() != GlobalValue::PrivateLinkage) {
428 SPIRV::LinkageType::LinkageType LnkTy =
429 F.isDeclaration()
430 ? SPIRV::LinkageType::Import
431 : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
432 ST->canUseExtension(
433 SPIRV::Extension::SPV_KHR_linkonce_odr)
434 ? SPIRV::LinkageType::LinkOnceODR
435 : SPIRV::LinkageType::Export);
436 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
437 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
438 }
439
440 // Handle function pointers decoration
441 bool hasFunctionPointers =
442 ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
443 if (hasFunctionPointers) {
444 if (F.hasFnAttribute("referenced-indirectly")) {
445 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
446 "Unexpected 'referenced-indirectly' attribute of the kernel "
447 "function");
448 buildOpDecorate(FuncVReg, MIRBuilder,
449 SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
450 }
451 }
452
453 return true;
454 }
455
456 // Used to postpone producing of indirect function pointer types after all
457 // indirect calls info is collected
458 // TODO:
459 // - add a topological sort of IndirectCalls to ensure the best types knowledge
460 // - we may need to fix function formal parameter types if they are opaque
461 // pointers used as function pointers in these indirect calls
produceIndirectPtrTypes(MachineIRBuilder & MIRBuilder) const462 void SPIRVCallLowering::produceIndirectPtrTypes(
463 MachineIRBuilder &MIRBuilder) const {
464 // Create indirect call data types if any
465 MachineFunction &MF = MIRBuilder.getMF();
466 for (auto const &IC : IndirectCalls) {
467 SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
468 SmallVector<SPIRVType *, 4> SpirvArgTypes;
469 for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
470 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
471 SpirvArgTypes.push_back(SPIRVTy);
472 if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
473 GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
474 }
475 // SPIR-V function type:
476 FunctionType *FTy =
477 FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
478 SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
479 FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
480 // SPIR-V pointer to function type:
481 SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
482 SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
483 // Correct the Callee type
484 GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
485 }
486 }
487
lowerCall(MachineIRBuilder & MIRBuilder,CallLoweringInfo & Info) const488 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
489 CallLoweringInfo &Info) const {
490 // Currently call returns should have single vregs.
491 // TODO: handle the case of multiple registers.
492 if (Info.OrigRet.Regs.size() > 1)
493 return false;
494 MachineFunction &MF = MIRBuilder.getMF();
495 GR->setCurrentFunc(MF);
496 const Function *CF = nullptr;
497 std::string DemangledName;
498 const Type *OrigRetTy = Info.OrigRet.Ty;
499
500 // Emit a regular OpFunctionCall. If it's an externally declared function,
501 // be sure to emit its type and function declaration here. It will be hoisted
502 // globally later.
503 if (Info.Callee.isGlobal()) {
504 std::string FuncName = Info.Callee.getGlobal()->getName().str();
505 DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
506 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
507 // TODO: support constexpr casts and indirect calls.
508 if (CF == nullptr)
509 return false;
510 if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
511 OrigRetTy = FTy->getReturnType();
512 if (isUntypedPointerTy(OrigRetTy)) {
513 if (auto *DerivedRetTy = GR->findReturnType(CF))
514 OrigRetTy = DerivedRetTy;
515 }
516 }
517 }
518
519 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
520 Register ResVReg =
521 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
522 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
523
524 bool isFunctionDecl = CF && CF->isDeclaration();
525 bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
526 bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
527 assert(canUseGLSL != canUseOpenCL &&
528 "Scenario where both sets are enabled is not supported.");
529
530 if (isFunctionDecl && !DemangledName.empty() &&
531 (canUseGLSL || canUseOpenCL)) {
532 SmallVector<Register, 8> ArgVRegs;
533 for (auto Arg : Info.OrigArgs) {
534 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
535 ArgVRegs.push_back(Arg.Regs[0]);
536 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
537 if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
538 GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
539 }
540 auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
541 : SPIRV::InstructionSet::GLSL_std_450;
542 if (auto Res =
543 SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
544 ResVReg, OrigRetTy, ArgVRegs, GR))
545 return *Res;
546 }
547
548 if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
549 // Emit the type info and forward function declaration to the first MBB
550 // to ensure VReg definition dependencies are valid across all MBBs.
551 MachineIRBuilder FirstBlockBuilder;
552 FirstBlockBuilder.setMF(MF);
553 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
554
555 SmallVector<ArrayRef<Register>, 8> VRegArgs;
556 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
557 for (const Argument &Arg : CF->args()) {
558 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
559 continue; // Don't handle zero sized types.
560 Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
561 MRI->setRegClass(Reg, &SPIRV::IDRegClass);
562 ToInsert.push_back({Reg});
563 VRegArgs.push_back(ToInsert.back());
564 }
565 // TODO: Reuse FunctionLoweringInfo
566 FunctionLoweringInfo FuncInfo;
567 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
568 }
569
570 unsigned CallOp;
571 if (Info.CB->isIndirectCall()) {
572 if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
573 report_fatal_error("An indirect call is encountered but SPIR-V without "
574 "extensions does not support it",
575 false);
576 // Set instruction operation according to SPV_INTEL_function_pointers
577 CallOp = SPIRV::OpFunctionPointerCallINTEL;
578 // Collect information about the indirect call to support possible
579 // specification of opaque ptr types of parent function's parameters
580 Register CalleeReg = Info.Callee.getReg();
581 if (CalleeReg.isValid()) {
582 SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
583 IndirectCall.Callee = CalleeReg;
584 IndirectCall.RetTy = OrigRetTy;
585 for (const auto &Arg : Info.OrigArgs) {
586 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
587 IndirectCall.ArgTys.push_back(Arg.Ty);
588 IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
589 }
590 IndirectCalls.push_back(IndirectCall);
591 }
592 } else {
593 // Emit a regular OpFunctionCall
594 CallOp = SPIRV::OpFunctionCall;
595 }
596
597 // Make sure there's a valid return reg, even for functions returning void.
598 if (!ResVReg.isValid())
599 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
600 SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
601
602 // Emit the call instruction and its args.
603 auto MIB = MIRBuilder.buildInstr(CallOp)
604 .addDef(ResVReg)
605 .addUse(GR->getSPIRVTypeID(RetType))
606 .add(Info.Callee);
607
608 for (const auto &Arg : Info.OrigArgs) {
609 // Currently call args should have single vregs.
610 if (Arg.Regs.size() > 1)
611 return false;
612 MIB.addUse(Arg.Regs[0]);
613 }
614 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
615 *ST->getRegBankInfo());
616 }
617