xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVBuiltins.h"
18 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVISelLowering.h"
20 #include "SPIRVMetadata.h"
21 #include "SPIRVRegisterInfo.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVUtils.h"
24 #include "llvm/CodeGen/FunctionLoweringInfo.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/IntrinsicsSPIRV.h"
27 #include "llvm/Support/ModRef.h"
28 
29 using namespace llvm;
30 
SPIRVCallLowering(const SPIRVTargetLowering & TLI,SPIRVGlobalRegistry * GR)31 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32                                      SPIRVGlobalRegistry *GR)
33     : CallLowering(&TLI), GR(GR) {}
34 
lowerReturn(MachineIRBuilder & MIRBuilder,const Value * Val,ArrayRef<Register> VRegs,FunctionLoweringInfo & FLI,Register SwiftErrorVReg) const35 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36                                     const Value *Val, ArrayRef<Register> VRegs,
37                                     FunctionLoweringInfo &FLI,
38                                     Register SwiftErrorVReg) const {
39   // Ignore if called from the internal service function
40   if (MIRBuilder.getMF()
41           .getFunction()
42           .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
43           .isValid())
44     return true;
45 
46   // Maybe run postponed production of types for function pointers
47   if (IndirectCalls.size() > 0) {
48     produceIndirectPtrTypes(MIRBuilder);
49     IndirectCalls.clear();
50   }
51 
52   // Currently all return types should use a single register.
53   // TODO: handle the case of multiple registers.
54   if (VRegs.size() > 1)
55     return false;
56   if (Val) {
57     const auto &STI = MIRBuilder.getMF().getSubtarget();
58     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
59         .addUse(VRegs[0])
60         .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
61                           *STI.getRegBankInfo());
62   }
63   MIRBuilder.buildInstr(SPIRV::OpReturn);
64   return true;
65 }
66 
67 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
getFunctionControl(const Function & F,const SPIRVSubtarget * ST)68 static uint32_t getFunctionControl(const Function &F,
69                                    const SPIRVSubtarget *ST) {
70   MemoryEffects MemEffects = F.getMemoryEffects();
71 
72   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
73 
74   if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
75     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
76   else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
77     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
78 
79   if (MemEffects.doesNotAccessMemory())
80     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
81   else if (MemEffects.onlyReadsMemory())
82     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
83 
84   if (ST->canUseExtension(SPIRV::Extension::SPV_INTEL_optnone) ||
85       ST->canUseExtension(SPIRV::Extension::SPV_EXT_optnone))
86     if (F.hasFnAttribute(Attribute::OptimizeNone))
87       FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::OptNoneEXT);
88 
89   return FuncControl;
90 }
91 
getConstInt(MDNode * MD,unsigned NumOp)92 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
93   if (MD->getNumOperands() > NumOp) {
94     auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
95     if (CMeta)
96       return dyn_cast<ConstantInt>(CMeta->getValue());
97   }
98   return nullptr;
99 }
100 
101 // If the function has pointer arguments, we are forced to re-create this
102 // function type from the very beginning, changing PointerType by
103 // TypedPointerType for each pointer argument. Otherwise, the same `Type*`
104 // potentially corresponds to different SPIR-V function type, effectively
105 // invalidating logic behind global registry and duplicates tracker.
106 static FunctionType *
fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry * GR,const Function & F,FunctionType * FTy,const SPIRVType * SRetTy,const SmallVector<SPIRVType *,4> & SArgTys)107 fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
108                          FunctionType *FTy, const SPIRVType *SRetTy,
109                          const SmallVector<SPIRVType *, 4> &SArgTys) {
110   bool hasArgPtrs = false;
111   for (auto &Arg : F.args()) {
112     // check if it's an instance of a non-typed PointerType
113     if (Arg.getType()->isPointerTy()) {
114       hasArgPtrs = true;
115       break;
116     }
117   }
118   if (!hasArgPtrs) {
119     Type *RetTy = FTy->getReturnType();
120     // check if it's an instance of a non-typed PointerType
121     if (!RetTy->isPointerTy())
122       return FTy;
123   }
124 
125   // re-create function type, using TypedPointerType instead of PointerType to
126   // properly trace argument types
127   const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
128   SmallVector<Type *, 4> ArgTys;
129   for (auto SArgTy : SArgTys)
130     ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
131   return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
132 }
133 
134 // This code restores function args/retvalue types for composite cases
135 // because the final types should still be aggregate whereas they're i32
136 // during the translation to cope with aggregate flattening etc.
getOriginalFunctionType(const Function & F)137 static FunctionType *getOriginalFunctionType(const Function &F) {
138   auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
139   if (NamedMD == nullptr)
140     return F.getFunctionType();
141 
142   Type *RetTy = F.getFunctionType()->getReturnType();
143   SmallVector<Type *, 4> ArgTypes;
144   for (auto &Arg : F.args())
145     ArgTypes.push_back(Arg.getType());
146 
147   auto ThisFuncMDIt =
148       std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
149         return isa<MDString>(N->getOperand(0)) &&
150                cast<MDString>(N->getOperand(0))->getString() == F.getName();
151       });
152   // TODO: probably one function can have numerous type mutations,
153   // so we should support this.
154   if (ThisFuncMDIt != NamedMD->op_end()) {
155     auto *ThisFuncMD = *ThisFuncMDIt;
156     MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
157     assert(MD && "MDNode operand is expected");
158     ConstantInt *Const = getConstInt(MD, 0);
159     if (Const) {
160       auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
161       assert(CMeta && "ConstantAsMetadata operand is expected");
162       assert(Const->getSExtValue() >= -1);
163       // Currently -1 indicates return value, greater values mean
164       // argument numbers.
165       if (Const->getSExtValue() == -1)
166         RetTy = CMeta->getType();
167       else
168         ArgTypes[Const->getSExtValue()] = CMeta->getType();
169     }
170   }
171 
172   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
173 }
174 
175 static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function & F,unsigned ArgIdx)176 getArgAccessQual(const Function &F, unsigned ArgIdx) {
177   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
178     return SPIRV::AccessQualifier::ReadWrite;
179 
180   MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
181   if (!ArgAttribute)
182     return SPIRV::AccessQualifier::ReadWrite;
183 
184   if (ArgAttribute->getString() == "read_only")
185     return SPIRV::AccessQualifier::ReadOnly;
186   if (ArgAttribute->getString() == "write_only")
187     return SPIRV::AccessQualifier::WriteOnly;
188   return SPIRV::AccessQualifier::ReadWrite;
189 }
190 
191 static std::vector<SPIRV::Decoration::Decoration>
getKernelArgTypeQual(const Function & F,unsigned ArgIdx)192 getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
193   MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
194   if (ArgAttribute && ArgAttribute->getString() == "volatile")
195     return {SPIRV::Decoration::Volatile};
196   return {};
197 }
198 
getArgSPIRVType(const Function & F,unsigned ArgIdx,SPIRVGlobalRegistry * GR,MachineIRBuilder & MIRBuilder,const SPIRVSubtarget & ST)199 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
200                                   SPIRVGlobalRegistry *GR,
201                                   MachineIRBuilder &MIRBuilder,
202                                   const SPIRVSubtarget &ST) {
203   // Read argument's access qualifier from metadata or default.
204   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
205       getArgAccessQual(F, ArgIdx);
206 
207   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
208 
209   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
210   // be legally reassigned later).
211   if (!isPointerTy(OriginalArgType))
212     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual,
213                                     true);
214 
215   Argument *Arg = F.getArg(ArgIdx);
216   Type *ArgType = Arg->getType();
217   if (isTypedPointerTy(ArgType)) {
218     return GR->getOrCreateSPIRVPointerType(
219         cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
220         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
221   }
222 
223   // In case OriginalArgType is of untyped pointer type, there are three
224   // possibilities:
225   // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
226   // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
227   //    intrinsic assigning a TargetExtType.
228   // 3) This is a pointer, try to retrieve pointer element type from a
229   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
230   // type.
231   if (hasPointeeTypeAttr(Arg)) {
232     return GR->getOrCreateSPIRVPointerType(
233         getPointeeTypeByAttr(Arg), MIRBuilder,
234         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
235   }
236 
237   for (auto User : Arg->users()) {
238     auto *II = dyn_cast<IntrinsicInst>(User);
239     // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
240     if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
241       MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
242       Type *BuiltinType =
243           cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
244       assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
245       return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual,
246                                       true);
247     }
248 
249     // Check if this is spv_assign_ptr_type assigning pointer element type.
250     if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
251       continue;
252 
253     MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
254     Type *ElementTy =
255         toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
256     return GR->getOrCreateSPIRVPointerType(
257         ElementTy, MIRBuilder,
258         addressSpaceToStorageClass(
259             cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
260   }
261 
262   // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
263   // LLVM types in a consistent manner
264   return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder,
265                                   ArgAccessQual, true);
266 }
267 
268 static SPIRV::ExecutionModel::ExecutionModel
getExecutionModel(const SPIRVSubtarget & STI,const Function & F)269 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
270   if (STI.isKernel())
271     return SPIRV::ExecutionModel::Kernel;
272 
273   if (STI.isShader()) {
274     auto attribute = F.getFnAttribute("hlsl.shader");
275     if (!attribute.isValid()) {
276       report_fatal_error(
277           "This entry point lacks mandatory hlsl.shader attribute.");
278     }
279 
280     const auto value = attribute.getValueAsString();
281     if (value == "compute")
282       return SPIRV::ExecutionModel::GLCompute;
283     if (value == "vertex")
284       return SPIRV::ExecutionModel::Vertex;
285     if (value == "pixel")
286       return SPIRV::ExecutionModel::Fragment;
287 
288     report_fatal_error(
289         "This HLSL entry point is not supported by this backend.");
290   }
291 
292   assert(STI.getEnv() == SPIRVSubtarget::Unknown);
293   // "hlsl.shader" attribute is mandatory for Vulkan, so we can set Env to
294   // Shader whenever we find it, and to Kernel otherwise.
295 
296   // We will now change the Env based on the attribute, so we need to strip
297   // `const` out of the ref to STI.
298   SPIRVSubtarget *NonConstSTI = const_cast<SPIRVSubtarget *>(&STI);
299   auto attribute = F.getFnAttribute("hlsl.shader");
300   if (!attribute.isValid()) {
301     NonConstSTI->setEnv(SPIRVSubtarget::Kernel);
302     return SPIRV::ExecutionModel::Kernel;
303   }
304   NonConstSTI->setEnv(SPIRVSubtarget::Shader);
305 
306   const auto value = attribute.getValueAsString();
307   if (value == "compute")
308     return SPIRV::ExecutionModel::GLCompute;
309   if (value == "vertex")
310     return SPIRV::ExecutionModel::Vertex;
311   if (value == "pixel")
312     return SPIRV::ExecutionModel::Fragment;
313 
314   report_fatal_error("This HLSL entry point is not supported by this backend.");
315 }
316 
lowerFormalArguments(MachineIRBuilder & MIRBuilder,const Function & F,ArrayRef<ArrayRef<Register>> VRegs,FunctionLoweringInfo & FLI) const317 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
318                                              const Function &F,
319                                              ArrayRef<ArrayRef<Register>> VRegs,
320                                              FunctionLoweringInfo &FLI) const {
321   // Discard the internal service function
322   if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid())
323     return true;
324 
325   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
326   GR->setCurrentFunc(MIRBuilder.getMF());
327 
328   // Get access to information about available extensions
329   const SPIRVSubtarget *ST =
330       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
331 
332   // Assign types and names to all args, and store their types for later.
333   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
334   if (VRegs.size() > 0) {
335     unsigned i = 0;
336     for (const auto &Arg : F.args()) {
337       // Currently formal args should use single registers.
338       // TODO: handle the case of multiple registers.
339       if (VRegs[i].size() > 1)
340         return false;
341       auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
342       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
343       ArgTypeVRegs.push_back(SpirvTy);
344 
345       if (Arg.hasName())
346         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
347       if (isPointerTyOrWrapper(Arg.getType())) {
348         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
349         if (DerefBytes != 0)
350           buildOpDecorate(VRegs[i][0], MIRBuilder,
351                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
352       }
353       if (Arg.hasAttribute(Attribute::Alignment) && !ST->isShader()) {
354         auto Alignment = static_cast<unsigned>(
355             Arg.getAttribute(Attribute::Alignment).getValueAsInt());
356         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
357                         {Alignment});
358       }
359       if (Arg.hasAttribute(Attribute::ReadOnly)) {
360         auto Attr =
361             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
362         buildOpDecorate(VRegs[i][0], MIRBuilder,
363                         SPIRV::Decoration::FuncParamAttr, {Attr});
364       }
365       if (Arg.hasAttribute(Attribute::ZExt)) {
366         auto Attr =
367             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
368         buildOpDecorate(VRegs[i][0], MIRBuilder,
369                         SPIRV::Decoration::FuncParamAttr, {Attr});
370       }
371       if (Arg.hasAttribute(Attribute::NoAlias)) {
372         auto Attr =
373             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
374         buildOpDecorate(VRegs[i][0], MIRBuilder,
375                         SPIRV::Decoration::FuncParamAttr, {Attr});
376       }
377       if (Arg.hasAttribute(Attribute::ByVal)) {
378         auto Attr =
379             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
380         buildOpDecorate(VRegs[i][0], MIRBuilder,
381                         SPIRV::Decoration::FuncParamAttr, {Attr});
382       }
383       if (Arg.hasAttribute(Attribute::StructRet)) {
384         auto Attr =
385             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Sret);
386         buildOpDecorate(VRegs[i][0], MIRBuilder,
387                         SPIRV::Decoration::FuncParamAttr, {Attr});
388       }
389 
390       if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
391         std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
392             getKernelArgTypeQual(F, i);
393         for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
394           buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
395       }
396 
397       MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
398       if (Node && i < Node->getNumOperands() &&
399           isa<MDNode>(Node->getOperand(i))) {
400         MDNode *MD = cast<MDNode>(Node->getOperand(i));
401         for (const MDOperand &MDOp : MD->operands()) {
402           MDNode *MD2 = dyn_cast<MDNode>(MDOp);
403           assert(MD2 && "Metadata operand is expected");
404           ConstantInt *Const = getConstInt(MD2, 0);
405           assert(Const && "MDOperand should be ConstantInt");
406           auto Dec =
407               static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
408           std::vector<uint32_t> DecVec;
409           for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
410             ConstantInt *Const = getConstInt(MD2, j);
411             assert(Const && "MDOperand should be ConstantInt");
412             DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
413           }
414           buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
415         }
416       }
417       ++i;
418     }
419   }
420 
421   auto MRI = MIRBuilder.getMRI();
422   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
423   MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
424   FunctionType *FTy = getOriginalFunctionType(F);
425   Type *FRetTy = FTy->getReturnType();
426   if (isUntypedPointerTy(FRetTy)) {
427     if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
428       TypedPointerType *DerivedTy = TypedPointerType::get(
429           toTypedPointer(FRetElemTy), getPointerAddressSpace(FRetTy));
430       GR->addReturnType(&F, DerivedTy);
431       FRetTy = DerivedTy;
432     }
433   }
434   SPIRVType *RetTy = GR->getOrCreateSPIRVType(
435       FRetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
436   FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
437   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
438       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
439   uint32_t FuncControl = getFunctionControl(F, ST);
440 
441   // Add OpFunction instruction
442   MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
443                                .addDef(FuncVReg)
444                                .addUse(GR->getSPIRVTypeID(RetTy))
445                                .addImm(FuncControl)
446                                .addUse(GR->getSPIRVTypeID(FuncTy));
447   GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
448   GR->addGlobalObject(&F, &MIRBuilder.getMF(), FuncVReg);
449   if (F.isDeclaration())
450     GR->add(&F, MB);
451 
452   // Add OpFunctionParameter instructions
453   int i = 0;
454   for (const auto &Arg : F.args()) {
455     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
456     Register ArgReg = VRegs[i][0];
457     MRI->setRegClass(ArgReg, GR->getRegClass(ArgTypeVRegs[i]));
458     MRI->setType(ArgReg, GR->getRegType(ArgTypeVRegs[i]));
459     auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
460                    .addDef(ArgReg)
461                    .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
462     if (F.isDeclaration())
463       GR->add(&Arg, MIB);
464     GR->addGlobalObject(&Arg, &MIRBuilder.getMF(), ArgReg);
465     i++;
466   }
467   // Name the function.
468   if (F.hasName())
469     buildOpName(FuncVReg, F.getName(), MIRBuilder);
470 
471   // Handle entry points and function linkage.
472   if (isEntryPoint(F)) {
473     // EntryPoints can help us to determine the environment we're working on.
474     // Therefore, we need a non-const pointer to SPIRVSubtarget to update the
475     // environment if we need to.
476     const SPIRVSubtarget *ST =
477         static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
478     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
479                    .addImm(static_cast<uint32_t>(getExecutionModel(*ST, F)))
480                    .addUse(FuncVReg);
481     addStringImm(F.getName(), MIB);
482   } else if (F.getLinkage() != GlobalValue::InternalLinkage &&
483              F.getLinkage() != GlobalValue::PrivateLinkage &&
484              F.getVisibility() != GlobalValue::HiddenVisibility) {
485     SPIRV::LinkageType::LinkageType LnkTy =
486         F.isDeclaration()
487             ? SPIRV::LinkageType::Import
488             : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
489                        ST->canUseExtension(
490                            SPIRV::Extension::SPV_KHR_linkonce_odr)
491                    ? SPIRV::LinkageType::LinkOnceODR
492                    : SPIRV::LinkageType::Export);
493     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
494                     {static_cast<uint32_t>(LnkTy)}, F.getName());
495   }
496 
497   // Handle function pointers decoration
498   bool hasFunctionPointers =
499       ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
500   if (hasFunctionPointers) {
501     if (F.hasFnAttribute("referenced-indirectly")) {
502       assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
503              "Unexpected 'referenced-indirectly' attribute of the kernel "
504              "function");
505       buildOpDecorate(FuncVReg, MIRBuilder,
506                       SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
507     }
508   }
509 
510   return true;
511 }
512 
513 // Used to postpone producing of indirect function pointer types after all
514 // indirect calls info is collected
515 // TODO:
516 // - add a topological sort of IndirectCalls to ensure the best types knowledge
517 // - we may need to fix function formal parameter types if they are opaque
518 //   pointers used as function pointers in these indirect calls
produceIndirectPtrTypes(MachineIRBuilder & MIRBuilder) const519 void SPIRVCallLowering::produceIndirectPtrTypes(
520     MachineIRBuilder &MIRBuilder) const {
521   // Create indirect call data types if any
522   MachineFunction &MF = MIRBuilder.getMF();
523   for (auto const &IC : IndirectCalls) {
524     SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
525         IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
526     SmallVector<SPIRVType *, 4> SpirvArgTypes;
527     for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
528       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(
529           IC.ArgTys[i], MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
530       SpirvArgTypes.push_back(SPIRVTy);
531       if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
532         GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
533     }
534     // SPIR-V function type:
535     FunctionType *FTy =
536         FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
537     SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
538         FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
539     // SPIR-V pointer to function type:
540     SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
541         SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
542     // Correct the Callee type
543     GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
544   }
545 }
546 
lowerCall(MachineIRBuilder & MIRBuilder,CallLoweringInfo & Info) const547 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
548                                   CallLoweringInfo &Info) const {
549   // Currently call returns should have single vregs.
550   // TODO: handle the case of multiple registers.
551   if (Info.OrigRet.Regs.size() > 1)
552     return false;
553   MachineFunction &MF = MIRBuilder.getMF();
554   GR->setCurrentFunc(MF);
555   const Function *CF = nullptr;
556   std::string DemangledName;
557   const Type *OrigRetTy = Info.OrigRet.Ty;
558 
559   // Emit a regular OpFunctionCall. If it's an externally declared function,
560   // be sure to emit its type and function declaration here. It will be hoisted
561   // globally later.
562   if (Info.Callee.isGlobal()) {
563     std::string FuncName = Info.Callee.getGlobal()->getName().str();
564     DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
565     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
566     // TODO: support constexpr casts and indirect calls.
567     if (CF == nullptr)
568       return false;
569     if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
570       OrigRetTy = FTy->getReturnType();
571       if (isUntypedPointerTy(OrigRetTy)) {
572         if (auto *DerivedRetTy = GR->findReturnType(CF))
573           OrigRetTy = DerivedRetTy;
574       }
575     }
576   }
577 
578   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
579   Register ResVReg =
580       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
581   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
582 
583   bool isFunctionDecl = CF && CF->isDeclaration();
584   if (isFunctionDecl && !DemangledName.empty()) {
585     if (ResVReg.isValid()) {
586       if (!GR->getSPIRVTypeForVReg(ResVReg)) {
587         const Type *RetTy = OrigRetTy;
588         if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
589           const Value *OrigValue = Info.OrigRet.OrigValue;
590           if (!OrigValue)
591             OrigValue = Info.CB;
592           if (OrigValue)
593             if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
594               RetTy =
595                   TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
596         }
597         setRegClassType(ResVReg, RetTy, GR, MIRBuilder,
598                         SPIRV::AccessQualifier::ReadWrite, true);
599       }
600     } else {
601       ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder,
602                                       SPIRV::AccessQualifier::ReadWrite, true);
603     }
604     SmallVector<Register, 8> ArgVRegs;
605     for (auto Arg : Info.OrigArgs) {
606       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
607       Register ArgReg = Arg.Regs[0];
608       ArgVRegs.push_back(ArgReg);
609       SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
610       if (!SpvType) {
611         Type *ArgTy = nullptr;
612         if (auto *PtrArgTy = dyn_cast<PointerType>(Arg.Ty)) {
613           // If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
614           // don't have access to original value in LLVM IR or info about
615           // deduced pointee type, then we should wait with setting the type for
616           // the virtual register until pre-legalizer step when we access
617           // @llvm.spv.assign.ptr.type.p...(...)'s info.
618           if (Arg.OrigValue)
619             if (Type *ElemTy = GR->findDeducedElementType(Arg.OrigValue))
620               ArgTy =
621                   TypedPointerType::get(ElemTy, PtrArgTy->getAddressSpace());
622         } else {
623           ArgTy = Arg.Ty;
624         }
625         if (ArgTy) {
626           SpvType = GR->getOrCreateSPIRVType(
627               ArgTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
628           GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
629         }
630       }
631       if (!MRI->getRegClassOrNull(ArgReg)) {
632         // Either we have SpvType created, or Arg.Ty is an untyped pointer and
633         // we know its virtual register's class and type even if we don't know
634         // pointee type.
635         MRI->setRegClass(ArgReg, SpvType ? GR->getRegClass(SpvType)
636                                          : &SPIRV::pIDRegClass);
637         MRI->setType(
638             ArgReg,
639             SpvType ? GR->getRegType(SpvType)
640                     : LLT::pointer(cast<PointerType>(Arg.Ty)->getAddressSpace(),
641                                    GR->getPointerSize()));
642       }
643     }
644     if (auto Res =
645             SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
646                                 MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
647       return *Res;
648   }
649 
650   if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
651     // Emit the type info and forward function declaration to the first MBB
652     // to ensure VReg definition dependencies are valid across all MBBs.
653     MachineIRBuilder FirstBlockBuilder;
654     FirstBlockBuilder.setMF(MF);
655     FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
656 
657     SmallVector<ArrayRef<Register>, 8> VRegArgs;
658     SmallVector<SmallVector<Register, 1>, 8> ToInsert;
659     for (const Argument &Arg : CF->args()) {
660       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
661         continue; // Don't handle zero sized types.
662       Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(64));
663       MRI->setRegClass(Reg, &SPIRV::iIDRegClass);
664       ToInsert.push_back({Reg});
665       VRegArgs.push_back(ToInsert.back());
666     }
667     // TODO: Reuse FunctionLoweringInfo
668     FunctionLoweringInfo FuncInfo;
669     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
670   }
671 
672   // Ignore the call if it's called from the internal service function
673   if (MIRBuilder.getMF()
674           .getFunction()
675           .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
676           .isValid()) {
677     // insert a no-op
678     MIRBuilder.buildTrap();
679     return true;
680   }
681 
682   unsigned CallOp;
683   if (Info.CB->isIndirectCall()) {
684     if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
685       report_fatal_error("An indirect call is encountered but SPIR-V without "
686                          "extensions does not support it",
687                          false);
688     // Set instruction operation according to SPV_INTEL_function_pointers
689     CallOp = SPIRV::OpFunctionPointerCallINTEL;
690     // Collect information about the indirect call to support possible
691     // specification of opaque ptr types of parent function's parameters
692     Register CalleeReg = Info.Callee.getReg();
693     if (CalleeReg.isValid()) {
694       SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
695       IndirectCall.Callee = CalleeReg;
696       IndirectCall.RetTy = OrigRetTy;
697       for (const auto &Arg : Info.OrigArgs) {
698         assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
699         IndirectCall.ArgTys.push_back(Arg.Ty);
700         IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
701       }
702       IndirectCalls.push_back(IndirectCall);
703     }
704   } else {
705     // Emit a regular OpFunctionCall
706     CallOp = SPIRV::OpFunctionCall;
707   }
708 
709   // Make sure there's a valid return reg, even for functions returning void.
710   if (!ResVReg.isValid())
711     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
712   SPIRVType *RetType = GR->assignTypeToVReg(
713       OrigRetTy, ResVReg, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
714 
715   // Emit the call instruction and its args.
716   auto MIB = MIRBuilder.buildInstr(CallOp)
717                  .addDef(ResVReg)
718                  .addUse(GR->getSPIRVTypeID(RetType))
719                  .add(Info.Callee);
720 
721   for (const auto &Arg : Info.OrigArgs) {
722     // Currently call args should have single vregs.
723     if (Arg.Regs.size() > 1)
724       return false;
725     MIB.addUse(Arg.Regs[0]);
726   }
727 
728   if (ST->canUseExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing)) {
729     // Process aliasing metadata.
730     const CallBase *CI = Info.CB;
731     if (CI && CI->hasMetadata()) {
732       if (MDNode *MD = CI->getMetadata(LLVMContext::MD_alias_scope))
733         GR->buildMemAliasingOpDecorate(ResVReg, MIRBuilder,
734                                        SPIRV::Decoration::AliasScopeINTEL, MD);
735       if (MDNode *MD = CI->getMetadata(LLVMContext::MD_noalias))
736         GR->buildMemAliasingOpDecorate(ResVReg, MIRBuilder,
737                                        SPIRV::Decoration::NoAliasINTEL, MD);
738     }
739   }
740 
741   return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
742                               *ST->getRegBankInfo());
743 }
744