xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- llvm/CodeGen/GlobalISel/LegalizerHelper.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This file implements the LegalizerHelper class to legalize
10 /// individual instructions and the LegalizeMachineIR wrapper pass for the
11 /// primary legalization.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16 #include "llvm/CodeGen/GlobalISel/CallLowering.h"
17 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
18 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
19 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
20 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
21 #include "llvm/CodeGen/GlobalISel/LostDebugLocObserver.h"
22 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
23 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
24 #include "llvm/CodeGen/GlobalISel/Utils.h"
25 #include "llvm/CodeGen/MachineConstantPool.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/RuntimeLibcallUtil.h"
29 #include "llvm/CodeGen/TargetFrameLowering.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetLowering.h"
32 #include "llvm/CodeGen/TargetOpcodes.h"
33 #include "llvm/CodeGen/TargetSubtargetInfo.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Target/TargetMachine.h"
39 #include <numeric>
40 #include <optional>
41 
42 #define DEBUG_TYPE "legalizer"
43 
44 using namespace llvm;
45 using namespace LegalizeActions;
46 using namespace MIPatternMatch;
47 
48 /// Try to break down \p OrigTy into \p NarrowTy sized pieces.
49 ///
50 /// Returns the number of \p NarrowTy elements needed to reconstruct \p OrigTy,
51 /// with any leftover piece as type \p LeftoverTy
52 ///
53 /// Returns -1 in the first element of the pair if the breakdown is not
54 /// satisfiable.
55 static std::pair<int, int>
getNarrowTypeBreakDown(LLT OrigTy,LLT NarrowTy,LLT & LeftoverTy)56 getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
57   assert(!LeftoverTy.isValid() && "this is an out argument");
58 
59   unsigned Size = OrigTy.getSizeInBits();
60   unsigned NarrowSize = NarrowTy.getSizeInBits();
61   unsigned NumParts = Size / NarrowSize;
62   unsigned LeftoverSize = Size - NumParts * NarrowSize;
63   assert(Size > NarrowSize);
64 
65   if (LeftoverSize == 0)
66     return {NumParts, 0};
67 
68   if (NarrowTy.isVector()) {
69     unsigned EltSize = OrigTy.getScalarSizeInBits();
70     if (LeftoverSize % EltSize != 0)
71       return {-1, -1};
72     LeftoverTy =
73         LLT::scalarOrVector(ElementCount::getFixed(LeftoverSize / EltSize),
74                             OrigTy.getElementType());
75   } else {
76     LeftoverTy = LLT::scalar(LeftoverSize);
77   }
78 
79   int NumLeftover = LeftoverSize / LeftoverTy.getSizeInBits();
80   return std::make_pair(NumParts, NumLeftover);
81 }
82 
getFloatTypeForLLT(LLVMContext & Ctx,LLT Ty)83 static Type *getFloatTypeForLLT(LLVMContext &Ctx, LLT Ty) {
84 
85   if (!Ty.isScalar())
86     return nullptr;
87 
88   switch (Ty.getSizeInBits()) {
89   case 16:
90     return Type::getHalfTy(Ctx);
91   case 32:
92     return Type::getFloatTy(Ctx);
93   case 64:
94     return Type::getDoubleTy(Ctx);
95   case 80:
96     return Type::getX86_FP80Ty(Ctx);
97   case 128:
98     return Type::getFP128Ty(Ctx);
99   default:
100     return nullptr;
101   }
102 }
103 
LegalizerHelper(MachineFunction & MF,GISelChangeObserver & Observer,MachineIRBuilder & Builder)104 LegalizerHelper::LegalizerHelper(MachineFunction &MF,
105                                  GISelChangeObserver &Observer,
106                                  MachineIRBuilder &Builder)
107     : MIRBuilder(Builder), Observer(Observer), MRI(MF.getRegInfo()),
108       LI(*MF.getSubtarget().getLegalizerInfo()),
109       TLI(*MF.getSubtarget().getTargetLowering()), KB(nullptr) {}
110 
LegalizerHelper(MachineFunction & MF,const LegalizerInfo & LI,GISelChangeObserver & Observer,MachineIRBuilder & B,GISelKnownBits * KB)111 LegalizerHelper::LegalizerHelper(MachineFunction &MF, const LegalizerInfo &LI,
112                                  GISelChangeObserver &Observer,
113                                  MachineIRBuilder &B, GISelKnownBits *KB)
114     : MIRBuilder(B), Observer(Observer), MRI(MF.getRegInfo()), LI(LI),
115       TLI(*MF.getSubtarget().getTargetLowering()), KB(KB) {}
116 
117 LegalizerHelper::LegalizeResult
legalizeInstrStep(MachineInstr & MI,LostDebugLocObserver & LocObserver)118 LegalizerHelper::legalizeInstrStep(MachineInstr &MI,
119                                    LostDebugLocObserver &LocObserver) {
120   LLVM_DEBUG(dbgs() << "Legalizing: " << MI);
121 
122   MIRBuilder.setInstrAndDebugLoc(MI);
123 
124   if (isa<GIntrinsic>(MI))
125     return LI.legalizeIntrinsic(*this, MI) ? Legalized : UnableToLegalize;
126   auto Step = LI.getAction(MI, MRI);
127   switch (Step.Action) {
128   case Legal:
129     LLVM_DEBUG(dbgs() << ".. Already legal\n");
130     return AlreadyLegal;
131   case Libcall:
132     LLVM_DEBUG(dbgs() << ".. Convert to libcall\n");
133     return libcall(MI, LocObserver);
134   case NarrowScalar:
135     LLVM_DEBUG(dbgs() << ".. Narrow scalar\n");
136     return narrowScalar(MI, Step.TypeIdx, Step.NewType);
137   case WidenScalar:
138     LLVM_DEBUG(dbgs() << ".. Widen scalar\n");
139     return widenScalar(MI, Step.TypeIdx, Step.NewType);
140   case Bitcast:
141     LLVM_DEBUG(dbgs() << ".. Bitcast type\n");
142     return bitcast(MI, Step.TypeIdx, Step.NewType);
143   case Lower:
144     LLVM_DEBUG(dbgs() << ".. Lower\n");
145     return lower(MI, Step.TypeIdx, Step.NewType);
146   case FewerElements:
147     LLVM_DEBUG(dbgs() << ".. Reduce number of elements\n");
148     return fewerElementsVector(MI, Step.TypeIdx, Step.NewType);
149   case MoreElements:
150     LLVM_DEBUG(dbgs() << ".. Increase number of elements\n");
151     return moreElementsVector(MI, Step.TypeIdx, Step.NewType);
152   case Custom:
153     LLVM_DEBUG(dbgs() << ".. Custom legalization\n");
154     return LI.legalizeCustom(*this, MI, LocObserver) ? Legalized
155                                                      : UnableToLegalize;
156   default:
157     LLVM_DEBUG(dbgs() << ".. Unable to legalize\n");
158     return UnableToLegalize;
159   }
160 }
161 
insertParts(Register DstReg,LLT ResultTy,LLT PartTy,ArrayRef<Register> PartRegs,LLT LeftoverTy,ArrayRef<Register> LeftoverRegs)162 void LegalizerHelper::insertParts(Register DstReg,
163                                   LLT ResultTy, LLT PartTy,
164                                   ArrayRef<Register> PartRegs,
165                                   LLT LeftoverTy,
166                                   ArrayRef<Register> LeftoverRegs) {
167   if (!LeftoverTy.isValid()) {
168     assert(LeftoverRegs.empty());
169 
170     if (!ResultTy.isVector()) {
171       MIRBuilder.buildMergeLikeInstr(DstReg, PartRegs);
172       return;
173     }
174 
175     if (PartTy.isVector())
176       MIRBuilder.buildConcatVectors(DstReg, PartRegs);
177     else
178       MIRBuilder.buildBuildVector(DstReg, PartRegs);
179     return;
180   }
181 
182   // Merge sub-vectors with different number of elements and insert into DstReg.
183   if (ResultTy.isVector()) {
184     assert(LeftoverRegs.size() == 1 && "Expected one leftover register");
185     SmallVector<Register, 8> AllRegs;
186     for (auto Reg : concat<const Register>(PartRegs, LeftoverRegs))
187       AllRegs.push_back(Reg);
188     return mergeMixedSubvectors(DstReg, AllRegs);
189   }
190 
191   SmallVector<Register> GCDRegs;
192   LLT GCDTy = getGCDType(getGCDType(ResultTy, LeftoverTy), PartTy);
193   for (auto PartReg : concat<const Register>(PartRegs, LeftoverRegs))
194     extractGCDType(GCDRegs, GCDTy, PartReg);
195   LLT ResultLCMTy = buildLCMMergePieces(ResultTy, LeftoverTy, GCDTy, GCDRegs);
196   buildWidenedRemergeToDst(DstReg, ResultLCMTy, GCDRegs);
197 }
198 
appendVectorElts(SmallVectorImpl<Register> & Elts,Register Reg)199 void LegalizerHelper::appendVectorElts(SmallVectorImpl<Register> &Elts,
200                                        Register Reg) {
201   LLT Ty = MRI.getType(Reg);
202   SmallVector<Register, 8> RegElts;
203   extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts,
204                MIRBuilder, MRI);
205   Elts.append(RegElts);
206 }
207 
208 /// Merge \p PartRegs with different types into \p DstReg.
mergeMixedSubvectors(Register DstReg,ArrayRef<Register> PartRegs)209 void LegalizerHelper::mergeMixedSubvectors(Register DstReg,
210                                            ArrayRef<Register> PartRegs) {
211   SmallVector<Register, 8> AllElts;
212   for (unsigned i = 0; i < PartRegs.size() - 1; ++i)
213     appendVectorElts(AllElts, PartRegs[i]);
214 
215   Register Leftover = PartRegs[PartRegs.size() - 1];
216   if (!MRI.getType(Leftover).isVector())
217     AllElts.push_back(Leftover);
218   else
219     appendVectorElts(AllElts, Leftover);
220 
221   MIRBuilder.buildMergeLikeInstr(DstReg, AllElts);
222 }
223 
224 /// Append the result registers of G_UNMERGE_VALUES \p MI to \p Regs.
getUnmergeResults(SmallVectorImpl<Register> & Regs,const MachineInstr & MI)225 static void getUnmergeResults(SmallVectorImpl<Register> &Regs,
226                               const MachineInstr &MI) {
227   assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
228 
229   const int StartIdx = Regs.size();
230   const int NumResults = MI.getNumOperands() - 1;
231   Regs.resize(Regs.size() + NumResults);
232   for (int I = 0; I != NumResults; ++I)
233     Regs[StartIdx + I] = MI.getOperand(I).getReg();
234 }
235 
extractGCDType(SmallVectorImpl<Register> & Parts,LLT GCDTy,Register SrcReg)236 void LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts,
237                                      LLT GCDTy, Register SrcReg) {
238   LLT SrcTy = MRI.getType(SrcReg);
239   if (SrcTy == GCDTy) {
240     // If the source already evenly divides the result type, we don't need to do
241     // anything.
242     Parts.push_back(SrcReg);
243   } else {
244     // Need to split into common type sized pieces.
245     auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
246     getUnmergeResults(Parts, *Unmerge);
247   }
248 }
249 
extractGCDType(SmallVectorImpl<Register> & Parts,LLT DstTy,LLT NarrowTy,Register SrcReg)250 LLT LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts, LLT DstTy,
251                                     LLT NarrowTy, Register SrcReg) {
252   LLT SrcTy = MRI.getType(SrcReg);
253   LLT GCDTy = getGCDType(getGCDType(SrcTy, NarrowTy), DstTy);
254   extractGCDType(Parts, GCDTy, SrcReg);
255   return GCDTy;
256 }
257 
buildLCMMergePieces(LLT DstTy,LLT NarrowTy,LLT GCDTy,SmallVectorImpl<Register> & VRegs,unsigned PadStrategy)258 LLT LegalizerHelper::buildLCMMergePieces(LLT DstTy, LLT NarrowTy, LLT GCDTy,
259                                          SmallVectorImpl<Register> &VRegs,
260                                          unsigned PadStrategy) {
261   LLT LCMTy = getLCMType(DstTy, NarrowTy);
262 
263   int NumParts = LCMTy.getSizeInBits() / NarrowTy.getSizeInBits();
264   int NumSubParts = NarrowTy.getSizeInBits() / GCDTy.getSizeInBits();
265   int NumOrigSrc = VRegs.size();
266 
267   Register PadReg;
268 
269   // Get a value we can use to pad the source value if the sources won't evenly
270   // cover the result type.
271   if (NumOrigSrc < NumParts * NumSubParts) {
272     if (PadStrategy == TargetOpcode::G_ZEXT)
273       PadReg = MIRBuilder.buildConstant(GCDTy, 0).getReg(0);
274     else if (PadStrategy == TargetOpcode::G_ANYEXT)
275       PadReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
276     else {
277       assert(PadStrategy == TargetOpcode::G_SEXT);
278 
279       // Shift the sign bit of the low register through the high register.
280       auto ShiftAmt =
281         MIRBuilder.buildConstant(LLT::scalar(64), GCDTy.getSizeInBits() - 1);
282       PadReg = MIRBuilder.buildAShr(GCDTy, VRegs.back(), ShiftAmt).getReg(0);
283     }
284   }
285 
286   // Registers for the final merge to be produced.
287   SmallVector<Register, 4> Remerge(NumParts);
288 
289   // Registers needed for intermediate merges, which will be merged into a
290   // source for Remerge.
291   SmallVector<Register, 4> SubMerge(NumSubParts);
292 
293   // Once we've fully read off the end of the original source bits, we can reuse
294   // the same high bits for remaining padding elements.
295   Register AllPadReg;
296 
297   // Build merges to the LCM type to cover the original result type.
298   for (int I = 0; I != NumParts; ++I) {
299     bool AllMergePartsArePadding = true;
300 
301     // Build the requested merges to the requested type.
302     for (int J = 0; J != NumSubParts; ++J) {
303       int Idx = I * NumSubParts + J;
304       if (Idx >= NumOrigSrc) {
305         SubMerge[J] = PadReg;
306         continue;
307       }
308 
309       SubMerge[J] = VRegs[Idx];
310 
311       // There are meaningful bits here we can't reuse later.
312       AllMergePartsArePadding = false;
313     }
314 
315     // If we've filled up a complete piece with padding bits, we can directly
316     // emit the natural sized constant if applicable, rather than a merge of
317     // smaller constants.
318     if (AllMergePartsArePadding && !AllPadReg) {
319       if (PadStrategy == TargetOpcode::G_ANYEXT)
320         AllPadReg = MIRBuilder.buildUndef(NarrowTy).getReg(0);
321       else if (PadStrategy == TargetOpcode::G_ZEXT)
322         AllPadReg = MIRBuilder.buildConstant(NarrowTy, 0).getReg(0);
323 
324       // If this is a sign extension, we can't materialize a trivial constant
325       // with the right type and have to produce a merge.
326     }
327 
328     if (AllPadReg) {
329       // Avoid creating additional instructions if we're just adding additional
330       // copies of padding bits.
331       Remerge[I] = AllPadReg;
332       continue;
333     }
334 
335     if (NumSubParts == 1)
336       Remerge[I] = SubMerge[0];
337     else
338       Remerge[I] = MIRBuilder.buildMergeLikeInstr(NarrowTy, SubMerge).getReg(0);
339 
340     // In the sign extend padding case, re-use the first all-signbit merge.
341     if (AllMergePartsArePadding && !AllPadReg)
342       AllPadReg = Remerge[I];
343   }
344 
345   VRegs = std::move(Remerge);
346   return LCMTy;
347 }
348 
buildWidenedRemergeToDst(Register DstReg,LLT LCMTy,ArrayRef<Register> RemergeRegs)349 void LegalizerHelper::buildWidenedRemergeToDst(Register DstReg, LLT LCMTy,
350                                                ArrayRef<Register> RemergeRegs) {
351   LLT DstTy = MRI.getType(DstReg);
352 
353   // Create the merge to the widened source, and extract the relevant bits into
354   // the result.
355 
356   if (DstTy == LCMTy) {
357     MIRBuilder.buildMergeLikeInstr(DstReg, RemergeRegs);
358     return;
359   }
360 
361   auto Remerge = MIRBuilder.buildMergeLikeInstr(LCMTy, RemergeRegs);
362   if (DstTy.isScalar() && LCMTy.isScalar()) {
363     MIRBuilder.buildTrunc(DstReg, Remerge);
364     return;
365   }
366 
367   if (LCMTy.isVector()) {
368     unsigned NumDefs = LCMTy.getSizeInBits() / DstTy.getSizeInBits();
369     SmallVector<Register, 8> UnmergeDefs(NumDefs);
370     UnmergeDefs[0] = DstReg;
371     for (unsigned I = 1; I != NumDefs; ++I)
372       UnmergeDefs[I] = MRI.createGenericVirtualRegister(DstTy);
373 
374     MIRBuilder.buildUnmerge(UnmergeDefs,
375                             MIRBuilder.buildMergeLikeInstr(LCMTy, RemergeRegs));
376     return;
377   }
378 
379   llvm_unreachable("unhandled case");
380 }
381 
getRTLibDesc(unsigned Opcode,unsigned Size)382 static RTLIB::Libcall getRTLibDesc(unsigned Opcode, unsigned Size) {
383 #define RTLIBCASE_INT(LibcallPrefix)                                           \
384   do {                                                                         \
385     switch (Size) {                                                            \
386     case 32:                                                                   \
387       return RTLIB::LibcallPrefix##32;                                         \
388     case 64:                                                                   \
389       return RTLIB::LibcallPrefix##64;                                         \
390     case 128:                                                                  \
391       return RTLIB::LibcallPrefix##128;                                        \
392     default:                                                                   \
393       llvm_unreachable("unexpected size");                                     \
394     }                                                                          \
395   } while (0)
396 
397 #define RTLIBCASE(LibcallPrefix)                                               \
398   do {                                                                         \
399     switch (Size) {                                                            \
400     case 32:                                                                   \
401       return RTLIB::LibcallPrefix##32;                                         \
402     case 64:                                                                   \
403       return RTLIB::LibcallPrefix##64;                                         \
404     case 80:                                                                   \
405       return RTLIB::LibcallPrefix##80;                                         \
406     case 128:                                                                  \
407       return RTLIB::LibcallPrefix##128;                                        \
408     default:                                                                   \
409       llvm_unreachable("unexpected size");                                     \
410     }                                                                          \
411   } while (0)
412 
413   switch (Opcode) {
414   case TargetOpcode::G_MUL:
415     RTLIBCASE_INT(MUL_I);
416   case TargetOpcode::G_SDIV:
417     RTLIBCASE_INT(SDIV_I);
418   case TargetOpcode::G_UDIV:
419     RTLIBCASE_INT(UDIV_I);
420   case TargetOpcode::G_SREM:
421     RTLIBCASE_INT(SREM_I);
422   case TargetOpcode::G_UREM:
423     RTLIBCASE_INT(UREM_I);
424   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
425     RTLIBCASE_INT(CTLZ_I);
426   case TargetOpcode::G_FADD:
427     RTLIBCASE(ADD_F);
428   case TargetOpcode::G_FSUB:
429     RTLIBCASE(SUB_F);
430   case TargetOpcode::G_FMUL:
431     RTLIBCASE(MUL_F);
432   case TargetOpcode::G_FDIV:
433     RTLIBCASE(DIV_F);
434   case TargetOpcode::G_FEXP:
435     RTLIBCASE(EXP_F);
436   case TargetOpcode::G_FEXP2:
437     RTLIBCASE(EXP2_F);
438   case TargetOpcode::G_FEXP10:
439     RTLIBCASE(EXP10_F);
440   case TargetOpcode::G_FREM:
441     RTLIBCASE(REM_F);
442   case TargetOpcode::G_FPOW:
443     RTLIBCASE(POW_F);
444   case TargetOpcode::G_FPOWI:
445     RTLIBCASE(POWI_F);
446   case TargetOpcode::G_FMA:
447     RTLIBCASE(FMA_F);
448   case TargetOpcode::G_FSIN:
449     RTLIBCASE(SIN_F);
450   case TargetOpcode::G_FCOS:
451     RTLIBCASE(COS_F);
452   case TargetOpcode::G_FTAN:
453     RTLIBCASE(TAN_F);
454   case TargetOpcode::G_FASIN:
455     RTLIBCASE(ASIN_F);
456   case TargetOpcode::G_FACOS:
457     RTLIBCASE(ACOS_F);
458   case TargetOpcode::G_FATAN:
459     RTLIBCASE(ATAN_F);
460   case TargetOpcode::G_FSINH:
461     RTLIBCASE(SINH_F);
462   case TargetOpcode::G_FCOSH:
463     RTLIBCASE(COSH_F);
464   case TargetOpcode::G_FTANH:
465     RTLIBCASE(TANH_F);
466   case TargetOpcode::G_FLOG10:
467     RTLIBCASE(LOG10_F);
468   case TargetOpcode::G_FLOG:
469     RTLIBCASE(LOG_F);
470   case TargetOpcode::G_FLOG2:
471     RTLIBCASE(LOG2_F);
472   case TargetOpcode::G_FLDEXP:
473     RTLIBCASE(LDEXP_F);
474   case TargetOpcode::G_FCEIL:
475     RTLIBCASE(CEIL_F);
476   case TargetOpcode::G_FFLOOR:
477     RTLIBCASE(FLOOR_F);
478   case TargetOpcode::G_FMINNUM:
479     RTLIBCASE(FMIN_F);
480   case TargetOpcode::G_FMAXNUM:
481     RTLIBCASE(FMAX_F);
482   case TargetOpcode::G_FSQRT:
483     RTLIBCASE(SQRT_F);
484   case TargetOpcode::G_FRINT:
485     RTLIBCASE(RINT_F);
486   case TargetOpcode::G_FNEARBYINT:
487     RTLIBCASE(NEARBYINT_F);
488   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
489     RTLIBCASE(ROUNDEVEN_F);
490   case TargetOpcode::G_INTRINSIC_LRINT:
491     RTLIBCASE(LRINT_F);
492   case TargetOpcode::G_INTRINSIC_LLRINT:
493     RTLIBCASE(LLRINT_F);
494   }
495   llvm_unreachable("Unknown libcall function");
496 }
497 
498 /// True if an instruction is in tail position in its caller. Intended for
499 /// legalizing libcalls as tail calls when possible.
isLibCallInTailPosition(const CallLowering::ArgInfo & Result,MachineInstr & MI,const TargetInstrInfo & TII,MachineRegisterInfo & MRI)500 static bool isLibCallInTailPosition(const CallLowering::ArgInfo &Result,
501                                     MachineInstr &MI,
502                                     const TargetInstrInfo &TII,
503                                     MachineRegisterInfo &MRI) {
504   MachineBasicBlock &MBB = *MI.getParent();
505   const Function &F = MBB.getParent()->getFunction();
506 
507   // Conservatively require the attributes of the call to match those of
508   // the return. Ignore NoAlias and NonNull because they don't affect the
509   // call sequence.
510   AttributeList CallerAttrs = F.getAttributes();
511   if (AttrBuilder(F.getContext(), CallerAttrs.getRetAttrs())
512           .removeAttribute(Attribute::NoAlias)
513           .removeAttribute(Attribute::NonNull)
514           .hasAttributes())
515     return false;
516 
517   // It's not safe to eliminate the sign / zero extension of the return value.
518   if (CallerAttrs.hasRetAttr(Attribute::ZExt) ||
519       CallerAttrs.hasRetAttr(Attribute::SExt))
520     return false;
521 
522   // Only tail call if the following instruction is a standard return or if we
523   // have a `thisreturn` callee, and a sequence like:
524   //
525   //   G_MEMCPY %0, %1, %2
526   //   $x0 = COPY %0
527   //   RET_ReallyLR implicit $x0
528   auto Next = next_nodbg(MI.getIterator(), MBB.instr_end());
529   if (Next != MBB.instr_end() && Next->isCopy()) {
530     if (MI.getOpcode() == TargetOpcode::G_BZERO)
531       return false;
532 
533     // For MEMCPY/MOMMOVE/MEMSET these will be the first use (the dst), as the
534     // mempy/etc routines return the same parameter. For other it will be the
535     // returned value.
536     Register VReg = MI.getOperand(0).getReg();
537     if (!VReg.isVirtual() || VReg != Next->getOperand(1).getReg())
538       return false;
539 
540     Register PReg = Next->getOperand(0).getReg();
541     if (!PReg.isPhysical())
542       return false;
543 
544     auto Ret = next_nodbg(Next, MBB.instr_end());
545     if (Ret == MBB.instr_end() || !Ret->isReturn())
546       return false;
547 
548     if (Ret->getNumImplicitOperands() != 1)
549       return false;
550 
551     if (!Ret->getOperand(0).isReg() || PReg != Ret->getOperand(0).getReg())
552       return false;
553 
554     // Skip over the COPY that we just validated.
555     Next = Ret;
556   }
557 
558   if (Next == MBB.instr_end() || TII.isTailCall(*Next) || !Next->isReturn())
559     return false;
560 
561   return true;
562 }
563 
564 LegalizerHelper::LegalizeResult
createLibcall(MachineIRBuilder & MIRBuilder,const char * Name,const CallLowering::ArgInfo & Result,ArrayRef<CallLowering::ArgInfo> Args,const CallingConv::ID CC,LostDebugLocObserver & LocObserver,MachineInstr * MI)565 llvm::createLibcall(MachineIRBuilder &MIRBuilder, const char *Name,
566                     const CallLowering::ArgInfo &Result,
567                     ArrayRef<CallLowering::ArgInfo> Args,
568                     const CallingConv::ID CC, LostDebugLocObserver &LocObserver,
569                     MachineInstr *MI) {
570   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
571 
572   CallLowering::CallLoweringInfo Info;
573   Info.CallConv = CC;
574   Info.Callee = MachineOperand::CreateES(Name);
575   Info.OrigRet = Result;
576   if (MI)
577     Info.IsTailCall =
578         (Result.Ty->isVoidTy() ||
579          Result.Ty == MIRBuilder.getMF().getFunction().getReturnType()) &&
580         isLibCallInTailPosition(Result, *MI, MIRBuilder.getTII(),
581                                 *MIRBuilder.getMRI());
582 
583   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
584   if (!CLI.lowerCall(MIRBuilder, Info))
585     return LegalizerHelper::UnableToLegalize;
586 
587   if (MI && Info.LoweredTailCall) {
588     assert(Info.IsTailCall && "Lowered tail call when it wasn't a tail call?");
589 
590     // Check debug locations before removing the return.
591     LocObserver.checkpoint(true);
592 
593     // We must have a return following the call (or debug insts) to get past
594     // isLibCallInTailPosition.
595     do {
596       MachineInstr *Next = MI->getNextNode();
597       assert(Next &&
598              (Next->isCopy() || Next->isReturn() || Next->isDebugInstr()) &&
599              "Expected instr following MI to be return or debug inst?");
600       // We lowered a tail call, so the call is now the return from the block.
601       // Delete the old return.
602       Next->eraseFromParent();
603     } while (MI->getNextNode());
604 
605     // We expect to lose the debug location from the return.
606     LocObserver.checkpoint(false);
607   }
608   return LegalizerHelper::Legalized;
609 }
610 
611 LegalizerHelper::LegalizeResult
createLibcall(MachineIRBuilder & MIRBuilder,RTLIB::Libcall Libcall,const CallLowering::ArgInfo & Result,ArrayRef<CallLowering::ArgInfo> Args,LostDebugLocObserver & LocObserver,MachineInstr * MI)612 llvm::createLibcall(MachineIRBuilder &MIRBuilder, RTLIB::Libcall Libcall,
613                     const CallLowering::ArgInfo &Result,
614                     ArrayRef<CallLowering::ArgInfo> Args,
615                     LostDebugLocObserver &LocObserver, MachineInstr *MI) {
616   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
617   const char *Name = TLI.getLibcallName(Libcall);
618   if (!Name)
619     return LegalizerHelper::UnableToLegalize;
620   const CallingConv::ID CC = TLI.getLibcallCallingConv(Libcall);
621   return createLibcall(MIRBuilder, Name, Result, Args, CC, LocObserver, MI);
622 }
623 
624 // Useful for libcalls where all operands have the same type.
625 static LegalizerHelper::LegalizeResult
simpleLibcall(MachineInstr & MI,MachineIRBuilder & MIRBuilder,unsigned Size,Type * OpType,LostDebugLocObserver & LocObserver)626 simpleLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, unsigned Size,
627               Type *OpType, LostDebugLocObserver &LocObserver) {
628   auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
629 
630   // FIXME: What does the original arg index mean here?
631   SmallVector<CallLowering::ArgInfo, 3> Args;
632   for (const MachineOperand &MO : llvm::drop_begin(MI.operands()))
633     Args.push_back({MO.getReg(), OpType, 0});
634   return createLibcall(MIRBuilder, Libcall,
635                        {MI.getOperand(0).getReg(), OpType, 0}, Args,
636                        LocObserver, &MI);
637 }
638 
639 LegalizerHelper::LegalizeResult
createMemLibcall(MachineIRBuilder & MIRBuilder,MachineRegisterInfo & MRI,MachineInstr & MI,LostDebugLocObserver & LocObserver)640 llvm::createMemLibcall(MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI,
641                        MachineInstr &MI, LostDebugLocObserver &LocObserver) {
642   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
643 
644   SmallVector<CallLowering::ArgInfo, 3> Args;
645   // Add all the args, except for the last which is an imm denoting 'tail'.
646   for (unsigned i = 0; i < MI.getNumOperands() - 1; ++i) {
647     Register Reg = MI.getOperand(i).getReg();
648 
649     // Need derive an IR type for call lowering.
650     LLT OpLLT = MRI.getType(Reg);
651     Type *OpTy = nullptr;
652     if (OpLLT.isPointer())
653       OpTy = PointerType::get(Ctx, OpLLT.getAddressSpace());
654     else
655       OpTy = IntegerType::get(Ctx, OpLLT.getSizeInBits());
656     Args.push_back({Reg, OpTy, 0});
657   }
658 
659   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
660   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
661   RTLIB::Libcall RTLibcall;
662   unsigned Opc = MI.getOpcode();
663   switch (Opc) {
664   case TargetOpcode::G_BZERO:
665     RTLibcall = RTLIB::BZERO;
666     break;
667   case TargetOpcode::G_MEMCPY:
668     RTLibcall = RTLIB::MEMCPY;
669     Args[0].Flags[0].setReturned();
670     break;
671   case TargetOpcode::G_MEMMOVE:
672     RTLibcall = RTLIB::MEMMOVE;
673     Args[0].Flags[0].setReturned();
674     break;
675   case TargetOpcode::G_MEMSET:
676     RTLibcall = RTLIB::MEMSET;
677     Args[0].Flags[0].setReturned();
678     break;
679   default:
680     llvm_unreachable("unsupported opcode");
681   }
682   const char *Name = TLI.getLibcallName(RTLibcall);
683 
684   // Unsupported libcall on the target.
685   if (!Name) {
686     LLVM_DEBUG(dbgs() << ".. .. Could not find libcall name for "
687                       << MIRBuilder.getTII().getName(Opc) << "\n");
688     return LegalizerHelper::UnableToLegalize;
689   }
690 
691   CallLowering::CallLoweringInfo Info;
692   Info.CallConv = TLI.getLibcallCallingConv(RTLibcall);
693   Info.Callee = MachineOperand::CreateES(Name);
694   Info.OrigRet = CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0);
695   Info.IsTailCall =
696       MI.getOperand(MI.getNumOperands() - 1).getImm() &&
697       isLibCallInTailPosition(Info.OrigRet, MI, MIRBuilder.getTII(), MRI);
698 
699   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
700   if (!CLI.lowerCall(MIRBuilder, Info))
701     return LegalizerHelper::UnableToLegalize;
702 
703   if (Info.LoweredTailCall) {
704     assert(Info.IsTailCall && "Lowered tail call when it wasn't a tail call?");
705 
706     // Check debug locations before removing the return.
707     LocObserver.checkpoint(true);
708 
709     // We must have a return following the call (or debug insts) to get past
710     // isLibCallInTailPosition.
711     do {
712       MachineInstr *Next = MI.getNextNode();
713       assert(Next &&
714              (Next->isCopy() || Next->isReturn() || Next->isDebugInstr()) &&
715              "Expected instr following MI to be return or debug inst?");
716       // We lowered a tail call, so the call is now the return from the block.
717       // Delete the old return.
718       Next->eraseFromParent();
719     } while (MI.getNextNode());
720 
721     // We expect to lose the debug location from the return.
722     LocObserver.checkpoint(false);
723   }
724 
725   return LegalizerHelper::Legalized;
726 }
727 
getOutlineAtomicLibcall(MachineInstr & MI)728 static RTLIB::Libcall getOutlineAtomicLibcall(MachineInstr &MI) {
729   unsigned Opc = MI.getOpcode();
730   auto &AtomicMI = cast<GMemOperation>(MI);
731   auto &MMO = AtomicMI.getMMO();
732   auto Ordering = MMO.getMergedOrdering();
733   LLT MemType = MMO.getMemoryType();
734   uint64_t MemSize = MemType.getSizeInBytes();
735   if (MemType.isVector())
736     return RTLIB::UNKNOWN_LIBCALL;
737 
738 #define LCALLS(A, B)                                                           \
739   { A##B##_RELAX, A##B##_ACQ, A##B##_REL, A##B##_ACQ_REL }
740 #define LCALL5(A)                                                              \
741   LCALLS(A, 1), LCALLS(A, 2), LCALLS(A, 4), LCALLS(A, 8), LCALLS(A, 16)
742   switch (Opc) {
743   case TargetOpcode::G_ATOMIC_CMPXCHG:
744   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
745     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_CAS)};
746     return getOutlineAtomicHelper(LC, Ordering, MemSize);
747   }
748   case TargetOpcode::G_ATOMICRMW_XCHG: {
749     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_SWP)};
750     return getOutlineAtomicHelper(LC, Ordering, MemSize);
751   }
752   case TargetOpcode::G_ATOMICRMW_ADD:
753   case TargetOpcode::G_ATOMICRMW_SUB: {
754     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDADD)};
755     return getOutlineAtomicHelper(LC, Ordering, MemSize);
756   }
757   case TargetOpcode::G_ATOMICRMW_AND: {
758     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDCLR)};
759     return getOutlineAtomicHelper(LC, Ordering, MemSize);
760   }
761   case TargetOpcode::G_ATOMICRMW_OR: {
762     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDSET)};
763     return getOutlineAtomicHelper(LC, Ordering, MemSize);
764   }
765   case TargetOpcode::G_ATOMICRMW_XOR: {
766     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDEOR)};
767     return getOutlineAtomicHelper(LC, Ordering, MemSize);
768   }
769   default:
770     return RTLIB::UNKNOWN_LIBCALL;
771   }
772 #undef LCALLS
773 #undef LCALL5
774 }
775 
776 static LegalizerHelper::LegalizeResult
createAtomicLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI)777 createAtomicLibcall(MachineIRBuilder &MIRBuilder, MachineInstr &MI) {
778   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
779 
780   Type *RetTy;
781   SmallVector<Register> RetRegs;
782   SmallVector<CallLowering::ArgInfo, 3> Args;
783   unsigned Opc = MI.getOpcode();
784   switch (Opc) {
785   case TargetOpcode::G_ATOMIC_CMPXCHG:
786   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
787     Register Success;
788     LLT SuccessLLT;
789     auto [Ret, RetLLT, Mem, MemLLT, Cmp, CmpLLT, New, NewLLT] =
790         MI.getFirst4RegLLTs();
791     RetRegs.push_back(Ret);
792     RetTy = IntegerType::get(Ctx, RetLLT.getSizeInBits());
793     if (Opc == TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS) {
794       std::tie(Ret, RetLLT, Success, SuccessLLT, Mem, MemLLT, Cmp, CmpLLT, New,
795                NewLLT) = MI.getFirst5RegLLTs();
796       RetRegs.push_back(Success);
797       RetTy = StructType::get(
798           Ctx, {RetTy, IntegerType::get(Ctx, SuccessLLT.getSizeInBits())});
799     }
800     Args.push_back({Cmp, IntegerType::get(Ctx, CmpLLT.getSizeInBits()), 0});
801     Args.push_back({New, IntegerType::get(Ctx, NewLLT.getSizeInBits()), 0});
802     Args.push_back({Mem, PointerType::get(Ctx, MemLLT.getAddressSpace()), 0});
803     break;
804   }
805   case TargetOpcode::G_ATOMICRMW_XCHG:
806   case TargetOpcode::G_ATOMICRMW_ADD:
807   case TargetOpcode::G_ATOMICRMW_SUB:
808   case TargetOpcode::G_ATOMICRMW_AND:
809   case TargetOpcode::G_ATOMICRMW_OR:
810   case TargetOpcode::G_ATOMICRMW_XOR: {
811     auto [Ret, RetLLT, Mem, MemLLT, Val, ValLLT] = MI.getFirst3RegLLTs();
812     RetRegs.push_back(Ret);
813     RetTy = IntegerType::get(Ctx, RetLLT.getSizeInBits());
814     if (Opc == TargetOpcode::G_ATOMICRMW_AND)
815       Val =
816           MIRBuilder.buildXor(ValLLT, MIRBuilder.buildConstant(ValLLT, -1), Val)
817               .getReg(0);
818     else if (Opc == TargetOpcode::G_ATOMICRMW_SUB)
819       Val =
820           MIRBuilder.buildSub(ValLLT, MIRBuilder.buildConstant(ValLLT, 0), Val)
821               .getReg(0);
822     Args.push_back({Val, IntegerType::get(Ctx, ValLLT.getSizeInBits()), 0});
823     Args.push_back({Mem, PointerType::get(Ctx, MemLLT.getAddressSpace()), 0});
824     break;
825   }
826   default:
827     llvm_unreachable("unsupported opcode");
828   }
829 
830   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
831   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
832   RTLIB::Libcall RTLibcall = getOutlineAtomicLibcall(MI);
833   const char *Name = TLI.getLibcallName(RTLibcall);
834 
835   // Unsupported libcall on the target.
836   if (!Name) {
837     LLVM_DEBUG(dbgs() << ".. .. Could not find libcall name for "
838                       << MIRBuilder.getTII().getName(Opc) << "\n");
839     return LegalizerHelper::UnableToLegalize;
840   }
841 
842   CallLowering::CallLoweringInfo Info;
843   Info.CallConv = TLI.getLibcallCallingConv(RTLibcall);
844   Info.Callee = MachineOperand::CreateES(Name);
845   Info.OrigRet = CallLowering::ArgInfo(RetRegs, RetTy, 0);
846 
847   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
848   if (!CLI.lowerCall(MIRBuilder, Info))
849     return LegalizerHelper::UnableToLegalize;
850 
851   return LegalizerHelper::Legalized;
852 }
853 
getConvRTLibDesc(unsigned Opcode,Type * ToType,Type * FromType)854 static RTLIB::Libcall getConvRTLibDesc(unsigned Opcode, Type *ToType,
855                                        Type *FromType) {
856   auto ToMVT = MVT::getVT(ToType);
857   auto FromMVT = MVT::getVT(FromType);
858 
859   switch (Opcode) {
860   case TargetOpcode::G_FPEXT:
861     return RTLIB::getFPEXT(FromMVT, ToMVT);
862   case TargetOpcode::G_FPTRUNC:
863     return RTLIB::getFPROUND(FromMVT, ToMVT);
864   case TargetOpcode::G_FPTOSI:
865     return RTLIB::getFPTOSINT(FromMVT, ToMVT);
866   case TargetOpcode::G_FPTOUI:
867     return RTLIB::getFPTOUINT(FromMVT, ToMVT);
868   case TargetOpcode::G_SITOFP:
869     return RTLIB::getSINTTOFP(FromMVT, ToMVT);
870   case TargetOpcode::G_UITOFP:
871     return RTLIB::getUINTTOFP(FromMVT, ToMVT);
872   }
873   llvm_unreachable("Unsupported libcall function");
874 }
875 
876 static LegalizerHelper::LegalizeResult
conversionLibcall(MachineInstr & MI,MachineIRBuilder & MIRBuilder,Type * ToType,Type * FromType,LostDebugLocObserver & LocObserver)877 conversionLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, Type *ToType,
878                   Type *FromType, LostDebugLocObserver &LocObserver) {
879   RTLIB::Libcall Libcall = getConvRTLibDesc(MI.getOpcode(), ToType, FromType);
880   return createLibcall(
881       MIRBuilder, Libcall, {MI.getOperand(0).getReg(), ToType, 0},
882       {{MI.getOperand(1).getReg(), FromType, 0}}, LocObserver, &MI);
883 }
884 
885 static RTLIB::Libcall
getStateLibraryFunctionFor(MachineInstr & MI,const TargetLowering & TLI)886 getStateLibraryFunctionFor(MachineInstr &MI, const TargetLowering &TLI) {
887   RTLIB::Libcall RTLibcall;
888   switch (MI.getOpcode()) {
889   case TargetOpcode::G_GET_FPENV:
890     RTLibcall = RTLIB::FEGETENV;
891     break;
892   case TargetOpcode::G_SET_FPENV:
893   case TargetOpcode::G_RESET_FPENV:
894     RTLibcall = RTLIB::FESETENV;
895     break;
896   case TargetOpcode::G_GET_FPMODE:
897     RTLibcall = RTLIB::FEGETMODE;
898     break;
899   case TargetOpcode::G_SET_FPMODE:
900   case TargetOpcode::G_RESET_FPMODE:
901     RTLibcall = RTLIB::FESETMODE;
902     break;
903   default:
904     llvm_unreachable("Unexpected opcode");
905   }
906   return RTLibcall;
907 }
908 
909 // Some library functions that read FP state (fegetmode, fegetenv) write the
910 // state into a region in memory. IR intrinsics that do the same operations
911 // (get_fpmode, get_fpenv) return the state as integer value. To implement these
912 // intrinsics via the library functions, we need to use temporary variable,
913 // for example:
914 //
915 //     %0:_(s32) = G_GET_FPMODE
916 //
917 // is transformed to:
918 //
919 //     %1:_(p0) = G_FRAME_INDEX %stack.0
920 //     BL &fegetmode
921 //     %0:_(s32) = G_LOAD % 1
922 //
923 LegalizerHelper::LegalizeResult
createGetStateLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI,LostDebugLocObserver & LocObserver)924 LegalizerHelper::createGetStateLibcall(MachineIRBuilder &MIRBuilder,
925                                        MachineInstr &MI,
926                                        LostDebugLocObserver &LocObserver) {
927   const DataLayout &DL = MIRBuilder.getDataLayout();
928   auto &MF = MIRBuilder.getMF();
929   auto &MRI = *MIRBuilder.getMRI();
930   auto &Ctx = MF.getFunction().getContext();
931 
932   // Create temporary, where library function will put the read state.
933   Register Dst = MI.getOperand(0).getReg();
934   LLT StateTy = MRI.getType(Dst);
935   TypeSize StateSize = StateTy.getSizeInBytes();
936   Align TempAlign = getStackTemporaryAlignment(StateTy);
937   MachinePointerInfo TempPtrInfo;
938   auto Temp = createStackTemporary(StateSize, TempAlign, TempPtrInfo);
939 
940   // Create a call to library function, with the temporary as an argument.
941   unsigned TempAddrSpace = DL.getAllocaAddrSpace();
942   Type *StatePtrTy = PointerType::get(Ctx, TempAddrSpace);
943   RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
944   auto Res =
945       createLibcall(MIRBuilder, RTLibcall,
946                     CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
947                     CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}),
948                     LocObserver, nullptr);
949   if (Res != LegalizerHelper::Legalized)
950     return Res;
951 
952   // Create a load from the temporary.
953   MachineMemOperand *MMO = MF.getMachineMemOperand(
954       TempPtrInfo, MachineMemOperand::MOLoad, StateTy, TempAlign);
955   MIRBuilder.buildLoadInstr(TargetOpcode::G_LOAD, Dst, Temp, *MMO);
956 
957   return LegalizerHelper::Legalized;
958 }
959 
960 // Similar to `createGetStateLibcall` the function calls a library function
961 // using transient space in stack. In this case the library function reads
962 // content of memory region.
963 LegalizerHelper::LegalizeResult
createSetStateLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI,LostDebugLocObserver & LocObserver)964 LegalizerHelper::createSetStateLibcall(MachineIRBuilder &MIRBuilder,
965                                        MachineInstr &MI,
966                                        LostDebugLocObserver &LocObserver) {
967   const DataLayout &DL = MIRBuilder.getDataLayout();
968   auto &MF = MIRBuilder.getMF();
969   auto &MRI = *MIRBuilder.getMRI();
970   auto &Ctx = MF.getFunction().getContext();
971 
972   // Create temporary, where library function will get the new state.
973   Register Src = MI.getOperand(0).getReg();
974   LLT StateTy = MRI.getType(Src);
975   TypeSize StateSize = StateTy.getSizeInBytes();
976   Align TempAlign = getStackTemporaryAlignment(StateTy);
977   MachinePointerInfo TempPtrInfo;
978   auto Temp = createStackTemporary(StateSize, TempAlign, TempPtrInfo);
979 
980   // Put the new state into the temporary.
981   MachineMemOperand *MMO = MF.getMachineMemOperand(
982       TempPtrInfo, MachineMemOperand::MOStore, StateTy, TempAlign);
983   MIRBuilder.buildStore(Src, Temp, *MMO);
984 
985   // Create a call to library function, with the temporary as an argument.
986   unsigned TempAddrSpace = DL.getAllocaAddrSpace();
987   Type *StatePtrTy = PointerType::get(Ctx, TempAddrSpace);
988   RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
989   return createLibcall(MIRBuilder, RTLibcall,
990                        CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
991                        CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}),
992                        LocObserver, nullptr);
993 }
994 
995 // The function is used to legalize operations that set default environment
996 // state. In C library a call like `fesetmode(FE_DFL_MODE)` is used for that.
997 // On most targets supported in glibc FE_DFL_MODE is defined as
998 // `((const femode_t *) -1)`. Such assumption is used here. If for some target
999 // it is not true, the target must provide custom lowering.
1000 LegalizerHelper::LegalizeResult
createResetStateLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI,LostDebugLocObserver & LocObserver)1001 LegalizerHelper::createResetStateLibcall(MachineIRBuilder &MIRBuilder,
1002                                          MachineInstr &MI,
1003                                          LostDebugLocObserver &LocObserver) {
1004   const DataLayout &DL = MIRBuilder.getDataLayout();
1005   auto &MF = MIRBuilder.getMF();
1006   auto &Ctx = MF.getFunction().getContext();
1007 
1008   // Create an argument for the library function.
1009   unsigned AddrSpace = DL.getDefaultGlobalsAddressSpace();
1010   Type *StatePtrTy = PointerType::get(Ctx, AddrSpace);
1011   unsigned PtrSize = DL.getPointerSizeInBits(AddrSpace);
1012   LLT MemTy = LLT::pointer(AddrSpace, PtrSize);
1013   auto DefValue = MIRBuilder.buildConstant(LLT::scalar(PtrSize), -1LL);
1014   DstOp Dest(MRI.createGenericVirtualRegister(MemTy));
1015   MIRBuilder.buildIntToPtr(Dest, DefValue);
1016 
1017   RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
1018   return createLibcall(MIRBuilder, RTLibcall,
1019                        CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
1020                        CallLowering::ArgInfo({Dest.getReg(), StatePtrTy, 0}),
1021                        LocObserver, &MI);
1022 }
1023 
1024 LegalizerHelper::LegalizeResult
libcall(MachineInstr & MI,LostDebugLocObserver & LocObserver)1025 LegalizerHelper::libcall(MachineInstr &MI, LostDebugLocObserver &LocObserver) {
1026   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1027 
1028   switch (MI.getOpcode()) {
1029   default:
1030     return UnableToLegalize;
1031   case TargetOpcode::G_MUL:
1032   case TargetOpcode::G_SDIV:
1033   case TargetOpcode::G_UDIV:
1034   case TargetOpcode::G_SREM:
1035   case TargetOpcode::G_UREM:
1036   case TargetOpcode::G_CTLZ_ZERO_UNDEF: {
1037     LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
1038     unsigned Size = LLTy.getSizeInBits();
1039     Type *HLTy = IntegerType::get(Ctx, Size);
1040     auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy, LocObserver);
1041     if (Status != Legalized)
1042       return Status;
1043     break;
1044   }
1045   case TargetOpcode::G_FADD:
1046   case TargetOpcode::G_FSUB:
1047   case TargetOpcode::G_FMUL:
1048   case TargetOpcode::G_FDIV:
1049   case TargetOpcode::G_FMA:
1050   case TargetOpcode::G_FPOW:
1051   case TargetOpcode::G_FREM:
1052   case TargetOpcode::G_FCOS:
1053   case TargetOpcode::G_FSIN:
1054   case TargetOpcode::G_FTAN:
1055   case TargetOpcode::G_FACOS:
1056   case TargetOpcode::G_FASIN:
1057   case TargetOpcode::G_FATAN:
1058   case TargetOpcode::G_FCOSH:
1059   case TargetOpcode::G_FSINH:
1060   case TargetOpcode::G_FTANH:
1061   case TargetOpcode::G_FLOG10:
1062   case TargetOpcode::G_FLOG:
1063   case TargetOpcode::G_FLOG2:
1064   case TargetOpcode::G_FLDEXP:
1065   case TargetOpcode::G_FEXP:
1066   case TargetOpcode::G_FEXP2:
1067   case TargetOpcode::G_FEXP10:
1068   case TargetOpcode::G_FCEIL:
1069   case TargetOpcode::G_FFLOOR:
1070   case TargetOpcode::G_FMINNUM:
1071   case TargetOpcode::G_FMAXNUM:
1072   case TargetOpcode::G_FSQRT:
1073   case TargetOpcode::G_FRINT:
1074   case TargetOpcode::G_FNEARBYINT:
1075   case TargetOpcode::G_INTRINSIC_ROUNDEVEN: {
1076     LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
1077     unsigned Size = LLTy.getSizeInBits();
1078     Type *HLTy = getFloatTypeForLLT(Ctx, LLTy);
1079     if (!HLTy || (Size != 32 && Size != 64 && Size != 80 && Size != 128)) {
1080       LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
1081       return UnableToLegalize;
1082     }
1083     auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy, LocObserver);
1084     if (Status != Legalized)
1085       return Status;
1086     break;
1087   }
1088   case TargetOpcode::G_INTRINSIC_LRINT:
1089   case TargetOpcode::G_INTRINSIC_LLRINT: {
1090     LLT LLTy = MRI.getType(MI.getOperand(1).getReg());
1091     unsigned Size = LLTy.getSizeInBits();
1092     Type *HLTy = getFloatTypeForLLT(Ctx, LLTy);
1093     Type *ITy = IntegerType::get(
1094         Ctx, MRI.getType(MI.getOperand(0).getReg()).getSizeInBits());
1095     if (!HLTy || (Size != 32 && Size != 64 && Size != 80 && Size != 128)) {
1096       LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
1097       return UnableToLegalize;
1098     }
1099     auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
1100     LegalizeResult Status =
1101         createLibcall(MIRBuilder, Libcall, {MI.getOperand(0).getReg(), ITy, 0},
1102                       {{MI.getOperand(1).getReg(), HLTy, 0}}, LocObserver, &MI);
1103     if (Status != Legalized)
1104       return Status;
1105     MI.eraseFromParent();
1106     return Legalized;
1107   }
1108   case TargetOpcode::G_FPOWI: {
1109     LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
1110     unsigned Size = LLTy.getSizeInBits();
1111     Type *HLTy = getFloatTypeForLLT(Ctx, LLTy);
1112     Type *ITy = IntegerType::get(
1113         Ctx, MRI.getType(MI.getOperand(2).getReg()).getSizeInBits());
1114     if (!HLTy || (Size != 32 && Size != 64 && Size != 80 && Size != 128)) {
1115       LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
1116       return UnableToLegalize;
1117     }
1118     auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
1119     std::initializer_list<CallLowering::ArgInfo> Args = {
1120         {MI.getOperand(1).getReg(), HLTy, 0},
1121         {MI.getOperand(2).getReg(), ITy, 1}};
1122     LegalizeResult Status =
1123         createLibcall(MIRBuilder, Libcall, {MI.getOperand(0).getReg(), HLTy, 0},
1124                       Args, LocObserver, &MI);
1125     if (Status != Legalized)
1126       return Status;
1127     break;
1128   }
1129   case TargetOpcode::G_FPEXT:
1130   case TargetOpcode::G_FPTRUNC: {
1131     Type *FromTy = getFloatTypeForLLT(Ctx,  MRI.getType(MI.getOperand(1).getReg()));
1132     Type *ToTy = getFloatTypeForLLT(Ctx, MRI.getType(MI.getOperand(0).getReg()));
1133     if (!FromTy || !ToTy)
1134       return UnableToLegalize;
1135     LegalizeResult Status =
1136         conversionLibcall(MI, MIRBuilder, ToTy, FromTy, LocObserver);
1137     if (Status != Legalized)
1138       return Status;
1139     break;
1140   }
1141   case TargetOpcode::G_FPTOSI:
1142   case TargetOpcode::G_FPTOUI: {
1143     // FIXME: Support other types
1144     Type *FromTy =
1145         getFloatTypeForLLT(Ctx, MRI.getType(MI.getOperand(1).getReg()));
1146     unsigned ToSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
1147     if ((ToSize != 32 && ToSize != 64 && ToSize != 128) || !FromTy)
1148       return UnableToLegalize;
1149     LegalizeResult Status = conversionLibcall(
1150         MI, MIRBuilder, Type::getIntNTy(Ctx, ToSize), FromTy, LocObserver);
1151     if (Status != Legalized)
1152       return Status;
1153     break;
1154   }
1155   case TargetOpcode::G_SITOFP:
1156   case TargetOpcode::G_UITOFP: {
1157     unsigned FromSize = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
1158     Type *ToTy =
1159         getFloatTypeForLLT(Ctx, MRI.getType(MI.getOperand(0).getReg()));
1160     if ((FromSize != 32 && FromSize != 64 && FromSize != 128) || !ToTy)
1161       return UnableToLegalize;
1162     LegalizeResult Status = conversionLibcall(
1163         MI, MIRBuilder, ToTy, Type::getIntNTy(Ctx, FromSize), LocObserver);
1164     if (Status != Legalized)
1165       return Status;
1166     break;
1167   }
1168   case TargetOpcode::G_ATOMICRMW_XCHG:
1169   case TargetOpcode::G_ATOMICRMW_ADD:
1170   case TargetOpcode::G_ATOMICRMW_SUB:
1171   case TargetOpcode::G_ATOMICRMW_AND:
1172   case TargetOpcode::G_ATOMICRMW_OR:
1173   case TargetOpcode::G_ATOMICRMW_XOR:
1174   case TargetOpcode::G_ATOMIC_CMPXCHG:
1175   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
1176     auto Status = createAtomicLibcall(MIRBuilder, MI);
1177     if (Status != Legalized)
1178       return Status;
1179     break;
1180   }
1181   case TargetOpcode::G_BZERO:
1182   case TargetOpcode::G_MEMCPY:
1183   case TargetOpcode::G_MEMMOVE:
1184   case TargetOpcode::G_MEMSET: {
1185     LegalizeResult Result =
1186         createMemLibcall(MIRBuilder, *MIRBuilder.getMRI(), MI, LocObserver);
1187     if (Result != Legalized)
1188       return Result;
1189     MI.eraseFromParent();
1190     return Result;
1191   }
1192   case TargetOpcode::G_GET_FPENV:
1193   case TargetOpcode::G_GET_FPMODE: {
1194     LegalizeResult Result = createGetStateLibcall(MIRBuilder, MI, LocObserver);
1195     if (Result != Legalized)
1196       return Result;
1197     break;
1198   }
1199   case TargetOpcode::G_SET_FPENV:
1200   case TargetOpcode::G_SET_FPMODE: {
1201     LegalizeResult Result = createSetStateLibcall(MIRBuilder, MI, LocObserver);
1202     if (Result != Legalized)
1203       return Result;
1204     break;
1205   }
1206   case TargetOpcode::G_RESET_FPENV:
1207   case TargetOpcode::G_RESET_FPMODE: {
1208     LegalizeResult Result =
1209         createResetStateLibcall(MIRBuilder, MI, LocObserver);
1210     if (Result != Legalized)
1211       return Result;
1212     break;
1213   }
1214   }
1215 
1216   MI.eraseFromParent();
1217   return Legalized;
1218 }
1219 
narrowScalar(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)1220 LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
1221                                                               unsigned TypeIdx,
1222                                                               LLT NarrowTy) {
1223   uint64_t SizeOp0 = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
1224   uint64_t NarrowSize = NarrowTy.getSizeInBits();
1225 
1226   switch (MI.getOpcode()) {
1227   default:
1228     return UnableToLegalize;
1229   case TargetOpcode::G_IMPLICIT_DEF: {
1230     Register DstReg = MI.getOperand(0).getReg();
1231     LLT DstTy = MRI.getType(DstReg);
1232 
1233     // If SizeOp0 is not an exact multiple of NarrowSize, emit
1234     // G_ANYEXT(G_IMPLICIT_DEF). Cast result to vector if needed.
1235     // FIXME: Although this would also be legal for the general case, it causes
1236     //  a lot of regressions in the emitted code (superfluous COPYs, artifact
1237     //  combines not being hit). This seems to be a problem related to the
1238     //  artifact combiner.
1239     if (SizeOp0 % NarrowSize != 0) {
1240       LLT ImplicitTy = NarrowTy;
1241       if (DstTy.isVector())
1242         ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy);
1243 
1244       Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0);
1245       MIRBuilder.buildAnyExt(DstReg, ImplicitReg);
1246 
1247       MI.eraseFromParent();
1248       return Legalized;
1249     }
1250 
1251     int NumParts = SizeOp0 / NarrowSize;
1252 
1253     SmallVector<Register, 2> DstRegs;
1254     for (int i = 0; i < NumParts; ++i)
1255       DstRegs.push_back(MIRBuilder.buildUndef(NarrowTy).getReg(0));
1256 
1257     if (DstTy.isVector())
1258       MIRBuilder.buildBuildVector(DstReg, DstRegs);
1259     else
1260       MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
1261     MI.eraseFromParent();
1262     return Legalized;
1263   }
1264   case TargetOpcode::G_CONSTANT: {
1265     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1266     const APInt &Val = MI.getOperand(1).getCImm()->getValue();
1267     unsigned TotalSize = Ty.getSizeInBits();
1268     unsigned NarrowSize = NarrowTy.getSizeInBits();
1269     int NumParts = TotalSize / NarrowSize;
1270 
1271     SmallVector<Register, 4> PartRegs;
1272     for (int I = 0; I != NumParts; ++I) {
1273       unsigned Offset = I * NarrowSize;
1274       auto K = MIRBuilder.buildConstant(NarrowTy,
1275                                         Val.lshr(Offset).trunc(NarrowSize));
1276       PartRegs.push_back(K.getReg(0));
1277     }
1278 
1279     LLT LeftoverTy;
1280     unsigned LeftoverBits = TotalSize - NumParts * NarrowSize;
1281     SmallVector<Register, 1> LeftoverRegs;
1282     if (LeftoverBits != 0) {
1283       LeftoverTy = LLT::scalar(LeftoverBits);
1284       auto K = MIRBuilder.buildConstant(
1285         LeftoverTy,
1286         Val.lshr(NumParts * NarrowSize).trunc(LeftoverBits));
1287       LeftoverRegs.push_back(K.getReg(0));
1288     }
1289 
1290     insertParts(MI.getOperand(0).getReg(),
1291                 Ty, NarrowTy, PartRegs, LeftoverTy, LeftoverRegs);
1292 
1293     MI.eraseFromParent();
1294     return Legalized;
1295   }
1296   case TargetOpcode::G_SEXT:
1297   case TargetOpcode::G_ZEXT:
1298   case TargetOpcode::G_ANYEXT:
1299     return narrowScalarExt(MI, TypeIdx, NarrowTy);
1300   case TargetOpcode::G_TRUNC: {
1301     if (TypeIdx != 1)
1302       return UnableToLegalize;
1303 
1304     uint64_t SizeOp1 = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
1305     if (NarrowTy.getSizeInBits() * 2 != SizeOp1) {
1306       LLVM_DEBUG(dbgs() << "Can't narrow trunc to type " << NarrowTy << "\n");
1307       return UnableToLegalize;
1308     }
1309 
1310     auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1));
1311     MIRBuilder.buildCopy(MI.getOperand(0), Unmerge.getReg(0));
1312     MI.eraseFromParent();
1313     return Legalized;
1314   }
1315   case TargetOpcode::G_CONSTANT_FOLD_BARRIER:
1316   case TargetOpcode::G_FREEZE: {
1317     if (TypeIdx != 0)
1318       return UnableToLegalize;
1319 
1320     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1321     // Should widen scalar first
1322     if (Ty.getSizeInBits() % NarrowTy.getSizeInBits() != 0)
1323       return UnableToLegalize;
1324 
1325     auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1).getReg());
1326     SmallVector<Register, 8> Parts;
1327     for (unsigned i = 0; i < Unmerge->getNumDefs(); ++i) {
1328       Parts.push_back(
1329           MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy}, {Unmerge.getReg(i)})
1330               .getReg(0));
1331     }
1332 
1333     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0).getReg(), Parts);
1334     MI.eraseFromParent();
1335     return Legalized;
1336   }
1337   case TargetOpcode::G_ADD:
1338   case TargetOpcode::G_SUB:
1339   case TargetOpcode::G_SADDO:
1340   case TargetOpcode::G_SSUBO:
1341   case TargetOpcode::G_SADDE:
1342   case TargetOpcode::G_SSUBE:
1343   case TargetOpcode::G_UADDO:
1344   case TargetOpcode::G_USUBO:
1345   case TargetOpcode::G_UADDE:
1346   case TargetOpcode::G_USUBE:
1347     return narrowScalarAddSub(MI, TypeIdx, NarrowTy);
1348   case TargetOpcode::G_MUL:
1349   case TargetOpcode::G_UMULH:
1350     return narrowScalarMul(MI, NarrowTy);
1351   case TargetOpcode::G_EXTRACT:
1352     return narrowScalarExtract(MI, TypeIdx, NarrowTy);
1353   case TargetOpcode::G_INSERT:
1354     return narrowScalarInsert(MI, TypeIdx, NarrowTy);
1355   case TargetOpcode::G_LOAD: {
1356     auto &LoadMI = cast<GLoad>(MI);
1357     Register DstReg = LoadMI.getDstReg();
1358     LLT DstTy = MRI.getType(DstReg);
1359     if (DstTy.isVector())
1360       return UnableToLegalize;
1361 
1362     if (8 * LoadMI.getMemSize().getValue() != DstTy.getSizeInBits()) {
1363       Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1364       MIRBuilder.buildLoad(TmpReg, LoadMI.getPointerReg(), LoadMI.getMMO());
1365       MIRBuilder.buildAnyExt(DstReg, TmpReg);
1366       LoadMI.eraseFromParent();
1367       return Legalized;
1368     }
1369 
1370     return reduceLoadStoreWidth(LoadMI, TypeIdx, NarrowTy);
1371   }
1372   case TargetOpcode::G_ZEXTLOAD:
1373   case TargetOpcode::G_SEXTLOAD: {
1374     auto &LoadMI = cast<GExtLoad>(MI);
1375     Register DstReg = LoadMI.getDstReg();
1376     Register PtrReg = LoadMI.getPointerReg();
1377 
1378     Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1379     auto &MMO = LoadMI.getMMO();
1380     unsigned MemSize = MMO.getSizeInBits().getValue();
1381 
1382     if (MemSize == NarrowSize) {
1383       MIRBuilder.buildLoad(TmpReg, PtrReg, MMO);
1384     } else if (MemSize < NarrowSize) {
1385       MIRBuilder.buildLoadInstr(LoadMI.getOpcode(), TmpReg, PtrReg, MMO);
1386     } else if (MemSize > NarrowSize) {
1387       // FIXME: Need to split the load.
1388       return UnableToLegalize;
1389     }
1390 
1391     if (isa<GZExtLoad>(LoadMI))
1392       MIRBuilder.buildZExt(DstReg, TmpReg);
1393     else
1394       MIRBuilder.buildSExt(DstReg, TmpReg);
1395 
1396     LoadMI.eraseFromParent();
1397     return Legalized;
1398   }
1399   case TargetOpcode::G_STORE: {
1400     auto &StoreMI = cast<GStore>(MI);
1401 
1402     Register SrcReg = StoreMI.getValueReg();
1403     LLT SrcTy = MRI.getType(SrcReg);
1404     if (SrcTy.isVector())
1405       return UnableToLegalize;
1406 
1407     int NumParts = SizeOp0 / NarrowSize;
1408     unsigned HandledSize = NumParts * NarrowTy.getSizeInBits();
1409     unsigned LeftoverBits = SrcTy.getSizeInBits() - HandledSize;
1410     if (SrcTy.isVector() && LeftoverBits != 0)
1411       return UnableToLegalize;
1412 
1413     if (8 * StoreMI.getMemSize().getValue() != SrcTy.getSizeInBits()) {
1414       Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1415       MIRBuilder.buildTrunc(TmpReg, SrcReg);
1416       MIRBuilder.buildStore(TmpReg, StoreMI.getPointerReg(), StoreMI.getMMO());
1417       StoreMI.eraseFromParent();
1418       return Legalized;
1419     }
1420 
1421     return reduceLoadStoreWidth(StoreMI, 0, NarrowTy);
1422   }
1423   case TargetOpcode::G_SELECT:
1424     return narrowScalarSelect(MI, TypeIdx, NarrowTy);
1425   case TargetOpcode::G_AND:
1426   case TargetOpcode::G_OR:
1427   case TargetOpcode::G_XOR: {
1428     // Legalize bitwise operation:
1429     // A = BinOp<Ty> B, C
1430     // into:
1431     // B1, ..., BN = G_UNMERGE_VALUES B
1432     // C1, ..., CN = G_UNMERGE_VALUES C
1433     // A1 = BinOp<Ty/N> B1, C2
1434     // ...
1435     // AN = BinOp<Ty/N> BN, CN
1436     // A = G_MERGE_VALUES A1, ..., AN
1437     return narrowScalarBasic(MI, TypeIdx, NarrowTy);
1438   }
1439   case TargetOpcode::G_SHL:
1440   case TargetOpcode::G_LSHR:
1441   case TargetOpcode::G_ASHR:
1442     return narrowScalarShift(MI, TypeIdx, NarrowTy);
1443   case TargetOpcode::G_CTLZ:
1444   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
1445   case TargetOpcode::G_CTTZ:
1446   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
1447   case TargetOpcode::G_CTPOP:
1448     if (TypeIdx == 1)
1449       switch (MI.getOpcode()) {
1450       case TargetOpcode::G_CTLZ:
1451       case TargetOpcode::G_CTLZ_ZERO_UNDEF:
1452         return narrowScalarCTLZ(MI, TypeIdx, NarrowTy);
1453       case TargetOpcode::G_CTTZ:
1454       case TargetOpcode::G_CTTZ_ZERO_UNDEF:
1455         return narrowScalarCTTZ(MI, TypeIdx, NarrowTy);
1456       case TargetOpcode::G_CTPOP:
1457         return narrowScalarCTPOP(MI, TypeIdx, NarrowTy);
1458       default:
1459         return UnableToLegalize;
1460       }
1461 
1462     Observer.changingInstr(MI);
1463     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1464     Observer.changedInstr(MI);
1465     return Legalized;
1466   case TargetOpcode::G_INTTOPTR:
1467     if (TypeIdx != 1)
1468       return UnableToLegalize;
1469 
1470     Observer.changingInstr(MI);
1471     narrowScalarSrc(MI, NarrowTy, 1);
1472     Observer.changedInstr(MI);
1473     return Legalized;
1474   case TargetOpcode::G_PTRTOINT:
1475     if (TypeIdx != 0)
1476       return UnableToLegalize;
1477 
1478     Observer.changingInstr(MI);
1479     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1480     Observer.changedInstr(MI);
1481     return Legalized;
1482   case TargetOpcode::G_PHI: {
1483     // FIXME: add support for when SizeOp0 isn't an exact multiple of
1484     // NarrowSize.
1485     if (SizeOp0 % NarrowSize != 0)
1486       return UnableToLegalize;
1487 
1488     unsigned NumParts = SizeOp0 / NarrowSize;
1489     SmallVector<Register, 2> DstRegs(NumParts);
1490     SmallVector<SmallVector<Register, 2>, 2> SrcRegs(MI.getNumOperands() / 2);
1491     Observer.changingInstr(MI);
1492     for (unsigned i = 1; i < MI.getNumOperands(); i += 2) {
1493       MachineBasicBlock &OpMBB = *MI.getOperand(i + 1).getMBB();
1494       MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
1495       extractParts(MI.getOperand(i).getReg(), NarrowTy, NumParts,
1496                    SrcRegs[i / 2], MIRBuilder, MRI);
1497     }
1498     MachineBasicBlock &MBB = *MI.getParent();
1499     MIRBuilder.setInsertPt(MBB, MI);
1500     for (unsigned i = 0; i < NumParts; ++i) {
1501       DstRegs[i] = MRI.createGenericVirtualRegister(NarrowTy);
1502       MachineInstrBuilder MIB =
1503           MIRBuilder.buildInstr(TargetOpcode::G_PHI).addDef(DstRegs[i]);
1504       for (unsigned j = 1; j < MI.getNumOperands(); j += 2)
1505         MIB.addUse(SrcRegs[j / 2][i]).add(MI.getOperand(j + 1));
1506     }
1507     MIRBuilder.setInsertPt(MBB, MBB.getFirstNonPHI());
1508     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), DstRegs);
1509     Observer.changedInstr(MI);
1510     MI.eraseFromParent();
1511     return Legalized;
1512   }
1513   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
1514   case TargetOpcode::G_INSERT_VECTOR_ELT: {
1515     if (TypeIdx != 2)
1516       return UnableToLegalize;
1517 
1518     int OpIdx = MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
1519     Observer.changingInstr(MI);
1520     narrowScalarSrc(MI, NarrowTy, OpIdx);
1521     Observer.changedInstr(MI);
1522     return Legalized;
1523   }
1524   case TargetOpcode::G_ICMP: {
1525     Register LHS = MI.getOperand(2).getReg();
1526     LLT SrcTy = MRI.getType(LHS);
1527     uint64_t SrcSize = SrcTy.getSizeInBits();
1528     CmpInst::Predicate Pred =
1529         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
1530 
1531     // TODO: Handle the non-equality case for weird sizes.
1532     if (NarrowSize * 2 != SrcSize && !ICmpInst::isEquality(Pred))
1533       return UnableToLegalize;
1534 
1535     LLT LeftoverTy; // Example: s88 -> s64 (NarrowTy) + s24 (leftover)
1536     SmallVector<Register, 4> LHSPartRegs, LHSLeftoverRegs;
1537     if (!extractParts(LHS, SrcTy, NarrowTy, LeftoverTy, LHSPartRegs,
1538                       LHSLeftoverRegs, MIRBuilder, MRI))
1539       return UnableToLegalize;
1540 
1541     LLT Unused; // Matches LeftoverTy; G_ICMP LHS and RHS are the same type.
1542     SmallVector<Register, 4> RHSPartRegs, RHSLeftoverRegs;
1543     if (!extractParts(MI.getOperand(3).getReg(), SrcTy, NarrowTy, Unused,
1544                       RHSPartRegs, RHSLeftoverRegs, MIRBuilder, MRI))
1545       return UnableToLegalize;
1546 
1547     // We now have the LHS and RHS of the compare split into narrow-type
1548     // registers, plus potentially some leftover type.
1549     Register Dst = MI.getOperand(0).getReg();
1550     LLT ResTy = MRI.getType(Dst);
1551     if (ICmpInst::isEquality(Pred)) {
1552       // For each part on the LHS and RHS, keep track of the result of XOR-ing
1553       // them together. For each equal part, the result should be all 0s. For
1554       // each non-equal part, we'll get at least one 1.
1555       auto Zero = MIRBuilder.buildConstant(NarrowTy, 0);
1556       SmallVector<Register, 4> Xors;
1557       for (auto LHSAndRHS : zip(LHSPartRegs, RHSPartRegs)) {
1558         auto LHS = std::get<0>(LHSAndRHS);
1559         auto RHS = std::get<1>(LHSAndRHS);
1560         auto Xor = MIRBuilder.buildXor(NarrowTy, LHS, RHS).getReg(0);
1561         Xors.push_back(Xor);
1562       }
1563 
1564       // Build a G_XOR for each leftover register. Each G_XOR must be widened
1565       // to the desired narrow type so that we can OR them together later.
1566       SmallVector<Register, 4> WidenedXors;
1567       for (auto LHSAndRHS : zip(LHSLeftoverRegs, RHSLeftoverRegs)) {
1568         auto LHS = std::get<0>(LHSAndRHS);
1569         auto RHS = std::get<1>(LHSAndRHS);
1570         auto Xor = MIRBuilder.buildXor(LeftoverTy, LHS, RHS).getReg(0);
1571         LLT GCDTy = extractGCDType(WidenedXors, NarrowTy, LeftoverTy, Xor);
1572         buildLCMMergePieces(LeftoverTy, NarrowTy, GCDTy, WidenedXors,
1573                             /* PadStrategy = */ TargetOpcode::G_ZEXT);
1574         Xors.insert(Xors.end(), WidenedXors.begin(), WidenedXors.end());
1575       }
1576 
1577       // Now, for each part we broke up, we know if they are equal/not equal
1578       // based off the G_XOR. We can OR these all together and compare against
1579       // 0 to get the result.
1580       assert(Xors.size() >= 2 && "Should have gotten at least two Xors?");
1581       auto Or = MIRBuilder.buildOr(NarrowTy, Xors[0], Xors[1]);
1582       for (unsigned I = 2, E = Xors.size(); I < E; ++I)
1583         Or = MIRBuilder.buildOr(NarrowTy, Or, Xors[I]);
1584       MIRBuilder.buildICmp(Pred, Dst, Or, Zero);
1585     } else {
1586       // TODO: Handle non-power-of-two types.
1587       assert(LHSPartRegs.size() == 2 && "Expected exactly 2 LHS part regs?");
1588       assert(RHSPartRegs.size() == 2 && "Expected exactly 2 RHS part regs?");
1589       Register LHSL = LHSPartRegs[0];
1590       Register LHSH = LHSPartRegs[1];
1591       Register RHSL = RHSPartRegs[0];
1592       Register RHSH = RHSPartRegs[1];
1593       MachineInstrBuilder CmpH = MIRBuilder.buildICmp(Pred, ResTy, LHSH, RHSH);
1594       MachineInstrBuilder CmpHEQ =
1595           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, ResTy, LHSH, RHSH);
1596       MachineInstrBuilder CmpLU = MIRBuilder.buildICmp(
1597           ICmpInst::getUnsignedPredicate(Pred), ResTy, LHSL, RHSL);
1598       MIRBuilder.buildSelect(Dst, CmpHEQ, CmpLU, CmpH);
1599     }
1600     MI.eraseFromParent();
1601     return Legalized;
1602   }
1603   case TargetOpcode::G_FCMP:
1604     if (TypeIdx != 0)
1605       return UnableToLegalize;
1606 
1607     Observer.changingInstr(MI);
1608     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1609     Observer.changedInstr(MI);
1610     return Legalized;
1611 
1612   case TargetOpcode::G_SEXT_INREG: {
1613     if (TypeIdx != 0)
1614       return UnableToLegalize;
1615 
1616     int64_t SizeInBits = MI.getOperand(2).getImm();
1617 
1618     // So long as the new type has more bits than the bits we're extending we
1619     // don't need to break it apart.
1620     if (NarrowTy.getScalarSizeInBits() > SizeInBits) {
1621       Observer.changingInstr(MI);
1622       // We don't lose any non-extension bits by truncating the src and
1623       // sign-extending the dst.
1624       MachineOperand &MO1 = MI.getOperand(1);
1625       auto TruncMIB = MIRBuilder.buildTrunc(NarrowTy, MO1);
1626       MO1.setReg(TruncMIB.getReg(0));
1627 
1628       MachineOperand &MO2 = MI.getOperand(0);
1629       Register DstExt = MRI.createGenericVirtualRegister(NarrowTy);
1630       MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1631       MIRBuilder.buildSExt(MO2, DstExt);
1632       MO2.setReg(DstExt);
1633       Observer.changedInstr(MI);
1634       return Legalized;
1635     }
1636 
1637     // Break it apart. Components below the extension point are unmodified. The
1638     // component containing the extension point becomes a narrower SEXT_INREG.
1639     // Components above it are ashr'd from the component containing the
1640     // extension point.
1641     if (SizeOp0 % NarrowSize != 0)
1642       return UnableToLegalize;
1643     int NumParts = SizeOp0 / NarrowSize;
1644 
1645     // List the registers where the destination will be scattered.
1646     SmallVector<Register, 2> DstRegs;
1647     // List the registers where the source will be split.
1648     SmallVector<Register, 2> SrcRegs;
1649 
1650     // Create all the temporary registers.
1651     for (int i = 0; i < NumParts; ++i) {
1652       Register SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
1653 
1654       SrcRegs.push_back(SrcReg);
1655     }
1656 
1657     // Explode the big arguments into smaller chunks.
1658     MIRBuilder.buildUnmerge(SrcRegs, MI.getOperand(1));
1659 
1660     Register AshrCstReg =
1661         MIRBuilder.buildConstant(NarrowTy, NarrowTy.getScalarSizeInBits() - 1)
1662             .getReg(0);
1663     Register FullExtensionReg;
1664     Register PartialExtensionReg;
1665 
1666     // Do the operation on each small part.
1667     for (int i = 0; i < NumParts; ++i) {
1668       if ((i + 1) * NarrowTy.getScalarSizeInBits() <= SizeInBits) {
1669         DstRegs.push_back(SrcRegs[i]);
1670         PartialExtensionReg = DstRegs.back();
1671       } else if (i * NarrowTy.getScalarSizeInBits() >= SizeInBits) {
1672         assert(PartialExtensionReg &&
1673                "Expected to visit partial extension before full");
1674         if (FullExtensionReg) {
1675           DstRegs.push_back(FullExtensionReg);
1676           continue;
1677         }
1678         DstRegs.push_back(
1679             MIRBuilder.buildAShr(NarrowTy, PartialExtensionReg, AshrCstReg)
1680                 .getReg(0));
1681         FullExtensionReg = DstRegs.back();
1682       } else {
1683         DstRegs.push_back(
1684             MIRBuilder
1685                 .buildInstr(
1686                     TargetOpcode::G_SEXT_INREG, {NarrowTy},
1687                     {SrcRegs[i], SizeInBits % NarrowTy.getScalarSizeInBits()})
1688                 .getReg(0));
1689         PartialExtensionReg = DstRegs.back();
1690       }
1691     }
1692 
1693     // Gather the destination registers into the final destination.
1694     Register DstReg = MI.getOperand(0).getReg();
1695     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
1696     MI.eraseFromParent();
1697     return Legalized;
1698   }
1699   case TargetOpcode::G_BSWAP:
1700   case TargetOpcode::G_BITREVERSE: {
1701     if (SizeOp0 % NarrowSize != 0)
1702       return UnableToLegalize;
1703 
1704     Observer.changingInstr(MI);
1705     SmallVector<Register, 2> SrcRegs, DstRegs;
1706     unsigned NumParts = SizeOp0 / NarrowSize;
1707     extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
1708                  MIRBuilder, MRI);
1709 
1710     for (unsigned i = 0; i < NumParts; ++i) {
1711       auto DstPart = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
1712                                            {SrcRegs[NumParts - 1 - i]});
1713       DstRegs.push_back(DstPart.getReg(0));
1714     }
1715 
1716     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), DstRegs);
1717 
1718     Observer.changedInstr(MI);
1719     MI.eraseFromParent();
1720     return Legalized;
1721   }
1722   case TargetOpcode::G_PTR_ADD:
1723   case TargetOpcode::G_PTRMASK: {
1724     if (TypeIdx != 1)
1725       return UnableToLegalize;
1726     Observer.changingInstr(MI);
1727     narrowScalarSrc(MI, NarrowTy, 2);
1728     Observer.changedInstr(MI);
1729     return Legalized;
1730   }
1731   case TargetOpcode::G_FPTOUI:
1732   case TargetOpcode::G_FPTOSI:
1733     return narrowScalarFPTOI(MI, TypeIdx, NarrowTy);
1734   case TargetOpcode::G_FPEXT:
1735     if (TypeIdx != 0)
1736       return UnableToLegalize;
1737     Observer.changingInstr(MI);
1738     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_FPEXT);
1739     Observer.changedInstr(MI);
1740     return Legalized;
1741   case TargetOpcode::G_FLDEXP:
1742   case TargetOpcode::G_STRICT_FLDEXP:
1743     return narrowScalarFLDEXP(MI, TypeIdx, NarrowTy);
1744   case TargetOpcode::G_VSCALE: {
1745     Register Dst = MI.getOperand(0).getReg();
1746     LLT Ty = MRI.getType(Dst);
1747 
1748     // Assume VSCALE(1) fits into a legal integer
1749     const APInt One(NarrowTy.getSizeInBits(), 1);
1750     auto VScaleBase = MIRBuilder.buildVScale(NarrowTy, One);
1751     auto ZExt = MIRBuilder.buildZExt(Ty, VScaleBase);
1752     auto C = MIRBuilder.buildConstant(Ty, *MI.getOperand(1).getCImm());
1753     MIRBuilder.buildMul(Dst, ZExt, C);
1754 
1755     MI.eraseFromParent();
1756     return Legalized;
1757   }
1758   }
1759 }
1760 
coerceToScalar(Register Val)1761 Register LegalizerHelper::coerceToScalar(Register Val) {
1762   LLT Ty = MRI.getType(Val);
1763   if (Ty.isScalar())
1764     return Val;
1765 
1766   const DataLayout &DL = MIRBuilder.getDataLayout();
1767   LLT NewTy = LLT::scalar(Ty.getSizeInBits());
1768   if (Ty.isPointer()) {
1769     if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
1770       return Register();
1771     return MIRBuilder.buildPtrToInt(NewTy, Val).getReg(0);
1772   }
1773 
1774   Register NewVal = Val;
1775 
1776   assert(Ty.isVector());
1777   if (Ty.isPointerVector())
1778     NewVal = MIRBuilder.buildPtrToInt(NewTy, NewVal).getReg(0);
1779   return MIRBuilder.buildBitcast(NewTy, NewVal).getReg(0);
1780 }
1781 
widenScalarSrc(MachineInstr & MI,LLT WideTy,unsigned OpIdx,unsigned ExtOpcode)1782 void LegalizerHelper::widenScalarSrc(MachineInstr &MI, LLT WideTy,
1783                                      unsigned OpIdx, unsigned ExtOpcode) {
1784   MachineOperand &MO = MI.getOperand(OpIdx);
1785   auto ExtB = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MO});
1786   MO.setReg(ExtB.getReg(0));
1787 }
1788 
narrowScalarSrc(MachineInstr & MI,LLT NarrowTy,unsigned OpIdx)1789 void LegalizerHelper::narrowScalarSrc(MachineInstr &MI, LLT NarrowTy,
1790                                       unsigned OpIdx) {
1791   MachineOperand &MO = MI.getOperand(OpIdx);
1792   auto ExtB = MIRBuilder.buildTrunc(NarrowTy, MO);
1793   MO.setReg(ExtB.getReg(0));
1794 }
1795 
widenScalarDst(MachineInstr & MI,LLT WideTy,unsigned OpIdx,unsigned TruncOpcode)1796 void LegalizerHelper::widenScalarDst(MachineInstr &MI, LLT WideTy,
1797                                      unsigned OpIdx, unsigned TruncOpcode) {
1798   MachineOperand &MO = MI.getOperand(OpIdx);
1799   Register DstExt = MRI.createGenericVirtualRegister(WideTy);
1800   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1801   MIRBuilder.buildInstr(TruncOpcode, {MO}, {DstExt});
1802   MO.setReg(DstExt);
1803 }
1804 
narrowScalarDst(MachineInstr & MI,LLT NarrowTy,unsigned OpIdx,unsigned ExtOpcode)1805 void LegalizerHelper::narrowScalarDst(MachineInstr &MI, LLT NarrowTy,
1806                                       unsigned OpIdx, unsigned ExtOpcode) {
1807   MachineOperand &MO = MI.getOperand(OpIdx);
1808   Register DstTrunc = MRI.createGenericVirtualRegister(NarrowTy);
1809   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1810   MIRBuilder.buildInstr(ExtOpcode, {MO}, {DstTrunc});
1811   MO.setReg(DstTrunc);
1812 }
1813 
moreElementsVectorDst(MachineInstr & MI,LLT WideTy,unsigned OpIdx)1814 void LegalizerHelper::moreElementsVectorDst(MachineInstr &MI, LLT WideTy,
1815                                             unsigned OpIdx) {
1816   MachineOperand &MO = MI.getOperand(OpIdx);
1817   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1818   Register Dst = MO.getReg();
1819   Register DstExt = MRI.createGenericVirtualRegister(WideTy);
1820   MO.setReg(DstExt);
1821   MIRBuilder.buildDeleteTrailingVectorElements(Dst, DstExt);
1822 }
1823 
moreElementsVectorSrc(MachineInstr & MI,LLT MoreTy,unsigned OpIdx)1824 void LegalizerHelper::moreElementsVectorSrc(MachineInstr &MI, LLT MoreTy,
1825                                             unsigned OpIdx) {
1826   MachineOperand &MO = MI.getOperand(OpIdx);
1827   SmallVector<Register, 8> Regs;
1828   MO.setReg(MIRBuilder.buildPadVectorWithUndefElements(MoreTy, MO).getReg(0));
1829 }
1830 
bitcastSrc(MachineInstr & MI,LLT CastTy,unsigned OpIdx)1831 void LegalizerHelper::bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
1832   MachineOperand &Op = MI.getOperand(OpIdx);
1833   Op.setReg(MIRBuilder.buildBitcast(CastTy, Op).getReg(0));
1834 }
1835 
bitcastDst(MachineInstr & MI,LLT CastTy,unsigned OpIdx)1836 void LegalizerHelper::bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
1837   MachineOperand &MO = MI.getOperand(OpIdx);
1838   Register CastDst = MRI.createGenericVirtualRegister(CastTy);
1839   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1840   MIRBuilder.buildBitcast(MO, CastDst);
1841   MO.setReg(CastDst);
1842 }
1843 
1844 LegalizerHelper::LegalizeResult
widenScalarMergeValues(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1845 LegalizerHelper::widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx,
1846                                         LLT WideTy) {
1847   if (TypeIdx != 1)
1848     return UnableToLegalize;
1849 
1850   auto [DstReg, DstTy, Src1Reg, Src1Ty] = MI.getFirst2RegLLTs();
1851   if (DstTy.isVector())
1852     return UnableToLegalize;
1853 
1854   LLT SrcTy = MRI.getType(Src1Reg);
1855   const int DstSize = DstTy.getSizeInBits();
1856   const int SrcSize = SrcTy.getSizeInBits();
1857   const int WideSize = WideTy.getSizeInBits();
1858   const int NumMerge = (DstSize + WideSize - 1) / WideSize;
1859 
1860   unsigned NumOps = MI.getNumOperands();
1861   unsigned NumSrc = MI.getNumOperands() - 1;
1862   unsigned PartSize = DstTy.getSizeInBits() / NumSrc;
1863 
1864   if (WideSize >= DstSize) {
1865     // Directly pack the bits in the target type.
1866     Register ResultReg = MIRBuilder.buildZExt(WideTy, Src1Reg).getReg(0);
1867 
1868     for (unsigned I = 2; I != NumOps; ++I) {
1869       const unsigned Offset = (I - 1) * PartSize;
1870 
1871       Register SrcReg = MI.getOperand(I).getReg();
1872       assert(MRI.getType(SrcReg) == LLT::scalar(PartSize));
1873 
1874       auto ZextInput = MIRBuilder.buildZExt(WideTy, SrcReg);
1875 
1876       Register NextResult = I + 1 == NumOps && WideTy == DstTy ? DstReg :
1877         MRI.createGenericVirtualRegister(WideTy);
1878 
1879       auto ShiftAmt = MIRBuilder.buildConstant(WideTy, Offset);
1880       auto Shl = MIRBuilder.buildShl(WideTy, ZextInput, ShiftAmt);
1881       MIRBuilder.buildOr(NextResult, ResultReg, Shl);
1882       ResultReg = NextResult;
1883     }
1884 
1885     if (WideSize > DstSize)
1886       MIRBuilder.buildTrunc(DstReg, ResultReg);
1887     else if (DstTy.isPointer())
1888       MIRBuilder.buildIntToPtr(DstReg, ResultReg);
1889 
1890     MI.eraseFromParent();
1891     return Legalized;
1892   }
1893 
1894   // Unmerge the original values to the GCD type, and recombine to the next
1895   // multiple greater than the original type.
1896   //
1897   // %3:_(s12) = G_MERGE_VALUES %0:_(s4), %1:_(s4), %2:_(s4) -> s6
1898   // %4:_(s2), %5:_(s2) = G_UNMERGE_VALUES %0
1899   // %6:_(s2), %7:_(s2) = G_UNMERGE_VALUES %1
1900   // %8:_(s2), %9:_(s2) = G_UNMERGE_VALUES %2
1901   // %10:_(s6) = G_MERGE_VALUES %4, %5, %6
1902   // %11:_(s6) = G_MERGE_VALUES %7, %8, %9
1903   // %12:_(s12) = G_MERGE_VALUES %10, %11
1904   //
1905   // Padding with undef if necessary:
1906   //
1907   // %2:_(s8) = G_MERGE_VALUES %0:_(s4), %1:_(s4) -> s6
1908   // %3:_(s2), %4:_(s2) = G_UNMERGE_VALUES %0
1909   // %5:_(s2), %6:_(s2) = G_UNMERGE_VALUES %1
1910   // %7:_(s2) = G_IMPLICIT_DEF
1911   // %8:_(s6) = G_MERGE_VALUES %3, %4, %5
1912   // %9:_(s6) = G_MERGE_VALUES %6, %7, %7
1913   // %10:_(s12) = G_MERGE_VALUES %8, %9
1914 
1915   const int GCD = std::gcd(SrcSize, WideSize);
1916   LLT GCDTy = LLT::scalar(GCD);
1917 
1918   SmallVector<Register, 8> Parts;
1919   SmallVector<Register, 8> NewMergeRegs;
1920   SmallVector<Register, 8> Unmerges;
1921   LLT WideDstTy = LLT::scalar(NumMerge * WideSize);
1922 
1923   // Decompose the original operands if they don't evenly divide.
1924   for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) {
1925     Register SrcReg = MO.getReg();
1926     if (GCD == SrcSize) {
1927       Unmerges.push_back(SrcReg);
1928     } else {
1929       auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
1930       for (int J = 0, JE = Unmerge->getNumOperands() - 1; J != JE; ++J)
1931         Unmerges.push_back(Unmerge.getReg(J));
1932     }
1933   }
1934 
1935   // Pad with undef to the next size that is a multiple of the requested size.
1936   if (static_cast<int>(Unmerges.size()) != NumMerge * WideSize) {
1937     Register UndefReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
1938     for (int I = Unmerges.size(); I != NumMerge * WideSize; ++I)
1939       Unmerges.push_back(UndefReg);
1940   }
1941 
1942   const int PartsPerGCD = WideSize / GCD;
1943 
1944   // Build merges of each piece.
1945   ArrayRef<Register> Slicer(Unmerges);
1946   for (int I = 0; I != NumMerge; ++I, Slicer = Slicer.drop_front(PartsPerGCD)) {
1947     auto Merge =
1948         MIRBuilder.buildMergeLikeInstr(WideTy, Slicer.take_front(PartsPerGCD));
1949     NewMergeRegs.push_back(Merge.getReg(0));
1950   }
1951 
1952   // A truncate may be necessary if the requested type doesn't evenly divide the
1953   // original result type.
1954   if (DstTy.getSizeInBits() == WideDstTy.getSizeInBits()) {
1955     MIRBuilder.buildMergeLikeInstr(DstReg, NewMergeRegs);
1956   } else {
1957     auto FinalMerge = MIRBuilder.buildMergeLikeInstr(WideDstTy, NewMergeRegs);
1958     MIRBuilder.buildTrunc(DstReg, FinalMerge.getReg(0));
1959   }
1960 
1961   MI.eraseFromParent();
1962   return Legalized;
1963 }
1964 
1965 LegalizerHelper::LegalizeResult
widenScalarUnmergeValues(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1966 LegalizerHelper::widenScalarUnmergeValues(MachineInstr &MI, unsigned TypeIdx,
1967                                           LLT WideTy) {
1968   if (TypeIdx != 0)
1969     return UnableToLegalize;
1970 
1971   int NumDst = MI.getNumOperands() - 1;
1972   Register SrcReg = MI.getOperand(NumDst).getReg();
1973   LLT SrcTy = MRI.getType(SrcReg);
1974   if (SrcTy.isVector())
1975     return UnableToLegalize;
1976 
1977   Register Dst0Reg = MI.getOperand(0).getReg();
1978   LLT DstTy = MRI.getType(Dst0Reg);
1979   if (!DstTy.isScalar())
1980     return UnableToLegalize;
1981 
1982   if (WideTy.getSizeInBits() >= SrcTy.getSizeInBits()) {
1983     if (SrcTy.isPointer()) {
1984       const DataLayout &DL = MIRBuilder.getDataLayout();
1985       if (DL.isNonIntegralAddressSpace(SrcTy.getAddressSpace())) {
1986         LLVM_DEBUG(
1987             dbgs() << "Not casting non-integral address space integer\n");
1988         return UnableToLegalize;
1989       }
1990 
1991       SrcTy = LLT::scalar(SrcTy.getSizeInBits());
1992       SrcReg = MIRBuilder.buildPtrToInt(SrcTy, SrcReg).getReg(0);
1993     }
1994 
1995     // Widen SrcTy to WideTy. This does not affect the result, but since the
1996     // user requested this size, it is probably better handled than SrcTy and
1997     // should reduce the total number of legalization artifacts.
1998     if (WideTy.getSizeInBits() > SrcTy.getSizeInBits()) {
1999       SrcTy = WideTy;
2000       SrcReg = MIRBuilder.buildAnyExt(WideTy, SrcReg).getReg(0);
2001     }
2002 
2003     // Theres no unmerge type to target. Directly extract the bits from the
2004     // source type
2005     unsigned DstSize = DstTy.getSizeInBits();
2006 
2007     MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
2008     for (int I = 1; I != NumDst; ++I) {
2009       auto ShiftAmt = MIRBuilder.buildConstant(SrcTy, DstSize * I);
2010       auto Shr = MIRBuilder.buildLShr(SrcTy, SrcReg, ShiftAmt);
2011       MIRBuilder.buildTrunc(MI.getOperand(I), Shr);
2012     }
2013 
2014     MI.eraseFromParent();
2015     return Legalized;
2016   }
2017 
2018   // Extend the source to a wider type.
2019   LLT LCMTy = getLCMType(SrcTy, WideTy);
2020 
2021   Register WideSrc = SrcReg;
2022   if (LCMTy.getSizeInBits() != SrcTy.getSizeInBits()) {
2023     // TODO: If this is an integral address space, cast to integer and anyext.
2024     if (SrcTy.isPointer()) {
2025       LLVM_DEBUG(dbgs() << "Widening pointer source types not implemented\n");
2026       return UnableToLegalize;
2027     }
2028 
2029     WideSrc = MIRBuilder.buildAnyExt(LCMTy, WideSrc).getReg(0);
2030   }
2031 
2032   auto Unmerge = MIRBuilder.buildUnmerge(WideTy, WideSrc);
2033 
2034   // Create a sequence of unmerges and merges to the original results. Since we
2035   // may have widened the source, we will need to pad the results with dead defs
2036   // to cover the source register.
2037   // e.g. widen s48 to s64:
2038   // %1:_(s48), %2:_(s48) = G_UNMERGE_VALUES %0:_(s96)
2039   //
2040   // =>
2041   //  %4:_(s192) = G_ANYEXT %0:_(s96)
2042   //  %5:_(s64), %6, %7 = G_UNMERGE_VALUES %4 ; Requested unmerge
2043   //  ; unpack to GCD type, with extra dead defs
2044   //  %8:_(s16), %9, %10, %11 = G_UNMERGE_VALUES %5:_(s64)
2045   //  %12:_(s16), %13, dead %14, dead %15 = G_UNMERGE_VALUES %6:_(s64)
2046   //  dead %16:_(s16), dead %17, dead %18, dead %18 = G_UNMERGE_VALUES %7:_(s64)
2047   //  %1:_(s48) = G_MERGE_VALUES %8:_(s16), %9, %10   ; Remerge to destination
2048   //  %2:_(s48) = G_MERGE_VALUES %11:_(s16), %12, %13 ; Remerge to destination
2049   const LLT GCDTy = getGCDType(WideTy, DstTy);
2050   const int NumUnmerge = Unmerge->getNumOperands() - 1;
2051   const int PartsPerRemerge = DstTy.getSizeInBits() / GCDTy.getSizeInBits();
2052 
2053   // Directly unmerge to the destination without going through a GCD type
2054   // if possible
2055   if (PartsPerRemerge == 1) {
2056     const int PartsPerUnmerge = WideTy.getSizeInBits() / DstTy.getSizeInBits();
2057 
2058     for (int I = 0; I != NumUnmerge; ++I) {
2059       auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
2060 
2061       for (int J = 0; J != PartsPerUnmerge; ++J) {
2062         int Idx = I * PartsPerUnmerge + J;
2063         if (Idx < NumDst)
2064           MIB.addDef(MI.getOperand(Idx).getReg());
2065         else {
2066           // Create dead def for excess components.
2067           MIB.addDef(MRI.createGenericVirtualRegister(DstTy));
2068         }
2069       }
2070 
2071       MIB.addUse(Unmerge.getReg(I));
2072     }
2073   } else {
2074     SmallVector<Register, 16> Parts;
2075     for (int J = 0; J != NumUnmerge; ++J)
2076       extractGCDType(Parts, GCDTy, Unmerge.getReg(J));
2077 
2078     SmallVector<Register, 8> RemergeParts;
2079     for (int I = 0; I != NumDst; ++I) {
2080       for (int J = 0; J < PartsPerRemerge; ++J) {
2081         const int Idx = I * PartsPerRemerge + J;
2082         RemergeParts.emplace_back(Parts[Idx]);
2083       }
2084 
2085       MIRBuilder.buildMergeLikeInstr(MI.getOperand(I).getReg(), RemergeParts);
2086       RemergeParts.clear();
2087     }
2088   }
2089 
2090   MI.eraseFromParent();
2091   return Legalized;
2092 }
2093 
2094 LegalizerHelper::LegalizeResult
widenScalarExtract(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2095 LegalizerHelper::widenScalarExtract(MachineInstr &MI, unsigned TypeIdx,
2096                                     LLT WideTy) {
2097   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
2098   unsigned Offset = MI.getOperand(2).getImm();
2099 
2100   if (TypeIdx == 0) {
2101     if (SrcTy.isVector() || DstTy.isVector())
2102       return UnableToLegalize;
2103 
2104     SrcOp Src(SrcReg);
2105     if (SrcTy.isPointer()) {
2106       // Extracts from pointers can be handled only if they are really just
2107       // simple integers.
2108       const DataLayout &DL = MIRBuilder.getDataLayout();
2109       if (DL.isNonIntegralAddressSpace(SrcTy.getAddressSpace()))
2110         return UnableToLegalize;
2111 
2112       LLT SrcAsIntTy = LLT::scalar(SrcTy.getSizeInBits());
2113       Src = MIRBuilder.buildPtrToInt(SrcAsIntTy, Src);
2114       SrcTy = SrcAsIntTy;
2115     }
2116 
2117     if (DstTy.isPointer())
2118       return UnableToLegalize;
2119 
2120     if (Offset == 0) {
2121       // Avoid a shift in the degenerate case.
2122       MIRBuilder.buildTrunc(DstReg,
2123                             MIRBuilder.buildAnyExtOrTrunc(WideTy, Src));
2124       MI.eraseFromParent();
2125       return Legalized;
2126     }
2127 
2128     // Do a shift in the source type.
2129     LLT ShiftTy = SrcTy;
2130     if (WideTy.getSizeInBits() > SrcTy.getSizeInBits()) {
2131       Src = MIRBuilder.buildAnyExt(WideTy, Src);
2132       ShiftTy = WideTy;
2133     }
2134 
2135     auto LShr = MIRBuilder.buildLShr(
2136       ShiftTy, Src, MIRBuilder.buildConstant(ShiftTy, Offset));
2137     MIRBuilder.buildTrunc(DstReg, LShr);
2138     MI.eraseFromParent();
2139     return Legalized;
2140   }
2141 
2142   if (SrcTy.isScalar()) {
2143     Observer.changingInstr(MI);
2144     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2145     Observer.changedInstr(MI);
2146     return Legalized;
2147   }
2148 
2149   if (!SrcTy.isVector())
2150     return UnableToLegalize;
2151 
2152   if (DstTy != SrcTy.getElementType())
2153     return UnableToLegalize;
2154 
2155   if (Offset % SrcTy.getScalarSizeInBits() != 0)
2156     return UnableToLegalize;
2157 
2158   Observer.changingInstr(MI);
2159   widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2160 
2161   MI.getOperand(2).setImm((WideTy.getSizeInBits() / SrcTy.getSizeInBits()) *
2162                           Offset);
2163   widenScalarDst(MI, WideTy.getScalarType(), 0);
2164   Observer.changedInstr(MI);
2165   return Legalized;
2166 }
2167 
2168 LegalizerHelper::LegalizeResult
widenScalarInsert(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2169 LegalizerHelper::widenScalarInsert(MachineInstr &MI, unsigned TypeIdx,
2170                                    LLT WideTy) {
2171   if (TypeIdx != 0 || WideTy.isVector())
2172     return UnableToLegalize;
2173   Observer.changingInstr(MI);
2174   widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2175   widenScalarDst(MI, WideTy);
2176   Observer.changedInstr(MI);
2177   return Legalized;
2178 }
2179 
2180 LegalizerHelper::LegalizeResult
widenScalarAddSubOverflow(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2181 LegalizerHelper::widenScalarAddSubOverflow(MachineInstr &MI, unsigned TypeIdx,
2182                                            LLT WideTy) {
2183   unsigned Opcode;
2184   unsigned ExtOpcode;
2185   std::optional<Register> CarryIn;
2186   switch (MI.getOpcode()) {
2187   default:
2188     llvm_unreachable("Unexpected opcode!");
2189   case TargetOpcode::G_SADDO:
2190     Opcode = TargetOpcode::G_ADD;
2191     ExtOpcode = TargetOpcode::G_SEXT;
2192     break;
2193   case TargetOpcode::G_SSUBO:
2194     Opcode = TargetOpcode::G_SUB;
2195     ExtOpcode = TargetOpcode::G_SEXT;
2196     break;
2197   case TargetOpcode::G_UADDO:
2198     Opcode = TargetOpcode::G_ADD;
2199     ExtOpcode = TargetOpcode::G_ZEXT;
2200     break;
2201   case TargetOpcode::G_USUBO:
2202     Opcode = TargetOpcode::G_SUB;
2203     ExtOpcode = TargetOpcode::G_ZEXT;
2204     break;
2205   case TargetOpcode::G_SADDE:
2206     Opcode = TargetOpcode::G_UADDE;
2207     ExtOpcode = TargetOpcode::G_SEXT;
2208     CarryIn = MI.getOperand(4).getReg();
2209     break;
2210   case TargetOpcode::G_SSUBE:
2211     Opcode = TargetOpcode::G_USUBE;
2212     ExtOpcode = TargetOpcode::G_SEXT;
2213     CarryIn = MI.getOperand(4).getReg();
2214     break;
2215   case TargetOpcode::G_UADDE:
2216     Opcode = TargetOpcode::G_UADDE;
2217     ExtOpcode = TargetOpcode::G_ZEXT;
2218     CarryIn = MI.getOperand(4).getReg();
2219     break;
2220   case TargetOpcode::G_USUBE:
2221     Opcode = TargetOpcode::G_USUBE;
2222     ExtOpcode = TargetOpcode::G_ZEXT;
2223     CarryIn = MI.getOperand(4).getReg();
2224     break;
2225   }
2226 
2227   if (TypeIdx == 1) {
2228     unsigned BoolExtOp = MIRBuilder.getBoolExtOp(WideTy.isVector(), false);
2229 
2230     Observer.changingInstr(MI);
2231     if (CarryIn)
2232       widenScalarSrc(MI, WideTy, 4, BoolExtOp);
2233     widenScalarDst(MI, WideTy, 1);
2234 
2235     Observer.changedInstr(MI);
2236     return Legalized;
2237   }
2238 
2239   auto LHSExt = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MI.getOperand(2)});
2240   auto RHSExt = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MI.getOperand(3)});
2241   // Do the arithmetic in the larger type.
2242   Register NewOp;
2243   if (CarryIn) {
2244     LLT CarryOutTy = MRI.getType(MI.getOperand(1).getReg());
2245     NewOp = MIRBuilder
2246                 .buildInstr(Opcode, {WideTy, CarryOutTy},
2247                             {LHSExt, RHSExt, *CarryIn})
2248                 .getReg(0);
2249   } else {
2250     NewOp = MIRBuilder.buildInstr(Opcode, {WideTy}, {LHSExt, RHSExt}).getReg(0);
2251   }
2252   LLT OrigTy = MRI.getType(MI.getOperand(0).getReg());
2253   auto TruncOp = MIRBuilder.buildTrunc(OrigTy, NewOp);
2254   auto ExtOp = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {TruncOp});
2255   // There is no overflow if the ExtOp is the same as NewOp.
2256   MIRBuilder.buildICmp(CmpInst::ICMP_NE, MI.getOperand(1), NewOp, ExtOp);
2257   // Now trunc the NewOp to the original result.
2258   MIRBuilder.buildTrunc(MI.getOperand(0), NewOp);
2259   MI.eraseFromParent();
2260   return Legalized;
2261 }
2262 
2263 LegalizerHelper::LegalizeResult
widenScalarAddSubShlSat(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2264 LegalizerHelper::widenScalarAddSubShlSat(MachineInstr &MI, unsigned TypeIdx,
2265                                          LLT WideTy) {
2266   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SADDSAT ||
2267                   MI.getOpcode() == TargetOpcode::G_SSUBSAT ||
2268                   MI.getOpcode() == TargetOpcode::G_SSHLSAT;
2269   bool IsShift = MI.getOpcode() == TargetOpcode::G_SSHLSAT ||
2270                  MI.getOpcode() == TargetOpcode::G_USHLSAT;
2271   // We can convert this to:
2272   //   1. Any extend iN to iM
2273   //   2. SHL by M-N
2274   //   3. [US][ADD|SUB|SHL]SAT
2275   //   4. L/ASHR by M-N
2276   //
2277   // It may be more efficient to lower this to a min and a max operation in
2278   // the higher precision arithmetic if the promoted operation isn't legal,
2279   // but this decision is up to the target's lowering request.
2280   Register DstReg = MI.getOperand(0).getReg();
2281 
2282   unsigned NewBits = WideTy.getScalarSizeInBits();
2283   unsigned SHLAmount = NewBits - MRI.getType(DstReg).getScalarSizeInBits();
2284 
2285   // Shifts must zero-extend the RHS to preserve the unsigned quantity, and
2286   // must not left shift the RHS to preserve the shift amount.
2287   auto LHS = MIRBuilder.buildAnyExt(WideTy, MI.getOperand(1));
2288   auto RHS = IsShift ? MIRBuilder.buildZExt(WideTy, MI.getOperand(2))
2289                      : MIRBuilder.buildAnyExt(WideTy, MI.getOperand(2));
2290   auto ShiftK = MIRBuilder.buildConstant(WideTy, SHLAmount);
2291   auto ShiftL = MIRBuilder.buildShl(WideTy, LHS, ShiftK);
2292   auto ShiftR = IsShift ? RHS : MIRBuilder.buildShl(WideTy, RHS, ShiftK);
2293 
2294   auto WideInst = MIRBuilder.buildInstr(MI.getOpcode(), {WideTy},
2295                                         {ShiftL, ShiftR}, MI.getFlags());
2296 
2297   // Use a shift that will preserve the number of sign bits when the trunc is
2298   // folded away.
2299   auto Result = IsSigned ? MIRBuilder.buildAShr(WideTy, WideInst, ShiftK)
2300                          : MIRBuilder.buildLShr(WideTy, WideInst, ShiftK);
2301 
2302   MIRBuilder.buildTrunc(DstReg, Result);
2303   MI.eraseFromParent();
2304   return Legalized;
2305 }
2306 
2307 LegalizerHelper::LegalizeResult
widenScalarMulo(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2308 LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
2309                                  LLT WideTy) {
2310   if (TypeIdx == 1) {
2311     Observer.changingInstr(MI);
2312     widenScalarDst(MI, WideTy, 1);
2313     Observer.changedInstr(MI);
2314     return Legalized;
2315   }
2316 
2317   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULO;
2318   auto [Result, OriginalOverflow, LHS, RHS] = MI.getFirst4Regs();
2319   LLT SrcTy = MRI.getType(LHS);
2320   LLT OverflowTy = MRI.getType(OriginalOverflow);
2321   unsigned SrcBitWidth = SrcTy.getScalarSizeInBits();
2322 
2323   // To determine if the result overflowed in the larger type, we extend the
2324   // input to the larger type, do the multiply (checking if it overflows),
2325   // then also check the high bits of the result to see if overflow happened
2326   // there.
2327   unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
2328   auto LeftOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {LHS});
2329   auto RightOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {RHS});
2330 
2331   // Multiplication cannot overflow if the WideTy is >= 2 * original width,
2332   // so we don't need to check the overflow result of larger type Mulo.
2333   bool WideMulCanOverflow = WideTy.getScalarSizeInBits() < 2 * SrcBitWidth;
2334 
2335   unsigned MulOpc =
2336       WideMulCanOverflow ? MI.getOpcode() : (unsigned)TargetOpcode::G_MUL;
2337 
2338   MachineInstrBuilder Mulo;
2339   if (WideMulCanOverflow)
2340     Mulo = MIRBuilder.buildInstr(MulOpc, {WideTy, OverflowTy},
2341                                  {LeftOperand, RightOperand});
2342   else
2343     Mulo = MIRBuilder.buildInstr(MulOpc, {WideTy}, {LeftOperand, RightOperand});
2344 
2345   auto Mul = Mulo->getOperand(0);
2346   MIRBuilder.buildTrunc(Result, Mul);
2347 
2348   MachineInstrBuilder ExtResult;
2349   // Overflow occurred if it occurred in the larger type, or if the high part
2350   // of the result does not zero/sign-extend the low part.  Check this second
2351   // possibility first.
2352   if (IsSigned) {
2353     // For signed, overflow occurred when the high part does not sign-extend
2354     // the low part.
2355     ExtResult = MIRBuilder.buildSExtInReg(WideTy, Mul, SrcBitWidth);
2356   } else {
2357     // Unsigned overflow occurred when the high part does not zero-extend the
2358     // low part.
2359     ExtResult = MIRBuilder.buildZExtInReg(WideTy, Mul, SrcBitWidth);
2360   }
2361 
2362   if (WideMulCanOverflow) {
2363     auto Overflow =
2364         MIRBuilder.buildICmp(CmpInst::ICMP_NE, OverflowTy, Mul, ExtResult);
2365     // Finally check if the multiplication in the larger type itself overflowed.
2366     MIRBuilder.buildOr(OriginalOverflow, Mulo->getOperand(1), Overflow);
2367   } else {
2368     MIRBuilder.buildICmp(CmpInst::ICMP_NE, OriginalOverflow, Mul, ExtResult);
2369   }
2370   MI.eraseFromParent();
2371   return Legalized;
2372 }
2373 
2374 LegalizerHelper::LegalizeResult
widenScalar(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2375 LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
2376   switch (MI.getOpcode()) {
2377   default:
2378     return UnableToLegalize;
2379   case TargetOpcode::G_ATOMICRMW_XCHG:
2380   case TargetOpcode::G_ATOMICRMW_ADD:
2381   case TargetOpcode::G_ATOMICRMW_SUB:
2382   case TargetOpcode::G_ATOMICRMW_AND:
2383   case TargetOpcode::G_ATOMICRMW_OR:
2384   case TargetOpcode::G_ATOMICRMW_XOR:
2385   case TargetOpcode::G_ATOMICRMW_MIN:
2386   case TargetOpcode::G_ATOMICRMW_MAX:
2387   case TargetOpcode::G_ATOMICRMW_UMIN:
2388   case TargetOpcode::G_ATOMICRMW_UMAX:
2389     assert(TypeIdx == 0 && "atomicrmw with second scalar type");
2390     Observer.changingInstr(MI);
2391     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2392     widenScalarDst(MI, WideTy, 0);
2393     Observer.changedInstr(MI);
2394     return Legalized;
2395   case TargetOpcode::G_ATOMIC_CMPXCHG:
2396     assert(TypeIdx == 0 && "G_ATOMIC_CMPXCHG with second scalar type");
2397     Observer.changingInstr(MI);
2398     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2399     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2400     widenScalarDst(MI, WideTy, 0);
2401     Observer.changedInstr(MI);
2402     return Legalized;
2403   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS:
2404     if (TypeIdx == 0) {
2405       Observer.changingInstr(MI);
2406       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2407       widenScalarSrc(MI, WideTy, 4, TargetOpcode::G_ANYEXT);
2408       widenScalarDst(MI, WideTy, 0);
2409       Observer.changedInstr(MI);
2410       return Legalized;
2411     }
2412     assert(TypeIdx == 1 &&
2413            "G_ATOMIC_CMPXCHG_WITH_SUCCESS with third scalar type");
2414     Observer.changingInstr(MI);
2415     widenScalarDst(MI, WideTy, 1);
2416     Observer.changedInstr(MI);
2417     return Legalized;
2418   case TargetOpcode::G_EXTRACT:
2419     return widenScalarExtract(MI, TypeIdx, WideTy);
2420   case TargetOpcode::G_INSERT:
2421     return widenScalarInsert(MI, TypeIdx, WideTy);
2422   case TargetOpcode::G_MERGE_VALUES:
2423     return widenScalarMergeValues(MI, TypeIdx, WideTy);
2424   case TargetOpcode::G_UNMERGE_VALUES:
2425     return widenScalarUnmergeValues(MI, TypeIdx, WideTy);
2426   case TargetOpcode::G_SADDO:
2427   case TargetOpcode::G_SSUBO:
2428   case TargetOpcode::G_UADDO:
2429   case TargetOpcode::G_USUBO:
2430   case TargetOpcode::G_SADDE:
2431   case TargetOpcode::G_SSUBE:
2432   case TargetOpcode::G_UADDE:
2433   case TargetOpcode::G_USUBE:
2434     return widenScalarAddSubOverflow(MI, TypeIdx, WideTy);
2435   case TargetOpcode::G_UMULO:
2436   case TargetOpcode::G_SMULO:
2437     return widenScalarMulo(MI, TypeIdx, WideTy);
2438   case TargetOpcode::G_SADDSAT:
2439   case TargetOpcode::G_SSUBSAT:
2440   case TargetOpcode::G_SSHLSAT:
2441   case TargetOpcode::G_UADDSAT:
2442   case TargetOpcode::G_USUBSAT:
2443   case TargetOpcode::G_USHLSAT:
2444     return widenScalarAddSubShlSat(MI, TypeIdx, WideTy);
2445   case TargetOpcode::G_CTTZ:
2446   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
2447   case TargetOpcode::G_CTLZ:
2448   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
2449   case TargetOpcode::G_CTPOP: {
2450     if (TypeIdx == 0) {
2451       Observer.changingInstr(MI);
2452       widenScalarDst(MI, WideTy, 0);
2453       Observer.changedInstr(MI);
2454       return Legalized;
2455     }
2456 
2457     Register SrcReg = MI.getOperand(1).getReg();
2458 
2459     // First extend the input.
2460     unsigned ExtOpc = MI.getOpcode() == TargetOpcode::G_CTTZ ||
2461                               MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF
2462                           ? TargetOpcode::G_ANYEXT
2463                           : TargetOpcode::G_ZEXT;
2464     auto MIBSrc = MIRBuilder.buildInstr(ExtOpc, {WideTy}, {SrcReg});
2465     LLT CurTy = MRI.getType(SrcReg);
2466     unsigned NewOpc = MI.getOpcode();
2467     if (NewOpc == TargetOpcode::G_CTTZ) {
2468       // The count is the same in the larger type except if the original
2469       // value was zero.  This can be handled by setting the bit just off
2470       // the top of the original type.
2471       auto TopBit =
2472           APInt::getOneBitSet(WideTy.getSizeInBits(), CurTy.getSizeInBits());
2473       MIBSrc = MIRBuilder.buildOr(
2474         WideTy, MIBSrc, MIRBuilder.buildConstant(WideTy, TopBit));
2475       // Now we know the operand is non-zero, use the more relaxed opcode.
2476       NewOpc = TargetOpcode::G_CTTZ_ZERO_UNDEF;
2477     }
2478 
2479     unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
2480 
2481     if (MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
2482       // An optimization where the result is the CTLZ after the left shift by
2483       // (Difference in widety and current ty), that is,
2484       // MIBSrc = MIBSrc << (sizeinbits(WideTy) - sizeinbits(CurTy))
2485       // Result = ctlz MIBSrc
2486       MIBSrc = MIRBuilder.buildShl(WideTy, MIBSrc,
2487                                    MIRBuilder.buildConstant(WideTy, SizeDiff));
2488     }
2489 
2490     // Perform the operation at the larger size.
2491     auto MIBNewOp = MIRBuilder.buildInstr(NewOpc, {WideTy}, {MIBSrc});
2492     // This is already the correct result for CTPOP and CTTZs
2493     if (MI.getOpcode() == TargetOpcode::G_CTLZ) {
2494       // The correct result is NewOp - (Difference in widety and current ty).
2495       MIBNewOp = MIRBuilder.buildSub(
2496           WideTy, MIBNewOp, MIRBuilder.buildConstant(WideTy, SizeDiff));
2497     }
2498 
2499     MIRBuilder.buildZExtOrTrunc(MI.getOperand(0), MIBNewOp);
2500     MI.eraseFromParent();
2501     return Legalized;
2502   }
2503   case TargetOpcode::G_BSWAP: {
2504     Observer.changingInstr(MI);
2505     Register DstReg = MI.getOperand(0).getReg();
2506 
2507     Register ShrReg = MRI.createGenericVirtualRegister(WideTy);
2508     Register DstExt = MRI.createGenericVirtualRegister(WideTy);
2509     Register ShiftAmtReg = MRI.createGenericVirtualRegister(WideTy);
2510     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2511 
2512     MI.getOperand(0).setReg(DstExt);
2513 
2514     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
2515 
2516     LLT Ty = MRI.getType(DstReg);
2517     unsigned DiffBits = WideTy.getScalarSizeInBits() - Ty.getScalarSizeInBits();
2518     MIRBuilder.buildConstant(ShiftAmtReg, DiffBits);
2519     MIRBuilder.buildLShr(ShrReg, DstExt, ShiftAmtReg);
2520 
2521     MIRBuilder.buildTrunc(DstReg, ShrReg);
2522     Observer.changedInstr(MI);
2523     return Legalized;
2524   }
2525   case TargetOpcode::G_BITREVERSE: {
2526     Observer.changingInstr(MI);
2527 
2528     Register DstReg = MI.getOperand(0).getReg();
2529     LLT Ty = MRI.getType(DstReg);
2530     unsigned DiffBits = WideTy.getScalarSizeInBits() - Ty.getScalarSizeInBits();
2531 
2532     Register DstExt = MRI.createGenericVirtualRegister(WideTy);
2533     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2534     MI.getOperand(0).setReg(DstExt);
2535     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
2536 
2537     auto ShiftAmt = MIRBuilder.buildConstant(WideTy, DiffBits);
2538     auto Shift = MIRBuilder.buildLShr(WideTy, DstExt, ShiftAmt);
2539     MIRBuilder.buildTrunc(DstReg, Shift);
2540     Observer.changedInstr(MI);
2541     return Legalized;
2542   }
2543   case TargetOpcode::G_FREEZE:
2544   case TargetOpcode::G_CONSTANT_FOLD_BARRIER:
2545     Observer.changingInstr(MI);
2546     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2547     widenScalarDst(MI, WideTy);
2548     Observer.changedInstr(MI);
2549     return Legalized;
2550 
2551   case TargetOpcode::G_ABS:
2552     Observer.changingInstr(MI);
2553     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2554     widenScalarDst(MI, WideTy);
2555     Observer.changedInstr(MI);
2556     return Legalized;
2557 
2558   case TargetOpcode::G_ADD:
2559   case TargetOpcode::G_AND:
2560   case TargetOpcode::G_MUL:
2561   case TargetOpcode::G_OR:
2562   case TargetOpcode::G_XOR:
2563   case TargetOpcode::G_SUB:
2564   case TargetOpcode::G_SHUFFLE_VECTOR:
2565     // Perform operation at larger width (any extension is fines here, high bits
2566     // don't affect the result) and then truncate the result back to the
2567     // original type.
2568     Observer.changingInstr(MI);
2569     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2570     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2571     widenScalarDst(MI, WideTy);
2572     Observer.changedInstr(MI);
2573     return Legalized;
2574 
2575   case TargetOpcode::G_SBFX:
2576   case TargetOpcode::G_UBFX:
2577     Observer.changingInstr(MI);
2578 
2579     if (TypeIdx == 0) {
2580       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2581       widenScalarDst(MI, WideTy);
2582     } else {
2583       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2584       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ZEXT);
2585     }
2586 
2587     Observer.changedInstr(MI);
2588     return Legalized;
2589 
2590   case TargetOpcode::G_SHL:
2591     Observer.changingInstr(MI);
2592 
2593     if (TypeIdx == 0) {
2594       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2595       widenScalarDst(MI, WideTy);
2596     } else {
2597       assert(TypeIdx == 1);
2598       // The "number of bits to shift" operand must preserve its value as an
2599       // unsigned integer:
2600       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2601     }
2602 
2603     Observer.changedInstr(MI);
2604     return Legalized;
2605 
2606   case TargetOpcode::G_ROTR:
2607   case TargetOpcode::G_ROTL:
2608     if (TypeIdx != 1)
2609       return UnableToLegalize;
2610 
2611     Observer.changingInstr(MI);
2612     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2613     Observer.changedInstr(MI);
2614     return Legalized;
2615 
2616   case TargetOpcode::G_SDIV:
2617   case TargetOpcode::G_SREM:
2618   case TargetOpcode::G_SMIN:
2619   case TargetOpcode::G_SMAX:
2620     Observer.changingInstr(MI);
2621     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2622     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2623     widenScalarDst(MI, WideTy);
2624     Observer.changedInstr(MI);
2625     return Legalized;
2626 
2627   case TargetOpcode::G_SDIVREM:
2628     Observer.changingInstr(MI);
2629     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2630     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_SEXT);
2631     widenScalarDst(MI, WideTy);
2632     widenScalarDst(MI, WideTy, 1);
2633     Observer.changedInstr(MI);
2634     return Legalized;
2635 
2636   case TargetOpcode::G_ASHR:
2637   case TargetOpcode::G_LSHR:
2638     Observer.changingInstr(MI);
2639 
2640     if (TypeIdx == 0) {
2641       unsigned CvtOp = MI.getOpcode() == TargetOpcode::G_ASHR ?
2642         TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
2643 
2644       widenScalarSrc(MI, WideTy, 1, CvtOp);
2645       widenScalarDst(MI, WideTy);
2646     } else {
2647       assert(TypeIdx == 1);
2648       // The "number of bits to shift" operand must preserve its value as an
2649       // unsigned integer:
2650       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2651     }
2652 
2653     Observer.changedInstr(MI);
2654     return Legalized;
2655   case TargetOpcode::G_UDIV:
2656   case TargetOpcode::G_UREM:
2657   case TargetOpcode::G_UMIN:
2658   case TargetOpcode::G_UMAX:
2659     Observer.changingInstr(MI);
2660     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2661     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2662     widenScalarDst(MI, WideTy);
2663     Observer.changedInstr(MI);
2664     return Legalized;
2665 
2666   case TargetOpcode::G_UDIVREM:
2667     Observer.changingInstr(MI);
2668     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2669     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ZEXT);
2670     widenScalarDst(MI, WideTy);
2671     widenScalarDst(MI, WideTy, 1);
2672     Observer.changedInstr(MI);
2673     return Legalized;
2674 
2675   case TargetOpcode::G_SELECT:
2676     Observer.changingInstr(MI);
2677     if (TypeIdx == 0) {
2678       // Perform operation at larger width (any extension is fine here, high
2679       // bits don't affect the result) and then truncate the result back to the
2680       // original type.
2681       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2682       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2683       widenScalarDst(MI, WideTy);
2684     } else {
2685       bool IsVec = MRI.getType(MI.getOperand(1).getReg()).isVector();
2686       // Explicit extension is required here since high bits affect the result.
2687       widenScalarSrc(MI, WideTy, 1, MIRBuilder.getBoolExtOp(IsVec, false));
2688     }
2689     Observer.changedInstr(MI);
2690     return Legalized;
2691 
2692   case TargetOpcode::G_FPTOSI:
2693   case TargetOpcode::G_FPTOUI:
2694   case TargetOpcode::G_INTRINSIC_LRINT:
2695   case TargetOpcode::G_INTRINSIC_LLRINT:
2696   case TargetOpcode::G_IS_FPCLASS:
2697     Observer.changingInstr(MI);
2698 
2699     if (TypeIdx == 0)
2700       widenScalarDst(MI, WideTy);
2701     else
2702       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
2703 
2704     Observer.changedInstr(MI);
2705     return Legalized;
2706   case TargetOpcode::G_SITOFP:
2707     Observer.changingInstr(MI);
2708 
2709     if (TypeIdx == 0)
2710       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2711     else
2712       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2713 
2714     Observer.changedInstr(MI);
2715     return Legalized;
2716   case TargetOpcode::G_UITOFP:
2717     Observer.changingInstr(MI);
2718 
2719     if (TypeIdx == 0)
2720       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2721     else
2722       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2723 
2724     Observer.changedInstr(MI);
2725     return Legalized;
2726   case TargetOpcode::G_LOAD:
2727   case TargetOpcode::G_SEXTLOAD:
2728   case TargetOpcode::G_ZEXTLOAD:
2729     Observer.changingInstr(MI);
2730     widenScalarDst(MI, WideTy);
2731     Observer.changedInstr(MI);
2732     return Legalized;
2733 
2734   case TargetOpcode::G_STORE: {
2735     if (TypeIdx != 0)
2736       return UnableToLegalize;
2737 
2738     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
2739     if (!Ty.isScalar())
2740       return UnableToLegalize;
2741 
2742     Observer.changingInstr(MI);
2743 
2744     unsigned ExtType = Ty.getScalarSizeInBits() == 1 ?
2745       TargetOpcode::G_ZEXT : TargetOpcode::G_ANYEXT;
2746     widenScalarSrc(MI, WideTy, 0, ExtType);
2747 
2748     Observer.changedInstr(MI);
2749     return Legalized;
2750   }
2751   case TargetOpcode::G_CONSTANT: {
2752     MachineOperand &SrcMO = MI.getOperand(1);
2753     LLVMContext &Ctx = MIRBuilder.getMF().getFunction().getContext();
2754     unsigned ExtOpc = LI.getExtOpcodeForWideningConstant(
2755         MRI.getType(MI.getOperand(0).getReg()));
2756     assert((ExtOpc == TargetOpcode::G_ZEXT || ExtOpc == TargetOpcode::G_SEXT ||
2757             ExtOpc == TargetOpcode::G_ANYEXT) &&
2758            "Illegal Extend");
2759     const APInt &SrcVal = SrcMO.getCImm()->getValue();
2760     const APInt &Val = (ExtOpc == TargetOpcode::G_SEXT)
2761                            ? SrcVal.sext(WideTy.getSizeInBits())
2762                            : SrcVal.zext(WideTy.getSizeInBits());
2763     Observer.changingInstr(MI);
2764     SrcMO.setCImm(ConstantInt::get(Ctx, Val));
2765 
2766     widenScalarDst(MI, WideTy);
2767     Observer.changedInstr(MI);
2768     return Legalized;
2769   }
2770   case TargetOpcode::G_FCONSTANT: {
2771     // To avoid changing the bits of the constant due to extension to a larger
2772     // type and then using G_FPTRUNC, we simply convert to a G_CONSTANT.
2773     MachineOperand &SrcMO = MI.getOperand(1);
2774     APInt Val = SrcMO.getFPImm()->getValueAPF().bitcastToAPInt();
2775     MIRBuilder.setInstrAndDebugLoc(MI);
2776     auto IntCst = MIRBuilder.buildConstant(MI.getOperand(0).getReg(), Val);
2777     widenScalarDst(*IntCst, WideTy, 0, TargetOpcode::G_TRUNC);
2778     MI.eraseFromParent();
2779     return Legalized;
2780   }
2781   case TargetOpcode::G_IMPLICIT_DEF: {
2782     Observer.changingInstr(MI);
2783     widenScalarDst(MI, WideTy);
2784     Observer.changedInstr(MI);
2785     return Legalized;
2786   }
2787   case TargetOpcode::G_BRCOND:
2788     Observer.changingInstr(MI);
2789     widenScalarSrc(MI, WideTy, 0, MIRBuilder.getBoolExtOp(false, false));
2790     Observer.changedInstr(MI);
2791     return Legalized;
2792 
2793   case TargetOpcode::G_FCMP:
2794     Observer.changingInstr(MI);
2795     if (TypeIdx == 0)
2796       widenScalarDst(MI, WideTy);
2797     else {
2798       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_FPEXT);
2799       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_FPEXT);
2800     }
2801     Observer.changedInstr(MI);
2802     return Legalized;
2803 
2804   case TargetOpcode::G_ICMP:
2805     Observer.changingInstr(MI);
2806     if (TypeIdx == 0)
2807       widenScalarDst(MI, WideTy);
2808     else {
2809       unsigned ExtOpcode = CmpInst::isSigned(static_cast<CmpInst::Predicate>(
2810                                MI.getOperand(1).getPredicate()))
2811                                ? TargetOpcode::G_SEXT
2812                                : TargetOpcode::G_ZEXT;
2813       widenScalarSrc(MI, WideTy, 2, ExtOpcode);
2814       widenScalarSrc(MI, WideTy, 3, ExtOpcode);
2815     }
2816     Observer.changedInstr(MI);
2817     return Legalized;
2818 
2819   case TargetOpcode::G_PTR_ADD:
2820     assert(TypeIdx == 1 && "unable to legalize pointer of G_PTR_ADD");
2821     Observer.changingInstr(MI);
2822     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2823     Observer.changedInstr(MI);
2824     return Legalized;
2825 
2826   case TargetOpcode::G_PHI: {
2827     assert(TypeIdx == 0 && "Expecting only Idx 0");
2828 
2829     Observer.changingInstr(MI);
2830     for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
2831       MachineBasicBlock &OpMBB = *MI.getOperand(I + 1).getMBB();
2832       MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
2833       widenScalarSrc(MI, WideTy, I, TargetOpcode::G_ANYEXT);
2834     }
2835 
2836     MachineBasicBlock &MBB = *MI.getParent();
2837     MIRBuilder.setInsertPt(MBB, --MBB.getFirstNonPHI());
2838     widenScalarDst(MI, WideTy);
2839     Observer.changedInstr(MI);
2840     return Legalized;
2841   }
2842   case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
2843     if (TypeIdx == 0) {
2844       Register VecReg = MI.getOperand(1).getReg();
2845       LLT VecTy = MRI.getType(VecReg);
2846       Observer.changingInstr(MI);
2847 
2848       widenScalarSrc(
2849           MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1,
2850           TargetOpcode::G_ANYEXT);
2851 
2852       widenScalarDst(MI, WideTy, 0);
2853       Observer.changedInstr(MI);
2854       return Legalized;
2855     }
2856 
2857     if (TypeIdx != 2)
2858       return UnableToLegalize;
2859     Observer.changingInstr(MI);
2860     // TODO: Probably should be zext
2861     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2862     Observer.changedInstr(MI);
2863     return Legalized;
2864   }
2865   case TargetOpcode::G_INSERT_VECTOR_ELT: {
2866     if (TypeIdx == 0) {
2867       Observer.changingInstr(MI);
2868       const LLT WideEltTy = WideTy.getElementType();
2869 
2870       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2871       widenScalarSrc(MI, WideEltTy, 2, TargetOpcode::G_ANYEXT);
2872       widenScalarDst(MI, WideTy, 0);
2873       Observer.changedInstr(MI);
2874       return Legalized;
2875     }
2876 
2877     if (TypeIdx == 1) {
2878       Observer.changingInstr(MI);
2879 
2880       Register VecReg = MI.getOperand(1).getReg();
2881       LLT VecTy = MRI.getType(VecReg);
2882       LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy);
2883 
2884       widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT);
2885       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2886       widenScalarDst(MI, WideVecTy, 0);
2887       Observer.changedInstr(MI);
2888       return Legalized;
2889     }
2890 
2891     if (TypeIdx == 2) {
2892       Observer.changingInstr(MI);
2893       // TODO: Probably should be zext
2894       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_SEXT);
2895       Observer.changedInstr(MI);
2896       return Legalized;
2897     }
2898 
2899     return UnableToLegalize;
2900   }
2901   case TargetOpcode::G_FADD:
2902   case TargetOpcode::G_FMUL:
2903   case TargetOpcode::G_FSUB:
2904   case TargetOpcode::G_FMA:
2905   case TargetOpcode::G_FMAD:
2906   case TargetOpcode::G_FNEG:
2907   case TargetOpcode::G_FABS:
2908   case TargetOpcode::G_FCANONICALIZE:
2909   case TargetOpcode::G_FMINNUM:
2910   case TargetOpcode::G_FMAXNUM:
2911   case TargetOpcode::G_FMINNUM_IEEE:
2912   case TargetOpcode::G_FMAXNUM_IEEE:
2913   case TargetOpcode::G_FMINIMUM:
2914   case TargetOpcode::G_FMAXIMUM:
2915   case TargetOpcode::G_FDIV:
2916   case TargetOpcode::G_FREM:
2917   case TargetOpcode::G_FCEIL:
2918   case TargetOpcode::G_FFLOOR:
2919   case TargetOpcode::G_FCOS:
2920   case TargetOpcode::G_FSIN:
2921   case TargetOpcode::G_FTAN:
2922   case TargetOpcode::G_FACOS:
2923   case TargetOpcode::G_FASIN:
2924   case TargetOpcode::G_FATAN:
2925   case TargetOpcode::G_FCOSH:
2926   case TargetOpcode::G_FSINH:
2927   case TargetOpcode::G_FTANH:
2928   case TargetOpcode::G_FLOG10:
2929   case TargetOpcode::G_FLOG:
2930   case TargetOpcode::G_FLOG2:
2931   case TargetOpcode::G_FRINT:
2932   case TargetOpcode::G_FNEARBYINT:
2933   case TargetOpcode::G_FSQRT:
2934   case TargetOpcode::G_FEXP:
2935   case TargetOpcode::G_FEXP2:
2936   case TargetOpcode::G_FEXP10:
2937   case TargetOpcode::G_FPOW:
2938   case TargetOpcode::G_INTRINSIC_TRUNC:
2939   case TargetOpcode::G_INTRINSIC_ROUND:
2940   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
2941     assert(TypeIdx == 0);
2942     Observer.changingInstr(MI);
2943 
2944     for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I)
2945       widenScalarSrc(MI, WideTy, I, TargetOpcode::G_FPEXT);
2946 
2947     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2948     Observer.changedInstr(MI);
2949     return Legalized;
2950   case TargetOpcode::G_FPOWI:
2951   case TargetOpcode::G_FLDEXP:
2952   case TargetOpcode::G_STRICT_FLDEXP: {
2953     if (TypeIdx == 0) {
2954       if (MI.getOpcode() == TargetOpcode::G_STRICT_FLDEXP)
2955         return UnableToLegalize;
2956 
2957       Observer.changingInstr(MI);
2958       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
2959       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2960       Observer.changedInstr(MI);
2961       return Legalized;
2962     }
2963 
2964     if (TypeIdx == 1) {
2965       // For some reason SelectionDAG tries to promote to a libcall without
2966       // actually changing the integer type for promotion.
2967       Observer.changingInstr(MI);
2968       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2969       Observer.changedInstr(MI);
2970       return Legalized;
2971     }
2972 
2973     return UnableToLegalize;
2974   }
2975   case TargetOpcode::G_FFREXP: {
2976     Observer.changingInstr(MI);
2977 
2978     if (TypeIdx == 0) {
2979       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_FPEXT);
2980       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2981     } else {
2982       widenScalarDst(MI, WideTy, 1);
2983     }
2984 
2985     Observer.changedInstr(MI);
2986     return Legalized;
2987   }
2988   case TargetOpcode::G_INTTOPTR:
2989     if (TypeIdx != 1)
2990       return UnableToLegalize;
2991 
2992     Observer.changingInstr(MI);
2993     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2994     Observer.changedInstr(MI);
2995     return Legalized;
2996   case TargetOpcode::G_PTRTOINT:
2997     if (TypeIdx != 0)
2998       return UnableToLegalize;
2999 
3000     Observer.changingInstr(MI);
3001     widenScalarDst(MI, WideTy, 0);
3002     Observer.changedInstr(MI);
3003     return Legalized;
3004   case TargetOpcode::G_BUILD_VECTOR: {
3005     Observer.changingInstr(MI);
3006 
3007     const LLT WideEltTy = TypeIdx == 1 ? WideTy : WideTy.getElementType();
3008     for (int I = 1, E = MI.getNumOperands(); I != E; ++I)
3009       widenScalarSrc(MI, WideEltTy, I, TargetOpcode::G_ANYEXT);
3010 
3011     // Avoid changing the result vector type if the source element type was
3012     // requested.
3013     if (TypeIdx == 1) {
3014       MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BUILD_VECTOR_TRUNC));
3015     } else {
3016       widenScalarDst(MI, WideTy, 0);
3017     }
3018 
3019     Observer.changedInstr(MI);
3020     return Legalized;
3021   }
3022   case TargetOpcode::G_SEXT_INREG:
3023     if (TypeIdx != 0)
3024       return UnableToLegalize;
3025 
3026     Observer.changingInstr(MI);
3027     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
3028     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_TRUNC);
3029     Observer.changedInstr(MI);
3030     return Legalized;
3031   case TargetOpcode::G_PTRMASK: {
3032     if (TypeIdx != 1)
3033       return UnableToLegalize;
3034     Observer.changingInstr(MI);
3035     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
3036     Observer.changedInstr(MI);
3037     return Legalized;
3038   }
3039   case TargetOpcode::G_VECREDUCE_FADD:
3040   case TargetOpcode::G_VECREDUCE_FMUL:
3041   case TargetOpcode::G_VECREDUCE_FMIN:
3042   case TargetOpcode::G_VECREDUCE_FMAX:
3043   case TargetOpcode::G_VECREDUCE_FMINIMUM:
3044   case TargetOpcode::G_VECREDUCE_FMAXIMUM: {
3045     if (TypeIdx != 0)
3046       return UnableToLegalize;
3047     Observer.changingInstr(MI);
3048     Register VecReg = MI.getOperand(1).getReg();
3049     LLT VecTy = MRI.getType(VecReg);
3050     LLT WideVecTy = VecTy.isVector()
3051                         ? LLT::vector(VecTy.getElementCount(), WideTy)
3052                         : WideTy;
3053     widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_FPEXT);
3054     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
3055     Observer.changedInstr(MI);
3056     return Legalized;
3057   }
3058   case TargetOpcode::G_VSCALE: {
3059     MachineOperand &SrcMO = MI.getOperand(1);
3060     LLVMContext &Ctx = MIRBuilder.getMF().getFunction().getContext();
3061     const APInt &SrcVal = SrcMO.getCImm()->getValue();
3062     // The CImm is always a signed value
3063     const APInt Val = SrcVal.sext(WideTy.getSizeInBits());
3064     Observer.changingInstr(MI);
3065     SrcMO.setCImm(ConstantInt::get(Ctx, Val));
3066     widenScalarDst(MI, WideTy);
3067     Observer.changedInstr(MI);
3068     return Legalized;
3069   }
3070   case TargetOpcode::G_SPLAT_VECTOR: {
3071     if (TypeIdx != 1)
3072       return UnableToLegalize;
3073 
3074     Observer.changingInstr(MI);
3075     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
3076     Observer.changedInstr(MI);
3077     return Legalized;
3078   }
3079   }
3080 }
3081 
getUnmergePieces(SmallVectorImpl<Register> & Pieces,MachineIRBuilder & B,Register Src,LLT Ty)3082 static void getUnmergePieces(SmallVectorImpl<Register> &Pieces,
3083                              MachineIRBuilder &B, Register Src, LLT Ty) {
3084   auto Unmerge = B.buildUnmerge(Ty, Src);
3085   for (int I = 0, E = Unmerge->getNumOperands() - 1; I != E; ++I)
3086     Pieces.push_back(Unmerge.getReg(I));
3087 }
3088 
emitLoadFromConstantPool(Register DstReg,const Constant * ConstVal,MachineIRBuilder & MIRBuilder)3089 static void emitLoadFromConstantPool(Register DstReg, const Constant *ConstVal,
3090                                      MachineIRBuilder &MIRBuilder) {
3091   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
3092   MachineFunction &MF = MIRBuilder.getMF();
3093   const DataLayout &DL = MIRBuilder.getDataLayout();
3094   unsigned AddrSpace = DL.getDefaultGlobalsAddressSpace();
3095   LLT AddrPtrTy = LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
3096   LLT DstLLT = MRI.getType(DstReg);
3097 
3098   Align Alignment(DL.getABITypeAlign(ConstVal->getType()));
3099 
3100   auto Addr = MIRBuilder.buildConstantPool(
3101       AddrPtrTy,
3102       MF.getConstantPool()->getConstantPoolIndex(ConstVal, Alignment));
3103 
3104   MachineMemOperand *MMO =
3105       MF.getMachineMemOperand(MachinePointerInfo::getConstantPool(MF),
3106                               MachineMemOperand::MOLoad, DstLLT, Alignment);
3107 
3108   MIRBuilder.buildLoadInstr(TargetOpcode::G_LOAD, DstReg, Addr, *MMO);
3109 }
3110 
3111 LegalizerHelper::LegalizeResult
lowerConstant(MachineInstr & MI)3112 LegalizerHelper::lowerConstant(MachineInstr &MI) {
3113   const MachineOperand &ConstOperand = MI.getOperand(1);
3114   const Constant *ConstantVal = ConstOperand.getCImm();
3115 
3116   emitLoadFromConstantPool(MI.getOperand(0).getReg(), ConstantVal, MIRBuilder);
3117   MI.eraseFromParent();
3118 
3119   return Legalized;
3120 }
3121 
3122 LegalizerHelper::LegalizeResult
lowerFConstant(MachineInstr & MI)3123 LegalizerHelper::lowerFConstant(MachineInstr &MI) {
3124   const MachineOperand &ConstOperand = MI.getOperand(1);
3125   const Constant *ConstantVal = ConstOperand.getFPImm();
3126 
3127   emitLoadFromConstantPool(MI.getOperand(0).getReg(), ConstantVal, MIRBuilder);
3128   MI.eraseFromParent();
3129 
3130   return Legalized;
3131 }
3132 
3133 LegalizerHelper::LegalizeResult
lowerBitcast(MachineInstr & MI)3134 LegalizerHelper::lowerBitcast(MachineInstr &MI) {
3135   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
3136   if (SrcTy.isVector()) {
3137     LLT SrcEltTy = SrcTy.getElementType();
3138     SmallVector<Register, 8> SrcRegs;
3139 
3140     if (DstTy.isVector()) {
3141       int NumDstElt = DstTy.getNumElements();
3142       int NumSrcElt = SrcTy.getNumElements();
3143 
3144       LLT DstEltTy = DstTy.getElementType();
3145       LLT DstCastTy = DstEltTy; // Intermediate bitcast result type
3146       LLT SrcPartTy = SrcEltTy; // Original unmerge result type.
3147 
3148       // If there's an element size mismatch, insert intermediate casts to match
3149       // the result element type.
3150       if (NumSrcElt < NumDstElt) { // Source element type is larger.
3151         // %1:_(<4 x s8>) = G_BITCAST %0:_(<2 x s16>)
3152         //
3153         // =>
3154         //
3155         // %2:_(s16), %3:_(s16) = G_UNMERGE_VALUES %0
3156         // %3:_(<2 x s8>) = G_BITCAST %2
3157         // %4:_(<2 x s8>) = G_BITCAST %3
3158         // %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4
3159         DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy);
3160         SrcPartTy = SrcEltTy;
3161       } else if (NumSrcElt > NumDstElt) { // Source element type is smaller.
3162         //
3163         // %1:_(<2 x s16>) = G_BITCAST %0:_(<4 x s8>)
3164         //
3165         // =>
3166         //
3167         // %2:_(<2 x s8>), %3:_(<2 x s8>) = G_UNMERGE_VALUES %0
3168         // %3:_(s16) = G_BITCAST %2
3169         // %4:_(s16) = G_BITCAST %3
3170         // %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4
3171         SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy);
3172         DstCastTy = DstEltTy;
3173       }
3174 
3175       getUnmergePieces(SrcRegs, MIRBuilder, Src, SrcPartTy);
3176       for (Register &SrcReg : SrcRegs)
3177         SrcReg = MIRBuilder.buildBitcast(DstCastTy, SrcReg).getReg(0);
3178     } else
3179       getUnmergePieces(SrcRegs, MIRBuilder, Src, SrcEltTy);
3180 
3181     MIRBuilder.buildMergeLikeInstr(Dst, SrcRegs);
3182     MI.eraseFromParent();
3183     return Legalized;
3184   }
3185 
3186   if (DstTy.isVector()) {
3187     SmallVector<Register, 8> SrcRegs;
3188     getUnmergePieces(SrcRegs, MIRBuilder, Src, DstTy.getElementType());
3189     MIRBuilder.buildMergeLikeInstr(Dst, SrcRegs);
3190     MI.eraseFromParent();
3191     return Legalized;
3192   }
3193 
3194   return UnableToLegalize;
3195 }
3196 
3197 /// Figure out the bit offset into a register when coercing a vector index for
3198 /// the wide element type. This is only for the case when promoting vector to
3199 /// one with larger elements.
3200 //
3201 ///
3202 /// %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
3203 /// %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
getBitcastWiderVectorElementOffset(MachineIRBuilder & B,Register Idx,unsigned NewEltSize,unsigned OldEltSize)3204 static Register getBitcastWiderVectorElementOffset(MachineIRBuilder &B,
3205                                                    Register Idx,
3206                                                    unsigned NewEltSize,
3207                                                    unsigned OldEltSize) {
3208   const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
3209   LLT IdxTy = B.getMRI()->getType(Idx);
3210 
3211   // Now figure out the amount we need to shift to get the target bits.
3212   auto OffsetMask = B.buildConstant(
3213       IdxTy, ~(APInt::getAllOnes(IdxTy.getSizeInBits()) << Log2EltRatio));
3214   auto OffsetIdx = B.buildAnd(IdxTy, Idx, OffsetMask);
3215   return B.buildShl(IdxTy, OffsetIdx,
3216                     B.buildConstant(IdxTy, Log2_32(OldEltSize))).getReg(0);
3217 }
3218 
3219 /// Perform a G_EXTRACT_VECTOR_ELT in a different sized vector element. If this
3220 /// is casting to a vector with a smaller element size, perform multiple element
3221 /// extracts and merge the results. If this is coercing to a vector with larger
3222 /// elements, index the bitcasted vector and extract the target element with bit
3223 /// operations. This is intended to force the indexing in the native register
3224 /// size for architectures that can dynamically index the register file.
3225 LegalizerHelper::LegalizeResult
bitcastExtractVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3226 LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
3227                                          LLT CastTy) {
3228   if (TypeIdx != 1)
3229     return UnableToLegalize;
3230 
3231   auto [Dst, DstTy, SrcVec, SrcVecTy, Idx, IdxTy] = MI.getFirst3RegLLTs();
3232 
3233   LLT SrcEltTy = SrcVecTy.getElementType();
3234   unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
3235   unsigned OldNumElts = SrcVecTy.getNumElements();
3236 
3237   LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
3238   Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
3239 
3240   const unsigned NewEltSize = NewEltTy.getSizeInBits();
3241   const unsigned OldEltSize = SrcEltTy.getSizeInBits();
3242   if (NewNumElts > OldNumElts) {
3243     // Decreasing the vector element size
3244     //
3245     // e.g. i64 = extract_vector_elt x:v2i64, y:i32
3246     //  =>
3247     //  v4i32:castx = bitcast x:v2i64
3248     //
3249     // i64 = bitcast
3250     //   (v2i32 build_vector (i32 (extract_vector_elt castx, (2 * y))),
3251     //                       (i32 (extract_vector_elt castx, (2 * y + 1)))
3252     //
3253     if (NewNumElts % OldNumElts != 0)
3254       return UnableToLegalize;
3255 
3256     // Type of the intermediate result vector.
3257     const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts;
3258     LLT MidTy =
3259         LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy);
3260 
3261     auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt);
3262 
3263     SmallVector<Register, 8> NewOps(NewEltsPerOldElt);
3264     auto NewBaseIdx = MIRBuilder.buildMul(IdxTy, Idx, NewEltsPerOldEltK);
3265 
3266     for (unsigned I = 0; I < NewEltsPerOldElt; ++I) {
3267       auto IdxOffset = MIRBuilder.buildConstant(IdxTy, I);
3268       auto TmpIdx = MIRBuilder.buildAdd(IdxTy, NewBaseIdx, IdxOffset);
3269       auto Elt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec, TmpIdx);
3270       NewOps[I] = Elt.getReg(0);
3271     }
3272 
3273     auto NewVec = MIRBuilder.buildBuildVector(MidTy, NewOps);
3274     MIRBuilder.buildBitcast(Dst, NewVec);
3275     MI.eraseFromParent();
3276     return Legalized;
3277   }
3278 
3279   if (NewNumElts < OldNumElts) {
3280     if (NewEltSize % OldEltSize != 0)
3281       return UnableToLegalize;
3282 
3283     // This only depends on powers of 2 because we use bit tricks to figure out
3284     // the bit offset we need to shift to get the target element. A general
3285     // expansion could emit division/multiply.
3286     if (!isPowerOf2_32(NewEltSize / OldEltSize))
3287       return UnableToLegalize;
3288 
3289     // Increasing the vector element size.
3290     // %elt:_(small_elt) = G_EXTRACT_VECTOR_ELT %vec:_(<N x small_elt>), %idx
3291     //
3292     //   =>
3293     //
3294     // %cast = G_BITCAST %vec
3295     // %scaled_idx = G_LSHR %idx, Log2(DstEltSize / SrcEltSize)
3296     // %wide_elt  = G_EXTRACT_VECTOR_ELT %cast, %scaled_idx
3297     // %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
3298     // %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
3299     // %elt_bits = G_LSHR %wide_elt, %offset_bits
3300     // %elt = G_TRUNC %elt_bits
3301 
3302     const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
3303     auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
3304 
3305     // Divide to get the index in the wider element type.
3306     auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
3307 
3308     Register WideElt = CastVec;
3309     if (CastTy.isVector()) {
3310       WideElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
3311                                                      ScaledIdx).getReg(0);
3312     }
3313 
3314     // Compute the bit offset into the register of the target element.
3315     Register OffsetBits = getBitcastWiderVectorElementOffset(
3316       MIRBuilder, Idx, NewEltSize, OldEltSize);
3317 
3318     // Shift the wide element to get the target element.
3319     auto ExtractedBits = MIRBuilder.buildLShr(NewEltTy, WideElt, OffsetBits);
3320     MIRBuilder.buildTrunc(Dst, ExtractedBits);
3321     MI.eraseFromParent();
3322     return Legalized;
3323   }
3324 
3325   return UnableToLegalize;
3326 }
3327 
3328 /// Emit code to insert \p InsertReg into \p TargetRet at \p OffsetBits in \p
3329 /// TargetReg, while preserving other bits in \p TargetReg.
3330 ///
3331 /// (InsertReg << Offset) | (TargetReg & ~(-1 >> InsertReg.size()) << Offset)
buildBitFieldInsert(MachineIRBuilder & B,Register TargetReg,Register InsertReg,Register OffsetBits)3332 static Register buildBitFieldInsert(MachineIRBuilder &B,
3333                                     Register TargetReg, Register InsertReg,
3334                                     Register OffsetBits) {
3335   LLT TargetTy = B.getMRI()->getType(TargetReg);
3336   LLT InsertTy = B.getMRI()->getType(InsertReg);
3337   auto ZextVal = B.buildZExt(TargetTy, InsertReg);
3338   auto ShiftedInsertVal = B.buildShl(TargetTy, ZextVal, OffsetBits);
3339 
3340   // Produce a bitmask of the value to insert
3341   auto EltMask = B.buildConstant(
3342     TargetTy, APInt::getLowBitsSet(TargetTy.getSizeInBits(),
3343                                    InsertTy.getSizeInBits()));
3344   // Shift it into position
3345   auto ShiftedMask = B.buildShl(TargetTy, EltMask, OffsetBits);
3346   auto InvShiftedMask = B.buildNot(TargetTy, ShiftedMask);
3347 
3348   // Clear out the bits in the wide element
3349   auto MaskedOldElt = B.buildAnd(TargetTy, TargetReg, InvShiftedMask);
3350 
3351   // The value to insert has all zeros already, so stick it into the masked
3352   // wide element.
3353   return B.buildOr(TargetTy, MaskedOldElt, ShiftedInsertVal).getReg(0);
3354 }
3355 
3356 /// Perform a G_INSERT_VECTOR_ELT in a different sized vector element. If this
3357 /// is increasing the element size, perform the indexing in the target element
3358 /// type, and use bit operations to insert at the element position. This is
3359 /// intended for architectures that can dynamically index the register file and
3360 /// want to force indexing in the native register size.
3361 LegalizerHelper::LegalizeResult
bitcastInsertVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3362 LegalizerHelper::bitcastInsertVectorElt(MachineInstr &MI, unsigned TypeIdx,
3363                                         LLT CastTy) {
3364   if (TypeIdx != 0)
3365     return UnableToLegalize;
3366 
3367   auto [Dst, DstTy, SrcVec, SrcVecTy, Val, ValTy, Idx, IdxTy] =
3368       MI.getFirst4RegLLTs();
3369   LLT VecTy = DstTy;
3370 
3371   LLT VecEltTy = VecTy.getElementType();
3372   LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
3373   const unsigned NewEltSize = NewEltTy.getSizeInBits();
3374   const unsigned OldEltSize = VecEltTy.getSizeInBits();
3375 
3376   unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
3377   unsigned OldNumElts = VecTy.getNumElements();
3378 
3379   Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
3380   if (NewNumElts < OldNumElts) {
3381     if (NewEltSize % OldEltSize != 0)
3382       return UnableToLegalize;
3383 
3384     // This only depends on powers of 2 because we use bit tricks to figure out
3385     // the bit offset we need to shift to get the target element. A general
3386     // expansion could emit division/multiply.
3387     if (!isPowerOf2_32(NewEltSize / OldEltSize))
3388       return UnableToLegalize;
3389 
3390     const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
3391     auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
3392 
3393     // Divide to get the index in the wider element type.
3394     auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
3395 
3396     Register ExtractedElt = CastVec;
3397     if (CastTy.isVector()) {
3398       ExtractedElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
3399                                                           ScaledIdx).getReg(0);
3400     }
3401 
3402     // Compute the bit offset into the register of the target element.
3403     Register OffsetBits = getBitcastWiderVectorElementOffset(
3404       MIRBuilder, Idx, NewEltSize, OldEltSize);
3405 
3406     Register InsertedElt = buildBitFieldInsert(MIRBuilder, ExtractedElt,
3407                                                Val, OffsetBits);
3408     if (CastTy.isVector()) {
3409       InsertedElt = MIRBuilder.buildInsertVectorElement(
3410         CastTy, CastVec, InsertedElt, ScaledIdx).getReg(0);
3411     }
3412 
3413     MIRBuilder.buildBitcast(Dst, InsertedElt);
3414     MI.eraseFromParent();
3415     return Legalized;
3416   }
3417 
3418   return UnableToLegalize;
3419 }
3420 
3421 // This attempts to handle G_CONCAT_VECTORS with illegal operands, particularly
3422 // those that have smaller than legal operands.
3423 //
3424 // <16 x s8> = G_CONCAT_VECTORS <4 x s8>, <4 x s8>, <4 x s8>, <4 x s8>
3425 //
3426 // ===>
3427 //
3428 // s32 = G_BITCAST <4 x s8>
3429 // s32 = G_BITCAST <4 x s8>
3430 // s32 = G_BITCAST <4 x s8>
3431 // s32 = G_BITCAST <4 x s8>
3432 // <4 x s32> = G_BUILD_VECTOR s32, s32, s32, s32
3433 // <16 x s8> = G_BITCAST <4 x s32>
3434 LegalizerHelper::LegalizeResult
bitcastConcatVector(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3435 LegalizerHelper::bitcastConcatVector(MachineInstr &MI, unsigned TypeIdx,
3436                                      LLT CastTy) {
3437   // Convert it to CONCAT instruction
3438   auto ConcatMI = dyn_cast<GConcatVectors>(&MI);
3439   if (!ConcatMI) {
3440     return UnableToLegalize;
3441   }
3442 
3443   // Check if bitcast is Legal
3444   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
3445   LLT SrcScalTy = LLT::scalar(SrcTy.getSizeInBits());
3446 
3447   // Check if the build vector is Legal
3448   if (!LI.isLegal({TargetOpcode::G_BUILD_VECTOR, {CastTy, SrcScalTy}})) {
3449     return UnableToLegalize;
3450   }
3451 
3452   // Bitcast the sources
3453   SmallVector<Register> BitcastRegs;
3454   for (unsigned i = 0; i < ConcatMI->getNumSources(); i++) {
3455     BitcastRegs.push_back(
3456         MIRBuilder.buildBitcast(SrcScalTy, ConcatMI->getSourceReg(i))
3457             .getReg(0));
3458   }
3459 
3460   // Build the scalar values into a vector
3461   Register BuildReg =
3462       MIRBuilder.buildBuildVector(CastTy, BitcastRegs).getReg(0);
3463   MIRBuilder.buildBitcast(DstReg, BuildReg);
3464 
3465   MI.eraseFromParent();
3466   return Legalized;
3467 }
3468 
lowerLoad(GAnyLoad & LoadMI)3469 LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
3470   // Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
3471   Register DstReg = LoadMI.getDstReg();
3472   Register PtrReg = LoadMI.getPointerReg();
3473   LLT DstTy = MRI.getType(DstReg);
3474   MachineMemOperand &MMO = LoadMI.getMMO();
3475   LLT MemTy = MMO.getMemoryType();
3476   MachineFunction &MF = MIRBuilder.getMF();
3477 
3478   unsigned MemSizeInBits = MemTy.getSizeInBits();
3479   unsigned MemStoreSizeInBits = 8 * MemTy.getSizeInBytes();
3480 
3481   if (MemSizeInBits != MemStoreSizeInBits) {
3482     if (MemTy.isVector())
3483       return UnableToLegalize;
3484 
3485     // Promote to a byte-sized load if not loading an integral number of
3486     // bytes.  For example, promote EXTLOAD:i20 -> EXTLOAD:i24.
3487     LLT WideMemTy = LLT::scalar(MemStoreSizeInBits);
3488     MachineMemOperand *NewMMO =
3489         MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), WideMemTy);
3490 
3491     Register LoadReg = DstReg;
3492     LLT LoadTy = DstTy;
3493 
3494     // If this wasn't already an extending load, we need to widen the result
3495     // register to avoid creating a load with a narrower result than the source.
3496     if (MemStoreSizeInBits > DstTy.getSizeInBits()) {
3497       LoadTy = WideMemTy;
3498       LoadReg = MRI.createGenericVirtualRegister(WideMemTy);
3499     }
3500 
3501     if (isa<GSExtLoad>(LoadMI)) {
3502       auto NewLoad = MIRBuilder.buildLoad(LoadTy, PtrReg, *NewMMO);
3503       MIRBuilder.buildSExtInReg(LoadReg, NewLoad, MemSizeInBits);
3504     } else if (isa<GZExtLoad>(LoadMI) || WideMemTy == LoadTy) {
3505       auto NewLoad = MIRBuilder.buildLoad(LoadTy, PtrReg, *NewMMO);
3506       // The extra bits are guaranteed to be zero, since we stored them that
3507       // way.  A zext load from Wide thus automatically gives zext from MemVT.
3508       MIRBuilder.buildAssertZExt(LoadReg, NewLoad, MemSizeInBits);
3509     } else {
3510       MIRBuilder.buildLoad(LoadReg, PtrReg, *NewMMO);
3511     }
3512 
3513     if (DstTy != LoadTy)
3514       MIRBuilder.buildTrunc(DstReg, LoadReg);
3515 
3516     LoadMI.eraseFromParent();
3517     return Legalized;
3518   }
3519 
3520   // Big endian lowering not implemented.
3521   if (MIRBuilder.getDataLayout().isBigEndian())
3522     return UnableToLegalize;
3523 
3524   // This load needs splitting into power of 2 sized loads.
3525   //
3526   // Our strategy here is to generate anyextending loads for the smaller
3527   // types up to next power-2 result type, and then combine the two larger
3528   // result values together, before truncating back down to the non-pow-2
3529   // type.
3530   // E.g. v1 = i24 load =>
3531   // v2 = i32 zextload (2 byte)
3532   // v3 = i32 load (1 byte)
3533   // v4 = i32 shl v3, 16
3534   // v5 = i32 or v4, v2
3535   // v1 = i24 trunc v5
3536   // By doing this we generate the correct truncate which should get
3537   // combined away as an artifact with a matching extend.
3538 
3539   uint64_t LargeSplitSize, SmallSplitSize;
3540 
3541   if (!isPowerOf2_32(MemSizeInBits)) {
3542     // This load needs splitting into power of 2 sized loads.
3543     LargeSplitSize = llvm::bit_floor(MemSizeInBits);
3544     SmallSplitSize = MemSizeInBits - LargeSplitSize;
3545   } else {
3546     // This is already a power of 2, but we still need to split this in half.
3547     //
3548     // Assume we're being asked to decompose an unaligned load.
3549     // TODO: If this requires multiple splits, handle them all at once.
3550     auto &Ctx = MF.getFunction().getContext();
3551     if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
3552       return UnableToLegalize;
3553 
3554     SmallSplitSize = LargeSplitSize = MemSizeInBits / 2;
3555   }
3556 
3557   if (MemTy.isVector()) {
3558     // TODO: Handle vector extloads
3559     if (MemTy != DstTy)
3560       return UnableToLegalize;
3561 
3562     // TODO: We can do better than scalarizing the vector and at least split it
3563     // in half.
3564     return reduceLoadStoreWidth(LoadMI, 0, DstTy.getElementType());
3565   }
3566 
3567   MachineMemOperand *LargeMMO =
3568       MF.getMachineMemOperand(&MMO, 0, LargeSplitSize / 8);
3569   MachineMemOperand *SmallMMO =
3570       MF.getMachineMemOperand(&MMO, LargeSplitSize / 8, SmallSplitSize / 8);
3571 
3572   LLT PtrTy = MRI.getType(PtrReg);
3573   unsigned AnyExtSize = PowerOf2Ceil(DstTy.getSizeInBits());
3574   LLT AnyExtTy = LLT::scalar(AnyExtSize);
3575   auto LargeLoad = MIRBuilder.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, AnyExtTy,
3576                                              PtrReg, *LargeMMO);
3577 
3578   auto OffsetCst = MIRBuilder.buildConstant(LLT::scalar(PtrTy.getSizeInBits()),
3579                                             LargeSplitSize / 8);
3580   Register PtrAddReg = MRI.createGenericVirtualRegister(PtrTy);
3581   auto SmallPtr = MIRBuilder.buildPtrAdd(PtrAddReg, PtrReg, OffsetCst);
3582   auto SmallLoad = MIRBuilder.buildLoadInstr(LoadMI.getOpcode(), AnyExtTy,
3583                                              SmallPtr, *SmallMMO);
3584 
3585   auto ShiftAmt = MIRBuilder.buildConstant(AnyExtTy, LargeSplitSize);
3586   auto Shift = MIRBuilder.buildShl(AnyExtTy, SmallLoad, ShiftAmt);
3587 
3588   if (AnyExtTy == DstTy)
3589     MIRBuilder.buildOr(DstReg, Shift, LargeLoad);
3590   else if (AnyExtTy.getSizeInBits() != DstTy.getSizeInBits()) {
3591     auto Or = MIRBuilder.buildOr(AnyExtTy, Shift, LargeLoad);
3592     MIRBuilder.buildTrunc(DstReg, {Or});
3593   } else {
3594     assert(DstTy.isPointer() && "expected pointer");
3595     auto Or = MIRBuilder.buildOr(AnyExtTy, Shift, LargeLoad);
3596 
3597     // FIXME: We currently consider this to be illegal for non-integral address
3598     // spaces, but we need still need a way to reinterpret the bits.
3599     MIRBuilder.buildIntToPtr(DstReg, Or);
3600   }
3601 
3602   LoadMI.eraseFromParent();
3603   return Legalized;
3604 }
3605 
lowerStore(GStore & StoreMI)3606 LegalizerHelper::LegalizeResult LegalizerHelper::lowerStore(GStore &StoreMI) {
3607   // Lower a non-power of 2 store into multiple pow-2 stores.
3608   // E.g. split an i24 store into an i16 store + i8 store.
3609   // We do this by first extending the stored value to the next largest power
3610   // of 2 type, and then using truncating stores to store the components.
3611   // By doing this, likewise with G_LOAD, generate an extend that can be
3612   // artifact-combined away instead of leaving behind extracts.
3613   Register SrcReg = StoreMI.getValueReg();
3614   Register PtrReg = StoreMI.getPointerReg();
3615   LLT SrcTy = MRI.getType(SrcReg);
3616   MachineFunction &MF = MIRBuilder.getMF();
3617   MachineMemOperand &MMO = **StoreMI.memoperands_begin();
3618   LLT MemTy = MMO.getMemoryType();
3619 
3620   unsigned StoreWidth = MemTy.getSizeInBits();
3621   unsigned StoreSizeInBits = 8 * MemTy.getSizeInBytes();
3622 
3623   if (StoreWidth != StoreSizeInBits) {
3624     if (SrcTy.isVector())
3625       return UnableToLegalize;
3626 
3627     // Promote to a byte-sized store with upper bits zero if not
3628     // storing an integral number of bytes.  For example, promote
3629     // TRUNCSTORE:i1 X -> TRUNCSTORE:i8 (and X, 1)
3630     LLT WideTy = LLT::scalar(StoreSizeInBits);
3631 
3632     if (StoreSizeInBits > SrcTy.getSizeInBits()) {
3633       // Avoid creating a store with a narrower source than result.
3634       SrcReg = MIRBuilder.buildAnyExt(WideTy, SrcReg).getReg(0);
3635       SrcTy = WideTy;
3636     }
3637 
3638     auto ZextInReg = MIRBuilder.buildZExtInReg(SrcTy, SrcReg, StoreWidth);
3639 
3640     MachineMemOperand *NewMMO =
3641         MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), WideTy);
3642     MIRBuilder.buildStore(ZextInReg, PtrReg, *NewMMO);
3643     StoreMI.eraseFromParent();
3644     return Legalized;
3645   }
3646 
3647   if (MemTy.isVector()) {
3648     // TODO: Handle vector trunc stores
3649     if (MemTy != SrcTy)
3650       return UnableToLegalize;
3651 
3652     // TODO: We can do better than scalarizing the vector and at least split it
3653     // in half.
3654     return reduceLoadStoreWidth(StoreMI, 0, SrcTy.getElementType());
3655   }
3656 
3657   unsigned MemSizeInBits = MemTy.getSizeInBits();
3658   uint64_t LargeSplitSize, SmallSplitSize;
3659 
3660   if (!isPowerOf2_32(MemSizeInBits)) {
3661     LargeSplitSize = llvm::bit_floor<uint64_t>(MemTy.getSizeInBits());
3662     SmallSplitSize = MemTy.getSizeInBits() - LargeSplitSize;
3663   } else {
3664     auto &Ctx = MF.getFunction().getContext();
3665     if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
3666       return UnableToLegalize; // Don't know what we're being asked to do.
3667 
3668     SmallSplitSize = LargeSplitSize = MemSizeInBits / 2;
3669   }
3670 
3671   // Extend to the next pow-2. If this store was itself the result of lowering,
3672   // e.g. an s56 store being broken into s32 + s24, we might have a stored type
3673   // that's wider than the stored size.
3674   unsigned AnyExtSize = PowerOf2Ceil(MemTy.getSizeInBits());
3675   const LLT NewSrcTy = LLT::scalar(AnyExtSize);
3676 
3677   if (SrcTy.isPointer()) {
3678     const LLT IntPtrTy = LLT::scalar(SrcTy.getSizeInBits());
3679     SrcReg = MIRBuilder.buildPtrToInt(IntPtrTy, SrcReg).getReg(0);
3680   }
3681 
3682   auto ExtVal = MIRBuilder.buildAnyExtOrTrunc(NewSrcTy, SrcReg);
3683 
3684   // Obtain the smaller value by shifting away the larger value.
3685   auto ShiftAmt = MIRBuilder.buildConstant(NewSrcTy, LargeSplitSize);
3686   auto SmallVal = MIRBuilder.buildLShr(NewSrcTy, ExtVal, ShiftAmt);
3687 
3688   // Generate the PtrAdd and truncating stores.
3689   LLT PtrTy = MRI.getType(PtrReg);
3690   auto OffsetCst = MIRBuilder.buildConstant(
3691     LLT::scalar(PtrTy.getSizeInBits()), LargeSplitSize / 8);
3692   auto SmallPtr =
3693     MIRBuilder.buildPtrAdd(PtrTy, PtrReg, OffsetCst);
3694 
3695   MachineMemOperand *LargeMMO =
3696     MF.getMachineMemOperand(&MMO, 0, LargeSplitSize / 8);
3697   MachineMemOperand *SmallMMO =
3698     MF.getMachineMemOperand(&MMO, LargeSplitSize / 8, SmallSplitSize / 8);
3699   MIRBuilder.buildStore(ExtVal, PtrReg, *LargeMMO);
3700   MIRBuilder.buildStore(SmallVal, SmallPtr, *SmallMMO);
3701   StoreMI.eraseFromParent();
3702   return Legalized;
3703 }
3704 
3705 LegalizerHelper::LegalizeResult
bitcast(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3706 LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
3707   switch (MI.getOpcode()) {
3708   case TargetOpcode::G_LOAD: {
3709     if (TypeIdx != 0)
3710       return UnableToLegalize;
3711     MachineMemOperand &MMO = **MI.memoperands_begin();
3712 
3713     // Not sure how to interpret a bitcast of an extending load.
3714     if (MMO.getMemoryType().getSizeInBits() != CastTy.getSizeInBits())
3715       return UnableToLegalize;
3716 
3717     Observer.changingInstr(MI);
3718     bitcastDst(MI, CastTy, 0);
3719     MMO.setType(CastTy);
3720     // The range metadata is no longer valid when reinterpreted as a different
3721     // type.
3722     MMO.clearRanges();
3723     Observer.changedInstr(MI);
3724     return Legalized;
3725   }
3726   case TargetOpcode::G_STORE: {
3727     if (TypeIdx != 0)
3728       return UnableToLegalize;
3729 
3730     MachineMemOperand &MMO = **MI.memoperands_begin();
3731 
3732     // Not sure how to interpret a bitcast of a truncating store.
3733     if (MMO.getMemoryType().getSizeInBits() != CastTy.getSizeInBits())
3734       return UnableToLegalize;
3735 
3736     Observer.changingInstr(MI);
3737     bitcastSrc(MI, CastTy, 0);
3738     MMO.setType(CastTy);
3739     Observer.changedInstr(MI);
3740     return Legalized;
3741   }
3742   case TargetOpcode::G_SELECT: {
3743     if (TypeIdx != 0)
3744       return UnableToLegalize;
3745 
3746     if (MRI.getType(MI.getOperand(1).getReg()).isVector()) {
3747       LLVM_DEBUG(
3748           dbgs() << "bitcast action not implemented for vector select\n");
3749       return UnableToLegalize;
3750     }
3751 
3752     Observer.changingInstr(MI);
3753     bitcastSrc(MI, CastTy, 2);
3754     bitcastSrc(MI, CastTy, 3);
3755     bitcastDst(MI, CastTy, 0);
3756     Observer.changedInstr(MI);
3757     return Legalized;
3758   }
3759   case TargetOpcode::G_AND:
3760   case TargetOpcode::G_OR:
3761   case TargetOpcode::G_XOR: {
3762     Observer.changingInstr(MI);
3763     bitcastSrc(MI, CastTy, 1);
3764     bitcastSrc(MI, CastTy, 2);
3765     bitcastDst(MI, CastTy, 0);
3766     Observer.changedInstr(MI);
3767     return Legalized;
3768   }
3769   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
3770     return bitcastExtractVectorElt(MI, TypeIdx, CastTy);
3771   case TargetOpcode::G_INSERT_VECTOR_ELT:
3772     return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
3773   case TargetOpcode::G_CONCAT_VECTORS:
3774     return bitcastConcatVector(MI, TypeIdx, CastTy);
3775   default:
3776     return UnableToLegalize;
3777   }
3778 }
3779 
3780 // Legalize an instruction by changing the opcode in place.
changeOpcode(MachineInstr & MI,unsigned NewOpcode)3781 void LegalizerHelper::changeOpcode(MachineInstr &MI, unsigned NewOpcode) {
3782     Observer.changingInstr(MI);
3783     MI.setDesc(MIRBuilder.getTII().get(NewOpcode));
3784     Observer.changedInstr(MI);
3785 }
3786 
3787 LegalizerHelper::LegalizeResult
lower(MachineInstr & MI,unsigned TypeIdx,LLT LowerHintTy)3788 LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
3789   using namespace TargetOpcode;
3790 
3791   switch(MI.getOpcode()) {
3792   default:
3793     return UnableToLegalize;
3794   case TargetOpcode::G_FCONSTANT:
3795     return lowerFConstant(MI);
3796   case TargetOpcode::G_BITCAST:
3797     return lowerBitcast(MI);
3798   case TargetOpcode::G_SREM:
3799   case TargetOpcode::G_UREM: {
3800     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3801     auto Quot =
3802         MIRBuilder.buildInstr(MI.getOpcode() == G_SREM ? G_SDIV : G_UDIV, {Ty},
3803                               {MI.getOperand(1), MI.getOperand(2)});
3804 
3805     auto Prod = MIRBuilder.buildMul(Ty, Quot, MI.getOperand(2));
3806     MIRBuilder.buildSub(MI.getOperand(0), MI.getOperand(1), Prod);
3807     MI.eraseFromParent();
3808     return Legalized;
3809   }
3810   case TargetOpcode::G_SADDO:
3811   case TargetOpcode::G_SSUBO:
3812     return lowerSADDO_SSUBO(MI);
3813   case TargetOpcode::G_UMULH:
3814   case TargetOpcode::G_SMULH:
3815     return lowerSMULH_UMULH(MI);
3816   case TargetOpcode::G_SMULO:
3817   case TargetOpcode::G_UMULO: {
3818     // Generate G_UMULH/G_SMULH to check for overflow and a normal G_MUL for the
3819     // result.
3820     auto [Res, Overflow, LHS, RHS] = MI.getFirst4Regs();
3821     LLT Ty = MRI.getType(Res);
3822 
3823     unsigned Opcode = MI.getOpcode() == TargetOpcode::G_SMULO
3824                           ? TargetOpcode::G_SMULH
3825                           : TargetOpcode::G_UMULH;
3826 
3827     Observer.changingInstr(MI);
3828     const auto &TII = MIRBuilder.getTII();
3829     MI.setDesc(TII.get(TargetOpcode::G_MUL));
3830     MI.removeOperand(1);
3831     Observer.changedInstr(MI);
3832 
3833     auto HiPart = MIRBuilder.buildInstr(Opcode, {Ty}, {LHS, RHS});
3834     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3835 
3836     // Move insert point forward so we can use the Res register if needed.
3837     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
3838 
3839     // For *signed* multiply, overflow is detected by checking:
3840     // (hi != (lo >> bitwidth-1))
3841     if (Opcode == TargetOpcode::G_SMULH) {
3842       auto ShiftAmt = MIRBuilder.buildConstant(Ty, Ty.getSizeInBits() - 1);
3843       auto Shifted = MIRBuilder.buildAShr(Ty, Res, ShiftAmt);
3844       MIRBuilder.buildICmp(CmpInst::ICMP_NE, Overflow, HiPart, Shifted);
3845     } else {
3846       MIRBuilder.buildICmp(CmpInst::ICMP_NE, Overflow, HiPart, Zero);
3847     }
3848     return Legalized;
3849   }
3850   case TargetOpcode::G_FNEG: {
3851     auto [Res, SubByReg] = MI.getFirst2Regs();
3852     LLT Ty = MRI.getType(Res);
3853 
3854     // TODO: Handle vector types once we are able to
3855     // represent them.
3856     if (Ty.isVector())
3857       return UnableToLegalize;
3858     auto SignMask =
3859         MIRBuilder.buildConstant(Ty, APInt::getSignMask(Ty.getSizeInBits()));
3860     MIRBuilder.buildXor(Res, SubByReg, SignMask);
3861     MI.eraseFromParent();
3862     return Legalized;
3863   }
3864   case TargetOpcode::G_FSUB:
3865   case TargetOpcode::G_STRICT_FSUB: {
3866     auto [Res, LHS, RHS] = MI.getFirst3Regs();
3867     LLT Ty = MRI.getType(Res);
3868 
3869     // Lower (G_FSUB LHS, RHS) to (G_FADD LHS, (G_FNEG RHS)).
3870     auto Neg = MIRBuilder.buildFNeg(Ty, RHS);
3871 
3872     if (MI.getOpcode() == TargetOpcode::G_STRICT_FSUB)
3873       MIRBuilder.buildStrictFAdd(Res, LHS, Neg, MI.getFlags());
3874     else
3875       MIRBuilder.buildFAdd(Res, LHS, Neg, MI.getFlags());
3876 
3877     MI.eraseFromParent();
3878     return Legalized;
3879   }
3880   case TargetOpcode::G_FMAD:
3881     return lowerFMad(MI);
3882   case TargetOpcode::G_FFLOOR:
3883     return lowerFFloor(MI);
3884   case TargetOpcode::G_INTRINSIC_ROUND:
3885     return lowerIntrinsicRound(MI);
3886   case TargetOpcode::G_FRINT: {
3887     // Since round even is the assumed rounding mode for unconstrained FP
3888     // operations, rint and roundeven are the same operation.
3889     changeOpcode(MI, TargetOpcode::G_INTRINSIC_ROUNDEVEN);
3890     return Legalized;
3891   }
3892   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
3893     auto [OldValRes, SuccessRes, Addr, CmpVal, NewVal] = MI.getFirst5Regs();
3894     Register NewOldValRes = MRI.cloneVirtualRegister(OldValRes);
3895     MIRBuilder.buildAtomicCmpXchg(NewOldValRes, Addr, CmpVal, NewVal,
3896                                   **MI.memoperands_begin());
3897     MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, NewOldValRes, CmpVal);
3898     MIRBuilder.buildCopy(OldValRes, NewOldValRes);
3899     MI.eraseFromParent();
3900     return Legalized;
3901   }
3902   case TargetOpcode::G_LOAD:
3903   case TargetOpcode::G_SEXTLOAD:
3904   case TargetOpcode::G_ZEXTLOAD:
3905     return lowerLoad(cast<GAnyLoad>(MI));
3906   case TargetOpcode::G_STORE:
3907     return lowerStore(cast<GStore>(MI));
3908   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
3909   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
3910   case TargetOpcode::G_CTLZ:
3911   case TargetOpcode::G_CTTZ:
3912   case TargetOpcode::G_CTPOP:
3913     return lowerBitCount(MI);
3914   case G_UADDO: {
3915     auto [Res, CarryOut, LHS, RHS] = MI.getFirst4Regs();
3916 
3917     Register NewRes = MRI.cloneVirtualRegister(Res);
3918 
3919     MIRBuilder.buildAdd(NewRes, LHS, RHS);
3920     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, NewRes, RHS);
3921 
3922     MIRBuilder.buildCopy(Res, NewRes);
3923 
3924     MI.eraseFromParent();
3925     return Legalized;
3926   }
3927   case G_UADDE: {
3928     auto [Res, CarryOut, LHS, RHS, CarryIn] = MI.getFirst5Regs();
3929     const LLT CondTy = MRI.getType(CarryOut);
3930     const LLT Ty = MRI.getType(Res);
3931 
3932     Register NewRes = MRI.cloneVirtualRegister(Res);
3933 
3934     // Initial add of the two operands.
3935     auto TmpRes = MIRBuilder.buildAdd(Ty, LHS, RHS);
3936 
3937     // Initial check for carry.
3938     auto Carry = MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CondTy, TmpRes, LHS);
3939 
3940     // Add the sum and the carry.
3941     auto ZExtCarryIn = MIRBuilder.buildZExt(Ty, CarryIn);
3942     MIRBuilder.buildAdd(NewRes, TmpRes, ZExtCarryIn);
3943 
3944     // Second check for carry. We can only carry if the initial sum is all 1s
3945     // and the carry is set, resulting in a new sum of 0.
3946     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3947     auto ResEqZero =
3948         MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, NewRes, Zero);
3949     auto Carry2 = MIRBuilder.buildAnd(CondTy, ResEqZero, CarryIn);
3950     MIRBuilder.buildOr(CarryOut, Carry, Carry2);
3951 
3952     MIRBuilder.buildCopy(Res, NewRes);
3953 
3954     MI.eraseFromParent();
3955     return Legalized;
3956   }
3957   case G_USUBO: {
3958     auto [Res, BorrowOut, LHS, RHS] = MI.getFirst4Regs();
3959 
3960     MIRBuilder.buildSub(Res, LHS, RHS);
3961     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, BorrowOut, LHS, RHS);
3962 
3963     MI.eraseFromParent();
3964     return Legalized;
3965   }
3966   case G_USUBE: {
3967     auto [Res, BorrowOut, LHS, RHS, BorrowIn] = MI.getFirst5Regs();
3968     const LLT CondTy = MRI.getType(BorrowOut);
3969     const LLT Ty = MRI.getType(Res);
3970 
3971     // Initial subtract of the two operands.
3972     auto TmpRes = MIRBuilder.buildSub(Ty, LHS, RHS);
3973 
3974     // Initial check for borrow.
3975     auto Borrow = MIRBuilder.buildICmp(CmpInst::ICMP_UGT, CondTy, TmpRes, LHS);
3976 
3977     // Subtract the borrow from the first subtract.
3978     auto ZExtBorrowIn = MIRBuilder.buildZExt(Ty, BorrowIn);
3979     MIRBuilder.buildSub(Res, TmpRes, ZExtBorrowIn);
3980 
3981     // Second check for borrow. We can only borrow if the initial difference is
3982     // 0 and the borrow is set, resulting in a new difference of all 1s.
3983     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3984     auto TmpResEqZero =
3985         MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, TmpRes, Zero);
3986     auto Borrow2 = MIRBuilder.buildAnd(CondTy, TmpResEqZero, BorrowIn);
3987     MIRBuilder.buildOr(BorrowOut, Borrow, Borrow2);
3988 
3989     MI.eraseFromParent();
3990     return Legalized;
3991   }
3992   case G_UITOFP:
3993     return lowerUITOFP(MI);
3994   case G_SITOFP:
3995     return lowerSITOFP(MI);
3996   case G_FPTOUI:
3997     return lowerFPTOUI(MI);
3998   case G_FPTOSI:
3999     return lowerFPTOSI(MI);
4000   case G_FPTRUNC:
4001     return lowerFPTRUNC(MI);
4002   case G_FPOWI:
4003     return lowerFPOWI(MI);
4004   case G_SMIN:
4005   case G_SMAX:
4006   case G_UMIN:
4007   case G_UMAX:
4008     return lowerMinMax(MI);
4009   case G_SCMP:
4010   case G_UCMP:
4011     return lowerThreewayCompare(MI);
4012   case G_FCOPYSIGN:
4013     return lowerFCopySign(MI);
4014   case G_FMINNUM:
4015   case G_FMAXNUM:
4016     return lowerFMinNumMaxNum(MI);
4017   case G_MERGE_VALUES:
4018     return lowerMergeValues(MI);
4019   case G_UNMERGE_VALUES:
4020     return lowerUnmergeValues(MI);
4021   case TargetOpcode::G_SEXT_INREG: {
4022     assert(MI.getOperand(2).isImm() && "Expected immediate");
4023     int64_t SizeInBits = MI.getOperand(2).getImm();
4024 
4025     auto [DstReg, SrcReg] = MI.getFirst2Regs();
4026     LLT DstTy = MRI.getType(DstReg);
4027     Register TmpRes = MRI.createGenericVirtualRegister(DstTy);
4028 
4029     auto MIBSz = MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - SizeInBits);
4030     MIRBuilder.buildShl(TmpRes, SrcReg, MIBSz->getOperand(0));
4031     MIRBuilder.buildAShr(DstReg, TmpRes, MIBSz->getOperand(0));
4032     MI.eraseFromParent();
4033     return Legalized;
4034   }
4035   case G_EXTRACT_VECTOR_ELT:
4036   case G_INSERT_VECTOR_ELT:
4037     return lowerExtractInsertVectorElt(MI);
4038   case G_SHUFFLE_VECTOR:
4039     return lowerShuffleVector(MI);
4040   case G_VECTOR_COMPRESS:
4041     return lowerVECTOR_COMPRESS(MI);
4042   case G_DYN_STACKALLOC:
4043     return lowerDynStackAlloc(MI);
4044   case G_STACKSAVE:
4045     return lowerStackSave(MI);
4046   case G_STACKRESTORE:
4047     return lowerStackRestore(MI);
4048   case G_EXTRACT:
4049     return lowerExtract(MI);
4050   case G_INSERT:
4051     return lowerInsert(MI);
4052   case G_BSWAP:
4053     return lowerBswap(MI);
4054   case G_BITREVERSE:
4055     return lowerBitreverse(MI);
4056   case G_READ_REGISTER:
4057   case G_WRITE_REGISTER:
4058     return lowerReadWriteRegister(MI);
4059   case G_UADDSAT:
4060   case G_USUBSAT: {
4061     // Try to make a reasonable guess about which lowering strategy to use. The
4062     // target can override this with custom lowering and calling the
4063     // implementation functions.
4064     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
4065     if (LI.isLegalOrCustom({G_UMIN, Ty}))
4066       return lowerAddSubSatToMinMax(MI);
4067     return lowerAddSubSatToAddoSubo(MI);
4068   }
4069   case G_SADDSAT:
4070   case G_SSUBSAT: {
4071     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
4072 
4073     // FIXME: It would probably make more sense to see if G_SADDO is preferred,
4074     // since it's a shorter expansion. However, we would need to figure out the
4075     // preferred boolean type for the carry out for the query.
4076     if (LI.isLegalOrCustom({G_SMIN, Ty}) && LI.isLegalOrCustom({G_SMAX, Ty}))
4077       return lowerAddSubSatToMinMax(MI);
4078     return lowerAddSubSatToAddoSubo(MI);
4079   }
4080   case G_SSHLSAT:
4081   case G_USHLSAT:
4082     return lowerShlSat(MI);
4083   case G_ABS:
4084     return lowerAbsToAddXor(MI);
4085   case G_SELECT:
4086     return lowerSelect(MI);
4087   case G_IS_FPCLASS:
4088     return lowerISFPCLASS(MI);
4089   case G_SDIVREM:
4090   case G_UDIVREM:
4091     return lowerDIVREM(MI);
4092   case G_FSHL:
4093   case G_FSHR:
4094     return lowerFunnelShift(MI);
4095   case G_ROTL:
4096   case G_ROTR:
4097     return lowerRotate(MI);
4098   case G_MEMSET:
4099   case G_MEMCPY:
4100   case G_MEMMOVE:
4101     return lowerMemCpyFamily(MI);
4102   case G_MEMCPY_INLINE:
4103     return lowerMemcpyInline(MI);
4104   case G_ZEXT:
4105   case G_SEXT:
4106   case G_ANYEXT:
4107     return lowerEXT(MI);
4108   case G_TRUNC:
4109     return lowerTRUNC(MI);
4110   GISEL_VECREDUCE_CASES_NONSEQ
4111     return lowerVectorReduction(MI);
4112   case G_VAARG:
4113     return lowerVAArg(MI);
4114   }
4115 }
4116 
getStackTemporaryAlignment(LLT Ty,Align MinAlign) const4117 Align LegalizerHelper::getStackTemporaryAlignment(LLT Ty,
4118                                                   Align MinAlign) const {
4119   // FIXME: We're missing a way to go back from LLT to llvm::Type to query the
4120   // datalayout for the preferred alignment. Also there should be a target hook
4121   // for this to allow targets to reduce the alignment and ignore the
4122   // datalayout. e.g. AMDGPU should always use a 4-byte alignment, regardless of
4123   // the type.
4124   return std::max(Align(PowerOf2Ceil(Ty.getSizeInBytes())), MinAlign);
4125 }
4126 
4127 MachineInstrBuilder
createStackTemporary(TypeSize Bytes,Align Alignment,MachinePointerInfo & PtrInfo)4128 LegalizerHelper::createStackTemporary(TypeSize Bytes, Align Alignment,
4129                                       MachinePointerInfo &PtrInfo) {
4130   MachineFunction &MF = MIRBuilder.getMF();
4131   const DataLayout &DL = MIRBuilder.getDataLayout();
4132   int FrameIdx = MF.getFrameInfo().CreateStackObject(Bytes, Alignment, false);
4133 
4134   unsigned AddrSpace = DL.getAllocaAddrSpace();
4135   LLT FramePtrTy = LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
4136 
4137   PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIdx);
4138   return MIRBuilder.buildFrameIndex(FramePtrTy, FrameIdx);
4139 }
4140 
clampVectorIndex(MachineIRBuilder & B,Register IdxReg,LLT VecTy)4141 static Register clampVectorIndex(MachineIRBuilder &B, Register IdxReg,
4142                                  LLT VecTy) {
4143   LLT IdxTy = B.getMRI()->getType(IdxReg);
4144   unsigned NElts = VecTy.getNumElements();
4145 
4146   int64_t IdxVal;
4147   if (mi_match(IdxReg, *B.getMRI(), m_ICst(IdxVal))) {
4148     if (IdxVal < VecTy.getNumElements())
4149       return IdxReg;
4150     // If a constant index would be out of bounds, clamp it as well.
4151   }
4152 
4153   if (isPowerOf2_32(NElts)) {
4154     APInt Imm = APInt::getLowBitsSet(IdxTy.getSizeInBits(), Log2_32(NElts));
4155     return B.buildAnd(IdxTy, IdxReg, B.buildConstant(IdxTy, Imm)).getReg(0);
4156   }
4157 
4158   return B.buildUMin(IdxTy, IdxReg, B.buildConstant(IdxTy, NElts - 1))
4159       .getReg(0);
4160 }
4161 
getVectorElementPointer(Register VecPtr,LLT VecTy,Register Index)4162 Register LegalizerHelper::getVectorElementPointer(Register VecPtr, LLT VecTy,
4163                                                   Register Index) {
4164   LLT EltTy = VecTy.getElementType();
4165 
4166   // Calculate the element offset and add it to the pointer.
4167   unsigned EltSize = EltTy.getSizeInBits() / 8; // FIXME: should be ABI size.
4168   assert(EltSize * 8 == EltTy.getSizeInBits() &&
4169          "Converting bits to bytes lost precision");
4170 
4171   Index = clampVectorIndex(MIRBuilder, Index, VecTy);
4172 
4173   // Convert index to the correct size for the address space.
4174   const DataLayout &DL = MIRBuilder.getDataLayout();
4175   unsigned AS = MRI.getType(VecPtr).getAddressSpace();
4176   unsigned IndexSizeInBits = DL.getIndexSize(AS) * 8;
4177   LLT IdxTy = MRI.getType(Index).changeElementSize(IndexSizeInBits);
4178   if (IdxTy != MRI.getType(Index))
4179     Index = MIRBuilder.buildSExtOrTrunc(IdxTy, Index).getReg(0);
4180 
4181   auto Mul = MIRBuilder.buildMul(IdxTy, Index,
4182                                  MIRBuilder.buildConstant(IdxTy, EltSize));
4183 
4184   LLT PtrTy = MRI.getType(VecPtr);
4185   return MIRBuilder.buildPtrAdd(PtrTy, VecPtr, Mul).getReg(0);
4186 }
4187 
4188 #ifndef NDEBUG
4189 /// Check that all vector operands have same number of elements. Other operands
4190 /// should be listed in NonVecOp.
hasSameNumEltsOnAllVectorOperands(GenericMachineInstr & MI,MachineRegisterInfo & MRI,std::initializer_list<unsigned> NonVecOpIndices)4191 static bool hasSameNumEltsOnAllVectorOperands(
4192     GenericMachineInstr &MI, MachineRegisterInfo &MRI,
4193     std::initializer_list<unsigned> NonVecOpIndices) {
4194   if (MI.getNumMemOperands() != 0)
4195     return false;
4196 
4197   LLT VecTy = MRI.getType(MI.getReg(0));
4198   if (!VecTy.isVector())
4199     return false;
4200   unsigned NumElts = VecTy.getNumElements();
4201 
4202   for (unsigned OpIdx = 1; OpIdx < MI.getNumOperands(); ++OpIdx) {
4203     MachineOperand &Op = MI.getOperand(OpIdx);
4204     if (!Op.isReg()) {
4205       if (!is_contained(NonVecOpIndices, OpIdx))
4206         return false;
4207       continue;
4208     }
4209 
4210     LLT Ty = MRI.getType(Op.getReg());
4211     if (!Ty.isVector()) {
4212       if (!is_contained(NonVecOpIndices, OpIdx))
4213         return false;
4214       continue;
4215     }
4216 
4217     if (Ty.getNumElements() != NumElts)
4218       return false;
4219   }
4220 
4221   return true;
4222 }
4223 #endif
4224 
4225 /// Fill \p DstOps with DstOps that have same number of elements combined as
4226 /// the Ty. These DstOps have either scalar type when \p NumElts = 1 or are
4227 /// vectors with \p NumElts elements. When Ty.getNumElements() is not multiple
4228 /// of \p NumElts last DstOp (leftover) has fewer then \p NumElts elements.
makeDstOps(SmallVectorImpl<DstOp> & DstOps,LLT Ty,unsigned NumElts)4229 static void makeDstOps(SmallVectorImpl<DstOp> &DstOps, LLT Ty,
4230                        unsigned NumElts) {
4231   LLT LeftoverTy;
4232   assert(Ty.isVector() && "Expected vector type");
4233   LLT EltTy = Ty.getElementType();
4234   LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
4235   int NumParts, NumLeftover;
4236   std::tie(NumParts, NumLeftover) =
4237       getNarrowTypeBreakDown(Ty, NarrowTy, LeftoverTy);
4238 
4239   assert(NumParts > 0 && "Error in getNarrowTypeBreakDown");
4240   for (int i = 0; i < NumParts; ++i) {
4241     DstOps.push_back(NarrowTy);
4242   }
4243 
4244   if (LeftoverTy.isValid()) {
4245     assert(NumLeftover == 1 && "expected exactly one leftover");
4246     DstOps.push_back(LeftoverTy);
4247   }
4248 }
4249 
4250 /// Operand \p Op is used on \p N sub-instructions. Fill \p Ops with \p N SrcOps
4251 /// made from \p Op depending on operand type.
broadcastSrcOp(SmallVectorImpl<SrcOp> & Ops,unsigned N,MachineOperand & Op)4252 static void broadcastSrcOp(SmallVectorImpl<SrcOp> &Ops, unsigned N,
4253                            MachineOperand &Op) {
4254   for (unsigned i = 0; i < N; ++i) {
4255     if (Op.isReg())
4256       Ops.push_back(Op.getReg());
4257     else if (Op.isImm())
4258       Ops.push_back(Op.getImm());
4259     else if (Op.isPredicate())
4260       Ops.push_back(static_cast<CmpInst::Predicate>(Op.getPredicate()));
4261     else
4262       llvm_unreachable("Unsupported type");
4263   }
4264 }
4265 
4266 // Handle splitting vector operations which need to have the same number of
4267 // elements in each type index, but each type index may have a different element
4268 // type.
4269 //
4270 // e.g.  <4 x s64> = G_SHL <4 x s64>, <4 x s32> ->
4271 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
4272 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
4273 //
4274 // Also handles some irregular breakdown cases, e.g.
4275 // e.g.  <3 x s64> = G_SHL <3 x s64>, <3 x s32> ->
4276 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
4277 //             s64 = G_SHL s64, s32
4278 LegalizerHelper::LegalizeResult
fewerElementsVectorMultiEltType(GenericMachineInstr & MI,unsigned NumElts,std::initializer_list<unsigned> NonVecOpIndices)4279 LegalizerHelper::fewerElementsVectorMultiEltType(
4280     GenericMachineInstr &MI, unsigned NumElts,
4281     std::initializer_list<unsigned> NonVecOpIndices) {
4282   assert(hasSameNumEltsOnAllVectorOperands(MI, MRI, NonVecOpIndices) &&
4283          "Non-compatible opcode or not specified non-vector operands");
4284   unsigned OrigNumElts = MRI.getType(MI.getReg(0)).getNumElements();
4285 
4286   unsigned NumInputs = MI.getNumOperands() - MI.getNumDefs();
4287   unsigned NumDefs = MI.getNumDefs();
4288 
4289   // Create DstOps (sub-vectors with NumElts elts + Leftover) for each output.
4290   // Build instructions with DstOps to use instruction found by CSE directly.
4291   // CSE copies found instruction into given vreg when building with vreg dest.
4292   SmallVector<SmallVector<DstOp, 8>, 2> OutputOpsPieces(NumDefs);
4293   // Output registers will be taken from created instructions.
4294   SmallVector<SmallVector<Register, 8>, 2> OutputRegs(NumDefs);
4295   for (unsigned i = 0; i < NumDefs; ++i) {
4296     makeDstOps(OutputOpsPieces[i], MRI.getType(MI.getReg(i)), NumElts);
4297   }
4298 
4299   // Split vector input operands into sub-vectors with NumElts elts + Leftover.
4300   // Operands listed in NonVecOpIndices will be used as is without splitting;
4301   // examples: compare predicate in icmp and fcmp (op 1), vector select with i1
4302   // scalar condition (op 1), immediate in sext_inreg (op 2).
4303   SmallVector<SmallVector<SrcOp, 8>, 3> InputOpsPieces(NumInputs);
4304   for (unsigned UseIdx = NumDefs, UseNo = 0; UseIdx < MI.getNumOperands();
4305        ++UseIdx, ++UseNo) {
4306     if (is_contained(NonVecOpIndices, UseIdx)) {
4307       broadcastSrcOp(InputOpsPieces[UseNo], OutputOpsPieces[0].size(),
4308                      MI.getOperand(UseIdx));
4309     } else {
4310       SmallVector<Register, 8> SplitPieces;
4311       extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces, MIRBuilder,
4312                          MRI);
4313       for (auto Reg : SplitPieces)
4314         InputOpsPieces[UseNo].push_back(Reg);
4315     }
4316   }
4317 
4318   unsigned NumLeftovers = OrigNumElts % NumElts ? 1 : 0;
4319 
4320   // Take i-th piece of each input operand split and build sub-vector/scalar
4321   // instruction. Set i-th DstOp(s) from OutputOpsPieces as destination(s).
4322   for (unsigned i = 0; i < OrigNumElts / NumElts + NumLeftovers; ++i) {
4323     SmallVector<DstOp, 2> Defs;
4324     for (unsigned DstNo = 0; DstNo < NumDefs; ++DstNo)
4325       Defs.push_back(OutputOpsPieces[DstNo][i]);
4326 
4327     SmallVector<SrcOp, 3> Uses;
4328     for (unsigned InputNo = 0; InputNo < NumInputs; ++InputNo)
4329       Uses.push_back(InputOpsPieces[InputNo][i]);
4330 
4331     auto I = MIRBuilder.buildInstr(MI.getOpcode(), Defs, Uses, MI.getFlags());
4332     for (unsigned DstNo = 0; DstNo < NumDefs; ++DstNo)
4333       OutputRegs[DstNo].push_back(I.getReg(DstNo));
4334   }
4335 
4336   // Merge small outputs into MI's output for each def operand.
4337   if (NumLeftovers) {
4338     for (unsigned i = 0; i < NumDefs; ++i)
4339       mergeMixedSubvectors(MI.getReg(i), OutputRegs[i]);
4340   } else {
4341     for (unsigned i = 0; i < NumDefs; ++i)
4342       MIRBuilder.buildMergeLikeInstr(MI.getReg(i), OutputRegs[i]);
4343   }
4344 
4345   MI.eraseFromParent();
4346   return Legalized;
4347 }
4348 
4349 LegalizerHelper::LegalizeResult
fewerElementsVectorPhi(GenericMachineInstr & MI,unsigned NumElts)4350 LegalizerHelper::fewerElementsVectorPhi(GenericMachineInstr &MI,
4351                                         unsigned NumElts) {
4352   unsigned OrigNumElts = MRI.getType(MI.getReg(0)).getNumElements();
4353 
4354   unsigned NumInputs = MI.getNumOperands() - MI.getNumDefs();
4355   unsigned NumDefs = MI.getNumDefs();
4356 
4357   SmallVector<DstOp, 8> OutputOpsPieces;
4358   SmallVector<Register, 8> OutputRegs;
4359   makeDstOps(OutputOpsPieces, MRI.getType(MI.getReg(0)), NumElts);
4360 
4361   // Instructions that perform register split will be inserted in basic block
4362   // where register is defined (basic block is in the next operand).
4363   SmallVector<SmallVector<Register, 8>, 3> InputOpsPieces(NumInputs / 2);
4364   for (unsigned UseIdx = NumDefs, UseNo = 0; UseIdx < MI.getNumOperands();
4365        UseIdx += 2, ++UseNo) {
4366     MachineBasicBlock &OpMBB = *MI.getOperand(UseIdx + 1).getMBB();
4367     MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
4368     extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo],
4369                        MIRBuilder, MRI);
4370   }
4371 
4372   // Build PHIs with fewer elements.
4373   unsigned NumLeftovers = OrigNumElts % NumElts ? 1 : 0;
4374   MIRBuilder.setInsertPt(*MI.getParent(), MI);
4375   for (unsigned i = 0; i < OrigNumElts / NumElts + NumLeftovers; ++i) {
4376     auto Phi = MIRBuilder.buildInstr(TargetOpcode::G_PHI);
4377     Phi.addDef(
4378         MRI.createGenericVirtualRegister(OutputOpsPieces[i].getLLTTy(MRI)));
4379     OutputRegs.push_back(Phi.getReg(0));
4380 
4381     for (unsigned j = 0; j < NumInputs / 2; ++j) {
4382       Phi.addUse(InputOpsPieces[j][i]);
4383       Phi.add(MI.getOperand(1 + j * 2 + 1));
4384     }
4385   }
4386 
4387   // Set the insert point after the existing PHIs
4388   MachineBasicBlock &MBB = *MI.getParent();
4389   MIRBuilder.setInsertPt(MBB, MBB.getFirstNonPHI());
4390 
4391   // Merge small outputs into MI's def.
4392   if (NumLeftovers) {
4393     mergeMixedSubvectors(MI.getReg(0), OutputRegs);
4394   } else {
4395     MIRBuilder.buildMergeLikeInstr(MI.getReg(0), OutputRegs);
4396   }
4397 
4398   MI.eraseFromParent();
4399   return Legalized;
4400 }
4401 
4402 LegalizerHelper::LegalizeResult
fewerElementsVectorUnmergeValues(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4403 LegalizerHelper::fewerElementsVectorUnmergeValues(MachineInstr &MI,
4404                                                   unsigned TypeIdx,
4405                                                   LLT NarrowTy) {
4406   const int NumDst = MI.getNumOperands() - 1;
4407   const Register SrcReg = MI.getOperand(NumDst).getReg();
4408   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
4409   LLT SrcTy = MRI.getType(SrcReg);
4410 
4411   if (TypeIdx != 1 || NarrowTy == DstTy)
4412     return UnableToLegalize;
4413 
4414   // Requires compatible types. Otherwise SrcReg should have been defined by
4415   // merge-like instruction that would get artifact combined. Most likely
4416   // instruction that defines SrcReg has to perform more/fewer elements
4417   // legalization compatible with NarrowTy.
4418   assert(SrcTy.isVector() && NarrowTy.isVector() && "Expected vector types");
4419   assert((SrcTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
4420 
4421   if ((SrcTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0) ||
4422       (NarrowTy.getSizeInBits() % DstTy.getSizeInBits() != 0))
4423     return UnableToLegalize;
4424 
4425   // This is most likely DstTy (smaller then register size) packed in SrcTy
4426   // (larger then register size) and since unmerge was not combined it will be
4427   // lowered to bit sequence extracts from register. Unpack SrcTy to NarrowTy
4428   // (register size) pieces first. Then unpack each of NarrowTy pieces to DstTy.
4429 
4430   // %1:_(DstTy), %2, %3, %4 = G_UNMERGE_VALUES %0:_(SrcTy)
4431   //
4432   // %5:_(NarrowTy), %6 = G_UNMERGE_VALUES %0:_(SrcTy) - reg sequence
4433   // %1:_(DstTy), %2 = G_UNMERGE_VALUES %5:_(NarrowTy) - sequence of bits in reg
4434   // %3:_(DstTy), %4 = G_UNMERGE_VALUES %6:_(NarrowTy)
4435   auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, SrcReg);
4436   const int NumUnmerge = Unmerge->getNumOperands() - 1;
4437   const int PartsPerUnmerge = NumDst / NumUnmerge;
4438 
4439   for (int I = 0; I != NumUnmerge; ++I) {
4440     auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
4441 
4442     for (int J = 0; J != PartsPerUnmerge; ++J)
4443       MIB.addDef(MI.getOperand(I * PartsPerUnmerge + J).getReg());
4444     MIB.addUse(Unmerge.getReg(I));
4445   }
4446 
4447   MI.eraseFromParent();
4448   return Legalized;
4449 }
4450 
4451 LegalizerHelper::LegalizeResult
fewerElementsVectorMerge(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4452 LegalizerHelper::fewerElementsVectorMerge(MachineInstr &MI, unsigned TypeIdx,
4453                                           LLT NarrowTy) {
4454   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
4455   // Requires compatible types. Otherwise user of DstReg did not perform unmerge
4456   // that should have been artifact combined. Most likely instruction that uses
4457   // DstReg has to do more/fewer elements legalization compatible with NarrowTy.
4458   assert(DstTy.isVector() && NarrowTy.isVector() && "Expected vector types");
4459   assert((DstTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
4460   if (NarrowTy == SrcTy)
4461     return UnableToLegalize;
4462 
4463   // This attempts to lower part of LCMTy merge/unmerge sequence. Intended use
4464   // is for old mir tests. Since the changes to more/fewer elements it should no
4465   // longer be possible to generate MIR like this when starting from llvm-ir
4466   // because LCMTy approach was replaced with merge/unmerge to vector elements.
4467   if (TypeIdx == 1) {
4468     assert(SrcTy.isVector() && "Expected vector types");
4469     assert((SrcTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
4470     if ((DstTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0) ||
4471         (NarrowTy.getNumElements() >= SrcTy.getNumElements()))
4472       return UnableToLegalize;
4473     // %2:_(DstTy) = G_CONCAT_VECTORS %0:_(SrcTy), %1:_(SrcTy)
4474     //
4475     // %3:_(EltTy), %4, %5 = G_UNMERGE_VALUES %0:_(SrcTy)
4476     // %6:_(EltTy), %7, %8 = G_UNMERGE_VALUES %1:_(SrcTy)
4477     // %9:_(NarrowTy) = G_BUILD_VECTOR %3:_(EltTy), %4
4478     // %10:_(NarrowTy) = G_BUILD_VECTOR %5:_(EltTy), %6
4479     // %11:_(NarrowTy) = G_BUILD_VECTOR %7:_(EltTy), %8
4480     // %2:_(DstTy) = G_CONCAT_VECTORS %9:_(NarrowTy), %10, %11
4481 
4482     SmallVector<Register, 8> Elts;
4483     LLT EltTy = MRI.getType(MI.getOperand(1).getReg()).getScalarType();
4484     for (unsigned i = 1; i < MI.getNumOperands(); ++i) {
4485       auto Unmerge = MIRBuilder.buildUnmerge(EltTy, MI.getOperand(i).getReg());
4486       for (unsigned j = 0; j < Unmerge->getNumDefs(); ++j)
4487         Elts.push_back(Unmerge.getReg(j));
4488     }
4489 
4490     SmallVector<Register, 8> NarrowTyElts;
4491     unsigned NumNarrowTyElts = NarrowTy.getNumElements();
4492     unsigned NumNarrowTyPieces = DstTy.getNumElements() / NumNarrowTyElts;
4493     for (unsigned i = 0, Offset = 0; i < NumNarrowTyPieces;
4494          ++i, Offset += NumNarrowTyElts) {
4495       ArrayRef<Register> Pieces(&Elts[Offset], NumNarrowTyElts);
4496       NarrowTyElts.push_back(
4497           MIRBuilder.buildMergeLikeInstr(NarrowTy, Pieces).getReg(0));
4498     }
4499 
4500     MIRBuilder.buildMergeLikeInstr(DstReg, NarrowTyElts);
4501     MI.eraseFromParent();
4502     return Legalized;
4503   }
4504 
4505   assert(TypeIdx == 0 && "Bad type index");
4506   if ((NarrowTy.getSizeInBits() % SrcTy.getSizeInBits() != 0) ||
4507       (DstTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0))
4508     return UnableToLegalize;
4509 
4510   // This is most likely SrcTy (smaller then register size) packed in DstTy
4511   // (larger then register size) and since merge was not combined it will be
4512   // lowered to bit sequence packing into register. Merge SrcTy to NarrowTy
4513   // (register size) pieces first. Then merge each of NarrowTy pieces to DstTy.
4514 
4515   // %0:_(DstTy) = G_MERGE_VALUES %1:_(SrcTy), %2, %3, %4
4516   //
4517   // %5:_(NarrowTy) = G_MERGE_VALUES %1:_(SrcTy), %2 - sequence of bits in reg
4518   // %6:_(NarrowTy) = G_MERGE_VALUES %3:_(SrcTy), %4
4519   // %0:_(DstTy)  = G_MERGE_VALUES %5:_(NarrowTy), %6 - reg sequence
4520   SmallVector<Register, 8> NarrowTyElts;
4521   unsigned NumParts = DstTy.getNumElements() / NarrowTy.getNumElements();
4522   unsigned NumSrcElts = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
4523   unsigned NumElts = NarrowTy.getNumElements() / NumSrcElts;
4524   for (unsigned i = 0; i < NumParts; ++i) {
4525     SmallVector<Register, 8> Sources;
4526     for (unsigned j = 0; j < NumElts; ++j)
4527       Sources.push_back(MI.getOperand(1 + i * NumElts + j).getReg());
4528     NarrowTyElts.push_back(
4529         MIRBuilder.buildMergeLikeInstr(NarrowTy, Sources).getReg(0));
4530   }
4531 
4532   MIRBuilder.buildMergeLikeInstr(DstReg, NarrowTyElts);
4533   MI.eraseFromParent();
4534   return Legalized;
4535 }
4536 
4537 LegalizerHelper::LegalizeResult
fewerElementsVectorExtractInsertVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT NarrowVecTy)4538 LegalizerHelper::fewerElementsVectorExtractInsertVectorElt(MachineInstr &MI,
4539                                                            unsigned TypeIdx,
4540                                                            LLT NarrowVecTy) {
4541   auto [DstReg, SrcVec] = MI.getFirst2Regs();
4542   Register InsertVal;
4543   bool IsInsert = MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT;
4544 
4545   assert((IsInsert ? TypeIdx == 0 : TypeIdx == 1) && "not a vector type index");
4546   if (IsInsert)
4547     InsertVal = MI.getOperand(2).getReg();
4548 
4549   Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
4550 
4551   // TODO: Handle total scalarization case.
4552   if (!NarrowVecTy.isVector())
4553     return UnableToLegalize;
4554 
4555   LLT VecTy = MRI.getType(SrcVec);
4556 
4557   // If the index is a constant, we can really break this down as you would
4558   // expect, and index into the target size pieces.
4559   int64_t IdxVal;
4560   auto MaybeCst = getIConstantVRegValWithLookThrough(Idx, MRI);
4561   if (MaybeCst) {
4562     IdxVal = MaybeCst->Value.getSExtValue();
4563     // Avoid out of bounds indexing the pieces.
4564     if (IdxVal >= VecTy.getNumElements()) {
4565       MIRBuilder.buildUndef(DstReg);
4566       MI.eraseFromParent();
4567       return Legalized;
4568     }
4569 
4570     SmallVector<Register, 8> VecParts;
4571     LLT GCDTy = extractGCDType(VecParts, VecTy, NarrowVecTy, SrcVec);
4572 
4573     // Build a sequence of NarrowTy pieces in VecParts for this operand.
4574     LLT LCMTy = buildLCMMergePieces(VecTy, NarrowVecTy, GCDTy, VecParts,
4575                                     TargetOpcode::G_ANYEXT);
4576 
4577     unsigned NewNumElts = NarrowVecTy.getNumElements();
4578 
4579     LLT IdxTy = MRI.getType(Idx);
4580     int64_t PartIdx = IdxVal / NewNumElts;
4581     auto NewIdx =
4582         MIRBuilder.buildConstant(IdxTy, IdxVal - NewNumElts * PartIdx);
4583 
4584     if (IsInsert) {
4585       LLT PartTy = MRI.getType(VecParts[PartIdx]);
4586 
4587       // Use the adjusted index to insert into one of the subvectors.
4588       auto InsertPart = MIRBuilder.buildInsertVectorElement(
4589           PartTy, VecParts[PartIdx], InsertVal, NewIdx);
4590       VecParts[PartIdx] = InsertPart.getReg(0);
4591 
4592       // Recombine the inserted subvector with the others to reform the result
4593       // vector.
4594       buildWidenedRemergeToDst(DstReg, LCMTy, VecParts);
4595     } else {
4596       MIRBuilder.buildExtractVectorElement(DstReg, VecParts[PartIdx], NewIdx);
4597     }
4598 
4599     MI.eraseFromParent();
4600     return Legalized;
4601   }
4602 
4603   // With a variable index, we can't perform the operation in a smaller type, so
4604   // we're forced to expand this.
4605   //
4606   // TODO: We could emit a chain of compare/select to figure out which piece to
4607   // index.
4608   return lowerExtractInsertVectorElt(MI);
4609 }
4610 
4611 LegalizerHelper::LegalizeResult
reduceLoadStoreWidth(GLoadStore & LdStMI,unsigned TypeIdx,LLT NarrowTy)4612 LegalizerHelper::reduceLoadStoreWidth(GLoadStore &LdStMI, unsigned TypeIdx,
4613                                       LLT NarrowTy) {
4614   // FIXME: Don't know how to handle secondary types yet.
4615   if (TypeIdx != 0)
4616     return UnableToLegalize;
4617 
4618   // This implementation doesn't work for atomics. Give up instead of doing
4619   // something invalid.
4620   if (LdStMI.isAtomic())
4621     return UnableToLegalize;
4622 
4623   bool IsLoad = isa<GLoad>(LdStMI);
4624   Register ValReg = LdStMI.getReg(0);
4625   Register AddrReg = LdStMI.getPointerReg();
4626   LLT ValTy = MRI.getType(ValReg);
4627 
4628   // FIXME: Do we need a distinct NarrowMemory legalize action?
4629   if (ValTy.getSizeInBits() != 8 * LdStMI.getMemSize().getValue()) {
4630     LLVM_DEBUG(dbgs() << "Can't narrow extload/truncstore\n");
4631     return UnableToLegalize;
4632   }
4633 
4634   int NumParts = -1;
4635   int NumLeftover = -1;
4636   LLT LeftoverTy;
4637   SmallVector<Register, 8> NarrowRegs, NarrowLeftoverRegs;
4638   if (IsLoad) {
4639     std::tie(NumParts, NumLeftover) = getNarrowTypeBreakDown(ValTy, NarrowTy, LeftoverTy);
4640   } else {
4641     if (extractParts(ValReg, ValTy, NarrowTy, LeftoverTy, NarrowRegs,
4642                      NarrowLeftoverRegs, MIRBuilder, MRI)) {
4643       NumParts = NarrowRegs.size();
4644       NumLeftover = NarrowLeftoverRegs.size();
4645     }
4646   }
4647 
4648   if (NumParts == -1)
4649     return UnableToLegalize;
4650 
4651   LLT PtrTy = MRI.getType(AddrReg);
4652   const LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
4653 
4654   unsigned TotalSize = ValTy.getSizeInBits();
4655 
4656   // Split the load/store into PartTy sized pieces starting at Offset. If this
4657   // is a load, return the new registers in ValRegs. For a store, each elements
4658   // of ValRegs should be PartTy. Returns the next offset that needs to be
4659   // handled.
4660   bool isBigEndian = MIRBuilder.getDataLayout().isBigEndian();
4661   auto MMO = LdStMI.getMMO();
4662   auto splitTypePieces = [=](LLT PartTy, SmallVectorImpl<Register> &ValRegs,
4663                              unsigned NumParts, unsigned Offset) -> unsigned {
4664     MachineFunction &MF = MIRBuilder.getMF();
4665     unsigned PartSize = PartTy.getSizeInBits();
4666     for (unsigned Idx = 0, E = NumParts; Idx != E && Offset < TotalSize;
4667          ++Idx) {
4668       unsigned ByteOffset = Offset / 8;
4669       Register NewAddrReg;
4670 
4671       MIRBuilder.materializePtrAdd(NewAddrReg, AddrReg, OffsetTy, ByteOffset);
4672 
4673       MachineMemOperand *NewMMO =
4674           MF.getMachineMemOperand(&MMO, ByteOffset, PartTy);
4675 
4676       if (IsLoad) {
4677         Register Dst = MRI.createGenericVirtualRegister(PartTy);
4678         ValRegs.push_back(Dst);
4679         MIRBuilder.buildLoad(Dst, NewAddrReg, *NewMMO);
4680       } else {
4681         MIRBuilder.buildStore(ValRegs[Idx], NewAddrReg, *NewMMO);
4682       }
4683       Offset = isBigEndian ? Offset - PartSize : Offset + PartSize;
4684     }
4685 
4686     return Offset;
4687   };
4688 
4689   unsigned Offset = isBigEndian ? TotalSize - NarrowTy.getSizeInBits() : 0;
4690   unsigned HandledOffset =
4691       splitTypePieces(NarrowTy, NarrowRegs, NumParts, Offset);
4692 
4693   // Handle the rest of the register if this isn't an even type breakdown.
4694   if (LeftoverTy.isValid())
4695     splitTypePieces(LeftoverTy, NarrowLeftoverRegs, NumLeftover, HandledOffset);
4696 
4697   if (IsLoad) {
4698     insertParts(ValReg, ValTy, NarrowTy, NarrowRegs,
4699                 LeftoverTy, NarrowLeftoverRegs);
4700   }
4701 
4702   LdStMI.eraseFromParent();
4703   return Legalized;
4704 }
4705 
4706 LegalizerHelper::LegalizeResult
fewerElementsVector(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4707 LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
4708                                      LLT NarrowTy) {
4709   using namespace TargetOpcode;
4710   GenericMachineInstr &GMI = cast<GenericMachineInstr>(MI);
4711   unsigned NumElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
4712 
4713   switch (MI.getOpcode()) {
4714   case G_IMPLICIT_DEF:
4715   case G_TRUNC:
4716   case G_AND:
4717   case G_OR:
4718   case G_XOR:
4719   case G_ADD:
4720   case G_SUB:
4721   case G_MUL:
4722   case G_PTR_ADD:
4723   case G_SMULH:
4724   case G_UMULH:
4725   case G_FADD:
4726   case G_FMUL:
4727   case G_FSUB:
4728   case G_FNEG:
4729   case G_FABS:
4730   case G_FCANONICALIZE:
4731   case G_FDIV:
4732   case G_FREM:
4733   case G_FMA:
4734   case G_FMAD:
4735   case G_FPOW:
4736   case G_FEXP:
4737   case G_FEXP2:
4738   case G_FEXP10:
4739   case G_FLOG:
4740   case G_FLOG2:
4741   case G_FLOG10:
4742   case G_FLDEXP:
4743   case G_FNEARBYINT:
4744   case G_FCEIL:
4745   case G_FFLOOR:
4746   case G_FRINT:
4747   case G_INTRINSIC_ROUND:
4748   case G_INTRINSIC_ROUNDEVEN:
4749   case G_INTRINSIC_TRUNC:
4750   case G_FCOS:
4751   case G_FSIN:
4752   case G_FTAN:
4753   case G_FACOS:
4754   case G_FASIN:
4755   case G_FATAN:
4756   case G_FCOSH:
4757   case G_FSINH:
4758   case G_FTANH:
4759   case G_FSQRT:
4760   case G_BSWAP:
4761   case G_BITREVERSE:
4762   case G_SDIV:
4763   case G_UDIV:
4764   case G_SREM:
4765   case G_UREM:
4766   case G_SDIVREM:
4767   case G_UDIVREM:
4768   case G_SMIN:
4769   case G_SMAX:
4770   case G_UMIN:
4771   case G_UMAX:
4772   case G_ABS:
4773   case G_FMINNUM:
4774   case G_FMAXNUM:
4775   case G_FMINNUM_IEEE:
4776   case G_FMAXNUM_IEEE:
4777   case G_FMINIMUM:
4778   case G_FMAXIMUM:
4779   case G_FSHL:
4780   case G_FSHR:
4781   case G_ROTL:
4782   case G_ROTR:
4783   case G_FREEZE:
4784   case G_SADDSAT:
4785   case G_SSUBSAT:
4786   case G_UADDSAT:
4787   case G_USUBSAT:
4788   case G_UMULO:
4789   case G_SMULO:
4790   case G_SHL:
4791   case G_LSHR:
4792   case G_ASHR:
4793   case G_SSHLSAT:
4794   case G_USHLSAT:
4795   case G_CTLZ:
4796   case G_CTLZ_ZERO_UNDEF:
4797   case G_CTTZ:
4798   case G_CTTZ_ZERO_UNDEF:
4799   case G_CTPOP:
4800   case G_FCOPYSIGN:
4801   case G_ZEXT:
4802   case G_SEXT:
4803   case G_ANYEXT:
4804   case G_FPEXT:
4805   case G_FPTRUNC:
4806   case G_SITOFP:
4807   case G_UITOFP:
4808   case G_FPTOSI:
4809   case G_FPTOUI:
4810   case G_INTTOPTR:
4811   case G_PTRTOINT:
4812   case G_ADDRSPACE_CAST:
4813   case G_UADDO:
4814   case G_USUBO:
4815   case G_UADDE:
4816   case G_USUBE:
4817   case G_SADDO:
4818   case G_SSUBO:
4819   case G_SADDE:
4820   case G_SSUBE:
4821   case G_STRICT_FADD:
4822   case G_STRICT_FSUB:
4823   case G_STRICT_FMUL:
4824   case G_STRICT_FMA:
4825   case G_STRICT_FLDEXP:
4826   case G_FFREXP:
4827     return fewerElementsVectorMultiEltType(GMI, NumElts);
4828   case G_ICMP:
4829   case G_FCMP:
4830     return fewerElementsVectorMultiEltType(GMI, NumElts, {1 /*cpm predicate*/});
4831   case G_IS_FPCLASS:
4832     return fewerElementsVectorMultiEltType(GMI, NumElts, {2, 3 /*mask,fpsem*/});
4833   case G_SELECT:
4834     if (MRI.getType(MI.getOperand(1).getReg()).isVector())
4835       return fewerElementsVectorMultiEltType(GMI, NumElts);
4836     return fewerElementsVectorMultiEltType(GMI, NumElts, {1 /*scalar cond*/});
4837   case G_PHI:
4838     return fewerElementsVectorPhi(GMI, NumElts);
4839   case G_UNMERGE_VALUES:
4840     return fewerElementsVectorUnmergeValues(MI, TypeIdx, NarrowTy);
4841   case G_BUILD_VECTOR:
4842     assert(TypeIdx == 0 && "not a vector type index");
4843     return fewerElementsVectorMerge(MI, TypeIdx, NarrowTy);
4844   case G_CONCAT_VECTORS:
4845     if (TypeIdx != 1) // TODO: This probably does work as expected already.
4846       return UnableToLegalize;
4847     return fewerElementsVectorMerge(MI, TypeIdx, NarrowTy);
4848   case G_EXTRACT_VECTOR_ELT:
4849   case G_INSERT_VECTOR_ELT:
4850     return fewerElementsVectorExtractInsertVectorElt(MI, TypeIdx, NarrowTy);
4851   case G_LOAD:
4852   case G_STORE:
4853     return reduceLoadStoreWidth(cast<GLoadStore>(MI), TypeIdx, NarrowTy);
4854   case G_SEXT_INREG:
4855     return fewerElementsVectorMultiEltType(GMI, NumElts, {2 /*imm*/});
4856   GISEL_VECREDUCE_CASES_NONSEQ
4857     return fewerElementsVectorReductions(MI, TypeIdx, NarrowTy);
4858   case TargetOpcode::G_VECREDUCE_SEQ_FADD:
4859   case TargetOpcode::G_VECREDUCE_SEQ_FMUL:
4860     return fewerElementsVectorSeqReductions(MI, TypeIdx, NarrowTy);
4861   case G_SHUFFLE_VECTOR:
4862     return fewerElementsVectorShuffle(MI, TypeIdx, NarrowTy);
4863   case G_FPOWI:
4864     return fewerElementsVectorMultiEltType(GMI, NumElts, {2 /*pow*/});
4865   case G_BITCAST:
4866     return fewerElementsBitcast(MI, TypeIdx, NarrowTy);
4867   case G_INTRINSIC_FPTRUNC_ROUND:
4868     return fewerElementsVectorMultiEltType(GMI, NumElts, {2});
4869   default:
4870     return UnableToLegalize;
4871   }
4872 }
4873 
4874 LegalizerHelper::LegalizeResult
fewerElementsBitcast(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4875 LegalizerHelper::fewerElementsBitcast(MachineInstr &MI, unsigned int TypeIdx,
4876                                       LLT NarrowTy) {
4877   assert(MI.getOpcode() == TargetOpcode::G_BITCAST &&
4878          "Not a bitcast operation");
4879 
4880   if (TypeIdx != 0)
4881     return UnableToLegalize;
4882 
4883   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
4884 
4885   unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
4886   LLT SrcNarrowTy =
4887       LLT::fixed_vector(NarrowTy.getSizeInBits() / SrcScalSize, SrcScalSize);
4888 
4889   // Split the Src and Dst Reg into smaller registers
4890   SmallVector<Register> SrcVRegs, BitcastVRegs;
4891   if (extractGCDType(SrcVRegs, DstTy, SrcNarrowTy, SrcReg) != SrcNarrowTy)
4892     return UnableToLegalize;
4893 
4894   // Build new smaller bitcast instructions
4895   // Not supporting Leftover types for now but will have to
4896   for (unsigned i = 0; i < SrcVRegs.size(); i++)
4897     BitcastVRegs.push_back(
4898         MIRBuilder.buildBitcast(NarrowTy, SrcVRegs[i]).getReg(0));
4899 
4900   MIRBuilder.buildMergeLikeInstr(DstReg, BitcastVRegs);
4901   MI.eraseFromParent();
4902   return Legalized;
4903 }
4904 
fewerElementsVectorShuffle(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4905 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
4906     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4907   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
4908   if (TypeIdx != 0)
4909     return UnableToLegalize;
4910 
4911   auto [DstReg, DstTy, Src1Reg, Src1Ty, Src2Reg, Src2Ty] =
4912       MI.getFirst3RegLLTs();
4913   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
4914   // The shuffle should be canonicalized by now.
4915   if (DstTy != Src1Ty)
4916     return UnableToLegalize;
4917   if (DstTy != Src2Ty)
4918     return UnableToLegalize;
4919 
4920   if (!isPowerOf2_32(DstTy.getNumElements()))
4921     return UnableToLegalize;
4922 
4923   // We only support splitting a shuffle into 2, so adjust NarrowTy accordingly.
4924   // Further legalization attempts will be needed to do split further.
4925   NarrowTy =
4926       DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2));
4927   unsigned NewElts = NarrowTy.getNumElements();
4928 
4929   SmallVector<Register> SplitSrc1Regs, SplitSrc2Regs;
4930   extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs, MIRBuilder, MRI);
4931   extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs, MIRBuilder, MRI);
4932   Register Inputs[4] = {SplitSrc1Regs[0], SplitSrc1Regs[1], SplitSrc2Regs[0],
4933                         SplitSrc2Regs[1]};
4934 
4935   Register Hi, Lo;
4936 
4937   // If Lo or Hi uses elements from at most two of the four input vectors, then
4938   // express it as a vector shuffle of those two inputs.  Otherwise extract the
4939   // input elements by hand and construct the Lo/Hi output using a BUILD_VECTOR.
4940   SmallVector<int, 16> Ops;
4941   for (unsigned High = 0; High < 2; ++High) {
4942     Register &Output = High ? Hi : Lo;
4943 
4944     // Build a shuffle mask for the output, discovering on the fly which
4945     // input vectors to use as shuffle operands (recorded in InputUsed).
4946     // If building a suitable shuffle vector proves too hard, then bail
4947     // out with useBuildVector set.
4948     unsigned InputUsed[2] = {-1U, -1U}; // Not yet discovered.
4949     unsigned FirstMaskIdx = High * NewElts;
4950     bool UseBuildVector = false;
4951     for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
4952       // The mask element.  This indexes into the input.
4953       int Idx = Mask[FirstMaskIdx + MaskOffset];
4954 
4955       // The input vector this mask element indexes into.
4956       unsigned Input = (unsigned)Idx / NewElts;
4957 
4958       if (Input >= std::size(Inputs)) {
4959         // The mask element does not index into any input vector.
4960         Ops.push_back(-1);
4961         continue;
4962       }
4963 
4964       // Turn the index into an offset from the start of the input vector.
4965       Idx -= Input * NewElts;
4966 
4967       // Find or create a shuffle vector operand to hold this input.
4968       unsigned OpNo;
4969       for (OpNo = 0; OpNo < std::size(InputUsed); ++OpNo) {
4970         if (InputUsed[OpNo] == Input) {
4971           // This input vector is already an operand.
4972           break;
4973         } else if (InputUsed[OpNo] == -1U) {
4974           // Create a new operand for this input vector.
4975           InputUsed[OpNo] = Input;
4976           break;
4977         }
4978       }
4979 
4980       if (OpNo >= std::size(InputUsed)) {
4981         // More than two input vectors used!  Give up on trying to create a
4982         // shuffle vector.  Insert all elements into a BUILD_VECTOR instead.
4983         UseBuildVector = true;
4984         break;
4985       }
4986 
4987       // Add the mask index for the new shuffle vector.
4988       Ops.push_back(Idx + OpNo * NewElts);
4989     }
4990 
4991     if (UseBuildVector) {
4992       LLT EltTy = NarrowTy.getElementType();
4993       SmallVector<Register, 16> SVOps;
4994 
4995       // Extract the input elements by hand.
4996       for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
4997         // The mask element.  This indexes into the input.
4998         int Idx = Mask[FirstMaskIdx + MaskOffset];
4999 
5000         // The input vector this mask element indexes into.
5001         unsigned Input = (unsigned)Idx / NewElts;
5002 
5003         if (Input >= std::size(Inputs)) {
5004           // The mask element is "undef" or indexes off the end of the input.
5005           SVOps.push_back(MIRBuilder.buildUndef(EltTy).getReg(0));
5006           continue;
5007         }
5008 
5009         // Turn the index into an offset from the start of the input vector.
5010         Idx -= Input * NewElts;
5011 
5012         // Extract the vector element by hand.
5013         SVOps.push_back(MIRBuilder
5014                             .buildExtractVectorElement(
5015                                 EltTy, Inputs[Input],
5016                                 MIRBuilder.buildConstant(LLT::scalar(32), Idx))
5017                             .getReg(0));
5018       }
5019 
5020       // Construct the Lo/Hi output using a G_BUILD_VECTOR.
5021       Output = MIRBuilder.buildBuildVector(NarrowTy, SVOps).getReg(0);
5022     } else if (InputUsed[0] == -1U) {
5023       // No input vectors were used! The result is undefined.
5024       Output = MIRBuilder.buildUndef(NarrowTy).getReg(0);
5025     } else {
5026       Register Op0 = Inputs[InputUsed[0]];
5027       // If only one input was used, use an undefined vector for the other.
5028       Register Op1 = InputUsed[1] == -1U
5029                          ? MIRBuilder.buildUndef(NarrowTy).getReg(0)
5030                          : Inputs[InputUsed[1]];
5031       // At least one input vector was used. Create a new shuffle vector.
5032       Output = MIRBuilder.buildShuffleVector(NarrowTy, Op0, Op1, Ops).getReg(0);
5033     }
5034 
5035     Ops.clear();
5036   }
5037 
5038   MIRBuilder.buildConcatVectors(DstReg, {Lo, Hi});
5039   MI.eraseFromParent();
5040   return Legalized;
5041 }
5042 
fewerElementsVectorReductions(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)5043 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
5044     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
5045   auto &RdxMI = cast<GVecReduce>(MI);
5046 
5047   if (TypeIdx != 1)
5048     return UnableToLegalize;
5049 
5050   // The semantics of the normal non-sequential reductions allow us to freely
5051   // re-associate the operation.
5052   auto [DstReg, DstTy, SrcReg, SrcTy] = RdxMI.getFirst2RegLLTs();
5053 
5054   if (NarrowTy.isVector() &&
5055       (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
5056     return UnableToLegalize;
5057 
5058   unsigned ScalarOpc = RdxMI.getScalarOpcForReduction();
5059   SmallVector<Register> SplitSrcs;
5060   // If NarrowTy is a scalar then we're being asked to scalarize.
5061   const unsigned NumParts =
5062       NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
5063                           : SrcTy.getNumElements();
5064 
5065   extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
5066   if (NarrowTy.isScalar()) {
5067     if (DstTy != NarrowTy)
5068       return UnableToLegalize; // FIXME: handle implicit extensions.
5069 
5070     if (isPowerOf2_32(NumParts)) {
5071       // Generate a tree of scalar operations to reduce the critical path.
5072       SmallVector<Register> PartialResults;
5073       unsigned NumPartsLeft = NumParts;
5074       while (NumPartsLeft > 1) {
5075         for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
5076           PartialResults.emplace_back(
5077               MIRBuilder
5078                   .buildInstr(ScalarOpc, {NarrowTy},
5079                               {SplitSrcs[Idx], SplitSrcs[Idx + 1]})
5080                   .getReg(0));
5081         }
5082         SplitSrcs = PartialResults;
5083         PartialResults.clear();
5084         NumPartsLeft = SplitSrcs.size();
5085       }
5086       assert(SplitSrcs.size() == 1);
5087       MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
5088       MI.eraseFromParent();
5089       return Legalized;
5090     }
5091     // If we can't generate a tree, then just do sequential operations.
5092     Register Acc = SplitSrcs[0];
5093     for (unsigned Idx = 1; Idx < NumParts; ++Idx)
5094       Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
5095                 .getReg(0);
5096     MIRBuilder.buildCopy(DstReg, Acc);
5097     MI.eraseFromParent();
5098     return Legalized;
5099   }
5100   SmallVector<Register> PartialReductions;
5101   for (unsigned Part = 0; Part < NumParts; ++Part) {
5102     PartialReductions.push_back(
5103         MIRBuilder.buildInstr(RdxMI.getOpcode(), {DstTy}, {SplitSrcs[Part]})
5104             .getReg(0));
5105   }
5106 
5107   // If the types involved are powers of 2, we can generate intermediate vector
5108   // ops, before generating a final reduction operation.
5109   if (isPowerOf2_32(SrcTy.getNumElements()) &&
5110       isPowerOf2_32(NarrowTy.getNumElements())) {
5111     return tryNarrowPow2Reduction(MI, SrcReg, SrcTy, NarrowTy, ScalarOpc);
5112   }
5113 
5114   Register Acc = PartialReductions[0];
5115   for (unsigned Part = 1; Part < NumParts; ++Part) {
5116     if (Part == NumParts - 1) {
5117       MIRBuilder.buildInstr(ScalarOpc, {DstReg},
5118                             {Acc, PartialReductions[Part]});
5119     } else {
5120       Acc = MIRBuilder
5121                 .buildInstr(ScalarOpc, {DstTy}, {Acc, PartialReductions[Part]})
5122                 .getReg(0);
5123     }
5124   }
5125   MI.eraseFromParent();
5126   return Legalized;
5127 }
5128 
5129 LegalizerHelper::LegalizeResult
fewerElementsVectorSeqReductions(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)5130 LegalizerHelper::fewerElementsVectorSeqReductions(MachineInstr &MI,
5131                                                   unsigned int TypeIdx,
5132                                                   LLT NarrowTy) {
5133   auto [DstReg, DstTy, ScalarReg, ScalarTy, SrcReg, SrcTy] =
5134       MI.getFirst3RegLLTs();
5135   if (!NarrowTy.isScalar() || TypeIdx != 2 || DstTy != ScalarTy ||
5136       DstTy != NarrowTy)
5137     return UnableToLegalize;
5138 
5139   assert((MI.getOpcode() == TargetOpcode::G_VECREDUCE_SEQ_FADD ||
5140           MI.getOpcode() == TargetOpcode::G_VECREDUCE_SEQ_FMUL) &&
5141          "Unexpected vecreduce opcode");
5142   unsigned ScalarOpc = MI.getOpcode() == TargetOpcode::G_VECREDUCE_SEQ_FADD
5143                            ? TargetOpcode::G_FADD
5144                            : TargetOpcode::G_FMUL;
5145 
5146   SmallVector<Register> SplitSrcs;
5147   unsigned NumParts = SrcTy.getNumElements();
5148   extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
5149   Register Acc = ScalarReg;
5150   for (unsigned i = 0; i < NumParts; i++)
5151     Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[i]})
5152               .getReg(0);
5153 
5154   MIRBuilder.buildCopy(DstReg, Acc);
5155   MI.eraseFromParent();
5156   return Legalized;
5157 }
5158 
5159 LegalizerHelper::LegalizeResult
tryNarrowPow2Reduction(MachineInstr & MI,Register SrcReg,LLT SrcTy,LLT NarrowTy,unsigned ScalarOpc)5160 LegalizerHelper::tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
5161                                         LLT SrcTy, LLT NarrowTy,
5162                                         unsigned ScalarOpc) {
5163   SmallVector<Register> SplitSrcs;
5164   // Split the sources into NarrowTy size pieces.
5165   extractParts(SrcReg, NarrowTy,
5166                SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs,
5167                MIRBuilder, MRI);
5168   // We're going to do a tree reduction using vector operations until we have
5169   // one NarrowTy size value left.
5170   while (SplitSrcs.size() > 1) {
5171     SmallVector<Register> PartialRdxs;
5172     for (unsigned Idx = 0; Idx < SplitSrcs.size()-1; Idx += 2) {
5173       Register LHS = SplitSrcs[Idx];
5174       Register RHS = SplitSrcs[Idx + 1];
5175       // Create the intermediate vector op.
5176       Register Res =
5177           MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {LHS, RHS}).getReg(0);
5178       PartialRdxs.push_back(Res);
5179     }
5180     SplitSrcs = std::move(PartialRdxs);
5181   }
5182   // Finally generate the requested NarrowTy based reduction.
5183   Observer.changingInstr(MI);
5184   MI.getOperand(1).setReg(SplitSrcs[0]);
5185   Observer.changedInstr(MI);
5186   return Legalized;
5187 }
5188 
5189 LegalizerHelper::LegalizeResult
narrowScalarShiftByConstant(MachineInstr & MI,const APInt & Amt,const LLT HalfTy,const LLT AmtTy)5190 LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
5191                                              const LLT HalfTy, const LLT AmtTy) {
5192 
5193   Register InL = MRI.createGenericVirtualRegister(HalfTy);
5194   Register InH = MRI.createGenericVirtualRegister(HalfTy);
5195   MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1));
5196 
5197   if (Amt.isZero()) {
5198     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), {InL, InH});
5199     MI.eraseFromParent();
5200     return Legalized;
5201   }
5202 
5203   LLT NVT = HalfTy;
5204   unsigned NVTBits = HalfTy.getSizeInBits();
5205   unsigned VTBits = 2 * NVTBits;
5206 
5207   SrcOp Lo(Register(0)), Hi(Register(0));
5208   if (MI.getOpcode() == TargetOpcode::G_SHL) {
5209     if (Amt.ugt(VTBits)) {
5210       Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
5211     } else if (Amt.ugt(NVTBits)) {
5212       Lo = MIRBuilder.buildConstant(NVT, 0);
5213       Hi = MIRBuilder.buildShl(NVT, InL,
5214                                MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
5215     } else if (Amt == NVTBits) {
5216       Lo = MIRBuilder.buildConstant(NVT, 0);
5217       Hi = InL;
5218     } else {
5219       Lo = MIRBuilder.buildShl(NVT, InL, MIRBuilder.buildConstant(AmtTy, Amt));
5220       auto OrLHS =
5221           MIRBuilder.buildShl(NVT, InH, MIRBuilder.buildConstant(AmtTy, Amt));
5222       auto OrRHS = MIRBuilder.buildLShr(
5223           NVT, InL, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
5224       Hi = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
5225     }
5226   } else if (MI.getOpcode() == TargetOpcode::G_LSHR) {
5227     if (Amt.ugt(VTBits)) {
5228       Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
5229     } else if (Amt.ugt(NVTBits)) {
5230       Lo = MIRBuilder.buildLShr(NVT, InH,
5231                                 MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
5232       Hi = MIRBuilder.buildConstant(NVT, 0);
5233     } else if (Amt == NVTBits) {
5234       Lo = InH;
5235       Hi = MIRBuilder.buildConstant(NVT, 0);
5236     } else {
5237       auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
5238 
5239       auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
5240       auto OrRHS = MIRBuilder.buildShl(
5241           NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
5242 
5243       Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
5244       Hi = MIRBuilder.buildLShr(NVT, InH, ShiftAmtConst);
5245     }
5246   } else {
5247     if (Amt.ugt(VTBits)) {
5248       Hi = Lo = MIRBuilder.buildAShr(
5249           NVT, InH, MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
5250     } else if (Amt.ugt(NVTBits)) {
5251       Lo = MIRBuilder.buildAShr(NVT, InH,
5252                                 MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
5253       Hi = MIRBuilder.buildAShr(NVT, InH,
5254                                 MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
5255     } else if (Amt == NVTBits) {
5256       Lo = InH;
5257       Hi = MIRBuilder.buildAShr(NVT, InH,
5258                                 MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
5259     } else {
5260       auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
5261 
5262       auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
5263       auto OrRHS = MIRBuilder.buildShl(
5264           NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
5265 
5266       Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
5267       Hi = MIRBuilder.buildAShr(NVT, InH, ShiftAmtConst);
5268     }
5269   }
5270 
5271   MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), {Lo, Hi});
5272   MI.eraseFromParent();
5273 
5274   return Legalized;
5275 }
5276 
5277 // TODO: Optimize if constant shift amount.
5278 LegalizerHelper::LegalizeResult
narrowScalarShift(MachineInstr & MI,unsigned TypeIdx,LLT RequestedTy)5279 LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx,
5280                                    LLT RequestedTy) {
5281   if (TypeIdx == 1) {
5282     Observer.changingInstr(MI);
5283     narrowScalarSrc(MI, RequestedTy, 2);
5284     Observer.changedInstr(MI);
5285     return Legalized;
5286   }
5287 
5288   Register DstReg = MI.getOperand(0).getReg();
5289   LLT DstTy = MRI.getType(DstReg);
5290   if (DstTy.isVector())
5291     return UnableToLegalize;
5292 
5293   Register Amt = MI.getOperand(2).getReg();
5294   LLT ShiftAmtTy = MRI.getType(Amt);
5295   const unsigned DstEltSize = DstTy.getScalarSizeInBits();
5296   if (DstEltSize % 2 != 0)
5297     return UnableToLegalize;
5298 
5299   // Ignore the input type. We can only go to exactly half the size of the
5300   // input. If that isn't small enough, the resulting pieces will be further
5301   // legalized.
5302   const unsigned NewBitSize = DstEltSize / 2;
5303   const LLT HalfTy = LLT::scalar(NewBitSize);
5304   const LLT CondTy = LLT::scalar(1);
5305 
5306   if (auto VRegAndVal = getIConstantVRegValWithLookThrough(Amt, MRI)) {
5307     return narrowScalarShiftByConstant(MI, VRegAndVal->Value, HalfTy,
5308                                        ShiftAmtTy);
5309   }
5310 
5311   // TODO: Expand with known bits.
5312 
5313   // Handle the fully general expansion by an unknown amount.
5314   auto NewBits = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize);
5315 
5316   Register InL = MRI.createGenericVirtualRegister(HalfTy);
5317   Register InH = MRI.createGenericVirtualRegister(HalfTy);
5318   MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1));
5319 
5320   auto AmtExcess = MIRBuilder.buildSub(ShiftAmtTy, Amt, NewBits);
5321   auto AmtLack = MIRBuilder.buildSub(ShiftAmtTy, NewBits, Amt);
5322 
5323   auto Zero = MIRBuilder.buildConstant(ShiftAmtTy, 0);
5324   auto IsShort = MIRBuilder.buildICmp(ICmpInst::ICMP_ULT, CondTy, Amt, NewBits);
5325   auto IsZero = MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, CondTy, Amt, Zero);
5326 
5327   Register ResultRegs[2];
5328   switch (MI.getOpcode()) {
5329   case TargetOpcode::G_SHL: {
5330     // Short: ShAmt < NewBitSize
5331     auto LoS = MIRBuilder.buildShl(HalfTy, InL, Amt);
5332 
5333     auto LoOr = MIRBuilder.buildLShr(HalfTy, InL, AmtLack);
5334     auto HiOr = MIRBuilder.buildShl(HalfTy, InH, Amt);
5335     auto HiS = MIRBuilder.buildOr(HalfTy, LoOr, HiOr);
5336 
5337     // Long: ShAmt >= NewBitSize
5338     auto LoL = MIRBuilder.buildConstant(HalfTy, 0);         // Lo part is zero.
5339     auto HiL = MIRBuilder.buildShl(HalfTy, InL, AmtExcess); // Hi from Lo part.
5340 
5341     auto Lo = MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL);
5342     auto Hi = MIRBuilder.buildSelect(
5343         HalfTy, IsZero, InH, MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL));
5344 
5345     ResultRegs[0] = Lo.getReg(0);
5346     ResultRegs[1] = Hi.getReg(0);
5347     break;
5348   }
5349   case TargetOpcode::G_LSHR:
5350   case TargetOpcode::G_ASHR: {
5351     // Short: ShAmt < NewBitSize
5352     auto HiS = MIRBuilder.buildInstr(MI.getOpcode(), {HalfTy}, {InH, Amt});
5353 
5354     auto LoOr = MIRBuilder.buildLShr(HalfTy, InL, Amt);
5355     auto HiOr = MIRBuilder.buildShl(HalfTy, InH, AmtLack);
5356     auto LoS = MIRBuilder.buildOr(HalfTy, LoOr, HiOr);
5357 
5358     // Long: ShAmt >= NewBitSize
5359     MachineInstrBuilder HiL;
5360     if (MI.getOpcode() == TargetOpcode::G_LSHR) {
5361       HiL = MIRBuilder.buildConstant(HalfTy, 0);            // Hi part is zero.
5362     } else {
5363       auto ShiftAmt = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize - 1);
5364       HiL = MIRBuilder.buildAShr(HalfTy, InH, ShiftAmt);    // Sign of Hi part.
5365     }
5366     auto LoL = MIRBuilder.buildInstr(MI.getOpcode(), {HalfTy},
5367                                      {InH, AmtExcess});     // Lo from Hi part.
5368 
5369     auto Lo = MIRBuilder.buildSelect(
5370         HalfTy, IsZero, InL, MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL));
5371 
5372     auto Hi = MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL);
5373 
5374     ResultRegs[0] = Lo.getReg(0);
5375     ResultRegs[1] = Hi.getReg(0);
5376     break;
5377   }
5378   default:
5379     llvm_unreachable("not a shift");
5380   }
5381 
5382   MIRBuilder.buildMergeLikeInstr(DstReg, ResultRegs);
5383   MI.eraseFromParent();
5384   return Legalized;
5385 }
5386 
5387 LegalizerHelper::LegalizeResult
moreElementsVectorPhi(MachineInstr & MI,unsigned TypeIdx,LLT MoreTy)5388 LegalizerHelper::moreElementsVectorPhi(MachineInstr &MI, unsigned TypeIdx,
5389                                        LLT MoreTy) {
5390   assert(TypeIdx == 0 && "Expecting only Idx 0");
5391 
5392   Observer.changingInstr(MI);
5393   for (unsigned I = 1, E = MI.getNumOperands(); I != E; I += 2) {
5394     MachineBasicBlock &OpMBB = *MI.getOperand(I + 1).getMBB();
5395     MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminator());
5396     moreElementsVectorSrc(MI, MoreTy, I);
5397   }
5398 
5399   MachineBasicBlock &MBB = *MI.getParent();
5400   MIRBuilder.setInsertPt(MBB, --MBB.getFirstNonPHI());
5401   moreElementsVectorDst(MI, MoreTy, 0);
5402   Observer.changedInstr(MI);
5403   return Legalized;
5404 }
5405 
getNeutralElementForVecReduce(unsigned Opcode,MachineIRBuilder & MIRBuilder,LLT Ty)5406 MachineInstrBuilder LegalizerHelper::getNeutralElementForVecReduce(
5407     unsigned Opcode, MachineIRBuilder &MIRBuilder, LLT Ty) {
5408   assert(Ty.isScalar() && "Expected scalar type to make neutral element for");
5409 
5410   switch (Opcode) {
5411   default:
5412     llvm_unreachable(
5413         "getNeutralElementForVecReduce called with invalid opcode!");
5414   case TargetOpcode::G_VECREDUCE_ADD:
5415   case TargetOpcode::G_VECREDUCE_OR:
5416   case TargetOpcode::G_VECREDUCE_XOR:
5417   case TargetOpcode::G_VECREDUCE_UMAX:
5418     return MIRBuilder.buildConstant(Ty, 0);
5419   case TargetOpcode::G_VECREDUCE_MUL:
5420     return MIRBuilder.buildConstant(Ty, 1);
5421   case TargetOpcode::G_VECREDUCE_AND:
5422   case TargetOpcode::G_VECREDUCE_UMIN:
5423     return MIRBuilder.buildConstant(
5424         Ty, APInt::getAllOnes(Ty.getScalarSizeInBits()));
5425   case TargetOpcode::G_VECREDUCE_SMAX:
5426     return MIRBuilder.buildConstant(
5427         Ty, APInt::getSignedMinValue(Ty.getSizeInBits()));
5428   case TargetOpcode::G_VECREDUCE_SMIN:
5429     return MIRBuilder.buildConstant(
5430         Ty, APInt::getSignedMaxValue(Ty.getSizeInBits()));
5431   case TargetOpcode::G_VECREDUCE_FADD:
5432     return MIRBuilder.buildFConstant(Ty, -0.0);
5433   case TargetOpcode::G_VECREDUCE_FMUL:
5434     return MIRBuilder.buildFConstant(Ty, 1.0);
5435   case TargetOpcode::G_VECREDUCE_FMINIMUM:
5436   case TargetOpcode::G_VECREDUCE_FMAXIMUM:
5437     assert(false && "getNeutralElementForVecReduce unimplemented for "
5438                     "G_VECREDUCE_FMINIMUM and G_VECREDUCE_FMAXIMUM!");
5439   }
5440   llvm_unreachable("switch expected to return!");
5441 }
5442 
5443 LegalizerHelper::LegalizeResult
moreElementsVector(MachineInstr & MI,unsigned TypeIdx,LLT MoreTy)5444 LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx,
5445                                     LLT MoreTy) {
5446   unsigned Opc = MI.getOpcode();
5447   switch (Opc) {
5448   case TargetOpcode::G_IMPLICIT_DEF:
5449   case TargetOpcode::G_LOAD: {
5450     if (TypeIdx != 0)
5451       return UnableToLegalize;
5452     Observer.changingInstr(MI);
5453     moreElementsVectorDst(MI, MoreTy, 0);
5454     Observer.changedInstr(MI);
5455     return Legalized;
5456   }
5457   case TargetOpcode::G_STORE:
5458     if (TypeIdx != 0)
5459       return UnableToLegalize;
5460     Observer.changingInstr(MI);
5461     moreElementsVectorSrc(MI, MoreTy, 0);
5462     Observer.changedInstr(MI);
5463     return Legalized;
5464   case TargetOpcode::G_AND:
5465   case TargetOpcode::G_OR:
5466   case TargetOpcode::G_XOR:
5467   case TargetOpcode::G_ADD:
5468   case TargetOpcode::G_SUB:
5469   case TargetOpcode::G_MUL:
5470   case TargetOpcode::G_FADD:
5471   case TargetOpcode::G_FSUB:
5472   case TargetOpcode::G_FMUL:
5473   case TargetOpcode::G_FDIV:
5474   case TargetOpcode::G_FCOPYSIGN:
5475   case TargetOpcode::G_UADDSAT:
5476   case TargetOpcode::G_USUBSAT:
5477   case TargetOpcode::G_SADDSAT:
5478   case TargetOpcode::G_SSUBSAT:
5479   case TargetOpcode::G_SMIN:
5480   case TargetOpcode::G_SMAX:
5481   case TargetOpcode::G_UMIN:
5482   case TargetOpcode::G_UMAX:
5483   case TargetOpcode::G_FMINNUM:
5484   case TargetOpcode::G_FMAXNUM:
5485   case TargetOpcode::G_FMINNUM_IEEE:
5486   case TargetOpcode::G_FMAXNUM_IEEE:
5487   case TargetOpcode::G_FMINIMUM:
5488   case TargetOpcode::G_FMAXIMUM:
5489   case TargetOpcode::G_STRICT_FADD:
5490   case TargetOpcode::G_STRICT_FSUB:
5491   case TargetOpcode::G_STRICT_FMUL:
5492   case TargetOpcode::G_SHL:
5493   case TargetOpcode::G_ASHR:
5494   case TargetOpcode::G_LSHR: {
5495     Observer.changingInstr(MI);
5496     moreElementsVectorSrc(MI, MoreTy, 1);
5497     moreElementsVectorSrc(MI, MoreTy, 2);
5498     moreElementsVectorDst(MI, MoreTy, 0);
5499     Observer.changedInstr(MI);
5500     return Legalized;
5501   }
5502   case TargetOpcode::G_FMA:
5503   case TargetOpcode::G_STRICT_FMA:
5504   case TargetOpcode::G_FSHR:
5505   case TargetOpcode::G_FSHL: {
5506     Observer.changingInstr(MI);
5507     moreElementsVectorSrc(MI, MoreTy, 1);
5508     moreElementsVectorSrc(MI, MoreTy, 2);
5509     moreElementsVectorSrc(MI, MoreTy, 3);
5510     moreElementsVectorDst(MI, MoreTy, 0);
5511     Observer.changedInstr(MI);
5512     return Legalized;
5513   }
5514   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
5515   case TargetOpcode::G_EXTRACT:
5516     if (TypeIdx != 1)
5517       return UnableToLegalize;
5518     Observer.changingInstr(MI);
5519     moreElementsVectorSrc(MI, MoreTy, 1);
5520     Observer.changedInstr(MI);
5521     return Legalized;
5522   case TargetOpcode::G_INSERT:
5523   case TargetOpcode::G_INSERT_VECTOR_ELT:
5524   case TargetOpcode::G_FREEZE:
5525   case TargetOpcode::G_FNEG:
5526   case TargetOpcode::G_FABS:
5527   case TargetOpcode::G_FSQRT:
5528   case TargetOpcode::G_FCEIL:
5529   case TargetOpcode::G_FFLOOR:
5530   case TargetOpcode::G_FNEARBYINT:
5531   case TargetOpcode::G_FRINT:
5532   case TargetOpcode::G_INTRINSIC_ROUND:
5533   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
5534   case TargetOpcode::G_INTRINSIC_TRUNC:
5535   case TargetOpcode::G_BSWAP:
5536   case TargetOpcode::G_FCANONICALIZE:
5537   case TargetOpcode::G_SEXT_INREG:
5538   case TargetOpcode::G_ABS:
5539     if (TypeIdx != 0)
5540       return UnableToLegalize;
5541     Observer.changingInstr(MI);
5542     moreElementsVectorSrc(MI, MoreTy, 1);
5543     moreElementsVectorDst(MI, MoreTy, 0);
5544     Observer.changedInstr(MI);
5545     return Legalized;
5546   case TargetOpcode::G_SELECT: {
5547     auto [DstReg, DstTy, CondReg, CondTy] = MI.getFirst2RegLLTs();
5548     if (TypeIdx == 1) {
5549       if (!CondTy.isScalar() ||
5550           DstTy.getElementCount() != MoreTy.getElementCount())
5551         return UnableToLegalize;
5552 
5553       // This is turning a scalar select of vectors into a vector
5554       // select. Broadcast the select condition.
5555       auto ShufSplat = MIRBuilder.buildShuffleSplat(MoreTy, CondReg);
5556       Observer.changingInstr(MI);
5557       MI.getOperand(1).setReg(ShufSplat.getReg(0));
5558       Observer.changedInstr(MI);
5559       return Legalized;
5560     }
5561 
5562     if (CondTy.isVector())
5563       return UnableToLegalize;
5564 
5565     Observer.changingInstr(MI);
5566     moreElementsVectorSrc(MI, MoreTy, 2);
5567     moreElementsVectorSrc(MI, MoreTy, 3);
5568     moreElementsVectorDst(MI, MoreTy, 0);
5569     Observer.changedInstr(MI);
5570     return Legalized;
5571   }
5572   case TargetOpcode::G_UNMERGE_VALUES:
5573     return UnableToLegalize;
5574   case TargetOpcode::G_PHI:
5575     return moreElementsVectorPhi(MI, TypeIdx, MoreTy);
5576   case TargetOpcode::G_SHUFFLE_VECTOR:
5577     return moreElementsVectorShuffle(MI, TypeIdx, MoreTy);
5578   case TargetOpcode::G_BUILD_VECTOR: {
5579     SmallVector<SrcOp, 8> Elts;
5580     for (auto Op : MI.uses()) {
5581       Elts.push_back(Op.getReg());
5582     }
5583 
5584     for (unsigned i = Elts.size(); i < MoreTy.getNumElements(); ++i) {
5585       Elts.push_back(MIRBuilder.buildUndef(MoreTy.getScalarType()));
5586     }
5587 
5588     MIRBuilder.buildDeleteTrailingVectorElements(
5589         MI.getOperand(0).getReg(), MIRBuilder.buildInstr(Opc, {MoreTy}, Elts));
5590     MI.eraseFromParent();
5591     return Legalized;
5592   }
5593   case TargetOpcode::G_SEXT:
5594   case TargetOpcode::G_ZEXT:
5595   case TargetOpcode::G_ANYEXT:
5596   case TargetOpcode::G_TRUNC:
5597   case TargetOpcode::G_FPTRUNC:
5598   case TargetOpcode::G_FPEXT:
5599   case TargetOpcode::G_FPTOSI:
5600   case TargetOpcode::G_FPTOUI:
5601   case TargetOpcode::G_SITOFP:
5602   case TargetOpcode::G_UITOFP: {
5603     Observer.changingInstr(MI);
5604     LLT SrcExtTy;
5605     LLT DstExtTy;
5606     if (TypeIdx == 0) {
5607       DstExtTy = MoreTy;
5608       SrcExtTy = LLT::fixed_vector(
5609           MoreTy.getNumElements(),
5610           MRI.getType(MI.getOperand(1).getReg()).getElementType());
5611     } else {
5612       DstExtTy = LLT::fixed_vector(
5613           MoreTy.getNumElements(),
5614           MRI.getType(MI.getOperand(0).getReg()).getElementType());
5615       SrcExtTy = MoreTy;
5616     }
5617     moreElementsVectorSrc(MI, SrcExtTy, 1);
5618     moreElementsVectorDst(MI, DstExtTy, 0);
5619     Observer.changedInstr(MI);
5620     return Legalized;
5621   }
5622   case TargetOpcode::G_ICMP:
5623   case TargetOpcode::G_FCMP: {
5624     if (TypeIdx != 1)
5625       return UnableToLegalize;
5626 
5627     Observer.changingInstr(MI);
5628     moreElementsVectorSrc(MI, MoreTy, 2);
5629     moreElementsVectorSrc(MI, MoreTy, 3);
5630     LLT CondTy = LLT::fixed_vector(
5631         MoreTy.getNumElements(),
5632         MRI.getType(MI.getOperand(0).getReg()).getElementType());
5633     moreElementsVectorDst(MI, CondTy, 0);
5634     Observer.changedInstr(MI);
5635     return Legalized;
5636   }
5637   case TargetOpcode::G_BITCAST: {
5638     if (TypeIdx != 0)
5639       return UnableToLegalize;
5640 
5641     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
5642     LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5643 
5644     unsigned coefficient = SrcTy.getNumElements() * MoreTy.getNumElements();
5645     if (coefficient % DstTy.getNumElements() != 0)
5646       return UnableToLegalize;
5647 
5648     coefficient = coefficient / DstTy.getNumElements();
5649 
5650     LLT NewTy = SrcTy.changeElementCount(
5651         ElementCount::get(coefficient, MoreTy.isScalable()));
5652     Observer.changingInstr(MI);
5653     moreElementsVectorSrc(MI, NewTy, 1);
5654     moreElementsVectorDst(MI, MoreTy, 0);
5655     Observer.changedInstr(MI);
5656     return Legalized;
5657   }
5658   case TargetOpcode::G_VECREDUCE_FADD:
5659   case TargetOpcode::G_VECREDUCE_FMUL:
5660   case TargetOpcode::G_VECREDUCE_ADD:
5661   case TargetOpcode::G_VECREDUCE_MUL:
5662   case TargetOpcode::G_VECREDUCE_AND:
5663   case TargetOpcode::G_VECREDUCE_OR:
5664   case TargetOpcode::G_VECREDUCE_XOR:
5665   case TargetOpcode::G_VECREDUCE_SMAX:
5666   case TargetOpcode::G_VECREDUCE_SMIN:
5667   case TargetOpcode::G_VECREDUCE_UMAX:
5668   case TargetOpcode::G_VECREDUCE_UMIN: {
5669     LLT OrigTy = MRI.getType(MI.getOperand(1).getReg());
5670     MachineOperand &MO = MI.getOperand(1);
5671     auto NewVec = MIRBuilder.buildPadVectorWithUndefElements(MoreTy, MO);
5672     auto NeutralElement = getNeutralElementForVecReduce(
5673         MI.getOpcode(), MIRBuilder, MoreTy.getElementType());
5674 
5675     LLT IdxTy(TLI.getVectorIdxTy(MIRBuilder.getDataLayout()));
5676     for (size_t i = OrigTy.getNumElements(), e = MoreTy.getNumElements();
5677          i != e; i++) {
5678       auto Idx = MIRBuilder.buildConstant(IdxTy, i);
5679       NewVec = MIRBuilder.buildInsertVectorElement(MoreTy, NewVec,
5680                                                    NeutralElement, Idx);
5681     }
5682 
5683     Observer.changingInstr(MI);
5684     MO.setReg(NewVec.getReg(0));
5685     Observer.changedInstr(MI);
5686     return Legalized;
5687   }
5688 
5689   default:
5690     return UnableToLegalize;
5691   }
5692 }
5693 
5694 LegalizerHelper::LegalizeResult
equalizeVectorShuffleLengths(MachineInstr & MI)5695 LegalizerHelper::equalizeVectorShuffleLengths(MachineInstr &MI) {
5696   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
5697   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
5698   unsigned MaskNumElts = Mask.size();
5699   unsigned SrcNumElts = SrcTy.getNumElements();
5700   LLT DestEltTy = DstTy.getElementType();
5701 
5702   if (MaskNumElts == SrcNumElts)
5703     return Legalized;
5704 
5705   if (MaskNumElts < SrcNumElts) {
5706     // Extend mask to match new destination vector size with
5707     // undef values.
5708     SmallVector<int, 16> NewMask(Mask);
5709     for (unsigned I = MaskNumElts; I < SrcNumElts; ++I)
5710       NewMask.push_back(-1);
5711 
5712     moreElementsVectorDst(MI, SrcTy, 0);
5713     MIRBuilder.setInstrAndDebugLoc(MI);
5714     MIRBuilder.buildShuffleVector(MI.getOperand(0).getReg(),
5715                                   MI.getOperand(1).getReg(),
5716                                   MI.getOperand(2).getReg(), NewMask);
5717     MI.eraseFromParent();
5718 
5719     return Legalized;
5720   }
5721 
5722   unsigned PaddedMaskNumElts = alignTo(MaskNumElts, SrcNumElts);
5723   unsigned NumConcat = PaddedMaskNumElts / SrcNumElts;
5724   LLT PaddedTy = LLT::fixed_vector(PaddedMaskNumElts, DestEltTy);
5725 
5726   // Create new source vectors by concatenating the initial
5727   // source vectors with undefined vectors of the same size.
5728   auto Undef = MIRBuilder.buildUndef(SrcTy);
5729   SmallVector<Register, 8> MOps1(NumConcat, Undef.getReg(0));
5730   SmallVector<Register, 8> MOps2(NumConcat, Undef.getReg(0));
5731   MOps1[0] = MI.getOperand(1).getReg();
5732   MOps2[0] = MI.getOperand(2).getReg();
5733 
5734   auto Src1 = MIRBuilder.buildConcatVectors(PaddedTy, MOps1);
5735   auto Src2 = MIRBuilder.buildConcatVectors(PaddedTy, MOps2);
5736 
5737   // Readjust mask for new input vector length.
5738   SmallVector<int, 8> MappedOps(PaddedMaskNumElts, -1);
5739   for (unsigned I = 0; I != MaskNumElts; ++I) {
5740     int Idx = Mask[I];
5741     if (Idx >= static_cast<int>(SrcNumElts))
5742       Idx += PaddedMaskNumElts - SrcNumElts;
5743     MappedOps[I] = Idx;
5744   }
5745 
5746   // If we got more elements than required, extract subvector.
5747   if (MaskNumElts != PaddedMaskNumElts) {
5748     auto Shuffle =
5749         MIRBuilder.buildShuffleVector(PaddedTy, Src1, Src2, MappedOps);
5750 
5751     SmallVector<Register, 16> Elts(MaskNumElts);
5752     for (unsigned I = 0; I < MaskNumElts; ++I) {
5753       Elts[I] =
5754           MIRBuilder.buildExtractVectorElementConstant(DestEltTy, Shuffle, I)
5755               .getReg(0);
5756     }
5757     MIRBuilder.buildBuildVector(DstReg, Elts);
5758   } else {
5759     MIRBuilder.buildShuffleVector(DstReg, Src1, Src2, MappedOps);
5760   }
5761 
5762   MI.eraseFromParent();
5763   return LegalizerHelper::LegalizeResult::Legalized;
5764 }
5765 
5766 LegalizerHelper::LegalizeResult
moreElementsVectorShuffle(MachineInstr & MI,unsigned int TypeIdx,LLT MoreTy)5767 LegalizerHelper::moreElementsVectorShuffle(MachineInstr &MI,
5768                                            unsigned int TypeIdx, LLT MoreTy) {
5769   auto [DstTy, Src1Ty, Src2Ty] = MI.getFirst3LLTs();
5770   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
5771   unsigned NumElts = DstTy.getNumElements();
5772   unsigned WidenNumElts = MoreTy.getNumElements();
5773 
5774   if (DstTy.isVector() && Src1Ty.isVector() &&
5775       DstTy.getNumElements() != Src1Ty.getNumElements()) {
5776     return equalizeVectorShuffleLengths(MI);
5777   }
5778 
5779   if (TypeIdx != 0)
5780     return UnableToLegalize;
5781 
5782   // Expect a canonicalized shuffle.
5783   if (DstTy != Src1Ty || DstTy != Src2Ty)
5784     return UnableToLegalize;
5785 
5786   moreElementsVectorSrc(MI, MoreTy, 1);
5787   moreElementsVectorSrc(MI, MoreTy, 2);
5788 
5789   // Adjust mask based on new input vector length.
5790   SmallVector<int, 16> NewMask;
5791   for (unsigned I = 0; I != NumElts; ++I) {
5792     int Idx = Mask[I];
5793     if (Idx < static_cast<int>(NumElts))
5794       NewMask.push_back(Idx);
5795     else
5796       NewMask.push_back(Idx - NumElts + WidenNumElts);
5797   }
5798   for (unsigned I = NumElts; I != WidenNumElts; ++I)
5799     NewMask.push_back(-1);
5800   moreElementsVectorDst(MI, MoreTy, 0);
5801   MIRBuilder.setInstrAndDebugLoc(MI);
5802   MIRBuilder.buildShuffleVector(MI.getOperand(0).getReg(),
5803                                 MI.getOperand(1).getReg(),
5804                                 MI.getOperand(2).getReg(), NewMask);
5805   MI.eraseFromParent();
5806   return Legalized;
5807 }
5808 
multiplyRegisters(SmallVectorImpl<Register> & DstRegs,ArrayRef<Register> Src1Regs,ArrayRef<Register> Src2Regs,LLT NarrowTy)5809 void LegalizerHelper::multiplyRegisters(SmallVectorImpl<Register> &DstRegs,
5810                                         ArrayRef<Register> Src1Regs,
5811                                         ArrayRef<Register> Src2Regs,
5812                                         LLT NarrowTy) {
5813   MachineIRBuilder &B = MIRBuilder;
5814   unsigned SrcParts = Src1Regs.size();
5815   unsigned DstParts = DstRegs.size();
5816 
5817   unsigned DstIdx = 0; // Low bits of the result.
5818   Register FactorSum =
5819       B.buildMul(NarrowTy, Src1Regs[DstIdx], Src2Regs[DstIdx]).getReg(0);
5820   DstRegs[DstIdx] = FactorSum;
5821 
5822   unsigned CarrySumPrevDstIdx;
5823   SmallVector<Register, 4> Factors;
5824 
5825   for (DstIdx = 1; DstIdx < DstParts; DstIdx++) {
5826     // Collect low parts of muls for DstIdx.
5827     for (unsigned i = DstIdx + 1 < SrcParts ? 0 : DstIdx - SrcParts + 1;
5828          i <= std::min(DstIdx, SrcParts - 1); ++i) {
5829       MachineInstrBuilder Mul =
5830           B.buildMul(NarrowTy, Src1Regs[DstIdx - i], Src2Regs[i]);
5831       Factors.push_back(Mul.getReg(0));
5832     }
5833     // Collect high parts of muls from previous DstIdx.
5834     for (unsigned i = DstIdx < SrcParts ? 0 : DstIdx - SrcParts;
5835          i <= std::min(DstIdx - 1, SrcParts - 1); ++i) {
5836       MachineInstrBuilder Umulh =
5837           B.buildUMulH(NarrowTy, Src1Regs[DstIdx - 1 - i], Src2Regs[i]);
5838       Factors.push_back(Umulh.getReg(0));
5839     }
5840     // Add CarrySum from additions calculated for previous DstIdx.
5841     if (DstIdx != 1) {
5842       Factors.push_back(CarrySumPrevDstIdx);
5843     }
5844 
5845     Register CarrySum;
5846     // Add all factors and accumulate all carries into CarrySum.
5847     if (DstIdx != DstParts - 1) {
5848       MachineInstrBuilder Uaddo =
5849           B.buildUAddo(NarrowTy, LLT::scalar(1), Factors[0], Factors[1]);
5850       FactorSum = Uaddo.getReg(0);
5851       CarrySum = B.buildZExt(NarrowTy, Uaddo.getReg(1)).getReg(0);
5852       for (unsigned i = 2; i < Factors.size(); ++i) {
5853         MachineInstrBuilder Uaddo =
5854             B.buildUAddo(NarrowTy, LLT::scalar(1), FactorSum, Factors[i]);
5855         FactorSum = Uaddo.getReg(0);
5856         MachineInstrBuilder Carry = B.buildZExt(NarrowTy, Uaddo.getReg(1));
5857         CarrySum = B.buildAdd(NarrowTy, CarrySum, Carry).getReg(0);
5858       }
5859     } else {
5860       // Since value for the next index is not calculated, neither is CarrySum.
5861       FactorSum = B.buildAdd(NarrowTy, Factors[0], Factors[1]).getReg(0);
5862       for (unsigned i = 2; i < Factors.size(); ++i)
5863         FactorSum = B.buildAdd(NarrowTy, FactorSum, Factors[i]).getReg(0);
5864     }
5865 
5866     CarrySumPrevDstIdx = CarrySum;
5867     DstRegs[DstIdx] = FactorSum;
5868     Factors.clear();
5869   }
5870 }
5871 
5872 LegalizerHelper::LegalizeResult
narrowScalarAddSub(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5873 LegalizerHelper::narrowScalarAddSub(MachineInstr &MI, unsigned TypeIdx,
5874                                     LLT NarrowTy) {
5875   if (TypeIdx != 0)
5876     return UnableToLegalize;
5877 
5878   Register DstReg = MI.getOperand(0).getReg();
5879   LLT DstType = MRI.getType(DstReg);
5880   // FIXME: add support for vector types
5881   if (DstType.isVector())
5882     return UnableToLegalize;
5883 
5884   unsigned Opcode = MI.getOpcode();
5885   unsigned OpO, OpE, OpF;
5886   switch (Opcode) {
5887   case TargetOpcode::G_SADDO:
5888   case TargetOpcode::G_SADDE:
5889   case TargetOpcode::G_UADDO:
5890   case TargetOpcode::G_UADDE:
5891   case TargetOpcode::G_ADD:
5892     OpO = TargetOpcode::G_UADDO;
5893     OpE = TargetOpcode::G_UADDE;
5894     OpF = TargetOpcode::G_UADDE;
5895     if (Opcode == TargetOpcode::G_SADDO || Opcode == TargetOpcode::G_SADDE)
5896       OpF = TargetOpcode::G_SADDE;
5897     break;
5898   case TargetOpcode::G_SSUBO:
5899   case TargetOpcode::G_SSUBE:
5900   case TargetOpcode::G_USUBO:
5901   case TargetOpcode::G_USUBE:
5902   case TargetOpcode::G_SUB:
5903     OpO = TargetOpcode::G_USUBO;
5904     OpE = TargetOpcode::G_USUBE;
5905     OpF = TargetOpcode::G_USUBE;
5906     if (Opcode == TargetOpcode::G_SSUBO || Opcode == TargetOpcode::G_SSUBE)
5907       OpF = TargetOpcode::G_SSUBE;
5908     break;
5909   default:
5910     llvm_unreachable("Unexpected add/sub opcode!");
5911   }
5912 
5913   // 1 for a plain add/sub, 2 if this is an operation with a carry-out.
5914   unsigned NumDefs = MI.getNumExplicitDefs();
5915   Register Src1 = MI.getOperand(NumDefs).getReg();
5916   Register Src2 = MI.getOperand(NumDefs + 1).getReg();
5917   Register CarryDst, CarryIn;
5918   if (NumDefs == 2)
5919     CarryDst = MI.getOperand(1).getReg();
5920   if (MI.getNumOperands() == NumDefs + 3)
5921     CarryIn = MI.getOperand(NumDefs + 2).getReg();
5922 
5923   LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
5924   LLT LeftoverTy, DummyTy;
5925   SmallVector<Register, 2> Src1Regs, Src2Regs, Src1Left, Src2Left, DstRegs;
5926   extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left,
5927                MIRBuilder, MRI);
5928   extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left, MIRBuilder,
5929                MRI);
5930 
5931   int NarrowParts = Src1Regs.size();
5932   for (int I = 0, E = Src1Left.size(); I != E; ++I) {
5933     Src1Regs.push_back(Src1Left[I]);
5934     Src2Regs.push_back(Src2Left[I]);
5935   }
5936   DstRegs.reserve(Src1Regs.size());
5937 
5938   for (int i = 0, e = Src1Regs.size(); i != e; ++i) {
5939     Register DstReg =
5940         MRI.createGenericVirtualRegister(MRI.getType(Src1Regs[i]));
5941     Register CarryOut = MRI.createGenericVirtualRegister(LLT::scalar(1));
5942     // Forward the final carry-out to the destination register
5943     if (i == e - 1 && CarryDst)
5944       CarryOut = CarryDst;
5945 
5946     if (!CarryIn) {
5947       MIRBuilder.buildInstr(OpO, {DstReg, CarryOut},
5948                             {Src1Regs[i], Src2Regs[i]});
5949     } else if (i == e - 1) {
5950       MIRBuilder.buildInstr(OpF, {DstReg, CarryOut},
5951                             {Src1Regs[i], Src2Regs[i], CarryIn});
5952     } else {
5953       MIRBuilder.buildInstr(OpE, {DstReg, CarryOut},
5954                             {Src1Regs[i], Src2Regs[i], CarryIn});
5955     }
5956 
5957     DstRegs.push_back(DstReg);
5958     CarryIn = CarryOut;
5959   }
5960   insertParts(MI.getOperand(0).getReg(), RegTy, NarrowTy,
5961               ArrayRef(DstRegs).take_front(NarrowParts), LeftoverTy,
5962               ArrayRef(DstRegs).drop_front(NarrowParts));
5963 
5964   MI.eraseFromParent();
5965   return Legalized;
5966 }
5967 
5968 LegalizerHelper::LegalizeResult
narrowScalarMul(MachineInstr & MI,LLT NarrowTy)5969 LegalizerHelper::narrowScalarMul(MachineInstr &MI, LLT NarrowTy) {
5970   auto [DstReg, Src1, Src2] = MI.getFirst3Regs();
5971 
5972   LLT Ty = MRI.getType(DstReg);
5973   if (Ty.isVector())
5974     return UnableToLegalize;
5975 
5976   unsigned Size = Ty.getSizeInBits();
5977   unsigned NarrowSize = NarrowTy.getSizeInBits();
5978   if (Size % NarrowSize != 0)
5979     return UnableToLegalize;
5980 
5981   unsigned NumParts = Size / NarrowSize;
5982   bool IsMulHigh = MI.getOpcode() == TargetOpcode::G_UMULH;
5983   unsigned DstTmpParts = NumParts * (IsMulHigh ? 2 : 1);
5984 
5985   SmallVector<Register, 2> Src1Parts, Src2Parts;
5986   SmallVector<Register, 2> DstTmpRegs(DstTmpParts);
5987   extractParts(Src1, NarrowTy, NumParts, Src1Parts, MIRBuilder, MRI);
5988   extractParts(Src2, NarrowTy, NumParts, Src2Parts, MIRBuilder, MRI);
5989   multiplyRegisters(DstTmpRegs, Src1Parts, Src2Parts, NarrowTy);
5990 
5991   // Take only high half of registers if this is high mul.
5992   ArrayRef<Register> DstRegs(&DstTmpRegs[DstTmpParts - NumParts], NumParts);
5993   MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5994   MI.eraseFromParent();
5995   return Legalized;
5996 }
5997 
5998 LegalizerHelper::LegalizeResult
narrowScalarFPTOI(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5999 LegalizerHelper::narrowScalarFPTOI(MachineInstr &MI, unsigned TypeIdx,
6000                                    LLT NarrowTy) {
6001   if (TypeIdx != 0)
6002     return UnableToLegalize;
6003 
6004   bool IsSigned = MI.getOpcode() == TargetOpcode::G_FPTOSI;
6005 
6006   Register Src = MI.getOperand(1).getReg();
6007   LLT SrcTy = MRI.getType(Src);
6008 
6009   // If all finite floats fit into the narrowed integer type, we can just swap
6010   // out the result type. This is practically only useful for conversions from
6011   // half to at least 16-bits, so just handle the one case.
6012   if (SrcTy.getScalarType() != LLT::scalar(16) ||
6013       NarrowTy.getScalarSizeInBits() < (IsSigned ? 17u : 16u))
6014     return UnableToLegalize;
6015 
6016   Observer.changingInstr(MI);
6017   narrowScalarDst(MI, NarrowTy, 0,
6018                   IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT);
6019   Observer.changedInstr(MI);
6020   return Legalized;
6021 }
6022 
6023 LegalizerHelper::LegalizeResult
narrowScalarExtract(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6024 LegalizerHelper::narrowScalarExtract(MachineInstr &MI, unsigned TypeIdx,
6025                                      LLT NarrowTy) {
6026   if (TypeIdx != 1)
6027     return UnableToLegalize;
6028 
6029   uint64_t NarrowSize = NarrowTy.getSizeInBits();
6030 
6031   int64_t SizeOp1 = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
6032   // FIXME: add support for when SizeOp1 isn't an exact multiple of
6033   // NarrowSize.
6034   if (SizeOp1 % NarrowSize != 0)
6035     return UnableToLegalize;
6036   int NumParts = SizeOp1 / NarrowSize;
6037 
6038   SmallVector<Register, 2> SrcRegs, DstRegs;
6039   SmallVector<uint64_t, 2> Indexes;
6040   extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
6041                MIRBuilder, MRI);
6042 
6043   Register OpReg = MI.getOperand(0).getReg();
6044   uint64_t OpStart = MI.getOperand(2).getImm();
6045   uint64_t OpSize = MRI.getType(OpReg).getSizeInBits();
6046   for (int i = 0; i < NumParts; ++i) {
6047     unsigned SrcStart = i * NarrowSize;
6048 
6049     if (SrcStart + NarrowSize <= OpStart || SrcStart >= OpStart + OpSize) {
6050       // No part of the extract uses this subregister, ignore it.
6051       continue;
6052     } else if (SrcStart == OpStart && NarrowTy == MRI.getType(OpReg)) {
6053       // The entire subregister is extracted, forward the value.
6054       DstRegs.push_back(SrcRegs[i]);
6055       continue;
6056     }
6057 
6058     // OpSegStart is where this destination segment would start in OpReg if it
6059     // extended infinitely in both directions.
6060     int64_t ExtractOffset;
6061     uint64_t SegSize;
6062     if (OpStart < SrcStart) {
6063       ExtractOffset = 0;
6064       SegSize = std::min(NarrowSize, OpStart + OpSize - SrcStart);
6065     } else {
6066       ExtractOffset = OpStart - SrcStart;
6067       SegSize = std::min(SrcStart + NarrowSize - OpStart, OpSize);
6068     }
6069 
6070     Register SegReg = SrcRegs[i];
6071     if (ExtractOffset != 0 || SegSize != NarrowSize) {
6072       // A genuine extract is needed.
6073       SegReg = MRI.createGenericVirtualRegister(LLT::scalar(SegSize));
6074       MIRBuilder.buildExtract(SegReg, SrcRegs[i], ExtractOffset);
6075     }
6076 
6077     DstRegs.push_back(SegReg);
6078   }
6079 
6080   Register DstReg = MI.getOperand(0).getReg();
6081   if (MRI.getType(DstReg).isVector())
6082     MIRBuilder.buildBuildVector(DstReg, DstRegs);
6083   else if (DstRegs.size() > 1)
6084     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
6085   else
6086     MIRBuilder.buildCopy(DstReg, DstRegs[0]);
6087   MI.eraseFromParent();
6088   return Legalized;
6089 }
6090 
6091 LegalizerHelper::LegalizeResult
narrowScalarInsert(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6092 LegalizerHelper::narrowScalarInsert(MachineInstr &MI, unsigned TypeIdx,
6093                                     LLT NarrowTy) {
6094   // FIXME: Don't know how to handle secondary types yet.
6095   if (TypeIdx != 0)
6096     return UnableToLegalize;
6097 
6098   SmallVector<Register, 2> SrcRegs, LeftoverRegs, DstRegs;
6099   SmallVector<uint64_t, 2> Indexes;
6100   LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
6101   LLT LeftoverTy;
6102   extractParts(MI.getOperand(1).getReg(), RegTy, NarrowTy, LeftoverTy, SrcRegs,
6103                LeftoverRegs, MIRBuilder, MRI);
6104 
6105   for (Register Reg : LeftoverRegs)
6106     SrcRegs.push_back(Reg);
6107 
6108   uint64_t NarrowSize = NarrowTy.getSizeInBits();
6109   Register OpReg = MI.getOperand(2).getReg();
6110   uint64_t OpStart = MI.getOperand(3).getImm();
6111   uint64_t OpSize = MRI.getType(OpReg).getSizeInBits();
6112   for (int I = 0, E = SrcRegs.size(); I != E; ++I) {
6113     unsigned DstStart = I * NarrowSize;
6114 
6115     if (DstStart == OpStart && NarrowTy == MRI.getType(OpReg)) {
6116       // The entire subregister is defined by this insert, forward the new
6117       // value.
6118       DstRegs.push_back(OpReg);
6119       continue;
6120     }
6121 
6122     Register SrcReg = SrcRegs[I];
6123     if (MRI.getType(SrcRegs[I]) == LeftoverTy) {
6124       // The leftover reg is smaller than NarrowTy, so we need to extend it.
6125       SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
6126       MIRBuilder.buildAnyExt(SrcReg, SrcRegs[I]);
6127     }
6128 
6129     if (DstStart + NarrowSize <= OpStart || DstStart >= OpStart + OpSize) {
6130       // No part of the insert affects this subregister, forward the original.
6131       DstRegs.push_back(SrcReg);
6132       continue;
6133     }
6134 
6135     // OpSegStart is where this destination segment would start in OpReg if it
6136     // extended infinitely in both directions.
6137     int64_t ExtractOffset, InsertOffset;
6138     uint64_t SegSize;
6139     if (OpStart < DstStart) {
6140       InsertOffset = 0;
6141       ExtractOffset = DstStart - OpStart;
6142       SegSize = std::min(NarrowSize, OpStart + OpSize - DstStart);
6143     } else {
6144       InsertOffset = OpStart - DstStart;
6145       ExtractOffset = 0;
6146       SegSize =
6147         std::min(NarrowSize - InsertOffset, OpStart + OpSize - DstStart);
6148     }
6149 
6150     Register SegReg = OpReg;
6151     if (ExtractOffset != 0 || SegSize != OpSize) {
6152       // A genuine extract is needed.
6153       SegReg = MRI.createGenericVirtualRegister(LLT::scalar(SegSize));
6154       MIRBuilder.buildExtract(SegReg, OpReg, ExtractOffset);
6155     }
6156 
6157     Register DstReg = MRI.createGenericVirtualRegister(NarrowTy);
6158     MIRBuilder.buildInsert(DstReg, SrcReg, SegReg, InsertOffset);
6159     DstRegs.push_back(DstReg);
6160   }
6161 
6162   uint64_t WideSize = DstRegs.size() * NarrowSize;
6163   Register DstReg = MI.getOperand(0).getReg();
6164   if (WideSize > RegTy.getSizeInBits()) {
6165     Register MergeReg = MRI.createGenericVirtualRegister(LLT::scalar(WideSize));
6166     MIRBuilder.buildMergeLikeInstr(MergeReg, DstRegs);
6167     MIRBuilder.buildTrunc(DstReg, MergeReg);
6168   } else
6169     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
6170 
6171   MI.eraseFromParent();
6172   return Legalized;
6173 }
6174 
6175 LegalizerHelper::LegalizeResult
narrowScalarBasic(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6176 LegalizerHelper::narrowScalarBasic(MachineInstr &MI, unsigned TypeIdx,
6177                                    LLT NarrowTy) {
6178   Register DstReg = MI.getOperand(0).getReg();
6179   LLT DstTy = MRI.getType(DstReg);
6180 
6181   assert(MI.getNumOperands() == 3 && TypeIdx == 0);
6182 
6183   SmallVector<Register, 4> DstRegs, DstLeftoverRegs;
6184   SmallVector<Register, 4> Src0Regs, Src0LeftoverRegs;
6185   SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
6186   LLT LeftoverTy;
6187   if (!extractParts(MI.getOperand(1).getReg(), DstTy, NarrowTy, LeftoverTy,
6188                     Src0Regs, Src0LeftoverRegs, MIRBuilder, MRI))
6189     return UnableToLegalize;
6190 
6191   LLT Unused;
6192   if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, Unused,
6193                     Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
6194     llvm_unreachable("inconsistent extractParts result");
6195 
6196   for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
6197     auto Inst = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
6198                                         {Src0Regs[I], Src1Regs[I]});
6199     DstRegs.push_back(Inst.getReg(0));
6200   }
6201 
6202   for (unsigned I = 0, E = Src1LeftoverRegs.size(); I != E; ++I) {
6203     auto Inst = MIRBuilder.buildInstr(
6204       MI.getOpcode(),
6205       {LeftoverTy}, {Src0LeftoverRegs[I], Src1LeftoverRegs[I]});
6206     DstLeftoverRegs.push_back(Inst.getReg(0));
6207   }
6208 
6209   insertParts(DstReg, DstTy, NarrowTy, DstRegs,
6210               LeftoverTy, DstLeftoverRegs);
6211 
6212   MI.eraseFromParent();
6213   return Legalized;
6214 }
6215 
6216 LegalizerHelper::LegalizeResult
narrowScalarExt(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6217 LegalizerHelper::narrowScalarExt(MachineInstr &MI, unsigned TypeIdx,
6218                                  LLT NarrowTy) {
6219   if (TypeIdx != 0)
6220     return UnableToLegalize;
6221 
6222   auto [DstReg, SrcReg] = MI.getFirst2Regs();
6223 
6224   LLT DstTy = MRI.getType(DstReg);
6225   if (DstTy.isVector())
6226     return UnableToLegalize;
6227 
6228   SmallVector<Register, 8> Parts;
6229   LLT GCDTy = extractGCDType(Parts, DstTy, NarrowTy, SrcReg);
6230   LLT LCMTy = buildLCMMergePieces(DstTy, NarrowTy, GCDTy, Parts, MI.getOpcode());
6231   buildWidenedRemergeToDst(DstReg, LCMTy, Parts);
6232 
6233   MI.eraseFromParent();
6234   return Legalized;
6235 }
6236 
6237 LegalizerHelper::LegalizeResult
narrowScalarSelect(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6238 LegalizerHelper::narrowScalarSelect(MachineInstr &MI, unsigned TypeIdx,
6239                                     LLT NarrowTy) {
6240   if (TypeIdx != 0)
6241     return UnableToLegalize;
6242 
6243   Register CondReg = MI.getOperand(1).getReg();
6244   LLT CondTy = MRI.getType(CondReg);
6245   if (CondTy.isVector()) // TODO: Handle vselect
6246     return UnableToLegalize;
6247 
6248   Register DstReg = MI.getOperand(0).getReg();
6249   LLT DstTy = MRI.getType(DstReg);
6250 
6251   SmallVector<Register, 4> DstRegs, DstLeftoverRegs;
6252   SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
6253   SmallVector<Register, 4> Src2Regs, Src2LeftoverRegs;
6254   LLT LeftoverTy;
6255   if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, LeftoverTy,
6256                     Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
6257     return UnableToLegalize;
6258 
6259   LLT Unused;
6260   if (!extractParts(MI.getOperand(3).getReg(), DstTy, NarrowTy, Unused,
6261                     Src2Regs, Src2LeftoverRegs, MIRBuilder, MRI))
6262     llvm_unreachable("inconsistent extractParts result");
6263 
6264   for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
6265     auto Select = MIRBuilder.buildSelect(NarrowTy,
6266                                          CondReg, Src1Regs[I], Src2Regs[I]);
6267     DstRegs.push_back(Select.getReg(0));
6268   }
6269 
6270   for (unsigned I = 0, E = Src1LeftoverRegs.size(); I != E; ++I) {
6271     auto Select = MIRBuilder.buildSelect(
6272       LeftoverTy, CondReg, Src1LeftoverRegs[I], Src2LeftoverRegs[I]);
6273     DstLeftoverRegs.push_back(Select.getReg(0));
6274   }
6275 
6276   insertParts(DstReg, DstTy, NarrowTy, DstRegs,
6277               LeftoverTy, DstLeftoverRegs);
6278 
6279   MI.eraseFromParent();
6280   return Legalized;
6281 }
6282 
6283 LegalizerHelper::LegalizeResult
narrowScalarCTLZ(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6284 LegalizerHelper::narrowScalarCTLZ(MachineInstr &MI, unsigned TypeIdx,
6285                                   LLT NarrowTy) {
6286   if (TypeIdx != 1)
6287     return UnableToLegalize;
6288 
6289   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6290   unsigned NarrowSize = NarrowTy.getSizeInBits();
6291 
6292   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
6293     const bool IsUndef = MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF;
6294 
6295     MachineIRBuilder &B = MIRBuilder;
6296     auto UnmergeSrc = B.buildUnmerge(NarrowTy, SrcReg);
6297     // ctlz(Hi:Lo) -> Hi == 0 ? (NarrowSize + ctlz(Lo)) : ctlz(Hi)
6298     auto C_0 = B.buildConstant(NarrowTy, 0);
6299     auto HiIsZero = B.buildICmp(CmpInst::ICMP_EQ, LLT::scalar(1),
6300                                 UnmergeSrc.getReg(1), C_0);
6301     auto LoCTLZ = IsUndef ?
6302       B.buildCTLZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(0)) :
6303       B.buildCTLZ(DstTy, UnmergeSrc.getReg(0));
6304     auto C_NarrowSize = B.buildConstant(DstTy, NarrowSize);
6305     auto HiIsZeroCTLZ = B.buildAdd(DstTy, LoCTLZ, C_NarrowSize);
6306     auto HiCTLZ = B.buildCTLZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(1));
6307     B.buildSelect(DstReg, HiIsZero, HiIsZeroCTLZ, HiCTLZ);
6308 
6309     MI.eraseFromParent();
6310     return Legalized;
6311   }
6312 
6313   return UnableToLegalize;
6314 }
6315 
6316 LegalizerHelper::LegalizeResult
narrowScalarCTTZ(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6317 LegalizerHelper::narrowScalarCTTZ(MachineInstr &MI, unsigned TypeIdx,
6318                                   LLT NarrowTy) {
6319   if (TypeIdx != 1)
6320     return UnableToLegalize;
6321 
6322   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6323   unsigned NarrowSize = NarrowTy.getSizeInBits();
6324 
6325   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
6326     const bool IsUndef = MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF;
6327 
6328     MachineIRBuilder &B = MIRBuilder;
6329     auto UnmergeSrc = B.buildUnmerge(NarrowTy, SrcReg);
6330     // cttz(Hi:Lo) -> Lo == 0 ? (cttz(Hi) + NarrowSize) : cttz(Lo)
6331     auto C_0 = B.buildConstant(NarrowTy, 0);
6332     auto LoIsZero = B.buildICmp(CmpInst::ICMP_EQ, LLT::scalar(1),
6333                                 UnmergeSrc.getReg(0), C_0);
6334     auto HiCTTZ = IsUndef ?
6335       B.buildCTTZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(1)) :
6336       B.buildCTTZ(DstTy, UnmergeSrc.getReg(1));
6337     auto C_NarrowSize = B.buildConstant(DstTy, NarrowSize);
6338     auto LoIsZeroCTTZ = B.buildAdd(DstTy, HiCTTZ, C_NarrowSize);
6339     auto LoCTTZ = B.buildCTTZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(0));
6340     B.buildSelect(DstReg, LoIsZero, LoIsZeroCTTZ, LoCTTZ);
6341 
6342     MI.eraseFromParent();
6343     return Legalized;
6344   }
6345 
6346   return UnableToLegalize;
6347 }
6348 
6349 LegalizerHelper::LegalizeResult
narrowScalarCTPOP(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6350 LegalizerHelper::narrowScalarCTPOP(MachineInstr &MI, unsigned TypeIdx,
6351                                    LLT NarrowTy) {
6352   if (TypeIdx != 1)
6353     return UnableToLegalize;
6354 
6355   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6356   unsigned NarrowSize = NarrowTy.getSizeInBits();
6357 
6358   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
6359     auto UnmergeSrc = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1));
6360 
6361     auto LoCTPOP = MIRBuilder.buildCTPOP(DstTy, UnmergeSrc.getReg(0));
6362     auto HiCTPOP = MIRBuilder.buildCTPOP(DstTy, UnmergeSrc.getReg(1));
6363     MIRBuilder.buildAdd(DstReg, HiCTPOP, LoCTPOP);
6364 
6365     MI.eraseFromParent();
6366     return Legalized;
6367   }
6368 
6369   return UnableToLegalize;
6370 }
6371 
6372 LegalizerHelper::LegalizeResult
narrowScalarFLDEXP(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6373 LegalizerHelper::narrowScalarFLDEXP(MachineInstr &MI, unsigned TypeIdx,
6374                                     LLT NarrowTy) {
6375   if (TypeIdx != 1)
6376     return UnableToLegalize;
6377 
6378   MachineIRBuilder &B = MIRBuilder;
6379   Register ExpReg = MI.getOperand(2).getReg();
6380   LLT ExpTy = MRI.getType(ExpReg);
6381 
6382   unsigned ClampSize = NarrowTy.getScalarSizeInBits();
6383 
6384   // Clamp the exponent to the range of the target type.
6385   auto MinExp = B.buildConstant(ExpTy, minIntN(ClampSize));
6386   auto ClampMin = B.buildSMax(ExpTy, ExpReg, MinExp);
6387   auto MaxExp = B.buildConstant(ExpTy, maxIntN(ClampSize));
6388   auto Clamp = B.buildSMin(ExpTy, ClampMin, MaxExp);
6389 
6390   auto Trunc = B.buildTrunc(NarrowTy, Clamp);
6391   Observer.changingInstr(MI);
6392   MI.getOperand(2).setReg(Trunc.getReg(0));
6393   Observer.changedInstr(MI);
6394   return Legalized;
6395 }
6396 
6397 LegalizerHelper::LegalizeResult
lowerBitCount(MachineInstr & MI)6398 LegalizerHelper::lowerBitCount(MachineInstr &MI) {
6399   unsigned Opc = MI.getOpcode();
6400   const auto &TII = MIRBuilder.getTII();
6401   auto isSupported = [this](const LegalityQuery &Q) {
6402     auto QAction = LI.getAction(Q).Action;
6403     return QAction == Legal || QAction == Libcall || QAction == Custom;
6404   };
6405   switch (Opc) {
6406   default:
6407     return UnableToLegalize;
6408   case TargetOpcode::G_CTLZ_ZERO_UNDEF: {
6409     // This trivially expands to CTLZ.
6410     Observer.changingInstr(MI);
6411     MI.setDesc(TII.get(TargetOpcode::G_CTLZ));
6412     Observer.changedInstr(MI);
6413     return Legalized;
6414   }
6415   case TargetOpcode::G_CTLZ: {
6416     auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6417     unsigned Len = SrcTy.getSizeInBits();
6418 
6419     if (isSupported({TargetOpcode::G_CTLZ_ZERO_UNDEF, {DstTy, SrcTy}})) {
6420       // If CTLZ_ZERO_UNDEF is supported, emit that and a select for zero.
6421       auto CtlzZU = MIRBuilder.buildCTLZ_ZERO_UNDEF(DstTy, SrcReg);
6422       auto ZeroSrc = MIRBuilder.buildConstant(SrcTy, 0);
6423       auto ICmp = MIRBuilder.buildICmp(
6424           CmpInst::ICMP_EQ, SrcTy.changeElementSize(1), SrcReg, ZeroSrc);
6425       auto LenConst = MIRBuilder.buildConstant(DstTy, Len);
6426       MIRBuilder.buildSelect(DstReg, ICmp, LenConst, CtlzZU);
6427       MI.eraseFromParent();
6428       return Legalized;
6429     }
6430     // for now, we do this:
6431     // NewLen = NextPowerOf2(Len);
6432     // x = x | (x >> 1);
6433     // x = x | (x >> 2);
6434     // ...
6435     // x = x | (x >>16);
6436     // x = x | (x >>32); // for 64-bit input
6437     // Upto NewLen/2
6438     // return Len - popcount(x);
6439     //
6440     // Ref: "Hacker's Delight" by Henry Warren
6441     Register Op = SrcReg;
6442     unsigned NewLen = PowerOf2Ceil(Len);
6443     for (unsigned i = 0; (1U << i) <= (NewLen / 2); ++i) {
6444       auto MIBShiftAmt = MIRBuilder.buildConstant(SrcTy, 1ULL << i);
6445       auto MIBOp = MIRBuilder.buildOr(
6446           SrcTy, Op, MIRBuilder.buildLShr(SrcTy, Op, MIBShiftAmt));
6447       Op = MIBOp.getReg(0);
6448     }
6449     auto MIBPop = MIRBuilder.buildCTPOP(DstTy, Op);
6450     MIRBuilder.buildSub(MI.getOperand(0), MIRBuilder.buildConstant(DstTy, Len),
6451                         MIBPop);
6452     MI.eraseFromParent();
6453     return Legalized;
6454   }
6455   case TargetOpcode::G_CTTZ_ZERO_UNDEF: {
6456     // This trivially expands to CTTZ.
6457     Observer.changingInstr(MI);
6458     MI.setDesc(TII.get(TargetOpcode::G_CTTZ));
6459     Observer.changedInstr(MI);
6460     return Legalized;
6461   }
6462   case TargetOpcode::G_CTTZ: {
6463     auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6464 
6465     unsigned Len = SrcTy.getSizeInBits();
6466     if (isSupported({TargetOpcode::G_CTTZ_ZERO_UNDEF, {DstTy, SrcTy}})) {
6467       // If CTTZ_ZERO_UNDEF is legal or custom, emit that and a select with
6468       // zero.
6469       auto CttzZU = MIRBuilder.buildCTTZ_ZERO_UNDEF(DstTy, SrcReg);
6470       auto Zero = MIRBuilder.buildConstant(SrcTy, 0);
6471       auto ICmp = MIRBuilder.buildICmp(
6472           CmpInst::ICMP_EQ, DstTy.changeElementSize(1), SrcReg, Zero);
6473       auto LenConst = MIRBuilder.buildConstant(DstTy, Len);
6474       MIRBuilder.buildSelect(DstReg, ICmp, LenConst, CttzZU);
6475       MI.eraseFromParent();
6476       return Legalized;
6477     }
6478     // for now, we use: { return popcount(~x & (x - 1)); }
6479     // unless the target has ctlz but not ctpop, in which case we use:
6480     // { return 32 - nlz(~x & (x-1)); }
6481     // Ref: "Hacker's Delight" by Henry Warren
6482     auto MIBCstNeg1 = MIRBuilder.buildConstant(SrcTy, -1);
6483     auto MIBNot = MIRBuilder.buildXor(SrcTy, SrcReg, MIBCstNeg1);
6484     auto MIBTmp = MIRBuilder.buildAnd(
6485         SrcTy, MIBNot, MIRBuilder.buildAdd(SrcTy, SrcReg, MIBCstNeg1));
6486     if (!isSupported({TargetOpcode::G_CTPOP, {SrcTy, SrcTy}}) &&
6487         isSupported({TargetOpcode::G_CTLZ, {SrcTy, SrcTy}})) {
6488       auto MIBCstLen = MIRBuilder.buildConstant(SrcTy, Len);
6489       MIRBuilder.buildSub(MI.getOperand(0), MIBCstLen,
6490                           MIRBuilder.buildCTLZ(SrcTy, MIBTmp));
6491       MI.eraseFromParent();
6492       return Legalized;
6493     }
6494     Observer.changingInstr(MI);
6495     MI.setDesc(TII.get(TargetOpcode::G_CTPOP));
6496     MI.getOperand(1).setReg(MIBTmp.getReg(0));
6497     Observer.changedInstr(MI);
6498     return Legalized;
6499   }
6500   case TargetOpcode::G_CTPOP: {
6501     Register SrcReg = MI.getOperand(1).getReg();
6502     LLT Ty = MRI.getType(SrcReg);
6503     unsigned Size = Ty.getSizeInBits();
6504     MachineIRBuilder &B = MIRBuilder;
6505 
6506     // Count set bits in blocks of 2 bits. Default approach would be
6507     // B2Count = { val & 0x55555555 } + { (val >> 1) & 0x55555555 }
6508     // We use following formula instead:
6509     // B2Count = val - { (val >> 1) & 0x55555555 }
6510     // since it gives same result in blocks of 2 with one instruction less.
6511     auto C_1 = B.buildConstant(Ty, 1);
6512     auto B2Set1LoTo1Hi = B.buildLShr(Ty, SrcReg, C_1);
6513     APInt B2Mask1HiTo0 = APInt::getSplat(Size, APInt(8, 0x55));
6514     auto C_B2Mask1HiTo0 = B.buildConstant(Ty, B2Mask1HiTo0);
6515     auto B2Count1Hi = B.buildAnd(Ty, B2Set1LoTo1Hi, C_B2Mask1HiTo0);
6516     auto B2Count = B.buildSub(Ty, SrcReg, B2Count1Hi);
6517 
6518     // In order to get count in blocks of 4 add values from adjacent block of 2.
6519     // B4Count = { B2Count & 0x33333333 } + { (B2Count >> 2) & 0x33333333 }
6520     auto C_2 = B.buildConstant(Ty, 2);
6521     auto B4Set2LoTo2Hi = B.buildLShr(Ty, B2Count, C_2);
6522     APInt B4Mask2HiTo0 = APInt::getSplat(Size, APInt(8, 0x33));
6523     auto C_B4Mask2HiTo0 = B.buildConstant(Ty, B4Mask2HiTo0);
6524     auto B4HiB2Count = B.buildAnd(Ty, B4Set2LoTo2Hi, C_B4Mask2HiTo0);
6525     auto B4LoB2Count = B.buildAnd(Ty, B2Count, C_B4Mask2HiTo0);
6526     auto B4Count = B.buildAdd(Ty, B4HiB2Count, B4LoB2Count);
6527 
6528     // For count in blocks of 8 bits we don't have to mask high 4 bits before
6529     // addition since count value sits in range {0,...,8} and 4 bits are enough
6530     // to hold such binary values. After addition high 4 bits still hold count
6531     // of set bits in high 4 bit block, set them to zero and get 8 bit result.
6532     // B8Count = { B4Count + (B4Count >> 4) } & 0x0F0F0F0F
6533     auto C_4 = B.buildConstant(Ty, 4);
6534     auto B8HiB4Count = B.buildLShr(Ty, B4Count, C_4);
6535     auto B8CountDirty4Hi = B.buildAdd(Ty, B8HiB4Count, B4Count);
6536     APInt B8Mask4HiTo0 = APInt::getSplat(Size, APInt(8, 0x0F));
6537     auto C_B8Mask4HiTo0 = B.buildConstant(Ty, B8Mask4HiTo0);
6538     auto B8Count = B.buildAnd(Ty, B8CountDirty4Hi, C_B8Mask4HiTo0);
6539 
6540     assert(Size<=128 && "Scalar size is too large for CTPOP lower algorithm");
6541     // 8 bits can hold CTPOP result of 128 bit int or smaller. Mul with this
6542     // bitmask will set 8 msb in ResTmp to sum of all B8Counts in 8 bit blocks.
6543     auto MulMask = B.buildConstant(Ty, APInt::getSplat(Size, APInt(8, 0x01)));
6544 
6545     // Shift count result from 8 high bits to low bits.
6546     auto C_SizeM8 = B.buildConstant(Ty, Size - 8);
6547 
6548     auto IsMulSupported = [this](const LLT Ty) {
6549       auto Action = LI.getAction({TargetOpcode::G_MUL, {Ty}}).Action;
6550       return Action == Legal || Action == WidenScalar || Action == Custom;
6551     };
6552     if (IsMulSupported(Ty)) {
6553       auto ResTmp = B.buildMul(Ty, B8Count, MulMask);
6554       B.buildLShr(MI.getOperand(0).getReg(), ResTmp, C_SizeM8);
6555     } else {
6556       auto ResTmp = B8Count;
6557       for (unsigned Shift = 8; Shift < Size; Shift *= 2) {
6558         auto ShiftC = B.buildConstant(Ty, Shift);
6559         auto Shl = B.buildShl(Ty, ResTmp, ShiftC);
6560         ResTmp = B.buildAdd(Ty, ResTmp, Shl);
6561       }
6562       B.buildLShr(MI.getOperand(0).getReg(), ResTmp, C_SizeM8);
6563     }
6564     MI.eraseFromParent();
6565     return Legalized;
6566   }
6567   }
6568 }
6569 
6570 // Check that (every element of) Reg is undef or not an exact multiple of BW.
isNonZeroModBitWidthOrUndef(const MachineRegisterInfo & MRI,Register Reg,unsigned BW)6571 static bool isNonZeroModBitWidthOrUndef(const MachineRegisterInfo &MRI,
6572                                         Register Reg, unsigned BW) {
6573   return matchUnaryPredicate(
6574       MRI, Reg,
6575       [=](const Constant *C) {
6576         // Null constant here means an undef.
6577         const ConstantInt *CI = dyn_cast_or_null<ConstantInt>(C);
6578         return !CI || CI->getValue().urem(BW) != 0;
6579       },
6580       /*AllowUndefs*/ true);
6581 }
6582 
6583 LegalizerHelper::LegalizeResult
lowerFunnelShiftWithInverse(MachineInstr & MI)6584 LegalizerHelper::lowerFunnelShiftWithInverse(MachineInstr &MI) {
6585   auto [Dst, X, Y, Z] = MI.getFirst4Regs();
6586   LLT Ty = MRI.getType(Dst);
6587   LLT ShTy = MRI.getType(Z);
6588 
6589   unsigned BW = Ty.getScalarSizeInBits();
6590 
6591   if (!isPowerOf2_32(BW))
6592     return UnableToLegalize;
6593 
6594   const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
6595   unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
6596 
6597   if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
6598     // fshl X, Y, Z -> fshr X, Y, -Z
6599     // fshr X, Y, Z -> fshl X, Y, -Z
6600     auto Zero = MIRBuilder.buildConstant(ShTy, 0);
6601     Z = MIRBuilder.buildSub(Ty, Zero, Z).getReg(0);
6602   } else {
6603     // fshl X, Y, Z -> fshr (srl X, 1), (fshr X, Y, 1), ~Z
6604     // fshr X, Y, Z -> fshl (fshl X, Y, 1), (shl Y, 1), ~Z
6605     auto One = MIRBuilder.buildConstant(ShTy, 1);
6606     if (IsFSHL) {
6607       Y = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
6608       X = MIRBuilder.buildLShr(Ty, X, One).getReg(0);
6609     } else {
6610       X = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
6611       Y = MIRBuilder.buildShl(Ty, Y, One).getReg(0);
6612     }
6613 
6614     Z = MIRBuilder.buildNot(ShTy, Z).getReg(0);
6615   }
6616 
6617   MIRBuilder.buildInstr(RevOpcode, {Dst}, {X, Y, Z});
6618   MI.eraseFromParent();
6619   return Legalized;
6620 }
6621 
6622 LegalizerHelper::LegalizeResult
lowerFunnelShiftAsShifts(MachineInstr & MI)6623 LegalizerHelper::lowerFunnelShiftAsShifts(MachineInstr &MI) {
6624   auto [Dst, X, Y, Z] = MI.getFirst4Regs();
6625   LLT Ty = MRI.getType(Dst);
6626   LLT ShTy = MRI.getType(Z);
6627 
6628   const unsigned BW = Ty.getScalarSizeInBits();
6629   const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
6630 
6631   Register ShX, ShY;
6632   Register ShAmt, InvShAmt;
6633 
6634   // FIXME: Emit optimized urem by constant instead of letting it expand later.
6635   if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
6636     // fshl: X << C | Y >> (BW - C)
6637     // fshr: X << (BW - C) | Y >> C
6638     // where C = Z % BW is not zero
6639     auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
6640     ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
6641     InvShAmt = MIRBuilder.buildSub(ShTy, BitWidthC, ShAmt).getReg(0);
6642     ShX = MIRBuilder.buildShl(Ty, X, IsFSHL ? ShAmt : InvShAmt).getReg(0);
6643     ShY = MIRBuilder.buildLShr(Ty, Y, IsFSHL ? InvShAmt : ShAmt).getReg(0);
6644   } else {
6645     // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
6646     // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
6647     auto Mask = MIRBuilder.buildConstant(ShTy, BW - 1);
6648     if (isPowerOf2_32(BW)) {
6649       // Z % BW -> Z & (BW - 1)
6650       ShAmt = MIRBuilder.buildAnd(ShTy, Z, Mask).getReg(0);
6651       // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
6652       auto NotZ = MIRBuilder.buildNot(ShTy, Z);
6653       InvShAmt = MIRBuilder.buildAnd(ShTy, NotZ, Mask).getReg(0);
6654     } else {
6655       auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
6656       ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
6657       InvShAmt = MIRBuilder.buildSub(ShTy, Mask, ShAmt).getReg(0);
6658     }
6659 
6660     auto One = MIRBuilder.buildConstant(ShTy, 1);
6661     if (IsFSHL) {
6662       ShX = MIRBuilder.buildShl(Ty, X, ShAmt).getReg(0);
6663       auto ShY1 = MIRBuilder.buildLShr(Ty, Y, One);
6664       ShY = MIRBuilder.buildLShr(Ty, ShY1, InvShAmt).getReg(0);
6665     } else {
6666       auto ShX1 = MIRBuilder.buildShl(Ty, X, One);
6667       ShX = MIRBuilder.buildShl(Ty, ShX1, InvShAmt).getReg(0);
6668       ShY = MIRBuilder.buildLShr(Ty, Y, ShAmt).getReg(0);
6669     }
6670   }
6671 
6672   MIRBuilder.buildOr(Dst, ShX, ShY);
6673   MI.eraseFromParent();
6674   return Legalized;
6675 }
6676 
6677 LegalizerHelper::LegalizeResult
lowerFunnelShift(MachineInstr & MI)6678 LegalizerHelper::lowerFunnelShift(MachineInstr &MI) {
6679   // These operations approximately do the following (while avoiding undefined
6680   // shifts by BW):
6681   // G_FSHL: (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
6682   // G_FSHR: (X << (BW - (Z % BW))) | (Y >> (Z % BW))
6683   Register Dst = MI.getOperand(0).getReg();
6684   LLT Ty = MRI.getType(Dst);
6685   LLT ShTy = MRI.getType(MI.getOperand(3).getReg());
6686 
6687   bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
6688   unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
6689 
6690   // TODO: Use smarter heuristic that accounts for vector legalization.
6691   if (LI.getAction({RevOpcode, {Ty, ShTy}}).Action == Lower)
6692     return lowerFunnelShiftAsShifts(MI);
6693 
6694   // This only works for powers of 2, fallback to shifts if it fails.
6695   LegalizerHelper::LegalizeResult Result = lowerFunnelShiftWithInverse(MI);
6696   if (Result == UnableToLegalize)
6697     return lowerFunnelShiftAsShifts(MI);
6698   return Result;
6699 }
6700 
lowerEXT(MachineInstr & MI)6701 LegalizerHelper::LegalizeResult LegalizerHelper::lowerEXT(MachineInstr &MI) {
6702   auto [Dst, Src] = MI.getFirst2Regs();
6703   LLT DstTy = MRI.getType(Dst);
6704   LLT SrcTy = MRI.getType(Src);
6705 
6706   uint32_t DstTySize = DstTy.getSizeInBits();
6707   uint32_t DstTyScalarSize = DstTy.getScalarSizeInBits();
6708   uint32_t SrcTyScalarSize = SrcTy.getScalarSizeInBits();
6709 
6710   if (!isPowerOf2_32(DstTySize) || !isPowerOf2_32(DstTyScalarSize) ||
6711       !isPowerOf2_32(SrcTyScalarSize))
6712     return UnableToLegalize;
6713 
6714   // The step between extend is too large, split it by creating an intermediate
6715   // extend instruction
6716   if (SrcTyScalarSize * 2 < DstTyScalarSize) {
6717     LLT MidTy = SrcTy.changeElementSize(SrcTyScalarSize * 2);
6718     // If the destination type is illegal, split it into multiple statements
6719     // zext x -> zext(merge(zext(unmerge), zext(unmerge)))
6720     auto NewExt = MIRBuilder.buildInstr(MI.getOpcode(), {MidTy}, {Src});
6721     // Unmerge the vector
6722     LLT EltTy = MidTy.changeElementCount(
6723         MidTy.getElementCount().divideCoefficientBy(2));
6724     auto UnmergeSrc = MIRBuilder.buildUnmerge(EltTy, NewExt);
6725 
6726     // ZExt the vectors
6727     LLT ZExtResTy = DstTy.changeElementCount(
6728         DstTy.getElementCount().divideCoefficientBy(2));
6729     auto ZExtRes1 = MIRBuilder.buildInstr(MI.getOpcode(), {ZExtResTy},
6730                                           {UnmergeSrc.getReg(0)});
6731     auto ZExtRes2 = MIRBuilder.buildInstr(MI.getOpcode(), {ZExtResTy},
6732                                           {UnmergeSrc.getReg(1)});
6733 
6734     // Merge the ending vectors
6735     MIRBuilder.buildMergeLikeInstr(Dst, {ZExtRes1, ZExtRes2});
6736 
6737     MI.eraseFromParent();
6738     return Legalized;
6739   }
6740   return UnableToLegalize;
6741 }
6742 
lowerTRUNC(MachineInstr & MI)6743 LegalizerHelper::LegalizeResult LegalizerHelper::lowerTRUNC(MachineInstr &MI) {
6744   // MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
6745   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
6746   // Similar to how operand splitting is done in SelectiondDAG, we can handle
6747   // %res(v8s8) = G_TRUNC %in(v8s32) by generating:
6748   //   %inlo(<4x s32>), %inhi(<4 x s32>) = G_UNMERGE %in(<8 x s32>)
6749   //   %lo16(<4 x s16>) = G_TRUNC %inlo
6750   //   %hi16(<4 x s16>) = G_TRUNC %inhi
6751   //   %in16(<8 x s16>) = G_CONCAT_VECTORS %lo16, %hi16
6752   //   %res(<8 x s8>) = G_TRUNC %in16
6753 
6754   assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
6755 
6756   Register DstReg = MI.getOperand(0).getReg();
6757   Register SrcReg = MI.getOperand(1).getReg();
6758   LLT DstTy = MRI.getType(DstReg);
6759   LLT SrcTy = MRI.getType(SrcReg);
6760 
6761   if (DstTy.isVector() && isPowerOf2_32(DstTy.getNumElements()) &&
6762       isPowerOf2_32(DstTy.getScalarSizeInBits()) &&
6763       isPowerOf2_32(SrcTy.getNumElements()) &&
6764       isPowerOf2_32(SrcTy.getScalarSizeInBits())) {
6765     // Split input type.
6766     LLT SplitSrcTy = SrcTy.changeElementCount(
6767         SrcTy.getElementCount().divideCoefficientBy(2));
6768 
6769     // First, split the source into two smaller vectors.
6770     SmallVector<Register, 2> SplitSrcs;
6771     extractParts(SrcReg, SplitSrcTy, 2, SplitSrcs, MIRBuilder, MRI);
6772 
6773     // Truncate the splits into intermediate narrower elements.
6774     LLT InterTy;
6775     if (DstTy.getScalarSizeInBits() * 2 < SrcTy.getScalarSizeInBits())
6776       InterTy = SplitSrcTy.changeElementSize(DstTy.getScalarSizeInBits() * 2);
6777     else
6778       InterTy = SplitSrcTy.changeElementSize(DstTy.getScalarSizeInBits());
6779     for (unsigned I = 0; I < SplitSrcs.size(); ++I) {
6780       SplitSrcs[I] = MIRBuilder.buildTrunc(InterTy, SplitSrcs[I]).getReg(0);
6781     }
6782 
6783     // Combine the new truncates into one vector
6784     auto Merge = MIRBuilder.buildMergeLikeInstr(
6785         DstTy.changeElementSize(InterTy.getScalarSizeInBits()), SplitSrcs);
6786 
6787     // Truncate the new vector to the final result type
6788     if (DstTy.getScalarSizeInBits() * 2 < SrcTy.getScalarSizeInBits())
6789       MIRBuilder.buildTrunc(MI.getOperand(0).getReg(), Merge.getReg(0));
6790     else
6791       MIRBuilder.buildCopy(MI.getOperand(0).getReg(), Merge.getReg(0));
6792 
6793     MI.eraseFromParent();
6794 
6795     return Legalized;
6796   }
6797   return UnableToLegalize;
6798 }
6799 
6800 LegalizerHelper::LegalizeResult
lowerRotateWithReverseRotate(MachineInstr & MI)6801 LegalizerHelper::lowerRotateWithReverseRotate(MachineInstr &MI) {
6802   auto [Dst, DstTy, Src, SrcTy, Amt, AmtTy] = MI.getFirst3RegLLTs();
6803   auto Zero = MIRBuilder.buildConstant(AmtTy, 0);
6804   bool IsLeft = MI.getOpcode() == TargetOpcode::G_ROTL;
6805   unsigned RevRot = IsLeft ? TargetOpcode::G_ROTR : TargetOpcode::G_ROTL;
6806   auto Neg = MIRBuilder.buildSub(AmtTy, Zero, Amt);
6807   MIRBuilder.buildInstr(RevRot, {Dst}, {Src, Neg});
6808   MI.eraseFromParent();
6809   return Legalized;
6810 }
6811 
lowerRotate(MachineInstr & MI)6812 LegalizerHelper::LegalizeResult LegalizerHelper::lowerRotate(MachineInstr &MI) {
6813   auto [Dst, DstTy, Src, SrcTy, Amt, AmtTy] = MI.getFirst3RegLLTs();
6814 
6815   unsigned EltSizeInBits = DstTy.getScalarSizeInBits();
6816   bool IsLeft = MI.getOpcode() == TargetOpcode::G_ROTL;
6817 
6818   MIRBuilder.setInstrAndDebugLoc(MI);
6819 
6820   // If a rotate in the other direction is supported, use it.
6821   unsigned RevRot = IsLeft ? TargetOpcode::G_ROTR : TargetOpcode::G_ROTL;
6822   if (LI.isLegalOrCustom({RevRot, {DstTy, SrcTy}}) &&
6823       isPowerOf2_32(EltSizeInBits))
6824     return lowerRotateWithReverseRotate(MI);
6825 
6826   // If a funnel shift is supported, use it.
6827   unsigned FShOpc = IsLeft ? TargetOpcode::G_FSHL : TargetOpcode::G_FSHR;
6828   unsigned RevFsh = !IsLeft ? TargetOpcode::G_FSHL : TargetOpcode::G_FSHR;
6829   bool IsFShLegal = false;
6830   if ((IsFShLegal = LI.isLegalOrCustom({FShOpc, {DstTy, AmtTy}})) ||
6831       LI.isLegalOrCustom({RevFsh, {DstTy, AmtTy}})) {
6832     auto buildFunnelShift = [&](unsigned Opc, Register R1, Register R2,
6833                                 Register R3) {
6834       MIRBuilder.buildInstr(Opc, {R1}, {R2, R2, R3});
6835       MI.eraseFromParent();
6836       return Legalized;
6837     };
6838     // If a funnel shift in the other direction is supported, use it.
6839     if (IsFShLegal) {
6840       return buildFunnelShift(FShOpc, Dst, Src, Amt);
6841     } else if (isPowerOf2_32(EltSizeInBits)) {
6842       Amt = MIRBuilder.buildNeg(DstTy, Amt).getReg(0);
6843       return buildFunnelShift(RevFsh, Dst, Src, Amt);
6844     }
6845   }
6846 
6847   auto Zero = MIRBuilder.buildConstant(AmtTy, 0);
6848   unsigned ShOpc = IsLeft ? TargetOpcode::G_SHL : TargetOpcode::G_LSHR;
6849   unsigned RevShiftOpc = IsLeft ? TargetOpcode::G_LSHR : TargetOpcode::G_SHL;
6850   auto BitWidthMinusOneC = MIRBuilder.buildConstant(AmtTy, EltSizeInBits - 1);
6851   Register ShVal;
6852   Register RevShiftVal;
6853   if (isPowerOf2_32(EltSizeInBits)) {
6854     // (rotl x, c) -> x << (c & (w - 1)) | x >> (-c & (w - 1))
6855     // (rotr x, c) -> x >> (c & (w - 1)) | x << (-c & (w - 1))
6856     auto NegAmt = MIRBuilder.buildSub(AmtTy, Zero, Amt);
6857     auto ShAmt = MIRBuilder.buildAnd(AmtTy, Amt, BitWidthMinusOneC);
6858     ShVal = MIRBuilder.buildInstr(ShOpc, {DstTy}, {Src, ShAmt}).getReg(0);
6859     auto RevAmt = MIRBuilder.buildAnd(AmtTy, NegAmt, BitWidthMinusOneC);
6860     RevShiftVal =
6861         MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Src, RevAmt}).getReg(0);
6862   } else {
6863     // (rotl x, c) -> x << (c % w) | x >> 1 >> (w - 1 - (c % w))
6864     // (rotr x, c) -> x >> (c % w) | x << 1 << (w - 1 - (c % w))
6865     auto BitWidthC = MIRBuilder.buildConstant(AmtTy, EltSizeInBits);
6866     auto ShAmt = MIRBuilder.buildURem(AmtTy, Amt, BitWidthC);
6867     ShVal = MIRBuilder.buildInstr(ShOpc, {DstTy}, {Src, ShAmt}).getReg(0);
6868     auto RevAmt = MIRBuilder.buildSub(AmtTy, BitWidthMinusOneC, ShAmt);
6869     auto One = MIRBuilder.buildConstant(AmtTy, 1);
6870     auto Inner = MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Src, One});
6871     RevShiftVal =
6872         MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Inner, RevAmt}).getReg(0);
6873   }
6874   MIRBuilder.buildOr(Dst, ShVal, RevShiftVal);
6875   MI.eraseFromParent();
6876   return Legalized;
6877 }
6878 
6879 // Expand s32 = G_UITOFP s64 using bit operations to an IEEE float
6880 // representation.
6881 LegalizerHelper::LegalizeResult
lowerU64ToF32BitOps(MachineInstr & MI)6882 LegalizerHelper::lowerU64ToF32BitOps(MachineInstr &MI) {
6883   auto [Dst, Src] = MI.getFirst2Regs();
6884   const LLT S64 = LLT::scalar(64);
6885   const LLT S32 = LLT::scalar(32);
6886   const LLT S1 = LLT::scalar(1);
6887 
6888   assert(MRI.getType(Src) == S64 && MRI.getType(Dst) == S32);
6889 
6890   // unsigned cul2f(ulong u) {
6891   //   uint lz = clz(u);
6892   //   uint e = (u != 0) ? 127U + 63U - lz : 0;
6893   //   u = (u << lz) & 0x7fffffffffffffffUL;
6894   //   ulong t = u & 0xffffffffffUL;
6895   //   uint v = (e << 23) | (uint)(u >> 40);
6896   //   uint r = t > 0x8000000000UL ? 1U : (t == 0x8000000000UL ? v & 1U : 0U);
6897   //   return as_float(v + r);
6898   // }
6899 
6900   auto Zero32 = MIRBuilder.buildConstant(S32, 0);
6901   auto Zero64 = MIRBuilder.buildConstant(S64, 0);
6902 
6903   auto LZ = MIRBuilder.buildCTLZ_ZERO_UNDEF(S32, Src);
6904 
6905   auto K = MIRBuilder.buildConstant(S32, 127U + 63U);
6906   auto Sub = MIRBuilder.buildSub(S32, K, LZ);
6907 
6908   auto NotZero = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, Src, Zero64);
6909   auto E = MIRBuilder.buildSelect(S32, NotZero, Sub, Zero32);
6910 
6911   auto Mask0 = MIRBuilder.buildConstant(S64, (-1ULL) >> 1);
6912   auto ShlLZ = MIRBuilder.buildShl(S64, Src, LZ);
6913 
6914   auto U = MIRBuilder.buildAnd(S64, ShlLZ, Mask0);
6915 
6916   auto Mask1 = MIRBuilder.buildConstant(S64, 0xffffffffffULL);
6917   auto T = MIRBuilder.buildAnd(S64, U, Mask1);
6918 
6919   auto UShl = MIRBuilder.buildLShr(S64, U, MIRBuilder.buildConstant(S64, 40));
6920   auto ShlE = MIRBuilder.buildShl(S32, E, MIRBuilder.buildConstant(S32, 23));
6921   auto V = MIRBuilder.buildOr(S32, ShlE, MIRBuilder.buildTrunc(S32, UShl));
6922 
6923   auto C = MIRBuilder.buildConstant(S64, 0x8000000000ULL);
6924   auto RCmp = MIRBuilder.buildICmp(CmpInst::ICMP_UGT, S1, T, C);
6925   auto TCmp = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1, T, C);
6926   auto One = MIRBuilder.buildConstant(S32, 1);
6927 
6928   auto VTrunc1 = MIRBuilder.buildAnd(S32, V, One);
6929   auto Select0 = MIRBuilder.buildSelect(S32, TCmp, VTrunc1, Zero32);
6930   auto R = MIRBuilder.buildSelect(S32, RCmp, One, Select0);
6931   MIRBuilder.buildAdd(Dst, V, R);
6932 
6933   MI.eraseFromParent();
6934   return Legalized;
6935 }
6936 
lowerUITOFP(MachineInstr & MI)6937 LegalizerHelper::LegalizeResult LegalizerHelper::lowerUITOFP(MachineInstr &MI) {
6938   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
6939 
6940   if (SrcTy == LLT::scalar(1)) {
6941     auto True = MIRBuilder.buildFConstant(DstTy, 1.0);
6942     auto False = MIRBuilder.buildFConstant(DstTy, 0.0);
6943     MIRBuilder.buildSelect(Dst, Src, True, False);
6944     MI.eraseFromParent();
6945     return Legalized;
6946   }
6947 
6948   if (SrcTy != LLT::scalar(64))
6949     return UnableToLegalize;
6950 
6951   if (DstTy == LLT::scalar(32)) {
6952     // TODO: SelectionDAG has several alternative expansions to port which may
6953     // be more reasonble depending on the available instructions. If a target
6954     // has sitofp, does not have CTLZ, or can efficiently use f64 as an
6955     // intermediate type, this is probably worse.
6956     return lowerU64ToF32BitOps(MI);
6957   }
6958 
6959   return UnableToLegalize;
6960 }
6961 
lowerSITOFP(MachineInstr & MI)6962 LegalizerHelper::LegalizeResult LegalizerHelper::lowerSITOFP(MachineInstr &MI) {
6963   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
6964 
6965   const LLT S64 = LLT::scalar(64);
6966   const LLT S32 = LLT::scalar(32);
6967   const LLT S1 = LLT::scalar(1);
6968 
6969   if (SrcTy == S1) {
6970     auto True = MIRBuilder.buildFConstant(DstTy, -1.0);
6971     auto False = MIRBuilder.buildFConstant(DstTy, 0.0);
6972     MIRBuilder.buildSelect(Dst, Src, True, False);
6973     MI.eraseFromParent();
6974     return Legalized;
6975   }
6976 
6977   if (SrcTy != S64)
6978     return UnableToLegalize;
6979 
6980   if (DstTy == S32) {
6981     // signed cl2f(long l) {
6982     //   long s = l >> 63;
6983     //   float r = cul2f((l + s) ^ s);
6984     //   return s ? -r : r;
6985     // }
6986     Register L = Src;
6987     auto SignBit = MIRBuilder.buildConstant(S64, 63);
6988     auto S = MIRBuilder.buildAShr(S64, L, SignBit);
6989 
6990     auto LPlusS = MIRBuilder.buildAdd(S64, L, S);
6991     auto Xor = MIRBuilder.buildXor(S64, LPlusS, S);
6992     auto R = MIRBuilder.buildUITOFP(S32, Xor);
6993 
6994     auto RNeg = MIRBuilder.buildFNeg(S32, R);
6995     auto SignNotZero = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, S,
6996                                             MIRBuilder.buildConstant(S64, 0));
6997     MIRBuilder.buildSelect(Dst, SignNotZero, RNeg, R);
6998     MI.eraseFromParent();
6999     return Legalized;
7000   }
7001 
7002   return UnableToLegalize;
7003 }
7004 
lowerFPTOUI(MachineInstr & MI)7005 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOUI(MachineInstr &MI) {
7006   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
7007   const LLT S64 = LLT::scalar(64);
7008   const LLT S32 = LLT::scalar(32);
7009 
7010   if (SrcTy != S64 && SrcTy != S32)
7011     return UnableToLegalize;
7012   if (DstTy != S32 && DstTy != S64)
7013     return UnableToLegalize;
7014 
7015   // FPTOSI gives same result as FPTOUI for positive signed integers.
7016   // FPTOUI needs to deal with fp values that convert to unsigned integers
7017   // greater or equal to 2^31 for float or 2^63 for double. For brevity 2^Exp.
7018 
7019   APInt TwoPExpInt = APInt::getSignMask(DstTy.getSizeInBits());
7020   APFloat TwoPExpFP(SrcTy.getSizeInBits() == 32 ? APFloat::IEEEsingle()
7021                                                 : APFloat::IEEEdouble(),
7022                     APInt::getZero(SrcTy.getSizeInBits()));
7023   TwoPExpFP.convertFromAPInt(TwoPExpInt, false, APFloat::rmNearestTiesToEven);
7024 
7025   MachineInstrBuilder FPTOSI = MIRBuilder.buildFPTOSI(DstTy, Src);
7026 
7027   MachineInstrBuilder Threshold = MIRBuilder.buildFConstant(SrcTy, TwoPExpFP);
7028   // For fp Value greater or equal to Threshold(2^Exp), we use FPTOSI on
7029   // (Value - 2^Exp) and add 2^Exp by setting highest bit in result to 1.
7030   MachineInstrBuilder FSub = MIRBuilder.buildFSub(SrcTy, Src, Threshold);
7031   MachineInstrBuilder ResLowBits = MIRBuilder.buildFPTOSI(DstTy, FSub);
7032   MachineInstrBuilder ResHighBit = MIRBuilder.buildConstant(DstTy, TwoPExpInt);
7033   MachineInstrBuilder Res = MIRBuilder.buildXor(DstTy, ResLowBits, ResHighBit);
7034 
7035   const LLT S1 = LLT::scalar(1);
7036 
7037   MachineInstrBuilder FCMP =
7038       MIRBuilder.buildFCmp(CmpInst::FCMP_ULT, S1, Src, Threshold);
7039   MIRBuilder.buildSelect(Dst, FCMP, FPTOSI, Res);
7040 
7041   MI.eraseFromParent();
7042   return Legalized;
7043 }
7044 
lowerFPTOSI(MachineInstr & MI)7045 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOSI(MachineInstr &MI) {
7046   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
7047   const LLT S64 = LLT::scalar(64);
7048   const LLT S32 = LLT::scalar(32);
7049 
7050   // FIXME: Only f32 to i64 conversions are supported.
7051   if (SrcTy.getScalarType() != S32 || DstTy.getScalarType() != S64)
7052     return UnableToLegalize;
7053 
7054   // Expand f32 -> i64 conversion
7055   // This algorithm comes from compiler-rt's implementation of fixsfdi:
7056   // https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/fixsfdi.c
7057 
7058   unsigned SrcEltBits = SrcTy.getScalarSizeInBits();
7059 
7060   auto ExponentMask = MIRBuilder.buildConstant(SrcTy, 0x7F800000);
7061   auto ExponentLoBit = MIRBuilder.buildConstant(SrcTy, 23);
7062 
7063   auto AndExpMask = MIRBuilder.buildAnd(SrcTy, Src, ExponentMask);
7064   auto ExponentBits = MIRBuilder.buildLShr(SrcTy, AndExpMask, ExponentLoBit);
7065 
7066   auto SignMask = MIRBuilder.buildConstant(SrcTy,
7067                                            APInt::getSignMask(SrcEltBits));
7068   auto AndSignMask = MIRBuilder.buildAnd(SrcTy, Src, SignMask);
7069   auto SignLowBit = MIRBuilder.buildConstant(SrcTy, SrcEltBits - 1);
7070   auto Sign = MIRBuilder.buildAShr(SrcTy, AndSignMask, SignLowBit);
7071   Sign = MIRBuilder.buildSExt(DstTy, Sign);
7072 
7073   auto MantissaMask = MIRBuilder.buildConstant(SrcTy, 0x007FFFFF);
7074   auto AndMantissaMask = MIRBuilder.buildAnd(SrcTy, Src, MantissaMask);
7075   auto K = MIRBuilder.buildConstant(SrcTy, 0x00800000);
7076 
7077   auto R = MIRBuilder.buildOr(SrcTy, AndMantissaMask, K);
7078   R = MIRBuilder.buildZExt(DstTy, R);
7079 
7080   auto Bias = MIRBuilder.buildConstant(SrcTy, 127);
7081   auto Exponent = MIRBuilder.buildSub(SrcTy, ExponentBits, Bias);
7082   auto SubExponent = MIRBuilder.buildSub(SrcTy, Exponent, ExponentLoBit);
7083   auto ExponentSub = MIRBuilder.buildSub(SrcTy, ExponentLoBit, Exponent);
7084 
7085   auto Shl = MIRBuilder.buildShl(DstTy, R, SubExponent);
7086   auto Srl = MIRBuilder.buildLShr(DstTy, R, ExponentSub);
7087 
7088   const LLT S1 = LLT::scalar(1);
7089   auto CmpGt = MIRBuilder.buildICmp(CmpInst::ICMP_SGT,
7090                                     S1, Exponent, ExponentLoBit);
7091 
7092   R = MIRBuilder.buildSelect(DstTy, CmpGt, Shl, Srl);
7093 
7094   auto XorSign = MIRBuilder.buildXor(DstTy, R, Sign);
7095   auto Ret = MIRBuilder.buildSub(DstTy, XorSign, Sign);
7096 
7097   auto ZeroSrcTy = MIRBuilder.buildConstant(SrcTy, 0);
7098 
7099   auto ExponentLt0 = MIRBuilder.buildICmp(CmpInst::ICMP_SLT,
7100                                           S1, Exponent, ZeroSrcTy);
7101 
7102   auto ZeroDstTy = MIRBuilder.buildConstant(DstTy, 0);
7103   MIRBuilder.buildSelect(Dst, ExponentLt0, ZeroDstTy, Ret);
7104 
7105   MI.eraseFromParent();
7106   return Legalized;
7107 }
7108 
7109 // f64 -> f16 conversion using round-to-nearest-even rounding mode.
7110 LegalizerHelper::LegalizeResult
lowerFPTRUNC_F64_TO_F16(MachineInstr & MI)7111 LegalizerHelper::lowerFPTRUNC_F64_TO_F16(MachineInstr &MI) {
7112   const LLT S1 = LLT::scalar(1);
7113   const LLT S32 = LLT::scalar(32);
7114 
7115   auto [Dst, Src] = MI.getFirst2Regs();
7116   assert(MRI.getType(Dst).getScalarType() == LLT::scalar(16) &&
7117          MRI.getType(Src).getScalarType() == LLT::scalar(64));
7118 
7119   if (MRI.getType(Src).isVector()) // TODO: Handle vectors directly.
7120     return UnableToLegalize;
7121 
7122   if (MIRBuilder.getMF().getTarget().Options.UnsafeFPMath) {
7123     unsigned Flags = MI.getFlags();
7124     auto Src32 = MIRBuilder.buildFPTrunc(S32, Src, Flags);
7125     MIRBuilder.buildFPTrunc(Dst, Src32, Flags);
7126     MI.eraseFromParent();
7127     return Legalized;
7128   }
7129 
7130   const unsigned ExpMask = 0x7ff;
7131   const unsigned ExpBiasf64 = 1023;
7132   const unsigned ExpBiasf16 = 15;
7133 
7134   auto Unmerge = MIRBuilder.buildUnmerge(S32, Src);
7135   Register U = Unmerge.getReg(0);
7136   Register UH = Unmerge.getReg(1);
7137 
7138   auto E = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 20));
7139   E = MIRBuilder.buildAnd(S32, E, MIRBuilder.buildConstant(S32, ExpMask));
7140 
7141   // Subtract the fp64 exponent bias (1023) to get the real exponent and
7142   // add the f16 bias (15) to get the biased exponent for the f16 format.
7143   E = MIRBuilder.buildAdd(
7144     S32, E, MIRBuilder.buildConstant(S32, -ExpBiasf64 + ExpBiasf16));
7145 
7146   auto M = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 8));
7147   M = MIRBuilder.buildAnd(S32, M, MIRBuilder.buildConstant(S32, 0xffe));
7148 
7149   auto MaskedSig = MIRBuilder.buildAnd(S32, UH,
7150                                        MIRBuilder.buildConstant(S32, 0x1ff));
7151   MaskedSig = MIRBuilder.buildOr(S32, MaskedSig, U);
7152 
7153   auto Zero = MIRBuilder.buildConstant(S32, 0);
7154   auto SigCmpNE0 = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, MaskedSig, Zero);
7155   auto Lo40Set = MIRBuilder.buildZExt(S32, SigCmpNE0);
7156   M = MIRBuilder.buildOr(S32, M, Lo40Set);
7157 
7158   // (M != 0 ? 0x0200 : 0) | 0x7c00;
7159   auto Bits0x200 = MIRBuilder.buildConstant(S32, 0x0200);
7160   auto CmpM_NE0 = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, M, Zero);
7161   auto SelectCC = MIRBuilder.buildSelect(S32, CmpM_NE0, Bits0x200, Zero);
7162 
7163   auto Bits0x7c00 = MIRBuilder.buildConstant(S32, 0x7c00);
7164   auto I = MIRBuilder.buildOr(S32, SelectCC, Bits0x7c00);
7165 
7166   // N = M | (E << 12);
7167   auto EShl12 = MIRBuilder.buildShl(S32, E, MIRBuilder.buildConstant(S32, 12));
7168   auto N = MIRBuilder.buildOr(S32, M, EShl12);
7169 
7170   // B = clamp(1-E, 0, 13);
7171   auto One = MIRBuilder.buildConstant(S32, 1);
7172   auto OneSubExp = MIRBuilder.buildSub(S32, One, E);
7173   auto B = MIRBuilder.buildSMax(S32, OneSubExp, Zero);
7174   B = MIRBuilder.buildSMin(S32, B, MIRBuilder.buildConstant(S32, 13));
7175 
7176   auto SigSetHigh = MIRBuilder.buildOr(S32, M,
7177                                        MIRBuilder.buildConstant(S32, 0x1000));
7178 
7179   auto D = MIRBuilder.buildLShr(S32, SigSetHigh, B);
7180   auto D0 = MIRBuilder.buildShl(S32, D, B);
7181 
7182   auto D0_NE_SigSetHigh = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1,
7183                                              D0, SigSetHigh);
7184   auto D1 = MIRBuilder.buildZExt(S32, D0_NE_SigSetHigh);
7185   D = MIRBuilder.buildOr(S32, D, D1);
7186 
7187   auto CmpELtOne = MIRBuilder.buildICmp(CmpInst::ICMP_SLT, S1, E, One);
7188   auto V = MIRBuilder.buildSelect(S32, CmpELtOne, D, N);
7189 
7190   auto VLow3 = MIRBuilder.buildAnd(S32, V, MIRBuilder.buildConstant(S32, 7));
7191   V = MIRBuilder.buildLShr(S32, V, MIRBuilder.buildConstant(S32, 2));
7192 
7193   auto VLow3Eq3 = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1, VLow3,
7194                                        MIRBuilder.buildConstant(S32, 3));
7195   auto V0 = MIRBuilder.buildZExt(S32, VLow3Eq3);
7196 
7197   auto VLow3Gt5 = MIRBuilder.buildICmp(CmpInst::ICMP_SGT, S1, VLow3,
7198                                        MIRBuilder.buildConstant(S32, 5));
7199   auto V1 = MIRBuilder.buildZExt(S32, VLow3Gt5);
7200 
7201   V1 = MIRBuilder.buildOr(S32, V0, V1);
7202   V = MIRBuilder.buildAdd(S32, V, V1);
7203 
7204   auto CmpEGt30 = MIRBuilder.buildICmp(CmpInst::ICMP_SGT,  S1,
7205                                        E, MIRBuilder.buildConstant(S32, 30));
7206   V = MIRBuilder.buildSelect(S32, CmpEGt30,
7207                              MIRBuilder.buildConstant(S32, 0x7c00), V);
7208 
7209   auto CmpEGt1039 = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1,
7210                                          E, MIRBuilder.buildConstant(S32, 1039));
7211   V = MIRBuilder.buildSelect(S32, CmpEGt1039, I, V);
7212 
7213   // Extract the sign bit.
7214   auto Sign = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 16));
7215   Sign = MIRBuilder.buildAnd(S32, Sign, MIRBuilder.buildConstant(S32, 0x8000));
7216 
7217   // Insert the sign bit
7218   V = MIRBuilder.buildOr(S32, Sign, V);
7219 
7220   MIRBuilder.buildTrunc(Dst, V);
7221   MI.eraseFromParent();
7222   return Legalized;
7223 }
7224 
7225 LegalizerHelper::LegalizeResult
lowerFPTRUNC(MachineInstr & MI)7226 LegalizerHelper::lowerFPTRUNC(MachineInstr &MI) {
7227   auto [DstTy, SrcTy] = MI.getFirst2LLTs();
7228   const LLT S64 = LLT::scalar(64);
7229   const LLT S16 = LLT::scalar(16);
7230 
7231   if (DstTy.getScalarType() == S16 && SrcTy.getScalarType() == S64)
7232     return lowerFPTRUNC_F64_TO_F16(MI);
7233 
7234   return UnableToLegalize;
7235 }
7236 
lowerFPOWI(MachineInstr & MI)7237 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPOWI(MachineInstr &MI) {
7238   auto [Dst, Src0, Src1] = MI.getFirst3Regs();
7239   LLT Ty = MRI.getType(Dst);
7240 
7241   auto CvtSrc1 = MIRBuilder.buildSITOFP(Ty, Src1);
7242   MIRBuilder.buildFPow(Dst, Src0, CvtSrc1, MI.getFlags());
7243   MI.eraseFromParent();
7244   return Legalized;
7245 }
7246 
minMaxToCompare(unsigned Opc)7247 static CmpInst::Predicate minMaxToCompare(unsigned Opc) {
7248   switch (Opc) {
7249   case TargetOpcode::G_SMIN:
7250     return CmpInst::ICMP_SLT;
7251   case TargetOpcode::G_SMAX:
7252     return CmpInst::ICMP_SGT;
7253   case TargetOpcode::G_UMIN:
7254     return CmpInst::ICMP_ULT;
7255   case TargetOpcode::G_UMAX:
7256     return CmpInst::ICMP_UGT;
7257   default:
7258     llvm_unreachable("not in integer min/max");
7259   }
7260 }
7261 
lowerMinMax(MachineInstr & MI)7262 LegalizerHelper::LegalizeResult LegalizerHelper::lowerMinMax(MachineInstr &MI) {
7263   auto [Dst, Src0, Src1] = MI.getFirst3Regs();
7264 
7265   const CmpInst::Predicate Pred = minMaxToCompare(MI.getOpcode());
7266   LLT CmpType = MRI.getType(Dst).changeElementSize(1);
7267 
7268   auto Cmp = MIRBuilder.buildICmp(Pred, CmpType, Src0, Src1);
7269   MIRBuilder.buildSelect(Dst, Cmp, Src0, Src1);
7270 
7271   MI.eraseFromParent();
7272   return Legalized;
7273 }
7274 
7275 LegalizerHelper::LegalizeResult
lowerThreewayCompare(MachineInstr & MI)7276 LegalizerHelper::lowerThreewayCompare(MachineInstr &MI) {
7277   GSUCmp *Cmp = cast<GSUCmp>(&MI);
7278 
7279   Register Dst = Cmp->getReg(0);
7280   LLT DstTy = MRI.getType(Dst);
7281   LLT CmpTy = DstTy.changeElementSize(1);
7282 
7283   CmpInst::Predicate LTPredicate = Cmp->isSigned()
7284                                        ? CmpInst::Predicate::ICMP_SLT
7285                                        : CmpInst::Predicate::ICMP_ULT;
7286   CmpInst::Predicate GTPredicate = Cmp->isSigned()
7287                                        ? CmpInst::Predicate::ICMP_SGT
7288                                        : CmpInst::Predicate::ICMP_UGT;
7289 
7290   auto One = MIRBuilder.buildConstant(DstTy, 1);
7291   auto Zero = MIRBuilder.buildConstant(DstTy, 0);
7292   auto IsGT = MIRBuilder.buildICmp(GTPredicate, CmpTy, Cmp->getLHSReg(),
7293                                    Cmp->getRHSReg());
7294   auto SelectZeroOrOne = MIRBuilder.buildSelect(DstTy, IsGT, One, Zero);
7295 
7296   auto MinusOne = MIRBuilder.buildConstant(DstTy, -1);
7297   auto IsLT = MIRBuilder.buildICmp(LTPredicate, CmpTy, Cmp->getLHSReg(),
7298                                    Cmp->getRHSReg());
7299   MIRBuilder.buildSelect(Dst, IsLT, MinusOne, SelectZeroOrOne);
7300 
7301   MI.eraseFromParent();
7302   return Legalized;
7303 }
7304 
7305 LegalizerHelper::LegalizeResult
lowerFCopySign(MachineInstr & MI)7306 LegalizerHelper::lowerFCopySign(MachineInstr &MI) {
7307   auto [Dst, DstTy, Src0, Src0Ty, Src1, Src1Ty] = MI.getFirst3RegLLTs();
7308   const int Src0Size = Src0Ty.getScalarSizeInBits();
7309   const int Src1Size = Src1Ty.getScalarSizeInBits();
7310 
7311   auto SignBitMask = MIRBuilder.buildConstant(
7312     Src0Ty, APInt::getSignMask(Src0Size));
7313 
7314   auto NotSignBitMask = MIRBuilder.buildConstant(
7315     Src0Ty, APInt::getLowBitsSet(Src0Size, Src0Size - 1));
7316 
7317   Register And0 = MIRBuilder.buildAnd(Src0Ty, Src0, NotSignBitMask).getReg(0);
7318   Register And1;
7319   if (Src0Ty == Src1Ty) {
7320     And1 = MIRBuilder.buildAnd(Src1Ty, Src1, SignBitMask).getReg(0);
7321   } else if (Src0Size > Src1Size) {
7322     auto ShiftAmt = MIRBuilder.buildConstant(Src0Ty, Src0Size - Src1Size);
7323     auto Zext = MIRBuilder.buildZExt(Src0Ty, Src1);
7324     auto Shift = MIRBuilder.buildShl(Src0Ty, Zext, ShiftAmt);
7325     And1 = MIRBuilder.buildAnd(Src0Ty, Shift, SignBitMask).getReg(0);
7326   } else {
7327     auto ShiftAmt = MIRBuilder.buildConstant(Src1Ty, Src1Size - Src0Size);
7328     auto Shift = MIRBuilder.buildLShr(Src1Ty, Src1, ShiftAmt);
7329     auto Trunc = MIRBuilder.buildTrunc(Src0Ty, Shift);
7330     And1 = MIRBuilder.buildAnd(Src0Ty, Trunc, SignBitMask).getReg(0);
7331   }
7332 
7333   // Be careful about setting nsz/nnan/ninf on every instruction, since the
7334   // constants are a nan and -0.0, but the final result should preserve
7335   // everything.
7336   unsigned Flags = MI.getFlags();
7337 
7338   // We masked the sign bit and the not-sign bit, so these are disjoint.
7339   Flags |= MachineInstr::Disjoint;
7340 
7341   MIRBuilder.buildOr(Dst, And0, And1, Flags);
7342 
7343   MI.eraseFromParent();
7344   return Legalized;
7345 }
7346 
7347 LegalizerHelper::LegalizeResult
lowerFMinNumMaxNum(MachineInstr & MI)7348 LegalizerHelper::lowerFMinNumMaxNum(MachineInstr &MI) {
7349   unsigned NewOp = MI.getOpcode() == TargetOpcode::G_FMINNUM ?
7350     TargetOpcode::G_FMINNUM_IEEE : TargetOpcode::G_FMAXNUM_IEEE;
7351 
7352   auto [Dst, Src0, Src1] = MI.getFirst3Regs();
7353   LLT Ty = MRI.getType(Dst);
7354 
7355   if (!MI.getFlag(MachineInstr::FmNoNans)) {
7356     // Insert canonicalizes if it's possible we need to quiet to get correct
7357     // sNaN behavior.
7358 
7359     // Note this must be done here, and not as an optimization combine in the
7360     // absence of a dedicate quiet-snan instruction as we're using an
7361     // omni-purpose G_FCANONICALIZE.
7362     if (!isKnownNeverSNaN(Src0, MRI))
7363       Src0 = MIRBuilder.buildFCanonicalize(Ty, Src0, MI.getFlags()).getReg(0);
7364 
7365     if (!isKnownNeverSNaN(Src1, MRI))
7366       Src1 = MIRBuilder.buildFCanonicalize(Ty, Src1, MI.getFlags()).getReg(0);
7367   }
7368 
7369   // If there are no nans, it's safe to simply replace this with the non-IEEE
7370   // version.
7371   MIRBuilder.buildInstr(NewOp, {Dst}, {Src0, Src1}, MI.getFlags());
7372   MI.eraseFromParent();
7373   return Legalized;
7374 }
7375 
lowerFMad(MachineInstr & MI)7376 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFMad(MachineInstr &MI) {
7377   // Expand G_FMAD a, b, c -> G_FADD (G_FMUL a, b), c
7378   Register DstReg = MI.getOperand(0).getReg();
7379   LLT Ty = MRI.getType(DstReg);
7380   unsigned Flags = MI.getFlags();
7381 
7382   auto Mul = MIRBuilder.buildFMul(Ty, MI.getOperand(1), MI.getOperand(2),
7383                                   Flags);
7384   MIRBuilder.buildFAdd(DstReg, Mul, MI.getOperand(3), Flags);
7385   MI.eraseFromParent();
7386   return Legalized;
7387 }
7388 
7389 LegalizerHelper::LegalizeResult
lowerIntrinsicRound(MachineInstr & MI)7390 LegalizerHelper::lowerIntrinsicRound(MachineInstr &MI) {
7391   auto [DstReg, X] = MI.getFirst2Regs();
7392   const unsigned Flags = MI.getFlags();
7393   const LLT Ty = MRI.getType(DstReg);
7394   const LLT CondTy = Ty.changeElementSize(1);
7395 
7396   // round(x) =>
7397   //  t = trunc(x);
7398   //  d = fabs(x - t);
7399   //  o = copysign(d >= 0.5 ? 1.0 : 0.0, x);
7400   //  return t + o;
7401 
7402   auto T = MIRBuilder.buildIntrinsicTrunc(Ty, X, Flags);
7403 
7404   auto Diff = MIRBuilder.buildFSub(Ty, X, T, Flags);
7405   auto AbsDiff = MIRBuilder.buildFAbs(Ty, Diff, Flags);
7406 
7407   auto Half = MIRBuilder.buildFConstant(Ty, 0.5);
7408   auto Cmp =
7409       MIRBuilder.buildFCmp(CmpInst::FCMP_OGE, CondTy, AbsDiff, Half, Flags);
7410 
7411   // Could emit G_UITOFP instead
7412   auto One = MIRBuilder.buildFConstant(Ty, 1.0);
7413   auto Zero = MIRBuilder.buildFConstant(Ty, 0.0);
7414   auto BoolFP = MIRBuilder.buildSelect(Ty, Cmp, One, Zero);
7415   auto SignedOffset = MIRBuilder.buildFCopysign(Ty, BoolFP, X);
7416 
7417   MIRBuilder.buildFAdd(DstReg, T, SignedOffset, Flags);
7418 
7419   MI.eraseFromParent();
7420   return Legalized;
7421 }
7422 
lowerFFloor(MachineInstr & MI)7423 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFFloor(MachineInstr &MI) {
7424   auto [DstReg, SrcReg] = MI.getFirst2Regs();
7425   unsigned Flags = MI.getFlags();
7426   LLT Ty = MRI.getType(DstReg);
7427   const LLT CondTy = Ty.changeElementSize(1);
7428 
7429   // result = trunc(src);
7430   // if (src < 0.0 && src != result)
7431   //   result += -1.0.
7432 
7433   auto Trunc = MIRBuilder.buildIntrinsicTrunc(Ty, SrcReg, Flags);
7434   auto Zero = MIRBuilder.buildFConstant(Ty, 0.0);
7435 
7436   auto Lt0 = MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, CondTy,
7437                                   SrcReg, Zero, Flags);
7438   auto NeTrunc = MIRBuilder.buildFCmp(CmpInst::FCMP_ONE, CondTy,
7439                                       SrcReg, Trunc, Flags);
7440   auto And = MIRBuilder.buildAnd(CondTy, Lt0, NeTrunc);
7441   auto AddVal = MIRBuilder.buildSITOFP(Ty, And);
7442 
7443   MIRBuilder.buildFAdd(DstReg, Trunc, AddVal, Flags);
7444   MI.eraseFromParent();
7445   return Legalized;
7446 }
7447 
7448 LegalizerHelper::LegalizeResult
lowerMergeValues(MachineInstr & MI)7449 LegalizerHelper::lowerMergeValues(MachineInstr &MI) {
7450   const unsigned NumOps = MI.getNumOperands();
7451   auto [DstReg, DstTy, Src0Reg, Src0Ty] = MI.getFirst2RegLLTs();
7452   unsigned PartSize = Src0Ty.getSizeInBits();
7453 
7454   LLT WideTy = LLT::scalar(DstTy.getSizeInBits());
7455   Register ResultReg = MIRBuilder.buildZExt(WideTy, Src0Reg).getReg(0);
7456 
7457   for (unsigned I = 2; I != NumOps; ++I) {
7458     const unsigned Offset = (I - 1) * PartSize;
7459 
7460     Register SrcReg = MI.getOperand(I).getReg();
7461     auto ZextInput = MIRBuilder.buildZExt(WideTy, SrcReg);
7462 
7463     Register NextResult = I + 1 == NumOps && WideTy == DstTy ? DstReg :
7464       MRI.createGenericVirtualRegister(WideTy);
7465 
7466     auto ShiftAmt = MIRBuilder.buildConstant(WideTy, Offset);
7467     auto Shl = MIRBuilder.buildShl(WideTy, ZextInput, ShiftAmt);
7468     MIRBuilder.buildOr(NextResult, ResultReg, Shl);
7469     ResultReg = NextResult;
7470   }
7471 
7472   if (DstTy.isPointer()) {
7473     if (MIRBuilder.getDataLayout().isNonIntegralAddressSpace(
7474           DstTy.getAddressSpace())) {
7475       LLVM_DEBUG(dbgs() << "Not casting nonintegral address space\n");
7476       return UnableToLegalize;
7477     }
7478 
7479     MIRBuilder.buildIntToPtr(DstReg, ResultReg);
7480   }
7481 
7482   MI.eraseFromParent();
7483   return Legalized;
7484 }
7485 
7486 LegalizerHelper::LegalizeResult
lowerUnmergeValues(MachineInstr & MI)7487 LegalizerHelper::lowerUnmergeValues(MachineInstr &MI) {
7488   const unsigned NumDst = MI.getNumOperands() - 1;
7489   Register SrcReg = MI.getOperand(NumDst).getReg();
7490   Register Dst0Reg = MI.getOperand(0).getReg();
7491   LLT DstTy = MRI.getType(Dst0Reg);
7492   if (DstTy.isPointer())
7493     return UnableToLegalize; // TODO
7494 
7495   SrcReg = coerceToScalar(SrcReg);
7496   if (!SrcReg)
7497     return UnableToLegalize;
7498 
7499   // Expand scalarizing unmerge as bitcast to integer and shift.
7500   LLT IntTy = MRI.getType(SrcReg);
7501 
7502   MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
7503 
7504   const unsigned DstSize = DstTy.getSizeInBits();
7505   unsigned Offset = DstSize;
7506   for (unsigned I = 1; I != NumDst; ++I, Offset += DstSize) {
7507     auto ShiftAmt = MIRBuilder.buildConstant(IntTy, Offset);
7508     auto Shift = MIRBuilder.buildLShr(IntTy, SrcReg, ShiftAmt);
7509     MIRBuilder.buildTrunc(MI.getOperand(I), Shift);
7510   }
7511 
7512   MI.eraseFromParent();
7513   return Legalized;
7514 }
7515 
7516 /// Lower a vector extract or insert by writing the vector to a stack temporary
7517 /// and reloading the element or vector.
7518 ///
7519 /// %dst = G_EXTRACT_VECTOR_ELT %vec, %idx
7520 ///  =>
7521 ///  %stack_temp = G_FRAME_INDEX
7522 ///  G_STORE %vec, %stack_temp
7523 ///  %idx = clamp(%idx, %vec.getNumElements())
7524 ///  %element_ptr = G_PTR_ADD %stack_temp, %idx
7525 ///  %dst = G_LOAD %element_ptr
7526 LegalizerHelper::LegalizeResult
lowerExtractInsertVectorElt(MachineInstr & MI)7527 LegalizerHelper::lowerExtractInsertVectorElt(MachineInstr &MI) {
7528   Register DstReg = MI.getOperand(0).getReg();
7529   Register SrcVec = MI.getOperand(1).getReg();
7530   Register InsertVal;
7531   if (MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
7532     InsertVal = MI.getOperand(2).getReg();
7533 
7534   Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
7535 
7536   LLT VecTy = MRI.getType(SrcVec);
7537   LLT EltTy = VecTy.getElementType();
7538   unsigned NumElts = VecTy.getNumElements();
7539 
7540   int64_t IdxVal;
7541   if (mi_match(Idx, MRI, m_ICst(IdxVal)) && IdxVal <= NumElts) {
7542     SmallVector<Register, 8> SrcRegs;
7543     extractParts(SrcVec, EltTy, NumElts, SrcRegs, MIRBuilder, MRI);
7544 
7545     if (InsertVal) {
7546       SrcRegs[IdxVal] = MI.getOperand(2).getReg();
7547       MIRBuilder.buildMergeLikeInstr(DstReg, SrcRegs);
7548     } else {
7549       MIRBuilder.buildCopy(DstReg, SrcRegs[IdxVal]);
7550     }
7551 
7552     MI.eraseFromParent();
7553     return Legalized;
7554   }
7555 
7556   if (!EltTy.isByteSized()) { // Not implemented.
7557     LLVM_DEBUG(dbgs() << "Can't handle non-byte element vectors yet\n");
7558     return UnableToLegalize;
7559   }
7560 
7561   unsigned EltBytes = EltTy.getSizeInBytes();
7562   Align VecAlign = getStackTemporaryAlignment(VecTy);
7563   Align EltAlign;
7564 
7565   MachinePointerInfo PtrInfo;
7566   auto StackTemp = createStackTemporary(
7567       TypeSize::getFixed(VecTy.getSizeInBytes()), VecAlign, PtrInfo);
7568   MIRBuilder.buildStore(SrcVec, StackTemp, PtrInfo, VecAlign);
7569 
7570   // Get the pointer to the element, and be sure not to hit undefined behavior
7571   // if the index is out of bounds.
7572   Register EltPtr = getVectorElementPointer(StackTemp.getReg(0), VecTy, Idx);
7573 
7574   if (mi_match(Idx, MRI, m_ICst(IdxVal))) {
7575     int64_t Offset = IdxVal * EltBytes;
7576     PtrInfo = PtrInfo.getWithOffset(Offset);
7577     EltAlign = commonAlignment(VecAlign, Offset);
7578   } else {
7579     // We lose information with a variable offset.
7580     EltAlign = getStackTemporaryAlignment(EltTy);
7581     PtrInfo = MachinePointerInfo(MRI.getType(EltPtr).getAddressSpace());
7582   }
7583 
7584   if (InsertVal) {
7585     // Write the inserted element
7586     MIRBuilder.buildStore(InsertVal, EltPtr, PtrInfo, EltAlign);
7587 
7588     // Reload the whole vector.
7589     MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
7590   } else {
7591     MIRBuilder.buildLoad(DstReg, EltPtr, PtrInfo, EltAlign);
7592   }
7593 
7594   MI.eraseFromParent();
7595   return Legalized;
7596 }
7597 
7598 LegalizerHelper::LegalizeResult
lowerShuffleVector(MachineInstr & MI)7599 LegalizerHelper::lowerShuffleVector(MachineInstr &MI) {
7600   auto [DstReg, DstTy, Src0Reg, Src0Ty, Src1Reg, Src1Ty] =
7601       MI.getFirst3RegLLTs();
7602   LLT IdxTy = LLT::scalar(32);
7603 
7604   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
7605   Register Undef;
7606   SmallVector<Register, 32> BuildVec;
7607   LLT EltTy = DstTy.getScalarType();
7608 
7609   for (int Idx : Mask) {
7610     if (Idx < 0) {
7611       if (!Undef.isValid())
7612         Undef = MIRBuilder.buildUndef(EltTy).getReg(0);
7613       BuildVec.push_back(Undef);
7614       continue;
7615     }
7616 
7617     if (Src0Ty.isScalar()) {
7618       BuildVec.push_back(Idx == 0 ? Src0Reg : Src1Reg);
7619     } else {
7620       int NumElts = Src0Ty.getNumElements();
7621       Register SrcVec = Idx < NumElts ? Src0Reg : Src1Reg;
7622       int ExtractIdx = Idx < NumElts ? Idx : Idx - NumElts;
7623       auto IdxK = MIRBuilder.buildConstant(IdxTy, ExtractIdx);
7624       auto Extract = MIRBuilder.buildExtractVectorElement(EltTy, SrcVec, IdxK);
7625       BuildVec.push_back(Extract.getReg(0));
7626     }
7627   }
7628 
7629   if (DstTy.isScalar())
7630     MIRBuilder.buildCopy(DstReg, BuildVec[0]);
7631   else
7632     MIRBuilder.buildBuildVector(DstReg, BuildVec);
7633   MI.eraseFromParent();
7634   return Legalized;
7635 }
7636 
7637 LegalizerHelper::LegalizeResult
lowerVECTOR_COMPRESS(llvm::MachineInstr & MI)7638 LegalizerHelper::lowerVECTOR_COMPRESS(llvm::MachineInstr &MI) {
7639   auto [Dst, DstTy, Vec, VecTy, Mask, MaskTy, Passthru, PassthruTy] =
7640       MI.getFirst4RegLLTs();
7641 
7642   if (VecTy.isScalableVector())
7643     report_fatal_error("Cannot expand masked_compress for scalable vectors.");
7644 
7645   Align VecAlign = getStackTemporaryAlignment(VecTy);
7646   MachinePointerInfo PtrInfo;
7647   Register StackPtr =
7648       createStackTemporary(TypeSize::getFixed(VecTy.getSizeInBytes()), VecAlign,
7649                            PtrInfo)
7650           .getReg(0);
7651   MachinePointerInfo ValPtrInfo =
7652       MachinePointerInfo::getUnknownStack(*MI.getMF());
7653 
7654   LLT IdxTy = LLT::scalar(32);
7655   LLT ValTy = VecTy.getElementType();
7656   Align ValAlign = getStackTemporaryAlignment(ValTy);
7657 
7658   auto OutPos = MIRBuilder.buildConstant(IdxTy, 0);
7659 
7660   bool HasPassthru =
7661       MRI.getVRegDef(Passthru)->getOpcode() != TargetOpcode::G_IMPLICIT_DEF;
7662 
7663   if (HasPassthru)
7664     MIRBuilder.buildStore(Passthru, StackPtr, PtrInfo, VecAlign);
7665 
7666   Register LastWriteVal;
7667   std::optional<APInt> PassthruSplatVal =
7668       isConstantOrConstantSplatVector(*MRI.getVRegDef(Passthru), MRI);
7669 
7670   if (PassthruSplatVal.has_value()) {
7671     LastWriteVal =
7672         MIRBuilder.buildConstant(ValTy, PassthruSplatVal.value()).getReg(0);
7673   } else if (HasPassthru) {
7674     auto Popcount = MIRBuilder.buildZExt(MaskTy.changeElementSize(32), Mask);
7675     Popcount = MIRBuilder.buildInstr(TargetOpcode::G_VECREDUCE_ADD,
7676                                      {LLT::scalar(32)}, {Popcount});
7677 
7678     Register LastElmtPtr =
7679         getVectorElementPointer(StackPtr, VecTy, Popcount.getReg(0));
7680     LastWriteVal =
7681         MIRBuilder.buildLoad(ValTy, LastElmtPtr, ValPtrInfo, ValAlign)
7682             .getReg(0);
7683   }
7684 
7685   unsigned NumElmts = VecTy.getNumElements();
7686   for (unsigned I = 0; I < NumElmts; ++I) {
7687     auto Idx = MIRBuilder.buildConstant(IdxTy, I);
7688     auto Val = MIRBuilder.buildExtractVectorElement(ValTy, Vec, Idx);
7689     Register ElmtPtr =
7690         getVectorElementPointer(StackPtr, VecTy, OutPos.getReg(0));
7691     MIRBuilder.buildStore(Val, ElmtPtr, ValPtrInfo, ValAlign);
7692 
7693     LLT MaskITy = MaskTy.getElementType();
7694     auto MaskI = MIRBuilder.buildExtractVectorElement(MaskITy, Mask, Idx);
7695     if (MaskITy.getSizeInBits() > 1)
7696       MaskI = MIRBuilder.buildTrunc(LLT::scalar(1), MaskI);
7697 
7698     MaskI = MIRBuilder.buildZExt(IdxTy, MaskI);
7699     OutPos = MIRBuilder.buildAdd(IdxTy, OutPos, MaskI);
7700 
7701     if (HasPassthru && I == NumElmts - 1) {
7702       auto EndOfVector =
7703           MIRBuilder.buildConstant(IdxTy, VecTy.getNumElements() - 1);
7704       auto AllLanesSelected = MIRBuilder.buildICmp(
7705           CmpInst::ICMP_UGT, LLT::scalar(1), OutPos, EndOfVector);
7706       OutPos = MIRBuilder.buildInstr(TargetOpcode::G_UMIN, {IdxTy},
7707                                      {OutPos, EndOfVector});
7708       ElmtPtr = getVectorElementPointer(StackPtr, VecTy, OutPos.getReg(0));
7709 
7710       LastWriteVal =
7711           MIRBuilder.buildSelect(ValTy, AllLanesSelected, Val, LastWriteVal)
7712               .getReg(0);
7713       MIRBuilder.buildStore(LastWriteVal, ElmtPtr, ValPtrInfo, ValAlign);
7714     }
7715   }
7716 
7717   // TODO: Use StackPtr's FrameIndex alignment.
7718   MIRBuilder.buildLoad(Dst, StackPtr, PtrInfo, VecAlign);
7719 
7720   MI.eraseFromParent();
7721   return Legalized;
7722 }
7723 
getDynStackAllocTargetPtr(Register SPReg,Register AllocSize,Align Alignment,LLT PtrTy)7724 Register LegalizerHelper::getDynStackAllocTargetPtr(Register SPReg,
7725                                                     Register AllocSize,
7726                                                     Align Alignment,
7727                                                     LLT PtrTy) {
7728   LLT IntPtrTy = LLT::scalar(PtrTy.getSizeInBits());
7729 
7730   auto SPTmp = MIRBuilder.buildCopy(PtrTy, SPReg);
7731   SPTmp = MIRBuilder.buildCast(IntPtrTy, SPTmp);
7732 
7733   // Subtract the final alloc from the SP. We use G_PTRTOINT here so we don't
7734   // have to generate an extra instruction to negate the alloc and then use
7735   // G_PTR_ADD to add the negative offset.
7736   auto Alloc = MIRBuilder.buildSub(IntPtrTy, SPTmp, AllocSize);
7737   if (Alignment > Align(1)) {
7738     APInt AlignMask(IntPtrTy.getSizeInBits(), Alignment.value(), true);
7739     AlignMask.negate();
7740     auto AlignCst = MIRBuilder.buildConstant(IntPtrTy, AlignMask);
7741     Alloc = MIRBuilder.buildAnd(IntPtrTy, Alloc, AlignCst);
7742   }
7743 
7744   return MIRBuilder.buildCast(PtrTy, Alloc).getReg(0);
7745 }
7746 
7747 LegalizerHelper::LegalizeResult
lowerDynStackAlloc(MachineInstr & MI)7748 LegalizerHelper::lowerDynStackAlloc(MachineInstr &MI) {
7749   const auto &MF = *MI.getMF();
7750   const auto &TFI = *MF.getSubtarget().getFrameLowering();
7751   if (TFI.getStackGrowthDirection() == TargetFrameLowering::StackGrowsUp)
7752     return UnableToLegalize;
7753 
7754   Register Dst = MI.getOperand(0).getReg();
7755   Register AllocSize = MI.getOperand(1).getReg();
7756   Align Alignment = assumeAligned(MI.getOperand(2).getImm());
7757 
7758   LLT PtrTy = MRI.getType(Dst);
7759   Register SPReg = TLI.getStackPointerRegisterToSaveRestore();
7760   Register SPTmp =
7761       getDynStackAllocTargetPtr(SPReg, AllocSize, Alignment, PtrTy);
7762 
7763   MIRBuilder.buildCopy(SPReg, SPTmp);
7764   MIRBuilder.buildCopy(Dst, SPTmp);
7765 
7766   MI.eraseFromParent();
7767   return Legalized;
7768 }
7769 
7770 LegalizerHelper::LegalizeResult
lowerStackSave(MachineInstr & MI)7771 LegalizerHelper::lowerStackSave(MachineInstr &MI) {
7772   Register StackPtr = TLI.getStackPointerRegisterToSaveRestore();
7773   if (!StackPtr)
7774     return UnableToLegalize;
7775 
7776   MIRBuilder.buildCopy(MI.getOperand(0), StackPtr);
7777   MI.eraseFromParent();
7778   return Legalized;
7779 }
7780 
7781 LegalizerHelper::LegalizeResult
lowerStackRestore(MachineInstr & MI)7782 LegalizerHelper::lowerStackRestore(MachineInstr &MI) {
7783   Register StackPtr = TLI.getStackPointerRegisterToSaveRestore();
7784   if (!StackPtr)
7785     return UnableToLegalize;
7786 
7787   MIRBuilder.buildCopy(StackPtr, MI.getOperand(0));
7788   MI.eraseFromParent();
7789   return Legalized;
7790 }
7791 
7792 LegalizerHelper::LegalizeResult
lowerExtract(MachineInstr & MI)7793 LegalizerHelper::lowerExtract(MachineInstr &MI) {
7794   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
7795   unsigned Offset = MI.getOperand(2).getImm();
7796 
7797   // Extract sub-vector or one element
7798   if (SrcTy.isVector()) {
7799     unsigned SrcEltSize = SrcTy.getElementType().getSizeInBits();
7800     unsigned DstSize = DstTy.getSizeInBits();
7801 
7802     if ((Offset % SrcEltSize == 0) && (DstSize % SrcEltSize == 0) &&
7803         (Offset + DstSize <= SrcTy.getSizeInBits())) {
7804       // Unmerge and allow access to each Src element for the artifact combiner.
7805       auto Unmerge = MIRBuilder.buildUnmerge(SrcTy.getElementType(), SrcReg);
7806 
7807       // Take element(s) we need to extract and copy it (merge them).
7808       SmallVector<Register, 8> SubVectorElts;
7809       for (unsigned Idx = Offset / SrcEltSize;
7810            Idx < (Offset + DstSize) / SrcEltSize; ++Idx) {
7811         SubVectorElts.push_back(Unmerge.getReg(Idx));
7812       }
7813       if (SubVectorElts.size() == 1)
7814         MIRBuilder.buildCopy(DstReg, SubVectorElts[0]);
7815       else
7816         MIRBuilder.buildMergeLikeInstr(DstReg, SubVectorElts);
7817 
7818       MI.eraseFromParent();
7819       return Legalized;
7820     }
7821   }
7822 
7823   if (DstTy.isScalar() &&
7824       (SrcTy.isScalar() ||
7825        (SrcTy.isVector() && DstTy == SrcTy.getElementType()))) {
7826     LLT SrcIntTy = SrcTy;
7827     if (!SrcTy.isScalar()) {
7828       SrcIntTy = LLT::scalar(SrcTy.getSizeInBits());
7829       SrcReg = MIRBuilder.buildBitcast(SrcIntTy, SrcReg).getReg(0);
7830     }
7831 
7832     if (Offset == 0)
7833       MIRBuilder.buildTrunc(DstReg, SrcReg);
7834     else {
7835       auto ShiftAmt = MIRBuilder.buildConstant(SrcIntTy, Offset);
7836       auto Shr = MIRBuilder.buildLShr(SrcIntTy, SrcReg, ShiftAmt);
7837       MIRBuilder.buildTrunc(DstReg, Shr);
7838     }
7839 
7840     MI.eraseFromParent();
7841     return Legalized;
7842   }
7843 
7844   return UnableToLegalize;
7845 }
7846 
lowerInsert(MachineInstr & MI)7847 LegalizerHelper::LegalizeResult LegalizerHelper::lowerInsert(MachineInstr &MI) {
7848   auto [Dst, Src, InsertSrc] = MI.getFirst3Regs();
7849   uint64_t Offset = MI.getOperand(3).getImm();
7850 
7851   LLT DstTy = MRI.getType(Src);
7852   LLT InsertTy = MRI.getType(InsertSrc);
7853 
7854   // Insert sub-vector or one element
7855   if (DstTy.isVector() && !InsertTy.isPointer()) {
7856     LLT EltTy = DstTy.getElementType();
7857     unsigned EltSize = EltTy.getSizeInBits();
7858     unsigned InsertSize = InsertTy.getSizeInBits();
7859 
7860     if ((Offset % EltSize == 0) && (InsertSize % EltSize == 0) &&
7861         (Offset + InsertSize <= DstTy.getSizeInBits())) {
7862       auto UnmergeSrc = MIRBuilder.buildUnmerge(EltTy, Src);
7863       SmallVector<Register, 8> DstElts;
7864       unsigned Idx = 0;
7865       // Elements from Src before insert start Offset
7866       for (; Idx < Offset / EltSize; ++Idx) {
7867         DstElts.push_back(UnmergeSrc.getReg(Idx));
7868       }
7869 
7870       // Replace elements in Src with elements from InsertSrc
7871       if (InsertTy.getSizeInBits() > EltSize) {
7872         auto UnmergeInsertSrc = MIRBuilder.buildUnmerge(EltTy, InsertSrc);
7873         for (unsigned i = 0; Idx < (Offset + InsertSize) / EltSize;
7874              ++Idx, ++i) {
7875           DstElts.push_back(UnmergeInsertSrc.getReg(i));
7876         }
7877       } else {
7878         DstElts.push_back(InsertSrc);
7879         ++Idx;
7880       }
7881 
7882       // Remaining elements from Src after insert
7883       for (; Idx < DstTy.getNumElements(); ++Idx) {
7884         DstElts.push_back(UnmergeSrc.getReg(Idx));
7885       }
7886 
7887       MIRBuilder.buildMergeLikeInstr(Dst, DstElts);
7888       MI.eraseFromParent();
7889       return Legalized;
7890     }
7891   }
7892 
7893   if (InsertTy.isVector() ||
7894       (DstTy.isVector() && DstTy.getElementType() != InsertTy))
7895     return UnableToLegalize;
7896 
7897   const DataLayout &DL = MIRBuilder.getDataLayout();
7898   if ((DstTy.isPointer() &&
7899        DL.isNonIntegralAddressSpace(DstTy.getAddressSpace())) ||
7900       (InsertTy.isPointer() &&
7901        DL.isNonIntegralAddressSpace(InsertTy.getAddressSpace()))) {
7902     LLVM_DEBUG(dbgs() << "Not casting non-integral address space integer\n");
7903     return UnableToLegalize;
7904   }
7905 
7906   LLT IntDstTy = DstTy;
7907 
7908   if (!DstTy.isScalar()) {
7909     IntDstTy = LLT::scalar(DstTy.getSizeInBits());
7910     Src = MIRBuilder.buildCast(IntDstTy, Src).getReg(0);
7911   }
7912 
7913   if (!InsertTy.isScalar()) {
7914     const LLT IntInsertTy = LLT::scalar(InsertTy.getSizeInBits());
7915     InsertSrc = MIRBuilder.buildPtrToInt(IntInsertTy, InsertSrc).getReg(0);
7916   }
7917 
7918   Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0);
7919   if (Offset != 0) {
7920     auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset);
7921     ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0);
7922   }
7923 
7924   APInt MaskVal = APInt::getBitsSetWithWrap(
7925       DstTy.getSizeInBits(), Offset + InsertTy.getSizeInBits(), Offset);
7926 
7927   auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal);
7928   auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask);
7929   auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc);
7930 
7931   MIRBuilder.buildCast(Dst, Or);
7932   MI.eraseFromParent();
7933   return Legalized;
7934 }
7935 
7936 LegalizerHelper::LegalizeResult
lowerSADDO_SSUBO(MachineInstr & MI)7937 LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
7938   auto [Dst0, Dst0Ty, Dst1, Dst1Ty, LHS, LHSTy, RHS, RHSTy] =
7939       MI.getFirst4RegLLTs();
7940   const bool IsAdd = MI.getOpcode() == TargetOpcode::G_SADDO;
7941 
7942   LLT Ty = Dst0Ty;
7943   LLT BoolTy = Dst1Ty;
7944 
7945   Register NewDst0 = MRI.cloneVirtualRegister(Dst0);
7946 
7947   if (IsAdd)
7948     MIRBuilder.buildAdd(NewDst0, LHS, RHS);
7949   else
7950     MIRBuilder.buildSub(NewDst0, LHS, RHS);
7951 
7952   // TODO: If SADDSAT/SSUBSAT is legal, compare results to detect overflow.
7953 
7954   auto Zero = MIRBuilder.buildConstant(Ty, 0);
7955 
7956   // For an addition, the result should be less than one of the operands (LHS)
7957   // if and only if the other operand (RHS) is negative, otherwise there will
7958   // be overflow.
7959   // For a subtraction, the result should be less than one of the operands
7960   // (LHS) if and only if the other operand (RHS) is (non-zero) positive,
7961   // otherwise there will be overflow.
7962   auto ResultLowerThanLHS =
7963       MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, NewDst0, LHS);
7964   auto ConditionRHS = MIRBuilder.buildICmp(
7965       IsAdd ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGT, BoolTy, RHS, Zero);
7966 
7967   MIRBuilder.buildXor(Dst1, ConditionRHS, ResultLowerThanLHS);
7968 
7969   MIRBuilder.buildCopy(Dst0, NewDst0);
7970   MI.eraseFromParent();
7971 
7972   return Legalized;
7973 }
7974 
7975 LegalizerHelper::LegalizeResult
lowerAddSubSatToMinMax(MachineInstr & MI)7976 LegalizerHelper::lowerAddSubSatToMinMax(MachineInstr &MI) {
7977   auto [Res, LHS, RHS] = MI.getFirst3Regs();
7978   LLT Ty = MRI.getType(Res);
7979   bool IsSigned;
7980   bool IsAdd;
7981   unsigned BaseOp;
7982   switch (MI.getOpcode()) {
7983   default:
7984     llvm_unreachable("unexpected addsat/subsat opcode");
7985   case TargetOpcode::G_UADDSAT:
7986     IsSigned = false;
7987     IsAdd = true;
7988     BaseOp = TargetOpcode::G_ADD;
7989     break;
7990   case TargetOpcode::G_SADDSAT:
7991     IsSigned = true;
7992     IsAdd = true;
7993     BaseOp = TargetOpcode::G_ADD;
7994     break;
7995   case TargetOpcode::G_USUBSAT:
7996     IsSigned = false;
7997     IsAdd = false;
7998     BaseOp = TargetOpcode::G_SUB;
7999     break;
8000   case TargetOpcode::G_SSUBSAT:
8001     IsSigned = true;
8002     IsAdd = false;
8003     BaseOp = TargetOpcode::G_SUB;
8004     break;
8005   }
8006 
8007   if (IsSigned) {
8008     // sadd.sat(a, b) ->
8009     //   hi = 0x7fffffff - smax(a, 0)
8010     //   lo = 0x80000000 - smin(a, 0)
8011     //   a + smin(smax(lo, b), hi)
8012     // ssub.sat(a, b) ->
8013     //   lo = smax(a, -1) - 0x7fffffff
8014     //   hi = smin(a, -1) - 0x80000000
8015     //   a - smin(smax(lo, b), hi)
8016     // TODO: AMDGPU can use a "median of 3" instruction here:
8017     //   a +/- med3(lo, b, hi)
8018     uint64_t NumBits = Ty.getScalarSizeInBits();
8019     auto MaxVal =
8020         MIRBuilder.buildConstant(Ty, APInt::getSignedMaxValue(NumBits));
8021     auto MinVal =
8022         MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(NumBits));
8023     MachineInstrBuilder Hi, Lo;
8024     if (IsAdd) {
8025       auto Zero = MIRBuilder.buildConstant(Ty, 0);
8026       Hi = MIRBuilder.buildSub(Ty, MaxVal, MIRBuilder.buildSMax(Ty, LHS, Zero));
8027       Lo = MIRBuilder.buildSub(Ty, MinVal, MIRBuilder.buildSMin(Ty, LHS, Zero));
8028     } else {
8029       auto NegOne = MIRBuilder.buildConstant(Ty, -1);
8030       Lo = MIRBuilder.buildSub(Ty, MIRBuilder.buildSMax(Ty, LHS, NegOne),
8031                                MaxVal);
8032       Hi = MIRBuilder.buildSub(Ty, MIRBuilder.buildSMin(Ty, LHS, NegOne),
8033                                MinVal);
8034     }
8035     auto RHSClamped =
8036         MIRBuilder.buildSMin(Ty, MIRBuilder.buildSMax(Ty, Lo, RHS), Hi);
8037     MIRBuilder.buildInstr(BaseOp, {Res}, {LHS, RHSClamped});
8038   } else {
8039     // uadd.sat(a, b) -> a + umin(~a, b)
8040     // usub.sat(a, b) -> a - umin(a, b)
8041     Register Not = IsAdd ? MIRBuilder.buildNot(Ty, LHS).getReg(0) : LHS;
8042     auto Min = MIRBuilder.buildUMin(Ty, Not, RHS);
8043     MIRBuilder.buildInstr(BaseOp, {Res}, {LHS, Min});
8044   }
8045 
8046   MI.eraseFromParent();
8047   return Legalized;
8048 }
8049 
8050 LegalizerHelper::LegalizeResult
lowerAddSubSatToAddoSubo(MachineInstr & MI)8051 LegalizerHelper::lowerAddSubSatToAddoSubo(MachineInstr &MI) {
8052   auto [Res, LHS, RHS] = MI.getFirst3Regs();
8053   LLT Ty = MRI.getType(Res);
8054   LLT BoolTy = Ty.changeElementSize(1);
8055   bool IsSigned;
8056   bool IsAdd;
8057   unsigned OverflowOp;
8058   switch (MI.getOpcode()) {
8059   default:
8060     llvm_unreachable("unexpected addsat/subsat opcode");
8061   case TargetOpcode::G_UADDSAT:
8062     IsSigned = false;
8063     IsAdd = true;
8064     OverflowOp = TargetOpcode::G_UADDO;
8065     break;
8066   case TargetOpcode::G_SADDSAT:
8067     IsSigned = true;
8068     IsAdd = true;
8069     OverflowOp = TargetOpcode::G_SADDO;
8070     break;
8071   case TargetOpcode::G_USUBSAT:
8072     IsSigned = false;
8073     IsAdd = false;
8074     OverflowOp = TargetOpcode::G_USUBO;
8075     break;
8076   case TargetOpcode::G_SSUBSAT:
8077     IsSigned = true;
8078     IsAdd = false;
8079     OverflowOp = TargetOpcode::G_SSUBO;
8080     break;
8081   }
8082 
8083   auto OverflowRes =
8084       MIRBuilder.buildInstr(OverflowOp, {Ty, BoolTy}, {LHS, RHS});
8085   Register Tmp = OverflowRes.getReg(0);
8086   Register Ov = OverflowRes.getReg(1);
8087   MachineInstrBuilder Clamp;
8088   if (IsSigned) {
8089     // sadd.sat(a, b) ->
8090     //   {tmp, ov} = saddo(a, b)
8091     //   ov ? (tmp >>s 31) + 0x80000000 : r
8092     // ssub.sat(a, b) ->
8093     //   {tmp, ov} = ssubo(a, b)
8094     //   ov ? (tmp >>s 31) + 0x80000000 : r
8095     uint64_t NumBits = Ty.getScalarSizeInBits();
8096     auto ShiftAmount = MIRBuilder.buildConstant(Ty, NumBits - 1);
8097     auto Sign = MIRBuilder.buildAShr(Ty, Tmp, ShiftAmount);
8098     auto MinVal =
8099         MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(NumBits));
8100     Clamp = MIRBuilder.buildAdd(Ty, Sign, MinVal);
8101   } else {
8102     // uadd.sat(a, b) ->
8103     //   {tmp, ov} = uaddo(a, b)
8104     //   ov ? 0xffffffff : tmp
8105     // usub.sat(a, b) ->
8106     //   {tmp, ov} = usubo(a, b)
8107     //   ov ? 0 : tmp
8108     Clamp = MIRBuilder.buildConstant(Ty, IsAdd ? -1 : 0);
8109   }
8110   MIRBuilder.buildSelect(Res, Ov, Clamp, Tmp);
8111 
8112   MI.eraseFromParent();
8113   return Legalized;
8114 }
8115 
8116 LegalizerHelper::LegalizeResult
lowerShlSat(MachineInstr & MI)8117 LegalizerHelper::lowerShlSat(MachineInstr &MI) {
8118   assert((MI.getOpcode() == TargetOpcode::G_SSHLSAT ||
8119           MI.getOpcode() == TargetOpcode::G_USHLSAT) &&
8120          "Expected shlsat opcode!");
8121   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SSHLSAT;
8122   auto [Res, LHS, RHS] = MI.getFirst3Regs();
8123   LLT Ty = MRI.getType(Res);
8124   LLT BoolTy = Ty.changeElementSize(1);
8125 
8126   unsigned BW = Ty.getScalarSizeInBits();
8127   auto Result = MIRBuilder.buildShl(Ty, LHS, RHS);
8128   auto Orig = IsSigned ? MIRBuilder.buildAShr(Ty, Result, RHS)
8129                        : MIRBuilder.buildLShr(Ty, Result, RHS);
8130 
8131   MachineInstrBuilder SatVal;
8132   if (IsSigned) {
8133     auto SatMin = MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(BW));
8134     auto SatMax = MIRBuilder.buildConstant(Ty, APInt::getSignedMaxValue(BW));
8135     auto Cmp = MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, LHS,
8136                                     MIRBuilder.buildConstant(Ty, 0));
8137     SatVal = MIRBuilder.buildSelect(Ty, Cmp, SatMin, SatMax);
8138   } else {
8139     SatVal = MIRBuilder.buildConstant(Ty, APInt::getMaxValue(BW));
8140   }
8141   auto Ov = MIRBuilder.buildICmp(CmpInst::ICMP_NE, BoolTy, LHS, Orig);
8142   MIRBuilder.buildSelect(Res, Ov, SatVal, Result);
8143 
8144   MI.eraseFromParent();
8145   return Legalized;
8146 }
8147 
lowerBswap(MachineInstr & MI)8148 LegalizerHelper::LegalizeResult LegalizerHelper::lowerBswap(MachineInstr &MI) {
8149   auto [Dst, Src] = MI.getFirst2Regs();
8150   const LLT Ty = MRI.getType(Src);
8151   unsigned SizeInBytes = (Ty.getScalarSizeInBits() + 7) / 8;
8152   unsigned BaseShiftAmt = (SizeInBytes - 1) * 8;
8153 
8154   // Swap most and least significant byte, set remaining bytes in Res to zero.
8155   auto ShiftAmt = MIRBuilder.buildConstant(Ty, BaseShiftAmt);
8156   auto LSByteShiftedLeft = MIRBuilder.buildShl(Ty, Src, ShiftAmt);
8157   auto MSByteShiftedRight = MIRBuilder.buildLShr(Ty, Src, ShiftAmt);
8158   auto Res = MIRBuilder.buildOr(Ty, MSByteShiftedRight, LSByteShiftedLeft);
8159 
8160   // Set i-th high/low byte in Res to i-th low/high byte from Src.
8161   for (unsigned i = 1; i < SizeInBytes / 2; ++i) {
8162     // AND with Mask leaves byte i unchanged and sets remaining bytes to 0.
8163     APInt APMask(SizeInBytes * 8, 0xFF << (i * 8));
8164     auto Mask = MIRBuilder.buildConstant(Ty, APMask);
8165     auto ShiftAmt = MIRBuilder.buildConstant(Ty, BaseShiftAmt - 16 * i);
8166     // Low byte shifted left to place of high byte: (Src & Mask) << ShiftAmt.
8167     auto LoByte = MIRBuilder.buildAnd(Ty, Src, Mask);
8168     auto LoShiftedLeft = MIRBuilder.buildShl(Ty, LoByte, ShiftAmt);
8169     Res = MIRBuilder.buildOr(Ty, Res, LoShiftedLeft);
8170     // High byte shifted right to place of low byte: (Src >> ShiftAmt) & Mask.
8171     auto SrcShiftedRight = MIRBuilder.buildLShr(Ty, Src, ShiftAmt);
8172     auto HiShiftedRight = MIRBuilder.buildAnd(Ty, SrcShiftedRight, Mask);
8173     Res = MIRBuilder.buildOr(Ty, Res, HiShiftedRight);
8174   }
8175   Res.getInstr()->getOperand(0).setReg(Dst);
8176 
8177   MI.eraseFromParent();
8178   return Legalized;
8179 }
8180 
8181 //{ (Src & Mask) >> N } | { (Src << N) & Mask }
SwapN(unsigned N,DstOp Dst,MachineIRBuilder & B,MachineInstrBuilder Src,const APInt & Mask)8182 static MachineInstrBuilder SwapN(unsigned N, DstOp Dst, MachineIRBuilder &B,
8183                                  MachineInstrBuilder Src, const APInt &Mask) {
8184   const LLT Ty = Dst.getLLTTy(*B.getMRI());
8185   MachineInstrBuilder C_N = B.buildConstant(Ty, N);
8186   MachineInstrBuilder MaskLoNTo0 = B.buildConstant(Ty, Mask);
8187   auto LHS = B.buildLShr(Ty, B.buildAnd(Ty, Src, MaskLoNTo0), C_N);
8188   auto RHS = B.buildAnd(Ty, B.buildShl(Ty, Src, C_N), MaskLoNTo0);
8189   return B.buildOr(Dst, LHS, RHS);
8190 }
8191 
8192 LegalizerHelper::LegalizeResult
lowerBitreverse(MachineInstr & MI)8193 LegalizerHelper::lowerBitreverse(MachineInstr &MI) {
8194   auto [Dst, Src] = MI.getFirst2Regs();
8195   const LLT Ty = MRI.getType(Src);
8196   unsigned Size = Ty.getScalarSizeInBits();
8197 
8198   if (Size >= 8) {
8199     MachineInstrBuilder BSWAP =
8200         MIRBuilder.buildInstr(TargetOpcode::G_BSWAP, {Ty}, {Src});
8201 
8202     // swap high and low 4 bits in 8 bit blocks 7654|3210 -> 3210|7654
8203     //    [(val & 0xF0F0F0F0) >> 4] | [(val & 0x0F0F0F0F) << 4]
8204     // -> [(val & 0xF0F0F0F0) >> 4] | [(val << 4) & 0xF0F0F0F0]
8205     MachineInstrBuilder Swap4 =
8206         SwapN(4, Ty, MIRBuilder, BSWAP, APInt::getSplat(Size, APInt(8, 0xF0)));
8207 
8208     // swap high and low 2 bits in 4 bit blocks 32|10 76|54 -> 10|32 54|76
8209     //    [(val & 0xCCCCCCCC) >> 2] & [(val & 0x33333333) << 2]
8210     // -> [(val & 0xCCCCCCCC) >> 2] & [(val << 2) & 0xCCCCCCCC]
8211     MachineInstrBuilder Swap2 =
8212         SwapN(2, Ty, MIRBuilder, Swap4, APInt::getSplat(Size, APInt(8, 0xCC)));
8213 
8214     // swap high and low 1 bit in 2 bit blocks 1|0 3|2 5|4 7|6 -> 0|1 2|3 4|5
8215     // 6|7
8216     //    [(val & 0xAAAAAAAA) >> 1] & [(val & 0x55555555) << 1]
8217     // -> [(val & 0xAAAAAAAA) >> 1] & [(val << 1) & 0xAAAAAAAA]
8218     SwapN(1, Dst, MIRBuilder, Swap2, APInt::getSplat(Size, APInt(8, 0xAA)));
8219   } else {
8220     // Expand bitreverse for types smaller than 8 bits.
8221     MachineInstrBuilder Tmp;
8222     for (unsigned I = 0, J = Size - 1; I < Size; ++I, --J) {
8223       MachineInstrBuilder Tmp2;
8224       if (I < J) {
8225         auto ShAmt = MIRBuilder.buildConstant(Ty, J - I);
8226         Tmp2 = MIRBuilder.buildShl(Ty, Src, ShAmt);
8227       } else {
8228         auto ShAmt = MIRBuilder.buildConstant(Ty, I - J);
8229         Tmp2 = MIRBuilder.buildLShr(Ty, Src, ShAmt);
8230       }
8231 
8232       auto Mask = MIRBuilder.buildConstant(Ty, 1ULL << J);
8233       Tmp2 = MIRBuilder.buildAnd(Ty, Tmp2, Mask);
8234       if (I == 0)
8235         Tmp = Tmp2;
8236       else
8237         Tmp = MIRBuilder.buildOr(Ty, Tmp, Tmp2);
8238     }
8239     MIRBuilder.buildCopy(Dst, Tmp);
8240   }
8241 
8242   MI.eraseFromParent();
8243   return Legalized;
8244 }
8245 
8246 LegalizerHelper::LegalizeResult
lowerReadWriteRegister(MachineInstr & MI)8247 LegalizerHelper::lowerReadWriteRegister(MachineInstr &MI) {
8248   MachineFunction &MF = MIRBuilder.getMF();
8249 
8250   bool IsRead = MI.getOpcode() == TargetOpcode::G_READ_REGISTER;
8251   int NameOpIdx = IsRead ? 1 : 0;
8252   int ValRegIndex = IsRead ? 0 : 1;
8253 
8254   Register ValReg = MI.getOperand(ValRegIndex).getReg();
8255   const LLT Ty = MRI.getType(ValReg);
8256   const MDString *RegStr = cast<MDString>(
8257     cast<MDNode>(MI.getOperand(NameOpIdx).getMetadata())->getOperand(0));
8258 
8259   Register PhysReg = TLI.getRegisterByName(RegStr->getString().data(), Ty, MF);
8260   if (!PhysReg.isValid())
8261     return UnableToLegalize;
8262 
8263   if (IsRead)
8264     MIRBuilder.buildCopy(ValReg, PhysReg);
8265   else
8266     MIRBuilder.buildCopy(PhysReg, ValReg);
8267 
8268   MI.eraseFromParent();
8269   return Legalized;
8270 }
8271 
8272 LegalizerHelper::LegalizeResult
lowerSMULH_UMULH(MachineInstr & MI)8273 LegalizerHelper::lowerSMULH_UMULH(MachineInstr &MI) {
8274   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULH;
8275   unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
8276   Register Result = MI.getOperand(0).getReg();
8277   LLT OrigTy = MRI.getType(Result);
8278   auto SizeInBits = OrigTy.getScalarSizeInBits();
8279   LLT WideTy = OrigTy.changeElementSize(SizeInBits * 2);
8280 
8281   auto LHS = MIRBuilder.buildInstr(ExtOp, {WideTy}, {MI.getOperand(1)});
8282   auto RHS = MIRBuilder.buildInstr(ExtOp, {WideTy}, {MI.getOperand(2)});
8283   auto Mul = MIRBuilder.buildMul(WideTy, LHS, RHS);
8284   unsigned ShiftOp = IsSigned ? TargetOpcode::G_ASHR : TargetOpcode::G_LSHR;
8285 
8286   auto ShiftAmt = MIRBuilder.buildConstant(WideTy, SizeInBits);
8287   auto Shifted = MIRBuilder.buildInstr(ShiftOp, {WideTy}, {Mul, ShiftAmt});
8288   MIRBuilder.buildTrunc(Result, Shifted);
8289 
8290   MI.eraseFromParent();
8291   return Legalized;
8292 }
8293 
8294 LegalizerHelper::LegalizeResult
lowerISFPCLASS(MachineInstr & MI)8295 LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) {
8296   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
8297   FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
8298 
8299   if (Mask == fcNone) {
8300     MIRBuilder.buildConstant(DstReg, 0);
8301     MI.eraseFromParent();
8302     return Legalized;
8303   }
8304   if (Mask == fcAllFlags) {
8305     MIRBuilder.buildConstant(DstReg, 1);
8306     MI.eraseFromParent();
8307     return Legalized;
8308   }
8309 
8310   // TODO: Try inverting the test with getInvertedFPClassTest like the DAG
8311   // version
8312 
8313   unsigned BitSize = SrcTy.getScalarSizeInBits();
8314   const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
8315 
8316   LLT IntTy = LLT::scalar(BitSize);
8317   if (SrcTy.isVector())
8318     IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
8319   auto AsInt = MIRBuilder.buildCopy(IntTy, SrcReg);
8320 
8321   // Various masks.
8322   APInt SignBit = APInt::getSignMask(BitSize);
8323   APInt ValueMask = APInt::getSignedMaxValue(BitSize);     // All bits but sign.
8324   APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
8325   APInt ExpMask = Inf;
8326   APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
8327   APInt QNaNBitMask =
8328       APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
8329   APInt InvertionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
8330 
8331   auto SignBitC = MIRBuilder.buildConstant(IntTy, SignBit);
8332   auto ValueMaskC = MIRBuilder.buildConstant(IntTy, ValueMask);
8333   auto InfC = MIRBuilder.buildConstant(IntTy, Inf);
8334   auto ExpMaskC = MIRBuilder.buildConstant(IntTy, ExpMask);
8335   auto ZeroC = MIRBuilder.buildConstant(IntTy, 0);
8336 
8337   auto Abs = MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC);
8338   auto Sign =
8339       MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs);
8340 
8341   auto Res = MIRBuilder.buildConstant(DstTy, 0);
8342   // Clang doesn't support capture of structured bindings:
8343   LLT DstTyCopy = DstTy;
8344   const auto appendToRes = [&](MachineInstrBuilder ToAppend) {
8345     Res = MIRBuilder.buildOr(DstTyCopy, Res, ToAppend);
8346   };
8347 
8348   // Tests that involve more than one class should be processed first.
8349   if ((Mask & fcFinite) == fcFinite) {
8350     // finite(V) ==> abs(V) u< exp_mask
8351     appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
8352                                      ExpMaskC));
8353     Mask &= ~fcFinite;
8354   } else if ((Mask & fcFinite) == fcPosFinite) {
8355     // finite(V) && V > 0 ==> V u< exp_mask
8356     appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
8357                                      ExpMaskC));
8358     Mask &= ~fcPosFinite;
8359   } else if ((Mask & fcFinite) == fcNegFinite) {
8360     // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
8361     auto Cmp = MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
8362                                     ExpMaskC);
8363     auto And = MIRBuilder.buildAnd(DstTy, Cmp, Sign);
8364     appendToRes(And);
8365     Mask &= ~fcNegFinite;
8366   }
8367 
8368   if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
8369     // fcZero | fcSubnormal => test all exponent bits are 0
8370     // TODO: Handle sign bit specific cases
8371     // TODO: Handle inverted case
8372     if (PartialCheck == (fcZero | fcSubnormal)) {
8373       auto ExpBits = MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC);
8374       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
8375                                        ExpBits, ZeroC));
8376       Mask &= ~PartialCheck;
8377     }
8378   }
8379 
8380   // Check for individual classes.
8381   if (FPClassTest PartialCheck = Mask & fcZero) {
8382     if (PartialCheck == fcPosZero)
8383       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
8384                                        AsInt, ZeroC));
8385     else if (PartialCheck == fcZero)
8386       appendToRes(
8387           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
8388     else // fcNegZero
8389       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
8390                                        AsInt, SignBitC));
8391   }
8392 
8393   if (FPClassTest PartialCheck = Mask & fcSubnormal) {
8394     // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
8395     // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
8396     auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
8397     auto OneC = MIRBuilder.buildConstant(IntTy, 1);
8398     auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
8399     auto SubnormalRes =
8400         MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
8401                              MIRBuilder.buildConstant(IntTy, AllOneMantissa));
8402     if (PartialCheck == fcNegSubnormal)
8403       SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
8404     appendToRes(SubnormalRes);
8405   }
8406 
8407   if (FPClassTest PartialCheck = Mask & fcInf) {
8408     if (PartialCheck == fcPosInf)
8409       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
8410                                        AsInt, InfC));
8411     else if (PartialCheck == fcInf)
8412       appendToRes(
8413           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
8414     else { // fcNegInf
8415       APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
8416       auto NegInfC = MIRBuilder.buildConstant(IntTy, NegInf);
8417       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
8418                                        AsInt, NegInfC));
8419     }
8420   }
8421 
8422   if (FPClassTest PartialCheck = Mask & fcNan) {
8423     auto InfWithQnanBitC = MIRBuilder.buildConstant(IntTy, Inf | QNaNBitMask);
8424     if (PartialCheck == fcNan) {
8425       // isnan(V) ==> abs(V) u> int(inf)
8426       appendToRes(
8427           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
8428     } else if (PartialCheck == fcQNan) {
8429       // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
8430       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
8431                                        InfWithQnanBitC));
8432     } else { // fcSNan
8433       // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
8434       //                    abs(V) u< (unsigned(Inf) | quiet_bit)
8435       auto IsNan =
8436           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC);
8437       auto IsNotQnan = MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy,
8438                                             Abs, InfWithQnanBitC);
8439       appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
8440     }
8441   }
8442 
8443   if (FPClassTest PartialCheck = Mask & fcNormal) {
8444     // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
8445     // (max_exp-1))
8446     APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
8447     auto ExpMinusOne = MIRBuilder.buildSub(
8448         IntTy, Abs, MIRBuilder.buildConstant(IntTy, ExpLSB));
8449     APInt MaxExpMinusOne = ExpMask - ExpLSB;
8450     auto NormalRes =
8451         MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
8452                              MIRBuilder.buildConstant(IntTy, MaxExpMinusOne));
8453     if (PartialCheck == fcNegNormal)
8454       NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
8455     else if (PartialCheck == fcPosNormal) {
8456       auto PosSign = MIRBuilder.buildXor(
8457           DstTy, Sign, MIRBuilder.buildConstant(DstTy, InvertionMask));
8458       NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
8459     }
8460     appendToRes(NormalRes);
8461   }
8462 
8463   MIRBuilder.buildCopy(DstReg, Res);
8464   MI.eraseFromParent();
8465   return Legalized;
8466 }
8467 
lowerSelect(MachineInstr & MI)8468 LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) {
8469   // Implement G_SELECT in terms of XOR, AND, OR.
8470   auto [DstReg, DstTy, MaskReg, MaskTy, Op1Reg, Op1Ty, Op2Reg, Op2Ty] =
8471       MI.getFirst4RegLLTs();
8472 
8473   bool IsEltPtr = DstTy.isPointerOrPointerVector();
8474   if (IsEltPtr) {
8475     LLT ScalarPtrTy = LLT::scalar(DstTy.getScalarSizeInBits());
8476     LLT NewTy = DstTy.changeElementType(ScalarPtrTy);
8477     Op1Reg = MIRBuilder.buildPtrToInt(NewTy, Op1Reg).getReg(0);
8478     Op2Reg = MIRBuilder.buildPtrToInt(NewTy, Op2Reg).getReg(0);
8479     DstTy = NewTy;
8480   }
8481 
8482   if (MaskTy.isScalar()) {
8483     // Turn the scalar condition into a vector condition mask if needed.
8484 
8485     Register MaskElt = MaskReg;
8486 
8487     // The condition was potentially zero extended before, but we want a sign
8488     // extended boolean.
8489     if (MaskTy != LLT::scalar(1))
8490       MaskElt = MIRBuilder.buildSExtInReg(MaskTy, MaskElt, 1).getReg(0);
8491 
8492     // Continue the sign extension (or truncate) to match the data type.
8493     MaskElt =
8494         MIRBuilder.buildSExtOrTrunc(DstTy.getScalarType(), MaskElt).getReg(0);
8495 
8496     if (DstTy.isVector()) {
8497       // Generate a vector splat idiom.
8498       auto ShufSplat = MIRBuilder.buildShuffleSplat(DstTy, MaskElt);
8499       MaskReg = ShufSplat.getReg(0);
8500     } else {
8501       MaskReg = MaskElt;
8502     }
8503     MaskTy = DstTy;
8504   } else if (!DstTy.isVector()) {
8505     // Cannot handle the case that mask is a vector and dst is a scalar.
8506     return UnableToLegalize;
8507   }
8508 
8509   if (MaskTy.getSizeInBits() != DstTy.getSizeInBits()) {
8510     return UnableToLegalize;
8511   }
8512 
8513   auto NotMask = MIRBuilder.buildNot(MaskTy, MaskReg);
8514   auto NewOp1 = MIRBuilder.buildAnd(MaskTy, Op1Reg, MaskReg);
8515   auto NewOp2 = MIRBuilder.buildAnd(MaskTy, Op2Reg, NotMask);
8516   if (IsEltPtr) {
8517     auto Or = MIRBuilder.buildOr(DstTy, NewOp1, NewOp2);
8518     MIRBuilder.buildIntToPtr(DstReg, Or);
8519   } else {
8520     MIRBuilder.buildOr(DstReg, NewOp1, NewOp2);
8521   }
8522   MI.eraseFromParent();
8523   return Legalized;
8524 }
8525 
lowerDIVREM(MachineInstr & MI)8526 LegalizerHelper::LegalizeResult LegalizerHelper::lowerDIVREM(MachineInstr &MI) {
8527   // Split DIVREM into individual instructions.
8528   unsigned Opcode = MI.getOpcode();
8529 
8530   MIRBuilder.buildInstr(
8531       Opcode == TargetOpcode::G_SDIVREM ? TargetOpcode::G_SDIV
8532                                         : TargetOpcode::G_UDIV,
8533       {MI.getOperand(0).getReg()}, {MI.getOperand(2), MI.getOperand(3)});
8534   MIRBuilder.buildInstr(
8535       Opcode == TargetOpcode::G_SDIVREM ? TargetOpcode::G_SREM
8536                                         : TargetOpcode::G_UREM,
8537       {MI.getOperand(1).getReg()}, {MI.getOperand(2), MI.getOperand(3)});
8538   MI.eraseFromParent();
8539   return Legalized;
8540 }
8541 
8542 LegalizerHelper::LegalizeResult
lowerAbsToAddXor(MachineInstr & MI)8543 LegalizerHelper::lowerAbsToAddXor(MachineInstr &MI) {
8544   // Expand %res = G_ABS %a into:
8545   // %v1 = G_ASHR %a, scalar_size-1
8546   // %v2 = G_ADD %a, %v1
8547   // %res = G_XOR %v2, %v1
8548   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
8549   Register OpReg = MI.getOperand(1).getReg();
8550   auto ShiftAmt =
8551       MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - 1);
8552   auto Shift = MIRBuilder.buildAShr(DstTy, OpReg, ShiftAmt);
8553   auto Add = MIRBuilder.buildAdd(DstTy, OpReg, Shift);
8554   MIRBuilder.buildXor(MI.getOperand(0).getReg(), Add, Shift);
8555   MI.eraseFromParent();
8556   return Legalized;
8557 }
8558 
8559 LegalizerHelper::LegalizeResult
lowerAbsToMaxNeg(MachineInstr & MI)8560 LegalizerHelper::lowerAbsToMaxNeg(MachineInstr &MI) {
8561   // Expand %res = G_ABS %a into:
8562   // %v1 = G_CONSTANT 0
8563   // %v2 = G_SUB %v1, %a
8564   // %res = G_SMAX %a, %v2
8565   Register SrcReg = MI.getOperand(1).getReg();
8566   LLT Ty = MRI.getType(SrcReg);
8567   auto Zero = MIRBuilder.buildConstant(Ty, 0);
8568   auto Sub = MIRBuilder.buildSub(Ty, Zero, SrcReg);
8569   MIRBuilder.buildSMax(MI.getOperand(0), SrcReg, Sub);
8570   MI.eraseFromParent();
8571   return Legalized;
8572 }
8573 
8574 LegalizerHelper::LegalizeResult
lowerAbsToCNeg(MachineInstr & MI)8575 LegalizerHelper::lowerAbsToCNeg(MachineInstr &MI) {
8576   Register SrcReg = MI.getOperand(1).getReg();
8577   Register DestReg = MI.getOperand(0).getReg();
8578   LLT Ty = MRI.getType(SrcReg), IType = LLT::scalar(1);
8579   auto Zero = MIRBuilder.buildConstant(Ty, 0).getReg(0);
8580   auto Sub = MIRBuilder.buildSub(Ty, Zero, SrcReg).getReg(0);
8581   auto ICmp = MIRBuilder.buildICmp(CmpInst::ICMP_SGT, IType, SrcReg, Zero);
8582   MIRBuilder.buildSelect(DestReg, ICmp, SrcReg, Sub);
8583   MI.eraseFromParent();
8584   return Legalized;
8585 }
8586 
8587 LegalizerHelper::LegalizeResult
lowerVectorReduction(MachineInstr & MI)8588 LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
8589   Register SrcReg = MI.getOperand(1).getReg();
8590   LLT SrcTy = MRI.getType(SrcReg);
8591   LLT DstTy = MRI.getType(SrcReg);
8592 
8593   // The source could be a scalar if the IR type was <1 x sN>.
8594   if (SrcTy.isScalar()) {
8595     if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
8596       return UnableToLegalize; // FIXME: handle extension.
8597     // This can be just a plain copy.
8598     Observer.changingInstr(MI);
8599     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
8600     Observer.changedInstr(MI);
8601     return Legalized;
8602   }
8603   return UnableToLegalize;
8604 }
8605 
lowerVAArg(MachineInstr & MI)8606 LegalizerHelper::LegalizeResult LegalizerHelper::lowerVAArg(MachineInstr &MI) {
8607   MachineFunction &MF = *MI.getMF();
8608   const DataLayout &DL = MIRBuilder.getDataLayout();
8609   LLVMContext &Ctx = MF.getFunction().getContext();
8610   Register ListPtr = MI.getOperand(1).getReg();
8611   LLT PtrTy = MRI.getType(ListPtr);
8612 
8613   // LstPtr is a pointer to the head of the list. Get the address
8614   // of the head of the list.
8615   Align PtrAlignment = DL.getABITypeAlign(getTypeForLLT(PtrTy, Ctx));
8616   MachineMemOperand *PtrLoadMMO = MF.getMachineMemOperand(
8617       MachinePointerInfo(), MachineMemOperand::MOLoad, PtrTy, PtrAlignment);
8618   auto VAList = MIRBuilder.buildLoad(PtrTy, ListPtr, *PtrLoadMMO).getReg(0);
8619 
8620   const Align A(MI.getOperand(2).getImm());
8621   LLT PtrTyAsScalarTy = LLT::scalar(PtrTy.getSizeInBits());
8622   if (A > TLI.getMinStackArgumentAlignment()) {
8623     Register AlignAmt =
8624         MIRBuilder.buildConstant(PtrTyAsScalarTy, A.value() - 1).getReg(0);
8625     auto AddDst = MIRBuilder.buildPtrAdd(PtrTy, VAList, AlignAmt);
8626     auto AndDst = MIRBuilder.buildMaskLowPtrBits(PtrTy, AddDst, Log2(A));
8627     VAList = AndDst.getReg(0);
8628   }
8629 
8630   // Increment the pointer, VAList, to the next vaarg
8631   // The list should be bumped by the size of element in the current head of
8632   // list.
8633   Register Dst = MI.getOperand(0).getReg();
8634   LLT LLTTy = MRI.getType(Dst);
8635   Type *Ty = getTypeForLLT(LLTTy, Ctx);
8636   auto IncAmt =
8637       MIRBuilder.buildConstant(PtrTyAsScalarTy, DL.getTypeAllocSize(Ty));
8638   auto Succ = MIRBuilder.buildPtrAdd(PtrTy, VAList, IncAmt);
8639 
8640   // Store the increment VAList to the legalized pointer
8641   MachineMemOperand *StoreMMO = MF.getMachineMemOperand(
8642       MachinePointerInfo(), MachineMemOperand::MOStore, PtrTy, PtrAlignment);
8643   MIRBuilder.buildStore(Succ, ListPtr, *StoreMMO);
8644   // Load the actual argument out of the pointer VAList
8645   Align EltAlignment = DL.getABITypeAlign(Ty);
8646   MachineMemOperand *EltLoadMMO = MF.getMachineMemOperand(
8647       MachinePointerInfo(), MachineMemOperand::MOLoad, LLTTy, EltAlignment);
8648   MIRBuilder.buildLoad(Dst, VAList, *EltLoadMMO);
8649 
8650   MI.eraseFromParent();
8651   return Legalized;
8652 }
8653 
shouldLowerMemFuncForSize(const MachineFunction & MF)8654 static bool shouldLowerMemFuncForSize(const MachineFunction &MF) {
8655   // On Darwin, -Os means optimize for size without hurting performance, so
8656   // only really optimize for size when -Oz (MinSize) is used.
8657   if (MF.getTarget().getTargetTriple().isOSDarwin())
8658     return MF.getFunction().hasMinSize();
8659   return MF.getFunction().hasOptSize();
8660 }
8661 
8662 // Returns a list of types to use for memory op lowering in MemOps. A partial
8663 // port of findOptimalMemOpLowering in TargetLowering.
findGISelOptimalMemOpLowering(std::vector<LLT> & MemOps,unsigned Limit,const MemOp & Op,unsigned DstAS,unsigned SrcAS,const AttributeList & FuncAttributes,const TargetLowering & TLI)8664 static bool findGISelOptimalMemOpLowering(std::vector<LLT> &MemOps,
8665                                           unsigned Limit, const MemOp &Op,
8666                                           unsigned DstAS, unsigned SrcAS,
8667                                           const AttributeList &FuncAttributes,
8668                                           const TargetLowering &TLI) {
8669   if (Op.isMemcpyWithFixedDstAlign() && Op.getSrcAlign() < Op.getDstAlign())
8670     return false;
8671 
8672   LLT Ty = TLI.getOptimalMemOpLLT(Op, FuncAttributes);
8673 
8674   if (Ty == LLT()) {
8675     // Use the largest scalar type whose alignment constraints are satisfied.
8676     // We only need to check DstAlign here as SrcAlign is always greater or
8677     // equal to DstAlign (or zero).
8678     Ty = LLT::scalar(64);
8679     if (Op.isFixedDstAlign())
8680       while (Op.getDstAlign() < Ty.getSizeInBytes() &&
8681              !TLI.allowsMisalignedMemoryAccesses(Ty, DstAS, Op.getDstAlign()))
8682         Ty = LLT::scalar(Ty.getSizeInBytes());
8683     assert(Ty.getSizeInBits() > 0 && "Could not find valid type");
8684     // FIXME: check for the largest legal type we can load/store to.
8685   }
8686 
8687   unsigned NumMemOps = 0;
8688   uint64_t Size = Op.size();
8689   while (Size) {
8690     unsigned TySize = Ty.getSizeInBytes();
8691     while (TySize > Size) {
8692       // For now, only use non-vector load / store's for the left-over pieces.
8693       LLT NewTy = Ty;
8694       // FIXME: check for mem op safety and legality of the types. Not all of
8695       // SDAGisms map cleanly to GISel concepts.
8696       if (NewTy.isVector())
8697         NewTy = NewTy.getSizeInBits() > 64 ? LLT::scalar(64) : LLT::scalar(32);
8698       NewTy = LLT::scalar(llvm::bit_floor(NewTy.getSizeInBits() - 1));
8699       unsigned NewTySize = NewTy.getSizeInBytes();
8700       assert(NewTySize > 0 && "Could not find appropriate type");
8701 
8702       // If the new LLT cannot cover all of the remaining bits, then consider
8703       // issuing a (or a pair of) unaligned and overlapping load / store.
8704       unsigned Fast;
8705       // Need to get a VT equivalent for allowMisalignedMemoryAccesses().
8706       MVT VT = getMVTForLLT(Ty);
8707       if (NumMemOps && Op.allowOverlap() && NewTySize < Size &&
8708           TLI.allowsMisalignedMemoryAccesses(
8709               VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
8710               MachineMemOperand::MONone, &Fast) &&
8711           Fast)
8712         TySize = Size;
8713       else {
8714         Ty = NewTy;
8715         TySize = NewTySize;
8716       }
8717     }
8718 
8719     if (++NumMemOps > Limit)
8720       return false;
8721 
8722     MemOps.push_back(Ty);
8723     Size -= TySize;
8724   }
8725 
8726   return true;
8727 }
8728 
8729 // Get a vectorized representation of the memset value operand, GISel edition.
getMemsetValue(Register Val,LLT Ty,MachineIRBuilder & MIB)8730 static Register getMemsetValue(Register Val, LLT Ty, MachineIRBuilder &MIB) {
8731   MachineRegisterInfo &MRI = *MIB.getMRI();
8732   unsigned NumBits = Ty.getScalarSizeInBits();
8733   auto ValVRegAndVal = getIConstantVRegValWithLookThrough(Val, MRI);
8734   if (!Ty.isVector() && ValVRegAndVal) {
8735     APInt Scalar = ValVRegAndVal->Value.trunc(8);
8736     APInt SplatVal = APInt::getSplat(NumBits, Scalar);
8737     return MIB.buildConstant(Ty, SplatVal).getReg(0);
8738   }
8739 
8740   // Extend the byte value to the larger type, and then multiply by a magic
8741   // value 0x010101... in order to replicate it across every byte.
8742   // Unless it's zero, in which case just emit a larger G_CONSTANT 0.
8743   if (ValVRegAndVal && ValVRegAndVal->Value == 0) {
8744     return MIB.buildConstant(Ty, 0).getReg(0);
8745   }
8746 
8747   LLT ExtType = Ty.getScalarType();
8748   auto ZExt = MIB.buildZExtOrTrunc(ExtType, Val);
8749   if (NumBits > 8) {
8750     APInt Magic = APInt::getSplat(NumBits, APInt(8, 0x01));
8751     auto MagicMI = MIB.buildConstant(ExtType, Magic);
8752     Val = MIB.buildMul(ExtType, ZExt, MagicMI).getReg(0);
8753   }
8754 
8755   // For vector types create a G_BUILD_VECTOR.
8756   if (Ty.isVector())
8757     Val = MIB.buildSplatBuildVector(Ty, Val).getReg(0);
8758 
8759   return Val;
8760 }
8761 
8762 LegalizerHelper::LegalizeResult
lowerMemset(MachineInstr & MI,Register Dst,Register Val,uint64_t KnownLen,Align Alignment,bool IsVolatile)8763 LegalizerHelper::lowerMemset(MachineInstr &MI, Register Dst, Register Val,
8764                              uint64_t KnownLen, Align Alignment,
8765                              bool IsVolatile) {
8766   auto &MF = *MI.getParent()->getParent();
8767   const auto &TLI = *MF.getSubtarget().getTargetLowering();
8768   auto &DL = MF.getDataLayout();
8769   LLVMContext &C = MF.getFunction().getContext();
8770 
8771   assert(KnownLen != 0 && "Have a zero length memset length!");
8772 
8773   bool DstAlignCanChange = false;
8774   MachineFrameInfo &MFI = MF.getFrameInfo();
8775   bool OptSize = shouldLowerMemFuncForSize(MF);
8776 
8777   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
8778   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
8779     DstAlignCanChange = true;
8780 
8781   unsigned Limit = TLI.getMaxStoresPerMemset(OptSize);
8782   std::vector<LLT> MemOps;
8783 
8784   const auto &DstMMO = **MI.memoperands_begin();
8785   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
8786 
8787   auto ValVRegAndVal = getIConstantVRegValWithLookThrough(Val, MRI);
8788   bool IsZeroVal = ValVRegAndVal && ValVRegAndVal->Value == 0;
8789 
8790   if (!findGISelOptimalMemOpLowering(MemOps, Limit,
8791                                      MemOp::Set(KnownLen, DstAlignCanChange,
8792                                                 Alignment,
8793                                                 /*IsZeroMemset=*/IsZeroVal,
8794                                                 /*IsVolatile=*/IsVolatile),
8795                                      DstPtrInfo.getAddrSpace(), ~0u,
8796                                      MF.getFunction().getAttributes(), TLI))
8797     return UnableToLegalize;
8798 
8799   if (DstAlignCanChange) {
8800     // Get an estimate of the type from the LLT.
8801     Type *IRTy = getTypeForLLT(MemOps[0], C);
8802     Align NewAlign = DL.getABITypeAlign(IRTy);
8803     if (NewAlign > Alignment) {
8804       Alignment = NewAlign;
8805       unsigned FI = FIDef->getOperand(1).getIndex();
8806       // Give the stack frame object a larger alignment if needed.
8807       if (MFI.getObjectAlign(FI) < Alignment)
8808         MFI.setObjectAlignment(FI, Alignment);
8809     }
8810   }
8811 
8812   MachineIRBuilder MIB(MI);
8813   // Find the largest store and generate the bit pattern for it.
8814   LLT LargestTy = MemOps[0];
8815   for (unsigned i = 1; i < MemOps.size(); i++)
8816     if (MemOps[i].getSizeInBits() > LargestTy.getSizeInBits())
8817       LargestTy = MemOps[i];
8818 
8819   // The memset stored value is always defined as an s8, so in order to make it
8820   // work with larger store types we need to repeat the bit pattern across the
8821   // wider type.
8822   Register MemSetValue = getMemsetValue(Val, LargestTy, MIB);
8823 
8824   if (!MemSetValue)
8825     return UnableToLegalize;
8826 
8827   // Generate the stores. For each store type in the list, we generate the
8828   // matching store of that type to the destination address.
8829   LLT PtrTy = MRI.getType(Dst);
8830   unsigned DstOff = 0;
8831   unsigned Size = KnownLen;
8832   for (unsigned I = 0; I < MemOps.size(); I++) {
8833     LLT Ty = MemOps[I];
8834     unsigned TySize = Ty.getSizeInBytes();
8835     if (TySize > Size) {
8836       // Issuing an unaligned load / store pair that overlaps with the previous
8837       // pair. Adjust the offset accordingly.
8838       assert(I == MemOps.size() - 1 && I != 0);
8839       DstOff -= TySize - Size;
8840     }
8841 
8842     // If this store is smaller than the largest store see whether we can get
8843     // the smaller value for free with a truncate.
8844     Register Value = MemSetValue;
8845     if (Ty.getSizeInBits() < LargestTy.getSizeInBits()) {
8846       MVT VT = getMVTForLLT(Ty);
8847       MVT LargestVT = getMVTForLLT(LargestTy);
8848       if (!LargestTy.isVector() && !Ty.isVector() &&
8849           TLI.isTruncateFree(LargestVT, VT))
8850         Value = MIB.buildTrunc(Ty, MemSetValue).getReg(0);
8851       else
8852         Value = getMemsetValue(Val, Ty, MIB);
8853       if (!Value)
8854         return UnableToLegalize;
8855     }
8856 
8857     auto *StoreMMO = MF.getMachineMemOperand(&DstMMO, DstOff, Ty);
8858 
8859     Register Ptr = Dst;
8860     if (DstOff != 0) {
8861       auto Offset =
8862           MIB.buildConstant(LLT::scalar(PtrTy.getSizeInBits()), DstOff);
8863       Ptr = MIB.buildPtrAdd(PtrTy, Dst, Offset).getReg(0);
8864     }
8865 
8866     MIB.buildStore(Value, Ptr, *StoreMMO);
8867     DstOff += Ty.getSizeInBytes();
8868     Size -= TySize;
8869   }
8870 
8871   MI.eraseFromParent();
8872   return Legalized;
8873 }
8874 
8875 LegalizerHelper::LegalizeResult
lowerMemcpyInline(MachineInstr & MI)8876 LegalizerHelper::lowerMemcpyInline(MachineInstr &MI) {
8877   assert(MI.getOpcode() == TargetOpcode::G_MEMCPY_INLINE);
8878 
8879   auto [Dst, Src, Len] = MI.getFirst3Regs();
8880 
8881   const auto *MMOIt = MI.memoperands_begin();
8882   const MachineMemOperand *MemOp = *MMOIt;
8883   bool IsVolatile = MemOp->isVolatile();
8884 
8885   // See if this is a constant length copy
8886   auto LenVRegAndVal = getIConstantVRegValWithLookThrough(Len, MRI);
8887   // FIXME: support dynamically sized G_MEMCPY_INLINE
8888   assert(LenVRegAndVal &&
8889          "inline memcpy with dynamic size is not yet supported");
8890   uint64_t KnownLen = LenVRegAndVal->Value.getZExtValue();
8891   if (KnownLen == 0) {
8892     MI.eraseFromParent();
8893     return Legalized;
8894   }
8895 
8896   const auto &DstMMO = **MI.memoperands_begin();
8897   const auto &SrcMMO = **std::next(MI.memoperands_begin());
8898   Align DstAlign = DstMMO.getBaseAlign();
8899   Align SrcAlign = SrcMMO.getBaseAlign();
8900 
8901   return lowerMemcpyInline(MI, Dst, Src, KnownLen, DstAlign, SrcAlign,
8902                            IsVolatile);
8903 }
8904 
8905 LegalizerHelper::LegalizeResult
lowerMemcpyInline(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,Align DstAlign,Align SrcAlign,bool IsVolatile)8906 LegalizerHelper::lowerMemcpyInline(MachineInstr &MI, Register Dst, Register Src,
8907                                    uint64_t KnownLen, Align DstAlign,
8908                                    Align SrcAlign, bool IsVolatile) {
8909   assert(MI.getOpcode() == TargetOpcode::G_MEMCPY_INLINE);
8910   return lowerMemcpy(MI, Dst, Src, KnownLen,
8911                      std::numeric_limits<uint64_t>::max(), DstAlign, SrcAlign,
8912                      IsVolatile);
8913 }
8914 
8915 LegalizerHelper::LegalizeResult
lowerMemcpy(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,uint64_t Limit,Align DstAlign,Align SrcAlign,bool IsVolatile)8916 LegalizerHelper::lowerMemcpy(MachineInstr &MI, Register Dst, Register Src,
8917                              uint64_t KnownLen, uint64_t Limit, Align DstAlign,
8918                              Align SrcAlign, bool IsVolatile) {
8919   auto &MF = *MI.getParent()->getParent();
8920   const auto &TLI = *MF.getSubtarget().getTargetLowering();
8921   auto &DL = MF.getDataLayout();
8922   LLVMContext &C = MF.getFunction().getContext();
8923 
8924   assert(KnownLen != 0 && "Have a zero length memcpy length!");
8925 
8926   bool DstAlignCanChange = false;
8927   MachineFrameInfo &MFI = MF.getFrameInfo();
8928   Align Alignment = std::min(DstAlign, SrcAlign);
8929 
8930   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
8931   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
8932     DstAlignCanChange = true;
8933 
8934   // FIXME: infer better src pointer alignment like SelectionDAG does here.
8935   // FIXME: also use the equivalent of isMemSrcFromConstant and alwaysinlining
8936   // if the memcpy is in a tail call position.
8937 
8938   std::vector<LLT> MemOps;
8939 
8940   const auto &DstMMO = **MI.memoperands_begin();
8941   const auto &SrcMMO = **std::next(MI.memoperands_begin());
8942   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
8943   MachinePointerInfo SrcPtrInfo = SrcMMO.getPointerInfo();
8944 
8945   if (!findGISelOptimalMemOpLowering(
8946           MemOps, Limit,
8947           MemOp::Copy(KnownLen, DstAlignCanChange, Alignment, SrcAlign,
8948                       IsVolatile),
8949           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
8950           MF.getFunction().getAttributes(), TLI))
8951     return UnableToLegalize;
8952 
8953   if (DstAlignCanChange) {
8954     // Get an estimate of the type from the LLT.
8955     Type *IRTy = getTypeForLLT(MemOps[0], C);
8956     Align NewAlign = DL.getABITypeAlign(IRTy);
8957 
8958     // Don't promote to an alignment that would require dynamic stack
8959     // realignment.
8960     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
8961     if (!TRI->hasStackRealignment(MF))
8962       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
8963         NewAlign = NewAlign.previous();
8964 
8965     if (NewAlign > Alignment) {
8966       Alignment = NewAlign;
8967       unsigned FI = FIDef->getOperand(1).getIndex();
8968       // Give the stack frame object a larger alignment if needed.
8969       if (MFI.getObjectAlign(FI) < Alignment)
8970         MFI.setObjectAlignment(FI, Alignment);
8971     }
8972   }
8973 
8974   LLVM_DEBUG(dbgs() << "Inlining memcpy: " << MI << " into loads & stores\n");
8975 
8976   MachineIRBuilder MIB(MI);
8977   // Now we need to emit a pair of load and stores for each of the types we've
8978   // collected. I.e. for each type, generate a load from the source pointer of
8979   // that type width, and then generate a corresponding store to the dest buffer
8980   // of that value loaded. This can result in a sequence of loads and stores
8981   // mixed types, depending on what the target specifies as good types to use.
8982   unsigned CurrOffset = 0;
8983   unsigned Size = KnownLen;
8984   for (auto CopyTy : MemOps) {
8985     // Issuing an unaligned load / store pair  that overlaps with the previous
8986     // pair. Adjust the offset accordingly.
8987     if (CopyTy.getSizeInBytes() > Size)
8988       CurrOffset -= CopyTy.getSizeInBytes() - Size;
8989 
8990     // Construct MMOs for the accesses.
8991     auto *LoadMMO =
8992         MF.getMachineMemOperand(&SrcMMO, CurrOffset, CopyTy.getSizeInBytes());
8993     auto *StoreMMO =
8994         MF.getMachineMemOperand(&DstMMO, CurrOffset, CopyTy.getSizeInBytes());
8995 
8996     // Create the load.
8997     Register LoadPtr = Src;
8998     Register Offset;
8999     if (CurrOffset != 0) {
9000       LLT SrcTy = MRI.getType(Src);
9001       Offset = MIB.buildConstant(LLT::scalar(SrcTy.getSizeInBits()), CurrOffset)
9002                    .getReg(0);
9003       LoadPtr = MIB.buildPtrAdd(SrcTy, Src, Offset).getReg(0);
9004     }
9005     auto LdVal = MIB.buildLoad(CopyTy, LoadPtr, *LoadMMO);
9006 
9007     // Create the store.
9008     Register StorePtr = Dst;
9009     if (CurrOffset != 0) {
9010       LLT DstTy = MRI.getType(Dst);
9011       StorePtr = MIB.buildPtrAdd(DstTy, Dst, Offset).getReg(0);
9012     }
9013     MIB.buildStore(LdVal, StorePtr, *StoreMMO);
9014     CurrOffset += CopyTy.getSizeInBytes();
9015     Size -= CopyTy.getSizeInBytes();
9016   }
9017 
9018   MI.eraseFromParent();
9019   return Legalized;
9020 }
9021 
9022 LegalizerHelper::LegalizeResult
lowerMemmove(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,Align DstAlign,Align SrcAlign,bool IsVolatile)9023 LegalizerHelper::lowerMemmove(MachineInstr &MI, Register Dst, Register Src,
9024                               uint64_t KnownLen, Align DstAlign, Align SrcAlign,
9025                               bool IsVolatile) {
9026   auto &MF = *MI.getParent()->getParent();
9027   const auto &TLI = *MF.getSubtarget().getTargetLowering();
9028   auto &DL = MF.getDataLayout();
9029   LLVMContext &C = MF.getFunction().getContext();
9030 
9031   assert(KnownLen != 0 && "Have a zero length memmove length!");
9032 
9033   bool DstAlignCanChange = false;
9034   MachineFrameInfo &MFI = MF.getFrameInfo();
9035   bool OptSize = shouldLowerMemFuncForSize(MF);
9036   Align Alignment = std::min(DstAlign, SrcAlign);
9037 
9038   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
9039   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
9040     DstAlignCanChange = true;
9041 
9042   unsigned Limit = TLI.getMaxStoresPerMemmove(OptSize);
9043   std::vector<LLT> MemOps;
9044 
9045   const auto &DstMMO = **MI.memoperands_begin();
9046   const auto &SrcMMO = **std::next(MI.memoperands_begin());
9047   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
9048   MachinePointerInfo SrcPtrInfo = SrcMMO.getPointerInfo();
9049 
9050   // FIXME: SelectionDAG always passes false for 'AllowOverlap', apparently due
9051   // to a bug in it's findOptimalMemOpLowering implementation. For now do the
9052   // same thing here.
9053   if (!findGISelOptimalMemOpLowering(
9054           MemOps, Limit,
9055           MemOp::Copy(KnownLen, DstAlignCanChange, Alignment, SrcAlign,
9056                       /*IsVolatile*/ true),
9057           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
9058           MF.getFunction().getAttributes(), TLI))
9059     return UnableToLegalize;
9060 
9061   if (DstAlignCanChange) {
9062     // Get an estimate of the type from the LLT.
9063     Type *IRTy = getTypeForLLT(MemOps[0], C);
9064     Align NewAlign = DL.getABITypeAlign(IRTy);
9065 
9066     // Don't promote to an alignment that would require dynamic stack
9067     // realignment.
9068     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
9069     if (!TRI->hasStackRealignment(MF))
9070       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
9071         NewAlign = NewAlign.previous();
9072 
9073     if (NewAlign > Alignment) {
9074       Alignment = NewAlign;
9075       unsigned FI = FIDef->getOperand(1).getIndex();
9076       // Give the stack frame object a larger alignment if needed.
9077       if (MFI.getObjectAlign(FI) < Alignment)
9078         MFI.setObjectAlignment(FI, Alignment);
9079     }
9080   }
9081 
9082   LLVM_DEBUG(dbgs() << "Inlining memmove: " << MI << " into loads & stores\n");
9083 
9084   MachineIRBuilder MIB(MI);
9085   // Memmove requires that we perform the loads first before issuing the stores.
9086   // Apart from that, this loop is pretty much doing the same thing as the
9087   // memcpy codegen function.
9088   unsigned CurrOffset = 0;
9089   SmallVector<Register, 16> LoadVals;
9090   for (auto CopyTy : MemOps) {
9091     // Construct MMO for the load.
9092     auto *LoadMMO =
9093         MF.getMachineMemOperand(&SrcMMO, CurrOffset, CopyTy.getSizeInBytes());
9094 
9095     // Create the load.
9096     Register LoadPtr = Src;
9097     if (CurrOffset != 0) {
9098       LLT SrcTy = MRI.getType(Src);
9099       auto Offset =
9100           MIB.buildConstant(LLT::scalar(SrcTy.getSizeInBits()), CurrOffset);
9101       LoadPtr = MIB.buildPtrAdd(SrcTy, Src, Offset).getReg(0);
9102     }
9103     LoadVals.push_back(MIB.buildLoad(CopyTy, LoadPtr, *LoadMMO).getReg(0));
9104     CurrOffset += CopyTy.getSizeInBytes();
9105   }
9106 
9107   CurrOffset = 0;
9108   for (unsigned I = 0; I < MemOps.size(); ++I) {
9109     LLT CopyTy = MemOps[I];
9110     // Now store the values loaded.
9111     auto *StoreMMO =
9112         MF.getMachineMemOperand(&DstMMO, CurrOffset, CopyTy.getSizeInBytes());
9113 
9114     Register StorePtr = Dst;
9115     if (CurrOffset != 0) {
9116       LLT DstTy = MRI.getType(Dst);
9117       auto Offset =
9118           MIB.buildConstant(LLT::scalar(DstTy.getSizeInBits()), CurrOffset);
9119       StorePtr = MIB.buildPtrAdd(DstTy, Dst, Offset).getReg(0);
9120     }
9121     MIB.buildStore(LoadVals[I], StorePtr, *StoreMMO);
9122     CurrOffset += CopyTy.getSizeInBytes();
9123   }
9124   MI.eraseFromParent();
9125   return Legalized;
9126 }
9127 
9128 LegalizerHelper::LegalizeResult
lowerMemCpyFamily(MachineInstr & MI,unsigned MaxLen)9129 LegalizerHelper::lowerMemCpyFamily(MachineInstr &MI, unsigned MaxLen) {
9130   const unsigned Opc = MI.getOpcode();
9131   // This combine is fairly complex so it's not written with a separate
9132   // matcher function.
9133   assert((Opc == TargetOpcode::G_MEMCPY || Opc == TargetOpcode::G_MEMMOVE ||
9134           Opc == TargetOpcode::G_MEMSET) &&
9135          "Expected memcpy like instruction");
9136 
9137   auto MMOIt = MI.memoperands_begin();
9138   const MachineMemOperand *MemOp = *MMOIt;
9139 
9140   Align DstAlign = MemOp->getBaseAlign();
9141   Align SrcAlign;
9142   auto [Dst, Src, Len] = MI.getFirst3Regs();
9143 
9144   if (Opc != TargetOpcode::G_MEMSET) {
9145     assert(MMOIt != MI.memoperands_end() && "Expected a second MMO on MI");
9146     MemOp = *(++MMOIt);
9147     SrcAlign = MemOp->getBaseAlign();
9148   }
9149 
9150   // See if this is a constant length copy
9151   auto LenVRegAndVal = getIConstantVRegValWithLookThrough(Len, MRI);
9152   if (!LenVRegAndVal)
9153     return UnableToLegalize;
9154   uint64_t KnownLen = LenVRegAndVal->Value.getZExtValue();
9155 
9156   if (KnownLen == 0) {
9157     MI.eraseFromParent();
9158     return Legalized;
9159   }
9160 
9161   bool IsVolatile = MemOp->isVolatile();
9162   if (Opc == TargetOpcode::G_MEMCPY_INLINE)
9163     return lowerMemcpyInline(MI, Dst, Src, KnownLen, DstAlign, SrcAlign,
9164                              IsVolatile);
9165 
9166   // Don't try to optimize volatile.
9167   if (IsVolatile)
9168     return UnableToLegalize;
9169 
9170   if (MaxLen && KnownLen > MaxLen)
9171     return UnableToLegalize;
9172 
9173   if (Opc == TargetOpcode::G_MEMCPY) {
9174     auto &MF = *MI.getParent()->getParent();
9175     const auto &TLI = *MF.getSubtarget().getTargetLowering();
9176     bool OptSize = shouldLowerMemFuncForSize(MF);
9177     uint64_t Limit = TLI.getMaxStoresPerMemcpy(OptSize);
9178     return lowerMemcpy(MI, Dst, Src, KnownLen, Limit, DstAlign, SrcAlign,
9179                        IsVolatile);
9180   }
9181   if (Opc == TargetOpcode::G_MEMMOVE)
9182     return lowerMemmove(MI, Dst, Src, KnownLen, DstAlign, SrcAlign, IsVolatile);
9183   if (Opc == TargetOpcode::G_MEMSET)
9184     return lowerMemset(MI, Dst, Src, KnownLen, DstAlign, IsVolatile);
9185   return UnableToLegalize;
9186 }
9187