xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVBuiltins.h"
18 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVISelLowering.h"
20 #include "SPIRVMetadata.h"
21 #include "SPIRVRegisterInfo.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/CodeGen/FunctionLoweringInfo.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include "llvm/Support/ModRef.h"
28 
29 using namespace llvm;
30 
31 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32                                      SPIRVGlobalRegistry *GR)
33     : CallLowering(&TLI), GR(GR) {}
34 
35 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36                                     const Value *Val, ArrayRef<Register> VRegs,
37                                     FunctionLoweringInfo &FLI,
38                                     Register SwiftErrorVReg) const {
39   // Maybe run postponed production of types for function pointers
40   if (IndirectCalls.size() > 0) {
41     produceIndirectPtrTypes(MIRBuilder);
42     IndirectCalls.clear();
43   }
44 
45   // Currently all return types should use a single register.
46   // TODO: handle the case of multiple registers.
47   if (VRegs.size() > 1)
48     return false;
49   if (Val) {
50     const auto &STI = MIRBuilder.getMF().getSubtarget();
51     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
52         .addUse(VRegs[0])
53         .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
54                           *STI.getRegBankInfo());
55   }
56   MIRBuilder.buildInstr(SPIRV::OpReturn);
57   return true;
58 }
59 
60 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
61 static uint32_t getFunctionControl(const Function &F) {
62   MemoryEffects MemEffects = F.getMemoryEffects();
63 
64   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
65 
66   if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
67     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
68   else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
69     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
70 
71   if (MemEffects.doesNotAccessMemory())
72     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
73   else if (MemEffects.onlyReadsMemory())
74     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
75 
76   return FuncControl;
77 }
78 
79 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
80   if (MD->getNumOperands() > NumOp) {
81     auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
82     if (CMeta)
83       return dyn_cast<ConstantInt>(CMeta->getValue());
84   }
85   return nullptr;
86 }
87 
88 // If the function has pointer arguments, we are forced to re-create this
89 // function type from the very beginning, changing PointerType by
90 // TypedPointerType for each pointer argument. Otherwise, the same `Type*`
91 // potentially corresponds to different SPIR-V function type, effectively
92 // invalidating logic behind global registry and duplicates tracker.
93 static FunctionType *
94 fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
95                          FunctionType *FTy, const SPIRVType *SRetTy,
96                          const SmallVector<SPIRVType *, 4> &SArgTys) {
97   if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
98     return FTy;
99 
100   bool hasArgPtrs = false;
101   for (auto &Arg : F.args()) {
102     // check if it's an instance of a non-typed PointerType
103     if (Arg.getType()->isPointerTy()) {
104       hasArgPtrs = true;
105       break;
106     }
107   }
108   if (!hasArgPtrs) {
109     Type *RetTy = FTy->getReturnType();
110     // check if it's an instance of a non-typed PointerType
111     if (!RetTy->isPointerTy())
112       return FTy;
113   }
114 
115   // re-create function type, using TypedPointerType instead of PointerType to
116   // properly trace argument types
117   const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
118   SmallVector<Type *, 4> ArgTys;
119   for (auto SArgTy : SArgTys)
120     ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
121   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
122 }
123 
124 // This code restores function args/retvalue types for composite cases
125 // because the final types should still be aggregate whereas they're i32
126 // during the translation to cope with aggregate flattening etc.
127 static FunctionType *getOriginalFunctionType(const Function &F) {
128   auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
129   if (NamedMD == nullptr)
130     return F.getFunctionType();
131 
132   Type *RetTy = F.getFunctionType()->getReturnType();
133   SmallVector<Type *, 4> ArgTypes;
134   for (auto &Arg : F.args())
135     ArgTypes.push_back(Arg.getType());
136 
137   auto ThisFuncMDIt =
138       std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
139         return isa<MDString>(N->getOperand(0)) &&
140                cast<MDString>(N->getOperand(0))->getString() == F.getName();
141       });
142   // TODO: probably one function can have numerous type mutations,
143   // so we should support this.
144   if (ThisFuncMDIt != NamedMD->op_end()) {
145     auto *ThisFuncMD = *ThisFuncMDIt;
146     MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
147     assert(MD && "MDNode operand is expected");
148     ConstantInt *Const = getConstInt(MD, 0);
149     if (Const) {
150       auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
151       assert(CMeta && "ConstantAsMetadata operand is expected");
152       assert(Const->getSExtValue() >= -1);
153       // Currently -1 indicates return value, greater values mean
154       // argument numbers.
155       if (Const->getSExtValue() == -1)
156         RetTy = CMeta->getType();
157       else
158         ArgTypes[Const->getSExtValue()] = CMeta->getType();
159     }
160   }
161 
162   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
163 }
164 
165 static SPIRV::AccessQualifier::AccessQualifier
166 getArgAccessQual(const Function &F, unsigned ArgIdx) {
167   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
168     return SPIRV::AccessQualifier::ReadWrite;
169 
170   MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
171   if (!ArgAttribute)
172     return SPIRV::AccessQualifier::ReadWrite;
173 
174   if (ArgAttribute->getString() == "read_only")
175     return SPIRV::AccessQualifier::ReadOnly;
176   if (ArgAttribute->getString() == "write_only")
177     return SPIRV::AccessQualifier::WriteOnly;
178   return SPIRV::AccessQualifier::ReadWrite;
179 }
180 
181 static std::vector<SPIRV::Decoration::Decoration>
182 getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
183   MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
184   if (ArgAttribute && ArgAttribute->getString() == "volatile")
185     return {SPIRV::Decoration::Volatile};
186   return {};
187 }
188 
189 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
190                                   SPIRVGlobalRegistry *GR,
191                                   MachineIRBuilder &MIRBuilder,
192                                   const SPIRVSubtarget &ST) {
193   // Read argument's access qualifier from metadata or default.
194   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
195       getArgAccessQual(F, ArgIdx);
196 
197   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
198 
199   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
200   // be legally reassigned later).
201   if (!isPointerTy(OriginalArgType))
202     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
203 
204   Argument *Arg = F.getArg(ArgIdx);
205   Type *ArgType = Arg->getType();
206   if (isTypedPointerTy(ArgType)) {
207     SPIRVType *ElementType = GR->getOrCreateSPIRVType(
208         cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
209     return GR->getOrCreateSPIRVPointerType(
210         ElementType, MIRBuilder,
211         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
212   }
213 
214   // In case OriginalArgType is of untyped pointer type, there are three
215   // possibilities:
216   // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
217   // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
218   //    intrinsic assigning a TargetExtType.
219   // 3) This is a pointer, try to retrieve pointer element type from a
220   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
221   // type.
222   if (hasPointeeTypeAttr(Arg)) {
223     SPIRVType *ElementType =
224         GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
225     return GR->getOrCreateSPIRVPointerType(
226         ElementType, MIRBuilder,
227         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
228   }
229 
230   for (auto User : Arg->users()) {
231     auto *II = dyn_cast<IntrinsicInst>(User);
232     // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
233     if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
234       MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
235       Type *BuiltinType =
236           cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
237       assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
238       return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
239     }
240 
241     // Check if this is spv_assign_ptr_type assigning pointer element type.
242     if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
243       continue;
244 
245     MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
246     Type *ElementTy =
247         toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
248     SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
249     return GR->getOrCreateSPIRVPointerType(
250         ElementType, MIRBuilder,
251         addressSpaceToStorageClass(
252             cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
253   }
254 
255   // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
256   // LLVM types in a consistent manner
257   return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder,
258                                   ArgAccessQual);
259 }
260 
261 static SPIRV::ExecutionModel::ExecutionModel
262 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
263   if (STI.isOpenCLEnv())
264     return SPIRV::ExecutionModel::Kernel;
265 
266   auto attribute = F.getFnAttribute("hlsl.shader");
267   if (!attribute.isValid()) {
268     report_fatal_error(
269         "This entry point lacks mandatory hlsl.shader attribute.");
270   }
271 
272   const auto value = attribute.getValueAsString();
273   if (value == "compute")
274     return SPIRV::ExecutionModel::GLCompute;
275 
276   report_fatal_error("This HLSL entry point is not supported by this backend.");
277 }
278 
279 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
280                                              const Function &F,
281                                              ArrayRef<ArrayRef<Register>> VRegs,
282                                              FunctionLoweringInfo &FLI) const {
283   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
284   GR->setCurrentFunc(MIRBuilder.getMF());
285 
286   // Get access to information about available extensions
287   const SPIRVSubtarget *ST =
288       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
289 
290   // Assign types and names to all args, and store their types for later.
291   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
292   if (VRegs.size() > 0) {
293     unsigned i = 0;
294     for (const auto &Arg : F.args()) {
295       // Currently formal args should use single registers.
296       // TODO: handle the case of multiple registers.
297       if (VRegs[i].size() > 1)
298         return false;
299       auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
300       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
301       ArgTypeVRegs.push_back(SpirvTy);
302 
303       if (Arg.hasName())
304         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
305       if (isPointerTy(Arg.getType())) {
306         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
307         if (DerefBytes != 0)
308           buildOpDecorate(VRegs[i][0], MIRBuilder,
309                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
310       }
311       if (Arg.hasAttribute(Attribute::Alignment)) {
312         auto Alignment = static_cast<unsigned>(
313             Arg.getAttribute(Attribute::Alignment).getValueAsInt());
314         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
315                         {Alignment});
316       }
317       if (Arg.hasAttribute(Attribute::ReadOnly)) {
318         auto Attr =
319             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
320         buildOpDecorate(VRegs[i][0], MIRBuilder,
321                         SPIRV::Decoration::FuncParamAttr, {Attr});
322       }
323       if (Arg.hasAttribute(Attribute::ZExt)) {
324         auto Attr =
325             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
326         buildOpDecorate(VRegs[i][0], MIRBuilder,
327                         SPIRV::Decoration::FuncParamAttr, {Attr});
328       }
329       if (Arg.hasAttribute(Attribute::NoAlias)) {
330         auto Attr =
331             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
332         buildOpDecorate(VRegs[i][0], MIRBuilder,
333                         SPIRV::Decoration::FuncParamAttr, {Attr});
334       }
335       if (Arg.hasAttribute(Attribute::ByVal)) {
336         auto Attr =
337             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
338         buildOpDecorate(VRegs[i][0], MIRBuilder,
339                         SPIRV::Decoration::FuncParamAttr, {Attr});
340       }
341 
342       if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
343         std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
344             getKernelArgTypeQual(F, i);
345         for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
346           buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
347       }
348 
349       MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
350       if (Node && i < Node->getNumOperands() &&
351           isa<MDNode>(Node->getOperand(i))) {
352         MDNode *MD = cast<MDNode>(Node->getOperand(i));
353         for (const MDOperand &MDOp : MD->operands()) {
354           MDNode *MD2 = dyn_cast<MDNode>(MDOp);
355           assert(MD2 && "Metadata operand is expected");
356           ConstantInt *Const = getConstInt(MD2, 0);
357           assert(Const && "MDOperand should be ConstantInt");
358           auto Dec =
359               static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
360           std::vector<uint32_t> DecVec;
361           for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
362             ConstantInt *Const = getConstInt(MD2, j);
363             assert(Const && "MDOperand should be ConstantInt");
364             DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
365           }
366           buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
367         }
368       }
369       ++i;
370     }
371   }
372 
373   auto MRI = MIRBuilder.getMRI();
374   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
375   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
376   if (F.isDeclaration())
377     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
378   FunctionType *FTy = getOriginalFunctionType(F);
379   Type *FRetTy = FTy->getReturnType();
380   if (isUntypedPointerTy(FRetTy)) {
381     if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
382       TypedPointerType *DerivedTy = TypedPointerType::get(
383           toTypedPointer(FRetElemTy), getPointerAddressSpace(FRetTy));
384       GR->addReturnType(&F, DerivedTy);
385       FRetTy = DerivedTy;
386     }
387   }
388   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
389   FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
390   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
391       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
392   uint32_t FuncControl = getFunctionControl(F);
393 
394   // Add OpFunction instruction
395   MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
396                                .addDef(FuncVReg)
397                                .addUse(GR->getSPIRVTypeID(RetTy))
398                                .addImm(FuncControl)
399                                .addUse(GR->getSPIRVTypeID(FuncTy));
400   GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
401 
402   // Add OpFunctionParameter instructions
403   int i = 0;
404   for (const auto &Arg : F.args()) {
405     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
406     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
407     MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
408         .addDef(VRegs[i][0])
409         .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
410     if (F.isDeclaration())
411       GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
412     i++;
413   }
414   // Name the function.
415   if (F.hasName())
416     buildOpName(FuncVReg, F.getName(), MIRBuilder);
417 
418   // Handle entry points and function linkage.
419   if (isEntryPoint(F)) {
420     const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
421     auto executionModel = getExecutionModel(STI, F);
422     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
423                    .addImm(static_cast<uint32_t>(executionModel))
424                    .addUse(FuncVReg);
425     addStringImm(F.getName(), MIB);
426   } else if (F.getLinkage() != GlobalValue::InternalLinkage &&
427              F.getLinkage() != GlobalValue::PrivateLinkage) {
428     SPIRV::LinkageType::LinkageType LnkTy =
429         F.isDeclaration()
430             ? SPIRV::LinkageType::Import
431             : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
432                        ST->canUseExtension(
433                            SPIRV::Extension::SPV_KHR_linkonce_odr)
434                    ? SPIRV::LinkageType::LinkOnceODR
435                    : SPIRV::LinkageType::Export);
436     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
437                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
438   }
439 
440   // Handle function pointers decoration
441   bool hasFunctionPointers =
442       ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
443   if (hasFunctionPointers) {
444     if (F.hasFnAttribute("referenced-indirectly")) {
445       assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
446              "Unexpected 'referenced-indirectly' attribute of the kernel "
447              "function");
448       buildOpDecorate(FuncVReg, MIRBuilder,
449                       SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
450     }
451   }
452 
453   return true;
454 }
455 
456 // Used to postpone producing of indirect function pointer types after all
457 // indirect calls info is collected
458 // TODO:
459 // - add a topological sort of IndirectCalls to ensure the best types knowledge
460 // - we may need to fix function formal parameter types if they are opaque
461 //   pointers used as function pointers in these indirect calls
462 void SPIRVCallLowering::produceIndirectPtrTypes(
463     MachineIRBuilder &MIRBuilder) const {
464   // Create indirect call data types if any
465   MachineFunction &MF = MIRBuilder.getMF();
466   for (auto const &IC : IndirectCalls) {
467     SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
468     SmallVector<SPIRVType *, 4> SpirvArgTypes;
469     for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
470       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
471       SpirvArgTypes.push_back(SPIRVTy);
472       if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
473         GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
474     }
475     // SPIR-V function type:
476     FunctionType *FTy =
477         FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
478     SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
479         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
480     // SPIR-V pointer to function type:
481     SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
482         SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
483     // Correct the Callee type
484     GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
485   }
486 }
487 
488 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
489                                   CallLoweringInfo &Info) const {
490   // Currently call returns should have single vregs.
491   // TODO: handle the case of multiple registers.
492   if (Info.OrigRet.Regs.size() > 1)
493     return false;
494   MachineFunction &MF = MIRBuilder.getMF();
495   GR->setCurrentFunc(MF);
496   const Function *CF = nullptr;
497   std::string DemangledName;
498   const Type *OrigRetTy = Info.OrigRet.Ty;
499 
500   // Emit a regular OpFunctionCall. If it's an externally declared function,
501   // be sure to emit its type and function declaration here. It will be hoisted
502   // globally later.
503   if (Info.Callee.isGlobal()) {
504     std::string FuncName = Info.Callee.getGlobal()->getName().str();
505     DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
506     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
507     // TODO: support constexpr casts and indirect calls.
508     if (CF == nullptr)
509       return false;
510     if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
511       OrigRetTy = FTy->getReturnType();
512       if (isUntypedPointerTy(OrigRetTy)) {
513         if (auto *DerivedRetTy = GR->findReturnType(CF))
514           OrigRetTy = DerivedRetTy;
515       }
516     }
517   }
518 
519   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
520   Register ResVReg =
521       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
522   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
523 
524   bool isFunctionDecl = CF && CF->isDeclaration();
525   bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
526   bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
527   assert(canUseGLSL != canUseOpenCL &&
528          "Scenario where both sets are enabled is not supported.");
529 
530   if (isFunctionDecl && !DemangledName.empty() &&
531       (canUseGLSL || canUseOpenCL)) {
532     SmallVector<Register, 8> ArgVRegs;
533     for (auto Arg : Info.OrigArgs) {
534       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
535       ArgVRegs.push_back(Arg.Regs[0]);
536       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
537       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
538         GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
539     }
540     auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
541                                        : SPIRV::InstructionSet::GLSL_std_450;
542     if (auto Res =
543             SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
544                                 ResVReg, OrigRetTy, ArgVRegs, GR))
545       return *Res;
546   }
547 
548   if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
549     // Emit the type info and forward function declaration to the first MBB
550     // to ensure VReg definition dependencies are valid across all MBBs.
551     MachineIRBuilder FirstBlockBuilder;
552     FirstBlockBuilder.setMF(MF);
553     FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
554 
555     SmallVector<ArrayRef<Register>, 8> VRegArgs;
556     SmallVector<SmallVector<Register, 1>, 8> ToInsert;
557     for (const Argument &Arg : CF->args()) {
558       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
559         continue; // Don't handle zero sized types.
560       Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
561       MRI->setRegClass(Reg, &SPIRV::IDRegClass);
562       ToInsert.push_back({Reg});
563       VRegArgs.push_back(ToInsert.back());
564     }
565     // TODO: Reuse FunctionLoweringInfo
566     FunctionLoweringInfo FuncInfo;
567     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
568   }
569 
570   unsigned CallOp;
571   if (Info.CB->isIndirectCall()) {
572     if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
573       report_fatal_error("An indirect call is encountered but SPIR-V without "
574                          "extensions does not support it",
575                          false);
576     // Set instruction operation according to SPV_INTEL_function_pointers
577     CallOp = SPIRV::OpFunctionPointerCallINTEL;
578     // Collect information about the indirect call to support possible
579     // specification of opaque ptr types of parent function's parameters
580     Register CalleeReg = Info.Callee.getReg();
581     if (CalleeReg.isValid()) {
582       SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
583       IndirectCall.Callee = CalleeReg;
584       IndirectCall.RetTy = OrigRetTy;
585       for (const auto &Arg : Info.OrigArgs) {
586         assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
587         IndirectCall.ArgTys.push_back(Arg.Ty);
588         IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
589       }
590       IndirectCalls.push_back(IndirectCall);
591     }
592   } else {
593     // Emit a regular OpFunctionCall
594     CallOp = SPIRV::OpFunctionCall;
595   }
596 
597   // Make sure there's a valid return reg, even for functions returning void.
598   if (!ResVReg.isValid())
599     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
600   SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
601 
602   // Emit the call instruction and its args.
603   auto MIB = MIRBuilder.buildInstr(CallOp)
604                  .addDef(ResVReg)
605                  .addUse(GR->getSPIRVTypeID(RetType))
606                  .add(Info.Callee);
607 
608   for (const auto &Arg : Info.OrigArgs) {
609     // Currently call args should have single vregs.
610     if (Arg.Regs.size() > 1)
611       return false;
612     MIB.addUse(Arg.Regs[0]);
613   }
614   return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
615                               *ST->getRegBankInfo());
616 }
617