1 //===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes a vectorized loop with canonical IV to using EVL-based 10 // IV if it was tail-folded by predicated EVL. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/IVDescriptors.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/LoopPass.h" 19 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 20 #include "llvm/Analysis/ScalarEvolution.h" 21 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 22 #include "llvm/Analysis/ValueTracking.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/Support/CommandLine.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/MathExtras.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include "llvm/Transforms/Scalar/LoopPassManager.h" 30 #include "llvm/Transforms/Utils/Local.h" 31 32 #define DEBUG_TYPE "evl-iv-simplify" 33 34 using namespace llvm; 35 36 STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated"); 37 38 static cl::opt<bool> EnableEVLIndVarSimplify( 39 "enable-evl-indvar-simplify", 40 cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden, 41 cl::init(true)); 42 43 namespace { 44 struct EVLIndVarSimplifyImpl { 45 ScalarEvolution &SE; 46 OptimizationRemarkEmitter *ORE = nullptr; 47 48 EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR, 49 OptimizationRemarkEmitter *ORE) 50 : SE(LAR.SE), ORE(ORE) {} 51 52 /// Returns true if modify the loop. 53 bool run(Loop &L); 54 }; 55 } // anonymous namespace 56 57 /// Returns the constant part of vectorization factor from the induction 58 /// variable's step value SCEV expression. 59 static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) { 60 if (!Step) 61 return 0U; 62 63 // Looking for loops with IV step value in the form of `(<constant VF> x 64 // vscale)`. 65 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Step)) { 66 if (Mul->getNumOperands() == 2) { 67 const SCEV *LHS = Mul->getOperand(0); 68 const SCEV *RHS = Mul->getOperand(1); 69 if (const auto *Const = dyn_cast<SCEVConstant>(LHS); 70 Const && isa<SCEVVScale>(RHS)) { 71 uint64_t V = Const->getAPInt().getLimitedValue(); 72 if (llvm::isUInt<32>(V)) 73 return V; 74 } 75 } 76 } 77 78 // If not, see if the vscale_range of the parent function is a fixed value, 79 // which makes the step value to be replaced by a constant. 80 if (F.hasFnAttribute(Attribute::VScaleRange)) 81 if (const auto *ConstStep = dyn_cast<SCEVConstant>(Step)) { 82 APInt V = ConstStep->getAPInt().abs(); 83 ConstantRange CR = llvm::getVScaleRange(&F, 64); 84 if (const APInt *Fixed = CR.getSingleElement()) { 85 V = V.zextOrTrunc(Fixed->getBitWidth()); 86 uint64_t VF = V.udiv(*Fixed).getLimitedValue(); 87 if (VF && llvm::isUInt<32>(VF) && 88 // Make sure step is divisible by vscale. 89 V.urem(*Fixed).isZero()) 90 return VF; 91 } 92 } 93 94 return 0U; 95 } 96 97 bool EVLIndVarSimplifyImpl::run(Loop &L) { 98 if (!EnableEVLIndVarSimplify) 99 return false; 100 101 if (!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized")) 102 return false; 103 const MDOperand *EVLMD = 104 findStringMetadataForLoop(&L, "llvm.loop.isvectorized.tailfoldingstyle") 105 .value_or(nullptr); 106 if (!EVLMD || !EVLMD->equalsStr("evl")) 107 return false; 108 109 BasicBlock *LatchBlock = L.getLoopLatch(); 110 ICmpInst *OrigLatchCmp = L.getLatchCmpInst(); 111 if (!LatchBlock || !OrigLatchCmp) 112 return false; 113 114 InductionDescriptor IVD; 115 PHINode *IndVar = L.getInductionVariable(SE); 116 if (!IndVar || !L.getInductionDescriptor(SE, IVD)) { 117 const char *Reason = (IndVar ? "induction descriptor is not available" 118 : "cannot recognize induction variable"); 119 LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName() 120 << " because" << Reason << "\n"); 121 if (ORE) { 122 ORE->emit([&]() { 123 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar", 124 L.getStartLoc(), L.getHeader()) 125 << "Cannot retrieve IV because " << ore::NV("Reason", Reason); 126 }); 127 } 128 return false; 129 } 130 131 BasicBlock *InitBlock, *BackEdgeBlock; 132 if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) { 133 LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in " 134 << L.getName() << "\n"); 135 if (ORE) { 136 ORE->emit([&]() { 137 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure", 138 L.getStartLoc(), L.getHeader()) 139 << "Does not have a unique incoming and backedge"; 140 }); 141 } 142 return false; 143 } 144 145 // Retrieve the loop bounds. 146 std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE); 147 if (!Bounds) { 148 LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName() 149 << "\n"); 150 if (ORE) { 151 ORE->emit([&]() { 152 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure", 153 L.getStartLoc(), L.getHeader()) 154 << "Could not obtain the loop bounds"; 155 }); 156 } 157 return false; 158 } 159 Value *CanonicalIVInit = &Bounds->getInitialIVValue(); 160 Value *CanonicalIVFinal = &Bounds->getFinalIVValue(); 161 162 const SCEV *StepV = IVD.getStep(); 163 uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent()); 164 if (!VF) { 165 LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV 166 << "'\n"); 167 if (ORE) { 168 ORE->emit([&]() { 169 return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar", 170 L.getStartLoc(), L.getHeader()) 171 << "Could not infer VF from IndVar step " 172 << ore::NV("Step", StepV); 173 }); 174 } 175 return false; 176 } 177 LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName() 178 << "\n"); 179 180 // Try to find the EVL-based induction variable. 181 using namespace PatternMatch; 182 BasicBlock *BB = IndVar->getParent(); 183 184 Value *EVLIndVar = nullptr; 185 Value *RemTC = nullptr; 186 Value *TC = nullptr; 187 auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>( 188 m_Value(RemTC), m_SpecificInt(VF), 189 /*Scalable=*/m_SpecificInt(1)); 190 for (PHINode &PN : BB->phis()) { 191 if (&PN == IndVar) 192 continue; 193 194 // Check 1: it has to contain both incoming (init) & backedge blocks 195 // from IndVar. 196 if (PN.getBasicBlockIndex(InitBlock) < 0 || 197 PN.getBasicBlockIndex(BackEdgeBlock) < 0) 198 continue; 199 // Check 2: EVL index is always increasing, thus its inital value has to be 200 // equal to either the initial IV value (when the canonical IV is also 201 // increasing) or the last IV value (when canonical IV is decreasing). 202 Value *Init = PN.getIncomingValueForBlock(InitBlock); 203 using Direction = Loop::LoopBounds::Direction; 204 switch (Bounds->getDirection()) { 205 case Direction::Increasing: 206 if (Init != CanonicalIVInit) 207 continue; 208 break; 209 case Direction::Decreasing: 210 if (Init != CanonicalIVFinal) 211 continue; 212 break; 213 case Direction::Unknown: 214 // To be more permissive and see if either the initial or final IV value 215 // matches PN's init value. 216 if (Init != CanonicalIVInit && Init != CanonicalIVFinal) 217 continue; 218 break; 219 } 220 Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock); 221 assert(RecValue && "expect recurrent IndVar value"); 222 223 LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN 224 << "\n"); 225 226 // Check 3: Pattern match to find the EVL-based index and total trip count 227 // (TC). 228 if (match(RecValue, 229 m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) && 230 match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) { 231 EVLIndVar = RecValue; 232 break; 233 } 234 } 235 236 if (!EVLIndVar || !TC) 237 return false; 238 239 LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n"); 240 if (ORE) { 241 ORE->emit([&]() { 242 DebugLoc DL; 243 BasicBlock *Region = nullptr; 244 if (auto *I = dyn_cast<Instruction>(EVLIndVar)) { 245 DL = I->getDebugLoc(); 246 Region = I->getParent(); 247 } else { 248 DL = L.getStartLoc(); 249 Region = L.getHeader(); 250 } 251 return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar", DL, Region) 252 << "Using " << ore::NV("EVLIndVar", EVLIndVar) 253 << " for EVL-based IndVar"; 254 }); 255 } 256 257 // Create an EVL-based comparison and replace the branch to use it as 258 // predicate. 259 260 // Loop::getLatchCmpInst check at the beginning of this function has ensured 261 // that latch block ends in a conditional branch. 262 auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator()); 263 assert(LatchBranch->isConditional() && 264 "expect the loop latch to be ended with a conditional branch"); 265 ICmpInst::Predicate Pred; 266 if (LatchBranch->getSuccessor(0) == L.getHeader()) 267 Pred = ICmpInst::ICMP_NE; 268 else 269 Pred = ICmpInst::ICMP_EQ; 270 271 IRBuilder<> Builder(OrigLatchCmp); 272 auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC); 273 OrigLatchCmp->replaceAllUsesWith(NewLatchCmp); 274 275 // llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are 276 // not used outside the cycles. However, in this case the now-RAUW-ed 277 // OrigLatchCmp will be considered a use outside the cycle while in reality 278 // it's practically dead. Thus we need to remove it before calling 279 // RecursivelyDeleteDeadPHINode. 280 (void)RecursivelyDeleteTriviallyDeadInstructions(OrigLatchCmp); 281 if (llvm::RecursivelyDeleteDeadPHINode(IndVar)) 282 LLVM_DEBUG(dbgs() << "Removed original IndVar\n"); 283 284 ++NumEliminatedCanonicalIV; 285 286 return true; 287 } 288 289 PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM, 290 LoopStandardAnalysisResults &AR, 291 LPMUpdater &U) { 292 Function &F = *L.getHeader()->getParent(); 293 auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR); 294 OptimizationRemarkEmitter *ORE = 295 FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(F); 296 297 if (EVLIndVarSimplifyImpl(AR, ORE).run(L)) 298 return PreservedAnalyses::allInSet<CFGAnalyses>(); 299 return PreservedAnalyses::all(); 300 } 301