xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
typeOfExtendedScalars(unsigned TypeIdx,bool IsExtendedInts)27 LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
28   return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
29     const LLT Ty = Query.Types[TypeIdx];
30     return IsExtendedInts && Ty.isValid() && Ty.isScalar();
31   };
32 }
33 
SPIRVLegalizerInfo(const SPIRVSubtarget & ST)34 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
35   using namespace TargetOpcode;
36 
37   this->ST = &ST;
38   GR = ST.getSPIRVGlobalRegistry();
39 
40   const LLT s1 = LLT::scalar(1);
41   const LLT s8 = LLT::scalar(8);
42   const LLT s16 = LLT::scalar(16);
43   const LLT s32 = LLT::scalar(32);
44   const LLT s64 = LLT::scalar(64);
45 
46   const LLT v16s64 = LLT::fixed_vector(16, 64);
47   const LLT v16s32 = LLT::fixed_vector(16, 32);
48   const LLT v16s16 = LLT::fixed_vector(16, 16);
49   const LLT v16s8 = LLT::fixed_vector(16, 8);
50   const LLT v16s1 = LLT::fixed_vector(16, 1);
51 
52   const LLT v8s64 = LLT::fixed_vector(8, 64);
53   const LLT v8s32 = LLT::fixed_vector(8, 32);
54   const LLT v8s16 = LLT::fixed_vector(8, 16);
55   const LLT v8s8 = LLT::fixed_vector(8, 8);
56   const LLT v8s1 = LLT::fixed_vector(8, 1);
57 
58   const LLT v4s64 = LLT::fixed_vector(4, 64);
59   const LLT v4s32 = LLT::fixed_vector(4, 32);
60   const LLT v4s16 = LLT::fixed_vector(4, 16);
61   const LLT v4s8 = LLT::fixed_vector(4, 8);
62   const LLT v4s1 = LLT::fixed_vector(4, 1);
63 
64   const LLT v3s64 = LLT::fixed_vector(3, 64);
65   const LLT v3s32 = LLT::fixed_vector(3, 32);
66   const LLT v3s16 = LLT::fixed_vector(3, 16);
67   const LLT v3s8 = LLT::fixed_vector(3, 8);
68   const LLT v3s1 = LLT::fixed_vector(3, 1);
69 
70   const LLT v2s64 = LLT::fixed_vector(2, 64);
71   const LLT v2s32 = LLT::fixed_vector(2, 32);
72   const LLT v2s16 = LLT::fixed_vector(2, 16);
73   const LLT v2s8 = LLT::fixed_vector(2, 8);
74   const LLT v2s1 = LLT::fixed_vector(2, 1);
75 
76   const unsigned PSize = ST.getPointerSize();
77   const LLT p0 = LLT::pointer(0, PSize); // Function
78   const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
79   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
80   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
81   const LLT p4 = LLT::pointer(4, PSize); // Generic
82   const LLT p5 =
83       LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
84   const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
85   const LLT p7 = LLT::pointer(7, PSize); // Input
86   const LLT p8 = LLT::pointer(8, PSize); // Output
87   const LLT p10 = LLT::pointer(10, PSize); // Private
88   const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
89   const LLT p12 = LLT::pointer(12, PSize); // Uniform
90 
91   // TODO: remove copy-pasting here by using concatenation in some way.
92   auto allPtrsScalarsAndVectors = {
93       p0,    p1,    p2,    p3,     p4,     p5,    p6,    p7,    p8,
94       p10,   p11,   p12,   s1,     s8,     s16,   s32,   s64,   v2s1,
95       v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,  v3s16, v3s32, v3s64,
96       v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,  v8s8,  v8s16, v8s32,
97       v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
98 
99   auto allVectors = {v2s1,  v2s8,   v2s16,  v2s32, v2s64, v3s1,  v3s8,
100                      v3s16, v3s32,  v3s64,  v4s1,  v4s8,  v4s16, v4s32,
101                      v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
102                      v16s8, v16s16, v16s32, v16s64};
103 
104   auto allScalarsAndVectors = {
105       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
106       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
107       v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
108 
109   auto allIntScalarsAndVectors = {s8,    s16,   s32,   s64,    v2s8,   v2s16,
110                                   v2s32, v2s64, v3s8,  v3s16,  v3s32,  v3s64,
111                                   v4s8,  v4s16, v4s32, v4s64,  v8s8,   v8s16,
112                                   v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
113 
114   auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
115 
116   auto allIntScalars = {s8, s16, s32, s64};
117 
118   auto allFloatScalars = {s16, s32, s64};
119 
120   auto allFloatScalarsAndVectors = {
121       s16,   s32,   s64,   v2s16, v2s32, v2s64, v3s16,  v3s32,  v3s64,
122       v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
123 
124   auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,  p2,  p3,
125                                        p4, p5,  p6,  p7,  p8, p10, p11, p12};
126 
127   auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128 
129   bool IsExtendedInts =
130       ST.canUseExtension(
131           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
132       ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
133       ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
134   auto extendedScalarsAndVectors =
135       [IsExtendedInts](const LegalityQuery &Query) {
136         const LLT Ty = Query.Types[0];
137         return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
138       };
139   auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
140                                               const LegalityQuery &Query) {
141     const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
142     return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
143            !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
144   };
145   auto extendedPtrsScalarsAndVectors =
146       [IsExtendedInts](const LegalityQuery &Query) {
147         const LLT Ty = Query.Types[0];
148         return IsExtendedInts && Ty.isValid();
149       };
150 
151   for (auto Opc : getTypeFoldingSupportedOpcodes())
152     getActionDefinitionsBuilder(Opc).custom();
153 
154   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
155 
156   // TODO: add proper rules for vectors legalization.
157   getActionDefinitionsBuilder(
158       {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
159       .alwaysLegal();
160 
161   // Vector Reduction Operations
162   getActionDefinitionsBuilder(
163       {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
164        G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165        G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166        G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167       .legalFor(allVectors)
168       .scalarize(1)
169       .lower();
170 
171   getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172       .scalarize(2)
173       .lower();
174 
175   // Merge/Unmerge
176   // TODO: add proper legalization rules.
177   getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
178 
179   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
180       .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
181 
182   getActionDefinitionsBuilder(G_MEMSET).legalIf(
183       all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
184 
185   getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
186       .legalForCartesianProduct(allPtrs, allPtrs);
187 
188   getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
189 
190   getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
191                                G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
192                                G_USUBSAT, G_SCMP, G_UCMP})
193       .legalFor(allIntScalarsAndVectors)
194       .legalIf(extendedScalarsAndVectors);
195 
196   getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
197       .legalFor(allFloatScalarsAndVectors);
198 
199   getActionDefinitionsBuilder(G_STRICT_FLDEXP)
200       .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
201 
202   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
203       .legalForCartesianProduct(allIntScalarsAndVectors,
204                                 allFloatScalarsAndVectors);
205 
206   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
207       .legalForCartesianProduct(allFloatScalarsAndVectors,
208                                 allScalarsAndVectors);
209 
210   getActionDefinitionsBuilder(G_CTPOP)
211       .legalForCartesianProduct(allIntScalarsAndVectors)
212       .legalIf(extendedScalarsAndVectorsProduct);
213 
214   // Extensions.
215   getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
216       .legalForCartesianProduct(allScalarsAndVectors)
217       .legalIf(extendedScalarsAndVectorsProduct);
218 
219   getActionDefinitionsBuilder(G_PHI)
220       .legalFor(allPtrsScalarsAndVectors)
221       .legalIf(extendedPtrsScalarsAndVectors);
222 
223   getActionDefinitionsBuilder(G_BITCAST).legalIf(
224       all(typeInSet(0, allPtrsScalarsAndVectors),
225           typeInSet(1, allPtrsScalarsAndVectors)));
226 
227   getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
228 
229   getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
230 
231   getActionDefinitionsBuilder(G_INTTOPTR)
232       .legalForCartesianProduct(allPtrs, allIntScalars)
233       .legalIf(
234           all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
235   getActionDefinitionsBuilder(G_PTRTOINT)
236       .legalForCartesianProduct(allIntScalars, allPtrs)
237       .legalIf(
238           all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
239   getActionDefinitionsBuilder(G_PTR_ADD)
240       .legalForCartesianProduct(allPtrs, allIntScalars)
241       .legalIf(
242           all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
243 
244   // ST.canDirectlyComparePointers() for pointer args is supported in
245   // legalizeCustom().
246   getActionDefinitionsBuilder(G_ICMP).customIf(
247       all(typeInSet(0, allBoolScalarsAndVectors),
248           typeInSet(1, allPtrsScalarsAndVectors)));
249 
250   getActionDefinitionsBuilder(G_FCMP).legalIf(
251       all(typeInSet(0, allBoolScalarsAndVectors),
252           typeInSet(1, allFloatScalarsAndVectors)));
253 
254   getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
255                                G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
256                                G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
257                                G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
258       .legalForCartesianProduct(allIntScalars, allPtrs);
259 
260   getActionDefinitionsBuilder(
261       {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
262       .legalForCartesianProduct(allFloatScalars, allPtrs);
263 
264   getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
265       .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
266 
267   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
268   // TODO: add proper legalization rules.
269   getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
270 
271   getActionDefinitionsBuilder(
272       {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
273       .alwaysLegal();
274 
275   // FP conversions.
276   getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
277       .legalForCartesianProduct(allFloatScalarsAndVectors);
278 
279   // Pointer-handling.
280   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
281 
282   // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
283   getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
284 
285   // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
286   // tighten these requirements. Many of these math functions are only legal on
287   // specific bitwidths, so they are not selectable for
288   // allFloatScalarsAndVectors.
289   getActionDefinitionsBuilder({G_STRICT_FSQRT,
290                                G_FPOW,
291                                G_FEXP,
292                                G_FEXP2,
293                                G_FLOG,
294                                G_FLOG2,
295                                G_FLOG10,
296                                G_FABS,
297                                G_FMINNUM,
298                                G_FMAXNUM,
299                                G_FCEIL,
300                                G_FCOS,
301                                G_FSIN,
302                                G_FTAN,
303                                G_FACOS,
304                                G_FASIN,
305                                G_FATAN,
306                                G_FATAN2,
307                                G_FCOSH,
308                                G_FSINH,
309                                G_FTANH,
310                                G_FSQRT,
311                                G_FFLOOR,
312                                G_FRINT,
313                                G_FNEARBYINT,
314                                G_INTRINSIC_ROUND,
315                                G_INTRINSIC_TRUNC,
316                                G_FMINIMUM,
317                                G_FMAXIMUM,
318                                G_INTRINSIC_ROUNDEVEN})
319       .legalFor(allFloatScalarsAndVectors);
320 
321   getActionDefinitionsBuilder(G_FCOPYSIGN)
322       .legalForCartesianProduct(allFloatScalarsAndVectors,
323                                 allFloatScalarsAndVectors);
324 
325   getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
326       allFloatScalarsAndVectors, allIntScalarsAndVectors);
327 
328   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
329     getActionDefinitionsBuilder(
330         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
331         .legalForCartesianProduct(allIntScalarsAndVectors,
332                                   allIntScalarsAndVectors);
333 
334     // Struct return types become a single scalar, so cannot easily legalize.
335     getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
336   }
337 
338   getLegacyLegalizerInfo().computeTables();
339   verify(*ST.getInstrInfo());
340 }
341 
convertPtrToInt(Register Reg,LLT ConvTy,SPIRVType * SpvType,LegalizerHelper & Helper,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)342 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
343                                 LegalizerHelper &Helper,
344                                 MachineRegisterInfo &MRI,
345                                 SPIRVGlobalRegistry *GR) {
346   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
347   MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
348   GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
349   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
350       .addDef(ConvReg)
351       .addUse(Reg);
352   return ConvReg;
353 }
354 
legalizeCustom(LegalizerHelper & Helper,MachineInstr & MI,LostDebugLocObserver & LocObserver) const355 bool SPIRVLegalizerInfo::legalizeCustom(
356     LegalizerHelper &Helper, MachineInstr &MI,
357     LostDebugLocObserver &LocObserver) const {
358   auto Opc = MI.getOpcode();
359   MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
360   if (Opc == TargetOpcode::G_ICMP) {
361     assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
362     auto &Op0 = MI.getOperand(2);
363     auto &Op1 = MI.getOperand(3);
364     Register Reg0 = Op0.getReg();
365     Register Reg1 = Op1.getReg();
366     CmpInst::Predicate Cond =
367         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
368     if ((!ST->canDirectlyComparePointers() ||
369          (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
370         MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
371       LLT ConvT = LLT::scalar(ST->getPointerSize());
372       Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
373                                       ST->getPointerSize());
374       SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
375           LLVMTy, Helper.MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
376       Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
377       Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
378     }
379     return true;
380   }
381   // TODO: implement legalization for other opcodes.
382   return true;
383 }
384