xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp (revision 681ce946f33e75c590e97c53076e86dff1fe8f4a)
1  //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This pass implements IR expansion for vector predication intrinsics, allowing
10  // targets to enable vector predication until just before codegen.
11  //
12  //===----------------------------------------------------------------------===//
13  
14  #include "llvm/CodeGen/ExpandVectorPredication.h"
15  #include "llvm/ADT/Statistic.h"
16  #include "llvm/Analysis/TargetTransformInfo.h"
17  #include "llvm/Analysis/ValueTracking.h"
18  #include "llvm/CodeGen/Passes.h"
19  #include "llvm/IR/Constants.h"
20  #include "llvm/IR/Function.h"
21  #include "llvm/IR/IRBuilder.h"
22  #include "llvm/IR/InstIterator.h"
23  #include "llvm/IR/Instructions.h"
24  #include "llvm/IR/IntrinsicInst.h"
25  #include "llvm/IR/Intrinsics.h"
26  #include "llvm/IR/Module.h"
27  #include "llvm/InitializePasses.h"
28  #include "llvm/Pass.h"
29  #include "llvm/Support/CommandLine.h"
30  #include "llvm/Support/Compiler.h"
31  #include "llvm/Support/Debug.h"
32  #include "llvm/Support/MathExtras.h"
33  
34  using namespace llvm;
35  
36  using VPLegalization = TargetTransformInfo::VPLegalization;
37  using VPTransform = TargetTransformInfo::VPLegalization::VPTransform;
38  
39  // Keep this in sync with TargetTransformInfo::VPLegalization.
40  #define VPINTERNAL_VPLEGAL_CASES                                               \
41    VPINTERNAL_CASE(Legal)                                                       \
42    VPINTERNAL_CASE(Discard)                                                     \
43    VPINTERNAL_CASE(Convert)
44  
45  #define VPINTERNAL_CASE(X) "|" #X
46  
47  // Override options.
48  static cl::opt<std::string> EVLTransformOverride(
49      "expandvp-override-evl-transform", cl::init(""), cl::Hidden,
50      cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
51               ". If non-empty, ignore "
52               "TargetTransformInfo and "
53               "always use this transformation for the %evl parameter (Used in "
54               "testing)."));
55  
56  static cl::opt<std::string> MaskTransformOverride(
57      "expandvp-override-mask-transform", cl::init(""), cl::Hidden,
58      cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES
59               ". If non-empty, Ignore "
60               "TargetTransformInfo and "
61               "always use this transformation for the %mask parameter (Used in "
62               "testing)."));
63  
64  #undef VPINTERNAL_CASE
65  #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X)
66  
67  static VPTransform parseOverrideOption(const std::string &TextOpt) {
68    return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES;
69  }
70  
71  #undef VPINTERNAL_VPLEGAL_CASES
72  
73  // Whether any override options are set.
74  static bool anyExpandVPOverridesSet() {
75    return !EVLTransformOverride.empty() || !MaskTransformOverride.empty();
76  }
77  
78  #define DEBUG_TYPE "expandvp"
79  
80  STATISTIC(NumFoldedVL, "Number of folded vector length params");
81  STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
82  
83  ///// Helpers {
84  
85  /// \returns Whether the vector mask \p MaskVal has all lane bits set.
86  static bool isAllTrueMask(Value *MaskVal) {
87    auto *ConstVec = dyn_cast<ConstantVector>(MaskVal);
88    return ConstVec && ConstVec->isAllOnesValue();
89  }
90  
91  /// \returns A non-excepting divisor constant for this type.
92  static Constant *getSafeDivisor(Type *DivTy) {
93    assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type");
94    return ConstantInt::get(DivTy, 1u, false);
95  }
96  
97  /// Transfer operation properties from \p OldVPI to \p NewVal.
98  static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
99    auto *NewInst = dyn_cast<Instruction>(&NewVal);
100    if (!NewInst || !isa<FPMathOperator>(NewVal))
101      return;
102  
103    auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
104    if (!OldFMOp)
105      return;
106  
107    NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
108  }
109  
110  /// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
111  /// OldVP gets erased.
112  static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
113    transferDecorations(NewOp, OldOp);
114    OldOp.replaceAllUsesWith(&NewOp);
115    OldOp.eraseFromParent();
116  }
117  
118  //// } Helpers
119  
120  namespace {
121  
122  // Expansion pass state at function scope.
123  struct CachingVPExpander {
124    Function &F;
125    const TargetTransformInfo &TTI;
126  
127    /// \returns A (fixed length) vector with ascending integer indices
128    /// (<0, 1, ..., NumElems-1>).
129    /// \p Builder
130    ///    Used for instruction creation.
131    /// \p LaneTy
132    ///    Integer element type of the result vector.
133    /// \p NumElems
134    ///    Number of vector elements.
135    Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy,
136                            unsigned NumElems);
137  
138    /// \returns A bitmask that is true where the lane position is less-than \p
139    /// EVLParam
140    ///
141    /// \p Builder
142    ///    Used for instruction creation.
143    /// \p VLParam
144    ///    The explicit vector length parameter to test against the lane
145    ///    positions.
146    /// \p ElemCount
147    ///    Static (potentially scalable) number of vector elements.
148    Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
149                            ElementCount ElemCount);
150  
151    Value *foldEVLIntoMask(VPIntrinsic &VPI);
152  
153    /// "Remove" the %evl parameter of \p PI by setting it to the static vector
154    /// length of the operation.
155    void discardEVLParameter(VPIntrinsic &PI);
156  
157    /// \brief Lower this VP binary operator to a unpredicated binary operator.
158    Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
159                                             VPIntrinsic &PI);
160  
161    /// \brief Query TTI and expand the vector predication in \p P accordingly.
162    Value *expandPredication(VPIntrinsic &PI);
163  
164    /// \brief  Determine how and whether the VPIntrinsic \p VPI shall be
165    /// expanded. This overrides TTI with the cl::opts listed at the top of this
166    /// file.
167    VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const;
168    bool UsingTTIOverrides;
169  
170  public:
171    CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
172        : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {}
173  
174    bool expandVectorPredication();
175  };
176  
177  //// CachingVPExpander {
178  
179  Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy,
180                                             unsigned NumElems) {
181    // TODO add caching
182    SmallVector<Constant *, 16> ConstElems;
183  
184    for (unsigned Idx = 0; Idx < NumElems; ++Idx)
185      ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
186  
187    return ConstantVector::get(ConstElems);
188  }
189  
190  Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
191                                             Value *EVLParam,
192                                             ElementCount ElemCount) {
193    // TODO add caching
194    // Scalable vector %evl conversion.
195    if (ElemCount.isScalable()) {
196      auto *M = Builder.GetInsertBlock()->getModule();
197      Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
198      Function *ActiveMaskFunc = Intrinsic::getDeclaration(
199          M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
200      // `get_active_lane_mask` performs an implicit less-than comparison.
201      Value *ConstZero = Builder.getInt32(0);
202      return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
203    }
204  
205    // Fixed vector %evl conversion.
206    Type *LaneTy = EVLParam->getType();
207    unsigned NumElems = ElemCount.getFixedValue();
208    Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam);
209    Value *IdxVec = createStepVector(Builder, LaneTy, NumElems);
210    return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
211  }
212  
213  Value *
214  CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
215                                                       VPIntrinsic &VPI) {
216    assert((isSafeToSpeculativelyExecute(&VPI) ||
217            VPI.canIgnoreVectorLengthParam()) &&
218           "Implicitly dropping %evl in non-speculatable operator!");
219  
220    auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode());
221    assert(Instruction::isBinaryOp(OC));
222  
223    Value *Op0 = VPI.getOperand(0);
224    Value *Op1 = VPI.getOperand(1);
225    Value *Mask = VPI.getMaskParam();
226  
227    // Blend in safe operands.
228    if (Mask && !isAllTrueMask(Mask)) {
229      switch (OC) {
230      default:
231        // Can safely ignore the predicate.
232        break;
233  
234      // Division operators need a safe divisor on masked-off lanes (1).
235      case Instruction::UDiv:
236      case Instruction::SDiv:
237      case Instruction::URem:
238      case Instruction::SRem:
239        // 2nd operand must not be zero.
240        Value *SafeDivisor = getSafeDivisor(VPI.getType());
241        Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor);
242      }
243    }
244  
245    Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName());
246  
247    replaceOperation(*NewBinOp, VPI);
248    return NewBinOp;
249  }
250  
251  void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
252    LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
253  
254    if (VPI.canIgnoreVectorLengthParam())
255      return;
256  
257    Value *EVLParam = VPI.getVectorLengthParam();
258    if (!EVLParam)
259      return;
260  
261    ElementCount StaticElemCount = VPI.getStaticVectorLength();
262    Value *MaxEVL = nullptr;
263    Type *Int32Ty = Type::getInt32Ty(VPI.getContext());
264    if (StaticElemCount.isScalable()) {
265      // TODO add caching
266      auto *M = VPI.getModule();
267      Function *VScaleFunc =
268          Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
269      IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
270      Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
271      Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
272      MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
273                                 /*NUW*/ true, /*NSW*/ false);
274    } else {
275      MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
276    }
277    VPI.setVectorLengthParam(MaxEVL);
278  }
279  
280  Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
281    LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
282  
283    IRBuilder<> Builder(&VPI);
284  
285    // Ineffective %evl parameter and so nothing to do here.
286    if (VPI.canIgnoreVectorLengthParam())
287      return &VPI;
288  
289    // Only VP intrinsics can have an %evl parameter.
290    Value *OldMaskParam = VPI.getMaskParam();
291    Value *OldEVLParam = VPI.getVectorLengthParam();
292    assert(OldMaskParam && "no mask param to fold the vl param into");
293    assert(OldEVLParam && "no EVL param to fold away");
294  
295    LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
296    LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
297  
298    // Convert the %evl predication into vector mask predication.
299    ElementCount ElemCount = VPI.getStaticVectorLength();
300    Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
301    Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
302    VPI.setMaskParam(NewMaskParam);
303  
304    // Drop the %evl parameter.
305    discardEVLParameter(VPI);
306    assert(VPI.canIgnoreVectorLengthParam() &&
307           "transformation did not render the evl param ineffective!");
308  
309    // Reassess the modified instruction.
310    return &VPI;
311  }
312  
313  Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
314    LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
315  
316    IRBuilder<> Builder(&VPI);
317  
318    // Try lowering to a LLVM instruction first.
319    auto OC = VPI.getFunctionalOpcode();
320  
321    if (OC && Instruction::isBinaryOp(*OC))
322      return expandPredicationInBinaryOperator(Builder, VPI);
323  
324    return &VPI;
325  }
326  
327  //// } CachingVPExpander
328  
329  struct TransformJob {
330    VPIntrinsic *PI;
331    TargetTransformInfo::VPLegalization Strategy;
332    TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
333        : PI(PI), Strategy(InitStrat) {}
334  
335    bool isDone() const { return Strategy.shouldDoNothing(); }
336  };
337  
338  void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) {
339    // Speculatable instructions do not strictly need predication.
340    if (isSafeToSpeculativelyExecute(&I)) {
341      // Converting a speculatable VP intrinsic means dropping %mask and %evl.
342      // No need to expand %evl into the %mask only to ignore that code.
343      if (LegalizeStrat.OpStrategy == VPLegalization::Convert)
344        LegalizeStrat.EVLParamStrategy = VPLegalization::Discard;
345      return;
346    }
347  
348    // We have to preserve the predicating effect of %evl for this
349    // non-speculatable VP intrinsic.
350    // 1) Never discard %evl.
351    // 2) If this VP intrinsic will be expanded to non-VP code, make sure that
352    //    %evl gets folded into %mask.
353    if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) ||
354        (LegalizeStrat.OpStrategy == VPLegalization::Convert)) {
355      LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
356    }
357  }
358  
359  VPLegalization
360  CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
361    auto VPStrat = TTI.getVPLegalizationStrategy(VPI);
362    if (LLVM_LIKELY(!UsingTTIOverrides)) {
363      // No overrides - we are in production.
364      return VPStrat;
365    }
366  
367    // Overrides set - we are in testing, the following does not need to be
368    // efficient.
369    VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride);
370    VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride);
371    return VPStrat;
372  }
373  
374  /// \brief Expand llvm.vp.* intrinsics as requested by \p TTI.
375  bool CachingVPExpander::expandVectorPredication() {
376    SmallVector<TransformJob, 16> Worklist;
377  
378    // Collect all VPIntrinsics that need expansion and determine their expansion
379    // strategy.
380    for (auto &I : instructions(F)) {
381      auto *VPI = dyn_cast<VPIntrinsic>(&I);
382      if (!VPI)
383        continue;
384      auto VPStrat = getVPLegalizationStrategy(*VPI);
385      sanitizeStrategy(I, VPStrat);
386      if (!VPStrat.shouldDoNothing())
387        Worklist.emplace_back(VPI, VPStrat);
388    }
389    if (Worklist.empty())
390      return false;
391  
392    // Transform all VPIntrinsics on the worklist.
393    LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size()
394                      << " instructions ::::\n");
395    for (TransformJob Job : Worklist) {
396      // Transform the EVL parameter.
397      switch (Job.Strategy.EVLParamStrategy) {
398      case VPLegalization::Legal:
399        break;
400      case VPLegalization::Discard:
401        discardEVLParameter(*Job.PI);
402        break;
403      case VPLegalization::Convert:
404        if (foldEVLIntoMask(*Job.PI))
405          ++NumFoldedVL;
406        break;
407      }
408      Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
409  
410      // Replace with a non-predicated operation.
411      switch (Job.Strategy.OpStrategy) {
412      case VPLegalization::Legal:
413        break;
414      case VPLegalization::Discard:
415        llvm_unreachable("Invalid strategy for operators.");
416      case VPLegalization::Convert:
417        expandPredication(*Job.PI);
418        ++NumLoweredVPOps;
419        break;
420      }
421      Job.Strategy.OpStrategy = VPLegalization::Legal;
422  
423      assert(Job.isDone() && "incomplete transformation");
424    }
425  
426    return true;
427  }
428  class ExpandVectorPredication : public FunctionPass {
429  public:
430    static char ID;
431    ExpandVectorPredication() : FunctionPass(ID) {
432      initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
433    }
434  
435    bool runOnFunction(Function &F) override {
436      const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
437      CachingVPExpander VPExpander(F, *TTI);
438      return VPExpander.expandVectorPredication();
439    }
440  
441    void getAnalysisUsage(AnalysisUsage &AU) const override {
442      AU.addRequired<TargetTransformInfoWrapperPass>();
443      AU.setPreservesCFG();
444    }
445  };
446  } // namespace
447  
448  char ExpandVectorPredication::ID;
449  INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp",
450                        "Expand vector predication intrinsics", false, false)
451  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
452  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
453  INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp",
454                      "Expand vector predication intrinsics", false, false)
455  
456  FunctionPass *llvm::createExpandVectorPredicationPass() {
457    return new ExpandVectorPredication();
458  }
459  
460  PreservedAnalyses
461  ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
462    const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
463    CachingVPExpander VPExpander(F, TTI);
464    if (!VPExpander.expandVectorPredication())
465      return PreservedAnalyses::all();
466    PreservedAnalyses PA;
467    PA.preserveSet<CFGAnalyses>();
468    return PA;
469  }
470