xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- CombinerHelperVectorOps.cpp-----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT,
10 // G_INSERT_VECTOR_ELT, and G_VSCALE
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
17 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/GlobalISel/Utils.h"
20 #include "llvm/CodeGen/LowLevelTypeUtils.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/TargetLowering.h"
24 #include "llvm/CodeGen/TargetOpcodes.h"
25 #include "llvm/Support/Casting.h"
26 #include <optional>
27 
28 #define DEBUG_TYPE "gi-combiner"
29 
30 using namespace llvm;
31 using namespace MIPatternMatch;
32 
33 bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI,
34                                                BuildFnTy &MatchInfo) {
35   GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI);
36 
37   Register Dst = Extract->getReg(0);
38   Register Vector = Extract->getVectorReg();
39   Register Index = Extract->getIndexReg();
40   LLT DstTy = MRI.getType(Dst);
41   LLT VectorTy = MRI.getType(Vector);
42 
43   // The vector register can be def'd by various ops that have vector as its
44   // type. They can all be used for constant folding, scalarizing,
45   // canonicalization, or combining based on symmetry.
46   //
47   // vector like ops
48   // * build vector
49   // * build vector trunc
50   // * shuffle vector
51   // * splat vector
52   // * concat vectors
53   // * insert/extract vector element
54   // * insert/extract subvector
55   // * vector loads
56   // * scalable vector loads
57   //
58   // compute like ops
59   // * binary ops
60   // * unary ops
61   //  * exts and truncs
62   //  * casts
63   //  * fneg
64   // * select
65   // * phis
66   // * cmps
67   // * freeze
68   // * bitcast
69   // * undef
70 
71   // We try to get the value of the Index register.
72   std::optional<ValueAndVReg> MaybeIndex =
73       getIConstantVRegValWithLookThrough(Index, MRI);
74   std::optional<APInt> IndexC = std::nullopt;
75 
76   if (MaybeIndex)
77     IndexC = MaybeIndex->Value;
78 
79   // Fold extractVectorElement(Vector, TOOLARGE) -> undef
80   if (IndexC && VectorTy.isFixedVector() &&
81       IndexC->uge(VectorTy.getNumElements()) &&
82       isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
83     // For fixed-length vectors, it's invalid to extract out-of-range elements.
84     MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
85     return true;
86   }
87 
88   return false;
89 }
90 
91 bool CombinerHelper::matchExtractVectorElementWithDifferentIndices(
92     const MachineOperand &MO, BuildFnTy &MatchInfo) {
93   MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
94   GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
95 
96   //
97   //  %idx1:_(s64) = G_CONSTANT i64 1
98   //  %idx2:_(s64) = G_CONSTANT i64 2
99   //  %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
100   //  %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2
101   //  x s32>), %idx1(s64)
102   //
103   //  -->
104   //
105   //  %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>),
106   //  %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x
107   //  s32>), %idx1(s64)
108   //
109   //
110 
111   Register Index = Extract->getIndexReg();
112 
113   // We try to get the value of the Index register.
114   std::optional<ValueAndVReg> MaybeIndex =
115       getIConstantVRegValWithLookThrough(Index, MRI);
116   std::optional<APInt> IndexC = std::nullopt;
117 
118   if (!MaybeIndex)
119     return false;
120   else
121     IndexC = MaybeIndex->Value;
122 
123   Register Vector = Extract->getVectorReg();
124 
125   GInsertVectorElement *Insert =
126       getOpcodeDef<GInsertVectorElement>(Vector, MRI);
127   if (!Insert)
128     return false;
129 
130   Register Dst = Extract->getReg(0);
131 
132   std::optional<ValueAndVReg> MaybeInsertIndex =
133       getIConstantVRegValWithLookThrough(Insert->getIndexReg(), MRI);
134 
135   if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) {
136     // There is no one-use check. We have to keep the insert. When both Index
137     // registers are constants and not equal, we can look into the Vector
138     // register of the insert.
139     MatchInfo = [=](MachineIRBuilder &B) {
140       B.buildExtractVectorElement(Dst, Insert->getVectorReg(), Index);
141     };
142     return true;
143   }
144 
145   return false;
146 }
147 
148 bool CombinerHelper::matchExtractVectorElementWithBuildVector(
149     const MachineOperand &MO, BuildFnTy &MatchInfo) {
150   MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
151   GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
152 
153   //
154   //  %zero:_(s64) = G_CONSTANT i64 0
155   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
156   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
157   //
158   //  -->
159   //
160   //  %extract:_(32) = COPY %arg1(s32)
161   //
162   //
163   //
164   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
165   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
166   //
167   //  -->
168   //
169   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32)
170   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
171   //
172 
173   Register Vector = Extract->getVectorReg();
174 
175   // We expect a buildVector on the Vector register.
176   GBuildVector *Build = getOpcodeDef<GBuildVector>(Vector, MRI);
177   if (!Build)
178     return false;
179 
180   LLT VectorTy = MRI.getType(Vector);
181 
182   // There is a one-use check. There are more combines on build vectors.
183   EVT Ty(getMVTForLLT(VectorTy));
184   if (!MRI.hasOneNonDBGUse(Build->getReg(0)) ||
185       !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
186     return false;
187 
188   Register Index = Extract->getIndexReg();
189 
190   // If the Index is constant, then we can extract the element from the given
191   // offset.
192   std::optional<ValueAndVReg> MaybeIndex =
193       getIConstantVRegValWithLookThrough(Index, MRI);
194   if (!MaybeIndex)
195     return false;
196 
197   // We now know that there is a buildVector def'd on the Vector register and
198   // the index is const. The combine will succeed.
199 
200   Register Dst = Extract->getReg(0);
201 
202   MatchInfo = [=](MachineIRBuilder &B) {
203     B.buildCopy(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue()));
204   };
205 
206   return true;
207 }
208 
209 bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc(
210     const MachineOperand &MO, BuildFnTy &MatchInfo) {
211   MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
212   GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root);
213 
214   //
215   //  %zero:_(s64) = G_CONSTANT i64 0
216   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
217   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64)
218   //
219   //  -->
220   //
221   //  %extract:_(32) = G_TRUNC %arg1(s64)
222   //
223   //
224   //
225   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
226   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
227   //
228   //  -->
229   //
230   //  %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64)
231   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64)
232   //
233 
234   Register Vector = Extract->getVectorReg();
235 
236   // We expect a buildVectorTrunc on the Vector register.
237   GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Vector, MRI);
238   if (!Build)
239     return false;
240 
241   LLT VectorTy = MRI.getType(Vector);
242 
243   // There is a one-use check. There are more combines on build vectors.
244   EVT Ty(getMVTForLLT(VectorTy));
245   if (!MRI.hasOneNonDBGUse(Build->getReg(0)) ||
246       !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
247     return false;
248 
249   Register Index = Extract->getIndexReg();
250 
251   // If the Index is constant, then we can extract the element from the given
252   // offset.
253   std::optional<ValueAndVReg> MaybeIndex =
254       getIConstantVRegValWithLookThrough(Index, MRI);
255   if (!MaybeIndex)
256     return false;
257 
258   // We now know that there is a buildVectorTrunc def'd on the Vector register
259   // and the index is const. The combine will succeed.
260 
261   Register Dst = Extract->getReg(0);
262   LLT DstTy = MRI.getType(Dst);
263   LLT SrcTy = MRI.getType(Build->getSourceReg(0));
264 
265   // For buildVectorTrunc, the inputs are truncated.
266   if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
267     return false;
268 
269   MatchInfo = [=](MachineIRBuilder &B) {
270     B.buildTrunc(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue()));
271   };
272 
273   return true;
274 }
275 
276 bool CombinerHelper::matchExtractVectorElementWithShuffleVector(
277     const MachineOperand &MO, BuildFnTy &MatchInfo) {
278   GExtractVectorElement *Extract =
279       cast<GExtractVectorElement>(getDefIgnoringCopies(MO.getReg(), MRI));
280 
281   //
282   //  %zero:_(s64) = G_CONSTANT i64 0
283   //  %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
284   //                     shufflemask(0, 0, 0, 0)
285   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64)
286   //
287   //  -->
288   //
289   //  %zero1:_(s64) = G_CONSTANT i64 0
290   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64)
291   //
292   //
293   //
294   //
295   //  %three:_(s64) = G_CONSTANT i64 3
296   //  %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
297   //                     shufflemask(0, 0, 0, -1)
298   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64)
299   //
300   //  -->
301   //
302   //  %extract:_(s32) = G_IMPLICIT_DEF
303   //
304   //
305   //
306   //
307   //
308   //  %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
309   //                     shufflemask(0, 0, 0, -1)
310   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64)
311   //
312   //  -->
313   //
314   //  %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>),
315   //                     shufflemask(0, 0, 0, -1)
316   //  %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64)
317   //
318 
319   // We try to get the value of the Index register.
320   std::optional<ValueAndVReg> MaybeIndex =
321       getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
322   if (!MaybeIndex)
323     return false;
324 
325   GShuffleVector *Shuffle =
326       cast<GShuffleVector>(getDefIgnoringCopies(Extract->getVectorReg(), MRI));
327 
328   ArrayRef<int> Mask = Shuffle->getMask();
329 
330   unsigned Offset = MaybeIndex->Value.getZExtValue();
331   int SrcIdx = Mask[Offset];
332 
333   LLT Src1Type = MRI.getType(Shuffle->getSrc1Reg());
334   // At the IR level a <1 x ty> shuffle  vector is valid, but we want to extract
335   // from a vector.
336   assert(Src1Type.isVector() && "expected to extract from a vector");
337   unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1;
338 
339   // Note that there is no one use check.
340   Register Dst = Extract->getReg(0);
341   LLT DstTy = MRI.getType(Dst);
342 
343   if (SrcIdx < 0 &&
344       isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
345     MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
346     return true;
347   }
348 
349   // If the legality check failed, then we still have to abort.
350   if (SrcIdx < 0)
351     return false;
352 
353   Register NewVector;
354 
355   // We check in which vector and at what offset to look through.
356   if (SrcIdx < (int)LHSWidth) {
357     NewVector = Shuffle->getSrc1Reg();
358     // SrcIdx unchanged
359   } else { // SrcIdx >= LHSWidth
360     NewVector = Shuffle->getSrc2Reg();
361     SrcIdx -= LHSWidth;
362   }
363 
364   LLT IdxTy = MRI.getType(Extract->getIndexReg());
365   LLT NewVectorTy = MRI.getType(NewVector);
366 
367   // We check the legality of the look through.
368   if (!isLegalOrBeforeLegalizer(
369           {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) ||
370       !isConstantLegalOrBeforeLegalizer({IdxTy}))
371     return false;
372 
373   // We look through the shuffle vector.
374   MatchInfo = [=](MachineIRBuilder &B) {
375     auto Idx = B.buildConstant(IdxTy, SrcIdx);
376     B.buildExtractVectorElement(Dst, NewVector, Idx);
377   };
378 
379   return true;
380 }
381 
382 bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
383                                                  BuildFnTy &MatchInfo) {
384   GInsertVectorElement *Insert = cast<GInsertVectorElement>(&MI);
385 
386   Register Dst = Insert->getReg(0);
387   LLT DstTy = MRI.getType(Dst);
388   Register Index = Insert->getIndexReg();
389 
390   if (!DstTy.isFixedVector())
391     return false;
392 
393   std::optional<ValueAndVReg> MaybeIndex =
394       getIConstantVRegValWithLookThrough(Index, MRI);
395 
396   if (MaybeIndex && MaybeIndex->Value.uge(DstTy.getNumElements()) &&
397       isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) {
398     MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); };
399     return true;
400   }
401 
402   return false;
403 }
404 
405 bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO,
406                                       BuildFnTy &MatchInfo) {
407   GAdd *Add = cast<GAdd>(MRI.getVRegDef(MO.getReg()));
408   GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getLHSReg()));
409   GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getRHSReg()));
410 
411   Register Dst = Add->getReg(0);
412 
413   if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
414       !MRI.hasOneNonDBGUse(RHSVScale->getReg(0)))
415     return false;
416 
417   MatchInfo = [=](MachineIRBuilder &B) {
418     B.buildVScale(Dst, LHSVScale->getSrc() + RHSVScale->getSrc());
419   };
420 
421   return true;
422 }
423 
424 bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO,
425                                       BuildFnTy &MatchInfo) {
426   GMul *Mul = cast<GMul>(MRI.getVRegDef(MO.getReg()));
427   GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Mul->getLHSReg()));
428 
429   std::optional<APInt> MaybeRHS = getIConstantVRegVal(Mul->getRHSReg(), MRI);
430   if (!MaybeRHS)
431     return false;
432 
433   Register Dst = MO.getReg();
434 
435   if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)))
436     return false;
437 
438   MatchInfo = [=](MachineIRBuilder &B) {
439     B.buildVScale(Dst, LHSVScale->getSrc() * *MaybeRHS);
440   };
441 
442   return true;
443 }
444 
445 bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
446                                       BuildFnTy &MatchInfo) {
447   GSub *Sub = cast<GSub>(MRI.getVRegDef(MO.getReg()));
448   GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Sub->getRHSReg()));
449 
450   Register Dst = MO.getReg();
451   LLT DstTy = MRI.getType(Dst);
452 
453   if (!MRI.hasOneNonDBGUse(RHSVScale->getReg(0)) ||
454       !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, DstTy}))
455     return false;
456 
457   MatchInfo = [=](MachineIRBuilder &B) {
458     auto VScale = B.buildVScale(DstTy, -RHSVScale->getSrc());
459     B.buildAdd(Dst, Sub->getLHSReg(), VScale, Sub->getFlags());
460   };
461 
462   return true;
463 }
464 
465 bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
466                                       BuildFnTy &MatchInfo) {
467   GShl *Shl = cast<GShl>(MRI.getVRegDef(MO.getReg()));
468   GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg()));
469 
470   std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI);
471   if (!MaybeRHS)
472     return false;
473 
474   Register Dst = MO.getReg();
475   LLT DstTy = MRI.getType(Dst);
476 
477   if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
478       !isLegalOrBeforeLegalizer({TargetOpcode::G_VSCALE, DstTy}))
479     return false;
480 
481   MatchInfo = [=](MachineIRBuilder &B) {
482     B.buildVScale(Dst, LHSVScale->getSrc().shl(*MaybeRHS));
483   };
484 
485   return true;
486 }
487