xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ExpandReductions.cpp (revision 5ca8c28cd8c725b81781201cfdb5f9969396f934)
1 //===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements IR expansion for reduction intrinsics, allowing targets
10 // to enable the intrinsics until just before codegen.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/ExpandReductions.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Transforms/Utils/LoopUtils.h"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 unsigned getOpcode(Intrinsic::ID ID) {
30   switch (ID) {
31   case Intrinsic::vector_reduce_fadd:
32     return Instruction::FAdd;
33   case Intrinsic::vector_reduce_fmul:
34     return Instruction::FMul;
35   case Intrinsic::vector_reduce_add:
36     return Instruction::Add;
37   case Intrinsic::vector_reduce_mul:
38     return Instruction::Mul;
39   case Intrinsic::vector_reduce_and:
40     return Instruction::And;
41   case Intrinsic::vector_reduce_or:
42     return Instruction::Or;
43   case Intrinsic::vector_reduce_xor:
44     return Instruction::Xor;
45   case Intrinsic::vector_reduce_smax:
46   case Intrinsic::vector_reduce_smin:
47   case Intrinsic::vector_reduce_umax:
48   case Intrinsic::vector_reduce_umin:
49     return Instruction::ICmp;
50   case Intrinsic::vector_reduce_fmax:
51   case Intrinsic::vector_reduce_fmin:
52     return Instruction::FCmp;
53   default:
54     llvm_unreachable("Unexpected ID");
55   }
56 }
57 
58 RecurKind getRK(Intrinsic::ID ID) {
59   switch (ID) {
60   case Intrinsic::vector_reduce_smax:
61     return RecurKind::SMax;
62   case Intrinsic::vector_reduce_smin:
63     return RecurKind::SMin;
64   case Intrinsic::vector_reduce_umax:
65     return RecurKind::UMax;
66   case Intrinsic::vector_reduce_umin:
67     return RecurKind::UMin;
68   case Intrinsic::vector_reduce_fmax:
69     return RecurKind::FMax;
70   case Intrinsic::vector_reduce_fmin:
71     return RecurKind::FMin;
72   default:
73     return RecurKind::None;
74   }
75 }
76 
77 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
78   bool Changed = false;
79   SmallVector<IntrinsicInst *, 4> Worklist;
80   for (auto &I : instructions(F)) {
81     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
82       switch (II->getIntrinsicID()) {
83       default: break;
84       case Intrinsic::vector_reduce_fadd:
85       case Intrinsic::vector_reduce_fmul:
86       case Intrinsic::vector_reduce_add:
87       case Intrinsic::vector_reduce_mul:
88       case Intrinsic::vector_reduce_and:
89       case Intrinsic::vector_reduce_or:
90       case Intrinsic::vector_reduce_xor:
91       case Intrinsic::vector_reduce_smax:
92       case Intrinsic::vector_reduce_smin:
93       case Intrinsic::vector_reduce_umax:
94       case Intrinsic::vector_reduce_umin:
95       case Intrinsic::vector_reduce_fmax:
96       case Intrinsic::vector_reduce_fmin:
97         if (TTI->shouldExpandReduction(II))
98           Worklist.push_back(II);
99 
100         break;
101       }
102     }
103   }
104 
105   for (auto *II : Worklist) {
106     FastMathFlags FMF =
107         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
108     Intrinsic::ID ID = II->getIntrinsicID();
109     RecurKind RK = getRK(ID);
110 
111     Value *Rdx = nullptr;
112     IRBuilder<> Builder(II);
113     IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
114     Builder.setFastMathFlags(FMF);
115     switch (ID) {
116     default: llvm_unreachable("Unexpected intrinsic!");
117     case Intrinsic::vector_reduce_fadd:
118     case Intrinsic::vector_reduce_fmul: {
119       // FMFs must be attached to the call, otherwise it's an ordered reduction
120       // and it can't be handled by generating a shuffle sequence.
121       Value *Acc = II->getArgOperand(0);
122       Value *Vec = II->getArgOperand(1);
123       if (!FMF.allowReassoc())
124         Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
125       else {
126         if (!isPowerOf2_32(
127                 cast<FixedVectorType>(Vec->getType())->getNumElements()))
128           continue;
129 
130         Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
131         Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
132                                   Acc, Rdx, "bin.rdx");
133       }
134       break;
135     }
136     case Intrinsic::vector_reduce_and:
137     case Intrinsic::vector_reduce_or: {
138       // Canonicalize logical or/and reductions:
139       // Or reduction for i1 is represented as:
140       // %val = bitcast <ReduxWidth x i1> to iReduxWidth
141       // %res = cmp ne iReduxWidth %val, 0
142       // And reduction for i1 is represented as:
143       // %val = bitcast <ReduxWidth x i1> to iReduxWidth
144       // %res = cmp eq iReduxWidth %val, 11111
145       Value *Vec = II->getArgOperand(0);
146       auto *FTy = cast<FixedVectorType>(Vec->getType());
147       unsigned NumElts = FTy->getNumElements();
148       if (!isPowerOf2_32(NumElts))
149         continue;
150 
151       if (FTy->getElementType() == Builder.getInt1Ty()) {
152         Rdx = Builder.CreateBitCast(Vec, Builder.getIntNTy(NumElts));
153         if (ID == Intrinsic::vector_reduce_and) {
154           Rdx = Builder.CreateICmpEQ(
155               Rdx, ConstantInt::getAllOnesValue(Rdx->getType()));
156         } else {
157           assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
158           Rdx = Builder.CreateIsNotNull(Rdx);
159         }
160         break;
161       }
162 
163       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
164       break;
165     }
166     case Intrinsic::vector_reduce_add:
167     case Intrinsic::vector_reduce_mul:
168     case Intrinsic::vector_reduce_xor:
169     case Intrinsic::vector_reduce_smax:
170     case Intrinsic::vector_reduce_smin:
171     case Intrinsic::vector_reduce_umax:
172     case Intrinsic::vector_reduce_umin: {
173       Value *Vec = II->getArgOperand(0);
174       if (!isPowerOf2_32(
175               cast<FixedVectorType>(Vec->getType())->getNumElements()))
176         continue;
177 
178       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
179       break;
180     }
181     case Intrinsic::vector_reduce_fmax:
182     case Intrinsic::vector_reduce_fmin: {
183       // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
184       // semantics of the reduction.
185       Value *Vec = II->getArgOperand(0);
186       if (!isPowerOf2_32(
187               cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
188           !FMF.noNaNs())
189         continue;
190 
191       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
192       break;
193     }
194     }
195     II->replaceAllUsesWith(Rdx);
196     II->eraseFromParent();
197     Changed = true;
198   }
199   return Changed;
200 }
201 
202 class ExpandReductions : public FunctionPass {
203 public:
204   static char ID;
205   ExpandReductions() : FunctionPass(ID) {
206     initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
207   }
208 
209   bool runOnFunction(Function &F) override {
210     const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
211     return expandReductions(F, TTI);
212   }
213 
214   void getAnalysisUsage(AnalysisUsage &AU) const override {
215     AU.addRequired<TargetTransformInfoWrapperPass>();
216     AU.setPreservesCFG();
217   }
218 };
219 }
220 
221 char ExpandReductions::ID;
222 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
223                       "Expand reduction intrinsics", false, false)
224 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
225 INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
226                     "Expand reduction intrinsics", false, false)
227 
228 FunctionPass *llvm::createExpandReductionsPass() {
229   return new ExpandReductions();
230 }
231 
232 PreservedAnalyses ExpandReductionsPass::run(Function &F,
233                                             FunctionAnalysisManager &AM) {
234   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
235   if (!expandReductions(F, &TTI))
236     return PreservedAnalyses::all();
237   PreservedAnalyses PA;
238   PA.preserveSet<CFGAnalyses>();
239   return PA;
240 }
241