xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp (revision 38a52bd3b5cac3da6f7f6eef3dd050e6aa08ebb3)
1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAccessAnalysis.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopIterator.h"
15 #include "llvm/Analysis/LoopPass.h"
16 #include "llvm/Analysis/MemorySSA.h"
17 #include "llvm/Analysis/MemorySSAUpdater.h"
18 #include "llvm/Analysis/ScalarEvolution.h"
19 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "llvm/Transforms/Utils/LoopSimplify.h"
24 #include "llvm/Transforms/Utils/LoopUtils.h"
25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
26 
27 #define DEBUG_TYPE "loop-bound-split"
28 
29 namespace llvm {
30 
31 using namespace PatternMatch;
32 
33 namespace {
34 struct ConditionInfo {
35   /// Branch instruction with this condition
36   BranchInst *BI;
37   /// ICmp instruction with this condition
38   ICmpInst *ICmp;
39   /// Preciate info
40   ICmpInst::Predicate Pred;
41   /// AddRec llvm value
42   Value *AddRecValue;
43   /// Non PHI AddRec llvm value
44   Value *NonPHIAddRecValue;
45   /// Bound llvm value
46   Value *BoundValue;
47   /// AddRec SCEV
48   const SCEVAddRecExpr *AddRecSCEV;
49   /// Bound SCEV
50   const SCEV *BoundSCEV;
51 
52   ConditionInfo()
53       : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
54         AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
55         BoundSCEV(nullptr) {}
56 };
57 } // namespace
58 
59 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
60                         ConditionInfo &Cond, const Loop &L) {
61   Cond.ICmp = ICmp;
62   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
63                          m_Value(Cond.BoundValue)))) {
64     const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
65     const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
66     const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
67     const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
68     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
69     if (!LHSAddRecSCEV && RHSAddRecSCEV) {
70       std::swap(Cond.AddRecValue, Cond.BoundValue);
71       std::swap(AddRecSCEV, BoundSCEV);
72       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
73     }
74 
75     Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
76     Cond.BoundSCEV = BoundSCEV;
77     Cond.NonPHIAddRecValue = Cond.AddRecValue;
78 
79     // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
80     // value from backedge.
81     if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
82       PHINode *PN = cast<PHINode>(Cond.AddRecValue);
83       Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
84     }
85   }
86 }
87 
88 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
89                                 ConditionInfo &Cond, bool IsExitCond) {
90   if (IsExitCond) {
91     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
92     if (isa<SCEVCouldNotCompute>(ExitCount))
93       return false;
94 
95     Cond.BoundSCEV = ExitCount;
96     return true;
97   }
98 
99   // For non-exit condtion, if pred is LT, keep existing bound.
100   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
101     return true;
102 
103   // For non-exit condition, if pre is LE, try to convert it to LT.
104   //      Range                 Range
105   // AddRec <= Bound  -->  AddRec < Bound + 1
106   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
107     return false;
108 
109   if (IntegerType *BoundSCEVIntType =
110           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
111     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
112     APInt Max = ICmpInst::isSigned(Cond.Pred)
113                     ? APInt::getSignedMaxValue(BitWidth)
114                     : APInt::getMaxValue(BitWidth);
115     const SCEV *MaxSCEV = SE.getConstant(Max);
116     // Check Bound < INT_MAX
117     ICmpInst::Predicate Pred =
118         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
119     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
120       const SCEV *BoundPlusOneSCEV =
121           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
122       Cond.BoundSCEV = BoundPlusOneSCEV;
123       Cond.Pred = Pred;
124       return true;
125     }
126   }
127 
128   // ToDo: Support ICMP_NE/EQ.
129 
130   return false;
131 }
132 
133 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
134                                     ICmpInst *ICmp, ConditionInfo &Cond,
135                                     bool IsExitCond) {
136   analyzeICmp(SE, ICmp, Cond, L);
137 
138   // The BoundSCEV should be evaluated at loop entry.
139   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
140     return false;
141 
142   // Allowed AddRec as induction variable.
143   if (!Cond.AddRecSCEV)
144     return false;
145 
146   if (!Cond.AddRecSCEV->isAffine())
147     return false;
148 
149   const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
150   // Allowed constant step.
151   if (!isa<SCEVConstant>(StepRecSCEV))
152     return false;
153 
154   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
155   // Allowed positive step for now.
156   // TODO: Support negative step.
157   if (StepCI->isNegative() || StepCI->isZero())
158     return false;
159 
160   // Calculate upper bound.
161   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
162     return false;
163 
164   return true;
165 }
166 
167 static bool isProcessableCondBI(const ScalarEvolution &SE,
168                                 const BranchInst *BI) {
169   BasicBlock *TrueSucc = nullptr;
170   BasicBlock *FalseSucc = nullptr;
171   ICmpInst::Predicate Pred;
172   Value *LHS, *RHS;
173   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
174                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
175     return false;
176 
177   if (!SE.isSCEVable(LHS->getType()))
178     return false;
179   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
180 
181   if (TrueSucc == FalseSucc)
182     return false;
183 
184   return true;
185 }
186 
187 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
188                               ScalarEvolution &SE, ConditionInfo &Cond) {
189   // Skip function with optsize.
190   if (L.getHeader()->getParent()->hasOptSize())
191     return false;
192 
193   // Split only innermost loop.
194   if (!L.isInnermost())
195     return false;
196 
197   // Check loop is in simplified form.
198   if (!L.isLoopSimplifyForm())
199     return false;
200 
201   // Check loop is in LCSSA form.
202   if (!L.isLCSSAForm(DT))
203     return false;
204 
205   // Skip loop that cannot be cloned.
206   if (!L.isSafeToClone())
207     return false;
208 
209   BasicBlock *ExitingBB = L.getExitingBlock();
210   // Assumed only one exiting block.
211   if (!ExitingBB)
212     return false;
213 
214   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
215   if (!ExitingBI)
216     return false;
217 
218   // Allowed only conditional branch with ICmp.
219   if (!isProcessableCondBI(SE, ExitingBI))
220     return false;
221 
222   // Check the condition is processable.
223   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
224   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
225     return false;
226 
227   Cond.BI = ExitingBI;
228   return true;
229 }
230 
231 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
232   // If the conditional branch splits a loop into two halves, we could
233   // generally say it is profitable.
234   //
235   // ToDo: Add more profitable cases here.
236 
237   // Check this branch causes diamond CFG.
238   BasicBlock *Succ0 = BI->getSuccessor(0);
239   BasicBlock *Succ1 = BI->getSuccessor(1);
240 
241   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
242   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
243   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
244     return false;
245 
246   // ToDo: Calculate each successor's instruction cost.
247 
248   return true;
249 }
250 
251 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
252                                       ConditionInfo &ExitingCond,
253                                       ConditionInfo &SplitCandidateCond) {
254   for (auto *BB : L.blocks()) {
255     // Skip condition of backedge.
256     if (L.getLoopLatch() == BB)
257       continue;
258 
259     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
260     if (!BI)
261       continue;
262 
263     // Check conditional branch with ICmp.
264     if (!isProcessableCondBI(SE, BI))
265       continue;
266 
267     // Skip loop invariant condition.
268     if (L.isLoopInvariant(BI->getCondition()))
269       continue;
270 
271     // Check the condition is processable.
272     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
273     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
274                                  /*IsExitCond*/ false))
275       continue;
276 
277     if (ExitingCond.BoundSCEV->getType() !=
278         SplitCandidateCond.BoundSCEV->getType())
279       continue;
280 
281     // After transformation, we assume the split condition of the pre-loop is
282     // always true. In order to guarantee it, we need to check the start value
283     // of the split cond AddRec satisfies the split condition.
284     if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
285                                      SplitCandidateCond.AddRecSCEV->getStart(),
286                                      SplitCandidateCond.BoundSCEV))
287       continue;
288 
289     SplitCandidateCond.BI = BI;
290     return BI;
291   }
292 
293   return nullptr;
294 }
295 
296 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
297                            ScalarEvolution &SE, LPMUpdater &U) {
298   ConditionInfo SplitCandidateCond;
299   ConditionInfo ExitingCond;
300 
301   // Check we can split this loop's bound.
302   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
303     return false;
304 
305   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
306     return false;
307 
308   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
309     return false;
310 
311   // Now, we have a split candidate. Let's build a form as below.
312   //    +--------------------+
313   //    |     preheader      |
314   //    |  set up newbound   |
315   //    +--------------------+
316   //             |     /----------------\
317   //    +--------v----v------+          |
318   //    |      header        |---\      |
319   //    | with true condition|   |      |
320   //    +--------------------+   |      |
321   //             |               |      |
322   //    +--------v-----------+   |      |
323   //    |     if.then.BB     |   |      |
324   //    +--------------------+   |      |
325   //             |               |      |
326   //    +--------v-----------<---/      |
327   //    |       latch        >----------/
328   //    |   with newbound    |
329   //    +--------------------+
330   //             |
331   //    +--------v-----------+
332   //    |     preheader2     |--------------\
333   //    | if (AddRec i !=    |              |
334   //    |     org bound)     |              |
335   //    +--------------------+              |
336   //             |     /----------------\   |
337   //    +--------v----v------+          |   |
338   //    |      header2       |---\      |   |
339   //    | conditional branch |   |      |   |
340   //    |with false condition|   |      |   |
341   //    +--------------------+   |      |   |
342   //             |               |      |   |
343   //    +--------v-----------+   |      |   |
344   //    |    if.then.BB2     |   |      |   |
345   //    +--------------------+   |      |   |
346   //             |               |      |   |
347   //    +--------v-----------<---/      |   |
348   //    |       latch2       >----------/   |
349   //    |   with org bound   |              |
350   //    +--------v-----------+              |
351   //             |                          |
352   //             |  +---------------+       |
353   //             +-->     exit      <-------/
354   //                +---------------+
355 
356   // Let's create post loop.
357   SmallVector<BasicBlock *, 8> PostLoopBlocks;
358   Loop *PostLoop;
359   ValueToValueMapTy VMap;
360   BasicBlock *PreHeader = L.getLoopPreheader();
361   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
362   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
363                                     ".split", &LI, &DT, PostLoopBlocks);
364   remapInstructionsInBlocks(PostLoopBlocks, VMap);
365 
366   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
367   IRBuilder<> Builder(&PostLoopPreHeader->front());
368 
369   // Update phi nodes in header of post-loop.
370   bool isExitingLatch =
371       (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
372   Value *ExitingCondLCSSAPhi = nullptr;
373   for (PHINode &PN : L.getHeader()->phis()) {
374     // Create LCSSA phi node in preheader of post-loop.
375     PHINode *LCSSAPhi =
376         Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
377     LCSSAPhi->setDebugLoc(PN.getDebugLoc());
378     // If the exiting block is loop latch, the phi does not have the update at
379     // last iteration. In this case, update lcssa phi with value from backedge.
380     LCSSAPhi->addIncoming(
381         isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
382         L.getExitingBlock());
383 
384     // Update the start value of phi node in post-loop with the LCSSA phi node.
385     PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
386     PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
387 
388     // Find PHI with exiting condition from pre-loop. The PHI should be
389     // SCEVAddRecExpr and have same incoming value from backedge with
390     // ExitingCond.
391     if (!SE.isSCEVable(PN.getType()))
392       continue;
393 
394     const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
395     if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
396                        PN.getIncomingValueForBlock(L.getLoopLatch()))
397       ExitingCondLCSSAPhi = LCSSAPhi;
398   }
399 
400   // Add conditional branch to check we can skip post-loop in its preheader.
401   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
402   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
403   Value *Cond =
404       Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
405   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
406   OrigBI->eraseFromParent();
407 
408   // Create new loop bound and add it into preheader of pre-loop.
409   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
410   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
411   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
412                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
413                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
414 
415   SCEVExpander Expander(
416       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
417   Instruction *InsertPt = SplitLoopPH->getTerminator();
418   Value *NewBoundValue =
419       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
420   NewBoundValue->setName("new.bound");
421 
422   // Replace exiting bound value of pre-loop NewBound.
423   ExitingCond.ICmp->setOperand(1, NewBoundValue);
424 
425   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
426   LLVMContext &Context = PreHeader->getContext();
427   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
428 
429   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
430   BranchInst *ClonedSplitCandidateBI =
431       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
432   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
433 
434   // Replace exit branch target of pre-loop by post-loop's preheader.
435   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
436     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
437   else
438     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
439 
440   // Update phi node in exit block of post-loop.
441   Builder.SetInsertPoint(&PostLoopPreHeader->front());
442   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
443     for (auto i : seq<int>(0, PN.getNumOperands())) {
444       // Check incoming block is pre-loop's exiting block.
445       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
446         Value *IncomingValue = PN.getIncomingValue(i);
447 
448         // Create LCSSA phi node for incoming value.
449         PHINode *LCSSAPhi =
450             Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
451         LCSSAPhi->setDebugLoc(PN.getDebugLoc());
452         LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
453 
454         // Replace pre-loop's exiting block by post-loop's preheader.
455         PN.setIncomingBlock(i, PostLoopPreHeader);
456         // Replace incoming value by LCSSAPhi.
457         PN.setIncomingValue(i, LCSSAPhi);
458         // Add a new incoming value with post-loop's exiting block.
459         PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
460       }
461     }
462   }
463 
464   // Update dominator tree.
465   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
466   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
467 
468   // Invalidate cached SE information.
469   SE.forgetLoop(&L);
470 
471   // Canonicalize loops.
472   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
473   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
474 
475   // Add new post-loop to loop pass manager.
476   U.addSiblingLoops(PostLoop);
477 
478   return true;
479 }
480 
481 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
482                                           LoopStandardAnalysisResults &AR,
483                                           LPMUpdater &U) {
484   Function &F = *L.getHeader()->getParent();
485   (void)F;
486 
487   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
488                     << "\n");
489 
490   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
491     return PreservedAnalyses::all();
492 
493   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
494   AR.LI.verify(AR.DT);
495 
496   return getLoopPassPreservedAnalyses();
497 }
498 
499 } // end namespace llvm
500