1 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines some vectorizer utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 14 #define LLVM_ANALYSIS_VECTORUTILS_H 15 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/LoopAccessAnalysis.h" 19 #include "llvm/IR/Module.h" 20 #include "llvm/IR/VFABIDemangler.h" 21 #include "llvm/IR/VectorTypeUtils.h" 22 #include "llvm/Support/CheckedArithmetic.h" 23 #include "llvm/Support/Compiler.h" 24 25 namespace llvm { 26 class TargetLibraryInfo; 27 28 /// The Vector Function Database. 29 /// 30 /// Helper class used to find the vector functions associated to a 31 /// scalar CallInst. 32 class VFDatabase { 33 /// The Module of the CallInst CI. 34 const Module *M; 35 /// The CallInst instance being queried for scalar to vector mappings. 36 const CallInst &CI; 37 /// List of vector functions descriptors associated to the call 38 /// instruction. 39 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 40 41 /// Retrieve the scalar-to-vector mappings associated to the rule of 42 /// a vector Function ABI. getVFABIMappings(const CallInst & CI,SmallVectorImpl<VFInfo> & Mappings)43 static void getVFABIMappings(const CallInst &CI, 44 SmallVectorImpl<VFInfo> &Mappings) { 45 if (!CI.getCalledFunction()) 46 return; 47 48 const StringRef ScalarName = CI.getCalledFunction()->getName(); 49 50 SmallVector<std::string, 8> ListOfStrings; 51 // The check for the vector-function-abi-variant attribute is done when 52 // retrieving the vector variant names here. 53 VFABI::getVectorVariantNames(CI, ListOfStrings); 54 if (ListOfStrings.empty()) 55 return; 56 for (const auto &MangledName : ListOfStrings) { 57 const std::optional<VFInfo> Shape = 58 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType()); 59 // A match is found via scalar and vector names, and also by 60 // ensuring that the variant described in the attribute has a 61 // corresponding definition or declaration of the vector 62 // function in the Module M. 63 if (Shape && (Shape->ScalarName == ScalarName)) { 64 assert(CI.getModule()->getFunction(Shape->VectorName) && 65 "Vector function is missing."); 66 Mappings.push_back(*Shape); 67 } 68 } 69 } 70 71 public: 72 /// Retrieve all the VFInfo instances associated to the CallInst CI. getMappings(const CallInst & CI)73 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 74 SmallVector<VFInfo, 8> Ret; 75 76 // Get mappings from the Vector Function ABI variants. 77 getVFABIMappings(CI, Ret); 78 79 // Other non-VFABI variants should be retrieved here. 80 81 return Ret; 82 } 83 84 static bool hasMaskedVariant(const CallInst &CI, 85 std::optional<ElementCount> VF = std::nullopt) { 86 // Check whether we have at least one masked vector version of a scalar 87 // function. If no VF is specified then we check for any masked variant, 88 // otherwise we look for one that matches the supplied VF. 89 auto Mappings = VFDatabase::getMappings(CI); 90 for (VFInfo Info : Mappings) 91 if (!VF || Info.Shape.VF == *VF) 92 if (Info.isMasked()) 93 return true; 94 95 return false; 96 } 97 98 /// Constructor, requires a CallInst instance. VFDatabase(CallInst & CI)99 VFDatabase(CallInst &CI) 100 : M(CI.getModule()), CI(CI), 101 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 102 103 /// \defgroup VFDatabase query interface. 104 /// 105 /// @{ 106 /// Retrieve the Function with VFShape \p Shape. getVectorizedFunction(const VFShape & Shape)107 Function *getVectorizedFunction(const VFShape &Shape) const { 108 if (Shape == VFShape::getScalarShape(CI.getFunctionType())) 109 return CI.getCalledFunction(); 110 111 for (const auto &Info : ScalarToVectorMappings) 112 if (Info.Shape == Shape) 113 return M->getFunction(Info.VectorName); 114 115 return nullptr; 116 } 117 /// @} 118 }; 119 120 template <typename T> class ArrayRef; 121 class DemandedBits; 122 template <typename InstTy> class InterleaveGroup; 123 class IRBuilderBase; 124 class Loop; 125 class TargetTransformInfo; 126 class Value; 127 128 namespace Intrinsic { 129 typedef unsigned ID; 130 } 131 132 /// Identify if the intrinsic is trivially vectorizable. 133 /// This method returns true if the intrinsic's argument types are all scalars 134 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 135 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic. 136 /// 137 /// Note: isTriviallyVectorizable implies isTriviallyScalarizable. 138 LLVM_ABI bool isTriviallyVectorizable(Intrinsic::ID ID); 139 140 /// Identify if the intrinsic is trivially scalarizable. 141 /// This method returns true following the same predicates of 142 /// isTriviallyVectorizable. 143 144 /// Note: There are intrinsics where implementing vectorization for the 145 /// intrinsic is redundant, but we want to implement scalarization of the 146 /// vector. To prevent the requirement that an intrinsic also implements 147 /// vectorization we provide this seperate function. 148 LLVM_ABI bool isTriviallyScalarizable(Intrinsic::ID ID, 149 const TargetTransformInfo *TTI); 150 151 /// Identifies if the vector form of the intrinsic has a scalar operand. 152 /// \p TTI is used to consider target specific intrinsics, if no target specific 153 /// intrinsics will be considered then it is appropriate to pass in nullptr. 154 LLVM_ABI bool 155 isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx, 156 const TargetTransformInfo *TTI); 157 158 /// Identifies if the vector form of the intrinsic is overloaded on the type of 159 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. 160 /// \p TTI is used to consider target specific intrinsics, if no target specific 161 /// intrinsics will be considered then it is appropriate to pass in nullptr. 162 LLVM_ABI bool 163 isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx, 164 const TargetTransformInfo *TTI); 165 166 /// Identifies if the vector form of the intrinsic that returns a struct is 167 /// overloaded at the struct element index \p RetIdx. /// \p TTI is used to 168 /// consider target specific intrinsics, if no target specific intrinsics 169 /// will be considered then it is appropriate to pass in nullptr. 170 LLVM_ABI bool isVectorIntrinsicWithStructReturnOverloadAtField( 171 Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI); 172 173 /// Returns intrinsic ID for call. 174 /// For the input call instruction it finds mapping intrinsic and returns 175 /// its intrinsic ID, in case it does not found it return not_intrinsic. 176 LLVM_ABI Intrinsic::ID 177 getVectorIntrinsicIDForCall(const CallInst *CI, const TargetLibraryInfo *TLI); 178 179 /// Returns the corresponding llvm.vector.interleaveN intrinsic for factor N. 180 LLVM_ABI Intrinsic::ID getInterleaveIntrinsicID(unsigned Factor); 181 182 /// Returns the corresponding llvm.vector.deinterleaveN intrinsic for factor N. 183 LLVM_ABI Intrinsic::ID getDeinterleaveIntrinsicID(unsigned Factor); 184 185 /// Returns the corresponding factor of llvm.vector.interleaveN intrinsics. 186 LLVM_ABI unsigned getInterleaveIntrinsicFactor(Intrinsic::ID ID); 187 188 /// Returns the corresponding factor of llvm.vector.deinterleaveN intrinsics. 189 LLVM_ABI unsigned getDeinterleaveIntrinsicFactor(Intrinsic::ID ID); 190 191 /// Given a vector and an element number, see if the scalar value is 192 /// already around as a register, for example if it were inserted then extracted 193 /// from the vector. 194 LLVM_ABI Value *findScalarElement(Value *V, unsigned EltNo); 195 196 /// If all non-negative \p Mask elements are the same value, return that value. 197 /// If all elements are negative (undefined) or \p Mask contains different 198 /// non-negative values, return -1. 199 LLVM_ABI int getSplatIndex(ArrayRef<int> Mask); 200 201 /// Get splat value if the input is a splat vector or return nullptr. 202 /// The value may be extracted from a splat constants vector or from 203 /// a sequence of instructions that broadcast a single value into a vector. 204 LLVM_ABI Value *getSplatValue(const Value *V); 205 206 /// Return true if each element of the vector value \p V is poisoned or equal to 207 /// every other non-poisoned element. If an index element is specified, either 208 /// every element of the vector is poisoned or the element at that index is not 209 /// poisoned and equal to every other non-poisoned element. 210 /// This may be more powerful than the related getSplatValue() because it is 211 /// not limited by finding a scalar source value to a splatted vector. 212 LLVM_ABI bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 213 214 /// Transform a shuffle mask's output demanded element mask into demanded 215 /// element masks for the 2 operands, returns false if the mask isn't valid. 216 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. 217 /// \p AllowUndefElts permits "-1" indices to be treated as undef. 218 LLVM_ABI bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask, 219 const APInt &DemandedElts, 220 APInt &DemandedLHS, APInt &DemandedRHS, 221 bool AllowUndefElts = false); 222 223 /// Does this shuffle mask represent either one slide shuffle or a pair of 224 /// two slide shuffles, combined with a select on some constant vector mask? 225 /// A slide is a shuffle mask which shifts some set of elements up or down 226 /// the vector, with all other elements being undefined. An identity shuffle 227 /// will be matched a slide by 0. The output parameter provides the source 228 /// (-1 means no source), and slide direction for each slide. 229 LLVM_ABI bool isMaskedSlidePair(ArrayRef<int> Mask, int NumElts, 230 std::array<std::pair<int, int>, 2> &SrcInfo); 231 232 /// Replace each shuffle mask index with the scaled sequential indices for an 233 /// equivalent mask of narrowed elements. Mask elements that are less than 0 234 /// (sentinel values) are repeated in the output mask. 235 /// 236 /// Example with Scale = 4: 237 /// <4 x i32> <3, 2, 0, -1> --> 238 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 239 /// 240 /// This is the reverse process of widening shuffle mask elements, but it always 241 /// succeeds because the indexes can always be multiplied (scaled up) to map to 242 /// narrower vector elements. 243 LLVM_ABI void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 244 SmallVectorImpl<int> &ScaledMask); 245 246 /// Try to transform a shuffle mask by replacing elements with the scaled index 247 /// for an equivalent mask of widened elements. If all mask elements that would 248 /// map to a wider element of the new mask are the same negative number 249 /// (sentinel value), that element of the new mask is the same value. If any 250 /// element in a given slice is negative and some other element in that slice is 251 /// not the same value, return false (partial matches with sentinel values are 252 /// not allowed). 253 /// 254 /// Example with Scale = 4: 255 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 256 /// <4 x i32> <3, 2, 0, -1> 257 /// 258 /// This is the reverse process of narrowing shuffle mask elements if it 259 /// succeeds. This transform is not always possible because indexes may not 260 /// divide evenly (scale down) to map to wider vector elements. 261 LLVM_ABI bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 262 SmallVectorImpl<int> &ScaledMask); 263 264 /// A variant of the previous method which is specialized for Scale=2, and 265 /// treats -1 as undef and allows widening when a wider element is partially 266 /// undef in the narrow form of the mask. This transformation discards 267 /// information about which bytes in the original shuffle were undef. 268 LLVM_ABI bool widenShuffleMaskElts(ArrayRef<int> M, 269 SmallVectorImpl<int> &NewMask); 270 271 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target 272 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts. 273 /// This will assert unless NumDstElts is a multiple of Mask.size (or 274 /// vice-versa). Returns false on failure, and ScaledMask will be in an 275 /// undefined state. 276 LLVM_ABI bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask, 277 SmallVectorImpl<int> &ScaledMask); 278 279 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds, 280 /// to get the shuffle mask with widest possible elements. 281 LLVM_ABI void getShuffleMaskWithWidestElts(ArrayRef<int> Mask, 282 SmallVectorImpl<int> &ScaledMask); 283 284 /// Splits and processes shuffle mask depending on the number of input and 285 /// output registers. The function does 2 main things: 1) splits the 286 /// source/destination vectors into real registers; 2) do the mask analysis to 287 /// identify which real registers are permuted. Then the function processes 288 /// resulting registers mask using provided action items. If no input register 289 /// is defined, \p NoInputAction action is used. If only 1 input register is 290 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to 291 /// process > 2 input registers and masks. 292 /// \param Mask Original shuffle mask. 293 /// \param NumOfSrcRegs Number of source registers. 294 /// \param NumOfDestRegs Number of destination registers. 295 /// \param NumOfUsedRegs Number of actually used destination registers. 296 LLVM_ABI void processShuffleMasks( 297 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, 298 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, 299 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, 300 function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)> 301 ManyInputsAction); 302 303 /// Compute the demanded elements mask of horizontal binary operations. A 304 /// horizontal operation combines two adjacent elements in a vector operand. 305 /// This function returns a mask for the elements that correspond to the first 306 /// operand of this horizontal combination. For example, for two vectors 307 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the 308 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the 309 /// result of this function to the left by 1. 310 /// 311 /// \param VectorBitWidth the total bit width of the vector 312 /// \param DemandedElts the demanded elements mask for the operation 313 /// \param DemandedLHS the demanded elements mask for the left operand 314 /// \param DemandedRHS the demanded elements mask for the right operand 315 LLVM_ABI void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth, 316 const APInt &DemandedElts, 317 APInt &DemandedLHS, 318 APInt &DemandedRHS); 319 320 /// Compute a map of integer instructions to their minimum legal type 321 /// size. 322 /// 323 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 324 /// type (e.g. i32) whenever arithmetic is performed on them. 325 /// 326 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 327 /// the arithmetic type down again. However InstCombine refuses to create 328 /// illegal types, so for targets without i8 or i16 registers, the lengthening 329 /// and shrinking remains. 330 /// 331 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 332 /// their scalar equivalents do not, so during vectorization it is important to 333 /// remove these lengthens and truncates when deciding the profitability of 334 /// vectorization. 335 /// 336 /// This function analyzes the given range of instructions and determines the 337 /// minimum type size each can be converted to. It attempts to remove or 338 /// minimize type size changes across each def-use chain, so for example in the 339 /// following code: 340 /// 341 /// %1 = load i8, i8* 342 /// %2 = add i8 %1, 2 343 /// %3 = load i16, i16* 344 /// %4 = zext i8 %2 to i32 345 /// %5 = zext i16 %3 to i32 346 /// %6 = add i32 %4, %5 347 /// %7 = trunc i32 %6 to i16 348 /// 349 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 350 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 351 /// 352 /// If the optional TargetTransformInfo is provided, this function tries harder 353 /// to do less work by only looking at illegal types. 354 LLVM_ABI MapVector<Instruction *, uint64_t> 355 computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 356 const TargetTransformInfo *TTI = nullptr); 357 358 /// Compute the union of two access-group lists. 359 /// 360 /// If the list contains just one access group, it is returned directly. If the 361 /// list is empty, returns nullptr. 362 LLVM_ABI MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 363 364 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 365 /// are both in. If either instruction does not access memory at all, it is 366 /// considered to be in every list. 367 /// 368 /// If the list contains just one access group, it is returned directly. If the 369 /// list is empty, returns nullptr. 370 LLVM_ABI MDNode *intersectAccessGroups(const Instruction *Inst1, 371 const Instruction *Inst2); 372 373 /// Add metadata from \p Inst to \p Metadata, if it can be preserved after 374 /// vectorization. It can be preserved after vectorization if the kind is one of 375 /// [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, MD_nontemporal, 376 /// MD_access_group, MD_mmra]. 377 LLVM_ABI void getMetadataToPropagate( 378 Instruction *Inst, 379 SmallVectorImpl<std::pair<unsigned, MDNode *>> &Metadata); 380 381 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 382 /// MD_nontemporal, MD_access_group, MD_mmra]. 383 /// For K in Kinds, we get the MDNode for K from each of the 384 /// elements of VL, compute their "intersection" (i.e., the most generic 385 /// metadata value that covers all of the individual values), and set I's 386 /// metadata for M equal to the intersection value. 387 /// 388 /// This function always sets a (possibly null) value for each K in Kinds. 389 LLVM_ABI Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 390 391 /// Create a mask that filters the members of an interleave group where there 392 /// are gaps. 393 /// 394 /// For example, the mask for \p Group with interleave-factor 3 395 /// and \p VF 4, that has only its first member present is: 396 /// 397 /// <1,0,0,1,0,0,1,0,0,1,0,0> 398 /// 399 /// Note: The result is a mask of 0's and 1's, as opposed to the other 400 /// create[*]Mask() utilities which create a shuffle mask (mask that 401 /// consists of indices). 402 LLVM_ABI Constant * 403 createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 404 const InterleaveGroup<Instruction> &Group); 405 406 /// Create a mask with replicated elements. 407 /// 408 /// This function creates a shuffle mask for replicating each of the \p VF 409 /// elements in a vector \p ReplicationFactor times. It can be used to 410 /// transform a mask of \p VF elements into a mask of 411 /// \p VF * \p ReplicationFactor elements used by a predicated 412 /// interleaved-group of loads/stores whose Interleaved-factor == 413 /// \p ReplicationFactor. 414 /// 415 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 416 /// 417 /// <0,0,0,1,1,1,2,2,2,3,3,3> 418 LLVM_ABI llvm::SmallVector<int, 16> 419 createReplicatedMask(unsigned ReplicationFactor, unsigned VF); 420 421 /// Create an interleave shuffle mask. 422 /// 423 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 424 /// vectorization factor \p VF into a single wide vector. The mask is of the 425 /// form: 426 /// 427 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 428 /// 429 /// For example, the mask for VF = 4 and NumVecs = 2 is: 430 /// 431 /// <0, 4, 1, 5, 2, 6, 3, 7>. 432 LLVM_ABI llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, 433 unsigned NumVecs); 434 435 /// Create a stride shuffle mask. 436 /// 437 /// This function creates a shuffle mask whose elements begin at \p Start and 438 /// are incremented by \p Stride. The mask can be used to deinterleave an 439 /// interleaved vector into separate vectors of vectorization factor \p VF. The 440 /// mask is of the form: 441 /// 442 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 443 /// 444 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 445 /// 446 /// <0, 2, 4, 6> 447 LLVM_ABI llvm::SmallVector<int, 16> 448 createStrideMask(unsigned Start, unsigned Stride, unsigned VF); 449 450 /// Create a sequential shuffle mask. 451 /// 452 /// This function creates shuffle mask whose elements are sequential and begin 453 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 454 /// NumUndefs undef values. The mask is of the form: 455 /// 456 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 457 /// 458 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 459 /// 460 /// <0, 1, 2, 3, undef, undef, undef, undef> 461 LLVM_ABI llvm::SmallVector<int, 16> 462 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 463 464 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle 465 /// mask assuming both operands are identical. This assumes that the unary 466 /// shuffle will use elements from operand 0 (operand 1 will be unused). 467 LLVM_ABI llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask, 468 unsigned NumElts); 469 470 /// Concatenate a list of vectors. 471 /// 472 /// This function generates code that concatenate the vectors in \p Vecs into a 473 /// single large vector. The number of vectors should be greater than one, and 474 /// their element types should be the same. The number of elements in the 475 /// vectors should also be the same; however, if the last vector has fewer 476 /// elements, it will be padded with undefs. 477 LLVM_ABI Value *concatenateVectors(IRBuilderBase &Builder, 478 ArrayRef<Value *> Vecs); 479 480 /// Given a mask vector of i1, Return true if all of the elements of this 481 /// predicate mask are known to be false or undef. That is, return true if all 482 /// lanes can be assumed inactive. 483 LLVM_ABI bool maskIsAllZeroOrUndef(Value *Mask); 484 485 /// Given a mask vector of i1, Return true if all of the elements of this 486 /// predicate mask are known to be true or undef. That is, return true if all 487 /// lanes can be assumed active. 488 LLVM_ABI bool maskIsAllOneOrUndef(Value *Mask); 489 490 /// Given a mask vector of i1, Return true if any of the elements of this 491 /// predicate mask are known to be true or undef. That is, return true if at 492 /// least one lane can be assumed active. 493 LLVM_ABI bool maskContainsAllOneOrUndef(Value *Mask); 494 495 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 496 /// for each lane which may be active. 497 LLVM_ABI APInt possiblyDemandedEltsInMask(Value *Mask); 498 499 /// The group of interleaved loads/stores sharing the same stride and 500 /// close to each other. 501 /// 502 /// Each member in this group has an index starting from 0, and the largest 503 /// index should be less than interleaved factor, which is equal to the absolute 504 /// value of the access's stride. 505 /// 506 /// E.g. An interleaved load group of factor 4: 507 /// for (unsigned i = 0; i < 1024; i+=4) { 508 /// a = A[i]; // Member of index 0 509 /// b = A[i+1]; // Member of index 1 510 /// d = A[i+3]; // Member of index 3 511 /// ... 512 /// } 513 /// 514 /// An interleaved store group of factor 4: 515 /// for (unsigned i = 0; i < 1024; i+=4) { 516 /// ... 517 /// A[i] = a; // Member of index 0 518 /// A[i+1] = b; // Member of index 1 519 /// A[i+2] = c; // Member of index 2 520 /// A[i+3] = d; // Member of index 3 521 /// } 522 /// 523 /// Note: the interleaved load group could have gaps (missing members), but 524 /// the interleaved store group doesn't allow gaps. 525 template <typename InstTy> class InterleaveGroup { 526 public: InterleaveGroup(uint32_t Factor,bool Reverse,Align Alignment)527 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 528 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 529 InsertPos(nullptr) {} 530 InterleaveGroup(InstTy * Instr,int32_t Stride,Align Alignment)531 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 532 : Alignment(Alignment), InsertPos(Instr) { 533 Factor = std::abs(Stride); 534 assert(Factor > 1 && "Invalid interleave factor"); 535 536 Reverse = Stride < 0; 537 Members[0] = Instr; 538 } 539 isReverse()540 bool isReverse() const { return Reverse; } getFactor()541 uint32_t getFactor() const { return Factor; } getAlign()542 Align getAlign() const { return Alignment; } getNumMembers()543 uint32_t getNumMembers() const { return Members.size(); } 544 545 /// Try to insert a new member \p Instr with index \p Index and 546 /// alignment \p NewAlign. The index is related to the leader and it could be 547 /// negative if it is the new leader. 548 /// 549 /// \returns false if the instruction doesn't belong to the group. insertMember(InstTy * Instr,int32_t Index,Align NewAlign)550 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 551 // Make sure the key fits in an int32_t. 552 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 553 if (!MaybeKey) 554 return false; 555 int32_t Key = *MaybeKey; 556 557 // Skip if the key is used for either the tombstone or empty special values. 558 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 559 DenseMapInfo<int32_t>::getEmptyKey() == Key) 560 return false; 561 562 // Skip if there is already a member with the same index. 563 if (Members.contains(Key)) 564 return false; 565 566 if (Key > LargestKey) { 567 // The largest index is always less than the interleave factor. 568 if (Index >= static_cast<int32_t>(Factor)) 569 return false; 570 571 LargestKey = Key; 572 } else if (Key < SmallestKey) { 573 574 // Make sure the largest index fits in an int32_t. 575 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 576 if (!MaybeLargestIndex) 577 return false; 578 579 // The largest index is always less than the interleave factor. 580 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 581 return false; 582 583 SmallestKey = Key; 584 } 585 586 // It's always safe to select the minimum alignment. 587 Alignment = std::min(Alignment, NewAlign); 588 Members[Key] = Instr; 589 return true; 590 } 591 592 /// Get the member with the given index \p Index 593 /// 594 /// \returns nullptr if contains no such member. getMember(uint32_t Index)595 InstTy *getMember(uint32_t Index) const { 596 int32_t Key = SmallestKey + Index; 597 return Members.lookup(Key); 598 } 599 600 /// Get the index for the given member. Unlike the key in the member 601 /// map, the index starts from 0. getIndex(const InstTy * Instr)602 uint32_t getIndex(const InstTy *Instr) const { 603 for (auto I : Members) { 604 if (I.second == Instr) 605 return I.first - SmallestKey; 606 } 607 608 llvm_unreachable("InterleaveGroup contains no such member"); 609 } 610 getInsertPos()611 InstTy *getInsertPos() const { return InsertPos; } setInsertPos(InstTy * Inst)612 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 613 614 /// Add metadata (e.g. alias info) from the instructions in this group to \p 615 /// NewInst. 616 /// 617 /// FIXME: this function currently does not add noalias metadata a'la 618 /// addNewMedata. To do that we need to compute the intersection of the 619 /// noalias info from all members. 620 void addMetadata(InstTy *NewInst) const; 621 622 /// Returns true if this Group requires a scalar iteration to handle gaps. requiresScalarEpilogue()623 bool requiresScalarEpilogue() const { 624 // If the last member of the Group exists, then a scalar epilog is not 625 // needed for this group. 626 if (getMember(getFactor() - 1)) 627 return false; 628 629 // We have a group with gaps. It therefore can't be a reversed access, 630 // because such groups get invalidated (TODO). 631 assert(!isReverse() && "Group should have been invalidated"); 632 633 // This is a group of loads, with gaps, and without a last-member 634 return true; 635 } 636 637 private: 638 uint32_t Factor; // Interleave Factor. 639 bool Reverse; 640 Align Alignment; 641 DenseMap<int32_t, InstTy *> Members; 642 int32_t SmallestKey = 0; 643 int32_t LargestKey = 0; 644 645 // To avoid breaking dependences, vectorized instructions of an interleave 646 // group should be inserted at either the first load or the last store in 647 // program order. 648 // 649 // E.g. %even = load i32 // Insert Position 650 // %add = add i32 %even // Use of %even 651 // %odd = load i32 652 // 653 // store i32 %even 654 // %odd = add i32 // Def of %odd 655 // store i32 %odd // Insert Position 656 InstTy *InsertPos; 657 }; 658 659 /// Drive the analysis of interleaved memory accesses in the loop. 660 /// 661 /// Use this class to analyze interleaved accesses only when we can vectorize 662 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 663 /// on interleaved accesses is unsafe. 664 /// 665 /// The analysis collects interleave groups and records the relationships 666 /// between the member and the group in a map. 667 class InterleavedAccessInfo { 668 public: InterleavedAccessInfo(PredicatedScalarEvolution & PSE,Loop * L,DominatorTree * DT,LoopInfo * LI,const LoopAccessInfo * LAI)669 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 670 DominatorTree *DT, LoopInfo *LI, 671 const LoopAccessInfo *LAI) 672 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 673 ~InterleavedAccessInfo()674 ~InterleavedAccessInfo() { invalidateGroups(); } 675 676 /// Analyze the interleaved accesses and collect them in interleave 677 /// groups. Substitute symbolic strides using \p Strides. 678 /// Consider also predicated loads/stores in the analysis if 679 /// \p EnableMaskedInterleavedGroup is true. 680 LLVM_ABI void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 681 682 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 683 /// contrary to original assumption. Although we currently prevent group 684 /// formation for predicated accesses, we may be able to relax this limitation 685 /// in the future once we handle more complicated blocks. Returns true if any 686 /// groups were invalidated. invalidateGroups()687 bool invalidateGroups() { 688 if (InterleaveGroups.empty()) { 689 assert( 690 !RequiresScalarEpilogue && 691 "RequiresScalarEpilog should not be set without interleave groups"); 692 return false; 693 } 694 695 InterleaveGroupMap.clear(); 696 for (auto *Ptr : InterleaveGroups) 697 delete Ptr; 698 InterleaveGroups.clear(); 699 RequiresScalarEpilogue = false; 700 return true; 701 } 702 703 /// Check if \p Instr belongs to any interleave group. isInterleaved(Instruction * Instr)704 bool isInterleaved(Instruction *Instr) const { 705 return InterleaveGroupMap.contains(Instr); 706 } 707 708 /// Get the interleave group that \p Instr belongs to. 709 /// 710 /// \returns nullptr if doesn't have such group. 711 InterleaveGroup<Instruction> * getInterleaveGroup(const Instruction * Instr)712 getInterleaveGroup(const Instruction *Instr) const { 713 return InterleaveGroupMap.lookup(Instr); 714 } 715 716 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> getInterleaveGroups()717 getInterleaveGroups() { 718 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 719 } 720 721 /// Returns true if an interleaved group that may access memory 722 /// out-of-bounds requires a scalar epilogue iteration for correctness. requiresScalarEpilogue()723 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 724 725 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 726 /// happen when optimizing for size forbids a scalar epilogue, and the gap 727 /// cannot be filtered by masking the load/store. 728 LLVM_ABI void invalidateGroupsRequiringScalarEpilogue(); 729 730 /// Returns true if we have any interleave groups. hasGroups()731 bool hasGroups() const { return !InterleaveGroups.empty(); } 732 733 private: 734 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 735 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 736 /// The interleaved access analysis can also add new predicates (for example 737 /// by versioning strides of pointers). 738 PredicatedScalarEvolution &PSE; 739 740 Loop *TheLoop; 741 DominatorTree *DT; 742 LoopInfo *LI; 743 const LoopAccessInfo *LAI; 744 745 /// True if the loop may contain non-reversed interleaved groups with 746 /// out-of-bounds accesses. We ensure we don't speculatively access memory 747 /// out-of-bounds by executing at least one scalar epilogue iteration. 748 bool RequiresScalarEpilogue = false; 749 750 /// Holds the relationships between the members and the interleave group. 751 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 752 753 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 754 755 /// Holds dependences among the memory accesses in the loop. It maps a source 756 /// access to a set of dependent sink accesses. 757 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 758 759 /// The descriptor for a strided memory access. 760 struct StrideDescriptor { 761 StrideDescriptor() = default; StrideDescriptorStrideDescriptor762 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 763 Align Alignment) 764 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 765 766 // The access's stride. It is negative for a reverse access. 767 int64_t Stride = 0; 768 769 // The scalar expression of this access. 770 const SCEV *Scev = nullptr; 771 772 // The size of the memory object. 773 uint64_t Size = 0; 774 775 // The alignment of this access. 776 Align Alignment; 777 }; 778 779 /// A type for holding instructions and their stride descriptors. 780 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 781 782 /// Create a new interleave group with the given instruction \p Instr, 783 /// stride \p Stride and alignment \p Align. 784 /// 785 /// \returns the newly created interleave group. 786 InterleaveGroup<Instruction> * createInterleaveGroup(Instruction * Instr,int Stride,Align Alignment)787 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 788 auto [It, Inserted] = InterleaveGroupMap.try_emplace(Instr); 789 assert(Inserted && "Already in an interleaved access group"); 790 It->second = new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 791 InterleaveGroups.insert(It->second); 792 return It->second; 793 } 794 795 /// Release the group and remove all the relationships. releaseGroup(InterleaveGroup<Instruction> * Group)796 void releaseGroup(InterleaveGroup<Instruction> *Group) { 797 InterleaveGroups.erase(Group); 798 releaseGroupWithoutRemovingFromSet(Group); 799 } 800 801 /// Do everything necessary to release the group, apart from removing it from 802 /// the InterleaveGroups set. releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> * Group)803 void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) { 804 for (unsigned i = 0; i < Group->getFactor(); i++) 805 if (Instruction *Member = Group->getMember(i)) 806 InterleaveGroupMap.erase(Member); 807 808 delete Group; 809 } 810 811 /// Collect all the accesses with a constant stride in program order. 812 void collectConstStrideAccesses( 813 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 814 const DenseMap<Value *, const SCEV *> &Strides); 815 816 /// Returns true if \p Stride is allowed in an interleaved group. 817 LLVM_ABI static bool isStrided(int Stride); 818 819 /// Returns true if \p BB is a predicated block. isPredicated(BasicBlock * BB)820 bool isPredicated(BasicBlock *BB) const { 821 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 822 } 823 824 /// Returns true if LoopAccessInfo can be used for dependence queries. areDependencesValid()825 bool areDependencesValid() const { 826 return LAI && LAI->getDepChecker().getDependences(); 827 } 828 829 /// Returns true if memory accesses \p A and \p B can be reordered, if 830 /// necessary, when constructing interleaved groups. 831 /// 832 /// \p A must precede \p B in program order. We return false if reordering is 833 /// not necessary or is prevented because \p A and \p B may be dependent. canReorderMemAccessesForInterleavedGroups(StrideEntry * A,StrideEntry * B)834 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 835 StrideEntry *B) const { 836 // Code motion for interleaved accesses can potentially hoist strided loads 837 // and sink strided stores. The code below checks the legality of the 838 // following two conditions: 839 // 840 // 1. Potentially moving a strided load (B) before any store (A) that 841 // precedes B, or 842 // 843 // 2. Potentially moving a strided store (A) after any load or store (B) 844 // that A precedes. 845 // 846 // It's legal to reorder A and B if we know there isn't a dependence from A 847 // to B. Note that this determination is conservative since some 848 // dependences could potentially be reordered safely. 849 850 // A is potentially the source of a dependence. 851 auto *Src = A->first; 852 auto SrcDes = A->second; 853 854 // B is potentially the sink of a dependence. 855 auto *Sink = B->first; 856 auto SinkDes = B->second; 857 858 // Code motion for interleaved accesses can't violate WAR dependences. 859 // Thus, reordering is legal if the source isn't a write. 860 if (!Src->mayWriteToMemory()) 861 return true; 862 863 // At least one of the accesses must be strided. 864 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 865 return true; 866 867 // If dependence information is not available from LoopAccessInfo, 868 // conservatively assume the instructions can't be reordered. 869 if (!areDependencesValid()) 870 return false; 871 872 // If we know there is a dependence from source to sink, assume the 873 // instructions can't be reordered. Otherwise, reordering is legal. 874 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink); 875 } 876 877 /// Collect the dependences from LoopAccessInfo. 878 /// 879 /// We process the dependences once during the interleaved access analysis to 880 /// enable constant-time dependence queries. collectDependences()881 void collectDependences() { 882 if (!areDependencesValid()) 883 return; 884 const auto &DepChecker = LAI->getDepChecker(); 885 auto *Deps = DepChecker.getDependences(); 886 for (auto Dep : *Deps) 887 Dependences[Dep.getSource(DepChecker)].insert( 888 Dep.getDestination(DepChecker)); 889 } 890 }; 891 892 } // llvm namespace 893 894 #endif 895