xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelperCasts.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- CombinerHelperCasts.cpp---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and
10 // G_ZEXT
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17 #include "llvm/CodeGen/GlobalISel/Utils.h"
18 #include "llvm/CodeGen/LowLevelTypeUtils.h"
19 #include "llvm/CodeGen/MachineOperand.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 #include "llvm/Support/Casting.h"
23 
24 #define DEBUG_TYPE "gi-combiner"
25 
26 using namespace llvm;
27 
28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
29                                       BuildFnTy &MatchInfo) const {
30   GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI));
31   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI));
32 
33   Register Dst = Sext->getReg(0);
34   Register Src = Trunc->getSrcReg();
35 
36   LLT DstTy = MRI.getType(Dst);
37   LLT SrcTy = MRI.getType(Src);
38 
39   // Combines without nsw trunc.
40   if (!Trunc->getFlag(MachineInstr::NoSWrap)) {
41     if (DstTy != SrcTy ||
42         !isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT_INREG, {DstTy, SrcTy}}))
43       return false;
44 
45     // Do this for 8 bit values and up. We don't want to do it for e.g. G_TRUNC
46     // to i1.
47     unsigned TruncWidth = MRI.getType(Trunc->getReg(0)).getScalarSizeInBits();
48     if (TruncWidth < 8)
49       return false;
50 
51     MatchInfo = [=](MachineIRBuilder &B) {
52       B.buildSExtInReg(Dst, Src, TruncWidth);
53     };
54     return true;
55   }
56 
57   // Combines for nsw trunc.
58 
59   if (DstTy == SrcTy) {
60     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
61     return true;
62   }
63 
64   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
65       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
66     MatchInfo = [=](MachineIRBuilder &B) {
67       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap);
68     };
69     return true;
70   }
71 
72   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
73       isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
74     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
75     return true;
76   }
77 
78   return false;
79 }
80 
81 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
82                                       BuildFnTy &MatchInfo) const {
83   GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI));
84   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI));
85 
86   Register Dst = Zext->getReg(0);
87   Register Src = Trunc->getSrcReg();
88 
89   LLT DstTy = MRI.getType(Dst);
90   LLT SrcTy = MRI.getType(Src);
91 
92   if (DstTy == SrcTy) {
93     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
94     return true;
95   }
96 
97   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
98       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
99     MatchInfo = [=](MachineIRBuilder &B) {
100       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap);
101     };
102     return true;
103   }
104 
105   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
106       isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
107     MatchInfo = [=](MachineIRBuilder &B) {
108       B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg);
109     };
110     return true;
111   }
112 
113   return false;
114 }
115 
116 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO,
117                                      BuildFnTy &MatchInfo) const {
118   GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg()));
119 
120   Register Dst = Zext->getReg(0);
121   Register Src = Zext->getSrcReg();
122 
123   LLT DstTy = MRI.getType(Dst);
124   LLT SrcTy = MRI.getType(Src);
125   const auto &TLI = getTargetLowering();
126 
127   // Convert zext nneg to sext if sext is the preferred form for the target.
128   if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) &&
129       TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) {
130     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
131     return true;
132   }
133 
134   return false;
135 }
136 
137 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root,
138                                         const MachineInstr &ExtMI,
139                                         BuildFnTy &MatchInfo) const {
140   const GTrunc *Trunc = cast<GTrunc>(&Root);
141   const GExtOp *Ext = cast<GExtOp>(&ExtMI);
142 
143   if (!MRI.hasOneNonDBGUse(Ext->getReg(0)))
144     return false;
145 
146   Register Dst = Trunc->getReg(0);
147   Register Src = Ext->getSrcReg();
148   LLT DstTy = MRI.getType(Dst);
149   LLT SrcTy = MRI.getType(Src);
150 
151   if (SrcTy == DstTy) {
152     // The source and the destination are equally sized. We need to copy.
153     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
154 
155     return true;
156   }
157 
158   if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) {
159     // If the source is smaller than the destination, we need to extend.
160 
161     if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}}))
162       return false;
163 
164     MatchInfo = [=](MachineIRBuilder &B) {
165       B.buildInstr(Ext->getOpcode(), {Dst}, {Src});
166     };
167 
168     return true;
169   }
170 
171   if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) {
172     // If the source is larger than the destination, then we need to truncate.
173 
174     if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
175       return false;
176 
177     MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); };
178 
179     return true;
180   }
181 
182   return false;
183 }
184 
185 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const {
186   const TargetLowering &TLI = getTargetLowering();
187   LLVMContext &Ctx = getContext();
188 
189   switch (Opcode) {
190   case TargetOpcode::G_ANYEXT:
191   case TargetOpcode::G_ZEXT:
192     return TLI.isZExtFree(FromTy, ToTy, Ctx);
193   case TargetOpcode::G_TRUNC:
194     return TLI.isTruncateFree(FromTy, ToTy, Ctx);
195   default:
196     return false;
197   }
198 }
199 
200 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI,
201                                        const MachineInstr &SelectMI,
202                                        BuildFnTy &MatchInfo) const {
203   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
204   const GSelect *Select = cast<GSelect>(&SelectMI);
205 
206   if (!MRI.hasOneNonDBGUse(Select->getReg(0)))
207     return false;
208 
209   Register Dst = Cast->getReg(0);
210   LLT DstTy = MRI.getType(Dst);
211   LLT CondTy = MRI.getType(Select->getCondReg());
212   Register TrueReg = Select->getTrueReg();
213   Register FalseReg = Select->getFalseReg();
214   LLT SrcTy = MRI.getType(TrueReg);
215   Register Cond = Select->getCondReg();
216 
217   if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}}))
218     return false;
219 
220   if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy))
221     return false;
222 
223   MatchInfo = [=](MachineIRBuilder &B) {
224     auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg});
225     auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg});
226     B.buildSelect(Dst, Cond, True, False);
227   };
228 
229   return true;
230 }
231 
232 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI,
233                                    const MachineInstr &SecondMI,
234                                    BuildFnTy &MatchInfo) const {
235   const GExtOp *First = cast<GExtOp>(&FirstMI);
236   const GExtOp *Second = cast<GExtOp>(&SecondMI);
237 
238   Register Dst = First->getReg(0);
239   Register Src = Second->getSrcReg();
240   LLT DstTy = MRI.getType(Dst);
241   LLT SrcTy = MRI.getType(Src);
242 
243   if (!MRI.hasOneNonDBGUse(Second->getReg(0)))
244     return false;
245 
246   // ext of ext -> later ext
247   if (First->getOpcode() == Second->getOpcode() &&
248       isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
249     if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
250       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
251       if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
252         Flag = MachineInstr::MIFlag::NonNeg;
253       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
254       return true;
255     }
256     // not zext -> no flags
257     MatchInfo = [=](MachineIRBuilder &B) {
258       B.buildInstr(Second->getOpcode(), {Dst}, {Src});
259     };
260     return true;
261   }
262 
263   // anyext of sext/zext  -> sext/zext
264   // -> pick anyext as second ext, then ext of ext
265   if (First->getOpcode() == TargetOpcode::G_ANYEXT &&
266       isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
267     if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
268       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
269       if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
270         Flag = MachineInstr::MIFlag::NonNeg;
271       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
272       return true;
273     }
274     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
275     return true;
276   }
277 
278   // sext/zext of anyext -> sext/zext
279   // -> pick anyext as first ext, then ext of ext
280   if (Second->getOpcode() == TargetOpcode::G_ANYEXT &&
281       isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) {
282     if (First->getOpcode() == TargetOpcode::G_ZEXT) {
283       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
284       if (First->getFlag(MachineInstr::MIFlag::NonNeg))
285         Flag = MachineInstr::MIFlag::NonNeg;
286       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
287       return true;
288     }
289     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
290     return true;
291   }
292 
293   return false;
294 }
295 
296 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI,
297                                             const MachineInstr &BVMI,
298                                             BuildFnTy &MatchInfo) const {
299   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
300   const GBuildVector *BV = cast<GBuildVector>(&BVMI);
301 
302   if (!MRI.hasOneNonDBGUse(BV->getReg(0)))
303     return false;
304 
305   Register Dst = Cast->getReg(0);
306   // The type of the new build vector.
307   LLT DstTy = MRI.getType(Dst);
308   // The scalar or element type of the new build vector.
309   LLT ElemTy = DstTy.getScalarType();
310   // The scalar or element type of the old build vector.
311   LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType();
312 
313   // Check legality of new build vector, the scalar casts, and profitability of
314   // the many casts.
315   if (!isLegalOrBeforeLegalizer(
316           {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) ||
317       !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) ||
318       !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy))
319     return false;
320 
321   MatchInfo = [=](MachineIRBuilder &B) {
322     SmallVector<Register> Casts;
323     unsigned Elements = BV->getNumSources();
324     for (unsigned I = 0; I < Elements; ++I) {
325       auto CastI =
326           B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)});
327       Casts.push_back(CastI.getReg(0));
328     }
329 
330     B.buildBuildVector(Dst, Casts);
331   };
332 
333   return true;
334 }
335 
336 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI,
337                                       const MachineInstr &BinopMI,
338                                       BuildFnTy &MatchInfo) const {
339   const GTrunc *Trunc = cast<GTrunc>(&TruncMI);
340   const GBinOp *BinOp = cast<GBinOp>(&BinopMI);
341 
342   if (!MRI.hasOneNonDBGUse(BinOp->getReg(0)))
343     return false;
344 
345   Register Dst = Trunc->getReg(0);
346   LLT DstTy = MRI.getType(Dst);
347 
348   // Is narrow binop legal?
349   if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}}))
350     return false;
351 
352   MatchInfo = [=](MachineIRBuilder &B) {
353     auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg());
354     auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg());
355     B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS});
356   };
357 
358   return true;
359 }
360 
361 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI,
362                                         APInt &MatchInfo) const {
363   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
364 
365   APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI);
366 
367   LLT DstTy = MRI.getType(Cast->getReg(0));
368 
369   if (!isConstantLegalOrBeforeLegalizer(DstTy))
370     return false;
371 
372   switch (Cast->getOpcode()) {
373   case TargetOpcode::G_TRUNC: {
374     MatchInfo = Input.trunc(DstTy.getScalarSizeInBits());
375     return true;
376   }
377   default:
378     return false;
379   }
380 }
381 
382 bool CombinerHelper::matchRedundantSextInReg(MachineInstr &Root,
383                                              MachineInstr &Other,
384                                              BuildFnTy &MatchInfo) const {
385   assert(Root.getOpcode() == TargetOpcode::G_SEXT_INREG &&
386          Other.getOpcode() == TargetOpcode::G_SEXT_INREG);
387 
388   unsigned RootWidth = Root.getOperand(2).getImm();
389   unsigned OtherWidth = Other.getOperand(2).getImm();
390 
391   Register Dst = Root.getOperand(0).getReg();
392   Register OtherDst = Other.getOperand(0).getReg();
393   Register Src = Other.getOperand(1).getReg();
394 
395   if (RootWidth >= OtherWidth) {
396     // The root sext_inreg is entirely redundant because the other one
397     // is narrower.
398     if (!canReplaceReg(Dst, OtherDst, MRI))
399       return false;
400 
401     MatchInfo = [=](MachineIRBuilder &B) {
402       Observer.changingAllUsesOfReg(MRI, Dst);
403       MRI.replaceRegWith(Dst, OtherDst);
404       Observer.finishedChangingAllUsesOfReg();
405     };
406   } else {
407     // RootWidth < OtherWidth, rewrite this G_SEXT_INREG with the source of the
408     // other G_SEXT_INREG.
409     MatchInfo = [=](MachineIRBuilder &B) {
410       B.buildSExtInReg(Dst, Src, RootWidth);
411     };
412   }
413 
414   return true;
415 }
416