xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (revision 577b62c2bacc7dfa228591ca3da361e1bc398301)
1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRVISelLowering.h"
19 #include "SPIRVRegisterInfo.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVUtils.h"
22 #include "llvm/CodeGen/FunctionLoweringInfo.h"
23 
24 using namespace llvm;
25 
26 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
27                                      SPIRVGlobalRegistry *GR)
28     : CallLowering(&TLI), GR(GR) {}
29 
30 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
31                                     const Value *Val, ArrayRef<Register> VRegs,
32                                     FunctionLoweringInfo &FLI,
33                                     Register SwiftErrorVReg) const {
34   // Currently all return types should use a single register.
35   // TODO: handle the case of multiple registers.
36   if (VRegs.size() > 1)
37     return false;
38   if (Val) {
39     const auto &STI = MIRBuilder.getMF().getSubtarget();
40     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
41         .addUse(VRegs[0])
42         .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
43                           *STI.getRegBankInfo());
44   }
45   MIRBuilder.buildInstr(SPIRV::OpReturn);
46   return true;
47 }
48 
49 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
50 static uint32_t getFunctionControl(const Function &F) {
51   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
52   if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) {
53     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
54   }
55   if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) {
56     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
57   }
58   if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) {
59     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
60   }
61   if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) {
62     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
63   }
64   return FuncControl;
65 }
66 
67 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
68   if (MD->getNumOperands() > NumOp) {
69     auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
70     if (CMeta)
71       return dyn_cast<ConstantInt>(CMeta->getValue());
72   }
73   return nullptr;
74 }
75 
76 // This code restores function args/retvalue types for composite cases
77 // because the final types should still be aggregate whereas they're i32
78 // during the translation to cope with aggregate flattening etc.
79 static FunctionType *getOriginalFunctionType(const Function &F) {
80   auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
81   if (NamedMD == nullptr)
82     return F.getFunctionType();
83 
84   Type *RetTy = F.getFunctionType()->getReturnType();
85   SmallVector<Type *, 4> ArgTypes;
86   for (auto &Arg : F.args())
87     ArgTypes.push_back(Arg.getType());
88 
89   auto ThisFuncMDIt =
90       std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
91         return isa<MDString>(N->getOperand(0)) &&
92                cast<MDString>(N->getOperand(0))->getString() == F.getName();
93       });
94   // TODO: probably one function can have numerous type mutations,
95   // so we should support this.
96   if (ThisFuncMDIt != NamedMD->op_end()) {
97     auto *ThisFuncMD = *ThisFuncMDIt;
98     MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
99     assert(MD && "MDNode operand is expected");
100     ConstantInt *Const = getConstInt(MD, 0);
101     if (Const) {
102       auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
103       assert(CMeta && "ConstantAsMetadata operand is expected");
104       assert(Const->getSExtValue() >= -1);
105       // Currently -1 indicates return value, greater values mean
106       // argument numbers.
107       if (Const->getSExtValue() == -1)
108         RetTy = CMeta->getType();
109       else
110         ArgTypes[Const->getSExtValue()] = CMeta->getType();
111     }
112   }
113 
114   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
115 }
116 
117 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
118                                              const Function &F,
119                                              ArrayRef<ArrayRef<Register>> VRegs,
120                                              FunctionLoweringInfo &FLI) const {
121   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
122   GR->setCurrentFunc(MIRBuilder.getMF());
123 
124   // Assign types and names to all args, and store their types for later.
125   FunctionType *FTy = getOriginalFunctionType(F);
126   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
127   if (VRegs.size() > 0) {
128     unsigned i = 0;
129     for (const auto &Arg : F.args()) {
130       // Currently formal args should use single registers.
131       // TODO: handle the case of multiple registers.
132       if (VRegs[i].size() > 1)
133         return false;
134       Type *ArgTy = FTy->getParamType(i);
135       SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite;
136       MDNode *Node = F.getMetadata("kernel_arg_access_qual");
137       if (Node && i < Node->getNumOperands()) {
138         StringRef AQString = cast<MDString>(Node->getOperand(i))->getString();
139         if (AQString.compare("read_only") == 0)
140           AQ = SPIRV::AccessQualifier::ReadOnly;
141         else if (AQString.compare("write_only") == 0)
142           AQ = SPIRV::AccessQualifier::WriteOnly;
143       }
144       auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ);
145       ArgTypeVRegs.push_back(SpirvTy);
146 
147       if (Arg.hasName())
148         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
149       if (Arg.getType()->isPointerTy()) {
150         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
151         if (DerefBytes != 0)
152           buildOpDecorate(VRegs[i][0], MIRBuilder,
153                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
154       }
155       if (Arg.hasAttribute(Attribute::Alignment)) {
156         auto Alignment = static_cast<unsigned>(
157             Arg.getAttribute(Attribute::Alignment).getValueAsInt());
158         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
159                         {Alignment});
160       }
161       if (Arg.hasAttribute(Attribute::ReadOnly)) {
162         auto Attr =
163             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
164         buildOpDecorate(VRegs[i][0], MIRBuilder,
165                         SPIRV::Decoration::FuncParamAttr, {Attr});
166       }
167       if (Arg.hasAttribute(Attribute::ZExt)) {
168         auto Attr =
169             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
170         buildOpDecorate(VRegs[i][0], MIRBuilder,
171                         SPIRV::Decoration::FuncParamAttr, {Attr});
172       }
173       if (Arg.hasAttribute(Attribute::NoAlias)) {
174         auto Attr =
175             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
176         buildOpDecorate(VRegs[i][0], MIRBuilder,
177                         SPIRV::Decoration::FuncParamAttr, {Attr});
178       }
179       Node = F.getMetadata("kernel_arg_type_qual");
180       if (Node && i < Node->getNumOperands()) {
181         StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString();
182         if (TypeQual.compare("volatile") == 0)
183           buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile,
184                           {});
185       }
186       Node = F.getMetadata("spirv.ParameterDecorations");
187       if (Node && i < Node->getNumOperands() &&
188           isa<MDNode>(Node->getOperand(i))) {
189         MDNode *MD = cast<MDNode>(Node->getOperand(i));
190         for (const MDOperand &MDOp : MD->operands()) {
191           MDNode *MD2 = dyn_cast<MDNode>(MDOp);
192           assert(MD2 && "Metadata operand is expected");
193           ConstantInt *Const = getConstInt(MD2, 0);
194           assert(Const && "MDOperand should be ConstantInt");
195           auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue());
196           std::vector<uint32_t> DecVec;
197           for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
198             ConstantInt *Const = getConstInt(MD2, j);
199             assert(Const && "MDOperand should be ConstantInt");
200             DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
201           }
202           buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
203         }
204       }
205       ++i;
206     }
207   }
208 
209   // Generate a SPIR-V type for the function.
210   auto MRI = MIRBuilder.getMRI();
211   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
212   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
213   if (F.isDeclaration())
214     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
215   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
216   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
217       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
218 
219   // Build the OpTypeFunction declaring it.
220   uint32_t FuncControl = getFunctionControl(F);
221 
222   MIRBuilder.buildInstr(SPIRV::OpFunction)
223       .addDef(FuncVReg)
224       .addUse(GR->getSPIRVTypeID(RetTy))
225       .addImm(FuncControl)
226       .addUse(GR->getSPIRVTypeID(FuncTy));
227 
228   // Add OpFunctionParameters.
229   int i = 0;
230   for (const auto &Arg : F.args()) {
231     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
232     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
233     MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
234         .addDef(VRegs[i][0])
235         .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
236     if (F.isDeclaration())
237       GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
238     i++;
239   }
240   // Name the function.
241   if (F.hasName())
242     buildOpName(FuncVReg, F.getName(), MIRBuilder);
243 
244   // Handle entry points and function linkage.
245   if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
246     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
247                    .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel))
248                    .addUse(FuncVReg);
249     addStringImm(F.getName(), MIB);
250   } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
251              F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
252     auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
253                                    : SPIRV::LinkageType::Export;
254     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
255                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
256   }
257 
258   return true;
259 }
260 
261 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
262                                   CallLoweringInfo &Info) const {
263   // Currently call returns should have single vregs.
264   // TODO: handle the case of multiple registers.
265   if (Info.OrigRet.Regs.size() > 1)
266     return false;
267   MachineFunction &MF = MIRBuilder.getMF();
268   GR->setCurrentFunc(MF);
269   FunctionType *FTy = nullptr;
270   const Function *CF = nullptr;
271 
272   // Emit a regular OpFunctionCall. If it's an externally declared function,
273   // be sure to emit its type and function declaration here. It will be hoisted
274   // globally later.
275   if (Info.Callee.isGlobal()) {
276     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
277     // TODO: support constexpr casts and indirect calls.
278     if (CF == nullptr)
279       return false;
280     FTy = getOriginalFunctionType(*CF);
281   }
282 
283   Register ResVReg =
284       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
285   if (CF && CF->isDeclaration() &&
286       !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
287     // Emit the type info and forward function declaration to the first MBB
288     // to ensure VReg definition dependencies are valid across all MBBs.
289     MachineIRBuilder FirstBlockBuilder;
290     FirstBlockBuilder.setMF(MF);
291     FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
292 
293     SmallVector<ArrayRef<Register>, 8> VRegArgs;
294     SmallVector<SmallVector<Register, 1>, 8> ToInsert;
295     for (const Argument &Arg : CF->args()) {
296       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
297         continue; // Don't handle zero sized types.
298       ToInsert.push_back(
299           {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))});
300       VRegArgs.push_back(ToInsert.back());
301     }
302     // TODO: Reuse FunctionLoweringInfo
303     FunctionLoweringInfo FuncInfo;
304     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
305   }
306 
307   // Make sure there's a valid return reg, even for functions returning void.
308   if (!ResVReg.isValid())
309     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
310   SPIRVType *RetType =
311       GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
312 
313   // Emit the OpFunctionCall and its args.
314   auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
315                  .addDef(ResVReg)
316                  .addUse(GR->getSPIRVTypeID(RetType))
317                  .add(Info.Callee);
318 
319   for (const auto &Arg : Info.OrigArgs) {
320     // Currently call args should have single vregs.
321     if (Arg.Regs.size() > 1)
322       return false;
323     MIB.addUse(Arg.Regs[0]);
324   }
325   const auto &STI = MF.getSubtarget();
326   return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
327                               *STI.getRegBankInfo());
328 }
329