1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52
isTypeFoldingSupported(unsigned Opcode)53 bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56
SPIRVLegalizerInfo(const SPIRVSubtarget & ST)57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 =
106 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
108
109 // TODO: remove copy-pasting here by using concatenation in some way.
110 auto allPtrsScalarsAndVectors = {
111 p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
112 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
113 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
115
116 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
117 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
118 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
119 v16s8, v16s16, v16s32, v16s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
123 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
124 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
125
126 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
127 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
128 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
129 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
130
131 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
132
133 auto allIntScalars = {s8, s16, s32, s64};
134
135 auto allFloatScalars = {s16, s32, s64};
136
137 auto allFloatScalarsAndVectors = {
138 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
139 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
140
141 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
142 p2, p3, p4, p5, p6};
143
144 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
145 auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
146
147 for (auto Opc : TypeFoldingSupportingOpcs)
148 getActionDefinitionsBuilder(Opc).custom();
149
150 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
151
152 // TODO: add proper rules for vectors legalization.
153 getActionDefinitionsBuilder(
154 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
155 .alwaysLegal();
156
157 // Vector Reduction Operations
158 getActionDefinitionsBuilder(
159 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
160 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
161 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
162 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
163 .legalFor(allVectors)
164 .scalarize(1)
165 .lower();
166
167 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
168 .scalarize(2)
169 .lower();
170
171 // Merge/Unmerge
172 // TODO: add proper legalization rules.
173 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
174
175 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
176 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
177
178 getActionDefinitionsBuilder(G_MEMSET).legalIf(
179 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
180
181 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
182 .legalForCartesianProduct(allPtrs, allPtrs);
183
184 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
185
186 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors);
187
188 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
189
190 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
191 .legalForCartesianProduct(allIntScalarsAndVectors,
192 allFloatScalarsAndVectors);
193
194 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
195 .legalForCartesianProduct(allFloatScalarsAndVectors,
196 allScalarsAndVectors);
197
198 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
199 .legalFor(allIntScalarsAndVectors);
200
201 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
202 allIntScalarsAndVectors, allIntScalarsAndVectors);
203
204 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
205
206 getActionDefinitionsBuilder(G_BITCAST).legalIf(
207 all(typeInSet(0, allPtrsScalarsAndVectors),
208 typeInSet(1, allPtrsScalarsAndVectors)));
209
210 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
211
212 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
213
214 getActionDefinitionsBuilder(G_INTTOPTR)
215 .legalForCartesianProduct(allPtrs, allIntScalars);
216 getActionDefinitionsBuilder(G_PTRTOINT)
217 .legalForCartesianProduct(allIntScalars, allPtrs);
218 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
219 allPtrs, allIntScalars);
220
221 // ST.canDirectlyComparePointers() for pointer args is supported in
222 // legalizeCustom().
223 getActionDefinitionsBuilder(G_ICMP).customIf(
224 all(typeInSet(0, allBoolScalarsAndVectors),
225 typeInSet(1, allPtrsScalarsAndVectors)));
226
227 getActionDefinitionsBuilder(G_FCMP).legalIf(
228 all(typeInSet(0, allBoolScalarsAndVectors),
229 typeInSet(1, allFloatScalarsAndVectors)));
230
231 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
232 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
233 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
234 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
235 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
236
237 getActionDefinitionsBuilder(
238 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
239 .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
240
241 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
242 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs);
243
244 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
245 // TODO: add proper legalization rules.
246 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
247
248 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
249 .alwaysLegal();
250
251 // Extensions.
252 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
253 .legalForCartesianProduct(allScalarsAndVectors);
254
255 // FP conversions.
256 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
257 .legalForCartesianProduct(allFloatScalarsAndVectors);
258
259 // Pointer-handling.
260 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
261
262 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
263 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
264
265 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
266 // tighten these requirements. Many of these math functions are only legal on
267 // specific bitwidths, so they are not selectable for
268 // allFloatScalarsAndVectors.
269 getActionDefinitionsBuilder({G_FPOW,
270 G_FEXP,
271 G_FEXP2,
272 G_FLOG,
273 G_FLOG2,
274 G_FLOG10,
275 G_FABS,
276 G_FMINNUM,
277 G_FMAXNUM,
278 G_FCEIL,
279 G_FCOS,
280 G_FSIN,
281 G_FTAN,
282 G_FACOS,
283 G_FASIN,
284 G_FATAN,
285 G_FCOSH,
286 G_FSINH,
287 G_FTANH,
288 G_FSQRT,
289 G_FFLOOR,
290 G_FRINT,
291 G_FNEARBYINT,
292 G_INTRINSIC_ROUND,
293 G_INTRINSIC_TRUNC,
294 G_FMINIMUM,
295 G_FMAXIMUM,
296 G_INTRINSIC_ROUNDEVEN})
297 .legalFor(allFloatScalarsAndVectors);
298
299 getActionDefinitionsBuilder(G_FCOPYSIGN)
300 .legalForCartesianProduct(allFloatScalarsAndVectors,
301 allFloatScalarsAndVectors);
302
303 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
304 allFloatScalarsAndVectors, allIntScalarsAndVectors);
305
306 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
307 getActionDefinitionsBuilder(
308 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
309 .legalForCartesianProduct(allIntScalarsAndVectors,
310 allIntScalarsAndVectors);
311
312 // Struct return types become a single scalar, so cannot easily legalize.
313 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
314
315 // supported saturation arithmetic
316 getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT})
317 .legalFor(allIntScalarsAndVectors);
318 }
319
320 getLegacyLegalizerInfo().computeTables();
321 verify(*ST.getInstrInfo());
322 }
323
convertPtrToInt(Register Reg,LLT ConvTy,SPIRVType * SpirvType,LegalizerHelper & Helper,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)324 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
325 LegalizerHelper &Helper,
326 MachineRegisterInfo &MRI,
327 SPIRVGlobalRegistry *GR) {
328 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
329 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
330 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
331 .addDef(ConvReg)
332 .addUse(Reg);
333 return ConvReg;
334 }
335
legalizeCustom(LegalizerHelper & Helper,MachineInstr & MI,LostDebugLocObserver & LocObserver) const336 bool SPIRVLegalizerInfo::legalizeCustom(
337 LegalizerHelper &Helper, MachineInstr &MI,
338 LostDebugLocObserver &LocObserver) const {
339 auto Opc = MI.getOpcode();
340 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
341 if (!isTypeFoldingSupported(Opc)) {
342 assert(Opc == TargetOpcode::G_ICMP);
343 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
344 auto &Op0 = MI.getOperand(2);
345 auto &Op1 = MI.getOperand(3);
346 Register Reg0 = Op0.getReg();
347 Register Reg1 = Op1.getReg();
348 CmpInst::Predicate Cond =
349 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
350 if ((!ST->canDirectlyComparePointers() ||
351 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
352 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
353 LLT ConvT = LLT::scalar(ST->getPointerSize());
354 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
355 ST->getPointerSize());
356 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
357 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
358 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
359 }
360 return true;
361 }
362 // TODO: implement legalization for other opcodes.
363 return true;
364 }
365