xref: /freebsd/contrib/llvm-project/llvm/lib/Target/X86/X86FixupVectorConstants.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- X86FixupVectorConstants.cpp - optimize constant generation  -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file examines all full size vector constant pool loads and attempts to
10 // replace them with smaller constant pool entries, including:
11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form.
12 // * Using vzload scalar loads.
13 // * Broadcasting of full width loads.
14 // * Sign/Zero extension of full width loads.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "X86.h"
19 #include "X86InstrFoldTables.h"
20 #include "X86InstrInfo.h"
21 #include "X86Subtarget.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/CodeGen/MachineConstantPool.h"
24 
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "x86-fixup-vector-constants"
28 
29 STATISTIC(NumInstChanges, "Number of instructions changes");
30 
31 namespace {
32 class X86FixupVectorConstantsPass : public MachineFunctionPass {
33 public:
34   static char ID;
35 
X86FixupVectorConstantsPass()36   X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {}
37 
getPassName() const38   StringRef getPassName() const override {
39     return "X86 Fixup Vector Constants";
40   }
41 
42   bool runOnMachineFunction(MachineFunction &MF) override;
43   bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB,
44                           MachineInstr &MI);
45 
46   // This pass runs after regalloc and doesn't support VReg operands.
getRequiredProperties() const47   MachineFunctionProperties getRequiredProperties() const override {
48     return MachineFunctionProperties().setNoVRegs();
49   }
50 
51 private:
52   const X86InstrInfo *TII = nullptr;
53   const X86Subtarget *ST = nullptr;
54   const MCSchedModel *SM = nullptr;
55 };
56 } // end anonymous namespace
57 
58 char X86FixupVectorConstantsPass::ID = 0;
59 
INITIALIZE_PASS(X86FixupVectorConstantsPass,DEBUG_TYPE,DEBUG_TYPE,false,false)60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false)
61 
62 FunctionPass *llvm::createX86FixupVectorConstants() {
63   return new X86FixupVectorConstantsPass();
64 }
65 
66 /// Normally, we only allow poison in vector splats. However, as this is part
67 /// of the backend, and working with the DAG representation, which currently
68 /// only natively represents undef values, we need to accept undefs here.
getSplatValueAllowUndef(const ConstantVector * C)69 static Constant *getSplatValueAllowUndef(const ConstantVector *C) {
70   Constant *Res = nullptr;
71   for (Value *Op : C->operands()) {
72     Constant *OpC = cast<Constant>(Op);
73     if (isa<UndefValue>(OpC))
74       continue;
75     if (!Res)
76       Res = OpC;
77     else if (Res != OpC)
78       return nullptr;
79   }
80   return Res;
81 }
82 
83 // Attempt to extract the full width of bits data from the constant.
extractConstantBits(const Constant * C)84 static std::optional<APInt> extractConstantBits(const Constant *C) {
85   unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
86 
87   if (isa<UndefValue>(C))
88     return APInt::getZero(NumBits);
89 
90   if (auto *CInt = dyn_cast<ConstantInt>(C)) {
91     if (isa<VectorType>(CInt->getType()))
92       return APInt::getSplat(NumBits, CInt->getValue());
93 
94     return CInt->getValue();
95   }
96 
97   if (auto *CFP = dyn_cast<ConstantFP>(C)) {
98     if (isa<VectorType>(CFP->getType()))
99       return APInt::getSplat(NumBits, CFP->getValue().bitcastToAPInt());
100 
101     return CFP->getValue().bitcastToAPInt();
102   }
103 
104   if (auto *CV = dyn_cast<ConstantVector>(C)) {
105     if (auto *CVSplat = getSplatValueAllowUndef(CV)) {
106       if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
107         assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
108         return APInt::getSplat(NumBits, *Bits);
109       }
110     }
111 
112     APInt Bits = APInt::getZero(NumBits);
113     for (unsigned I = 0, E = CV->getNumOperands(); I != E; ++I) {
114       Constant *Elt = CV->getOperand(I);
115       std::optional<APInt> SubBits = extractConstantBits(Elt);
116       if (!SubBits)
117         return std::nullopt;
118       assert(NumBits == (E * SubBits->getBitWidth()) &&
119              "Illegal vector element size");
120       Bits.insertBits(*SubBits, I * SubBits->getBitWidth());
121     }
122     return Bits;
123   }
124 
125   if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
126     bool IsInteger = CDS->getElementType()->isIntegerTy();
127     bool IsFloat = CDS->getElementType()->isHalfTy() ||
128                    CDS->getElementType()->isBFloatTy() ||
129                    CDS->getElementType()->isFloatTy() ||
130                    CDS->getElementType()->isDoubleTy();
131     if (IsInteger || IsFloat) {
132       APInt Bits = APInt::getZero(NumBits);
133       unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits();
134       for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) {
135         if (IsInteger)
136           Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits);
137         else
138           Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(),
139                           I * EltBits);
140       }
141       return Bits;
142     }
143   }
144 
145   return std::nullopt;
146 }
147 
extractConstantBits(const Constant * C,unsigned NumBits)148 static std::optional<APInt> extractConstantBits(const Constant *C,
149                                                 unsigned NumBits) {
150   if (std::optional<APInt> Bits = extractConstantBits(C))
151     return Bits->zextOrTrunc(NumBits);
152   return std::nullopt;
153 }
154 
155 // Attempt to compute the splat width of bits data by normalizing the splat to
156 // remove undefs.
getSplatableConstant(const Constant * C,unsigned SplatBitWidth)157 static std::optional<APInt> getSplatableConstant(const Constant *C,
158                                                  unsigned SplatBitWidth) {
159   const Type *Ty = C->getType();
160   assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 &&
161          "Illegal splat width");
162 
163   if (std::optional<APInt> Bits = extractConstantBits(C))
164     if (Bits->isSplat(SplatBitWidth))
165       return Bits->trunc(SplatBitWidth);
166 
167   // Detect general splats with undefs.
168   // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting?
169   if (auto *CV = dyn_cast<ConstantVector>(C)) {
170     unsigned NumOps = CV->getNumOperands();
171     unsigned NumEltsBits = Ty->getScalarSizeInBits();
172     unsigned NumScaleOps = SplatBitWidth / NumEltsBits;
173     if ((SplatBitWidth % NumEltsBits) == 0) {
174       // Collect the elements and ensure that within the repeated splat sequence
175       // they either match or are undef.
176       SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr);
177       for (unsigned Idx = 0; Idx != NumOps; ++Idx) {
178         if (Constant *Elt = CV->getAggregateElement(Idx)) {
179           if (isa<UndefValue>(Elt))
180             continue;
181           unsigned SplatIdx = Idx % NumScaleOps;
182           if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) {
183             Sequence[SplatIdx] = Elt;
184             continue;
185           }
186         }
187         return std::nullopt;
188       }
189       // Extract the constant bits forming the splat and insert into the bits
190       // data, leave undef as zero.
191       APInt SplatBits = APInt::getZero(SplatBitWidth);
192       for (unsigned I = 0; I != NumScaleOps; ++I) {
193         if (!Sequence[I])
194           continue;
195         if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) {
196           SplatBits.insertBits(*Bits, I * Bits->getBitWidth());
197           continue;
198         }
199         return std::nullopt;
200       }
201       return SplatBits;
202     }
203   }
204 
205   return std::nullopt;
206 }
207 
208 // Split raw bits into a constant vector of elements of a specific bit width.
209 // NOTE: We don't always bother converting to scalars if the vector length is 1.
rebuildConstant(LLVMContext & Ctx,Type * SclTy,const APInt & Bits,unsigned NumSclBits)210 static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
211                                  const APInt &Bits, unsigned NumSclBits) {
212   unsigned BitWidth = Bits.getBitWidth();
213 
214   if (NumSclBits == 8) {
215     SmallVector<uint8_t> RawBits;
216     for (unsigned I = 0; I != BitWidth; I += 8)
217       RawBits.push_back(Bits.extractBits(8, I).getZExtValue());
218     return ConstantDataVector::get(Ctx, RawBits);
219   }
220 
221   if (NumSclBits == 16) {
222     SmallVector<uint16_t> RawBits;
223     for (unsigned I = 0; I != BitWidth; I += 16)
224       RawBits.push_back(Bits.extractBits(16, I).getZExtValue());
225     if (SclTy->is16bitFPTy())
226       return ConstantDataVector::getFP(SclTy, RawBits);
227     return ConstantDataVector::get(Ctx, RawBits);
228   }
229 
230   if (NumSclBits == 32) {
231     SmallVector<uint32_t> RawBits;
232     for (unsigned I = 0; I != BitWidth; I += 32)
233       RawBits.push_back(Bits.extractBits(32, I).getZExtValue());
234     if (SclTy->isFloatTy())
235       return ConstantDataVector::getFP(SclTy, RawBits);
236     return ConstantDataVector::get(Ctx, RawBits);
237   }
238 
239   assert(NumSclBits == 64 && "Unhandled vector element width");
240 
241   SmallVector<uint64_t> RawBits;
242   for (unsigned I = 0; I != BitWidth; I += 64)
243     RawBits.push_back(Bits.extractBits(64, I).getZExtValue());
244   if (SclTy->isDoubleTy())
245     return ConstantDataVector::getFP(SclTy, RawBits);
246   return ConstantDataVector::get(Ctx, RawBits);
247 }
248 
249 // Attempt to rebuild a normalized splat vector constant of the requested splat
250 // width, built up of potentially smaller scalar values.
rebuildSplatCst(const Constant * C,unsigned,unsigned,unsigned SplatBitWidth)251 static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
252                                  unsigned /*NumElts*/, unsigned SplatBitWidth) {
253   // TODO: Truncate to NumBits once ConvertToBroadcastAVX512 support this.
254   std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
255   if (!Splat)
256     return nullptr;
257 
258   // Determine scalar size to use for the constant splat vector, clamping as we
259   // might have found a splat smaller than the original constant data.
260   Type *SclTy = C->getType()->getScalarType();
261   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
262   NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
263 
264   // Fallback to i64 / double.
265   NumSclBits = (NumSclBits == 8 || NumSclBits == 16 || NumSclBits == 32)
266                    ? NumSclBits
267                    : 64;
268 
269   // Extract per-element bits.
270   return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
271 }
272 
rebuildZeroUpperCst(const Constant * C,unsigned NumBits,unsigned,unsigned ScalarBitWidth)273 static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
274                                      unsigned /*NumElts*/,
275                                      unsigned ScalarBitWidth) {
276   Type *SclTy = C->getType()->getScalarType();
277   unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
278   LLVMContext &Ctx = C->getContext();
279 
280   if (NumBits > ScalarBitWidth) {
281     // Determine if the upper bits are all zero.
282     if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
283       if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
284         // If the original constant was made of smaller elements, try to retain
285         // those types.
286         if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0)
287           return rebuildConstant(Ctx, SclTy, *Bits, NumSclBits);
288 
289         // Fallback to raw integer bits.
290         APInt RawBits = Bits->zextOrTrunc(ScalarBitWidth);
291         return ConstantInt::get(Ctx, RawBits);
292       }
293     }
294   }
295 
296   return nullptr;
297 }
298 
rebuildExtCst(const Constant * C,bool IsSExt,unsigned NumBits,unsigned NumElts,unsigned SrcEltBitWidth)299 static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
300                                unsigned NumBits, unsigned NumElts,
301                                unsigned SrcEltBitWidth) {
302   unsigned DstEltBitWidth = NumBits / NumElts;
303   assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
304          (DstEltBitWidth % SrcEltBitWidth) == 0 &&
305          (DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
306 
307   if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
308     assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
309            (Bits->getBitWidth() % DstEltBitWidth) == 0 &&
310            "Unexpected constant extension");
311 
312     // Ensure every vector element can be represented by the src bitwidth.
313     APInt TruncBits = APInt::getZero(NumElts * SrcEltBitWidth);
314     for (unsigned I = 0; I != NumElts; ++I) {
315       APInt Elt = Bits->extractBits(DstEltBitWidth, I * DstEltBitWidth);
316       if ((IsSExt && Elt.getSignificantBits() > SrcEltBitWidth) ||
317           (!IsSExt && Elt.getActiveBits() > SrcEltBitWidth))
318         return nullptr;
319       TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
320     }
321 
322     Type *Ty = C->getType();
323     return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
324                            SrcEltBitWidth);
325   }
326 
327   return nullptr;
328 }
rebuildSExtCst(const Constant * C,unsigned NumBits,unsigned NumElts,unsigned SrcEltBitWidth)329 static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
330                                 unsigned NumElts, unsigned SrcEltBitWidth) {
331   return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
332 }
rebuildZExtCst(const Constant * C,unsigned NumBits,unsigned NumElts,unsigned SrcEltBitWidth)333 static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
334                                 unsigned NumElts, unsigned SrcEltBitWidth) {
335   return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
336 }
337 
processInstruction(MachineFunction & MF,MachineBasicBlock & MBB,MachineInstr & MI)338 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
339                                                      MachineBasicBlock &MBB,
340                                                      MachineInstr &MI) {
341   unsigned Opc = MI.getOpcode();
342   MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
343   bool HasSSE2 = ST->hasSSE2();
344   bool HasSSE41 = ST->hasSSE41();
345   bool HasAVX2 = ST->hasAVX2();
346   bool HasDQI = ST->hasDQI();
347   bool HasBWI = ST->hasBWI();
348   bool HasVLX = ST->hasVLX();
349   bool MultiDomain = ST->hasAVX512() || ST->hasNoDomainDelayMov();
350   bool OptSize = MF.getFunction().hasOptSize();
351 
352   struct FixupEntry {
353     int Op;
354     int NumCstElts;
355     int MemBitWidth;
356     std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
357         RebuildConstant;
358   };
359 
360   auto NewOpcPreferable = [&](const FixupEntry &Fixup,
361                               unsigned RegBitWidth) -> bool {
362     if (SM->hasInstrSchedModel()) {
363       unsigned NewOpc = Fixup.Op;
364       auto *OldDesc = SM->getSchedClassDesc(TII->get(Opc).getSchedClass());
365       auto *NewDesc = SM->getSchedClassDesc(TII->get(NewOpc).getSchedClass());
366       unsigned BitsSaved = RegBitWidth - (Fixup.NumCstElts * Fixup.MemBitWidth);
367 
368       // Compare tput/lat - avoid any regressions, but allow extra cycle of
369       // latency in exchange for each 128-bit (or less) constant pool reduction
370       // (this is a very simple cost:benefit estimate - there will probably be
371       // better ways to calculate this).
372       double OldTput = MCSchedModel::getReciprocalThroughput(*ST, *OldDesc);
373       double NewTput = MCSchedModel::getReciprocalThroughput(*ST, *NewDesc);
374       if (OldTput != NewTput)
375         return NewTput < OldTput;
376 
377       int LatTol = (BitsSaved + 127) / 128;
378       int OldLat = MCSchedModel::computeInstrLatency(*ST, *OldDesc);
379       int NewLat = MCSchedModel::computeInstrLatency(*ST, *NewDesc);
380       if (OldLat != NewLat)
381         return NewLat < (OldLat + LatTol);
382     }
383 
384     // We either were unable to get tput/lat or all values were equal.
385     // Prefer the new opcode for reduced constant pool size.
386     return true;
387   };
388 
389   auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned RegBitWidth,
390                            unsigned OperandNo) {
391 #ifdef EXPENSIVE_CHECKS
392     assert(llvm::is_sorted(Fixups,
393                            [](const FixupEntry &A, const FixupEntry &B) {
394                              return (A.NumCstElts * A.MemBitWidth) <
395                                     (B.NumCstElts * B.MemBitWidth);
396                            }) &&
397            "Constant fixup table not sorted in ascending constant size");
398 #endif
399     assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
400            "Unexpected number of operands!");
401     if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
402       unsigned CstBitWidth = C->getType()->getPrimitiveSizeInBits();
403       RegBitWidth = RegBitWidth ? RegBitWidth : CstBitWidth;
404       for (const FixupEntry &Fixup : Fixups) {
405         // Always uses the smallest possible constant load with opt/minsize,
406         // otherwise use the smallest instruction that doesn't affect
407         // performance.
408         // TODO: If constant has been hoisted from loop, use smallest constant.
409         if (Fixup.Op && (OptSize || NewOpcPreferable(Fixup, RegBitWidth))) {
410           // Construct a suitable constant and adjust the MI to use the new
411           // constant pool entry.
412           if (Constant *NewCst = Fixup.RebuildConstant(
413                   C, RegBitWidth, Fixup.NumCstElts, Fixup.MemBitWidth)) {
414             unsigned NewCPI =
415                 CP->getConstantPoolIndex(NewCst, Align(Fixup.MemBitWidth / 8));
416             MI.setDesc(TII->get(Fixup.Op));
417             MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
418             return true;
419           }
420         }
421       }
422     }
423     return false;
424   };
425 
426   // Attempt to detect a suitable vzload/broadcast/vextload from increasing
427   // constant bitwidths. Prefer vzload/broadcast/vextload for same bitwidth:
428   // - vzload shouldn't ever need a shuffle port to zero the upper elements and
429   // the fp/int domain versions are equally available so we don't introduce a
430   // domain crossing penalty.
431   // - broadcast sometimes need a shuffle port (especially for 8/16-bit
432   // variants), AVX1 only has fp domain broadcasts but AVX2+ have good fp/int
433   // domain equivalents.
434   // - vextload always needs a shuffle port and is only ever int domain.
435   switch (Opc) {
436   /* FP Loads */
437   case X86::MOVAPDrm:
438   case X86::MOVAPSrm:
439   case X86::MOVUPDrm:
440   case X86::MOVUPSrm: {
441     // TODO: SSE3 MOVDDUP Handling
442     FixupEntry Fixups[] = {
443         {X86::MOVSSrm, 1, 32, rebuildZeroUpperCst},
444         {HasSSE2 ? X86::MOVSDrm : 0, 1, 64, rebuildZeroUpperCst}};
445     return FixupConstant(Fixups, 128, 1);
446   }
447   case X86::VMOVAPDrm:
448   case X86::VMOVAPSrm:
449   case X86::VMOVUPDrm:
450   case X86::VMOVUPSrm: {
451     FixupEntry Fixups[] = {
452         {MultiDomain ? X86::VPMOVSXBQrm : 0, 2, 8, rebuildSExtCst},
453         {MultiDomain ? X86::VPMOVZXBQrm : 0, 2, 8, rebuildZExtCst},
454         {X86::VMOVSSrm, 1, 32, rebuildZeroUpperCst},
455         {X86::VBROADCASTSSrm, 1, 32, rebuildSplatCst},
456         {MultiDomain ? X86::VPMOVSXBDrm : 0, 4, 8, rebuildSExtCst},
457         {MultiDomain ? X86::VPMOVZXBDrm : 0, 4, 8, rebuildZExtCst},
458         {MultiDomain ? X86::VPMOVSXWQrm : 0, 2, 16, rebuildSExtCst},
459         {MultiDomain ? X86::VPMOVZXWQrm : 0, 2, 16, rebuildZExtCst},
460         {X86::VMOVSDrm, 1, 64, rebuildZeroUpperCst},
461         {X86::VMOVDDUPrm, 1, 64, rebuildSplatCst},
462         {MultiDomain ? X86::VPMOVSXWDrm : 0, 4, 16, rebuildSExtCst},
463         {MultiDomain ? X86::VPMOVZXWDrm : 0, 4, 16, rebuildZExtCst},
464         {MultiDomain ? X86::VPMOVSXDQrm : 0, 2, 32, rebuildSExtCst},
465         {MultiDomain ? X86::VPMOVZXDQrm : 0, 2, 32, rebuildZExtCst}};
466     return FixupConstant(Fixups, 128, 1);
467   }
468   case X86::VMOVAPDYrm:
469   case X86::VMOVAPSYrm:
470   case X86::VMOVUPDYrm:
471   case X86::VMOVUPSYrm: {
472     FixupEntry Fixups[] = {
473         {X86::VBROADCASTSSYrm, 1, 32, rebuildSplatCst},
474         {HasAVX2 && MultiDomain ? X86::VPMOVSXBQYrm : 0, 4, 8, rebuildSExtCst},
475         {HasAVX2 && MultiDomain ? X86::VPMOVZXBQYrm : 0, 4, 8, rebuildZExtCst},
476         {X86::VBROADCASTSDYrm, 1, 64, rebuildSplatCst},
477         {HasAVX2 && MultiDomain ? X86::VPMOVSXBDYrm : 0, 8, 8, rebuildSExtCst},
478         {HasAVX2 && MultiDomain ? X86::VPMOVZXBDYrm : 0, 8, 8, rebuildZExtCst},
479         {HasAVX2 && MultiDomain ? X86::VPMOVSXWQYrm : 0, 4, 16, rebuildSExtCst},
480         {HasAVX2 && MultiDomain ? X86::VPMOVZXWQYrm : 0, 4, 16, rebuildZExtCst},
481         {X86::VBROADCASTF128rm, 1, 128, rebuildSplatCst},
482         {HasAVX2 && MultiDomain ? X86::VPMOVSXWDYrm : 0, 8, 16, rebuildSExtCst},
483         {HasAVX2 && MultiDomain ? X86::VPMOVZXWDYrm : 0, 8, 16, rebuildZExtCst},
484         {HasAVX2 && MultiDomain ? X86::VPMOVSXDQYrm : 0, 4, 32, rebuildSExtCst},
485         {HasAVX2 && MultiDomain ? X86::VPMOVZXDQYrm : 0, 4, 32,
486          rebuildZExtCst}};
487     return FixupConstant(Fixups, 256, 1);
488   }
489   case X86::VMOVAPDZ128rm:
490   case X86::VMOVAPSZ128rm:
491   case X86::VMOVUPDZ128rm:
492   case X86::VMOVUPSZ128rm: {
493     FixupEntry Fixups[] = {
494         {MultiDomain ? X86::VPMOVSXBQZ128rm : 0, 2, 8, rebuildSExtCst},
495         {MultiDomain ? X86::VPMOVZXBQZ128rm : 0, 2, 8, rebuildZExtCst},
496         {X86::VMOVSSZrm, 1, 32, rebuildZeroUpperCst},
497         {X86::VBROADCASTSSZ128rm, 1, 32, rebuildSplatCst},
498         {MultiDomain ? X86::VPMOVSXBDZ128rm : 0, 4, 8, rebuildSExtCst},
499         {MultiDomain ? X86::VPMOVZXBDZ128rm : 0, 4, 8, rebuildZExtCst},
500         {MultiDomain ? X86::VPMOVSXWQZ128rm : 0, 2, 16, rebuildSExtCst},
501         {MultiDomain ? X86::VPMOVZXWQZ128rm : 0, 2, 16, rebuildZExtCst},
502         {X86::VMOVSDZrm, 1, 64, rebuildZeroUpperCst},
503         {X86::VMOVDDUPZ128rm, 1, 64, rebuildSplatCst},
504         {MultiDomain ? X86::VPMOVSXWDZ128rm : 0, 4, 16, rebuildSExtCst},
505         {MultiDomain ? X86::VPMOVZXWDZ128rm : 0, 4, 16, rebuildZExtCst},
506         {MultiDomain ? X86::VPMOVSXDQZ128rm : 0, 2, 32, rebuildSExtCst},
507         {MultiDomain ? X86::VPMOVZXDQZ128rm : 0, 2, 32, rebuildZExtCst}};
508     return FixupConstant(Fixups, 128, 1);
509   }
510   case X86::VMOVAPDZ256rm:
511   case X86::VMOVAPSZ256rm:
512   case X86::VMOVUPDZ256rm:
513   case X86::VMOVUPSZ256rm: {
514     FixupEntry Fixups[] = {
515         {X86::VBROADCASTSSZ256rm, 1, 32, rebuildSplatCst},
516         {MultiDomain ? X86::VPMOVSXBQZ256rm : 0, 4, 8, rebuildSExtCst},
517         {MultiDomain ? X86::VPMOVZXBQZ256rm : 0, 4, 8, rebuildZExtCst},
518         {X86::VBROADCASTSDZ256rm, 1, 64, rebuildSplatCst},
519         {MultiDomain ? X86::VPMOVSXBDZ256rm : 0, 8, 8, rebuildSExtCst},
520         {MultiDomain ? X86::VPMOVZXBDZ256rm : 0, 8, 8, rebuildZExtCst},
521         {MultiDomain ? X86::VPMOVSXWQZ256rm : 0, 4, 16, rebuildSExtCst},
522         {MultiDomain ? X86::VPMOVZXWQZ256rm : 0, 4, 16, rebuildZExtCst},
523         {X86::VBROADCASTF32X4Z256rm, 1, 128, rebuildSplatCst},
524         {MultiDomain ? X86::VPMOVSXWDZ256rm : 0, 8, 16, rebuildSExtCst},
525         {MultiDomain ? X86::VPMOVZXWDZ256rm : 0, 8, 16, rebuildZExtCst},
526         {MultiDomain ? X86::VPMOVSXDQZ256rm : 0, 4, 32, rebuildSExtCst},
527         {MultiDomain ? X86::VPMOVZXDQZ256rm : 0, 4, 32, rebuildZExtCst}};
528     return FixupConstant(Fixups, 256, 1);
529   }
530   case X86::VMOVAPDZrm:
531   case X86::VMOVAPSZrm:
532   case X86::VMOVUPDZrm:
533   case X86::VMOVUPSZrm: {
534     FixupEntry Fixups[] = {
535         {X86::VBROADCASTSSZrm, 1, 32, rebuildSplatCst},
536         {X86::VBROADCASTSDZrm, 1, 64, rebuildSplatCst},
537         {MultiDomain ? X86::VPMOVSXBQZrm : 0, 8, 8, rebuildSExtCst},
538         {MultiDomain ? X86::VPMOVZXBQZrm : 0, 8, 8, rebuildZExtCst},
539         {X86::VBROADCASTF32X4Zrm, 1, 128, rebuildSplatCst},
540         {MultiDomain ? X86::VPMOVSXBDZrm : 0, 16, 8, rebuildSExtCst},
541         {MultiDomain ? X86::VPMOVZXBDZrm : 0, 16, 8, rebuildZExtCst},
542         {MultiDomain ? X86::VPMOVSXWQZrm : 0, 8, 16, rebuildSExtCst},
543         {MultiDomain ? X86::VPMOVZXWQZrm : 0, 8, 16, rebuildZExtCst},
544         {X86::VBROADCASTF64X4Zrm, 1, 256, rebuildSplatCst},
545         {MultiDomain ? X86::VPMOVSXWDZrm : 0, 16, 16, rebuildSExtCst},
546         {MultiDomain ? X86::VPMOVZXWDZrm : 0, 16, 16, rebuildZExtCst},
547         {MultiDomain ? X86::VPMOVSXDQZrm : 0, 8, 32, rebuildSExtCst},
548         {MultiDomain ? X86::VPMOVZXDQZrm : 0, 8, 32, rebuildZExtCst}};
549     return FixupConstant(Fixups, 512, 1);
550   }
551     /* Integer Loads */
552   case X86::MOVDQArm:
553   case X86::MOVDQUrm: {
554     FixupEntry Fixups[] = {
555         {HasSSE41 ? X86::PMOVSXBQrm : 0, 2, 8, rebuildSExtCst},
556         {HasSSE41 ? X86::PMOVZXBQrm : 0, 2, 8, rebuildZExtCst},
557         {X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
558         {HasSSE41 ? X86::PMOVSXBDrm : 0, 4, 8, rebuildSExtCst},
559         {HasSSE41 ? X86::PMOVZXBDrm : 0, 4, 8, rebuildZExtCst},
560         {HasSSE41 ? X86::PMOVSXWQrm : 0, 2, 16, rebuildSExtCst},
561         {HasSSE41 ? X86::PMOVZXWQrm : 0, 2, 16, rebuildZExtCst},
562         {X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
563         {HasSSE41 ? X86::PMOVSXBWrm : 0, 8, 8, rebuildSExtCst},
564         {HasSSE41 ? X86::PMOVZXBWrm : 0, 8, 8, rebuildZExtCst},
565         {HasSSE41 ? X86::PMOVSXWDrm : 0, 4, 16, rebuildSExtCst},
566         {HasSSE41 ? X86::PMOVZXWDrm : 0, 4, 16, rebuildZExtCst},
567         {HasSSE41 ? X86::PMOVSXDQrm : 0, 2, 32, rebuildSExtCst},
568         {HasSSE41 ? X86::PMOVZXDQrm : 0, 2, 32, rebuildZExtCst}};
569     return FixupConstant(Fixups, 128, 1);
570   }
571   case X86::VMOVDQArm:
572   case X86::VMOVDQUrm: {
573     FixupEntry Fixups[] = {
574         {HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
575         {HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
576         {X86::VPMOVSXBQrm, 2, 8, rebuildSExtCst},
577         {X86::VPMOVZXBQrm, 2, 8, rebuildZExtCst},
578         {X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
579         {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
580          rebuildSplatCst},
581         {X86::VPMOVSXBDrm, 4, 8, rebuildSExtCst},
582         {X86::VPMOVZXBDrm, 4, 8, rebuildZExtCst},
583         {X86::VPMOVSXWQrm, 2, 16, rebuildSExtCst},
584         {X86::VPMOVZXWQrm, 2, 16, rebuildZExtCst},
585         {X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
586         {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
587          rebuildSplatCst},
588         {X86::VPMOVSXBWrm, 8, 8, rebuildSExtCst},
589         {X86::VPMOVZXBWrm, 8, 8, rebuildZExtCst},
590         {X86::VPMOVSXWDrm, 4, 16, rebuildSExtCst},
591         {X86::VPMOVZXWDrm, 4, 16, rebuildZExtCst},
592         {X86::VPMOVSXDQrm, 2, 32, rebuildSExtCst},
593         {X86::VPMOVZXDQrm, 2, 32, rebuildZExtCst}};
594     return FixupConstant(Fixups, 128, 1);
595   }
596   case X86::VMOVDQAYrm:
597   case X86::VMOVDQUYrm: {
598     FixupEntry Fixups[] = {
599         {HasAVX2 ? X86::VPBROADCASTBYrm : 0, 1, 8, rebuildSplatCst},
600         {HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
601         {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
602          rebuildSplatCst},
603         {HasAVX2 ? X86::VPMOVSXBQYrm : 0, 4, 8, rebuildSExtCst},
604         {HasAVX2 ? X86::VPMOVZXBQYrm : 0, 4, 8, rebuildZExtCst},
605         {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
606          rebuildSplatCst},
607         {HasAVX2 ? X86::VPMOVSXBDYrm : 0, 8, 8, rebuildSExtCst},
608         {HasAVX2 ? X86::VPMOVZXBDYrm : 0, 8, 8, rebuildZExtCst},
609         {HasAVX2 ? X86::VPMOVSXWQYrm : 0, 4, 16, rebuildSExtCst},
610         {HasAVX2 ? X86::VPMOVZXWQYrm : 0, 4, 16, rebuildZExtCst},
611         {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
612          rebuildSplatCst},
613         {HasAVX2 ? X86::VPMOVSXBWYrm : 0, 16, 8, rebuildSExtCst},
614         {HasAVX2 ? X86::VPMOVZXBWYrm : 0, 16, 8, rebuildZExtCst},
615         {HasAVX2 ? X86::VPMOVSXWDYrm : 0, 8, 16, rebuildSExtCst},
616         {HasAVX2 ? X86::VPMOVZXWDYrm : 0, 8, 16, rebuildZExtCst},
617         {HasAVX2 ? X86::VPMOVSXDQYrm : 0, 4, 32, rebuildSExtCst},
618         {HasAVX2 ? X86::VPMOVZXDQYrm : 0, 4, 32, rebuildZExtCst}};
619     return FixupConstant(Fixups, 256, 1);
620   }
621   case X86::VMOVDQA32Z128rm:
622   case X86::VMOVDQA64Z128rm:
623   case X86::VMOVDQU32Z128rm:
624   case X86::VMOVDQU64Z128rm: {
625     FixupEntry Fixups[] = {
626         {HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
627         {HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
628         {X86::VPMOVSXBQZ128rm, 2, 8, rebuildSExtCst},
629         {X86::VPMOVZXBQZ128rm, 2, 8, rebuildZExtCst},
630         {X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
631         {X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
632         {X86::VPMOVSXBDZ128rm, 4, 8, rebuildSExtCst},
633         {X86::VPMOVZXBDZ128rm, 4, 8, rebuildZExtCst},
634         {X86::VPMOVSXWQZ128rm, 2, 16, rebuildSExtCst},
635         {X86::VPMOVZXWQZ128rm, 2, 16, rebuildZExtCst},
636         {X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
637         {X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst},
638         {HasBWI ? X86::VPMOVSXBWZ128rm : 0, 8, 8, rebuildSExtCst},
639         {HasBWI ? X86::VPMOVZXBWZ128rm : 0, 8, 8, rebuildZExtCst},
640         {X86::VPMOVSXWDZ128rm, 4, 16, rebuildSExtCst},
641         {X86::VPMOVZXWDZ128rm, 4, 16, rebuildZExtCst},
642         {X86::VPMOVSXDQZ128rm, 2, 32, rebuildSExtCst},
643         {X86::VPMOVZXDQZ128rm, 2, 32, rebuildZExtCst}};
644     return FixupConstant(Fixups, 128, 1);
645   }
646   case X86::VMOVDQA32Z256rm:
647   case X86::VMOVDQA64Z256rm:
648   case X86::VMOVDQU32Z256rm:
649   case X86::VMOVDQU64Z256rm: {
650     FixupEntry Fixups[] = {
651         {HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
652         {HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
653         {X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
654         {X86::VPMOVSXBQZ256rm, 4, 8, rebuildSExtCst},
655         {X86::VPMOVZXBQZ256rm, 4, 8, rebuildZExtCst},
656         {X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
657         {X86::VPMOVSXBDZ256rm, 8, 8, rebuildSExtCst},
658         {X86::VPMOVZXBDZ256rm, 8, 8, rebuildZExtCst},
659         {X86::VPMOVSXWQZ256rm, 4, 16, rebuildSExtCst},
660         {X86::VPMOVZXWQZ256rm, 4, 16, rebuildZExtCst},
661         {X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst},
662         {HasBWI ? X86::VPMOVSXBWZ256rm : 0, 16, 8, rebuildSExtCst},
663         {HasBWI ? X86::VPMOVZXBWZ256rm : 0, 16, 8, rebuildZExtCst},
664         {X86::VPMOVSXWDZ256rm, 8, 16, rebuildSExtCst},
665         {X86::VPMOVZXWDZ256rm, 8, 16, rebuildZExtCst},
666         {X86::VPMOVSXDQZ256rm, 4, 32, rebuildSExtCst},
667         {X86::VPMOVZXDQZ256rm, 4, 32, rebuildZExtCst}};
668     return FixupConstant(Fixups, 256, 1);
669   }
670   case X86::VMOVDQA32Zrm:
671   case X86::VMOVDQA64Zrm:
672   case X86::VMOVDQU32Zrm:
673   case X86::VMOVDQU64Zrm: {
674     FixupEntry Fixups[] = {
675         {HasBWI ? X86::VPBROADCASTBZrm : 0, 1, 8, rebuildSplatCst},
676         {HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
677         {X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
678         {X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
679         {X86::VPMOVSXBQZrm, 8, 8, rebuildSExtCst},
680         {X86::VPMOVZXBQZrm, 8, 8, rebuildZExtCst},
681         {X86::VBROADCASTI32X4Zrm, 1, 128, rebuildSplatCst},
682         {X86::VPMOVSXBDZrm, 16, 8, rebuildSExtCst},
683         {X86::VPMOVZXBDZrm, 16, 8, rebuildZExtCst},
684         {X86::VPMOVSXWQZrm, 8, 16, rebuildSExtCst},
685         {X86::VPMOVZXWQZrm, 8, 16, rebuildZExtCst},
686         {X86::VBROADCASTI64X4Zrm, 1, 256, rebuildSplatCst},
687         {HasBWI ? X86::VPMOVSXBWZrm : 0, 32, 8, rebuildSExtCst},
688         {HasBWI ? X86::VPMOVZXBWZrm : 0, 32, 8, rebuildZExtCst},
689         {X86::VPMOVSXWDZrm, 16, 16, rebuildSExtCst},
690         {X86::VPMOVZXWDZrm, 16, 16, rebuildZExtCst},
691         {X86::VPMOVSXDQZrm, 8, 32, rebuildSExtCst},
692         {X86::VPMOVZXDQZrm, 8, 32, rebuildZExtCst}};
693     return FixupConstant(Fixups, 512, 1);
694   }
695   }
696 
697   auto ConvertToBroadcast = [&](unsigned OpSrc, int BW) {
698     if (OpSrc) {
699       if (const X86FoldTableEntry *Mem2Bcst =
700               llvm::lookupBroadcastFoldTableBySize(OpSrc, BW)) {
701         unsigned OpBcst = Mem2Bcst->DstOp;
702         unsigned OpNoBcst = Mem2Bcst->Flags & TB_INDEX_MASK;
703         FixupEntry Fixups[] = {{(int)OpBcst, 1, BW, rebuildSplatCst}};
704         // TODO: Add support for RegBitWidth, but currently rebuildSplatCst
705         // doesn't require it (defaults to Constant::getPrimitiveSizeInBits).
706         return FixupConstant(Fixups, 0, OpNoBcst);
707       }
708     }
709     return false;
710   };
711 
712   // Attempt to find a AVX512 mapping from a full width memory-fold instruction
713   // to a broadcast-fold instruction variant.
714   if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX)
715     return ConvertToBroadcast(Opc, 32) || ConvertToBroadcast(Opc, 64);
716 
717   // Reverse the X86InstrInfo::setExecutionDomainCustom EVEX->VEX logic
718   // conversion to see if we can convert to a broadcasted (integer) logic op.
719   if (HasVLX && !HasDQI) {
720     unsigned OpSrc32 = 0, OpSrc64 = 0;
721     switch (Opc) {
722     case X86::VANDPDrm:
723     case X86::VANDPSrm:
724     case X86::VPANDrm:
725       OpSrc32 = X86 ::VPANDDZ128rm;
726       OpSrc64 = X86 ::VPANDQZ128rm;
727       break;
728     case X86::VANDPDYrm:
729     case X86::VANDPSYrm:
730     case X86::VPANDYrm:
731       OpSrc32 = X86 ::VPANDDZ256rm;
732       OpSrc64 = X86 ::VPANDQZ256rm;
733       break;
734     case X86::VANDNPDrm:
735     case X86::VANDNPSrm:
736     case X86::VPANDNrm:
737       OpSrc32 = X86 ::VPANDNDZ128rm;
738       OpSrc64 = X86 ::VPANDNQZ128rm;
739       break;
740     case X86::VANDNPDYrm:
741     case X86::VANDNPSYrm:
742     case X86::VPANDNYrm:
743       OpSrc32 = X86 ::VPANDNDZ256rm;
744       OpSrc64 = X86 ::VPANDNQZ256rm;
745       break;
746     case X86::VORPDrm:
747     case X86::VORPSrm:
748     case X86::VPORrm:
749       OpSrc32 = X86 ::VPORDZ128rm;
750       OpSrc64 = X86 ::VPORQZ128rm;
751       break;
752     case X86::VORPDYrm:
753     case X86::VORPSYrm:
754     case X86::VPORYrm:
755       OpSrc32 = X86 ::VPORDZ256rm;
756       OpSrc64 = X86 ::VPORQZ256rm;
757       break;
758     case X86::VXORPDrm:
759     case X86::VXORPSrm:
760     case X86::VPXORrm:
761       OpSrc32 = X86 ::VPXORDZ128rm;
762       OpSrc64 = X86 ::VPXORQZ128rm;
763       break;
764     case X86::VXORPDYrm:
765     case X86::VXORPSYrm:
766     case X86::VPXORYrm:
767       OpSrc32 = X86 ::VPXORDZ256rm;
768       OpSrc64 = X86 ::VPXORQZ256rm;
769       break;
770     }
771     if (OpSrc32 || OpSrc64)
772       return ConvertToBroadcast(OpSrc32, 32) || ConvertToBroadcast(OpSrc64, 64);
773   }
774 
775   return false;
776 }
777 
runOnMachineFunction(MachineFunction & MF)778 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) {
779   LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";);
780   bool Changed = false;
781   ST = &MF.getSubtarget<X86Subtarget>();
782   TII = ST->getInstrInfo();
783   SM = &ST->getSchedModel();
784 
785   for (MachineBasicBlock &MBB : MF) {
786     for (MachineInstr &MI : MBB) {
787       if (processInstruction(MF, MBB, MI)) {
788         ++NumInstChanges;
789         Changed = true;
790       }
791     }
792   }
793   LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";);
794   return Changed;
795 }
796