xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
181ad6265SDimitry Andric //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // This file implements the SPIRVTargetLowering class.
1081ad6265SDimitry Andric //
1181ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1281ad6265SDimitry Andric 
1381ad6265SDimitry Andric #include "SPIRVISelLowering.h"
1481ad6265SDimitry Andric #include "SPIRV.h"
15bdd1243dSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
1681ad6265SDimitry Andric 
1781ad6265SDimitry Andric #define DEBUG_TYPE "spirv-lower"
1881ad6265SDimitry Andric 
1981ad6265SDimitry Andric using namespace llvm;
2081ad6265SDimitry Andric 
2181ad6265SDimitry Andric unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
2281ad6265SDimitry Andric     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
2381ad6265SDimitry Andric   // This code avoids CallLowering fail inside getVectorTypeBreakdown
2481ad6265SDimitry Andric   // on v3i1 arguments. Maybe we need to return 1 for all types.
2581ad6265SDimitry Andric   // TODO: remove it once this case is supported by the default implementation.
2681ad6265SDimitry Andric   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
2781ad6265SDimitry Andric       (VT.getVectorElementType() == MVT::i1 ||
2881ad6265SDimitry Andric        VT.getVectorElementType() == MVT::i8))
2981ad6265SDimitry Andric     return 1;
30*06c3fb27SDimitry Andric   if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
31*06c3fb27SDimitry Andric     return 1;
3281ad6265SDimitry Andric   return getNumRegisters(Context, VT);
3381ad6265SDimitry Andric }
3481ad6265SDimitry Andric 
3581ad6265SDimitry Andric MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
3681ad6265SDimitry Andric                                                        CallingConv::ID CC,
3781ad6265SDimitry Andric                                                        EVT VT) const {
3881ad6265SDimitry Andric   // This code avoids CallLowering fail inside getVectorTypeBreakdown
3981ad6265SDimitry Andric   // on v3i1 arguments. Maybe we need to return i32 for all types.
4081ad6265SDimitry Andric   // TODO: remove it once this case is supported by the default implementation.
4181ad6265SDimitry Andric   if (VT.isVector() && VT.getVectorNumElements() == 3) {
4281ad6265SDimitry Andric     if (VT.getVectorElementType() == MVT::i1)
4381ad6265SDimitry Andric       return MVT::v4i1;
4481ad6265SDimitry Andric     else if (VT.getVectorElementType() == MVT::i8)
4581ad6265SDimitry Andric       return MVT::v4i8;
4681ad6265SDimitry Andric   }
4781ad6265SDimitry Andric   return getRegisterType(Context, VT);
4881ad6265SDimitry Andric }
49bdd1243dSDimitry Andric 
50bdd1243dSDimitry Andric bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
51bdd1243dSDimitry Andric                                              const CallInst &I,
52bdd1243dSDimitry Andric                                              MachineFunction &MF,
53bdd1243dSDimitry Andric                                              unsigned Intrinsic) const {
54bdd1243dSDimitry Andric   unsigned AlignIdx = 3;
55bdd1243dSDimitry Andric   switch (Intrinsic) {
56bdd1243dSDimitry Andric   case Intrinsic::spv_load:
57bdd1243dSDimitry Andric     AlignIdx = 2;
58*06c3fb27SDimitry Andric     [[fallthrough]];
59bdd1243dSDimitry Andric   case Intrinsic::spv_store: {
60bdd1243dSDimitry Andric     if (I.getNumOperands() >= AlignIdx + 1) {
61bdd1243dSDimitry Andric       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
62bdd1243dSDimitry Andric       Info.align = Align(AlignOp->getZExtValue());
63bdd1243dSDimitry Andric     }
64bdd1243dSDimitry Andric     Info.flags = static_cast<MachineMemOperand::Flags>(
65bdd1243dSDimitry Andric         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
66bdd1243dSDimitry Andric     Info.memVT = MVT::i64;
67bdd1243dSDimitry Andric     // TODO: take into account opaque pointers (don't use getElementType).
68bdd1243dSDimitry Andric     // MVT::getVT(PtrTy->getElementType());
69bdd1243dSDimitry Andric     return true;
70bdd1243dSDimitry Andric     break;
71bdd1243dSDimitry Andric   }
72bdd1243dSDimitry Andric   default:
73bdd1243dSDimitry Andric     break;
74bdd1243dSDimitry Andric   }
75bdd1243dSDimitry Andric   return false;
76bdd1243dSDimitry Andric }
77