1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SPIRVTargetLowering class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVISelLowering.h" 14 #include "SPIRV.h" 15 #include "llvm/IR/IntrinsicsSPIRV.h" 16 17 #define DEBUG_TYPE "spirv-lower" 18 19 using namespace llvm; 20 21 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( 22 LLVMContext &Context, CallingConv::ID CC, EVT VT) const { 23 // This code avoids CallLowering fail inside getVectorTypeBreakdown 24 // on v3i1 arguments. Maybe we need to return 1 for all types. 25 // TODO: remove it once this case is supported by the default implementation. 26 if (VT.isVector() && VT.getVectorNumElements() == 3 && 27 (VT.getVectorElementType() == MVT::i1 || 28 VT.getVectorElementType() == MVT::i8)) 29 return 1; 30 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64) 31 return 1; 32 return getNumRegisters(Context, VT); 33 } 34 35 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, 36 CallingConv::ID CC, 37 EVT VT) const { 38 // This code avoids CallLowering fail inside getVectorTypeBreakdown 39 // on v3i1 arguments. Maybe we need to return i32 for all types. 40 // TODO: remove it once this case is supported by the default implementation. 41 if (VT.isVector() && VT.getVectorNumElements() == 3) { 42 if (VT.getVectorElementType() == MVT::i1) 43 return MVT::v4i1; 44 else if (VT.getVectorElementType() == MVT::i8) 45 return MVT::v4i8; 46 } 47 return getRegisterType(Context, VT); 48 } 49 50 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, 51 const CallInst &I, 52 MachineFunction &MF, 53 unsigned Intrinsic) const { 54 unsigned AlignIdx = 3; 55 switch (Intrinsic) { 56 case Intrinsic::spv_load: 57 AlignIdx = 2; 58 [[fallthrough]]; 59 case Intrinsic::spv_store: { 60 if (I.getNumOperands() >= AlignIdx + 1) { 61 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx)); 62 Info.align = Align(AlignOp->getZExtValue()); 63 } 64 Info.flags = static_cast<MachineMemOperand::Flags>( 65 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue()); 66 Info.memVT = MVT::i64; 67 // TODO: take into account opaque pointers (don't use getElementType). 68 // MVT::getVT(PtrTy->getElementType()); 69 return true; 70 break; 71 } 72 default: 73 break; 74 } 75 return false; 76 } 77