xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (revision b3e7694832e81d7a904a10f525f8797b753bf0d3)
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28     TargetOpcode::G_ADD,
29     TargetOpcode::G_FADD,
30     TargetOpcode::G_SUB,
31     TargetOpcode::G_FSUB,
32     TargetOpcode::G_MUL,
33     TargetOpcode::G_FMUL,
34     TargetOpcode::G_SDIV,
35     TargetOpcode::G_UDIV,
36     TargetOpcode::G_FDIV,
37     TargetOpcode::G_SREM,
38     TargetOpcode::G_UREM,
39     TargetOpcode::G_FREM,
40     TargetOpcode::G_FNEG,
41     TargetOpcode::G_CONSTANT,
42     TargetOpcode::G_FCONSTANT,
43     TargetOpcode::G_AND,
44     TargetOpcode::G_OR,
45     TargetOpcode::G_XOR,
46     TargetOpcode::G_SHL,
47     TargetOpcode::G_ASHR,
48     TargetOpcode::G_LSHR,
49     TargetOpcode::G_SELECT,
50     TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52 
53 bool isTypeFoldingSupported(unsigned Opcode) {
54   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56 
57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58   using namespace TargetOpcode;
59 
60   this->ST = &ST;
61   GR = ST.getSPIRVGlobalRegistry();
62 
63   const LLT s1 = LLT::scalar(1);
64   const LLT s8 = LLT::scalar(8);
65   const LLT s16 = LLT::scalar(16);
66   const LLT s32 = LLT::scalar(32);
67   const LLT s64 = LLT::scalar(64);
68 
69   const LLT v16s64 = LLT::fixed_vector(16, 64);
70   const LLT v16s32 = LLT::fixed_vector(16, 32);
71   const LLT v16s16 = LLT::fixed_vector(16, 16);
72   const LLT v16s8 = LLT::fixed_vector(16, 8);
73   const LLT v16s1 = LLT::fixed_vector(16, 1);
74 
75   const LLT v8s64 = LLT::fixed_vector(8, 64);
76   const LLT v8s32 = LLT::fixed_vector(8, 32);
77   const LLT v8s16 = LLT::fixed_vector(8, 16);
78   const LLT v8s8 = LLT::fixed_vector(8, 8);
79   const LLT v8s1 = LLT::fixed_vector(8, 1);
80 
81   const LLT v4s64 = LLT::fixed_vector(4, 64);
82   const LLT v4s32 = LLT::fixed_vector(4, 32);
83   const LLT v4s16 = LLT::fixed_vector(4, 16);
84   const LLT v4s8 = LLT::fixed_vector(4, 8);
85   const LLT v4s1 = LLT::fixed_vector(4, 1);
86 
87   const LLT v3s64 = LLT::fixed_vector(3, 64);
88   const LLT v3s32 = LLT::fixed_vector(3, 32);
89   const LLT v3s16 = LLT::fixed_vector(3, 16);
90   const LLT v3s8 = LLT::fixed_vector(3, 8);
91   const LLT v3s1 = LLT::fixed_vector(3, 1);
92 
93   const LLT v2s64 = LLT::fixed_vector(2, 64);
94   const LLT v2s32 = LLT::fixed_vector(2, 32);
95   const LLT v2s16 = LLT::fixed_vector(2, 16);
96   const LLT v2s8 = LLT::fixed_vector(2, 8);
97   const LLT v2s1 = LLT::fixed_vector(2, 1);
98 
99   const unsigned PSize = ST.getPointerSize();
100   const LLT p0 = LLT::pointer(0, PSize); // Function
101   const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104   const LLT p4 = LLT::pointer(4, PSize); // Generic
105   const LLT p5 = LLT::pointer(5, PSize); // Input
106 
107   // TODO: remove copy-pasting here by using concatenation in some way.
108   auto allPtrsScalarsAndVectors = {
109       p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
110       s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
111       v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
112       v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113 
114   auto allScalarsAndVectors = {
115       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
116       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
117       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118 
119   auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
120                                   v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
121                                   v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
122                                   v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123 
124   auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125 
126   auto allIntScalars = {s8, s16, s32, s64};
127 
128   auto allFloatScalarsAndVectors = {
129       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
130       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131 
132   auto allFloatAndIntScalars = allIntScalars;
133 
134   auto allPtrs = {p0, p1, p2, p3, p4, p5};
135   auto allWritablePtrs = {p0, p1, p3, p4};
136 
137   for (auto Opc : TypeFoldingSupportingOpcs)
138     getActionDefinitionsBuilder(Opc).custom();
139 
140   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141 
142   // TODO: add proper rules for vectors legalization.
143   getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144 
145   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146       .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147 
148   getActionDefinitionsBuilder(G_MEMSET).legalIf(
149       all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150 
151   getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152       .legalForCartesianProduct(allPtrs, allPtrs);
153 
154   getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155 
156   getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157 
158   getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159 
160   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161       .legalForCartesianProduct(allIntScalarsAndVectors,
162                                 allFloatScalarsAndVectors);
163 
164   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165       .legalForCartesianProduct(allFloatScalarsAndVectors,
166                                 allScalarsAndVectors);
167 
168   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169       .legalFor(allIntScalarsAndVectors);
170 
171   getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
172       allIntScalarsAndVectors, allIntScalarsAndVectors);
173 
174   getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175 
176   getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
177       typeInSet(0, allPtrsScalarsAndVectors),
178       typeInSet(1, allPtrsScalarsAndVectors),
179       LegalityPredicate(([=](const LegalityQuery &Query) {
180         return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181       }))));
182 
183   getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
184 
185   getActionDefinitionsBuilder(G_INTTOPTR)
186       .legalForCartesianProduct(allPtrs, allIntScalars);
187   getActionDefinitionsBuilder(G_PTRTOINT)
188       .legalForCartesianProduct(allIntScalars, allPtrs);
189   getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
190       allPtrs, allIntScalars);
191 
192   // ST.canDirectlyComparePointers() for pointer args is supported in
193   // legalizeCustom().
194   getActionDefinitionsBuilder(G_ICMP).customIf(
195       all(typeInSet(0, allBoolScalarsAndVectors),
196           typeInSet(1, allPtrsScalarsAndVectors)));
197 
198   getActionDefinitionsBuilder(G_FCMP).legalIf(
199       all(typeInSet(0, allBoolScalarsAndVectors),
200           typeInSet(1, allFloatScalarsAndVectors)));
201 
202   getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203                                G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204                                G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206       .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207 
208   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209       .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210 
211   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212   // TODO: add proper legalization rules.
213   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214 
215   getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216       .alwaysLegal();
217 
218   // Extensions.
219   getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220       .legalForCartesianProduct(allScalarsAndVectors);
221 
222   // FP conversions.
223   getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224       .legalForCartesianProduct(allFloatScalarsAndVectors);
225 
226   // Pointer-handling.
227   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228 
229   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231 
232   getActionDefinitionsBuilder({G_FPOW,
233                                G_FEXP,
234                                G_FEXP2,
235                                G_FLOG,
236                                G_FLOG2,
237                                G_FABS,
238                                G_FMINNUM,
239                                G_FMAXNUM,
240                                G_FCEIL,
241                                G_FCOS,
242                                G_FSIN,
243                                G_FSQRT,
244                                G_FFLOOR,
245                                G_FRINT,
246                                G_FNEARBYINT,
247                                G_INTRINSIC_ROUND,
248                                G_INTRINSIC_TRUNC,
249                                G_FMINIMUM,
250                                G_FMAXIMUM,
251                                G_INTRINSIC_ROUNDEVEN})
252       .legalFor(allFloatScalarsAndVectors);
253 
254   getActionDefinitionsBuilder(G_FCOPYSIGN)
255       .legalForCartesianProduct(allFloatScalarsAndVectors,
256                                 allFloatScalarsAndVectors);
257 
258   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
259       allFloatScalarsAndVectors, allIntScalarsAndVectors);
260 
261   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
262     getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);
263 
264     getActionDefinitionsBuilder(
265         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
266         .legalForCartesianProduct(allIntScalarsAndVectors,
267                                   allIntScalarsAndVectors);
268 
269     // Struct return types become a single scalar, so cannot easily legalize.
270     getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
271   }
272 
273   getLegacyLegalizerInfo().computeTables();
274   verify(*ST.getInstrInfo());
275 }
276 
277 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
278                                 LegalizerHelper &Helper,
279                                 MachineRegisterInfo &MRI,
280                                 SPIRVGlobalRegistry *GR) {
281   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
282   GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
283   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
284       .addDef(ConvReg)
285       .addUse(Reg);
286   return ConvReg;
287 }
288 
289 bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
290                                         MachineInstr &MI) const {
291   auto Opc = MI.getOpcode();
292   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
293   if (!isTypeFoldingSupported(Opc)) {
294     assert(Opc == TargetOpcode::G_ICMP);
295     assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
296     auto &Op0 = MI.getOperand(2);
297     auto &Op1 = MI.getOperand(3);
298     Register Reg0 = Op0.getReg();
299     Register Reg1 = Op1.getReg();
300     CmpInst::Predicate Cond =
301         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
302     if ((!ST->canDirectlyComparePointers() ||
303          (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
304         MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
305       LLT ConvT = LLT::scalar(ST->getPointerSize());
306       Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
307                                       ST->getPointerSize());
308       SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
309       Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
310       Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
311     }
312     return true;
313   }
314   // TODO: implement legalization for other opcodes.
315   return true;
316 }
317