xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp (revision 52d973f52c07b94909a6487be373c269988dc151)
1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/Analysis/LoopAccessAnalysis.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/LoopIterator.h"
14 #include "llvm/Analysis/LoopPass.h"
15 #include "llvm/Analysis/MemorySSA.h"
16 #include "llvm/Analysis/MemorySSAUpdater.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/IR/PatternMatch.h"
20 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
21 #include "llvm/Transforms/Utils/Cloning.h"
22 #include "llvm/Transforms/Utils/LoopSimplify.h"
23 #include "llvm/Transforms/Utils/LoopUtils.h"
24 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
25 
26 #define DEBUG_TYPE "loop-bound-split"
27 
28 namespace llvm {
29 
30 using namespace PatternMatch;
31 
32 namespace {
33 struct ConditionInfo {
34   /// Branch instruction with this condition
35   BranchInst *BI;
36   /// ICmp instruction with this condition
37   ICmpInst *ICmp;
38   /// Preciate info
39   ICmpInst::Predicate Pred;
40   /// AddRec llvm value
41   Value *AddRecValue;
42   /// Bound llvm value
43   Value *BoundValue;
44   /// AddRec SCEV
45   const SCEV *AddRecSCEV;
46   /// Bound SCEV
47   const SCEV *BoundSCEV;
48 
49   ConditionInfo()
50       : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
51         AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
52         BoundSCEV(nullptr) {}
53 };
54 } // namespace
55 
56 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
57                         ConditionInfo &Cond) {
58   Cond.ICmp = ICmp;
59   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
60                          m_Value(Cond.BoundValue)))) {
61     Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
62     Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue);
63     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
64     if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) &&
65         !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) {
66       std::swap(Cond.AddRecValue, Cond.BoundValue);
67       std::swap(Cond.AddRecSCEV, Cond.BoundSCEV);
68       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
69     }
70   }
71 }
72 
73 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
74                                 ConditionInfo &Cond, bool IsExitCond) {
75   if (IsExitCond) {
76     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
77     if (isa<SCEVCouldNotCompute>(ExitCount))
78       return false;
79 
80     Cond.BoundSCEV = ExitCount;
81     return true;
82   }
83 
84   // For non-exit condtion, if pred is LT, keep existing bound.
85   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
86     return true;
87 
88   // For non-exit condition, if pre is LE, try to convert it to LT.
89   //      Range                 Range
90   // AddRec <= Bound  -->  AddRec < Bound + 1
91   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
92     return false;
93 
94   if (IntegerType *BoundSCEVIntType =
95           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
96     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
97     APInt Max = ICmpInst::isSigned(Cond.Pred)
98                     ? APInt::getSignedMaxValue(BitWidth)
99                     : APInt::getMaxValue(BitWidth);
100     const SCEV *MaxSCEV = SE.getConstant(Max);
101     // Check Bound < INT_MAX
102     ICmpInst::Predicate Pred =
103         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
104     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
105       const SCEV *BoundPlusOneSCEV =
106           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
107       Cond.BoundSCEV = BoundPlusOneSCEV;
108       Cond.Pred = Pred;
109       return true;
110     }
111   }
112 
113   // ToDo: Support ICMP_NE/EQ.
114 
115   return false;
116 }
117 
118 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
119                                     ICmpInst *ICmp, ConditionInfo &Cond,
120                                     bool IsExitCond) {
121   analyzeICmp(SE, ICmp, Cond);
122 
123   // The BoundSCEV should be evaluated at loop entry.
124   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
125     return false;
126 
127   const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV);
128   // Allowed AddRec as induction variable.
129   if (!AddRecSCEV)
130     return false;
131 
132   if (!AddRecSCEV->isAffine())
133     return false;
134 
135   const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE);
136   // Allowed constant step.
137   if (!isa<SCEVConstant>(StepRecSCEV))
138     return false;
139 
140   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
141   // Allowed positive step for now.
142   // TODO: Support negative step.
143   if (StepCI->isNegative() || StepCI->isZero())
144     return false;
145 
146   // Calculate upper bound.
147   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
148     return false;
149 
150   return true;
151 }
152 
153 static bool isProcessableCondBI(const ScalarEvolution &SE,
154                                 const BranchInst *BI) {
155   BasicBlock *TrueSucc = nullptr;
156   BasicBlock *FalseSucc = nullptr;
157   ICmpInst::Predicate Pred;
158   Value *LHS, *RHS;
159   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
160                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
161     return false;
162 
163   if (!SE.isSCEVable(LHS->getType()))
164     return false;
165   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
166 
167   if (TrueSucc == FalseSucc)
168     return false;
169 
170   return true;
171 }
172 
173 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
174                               ScalarEvolution &SE, ConditionInfo &Cond) {
175   // Skip function with optsize.
176   if (L.getHeader()->getParent()->hasOptSize())
177     return false;
178 
179   // Split only innermost loop.
180   if (!L.isInnermost())
181     return false;
182 
183   // Check loop is in simplified form.
184   if (!L.isLoopSimplifyForm())
185     return false;
186 
187   // Check loop is in LCSSA form.
188   if (!L.isLCSSAForm(DT))
189     return false;
190 
191   // Skip loop that cannot be cloned.
192   if (!L.isSafeToClone())
193     return false;
194 
195   BasicBlock *ExitingBB = L.getExitingBlock();
196   // Assumed only one exiting block.
197   if (!ExitingBB)
198     return false;
199 
200   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
201   if (!ExitingBI)
202     return false;
203 
204   // Allowed only conditional branch with ICmp.
205   if (!isProcessableCondBI(SE, ExitingBI))
206     return false;
207 
208   // Check the condition is processable.
209   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
210   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
211     return false;
212 
213   Cond.BI = ExitingBI;
214   return true;
215 }
216 
217 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
218   // If the conditional branch splits a loop into two halves, we could
219   // generally say it is profitable.
220   //
221   // ToDo: Add more profitable cases here.
222 
223   // Check this branch causes diamond CFG.
224   BasicBlock *Succ0 = BI->getSuccessor(0);
225   BasicBlock *Succ1 = BI->getSuccessor(1);
226 
227   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
228   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
229   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
230     return false;
231 
232   // ToDo: Calculate each successor's instruction cost.
233 
234   return true;
235 }
236 
237 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
238                                       ConditionInfo &ExitingCond,
239                                       ConditionInfo &SplitCandidateCond) {
240   for (auto *BB : L.blocks()) {
241     // Skip condition of backedge.
242     if (L.getLoopLatch() == BB)
243       continue;
244 
245     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
246     if (!BI)
247       continue;
248 
249     // Check conditional branch with ICmp.
250     if (!isProcessableCondBI(SE, BI))
251       continue;
252 
253     // Skip loop invariant condition.
254     if (L.isLoopInvariant(BI->getCondition()))
255       continue;
256 
257     // Check the condition is processable.
258     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
259     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
260                                  /*IsExitCond*/ false))
261       continue;
262 
263     if (ExitingCond.BoundSCEV->getType() !=
264         SplitCandidateCond.BoundSCEV->getType())
265       continue;
266 
267     SplitCandidateCond.BI = BI;
268     return BI;
269   }
270 
271   return nullptr;
272 }
273 
274 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
275                            ScalarEvolution &SE, LPMUpdater &U) {
276   ConditionInfo SplitCandidateCond;
277   ConditionInfo ExitingCond;
278 
279   // Check we can split this loop's bound.
280   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
281     return false;
282 
283   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
284     return false;
285 
286   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
287     return false;
288 
289   // Now, we have a split candidate. Let's build a form as below.
290   //    +--------------------+
291   //    |     preheader      |
292   //    |  set up newbound   |
293   //    +--------------------+
294   //             |     /----------------\
295   //    +--------v----v------+          |
296   //    |      header        |---\      |
297   //    | with true condition|   |      |
298   //    +--------------------+   |      |
299   //             |               |      |
300   //    +--------v-----------+   |      |
301   //    |     if.then.BB     |   |      |
302   //    +--------------------+   |      |
303   //             |               |      |
304   //    +--------v-----------<---/      |
305   //    |       latch        >----------/
306   //    |   with newbound    |
307   //    +--------------------+
308   //             |
309   //    +--------v-----------+
310   //    |     preheader2     |--------------\
311   //    | if (AddRec i !=    |              |
312   //    |     org bound)     |              |
313   //    +--------------------+              |
314   //             |     /----------------\   |
315   //    +--------v----v------+          |   |
316   //    |      header2       |---\      |   |
317   //    | conditional branch |   |      |   |
318   //    |with false condition|   |      |   |
319   //    +--------------------+   |      |   |
320   //             |               |      |   |
321   //    +--------v-----------+   |      |   |
322   //    |    if.then.BB2     |   |      |   |
323   //    +--------------------+   |      |   |
324   //             |               |      |   |
325   //    +--------v-----------<---/      |   |
326   //    |       latch2       >----------/   |
327   //    |   with org bound   |              |
328   //    +--------v-----------+              |
329   //             |                          |
330   //             |  +---------------+       |
331   //             +-->     exit      <-------/
332   //                +---------------+
333 
334   // Let's create post loop.
335   SmallVector<BasicBlock *, 8> PostLoopBlocks;
336   Loop *PostLoop;
337   ValueToValueMapTy VMap;
338   BasicBlock *PreHeader = L.getLoopPreheader();
339   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
340   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
341                                     ".split", &LI, &DT, PostLoopBlocks);
342   remapInstructionsInBlocks(PostLoopBlocks, VMap);
343 
344   // Add conditional branch to check we can skip post-loop in its preheader.
345   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
346   IRBuilder<> Builder(PostLoopPreHeader);
347   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
348   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
349   Value *Cond =
350       Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue);
351   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
352   OrigBI->eraseFromParent();
353 
354   // Create new loop bound and add it into preheader of pre-loop.
355   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
356   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
357   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
358                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
359                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
360 
361   SCEVExpander Expander(
362       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
363   Instruction *InsertPt = SplitLoopPH->getTerminator();
364   Value *NewBoundValue =
365       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
366   NewBoundValue->setName("new.bound");
367 
368   // Replace exiting bound value of pre-loop NewBound.
369   ExitingCond.ICmp->setOperand(1, NewBoundValue);
370 
371   // Replace IV's start value of post-loop by NewBound.
372   for (PHINode &PN : L.getHeader()->phis()) {
373     // Find PHI with exiting condition from pre-loop.
374     if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) {
375       for (Value *Op : PN.incoming_values()) {
376         if (Op == ExitingCond.AddRecValue) {
377           // Find cloned PHI for post-loop.
378           PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
379           PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader,
380                                                NewBoundValue);
381         }
382       }
383     }
384   }
385 
386   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
387   LLVMContext &Context = PreHeader->getContext();
388   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
389 
390   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
391   BranchInst *ClonedSplitCandidateBI =
392       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
393   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
394 
395   // Replace exit branch target of pre-loop by post-loop's preheader.
396   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
397     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
398   else
399     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
400 
401   // Update dominator tree.
402   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
403   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
404 
405   // Invalidate cached SE information.
406   SE.forgetLoop(&L);
407 
408   // Canonicalize loops.
409   // TODO: Try to update LCSSA information according to above change.
410   formLCSSA(L, DT, &LI, &SE);
411   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
412   formLCSSA(*PostLoop, DT, &LI, &SE);
413   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
414 
415   // Add new post-loop to loop pass manager.
416   U.addSiblingLoops(PostLoop);
417 
418   return true;
419 }
420 
421 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
422                                           LoopStandardAnalysisResults &AR,
423                                           LPMUpdater &U) {
424   Function &F = *L.getHeader()->getParent();
425   (void)F;
426 
427   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
428                     << "\n");
429 
430   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
431     return PreservedAnalyses::all();
432 
433   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
434   AR.LI.verify(AR.DT);
435 
436   return getLoopPassPreservedAnalyses();
437 }
438 
439 } // end namespace llvm
440