xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ExpandReductions.cpp (revision a2fda816eb054d5873be223ef2461741dfcc253c)
1  //===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This pass implements IR expansion for reduction intrinsics, allowing targets
10  // to enable the intrinsics until just before codegen.
11  //
12  //===----------------------------------------------------------------------===//
13  
14  #include "llvm/CodeGen/ExpandReductions.h"
15  #include "llvm/Analysis/TargetTransformInfo.h"
16  #include "llvm/CodeGen/Passes.h"
17  #include "llvm/IR/IRBuilder.h"
18  #include "llvm/IR/InstIterator.h"
19  #include "llvm/IR/IntrinsicInst.h"
20  #include "llvm/IR/Intrinsics.h"
21  #include "llvm/InitializePasses.h"
22  #include "llvm/Pass.h"
23  #include "llvm/Transforms/Utils/LoopUtils.h"
24  
25  using namespace llvm;
26  
27  namespace {
28  
29  unsigned getOpcode(Intrinsic::ID ID) {
30    switch (ID) {
31    case Intrinsic::vector_reduce_fadd:
32      return Instruction::FAdd;
33    case Intrinsic::vector_reduce_fmul:
34      return Instruction::FMul;
35    case Intrinsic::vector_reduce_add:
36      return Instruction::Add;
37    case Intrinsic::vector_reduce_mul:
38      return Instruction::Mul;
39    case Intrinsic::vector_reduce_and:
40      return Instruction::And;
41    case Intrinsic::vector_reduce_or:
42      return Instruction::Or;
43    case Intrinsic::vector_reduce_xor:
44      return Instruction::Xor;
45    case Intrinsic::vector_reduce_smax:
46    case Intrinsic::vector_reduce_smin:
47    case Intrinsic::vector_reduce_umax:
48    case Intrinsic::vector_reduce_umin:
49      return Instruction::ICmp;
50    case Intrinsic::vector_reduce_fmax:
51    case Intrinsic::vector_reduce_fmin:
52      return Instruction::FCmp;
53    default:
54      llvm_unreachable("Unexpected ID");
55    }
56  }
57  
58  RecurKind getRK(Intrinsic::ID ID) {
59    switch (ID) {
60    case Intrinsic::vector_reduce_smax:
61      return RecurKind::SMax;
62    case Intrinsic::vector_reduce_smin:
63      return RecurKind::SMin;
64    case Intrinsic::vector_reduce_umax:
65      return RecurKind::UMax;
66    case Intrinsic::vector_reduce_umin:
67      return RecurKind::UMin;
68    case Intrinsic::vector_reduce_fmax:
69      return RecurKind::FMax;
70    case Intrinsic::vector_reduce_fmin:
71      return RecurKind::FMin;
72    default:
73      return RecurKind::None;
74    }
75  }
76  
77  bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
78    bool Changed = false;
79    SmallVector<IntrinsicInst *, 4> Worklist;
80    for (auto &I : instructions(F)) {
81      if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
82        switch (II->getIntrinsicID()) {
83        default: break;
84        case Intrinsic::vector_reduce_fadd:
85        case Intrinsic::vector_reduce_fmul:
86        case Intrinsic::vector_reduce_add:
87        case Intrinsic::vector_reduce_mul:
88        case Intrinsic::vector_reduce_and:
89        case Intrinsic::vector_reduce_or:
90        case Intrinsic::vector_reduce_xor:
91        case Intrinsic::vector_reduce_smax:
92        case Intrinsic::vector_reduce_smin:
93        case Intrinsic::vector_reduce_umax:
94        case Intrinsic::vector_reduce_umin:
95        case Intrinsic::vector_reduce_fmax:
96        case Intrinsic::vector_reduce_fmin:
97          if (TTI->shouldExpandReduction(II))
98            Worklist.push_back(II);
99  
100          break;
101        }
102      }
103    }
104  
105    for (auto *II : Worklist) {
106      FastMathFlags FMF =
107          isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
108      Intrinsic::ID ID = II->getIntrinsicID();
109      RecurKind RK = getRK(ID);
110  
111      Value *Rdx = nullptr;
112      IRBuilder<> Builder(II);
113      IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
114      Builder.setFastMathFlags(FMF);
115      switch (ID) {
116      default: llvm_unreachable("Unexpected intrinsic!");
117      case Intrinsic::vector_reduce_fadd:
118      case Intrinsic::vector_reduce_fmul: {
119        // FMFs must be attached to the call, otherwise it's an ordered reduction
120        // and it can't be handled by generating a shuffle sequence.
121        Value *Acc = II->getArgOperand(0);
122        Value *Vec = II->getArgOperand(1);
123        if (!FMF.allowReassoc())
124          Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
125        else {
126          if (!isPowerOf2_32(
127                  cast<FixedVectorType>(Vec->getType())->getNumElements()))
128            continue;
129  
130          Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
131          Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
132                                    Acc, Rdx, "bin.rdx");
133        }
134        break;
135      }
136      case Intrinsic::vector_reduce_and:
137      case Intrinsic::vector_reduce_or: {
138        // Canonicalize logical or/and reductions:
139        // Or reduction for i1 is represented as:
140        // %val = bitcast <ReduxWidth x i1> to iReduxWidth
141        // %res = cmp ne iReduxWidth %val, 0
142        // And reduction for i1 is represented as:
143        // %val = bitcast <ReduxWidth x i1> to iReduxWidth
144        // %res = cmp eq iReduxWidth %val, 11111
145        Value *Vec = II->getArgOperand(0);
146        auto *FTy = cast<FixedVectorType>(Vec->getType());
147        unsigned NumElts = FTy->getNumElements();
148        if (!isPowerOf2_32(NumElts))
149          continue;
150  
151        if (FTy->getElementType() == Builder.getInt1Ty()) {
152          Rdx = Builder.CreateBitCast(Vec, Builder.getIntNTy(NumElts));
153          if (ID == Intrinsic::vector_reduce_and) {
154            Rdx = Builder.CreateICmpEQ(
155                Rdx, ConstantInt::getAllOnesValue(Rdx->getType()));
156          } else {
157            assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
158            Rdx = Builder.CreateIsNotNull(Rdx);
159          }
160          break;
161        }
162  
163        Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
164        break;
165      }
166      case Intrinsic::vector_reduce_add:
167      case Intrinsic::vector_reduce_mul:
168      case Intrinsic::vector_reduce_xor:
169      case Intrinsic::vector_reduce_smax:
170      case Intrinsic::vector_reduce_smin:
171      case Intrinsic::vector_reduce_umax:
172      case Intrinsic::vector_reduce_umin: {
173        Value *Vec = II->getArgOperand(0);
174        if (!isPowerOf2_32(
175                cast<FixedVectorType>(Vec->getType())->getNumElements()))
176          continue;
177  
178        Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
179        break;
180      }
181      case Intrinsic::vector_reduce_fmax:
182      case Intrinsic::vector_reduce_fmin: {
183        // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
184        // semantics of the reduction.
185        Value *Vec = II->getArgOperand(0);
186        if (!isPowerOf2_32(
187                cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
188            !FMF.noNaNs())
189          continue;
190  
191        Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
192        break;
193      }
194      }
195      II->replaceAllUsesWith(Rdx);
196      II->eraseFromParent();
197      Changed = true;
198    }
199    return Changed;
200  }
201  
202  class ExpandReductions : public FunctionPass {
203  public:
204    static char ID;
205    ExpandReductions() : FunctionPass(ID) {
206      initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
207    }
208  
209    bool runOnFunction(Function &F) override {
210      const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
211      return expandReductions(F, TTI);
212    }
213  
214    void getAnalysisUsage(AnalysisUsage &AU) const override {
215      AU.addRequired<TargetTransformInfoWrapperPass>();
216      AU.setPreservesCFG();
217    }
218  };
219  }
220  
221  char ExpandReductions::ID;
222  INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
223                        "Expand reduction intrinsics", false, false)
224  INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
225  INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
226                      "Expand reduction intrinsics", false, false)
227  
228  FunctionPass *llvm::createExpandReductionsPass() {
229    return new ExpandReductions();
230  }
231  
232  PreservedAnalyses ExpandReductionsPass::run(Function &F,
233                                              FunctionAnalysisManager &AM) {
234    const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
235    if (!expandReductions(F, &TTI))
236      return PreservedAnalyses::all();
237    PreservedAnalyses PA;
238    PA.preserveSet<CFGAnalyses>();
239    return PA;
240  }
241