xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp (revision 21817992b3314c908ab50f0bb88d2ee750b9c4ac)
1  //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This pass implements IR expansion for vector predication intrinsics, allowing
10  // targets to enable vector predication until just before codegen.
11  //
12  //===----------------------------------------------------------------------===//
13  
14  #include "llvm/CodeGen/ExpandVectorPredication.h"
15  #include "llvm/ADT/Statistic.h"
16  #include "llvm/Analysis/TargetTransformInfo.h"
17  #include "llvm/Analysis/ValueTracking.h"
18  #include "llvm/Analysis/VectorUtils.h"
19  #include "llvm/CodeGen/Passes.h"
20  #include "llvm/IR/Constants.h"
21  #include "llvm/IR/Function.h"
22  #include "llvm/IR/IRBuilder.h"
23  #include "llvm/IR/InstIterator.h"
24  #include "llvm/IR/Instructions.h"
25  #include "llvm/IR/IntrinsicInst.h"
26  #include "llvm/IR/Intrinsics.h"
27  #include "llvm/InitializePasses.h"
28  #include "llvm/Pass.h"
29  #include "llvm/Support/CommandLine.h"
30  #include "llvm/Support/Compiler.h"
31  #include "llvm/Support/Debug.h"
32  #include <optional>
33  
34  using namespace llvm;
35  
36  using VPLegalization = TargetTransformInfo::VPLegalization;
37  using VPTransform = TargetTransformInfo::VPLegalization::VPTransform;
38  
39  // Keep this in sync with TargetTransformInfo::VPLegalization.
40  #define VPINTERNAL_VPLEGAL_CASES                                               \
41    VPINTERNAL_CASE(Legal)                                                       \
42    VPINTERNAL_CASE(Discard)                                                     \
43    VPINTERNAL_CASE(Convert)
44  
45  #define VPINTERNAL_CASE(X) "|" #X
46  
47  // Override options.
48  static cl::opt<std::string> EVLTransformOverride(
49      "expandvp-override-evl-transform", cl::init(""), cl::Hidden,
50      cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
51               ". If non-empty, ignore "
52               "TargetTransformInfo and "
53               "always use this transformation for the %evl parameter (Used in "
54               "testing)."));
55  
56  static cl::opt<std::string> MaskTransformOverride(
57      "expandvp-override-mask-transform", cl::init(""), cl::Hidden,
58      cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
59               ". If non-empty, Ignore "
60               "TargetTransformInfo and "
61               "always use this transformation for the %mask parameter (Used in "
62               "testing)."));
63  
64  #undef VPINTERNAL_CASE
65  #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X)
66  
67  static VPTransform parseOverrideOption(const std::string &TextOpt) {
68    return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES;
69  }
70  
71  #undef VPINTERNAL_VPLEGAL_CASES
72  
73  // Whether any override options are set.
74  static bool anyExpandVPOverridesSet() {
75    return !EVLTransformOverride.empty() || !MaskTransformOverride.empty();
76  }
77  
78  #define DEBUG_TYPE "expandvp"
79  
80  STATISTIC(NumFoldedVL, "Number of folded vector length params");
81  STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
82  
83  ///// Helpers {
84  
85  /// \returns Whether the vector mask \p MaskVal has all lane bits set.
86  static bool isAllTrueMask(Value *MaskVal) {
87    if (Value *SplattedVal = getSplatValue(MaskVal))
88      if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
89        return ConstValue->isAllOnesValue();
90  
91    return false;
92  }
93  
94  /// \returns A non-excepting divisor constant for this type.
95  static Constant *getSafeDivisor(Type *DivTy) {
96    assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type");
97    return ConstantInt::get(DivTy, 1u, false);
98  }
99  
100  /// Transfer operation properties from \p OldVPI to \p NewVal.
101  static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
102    auto *NewInst = dyn_cast<Instruction>(&NewVal);
103    if (!NewInst || !isa<FPMathOperator>(NewVal))
104      return;
105  
106    auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
107    if (!OldFMOp)
108      return;
109  
110    NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
111  }
112  
113  /// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
114  /// OldVP gets erased.
115  static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
116    transferDecorations(NewOp, OldOp);
117    OldOp.replaceAllUsesWith(&NewOp);
118    OldOp.eraseFromParent();
119  }
120  
121  static bool maySpeculateLanes(VPIntrinsic &VPI) {
122    // The result of VP reductions depends on the mask and evl.
123    if (isa<VPReductionIntrinsic>(VPI))
124      return false;
125    // Fallback to whether the intrinsic is speculatable.
126    if (auto IntrID = VPI.getFunctionalIntrinsicID())
127      return Intrinsic::getAttributes(VPI.getContext(), *IntrID)
128          .hasFnAttr(Attribute::AttrKind::Speculatable);
129    if (auto Opc = VPI.getFunctionalOpcode())
130      return isSafeToSpeculativelyExecuteWithOpcode(*Opc, &VPI);
131    return false;
132  }
133  
134  //// } Helpers
135  
136  namespace {
137  
138  // Expansion pass state at function scope.
139  struct CachingVPExpander {
140    Function &F;
141    const TargetTransformInfo &TTI;
142  
143    /// \returns A (fixed length) vector with ascending integer indices
144    /// (<0, 1, ..., NumElems-1>).
145    /// \p Builder
146    ///    Used for instruction creation.
147    /// \p LaneTy
148    ///    Integer element type of the result vector.
149    /// \p NumElems
150    ///    Number of vector elements.
151    Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy,
152                            unsigned NumElems);
153  
154    /// \returns A bitmask that is true where the lane position is less-than \p
155    /// EVLParam
156    ///
157    /// \p Builder
158    ///    Used for instruction creation.
159    /// \p VLParam
160    ///    The explicit vector length parameter to test against the lane
161    ///    positions.
162    /// \p ElemCount
163    ///    Static (potentially scalable) number of vector elements.
164    Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
165                            ElementCount ElemCount);
166  
167    Value *foldEVLIntoMask(VPIntrinsic &VPI);
168  
169    /// "Remove" the %evl parameter of \p PI by setting it to the static vector
170    /// length of the operation.
171    void discardEVLParameter(VPIntrinsic &PI);
172  
173    /// Lower this VP binary operator to a unpredicated binary operator.
174    Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
175                                             VPIntrinsic &PI);
176  
177    /// Lower this VP int call to a unpredicated int call.
178    Value *expandPredicationToIntCall(IRBuilder<> &Builder, VPIntrinsic &PI,
179                                      unsigned UnpredicatedIntrinsicID);
180  
181    /// Lower this VP fp call to a unpredicated fp call.
182    Value *expandPredicationToFPCall(IRBuilder<> &Builder, VPIntrinsic &PI,
183                                     unsigned UnpredicatedIntrinsicID);
184  
185    /// Lower this VP reduction to a call to an unpredicated reduction intrinsic.
186    Value *expandPredicationInReduction(IRBuilder<> &Builder,
187                                        VPReductionIntrinsic &PI);
188  
189    /// Lower this VP cast operation to a non-VP intrinsic.
190    Value *expandPredicationToCastIntrinsic(IRBuilder<> &Builder,
191                                            VPIntrinsic &VPI);
192  
193    /// Lower this VP memory operation to a non-VP intrinsic.
194    Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
195                                              VPIntrinsic &VPI);
196  
197    /// Lower this VP comparison to a call to an unpredicated comparison.
198    Value *expandPredicationInComparison(IRBuilder<> &Builder,
199                                         VPCmpIntrinsic &PI);
200  
201    /// Query TTI and expand the vector predication in \p P accordingly.
202    Value *expandPredication(VPIntrinsic &PI);
203  
204    /// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This
205    /// overrides TTI with the cl::opts listed at the top of this file.
206    VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const;
207    bool UsingTTIOverrides;
208  
209  public:
210    CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
211        : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {}
212  
213    bool expandVectorPredication();
214  };
215  
216  //// CachingVPExpander {
217  
218  Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy,
219                                             unsigned NumElems) {
220    // TODO add caching
221    SmallVector<Constant *, 16> ConstElems;
222  
223    for (unsigned Idx = 0; Idx < NumElems; ++Idx)
224      ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
225  
226    return ConstantVector::get(ConstElems);
227  }
228  
229  Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
230                                             Value *EVLParam,
231                                             ElementCount ElemCount) {
232    // TODO add caching
233    // Scalable vector %evl conversion.
234    if (ElemCount.isScalable()) {
235      auto *M = Builder.GetInsertBlock()->getModule();
236      Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
237      Function *ActiveMaskFunc = Intrinsic::getDeclaration(
238          M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
239      // `get_active_lane_mask` performs an implicit less-than comparison.
240      Value *ConstZero = Builder.getInt32(0);
241      return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
242    }
243  
244    // Fixed vector %evl conversion.
245    Type *LaneTy = EVLParam->getType();
246    unsigned NumElems = ElemCount.getFixedValue();
247    Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam);
248    Value *IdxVec = createStepVector(Builder, LaneTy, NumElems);
249    return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
250  }
251  
252  Value *
253  CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
254                                                       VPIntrinsic &VPI) {
255    assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
256           "Implicitly dropping %evl in non-speculatable operator!");
257  
258    auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode());
259    assert(Instruction::isBinaryOp(OC));
260  
261    Value *Op0 = VPI.getOperand(0);
262    Value *Op1 = VPI.getOperand(1);
263    Value *Mask = VPI.getMaskParam();
264  
265    // Blend in safe operands.
266    if (Mask && !isAllTrueMask(Mask)) {
267      switch (OC) {
268      default:
269        // Can safely ignore the predicate.
270        break;
271  
272      // Division operators need a safe divisor on masked-off lanes (1).
273      case Instruction::UDiv:
274      case Instruction::SDiv:
275      case Instruction::URem:
276      case Instruction::SRem:
277        // 2nd operand must not be zero.
278        Value *SafeDivisor = getSafeDivisor(VPI.getType());
279        Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor);
280      }
281    }
282  
283    Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName());
284  
285    replaceOperation(*NewBinOp, VPI);
286    return NewBinOp;
287  }
288  
289  Value *CachingVPExpander::expandPredicationToIntCall(
290      IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) {
291    switch (UnpredicatedIntrinsicID) {
292    case Intrinsic::abs:
293    case Intrinsic::smax:
294    case Intrinsic::smin:
295    case Intrinsic::umax:
296    case Intrinsic::umin: {
297      Value *Op0 = VPI.getOperand(0);
298      Value *Op1 = VPI.getOperand(1);
299      Function *Fn = Intrinsic::getDeclaration(
300          VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
301      Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName());
302      replaceOperation(*NewOp, VPI);
303      return NewOp;
304    }
305    case Intrinsic::bswap:
306    case Intrinsic::bitreverse: {
307      Value *Op = VPI.getOperand(0);
308      Function *Fn = Intrinsic::getDeclaration(
309          VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
310      Value *NewOp = Builder.CreateCall(Fn, {Op}, VPI.getName());
311      replaceOperation(*NewOp, VPI);
312      return NewOp;
313    }
314    }
315    return nullptr;
316  }
317  
318  Value *CachingVPExpander::expandPredicationToFPCall(
319      IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) {
320    assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
321           "Implicitly dropping %evl in non-speculatable operator!");
322  
323    switch (UnpredicatedIntrinsicID) {
324    case Intrinsic::fabs:
325    case Intrinsic::sqrt: {
326      Value *Op0 = VPI.getOperand(0);
327      Function *Fn = Intrinsic::getDeclaration(
328          VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
329      Value *NewOp = Builder.CreateCall(Fn, {Op0}, VPI.getName());
330      replaceOperation(*NewOp, VPI);
331      return NewOp;
332    }
333    case Intrinsic::maxnum:
334    case Intrinsic::minnum: {
335      Value *Op0 = VPI.getOperand(0);
336      Value *Op1 = VPI.getOperand(1);
337      Function *Fn = Intrinsic::getDeclaration(
338          VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
339      Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName());
340      replaceOperation(*NewOp, VPI);
341      return NewOp;
342    }
343    case Intrinsic::experimental_constrained_fma:
344    case Intrinsic::experimental_constrained_fmuladd: {
345      Value *Op0 = VPI.getOperand(0);
346      Value *Op1 = VPI.getOperand(1);
347      Value *Op2 = VPI.getOperand(2);
348      Function *Fn = Intrinsic::getDeclaration(
349          VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});
350      Value *NewOp =
351          Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName());
352      replaceOperation(*NewOp, VPI);
353      return NewOp;
354    }
355    }
356  
357    return nullptr;
358  }
359  
360  static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI,
361                                           Type *EltTy) {
362    bool Negative = false;
363    unsigned EltBits = EltTy->getScalarSizeInBits();
364    switch (VPI.getIntrinsicID()) {
365    default:
366      llvm_unreachable("Expecting a VP reduction intrinsic");
367    case Intrinsic::vp_reduce_add:
368    case Intrinsic::vp_reduce_or:
369    case Intrinsic::vp_reduce_xor:
370    case Intrinsic::vp_reduce_umax:
371      return Constant::getNullValue(EltTy);
372    case Intrinsic::vp_reduce_mul:
373      return ConstantInt::get(EltTy, 1, /*IsSigned*/ false);
374    case Intrinsic::vp_reduce_and:
375    case Intrinsic::vp_reduce_umin:
376      return ConstantInt::getAllOnesValue(EltTy);
377    case Intrinsic::vp_reduce_smin:
378      return ConstantInt::get(EltTy->getContext(),
379                              APInt::getSignedMaxValue(EltBits));
380    case Intrinsic::vp_reduce_smax:
381      return ConstantInt::get(EltTy->getContext(),
382                              APInt::getSignedMinValue(EltBits));
383    case Intrinsic::vp_reduce_fmax:
384      Negative = true;
385      [[fallthrough]];
386    case Intrinsic::vp_reduce_fmin: {
387      FastMathFlags Flags = VPI.getFastMathFlags();
388      const fltSemantics &Semantics = EltTy->getFltSemantics();
389      return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative)
390             : !Flags.noInfs()
391                 ? ConstantFP::getInfinity(EltTy, Negative)
392                 : ConstantFP::get(EltTy,
393                                   APFloat::getLargest(Semantics, Negative));
394    }
395    case Intrinsic::vp_reduce_fadd:
396      return ConstantFP::getNegativeZero(EltTy);
397    case Intrinsic::vp_reduce_fmul:
398      return ConstantFP::get(EltTy, 1.0);
399    }
400  }
401  
402  Value *
403  CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder,
404                                                  VPReductionIntrinsic &VPI) {
405    assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
406           "Implicitly dropping %evl in non-speculatable operator!");
407  
408    Value *Mask = VPI.getMaskParam();
409    Value *RedOp = VPI.getOperand(VPI.getVectorParamPos());
410  
411    // Insert neutral element in masked-out positions
412    if (Mask && !isAllTrueMask(Mask)) {
413      auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType());
414      auto *NeutralVector = Builder.CreateVectorSplat(
415          cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt);
416      RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector);
417    }
418  
419    Value *Reduction;
420    Value *Start = VPI.getOperand(VPI.getStartParamPos());
421  
422    switch (VPI.getIntrinsicID()) {
423    default:
424      llvm_unreachable("Impossible reduction kind");
425    case Intrinsic::vp_reduce_add:
426      Reduction = Builder.CreateAddReduce(RedOp);
427      Reduction = Builder.CreateAdd(Reduction, Start);
428      break;
429    case Intrinsic::vp_reduce_mul:
430      Reduction = Builder.CreateMulReduce(RedOp);
431      Reduction = Builder.CreateMul(Reduction, Start);
432      break;
433    case Intrinsic::vp_reduce_and:
434      Reduction = Builder.CreateAndReduce(RedOp);
435      Reduction = Builder.CreateAnd(Reduction, Start);
436      break;
437    case Intrinsic::vp_reduce_or:
438      Reduction = Builder.CreateOrReduce(RedOp);
439      Reduction = Builder.CreateOr(Reduction, Start);
440      break;
441    case Intrinsic::vp_reduce_xor:
442      Reduction = Builder.CreateXorReduce(RedOp);
443      Reduction = Builder.CreateXor(Reduction, Start);
444      break;
445    case Intrinsic::vp_reduce_smax:
446      Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true);
447      Reduction =
448          Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start);
449      break;
450    case Intrinsic::vp_reduce_smin:
451      Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true);
452      Reduction =
453          Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start);
454      break;
455    case Intrinsic::vp_reduce_umax:
456      Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false);
457      Reduction =
458          Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start);
459      break;
460    case Intrinsic::vp_reduce_umin:
461      Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false);
462      Reduction =
463          Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start);
464      break;
465    case Intrinsic::vp_reduce_fmax:
466      Reduction = Builder.CreateFPMaxReduce(RedOp);
467      transferDecorations(*Reduction, VPI);
468      Reduction =
469          Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start);
470      break;
471    case Intrinsic::vp_reduce_fmin:
472      Reduction = Builder.CreateFPMinReduce(RedOp);
473      transferDecorations(*Reduction, VPI);
474      Reduction =
475          Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start);
476      break;
477    case Intrinsic::vp_reduce_fadd:
478      Reduction = Builder.CreateFAddReduce(Start, RedOp);
479      break;
480    case Intrinsic::vp_reduce_fmul:
481      Reduction = Builder.CreateFMulReduce(Start, RedOp);
482      break;
483    }
484  
485    replaceOperation(*Reduction, VPI);
486    return Reduction;
487  }
488  
489  Value *CachingVPExpander::expandPredicationToCastIntrinsic(IRBuilder<> &Builder,
490                                                             VPIntrinsic &VPI) {
491    Value *CastOp = nullptr;
492    switch (VPI.getIntrinsicID()) {
493    default:
494      llvm_unreachable("Not a VP cast intrinsic");
495    case Intrinsic::vp_sext:
496      CastOp =
497          Builder.CreateSExt(VPI.getOperand(0), VPI.getType(), VPI.getName());
498      break;
499    case Intrinsic::vp_zext:
500      CastOp =
501          Builder.CreateZExt(VPI.getOperand(0), VPI.getType(), VPI.getName());
502      break;
503    case Intrinsic::vp_trunc:
504      CastOp =
505          Builder.CreateTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName());
506      break;
507    case Intrinsic::vp_inttoptr:
508      CastOp =
509          Builder.CreateIntToPtr(VPI.getOperand(0), VPI.getType(), VPI.getName());
510      break;
511    case Intrinsic::vp_ptrtoint:
512      CastOp =
513          Builder.CreatePtrToInt(VPI.getOperand(0), VPI.getType(), VPI.getName());
514      break;
515    case Intrinsic::vp_fptosi:
516      CastOp =
517          Builder.CreateFPToSI(VPI.getOperand(0), VPI.getType(), VPI.getName());
518      break;
519  
520    case Intrinsic::vp_fptoui:
521      CastOp =
522          Builder.CreateFPToUI(VPI.getOperand(0), VPI.getType(), VPI.getName());
523      break;
524    case Intrinsic::vp_sitofp:
525      CastOp =
526          Builder.CreateSIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName());
527      break;
528    case Intrinsic::vp_uitofp:
529      CastOp =
530          Builder.CreateUIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName());
531      break;
532    case Intrinsic::vp_fptrunc:
533      CastOp =
534          Builder.CreateFPTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName());
535      break;
536    case Intrinsic::vp_fpext:
537      CastOp =
538          Builder.CreateFPExt(VPI.getOperand(0), VPI.getType(), VPI.getName());
539      break;
540    }
541    replaceOperation(*CastOp, VPI);
542    return CastOp;
543  }
544  
545  Value *
546  CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
547                                                        VPIntrinsic &VPI) {
548    assert(VPI.canIgnoreVectorLengthParam());
549  
550    const auto &DL = F.getParent()->getDataLayout();
551  
552    Value *MaskParam = VPI.getMaskParam();
553    Value *PtrParam = VPI.getMemoryPointerParam();
554    Value *DataParam = VPI.getMemoryDataParam();
555    bool IsUnmasked = isAllTrueMask(MaskParam);
556  
557    MaybeAlign AlignOpt = VPI.getPointerAlignment();
558  
559    Value *NewMemoryInst = nullptr;
560    switch (VPI.getIntrinsicID()) {
561    default:
562      llvm_unreachable("Not a VP memory intrinsic");
563    case Intrinsic::vp_store:
564      if (IsUnmasked) {
565        StoreInst *NewStore =
566            Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false);
567        if (AlignOpt.has_value())
568          NewStore->setAlignment(*AlignOpt);
569        NewMemoryInst = NewStore;
570      } else
571        NewMemoryInst = Builder.CreateMaskedStore(
572            DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam);
573  
574      break;
575    case Intrinsic::vp_load:
576      if (IsUnmasked) {
577        LoadInst *NewLoad =
578            Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false);
579        if (AlignOpt.has_value())
580          NewLoad->setAlignment(*AlignOpt);
581        NewMemoryInst = NewLoad;
582      } else
583        NewMemoryInst = Builder.CreateMaskedLoad(
584            VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam);
585  
586      break;
587    case Intrinsic::vp_scatter: {
588      auto *ElementType =
589          cast<VectorType>(DataParam->getType())->getElementType();
590      NewMemoryInst = Builder.CreateMaskedScatter(
591          DataParam, PtrParam,
592          AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam);
593      break;
594    }
595    case Intrinsic::vp_gather: {
596      auto *ElementType = cast<VectorType>(VPI.getType())->getElementType();
597      NewMemoryInst = Builder.CreateMaskedGather(
598          VPI.getType(), PtrParam,
599          AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr,
600          VPI.getName());
601      break;
602    }
603    }
604  
605    assert(NewMemoryInst);
606    replaceOperation(*NewMemoryInst, VPI);
607    return NewMemoryInst;
608  }
609  
610  Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder,
611                                                          VPCmpIntrinsic &VPI) {
612    assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&
613           "Implicitly dropping %evl in non-speculatable operator!");
614  
615    assert(*VPI.getFunctionalOpcode() == Instruction::ICmp ||
616           *VPI.getFunctionalOpcode() == Instruction::FCmp);
617  
618    Value *Op0 = VPI.getOperand(0);
619    Value *Op1 = VPI.getOperand(1);
620    auto Pred = VPI.getPredicate();
621  
622    auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1);
623  
624    replaceOperation(*NewCmp, VPI);
625    return NewCmp;
626  }
627  
628  void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
629    LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
630  
631    if (VPI.canIgnoreVectorLengthParam())
632      return;
633  
634    Value *EVLParam = VPI.getVectorLengthParam();
635    if (!EVLParam)
636      return;
637  
638    ElementCount StaticElemCount = VPI.getStaticVectorLength();
639    Value *MaxEVL = nullptr;
640    Type *Int32Ty = Type::getInt32Ty(VPI.getContext());
641    if (StaticElemCount.isScalable()) {
642      // TODO add caching
643      auto *M = VPI.getModule();
644      Function *VScaleFunc =
645          Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
646      IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
647      Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
648      Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
649      MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
650                                 /*NUW*/ true, /*NSW*/ false);
651    } else {
652      MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
653    }
654    VPI.setVectorLengthParam(MaxEVL);
655  }
656  
657  Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
658    LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
659  
660    IRBuilder<> Builder(&VPI);
661  
662    // Ineffective %evl parameter and so nothing to do here.
663    if (VPI.canIgnoreVectorLengthParam())
664      return &VPI;
665  
666    // Only VP intrinsics can have an %evl parameter.
667    Value *OldMaskParam = VPI.getMaskParam();
668    Value *OldEVLParam = VPI.getVectorLengthParam();
669    assert(OldMaskParam && "no mask param to fold the vl param into");
670    assert(OldEVLParam && "no EVL param to fold away");
671  
672    LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
673    LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
674  
675    // Convert the %evl predication into vector mask predication.
676    ElementCount ElemCount = VPI.getStaticVectorLength();
677    Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
678    Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
679    VPI.setMaskParam(NewMaskParam);
680  
681    // Drop the %evl parameter.
682    discardEVLParameter(VPI);
683    assert(VPI.canIgnoreVectorLengthParam() &&
684           "transformation did not render the evl param ineffective!");
685  
686    // Reassess the modified instruction.
687    return &VPI;
688  }
689  
690  Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
691    LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
692  
693    IRBuilder<> Builder(&VPI);
694  
695    // Try lowering to a LLVM instruction first.
696    auto OC = VPI.getFunctionalOpcode();
697  
698    if (OC && Instruction::isBinaryOp(*OC))
699      return expandPredicationInBinaryOperator(Builder, VPI);
700  
701    if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI))
702      return expandPredicationInReduction(Builder, *VPRI);
703  
704    if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI))
705      return expandPredicationInComparison(Builder, *VPCmp);
706  
707    if (VPCastIntrinsic::isVPCast(VPI.getIntrinsicID())) {
708      return expandPredicationToCastIntrinsic(Builder, VPI);
709    }
710  
711    switch (VPI.getIntrinsicID()) {
712    default:
713      break;
714    case Intrinsic::vp_fneg: {
715      Value *NewNegOp = Builder.CreateFNeg(VPI.getOperand(0), VPI.getName());
716      replaceOperation(*NewNegOp, VPI);
717      return NewNegOp;
718    }
719    case Intrinsic::vp_abs:
720    case Intrinsic::vp_smax:
721    case Intrinsic::vp_smin:
722    case Intrinsic::vp_umax:
723    case Intrinsic::vp_umin:
724    case Intrinsic::vp_bswap:
725    case Intrinsic::vp_bitreverse:
726      return expandPredicationToIntCall(Builder, VPI,
727                                        VPI.getFunctionalIntrinsicID().value());
728    case Intrinsic::vp_fabs:
729    case Intrinsic::vp_sqrt:
730    case Intrinsic::vp_maxnum:
731    case Intrinsic::vp_minnum:
732    case Intrinsic::vp_maximum:
733    case Intrinsic::vp_minimum:
734      return expandPredicationToFPCall(Builder, VPI,
735                                       VPI.getFunctionalIntrinsicID().value());
736    case Intrinsic::vp_load:
737    case Intrinsic::vp_store:
738    case Intrinsic::vp_gather:
739    case Intrinsic::vp_scatter:
740      return expandPredicationInMemoryIntrinsic(Builder, VPI);
741    }
742  
743    if (auto CID = VPI.getConstrainedIntrinsicID())
744      if (Value *Call = expandPredicationToFPCall(Builder, VPI, *CID))
745        return Call;
746  
747    return &VPI;
748  }
749  
750  //// } CachingVPExpander
751  
752  struct TransformJob {
753    VPIntrinsic *PI;
754    TargetTransformInfo::VPLegalization Strategy;
755    TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
756        : PI(PI), Strategy(InitStrat) {}
757  
758    bool isDone() const { return Strategy.shouldDoNothing(); }
759  };
760  
761  void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) {
762    // Operations with speculatable lanes do not strictly need predication.
763    if (maySpeculateLanes(VPI)) {
764      // Converting a speculatable VP intrinsic means dropping %mask and %evl.
765      // No need to expand %evl into the %mask only to ignore that code.
766      if (LegalizeStrat.OpStrategy == VPLegalization::Convert)
767        LegalizeStrat.EVLParamStrategy = VPLegalization::Discard;
768      return;
769    }
770  
771    // We have to preserve the predicating effect of %evl for this
772    // non-speculatable VP intrinsic.
773    // 1) Never discard %evl.
774    // 2) If this VP intrinsic will be expanded to non-VP code, make sure that
775    //    %evl gets folded into %mask.
776    if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) ||
777        (LegalizeStrat.OpStrategy == VPLegalization::Convert)) {
778      LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
779    }
780  }
781  
782  VPLegalization
783  CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
784    auto VPStrat = TTI.getVPLegalizationStrategy(VPI);
785    if (LLVM_LIKELY(!UsingTTIOverrides)) {
786      // No overrides - we are in production.
787      return VPStrat;
788    }
789  
790    // Overrides set - we are in testing, the following does not need to be
791    // efficient.
792    VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride);
793    VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride);
794    return VPStrat;
795  }
796  
797  /// Expand llvm.vp.* intrinsics as requested by \p TTI.
798  bool CachingVPExpander::expandVectorPredication() {
799    SmallVector<TransformJob, 16> Worklist;
800  
801    // Collect all VPIntrinsics that need expansion and determine their expansion
802    // strategy.
803    for (auto &I : instructions(F)) {
804      auto *VPI = dyn_cast<VPIntrinsic>(&I);
805      if (!VPI)
806        continue;
807      auto VPStrat = getVPLegalizationStrategy(*VPI);
808      sanitizeStrategy(*VPI, VPStrat);
809      if (!VPStrat.shouldDoNothing())
810        Worklist.emplace_back(VPI, VPStrat);
811    }
812    if (Worklist.empty())
813      return false;
814  
815    // Transform all VPIntrinsics on the worklist.
816    LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size()
817                      << " instructions ::::\n");
818    for (TransformJob Job : Worklist) {
819      // Transform the EVL parameter.
820      switch (Job.Strategy.EVLParamStrategy) {
821      case VPLegalization::Legal:
822        break;
823      case VPLegalization::Discard:
824        discardEVLParameter(*Job.PI);
825        break;
826      case VPLegalization::Convert:
827        if (foldEVLIntoMask(*Job.PI))
828          ++NumFoldedVL;
829        break;
830      }
831      Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
832  
833      // Replace with a non-predicated operation.
834      switch (Job.Strategy.OpStrategy) {
835      case VPLegalization::Legal:
836        break;
837      case VPLegalization::Discard:
838        llvm_unreachable("Invalid strategy for operators.");
839      case VPLegalization::Convert:
840        expandPredication(*Job.PI);
841        ++NumLoweredVPOps;
842        break;
843      }
844      Job.Strategy.OpStrategy = VPLegalization::Legal;
845  
846      assert(Job.isDone() && "incomplete transformation");
847    }
848  
849    return true;
850  }
851  class ExpandVectorPredication : public FunctionPass {
852  public:
853    static char ID;
854    ExpandVectorPredication() : FunctionPass(ID) {
855      initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
856    }
857  
858    bool runOnFunction(Function &F) override {
859      const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
860      CachingVPExpander VPExpander(F, *TTI);
861      return VPExpander.expandVectorPredication();
862    }
863  
864    void getAnalysisUsage(AnalysisUsage &AU) const override {
865      AU.addRequired<TargetTransformInfoWrapperPass>();
866      AU.setPreservesCFG();
867    }
868  };
869  } // namespace
870  
871  char ExpandVectorPredication::ID;
872  INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp",
873                        "Expand vector predication intrinsics", false, false)
874  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
875  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
876  INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp",
877                      "Expand vector predication intrinsics", false, false)
878  
879  FunctionPass *llvm::createExpandVectorPredicationPass() {
880    return new ExpandVectorPredication();
881  }
882  
883  PreservedAnalyses
884  ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
885    const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
886    CachingVPExpander VPExpander(F, TTI);
887    if (!VPExpander.expandVectorPredication())
888      return PreservedAnalyses::all();
889    PreservedAnalyses PA;
890    PA.preserveSet<CFGAnalyses>();
891    return PA;
892  }
893