xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/VectorUtils.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Analysis/VectorUtils.h"
14 #include "llvm/ADT/EquivalenceClasses.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/Analysis/DemandedBits.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/LoopIterator.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
20 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/MemoryModelRelaxationAnnotations.h"
27 #include "llvm/IR/PatternMatch.h"
28 #include "llvm/IR/Value.h"
29 #include "llvm/Support/CommandLine.h"
30 
31 #define DEBUG_TYPE "vectorutils"
32 
33 using namespace llvm;
34 using namespace llvm::PatternMatch;
35 
36 /// Maximum factor for an interleaved memory access.
37 static cl::opt<unsigned> MaxInterleaveGroupFactor(
38     "max-interleave-group-factor", cl::Hidden,
39     cl::desc("Maximum factor for an interleaved access group (default = 8)"),
40     cl::init(8));
41 
42 /// Return true if all of the intrinsic's arguments and return type are scalars
43 /// for the scalar form of the intrinsic, and vectors for the vector form of the
44 /// intrinsic (except operands that are marked as always being scalar by
45 /// isVectorIntrinsicWithScalarOpAtArg).
46 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
47   switch (ID) {
48   case Intrinsic::abs:   // Begin integer bit-manipulation.
49   case Intrinsic::bswap:
50   case Intrinsic::bitreverse:
51   case Intrinsic::ctpop:
52   case Intrinsic::ctlz:
53   case Intrinsic::cttz:
54   case Intrinsic::fshl:
55   case Intrinsic::fshr:
56   case Intrinsic::smax:
57   case Intrinsic::smin:
58   case Intrinsic::umax:
59   case Intrinsic::umin:
60   case Intrinsic::sadd_sat:
61   case Intrinsic::ssub_sat:
62   case Intrinsic::uadd_sat:
63   case Intrinsic::usub_sat:
64   case Intrinsic::smul_fix:
65   case Intrinsic::smul_fix_sat:
66   case Intrinsic::umul_fix:
67   case Intrinsic::umul_fix_sat:
68   case Intrinsic::sqrt: // Begin floating-point.
69   case Intrinsic::asin:
70   case Intrinsic::acos:
71   case Intrinsic::atan:
72   case Intrinsic::atan2:
73   case Intrinsic::sin:
74   case Intrinsic::cos:
75   case Intrinsic::sincos:
76   case Intrinsic::sincospi:
77   case Intrinsic::tan:
78   case Intrinsic::sinh:
79   case Intrinsic::cosh:
80   case Intrinsic::tanh:
81   case Intrinsic::exp:
82   case Intrinsic::exp10:
83   case Intrinsic::exp2:
84   case Intrinsic::log:
85   case Intrinsic::log10:
86   case Intrinsic::log2:
87   case Intrinsic::fabs:
88   case Intrinsic::minnum:
89   case Intrinsic::maxnum:
90   case Intrinsic::minimum:
91   case Intrinsic::maximum:
92   case Intrinsic::minimumnum:
93   case Intrinsic::maximumnum:
94   case Intrinsic::modf:
95   case Intrinsic::copysign:
96   case Intrinsic::floor:
97   case Intrinsic::ceil:
98   case Intrinsic::trunc:
99   case Intrinsic::rint:
100   case Intrinsic::nearbyint:
101   case Intrinsic::round:
102   case Intrinsic::roundeven:
103   case Intrinsic::pow:
104   case Intrinsic::fma:
105   case Intrinsic::fmuladd:
106   case Intrinsic::is_fpclass:
107   case Intrinsic::powi:
108   case Intrinsic::canonicalize:
109   case Intrinsic::fptosi_sat:
110   case Intrinsic::fptoui_sat:
111   case Intrinsic::lrint:
112   case Intrinsic::llrint:
113   case Intrinsic::ucmp:
114   case Intrinsic::scmp:
115     return true;
116   default:
117     return false;
118   }
119 }
120 
121 bool llvm::isTriviallyScalarizable(Intrinsic::ID ID,
122                                    const TargetTransformInfo *TTI) {
123   if (isTriviallyVectorizable(ID))
124     return true;
125 
126   if (TTI && Intrinsic::isTargetIntrinsic(ID))
127     return TTI->isTargetIntrinsicTriviallyScalarizable(ID);
128 
129   // TODO: Move frexp to isTriviallyVectorizable.
130   // https://github.com/llvm/llvm-project/issues/112408
131   switch (ID) {
132   case Intrinsic::frexp:
133   case Intrinsic::uadd_with_overflow:
134   case Intrinsic::sadd_with_overflow:
135   case Intrinsic::ssub_with_overflow:
136   case Intrinsic::usub_with_overflow:
137   case Intrinsic::umul_with_overflow:
138   case Intrinsic::smul_with_overflow:
139     return true;
140   }
141   return false;
142 }
143 
144 /// Identifies if the vector form of the intrinsic has a scalar operand.
145 bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
146                                               unsigned ScalarOpdIdx,
147                                               const TargetTransformInfo *TTI) {
148 
149   if (TTI && Intrinsic::isTargetIntrinsic(ID))
150     return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
151 
152   // Vector predication intrinsics have the EVL as the last operand.
153   if (VPIntrinsic::getVectorLengthParamPos(ID) == ScalarOpdIdx)
154     return true;
155 
156   switch (ID) {
157   case Intrinsic::abs:
158   case Intrinsic::vp_abs:
159   case Intrinsic::ctlz:
160   case Intrinsic::vp_ctlz:
161   case Intrinsic::cttz:
162   case Intrinsic::vp_cttz:
163   case Intrinsic::is_fpclass:
164   case Intrinsic::vp_is_fpclass:
165   case Intrinsic::powi:
166     return (ScalarOpdIdx == 1);
167   case Intrinsic::smul_fix:
168   case Intrinsic::smul_fix_sat:
169   case Intrinsic::umul_fix:
170   case Intrinsic::umul_fix_sat:
171     return (ScalarOpdIdx == 2);
172   case Intrinsic::experimental_vp_splice:
173     return ScalarOpdIdx == 2 || ScalarOpdIdx == 4;
174   default:
175     return false;
176   }
177 }
178 
179 bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
180     Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) {
181   assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");
182 
183   if (TTI && Intrinsic::isTargetIntrinsic(ID))
184     return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
185 
186   if (VPCastIntrinsic::isVPCast(ID))
187     return OpdIdx == -1 || OpdIdx == 0;
188 
189   switch (ID) {
190   case Intrinsic::fptosi_sat:
191   case Intrinsic::fptoui_sat:
192   case Intrinsic::lrint:
193   case Intrinsic::llrint:
194   case Intrinsic::vp_lrint:
195   case Intrinsic::vp_llrint:
196   case Intrinsic::ucmp:
197   case Intrinsic::scmp:
198     return OpdIdx == -1 || OpdIdx == 0;
199   case Intrinsic::modf:
200   case Intrinsic::sincos:
201   case Intrinsic::sincospi:
202   case Intrinsic::is_fpclass:
203   case Intrinsic::vp_is_fpclass:
204     return OpdIdx == 0;
205   case Intrinsic::powi:
206     return OpdIdx == -1 || OpdIdx == 1;
207   default:
208     return OpdIdx == -1;
209   }
210 }
211 
212 bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
213     Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {
214 
215   if (TTI && Intrinsic::isTargetIntrinsic(ID))
216     return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
217 
218   switch (ID) {
219   case Intrinsic::frexp:
220     return RetIdx == 0 || RetIdx == 1;
221   default:
222     return RetIdx == 0;
223   }
224 }
225 
226 /// Returns intrinsic ID for call.
227 /// For the input call instruction it finds mapping intrinsic and returns
228 /// its ID, in case it does not found it return not_intrinsic.
229 Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI,
230                                                 const TargetLibraryInfo *TLI) {
231   Intrinsic::ID ID = getIntrinsicForCallSite(*CI, TLI);
232   if (ID == Intrinsic::not_intrinsic)
233     return Intrinsic::not_intrinsic;
234 
235   if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start ||
236       ID == Intrinsic::lifetime_end || ID == Intrinsic::assume ||
237       ID == Intrinsic::experimental_noalias_scope_decl ||
238       ID == Intrinsic::sideeffect || ID == Intrinsic::pseudoprobe)
239     return ID;
240   return Intrinsic::not_intrinsic;
241 }
242 
243 struct InterleaveIntrinsic {
244   Intrinsic::ID Interleave, Deinterleave;
245 };
246 
247 static InterleaveIntrinsic InterleaveIntrinsics[] = {
248     {Intrinsic::vector_interleave2, Intrinsic::vector_deinterleave2},
249     {Intrinsic::vector_interleave3, Intrinsic::vector_deinterleave3},
250     {Intrinsic::vector_interleave4, Intrinsic::vector_deinterleave4},
251     {Intrinsic::vector_interleave5, Intrinsic::vector_deinterleave5},
252     {Intrinsic::vector_interleave6, Intrinsic::vector_deinterleave6},
253     {Intrinsic::vector_interleave7, Intrinsic::vector_deinterleave7},
254     {Intrinsic::vector_interleave8, Intrinsic::vector_deinterleave8},
255 };
256 
257 Intrinsic::ID llvm::getInterleaveIntrinsicID(unsigned Factor) {
258   assert(Factor >= 2 && Factor <= 8 && "Unexpected factor");
259   return InterleaveIntrinsics[Factor - 2].Interleave;
260 }
261 
262 Intrinsic::ID llvm::getDeinterleaveIntrinsicID(unsigned Factor) {
263   assert(Factor >= 2 && Factor <= 8 && "Unexpected factor");
264   return InterleaveIntrinsics[Factor - 2].Deinterleave;
265 }
266 
267 unsigned llvm::getInterleaveIntrinsicFactor(Intrinsic::ID ID) {
268   switch (ID) {
269   case Intrinsic::vector_interleave2:
270     return 2;
271   case Intrinsic::vector_interleave3:
272     return 3;
273   case Intrinsic::vector_interleave4:
274     return 4;
275   case Intrinsic::vector_interleave5:
276     return 5;
277   case Intrinsic::vector_interleave6:
278     return 6;
279   case Intrinsic::vector_interleave7:
280     return 7;
281   case Intrinsic::vector_interleave8:
282     return 8;
283   default:
284     return 0;
285   }
286 }
287 
288 unsigned llvm::getDeinterleaveIntrinsicFactor(Intrinsic::ID ID) {
289   switch (ID) {
290   case Intrinsic::vector_deinterleave2:
291     return 2;
292   case Intrinsic::vector_deinterleave3:
293     return 3;
294   case Intrinsic::vector_deinterleave4:
295     return 4;
296   case Intrinsic::vector_deinterleave5:
297     return 5;
298   case Intrinsic::vector_deinterleave6:
299     return 6;
300   case Intrinsic::vector_deinterleave7:
301     return 7;
302   case Intrinsic::vector_deinterleave8:
303     return 8;
304   default:
305     return 0;
306   }
307 }
308 
309 /// Given a vector and an element number, see if the scalar value is
310 /// already around as a register, for example if it were inserted then extracted
311 /// from the vector.
312 Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
313   assert(V->getType()->isVectorTy() && "Not looking at a vector?");
314   VectorType *VTy = cast<VectorType>(V->getType());
315   // For fixed-length vector, return poison for out of range access.
316   if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
317     unsigned Width = FVTy->getNumElements();
318     if (EltNo >= Width)
319       return PoisonValue::get(FVTy->getElementType());
320   }
321 
322   if (Constant *C = dyn_cast<Constant>(V))
323     return C->getAggregateElement(EltNo);
324 
325   if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
326     // If this is an insert to a variable element, we don't know what it is.
327     if (!isa<ConstantInt>(III->getOperand(2)))
328       return nullptr;
329     unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
330 
331     // If this is an insert to the element we are looking for, return the
332     // inserted value.
333     if (EltNo == IIElt)
334       return III->getOperand(1);
335 
336     // Guard against infinite loop on malformed, unreachable IR.
337     if (III == III->getOperand(0))
338       return nullptr;
339 
340     // Otherwise, the insertelement doesn't modify the value, recurse on its
341     // vector input.
342     return findScalarElement(III->getOperand(0), EltNo);
343   }
344 
345   ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V);
346   // Restrict the following transformation to fixed-length vector.
347   if (SVI && isa<FixedVectorType>(SVI->getType())) {
348     unsigned LHSWidth =
349         cast<FixedVectorType>(SVI->getOperand(0)->getType())->getNumElements();
350     int InEl = SVI->getMaskValue(EltNo);
351     if (InEl < 0)
352       return PoisonValue::get(VTy->getElementType());
353     if (InEl < (int)LHSWidth)
354       return findScalarElement(SVI->getOperand(0), InEl);
355     return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
356   }
357 
358   // Extract a value from a vector add operation with a constant zero.
359   // TODO: Use getBinOpIdentity() to generalize this.
360   Value *Val; Constant *C;
361   if (match(V, m_Add(m_Value(Val), m_Constant(C))))
362     if (Constant *Elt = C->getAggregateElement(EltNo))
363       if (Elt->isNullValue())
364         return findScalarElement(Val, EltNo);
365 
366   // If the vector is a splat then we can trivially find the scalar element.
367   if (isa<ScalableVectorType>(VTy))
368     if (Value *Splat = getSplatValue(V))
369       if (EltNo < VTy->getElementCount().getKnownMinValue())
370         return Splat;
371 
372   // Otherwise, we don't know.
373   return nullptr;
374 }
375 
376 int llvm::getSplatIndex(ArrayRef<int> Mask) {
377   int SplatIndex = -1;
378   for (int M : Mask) {
379     // Ignore invalid (undefined) mask elements.
380     if (M < 0)
381       continue;
382 
383     // There can be only 1 non-negative mask element value if this is a splat.
384     if (SplatIndex != -1 && SplatIndex != M)
385       return -1;
386 
387     // Initialize the splat index to the 1st non-negative mask element.
388     SplatIndex = M;
389   }
390   assert((SplatIndex == -1 || SplatIndex >= 0) && "Negative index?");
391   return SplatIndex;
392 }
393 
394 /// Get splat value if the input is a splat vector or return nullptr.
395 /// This function is not fully general. It checks only 2 cases:
396 /// the input value is (1) a splat constant vector or (2) a sequence
397 /// of instructions that broadcasts a scalar at element 0.
398 Value *llvm::getSplatValue(const Value *V) {
399   if (isa<VectorType>(V->getType()))
400     if (auto *C = dyn_cast<Constant>(V))
401       return C->getSplatValue();
402 
403   // shuf (inselt ?, Splat, 0), ?, <0, undef, 0, ...>
404   Value *Splat;
405   if (match(V,
406             m_Shuffle(m_InsertElt(m_Value(), m_Value(Splat), m_ZeroInt()),
407                       m_Value(), m_ZeroMask())))
408     return Splat;
409 
410   return nullptr;
411 }
412 
413 bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
414   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
415 
416   if (isa<VectorType>(V->getType())) {
417     if (isa<UndefValue>(V))
418       return true;
419     // FIXME: We can allow undefs, but if Index was specified, we may want to
420     //        check that the constant is defined at that index.
421     if (auto *C = dyn_cast<Constant>(V))
422       return C->getSplatValue() != nullptr;
423   }
424 
425   if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
426     // FIXME: We can safely allow undefs here. If Index was specified, we will
427     //        check that the mask elt is defined at the required index.
428     if (!all_equal(Shuf->getShuffleMask()))
429       return false;
430 
431     // Match any index.
432     if (Index == -1)
433       return true;
434 
435     // Match a specific element. The mask should be defined at and match the
436     // specified index.
437     return Shuf->getMaskValue(Index) == Index;
438   }
439 
440   // The remaining tests are all recursive, so bail out if we hit the limit.
441   if (Depth++ == MaxAnalysisRecursionDepth)
442     return false;
443 
444   // If both operands of a binop are splats, the result is a splat.
445   Value *X, *Y, *Z;
446   if (match(V, m_BinOp(m_Value(X), m_Value(Y))))
447     return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth);
448 
449   // If all operands of a select are splats, the result is a splat.
450   if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z))))
451     return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth) &&
452            isSplatValue(Z, Index, Depth);
453 
454   // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops).
455 
456   return false;
457 }
458 
459 bool llvm::getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
460                                   const APInt &DemandedElts, APInt &DemandedLHS,
461                                   APInt &DemandedRHS, bool AllowUndefElts) {
462   DemandedLHS = DemandedRHS = APInt::getZero(SrcWidth);
463 
464   // Early out if we don't demand any elements.
465   if (DemandedElts.isZero())
466     return true;
467 
468   // Simple case of a shuffle with zeroinitializer.
469   if (all_of(Mask, [](int Elt) { return Elt == 0; })) {
470     DemandedLHS.setBit(0);
471     return true;
472   }
473 
474   for (unsigned I = 0, E = Mask.size(); I != E; ++I) {
475     int M = Mask[I];
476     assert((-1 <= M) && (M < (SrcWidth * 2)) &&
477            "Invalid shuffle mask constant");
478 
479     if (!DemandedElts[I] || (AllowUndefElts && (M < 0)))
480       continue;
481 
482     // For undef elements, we don't know anything about the common state of
483     // the shuffle result.
484     if (M < 0)
485       return false;
486 
487     if (M < SrcWidth)
488       DemandedLHS.setBit(M);
489     else
490       DemandedRHS.setBit(M - SrcWidth);
491   }
492 
493   return true;
494 }
495 
496 bool llvm::isMaskedSlidePair(ArrayRef<int> Mask, int NumElts,
497                              std::array<std::pair<int, int>, 2> &SrcInfo) {
498   const int SignalValue = NumElts * 2;
499   SrcInfo[0] = {-1, SignalValue};
500   SrcInfo[1] = {-1, SignalValue};
501   for (auto [i, M] : enumerate(Mask)) {
502     if (M < 0)
503       continue;
504     int Src = M >= (int)NumElts;
505     int Diff = (int)i - (M % NumElts);
506     bool Match = false;
507     for (int j = 0; j < 2; j++) {
508       auto &[SrcE, DiffE] = SrcInfo[j];
509       if (SrcE == -1) {
510         assert(DiffE == SignalValue);
511         SrcE = Src;
512         DiffE = Diff;
513       }
514       if (SrcE == Src && DiffE == Diff) {
515         Match = true;
516         break;
517       }
518     }
519     if (!Match)
520       return false;
521   }
522   // Avoid all undef masks
523   return SrcInfo[0].first != -1;
524 }
525 
526 void llvm::narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
527                                  SmallVectorImpl<int> &ScaledMask) {
528   assert(Scale > 0 && "Unexpected scaling factor");
529 
530   // Fast-path: if no scaling, then it is just a copy.
531   if (Scale == 1) {
532     ScaledMask.assign(Mask.begin(), Mask.end());
533     return;
534   }
535 
536   ScaledMask.clear();
537   for (int MaskElt : Mask) {
538     if (MaskElt >= 0) {
539       assert(((uint64_t)Scale * MaskElt + (Scale - 1)) <= INT32_MAX &&
540              "Overflowed 32-bits");
541     }
542     for (int SliceElt = 0; SliceElt != Scale; ++SliceElt)
543       ScaledMask.push_back(MaskElt < 0 ? MaskElt : Scale * MaskElt + SliceElt);
544   }
545 }
546 
547 bool llvm::widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
548                                 SmallVectorImpl<int> &ScaledMask) {
549   assert(Scale > 0 && "Unexpected scaling factor");
550 
551   // Fast-path: if no scaling, then it is just a copy.
552   if (Scale == 1) {
553     ScaledMask.assign(Mask.begin(), Mask.end());
554     return true;
555   }
556 
557   // We must map the original elements down evenly to a type with less elements.
558   int NumElts = Mask.size();
559   if (NumElts % Scale != 0)
560     return false;
561 
562   ScaledMask.clear();
563   ScaledMask.reserve(NumElts / Scale);
564 
565   // Step through the input mask by splitting into Scale-sized slices.
566   do {
567     ArrayRef<int> MaskSlice = Mask.take_front(Scale);
568     assert((int)MaskSlice.size() == Scale && "Expected Scale-sized slice.");
569 
570     // The first element of the slice determines how we evaluate this slice.
571     int SliceFront = MaskSlice.front();
572     if (SliceFront < 0) {
573       // Negative values (undef or other "sentinel" values) must be equal across
574       // the entire slice.
575       if (!all_equal(MaskSlice))
576         return false;
577       ScaledMask.push_back(SliceFront);
578     } else {
579       // A positive mask element must be cleanly divisible.
580       if (SliceFront % Scale != 0)
581         return false;
582       // Elements of the slice must be consecutive.
583       for (int i = 1; i < Scale; ++i)
584         if (MaskSlice[i] != SliceFront + i)
585           return false;
586       ScaledMask.push_back(SliceFront / Scale);
587     }
588     Mask = Mask.drop_front(Scale);
589   } while (!Mask.empty());
590 
591   assert((int)ScaledMask.size() * Scale == NumElts && "Unexpected scaled mask");
592 
593   // All elements of the original mask can be scaled down to map to the elements
594   // of a mask with wider elements.
595   return true;
596 }
597 
598 bool llvm::widenShuffleMaskElts(ArrayRef<int> M,
599                                 SmallVectorImpl<int> &NewMask) {
600   unsigned NumElts = M.size();
601   if (NumElts % 2 != 0)
602     return false;
603 
604   NewMask.clear();
605   for (unsigned i = 0; i < NumElts; i += 2) {
606     int M0 = M[i];
607     int M1 = M[i + 1];
608 
609     // If both elements are undef, new mask is undef too.
610     if (M0 == -1 && M1 == -1) {
611       NewMask.push_back(-1);
612       continue;
613     }
614 
615     if (M0 == -1 && M1 != -1 && (M1 % 2) == 1) {
616       NewMask.push_back(M1 / 2);
617       continue;
618     }
619 
620     if (M0 != -1 && (M0 % 2) == 0 && ((M0 + 1) == M1 || M1 == -1)) {
621       NewMask.push_back(M0 / 2);
622       continue;
623     }
624 
625     NewMask.clear();
626     return false;
627   }
628 
629   assert(NewMask.size() == NumElts / 2 && "Incorrect size for mask!");
630   return true;
631 }
632 
633 bool llvm::scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask,
634                                 SmallVectorImpl<int> &ScaledMask) {
635   unsigned NumSrcElts = Mask.size();
636   assert(NumSrcElts > 0 && NumDstElts > 0 && "Unexpected scaling factor");
637 
638   // Fast-path: if no scaling, then it is just a copy.
639   if (NumSrcElts == NumDstElts) {
640     ScaledMask.assign(Mask.begin(), Mask.end());
641     return true;
642   }
643 
644   // Ensure we can find a whole scale factor.
645   assert(((NumSrcElts % NumDstElts) == 0 || (NumDstElts % NumSrcElts) == 0) &&
646          "Unexpected scaling factor");
647 
648   if (NumSrcElts > NumDstElts) {
649     int Scale = NumSrcElts / NumDstElts;
650     return widenShuffleMaskElts(Scale, Mask, ScaledMask);
651   }
652 
653   int Scale = NumDstElts / NumSrcElts;
654   narrowShuffleMaskElts(Scale, Mask, ScaledMask);
655   return true;
656 }
657 
658 void llvm::getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
659                                         SmallVectorImpl<int> &ScaledMask) {
660   std::array<SmallVector<int, 16>, 2> TmpMasks;
661   SmallVectorImpl<int> *Output = &TmpMasks[0], *Tmp = &TmpMasks[1];
662   ArrayRef<int> InputMask = Mask;
663   for (unsigned Scale = 2; Scale <= InputMask.size(); ++Scale) {
664     while (widenShuffleMaskElts(Scale, InputMask, *Output)) {
665       InputMask = *Output;
666       std::swap(Output, Tmp);
667     }
668   }
669   ScaledMask.assign(InputMask.begin(), InputMask.end());
670 }
671 
672 void llvm::processShuffleMasks(
673     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
674     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
675     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
676     function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)>
677         ManyInputsAction) {
678   SmallVector<SmallVector<SmallVector<int>>> Res(NumOfDestRegs);
679   // Try to perform better estimation of the permutation.
680   // 1. Split the source/destination vectors into real registers.
681   // 2. Do the mask analysis to identify which real registers are
682   // permuted.
683   int Sz = Mask.size();
684   unsigned SzDest = Sz / NumOfDestRegs;
685   unsigned SzSrc = Sz / NumOfSrcRegs;
686   for (unsigned I = 0; I < NumOfDestRegs; ++I) {
687     auto &RegMasks = Res[I];
688     RegMasks.assign(2 * NumOfSrcRegs, {});
689     // Check that the values in dest registers are in the one src
690     // register.
691     for (unsigned K = 0; K < SzDest; ++K) {
692       int Idx = I * SzDest + K;
693       if (Idx == Sz)
694         break;
695       if (Mask[Idx] >= 2 * Sz || Mask[Idx] == PoisonMaskElem)
696         continue;
697       int MaskIdx = Mask[Idx] % Sz;
698       int SrcRegIdx = MaskIdx / SzSrc + (Mask[Idx] >= Sz ? NumOfSrcRegs : 0);
699       // Add a cost of PermuteTwoSrc for each new source register permute,
700       // if we have more than one source registers.
701       if (RegMasks[SrcRegIdx].empty())
702         RegMasks[SrcRegIdx].assign(SzDest, PoisonMaskElem);
703       RegMasks[SrcRegIdx][K] = MaskIdx % SzSrc;
704     }
705   }
706   // Process split mask.
707   for (unsigned I : seq<unsigned>(NumOfUsedRegs)) {
708     auto &Dest = Res[I];
709     int NumSrcRegs =
710         count_if(Dest, [](ArrayRef<int> Mask) { return !Mask.empty(); });
711     switch (NumSrcRegs) {
712     case 0:
713       // No input vectors were used!
714       NoInputAction();
715       break;
716     case 1: {
717       // Find the only mask with at least single undef mask elem.
718       auto *It =
719           find_if(Dest, [](ArrayRef<int> Mask) { return !Mask.empty(); });
720       unsigned SrcReg = std::distance(Dest.begin(), It);
721       SingleInputAction(*It, SrcReg, I);
722       break;
723     }
724     default: {
725       // The first mask is a permutation of a single register. Since we have >2
726       // input registers to shuffle, we merge the masks for 2 first registers
727       // and generate a shuffle of 2 registers rather than the reordering of the
728       // first register and then shuffle with the second register. Next,
729       // generate the shuffles of the resulting register + the remaining
730       // registers from the list.
731       auto &&CombineMasks = [](MutableArrayRef<int> FirstMask,
732                                ArrayRef<int> SecondMask) {
733         for (int Idx = 0, VF = FirstMask.size(); Idx < VF; ++Idx) {
734           if (SecondMask[Idx] != PoisonMaskElem) {
735             assert(FirstMask[Idx] == PoisonMaskElem &&
736                    "Expected undefined mask element.");
737             FirstMask[Idx] = SecondMask[Idx] + VF;
738           }
739         }
740       };
741       auto &&NormalizeMask = [](MutableArrayRef<int> Mask) {
742         for (int Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) {
743           if (Mask[Idx] != PoisonMaskElem)
744             Mask[Idx] = Idx;
745         }
746       };
747       int SecondIdx;
748       bool NewReg = true;
749       do {
750         int FirstIdx = -1;
751         SecondIdx = -1;
752         MutableArrayRef<int> FirstMask, SecondMask;
753         for (unsigned I : seq<unsigned>(2 * NumOfSrcRegs)) {
754           SmallVectorImpl<int> &RegMask = Dest[I];
755           if (RegMask.empty())
756             continue;
757 
758           if (FirstIdx == SecondIdx) {
759             FirstIdx = I;
760             FirstMask = RegMask;
761             continue;
762           }
763           SecondIdx = I;
764           SecondMask = RegMask;
765           CombineMasks(FirstMask, SecondMask);
766           ManyInputsAction(FirstMask, FirstIdx, SecondIdx, NewReg);
767           NewReg = false;
768           NormalizeMask(FirstMask);
769           RegMask.clear();
770           SecondMask = FirstMask;
771           SecondIdx = FirstIdx;
772         }
773         if (FirstIdx != SecondIdx && SecondIdx >= 0) {
774           CombineMasks(SecondMask, FirstMask);
775           ManyInputsAction(SecondMask, SecondIdx, FirstIdx, NewReg);
776           NewReg = false;
777           Dest[FirstIdx].clear();
778           NormalizeMask(SecondMask);
779         }
780       } while (SecondIdx >= 0);
781       break;
782     }
783     }
784   }
785 }
786 
787 void llvm::getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
788                                                const APInt &DemandedElts,
789                                                APInt &DemandedLHS,
790                                                APInt &DemandedRHS) {
791   assert(VectorBitWidth >= 128 && "Vectors smaller than 128 bit not supported");
792   int NumLanes = VectorBitWidth / 128;
793   int NumElts = DemandedElts.getBitWidth();
794   int NumEltsPerLane = NumElts / NumLanes;
795   int HalfEltsPerLane = NumEltsPerLane / 2;
796 
797   DemandedLHS = APInt::getZero(NumElts);
798   DemandedRHS = APInt::getZero(NumElts);
799 
800   // Map DemandedElts to the horizontal operands.
801   for (int Idx = 0; Idx != NumElts; ++Idx) {
802     if (!DemandedElts[Idx])
803       continue;
804     int LaneIdx = (Idx / NumEltsPerLane) * NumEltsPerLane;
805     int LocalIdx = Idx % NumEltsPerLane;
806     if (LocalIdx < HalfEltsPerLane) {
807       DemandedLHS.setBit(LaneIdx + 2 * LocalIdx);
808     } else {
809       LocalIdx -= HalfEltsPerLane;
810       DemandedRHS.setBit(LaneIdx + 2 * LocalIdx);
811     }
812   }
813 }
814 
815 MapVector<Instruction *, uint64_t>
816 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
817                                const TargetTransformInfo *TTI) {
818 
819   // DemandedBits will give us every value's live-out bits. But we want
820   // to ensure no extra casts would need to be inserted, so every DAG
821   // of connected values must have the same minimum bitwidth.
822   EquivalenceClasses<Value *> ECs;
823   SmallVector<Instruction *, 16> Worklist;
824   SmallPtrSet<Instruction *, 4> Roots;
825   SmallPtrSet<Instruction *, 16> Visited;
826   DenseMap<Value *, uint64_t> DBits;
827   SmallPtrSet<Instruction *, 4> InstructionSet;
828   MapVector<Instruction *, uint64_t> MinBWs;
829 
830   // Determine the roots. We work bottom-up, from truncs or icmps.
831   bool SeenExtFromIllegalType = false;
832   for (auto *BB : Blocks)
833     for (auto &I : *BB) {
834       InstructionSet.insert(&I);
835 
836       if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) &&
837           !TTI->isTypeLegal(I.getOperand(0)->getType()))
838         SeenExtFromIllegalType = true;
839 
840       // Only deal with non-vector integers up to 64-bits wide.
841       if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) &&
842           !I.getType()->isVectorTy() &&
843           I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) {
844         // Don't make work for ourselves. If we know the loaded type is legal,
845         // don't add it to the worklist.
846         if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType()))
847           continue;
848 
849         Worklist.push_back(&I);
850         Roots.insert(&I);
851       }
852     }
853   // Early exit.
854   if (Worklist.empty() || (TTI && !SeenExtFromIllegalType))
855     return MinBWs;
856 
857   // Now proceed breadth-first, unioning values together.
858   while (!Worklist.empty()) {
859     Instruction *I = Worklist.pop_back_val();
860     Value *Leader = ECs.getOrInsertLeaderValue(I);
861 
862     if (!Visited.insert(I).second)
863       continue;
864 
865     // If we encounter a type that is larger than 64 bits, we can't represent
866     // it so bail out.
867     if (DB.getDemandedBits(I).getBitWidth() > 64)
868       return MapVector<Instruction *, uint64_t>();
869 
870     uint64_t V = DB.getDemandedBits(I).getZExtValue();
871     DBits[Leader] |= V;
872     DBits[I] = V;
873 
874     // Casts, loads and instructions outside of our range terminate a chain
875     // successfully.
876     if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) ||
877         !InstructionSet.count(I))
878       continue;
879 
880     // Unsafe casts terminate a chain unsuccessfully. We can't do anything
881     // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to
882     // transform anything that relies on them.
883     if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) ||
884         !I->getType()->isIntegerTy()) {
885       DBits[Leader] |= ~0ULL;
886       continue;
887     }
888 
889     // We don't modify the types of PHIs. Reductions will already have been
890     // truncated if possible, and inductions' sizes will have been chosen by
891     // indvars.
892     if (isa<PHINode>(I))
893       continue;
894 
895     // Don't modify the types of operands of a call, as doing that would cause a
896     // signature mismatch.
897     if (isa<CallBase>(I))
898       continue;
899 
900     if (DBits[Leader] == ~0ULL)
901       // All bits demanded, no point continuing.
902       continue;
903 
904     for (Value *O : I->operands()) {
905       ECs.unionSets(Leader, O);
906       if (auto *OI = dyn_cast<Instruction>(O))
907         Worklist.push_back(OI);
908     }
909   }
910 
911   // Now we've discovered all values, walk them to see if there are
912   // any users we didn't see. If there are, we can't optimize that
913   // chain.
914   for (auto &I : DBits)
915     for (auto *U : I.first->users())
916       if (U->getType()->isIntegerTy() && DBits.count(U) == 0)
917         DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL;
918 
919   for (const auto &E : ECs) {
920     if (!E->isLeader())
921       continue;
922     uint64_t LeaderDemandedBits = 0;
923     for (Value *M : ECs.members(*E))
924       LeaderDemandedBits |= DBits[M];
925 
926     uint64_t MinBW = llvm::bit_width(LeaderDemandedBits);
927     // Round up to a power of 2
928     MinBW = llvm::bit_ceil(MinBW);
929 
930     // We don't modify the types of PHIs. Reductions will already have been
931     // truncated if possible, and inductions' sizes will have been chosen by
932     // indvars.
933     // If we are required to shrink a PHI, abandon this entire equivalence class.
934     bool Abort = false;
935     for (Value *M : ECs.members(*E))
936       if (isa<PHINode>(M) && MinBW < M->getType()->getScalarSizeInBits()) {
937         Abort = true;
938         break;
939       }
940     if (Abort)
941       continue;
942 
943     for (Value *M : ECs.members(*E)) {
944       auto *MI = dyn_cast<Instruction>(M);
945       if (!MI)
946         continue;
947       Type *Ty = M->getType();
948       if (Roots.count(MI))
949         Ty = MI->getOperand(0)->getType();
950 
951       if (MinBW >= Ty->getScalarSizeInBits())
952         continue;
953 
954       // If any of M's operands demand more bits than MinBW then M cannot be
955       // performed safely in MinBW.
956       auto *Call = dyn_cast<CallBase>(MI);
957       auto Ops = Call ? Call->args() : MI->operands();
958       if (any_of(Ops, [&DB, MinBW](Use &U) {
959             auto *CI = dyn_cast<ConstantInt>(U);
960             // For constants shift amounts, check if the shift would result in
961             // poison.
962             if (CI &&
963                 isa<ShlOperator, LShrOperator, AShrOperator>(U.getUser()) &&
964                 U.getOperandNo() == 1)
965               return CI->uge(MinBW);
966             uint64_t BW = bit_width(DB.getDemandedBits(&U).getZExtValue());
967             return bit_ceil(BW) > MinBW;
968           }))
969         continue;
970 
971       MinBWs[MI] = MinBW;
972     }
973   }
974 
975   return MinBWs;
976 }
977 
978 /// Add all access groups in @p AccGroups to @p List.
979 template <typename ListT>
980 static void addToAccessGroupList(ListT &List, MDNode *AccGroups) {
981   // Interpret an access group as a list containing itself.
982   if (AccGroups->getNumOperands() == 0) {
983     assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group");
984     List.insert(AccGroups);
985     return;
986   }
987 
988   for (const auto &AccGroupListOp : AccGroups->operands()) {
989     auto *Item = cast<MDNode>(AccGroupListOp.get());
990     assert(isValidAsAccessGroup(Item) && "List item must be an access group");
991     List.insert(Item);
992   }
993 }
994 
995 MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) {
996   if (!AccGroups1)
997     return AccGroups2;
998   if (!AccGroups2)
999     return AccGroups1;
1000   if (AccGroups1 == AccGroups2)
1001     return AccGroups1;
1002 
1003   SmallSetVector<Metadata *, 4> Union;
1004   addToAccessGroupList(Union, AccGroups1);
1005   addToAccessGroupList(Union, AccGroups2);
1006 
1007   if (Union.size() == 0)
1008     return nullptr;
1009   if (Union.size() == 1)
1010     return cast<MDNode>(Union.front());
1011 
1012   LLVMContext &Ctx = AccGroups1->getContext();
1013   return MDNode::get(Ctx, Union.getArrayRef());
1014 }
1015 
1016 MDNode *llvm::intersectAccessGroups(const Instruction *Inst1,
1017                                     const Instruction *Inst2) {
1018   bool MayAccessMem1 = Inst1->mayReadOrWriteMemory();
1019   bool MayAccessMem2 = Inst2->mayReadOrWriteMemory();
1020 
1021   if (!MayAccessMem1 && !MayAccessMem2)
1022     return nullptr;
1023   if (!MayAccessMem1)
1024     return Inst2->getMetadata(LLVMContext::MD_access_group);
1025   if (!MayAccessMem2)
1026     return Inst1->getMetadata(LLVMContext::MD_access_group);
1027 
1028   MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group);
1029   MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group);
1030   if (!MD1 || !MD2)
1031     return nullptr;
1032   if (MD1 == MD2)
1033     return MD1;
1034 
1035   // Use set for scalable 'contains' check.
1036   SmallPtrSet<Metadata *, 4> AccGroupSet2;
1037   addToAccessGroupList(AccGroupSet2, MD2);
1038 
1039   SmallVector<Metadata *, 4> Intersection;
1040   if (MD1->getNumOperands() == 0) {
1041     assert(isValidAsAccessGroup(MD1) && "Node must be an access group");
1042     if (AccGroupSet2.count(MD1))
1043       Intersection.push_back(MD1);
1044   } else {
1045     for (const MDOperand &Node : MD1->operands()) {
1046       auto *Item = cast<MDNode>(Node.get());
1047       assert(isValidAsAccessGroup(Item) && "List item must be an access group");
1048       if (AccGroupSet2.count(Item))
1049         Intersection.push_back(Item);
1050     }
1051   }
1052 
1053   if (Intersection.size() == 0)
1054     return nullptr;
1055   if (Intersection.size() == 1)
1056     return cast<MDNode>(Intersection.front());
1057 
1058   LLVMContext &Ctx = Inst1->getContext();
1059   return MDNode::get(Ctx, Intersection);
1060 }
1061 
1062 /// Add metadata from \p Inst to \p Metadata, if it can be preserved after
1063 /// vectorization.
1064 void llvm::getMetadataToPropagate(
1065     Instruction *Inst,
1066     SmallVectorImpl<std::pair<unsigned, MDNode *>> &Metadata) {
1067   Inst->getAllMetadataOtherThanDebugLoc(Metadata);
1068   static const unsigned SupportedIDs[] = {
1069       LLVMContext::MD_tbaa,         LLVMContext::MD_alias_scope,
1070       LLVMContext::MD_noalias,      LLVMContext::MD_fpmath,
1071       LLVMContext::MD_nontemporal,  LLVMContext::MD_invariant_load,
1072       LLVMContext::MD_access_group, LLVMContext::MD_mmra};
1073 
1074   // Remove any unsupported metadata kinds from Metadata.
1075   for (unsigned Idx = 0; Idx != Metadata.size();) {
1076     if (is_contained(SupportedIDs, Metadata[Idx].first)) {
1077       ++Idx;
1078     } else {
1079       // Swap element to end and remove it.
1080       std::swap(Metadata[Idx], Metadata.back());
1081       Metadata.pop_back();
1082     }
1083   }
1084 }
1085 
1086 /// \returns \p I after propagating metadata from \p VL.
1087 Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) {
1088   if (VL.empty())
1089     return Inst;
1090   SmallVector<std::pair<unsigned, MDNode *>> Metadata;
1091   getMetadataToPropagate(cast<Instruction>(VL[0]), Metadata);
1092 
1093   for (auto &[Kind, MD] : Metadata) {
1094     for (int J = 1, E = VL.size(); MD && J != E; ++J) {
1095       const Instruction *IJ = cast<Instruction>(VL[J]);
1096       MDNode *IMD = IJ->getMetadata(Kind);
1097 
1098       switch (Kind) {
1099       case LLVMContext::MD_mmra: {
1100         MD = MMRAMetadata::combine(Inst->getContext(), MD, IMD);
1101         break;
1102       }
1103       case LLVMContext::MD_tbaa:
1104         MD = MDNode::getMostGenericTBAA(MD, IMD);
1105         break;
1106       case LLVMContext::MD_alias_scope:
1107         MD = MDNode::getMostGenericAliasScope(MD, IMD);
1108         break;
1109       case LLVMContext::MD_fpmath:
1110         MD = MDNode::getMostGenericFPMath(MD, IMD);
1111         break;
1112       case LLVMContext::MD_noalias:
1113       case LLVMContext::MD_nontemporal:
1114       case LLVMContext::MD_invariant_load:
1115         MD = MDNode::intersect(MD, IMD);
1116         break;
1117       case LLVMContext::MD_access_group:
1118         MD = intersectAccessGroups(Inst, IJ);
1119         break;
1120       default:
1121         llvm_unreachable("unhandled metadata");
1122       }
1123     }
1124 
1125     Inst->setMetadata(Kind, MD);
1126   }
1127 
1128   return Inst;
1129 }
1130 
1131 Constant *
1132 llvm::createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
1133                            const InterleaveGroup<Instruction> &Group) {
1134   // All 1's means mask is not needed.
1135   if (Group.getNumMembers() == Group.getFactor())
1136     return nullptr;
1137 
1138   // TODO: support reversed access.
1139   assert(!Group.isReverse() && "Reversed group not supported.");
1140 
1141   SmallVector<Constant *, 16> Mask;
1142   for (unsigned i = 0; i < VF; i++)
1143     for (unsigned j = 0; j < Group.getFactor(); ++j) {
1144       unsigned HasMember = Group.getMember(j) ? 1 : 0;
1145       Mask.push_back(Builder.getInt1(HasMember));
1146     }
1147 
1148   return ConstantVector::get(Mask);
1149 }
1150 
1151 llvm::SmallVector<int, 16>
1152 llvm::createReplicatedMask(unsigned ReplicationFactor, unsigned VF) {
1153   SmallVector<int, 16> MaskVec;
1154   for (unsigned i = 0; i < VF; i++)
1155     for (unsigned j = 0; j < ReplicationFactor; j++)
1156       MaskVec.push_back(i);
1157 
1158   return MaskVec;
1159 }
1160 
1161 llvm::SmallVector<int, 16> llvm::createInterleaveMask(unsigned VF,
1162                                                       unsigned NumVecs) {
1163   SmallVector<int, 16> Mask;
1164   for (unsigned i = 0; i < VF; i++)
1165     for (unsigned j = 0; j < NumVecs; j++)
1166       Mask.push_back(j * VF + i);
1167 
1168   return Mask;
1169 }
1170 
1171 llvm::SmallVector<int, 16>
1172 llvm::createStrideMask(unsigned Start, unsigned Stride, unsigned VF) {
1173   SmallVector<int, 16> Mask;
1174   for (unsigned i = 0; i < VF; i++)
1175     Mask.push_back(Start + i * Stride);
1176 
1177   return Mask;
1178 }
1179 
1180 llvm::SmallVector<int, 16> llvm::createSequentialMask(unsigned Start,
1181                                                       unsigned NumInts,
1182                                                       unsigned NumUndefs) {
1183   SmallVector<int, 16> Mask;
1184   for (unsigned i = 0; i < NumInts; i++)
1185     Mask.push_back(Start + i);
1186 
1187   for (unsigned i = 0; i < NumUndefs; i++)
1188     Mask.push_back(-1);
1189 
1190   return Mask;
1191 }
1192 
1193 llvm::SmallVector<int, 16> llvm::createUnaryMask(ArrayRef<int> Mask,
1194                                                  unsigned NumElts) {
1195   // Avoid casts in the loop and make sure we have a reasonable number.
1196   int NumEltsSigned = NumElts;
1197   assert(NumEltsSigned > 0 && "Expected smaller or non-zero element count");
1198 
1199   // If the mask chooses an element from operand 1, reduce it to choose from the
1200   // corresponding element of operand 0. Undef mask elements are unchanged.
1201   SmallVector<int, 16> UnaryMask;
1202   for (int MaskElt : Mask) {
1203     assert((MaskElt < NumEltsSigned * 2) && "Expected valid shuffle mask");
1204     int UnaryElt = MaskElt >= NumEltsSigned ? MaskElt - NumEltsSigned : MaskElt;
1205     UnaryMask.push_back(UnaryElt);
1206   }
1207   return UnaryMask;
1208 }
1209 
1210 /// A helper function for concatenating vectors. This function concatenates two
1211 /// vectors having the same element type. If the second vector has fewer
1212 /// elements than the first, it is padded with undefs.
1213 static Value *concatenateTwoVectors(IRBuilderBase &Builder, Value *V1,
1214                                     Value *V2) {
1215   VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType());
1216   VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType());
1217   assert(VecTy1 && VecTy2 &&
1218          VecTy1->getScalarType() == VecTy2->getScalarType() &&
1219          "Expect two vectors with the same element type");
1220 
1221   unsigned NumElts1 = cast<FixedVectorType>(VecTy1)->getNumElements();
1222   unsigned NumElts2 = cast<FixedVectorType>(VecTy2)->getNumElements();
1223   assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements");
1224 
1225   if (NumElts1 > NumElts2) {
1226     // Extend with UNDEFs.
1227     V2 = Builder.CreateShuffleVector(
1228         V2, createSequentialMask(0, NumElts2, NumElts1 - NumElts2));
1229   }
1230 
1231   return Builder.CreateShuffleVector(
1232       V1, V2, createSequentialMask(0, NumElts1 + NumElts2, 0));
1233 }
1234 
1235 Value *llvm::concatenateVectors(IRBuilderBase &Builder,
1236                                 ArrayRef<Value *> Vecs) {
1237   unsigned NumVecs = Vecs.size();
1238   assert(NumVecs > 1 && "Should be at least two vectors");
1239 
1240   SmallVector<Value *, 8> ResList;
1241   ResList.append(Vecs.begin(), Vecs.end());
1242   do {
1243     SmallVector<Value *, 8> TmpList;
1244     for (unsigned i = 0; i < NumVecs - 1; i += 2) {
1245       Value *V0 = ResList[i], *V1 = ResList[i + 1];
1246       assert((V0->getType() == V1->getType() || i == NumVecs - 2) &&
1247              "Only the last vector may have a different type");
1248 
1249       TmpList.push_back(concatenateTwoVectors(Builder, V0, V1));
1250     }
1251 
1252     // Push the last vector if the total number of vectors is odd.
1253     if (NumVecs % 2 != 0)
1254       TmpList.push_back(ResList[NumVecs - 1]);
1255 
1256     ResList = TmpList;
1257     NumVecs = ResList.size();
1258   } while (NumVecs > 1);
1259 
1260   return ResList[0];
1261 }
1262 
1263 bool llvm::maskIsAllZeroOrUndef(Value *Mask) {
1264   assert(isa<VectorType>(Mask->getType()) &&
1265          isa<IntegerType>(Mask->getType()->getScalarType()) &&
1266          cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
1267              1 &&
1268          "Mask must be a vector of i1");
1269 
1270   auto *ConstMask = dyn_cast<Constant>(Mask);
1271   if (!ConstMask)
1272     return false;
1273   if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask))
1274     return true;
1275   if (isa<ScalableVectorType>(ConstMask->getType()))
1276     return false;
1277   for (unsigned
1278            I = 0,
1279            E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
1280        I != E; ++I) {
1281     if (auto *MaskElt = ConstMask->getAggregateElement(I))
1282       if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt))
1283         continue;
1284     return false;
1285   }
1286   return true;
1287 }
1288 
1289 bool llvm::maskIsAllOneOrUndef(Value *Mask) {
1290   assert(isa<VectorType>(Mask->getType()) &&
1291          isa<IntegerType>(Mask->getType()->getScalarType()) &&
1292          cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
1293              1 &&
1294          "Mask must be a vector of i1");
1295 
1296   auto *ConstMask = dyn_cast<Constant>(Mask);
1297   if (!ConstMask)
1298     return false;
1299   if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
1300     return true;
1301   if (isa<ScalableVectorType>(ConstMask->getType()))
1302     return false;
1303   for (unsigned
1304            I = 0,
1305            E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
1306        I != E; ++I) {
1307     if (auto *MaskElt = ConstMask->getAggregateElement(I))
1308       if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
1309         continue;
1310     return false;
1311   }
1312   return true;
1313 }
1314 
1315 bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
1316   assert(isa<VectorType>(Mask->getType()) &&
1317          isa<IntegerType>(Mask->getType()->getScalarType()) &&
1318          cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
1319              1 &&
1320          "Mask must be a vector of i1");
1321 
1322   auto *ConstMask = dyn_cast<Constant>(Mask);
1323   if (!ConstMask)
1324     return false;
1325   if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
1326     return true;
1327   if (isa<ScalableVectorType>(ConstMask->getType()))
1328     return false;
1329   for (unsigned
1330            I = 0,
1331            E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
1332        I != E; ++I) {
1333     if (auto *MaskElt = ConstMask->getAggregateElement(I))
1334       if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
1335         return true;
1336   }
1337   return false;
1338 }
1339 
1340 /// TODO: This is a lot like known bits, but for
1341 /// vectors.  Is there something we can common this with?
1342 APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
1343   assert(isa<FixedVectorType>(Mask->getType()) &&
1344          isa<IntegerType>(Mask->getType()->getScalarType()) &&
1345          cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
1346              1 &&
1347          "Mask must be a fixed width vector of i1");
1348 
1349   const unsigned VWidth =
1350       cast<FixedVectorType>(Mask->getType())->getNumElements();
1351   APInt DemandedElts = APInt::getAllOnes(VWidth);
1352   if (auto *CV = dyn_cast<ConstantVector>(Mask))
1353     for (unsigned i = 0; i < VWidth; i++)
1354       if (CV->getAggregateElement(i)->isNullValue())
1355         DemandedElts.clearBit(i);
1356   return DemandedElts;
1357 }
1358 
1359 bool InterleavedAccessInfo::isStrided(int Stride) {
1360   unsigned Factor = std::abs(Stride);
1361   return Factor >= 2 && Factor <= MaxInterleaveGroupFactor;
1362 }
1363 
1364 void InterleavedAccessInfo::collectConstStrideAccesses(
1365     MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
1366     const DenseMap<Value*, const SCEV*> &Strides) {
1367   auto &DL = TheLoop->getHeader()->getDataLayout();
1368 
1369   // Since it's desired that the load/store instructions be maintained in
1370   // "program order" for the interleaved access analysis, we have to visit the
1371   // blocks in the loop in reverse postorder (i.e., in a topological order).
1372   // Such an ordering will ensure that any load/store that may be executed
1373   // before a second load/store will precede the second load/store in
1374   // AccessStrideInfo.
1375   LoopBlocksDFS DFS(TheLoop);
1376   DFS.perform(LI);
1377   for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO()))
1378     for (auto &I : *BB) {
1379       Value *Ptr = getLoadStorePointerOperand(&I);
1380       if (!Ptr)
1381         continue;
1382       Type *ElementTy = getLoadStoreType(&I);
1383 
1384       // Currently, codegen doesn't support cases where the type size doesn't
1385       // match the alloc size. Skip them for now.
1386       uint64_t Size = DL.getTypeAllocSize(ElementTy);
1387       if (Size * 8 != DL.getTypeSizeInBits(ElementTy))
1388         continue;
1389 
1390       // We don't check wrapping here because we don't know yet if Ptr will be
1391       // part of a full group or a group with gaps. Checking wrapping for all
1392       // pointers (even those that end up in groups with no gaps) will be overly
1393       // conservative. For full groups, wrapping should be ok since if we would
1394       // wrap around the address space we would do a memory access at nullptr
1395       // even without the transformation. The wrapping checks are therefore
1396       // deferred until after we've formed the interleaved groups.
1397       int64_t Stride =
1398         getPtrStride(PSE, ElementTy, Ptr, TheLoop, Strides,
1399                      /*Assume=*/true, /*ShouldCheckWrap=*/false).value_or(0);
1400 
1401       const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
1402       AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size,
1403                                               getLoadStoreAlignment(&I));
1404     }
1405 }
1406 
1407 // Analyze interleaved accesses and collect them into interleaved load and
1408 // store groups.
1409 //
1410 // When generating code for an interleaved load group, we effectively hoist all
1411 // loads in the group to the location of the first load in program order. When
1412 // generating code for an interleaved store group, we sink all stores to the
1413 // location of the last store. This code motion can change the order of load
1414 // and store instructions and may break dependences.
1415 //
1416 // The code generation strategy mentioned above ensures that we won't violate
1417 // any write-after-read (WAR) dependences.
1418 //
1419 // E.g., for the WAR dependence:  a = A[i];      // (1)
1420 //                                A[i] = b;      // (2)
1421 //
1422 // The store group of (2) is always inserted at or below (2), and the load
1423 // group of (1) is always inserted at or above (1). Thus, the instructions will
1424 // never be reordered. All other dependences are checked to ensure the
1425 // correctness of the instruction reordering.
1426 //
1427 // The algorithm visits all memory accesses in the loop in bottom-up program
1428 // order. Program order is established by traversing the blocks in the loop in
1429 // reverse postorder when collecting the accesses.
1430 //
1431 // We visit the memory accesses in bottom-up order because it can simplify the
1432 // construction of store groups in the presence of write-after-write (WAW)
1433 // dependences.
1434 //
1435 // E.g., for the WAW dependence:  A[i] = a;      // (1)
1436 //                                A[i] = b;      // (2)
1437 //                                A[i + 1] = c;  // (3)
1438 //
1439 // We will first create a store group with (3) and (2). (1) can't be added to
1440 // this group because it and (2) are dependent. However, (1) can be grouped
1441 // with other accesses that may precede it in program order. Note that a
1442 // bottom-up order does not imply that WAW dependences should not be checked.
1443 void InterleavedAccessInfo::analyzeInterleaving(
1444                                  bool EnablePredicatedInterleavedMemAccesses) {
1445   LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n");
1446   const auto &Strides = LAI->getSymbolicStrides();
1447 
1448   // Holds all accesses with a constant stride.
1449   MapVector<Instruction *, StrideDescriptor> AccessStrideInfo;
1450   collectConstStrideAccesses(AccessStrideInfo, Strides);
1451 
1452   if (AccessStrideInfo.empty())
1453     return;
1454 
1455   // Collect the dependences in the loop.
1456   collectDependences();
1457 
1458   // Holds all interleaved store groups temporarily.
1459   SmallSetVector<InterleaveGroup<Instruction> *, 4> StoreGroups;
1460   // Holds all interleaved load groups temporarily.
1461   SmallSetVector<InterleaveGroup<Instruction> *, 4> LoadGroups;
1462   // Groups added to this set cannot have new members added.
1463   SmallPtrSet<InterleaveGroup<Instruction> *, 4> CompletedLoadGroups;
1464 
1465   // Search in bottom-up program order for pairs of accesses (A and B) that can
1466   // form interleaved load or store groups. In the algorithm below, access A
1467   // precedes access B in program order. We initialize a group for B in the
1468   // outer loop of the algorithm, and then in the inner loop, we attempt to
1469   // insert each A into B's group if:
1470   //
1471   //  1. A and B have the same stride,
1472   //  2. A and B have the same memory object size, and
1473   //  3. A belongs in B's group according to its distance from B.
1474   //
1475   // Special care is taken to ensure group formation will not break any
1476   // dependences.
1477   for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend();
1478        BI != E; ++BI) {
1479     Instruction *B = BI->first;
1480     StrideDescriptor DesB = BI->second;
1481 
1482     // Initialize a group for B if it has an allowable stride. Even if we don't
1483     // create a group for B, we continue with the bottom-up algorithm to ensure
1484     // we don't break any of B's dependences.
1485     InterleaveGroup<Instruction> *GroupB = nullptr;
1486     if (isStrided(DesB.Stride) &&
1487         (!isPredicated(B->getParent()) || EnablePredicatedInterleavedMemAccesses)) {
1488       GroupB = getInterleaveGroup(B);
1489       if (!GroupB) {
1490         LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B
1491                           << '\n');
1492         GroupB = createInterleaveGroup(B, DesB.Stride, DesB.Alignment);
1493         if (B->mayWriteToMemory())
1494           StoreGroups.insert(GroupB);
1495         else
1496           LoadGroups.insert(GroupB);
1497       }
1498     }
1499 
1500     for (auto AI = std::next(BI); AI != E; ++AI) {
1501       Instruction *A = AI->first;
1502       StrideDescriptor DesA = AI->second;
1503 
1504       // Our code motion strategy implies that we can't have dependences
1505       // between accesses in an interleaved group and other accesses located
1506       // between the first and last member of the group. Note that this also
1507       // means that a group can't have more than one member at a given offset.
1508       // The accesses in a group can have dependences with other accesses, but
1509       // we must ensure we don't extend the boundaries of the group such that
1510       // we encompass those dependent accesses.
1511       //
1512       // For example, assume we have the sequence of accesses shown below in a
1513       // stride-2 loop:
1514       //
1515       //  (1, 2) is a group | A[i]   = a;  // (1)
1516       //                    | A[i-1] = b;  // (2) |
1517       //                      A[i-3] = c;  // (3)
1518       //                      A[i]   = d;  // (4) | (2, 4) is not a group
1519       //
1520       // Because accesses (2) and (3) are dependent, we can group (2) with (1)
1521       // but not with (4). If we did, the dependent access (3) would be within
1522       // the boundaries of the (2, 4) group.
1523       auto DependentMember = [&](InterleaveGroup<Instruction> *Group,
1524                                  StrideEntry *A) -> Instruction * {
1525         for (uint32_t Index = 0; Index < Group->getFactor(); ++Index) {
1526           Instruction *MemberOfGroupB = Group->getMember(Index);
1527           if (MemberOfGroupB && !canReorderMemAccessesForInterleavedGroups(
1528                                     A, &*AccessStrideInfo.find(MemberOfGroupB)))
1529             return MemberOfGroupB;
1530         }
1531         return nullptr;
1532       };
1533 
1534       auto GroupA = getInterleaveGroup(A);
1535       // If A is a load, dependencies are tolerable, there's nothing to do here.
1536       // If both A and B belong to the same (store) group, they are independent,
1537       // even if dependencies have not been recorded.
1538       // If both GroupA and GroupB are null, there's nothing to do here.
1539       if (A->mayWriteToMemory() && GroupA != GroupB) {
1540         Instruction *DependentInst = nullptr;
1541         // If GroupB is a load group, we have to compare AI against all
1542         // members of GroupB because if any load within GroupB has a dependency
1543         // on AI, we need to mark GroupB as complete and also release the
1544         // store GroupA (if A belongs to one). The former prevents incorrect
1545         // hoisting of load B above store A while the latter prevents incorrect
1546         // sinking of store A below load B.
1547         if (GroupB && LoadGroups.contains(GroupB))
1548           DependentInst = DependentMember(GroupB, &*AI);
1549         else if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI))
1550           DependentInst = B;
1551 
1552         if (DependentInst) {
1553           // A has a store dependence on B (or on some load within GroupB) and
1554           // is part of a store group. Release A's group to prevent illegal
1555           // sinking of A below B. A will then be free to form another group
1556           // with instructions that precede it.
1557           if (GroupA && StoreGroups.contains(GroupA)) {
1558             LLVM_DEBUG(dbgs() << "LV: Invalidated store group due to "
1559                                  "dependence between "
1560                               << *A << " and " << *DependentInst << '\n');
1561             StoreGroups.remove(GroupA);
1562             releaseGroup(GroupA);
1563           }
1564           // If B is a load and part of an interleave group, no earlier loads
1565           // can be added to B's interleave group, because this would mean the
1566           // DependentInst would move across store A. Mark the interleave group
1567           // as complete.
1568           if (GroupB && LoadGroups.contains(GroupB)) {
1569             LLVM_DEBUG(dbgs() << "LV: Marking interleave group for " << *B
1570                               << " as complete.\n");
1571             CompletedLoadGroups.insert(GroupB);
1572           }
1573         }
1574       }
1575       if (CompletedLoadGroups.contains(GroupB)) {
1576         // Skip trying to add A to B, continue to look for other conflicting A's
1577         // in groups to be released.
1578         continue;
1579       }
1580 
1581       // At this point, we've checked for illegal code motion. If either A or B
1582       // isn't strided, there's nothing left to do.
1583       if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride))
1584         continue;
1585 
1586       // Ignore A if it's already in a group or isn't the same kind of memory
1587       // operation as B.
1588       // Note that mayReadFromMemory() isn't mutually exclusive to
1589       // mayWriteToMemory in the case of atomic loads. We shouldn't see those
1590       // here, canVectorizeMemory() should have returned false - except for the
1591       // case we asked for optimization remarks.
1592       if (isInterleaved(A) ||
1593           (A->mayReadFromMemory() != B->mayReadFromMemory()) ||
1594           (A->mayWriteToMemory() != B->mayWriteToMemory()))
1595         continue;
1596 
1597       // Check rules 1 and 2. Ignore A if its stride or size is different from
1598       // that of B.
1599       if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size)
1600         continue;
1601 
1602       // Ignore A if the memory object of A and B don't belong to the same
1603       // address space
1604       if (getLoadStoreAddressSpace(A) != getLoadStoreAddressSpace(B))
1605         continue;
1606 
1607       // Calculate the distance from A to B.
1608       const SCEVConstant *DistToB = dyn_cast<SCEVConstant>(
1609           PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev));
1610       if (!DistToB)
1611         continue;
1612       int64_t DistanceToB = DistToB->getAPInt().getSExtValue();
1613 
1614       // Check rule 3. Ignore A if its distance to B is not a multiple of the
1615       // size.
1616       if (DistanceToB % static_cast<int64_t>(DesB.Size))
1617         continue;
1618 
1619       // All members of a predicated interleave-group must have the same predicate,
1620       // and currently must reside in the same BB.
1621       BasicBlock *BlockA = A->getParent();
1622       BasicBlock *BlockB = B->getParent();
1623       if ((isPredicated(BlockA) || isPredicated(BlockB)) &&
1624           (!EnablePredicatedInterleavedMemAccesses || BlockA != BlockB))
1625         continue;
1626 
1627       // The index of A is the index of B plus A's distance to B in multiples
1628       // of the size.
1629       int IndexA =
1630           GroupB->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size);
1631 
1632       // Try to insert A into B's group.
1633       if (GroupB->insertMember(A, IndexA, DesA.Alignment)) {
1634         LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n'
1635                           << "    into the interleave group with" << *B
1636                           << '\n');
1637         InterleaveGroupMap[A] = GroupB;
1638 
1639         // Set the first load in program order as the insert position.
1640         if (A->mayReadFromMemory())
1641           GroupB->setInsertPos(A);
1642       }
1643     } // Iteration over A accesses.
1644   }   // Iteration over B accesses.
1645 
1646   auto InvalidateGroupIfMemberMayWrap = [&](InterleaveGroup<Instruction> *Group,
1647                                             int Index,
1648                                             const char *FirstOrLast) -> bool {
1649     Instruction *Member = Group->getMember(Index);
1650     assert(Member && "Group member does not exist");
1651     Value *MemberPtr = getLoadStorePointerOperand(Member);
1652     Type *AccessTy = getLoadStoreType(Member);
1653     if (getPtrStride(PSE, AccessTy, MemberPtr, TheLoop, Strides,
1654                      /*Assume=*/false, /*ShouldCheckWrap=*/true).value_or(0))
1655       return false;
1656     LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to "
1657                       << FirstOrLast
1658                       << " group member potentially pointer-wrapping.\n");
1659     releaseGroup(Group);
1660     return true;
1661   };
1662 
1663   // Remove interleaved groups with gaps whose memory
1664   // accesses may wrap around. We have to revisit the getPtrStride analysis,
1665   // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does
1666   // not check wrapping (see documentation there).
1667   // FORNOW we use Assume=false;
1668   // TODO: Change to Assume=true but making sure we don't exceed the threshold
1669   // of runtime SCEV assumptions checks (thereby potentially failing to
1670   // vectorize altogether).
1671   // Additional optional optimizations:
1672   // TODO: If we are peeling the loop and we know that the first pointer doesn't
1673   // wrap then we can deduce that all pointers in the group don't wrap.
1674   // This means that we can forcefully peel the loop in order to only have to
1675   // check the first pointer for no-wrap. When we'll change to use Assume=true
1676   // we'll only need at most one runtime check per interleaved group.
1677   for (auto *Group : LoadGroups) {
1678     // Case 1: A full group. Can Skip the checks; For full groups, if the wide
1679     // load would wrap around the address space we would do a memory access at
1680     // nullptr even without the transformation.
1681     if (Group->getNumMembers() == Group->getFactor())
1682       continue;
1683 
1684     // Case 2: If first and last members of the group don't wrap this implies
1685     // that all the pointers in the group don't wrap.
1686     // So we check only group member 0 (which is always guaranteed to exist),
1687     // and group member Factor - 1; If the latter doesn't exist we rely on
1688     // peeling (if it is a non-reversed access -- see Case 3).
1689     if (InvalidateGroupIfMemberMayWrap(Group, 0, "first"))
1690       continue;
1691     if (Group->getMember(Group->getFactor() - 1))
1692       InvalidateGroupIfMemberMayWrap(Group, Group->getFactor() - 1, "last");
1693     else {
1694       // Case 3: A non-reversed interleaved load group with gaps: We need
1695       // to execute at least one scalar epilogue iteration. This will ensure
1696       // we don't speculatively access memory out-of-bounds. We only need
1697       // to look for a member at index factor - 1, since every group must have
1698       // a member at index zero.
1699       if (Group->isReverse()) {
1700         LLVM_DEBUG(
1701             dbgs() << "LV: Invalidate candidate interleaved group due to "
1702                       "a reverse access with gaps.\n");
1703         releaseGroup(Group);
1704         continue;
1705       }
1706       LLVM_DEBUG(
1707           dbgs() << "LV: Interleaved group requires epilogue iteration.\n");
1708       RequiresScalarEpilogue = true;
1709     }
1710   }
1711 
1712   for (auto *Group : StoreGroups) {
1713     // Case 1: A full group. Can Skip the checks; For full groups, if the wide
1714     // store would wrap around the address space we would do a memory access at
1715     // nullptr even without the transformation.
1716     if (Group->getNumMembers() == Group->getFactor())
1717       continue;
1718 
1719     // Interleave-store-group with gaps is implemented using masked wide store.
1720     // Remove interleaved store groups with gaps if
1721     // masked-interleaved-accesses are not enabled by the target.
1722     if (!EnablePredicatedInterleavedMemAccesses) {
1723       LLVM_DEBUG(
1724           dbgs() << "LV: Invalidate candidate interleaved store group due "
1725                     "to gaps.\n");
1726       releaseGroup(Group);
1727       continue;
1728     }
1729 
1730     // Case 2: If first and last members of the group don't wrap this implies
1731     // that all the pointers in the group don't wrap.
1732     // So we check only group member 0 (which is always guaranteed to exist),
1733     // and the last group member. Case 3 (scalar epilog) is not relevant for
1734     // stores with gaps, which are implemented with masked-store (rather than
1735     // speculative access, as in loads).
1736     if (InvalidateGroupIfMemberMayWrap(Group, 0, "first"))
1737       continue;
1738     for (int Index = Group->getFactor() - 1; Index > 0; Index--)
1739       if (Group->getMember(Index)) {
1740         InvalidateGroupIfMemberMayWrap(Group, Index, "last");
1741         break;
1742       }
1743   }
1744 }
1745 
1746 void InterleavedAccessInfo::invalidateGroupsRequiringScalarEpilogue() {
1747   // If no group had triggered the requirement to create an epilogue loop,
1748   // there is nothing to do.
1749   if (!requiresScalarEpilogue())
1750     return;
1751 
1752   // Release groups requiring scalar epilogues. Note that this also removes them
1753   // from InterleaveGroups.
1754   bool ReleasedGroup = InterleaveGroups.remove_if([&](auto *Group) {
1755     if (!Group->requiresScalarEpilogue())
1756       return false;
1757     LLVM_DEBUG(
1758         dbgs()
1759         << "LV: Invalidate candidate interleaved group due to gaps that "
1760            "require a scalar epilogue (not allowed under optsize) and cannot "
1761            "be masked (not enabled). \n");
1762     releaseGroupWithoutRemovingFromSet(Group);
1763     return true;
1764   });
1765   assert(ReleasedGroup && "At least one group must be invalidated, as a "
1766                           "scalar epilogue was required");
1767   (void)ReleasedGroup;
1768   RequiresScalarEpilogue = false;
1769 }
1770 
1771 template <typename InstT>
1772 void InterleaveGroup<InstT>::addMetadata(InstT *NewInst) const {
1773   llvm_unreachable("addMetadata can only be used for Instruction");
1774 }
1775 
1776 namespace llvm {
1777 template <>
1778 void InterleaveGroup<Instruction>::addMetadata(Instruction *NewInst) const {
1779   SmallVector<Value *, 4> VL(make_second_range(Members));
1780   propagateMetadata(NewInst, VL);
1781 }
1782 } // namespace llvm
1783