1 //===- CombinerHelperVectorOps.cpp-----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT,
10 // G_INSERT_VECTOR_ELT, and G_VSCALE
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
17 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/GlobalISel/Utils.h"
20 #include "llvm/CodeGen/LowLevelTypeUtils.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/TargetLowering.h"
24 #include "llvm/CodeGen/TargetOpcodes.h"
25 #include "llvm/Support/Casting.h"
26 #include <optional>
27
28 #define DEBUG_TYPE "gi-combiner"
29
30 using namespace llvm;
31 using namespace MIPatternMatch;
32
matchExtractVectorElement(MachineInstr & MI,BuildFnTy & MatchInfo) const33 bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI,
34 BuildFnTy &MatchInfo) const {
35 GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
36
37 Register Dst = Extract->getReg(0);
38 Register Vector = Extract->getVectorReg();
39 Register Index = Extract->getIndexReg();
40 LLT DstTy = MRI.getType(Dst);
41 LLT VectorTy = MRI.getType(Vector);
42
43 // The vector register can be def'd by various ops that have vector as its
44 // type. They can all be used for constant folding, scalarizing,
45 // canonicalization, or combining based on symmetry.
46 //
47 // vector like ops
48 // * build vector
49 // * build vector trunc
50 // * shuffle vector
51 // * splat vector
52 // * concat vectors
53 // * insert/extract vector element
54 // * insert/extract subvector
55 // * vector loads
56 // * scalable vector loads
57 //
58 // compute like ops
59 // * binary ops
60 // * unary ops
61 // * exts and truncs
62 // * casts
63 // * fneg
64 // * select
65 // * phis
66 // * cmps
67 // * freeze
68 // * bitcast
69 // * undef
70
71 // We try to get the value of the Index register.
72 std::optional<ValueAndVReg> MaybeIndex =
73 getIConstantVRegValWithLookThrough(Index, MRI);
74 std::optional<APInt> IndexC = std::nullopt;
75
76 if (MaybeIndex)
77 IndexC = MaybeIndex->Value;
78
79 // Fold extractVectorElement(Vector, TOOLARGE) -> undef
80 if (IndexC && VectorTy.isFixedVector() &&
81 IndexC->uge(VectorTy.getNumElements()) &&
82 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
83 // For fixed-length vectors, it's invalid to extract out-of-range elements.
84 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
85 return true;
86 }
87
88 return false;
89 }
90
matchExtractVectorElementWithDifferentIndices(const MachineOperand & MO,BuildFnTy & MatchInfo) const91 bool CombinerHelper::matchExtractVectorElementWithDifferentIndices(
92 const MachineOperand &MO, BuildFnTy &MatchInfo) const {
93 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
94 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
95
96 //
97 // %idx1:_(s64) = G_CONSTANT i64 1
98 // %idx2:_(s64) = G_CONSTANT i64 2
99 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
100 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2
101 // x s32>), %idx1(s64)
102 //
103 // -->
104 //
105 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
106 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x
107 // s32>), %idx1(s64)
108 //
109 //
110
111 Register Index = Extract->getIndexReg();
112
113 // We try to get the value of the Index register.
114 std::optional<ValueAndVReg> MaybeIndex =
115 getIConstantVRegValWithLookThrough(Index, MRI);
116 std::optional<APInt> IndexC = std::nullopt;
117
118 if (!MaybeIndex)
119 return false;
120 else
121 IndexC = MaybeIndex->Value;
122
123 Register Vector = Extract->getVectorReg();
124
125 GInsertVectorElement *Insert =
126 getOpcodeDef<GInsertVectorElement>(Vector, MRI);
127 if (!Insert)
128 return false;
129
130 Register Dst = Extract->getReg(0);
131
132 std::optional<ValueAndVReg> MaybeInsertIndex =
133 getIConstantVRegValWithLookThrough(Insert->getIndexReg(), MRI);
134
135 if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) {
136 // There is no one-use check. We have to keep the insert. When both Index
137 // registers are constants and not equal, we can look into the Vector
138 // register of the insert.
139 MatchInfo = [=](MachineIRBuilder &B) {
140 B.buildExtractVectorElement(Dst, Insert->getVectorReg(), Index);
141 };
142 return true;
143 }
144
145 return false;
146 }
147
matchExtractVectorElementWithBuildVector(const MachineInstr & MI,const MachineInstr & MI2,BuildFnTy & MatchInfo) const148 bool CombinerHelper::matchExtractVectorElementWithBuildVector(
149 const MachineInstr &MI, const MachineInstr &MI2,
150 BuildFnTy &MatchInfo) const {
151 const GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
152 const GBuildVector *Build = cast<GBuildVector>(&MI2);
153
154 //
155 // %zero:_(s64) = G_CONSTANT i64 0
156 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
157 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
158 //
159 // -->
160 //
161 // %extract:_(32) = COPY %arg1(s32)
162 //
163 //
164
165 Register Vector = Extract->getVectorReg();
166 LLT VectorTy = MRI.getType(Vector);
167
168 // There is a one-use check. There are more combines on build vectors.
169 EVT Ty(getMVTForLLT(VectorTy));
170 if (!MRI.hasOneNonDBGUse(Build->getReg(0)) ||
171 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
172 return false;
173
174 APInt Index = getIConstantFromReg(Extract->getIndexReg(), MRI);
175
176 // We now know that there is a buildVector def'd on the Vector register and
177 // the index is const. The combine will succeed.
178
179 Register Dst = Extract->getReg(0);
180
181 MatchInfo = [=](MachineIRBuilder &B) {
182 B.buildCopy(Dst, Build->getSourceReg(Index.getZExtValue()));
183 };
184
185 return true;
186 }
187
matchExtractVectorElementWithBuildVectorTrunc(const MachineOperand & MO,BuildFnTy & MatchInfo) const188 bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc(
189 const MachineOperand &MO, BuildFnTy &MatchInfo) const {
190 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
191 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
192
193 //
194 // %zero:_(s64) = G_CONSTANT i64 0
195 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
196 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
197 //
198 // -->
199 //
200 // %extract:_(32) = G_TRUNC %arg1(s64)
201 //
202 //
203 //
204 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
205 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
206 //
207 // -->
208 //
209 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
210 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
211 //
212
213 Register Vector = Extract->getVectorReg();
214
215 // We expect a buildVectorTrunc on the Vector register.
216 GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Vector, MRI);
217 if (!Build)
218 return false;
219
220 LLT VectorTy = MRI.getType(Vector);
221
222 // There is a one-use check. There are more combines on build vectors.
223 EVT Ty(getMVTForLLT(VectorTy));
224 if (!MRI.hasOneNonDBGUse(Build->getReg(0)) ||
225 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
226 return false;
227
228 Register Index = Extract->getIndexReg();
229
230 // If the Index is constant, then we can extract the element from the given
231 // offset.
232 std::optional<ValueAndVReg> MaybeIndex =
233 getIConstantVRegValWithLookThrough(Index, MRI);
234 if (!MaybeIndex)
235 return false;
236
237 // We now know that there is a buildVectorTrunc def'd on the Vector register
238 // and the index is const. The combine will succeed.
239
240 Register Dst = Extract->getReg(0);
241 LLT DstTy = MRI.getType(Dst);
242 LLT SrcTy = MRI.getType(Build->getSourceReg(0));
243
244 // For buildVectorTrunc, the inputs are truncated.
245 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
246 return false;
247
248 MatchInfo = [=](MachineIRBuilder &B) {
249 B.buildTrunc(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue()));
250 };
251
252 return true;
253 }
254
matchExtractVectorElementWithShuffleVector(const MachineInstr & MI,const MachineInstr & MI2,BuildFnTy & MatchInfo) const255 bool CombinerHelper::matchExtractVectorElementWithShuffleVector(
256 const MachineInstr &MI, const MachineInstr &MI2,
257 BuildFnTy &MatchInfo) const {
258 const GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
259 const GShuffleVector *Shuffle = cast<GShuffleVector>(&MI2);
260
261 //
262 // %zero:_(s64) = G_CONSTANT i64 0
263 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
264 // shufflemask(0, 0, 0, 0)
265 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64)
266 //
267 // -->
268 //
269 // %zero1:_(s64) = G_CONSTANT i64 0
270 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64)
271 //
272 //
273 //
274 //
275 // %three:_(s64) = G_CONSTANT i64 3
276 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
277 // shufflemask(0, 0, 0, -1)
278 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64)
279 //
280 // -->
281 //
282 // %extract:_(s32) = G_IMPLICIT_DEF
283 //
284 //
285
286 APInt Index = getIConstantFromReg(Extract->getIndexReg(), MRI);
287
288 ArrayRef<int> Mask = Shuffle->getMask();
289
290 unsigned Offset = Index.getZExtValue();
291 int SrcIdx = Mask[Offset];
292
293 LLT Src1Type = MRI.getType(Shuffle->getSrc1Reg());
294 // At the IR level a <1 x ty> shuffle vector is valid, but we want to extract
295 // from a vector.
296 assert(Src1Type.isVector() && "expected to extract from a vector");
297 unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1;
298
299 // Note that there is no one use check.
300 Register Dst = Extract->getReg(0);
301 LLT DstTy = MRI.getType(Dst);
302
303 if (SrcIdx < 0 &&
304 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
305 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
306 return true;
307 }
308
309 // If the legality check failed, then we still have to abort.
310 if (SrcIdx < 0)
311 return false;
312
313 Register NewVector;
314
315 // We check in which vector and at what offset to look through.
316 if (SrcIdx < (int)LHSWidth) {
317 NewVector = Shuffle->getSrc1Reg();
318 // SrcIdx unchanged
319 } else { // SrcIdx >= LHSWidth
320 NewVector = Shuffle->getSrc2Reg();
321 SrcIdx -= LHSWidth;
322 }
323
324 LLT IdxTy = MRI.getType(Extract->getIndexReg());
325 LLT NewVectorTy = MRI.getType(NewVector);
326
327 // We check the legality of the look through.
328 if (!isLegalOrBeforeLegalizer(
329 {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) ||
330 !isConstantLegalOrBeforeLegalizer({IdxTy}))
331 return false;
332
333 // We look through the shuffle vector.
334 MatchInfo = [=](MachineIRBuilder &B) {
335 auto Idx = B.buildConstant(IdxTy, SrcIdx);
336 B.buildExtractVectorElement(Dst, NewVector, Idx);
337 };
338
339 return true;
340 }
341
matchInsertVectorElementOOB(MachineInstr & MI,BuildFnTy & MatchInfo) const342 bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
343 BuildFnTy &MatchInfo) const {
344 GInsertVectorElement *Insert = cast<GInsertVectorElement>(&MI);
345
346 Register Dst = Insert->getReg(0);
347 LLT DstTy = MRI.getType(Dst);
348 Register Index = Insert->getIndexReg();
349
350 if (!DstTy.isFixedVector())
351 return false;
352
353 std::optional<ValueAndVReg> MaybeIndex =
354 getIConstantVRegValWithLookThrough(Index, MRI);
355
356 if (MaybeIndex && MaybeIndex->Value.uge(DstTy.getNumElements()) &&
357 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
358 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
359 return true;
360 }
361
362 return false;
363 }
364
matchAddOfVScale(const MachineOperand & MO,BuildFnTy & MatchInfo) const365 bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO,
366 BuildFnTy &MatchInfo) const {
367 GAdd *Add = cast<GAdd>(MRI.getVRegDef(MO.getReg()));
368 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getLHSReg()));
369 GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getRHSReg()));
370
371 Register Dst = Add->getReg(0);
372
373 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
374 !MRI.hasOneNonDBGUse(RHSVScale->getReg(0)))
375 return false;
376
377 MatchInfo = [=](MachineIRBuilder &B) {
378 B.buildVScale(Dst, LHSVScale->getSrc() + RHSVScale->getSrc());
379 };
380
381 return true;
382 }
383
matchMulOfVScale(const MachineOperand & MO,BuildFnTy & MatchInfo) const384 bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO,
385 BuildFnTy &MatchInfo) const {
386 GMul *Mul = cast<GMul>(MRI.getVRegDef(MO.getReg()));
387 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Mul->getLHSReg()));
388
389 std::optional<APInt> MaybeRHS = getIConstantVRegVal(Mul->getRHSReg(), MRI);
390 if (!MaybeRHS)
391 return false;
392
393 Register Dst = MO.getReg();
394
395 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)))
396 return false;
397
398 MatchInfo = [=](MachineIRBuilder &B) {
399 B.buildVScale(Dst, LHSVScale->getSrc() * *MaybeRHS);
400 };
401
402 return true;
403 }
404
matchSubOfVScale(const MachineOperand & MO,BuildFnTy & MatchInfo) const405 bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
406 BuildFnTy &MatchInfo) const {
407 GSub *Sub = cast<GSub>(MRI.getVRegDef(MO.getReg()));
408 GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Sub->getRHSReg()));
409
410 Register Dst = MO.getReg();
411 LLT DstTy = MRI.getType(Dst);
412
413 if (!MRI.hasOneNonDBGUse(RHSVScale->getReg(0)) ||
414 !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, DstTy}))
415 return false;
416
417 MatchInfo = [=](MachineIRBuilder &B) {
418 auto VScale = B.buildVScale(DstTy, -RHSVScale->getSrc());
419 B.buildAdd(Dst, Sub->getLHSReg(), VScale, Sub->getFlags());
420 };
421
422 return true;
423 }
424
matchShlOfVScale(const MachineOperand & MO,BuildFnTy & MatchInfo) const425 bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
426 BuildFnTy &MatchInfo) const {
427 GShl *Shl = cast<GShl>(MRI.getVRegDef(MO.getReg()));
428 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg()));
429
430 std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI);
431 if (!MaybeRHS)
432 return false;
433
434 Register Dst = MO.getReg();
435 LLT DstTy = MRI.getType(Dst);
436
437 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
438 !isLegalOrBeforeLegalizer({TargetOpcode::G_VSCALE, DstTy}))
439 return false;
440
441 MatchInfo = [=](MachineIRBuilder &B) {
442 B.buildVScale(Dst, LHSVScale->getSrc().shl(*MaybeRHS));
443 };
444
445 return true;
446 }
447