1 //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //===----------------------------------------------------------------------===//
9
10 #include "llvm/Transforms/Scalar/LoopTermFold.h"
11 #include "llvm/ADT/Statistic.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopPass.h"
15 #include "llvm/Analysis/MemorySSA.h"
16 #include "llvm/Analysis/MemorySSAUpdater.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/Analysis/ValueTracking.h"
22 #include "llvm/Config/llvm-config.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/IR/Value.h"
31 #include "llvm/InitializePasses.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Utils.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/Local.h"
39 #include "llvm/Transforms/Utils/LoopUtils.h"
40 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41 #include <cassert>
42 #include <optional>
43
44 using namespace llvm;
45
46 #define DEBUG_TYPE "loop-term-fold"
47
48 STATISTIC(NumTermFold,
49 "Number of terminating condition fold recognized and performed");
50
51 static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
canFoldTermCondOfLoop(Loop * L,ScalarEvolution & SE,DominatorTree & DT,const LoopInfo & LI,const TargetTransformInfo & TTI)52 canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
53 const LoopInfo &LI, const TargetTransformInfo &TTI) {
54 if (!L->isInnermost()) {
55 LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56 return std::nullopt;
57 }
58 // Only inspect on simple loop structure
59 if (!L->isLoopSimplifyForm()) {
60 LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61 return std::nullopt;
62 }
63
64 if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
65 LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66 return std::nullopt;
67 }
68
69 BasicBlock *LoopLatch = L->getLoopLatch();
70 BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
71 if (!BI || BI->isUnconditional())
72 return std::nullopt;
73 auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
74 if (!TermCond) {
75 LLVM_DEBUG(
76 dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77 return std::nullopt;
78 }
79 if (!TermCond->hasOneUse()) {
80 LLVM_DEBUG(
81 dbgs()
82 << "Cannot replace terminating condition with more than one use\n");
83 return std::nullopt;
84 }
85
86 BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
87 Value *RHS = TermCond->getOperand(1);
88 if (!LHS || !L->isLoopInvariant(RHS))
89 // We could pattern match the inverse form of the icmp, but that is
90 // non-canonical, and this pass is running *very* late in the pipeline.
91 return std::nullopt;
92
93 // Find the IV used by the current exit condition.
94 PHINode *ToFold;
95 Value *ToFoldStart, *ToFoldStep;
96 if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
97 return std::nullopt;
98
99 // Ensure the simple recurrence is a part of the current loop.
100 if (ToFold->getParent() != L->getHeader())
101 return std::nullopt;
102
103 // If that IV isn't dead after we rewrite the exit condition in terms of
104 // another IV, there's no point in doing the transform.
105 if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
106 return std::nullopt;
107
108 // Inserting instructions in the preheader has a runtime cost, scale
109 // the allowed cost with the loops trip count as best we can.
110 const unsigned ExpansionBudget = [&]() {
111 unsigned Budget = 2 * SCEVCheapExpansionBudget;
112 if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113 return std::min(Budget, SmallTC);
114 if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115 return std::min(Budget, *SmallTC);
116 // Unknown trip count, assume long running by default.
117 return Budget;
118 }();
119
120 const SCEV *BECount = SE.getBackedgeTakenCount(L);
121 const DataLayout &DL = L->getHeader()->getDataLayout();
122 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
123
124 PHINode *ToHelpFold = nullptr;
125 const SCEV *TermValueS = nullptr;
126 bool MustDropPoison = false;
127 auto InsertPt = L->getLoopPreheader()->getTerminator();
128 for (PHINode &PN : L->getHeader()->phis()) {
129 if (ToFold == &PN)
130 continue;
131
132 if (!SE.isSCEVable(PN.getType())) {
133 LLVM_DEBUG(dbgs() << "IV of phi '" << PN
134 << "' is not SCEV-able, not qualified for the "
135 "terminating condition folding.\n");
136 continue;
137 }
138 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
139 // Only speculate on affine AddRec
140 if (!AddRec || !AddRec->isAffine()) {
141 LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
142 << "' is not an affine add recursion, not qualified "
143 "for the terminating condition folding.\n");
144 continue;
145 }
146
147 // Check that we can compute the value of AddRec on the exiting iteration
148 // without soundness problems. evaluateAtIteration internally needs
149 // to multiply the stride of the iteration number - which may wrap around.
150 // The issue here is subtle because computing the result accounting for
151 // wrap is insufficient. In order to use the result in an exit test, we
152 // must also know that AddRec doesn't take the same value on any previous
153 // iteration. The simplest case to consider is a candidate IV which is
154 // narrower than the trip count (and thus original IV), but this can
155 // also happen due to non-unit strides on the candidate IVs.
156 if (!AddRec->hasNoSelfWrap() ||
157 !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
158 continue;
159
160 const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
161 const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
162 if (!Expander.isSafeToExpand(TermValueSLocal)) {
163 LLVM_DEBUG(
164 dbgs() << "Is not safe to expand terminating value for phi node" << PN
165 << "\n");
166 continue;
167 }
168
169 if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
170 InsertPt)) {
171 LLVM_DEBUG(
172 dbgs() << "Is too expensive to expand terminating value for phi node"
173 << PN << "\n");
174 continue;
175 }
176
177 // The candidate IV may have been otherwise dead and poison from the
178 // very first iteration. If we can't disprove that, we can't use the IV.
179 if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
180 LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
181 continue;
182 }
183
184 // The candidate IV may become poison on the last iteration. If this
185 // value is not branched on, this is a well defined program. We're
186 // about to add a new use to this IV, and we have to ensure we don't
187 // insert UB which didn't previously exist.
188 bool MustDropPoisonLocal = false;
189 Instruction *PostIncV =
190 cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
191 if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
192 &DT)) {
193 LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
194 << "\n");
195
196 // If this is a complex recurrance with multiple instructions computing
197 // the backedge value, we might need to strip poison flags from all of
198 // them.
199 if (PostIncV->getOperand(0) != &PN)
200 continue;
201
202 // In order to perform the transform, we need to drop the poison
203 // generating flags on this instruction (if any).
204 MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
205 }
206
207 // We pick the last legal alternate IV. We could expore choosing an optimal
208 // alternate IV if we had a decent heuristic to do so.
209 ToHelpFold = &PN;
210 TermValueS = TermValueSLocal;
211 MustDropPoison = MustDropPoisonLocal;
212 }
213
214 LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
215 << "Cannot find other AddRec IV to help folding\n";);
216
217 LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
218 << "\nFound loop that can fold terminating condition\n"
219 << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
220 << " TermCond: " << *TermCond << "\n"
221 << " BrandInst: " << *BI << "\n"
222 << " ToFold: " << *ToFold << "\n"
223 << " ToHelpFold: " << *ToHelpFold << "\n");
224
225 if (!ToFold || !ToHelpFold)
226 return std::nullopt;
227 return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
228 }
229
RunTermFold(Loop * L,ScalarEvolution & SE,DominatorTree & DT,LoopInfo & LI,const TargetTransformInfo & TTI,TargetLibraryInfo & TLI,MemorySSA * MSSA)230 static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
231 LoopInfo &LI, const TargetTransformInfo &TTI,
232 TargetLibraryInfo &TLI, MemorySSA *MSSA) {
233 std::unique_ptr<MemorySSAUpdater> MSSAU;
234 if (MSSA)
235 MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
236
237 auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
238 if (!Opt)
239 return false;
240
241 auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
242
243 NumTermFold++;
244
245 BasicBlock *LoopPreheader = L->getLoopPreheader();
246 BasicBlock *LoopLatch = L->getLoopLatch();
247
248 (void)ToFold;
249 LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
250 << *ToFold << "\n"
251 << "New term-cond phi-node:\n"
252 << *ToHelpFold << "\n");
253
254 Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
255 (void)StartValue;
256 Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
257
258 // See comment in canFoldTermCondOfLoop on why this is sufficient.
259 if (MustDrop)
260 cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
261
262 // SCEVExpander for both use in preheader and latch
263 const DataLayout &DL = L->getHeader()->getDataLayout();
264 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
265
266 assert(Expander.isSafeToExpand(TermValueS) &&
267 "Terminating value was checked safe in canFoldTerminatingCondition");
268
269 // Create new terminating value at loop preheader
270 Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
271 LoopPreheader->getTerminator());
272
273 LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
274 << *StartValue << "\n"
275 << "Terminating value of new term-cond phi-node:\n"
276 << *TermValue << "\n");
277
278 // Create new terminating condition at loop latch
279 BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
280 ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
281 IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
282 Value *NewTermCond =
283 LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
284 "lsr_fold_term_cond.replaced_term_cond");
285 // Swap successors to exit loop body if IV equals to new TermValue
286 if (BI->getSuccessor(0) == L->getHeader())
287 BI->swapSuccessors();
288
289 LLVM_DEBUG(dbgs() << "Old term-cond:\n"
290 << *OldTermCond << "\n"
291 << "New term-cond:\n"
292 << *NewTermCond << "\n");
293
294 BI->setCondition(NewTermCond);
295
296 Expander.clear();
297 OldTermCond->eraseFromParent();
298 DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
299 return true;
300 }
301
302 namespace {
303
304 class LoopTermFold : public LoopPass {
305 public:
306 static char ID; // Pass ID, replacement for typeid
307
308 LoopTermFold();
309
310 private:
311 bool runOnLoop(Loop *L, LPPassManager &LPM) override;
312 void getAnalysisUsage(AnalysisUsage &AU) const override;
313 };
314
315 } // end anonymous namespace
316
LoopTermFold()317 LoopTermFold::LoopTermFold() : LoopPass(ID) {
318 initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
319 }
320
getAnalysisUsage(AnalysisUsage & AU) const321 void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
322 AU.addRequired<LoopInfoWrapperPass>();
323 AU.addPreserved<LoopInfoWrapperPass>();
324 AU.addPreservedID(LoopSimplifyID);
325 AU.addRequiredID(LoopSimplifyID);
326 AU.addRequired<DominatorTreeWrapperPass>();
327 AU.addPreserved<DominatorTreeWrapperPass>();
328 AU.addRequired<ScalarEvolutionWrapperPass>();
329 AU.addPreserved<ScalarEvolutionWrapperPass>();
330 AU.addRequired<TargetLibraryInfoWrapperPass>();
331 AU.addRequired<TargetTransformInfoWrapperPass>();
332 AU.addPreserved<MemorySSAWrapperPass>();
333 }
334
runOnLoop(Loop * L,LPPassManager &)335 bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
336 if (skipLoop(L))
337 return false;
338
339 auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
340 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
341 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
342 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
343 *L->getHeader()->getParent());
344 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
345 *L->getHeader()->getParent());
346 auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
347 MemorySSA *MSSA = nullptr;
348 if (MSSAAnalysis)
349 MSSA = &MSSAAnalysis->getMSSA();
350 return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
351 }
352
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater &)353 PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
354 LoopStandardAnalysisResults &AR,
355 LPMUpdater &) {
356 if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
357 return PreservedAnalyses::all();
358
359 auto PA = getLoopPassPreservedAnalyses();
360 if (AR.MSSA)
361 PA.preserve<MemorySSAAnalysis>();
362 return PA;
363 }
364
365 char LoopTermFold::ID = 0;
366
367 INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
368 false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)369 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
370 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
372 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
373 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
374 INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
375 false, false)
376
377 Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
378