xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/VectorUtils.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15 
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/VFABIDemangler.h"
21 #include "llvm/Support/CheckedArithmetic.h"
22 
23 namespace llvm {
24 class TargetLibraryInfo;
25 
26 /// The Vector Function Database.
27 ///
28 /// Helper class used to find the vector functions associated to a
29 /// scalar CallInst.
30 class VFDatabase {
31   /// The Module of the CallInst CI.
32   const Module *M;
33   /// The CallInst instance being queried for scalar to vector mappings.
34   const CallInst &CI;
35   /// List of vector functions descriptors associated to the call
36   /// instruction.
37   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
38 
39   /// Retrieve the scalar-to-vector mappings associated to the rule of
40   /// a vector Function ABI.
getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)41   static void getVFABIMappings(const CallInst &CI,
42                                SmallVectorImpl<VFInfo> &Mappings) {
43     if (!CI.getCalledFunction())
44       return;
45 
46     const StringRef ScalarName = CI.getCalledFunction()->getName();
47 
48     SmallVector<std::string, 8> ListOfStrings;
49     // The check for the vector-function-abi-variant attribute is done when
50     // retrieving the vector variant names here.
51     VFABI::getVectorVariantNames(CI, ListOfStrings);
52     if (ListOfStrings.empty())
53       return;
54     for (const auto &MangledName : ListOfStrings) {
55       const std::optional<VFInfo> Shape =
56           VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
57       // A match is found via scalar and vector names, and also by
58       // ensuring that the variant described in the attribute has a
59       // corresponding definition or declaration of the vector
60       // function in the Module M.
61       if (Shape && (Shape->ScalarName == ScalarName)) {
62         assert(CI.getModule()->getFunction(Shape->VectorName) &&
63                "Vector function is missing.");
64         Mappings.push_back(*Shape);
65       }
66     }
67   }
68 
69 public:
70   /// Retrieve all the VFInfo instances associated to the CallInst CI.
getMappings(const CallInst & CI)71   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
72     SmallVector<VFInfo, 8> Ret;
73 
74     // Get mappings from the Vector Function ABI variants.
75     getVFABIMappings(CI, Ret);
76 
77     // Other non-VFABI variants should be retrieved here.
78 
79     return Ret;
80   }
81 
82   static bool hasMaskedVariant(const CallInst &CI,
83                                std::optional<ElementCount> VF = std::nullopt) {
84     // Check whether we have at least one masked vector version of a scalar
85     // function. If no VF is specified then we check for any masked variant,
86     // otherwise we look for one that matches the supplied VF.
87     auto Mappings = VFDatabase::getMappings(CI);
88     for (VFInfo Info : Mappings)
89       if (!VF || Info.Shape.VF == *VF)
90         if (Info.isMasked())
91           return true;
92 
93     return false;
94   }
95 
96   /// Constructor, requires a CallInst instance.
VFDatabase(CallInst & CI)97   VFDatabase(CallInst &CI)
98       : M(CI.getModule()), CI(CI),
99         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
100 
101   /// \defgroup VFDatabase query interface.
102   ///
103   /// @{
104   /// Retrieve the Function with VFShape \p Shape.
getVectorizedFunction(const VFShape & Shape)105   Function *getVectorizedFunction(const VFShape &Shape) const {
106     if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
107       return CI.getCalledFunction();
108 
109     for (const auto &Info : ScalarToVectorMappings)
110       if (Info.Shape == Shape)
111         return M->getFunction(Info.VectorName);
112 
113     return nullptr;
114   }
115   /// @}
116 };
117 
118 template <typename T> class ArrayRef;
119 class DemandedBits;
120 template <typename InstTy> class InterleaveGroup;
121 class IRBuilderBase;
122 class Loop;
123 class ScalarEvolution;
124 class TargetTransformInfo;
125 class Type;
126 class Value;
127 
128 namespace Intrinsic {
129 typedef unsigned ID;
130 }
131 
132 /// A helper function for converting Scalar types to vector types. If
133 /// the incoming type is void, we return void. If the EC represents a
134 /// scalar, we return the scalar type.
ToVectorTy(Type * Scalar,ElementCount EC)135 inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
136   if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
137     return Scalar;
138   return VectorType::get(Scalar, EC);
139 }
140 
ToVectorTy(Type * Scalar,unsigned VF)141 inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
142   return ToVectorTy(Scalar, ElementCount::getFixed(VF));
143 }
144 
145 /// Identify if the intrinsic is trivially vectorizable.
146 /// This method returns true if the intrinsic's argument types are all scalars
147 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
148 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
149 bool isTriviallyVectorizable(Intrinsic::ID ID);
150 
151 /// Identifies if the vector form of the intrinsic has a scalar operand.
152 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
153                                         unsigned ScalarOpdIdx);
154 
155 /// Identifies if the vector form of the intrinsic is overloaded on the type of
156 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
157 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
158 
159 /// Returns intrinsic ID for call.
160 /// For the input call instruction it finds mapping intrinsic and returns
161 /// its intrinsic ID, in case it does not found it return not_intrinsic.
162 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
163                                           const TargetLibraryInfo *TLI);
164 
165 /// Given a vector and an element number, see if the scalar value is
166 /// already around as a register, for example if it were inserted then extracted
167 /// from the vector.
168 Value *findScalarElement(Value *V, unsigned EltNo);
169 
170 /// If all non-negative \p Mask elements are the same value, return that value.
171 /// If all elements are negative (undefined) or \p Mask contains different
172 /// non-negative values, return -1.
173 int getSplatIndex(ArrayRef<int> Mask);
174 
175 /// Get splat value if the input is a splat vector or return nullptr.
176 /// The value may be extracted from a splat constants vector or from
177 /// a sequence of instructions that broadcast a single value into a vector.
178 Value *getSplatValue(const Value *V);
179 
180 /// Return true if each element of the vector value \p V is poisoned or equal to
181 /// every other non-poisoned element. If an index element is specified, either
182 /// every element of the vector is poisoned or the element at that index is not
183 /// poisoned and equal to every other non-poisoned element.
184 /// This may be more powerful than the related getSplatValue() because it is
185 /// not limited by finding a scalar source value to a splatted vector.
186 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
187 
188 /// Transform a shuffle mask's output demanded element mask into demanded
189 /// element masks for the 2 operands, returns false if the mask isn't valid.
190 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
191 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
192 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
193                             const APInt &DemandedElts, APInt &DemandedLHS,
194                             APInt &DemandedRHS, bool AllowUndefElts = false);
195 
196 /// Replace each shuffle mask index with the scaled sequential indices for an
197 /// equivalent mask of narrowed elements. Mask elements that are less than 0
198 /// (sentinel values) are repeated in the output mask.
199 ///
200 /// Example with Scale = 4:
201 ///   <4 x i32> <3, 2, 0, -1> -->
202 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
203 ///
204 /// This is the reverse process of widening shuffle mask elements, but it always
205 /// succeeds because the indexes can always be multiplied (scaled up) to map to
206 /// narrower vector elements.
207 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
208                            SmallVectorImpl<int> &ScaledMask);
209 
210 /// Try to transform a shuffle mask by replacing elements with the scaled index
211 /// for an equivalent mask of widened elements. If all mask elements that would
212 /// map to a wider element of the new mask are the same negative number
213 /// (sentinel value), that element of the new mask is the same value. If any
214 /// element in a given slice is negative and some other element in that slice is
215 /// not the same value, return false (partial matches with sentinel values are
216 /// not allowed).
217 ///
218 /// Example with Scale = 4:
219 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
220 ///   <4 x i32> <3, 2, 0, -1>
221 ///
222 /// This is the reverse process of narrowing shuffle mask elements if it
223 /// succeeds. This transform is not always possible because indexes may not
224 /// divide evenly (scale down) to map to wider vector elements.
225 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
226                           SmallVectorImpl<int> &ScaledMask);
227 
228 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target
229 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts.
230 /// This will assert unless NumDstElts is a multiple of Mask.size (or vice-versa).
231 /// Returns false on failure, and ScaledMask will be in an undefined state.
232 bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask,
233                           SmallVectorImpl<int> &ScaledMask);
234 
235 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
236 /// to get the shuffle mask with widest possible elements.
237 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
238                                   SmallVectorImpl<int> &ScaledMask);
239 
240 /// Splits and processes shuffle mask depending on the number of input and
241 /// output registers. The function does 2 main things: 1) splits the
242 /// source/destination vectors into real registers; 2) do the mask analysis to
243 /// identify which real registers are permuted. Then the function processes
244 /// resulting registers mask using provided action items. If no input register
245 /// is defined, \p NoInputAction action is used. If only 1 input register is
246 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
247 /// process > 2 input registers and masks.
248 /// \param Mask Original shuffle mask.
249 /// \param NumOfSrcRegs Number of source registers.
250 /// \param NumOfDestRegs Number of destination registers.
251 /// \param NumOfUsedRegs Number of actually used destination registers.
252 void processShuffleMasks(
253     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
254     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
255     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
256     function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
257 
258 /// Compute the demanded elements mask of horizontal binary operations. A
259 /// horizontal operation combines two adjacent elements in a vector operand.
260 /// This function returns a mask for the elements that correspond to the first
261 /// operand of this horizontal combination. For example, for two vectors
262 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
263 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
264 /// result of this function to the left by 1.
265 ///
266 /// \param VectorBitWidth the total bit width of the vector
267 /// \param DemandedElts   the demanded elements mask for the operation
268 /// \param DemandedLHS    the demanded elements mask for the left operand
269 /// \param DemandedRHS    the demanded elements mask for the right operand
270 void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
271                                          const APInt &DemandedElts,
272                                          APInt &DemandedLHS,
273                                          APInt &DemandedRHS);
274 
275 /// Compute a map of integer instructions to their minimum legal type
276 /// size.
277 ///
278 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
279 /// type (e.g. i32) whenever arithmetic is performed on them.
280 ///
281 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
282 /// the arithmetic type down again. However InstCombine refuses to create
283 /// illegal types, so for targets without i8 or i16 registers, the lengthening
284 /// and shrinking remains.
285 ///
286 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
287 /// their scalar equivalents do not, so during vectorization it is important to
288 /// remove these lengthens and truncates when deciding the profitability of
289 /// vectorization.
290 ///
291 /// This function analyzes the given range of instructions and determines the
292 /// minimum type size each can be converted to. It attempts to remove or
293 /// minimize type size changes across each def-use chain, so for example in the
294 /// following code:
295 ///
296 ///   %1 = load i8, i8*
297 ///   %2 = add i8 %1, 2
298 ///   %3 = load i16, i16*
299 ///   %4 = zext i8 %2 to i32
300 ///   %5 = zext i16 %3 to i32
301 ///   %6 = add i32 %4, %5
302 ///   %7 = trunc i32 %6 to i16
303 ///
304 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
305 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
306 ///
307 /// If the optional TargetTransformInfo is provided, this function tries harder
308 /// to do less work by only looking at illegal types.
309 MapVector<Instruction*, uint64_t>
310 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
311                          DemandedBits &DB,
312                          const TargetTransformInfo *TTI=nullptr);
313 
314 /// Compute the union of two access-group lists.
315 ///
316 /// If the list contains just one access group, it is returned directly. If the
317 /// list is empty, returns nullptr.
318 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
319 
320 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
321 /// are both in. If either instruction does not access memory at all, it is
322 /// considered to be in every list.
323 ///
324 /// If the list contains just one access group, it is returned directly. If the
325 /// list is empty, returns nullptr.
326 MDNode *intersectAccessGroups(const Instruction *Inst1,
327                               const Instruction *Inst2);
328 
329 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
330 /// MD_nontemporal, MD_access_group, MD_mmra].
331 /// For K in Kinds, we get the MDNode for K from each of the
332 /// elements of VL, compute their "intersection" (i.e., the most generic
333 /// metadata value that covers all of the individual values), and set I's
334 /// metadata for M equal to the intersection value.
335 ///
336 /// This function always sets a (possibly null) value for each K in Kinds.
337 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
338 
339 /// Create a mask that filters the members of an interleave group where there
340 /// are gaps.
341 ///
342 /// For example, the mask for \p Group with interleave-factor 3
343 /// and \p VF 4, that has only its first member present is:
344 ///
345 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
346 ///
347 /// Note: The result is a mask of 0's and 1's, as opposed to the other
348 /// create[*]Mask() utilities which create a shuffle mask (mask that
349 /// consists of indices).
350 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
351                                const InterleaveGroup<Instruction> &Group);
352 
353 /// Create a mask with replicated elements.
354 ///
355 /// This function creates a shuffle mask for replicating each of the \p VF
356 /// elements in a vector \p ReplicationFactor times. It can be used to
357 /// transform a mask of \p VF elements into a mask of
358 /// \p VF * \p ReplicationFactor elements used by a predicated
359 /// interleaved-group of loads/stores whose Interleaved-factor ==
360 /// \p ReplicationFactor.
361 ///
362 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
363 ///
364 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
365 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
366                                                 unsigned VF);
367 
368 /// Create an interleave shuffle mask.
369 ///
370 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
371 /// vectorization factor \p VF into a single wide vector. The mask is of the
372 /// form:
373 ///
374 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
375 ///
376 /// For example, the mask for VF = 4 and NumVecs = 2 is:
377 ///
378 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
379 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
380 
381 /// Create a stride shuffle mask.
382 ///
383 /// This function creates a shuffle mask whose elements begin at \p Start and
384 /// are incremented by \p Stride. The mask can be used to deinterleave an
385 /// interleaved vector into separate vectors of vectorization factor \p VF. The
386 /// mask is of the form:
387 ///
388 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
389 ///
390 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
391 ///
392 ///   <0, 2, 4, 6>
393 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
394                                             unsigned VF);
395 
396 /// Create a sequential shuffle mask.
397 ///
398 /// This function creates shuffle mask whose elements are sequential and begin
399 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
400 /// NumUndefs undef values. The mask is of the form:
401 ///
402 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
403 ///
404 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
405 ///
406 ///   <0, 1, 2, 3, undef, undef, undef, undef>
407 llvm::SmallVector<int, 16>
408 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
409 
410 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
411 /// mask assuming both operands are identical. This assumes that the unary
412 /// shuffle will use elements from operand 0 (operand 1 will be unused).
413 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
414                                            unsigned NumElts);
415 
416 /// Concatenate a list of vectors.
417 ///
418 /// This function generates code that concatenate the vectors in \p Vecs into a
419 /// single large vector. The number of vectors should be greater than one, and
420 /// their element types should be the same. The number of elements in the
421 /// vectors should also be the same; however, if the last vector has fewer
422 /// elements, it will be padded with undefs.
423 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
424 
425 /// Given a mask vector of i1, Return true if all of the elements of this
426 /// predicate mask are known to be false or undef.  That is, return true if all
427 /// lanes can be assumed inactive.
428 bool maskIsAllZeroOrUndef(Value *Mask);
429 
430 /// Given a mask vector of i1, Return true if all of the elements of this
431 /// predicate mask are known to be true or undef.  That is, return true if all
432 /// lanes can be assumed active.
433 bool maskIsAllOneOrUndef(Value *Mask);
434 
435 /// Given a mask vector of i1, Return true if any of the elements of this
436 /// predicate mask are known to be true or undef.  That is, return true if at
437 /// least one lane can be assumed active.
438 bool maskContainsAllOneOrUndef(Value *Mask);
439 
440 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
441 /// for each lane which may be active.
442 APInt possiblyDemandedEltsInMask(Value *Mask);
443 
444 /// The group of interleaved loads/stores sharing the same stride and
445 /// close to each other.
446 ///
447 /// Each member in this group has an index starting from 0, and the largest
448 /// index should be less than interleaved factor, which is equal to the absolute
449 /// value of the access's stride.
450 ///
451 /// E.g. An interleaved load group of factor 4:
452 ///        for (unsigned i = 0; i < 1024; i+=4) {
453 ///          a = A[i];                           // Member of index 0
454 ///          b = A[i+1];                         // Member of index 1
455 ///          d = A[i+3];                         // Member of index 3
456 ///          ...
457 ///        }
458 ///
459 ///      An interleaved store group of factor 4:
460 ///        for (unsigned i = 0; i < 1024; i+=4) {
461 ///          ...
462 ///          A[i]   = a;                         // Member of index 0
463 ///          A[i+1] = b;                         // Member of index 1
464 ///          A[i+2] = c;                         // Member of index 2
465 ///          A[i+3] = d;                         // Member of index 3
466 ///        }
467 ///
468 /// Note: the interleaved load group could have gaps (missing members), but
469 /// the interleaved store group doesn't allow gaps.
470 template <typename InstTy> class InterleaveGroup {
471 public:
InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)472   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
473       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
474         InsertPos(nullptr) {}
475 
InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)476   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
477       : Alignment(Alignment), InsertPos(Instr) {
478     Factor = std::abs(Stride);
479     assert(Factor > 1 && "Invalid interleave factor");
480 
481     Reverse = Stride < 0;
482     Members[0] = Instr;
483   }
484 
isReverse()485   bool isReverse() const { return Reverse; }
getFactor()486   uint32_t getFactor() const { return Factor; }
getAlign()487   Align getAlign() const { return Alignment; }
getNumMembers()488   uint32_t getNumMembers() const { return Members.size(); }
489 
490   /// Try to insert a new member \p Instr with index \p Index and
491   /// alignment \p NewAlign. The index is related to the leader and it could be
492   /// negative if it is the new leader.
493   ///
494   /// \returns false if the instruction doesn't belong to the group.
insertMember(InstTy * Instr,int32_t Index,Align NewAlign)495   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
496     // Make sure the key fits in an int32_t.
497     std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
498     if (!MaybeKey)
499       return false;
500     int32_t Key = *MaybeKey;
501 
502     // Skip if the key is used for either the tombstone or empty special values.
503     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
504         DenseMapInfo<int32_t>::getEmptyKey() == Key)
505       return false;
506 
507     // Skip if there is already a member with the same index.
508     if (Members.contains(Key))
509       return false;
510 
511     if (Key > LargestKey) {
512       // The largest index is always less than the interleave factor.
513       if (Index >= static_cast<int32_t>(Factor))
514         return false;
515 
516       LargestKey = Key;
517     } else if (Key < SmallestKey) {
518 
519       // Make sure the largest index fits in an int32_t.
520       std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
521       if (!MaybeLargestIndex)
522         return false;
523 
524       // The largest index is always less than the interleave factor.
525       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
526         return false;
527 
528       SmallestKey = Key;
529     }
530 
531     // It's always safe to select the minimum alignment.
532     Alignment = std::min(Alignment, NewAlign);
533     Members[Key] = Instr;
534     return true;
535   }
536 
537   /// Get the member with the given index \p Index
538   ///
539   /// \returns nullptr if contains no such member.
getMember(uint32_t Index)540   InstTy *getMember(uint32_t Index) const {
541     int32_t Key = SmallestKey + Index;
542     return Members.lookup(Key);
543   }
544 
545   /// Get the index for the given member. Unlike the key in the member
546   /// map, the index starts from 0.
getIndex(const InstTy * Instr)547   uint32_t getIndex(const InstTy *Instr) const {
548     for (auto I : Members) {
549       if (I.second == Instr)
550         return I.first - SmallestKey;
551     }
552 
553     llvm_unreachable("InterleaveGroup contains no such member");
554   }
555 
getInsertPos()556   InstTy *getInsertPos() const { return InsertPos; }
setInsertPos(InstTy * Inst)557   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
558 
559   /// Add metadata (e.g. alias info) from the instructions in this group to \p
560   /// NewInst.
561   ///
562   /// FIXME: this function currently does not add noalias metadata a'la
563   /// addNewMedata.  To do that we need to compute the intersection of the
564   /// noalias info from all members.
565   void addMetadata(InstTy *NewInst) const;
566 
567   /// Returns true if this Group requires a scalar iteration to handle gaps.
requiresScalarEpilogue()568   bool requiresScalarEpilogue() const {
569     // If the last member of the Group exists, then a scalar epilog is not
570     // needed for this group.
571     if (getMember(getFactor() - 1))
572       return false;
573 
574     // We have a group with gaps. It therefore can't be a reversed access,
575     // because such groups get invalidated (TODO).
576     assert(!isReverse() && "Group should have been invalidated");
577 
578     // This is a group of loads, with gaps, and without a last-member
579     return true;
580   }
581 
582 private:
583   uint32_t Factor; // Interleave Factor.
584   bool Reverse;
585   Align Alignment;
586   DenseMap<int32_t, InstTy *> Members;
587   int32_t SmallestKey = 0;
588   int32_t LargestKey = 0;
589 
590   // To avoid breaking dependences, vectorized instructions of an interleave
591   // group should be inserted at either the first load or the last store in
592   // program order.
593   //
594   // E.g. %even = load i32             // Insert Position
595   //      %add = add i32 %even         // Use of %even
596   //      %odd = load i32
597   //
598   //      store i32 %even
599   //      %odd = add i32               // Def of %odd
600   //      store i32 %odd               // Insert Position
601   InstTy *InsertPos;
602 };
603 
604 /// Drive the analysis of interleaved memory accesses in the loop.
605 ///
606 /// Use this class to analyze interleaved accesses only when we can vectorize
607 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
608 /// on interleaved accesses is unsafe.
609 ///
610 /// The analysis collects interleave groups and records the relationships
611 /// between the member and the group in a map.
612 class InterleavedAccessInfo {
613 public:
InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)614   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
615                         DominatorTree *DT, LoopInfo *LI,
616                         const LoopAccessInfo *LAI)
617       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
618 
~InterleavedAccessInfo()619   ~InterleavedAccessInfo() { invalidateGroups(); }
620 
621   /// Analyze the interleaved accesses and collect them in interleave
622   /// groups. Substitute symbolic strides using \p Strides.
623   /// Consider also predicated loads/stores in the analysis if
624   /// \p EnableMaskedInterleavedGroup is true.
625   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
626 
627   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
628   /// contrary to original assumption. Although we currently prevent group
629   /// formation for predicated accesses, we may be able to relax this limitation
630   /// in the future once we handle more complicated blocks. Returns true if any
631   /// groups were invalidated.
invalidateGroups()632   bool invalidateGroups() {
633     if (InterleaveGroups.empty()) {
634       assert(
635           !RequiresScalarEpilogue &&
636           "RequiresScalarEpilog should not be set without interleave groups");
637       return false;
638     }
639 
640     InterleaveGroupMap.clear();
641     for (auto *Ptr : InterleaveGroups)
642       delete Ptr;
643     InterleaveGroups.clear();
644     RequiresScalarEpilogue = false;
645     return true;
646   }
647 
648   /// Check if \p Instr belongs to any interleave group.
isInterleaved(Instruction * Instr)649   bool isInterleaved(Instruction *Instr) const {
650     return InterleaveGroupMap.contains(Instr);
651   }
652 
653   /// Get the interleave group that \p Instr belongs to.
654   ///
655   /// \returns nullptr if doesn't have such group.
656   InterleaveGroup<Instruction> *
getInterleaveGroup(const Instruction * Instr)657   getInterleaveGroup(const Instruction *Instr) const {
658     return InterleaveGroupMap.lookup(Instr);
659   }
660 
661   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
getInterleaveGroups()662   getInterleaveGroups() {
663     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
664   }
665 
666   /// Returns true if an interleaved group that may access memory
667   /// out-of-bounds requires a scalar epilogue iteration for correctness.
requiresScalarEpilogue()668   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
669 
670   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
671   /// happen when optimizing for size forbids a scalar epilogue, and the gap
672   /// cannot be filtered by masking the load/store.
673   void invalidateGroupsRequiringScalarEpilogue();
674 
675   /// Returns true if we have any interleave groups.
hasGroups()676   bool hasGroups() const { return !InterleaveGroups.empty(); }
677 
678 private:
679   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
680   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
681   /// The interleaved access analysis can also add new predicates (for example
682   /// by versioning strides of pointers).
683   PredicatedScalarEvolution &PSE;
684 
685   Loop *TheLoop;
686   DominatorTree *DT;
687   LoopInfo *LI;
688   const LoopAccessInfo *LAI;
689 
690   /// True if the loop may contain non-reversed interleaved groups with
691   /// out-of-bounds accesses. We ensure we don't speculatively access memory
692   /// out-of-bounds by executing at least one scalar epilogue iteration.
693   bool RequiresScalarEpilogue = false;
694 
695   /// Holds the relationships between the members and the interleave group.
696   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
697 
698   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
699 
700   /// Holds dependences among the memory accesses in the loop. It maps a source
701   /// access to a set of dependent sink accesses.
702   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
703 
704   /// The descriptor for a strided memory access.
705   struct StrideDescriptor {
706     StrideDescriptor() = default;
StrideDescriptorStrideDescriptor707     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
708                      Align Alignment)
709         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
710 
711     // The access's stride. It is negative for a reverse access.
712     int64_t Stride = 0;
713 
714     // The scalar expression of this access.
715     const SCEV *Scev = nullptr;
716 
717     // The size of the memory object.
718     uint64_t Size = 0;
719 
720     // The alignment of this access.
721     Align Alignment;
722   };
723 
724   /// A type for holding instructions and their stride descriptors.
725   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
726 
727   /// Create a new interleave group with the given instruction \p Instr,
728   /// stride \p Stride and alignment \p Align.
729   ///
730   /// \returns the newly created interleave group.
731   InterleaveGroup<Instruction> *
createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)732   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
733     assert(!InterleaveGroupMap.count(Instr) &&
734            "Already in an interleaved access group");
735     InterleaveGroupMap[Instr] =
736         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
737     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
738     return InterleaveGroupMap[Instr];
739   }
740 
741   /// Release the group and remove all the relationships.
releaseGroup(InterleaveGroup<Instruction> * Group)742   void releaseGroup(InterleaveGroup<Instruction> *Group) {
743     InterleaveGroups.erase(Group);
744     releaseGroupWithoutRemovingFromSet(Group);
745   }
746 
747   /// Do everything necessary to release the group, apart from removing it from
748   /// the InterleaveGroups set.
releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> * Group)749   void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) {
750     for (unsigned i = 0; i < Group->getFactor(); i++)
751       if (Instruction *Member = Group->getMember(i))
752         InterleaveGroupMap.erase(Member);
753 
754     delete Group;
755   }
756 
757   /// Collect all the accesses with a constant stride in program order.
758   void collectConstStrideAccesses(
759       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
760       const DenseMap<Value *, const SCEV *> &Strides);
761 
762   /// Returns true if \p Stride is allowed in an interleaved group.
763   static bool isStrided(int Stride);
764 
765   /// Returns true if \p BB is a predicated block.
isPredicated(BasicBlock * BB)766   bool isPredicated(BasicBlock *BB) const {
767     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
768   }
769 
770   /// Returns true if LoopAccessInfo can be used for dependence queries.
areDependencesValid()771   bool areDependencesValid() const {
772     return LAI && LAI->getDepChecker().getDependences();
773   }
774 
775   /// Returns true if memory accesses \p A and \p B can be reordered, if
776   /// necessary, when constructing interleaved groups.
777   ///
778   /// \p A must precede \p B in program order. We return false if reordering is
779   /// not necessary or is prevented because \p A and \p B may be dependent.
canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)780   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
781                                                  StrideEntry *B) const {
782     // Code motion for interleaved accesses can potentially hoist strided loads
783     // and sink strided stores. The code below checks the legality of the
784     // following two conditions:
785     //
786     // 1. Potentially moving a strided load (B) before any store (A) that
787     //    precedes B, or
788     //
789     // 2. Potentially moving a strided store (A) after any load or store (B)
790     //    that A precedes.
791     //
792     // It's legal to reorder A and B if we know there isn't a dependence from A
793     // to B. Note that this determination is conservative since some
794     // dependences could potentially be reordered safely.
795 
796     // A is potentially the source of a dependence.
797     auto *Src = A->first;
798     auto SrcDes = A->second;
799 
800     // B is potentially the sink of a dependence.
801     auto *Sink = B->first;
802     auto SinkDes = B->second;
803 
804     // Code motion for interleaved accesses can't violate WAR dependences.
805     // Thus, reordering is legal if the source isn't a write.
806     if (!Src->mayWriteToMemory())
807       return true;
808 
809     // At least one of the accesses must be strided.
810     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
811       return true;
812 
813     // If dependence information is not available from LoopAccessInfo,
814     // conservatively assume the instructions can't be reordered.
815     if (!areDependencesValid())
816       return false;
817 
818     // If we know there is a dependence from source to sink, assume the
819     // instructions can't be reordered. Otherwise, reordering is legal.
820     return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
821   }
822 
823   /// Collect the dependences from LoopAccessInfo.
824   ///
825   /// We process the dependences once during the interleaved access analysis to
826   /// enable constant-time dependence queries.
collectDependences()827   void collectDependences() {
828     if (!areDependencesValid())
829       return;
830     const auto &DepChecker = LAI->getDepChecker();
831     auto *Deps = DepChecker.getDependences();
832     for (auto Dep : *Deps)
833       Dependences[Dep.getSource(DepChecker)].insert(
834           Dep.getDestination(DepChecker));
835   }
836 };
837 
838 } // llvm namespace
839 
840 #endif
841