xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1*81ad6265SDimitry Andric //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2*81ad6265SDimitry Andric //
3*81ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*81ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*81ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*81ad6265SDimitry Andric //
7*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
8*81ad6265SDimitry Andric //
9*81ad6265SDimitry Andric // This file implements the lowering of LLVM calls to machine code calls for
10*81ad6265SDimitry Andric // GlobalISel.
11*81ad6265SDimitry Andric //
12*81ad6265SDimitry Andric //===----------------------------------------------------------------------===//
13*81ad6265SDimitry Andric 
14*81ad6265SDimitry Andric #include "SPIRVCallLowering.h"
15*81ad6265SDimitry Andric #include "MCTargetDesc/SPIRVBaseInfo.h"
16*81ad6265SDimitry Andric #include "SPIRV.h"
17*81ad6265SDimitry Andric #include "SPIRVGlobalRegistry.h"
18*81ad6265SDimitry Andric #include "SPIRVISelLowering.h"
19*81ad6265SDimitry Andric #include "SPIRVRegisterInfo.h"
20*81ad6265SDimitry Andric #include "SPIRVSubtarget.h"
21*81ad6265SDimitry Andric #include "SPIRVUtils.h"
22*81ad6265SDimitry Andric #include "llvm/CodeGen/FunctionLoweringInfo.h"
23*81ad6265SDimitry Andric 
24*81ad6265SDimitry Andric using namespace llvm;
25*81ad6265SDimitry Andric 
26*81ad6265SDimitry Andric SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
27*81ad6265SDimitry Andric                                      const SPIRVSubtarget &ST,
28*81ad6265SDimitry Andric                                      SPIRVGlobalRegistry *GR)
29*81ad6265SDimitry Andric     : CallLowering(&TLI), ST(ST), GR(GR) {}
30*81ad6265SDimitry Andric 
31*81ad6265SDimitry Andric bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
32*81ad6265SDimitry Andric                                     const Value *Val, ArrayRef<Register> VRegs,
33*81ad6265SDimitry Andric                                     FunctionLoweringInfo &FLI,
34*81ad6265SDimitry Andric                                     Register SwiftErrorVReg) const {
35*81ad6265SDimitry Andric   // Currently all return types should use a single register.
36*81ad6265SDimitry Andric   // TODO: handle the case of multiple registers.
37*81ad6265SDimitry Andric   if (VRegs.size() > 1)
38*81ad6265SDimitry Andric     return false;
39*81ad6265SDimitry Andric   if (Val)
40*81ad6265SDimitry Andric     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
41*81ad6265SDimitry Andric         .addUse(VRegs[0])
42*81ad6265SDimitry Andric         .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
43*81ad6265SDimitry Andric                           *ST.getRegBankInfo());
44*81ad6265SDimitry Andric   MIRBuilder.buildInstr(SPIRV::OpReturn);
45*81ad6265SDimitry Andric   return true;
46*81ad6265SDimitry Andric }
47*81ad6265SDimitry Andric 
48*81ad6265SDimitry Andric // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
49*81ad6265SDimitry Andric static uint32_t getFunctionControl(const Function &F) {
50*81ad6265SDimitry Andric   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
51*81ad6265SDimitry Andric   if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) {
52*81ad6265SDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
53*81ad6265SDimitry Andric   }
54*81ad6265SDimitry Andric   if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) {
55*81ad6265SDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
56*81ad6265SDimitry Andric   }
57*81ad6265SDimitry Andric   if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) {
58*81ad6265SDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
59*81ad6265SDimitry Andric   }
60*81ad6265SDimitry Andric   if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) {
61*81ad6265SDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
62*81ad6265SDimitry Andric   }
63*81ad6265SDimitry Andric   return FuncControl;
64*81ad6265SDimitry Andric }
65*81ad6265SDimitry Andric 
66*81ad6265SDimitry Andric bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
67*81ad6265SDimitry Andric                                              const Function &F,
68*81ad6265SDimitry Andric                                              ArrayRef<ArrayRef<Register>> VRegs,
69*81ad6265SDimitry Andric                                              FunctionLoweringInfo &FLI) const {
70*81ad6265SDimitry Andric   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
71*81ad6265SDimitry Andric 
72*81ad6265SDimitry Andric   // Assign types and names to all args, and store their types for later.
73*81ad6265SDimitry Andric   SmallVector<Register, 4> ArgTypeVRegs;
74*81ad6265SDimitry Andric   if (VRegs.size() > 0) {
75*81ad6265SDimitry Andric     unsigned i = 0;
76*81ad6265SDimitry Andric     for (const auto &Arg : F.args()) {
77*81ad6265SDimitry Andric       // Currently formal args should use single registers.
78*81ad6265SDimitry Andric       // TODO: handle the case of multiple registers.
79*81ad6265SDimitry Andric       if (VRegs[i].size() > 1)
80*81ad6265SDimitry Andric         return false;
81*81ad6265SDimitry Andric       auto *SpirvTy =
82*81ad6265SDimitry Andric           GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder);
83*81ad6265SDimitry Andric       ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy));
84*81ad6265SDimitry Andric 
85*81ad6265SDimitry Andric       if (Arg.hasName())
86*81ad6265SDimitry Andric         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
87*81ad6265SDimitry Andric       if (Arg.getType()->isPointerTy()) {
88*81ad6265SDimitry Andric         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
89*81ad6265SDimitry Andric         if (DerefBytes != 0)
90*81ad6265SDimitry Andric           buildOpDecorate(VRegs[i][0], MIRBuilder,
91*81ad6265SDimitry Andric                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
92*81ad6265SDimitry Andric       }
93*81ad6265SDimitry Andric       if (Arg.hasAttribute(Attribute::Alignment)) {
94*81ad6265SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
95*81ad6265SDimitry Andric                         {static_cast<unsigned>(Arg.getParamAlignment())});
96*81ad6265SDimitry Andric       }
97*81ad6265SDimitry Andric       if (Arg.hasAttribute(Attribute::ReadOnly)) {
98*81ad6265SDimitry Andric         auto Attr =
99*81ad6265SDimitry Andric             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
100*81ad6265SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder,
101*81ad6265SDimitry Andric                         SPIRV::Decoration::FuncParamAttr, {Attr});
102*81ad6265SDimitry Andric       }
103*81ad6265SDimitry Andric       if (Arg.hasAttribute(Attribute::ZExt)) {
104*81ad6265SDimitry Andric         auto Attr =
105*81ad6265SDimitry Andric             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
106*81ad6265SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder,
107*81ad6265SDimitry Andric                         SPIRV::Decoration::FuncParamAttr, {Attr});
108*81ad6265SDimitry Andric       }
109*81ad6265SDimitry Andric       ++i;
110*81ad6265SDimitry Andric     }
111*81ad6265SDimitry Andric   }
112*81ad6265SDimitry Andric 
113*81ad6265SDimitry Andric   // Generate a SPIR-V type for the function.
114*81ad6265SDimitry Andric   auto MRI = MIRBuilder.getMRI();
115*81ad6265SDimitry Andric   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
116*81ad6265SDimitry Andric   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
117*81ad6265SDimitry Andric 
118*81ad6265SDimitry Andric   auto *FTy = F.getFunctionType();
119*81ad6265SDimitry Andric   auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder);
120*81ad6265SDimitry Andric 
121*81ad6265SDimitry Andric   // Build the OpTypeFunction declaring it.
122*81ad6265SDimitry Andric   Register ReturnTypeID = FuncTy->getOperand(1).getReg();
123*81ad6265SDimitry Andric   uint32_t FuncControl = getFunctionControl(F);
124*81ad6265SDimitry Andric 
125*81ad6265SDimitry Andric   MIRBuilder.buildInstr(SPIRV::OpFunction)
126*81ad6265SDimitry Andric       .addDef(FuncVReg)
127*81ad6265SDimitry Andric       .addUse(ReturnTypeID)
128*81ad6265SDimitry Andric       .addImm(FuncControl)
129*81ad6265SDimitry Andric       .addUse(GR->getSPIRVTypeID(FuncTy));
130*81ad6265SDimitry Andric 
131*81ad6265SDimitry Andric   // Add OpFunctionParameters.
132*81ad6265SDimitry Andric   const unsigned NumArgs = ArgTypeVRegs.size();
133*81ad6265SDimitry Andric   for (unsigned i = 0; i < NumArgs; ++i) {
134*81ad6265SDimitry Andric     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
135*81ad6265SDimitry Andric     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
136*81ad6265SDimitry Andric     MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
137*81ad6265SDimitry Andric         .addDef(VRegs[i][0])
138*81ad6265SDimitry Andric         .addUse(ArgTypeVRegs[i]);
139*81ad6265SDimitry Andric   }
140*81ad6265SDimitry Andric   // Name the function.
141*81ad6265SDimitry Andric   if (F.hasName())
142*81ad6265SDimitry Andric     buildOpName(FuncVReg, F.getName(), MIRBuilder);
143*81ad6265SDimitry Andric 
144*81ad6265SDimitry Andric   // Handle entry points and function linkage.
145*81ad6265SDimitry Andric   if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
146*81ad6265SDimitry Andric     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
147*81ad6265SDimitry Andric                    .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel))
148*81ad6265SDimitry Andric                    .addUse(FuncVReg);
149*81ad6265SDimitry Andric     addStringImm(F.getName(), MIB);
150*81ad6265SDimitry Andric   } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
151*81ad6265SDimitry Andric              F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
152*81ad6265SDimitry Andric     auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
153*81ad6265SDimitry Andric                                    : SPIRV::LinkageType::Export;
154*81ad6265SDimitry Andric     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
155*81ad6265SDimitry Andric                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
156*81ad6265SDimitry Andric   }
157*81ad6265SDimitry Andric 
158*81ad6265SDimitry Andric   return true;
159*81ad6265SDimitry Andric }
160*81ad6265SDimitry Andric 
161*81ad6265SDimitry Andric bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
162*81ad6265SDimitry Andric                                   CallLoweringInfo &Info) const {
163*81ad6265SDimitry Andric   // Currently call returns should have single vregs.
164*81ad6265SDimitry Andric   // TODO: handle the case of multiple registers.
165*81ad6265SDimitry Andric   if (Info.OrigRet.Regs.size() > 1)
166*81ad6265SDimitry Andric     return false;
167*81ad6265SDimitry Andric 
168*81ad6265SDimitry Andric   Register ResVReg =
169*81ad6265SDimitry Andric       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
170*81ad6265SDimitry Andric   // Emit a regular OpFunctionCall. If it's an externally declared function,
171*81ad6265SDimitry Andric   // be sure to emit its type and function declaration here. It will be
172*81ad6265SDimitry Andric   // hoisted globally later.
173*81ad6265SDimitry Andric   if (Info.Callee.isGlobal()) {
174*81ad6265SDimitry Andric     auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
175*81ad6265SDimitry Andric     // TODO: support constexpr casts and indirect calls.
176*81ad6265SDimitry Andric     if (CF == nullptr)
177*81ad6265SDimitry Andric       return false;
178*81ad6265SDimitry Andric     if (CF->isDeclaration()) {
179*81ad6265SDimitry Andric       // Emit the type info and forward function declaration to the first MBB
180*81ad6265SDimitry Andric       // to ensure VReg definition dependencies are valid across all MBBs.
181*81ad6265SDimitry Andric       MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt();
182*81ad6265SDimitry Andric       MachineBasicBlock &OldBB = MIRBuilder.getMBB();
183*81ad6265SDimitry Andric       MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0);
184*81ad6265SDimitry Andric       MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end());
185*81ad6265SDimitry Andric 
186*81ad6265SDimitry Andric       SmallVector<ArrayRef<Register>, 8> VRegArgs;
187*81ad6265SDimitry Andric       SmallVector<SmallVector<Register, 1>, 8> ToInsert;
188*81ad6265SDimitry Andric       for (const Argument &Arg : CF->args()) {
189*81ad6265SDimitry Andric         if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
190*81ad6265SDimitry Andric           continue; // Don't handle zero sized types.
191*81ad6265SDimitry Andric         ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister(
192*81ad6265SDimitry Andric             LLT::scalar(32))});
193*81ad6265SDimitry Andric         VRegArgs.push_back(ToInsert.back());
194*81ad6265SDimitry Andric       }
195*81ad6265SDimitry Andric       // TODO: Reuse FunctionLoweringInfo.
196*81ad6265SDimitry Andric       FunctionLoweringInfo FuncInfo;
197*81ad6265SDimitry Andric       lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo);
198*81ad6265SDimitry Andric       MIRBuilder.setInsertPt(OldBB, OldII);
199*81ad6265SDimitry Andric     }
200*81ad6265SDimitry Andric   }
201*81ad6265SDimitry Andric 
202*81ad6265SDimitry Andric   // Make sure there's a valid return reg, even for functions returning void.
203*81ad6265SDimitry Andric   if (!ResVReg.isValid()) {
204*81ad6265SDimitry Andric     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
205*81ad6265SDimitry Andric   }
206*81ad6265SDimitry Andric   SPIRVType *RetType =
207*81ad6265SDimitry Andric       GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder);
208*81ad6265SDimitry Andric 
209*81ad6265SDimitry Andric   // Emit the OpFunctionCall and its args.
210*81ad6265SDimitry Andric   auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
211*81ad6265SDimitry Andric                  .addDef(ResVReg)
212*81ad6265SDimitry Andric                  .addUse(GR->getSPIRVTypeID(RetType))
213*81ad6265SDimitry Andric                  .add(Info.Callee);
214*81ad6265SDimitry Andric 
215*81ad6265SDimitry Andric   for (const auto &Arg : Info.OrigArgs) {
216*81ad6265SDimitry Andric     // Currently call args should have single vregs.
217*81ad6265SDimitry Andric     if (Arg.Regs.size() > 1)
218*81ad6265SDimitry Andric       return false;
219*81ad6265SDimitry Andric     MIB.addUse(Arg.Regs[0]);
220*81ad6265SDimitry Andric   }
221*81ad6265SDimitry Andric   return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
222*81ad6265SDimitry Andric                               *ST.getRegBankInfo());
223*81ad6265SDimitry Andric }
224