xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision 9f23cbd6cae82fd77edfad7173432fa8dccd0a95)
1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "llvm/IR/IntrinsicsSPIRV.h"
16 
17 #define DEBUG_TYPE "spirv-lower"
18 
19 using namespace llvm;
20 
21 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
22     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
23   // This code avoids CallLowering fail inside getVectorTypeBreakdown
24   // on v3i1 arguments. Maybe we need to return 1 for all types.
25   // TODO: remove it once this case is supported by the default implementation.
26   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
27       (VT.getVectorElementType() == MVT::i1 ||
28        VT.getVectorElementType() == MVT::i8))
29     return 1;
30   return getNumRegisters(Context, VT);
31 }
32 
33 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
34                                                        CallingConv::ID CC,
35                                                        EVT VT) const {
36   // This code avoids CallLowering fail inside getVectorTypeBreakdown
37   // on v3i1 arguments. Maybe we need to return i32 for all types.
38   // TODO: remove it once this case is supported by the default implementation.
39   if (VT.isVector() && VT.getVectorNumElements() == 3) {
40     if (VT.getVectorElementType() == MVT::i1)
41       return MVT::v4i1;
42     else if (VT.getVectorElementType() == MVT::i8)
43       return MVT::v4i8;
44   }
45   return getRegisterType(Context, VT);
46 }
47 
48 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
49                                              const CallInst &I,
50                                              MachineFunction &MF,
51                                              unsigned Intrinsic) const {
52   unsigned AlignIdx = 3;
53   switch (Intrinsic) {
54   case Intrinsic::spv_load:
55     AlignIdx = 2;
56     LLVM_FALLTHROUGH;
57   case Intrinsic::spv_store: {
58     if (I.getNumOperands() >= AlignIdx + 1) {
59       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
60       Info.align = Align(AlignOp->getZExtValue());
61     }
62     Info.flags = static_cast<MachineMemOperand::Flags>(
63         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
64     Info.memVT = MVT::i64;
65     // TODO: take into account opaque pointers (don't use getElementType).
66     // MVT::getVT(PtrTy->getElementType());
67     return true;
68     break;
69   }
70   default:
71     break;
72   }
73   return false;
74 }
75