xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28     TargetOpcode::G_ADD,
29     TargetOpcode::G_FADD,
30     TargetOpcode::G_SUB,
31     TargetOpcode::G_FSUB,
32     TargetOpcode::G_MUL,
33     TargetOpcode::G_FMUL,
34     TargetOpcode::G_SDIV,
35     TargetOpcode::G_UDIV,
36     TargetOpcode::G_FDIV,
37     TargetOpcode::G_SREM,
38     TargetOpcode::G_UREM,
39     TargetOpcode::G_FREM,
40     TargetOpcode::G_FNEG,
41     TargetOpcode::G_CONSTANT,
42     TargetOpcode::G_FCONSTANT,
43     TargetOpcode::G_AND,
44     TargetOpcode::G_OR,
45     TargetOpcode::G_XOR,
46     TargetOpcode::G_SHL,
47     TargetOpcode::G_ASHR,
48     TargetOpcode::G_LSHR,
49     TargetOpcode::G_SELECT,
50     TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52 
isTypeFoldingSupported(unsigned Opcode)53 bool isTypeFoldingSupported(unsigned Opcode) {
54   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56 
SPIRVLegalizerInfo(const SPIRVSubtarget & ST)57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58   using namespace TargetOpcode;
59 
60   this->ST = &ST;
61   GR = ST.getSPIRVGlobalRegistry();
62 
63   const LLT s1 = LLT::scalar(1);
64   const LLT s8 = LLT::scalar(8);
65   const LLT s16 = LLT::scalar(16);
66   const LLT s32 = LLT::scalar(32);
67   const LLT s64 = LLT::scalar(64);
68 
69   const LLT v16s64 = LLT::fixed_vector(16, 64);
70   const LLT v16s32 = LLT::fixed_vector(16, 32);
71   const LLT v16s16 = LLT::fixed_vector(16, 16);
72   const LLT v16s8 = LLT::fixed_vector(16, 8);
73   const LLT v16s1 = LLT::fixed_vector(16, 1);
74 
75   const LLT v8s64 = LLT::fixed_vector(8, 64);
76   const LLT v8s32 = LLT::fixed_vector(8, 32);
77   const LLT v8s16 = LLT::fixed_vector(8, 16);
78   const LLT v8s8 = LLT::fixed_vector(8, 8);
79   const LLT v8s1 = LLT::fixed_vector(8, 1);
80 
81   const LLT v4s64 = LLT::fixed_vector(4, 64);
82   const LLT v4s32 = LLT::fixed_vector(4, 32);
83   const LLT v4s16 = LLT::fixed_vector(4, 16);
84   const LLT v4s8 = LLT::fixed_vector(4, 8);
85   const LLT v4s1 = LLT::fixed_vector(4, 1);
86 
87   const LLT v3s64 = LLT::fixed_vector(3, 64);
88   const LLT v3s32 = LLT::fixed_vector(3, 32);
89   const LLT v3s16 = LLT::fixed_vector(3, 16);
90   const LLT v3s8 = LLT::fixed_vector(3, 8);
91   const LLT v3s1 = LLT::fixed_vector(3, 1);
92 
93   const LLT v2s64 = LLT::fixed_vector(2, 64);
94   const LLT v2s32 = LLT::fixed_vector(2, 32);
95   const LLT v2s16 = LLT::fixed_vector(2, 16);
96   const LLT v2s8 = LLT::fixed_vector(2, 8);
97   const LLT v2s1 = LLT::fixed_vector(2, 1);
98 
99   const unsigned PSize = ST.getPointerSize();
100   const LLT p0 = LLT::pointer(0, PSize); // Function
101   const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104   const LLT p4 = LLT::pointer(4, PSize); // Generic
105   const LLT p5 =
106       LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107   const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
108 
109   // TODO: remove copy-pasting here by using concatenation in some way.
110   auto allPtrsScalarsAndVectors = {
111       p0,    p1,    p2,    p3,    p4,     p5,     p6,    s1,   s8,   s16,
112       s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64, v3s1, v3s8, v3s16,
113       v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64, v8s1, v8s8, v8s16,
114       v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
115 
116   auto allVectors = {v2s1,  v2s8,   v2s16,  v2s32, v2s64, v3s1,  v3s8,
117                      v3s16, v3s32,  v3s64,  v4s1,  v4s8,  v4s16, v4s32,
118                      v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
119                      v16s8, v16s16, v16s32, v16s64};
120 
121   auto allScalarsAndVectors = {
122       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
123       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
124       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
125 
126   auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
127                                   v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
128                                   v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
129                                   v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
130 
131   auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
132 
133   auto allIntScalars = {s8, s16, s32, s64};
134 
135   auto allFloatScalars = {s16, s32, s64};
136 
137   auto allFloatScalarsAndVectors = {
138       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
139       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
140 
141   auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
142                                        p2, p3,  p4,  p5,  p6};
143 
144   auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
145   auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
146 
147   for (auto Opc : TypeFoldingSupportingOpcs)
148     getActionDefinitionsBuilder(Opc).custom();
149 
150   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
151 
152   // TODO: add proper rules for vectors legalization.
153   getActionDefinitionsBuilder(
154       {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
155       .alwaysLegal();
156 
157   // Vector Reduction Operations
158   getActionDefinitionsBuilder(
159       {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
160        G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
161        G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
162        G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
163       .legalFor(allVectors)
164       .scalarize(1)
165       .lower();
166 
167   getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
168       .scalarize(2)
169       .lower();
170 
171   // Merge/Unmerge
172   // TODO: add proper legalization rules.
173   getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
174 
175   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
176       .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
177 
178   getActionDefinitionsBuilder(G_MEMSET).legalIf(
179       all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
180 
181   getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
182       .legalForCartesianProduct(allPtrs, allPtrs);
183 
184   getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
185 
186   getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors);
187 
188   getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
189 
190   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
191       .legalForCartesianProduct(allIntScalarsAndVectors,
192                                 allFloatScalarsAndVectors);
193 
194   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
195       .legalForCartesianProduct(allFloatScalarsAndVectors,
196                                 allScalarsAndVectors);
197 
198   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
199       .legalFor(allIntScalarsAndVectors);
200 
201   getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
202       allIntScalarsAndVectors, allIntScalarsAndVectors);
203 
204   getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
205 
206   getActionDefinitionsBuilder(G_BITCAST).legalIf(
207       all(typeInSet(0, allPtrsScalarsAndVectors),
208           typeInSet(1, allPtrsScalarsAndVectors)));
209 
210   getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
211 
212   getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
213 
214   getActionDefinitionsBuilder(G_INTTOPTR)
215       .legalForCartesianProduct(allPtrs, allIntScalars);
216   getActionDefinitionsBuilder(G_PTRTOINT)
217       .legalForCartesianProduct(allIntScalars, allPtrs);
218   getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
219       allPtrs, allIntScalars);
220 
221   // ST.canDirectlyComparePointers() for pointer args is supported in
222   // legalizeCustom().
223   getActionDefinitionsBuilder(G_ICMP).customIf(
224       all(typeInSet(0, allBoolScalarsAndVectors),
225           typeInSet(1, allPtrsScalarsAndVectors)));
226 
227   getActionDefinitionsBuilder(G_FCMP).legalIf(
228       all(typeInSet(0, allBoolScalarsAndVectors),
229           typeInSet(1, allFloatScalarsAndVectors)));
230 
231   getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
232                                G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
233                                G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
234                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
235       .legalForCartesianProduct(allIntScalars, allWritablePtrs);
236 
237   getActionDefinitionsBuilder(
238       {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
239       .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
240 
241   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
242       .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs);
243 
244   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
245   // TODO: add proper legalization rules.
246   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
247 
248   getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
249       .alwaysLegal();
250 
251   // Extensions.
252   getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
253       .legalForCartesianProduct(allScalarsAndVectors);
254 
255   // FP conversions.
256   getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
257       .legalForCartesianProduct(allFloatScalarsAndVectors);
258 
259   // Pointer-handling.
260   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
261 
262   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
263   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
264 
265   // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
266   // tighten these requirements. Many of these math functions are only legal on
267   // specific bitwidths, so they are not selectable for
268   // allFloatScalarsAndVectors.
269   getActionDefinitionsBuilder({G_FPOW,
270                                G_FEXP,
271                                G_FEXP2,
272                                G_FLOG,
273                                G_FLOG2,
274                                G_FLOG10,
275                                G_FABS,
276                                G_FMINNUM,
277                                G_FMAXNUM,
278                                G_FCEIL,
279                                G_FCOS,
280                                G_FSIN,
281                                G_FTAN,
282                                G_FACOS,
283                                G_FASIN,
284                                G_FATAN,
285                                G_FCOSH,
286                                G_FSINH,
287                                G_FTANH,
288                                G_FSQRT,
289                                G_FFLOOR,
290                                G_FRINT,
291                                G_FNEARBYINT,
292                                G_INTRINSIC_ROUND,
293                                G_INTRINSIC_TRUNC,
294                                G_FMINIMUM,
295                                G_FMAXIMUM,
296                                G_INTRINSIC_ROUNDEVEN})
297       .legalFor(allFloatScalarsAndVectors);
298 
299   getActionDefinitionsBuilder(G_FCOPYSIGN)
300       .legalForCartesianProduct(allFloatScalarsAndVectors,
301                                 allFloatScalarsAndVectors);
302 
303   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
304       allFloatScalarsAndVectors, allIntScalarsAndVectors);
305 
306   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
307     getActionDefinitionsBuilder(
308         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
309         .legalForCartesianProduct(allIntScalarsAndVectors,
310                                   allIntScalarsAndVectors);
311 
312     // Struct return types become a single scalar, so cannot easily legalize.
313     getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
314 
315     // supported saturation arithmetic
316     getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT})
317         .legalFor(allIntScalarsAndVectors);
318   }
319 
320   getLegacyLegalizerInfo().computeTables();
321   verify(*ST.getInstrInfo());
322 }
323 
convertPtrToInt(Register Reg,LLT ConvTy,SPIRVType * SpirvType,LegalizerHelper & Helper,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)324 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
325                                 LegalizerHelper &Helper,
326                                 MachineRegisterInfo &MRI,
327                                 SPIRVGlobalRegistry *GR) {
328   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
329   GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
330   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
331       .addDef(ConvReg)
332       .addUse(Reg);
333   return ConvReg;
334 }
335 
legalizeCustom(LegalizerHelper & Helper,MachineInstr & MI,LostDebugLocObserver & LocObserver) const336 bool SPIRVLegalizerInfo::legalizeCustom(
337     LegalizerHelper &Helper, MachineInstr &MI,
338     LostDebugLocObserver &LocObserver) const {
339   auto Opc = MI.getOpcode();
340   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
341   if (!isTypeFoldingSupported(Opc)) {
342     assert(Opc == TargetOpcode::G_ICMP);
343     assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
344     auto &Op0 = MI.getOperand(2);
345     auto &Op1 = MI.getOperand(3);
346     Register Reg0 = Op0.getReg();
347     Register Reg1 = Op1.getReg();
348     CmpInst::Predicate Cond =
349         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
350     if ((!ST->canDirectlyComparePointers() ||
351          (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
352         MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
353       LLT ConvT = LLT::scalar(ST->getPointerSize());
354       Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
355                                       ST->getPointerSize());
356       SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
357       Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
358       Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
359     }
360     return true;
361   }
362   // TODO: implement legalization for other opcodes.
363   return true;
364 }
365