xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ExpandReductions.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements IR expansion for reduction intrinsics, allowing targets
10 // to enable the intrinsics until just before codegen.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/ExpandReductions.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Transforms/Utils/LoopUtils.h"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
30   bool Changed = false;
31   SmallVector<IntrinsicInst *, 4> Worklist;
32   for (auto &I : instructions(F)) {
33     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
34       switch (II->getIntrinsicID()) {
35       default: break;
36       case Intrinsic::vector_reduce_fadd:
37       case Intrinsic::vector_reduce_fmul:
38       case Intrinsic::vector_reduce_add:
39       case Intrinsic::vector_reduce_mul:
40       case Intrinsic::vector_reduce_and:
41       case Intrinsic::vector_reduce_or:
42       case Intrinsic::vector_reduce_xor:
43       case Intrinsic::vector_reduce_smax:
44       case Intrinsic::vector_reduce_smin:
45       case Intrinsic::vector_reduce_umax:
46       case Intrinsic::vector_reduce_umin:
47       case Intrinsic::vector_reduce_fmax:
48       case Intrinsic::vector_reduce_fmin:
49         if (TTI->shouldExpandReduction(II))
50           Worklist.push_back(II);
51 
52         break;
53       }
54     }
55   }
56 
57   for (auto *II : Worklist) {
58     FastMathFlags FMF =
59         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
60     Intrinsic::ID ID = II->getIntrinsicID();
61     RecurKind RK = getMinMaxReductionRecurKind(ID);
62     TargetTransformInfo::ReductionShuffle RS =
63         TTI->getPreferredExpandedReductionShuffle(II);
64 
65     Value *Rdx = nullptr;
66     IRBuilder<> Builder(II);
67     IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
68     Builder.setFastMathFlags(FMF);
69     switch (ID) {
70     default: llvm_unreachable("Unexpected intrinsic!");
71     case Intrinsic::vector_reduce_fadd:
72     case Intrinsic::vector_reduce_fmul: {
73       // FMFs must be attached to the call, otherwise it's an ordered reduction
74       // and it can't be handled by generating a shuffle sequence.
75       Value *Acc = II->getArgOperand(0);
76       Value *Vec = II->getArgOperand(1);
77       unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
78       if (!FMF.allowReassoc())
79         Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK);
80       else {
81         if (!isPowerOf2_32(
82                 cast<FixedVectorType>(Vec->getType())->getNumElements()))
83           continue;
84         Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
85         Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
86                                   "bin.rdx");
87       }
88       break;
89     }
90     case Intrinsic::vector_reduce_and:
91     case Intrinsic::vector_reduce_or: {
92       // Canonicalize logical or/and reductions:
93       // Or reduction for i1 is represented as:
94       // %val = bitcast <ReduxWidth x i1> to iReduxWidth
95       // %res = cmp ne iReduxWidth %val, 0
96       // And reduction for i1 is represented as:
97       // %val = bitcast <ReduxWidth x i1> to iReduxWidth
98       // %res = cmp eq iReduxWidth %val, 11111
99       Value *Vec = II->getArgOperand(0);
100       auto *FTy = cast<FixedVectorType>(Vec->getType());
101       unsigned NumElts = FTy->getNumElements();
102       if (!isPowerOf2_32(NumElts))
103         continue;
104 
105       if (FTy->getElementType() == Builder.getInt1Ty()) {
106         Rdx = Builder.CreateBitCast(Vec, Builder.getIntNTy(NumElts));
107         if (ID == Intrinsic::vector_reduce_and) {
108           Rdx = Builder.CreateICmpEQ(
109               Rdx, ConstantInt::getAllOnesValue(Rdx->getType()));
110         } else {
111           assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
112           Rdx = Builder.CreateIsNotNull(Rdx);
113         }
114         break;
115       }
116       unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
117       Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
118       break;
119     }
120     case Intrinsic::vector_reduce_add:
121     case Intrinsic::vector_reduce_mul:
122     case Intrinsic::vector_reduce_xor:
123     case Intrinsic::vector_reduce_smax:
124     case Intrinsic::vector_reduce_smin:
125     case Intrinsic::vector_reduce_umax:
126     case Intrinsic::vector_reduce_umin: {
127       Value *Vec = II->getArgOperand(0);
128       if (!isPowerOf2_32(
129               cast<FixedVectorType>(Vec->getType())->getNumElements()))
130         continue;
131       unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
132       Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
133       break;
134     }
135     case Intrinsic::vector_reduce_fmax:
136     case Intrinsic::vector_reduce_fmin: {
137       // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
138       // semantics of the reduction.
139       Value *Vec = II->getArgOperand(0);
140       if (!isPowerOf2_32(
141               cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
142           !FMF.noNaNs())
143         continue;
144       unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
145       Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
146       break;
147     }
148     }
149     II->replaceAllUsesWith(Rdx);
150     II->eraseFromParent();
151     Changed = true;
152   }
153   return Changed;
154 }
155 
156 class ExpandReductions : public FunctionPass {
157 public:
158   static char ID;
159   ExpandReductions() : FunctionPass(ID) {
160     initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
161   }
162 
163   bool runOnFunction(Function &F) override {
164     const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
165     return expandReductions(F, TTI);
166   }
167 
168   void getAnalysisUsage(AnalysisUsage &AU) const override {
169     AU.addRequired<TargetTransformInfoWrapperPass>();
170     AU.setPreservesCFG();
171   }
172 };
173 }
174 
175 char ExpandReductions::ID;
176 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
177                       "Expand reduction intrinsics", false, false)
178 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
179 INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
180                     "Expand reduction intrinsics", false, false)
181 
182 FunctionPass *llvm::createExpandReductionsPass() {
183   return new ExpandReductions();
184 }
185 
186 PreservedAnalyses ExpandReductionsPass::run(Function &F,
187                                             FunctionAnalysisManager &AM) {
188   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
189   if (!expandReductions(F, &TTI))
190     return PreservedAnalyses::all();
191   PreservedAnalyses PA;
192   PA.preserveSet<CFGAnalyses>();
193   return PA;
194 }
195