xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Vectorize/EVLIndVarSimplify.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes a vectorized loop with canonical IV to using EVL-based
10 // IV if it was tail-folded by predicated EVL.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/IVDescriptors.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/LoopPass.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar/LoopPassManager.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 
32 #define DEBUG_TYPE "evl-iv-simplify"
33 
34 using namespace llvm;
35 
36 STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");
37 
38 static cl::opt<bool> EnableEVLIndVarSimplify(
39     "enable-evl-indvar-simplify",
40     cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,
41     cl::init(true));
42 
43 namespace {
44 struct EVLIndVarSimplifyImpl {
45   ScalarEvolution &SE;
46   OptimizationRemarkEmitter *ORE = nullptr;
47 
48   EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR,
49                         OptimizationRemarkEmitter *ORE)
50       : SE(LAR.SE), ORE(ORE) {}
51 
52   /// Returns true if modify the loop.
53   bool run(Loop &L);
54 };
55 } // anonymous namespace
56 
57 /// Returns the constant part of vectorization factor from the induction
58 /// variable's step value SCEV expression.
59 static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {
60   if (!Step)
61     return 0U;
62 
63   // Looking for loops with IV step value in the form of `(<constant VF> x
64   // vscale)`.
65   if (const auto *Mul = dyn_cast<SCEVMulExpr>(Step)) {
66     if (Mul->getNumOperands() == 2) {
67       const SCEV *LHS = Mul->getOperand(0);
68       const SCEV *RHS = Mul->getOperand(1);
69       if (const auto *Const = dyn_cast<SCEVConstant>(LHS);
70           Const && isa<SCEVVScale>(RHS)) {
71         uint64_t V = Const->getAPInt().getLimitedValue();
72         if (llvm::isUInt<32>(V))
73           return V;
74       }
75     }
76   }
77 
78   // If not, see if the vscale_range of the parent function is a fixed value,
79   // which makes the step value to be replaced by a constant.
80   if (F.hasFnAttribute(Attribute::VScaleRange))
81     if (const auto *ConstStep = dyn_cast<SCEVConstant>(Step)) {
82       APInt V = ConstStep->getAPInt().abs();
83       ConstantRange CR = llvm::getVScaleRange(&F, 64);
84       if (const APInt *Fixed = CR.getSingleElement()) {
85         V = V.zextOrTrunc(Fixed->getBitWidth());
86         uint64_t VF = V.udiv(*Fixed).getLimitedValue();
87         if (VF && llvm::isUInt<32>(VF) &&
88             // Make sure step is divisible by vscale.
89             V.urem(*Fixed).isZero())
90           return VF;
91       }
92     }
93 
94   return 0U;
95 }
96 
97 bool EVLIndVarSimplifyImpl::run(Loop &L) {
98   if (!EnableEVLIndVarSimplify)
99     return false;
100 
101   if (!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized"))
102     return false;
103   const MDOperand *EVLMD =
104       findStringMetadataForLoop(&L, "llvm.loop.isvectorized.tailfoldingstyle")
105           .value_or(nullptr);
106   if (!EVLMD || !EVLMD->equalsStr("evl"))
107     return false;
108 
109   BasicBlock *LatchBlock = L.getLoopLatch();
110   ICmpInst *OrigLatchCmp = L.getLatchCmpInst();
111   if (!LatchBlock || !OrigLatchCmp)
112     return false;
113 
114   InductionDescriptor IVD;
115   PHINode *IndVar = L.getInductionVariable(SE);
116   if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
117     const char *Reason = (IndVar ? "induction descriptor is not available"
118                                  : "cannot recognize induction variable");
119     LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName()
120                       << " because" << Reason << "\n");
121     if (ORE) {
122       ORE->emit([&]() {
123         return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",
124                                         L.getStartLoc(), L.getHeader())
125                << "Cannot retrieve IV because " << ore::NV("Reason", Reason);
126       });
127     }
128     return false;
129   }
130 
131   BasicBlock *InitBlock, *BackEdgeBlock;
132   if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) {
133     LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in "
134                       << L.getName() << "\n");
135     if (ORE) {
136       ORE->emit([&]() {
137         return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",
138                                         L.getStartLoc(), L.getHeader())
139                << "Does not have a unique incoming and backedge";
140       });
141     }
142     return false;
143   }
144 
145   // Retrieve the loop bounds.
146   std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE);
147   if (!Bounds) {
148     LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName()
149                       << "\n");
150     if (ORE) {
151       ORE->emit([&]() {
152         return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",
153                                         L.getStartLoc(), L.getHeader())
154                << "Could not obtain the loop bounds";
155       });
156     }
157     return false;
158   }
159   Value *CanonicalIVInit = &Bounds->getInitialIVValue();
160   Value *CanonicalIVFinal = &Bounds->getFinalIVValue();
161 
162   const SCEV *StepV = IVD.getStep();
163   uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
164   if (!VF) {
165     LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV
166                       << "'\n");
167     if (ORE) {
168       ORE->emit([&]() {
169         return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",
170                                         L.getStartLoc(), L.getHeader())
171                << "Could not infer VF from IndVar step "
172                << ore::NV("Step", StepV);
173       });
174     }
175     return false;
176   }
177   LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()
178                     << "\n");
179 
180   // Try to find the EVL-based induction variable.
181   using namespace PatternMatch;
182   BasicBlock *BB = IndVar->getParent();
183 
184   Value *EVLIndVar = nullptr;
185   Value *RemTC = nullptr;
186   Value *TC = nullptr;
187   auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
188       m_Value(RemTC), m_SpecificInt(VF),
189       /*Scalable=*/m_SpecificInt(1));
190   for (PHINode &PN : BB->phis()) {
191     if (&PN == IndVar)
192       continue;
193 
194     // Check 1: it has to contain both incoming (init) & backedge blocks
195     // from IndVar.
196     if (PN.getBasicBlockIndex(InitBlock) < 0 ||
197         PN.getBasicBlockIndex(BackEdgeBlock) < 0)
198       continue;
199     // Check 2: EVL index is always increasing, thus its inital value has to be
200     // equal to either the initial IV value (when the canonical IV is also
201     // increasing) or the last IV value (when canonical IV is decreasing).
202     Value *Init = PN.getIncomingValueForBlock(InitBlock);
203     using Direction = Loop::LoopBounds::Direction;
204     switch (Bounds->getDirection()) {
205     case Direction::Increasing:
206       if (Init != CanonicalIVInit)
207         continue;
208       break;
209     case Direction::Decreasing:
210       if (Init != CanonicalIVFinal)
211         continue;
212       break;
213     case Direction::Unknown:
214       // To be more permissive and see if either the initial or final IV value
215       // matches PN's init value.
216       if (Init != CanonicalIVInit && Init != CanonicalIVFinal)
217         continue;
218       break;
219     }
220     Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock);
221     assert(RecValue && "expect recurrent IndVar value");
222 
223     LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN
224                       << "\n");
225 
226     // Check 3: Pattern match to find the EVL-based index and total trip count
227     // (TC).
228     if (match(RecValue,
229               m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
230         match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
231       EVLIndVar = RecValue;
232       break;
233     }
234   }
235 
236   if (!EVLIndVar || !TC)
237     return false;
238 
239   LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
240   if (ORE) {
241     ORE->emit([&]() {
242       DebugLoc DL;
243       BasicBlock *Region = nullptr;
244       if (auto *I = dyn_cast<Instruction>(EVLIndVar)) {
245         DL = I->getDebugLoc();
246         Region = I->getParent();
247       } else {
248         DL = L.getStartLoc();
249         Region = L.getHeader();
250       }
251       return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar", DL, Region)
252              << "Using " << ore::NV("EVLIndVar", EVLIndVar)
253              << " for EVL-based IndVar";
254     });
255   }
256 
257   // Create an EVL-based comparison and replace the branch to use it as
258   // predicate.
259 
260   // Loop::getLatchCmpInst check at the beginning of this function has ensured
261   // that latch block ends in a conditional branch.
262   auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator());
263   assert(LatchBranch->isConditional() &&
264          "expect the loop latch to be ended with a conditional branch");
265   ICmpInst::Predicate Pred;
266   if (LatchBranch->getSuccessor(0) == L.getHeader())
267     Pred = ICmpInst::ICMP_NE;
268   else
269     Pred = ICmpInst::ICMP_EQ;
270 
271   IRBuilder<> Builder(OrigLatchCmp);
272   auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC);
273   OrigLatchCmp->replaceAllUsesWith(NewLatchCmp);
274 
275   // llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are
276   // not used outside the cycles. However, in this case the now-RAUW-ed
277   // OrigLatchCmp will be considered a use outside the cycle while in reality
278   // it's practically dead. Thus we need to remove it before calling
279   // RecursivelyDeleteDeadPHINode.
280   (void)RecursivelyDeleteTriviallyDeadInstructions(OrigLatchCmp);
281   if (llvm::RecursivelyDeleteDeadPHINode(IndVar))
282     LLVM_DEBUG(dbgs() << "Removed original IndVar\n");
283 
284   ++NumEliminatedCanonicalIV;
285 
286   return true;
287 }
288 
289 PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM,
290                                              LoopStandardAnalysisResults &AR,
291                                              LPMUpdater &U) {
292   Function &F = *L.getHeader()->getParent();
293   auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR);
294   OptimizationRemarkEmitter *ORE =
295       FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(F);
296 
297   if (EVLIndVarSimplifyImpl(AR, ORE).run(L))
298     return PreservedAnalyses::allInSet<CFGAnalyses>();
299   return PreservedAnalyses::all();
300 }
301