xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 #include "llvm/Transforms/Utils/LoopConstrainer.h"
2 #include "llvm/Analysis/LoopInfo.h"
3 #include "llvm/Analysis/ScalarEvolution.h"
4 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
5 #include "llvm/IR/Dominators.h"
6 #include "llvm/Transforms/Utils/Cloning.h"
7 #include "llvm/Transforms/Utils/LoopSimplify.h"
8 #include "llvm/Transforms/Utils/LoopUtils.h"
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 
11 using namespace llvm;
12 
13 static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
14 
15 #define DEBUG_TYPE "loop-constrainer"
16 
17 /// Given a loop with an deccreasing induction variable, is it possible to
18 /// safely calculate the bounds of a new loop using the given Predicate.
19 static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20                                   const SCEV *Step, ICmpInst::Predicate Pred,
21                                   unsigned LatchBrExitIdx, Loop *L,
22                                   ScalarEvolution &SE) {
23   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25     return false;
26 
27   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
28     return false;
29 
30   assert(SE.isKnownNegative(Step) && "expecting negative step");
31 
32   LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33   LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34   LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35   LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36   LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37   LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
38 
39   bool IsSigned = ICmpInst::isSigned(Pred);
40   // The predicate that we need to check that the induction variable lies
41   // within bounds.
42   ICmpInst::Predicate BoundPred =
43       IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
44 
45   auto StartLG = SE.applyLoopGuards(Start, L);
46   auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
47 
48   if (LatchBrExitIdx == 1)
49     return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
50 
51   assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
52 
53   const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
54   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
55   APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
56                        : APInt::getMinValue(BitWidth);
57   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
58 
59   const SCEV *MinusOne =
60       SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType()));
61 
62   return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) &&
63          SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit);
64 }
65 
66 /// Given a loop with an increasing induction variable, is it possible to
67 /// safely calculate the bounds of a new loop using the given Predicate.
68 static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
69                                   const SCEV *Step, ICmpInst::Predicate Pred,
70                                   unsigned LatchBrExitIdx, Loop *L,
71                                   ScalarEvolution &SE) {
72   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
73       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
74     return false;
75 
76   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
77     return false;
78 
79   LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
80   LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
81   LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
82   LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
83   LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
84   LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
85 
86   bool IsSigned = ICmpInst::isSigned(Pred);
87   // The predicate that we need to check that the induction variable lies
88   // within bounds.
89   ICmpInst::Predicate BoundPred =
90       IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
91 
92   auto StartLG = SE.applyLoopGuards(Start, L);
93   auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
94 
95   if (LatchBrExitIdx == 1)
96     return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
97 
98   assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
99 
100   const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
101   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
102   APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
103                        : APInt::getMaxValue(BitWidth);
104   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
105 
106   return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG,
107                                       SE.getAddExpr(BoundLG, Step)) &&
108           SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit));
109 }
110 
111 /// Returns estimate for max latch taken count of the loop of the narrowest
112 /// available type. If the latch block has such estimate, it is returned.
113 /// Otherwise, we use max exit count of whole loop (that is potentially of wider
114 /// type than latch check itself), which is still better than no estimate.
115 static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
116                                                           const Loop &L) {
117   const SCEV *FromBlock =
118       SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
119   if (isa<SCEVCouldNotCompute>(FromBlock))
120     return SE.getSymbolicMaxBackedgeTakenCount(&L);
121   return FromBlock;
122 }
123 
124 std::optional<LoopStructure>
125 LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
126                                   bool AllowUnsignedLatchCond,
127                                   const char *&FailureReason) {
128   if (!L.isLoopSimplifyForm()) {
129     FailureReason = "loop not in LoopSimplify form";
130     return std::nullopt;
131   }
132 
133   BasicBlock *Latch = L.getLoopLatch();
134   assert(Latch && "Simplified loops only have one latch!");
135 
136   if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
137     FailureReason = "loop has already been cloned";
138     return std::nullopt;
139   }
140 
141   if (!L.isLoopExiting(Latch)) {
142     FailureReason = "no loop latch";
143     return std::nullopt;
144   }
145 
146   BasicBlock *Header = L.getHeader();
147   BasicBlock *Preheader = L.getLoopPreheader();
148   if (!Preheader) {
149     FailureReason = "no preheader";
150     return std::nullopt;
151   }
152 
153   BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
154   if (!LatchBr || LatchBr->isUnconditional()) {
155     FailureReason = "latch terminator not conditional branch";
156     return std::nullopt;
157   }
158 
159   unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
160 
161   ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
162   if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
163     FailureReason = "latch terminator branch not conditional on integral icmp";
164     return std::nullopt;
165   }
166 
167   const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
168   if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
169     FailureReason = "could not compute latch count";
170     return std::nullopt;
171   }
172   assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
173              ScalarEvolution::LoopInvariant &&
174          "loop variant exit count doesn't make sense!");
175 
176   ICmpInst::Predicate Pred = ICI->getPredicate();
177   Value *LeftValue = ICI->getOperand(0);
178   const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
179   IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
180 
181   Value *RightValue = ICI->getOperand(1);
182   const SCEV *RightSCEV = SE.getSCEV(RightValue);
183 
184   // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
185   if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
186     if (isa<SCEVAddRecExpr>(RightSCEV)) {
187       std::swap(LeftSCEV, RightSCEV);
188       std::swap(LeftValue, RightValue);
189       Pred = ICmpInst::getSwappedPredicate(Pred);
190     } else {
191       FailureReason = "no add recurrences in the icmp";
192       return std::nullopt;
193     }
194   }
195 
196   auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
197     if (AR->getNoWrapFlags(SCEV::FlagNSW))
198       return true;
199 
200     IntegerType *Ty = cast<IntegerType>(AR->getType());
201     IntegerType *WideTy =
202         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
203 
204     const SCEVAddRecExpr *ExtendAfterOp =
205         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
206     if (ExtendAfterOp) {
207       const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
208       const SCEV *ExtendedStep =
209           SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
210 
211       bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
212                           ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
213 
214       if (NoSignedWrap)
215         return true;
216     }
217 
218     // We may have proved this when computing the sign extension above.
219     return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
220   };
221 
222   // `ICI` is interpreted as taking the backedge if the *next* value of the
223   // induction variable satisfies some constraint.
224 
225   const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
226   if (IndVarBase->getLoop() != &L) {
227     FailureReason = "LHS in cmp is not an AddRec for this loop";
228     return std::nullopt;
229   }
230   if (!IndVarBase->isAffine()) {
231     FailureReason = "LHS in icmp not induction variable";
232     return std::nullopt;
233   }
234   const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
235   if (!isa<SCEVConstant>(StepRec)) {
236     FailureReason = "LHS in icmp not induction variable";
237     return std::nullopt;
238   }
239   ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
240 
241   if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
242     FailureReason = "LHS in icmp needs nsw for equality predicates";
243     return std::nullopt;
244   }
245 
246   assert(!StepCI->isZero() && "Zero step?");
247   bool IsIncreasing = !StepCI->isNegative();
248   bool IsSignedPredicate;
249   const SCEV *StartNext = IndVarBase->getStart();
250   const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
251   const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
252   const SCEV *Step = SE.getSCEV(StepCI);
253 
254   const SCEV *FixedRightSCEV = nullptr;
255 
256   // If RightValue resides within loop (but still being loop invariant),
257   // regenerate it as preheader.
258   if (auto *I = dyn_cast<Instruction>(RightValue))
259     if (L.contains(I->getParent()))
260       FixedRightSCEV = RightSCEV;
261 
262   if (IsIncreasing) {
263     bool DecreasedRightValueByOne = false;
264     if (StepCI->isOne()) {
265       // Try to turn eq/ne predicates to those we can work with.
266       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
267         // while (++i != len) {         while (++i < len) {
268         //   ...                 --->     ...
269         // }                            }
270         // If both parts are known non-negative, it is profitable to use
271         // unsigned comparison in increasing loop. This allows us to make the
272         // comparison check against "RightSCEV + 1" more optimistic.
273         if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
274             isKnownNonNegativeInLoop(RightSCEV, &L, SE))
275           Pred = ICmpInst::ICMP_ULT;
276         else
277           Pred = ICmpInst::ICMP_SLT;
278       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
279         // while (true) {               while (true) {
280         //   if (++i == len)     --->     if (++i > len - 1)
281         //     break;                       break;
282         //   ...                          ...
283         // }                            }
284         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
285             cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
286           Pred = ICmpInst::ICMP_UGT;
287           RightSCEV =
288               SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
289           DecreasedRightValueByOne = true;
290         } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
291           Pred = ICmpInst::ICMP_SGT;
292           RightSCEV =
293               SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
294           DecreasedRightValueByOne = true;
295         }
296       }
297     }
298 
299     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
300     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
301     bool FoundExpectedPred =
302         (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
303 
304     if (!FoundExpectedPred) {
305       FailureReason = "expected icmp slt semantically, found something else";
306       return std::nullopt;
307     }
308 
309     IsSignedPredicate = ICmpInst::isSigned(Pred);
310     if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
311       FailureReason = "unsigned latch conditions are explicitly prohibited";
312       return std::nullopt;
313     }
314 
315     if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
316                                LatchBrExitIdx, &L, SE)) {
317       FailureReason = "Unsafe loop bounds";
318       return std::nullopt;
319     }
320     if (LatchBrExitIdx == 0) {
321       // We need to increase the right value unless we have already decreased
322       // it virtually when we replaced EQ with SGT.
323       if (!DecreasedRightValueByOne)
324         FixedRightSCEV =
325             SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
326     } else {
327       assert(!DecreasedRightValueByOne &&
328              "Right value can be decreased only for LatchBrExitIdx == 0!");
329     }
330   } else {
331     bool IncreasedRightValueByOne = false;
332     if (StepCI->isMinusOne()) {
333       // Try to turn eq/ne predicates to those we can work with.
334       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
335         // while (--i != len) {         while (--i > len) {
336         //   ...                 --->     ...
337         // }                            }
338         // We intentionally don't turn the predicate into UGT even if we know
339         // that both operands are non-negative, because it will only pessimize
340         // our check against "RightSCEV - 1".
341         Pred = ICmpInst::ICMP_SGT;
342       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
343         // while (true) {               while (true) {
344         //   if (--i == len)     --->     if (--i < len + 1)
345         //     break;                       break;
346         //   ...                          ...
347         // }                            }
348         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
349             cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
350           Pred = ICmpInst::ICMP_ULT;
351           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
352           IncreasedRightValueByOne = true;
353         } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
354           Pred = ICmpInst::ICMP_SLT;
355           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
356           IncreasedRightValueByOne = true;
357         }
358       }
359     }
360 
361     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
362     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
363 
364     bool FoundExpectedPred =
365         (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
366 
367     if (!FoundExpectedPred) {
368       FailureReason = "expected icmp sgt semantically, found something else";
369       return std::nullopt;
370     }
371 
372     IsSignedPredicate =
373         Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
374 
375     if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
376       FailureReason = "unsigned latch conditions are explicitly prohibited";
377       return std::nullopt;
378     }
379 
380     if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
381                                LatchBrExitIdx, &L, SE)) {
382       FailureReason = "Unsafe bounds";
383       return std::nullopt;
384     }
385 
386     if (LatchBrExitIdx == 0) {
387       // We need to decrease the right value unless we have already increased
388       // it virtually when we replaced EQ with SLT.
389       if (!IncreasedRightValueByOne)
390         FixedRightSCEV =
391             SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
392     } else {
393       assert(!IncreasedRightValueByOne &&
394              "Right value can be increased only for LatchBrExitIdx == 0!");
395     }
396   }
397   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
398 
399   assert(!L.contains(LatchExit) && "expected an exit block!");
400   const DataLayout &DL = Preheader->getDataLayout();
401   SCEVExpander Expander(SE, DL, "loop-constrainer");
402   Instruction *Ins = Preheader->getTerminator();
403 
404   if (FixedRightSCEV)
405     RightValue =
406         Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
407 
408   Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
409   IndVarStartV->setName("indvar.start");
410 
411   LoopStructure Result;
412 
413   Result.Tag = "main";
414   Result.Header = Header;
415   Result.Latch = Latch;
416   Result.LatchBr = LatchBr;
417   Result.LatchExit = LatchExit;
418   Result.LatchBrExitIdx = LatchBrExitIdx;
419   Result.IndVarStart = IndVarStartV;
420   Result.IndVarStep = StepCI;
421   Result.IndVarBase = LeftValue;
422   Result.IndVarIncreasing = IsIncreasing;
423   Result.LoopExitAt = RightValue;
424   Result.IsSignedPredicate = IsSignedPredicate;
425   Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
426 
427   FailureReason = nullptr;
428 
429   return Result;
430 }
431 
432 // Add metadata to the loop L to disable loop optimizations. Callers need to
433 // confirm that optimizing loop L is not beneficial.
434 static void DisableAllLoopOptsOnLoop(Loop &L) {
435   // We do not care about any existing loopID related metadata for L, since we
436   // are setting all loop metadata to false.
437   LLVMContext &Context = L.getHeader()->getContext();
438   // Reserve first location for self reference to the LoopID metadata node.
439   MDNode *Dummy = MDNode::get(Context, {});
440   MDNode *DisableUnroll = MDNode::get(
441       Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
442   Metadata *FalseVal =
443       ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
444   MDNode *DisableVectorize = MDNode::get(
445       Context,
446       {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
447   MDNode *DisableLICMVersioning = MDNode::get(
448       Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
449   MDNode *DisableDistribution = MDNode::get(
450       Context,
451       {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
452   MDNode *NewLoopID =
453       MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
454                             DisableLICMVersioning, DisableDistribution});
455   // Set operand 0 to refer to the loop id itself.
456   NewLoopID->replaceOperandWith(0, NewLoopID);
457   L.setLoopID(NewLoopID);
458 }
459 
460 LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
461                                  function_ref<void(Loop *, bool)> LPMAddNewLoop,
462                                  const LoopStructure &LS, ScalarEvolution &SE,
463                                  DominatorTree &DT, Type *T, SubRanges SR)
464     : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
465       DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
466       MainLoopStructure(LS), SR(SR) {}
467 
468 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
469                                 const char *Tag) const {
470   for (BasicBlock *BB : OriginalLoop.getBlocks()) {
471     BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
472     Result.Blocks.push_back(Clone);
473     Result.Map[BB] = Clone;
474   }
475 
476   auto GetClonedValue = [&Result](Value *V) {
477     assert(V && "null values not in domain!");
478     auto It = Result.Map.find(V);
479     if (It == Result.Map.end())
480       return V;
481     return static_cast<Value *>(It->second);
482   };
483 
484   auto *ClonedLatch =
485       cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
486   ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
487                                             MDNode::get(Ctx, {}));
488 
489   Result.Structure = MainLoopStructure.map(GetClonedValue);
490   Result.Structure.Tag = Tag;
491 
492   for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
493     BasicBlock *ClonedBB = Result.Blocks[i];
494     BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
495 
496     assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
497 
498     for (Instruction &I : *ClonedBB)
499       RemapInstruction(&I, Result.Map,
500                        RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
501 
502     // Exit blocks will now have one more predecessor and their PHI nodes need
503     // to be edited to reflect that.  No phi nodes need to be introduced because
504     // the loop is in LCSSA.
505 
506     for (auto *SBB : successors(OriginalBB)) {
507       if (OriginalLoop.contains(SBB))
508         continue; // not an exit block
509 
510       for (PHINode &PN : SBB->phis()) {
511         Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
512         PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
513         SE.forgetValue(&PN);
514       }
515     }
516   }
517 }
518 
519 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
520     const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
521     BasicBlock *ContinuationBlock) const {
522   // We start with a loop with a single latch:
523   //
524   //    +--------------------+
525   //    |                    |
526   //    |     preheader      |
527   //    |                    |
528   //    +--------+-----------+
529   //             |      ----------------\
530   //             |     /                |
531   //    +--------v----v------+          |
532   //    |                    |          |
533   //    |      header        |          |
534   //    |                    |          |
535   //    +--------------------+          |
536   //                                    |
537   //            .....                   |
538   //                                    |
539   //    +--------------------+          |
540   //    |                    |          |
541   //    |       latch        >----------/
542   //    |                    |
543   //    +-------v------------+
544   //            |
545   //            |
546   //            |   +--------------------+
547   //            |   |                    |
548   //            +--->   original exit    |
549   //                |                    |
550   //                +--------------------+
551   //
552   // We change the control flow to look like
553   //
554   //
555   //    +--------------------+
556   //    |                    |
557   //    |     preheader      >-------------------------+
558   //    |                    |                         |
559   //    +--------v-----------+                         |
560   //             |    /-------------+                  |
561   //             |   /              |                  |
562   //    +--------v--v--------+      |                  |
563   //    |                    |      |                  |
564   //    |      header        |      |   +--------+     |
565   //    |                    |      |   |        |     |
566   //    +--------------------+      |   |  +-----v-----v-----------+
567   //                                |   |  |                       |
568   //                                |   |  |     .pseudo.exit      |
569   //                                |   |  |                       |
570   //                                |   |  +-----------v-----------+
571   //                                |   |              |
572   //            .....               |   |              |
573   //                                |   |     +--------v-------------+
574   //    +--------------------+      |   |     |                      |
575   //    |                    |      |   |     |   ContinuationBlock  |
576   //    |       latch        >------+   |     |                      |
577   //    |                    |          |     +----------------------+
578   //    +---------v----------+          |
579   //              |                     |
580   //              |                     |
581   //              |     +---------------^-----+
582   //              |     |                     |
583   //              +----->    .exit.selector   |
584   //                    |                     |
585   //                    +----------v----------+
586   //                               |
587   //     +--------------------+    |
588   //     |                    |    |
589   //     |   original exit    <----+
590   //     |                    |
591   //     +--------------------+
592 
593   RewrittenRangeInfo RRI;
594 
595   BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
596   RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
597                                         &F, BBInsertLocation);
598   RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
599                                       BBInsertLocation);
600 
601   BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
602   bool Increasing = LS.IndVarIncreasing;
603   bool IsSignedPredicate = LS.IsSignedPredicate;
604 
605   IRBuilder<> B(PreheaderJump);
606   auto NoopOrExt = [&](Value *V) {
607     if (V->getType() == RangeTy)
608       return V;
609     return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
610                              : B.CreateZExt(V, RangeTy, "wide." + V->getName());
611   };
612 
613   // EnterLoopCond - is it okay to start executing this `LS'?
614   Value *EnterLoopCond = nullptr;
615   auto Pred =
616       Increasing
617           ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
618           : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
619   Value *IndVarStart = NoopOrExt(LS.IndVarStart);
620   EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
621 
622   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
623   PreheaderJump->eraseFromParent();
624 
625   LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
626   B.SetInsertPoint(LS.LatchBr);
627   Value *IndVarBase = NoopOrExt(LS.IndVarBase);
628   Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
629 
630   Value *CondForBranch = LS.LatchBrExitIdx == 1
631                              ? TakeBackedgeLoopCond
632                              : B.CreateNot(TakeBackedgeLoopCond);
633 
634   LS.LatchBr->setCondition(CondForBranch);
635 
636   B.SetInsertPoint(RRI.ExitSelector);
637 
638   // IterationsLeft - are there any more iterations left, given the original
639   // upper bound on the induction variable?  If not, we branch to the "real"
640   // exit.
641   Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
642   Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
643   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
644 
645   BranchInst *BranchToContinuation =
646       BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
647 
648   // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
649   // each of the PHI nodes in the loop header.  This feeds into the initial
650   // value of the same PHI nodes if/when we continue execution.
651   for (PHINode &PN : LS.Header->phis()) {
652     PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
653                                       BranchToContinuation->getIterator());
654 
655     NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
656     NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
657                         RRI.ExitSelector);
658     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
659   }
660 
661   RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
662                                   BranchToContinuation->getIterator());
663   RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
664   RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
665 
666   // The latch exit now has a branch from `RRI.ExitSelector' instead of
667   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
668   LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
669 
670   return RRI;
671 }
672 
673 void LoopConstrainer::rewriteIncomingValuesForPHIs(
674     LoopStructure &LS, BasicBlock *ContinuationBlock,
675     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
676   unsigned PHIIndex = 0;
677   for (PHINode &PN : LS.Header->phis())
678     PN.setIncomingValueForBlock(ContinuationBlock,
679                                 RRI.PHIValuesAtPseudoExit[PHIIndex++]);
680 
681   LS.IndVarStart = RRI.IndVarEnd;
682 }
683 
684 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
685                                              BasicBlock *OldPreheader,
686                                              const char *Tag) const {
687   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
688   BranchInst::Create(LS.Header, Preheader);
689 
690   LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
691 
692   return Preheader;
693 }
694 
695 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
696   Loop *ParentLoop = OriginalLoop.getParentLoop();
697   if (!ParentLoop)
698     return;
699 
700   for (BasicBlock *BB : BBs)
701     ParentLoop->addBasicBlockToLoop(BB, LI);
702 }
703 
704 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
705                                                  ValueToValueMapTy &VM,
706                                                  bool IsSubloop) {
707   Loop &New = *LI.AllocateLoop();
708   if (Parent)
709     Parent->addChildLoop(&New);
710   else
711     LI.addTopLevelLoop(&New);
712   LPMAddNewLoop(&New, IsSubloop);
713 
714   // Add all of the blocks in Original to the new loop.
715   for (auto *BB : Original->blocks())
716     if (LI.getLoopFor(BB) == Original)
717       New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
718 
719   // Add all of the subloops to the new loop.
720   for (Loop *SubLoop : *Original)
721     createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
722 
723   return &New;
724 }
725 
726 bool LoopConstrainer::run() {
727   BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
728   assert(Preheader != nullptr && "precondition!");
729 
730   OriginalPreheader = Preheader;
731   MainLoopPreheader = Preheader;
732   bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
733   bool Increasing = MainLoopStructure.IndVarIncreasing;
734   IntegerType *IVTy = cast<IntegerType>(RangeTy);
735 
736   SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer");
737   Instruction *InsertPt = OriginalPreheader->getTerminator();
738 
739   // It would have been better to make `PreLoop' and `PostLoop'
740   // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
741   // constructor.
742   ClonedLoop PreLoop, PostLoop;
743   bool NeedsPreLoop =
744       Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
745   bool NeedsPostLoop =
746       Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
747 
748   Value *ExitPreLoopAt = nullptr;
749   Value *ExitMainLoopAt = nullptr;
750   const SCEVConstant *MinusOneS =
751       cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
752 
753   if (NeedsPreLoop) {
754     const SCEV *ExitPreLoopAtSCEV = nullptr;
755 
756     if (Increasing)
757       ExitPreLoopAtSCEV = *SR.LowLimit;
758     else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
759                                IsSignedPredicate))
760       ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
761     else {
762       LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
763                         << "preloop exit limit.  HighLimit = "
764                         << *(*SR.HighLimit) << "\n");
765       return false;
766     }
767 
768     if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
769       LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
770                         << " preloop exit limit " << *ExitPreLoopAtSCEV
771                         << " at block " << InsertPt->getParent()->getName()
772                         << "\n");
773       return false;
774     }
775 
776     ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
777     ExitPreLoopAt->setName("exit.preloop.at");
778   }
779 
780   if (NeedsPostLoop) {
781     const SCEV *ExitMainLoopAtSCEV = nullptr;
782 
783     if (Increasing)
784       ExitMainLoopAtSCEV = *SR.HighLimit;
785     else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
786                                IsSignedPredicate))
787       ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
788     else {
789       LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
790                         << "mainloop exit limit.  LowLimit = "
791                         << *(*SR.LowLimit) << "\n");
792       return false;
793     }
794 
795     if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
796       LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
797                         << " main loop exit limit " << *ExitMainLoopAtSCEV
798                         << " at block " << InsertPt->getParent()->getName()
799                         << "\n");
800       return false;
801     }
802 
803     ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
804     ExitMainLoopAt->setName("exit.mainloop.at");
805   }
806 
807   // We clone these ahead of time so that we don't have to deal with changing
808   // and temporarily invalid IR as we transform the loops.
809   if (NeedsPreLoop)
810     cloneLoop(PreLoop, "preloop");
811   if (NeedsPostLoop)
812     cloneLoop(PostLoop, "postloop");
813 
814   RewrittenRangeInfo PreLoopRRI;
815 
816   if (NeedsPreLoop) {
817     Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
818                                                   PreLoop.Structure.Header);
819 
820     MainLoopPreheader =
821         createPreheader(MainLoopStructure, Preheader, "mainloop");
822     PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
823                                          ExitPreLoopAt, MainLoopPreheader);
824     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
825                                  PreLoopRRI);
826   }
827 
828   BasicBlock *PostLoopPreheader = nullptr;
829   RewrittenRangeInfo PostLoopRRI;
830 
831   if (NeedsPostLoop) {
832     PostLoopPreheader =
833         createPreheader(PostLoop.Structure, Preheader, "postloop");
834     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
835                                           ExitMainLoopAt, PostLoopPreheader);
836     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
837                                  PostLoopRRI);
838   }
839 
840   BasicBlock *NewMainLoopPreheader =
841       MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
842   BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
843                              PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
844                              PostLoopRRI.ExitSelector, NewMainLoopPreheader};
845 
846   // Some of the above may be nullptr, filter them out before passing to
847   // addToParentLoopIfNeeded.
848   auto NewBlocksEnd =
849       std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
850 
851   addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
852 
853   DT.recalculate(F);
854 
855   // We need to first add all the pre and post loop blocks into the loop
856   // structures (as part of createClonedLoopStructure), and then update the
857   // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
858   // LI when LoopSimplifyForm is generated.
859   Loop *PreL = nullptr, *PostL = nullptr;
860   if (!PreLoop.Blocks.empty()) {
861     PreL = createClonedLoopStructure(&OriginalLoop,
862                                      OriginalLoop.getParentLoop(), PreLoop.Map,
863                                      /* IsSubLoop */ false);
864   }
865 
866   if (!PostLoop.Blocks.empty()) {
867     PostL =
868         createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
869                                   PostLoop.Map, /* IsSubLoop */ false);
870   }
871 
872   // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
873   auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
874     formLCSSARecursively(*L, DT, &LI, &SE);
875     simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
876     // Pre/post loops are slow paths, we do not need to perform any loop
877     // optimizations on them.
878     if (!IsOriginalLoop)
879       DisableAllLoopOptsOnLoop(*L);
880   };
881   if (PreL)
882     CanonicalizeLoop(PreL, false);
883   if (PostL)
884     CanonicalizeLoop(PostL, false);
885   CanonicalizeLoop(&OriginalLoop, true);
886 
887   /// At this point:
888   /// - We've broken a "main loop" out of the loop in a way that the "main loop"
889   /// runs with the induction variable in a subset of [Begin, End).
890   /// - There is no overflow when computing "main loop" exit limit.
891   /// - Max latch taken count of the loop is limited.
892   /// It guarantees that induction variable will not overflow iterating in the
893   /// "main loop".
894   if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
895     if (IsSignedPredicate)
896       cast<BinaryOperator>(MainLoopStructure.IndVarBase)
897           ->setHasNoSignedWrap(true);
898   /// TODO: support unsigned predicate.
899   /// To add NUW flag we need to prove that both operands of BO are
900   /// non-negative. E.g:
901   /// ...
902   /// %iv.next = add nsw i32 %iv, -1
903   /// %cmp = icmp ult i32 %iv.next, %n
904   /// br i1 %cmp, label %loopexit, label %loop
905   ///
906   /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
907   /// overflow, therefore NUW flag is not legal here.
908 
909   return true;
910 }
911