xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (revision 357378bbdedf24ce2b90e9bd831af4a9db3ec70a)
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28     TargetOpcode::G_ADD,
29     TargetOpcode::G_FADD,
30     TargetOpcode::G_SUB,
31     TargetOpcode::G_FSUB,
32     TargetOpcode::G_MUL,
33     TargetOpcode::G_FMUL,
34     TargetOpcode::G_SDIV,
35     TargetOpcode::G_UDIV,
36     TargetOpcode::G_FDIV,
37     TargetOpcode::G_SREM,
38     TargetOpcode::G_UREM,
39     TargetOpcode::G_FREM,
40     TargetOpcode::G_FNEG,
41     TargetOpcode::G_CONSTANT,
42     TargetOpcode::G_FCONSTANT,
43     TargetOpcode::G_AND,
44     TargetOpcode::G_OR,
45     TargetOpcode::G_XOR,
46     TargetOpcode::G_SHL,
47     TargetOpcode::G_ASHR,
48     TargetOpcode::G_LSHR,
49     TargetOpcode::G_SELECT,
50     TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52 
53 bool isTypeFoldingSupported(unsigned Opcode) {
54   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56 
57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58   using namespace TargetOpcode;
59 
60   this->ST = &ST;
61   GR = ST.getSPIRVGlobalRegistry();
62 
63   const LLT s1 = LLT::scalar(1);
64   const LLT s8 = LLT::scalar(8);
65   const LLT s16 = LLT::scalar(16);
66   const LLT s32 = LLT::scalar(32);
67   const LLT s64 = LLT::scalar(64);
68 
69   const LLT v16s64 = LLT::fixed_vector(16, 64);
70   const LLT v16s32 = LLT::fixed_vector(16, 32);
71   const LLT v16s16 = LLT::fixed_vector(16, 16);
72   const LLT v16s8 = LLT::fixed_vector(16, 8);
73   const LLT v16s1 = LLT::fixed_vector(16, 1);
74 
75   const LLT v8s64 = LLT::fixed_vector(8, 64);
76   const LLT v8s32 = LLT::fixed_vector(8, 32);
77   const LLT v8s16 = LLT::fixed_vector(8, 16);
78   const LLT v8s8 = LLT::fixed_vector(8, 8);
79   const LLT v8s1 = LLT::fixed_vector(8, 1);
80 
81   const LLT v4s64 = LLT::fixed_vector(4, 64);
82   const LLT v4s32 = LLT::fixed_vector(4, 32);
83   const LLT v4s16 = LLT::fixed_vector(4, 16);
84   const LLT v4s8 = LLT::fixed_vector(4, 8);
85   const LLT v4s1 = LLT::fixed_vector(4, 1);
86 
87   const LLT v3s64 = LLT::fixed_vector(3, 64);
88   const LLT v3s32 = LLT::fixed_vector(3, 32);
89   const LLT v3s16 = LLT::fixed_vector(3, 16);
90   const LLT v3s8 = LLT::fixed_vector(3, 8);
91   const LLT v3s1 = LLT::fixed_vector(3, 1);
92 
93   const LLT v2s64 = LLT::fixed_vector(2, 64);
94   const LLT v2s32 = LLT::fixed_vector(2, 32);
95   const LLT v2s16 = LLT::fixed_vector(2, 16);
96   const LLT v2s8 = LLT::fixed_vector(2, 8);
97   const LLT v2s1 = LLT::fixed_vector(2, 1);
98 
99   const unsigned PSize = ST.getPointerSize();
100   const LLT p0 = LLT::pointer(0, PSize); // Function
101   const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104   const LLT p4 = LLT::pointer(4, PSize); // Generic
105   const LLT p5 = LLT::pointer(5, PSize); // Input
106 
107   // TODO: remove copy-pasting here by using concatenation in some way.
108   auto allPtrsScalarsAndVectors = {
109       p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
110       s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
111       v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
112       v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113 
114   auto allScalarsAndVectors = {
115       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
116       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
117       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118 
119   auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
120                                   v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
121                                   v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
122                                   v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123 
124   auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125 
126   auto allIntScalars = {s8, s16, s32, s64};
127 
128   auto allFloatScalarsAndVectors = {
129       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
130       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131 
132   auto allFloatAndIntScalars = allIntScalars;
133 
134   auto allPtrs = {p0, p1, p2, p3, p4, p5};
135   auto allWritablePtrs = {p0, p1, p3, p4};
136 
137   for (auto Opc : TypeFoldingSupportingOpcs)
138     getActionDefinitionsBuilder(Opc).custom();
139 
140   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141 
142   // TODO: add proper rules for vectors legalization.
143   getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144 
145   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146       .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147 
148   getActionDefinitionsBuilder(G_MEMSET).legalIf(
149       all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150 
151   getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152       .legalForCartesianProduct(allPtrs, allPtrs);
153 
154   getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155 
156   getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157 
158   getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159 
160   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161       .legalForCartesianProduct(allIntScalarsAndVectors,
162                                 allFloatScalarsAndVectors);
163 
164   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165       .legalForCartesianProduct(allFloatScalarsAndVectors,
166                                 allScalarsAndVectors);
167 
168   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169       .legalFor(allIntScalarsAndVectors);
170 
171   getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
172       allIntScalarsAndVectors, allIntScalarsAndVectors);
173 
174   getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175 
176   getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
177       typeInSet(0, allPtrsScalarsAndVectors),
178       typeInSet(1, allPtrsScalarsAndVectors),
179       LegalityPredicate(([=](const LegalityQuery &Query) {
180         return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181       }))));
182 
183   getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
184 
185   getActionDefinitionsBuilder(G_INTTOPTR)
186       .legalForCartesianProduct(allPtrs, allIntScalars);
187   getActionDefinitionsBuilder(G_PTRTOINT)
188       .legalForCartesianProduct(allIntScalars, allPtrs);
189   getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
190       allPtrs, allIntScalars);
191 
192   // ST.canDirectlyComparePointers() for pointer args is supported in
193   // legalizeCustom().
194   getActionDefinitionsBuilder(G_ICMP).customIf(
195       all(typeInSet(0, allBoolScalarsAndVectors),
196           typeInSet(1, allPtrsScalarsAndVectors)));
197 
198   getActionDefinitionsBuilder(G_FCMP).legalIf(
199       all(typeInSet(0, allBoolScalarsAndVectors),
200           typeInSet(1, allFloatScalarsAndVectors)));
201 
202   getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203                                G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204                                G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206       .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207 
208   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209       .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210 
211   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212   // TODO: add proper legalization rules.
213   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214 
215   getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216       .alwaysLegal();
217 
218   // Extensions.
219   getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220       .legalForCartesianProduct(allScalarsAndVectors);
221 
222   // FP conversions.
223   getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224       .legalForCartesianProduct(allFloatScalarsAndVectors);
225 
226   // Pointer-handling.
227   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228 
229   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231 
232   // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
233   // tighten these requirements. Many of these math functions are only legal on
234   // specific bitwidths, so they are not selectable for
235   // allFloatScalarsAndVectors.
236   getActionDefinitionsBuilder({G_FPOW,
237                                G_FEXP,
238                                G_FEXP2,
239                                G_FLOG,
240                                G_FLOG2,
241                                G_FLOG10,
242                                G_FABS,
243                                G_FMINNUM,
244                                G_FMAXNUM,
245                                G_FCEIL,
246                                G_FCOS,
247                                G_FSIN,
248                                G_FSQRT,
249                                G_FFLOOR,
250                                G_FRINT,
251                                G_FNEARBYINT,
252                                G_INTRINSIC_ROUND,
253                                G_INTRINSIC_TRUNC,
254                                G_FMINIMUM,
255                                G_FMAXIMUM,
256                                G_INTRINSIC_ROUNDEVEN})
257       .legalFor(allFloatScalarsAndVectors);
258 
259   getActionDefinitionsBuilder(G_FCOPYSIGN)
260       .legalForCartesianProduct(allFloatScalarsAndVectors,
261                                 allFloatScalarsAndVectors);
262 
263   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
264       allFloatScalarsAndVectors, allIntScalarsAndVectors);
265 
266   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
267     getActionDefinitionsBuilder(
268         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
269         .legalForCartesianProduct(allIntScalarsAndVectors,
270                                   allIntScalarsAndVectors);
271 
272     // Struct return types become a single scalar, so cannot easily legalize.
273     getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
274   }
275 
276   getLegacyLegalizerInfo().computeTables();
277   verify(*ST.getInstrInfo());
278 }
279 
280 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
281                                 LegalizerHelper &Helper,
282                                 MachineRegisterInfo &MRI,
283                                 SPIRVGlobalRegistry *GR) {
284   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
285   GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
286   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
287       .addDef(ConvReg)
288       .addUse(Reg);
289   return ConvReg;
290 }
291 
292 bool SPIRVLegalizerInfo::legalizeCustom(
293     LegalizerHelper &Helper, MachineInstr &MI,
294     LostDebugLocObserver &LocObserver) const {
295   auto Opc = MI.getOpcode();
296   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
297   if (!isTypeFoldingSupported(Opc)) {
298     assert(Opc == TargetOpcode::G_ICMP);
299     assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
300     auto &Op0 = MI.getOperand(2);
301     auto &Op1 = MI.getOperand(3);
302     Register Reg0 = Op0.getReg();
303     Register Reg1 = Op1.getReg();
304     CmpInst::Predicate Cond =
305         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
306     if ((!ST->canDirectlyComparePointers() ||
307          (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
308         MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
309       LLT ConvT = LLT::scalar(ST->getPointerSize());
310       Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
311                                       ST->getPointerSize());
312       SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
313       Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
314       Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
315     }
316     return true;
317   }
318   // TODO: implement legalization for other opcodes.
319   return true;
320 }
321