1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SPIRVTargetLowering class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVISelLowering.h" 14 #include "SPIRV.h" 15 #include "llvm/IR/IntrinsicsSPIRV.h" 16 17 #define DEBUG_TYPE "spirv-lower" 18 19 using namespace llvm; 20 21 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( 22 LLVMContext &Context, CallingConv::ID CC, EVT VT) const { 23 // This code avoids CallLowering fail inside getVectorTypeBreakdown 24 // on v3i1 arguments. Maybe we need to return 1 for all types. 25 // TODO: remove it once this case is supported by the default implementation. 26 if (VT.isVector() && VT.getVectorNumElements() == 3 && 27 (VT.getVectorElementType() == MVT::i1 || 28 VT.getVectorElementType() == MVT::i8)) 29 return 1; 30 return getNumRegisters(Context, VT); 31 } 32 33 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, 34 CallingConv::ID CC, 35 EVT VT) const { 36 // This code avoids CallLowering fail inside getVectorTypeBreakdown 37 // on v3i1 arguments. Maybe we need to return i32 for all types. 38 // TODO: remove it once this case is supported by the default implementation. 39 if (VT.isVector() && VT.getVectorNumElements() == 3) { 40 if (VT.getVectorElementType() == MVT::i1) 41 return MVT::v4i1; 42 else if (VT.getVectorElementType() == MVT::i8) 43 return MVT::v4i8; 44 } 45 return getRegisterType(Context, VT); 46 } 47 48 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, 49 const CallInst &I, 50 MachineFunction &MF, 51 unsigned Intrinsic) const { 52 unsigned AlignIdx = 3; 53 switch (Intrinsic) { 54 case Intrinsic::spv_load: 55 AlignIdx = 2; 56 LLVM_FALLTHROUGH; 57 case Intrinsic::spv_store: { 58 if (I.getNumOperands() >= AlignIdx + 1) { 59 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx)); 60 Info.align = Align(AlignOp->getZExtValue()); 61 } 62 Info.flags = static_cast<MachineMemOperand::Flags>( 63 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue()); 64 Info.memVT = MVT::i64; 65 // TODO: take into account opaque pointers (don't use getElementType). 66 // MVT::getVT(PtrTy->getElementType()); 67 return true; 68 break; 69 } 70 default: 71 break; 72 } 73 return false; 74 } 75