1 //===----- RISCVCodeGenPrepare.cpp ----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a RISC-V specific version of CodeGenPrepare. 10 // It munges the code in the input function to better prepare it for 11 // SelectionDAG-based code generation. This works around limitations in it's 12 // basic-block-at-a-time approach. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "RISCV.h" 17 #include "RISCVTargetMachine.h" 18 #include "llvm/ADT/Statistic.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/InstVisitor.h" 24 #include "llvm/IR/Intrinsics.h" 25 #include "llvm/IR/PatternMatch.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Transforms/Utils/Local.h" 29 30 using namespace llvm; 31 32 #define DEBUG_TYPE "riscv-codegenprepare" 33 #define PASS_NAME "RISC-V CodeGenPrepare" 34 35 namespace { 36 37 class RISCVCodeGenPrepare : public FunctionPass, 38 public InstVisitor<RISCVCodeGenPrepare, bool> { 39 const DataLayout *DL; 40 const DominatorTree *DT; 41 const RISCVSubtarget *ST; 42 43 public: 44 static char ID; 45 46 RISCVCodeGenPrepare() : FunctionPass(ID) {} 47 48 bool runOnFunction(Function &F) override; 49 50 StringRef getPassName() const override { return PASS_NAME; } 51 52 void getAnalysisUsage(AnalysisUsage &AU) const override { 53 AU.setPreservesCFG(); 54 AU.addRequired<DominatorTreeWrapperPass>(); 55 AU.addRequired<TargetPassConfig>(); 56 } 57 58 bool visitInstruction(Instruction &I) { return false; } 59 bool visitAnd(BinaryOperator &BO); 60 bool visitIntrinsicInst(IntrinsicInst &I); 61 bool expandVPStrideLoad(IntrinsicInst &I); 62 bool widenVPMerge(IntrinsicInst &I); 63 }; 64 65 } // end anonymous namespace 66 67 // Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set, 68 // but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill 69 // the upper 32 bits with ones. 70 bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) { 71 if (!ST->is64Bit()) 72 return false; 73 74 if (!BO.getType()->isIntegerTy(64)) 75 return false; 76 77 using namespace PatternMatch; 78 79 // Left hand side should be a zext nneg. 80 Value *LHSSrc; 81 if (!match(BO.getOperand(0), m_NNegZExt(m_Value(LHSSrc)))) 82 return false; 83 84 if (!LHSSrc->getType()->isIntegerTy(32)) 85 return false; 86 87 // Right hand side should be a constant. 88 Value *RHS = BO.getOperand(1); 89 90 auto *CI = dyn_cast<ConstantInt>(RHS); 91 if (!CI) 92 return false; 93 uint64_t C = CI->getZExtValue(); 94 95 // Look for constants that fit in 32 bits but not simm12, and can be made 96 // into simm12 by sign extending bit 31. This will allow use of ANDI. 97 // TODO: Is worth making simm32? 98 if (!isUInt<32>(C) || isInt<12>(C) || !isInt<12>(SignExtend64<32>(C))) 99 return false; 100 101 // Sign extend the constant and replace the And operand. 102 C = SignExtend64<32>(C); 103 BO.setOperand(1, ConstantInt::get(RHS->getType(), C)); 104 105 return true; 106 } 107 108 // With EVL tail folding, an AnyOf reduction will generate an i1 vp.merge like 109 // follows: 110 // 111 // loop: 112 // %phi = phi <vscale x 4 x i1> [ zeroinitializer, %entry ], [ %rec, %loop ] 113 // %cmp = icmp ... 114 // %rec = call <vscale x 4 x i1> @llvm.vp.merge(%cmp, i1 true, %phi, %evl) 115 // ... 116 // middle: 117 // %res = call i1 @llvm.vector.reduce.or(<vscale x 4 x i1> %rec) 118 // 119 // However RVV doesn't have any tail undisturbed mask instructions and so we 120 // need a convoluted sequence of mask instructions to lower the i1 vp.merge: see 121 // llvm/test/CodeGen/RISCV/rvv/vpmerge-sdnode.ll. 122 // 123 // To avoid that this widens the i1 vp.merge to an i8 vp.merge, which will 124 // generate a single vmerge.vim: 125 // 126 // loop: 127 // %phi = phi <vscale x 4 x i8> [ zeroinitializer, %entry ], [ %rec, %loop ] 128 // %cmp = icmp ... 129 // %rec = call <vscale x 4 x i8> @llvm.vp.merge(%cmp, i8 true, %phi, %evl) 130 // %trunc = trunc <vscale x 4 x i8> %rec to <vscale x 4 x i1> 131 // ... 132 // middle: 133 // %res = call i1 @llvm.vector.reduce.or(<vscale x 4 x i1> %rec) 134 // 135 // The trunc will normally be sunk outside of the loop, but even if there are 136 // users inside the loop it is still profitable. 137 bool RISCVCodeGenPrepare::widenVPMerge(IntrinsicInst &II) { 138 if (!II.getType()->getScalarType()->isIntegerTy(1)) 139 return false; 140 141 Value *Mask, *True, *PhiV, *EVL; 142 using namespace PatternMatch; 143 if (!match(&II, 144 m_Intrinsic<Intrinsic::vp_merge>(m_Value(Mask), m_Value(True), 145 m_Value(PhiV), m_Value(EVL)))) 146 return false; 147 148 auto *Phi = dyn_cast<PHINode>(PhiV); 149 if (!Phi || !Phi->hasOneUse() || Phi->getNumIncomingValues() != 2 || 150 !match(Phi->getIncomingValue(0), m_Zero()) || 151 Phi->getIncomingValue(1) != &II) 152 return false; 153 154 Type *WideTy = 155 VectorType::get(IntegerType::getInt8Ty(II.getContext()), 156 cast<VectorType>(II.getType())->getElementCount()); 157 158 IRBuilder<> Builder(Phi); 159 PHINode *WidePhi = Builder.CreatePHI(WideTy, 2); 160 WidePhi->addIncoming(ConstantAggregateZero::get(WideTy), 161 Phi->getIncomingBlock(0)); 162 Builder.SetInsertPoint(&II); 163 Value *WideTrue = Builder.CreateZExt(True, WideTy); 164 Value *WideMerge = Builder.CreateIntrinsic(Intrinsic::vp_merge, {WideTy}, 165 {Mask, WideTrue, WidePhi, EVL}); 166 WidePhi->addIncoming(WideMerge, Phi->getIncomingBlock(1)); 167 Value *Trunc = Builder.CreateTrunc(WideMerge, II.getType()); 168 169 II.replaceAllUsesWith(Trunc); 170 171 // Break the cycle and delete the old chain. 172 Phi->setIncomingValue(1, Phi->getIncomingValue(0)); 173 llvm::RecursivelyDeleteTriviallyDeadInstructions(&II); 174 175 return true; 176 } 177 178 // LLVM vector reduction intrinsics return a scalar result, but on RISC-V vector 179 // reduction instructions write the result in the first element of a vector 180 // register. So when a reduction in a loop uses a scalar phi, we end up with 181 // unnecessary scalar moves: 182 // 183 // loop: 184 // vfmv.s.f v10, fa0 185 // vfredosum.vs v8, v8, v10 186 // vfmv.f.s fa0, v8 187 // 188 // This mainly affects ordered fadd reductions and VP reductions that have a 189 // scalar start value, since other types of reduction typically use element-wise 190 // vectorisation in the loop body. This tries to vectorize any scalar phis that 191 // feed into these reductions: 192 // 193 // loop: 194 // %phi = phi <float> [ ..., %entry ], [ %acc, %loop ] 195 // %acc = call float @llvm.vector.reduce.fadd.nxv2f32(float %phi, 196 // <vscale x 2 x float> %vec) 197 // 198 // -> 199 // 200 // loop: 201 // %phi = phi <vscale x 2 x float> [ ..., %entry ], [ %acc.vec, %loop ] 202 // %phi.scalar = extractelement <vscale x 2 x float> %phi, i64 0 203 // %acc = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, 204 // <vscale x 2 x float> %vec) 205 // %acc.vec = insertelement <vscale x 2 x float> poison, float %acc.next, i64 0 206 // 207 // Which eliminates the scalar -> vector -> scalar crossing during instruction 208 // selection. 209 bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) { 210 if (expandVPStrideLoad(I)) 211 return true; 212 213 if (widenVPMerge(I)) 214 return true; 215 216 if (I.getIntrinsicID() != Intrinsic::vector_reduce_fadd && 217 !isa<VPReductionIntrinsic>(&I)) 218 return false; 219 220 auto *PHI = dyn_cast<PHINode>(I.getOperand(0)); 221 if (!PHI || !PHI->hasOneUse() || 222 !llvm::is_contained(PHI->incoming_values(), &I)) 223 return false; 224 225 Type *VecTy = I.getOperand(1)->getType(); 226 IRBuilder<> Builder(PHI); 227 auto *VecPHI = Builder.CreatePHI(VecTy, PHI->getNumIncomingValues()); 228 229 for (auto *BB : PHI->blocks()) { 230 Builder.SetInsertPoint(BB->getTerminator()); 231 Value *InsertElt = Builder.CreateInsertElement( 232 VecTy, PHI->getIncomingValueForBlock(BB), (uint64_t)0); 233 VecPHI->addIncoming(InsertElt, BB); 234 } 235 236 Builder.SetInsertPoint(&I); 237 I.setOperand(0, Builder.CreateExtractElement(VecPHI, (uint64_t)0)); 238 239 PHI->eraseFromParent(); 240 241 return true; 242 } 243 244 // Always expand zero strided loads so we match more .vx splat patterns, even if 245 // we have +optimized-zero-stride-loads. RISCVDAGToDAGISel::Select will convert 246 // it back to a strided load if it's optimized. 247 bool RISCVCodeGenPrepare::expandVPStrideLoad(IntrinsicInst &II) { 248 Value *BasePtr, *VL; 249 250 using namespace PatternMatch; 251 if (!match(&II, m_Intrinsic<Intrinsic::experimental_vp_strided_load>( 252 m_Value(BasePtr), m_Zero(), m_AllOnes(), m_Value(VL)))) 253 return false; 254 255 // If SEW>XLEN then a splat will get lowered as a zero strided load anyway, so 256 // avoid expanding here. 257 if (II.getType()->getScalarSizeInBits() > ST->getXLen()) 258 return false; 259 260 if (!isKnownNonZero(VL, {*DL, DT, nullptr, &II})) 261 return false; 262 263 auto *VTy = cast<VectorType>(II.getType()); 264 265 IRBuilder<> Builder(&II); 266 Type *STy = VTy->getElementType(); 267 Value *Val = Builder.CreateLoad(STy, BasePtr); 268 Value *Res = Builder.CreateIntrinsic(Intrinsic::experimental_vp_splat, {VTy}, 269 {Val, II.getOperand(2), VL}); 270 271 II.replaceAllUsesWith(Res); 272 II.eraseFromParent(); 273 return true; 274 } 275 276 bool RISCVCodeGenPrepare::runOnFunction(Function &F) { 277 if (skipFunction(F)) 278 return false; 279 280 auto &TPC = getAnalysis<TargetPassConfig>(); 281 auto &TM = TPC.getTM<RISCVTargetMachine>(); 282 ST = &TM.getSubtarget<RISCVSubtarget>(F); 283 284 DL = &F.getDataLayout(); 285 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 286 287 bool MadeChange = false; 288 for (auto &BB : F) 289 for (Instruction &I : llvm::make_early_inc_range(BB)) 290 MadeChange |= visit(I); 291 292 return MadeChange; 293 } 294 295 INITIALIZE_PASS_BEGIN(RISCVCodeGenPrepare, DEBUG_TYPE, PASS_NAME, false, false) 296 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 297 INITIALIZE_PASS_END(RISCVCodeGenPrepare, DEBUG_TYPE, PASS_NAME, false, false) 298 299 char RISCVCodeGenPrepare::ID = 0; 300 301 FunctionPass *llvm::createRISCVCodeGenPreparePass() { 302 return new RISCVCodeGenPrepare(); 303 } 304