xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopTermFold.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/Transforms/Scalar/LoopTermFold.h"
11 #include "llvm/ADT/Statistic.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopPass.h"
15 #include "llvm/Analysis/MemorySSA.h"
16 #include "llvm/Analysis/MemorySSAUpdater.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/Analysis/ValueTracking.h"
22 #include "llvm/Config/llvm-config.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/IR/Value.h"
31 #include "llvm/InitializePasses.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Utils.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/Local.h"
39 #include "llvm/Transforms/Utils/LoopUtils.h"
40 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41 #include <cassert>
42 #include <optional>
43 
44 using namespace llvm;
45 
46 #define DEBUG_TYPE "loop-term-fold"
47 
48 STATISTIC(NumTermFold,
49           "Number of terminating condition fold recognized and performed");
50 
51 static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
canFoldTermCondOfLoop(Loop * L,ScalarEvolution & SE,DominatorTree & DT,const LoopInfo & LI,const TargetTransformInfo & TTI)52 canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
53                       const LoopInfo &LI, const TargetTransformInfo &TTI) {
54   if (!L->isInnermost()) {
55     LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56     return std::nullopt;
57   }
58   // Only inspect on simple loop structure
59   if (!L->isLoopSimplifyForm()) {
60     LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61     return std::nullopt;
62   }
63 
64   if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
65     LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66     return std::nullopt;
67   }
68 
69   BasicBlock *LoopLatch = L->getLoopLatch();
70   BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
71   if (!BI || BI->isUnconditional())
72     return std::nullopt;
73   auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
74   if (!TermCond) {
75     LLVM_DEBUG(
76         dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77     return std::nullopt;
78   }
79   if (!TermCond->hasOneUse()) {
80     LLVM_DEBUG(
81         dbgs()
82         << "Cannot replace terminating condition with more than one use\n");
83     return std::nullopt;
84   }
85 
86   BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
87   Value *RHS = TermCond->getOperand(1);
88   if (!LHS || !L->isLoopInvariant(RHS))
89     // We could pattern match the inverse form of the icmp, but that is
90     // non-canonical, and this pass is running *very* late in the pipeline.
91     return std::nullopt;
92 
93   // Find the IV used by the current exit condition.
94   PHINode *ToFold;
95   Value *ToFoldStart, *ToFoldStep;
96   if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
97     return std::nullopt;
98 
99   // Ensure the simple recurrence is a part of the current loop.
100   if (ToFold->getParent() != L->getHeader())
101     return std::nullopt;
102 
103   // If that IV isn't dead after we rewrite the exit condition in terms of
104   // another IV, there's no point in doing the transform.
105   if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
106     return std::nullopt;
107 
108   // Inserting instructions in the preheader has a runtime cost, scale
109   // the allowed cost with the loops trip count as best we can.
110   const unsigned ExpansionBudget = [&]() {
111     unsigned Budget = 2 * SCEVCheapExpansionBudget;
112     if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113       return std::min(Budget, SmallTC);
114     if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115       return std::min(Budget, *SmallTC);
116     // Unknown trip count, assume long running by default.
117     return Budget;
118   }();
119 
120   const SCEV *BECount = SE.getBackedgeTakenCount(L);
121   const DataLayout &DL = L->getHeader()->getDataLayout();
122   SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
123 
124   PHINode *ToHelpFold = nullptr;
125   const SCEV *TermValueS = nullptr;
126   bool MustDropPoison = false;
127   auto InsertPt = L->getLoopPreheader()->getTerminator();
128   for (PHINode &PN : L->getHeader()->phis()) {
129     if (ToFold == &PN)
130       continue;
131 
132     if (!SE.isSCEVable(PN.getType())) {
133       LLVM_DEBUG(dbgs() << "IV of phi '" << PN
134                         << "' is not SCEV-able, not qualified for the "
135                            "terminating condition folding.\n");
136       continue;
137     }
138     const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
139     // Only speculate on affine AddRec
140     if (!AddRec || !AddRec->isAffine()) {
141       LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
142                         << "' is not an affine add recursion, not qualified "
143                            "for the terminating condition folding.\n");
144       continue;
145     }
146 
147     // Check that we can compute the value of AddRec on the exiting iteration
148     // without soundness problems.  evaluateAtIteration internally needs
149     // to multiply the stride of the iteration number - which may wrap around.
150     // The issue here is subtle because computing the result accounting for
151     // wrap is insufficient. In order to use the result in an exit test, we
152     // must also know that AddRec doesn't take the same value on any previous
153     // iteration. The simplest case to consider is a candidate IV which is
154     // narrower than the trip count (and thus original IV), but this can
155     // also happen due to non-unit strides on the candidate IVs.
156     if (!AddRec->hasNoSelfWrap() ||
157         !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
158       continue;
159 
160     const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
161     const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
162     if (!Expander.isSafeToExpand(TermValueSLocal)) {
163       LLVM_DEBUG(
164           dbgs() << "Is not safe to expand terminating value for phi node" << PN
165                  << "\n");
166       continue;
167     }
168 
169     if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
170                                      InsertPt)) {
171       LLVM_DEBUG(
172           dbgs() << "Is too expensive to expand terminating value for phi node"
173                  << PN << "\n");
174       continue;
175     }
176 
177     // The candidate IV may have been otherwise dead and poison from the
178     // very first iteration.  If we can't disprove that, we can't use the IV.
179     if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
180       LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
181       continue;
182     }
183 
184     // The candidate IV may become poison on the last iteration.  If this
185     // value is not branched on, this is a well defined program.  We're
186     // about to add a new use to this IV, and we have to ensure we don't
187     // insert UB which didn't previously exist.
188     bool MustDropPoisonLocal = false;
189     Instruction *PostIncV =
190         cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
191     if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
192                                        &DT)) {
193       LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
194                         << "\n");
195 
196       // If this is a complex recurrance with multiple instructions computing
197       // the backedge value, we might need to strip poison flags from all of
198       // them.
199       if (PostIncV->getOperand(0) != &PN)
200         continue;
201 
202       // In order to perform the transform, we need to drop the poison
203       // generating flags on this instruction (if any).
204       MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
205     }
206 
207     // We pick the last legal alternate IV.  We could expore choosing an optimal
208     // alternate IV if we had a decent heuristic to do so.
209     ToHelpFold = &PN;
210     TermValueS = TermValueSLocal;
211     MustDropPoison = MustDropPoisonLocal;
212   }
213 
214   LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
215                  << "Cannot find other AddRec IV to help folding\n";);
216 
217   LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
218              << "\nFound loop that can fold terminating condition\n"
219              << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
220              << "  TermCond: " << *TermCond << "\n"
221              << "  BrandInst: " << *BI << "\n"
222              << "  ToFold: " << *ToFold << "\n"
223              << "  ToHelpFold: " << *ToHelpFold << "\n");
224 
225   if (!ToFold || !ToHelpFold)
226     return std::nullopt;
227   return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
228 }
229 
RunTermFold(Loop * L,ScalarEvolution & SE,DominatorTree & DT,LoopInfo & LI,const TargetTransformInfo & TTI,TargetLibraryInfo & TLI,MemorySSA * MSSA)230 static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
231                         LoopInfo &LI, const TargetTransformInfo &TTI,
232                         TargetLibraryInfo &TLI, MemorySSA *MSSA) {
233   std::unique_ptr<MemorySSAUpdater> MSSAU;
234   if (MSSA)
235     MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
236 
237   auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
238   if (!Opt)
239     return false;
240 
241   auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
242 
243   NumTermFold++;
244 
245   BasicBlock *LoopPreheader = L->getLoopPreheader();
246   BasicBlock *LoopLatch = L->getLoopLatch();
247 
248   (void)ToFold;
249   LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
250                     << *ToFold << "\n"
251                     << "New term-cond phi-node:\n"
252                     << *ToHelpFold << "\n");
253 
254   Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
255   (void)StartValue;
256   Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
257 
258   // See comment in canFoldTermCondOfLoop on why this is sufficient.
259   if (MustDrop)
260     cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
261 
262   // SCEVExpander for both use in preheader and latch
263   const DataLayout &DL = L->getHeader()->getDataLayout();
264   SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
265 
266   assert(Expander.isSafeToExpand(TermValueS) &&
267          "Terminating value was checked safe in canFoldTerminatingCondition");
268 
269   // Create new terminating value at loop preheader
270   Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
271                                             LoopPreheader->getTerminator());
272 
273   LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
274                     << *StartValue << "\n"
275                     << "Terminating value of new term-cond phi-node:\n"
276                     << *TermValue << "\n");
277 
278   // Create new terminating condition at loop latch
279   BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
280   ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
281   IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
282   Value *NewTermCond =
283       LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
284                               "lsr_fold_term_cond.replaced_term_cond");
285   // Swap successors to exit loop body if IV equals to new TermValue
286   if (BI->getSuccessor(0) == L->getHeader())
287     BI->swapSuccessors();
288 
289   LLVM_DEBUG(dbgs() << "Old term-cond:\n"
290                     << *OldTermCond << "\n"
291                     << "New term-cond:\n"
292                     << *NewTermCond << "\n");
293 
294   BI->setCondition(NewTermCond);
295 
296   Expander.clear();
297   OldTermCond->eraseFromParent();
298   DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
299   return true;
300 }
301 
302 namespace {
303 
304 class LoopTermFold : public LoopPass {
305 public:
306   static char ID; // Pass ID, replacement for typeid
307 
308   LoopTermFold();
309 
310 private:
311   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
312   void getAnalysisUsage(AnalysisUsage &AU) const override;
313 };
314 
315 } // end anonymous namespace
316 
LoopTermFold()317 LoopTermFold::LoopTermFold() : LoopPass(ID) {
318   initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
319 }
320 
getAnalysisUsage(AnalysisUsage & AU) const321 void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
322   AU.addRequired<LoopInfoWrapperPass>();
323   AU.addPreserved<LoopInfoWrapperPass>();
324   AU.addPreservedID(LoopSimplifyID);
325   AU.addRequiredID(LoopSimplifyID);
326   AU.addRequired<DominatorTreeWrapperPass>();
327   AU.addPreserved<DominatorTreeWrapperPass>();
328   AU.addRequired<ScalarEvolutionWrapperPass>();
329   AU.addPreserved<ScalarEvolutionWrapperPass>();
330   AU.addRequired<TargetLibraryInfoWrapperPass>();
331   AU.addRequired<TargetTransformInfoWrapperPass>();
332   AU.addPreserved<MemorySSAWrapperPass>();
333 }
334 
runOnLoop(Loop * L,LPPassManager &)335 bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
336   if (skipLoop(L))
337     return false;
338 
339   auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
340   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
341   auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
342   const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
343       *L->getHeader()->getParent());
344   auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
345       *L->getHeader()->getParent());
346   auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
347   MemorySSA *MSSA = nullptr;
348   if (MSSAAnalysis)
349     MSSA = &MSSAAnalysis->getMSSA();
350   return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
351 }
352 
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater &)353 PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
354                                         LoopStandardAnalysisResults &AR,
355                                         LPMUpdater &) {
356   if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
357     return PreservedAnalyses::all();
358 
359   auto PA = getLoopPassPreservedAnalyses();
360   if (AR.MSSA)
361     PA.preserve<MemorySSAAnalysis>();
362   return PA;
363 }
364 
365 char LoopTermFold::ID = 0;
366 
367 INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
368                       false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)369 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
370 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
372 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
373 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
374 INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
375                     false, false)
376 
377 Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
378