1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of LLVM calls to machine code calls for 10 // GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRVCallLowering.h" 15 #include "MCTargetDesc/SPIRVBaseInfo.h" 16 #include "SPIRV.h" 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRVISelLowering.h" 19 #include "SPIRVRegisterInfo.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVUtils.h" 22 #include "llvm/CodeGen/FunctionLoweringInfo.h" 23 24 using namespace llvm; 25 26 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 27 SPIRVGlobalRegistry *GR) 28 : CallLowering(&TLI), GR(GR) {} 29 30 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 31 const Value *Val, ArrayRef<Register> VRegs, 32 FunctionLoweringInfo &FLI, 33 Register SwiftErrorVReg) const { 34 // Currently all return types should use a single register. 35 // TODO: handle the case of multiple registers. 36 if (VRegs.size() > 1) 37 return false; 38 if (Val) { 39 const auto &STI = MIRBuilder.getMF().getSubtarget(); 40 return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 41 .addUse(VRegs[0]) 42 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 43 *STI.getRegBankInfo()); 44 } 45 MIRBuilder.buildInstr(SPIRV::OpReturn); 46 return true; 47 } 48 49 // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 50 static uint32_t getFunctionControl(const Function &F) { 51 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 52 if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { 53 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 54 } 55 if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { 56 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 57 } 58 if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { 59 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 60 } 61 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { 62 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 63 } 64 return FuncControl; 65 } 66 67 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { 68 if (MD->getNumOperands() > NumOp) { 69 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); 70 if (CMeta) 71 return dyn_cast<ConstantInt>(CMeta->getValue()); 72 } 73 return nullptr; 74 } 75 76 // This code restores function args/retvalue types for composite cases 77 // because the final types should still be aggregate whereas they're i32 78 // during the translation to cope with aggregate flattening etc. 79 static FunctionType *getOriginalFunctionType(const Function &F) { 80 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); 81 if (NamedMD == nullptr) 82 return F.getFunctionType(); 83 84 Type *RetTy = F.getFunctionType()->getReturnType(); 85 SmallVector<Type *, 4> ArgTypes; 86 for (auto &Arg : F.args()) 87 ArgTypes.push_back(Arg.getType()); 88 89 auto ThisFuncMDIt = 90 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { 91 return isa<MDString>(N->getOperand(0)) && 92 cast<MDString>(N->getOperand(0))->getString() == F.getName(); 93 }); 94 // TODO: probably one function can have numerous type mutations, 95 // so we should support this. 96 if (ThisFuncMDIt != NamedMD->op_end()) { 97 auto *ThisFuncMD = *ThisFuncMDIt; 98 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); 99 assert(MD && "MDNode operand is expected"); 100 ConstantInt *Const = getConstInt(MD, 0); 101 if (Const) { 102 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); 103 assert(CMeta && "ConstantAsMetadata operand is expected"); 104 assert(Const->getSExtValue() >= -1); 105 // Currently -1 indicates return value, greater values mean 106 // argument numbers. 107 if (Const->getSExtValue() == -1) 108 RetTy = CMeta->getType(); 109 else 110 ArgTypes[Const->getSExtValue()] = CMeta->getType(); 111 } 112 } 113 114 return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); 115 } 116 117 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 118 const Function &F, 119 ArrayRef<ArrayRef<Register>> VRegs, 120 FunctionLoweringInfo &FLI) const { 121 assert(GR && "Must initialize the SPIRV type registry before lowering args."); 122 GR->setCurrentFunc(MIRBuilder.getMF()); 123 124 // Assign types and names to all args, and store their types for later. 125 FunctionType *FTy = getOriginalFunctionType(F); 126 SmallVector<SPIRVType *, 4> ArgTypeVRegs; 127 if (VRegs.size() > 0) { 128 unsigned i = 0; 129 for (const auto &Arg : F.args()) { 130 // Currently formal args should use single registers. 131 // TODO: handle the case of multiple registers. 132 if (VRegs[i].size() > 1) 133 return false; 134 Type *ArgTy = FTy->getParamType(i); 135 SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite; 136 MDNode *Node = F.getMetadata("kernel_arg_access_qual"); 137 if (Node && i < Node->getNumOperands()) { 138 StringRef AQString = cast<MDString>(Node->getOperand(i))->getString(); 139 if (AQString.compare("read_only") == 0) 140 AQ = SPIRV::AccessQualifier::ReadOnly; 141 else if (AQString.compare("write_only") == 0) 142 AQ = SPIRV::AccessQualifier::WriteOnly; 143 } 144 auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ); 145 ArgTypeVRegs.push_back(SpirvTy); 146 147 if (Arg.hasName()) 148 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 149 if (Arg.getType()->isPointerTy()) { 150 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 151 if (DerefBytes != 0) 152 buildOpDecorate(VRegs[i][0], MIRBuilder, 153 SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 154 } 155 if (Arg.hasAttribute(Attribute::Alignment)) { 156 auto Alignment = static_cast<unsigned>( 157 Arg.getAttribute(Attribute::Alignment).getValueAsInt()); 158 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 159 {Alignment}); 160 } 161 if (Arg.hasAttribute(Attribute::ReadOnly)) { 162 auto Attr = 163 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 164 buildOpDecorate(VRegs[i][0], MIRBuilder, 165 SPIRV::Decoration::FuncParamAttr, {Attr}); 166 } 167 if (Arg.hasAttribute(Attribute::ZExt)) { 168 auto Attr = 169 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 170 buildOpDecorate(VRegs[i][0], MIRBuilder, 171 SPIRV::Decoration::FuncParamAttr, {Attr}); 172 } 173 if (Arg.hasAttribute(Attribute::NoAlias)) { 174 auto Attr = 175 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); 176 buildOpDecorate(VRegs[i][0], MIRBuilder, 177 SPIRV::Decoration::FuncParamAttr, {Attr}); 178 } 179 Node = F.getMetadata("kernel_arg_type_qual"); 180 if (Node && i < Node->getNumOperands()) { 181 StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString(); 182 if (TypeQual.compare("volatile") == 0) 183 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile, 184 {}); 185 } 186 Node = F.getMetadata("spirv.ParameterDecorations"); 187 if (Node && i < Node->getNumOperands() && 188 isa<MDNode>(Node->getOperand(i))) { 189 MDNode *MD = cast<MDNode>(Node->getOperand(i)); 190 for (const MDOperand &MDOp : MD->operands()) { 191 MDNode *MD2 = dyn_cast<MDNode>(MDOp); 192 assert(MD2 && "Metadata operand is expected"); 193 ConstantInt *Const = getConstInt(MD2, 0); 194 assert(Const && "MDOperand should be ConstantInt"); 195 auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue()); 196 std::vector<uint32_t> DecVec; 197 for (unsigned j = 1; j < MD2->getNumOperands(); j++) { 198 ConstantInt *Const = getConstInt(MD2, j); 199 assert(Const && "MDOperand should be ConstantInt"); 200 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); 201 } 202 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); 203 } 204 } 205 ++i; 206 } 207 } 208 209 // Generate a SPIR-V type for the function. 210 auto MRI = MIRBuilder.getMRI(); 211 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 212 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); 213 if (F.isDeclaration()) 214 GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 215 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); 216 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 217 FTy, RetTy, ArgTypeVRegs, MIRBuilder); 218 219 // Build the OpTypeFunction declaring it. 220 uint32_t FuncControl = getFunctionControl(F); 221 222 MIRBuilder.buildInstr(SPIRV::OpFunction) 223 .addDef(FuncVReg) 224 .addUse(GR->getSPIRVTypeID(RetTy)) 225 .addImm(FuncControl) 226 .addUse(GR->getSPIRVTypeID(FuncTy)); 227 228 // Add OpFunctionParameters. 229 int i = 0; 230 for (const auto &Arg : F.args()) { 231 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 232 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); 233 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 234 .addDef(VRegs[i][0]) 235 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); 236 if (F.isDeclaration()) 237 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); 238 i++; 239 } 240 // Name the function. 241 if (F.hasName()) 242 buildOpName(FuncVReg, F.getName(), MIRBuilder); 243 244 // Handle entry points and function linkage. 245 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 246 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 247 .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel)) 248 .addUse(FuncVReg); 249 addStringImm(F.getName(), MIB); 250 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || 251 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { 252 auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import 253 : SPIRV::LinkageType::Export; 254 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 255 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 256 } 257 258 return true; 259 } 260 261 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 262 CallLoweringInfo &Info) const { 263 // Currently call returns should have single vregs. 264 // TODO: handle the case of multiple registers. 265 if (Info.OrigRet.Regs.size() > 1) 266 return false; 267 MachineFunction &MF = MIRBuilder.getMF(); 268 GR->setCurrentFunc(MF); 269 FunctionType *FTy = nullptr; 270 const Function *CF = nullptr; 271 272 // Emit a regular OpFunctionCall. If it's an externally declared function, 273 // be sure to emit its type and function declaration here. It will be hoisted 274 // globally later. 275 if (Info.Callee.isGlobal()) { 276 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 277 // TODO: support constexpr casts and indirect calls. 278 if (CF == nullptr) 279 return false; 280 FTy = getOriginalFunctionType(*CF); 281 } 282 283 Register ResVReg = 284 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 285 if (CF && CF->isDeclaration() && 286 !GR->find(CF, &MIRBuilder.getMF()).isValid()) { 287 // Emit the type info and forward function declaration to the first MBB 288 // to ensure VReg definition dependencies are valid across all MBBs. 289 MachineIRBuilder FirstBlockBuilder; 290 FirstBlockBuilder.setMF(MF); 291 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); 292 293 SmallVector<ArrayRef<Register>, 8> VRegArgs; 294 SmallVector<SmallVector<Register, 1>, 8> ToInsert; 295 for (const Argument &Arg : CF->args()) { 296 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 297 continue; // Don't handle zero sized types. 298 ToInsert.push_back( 299 {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))}); 300 VRegArgs.push_back(ToInsert.back()); 301 } 302 // TODO: Reuse FunctionLoweringInfo 303 FunctionLoweringInfo FuncInfo; 304 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); 305 } 306 307 // Make sure there's a valid return reg, even for functions returning void. 308 if (!ResVReg.isValid()) 309 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 310 SPIRVType *RetType = 311 GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); 312 313 // Emit the OpFunctionCall and its args. 314 auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) 315 .addDef(ResVReg) 316 .addUse(GR->getSPIRVTypeID(RetType)) 317 .add(Info.Callee); 318 319 for (const auto &Arg : Info.OrigArgs) { 320 // Currently call args should have single vregs. 321 if (Arg.Regs.size() > 1) 322 return false; 323 MIB.addUse(Arg.Regs[0]); 324 } 325 const auto &STI = MF.getSubtarget(); 326 return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 327 *STI.getRegBankInfo()); 328 } 329