xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Analysis/VectorUtils.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines some vectorizer utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
14 #define LLVM_ANALYSIS_VECTORUTILS_H
15 
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/IR/Module.h"
20 #include "llvm/IR/VFABIDemangler.h"
21 #include "llvm/IR/VectorTypeUtils.h"
22 #include "llvm/Support/CheckedArithmetic.h"
23 #include "llvm/Support/Compiler.h"
24 
25 namespace llvm {
26 class TargetLibraryInfo;
27 
28 /// The Vector Function Database.
29 ///
30 /// Helper class used to find the vector functions associated to a
31 /// scalar CallInst.
32 class VFDatabase {
33   /// The Module of the CallInst CI.
34   const Module *M;
35   /// The CallInst instance being queried for scalar to vector mappings.
36   const CallInst &CI;
37   /// List of vector functions descriptors associated to the call
38   /// instruction.
39   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
40 
41   /// Retrieve the scalar-to-vector mappings associated to the rule of
42   /// a vector Function ABI.
getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)43   static void getVFABIMappings(const CallInst &CI,
44                                SmallVectorImpl<VFInfo> &Mappings) {
45     if (!CI.getCalledFunction())
46       return;
47 
48     const StringRef ScalarName = CI.getCalledFunction()->getName();
49 
50     SmallVector<std::string, 8> ListOfStrings;
51     // The check for the vector-function-abi-variant attribute is done when
52     // retrieving the vector variant names here.
53     VFABI::getVectorVariantNames(CI, ListOfStrings);
54     if (ListOfStrings.empty())
55       return;
56     for (const auto &MangledName : ListOfStrings) {
57       const std::optional<VFInfo> Shape =
58           VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
59       // A match is found via scalar and vector names, and also by
60       // ensuring that the variant described in the attribute has a
61       // corresponding definition or declaration of the vector
62       // function in the Module M.
63       if (Shape && (Shape->ScalarName == ScalarName)) {
64         assert(CI.getModule()->getFunction(Shape->VectorName) &&
65                "Vector function is missing.");
66         Mappings.push_back(*Shape);
67       }
68     }
69   }
70 
71 public:
72   /// Retrieve all the VFInfo instances associated to the CallInst CI.
getMappings(const CallInst & CI)73   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
74     SmallVector<VFInfo, 8> Ret;
75 
76     // Get mappings from the Vector Function ABI variants.
77     getVFABIMappings(CI, Ret);
78 
79     // Other non-VFABI variants should be retrieved here.
80 
81     return Ret;
82   }
83 
84   static bool hasMaskedVariant(const CallInst &CI,
85                                std::optional<ElementCount> VF = std::nullopt) {
86     // Check whether we have at least one masked vector version of a scalar
87     // function. If no VF is specified then we check for any masked variant,
88     // otherwise we look for one that matches the supplied VF.
89     auto Mappings = VFDatabase::getMappings(CI);
90     for (VFInfo Info : Mappings)
91       if (!VF || Info.Shape.VF == *VF)
92         if (Info.isMasked())
93           return true;
94 
95     return false;
96   }
97 
98   /// Constructor, requires a CallInst instance.
VFDatabase(CallInst & CI)99   VFDatabase(CallInst &CI)
100       : M(CI.getModule()), CI(CI),
101         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
102 
103   /// \defgroup VFDatabase query interface.
104   ///
105   /// @{
106   /// Retrieve the Function with VFShape \p Shape.
getVectorizedFunction(const VFShape & Shape)107   Function *getVectorizedFunction(const VFShape &Shape) const {
108     if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
109       return CI.getCalledFunction();
110 
111     for (const auto &Info : ScalarToVectorMappings)
112       if (Info.Shape == Shape)
113         return M->getFunction(Info.VectorName);
114 
115     return nullptr;
116   }
117   /// @}
118 };
119 
120 template <typename T> class ArrayRef;
121 class DemandedBits;
122 template <typename InstTy> class InterleaveGroup;
123 class IRBuilderBase;
124 class Loop;
125 class TargetTransformInfo;
126 class Value;
127 
128 namespace Intrinsic {
129 typedef unsigned ID;
130 }
131 
132 /// Identify if the intrinsic is trivially vectorizable.
133 /// This method returns true if the intrinsic's argument types are all scalars
134 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
135 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
136 ///
137 /// Note: isTriviallyVectorizable implies isTriviallyScalarizable.
138 LLVM_ABI bool isTriviallyVectorizable(Intrinsic::ID ID);
139 
140 /// Identify if the intrinsic is trivially scalarizable.
141 /// This method returns true following the same predicates of
142 /// isTriviallyVectorizable.
143 
144 /// Note: There are intrinsics where implementing vectorization for the
145 /// intrinsic is redundant, but we want to implement scalarization of the
146 /// vector. To prevent the requirement that an intrinsic also implements
147 /// vectorization we provide this seperate function.
148 LLVM_ABI bool isTriviallyScalarizable(Intrinsic::ID ID,
149                                       const TargetTransformInfo *TTI);
150 
151 /// Identifies if the vector form of the intrinsic has a scalar operand.
152 /// \p TTI is used to consider target specific intrinsics, if no target specific
153 /// intrinsics will be considered then it is appropriate to pass in nullptr.
154 LLVM_ABI bool
155 isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
156                                    const TargetTransformInfo *TTI);
157 
158 /// Identifies if the vector form of the intrinsic is overloaded on the type of
159 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
160 /// \p TTI is used to consider target specific intrinsics, if no target specific
161 /// intrinsics will be considered then it is appropriate to pass in nullptr.
162 LLVM_ABI bool
163 isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
164                                        const TargetTransformInfo *TTI);
165 
166 /// Identifies if the vector form of the intrinsic that returns a struct is
167 /// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
168 /// consider target specific intrinsics, if no target specific intrinsics
169 /// will be considered then it is appropriate to pass in nullptr.
170 LLVM_ABI bool isVectorIntrinsicWithStructReturnOverloadAtField(
171     Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
172 
173 /// Returns intrinsic ID for call.
174 /// For the input call instruction it finds mapping intrinsic and returns
175 /// its intrinsic ID, in case it does not found it return not_intrinsic.
176 LLVM_ABI Intrinsic::ID
177 getVectorIntrinsicIDForCall(const CallInst *CI, const TargetLibraryInfo *TLI);
178 
179 /// Returns the corresponding llvm.vector.interleaveN intrinsic for factor N.
180 LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor);
181 
182 /// Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N.
183 LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor);
184 
185 /// Returns the corresponding factor of llvm.vector.interleaveN intrinsics.
186 LLVM_ABI unsigned getInterleaveIntrinsicFactor(Intrinsic::ID ID);
187 
188 /// Returns the corresponding factor of llvm.vector.deinterleaveN intrinsics.
189 LLVM_ABI unsigned getDeinterleaveIntrinsicFactor(Intrinsic::ID ID);
190 
191 /// Given a vector and an element number, see if the scalar value is
192 /// already around as a register, for example if it were inserted then extracted
193 /// from the vector.
194 LLVM_ABI Value *findScalarElement(Value *V, unsigned EltNo);
195 
196 /// If all non-negative \p Mask elements are the same value, return that value.
197 /// If all elements are negative (undefined) or \p Mask contains different
198 /// non-negative values, return -1.
199 LLVM_ABI int getSplatIndex(ArrayRef<int> Mask);
200 
201 /// Get splat value if the input is a splat vector or return nullptr.
202 /// The value may be extracted from a splat constants vector or from
203 /// a sequence of instructions that broadcast a single value into a vector.
204 LLVM_ABI Value *getSplatValue(const Value *V);
205 
206 /// Return true if each element of the vector value \p V is poisoned or equal to
207 /// every other non-poisoned element. If an index element is specified, either
208 /// every element of the vector is poisoned or the element at that index is not
209 /// poisoned and equal to every other non-poisoned element.
210 /// This may be more powerful than the related getSplatValue() because it is
211 /// not limited by finding a scalar source value to a splatted vector.
212 LLVM_ABI bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
213 
214 /// Transform a shuffle mask's output demanded element mask into demanded
215 /// element masks for the 2 operands, returns false if the mask isn't valid.
216 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
217 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
218 LLVM_ABI bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
219                                      const APInt &DemandedElts,
220                                      APInt &DemandedLHS, APInt &DemandedRHS,
221                                      bool AllowUndefElts = false);
222 
223 /// Does this shuffle mask represent either one slide shuffle or a pair of
224 /// two slide shuffles, combined with a select on some constant vector mask?
225 /// A slide is a shuffle mask which shifts some set of elements up or down
226 /// the vector, with all other elements being undefined.  An identity shuffle
227 /// will be matched a slide by 0.  The output parameter provides the source
228 /// (-1 means no source), and slide direction for each slide.
229 LLVM_ABI bool isMaskedSlidePair(ArrayRef<int> Mask, int NumElts,
230                                 std::array<std::pair<int, int>, 2> &SrcInfo);
231 
232 /// Replace each shuffle mask index with the scaled sequential indices for an
233 /// equivalent mask of narrowed elements. Mask elements that are less than 0
234 /// (sentinel values) are repeated in the output mask.
235 ///
236 /// Example with Scale = 4:
237 ///   <4 x i32> <3, 2, 0, -1> -->
238 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
239 ///
240 /// This is the reverse process of widening shuffle mask elements, but it always
241 /// succeeds because the indexes can always be multiplied (scaled up) to map to
242 /// narrower vector elements.
243 LLVM_ABI void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
244                                     SmallVectorImpl<int> &ScaledMask);
245 
246 /// Try to transform a shuffle mask by replacing elements with the scaled index
247 /// for an equivalent mask of widened elements. If all mask elements that would
248 /// map to a wider element of the new mask are the same negative number
249 /// (sentinel value), that element of the new mask is the same value. If any
250 /// element in a given slice is negative and some other element in that slice is
251 /// not the same value, return false (partial matches with sentinel values are
252 /// not allowed).
253 ///
254 /// Example with Scale = 4:
255 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
256 ///   <4 x i32> <3, 2, 0, -1>
257 ///
258 /// This is the reverse process of narrowing shuffle mask elements if it
259 /// succeeds. This transform is not always possible because indexes may not
260 /// divide evenly (scale down) to map to wider vector elements.
261 LLVM_ABI bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
262                                    SmallVectorImpl<int> &ScaledMask);
263 
264 /// A variant of the previous method which is specialized for Scale=2, and
265 /// treats -1 as undef and allows widening when a wider element is partially
266 /// undef in the narrow form of the mask.  This transformation discards
267 /// information about which bytes in the original shuffle were undef.
268 LLVM_ABI bool widenShuffleMaskElts(ArrayRef<int> M,
269                                    SmallVectorImpl<int> &NewMask);
270 
271 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target
272 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts.
273 /// This will assert unless NumDstElts is a multiple of Mask.size (or
274 /// vice-versa). Returns false on failure, and ScaledMask will be in an
275 /// undefined state.
276 LLVM_ABI bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask,
277                                    SmallVectorImpl<int> &ScaledMask);
278 
279 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
280 /// to get the shuffle mask with widest possible elements.
281 LLVM_ABI void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
282                                            SmallVectorImpl<int> &ScaledMask);
283 
284 /// Splits and processes shuffle mask depending on the number of input and
285 /// output registers. The function does 2 main things: 1) splits the
286 /// source/destination vectors into real registers; 2) do the mask analysis to
287 /// identify which real registers are permuted. Then the function processes
288 /// resulting registers mask using provided action items. If no input register
289 /// is defined, \p NoInputAction action is used. If only 1 input register is
290 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
291 /// process > 2 input registers and masks.
292 /// \param Mask Original shuffle mask.
293 /// \param NumOfSrcRegs Number of source registers.
294 /// \param NumOfDestRegs Number of destination registers.
295 /// \param NumOfUsedRegs Number of actually used destination registers.
296 LLVM_ABI void processShuffleMasks(
297     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
298     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
299     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
300     function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)>
301         ManyInputsAction);
302 
303 /// Compute the demanded elements mask of horizontal binary operations. A
304 /// horizontal operation combines two adjacent elements in a vector operand.
305 /// This function returns a mask for the elements that correspond to the first
306 /// operand of this horizontal combination. For example, for two vectors
307 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
308 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
309 /// result of this function to the left by 1.
310 ///
311 /// \param VectorBitWidth the total bit width of the vector
312 /// \param DemandedElts   the demanded elements mask for the operation
313 /// \param DemandedLHS    the demanded elements mask for the left operand
314 /// \param DemandedRHS    the demanded elements mask for the right operand
315 LLVM_ABI void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
316                                                   const APInt &DemandedElts,
317                                                   APInt &DemandedLHS,
318                                                   APInt &DemandedRHS);
319 
320 /// Compute a map of integer instructions to their minimum legal type
321 /// size.
322 ///
323 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
324 /// type (e.g. i32) whenever arithmetic is performed on them.
325 ///
326 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
327 /// the arithmetic type down again. However InstCombine refuses to create
328 /// illegal types, so for targets without i8 or i16 registers, the lengthening
329 /// and shrinking remains.
330 ///
331 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
332 /// their scalar equivalents do not, so during vectorization it is important to
333 /// remove these lengthens and truncates when deciding the profitability of
334 /// vectorization.
335 ///
336 /// This function analyzes the given range of instructions and determines the
337 /// minimum type size each can be converted to. It attempts to remove or
338 /// minimize type size changes across each def-use chain, so for example in the
339 /// following code:
340 ///
341 ///   %1 = load i8, i8*
342 ///   %2 = add i8 %1, 2
343 ///   %3 = load i16, i16*
344 ///   %4 = zext i8 %2 to i32
345 ///   %5 = zext i16 %3 to i32
346 ///   %6 = add i32 %4, %5
347 ///   %7 = trunc i32 %6 to i16
348 ///
349 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
350 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
351 ///
352 /// If the optional TargetTransformInfo is provided, this function tries harder
353 /// to do less work by only looking at illegal types.
354 LLVM_ABI MapVector<Instruction *, uint64_t>
355 computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB,
356                          const TargetTransformInfo *TTI = nullptr);
357 
358 /// Compute the union of two access-group lists.
359 ///
360 /// If the list contains just one access group, it is returned directly. If the
361 /// list is empty, returns nullptr.
362 LLVM_ABI MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
363 
364 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
365 /// are both in. If either instruction does not access memory at all, it is
366 /// considered to be in every list.
367 ///
368 /// If the list contains just one access group, it is returned directly. If the
369 /// list is empty, returns nullptr.
370 LLVM_ABI MDNode *intersectAccessGroups(const Instruction *Inst1,
371                                        const Instruction *Inst2);
372 
373 /// Add metadata from \p Inst to \p Metadata, if it can be preserved after
374 /// vectorization. It can be preserved after vectorization if the kind is one of
375 /// [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, MD_nontemporal,
376 /// MD_access_group, MD_mmra].
377 LLVM_ABI void getMetadataToPropagate(
378     Instruction *Inst,
379     SmallVectorImpl<std::pair<unsigned, MDNode *>> &Metadata);
380 
381 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
382 /// MD_nontemporal, MD_access_group, MD_mmra].
383 /// For K in Kinds, we get the MDNode for K from each of the
384 /// elements of VL, compute their "intersection" (i.e., the most generic
385 /// metadata value that covers all of the individual values), and set I's
386 /// metadata for M equal to the intersection value.
387 ///
388 /// This function always sets a (possibly null) value for each K in Kinds.
389 LLVM_ABI Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
390 
391 /// Create a mask that filters the members of an interleave group where there
392 /// are gaps.
393 ///
394 /// For example, the mask for \p Group with interleave-factor 3
395 /// and \p VF 4, that has only its first member present is:
396 ///
397 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
398 ///
399 /// Note: The result is a mask of 0's and 1's, as opposed to the other
400 /// create[*]Mask() utilities which create a shuffle mask (mask that
401 /// consists of indices).
402 LLVM_ABI Constant *
403 createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
404                      const InterleaveGroup<Instruction> &Group);
405 
406 /// Create a mask with replicated elements.
407 ///
408 /// This function creates a shuffle mask for replicating each of the \p VF
409 /// elements in a vector \p ReplicationFactor times. It can be used to
410 /// transform a mask of \p VF elements into a mask of
411 /// \p VF * \p ReplicationFactor elements used by a predicated
412 /// interleaved-group of loads/stores whose Interleaved-factor ==
413 /// \p ReplicationFactor.
414 ///
415 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
416 ///
417 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
418 LLVM_ABI llvm::SmallVector<int, 16>
419 createReplicatedMask(unsigned ReplicationFactor, unsigned VF);
420 
421 /// Create an interleave shuffle mask.
422 ///
423 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
424 /// vectorization factor \p VF into a single wide vector. The mask is of the
425 /// form:
426 ///
427 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
428 ///
429 /// For example, the mask for VF = 4 and NumVecs = 2 is:
430 ///
431 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
432 LLVM_ABI llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF,
433                                                          unsigned NumVecs);
434 
435 /// Create a stride shuffle mask.
436 ///
437 /// This function creates a shuffle mask whose elements begin at \p Start and
438 /// are incremented by \p Stride. The mask can be used to deinterleave an
439 /// interleaved vector into separate vectors of vectorization factor \p VF. The
440 /// mask is of the form:
441 ///
442 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
443 ///
444 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
445 ///
446 ///   <0, 2, 4, 6>
447 LLVM_ABI llvm::SmallVector<int, 16>
448 createStrideMask(unsigned Start, unsigned Stride, unsigned VF);
449 
450 /// Create a sequential shuffle mask.
451 ///
452 /// This function creates shuffle mask whose elements are sequential and begin
453 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
454 /// NumUndefs undef values. The mask is of the form:
455 ///
456 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
457 ///
458 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
459 ///
460 ///   <0, 1, 2, 3, undef, undef, undef, undef>
461 LLVM_ABI llvm::SmallVector<int, 16>
462 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
463 
464 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
465 /// mask assuming both operands are identical. This assumes that the unary
466 /// shuffle will use elements from operand 0 (operand 1 will be unused).
467 LLVM_ABI llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
468                                                     unsigned NumElts);
469 
470 /// Concatenate a list of vectors.
471 ///
472 /// This function generates code that concatenate the vectors in \p Vecs into a
473 /// single large vector. The number of vectors should be greater than one, and
474 /// their element types should be the same. The number of elements in the
475 /// vectors should also be the same; however, if the last vector has fewer
476 /// elements, it will be padded with undefs.
477 LLVM_ABI Value *concatenateVectors(IRBuilderBase &Builder,
478                                    ArrayRef<Value *> Vecs);
479 
480 /// Given a mask vector of i1, Return true if all of the elements of this
481 /// predicate mask are known to be false or undef.  That is, return true if all
482 /// lanes can be assumed inactive.
483 LLVM_ABI bool maskIsAllZeroOrUndef(Value *Mask);
484 
485 /// Given a mask vector of i1, Return true if all of the elements of this
486 /// predicate mask are known to be true or undef.  That is, return true if all
487 /// lanes can be assumed active.
488 LLVM_ABI bool maskIsAllOneOrUndef(Value *Mask);
489 
490 /// Given a mask vector of i1, Return true if any of the elements of this
491 /// predicate mask are known to be true or undef.  That is, return true if at
492 /// least one lane can be assumed active.
493 LLVM_ABI bool maskContainsAllOneOrUndef(Value *Mask);
494 
495 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
496 /// for each lane which may be active.
497 LLVM_ABI APInt possiblyDemandedEltsInMask(Value *Mask);
498 
499 /// The group of interleaved loads/stores sharing the same stride and
500 /// close to each other.
501 ///
502 /// Each member in this group has an index starting from 0, and the largest
503 /// index should be less than interleaved factor, which is equal to the absolute
504 /// value of the access's stride.
505 ///
506 /// E.g. An interleaved load group of factor 4:
507 ///        for (unsigned i = 0; i < 1024; i+=4) {
508 ///          a = A[i];                           // Member of index 0
509 ///          b = A[i+1];                         // Member of index 1
510 ///          d = A[i+3];                         // Member of index 3
511 ///          ...
512 ///        }
513 ///
514 ///      An interleaved store group of factor 4:
515 ///        for (unsigned i = 0; i < 1024; i+=4) {
516 ///          ...
517 ///          A[i]   = a;                         // Member of index 0
518 ///          A[i+1] = b;                         // Member of index 1
519 ///          A[i+2] = c;                         // Member of index 2
520 ///          A[i+3] = d;                         // Member of index 3
521 ///        }
522 ///
523 /// Note: the interleaved load group could have gaps (missing members), but
524 /// the interleaved store group doesn't allow gaps.
525 template <typename InstTy> class InterleaveGroup {
526 public:
InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)527   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
528       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
529         InsertPos(nullptr) {}
530 
InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)531   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
532       : Alignment(Alignment), InsertPos(Instr) {
533     Factor = std::abs(Stride);
534     assert(Factor > 1 && "Invalid interleave factor");
535 
536     Reverse = Stride < 0;
537     Members[0] = Instr;
538   }
539 
isReverse()540   bool isReverse() const { return Reverse; }
getFactor()541   uint32_t getFactor() const { return Factor; }
getAlign()542   Align getAlign() const { return Alignment; }
getNumMembers()543   uint32_t getNumMembers() const { return Members.size(); }
544 
545   /// Try to insert a new member \p Instr with index \p Index and
546   /// alignment \p NewAlign. The index is related to the leader and it could be
547   /// negative if it is the new leader.
548   ///
549   /// \returns false if the instruction doesn't belong to the group.
insertMember(InstTy * Instr,int32_t Index,Align NewAlign)550   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
551     // Make sure the key fits in an int32_t.
552     std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
553     if (!MaybeKey)
554       return false;
555     int32_t Key = *MaybeKey;
556 
557     // Skip if the key is used for either the tombstone or empty special values.
558     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
559         DenseMapInfo<int32_t>::getEmptyKey() == Key)
560       return false;
561 
562     // Skip if there is already a member with the same index.
563     if (Members.contains(Key))
564       return false;
565 
566     if (Key > LargestKey) {
567       // The largest index is always less than the interleave factor.
568       if (Index >= static_cast<int32_t>(Factor))
569         return false;
570 
571       LargestKey = Key;
572     } else if (Key < SmallestKey) {
573 
574       // Make sure the largest index fits in an int32_t.
575       std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
576       if (!MaybeLargestIndex)
577         return false;
578 
579       // The largest index is always less than the interleave factor.
580       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
581         return false;
582 
583       SmallestKey = Key;
584     }
585 
586     // It's always safe to select the minimum alignment.
587     Alignment = std::min(Alignment, NewAlign);
588     Members[Key] = Instr;
589     return true;
590   }
591 
592   /// Get the member with the given index \p Index
593   ///
594   /// \returns nullptr if contains no such member.
getMember(uint32_t Index)595   InstTy *getMember(uint32_t Index) const {
596     int32_t Key = SmallestKey + Index;
597     return Members.lookup(Key);
598   }
599 
600   /// Get the index for the given member. Unlike the key in the member
601   /// map, the index starts from 0.
getIndex(const InstTy * Instr)602   uint32_t getIndex(const InstTy *Instr) const {
603     for (auto I : Members) {
604       if (I.second == Instr)
605         return I.first - SmallestKey;
606     }
607 
608     llvm_unreachable("InterleaveGroup contains no such member");
609   }
610 
getInsertPos()611   InstTy *getInsertPos() const { return InsertPos; }
setInsertPos(InstTy * Inst)612   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
613 
614   /// Add metadata (e.g. alias info) from the instructions in this group to \p
615   /// NewInst.
616   ///
617   /// FIXME: this function currently does not add noalias metadata a'la
618   /// addNewMedata.  To do that we need to compute the intersection of the
619   /// noalias info from all members.
620   void addMetadata(InstTy *NewInst) const;
621 
622   /// Returns true if this Group requires a scalar iteration to handle gaps.
requiresScalarEpilogue()623   bool requiresScalarEpilogue() const {
624     // If the last member of the Group exists, then a scalar epilog is not
625     // needed for this group.
626     if (getMember(getFactor() - 1))
627       return false;
628 
629     // We have a group with gaps. It therefore can't be a reversed access,
630     // because such groups get invalidated (TODO).
631     assert(!isReverse() && "Group should have been invalidated");
632 
633     // This is a group of loads, with gaps, and without a last-member
634     return true;
635   }
636 
637 private:
638   uint32_t Factor; // Interleave Factor.
639   bool Reverse;
640   Align Alignment;
641   DenseMap<int32_t, InstTy *> Members;
642   int32_t SmallestKey = 0;
643   int32_t LargestKey = 0;
644 
645   // To avoid breaking dependences, vectorized instructions of an interleave
646   // group should be inserted at either the first load or the last store in
647   // program order.
648   //
649   // E.g. %even = load i32             // Insert Position
650   //      %add = add i32 %even         // Use of %even
651   //      %odd = load i32
652   //
653   //      store i32 %even
654   //      %odd = add i32               // Def of %odd
655   //      store i32 %odd               // Insert Position
656   InstTy *InsertPos;
657 };
658 
659 /// Drive the analysis of interleaved memory accesses in the loop.
660 ///
661 /// Use this class to analyze interleaved accesses only when we can vectorize
662 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
663 /// on interleaved accesses is unsafe.
664 ///
665 /// The analysis collects interleave groups and records the relationships
666 /// between the member and the group in a map.
667 class InterleavedAccessInfo {
668 public:
InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)669   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
670                         DominatorTree *DT, LoopInfo *LI,
671                         const LoopAccessInfo *LAI)
672       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
673 
~InterleavedAccessInfo()674   ~InterleavedAccessInfo() { invalidateGroups(); }
675 
676   /// Analyze the interleaved accesses and collect them in interleave
677   /// groups. Substitute symbolic strides using \p Strides.
678   /// Consider also predicated loads/stores in the analysis if
679   /// \p EnableMaskedInterleavedGroup is true.
680   LLVM_ABI void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
681 
682   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
683   /// contrary to original assumption. Although we currently prevent group
684   /// formation for predicated accesses, we may be able to relax this limitation
685   /// in the future once we handle more complicated blocks. Returns true if any
686   /// groups were invalidated.
invalidateGroups()687   bool invalidateGroups() {
688     if (InterleaveGroups.empty()) {
689       assert(
690           !RequiresScalarEpilogue &&
691           "RequiresScalarEpilog should not be set without interleave groups");
692       return false;
693     }
694 
695     InterleaveGroupMap.clear();
696     for (auto *Ptr : InterleaveGroups)
697       delete Ptr;
698     InterleaveGroups.clear();
699     RequiresScalarEpilogue = false;
700     return true;
701   }
702 
703   /// Check if \p Instr belongs to any interleave group.
isInterleaved(Instruction * Instr)704   bool isInterleaved(Instruction *Instr) const {
705     return InterleaveGroupMap.contains(Instr);
706   }
707 
708   /// Get the interleave group that \p Instr belongs to.
709   ///
710   /// \returns nullptr if doesn't have such group.
711   InterleaveGroup<Instruction> *
getInterleaveGroup(const Instruction * Instr)712   getInterleaveGroup(const Instruction *Instr) const {
713     return InterleaveGroupMap.lookup(Instr);
714   }
715 
716   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
getInterleaveGroups()717   getInterleaveGroups() {
718     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
719   }
720 
721   /// Returns true if an interleaved group that may access memory
722   /// out-of-bounds requires a scalar epilogue iteration for correctness.
requiresScalarEpilogue()723   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
724 
725   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
726   /// happen when optimizing for size forbids a scalar epilogue, and the gap
727   /// cannot be filtered by masking the load/store.
728   LLVM_ABI void invalidateGroupsRequiringScalarEpilogue();
729 
730   /// Returns true if we have any interleave groups.
hasGroups()731   bool hasGroups() const { return !InterleaveGroups.empty(); }
732 
733 private:
734   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
735   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
736   /// The interleaved access analysis can also add new predicates (for example
737   /// by versioning strides of pointers).
738   PredicatedScalarEvolution &PSE;
739 
740   Loop *TheLoop;
741   DominatorTree *DT;
742   LoopInfo *LI;
743   const LoopAccessInfo *LAI;
744 
745   /// True if the loop may contain non-reversed interleaved groups with
746   /// out-of-bounds accesses. We ensure we don't speculatively access memory
747   /// out-of-bounds by executing at least one scalar epilogue iteration.
748   bool RequiresScalarEpilogue = false;
749 
750   /// Holds the relationships between the members and the interleave group.
751   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
752 
753   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
754 
755   /// Holds dependences among the memory accesses in the loop. It maps a source
756   /// access to a set of dependent sink accesses.
757   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
758 
759   /// The descriptor for a strided memory access.
760   struct StrideDescriptor {
761     StrideDescriptor() = default;
StrideDescriptorStrideDescriptor762     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
763                      Align Alignment)
764         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
765 
766     // The access's stride. It is negative for a reverse access.
767     int64_t Stride = 0;
768 
769     // The scalar expression of this access.
770     const SCEV *Scev = nullptr;
771 
772     // The size of the memory object.
773     uint64_t Size = 0;
774 
775     // The alignment of this access.
776     Align Alignment;
777   };
778 
779   /// A type for holding instructions and their stride descriptors.
780   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
781 
782   /// Create a new interleave group with the given instruction \p Instr,
783   /// stride \p Stride and alignment \p Align.
784   ///
785   /// \returns the newly created interleave group.
786   InterleaveGroup<Instruction> *
createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)787   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
788     auto [It, Inserted] = InterleaveGroupMap.try_emplace(Instr);
789     assert(Inserted && "Already in an interleaved access group");
790     It->second = new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
791     InterleaveGroups.insert(It->second);
792     return It->second;
793   }
794 
795   /// Release the group and remove all the relationships.
releaseGroup(InterleaveGroup<Instruction> * Group)796   void releaseGroup(InterleaveGroup<Instruction> *Group) {
797     InterleaveGroups.erase(Group);
798     releaseGroupWithoutRemovingFromSet(Group);
799   }
800 
801   /// Do everything necessary to release the group, apart from removing it from
802   /// the InterleaveGroups set.
releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> * Group)803   void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) {
804     for (unsigned i = 0; i < Group->getFactor(); i++)
805       if (Instruction *Member = Group->getMember(i))
806         InterleaveGroupMap.erase(Member);
807 
808     delete Group;
809   }
810 
811   /// Collect all the accesses with a constant stride in program order.
812   void collectConstStrideAccesses(
813       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
814       const DenseMap<Value *, const SCEV *> &Strides);
815 
816   /// Returns true if \p Stride is allowed in an interleaved group.
817   LLVM_ABI static bool isStrided(int Stride);
818 
819   /// Returns true if \p BB is a predicated block.
isPredicated(BasicBlock * BB)820   bool isPredicated(BasicBlock *BB) const {
821     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
822   }
823 
824   /// Returns true if LoopAccessInfo can be used for dependence queries.
areDependencesValid()825   bool areDependencesValid() const {
826     return LAI && LAI->getDepChecker().getDependences();
827   }
828 
829   /// Returns true if memory accesses \p A and \p B can be reordered, if
830   /// necessary, when constructing interleaved groups.
831   ///
832   /// \p A must precede \p B in program order. We return false if reordering is
833   /// not necessary or is prevented because \p A and \p B may be dependent.
canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)834   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
835                                                  StrideEntry *B) const {
836     // Code motion for interleaved accesses can potentially hoist strided loads
837     // and sink strided stores. The code below checks the legality of the
838     // following two conditions:
839     //
840     // 1. Potentially moving a strided load (B) before any store (A) that
841     //    precedes B, or
842     //
843     // 2. Potentially moving a strided store (A) after any load or store (B)
844     //    that A precedes.
845     //
846     // It's legal to reorder A and B if we know there isn't a dependence from A
847     // to B. Note that this determination is conservative since some
848     // dependences could potentially be reordered safely.
849 
850     // A is potentially the source of a dependence.
851     auto *Src = A->first;
852     auto SrcDes = A->second;
853 
854     // B is potentially the sink of a dependence.
855     auto *Sink = B->first;
856     auto SinkDes = B->second;
857 
858     // Code motion for interleaved accesses can't violate WAR dependences.
859     // Thus, reordering is legal if the source isn't a write.
860     if (!Src->mayWriteToMemory())
861       return true;
862 
863     // At least one of the accesses must be strided.
864     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
865       return true;
866 
867     // If dependence information is not available from LoopAccessInfo,
868     // conservatively assume the instructions can't be reordered.
869     if (!areDependencesValid())
870       return false;
871 
872     // If we know there is a dependence from source to sink, assume the
873     // instructions can't be reordered. Otherwise, reordering is legal.
874     return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
875   }
876 
877   /// Collect the dependences from LoopAccessInfo.
878   ///
879   /// We process the dependences once during the interleaved access analysis to
880   /// enable constant-time dependence queries.
collectDependences()881   void collectDependences() {
882     if (!areDependencesValid())
883       return;
884     const auto &DepChecker = LAI->getDepChecker();
885     auto *Deps = DepChecker.getDependences();
886     for (auto Dep : *Deps)
887       Dependences[Dep.getSource(DepChecker)].insert(
888           Dep.getDestination(DepChecker));
889   }
890 };
891 
892 } // llvm namespace
893 
894 #endif
895