xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp (revision 5b56413d04e608379c9a306373554a8e4d321bc0)
1 #include "llvm/Transforms/Utils/LoopConstrainer.h"
2 #include "llvm/Analysis/LoopInfo.h"
3 #include "llvm/Analysis/ScalarEvolution.h"
4 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
5 #include "llvm/IR/Dominators.h"
6 #include "llvm/Transforms/Utils/Cloning.h"
7 #include "llvm/Transforms/Utils/LoopSimplify.h"
8 #include "llvm/Transforms/Utils/LoopUtils.h"
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10 
11 using namespace llvm;
12 
13 static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
14 
15 #define DEBUG_TYPE "loop-constrainer"
16 
17 /// Given a loop with an deccreasing induction variable, is it possible to
18 /// safely calculate the bounds of a new loop using the given Predicate.
19 static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20                                   const SCEV *Step, ICmpInst::Predicate Pred,
21                                   unsigned LatchBrExitIdx, Loop *L,
22                                   ScalarEvolution &SE) {
23   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25     return false;
26 
27   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
28     return false;
29 
30   assert(SE.isKnownNegative(Step) && "expecting negative step");
31 
32   LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33   LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34   LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35   LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36   LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37   LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
38 
39   bool IsSigned = ICmpInst::isSigned(Pred);
40   // The predicate that we need to check that the induction variable lies
41   // within bounds.
42   ICmpInst::Predicate BoundPred =
43       IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
44 
45   if (LatchBrExitIdx == 1)
46     return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
47 
48   assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
49 
50   const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
51   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
52   APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
53                        : APInt::getMinValue(BitWidth);
54   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
55 
56   const SCEV *MinusOne =
57       SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
58 
59   return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
60          SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
61 }
62 
63 /// Given a loop with an increasing induction variable, is it possible to
64 /// safely calculate the bounds of a new loop using the given Predicate.
65 static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
66                                   const SCEV *Step, ICmpInst::Predicate Pred,
67                                   unsigned LatchBrExitIdx, Loop *L,
68                                   ScalarEvolution &SE) {
69   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
70       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
71     return false;
72 
73   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
74     return false;
75 
76   LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
77   LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
78   LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
79   LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
80   LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
81   LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
82 
83   bool IsSigned = ICmpInst::isSigned(Pred);
84   // The predicate that we need to check that the induction variable lies
85   // within bounds.
86   ICmpInst::Predicate BoundPred =
87       IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
88 
89   if (LatchBrExitIdx == 1)
90     return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
91 
92   assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
93 
94   const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
95   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
96   APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
97                        : APInt::getMaxValue(BitWidth);
98   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
99 
100   return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
101                                       SE.getAddExpr(BoundSCEV, Step)) &&
102           SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
103 }
104 
105 /// Returns estimate for max latch taken count of the loop of the narrowest
106 /// available type. If the latch block has such estimate, it is returned.
107 /// Otherwise, we use max exit count of whole loop (that is potentially of wider
108 /// type than latch check itself), which is still better than no estimate.
109 static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
110                                                           const Loop &L) {
111   const SCEV *FromBlock =
112       SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
113   if (isa<SCEVCouldNotCompute>(FromBlock))
114     return SE.getSymbolicMaxBackedgeTakenCount(&L);
115   return FromBlock;
116 }
117 
118 std::optional<LoopStructure>
119 LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
120                                   bool AllowUnsignedLatchCond,
121                                   const char *&FailureReason) {
122   if (!L.isLoopSimplifyForm()) {
123     FailureReason = "loop not in LoopSimplify form";
124     return std::nullopt;
125   }
126 
127   BasicBlock *Latch = L.getLoopLatch();
128   assert(Latch && "Simplified loops only have one latch!");
129 
130   if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
131     FailureReason = "loop has already been cloned";
132     return std::nullopt;
133   }
134 
135   if (!L.isLoopExiting(Latch)) {
136     FailureReason = "no loop latch";
137     return std::nullopt;
138   }
139 
140   BasicBlock *Header = L.getHeader();
141   BasicBlock *Preheader = L.getLoopPreheader();
142   if (!Preheader) {
143     FailureReason = "no preheader";
144     return std::nullopt;
145   }
146 
147   BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
148   if (!LatchBr || LatchBr->isUnconditional()) {
149     FailureReason = "latch terminator not conditional branch";
150     return std::nullopt;
151   }
152 
153   unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
154 
155   ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
156   if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
157     FailureReason = "latch terminator branch not conditional on integral icmp";
158     return std::nullopt;
159   }
160 
161   const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
162   if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
163     FailureReason = "could not compute latch count";
164     return std::nullopt;
165   }
166   assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
167              ScalarEvolution::LoopInvariant &&
168          "loop variant exit count doesn't make sense!");
169 
170   ICmpInst::Predicate Pred = ICI->getPredicate();
171   Value *LeftValue = ICI->getOperand(0);
172   const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
173   IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
174 
175   Value *RightValue = ICI->getOperand(1);
176   const SCEV *RightSCEV = SE.getSCEV(RightValue);
177 
178   // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
179   if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
180     if (isa<SCEVAddRecExpr>(RightSCEV)) {
181       std::swap(LeftSCEV, RightSCEV);
182       std::swap(LeftValue, RightValue);
183       Pred = ICmpInst::getSwappedPredicate(Pred);
184     } else {
185       FailureReason = "no add recurrences in the icmp";
186       return std::nullopt;
187     }
188   }
189 
190   auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
191     if (AR->getNoWrapFlags(SCEV::FlagNSW))
192       return true;
193 
194     IntegerType *Ty = cast<IntegerType>(AR->getType());
195     IntegerType *WideTy =
196         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
197 
198     const SCEVAddRecExpr *ExtendAfterOp =
199         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
200     if (ExtendAfterOp) {
201       const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
202       const SCEV *ExtendedStep =
203           SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
204 
205       bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
206                           ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
207 
208       if (NoSignedWrap)
209         return true;
210     }
211 
212     // We may have proved this when computing the sign extension above.
213     return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
214   };
215 
216   // `ICI` is interpreted as taking the backedge if the *next* value of the
217   // induction variable satisfies some constraint.
218 
219   const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
220   if (IndVarBase->getLoop() != &L) {
221     FailureReason = "LHS in cmp is not an AddRec for this loop";
222     return std::nullopt;
223   }
224   if (!IndVarBase->isAffine()) {
225     FailureReason = "LHS in icmp not induction variable";
226     return std::nullopt;
227   }
228   const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
229   if (!isa<SCEVConstant>(StepRec)) {
230     FailureReason = "LHS in icmp not induction variable";
231     return std::nullopt;
232   }
233   ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
234 
235   if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
236     FailureReason = "LHS in icmp needs nsw for equality predicates";
237     return std::nullopt;
238   }
239 
240   assert(!StepCI->isZero() && "Zero step?");
241   bool IsIncreasing = !StepCI->isNegative();
242   bool IsSignedPredicate;
243   const SCEV *StartNext = IndVarBase->getStart();
244   const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
245   const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
246   const SCEV *Step = SE.getSCEV(StepCI);
247 
248   const SCEV *FixedRightSCEV = nullptr;
249 
250   // If RightValue resides within loop (but still being loop invariant),
251   // regenerate it as preheader.
252   if (auto *I = dyn_cast<Instruction>(RightValue))
253     if (L.contains(I->getParent()))
254       FixedRightSCEV = RightSCEV;
255 
256   if (IsIncreasing) {
257     bool DecreasedRightValueByOne = false;
258     if (StepCI->isOne()) {
259       // Try to turn eq/ne predicates to those we can work with.
260       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
261         // while (++i != len) {         while (++i < len) {
262         //   ...                 --->     ...
263         // }                            }
264         // If both parts are known non-negative, it is profitable to use
265         // unsigned comparison in increasing loop. This allows us to make the
266         // comparison check against "RightSCEV + 1" more optimistic.
267         if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
268             isKnownNonNegativeInLoop(RightSCEV, &L, SE))
269           Pred = ICmpInst::ICMP_ULT;
270         else
271           Pred = ICmpInst::ICMP_SLT;
272       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
273         // while (true) {               while (true) {
274         //   if (++i == len)     --->     if (++i > len - 1)
275         //     break;                       break;
276         //   ...                          ...
277         // }                            }
278         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
279             cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
280           Pred = ICmpInst::ICMP_UGT;
281           RightSCEV =
282               SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
283           DecreasedRightValueByOne = true;
284         } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
285           Pred = ICmpInst::ICMP_SGT;
286           RightSCEV =
287               SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
288           DecreasedRightValueByOne = true;
289         }
290       }
291     }
292 
293     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
294     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
295     bool FoundExpectedPred =
296         (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
297 
298     if (!FoundExpectedPred) {
299       FailureReason = "expected icmp slt semantically, found something else";
300       return std::nullopt;
301     }
302 
303     IsSignedPredicate = ICmpInst::isSigned(Pred);
304     if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
305       FailureReason = "unsigned latch conditions are explicitly prohibited";
306       return std::nullopt;
307     }
308 
309     if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
310                                LatchBrExitIdx, &L, SE)) {
311       FailureReason = "Unsafe loop bounds";
312       return std::nullopt;
313     }
314     if (LatchBrExitIdx == 0) {
315       // We need to increase the right value unless we have already decreased
316       // it virtually when we replaced EQ with SGT.
317       if (!DecreasedRightValueByOne)
318         FixedRightSCEV =
319             SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
320     } else {
321       assert(!DecreasedRightValueByOne &&
322              "Right value can be decreased only for LatchBrExitIdx == 0!");
323     }
324   } else {
325     bool IncreasedRightValueByOne = false;
326     if (StepCI->isMinusOne()) {
327       // Try to turn eq/ne predicates to those we can work with.
328       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
329         // while (--i != len) {         while (--i > len) {
330         //   ...                 --->     ...
331         // }                            }
332         // We intentionally don't turn the predicate into UGT even if we know
333         // that both operands are non-negative, because it will only pessimize
334         // our check against "RightSCEV - 1".
335         Pred = ICmpInst::ICMP_SGT;
336       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
337         // while (true) {               while (true) {
338         //   if (--i == len)     --->     if (--i < len + 1)
339         //     break;                       break;
340         //   ...                          ...
341         // }                            }
342         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
343             cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
344           Pred = ICmpInst::ICMP_ULT;
345           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
346           IncreasedRightValueByOne = true;
347         } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
348           Pred = ICmpInst::ICMP_SLT;
349           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
350           IncreasedRightValueByOne = true;
351         }
352       }
353     }
354 
355     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
356     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
357 
358     bool FoundExpectedPred =
359         (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
360 
361     if (!FoundExpectedPred) {
362       FailureReason = "expected icmp sgt semantically, found something else";
363       return std::nullopt;
364     }
365 
366     IsSignedPredicate =
367         Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
368 
369     if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
370       FailureReason = "unsigned latch conditions are explicitly prohibited";
371       return std::nullopt;
372     }
373 
374     if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
375                                LatchBrExitIdx, &L, SE)) {
376       FailureReason = "Unsafe bounds";
377       return std::nullopt;
378     }
379 
380     if (LatchBrExitIdx == 0) {
381       // We need to decrease the right value unless we have already increased
382       // it virtually when we replaced EQ with SLT.
383       if (!IncreasedRightValueByOne)
384         FixedRightSCEV =
385             SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
386     } else {
387       assert(!IncreasedRightValueByOne &&
388              "Right value can be increased only for LatchBrExitIdx == 0!");
389     }
390   }
391   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
392 
393   assert(!L.contains(LatchExit) && "expected an exit block!");
394   const DataLayout &DL = Preheader->getModule()->getDataLayout();
395   SCEVExpander Expander(SE, DL, "loop-constrainer");
396   Instruction *Ins = Preheader->getTerminator();
397 
398   if (FixedRightSCEV)
399     RightValue =
400         Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
401 
402   Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
403   IndVarStartV->setName("indvar.start");
404 
405   LoopStructure Result;
406 
407   Result.Tag = "main";
408   Result.Header = Header;
409   Result.Latch = Latch;
410   Result.LatchBr = LatchBr;
411   Result.LatchExit = LatchExit;
412   Result.LatchBrExitIdx = LatchBrExitIdx;
413   Result.IndVarStart = IndVarStartV;
414   Result.IndVarStep = StepCI;
415   Result.IndVarBase = LeftValue;
416   Result.IndVarIncreasing = IsIncreasing;
417   Result.LoopExitAt = RightValue;
418   Result.IsSignedPredicate = IsSignedPredicate;
419   Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
420 
421   FailureReason = nullptr;
422 
423   return Result;
424 }
425 
426 // Add metadata to the loop L to disable loop optimizations. Callers need to
427 // confirm that optimizing loop L is not beneficial.
428 static void DisableAllLoopOptsOnLoop(Loop &L) {
429   // We do not care about any existing loopID related metadata for L, since we
430   // are setting all loop metadata to false.
431   LLVMContext &Context = L.getHeader()->getContext();
432   // Reserve first location for self reference to the LoopID metadata node.
433   MDNode *Dummy = MDNode::get(Context, {});
434   MDNode *DisableUnroll = MDNode::get(
435       Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
436   Metadata *FalseVal =
437       ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
438   MDNode *DisableVectorize = MDNode::get(
439       Context,
440       {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
441   MDNode *DisableLICMVersioning = MDNode::get(
442       Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
443   MDNode *DisableDistribution = MDNode::get(
444       Context,
445       {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
446   MDNode *NewLoopID =
447       MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
448                             DisableLICMVersioning, DisableDistribution});
449   // Set operand 0 to refer to the loop id itself.
450   NewLoopID->replaceOperandWith(0, NewLoopID);
451   L.setLoopID(NewLoopID);
452 }
453 
454 LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
455                                  function_ref<void(Loop *, bool)> LPMAddNewLoop,
456                                  const LoopStructure &LS, ScalarEvolution &SE,
457                                  DominatorTree &DT, Type *T, SubRanges SR)
458     : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
459       DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
460       MainLoopStructure(LS), SR(SR) {}
461 
462 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
463                                 const char *Tag) const {
464   for (BasicBlock *BB : OriginalLoop.getBlocks()) {
465     BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
466     Result.Blocks.push_back(Clone);
467     Result.Map[BB] = Clone;
468   }
469 
470   auto GetClonedValue = [&Result](Value *V) {
471     assert(V && "null values not in domain!");
472     auto It = Result.Map.find(V);
473     if (It == Result.Map.end())
474       return V;
475     return static_cast<Value *>(It->second);
476   };
477 
478   auto *ClonedLatch =
479       cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
480   ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
481                                             MDNode::get(Ctx, {}));
482 
483   Result.Structure = MainLoopStructure.map(GetClonedValue);
484   Result.Structure.Tag = Tag;
485 
486   for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
487     BasicBlock *ClonedBB = Result.Blocks[i];
488     BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
489 
490     assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
491 
492     for (Instruction &I : *ClonedBB)
493       RemapInstruction(&I, Result.Map,
494                        RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
495 
496     // Exit blocks will now have one more predecessor and their PHI nodes need
497     // to be edited to reflect that.  No phi nodes need to be introduced because
498     // the loop is in LCSSA.
499 
500     for (auto *SBB : successors(OriginalBB)) {
501       if (OriginalLoop.contains(SBB))
502         continue; // not an exit block
503 
504       for (PHINode &PN : SBB->phis()) {
505         Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
506         PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
507         SE.forgetValue(&PN);
508       }
509     }
510   }
511 }
512 
513 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
514     const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
515     BasicBlock *ContinuationBlock) const {
516   // We start with a loop with a single latch:
517   //
518   //    +--------------------+
519   //    |                    |
520   //    |     preheader      |
521   //    |                    |
522   //    +--------+-----------+
523   //             |      ----------------\
524   //             |     /                |
525   //    +--------v----v------+          |
526   //    |                    |          |
527   //    |      header        |          |
528   //    |                    |          |
529   //    +--------------------+          |
530   //                                    |
531   //            .....                   |
532   //                                    |
533   //    +--------------------+          |
534   //    |                    |          |
535   //    |       latch        >----------/
536   //    |                    |
537   //    +-------v------------+
538   //            |
539   //            |
540   //            |   +--------------------+
541   //            |   |                    |
542   //            +--->   original exit    |
543   //                |                    |
544   //                +--------------------+
545   //
546   // We change the control flow to look like
547   //
548   //
549   //    +--------------------+
550   //    |                    |
551   //    |     preheader      >-------------------------+
552   //    |                    |                         |
553   //    +--------v-----------+                         |
554   //             |    /-------------+                  |
555   //             |   /              |                  |
556   //    +--------v--v--------+      |                  |
557   //    |                    |      |                  |
558   //    |      header        |      |   +--------+     |
559   //    |                    |      |   |        |     |
560   //    +--------------------+      |   |  +-----v-----v-----------+
561   //                                |   |  |                       |
562   //                                |   |  |     .pseudo.exit      |
563   //                                |   |  |                       |
564   //                                |   |  +-----------v-----------+
565   //                                |   |              |
566   //            .....               |   |              |
567   //                                |   |     +--------v-------------+
568   //    +--------------------+      |   |     |                      |
569   //    |                    |      |   |     |   ContinuationBlock  |
570   //    |       latch        >------+   |     |                      |
571   //    |                    |          |     +----------------------+
572   //    +---------v----------+          |
573   //              |                     |
574   //              |                     |
575   //              |     +---------------^-----+
576   //              |     |                     |
577   //              +----->    .exit.selector   |
578   //                    |                     |
579   //                    +----------v----------+
580   //                               |
581   //     +--------------------+    |
582   //     |                    |    |
583   //     |   original exit    <----+
584   //     |                    |
585   //     +--------------------+
586 
587   RewrittenRangeInfo RRI;
588 
589   BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
590   RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
591                                         &F, BBInsertLocation);
592   RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
593                                       BBInsertLocation);
594 
595   BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
596   bool Increasing = LS.IndVarIncreasing;
597   bool IsSignedPredicate = LS.IsSignedPredicate;
598 
599   IRBuilder<> B(PreheaderJump);
600   auto NoopOrExt = [&](Value *V) {
601     if (V->getType() == RangeTy)
602       return V;
603     return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
604                              : B.CreateZExt(V, RangeTy, "wide." + V->getName());
605   };
606 
607   // EnterLoopCond - is it okay to start executing this `LS'?
608   Value *EnterLoopCond = nullptr;
609   auto Pred =
610       Increasing
611           ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
612           : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
613   Value *IndVarStart = NoopOrExt(LS.IndVarStart);
614   EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
615 
616   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
617   PreheaderJump->eraseFromParent();
618 
619   LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
620   B.SetInsertPoint(LS.LatchBr);
621   Value *IndVarBase = NoopOrExt(LS.IndVarBase);
622   Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
623 
624   Value *CondForBranch = LS.LatchBrExitIdx == 1
625                              ? TakeBackedgeLoopCond
626                              : B.CreateNot(TakeBackedgeLoopCond);
627 
628   LS.LatchBr->setCondition(CondForBranch);
629 
630   B.SetInsertPoint(RRI.ExitSelector);
631 
632   // IterationsLeft - are there any more iterations left, given the original
633   // upper bound on the induction variable?  If not, we branch to the "real"
634   // exit.
635   Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
636   Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
637   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
638 
639   BranchInst *BranchToContinuation =
640       BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
641 
642   // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
643   // each of the PHI nodes in the loop header.  This feeds into the initial
644   // value of the same PHI nodes if/when we continue execution.
645   for (PHINode &PN : LS.Header->phis()) {
646     PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
647                                       BranchToContinuation);
648 
649     NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
650     NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
651                         RRI.ExitSelector);
652     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
653   }
654 
655   RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
656                                   BranchToContinuation);
657   RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
658   RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
659 
660   // The latch exit now has a branch from `RRI.ExitSelector' instead of
661   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
662   LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
663 
664   return RRI;
665 }
666 
667 void LoopConstrainer::rewriteIncomingValuesForPHIs(
668     LoopStructure &LS, BasicBlock *ContinuationBlock,
669     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
670   unsigned PHIIndex = 0;
671   for (PHINode &PN : LS.Header->phis())
672     PN.setIncomingValueForBlock(ContinuationBlock,
673                                 RRI.PHIValuesAtPseudoExit[PHIIndex++]);
674 
675   LS.IndVarStart = RRI.IndVarEnd;
676 }
677 
678 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
679                                              BasicBlock *OldPreheader,
680                                              const char *Tag) const {
681   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
682   BranchInst::Create(LS.Header, Preheader);
683 
684   LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
685 
686   return Preheader;
687 }
688 
689 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
690   Loop *ParentLoop = OriginalLoop.getParentLoop();
691   if (!ParentLoop)
692     return;
693 
694   for (BasicBlock *BB : BBs)
695     ParentLoop->addBasicBlockToLoop(BB, LI);
696 }
697 
698 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
699                                                  ValueToValueMapTy &VM,
700                                                  bool IsSubloop) {
701   Loop &New = *LI.AllocateLoop();
702   if (Parent)
703     Parent->addChildLoop(&New);
704   else
705     LI.addTopLevelLoop(&New);
706   LPMAddNewLoop(&New, IsSubloop);
707 
708   // Add all of the blocks in Original to the new loop.
709   for (auto *BB : Original->blocks())
710     if (LI.getLoopFor(BB) == Original)
711       New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
712 
713   // Add all of the subloops to the new loop.
714   for (Loop *SubLoop : *Original)
715     createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
716 
717   return &New;
718 }
719 
720 bool LoopConstrainer::run() {
721   BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
722   assert(Preheader != nullptr && "precondition!");
723 
724   OriginalPreheader = Preheader;
725   MainLoopPreheader = Preheader;
726   bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
727   bool Increasing = MainLoopStructure.IndVarIncreasing;
728   IntegerType *IVTy = cast<IntegerType>(RangeTy);
729 
730   SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer");
731   Instruction *InsertPt = OriginalPreheader->getTerminator();
732 
733   // It would have been better to make `PreLoop' and `PostLoop'
734   // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
735   // constructor.
736   ClonedLoop PreLoop, PostLoop;
737   bool NeedsPreLoop =
738       Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
739   bool NeedsPostLoop =
740       Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
741 
742   Value *ExitPreLoopAt = nullptr;
743   Value *ExitMainLoopAt = nullptr;
744   const SCEVConstant *MinusOneS =
745       cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
746 
747   if (NeedsPreLoop) {
748     const SCEV *ExitPreLoopAtSCEV = nullptr;
749 
750     if (Increasing)
751       ExitPreLoopAtSCEV = *SR.LowLimit;
752     else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
753                                IsSignedPredicate))
754       ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
755     else {
756       LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
757                         << "preloop exit limit.  HighLimit = "
758                         << *(*SR.HighLimit) << "\n");
759       return false;
760     }
761 
762     if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
763       LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
764                         << " preloop exit limit " << *ExitPreLoopAtSCEV
765                         << " at block " << InsertPt->getParent()->getName()
766                         << "\n");
767       return false;
768     }
769 
770     ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
771     ExitPreLoopAt->setName("exit.preloop.at");
772   }
773 
774   if (NeedsPostLoop) {
775     const SCEV *ExitMainLoopAtSCEV = nullptr;
776 
777     if (Increasing)
778       ExitMainLoopAtSCEV = *SR.HighLimit;
779     else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
780                                IsSignedPredicate))
781       ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
782     else {
783       LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
784                         << "mainloop exit limit.  LowLimit = "
785                         << *(*SR.LowLimit) << "\n");
786       return false;
787     }
788 
789     if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
790       LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
791                         << " main loop exit limit " << *ExitMainLoopAtSCEV
792                         << " at block " << InsertPt->getParent()->getName()
793                         << "\n");
794       return false;
795     }
796 
797     ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
798     ExitMainLoopAt->setName("exit.mainloop.at");
799   }
800 
801   // We clone these ahead of time so that we don't have to deal with changing
802   // and temporarily invalid IR as we transform the loops.
803   if (NeedsPreLoop)
804     cloneLoop(PreLoop, "preloop");
805   if (NeedsPostLoop)
806     cloneLoop(PostLoop, "postloop");
807 
808   RewrittenRangeInfo PreLoopRRI;
809 
810   if (NeedsPreLoop) {
811     Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
812                                                   PreLoop.Structure.Header);
813 
814     MainLoopPreheader =
815         createPreheader(MainLoopStructure, Preheader, "mainloop");
816     PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
817                                          ExitPreLoopAt, MainLoopPreheader);
818     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
819                                  PreLoopRRI);
820   }
821 
822   BasicBlock *PostLoopPreheader = nullptr;
823   RewrittenRangeInfo PostLoopRRI;
824 
825   if (NeedsPostLoop) {
826     PostLoopPreheader =
827         createPreheader(PostLoop.Structure, Preheader, "postloop");
828     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
829                                           ExitMainLoopAt, PostLoopPreheader);
830     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
831                                  PostLoopRRI);
832   }
833 
834   BasicBlock *NewMainLoopPreheader =
835       MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
836   BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
837                              PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
838                              PostLoopRRI.ExitSelector, NewMainLoopPreheader};
839 
840   // Some of the above may be nullptr, filter them out before passing to
841   // addToParentLoopIfNeeded.
842   auto NewBlocksEnd =
843       std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
844 
845   addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
846 
847   DT.recalculate(F);
848 
849   // We need to first add all the pre and post loop blocks into the loop
850   // structures (as part of createClonedLoopStructure), and then update the
851   // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
852   // LI when LoopSimplifyForm is generated.
853   Loop *PreL = nullptr, *PostL = nullptr;
854   if (!PreLoop.Blocks.empty()) {
855     PreL = createClonedLoopStructure(&OriginalLoop,
856                                      OriginalLoop.getParentLoop(), PreLoop.Map,
857                                      /* IsSubLoop */ false);
858   }
859 
860   if (!PostLoop.Blocks.empty()) {
861     PostL =
862         createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
863                                   PostLoop.Map, /* IsSubLoop */ false);
864   }
865 
866   // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
867   auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
868     formLCSSARecursively(*L, DT, &LI, &SE);
869     simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
870     // Pre/post loops are slow paths, we do not need to perform any loop
871     // optimizations on them.
872     if (!IsOriginalLoop)
873       DisableAllLoopOptsOnLoop(*L);
874   };
875   if (PreL)
876     CanonicalizeLoop(PreL, false);
877   if (PostL)
878     CanonicalizeLoop(PostL, false);
879   CanonicalizeLoop(&OriginalLoop, true);
880 
881   /// At this point:
882   /// - We've broken a "main loop" out of the loop in a way that the "main loop"
883   /// runs with the induction variable in a subset of [Begin, End).
884   /// - There is no overflow when computing "main loop" exit limit.
885   /// - Max latch taken count of the loop is limited.
886   /// It guarantees that induction variable will not overflow iterating in the
887   /// "main loop".
888   if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
889     if (IsSignedPredicate)
890       cast<BinaryOperator>(MainLoopStructure.IndVarBase)
891           ->setHasNoSignedWrap(true);
892   /// TODO: support unsigned predicate.
893   /// To add NUW flag we need to prove that both operands of BO are
894   /// non-negative. E.g:
895   /// ...
896   /// %iv.next = add nsw i32 %iv, -1
897   /// %cmp = icmp ult i32 %iv.next, %n
898   /// br i1 %cmp, label %loopexit, label %loop
899   ///
900   /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
901   /// overflow, therefore NUW flag is not legal here.
902 
903   return true;
904 }
905