xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (revision f126890ac5386406dadf7c4cfa9566cbb56537c5)
1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SPIRVTargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVISelLowering.h"
14 #include "SPIRV.h"
15 #include "llvm/IR/IntrinsicsSPIRV.h"
16 
17 #define DEBUG_TYPE "spirv-lower"
18 
19 using namespace llvm;
20 
21 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
22     LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
23   // This code avoids CallLowering fail inside getVectorTypeBreakdown
24   // on v3i1 arguments. Maybe we need to return 1 for all types.
25   // TODO: remove it once this case is supported by the default implementation.
26   if (VT.isVector() && VT.getVectorNumElements() == 3 &&
27       (VT.getVectorElementType() == MVT::i1 ||
28        VT.getVectorElementType() == MVT::i8))
29     return 1;
30   if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
31     return 1;
32   return getNumRegisters(Context, VT);
33 }
34 
35 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
36                                                        CallingConv::ID CC,
37                                                        EVT VT) const {
38   // This code avoids CallLowering fail inside getVectorTypeBreakdown
39   // on v3i1 arguments. Maybe we need to return i32 for all types.
40   // TODO: remove it once this case is supported by the default implementation.
41   if (VT.isVector() && VT.getVectorNumElements() == 3) {
42     if (VT.getVectorElementType() == MVT::i1)
43       return MVT::v4i1;
44     else if (VT.getVectorElementType() == MVT::i8)
45       return MVT::v4i8;
46   }
47   return getRegisterType(Context, VT);
48 }
49 
50 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
51                                              const CallInst &I,
52                                              MachineFunction &MF,
53                                              unsigned Intrinsic) const {
54   unsigned AlignIdx = 3;
55   switch (Intrinsic) {
56   case Intrinsic::spv_load:
57     AlignIdx = 2;
58     [[fallthrough]];
59   case Intrinsic::spv_store: {
60     if (I.getNumOperands() >= AlignIdx + 1) {
61       auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
62       Info.align = Align(AlignOp->getZExtValue());
63     }
64     Info.flags = static_cast<MachineMemOperand::Flags>(
65         cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
66     Info.memVT = MVT::i64;
67     // TODO: take into account opaque pointers (don't use getElementType).
68     // MVT::getVT(PtrTy->getElementType());
69     return true;
70     break;
71   }
72   default:
73     break;
74   }
75   return false;
76 }
77