1 //===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements IR expansion for reduction intrinsics, allowing targets 10 // to enable the intrinsics until just before codegen. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ExpandReductions.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/CodeGen/Passes.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/InstIterator.h" 19 #include "llvm/IR/IntrinsicInst.h" 20 #include "llvm/IR/Intrinsics.h" 21 #include "llvm/InitializePasses.h" 22 #include "llvm/Pass.h" 23 #include "llvm/Transforms/Utils/LoopUtils.h" 24 25 using namespace llvm; 26 27 namespace { 28 29 bool expandReductions(Function &F, const TargetTransformInfo *TTI) { 30 bool Changed = false; 31 SmallVector<IntrinsicInst *, 4> Worklist; 32 for (auto &I : instructions(F)) { 33 if (auto *II = dyn_cast<IntrinsicInst>(&I)) { 34 switch (II->getIntrinsicID()) { 35 default: break; 36 case Intrinsic::vector_reduce_fadd: 37 case Intrinsic::vector_reduce_fmul: 38 case Intrinsic::vector_reduce_add: 39 case Intrinsic::vector_reduce_mul: 40 case Intrinsic::vector_reduce_and: 41 case Intrinsic::vector_reduce_or: 42 case Intrinsic::vector_reduce_xor: 43 case Intrinsic::vector_reduce_smax: 44 case Intrinsic::vector_reduce_smin: 45 case Intrinsic::vector_reduce_umax: 46 case Intrinsic::vector_reduce_umin: 47 case Intrinsic::vector_reduce_fmax: 48 case Intrinsic::vector_reduce_fmin: 49 if (TTI->shouldExpandReduction(II)) 50 Worklist.push_back(II); 51 52 break; 53 } 54 } 55 } 56 57 for (auto *II : Worklist) { 58 FastMathFlags FMF = 59 isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{}; 60 Intrinsic::ID ID = II->getIntrinsicID(); 61 RecurKind RK = getMinMaxReductionRecurKind(ID); 62 TargetTransformInfo::ReductionShuffle RS = 63 TTI->getPreferredExpandedReductionShuffle(II); 64 65 Value *Rdx = nullptr; 66 IRBuilder<> Builder(II); 67 IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); 68 Builder.setFastMathFlags(FMF); 69 switch (ID) { 70 default: llvm_unreachable("Unexpected intrinsic!"); 71 case Intrinsic::vector_reduce_fadd: 72 case Intrinsic::vector_reduce_fmul: { 73 // FMFs must be attached to the call, otherwise it's an ordered reduction 74 // and it can't be handled by generating a shuffle sequence. 75 Value *Acc = II->getArgOperand(0); 76 Value *Vec = II->getArgOperand(1); 77 unsigned RdxOpcode = getArithmeticReductionInstruction(ID); 78 if (!FMF.allowReassoc()) 79 Rdx = getOrderedReduction(Builder, Acc, Vec, RdxOpcode, RK); 80 else { 81 if (!isPowerOf2_32( 82 cast<FixedVectorType>(Vec->getType())->getNumElements())) 83 continue; 84 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK); 85 Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx, 86 "bin.rdx"); 87 } 88 break; 89 } 90 case Intrinsic::vector_reduce_and: 91 case Intrinsic::vector_reduce_or: { 92 // Canonicalize logical or/and reductions: 93 // Or reduction for i1 is represented as: 94 // %val = bitcast <ReduxWidth x i1> to iReduxWidth 95 // %res = cmp ne iReduxWidth %val, 0 96 // And reduction for i1 is represented as: 97 // %val = bitcast <ReduxWidth x i1> to iReduxWidth 98 // %res = cmp eq iReduxWidth %val, 11111 99 Value *Vec = II->getArgOperand(0); 100 auto *FTy = cast<FixedVectorType>(Vec->getType()); 101 unsigned NumElts = FTy->getNumElements(); 102 if (!isPowerOf2_32(NumElts)) 103 continue; 104 105 if (FTy->getElementType() == Builder.getInt1Ty()) { 106 Rdx = Builder.CreateBitCast(Vec, Builder.getIntNTy(NumElts)); 107 if (ID == Intrinsic::vector_reduce_and) { 108 Rdx = Builder.CreateICmpEQ( 109 Rdx, ConstantInt::getAllOnesValue(Rdx->getType())); 110 } else { 111 assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction."); 112 Rdx = Builder.CreateIsNotNull(Rdx); 113 } 114 break; 115 } 116 unsigned RdxOpcode = getArithmeticReductionInstruction(ID); 117 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK); 118 break; 119 } 120 case Intrinsic::vector_reduce_add: 121 case Intrinsic::vector_reduce_mul: 122 case Intrinsic::vector_reduce_xor: 123 case Intrinsic::vector_reduce_smax: 124 case Intrinsic::vector_reduce_smin: 125 case Intrinsic::vector_reduce_umax: 126 case Intrinsic::vector_reduce_umin: { 127 Value *Vec = II->getArgOperand(0); 128 if (!isPowerOf2_32( 129 cast<FixedVectorType>(Vec->getType())->getNumElements())) 130 continue; 131 unsigned RdxOpcode = getArithmeticReductionInstruction(ID); 132 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK); 133 break; 134 } 135 case Intrinsic::vector_reduce_fmax: 136 case Intrinsic::vector_reduce_fmin: { 137 // We require "nnan" to use a shuffle reduction; "nsz" is implied by the 138 // semantics of the reduction. 139 Value *Vec = II->getArgOperand(0); 140 if (!isPowerOf2_32( 141 cast<FixedVectorType>(Vec->getType())->getNumElements()) || 142 !FMF.noNaNs()) 143 continue; 144 unsigned RdxOpcode = getArithmeticReductionInstruction(ID); 145 Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK); 146 break; 147 } 148 } 149 II->replaceAllUsesWith(Rdx); 150 II->eraseFromParent(); 151 Changed = true; 152 } 153 return Changed; 154 } 155 156 class ExpandReductions : public FunctionPass { 157 public: 158 static char ID; 159 ExpandReductions() : FunctionPass(ID) { 160 initializeExpandReductionsPass(*PassRegistry::getPassRegistry()); 161 } 162 163 bool runOnFunction(Function &F) override { 164 const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 165 return expandReductions(F, TTI); 166 } 167 168 void getAnalysisUsage(AnalysisUsage &AU) const override { 169 AU.addRequired<TargetTransformInfoWrapperPass>(); 170 AU.setPreservesCFG(); 171 } 172 }; 173 } 174 175 char ExpandReductions::ID; 176 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions", 177 "Expand reduction intrinsics", false, false) 178 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 179 INITIALIZE_PASS_END(ExpandReductions, "expand-reductions", 180 "Expand reduction intrinsics", false, false) 181 182 FunctionPass *llvm::createExpandReductionsPass() { 183 return new ExpandReductions(); 184 } 185 186 PreservedAnalyses ExpandReductionsPass::run(Function &F, 187 FunctionAnalysisManager &AM) { 188 const auto &TTI = AM.getResult<TargetIRAnalysis>(F); 189 if (!expandReductions(F, &TTI)) 190 return PreservedAnalyses::all(); 191 PreservedAnalyses PA; 192 PA.preserveSet<CFGAnalyses>(); 193 return PA; 194 } 195