xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp (revision 5ca8e32633c4ffbbcd6762e5888b6a4ba0708c6c)
1 //===- InductiveRangeCheckElimination.cpp - -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The InductiveRangeCheckElimination pass splits a loop's iteration space into
10 // three disjoint ranges.  It does that in a way such that the loop running in
11 // the middle loop provably does not need range checks. As an example, it will
12 // convert
13 //
14 //   len = < known positive >
15 //   for (i = 0; i < n; i++) {
16 //     if (0 <= i && i < len) {
17 //       do_something();
18 //     } else {
19 //       throw_out_of_bounds();
20 //     }
21 //   }
22 //
23 // to
24 //
25 //   len = < known positive >
26 //   limit = smin(n, len)
27 //   // no first segment
28 //   for (i = 0; i < limit; i++) {
29 //     if (0 <= i && i < len) { // this check is fully redundant
30 //       do_something();
31 //     } else {
32 //       throw_out_of_bounds();
33 //     }
34 //   }
35 //   for (i = limit; i < n; i++) {
36 //     if (0 <= i && i < len) {
37 //       do_something();
38 //     } else {
39 //       throw_out_of_bounds();
40 //     }
41 //   }
42 //
43 //===----------------------------------------------------------------------===//
44 
45 #include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h"
46 #include "llvm/ADT/APInt.h"
47 #include "llvm/ADT/ArrayRef.h"
48 #include "llvm/ADT/PriorityWorklist.h"
49 #include "llvm/ADT/SmallPtrSet.h"
50 #include "llvm/ADT/SmallVector.h"
51 #include "llvm/ADT/StringRef.h"
52 #include "llvm/ADT/Twine.h"
53 #include "llvm/Analysis/BlockFrequencyInfo.h"
54 #include "llvm/Analysis/BranchProbabilityInfo.h"
55 #include "llvm/Analysis/LoopAnalysisManager.h"
56 #include "llvm/Analysis/LoopInfo.h"
57 #include "llvm/Analysis/ScalarEvolution.h"
58 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
59 #include "llvm/IR/BasicBlock.h"
60 #include "llvm/IR/CFG.h"
61 #include "llvm/IR/Constants.h"
62 #include "llvm/IR/DerivedTypes.h"
63 #include "llvm/IR/Dominators.h"
64 #include "llvm/IR/Function.h"
65 #include "llvm/IR/IRBuilder.h"
66 #include "llvm/IR/InstrTypes.h"
67 #include "llvm/IR/Instructions.h"
68 #include "llvm/IR/Metadata.h"
69 #include "llvm/IR/Module.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/IR/Type.h"
72 #include "llvm/IR/Use.h"
73 #include "llvm/IR/User.h"
74 #include "llvm/IR/Value.h"
75 #include "llvm/Support/BranchProbability.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/CommandLine.h"
78 #include "llvm/Support/Compiler.h"
79 #include "llvm/Support/Debug.h"
80 #include "llvm/Support/ErrorHandling.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
83 #include "llvm/Transforms/Utils/Cloning.h"
84 #include "llvm/Transforms/Utils/LoopSimplify.h"
85 #include "llvm/Transforms/Utils/LoopUtils.h"
86 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
87 #include "llvm/Transforms/Utils/ValueMapper.h"
88 #include <algorithm>
89 #include <cassert>
90 #include <iterator>
91 #include <limits>
92 #include <optional>
93 #include <utility>
94 #include <vector>
95 
96 using namespace llvm;
97 using namespace llvm::PatternMatch;
98 
99 static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
100                                         cl::init(64));
101 
102 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
103                                        cl::init(false));
104 
105 static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
106                                       cl::init(false));
107 
108 static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks",
109                                              cl::Hidden, cl::init(false));
110 
111 static cl::opt<unsigned> MinRuntimeIterations("irce-min-runtime-iterations",
112                                               cl::Hidden, cl::init(10));
113 
114 static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch",
115                                                  cl::Hidden, cl::init(true));
116 
117 static cl::opt<bool> AllowNarrowLatchCondition(
118     "irce-allow-narrow-latch", cl::Hidden, cl::init(true),
119     cl::desc("If set to true, IRCE may eliminate wide range checks in loops "
120              "with narrow latch condition."));
121 
122 static cl::opt<unsigned> MaxTypeSizeForOverflowCheck(
123     "irce-max-type-size-for-overflow-check", cl::Hidden, cl::init(32),
124     cl::desc(
125         "Maximum size of range check type for which can be produced runtime "
126         "overflow check of its limit's computation"));
127 
128 static cl::opt<bool>
129     PrintScaledBoundaryRangeChecks("irce-print-scaled-boundary-range-checks",
130                                    cl::Hidden, cl::init(false));
131 
132 static const char *ClonedLoopTag = "irce.loop.clone";
133 
134 #define DEBUG_TYPE "irce"
135 
136 namespace {
137 
138 /// An inductive range check is conditional branch in a loop with
139 ///
140 ///  1. a very cold successor (i.e. the branch jumps to that successor very
141 ///     rarely)
142 ///
143 ///  and
144 ///
145 ///  2. a condition that is provably true for some contiguous range of values
146 ///     taken by the containing loop's induction variable.
147 ///
148 class InductiveRangeCheck {
149 
150   const SCEV *Begin = nullptr;
151   const SCEV *Step = nullptr;
152   const SCEV *End = nullptr;
153   Use *CheckUse = nullptr;
154 
155   static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE,
156                                   const SCEVAddRecExpr *&Index,
157                                   const SCEV *&End);
158 
159   static void
160   extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse,
161                              SmallVectorImpl<InductiveRangeCheck> &Checks,
162                              SmallPtrSetImpl<Value *> &Visited);
163 
164   static bool parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS,
165                                   ICmpInst::Predicate Pred, ScalarEvolution &SE,
166                                   const SCEVAddRecExpr *&Index,
167                                   const SCEV *&End);
168 
169   static bool reassociateSubLHS(Loop *L, Value *VariantLHS, Value *InvariantRHS,
170                                 ICmpInst::Predicate Pred, ScalarEvolution &SE,
171                                 const SCEVAddRecExpr *&Index, const SCEV *&End);
172 
173 public:
174   const SCEV *getBegin() const { return Begin; }
175   const SCEV *getStep() const { return Step; }
176   const SCEV *getEnd() const { return End; }
177 
178   void print(raw_ostream &OS) const {
179     OS << "InductiveRangeCheck:\n";
180     OS << "  Begin: ";
181     Begin->print(OS);
182     OS << "  Step: ";
183     Step->print(OS);
184     OS << "  End: ";
185     End->print(OS);
186     OS << "\n  CheckUse: ";
187     getCheckUse()->getUser()->print(OS);
188     OS << " Operand: " << getCheckUse()->getOperandNo() << "\n";
189   }
190 
191   LLVM_DUMP_METHOD
192   void dump() {
193     print(dbgs());
194   }
195 
196   Use *getCheckUse() const { return CheckUse; }
197 
198   /// Represents an signed integer range [Range.getBegin(), Range.getEnd()).  If
199   /// R.getEnd() le R.getBegin(), then R denotes the empty range.
200 
201   class Range {
202     const SCEV *Begin;
203     const SCEV *End;
204 
205   public:
206     Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) {
207       assert(Begin->getType() == End->getType() && "ill-typed range!");
208     }
209 
210     Type *getType() const { return Begin->getType(); }
211     const SCEV *getBegin() const { return Begin; }
212     const SCEV *getEnd() const { return End; }
213     bool isEmpty(ScalarEvolution &SE, bool IsSigned) const {
214       if (Begin == End)
215         return true;
216       if (IsSigned)
217         return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End);
218       else
219         return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End);
220     }
221   };
222 
223   /// This is the value the condition of the branch needs to evaluate to for the
224   /// branch to take the hot successor (see (1) above).
225   bool getPassingDirection() { return true; }
226 
227   /// Computes a range for the induction variable (IndVar) in which the range
228   /// check is redundant and can be constant-folded away.  The induction
229   /// variable is not required to be the canonical {0,+,1} induction variable.
230   std::optional<Range> computeSafeIterationSpace(ScalarEvolution &SE,
231                                                  const SCEVAddRecExpr *IndVar,
232                                                  bool IsLatchSigned) const;
233 
234   /// Parse out a set of inductive range checks from \p BI and append them to \p
235   /// Checks.
236   ///
237   /// NB! There may be conditions feeding into \p BI that aren't inductive range
238   /// checks, and hence don't end up in \p Checks.
239   static void extractRangeChecksFromBranch(
240       BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI,
241       SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed);
242 };
243 
244 struct LoopStructure;
245 
246 class InductiveRangeCheckElimination {
247   ScalarEvolution &SE;
248   BranchProbabilityInfo *BPI;
249   DominatorTree &DT;
250   LoopInfo &LI;
251 
252   using GetBFIFunc =
253       std::optional<llvm::function_ref<llvm::BlockFrequencyInfo &()>>;
254   GetBFIFunc GetBFI;
255 
256   // Returns true if it is profitable to do a transform basing on estimation of
257   // number of iterations.
258   bool isProfitableToTransform(const Loop &L, LoopStructure &LS);
259 
260 public:
261   InductiveRangeCheckElimination(ScalarEvolution &SE,
262                                  BranchProbabilityInfo *BPI, DominatorTree &DT,
263                                  LoopInfo &LI, GetBFIFunc GetBFI = std::nullopt)
264       : SE(SE), BPI(BPI), DT(DT), LI(LI), GetBFI(GetBFI) {}
265 
266   bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop);
267 };
268 
269 } // end anonymous namespace
270 
271 /// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI` cannot
272 /// be interpreted as a range check, return false.  Otherwise set `Index` to the
273 /// SCEV being range checked, and set `End` to the upper or lower limit `Index`
274 /// is being range checked.
275 bool InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
276                                               ScalarEvolution &SE,
277                                               const SCEVAddRecExpr *&Index,
278                                               const SCEV *&End) {
279   auto IsLoopInvariant = [&SE, L](Value *V) {
280     return SE.isLoopInvariant(SE.getSCEV(V), L);
281   };
282 
283   ICmpInst::Predicate Pred = ICI->getPredicate();
284   Value *LHS = ICI->getOperand(0);
285   Value *RHS = ICI->getOperand(1);
286 
287   // Canonicalize to the `Index Pred Invariant` comparison
288   if (IsLoopInvariant(LHS)) {
289     std::swap(LHS, RHS);
290     Pred = CmpInst::getSwappedPredicate(Pred);
291   } else if (!IsLoopInvariant(RHS))
292     // Both LHS and RHS are loop variant
293     return false;
294 
295   if (parseIvAgaisntLimit(L, LHS, RHS, Pred, SE, Index, End))
296     return true;
297 
298   if (reassociateSubLHS(L, LHS, RHS, Pred, SE, Index, End))
299     return true;
300 
301   // TODO: support ReassociateAddLHS
302   return false;
303 }
304 
305 // Try to parse range check in the form of "IV vs Limit"
306 bool InductiveRangeCheck::parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS,
307                                               ICmpInst::Predicate Pred,
308                                               ScalarEvolution &SE,
309                                               const SCEVAddRecExpr *&Index,
310                                               const SCEV *&End) {
311 
312   auto SIntMaxSCEV = [&](Type *T) {
313     unsigned BitWidth = cast<IntegerType>(T)->getBitWidth();
314     return SE.getConstant(APInt::getSignedMaxValue(BitWidth));
315   };
316 
317   const auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(LHS));
318   if (!AddRec)
319     return false;
320 
321   // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
322   // We can potentially do much better here.
323   // If we want to adjust upper bound for the unsigned range check as we do it
324   // for signed one, we will need to pick Unsigned max
325   switch (Pred) {
326   default:
327     return false;
328 
329   case ICmpInst::ICMP_SGE:
330     if (match(RHS, m_ConstantInt<0>())) {
331       Index = AddRec;
332       End = SIntMaxSCEV(Index->getType());
333       return true;
334     }
335     return false;
336 
337   case ICmpInst::ICMP_SGT:
338     if (match(RHS, m_ConstantInt<-1>())) {
339       Index = AddRec;
340       End = SIntMaxSCEV(Index->getType());
341       return true;
342     }
343     return false;
344 
345   case ICmpInst::ICMP_SLT:
346   case ICmpInst::ICMP_ULT:
347     Index = AddRec;
348     End = SE.getSCEV(RHS);
349     return true;
350 
351   case ICmpInst::ICMP_SLE:
352   case ICmpInst::ICMP_ULE:
353     const SCEV *One = SE.getOne(RHS->getType());
354     const SCEV *RHSS = SE.getSCEV(RHS);
355     bool Signed = Pred == ICmpInst::ICMP_SLE;
356     if (SE.willNotOverflow(Instruction::BinaryOps::Add, Signed, RHSS, One)) {
357       Index = AddRec;
358       End = SE.getAddExpr(RHSS, One);
359       return true;
360     }
361     return false;
362   }
363 
364   llvm_unreachable("default clause returns!");
365 }
366 
367 // Try to parse range check in the form of "IV - Offset vs Limit" or "Offset -
368 // IV vs Limit"
369 bool InductiveRangeCheck::reassociateSubLHS(
370     Loop *L, Value *VariantLHS, Value *InvariantRHS, ICmpInst::Predicate Pred,
371     ScalarEvolution &SE, const SCEVAddRecExpr *&Index, const SCEV *&End) {
372   Value *LHS, *RHS;
373   if (!match(VariantLHS, m_Sub(m_Value(LHS), m_Value(RHS))))
374     return false;
375 
376   const SCEV *IV = SE.getSCEV(LHS);
377   const SCEV *Offset = SE.getSCEV(RHS);
378   const SCEV *Limit = SE.getSCEV(InvariantRHS);
379 
380   bool OffsetSubtracted = false;
381   if (SE.isLoopInvariant(IV, L))
382     // "Offset - IV vs Limit"
383     std::swap(IV, Offset);
384   else if (SE.isLoopInvariant(Offset, L))
385     // "IV - Offset vs Limit"
386     OffsetSubtracted = true;
387   else
388     return false;
389 
390   const auto *AddRec = dyn_cast<SCEVAddRecExpr>(IV);
391   if (!AddRec)
392     return false;
393 
394   // In order to turn "IV - Offset < Limit" into "IV < Limit + Offset", we need
395   // to be able to freely move values from left side of inequality to right side
396   // (just as in normal linear arithmetics). Overflows make things much more
397   // complicated, so we want to avoid this.
398   //
399   // Let's prove that the initial subtraction doesn't overflow with all IV's
400   // values from the safe range constructed for that check.
401   //
402   // [Case 1] IV - Offset < Limit
403   // It doesn't overflow if:
404   //     SINT_MIN <= IV - Offset <= SINT_MAX
405   // In terms of scaled SINT we need to prove:
406   //     SINT_MIN + Offset <= IV <= SINT_MAX + Offset
407   // Safe range will be constructed:
408   //     0 <= IV < Limit + Offset
409   // It means that 'IV - Offset' doesn't underflow, because:
410   //     SINT_MIN + Offset < 0 <= IV
411   // and doesn't overflow:
412   //     IV < Limit + Offset <= SINT_MAX + Offset
413   //
414   // [Case 2] Offset - IV > Limit
415   // It doesn't overflow if:
416   //     SINT_MIN <= Offset - IV <= SINT_MAX
417   // In terms of scaled SINT we need to prove:
418   //     -SINT_MIN >= IV - Offset >= -SINT_MAX
419   //     Offset - SINT_MIN >= IV >= Offset - SINT_MAX
420   // Safe range will be constructed:
421   //     0 <= IV < Offset - Limit
422   // It means that 'Offset - IV' doesn't underflow, because
423   //     Offset - SINT_MAX < 0 <= IV
424   // and doesn't overflow:
425   //     IV < Offset - Limit <= Offset - SINT_MIN
426   //
427   // For the computed upper boundary of the IV's range (Offset +/- Limit) we
428   // don't know exactly whether it overflows or not. So if we can't prove this
429   // fact at compile time, we scale boundary computations to a wider type with
430   // the intention to add runtime overflow check.
431 
432   auto getExprScaledIfOverflow = [&](Instruction::BinaryOps BinOp,
433                                      const SCEV *LHS,
434                                      const SCEV *RHS) -> const SCEV * {
435     const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
436                                               SCEV::NoWrapFlags, unsigned);
437     switch (BinOp) {
438     default:
439       llvm_unreachable("Unsupported binary op");
440     case Instruction::Add:
441       Operation = &ScalarEvolution::getAddExpr;
442       break;
443     case Instruction::Sub:
444       Operation = &ScalarEvolution::getMinusSCEV;
445       break;
446     }
447 
448     if (SE.willNotOverflow(BinOp, ICmpInst::isSigned(Pred), LHS, RHS,
449                            cast<Instruction>(VariantLHS)))
450       return (SE.*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0);
451 
452     // We couldn't prove that the expression does not overflow.
453     // Than scale it to a wider type to check overflow at runtime.
454     auto *Ty = cast<IntegerType>(LHS->getType());
455     if (Ty->getBitWidth() > MaxTypeSizeForOverflowCheck)
456       return nullptr;
457 
458     auto WideTy = IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
459     return (SE.*Operation)(SE.getSignExtendExpr(LHS, WideTy),
460                            SE.getSignExtendExpr(RHS, WideTy), SCEV::FlagAnyWrap,
461                            0);
462   };
463 
464   if (OffsetSubtracted)
465     // "IV - Offset < Limit" -> "IV" < Offset + Limit
466     Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Offset, Limit);
467   else {
468     // "Offset - IV > Limit" -> "IV" < Offset - Limit
469     Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Sub, Offset, Limit);
470     Pred = ICmpInst::getSwappedPredicate(Pred);
471   }
472 
473   if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
474     // "Expr <= Limit" -> "Expr < Limit + 1"
475     if (Pred == ICmpInst::ICMP_SLE && Limit)
476       Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Limit,
477                                       SE.getOne(Limit->getType()));
478     if (Limit) {
479       Index = AddRec;
480       End = Limit;
481       return true;
482     }
483   }
484   return false;
485 }
486 
487 void InductiveRangeCheck::extractRangeChecksFromCond(
488     Loop *L, ScalarEvolution &SE, Use &ConditionUse,
489     SmallVectorImpl<InductiveRangeCheck> &Checks,
490     SmallPtrSetImpl<Value *> &Visited) {
491   Value *Condition = ConditionUse.get();
492   if (!Visited.insert(Condition).second)
493     return;
494 
495   // TODO: Do the same for OR, XOR, NOT etc?
496   if (match(Condition, m_LogicalAnd(m_Value(), m_Value()))) {
497     extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0),
498                                Checks, Visited);
499     extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1),
500                                Checks, Visited);
501     return;
502   }
503 
504   ICmpInst *ICI = dyn_cast<ICmpInst>(Condition);
505   if (!ICI)
506     return;
507 
508   const SCEV *End = nullptr;
509   const SCEVAddRecExpr *IndexAddRec = nullptr;
510   if (!parseRangeCheckICmp(L, ICI, SE, IndexAddRec, End))
511     return;
512 
513   assert(IndexAddRec && "IndexAddRec was not computed");
514   assert(End && "End was not computed");
515 
516   if ((IndexAddRec->getLoop() != L) || !IndexAddRec->isAffine())
517     return;
518 
519   InductiveRangeCheck IRC;
520   IRC.End = End;
521   IRC.Begin = IndexAddRec->getStart();
522   IRC.Step = IndexAddRec->getStepRecurrence(SE);
523   IRC.CheckUse = &ConditionUse;
524   Checks.push_back(IRC);
525 }
526 
527 void InductiveRangeCheck::extractRangeChecksFromBranch(
528     BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI,
529     SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed) {
530   if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch())
531     return;
532 
533   unsigned IndexLoopSucc = L->contains(BI->getSuccessor(0)) ? 0 : 1;
534   assert(L->contains(BI->getSuccessor(IndexLoopSucc)) &&
535          "No edges coming to loop?");
536   BranchProbability LikelyTaken(15, 16);
537 
538   if (!SkipProfitabilityChecks && BPI &&
539       BPI->getEdgeProbability(BI->getParent(), IndexLoopSucc) < LikelyTaken)
540     return;
541 
542   // IRCE expects branch's true edge comes to loop. Invert branch for opposite
543   // case.
544   if (IndexLoopSucc != 0) {
545     IRBuilder<> Builder(BI);
546     InvertBranch(BI, Builder);
547     if (BPI)
548       BPI->swapSuccEdgesProbabilities(BI->getParent());
549     Changed = true;
550   }
551 
552   SmallPtrSet<Value *, 8> Visited;
553   InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0),
554                                                   Checks, Visited);
555 }
556 
557 // Add metadata to the loop L to disable loop optimizations. Callers need to
558 // confirm that optimizing loop L is not beneficial.
559 static void DisableAllLoopOptsOnLoop(Loop &L) {
560   // We do not care about any existing loopID related metadata for L, since we
561   // are setting all loop metadata to false.
562   LLVMContext &Context = L.getHeader()->getContext();
563   // Reserve first location for self reference to the LoopID metadata node.
564   MDNode *Dummy = MDNode::get(Context, {});
565   MDNode *DisableUnroll = MDNode::get(
566       Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
567   Metadata *FalseVal =
568       ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
569   MDNode *DisableVectorize = MDNode::get(
570       Context,
571       {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
572   MDNode *DisableLICMVersioning = MDNode::get(
573       Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
574   MDNode *DisableDistribution= MDNode::get(
575       Context,
576       {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
577   MDNode *NewLoopID =
578       MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
579                             DisableLICMVersioning, DisableDistribution});
580   // Set operand 0 to refer to the loop id itself.
581   NewLoopID->replaceOperandWith(0, NewLoopID);
582   L.setLoopID(NewLoopID);
583 }
584 
585 namespace {
586 
587 // Keeps track of the structure of a loop.  This is similar to llvm::Loop,
588 // except that it is more lightweight and can track the state of a loop through
589 // changing and potentially invalid IR.  This structure also formalizes the
590 // kinds of loops we can deal with -- ones that have a single latch that is also
591 // an exiting block *and* have a canonical induction variable.
592 struct LoopStructure {
593   const char *Tag = "";
594 
595   BasicBlock *Header = nullptr;
596   BasicBlock *Latch = nullptr;
597 
598   // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
599   // successor is `LatchExit', the exit block of the loop.
600   BranchInst *LatchBr = nullptr;
601   BasicBlock *LatchExit = nullptr;
602   unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max();
603 
604   // The loop represented by this instance of LoopStructure is semantically
605   // equivalent to:
606   //
607   // intN_ty inc = IndVarIncreasing ? 1 : -1;
608   // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT;
609   //
610   // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase)
611   //   ... body ...
612 
613   Value *IndVarBase = nullptr;
614   Value *IndVarStart = nullptr;
615   Value *IndVarStep = nullptr;
616   Value *LoopExitAt = nullptr;
617   bool IndVarIncreasing = false;
618   bool IsSignedPredicate = true;
619 
620   LoopStructure() = default;
621 
622   template <typename M> LoopStructure map(M Map) const {
623     LoopStructure Result;
624     Result.Tag = Tag;
625     Result.Header = cast<BasicBlock>(Map(Header));
626     Result.Latch = cast<BasicBlock>(Map(Latch));
627     Result.LatchBr = cast<BranchInst>(Map(LatchBr));
628     Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
629     Result.LatchBrExitIdx = LatchBrExitIdx;
630     Result.IndVarBase = Map(IndVarBase);
631     Result.IndVarStart = Map(IndVarStart);
632     Result.IndVarStep = Map(IndVarStep);
633     Result.LoopExitAt = Map(LoopExitAt);
634     Result.IndVarIncreasing = IndVarIncreasing;
635     Result.IsSignedPredicate = IsSignedPredicate;
636     return Result;
637   }
638 
639   static std::optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
640                                                          Loop &, const char *&);
641 };
642 
643 /// This class is used to constrain loops to run within a given iteration space.
644 /// The algorithm this class implements is given a Loop and a range [Begin,
645 /// End).  The algorithm then tries to break out a "main loop" out of the loop
646 /// it is given in a way that the "main loop" runs with the induction variable
647 /// in a subset of [Begin, End).  The algorithm emits appropriate pre and post
648 /// loops to run any remaining iterations.  The pre loop runs any iterations in
649 /// which the induction variable is < Begin, and the post loop runs any
650 /// iterations in which the induction variable is >= End.
651 class LoopConstrainer {
652   // The representation of a clone of the original loop we started out with.
653   struct ClonedLoop {
654     // The cloned blocks
655     std::vector<BasicBlock *> Blocks;
656 
657     // `Map` maps values in the clonee into values in the cloned version
658     ValueToValueMapTy Map;
659 
660     // An instance of `LoopStructure` for the cloned loop
661     LoopStructure Structure;
662   };
663 
664   // Result of rewriting the range of a loop.  See changeIterationSpaceEnd for
665   // more details on what these fields mean.
666   struct RewrittenRangeInfo {
667     BasicBlock *PseudoExit = nullptr;
668     BasicBlock *ExitSelector = nullptr;
669     std::vector<PHINode *> PHIValuesAtPseudoExit;
670     PHINode *IndVarEnd = nullptr;
671 
672     RewrittenRangeInfo() = default;
673   };
674 
675   // Calculated subranges we restrict the iteration space of the main loop to.
676   // See the implementation of `calculateSubRanges' for more details on how
677   // these fields are computed.  `LowLimit` is std::nullopt if there is no
678   // restriction on low end of the restricted iteration space of the main loop.
679   // `HighLimit` is std::nullopt if there is no restriction on high end of the
680   // restricted iteration space of the main loop.
681 
682   struct SubRanges {
683     std::optional<const SCEV *> LowLimit;
684     std::optional<const SCEV *> HighLimit;
685   };
686 
687   // Compute a safe set of limits for the main loop to run in -- effectively the
688   // intersection of `Range' and the iteration space of the original loop.
689   // Return std::nullopt if unable to compute the set of subranges.
690   std::optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const;
691 
692   // Clone `OriginalLoop' and return the result in CLResult.  The IR after
693   // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
694   // the PHI nodes say that there is an incoming edge from `OriginalPreheader`
695   // but there is no such edge.
696   void cloneLoop(ClonedLoop &CLResult, const char *Tag) const;
697 
698   // Create the appropriate loop structure needed to describe a cloned copy of
699   // `Original`.  The clone is described by `VM`.
700   Loop *createClonedLoopStructure(Loop *Original, Loop *Parent,
701                                   ValueToValueMapTy &VM, bool IsSubloop);
702 
703   // Rewrite the iteration space of the loop denoted by (LS, Preheader). The
704   // iteration space of the rewritten loop ends at ExitLoopAt.  The start of the
705   // iteration space is not changed.  `ExitLoopAt' is assumed to be slt
706   // `OriginalHeaderCount'.
707   //
708   // If there are iterations left to execute, control is made to jump to
709   // `ContinuationBlock', otherwise they take the normal loop exit.  The
710   // returned `RewrittenRangeInfo' object is populated as follows:
711   //
712   //  .PseudoExit is a basic block that unconditionally branches to
713   //      `ContinuationBlock'.
714   //
715   //  .ExitSelector is a basic block that decides, on exit from the loop,
716   //      whether to branch to the "true" exit or to `PseudoExit'.
717   //
718   //  .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value
719   //      for each PHINode in the loop header on taking the pseudo exit.
720   //
721   // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate
722   // preheader because it is made to branch to the loop header only
723   // conditionally.
724   RewrittenRangeInfo
725   changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader,
726                           Value *ExitLoopAt,
727                           BasicBlock *ContinuationBlock) const;
728 
729   // The loop denoted by `LS' has `OldPreheader' as its preheader.  This
730   // function creates a new preheader for `LS' and returns it.
731   BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
732                               const char *Tag) const;
733 
734   // `ContinuationBlockAndPreheader' was the continuation block for some call to
735   // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
736   // This function rewrites the PHI nodes in `LS.Header' to start with the
737   // correct value.
738   void rewriteIncomingValuesForPHIs(
739       LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
740       const LoopConstrainer::RewrittenRangeInfo &RRI) const;
741 
742   // Even though we do not preserve any passes at this time, we at least need to
743   // keep the parent loop structure consistent.  The `LPPassManager' seems to
744   // verify this after running a loop pass.  This function adds the list of
745   // blocks denoted by BBs to this loops parent loop if required.
746   void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs);
747 
748   // Some global state.
749   Function &F;
750   LLVMContext &Ctx;
751   ScalarEvolution &SE;
752   DominatorTree &DT;
753   LoopInfo &LI;
754   function_ref<void(Loop *, bool)> LPMAddNewLoop;
755 
756   // Information about the original loop we started out with.
757   Loop &OriginalLoop;
758 
759   const IntegerType *ExitCountTy = nullptr;
760   BasicBlock *OriginalPreheader = nullptr;
761 
762   // The preheader of the main loop.  This may or may not be different from
763   // `OriginalPreheader'.
764   BasicBlock *MainLoopPreheader = nullptr;
765 
766   // The range we need to run the main loop in.
767   InductiveRangeCheck::Range Range;
768 
769   // The structure of the main loop (see comment at the beginning of this class
770   // for a definition)
771   LoopStructure MainLoopStructure;
772 
773 public:
774   LoopConstrainer(Loop &L, LoopInfo &LI,
775                   function_ref<void(Loop *, bool)> LPMAddNewLoop,
776                   const LoopStructure &LS, ScalarEvolution &SE,
777                   DominatorTree &DT, InductiveRangeCheck::Range R)
778       : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
779         SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L),
780         Range(R), MainLoopStructure(LS) {}
781 
782   // Entry point for the algorithm.  Returns true on success.
783   bool run();
784 };
785 
786 } // end anonymous namespace
787 
788 /// Given a loop with an deccreasing induction variable, is it possible to
789 /// safely calculate the bounds of a new loop using the given Predicate.
790 static bool isSafeDecreasingBound(const SCEV *Start,
791                                   const SCEV *BoundSCEV, const SCEV *Step,
792                                   ICmpInst::Predicate Pred,
793                                   unsigned LatchBrExitIdx,
794                                   Loop *L, ScalarEvolution &SE) {
795   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
796       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
797     return false;
798 
799   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
800     return false;
801 
802   assert(SE.isKnownNegative(Step) && "expecting negative step");
803 
804   LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n");
805   LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
806   LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
807   LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
808   LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n");
809   LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
810 
811   bool IsSigned = ICmpInst::isSigned(Pred);
812   // The predicate that we need to check that the induction variable lies
813   // within bounds.
814   ICmpInst::Predicate BoundPred =
815     IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
816 
817   if (LatchBrExitIdx == 1)
818     return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
819 
820   assert(LatchBrExitIdx == 0 &&
821          "LatchBrExitIdx should be either 0 or 1");
822 
823   const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
824   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
825   APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) :
826     APInt::getMinValue(BitWidth);
827   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
828 
829   const SCEV *MinusOne =
830     SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
831 
832   return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
833          SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
834 
835 }
836 
837 /// Given a loop with an increasing induction variable, is it possible to
838 /// safely calculate the bounds of a new loop using the given Predicate.
839 static bool isSafeIncreasingBound(const SCEV *Start,
840                                   const SCEV *BoundSCEV, const SCEV *Step,
841                                   ICmpInst::Predicate Pred,
842                                   unsigned LatchBrExitIdx,
843                                   Loop *L, ScalarEvolution &SE) {
844   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
845       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
846     return false;
847 
848   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
849     return false;
850 
851   LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n");
852   LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
853   LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
854   LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
855   LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n");
856   LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
857 
858   bool IsSigned = ICmpInst::isSigned(Pred);
859   // The predicate that we need to check that the induction variable lies
860   // within bounds.
861   ICmpInst::Predicate BoundPred =
862       IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
863 
864   if (LatchBrExitIdx == 1)
865     return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
866 
867   assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
868 
869   const SCEV *StepMinusOne =
870     SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
871   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
872   APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) :
873     APInt::getMaxValue(BitWidth);
874   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
875 
876   return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
877                                       SE.getAddExpr(BoundSCEV, Step)) &&
878           SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
879 }
880 
881 /// Returns estimate for max latch taken count of the loop of the narrowest
882 /// available type. If the latch block has such estimate, it is returned.
883 /// Otherwise, we use max exit count of whole loop (that is potentially of wider
884 /// type than latch check itself), which is still better than no estimate.
885 static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
886                                                           const Loop &L) {
887   const SCEV *FromBlock =
888       SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
889   if (isa<SCEVCouldNotCompute>(FromBlock))
890     return SE.getSymbolicMaxBackedgeTakenCount(&L);
891   return FromBlock;
892 }
893 
894 std::optional<LoopStructure>
895 LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
896                                   const char *&FailureReason) {
897   if (!L.isLoopSimplifyForm()) {
898     FailureReason = "loop not in LoopSimplify form";
899     return std::nullopt;
900   }
901 
902   BasicBlock *Latch = L.getLoopLatch();
903   assert(Latch && "Simplified loops only have one latch!");
904 
905   if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
906     FailureReason = "loop has already been cloned";
907     return std::nullopt;
908   }
909 
910   if (!L.isLoopExiting(Latch)) {
911     FailureReason = "no loop latch";
912     return std::nullopt;
913   }
914 
915   BasicBlock *Header = L.getHeader();
916   BasicBlock *Preheader = L.getLoopPreheader();
917   if (!Preheader) {
918     FailureReason = "no preheader";
919     return std::nullopt;
920   }
921 
922   BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
923   if (!LatchBr || LatchBr->isUnconditional()) {
924     FailureReason = "latch terminator not conditional branch";
925     return std::nullopt;
926   }
927 
928   unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
929 
930   ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
931   if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
932     FailureReason = "latch terminator branch not conditional on integral icmp";
933     return std::nullopt;
934   }
935 
936   const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
937   if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
938     FailureReason = "could not compute latch count";
939     return std::nullopt;
940   }
941   assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
942              ScalarEvolution::LoopInvariant &&
943          "loop variant exit count doesn't make sense!");
944 
945   ICmpInst::Predicate Pred = ICI->getPredicate();
946   Value *LeftValue = ICI->getOperand(0);
947   const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
948   IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
949 
950   Value *RightValue = ICI->getOperand(1);
951   const SCEV *RightSCEV = SE.getSCEV(RightValue);
952 
953   // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
954   if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
955     if (isa<SCEVAddRecExpr>(RightSCEV)) {
956       std::swap(LeftSCEV, RightSCEV);
957       std::swap(LeftValue, RightValue);
958       Pred = ICmpInst::getSwappedPredicate(Pred);
959     } else {
960       FailureReason = "no add recurrences in the icmp";
961       return std::nullopt;
962     }
963   }
964 
965   auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
966     if (AR->getNoWrapFlags(SCEV::FlagNSW))
967       return true;
968 
969     IntegerType *Ty = cast<IntegerType>(AR->getType());
970     IntegerType *WideTy =
971         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
972 
973     const SCEVAddRecExpr *ExtendAfterOp =
974         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
975     if (ExtendAfterOp) {
976       const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
977       const SCEV *ExtendedStep =
978           SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
979 
980       bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
981                           ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
982 
983       if (NoSignedWrap)
984         return true;
985     }
986 
987     // We may have proved this when computing the sign extension above.
988     return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
989   };
990 
991   // `ICI` is interpreted as taking the backedge if the *next* value of the
992   // induction variable satisfies some constraint.
993 
994   const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
995   if (IndVarBase->getLoop() != &L) {
996     FailureReason = "LHS in cmp is not an AddRec for this loop";
997     return std::nullopt;
998   }
999   if (!IndVarBase->isAffine()) {
1000     FailureReason = "LHS in icmp not induction variable";
1001     return std::nullopt;
1002   }
1003   const SCEV* StepRec = IndVarBase->getStepRecurrence(SE);
1004   if (!isa<SCEVConstant>(StepRec)) {
1005     FailureReason = "LHS in icmp not induction variable";
1006     return std::nullopt;
1007   }
1008   ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
1009 
1010   if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
1011     FailureReason = "LHS in icmp needs nsw for equality predicates";
1012     return std::nullopt;
1013   }
1014 
1015   assert(!StepCI->isZero() && "Zero step?");
1016   bool IsIncreasing = !StepCI->isNegative();
1017   bool IsSignedPredicate;
1018   const SCEV *StartNext = IndVarBase->getStart();
1019   const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
1020   const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
1021   const SCEV *Step = SE.getSCEV(StepCI);
1022 
1023   const SCEV *FixedRightSCEV = nullptr;
1024 
1025   // If RightValue resides within loop (but still being loop invariant),
1026   // regenerate it as preheader.
1027   if (auto *I = dyn_cast<Instruction>(RightValue))
1028     if (L.contains(I->getParent()))
1029       FixedRightSCEV = RightSCEV;
1030 
1031   if (IsIncreasing) {
1032     bool DecreasedRightValueByOne = false;
1033     if (StepCI->isOne()) {
1034       // Try to turn eq/ne predicates to those we can work with.
1035       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
1036         // while (++i != len) {         while (++i < len) {
1037         //   ...                 --->     ...
1038         // }                            }
1039         // If both parts are known non-negative, it is profitable to use
1040         // unsigned comparison in increasing loop. This allows us to make the
1041         // comparison check against "RightSCEV + 1" more optimistic.
1042         if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
1043             isKnownNonNegativeInLoop(RightSCEV, &L, SE))
1044           Pred = ICmpInst::ICMP_ULT;
1045         else
1046           Pred = ICmpInst::ICMP_SLT;
1047       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
1048         // while (true) {               while (true) {
1049         //   if (++i == len)     --->     if (++i > len - 1)
1050         //     break;                       break;
1051         //   ...                          ...
1052         // }                            }
1053         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
1054             cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) {
1055           Pred = ICmpInst::ICMP_UGT;
1056           RightSCEV = SE.getMinusSCEV(RightSCEV,
1057                                       SE.getOne(RightSCEV->getType()));
1058           DecreasedRightValueByOne = true;
1059         } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) {
1060           Pred = ICmpInst::ICMP_SGT;
1061           RightSCEV = SE.getMinusSCEV(RightSCEV,
1062                                       SE.getOne(RightSCEV->getType()));
1063           DecreasedRightValueByOne = true;
1064         }
1065       }
1066     }
1067 
1068     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
1069     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
1070     bool FoundExpectedPred =
1071         (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
1072 
1073     if (!FoundExpectedPred) {
1074       FailureReason = "expected icmp slt semantically, found something else";
1075       return std::nullopt;
1076     }
1077 
1078     IsSignedPredicate = ICmpInst::isSigned(Pred);
1079     if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
1080       FailureReason = "unsigned latch conditions are explicitly prohibited";
1081       return std::nullopt;
1082     }
1083 
1084     if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
1085                                LatchBrExitIdx, &L, SE)) {
1086       FailureReason = "Unsafe loop bounds";
1087       return std::nullopt;
1088     }
1089     if (LatchBrExitIdx == 0) {
1090       // We need to increase the right value unless we have already decreased
1091       // it virtually when we replaced EQ with SGT.
1092       if (!DecreasedRightValueByOne)
1093         FixedRightSCEV =
1094             SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
1095     } else {
1096       assert(!DecreasedRightValueByOne &&
1097              "Right value can be decreased only for LatchBrExitIdx == 0!");
1098     }
1099   } else {
1100     bool IncreasedRightValueByOne = false;
1101     if (StepCI->isMinusOne()) {
1102       // Try to turn eq/ne predicates to those we can work with.
1103       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
1104         // while (--i != len) {         while (--i > len) {
1105         //   ...                 --->     ...
1106         // }                            }
1107         // We intentionally don't turn the predicate into UGT even if we know
1108         // that both operands are non-negative, because it will only pessimize
1109         // our check against "RightSCEV - 1".
1110         Pred = ICmpInst::ICMP_SGT;
1111       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
1112         // while (true) {               while (true) {
1113         //   if (--i == len)     --->     if (--i < len + 1)
1114         //     break;                       break;
1115         //   ...                          ...
1116         // }                            }
1117         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
1118             cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
1119           Pred = ICmpInst::ICMP_ULT;
1120           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
1121           IncreasedRightValueByOne = true;
1122         } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
1123           Pred = ICmpInst::ICMP_SLT;
1124           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
1125           IncreasedRightValueByOne = true;
1126         }
1127       }
1128     }
1129 
1130     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
1131     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
1132 
1133     bool FoundExpectedPred =
1134         (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
1135 
1136     if (!FoundExpectedPred) {
1137       FailureReason = "expected icmp sgt semantically, found something else";
1138       return std::nullopt;
1139     }
1140 
1141     IsSignedPredicate =
1142         Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
1143 
1144     if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
1145       FailureReason = "unsigned latch conditions are explicitly prohibited";
1146       return std::nullopt;
1147     }
1148 
1149     if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
1150                                LatchBrExitIdx, &L, SE)) {
1151       FailureReason = "Unsafe bounds";
1152       return std::nullopt;
1153     }
1154 
1155     if (LatchBrExitIdx == 0) {
1156       // We need to decrease the right value unless we have already increased
1157       // it virtually when we replaced EQ with SLT.
1158       if (!IncreasedRightValueByOne)
1159         FixedRightSCEV =
1160             SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
1161     } else {
1162       assert(!IncreasedRightValueByOne &&
1163              "Right value can be increased only for LatchBrExitIdx == 0!");
1164     }
1165   }
1166   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
1167 
1168   assert(!L.contains(LatchExit) && "expected an exit block!");
1169   const DataLayout &DL = Preheader->getModule()->getDataLayout();
1170   SCEVExpander Expander(SE, DL, "irce");
1171   Instruction *Ins = Preheader->getTerminator();
1172 
1173   if (FixedRightSCEV)
1174     RightValue =
1175         Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
1176 
1177   Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
1178   IndVarStartV->setName("indvar.start");
1179 
1180   LoopStructure Result;
1181 
1182   Result.Tag = "main";
1183   Result.Header = Header;
1184   Result.Latch = Latch;
1185   Result.LatchBr = LatchBr;
1186   Result.LatchExit = LatchExit;
1187   Result.LatchBrExitIdx = LatchBrExitIdx;
1188   Result.IndVarStart = IndVarStartV;
1189   Result.IndVarStep = StepCI;
1190   Result.IndVarBase = LeftValue;
1191   Result.IndVarIncreasing = IsIncreasing;
1192   Result.LoopExitAt = RightValue;
1193   Result.IsSignedPredicate = IsSignedPredicate;
1194 
1195   FailureReason = nullptr;
1196 
1197   return Result;
1198 }
1199 
1200 /// If the type of \p S matches with \p Ty, return \p S. Otherwise, return
1201 /// signed or unsigned extension of \p S to type \p Ty.
1202 static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE,
1203                                 bool Signed) {
1204   return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty);
1205 }
1206 
1207 std::optional<LoopConstrainer::SubRanges>
1208 LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
1209   auto *RTy = cast<IntegerType>(Range.getType());
1210   // We only support wide range checks and narrow latches.
1211   if (!AllowNarrowLatchCondition && RTy != ExitCountTy)
1212     return std::nullopt;
1213   if (RTy->getBitWidth() < ExitCountTy->getBitWidth())
1214     return std::nullopt;
1215 
1216   LoopConstrainer::SubRanges Result;
1217 
1218   // I think we can be more aggressive here and make this nuw / nsw if the
1219   // addition that feeds into the icmp for the latch's terminating branch is nuw
1220   // / nsw.  In any case, a wrapping 2's complement addition is safe.
1221   const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart),
1222                                    RTy, SE, IsSignedPredicate);
1223   const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy,
1224                                  SE, IsSignedPredicate);
1225 
1226   bool Increasing = MainLoopStructure.IndVarIncreasing;
1227 
1228   // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or
1229   // [Smallest, GreatestSeen] is the range of values the induction variable
1230   // takes.
1231 
1232   const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr;
1233 
1234   const SCEV *One = SE.getOne(RTy);
1235   if (Increasing) {
1236     Smallest = Start;
1237     Greatest = End;
1238     // No overflow, because the range [Smallest, GreatestSeen] is not empty.
1239     GreatestSeen = SE.getMinusSCEV(End, One);
1240   } else {
1241     // These two computations may sign-overflow.  Here is why that is okay:
1242     //
1243     // We know that the induction variable does not sign-overflow on any
1244     // iteration except the last one, and it starts at `Start` and ends at
1245     // `End`, decrementing by one every time.
1246     //
1247     //  * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
1248     //    induction variable is decreasing we know that that the smallest value
1249     //    the loop body is actually executed with is `INT_SMIN` == `Smallest`.
1250     //
1251     //  * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`.  In
1252     //    that case, `Clamp` will always return `Smallest` and
1253     //    [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`)
1254     //    will be an empty range.  Returning an empty range is always safe.
1255 
1256     Smallest = SE.getAddExpr(End, One);
1257     Greatest = SE.getAddExpr(Start, One);
1258     GreatestSeen = Start;
1259   }
1260 
1261   auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) {
1262     return IsSignedPredicate
1263                ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S))
1264                : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S));
1265   };
1266 
1267   // In some cases we can prove that we don't need a pre or post loop.
1268   ICmpInst::Predicate PredLE =
1269       IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
1270   ICmpInst::Predicate PredLT =
1271       IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1272 
1273   bool ProvablyNoPreloop =
1274       SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest);
1275   if (!ProvablyNoPreloop)
1276     Result.LowLimit = Clamp(Range.getBegin());
1277 
1278   bool ProvablyNoPostLoop =
1279       SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd());
1280   if (!ProvablyNoPostLoop)
1281     Result.HighLimit = Clamp(Range.getEnd());
1282 
1283   return Result;
1284 }
1285 
1286 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
1287                                 const char *Tag) const {
1288   for (BasicBlock *BB : OriginalLoop.getBlocks()) {
1289     BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
1290     Result.Blocks.push_back(Clone);
1291     Result.Map[BB] = Clone;
1292   }
1293 
1294   auto GetClonedValue = [&Result](Value *V) {
1295     assert(V && "null values not in domain!");
1296     auto It = Result.Map.find(V);
1297     if (It == Result.Map.end())
1298       return V;
1299     return static_cast<Value *>(It->second);
1300   };
1301 
1302   auto *ClonedLatch =
1303       cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
1304   ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
1305                                             MDNode::get(Ctx, {}));
1306 
1307   Result.Structure = MainLoopStructure.map(GetClonedValue);
1308   Result.Structure.Tag = Tag;
1309 
1310   for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
1311     BasicBlock *ClonedBB = Result.Blocks[i];
1312     BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
1313 
1314     assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
1315 
1316     for (Instruction &I : *ClonedBB)
1317       RemapInstruction(&I, Result.Map,
1318                        RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
1319 
1320     // Exit blocks will now have one more predecessor and their PHI nodes need
1321     // to be edited to reflect that.  No phi nodes need to be introduced because
1322     // the loop is in LCSSA.
1323 
1324     for (auto *SBB : successors(OriginalBB)) {
1325       if (OriginalLoop.contains(SBB))
1326         continue; // not an exit block
1327 
1328       for (PHINode &PN : SBB->phis()) {
1329         Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
1330         PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
1331         SE.forgetValue(&PN);
1332       }
1333     }
1334   }
1335 }
1336 
1337 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
1338     const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
1339     BasicBlock *ContinuationBlock) const {
1340   // We start with a loop with a single latch:
1341   //
1342   //    +--------------------+
1343   //    |                    |
1344   //    |     preheader      |
1345   //    |                    |
1346   //    +--------+-----------+
1347   //             |      ----------------\
1348   //             |     /                |
1349   //    +--------v----v------+          |
1350   //    |                    |          |
1351   //    |      header        |          |
1352   //    |                    |          |
1353   //    +--------------------+          |
1354   //                                    |
1355   //            .....                   |
1356   //                                    |
1357   //    +--------------------+          |
1358   //    |                    |          |
1359   //    |       latch        >----------/
1360   //    |                    |
1361   //    +-------v------------+
1362   //            |
1363   //            |
1364   //            |   +--------------------+
1365   //            |   |                    |
1366   //            +--->   original exit    |
1367   //                |                    |
1368   //                +--------------------+
1369   //
1370   // We change the control flow to look like
1371   //
1372   //
1373   //    +--------------------+
1374   //    |                    |
1375   //    |     preheader      >-------------------------+
1376   //    |                    |                         |
1377   //    +--------v-----------+                         |
1378   //             |    /-------------+                  |
1379   //             |   /              |                  |
1380   //    +--------v--v--------+      |                  |
1381   //    |                    |      |                  |
1382   //    |      header        |      |   +--------+     |
1383   //    |                    |      |   |        |     |
1384   //    +--------------------+      |   |  +-----v-----v-----------+
1385   //                                |   |  |                       |
1386   //                                |   |  |     .pseudo.exit      |
1387   //                                |   |  |                       |
1388   //                                |   |  +-----------v-----------+
1389   //                                |   |              |
1390   //            .....               |   |              |
1391   //                                |   |     +--------v-------------+
1392   //    +--------------------+      |   |     |                      |
1393   //    |                    |      |   |     |   ContinuationBlock  |
1394   //    |       latch        >------+   |     |                      |
1395   //    |                    |          |     +----------------------+
1396   //    +---------v----------+          |
1397   //              |                     |
1398   //              |                     |
1399   //              |     +---------------^-----+
1400   //              |     |                     |
1401   //              +----->    .exit.selector   |
1402   //                    |                     |
1403   //                    +----------v----------+
1404   //                               |
1405   //     +--------------------+    |
1406   //     |                    |    |
1407   //     |   original exit    <----+
1408   //     |                    |
1409   //     +--------------------+
1410 
1411   RewrittenRangeInfo RRI;
1412 
1413   BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
1414   RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
1415                                         &F, BBInsertLocation);
1416   RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
1417                                       BBInsertLocation);
1418 
1419   BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
1420   bool Increasing = LS.IndVarIncreasing;
1421   bool IsSignedPredicate = LS.IsSignedPredicate;
1422 
1423   IRBuilder<> B(PreheaderJump);
1424   auto *RangeTy = Range.getBegin()->getType();
1425   auto NoopOrExt = [&](Value *V) {
1426     if (V->getType() == RangeTy)
1427       return V;
1428     return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
1429                              : B.CreateZExt(V, RangeTy, "wide." + V->getName());
1430   };
1431 
1432   // EnterLoopCond - is it okay to start executing this `LS'?
1433   Value *EnterLoopCond = nullptr;
1434   auto Pred =
1435       Increasing
1436           ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
1437           : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
1438   Value *IndVarStart = NoopOrExt(LS.IndVarStart);
1439   EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
1440 
1441   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
1442   PreheaderJump->eraseFromParent();
1443 
1444   LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
1445   B.SetInsertPoint(LS.LatchBr);
1446   Value *IndVarBase = NoopOrExt(LS.IndVarBase);
1447   Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
1448 
1449   Value *CondForBranch = LS.LatchBrExitIdx == 1
1450                              ? TakeBackedgeLoopCond
1451                              : B.CreateNot(TakeBackedgeLoopCond);
1452 
1453   LS.LatchBr->setCondition(CondForBranch);
1454 
1455   B.SetInsertPoint(RRI.ExitSelector);
1456 
1457   // IterationsLeft - are there any more iterations left, given the original
1458   // upper bound on the induction variable?  If not, we branch to the "real"
1459   // exit.
1460   Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
1461   Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
1462   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
1463 
1464   BranchInst *BranchToContinuation =
1465       BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
1466 
1467   // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
1468   // each of the PHI nodes in the loop header.  This feeds into the initial
1469   // value of the same PHI nodes if/when we continue execution.
1470   for (PHINode &PN : LS.Header->phis()) {
1471     PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
1472                                       BranchToContinuation);
1473 
1474     NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
1475     NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
1476                         RRI.ExitSelector);
1477     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
1478   }
1479 
1480   RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
1481                                   BranchToContinuation);
1482   RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
1483   RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
1484 
1485   // The latch exit now has a branch from `RRI.ExitSelector' instead of
1486   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
1487   LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
1488 
1489   return RRI;
1490 }
1491 
1492 void LoopConstrainer::rewriteIncomingValuesForPHIs(
1493     LoopStructure &LS, BasicBlock *ContinuationBlock,
1494     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
1495   unsigned PHIIndex = 0;
1496   for (PHINode &PN : LS.Header->phis())
1497     PN.setIncomingValueForBlock(ContinuationBlock,
1498                                 RRI.PHIValuesAtPseudoExit[PHIIndex++]);
1499 
1500   LS.IndVarStart = RRI.IndVarEnd;
1501 }
1502 
1503 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
1504                                              BasicBlock *OldPreheader,
1505                                              const char *Tag) const {
1506   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
1507   BranchInst::Create(LS.Header, Preheader);
1508 
1509   LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
1510 
1511   return Preheader;
1512 }
1513 
1514 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
1515   Loop *ParentLoop = OriginalLoop.getParentLoop();
1516   if (!ParentLoop)
1517     return;
1518 
1519   for (BasicBlock *BB : BBs)
1520     ParentLoop->addBasicBlockToLoop(BB, LI);
1521 }
1522 
1523 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
1524                                                  ValueToValueMapTy &VM,
1525                                                  bool IsSubloop) {
1526   Loop &New = *LI.AllocateLoop();
1527   if (Parent)
1528     Parent->addChildLoop(&New);
1529   else
1530     LI.addTopLevelLoop(&New);
1531   LPMAddNewLoop(&New, IsSubloop);
1532 
1533   // Add all of the blocks in Original to the new loop.
1534   for (auto *BB : Original->blocks())
1535     if (LI.getLoopFor(BB) == Original)
1536       New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
1537 
1538   // Add all of the subloops to the new loop.
1539   for (Loop *SubLoop : *Original)
1540     createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
1541 
1542   return &New;
1543 }
1544 
1545 bool LoopConstrainer::run() {
1546   BasicBlock *Preheader = nullptr;
1547   const SCEV *MaxBETakenCount =
1548       getNarrowestLatchMaxTakenCountEstimate(SE, OriginalLoop);
1549   Preheader = OriginalLoop.getLoopPreheader();
1550   assert(!isa<SCEVCouldNotCompute>(MaxBETakenCount) && Preheader != nullptr &&
1551          "preconditions!");
1552   ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
1553 
1554   OriginalPreheader = Preheader;
1555   MainLoopPreheader = Preheader;
1556 
1557   bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
1558   std::optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate);
1559   if (!MaybeSR) {
1560     LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n");
1561     return false;
1562   }
1563 
1564   SubRanges SR = *MaybeSR;
1565   bool Increasing = MainLoopStructure.IndVarIncreasing;
1566   IntegerType *IVTy =
1567       cast<IntegerType>(Range.getBegin()->getType());
1568 
1569   SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
1570   Instruction *InsertPt = OriginalPreheader->getTerminator();
1571 
1572   // It would have been better to make `PreLoop' and `PostLoop'
1573   // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
1574   // constructor.
1575   ClonedLoop PreLoop, PostLoop;
1576   bool NeedsPreLoop =
1577       Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
1578   bool NeedsPostLoop =
1579       Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
1580 
1581   Value *ExitPreLoopAt = nullptr;
1582   Value *ExitMainLoopAt = nullptr;
1583   const SCEVConstant *MinusOneS =
1584       cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
1585 
1586   if (NeedsPreLoop) {
1587     const SCEV *ExitPreLoopAtSCEV = nullptr;
1588 
1589     if (Increasing)
1590       ExitPreLoopAtSCEV = *SR.LowLimit;
1591     else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
1592                                IsSignedPredicate))
1593       ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
1594     else {
1595       LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1596                         << "preloop exit limit.  HighLimit = "
1597                         << *(*SR.HighLimit) << "\n");
1598       return false;
1599     }
1600 
1601     if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
1602       LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
1603                         << " preloop exit limit " << *ExitPreLoopAtSCEV
1604                         << " at block " << InsertPt->getParent()->getName()
1605                         << "\n");
1606       return false;
1607     }
1608 
1609     ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
1610     ExitPreLoopAt->setName("exit.preloop.at");
1611   }
1612 
1613   if (NeedsPostLoop) {
1614     const SCEV *ExitMainLoopAtSCEV = nullptr;
1615 
1616     if (Increasing)
1617       ExitMainLoopAtSCEV = *SR.HighLimit;
1618     else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
1619                                IsSignedPredicate))
1620       ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
1621     else {
1622       LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1623                         << "mainloop exit limit.  LowLimit = "
1624                         << *(*SR.LowLimit) << "\n");
1625       return false;
1626     }
1627 
1628     if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
1629       LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
1630                         << " main loop exit limit " << *ExitMainLoopAtSCEV
1631                         << " at block " << InsertPt->getParent()->getName()
1632                         << "\n");
1633       return false;
1634     }
1635 
1636     ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
1637     ExitMainLoopAt->setName("exit.mainloop.at");
1638   }
1639 
1640   // We clone these ahead of time so that we don't have to deal with changing
1641   // and temporarily invalid IR as we transform the loops.
1642   if (NeedsPreLoop)
1643     cloneLoop(PreLoop, "preloop");
1644   if (NeedsPostLoop)
1645     cloneLoop(PostLoop, "postloop");
1646 
1647   RewrittenRangeInfo PreLoopRRI;
1648 
1649   if (NeedsPreLoop) {
1650     Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
1651                                                   PreLoop.Structure.Header);
1652 
1653     MainLoopPreheader =
1654         createPreheader(MainLoopStructure, Preheader, "mainloop");
1655     PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
1656                                          ExitPreLoopAt, MainLoopPreheader);
1657     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
1658                                  PreLoopRRI);
1659   }
1660 
1661   BasicBlock *PostLoopPreheader = nullptr;
1662   RewrittenRangeInfo PostLoopRRI;
1663 
1664   if (NeedsPostLoop) {
1665     PostLoopPreheader =
1666         createPreheader(PostLoop.Structure, Preheader, "postloop");
1667     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
1668                                           ExitMainLoopAt, PostLoopPreheader);
1669     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
1670                                  PostLoopRRI);
1671   }
1672 
1673   BasicBlock *NewMainLoopPreheader =
1674       MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
1675   BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
1676                              PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
1677                              PostLoopRRI.ExitSelector, NewMainLoopPreheader};
1678 
1679   // Some of the above may be nullptr, filter them out before passing to
1680   // addToParentLoopIfNeeded.
1681   auto NewBlocksEnd =
1682       std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
1683 
1684   addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
1685 
1686   DT.recalculate(F);
1687 
1688   // We need to first add all the pre and post loop blocks into the loop
1689   // structures (as part of createClonedLoopStructure), and then update the
1690   // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
1691   // LI when LoopSimplifyForm is generated.
1692   Loop *PreL = nullptr, *PostL = nullptr;
1693   if (!PreLoop.Blocks.empty()) {
1694     PreL = createClonedLoopStructure(&OriginalLoop,
1695                                      OriginalLoop.getParentLoop(), PreLoop.Map,
1696                                      /* IsSubLoop */ false);
1697   }
1698 
1699   if (!PostLoop.Blocks.empty()) {
1700     PostL =
1701         createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
1702                                   PostLoop.Map, /* IsSubLoop */ false);
1703   }
1704 
1705   // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
1706   auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) {
1707     formLCSSARecursively(*L, DT, &LI, &SE);
1708     simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
1709     // Pre/post loops are slow paths, we do not need to perform any loop
1710     // optimizations on them.
1711     if (!IsOriginalLoop)
1712       DisableAllLoopOptsOnLoop(*L);
1713   };
1714   if (PreL)
1715     CanonicalizeLoop(PreL, false);
1716   if (PostL)
1717     CanonicalizeLoop(PostL, false);
1718   CanonicalizeLoop(&OriginalLoop, true);
1719 
1720   /// At this point:
1721   /// - We've broken a "main loop" out of the loop in a way that the "main loop"
1722   /// runs with the induction variable in a subset of [Begin, End).
1723   /// - There is no overflow when computing "main loop" exit limit.
1724   /// - Max latch taken count of the loop is limited.
1725   /// It guarantees that induction variable will not overflow iterating in the
1726   /// "main loop".
1727   if (auto BO = dyn_cast<BinaryOperator>(MainLoopStructure.IndVarBase))
1728     if (IsSignedPredicate)
1729       BO->setHasNoSignedWrap(true);
1730   /// TODO: support unsigned predicate.
1731   /// To add NUW flag we need to prove that both operands of BO are
1732   /// non-negative. E.g:
1733   /// ...
1734   /// %iv.next = add nsw i32 %iv, -1
1735   /// %cmp = icmp ult i32 %iv.next, %n
1736   /// br i1 %cmp, label %loopexit, label %loop
1737   ///
1738   /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
1739   /// overflow, therefore NUW flag is not legal here.
1740 
1741   return true;
1742 }
1743 
1744 /// Computes and returns a range of values for the induction variable (IndVar)
1745 /// in which the range check can be safely elided.  If it cannot compute such a
1746 /// range, returns std::nullopt.
1747 std::optional<InductiveRangeCheck::Range>
1748 InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
1749                                                const SCEVAddRecExpr *IndVar,
1750                                                bool IsLatchSigned) const {
1751   // We can deal when types of latch check and range checks don't match in case
1752   // if latch check is more narrow.
1753   auto *IVType = dyn_cast<IntegerType>(IndVar->getType());
1754   auto *RCType = dyn_cast<IntegerType>(getBegin()->getType());
1755   auto *EndType = dyn_cast<IntegerType>(getEnd()->getType());
1756   // Do not work with pointer types.
1757   if (!IVType || !RCType)
1758     return std::nullopt;
1759   if (IVType->getBitWidth() > RCType->getBitWidth())
1760     return std::nullopt;
1761 
1762   // IndVar is of the form "A + B * I" (where "I" is the canonical induction
1763   // variable, that may or may not exist as a real llvm::Value in the loop) and
1764   // this inductive range check is a range check on the "C + D * I" ("C" is
1765   // getBegin() and "D" is getStep()).  We rewrite the value being range
1766   // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA".
1767   //
1768   // The actual inequalities we solve are of the form
1769   //
1770   //   0 <= M + 1 * IndVar < L given L >= 0  (i.e. N == 1)
1771   //
1772   // Here L stands for upper limit of the safe iteration space.
1773   // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid
1774   // overflows when calculating (0 - M) and (L - M) we, depending on type of
1775   // IV's iteration space, limit the calculations by borders of the iteration
1776   // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0.
1777   // If we figured out that "anything greater than (-M) is safe", we strengthen
1778   // this to "everything greater than 0 is safe", assuming that values between
1779   // -M and 0 just do not exist in unsigned iteration space, and we don't want
1780   // to deal with overflown values.
1781 
1782   if (!IndVar->isAffine())
1783     return std::nullopt;
1784 
1785   const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned);
1786   const SCEVConstant *B = dyn_cast<SCEVConstant>(
1787       NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned));
1788   if (!B)
1789     return std::nullopt;
1790   assert(!B->isZero() && "Recurrence with zero step?");
1791 
1792   const SCEV *C = getBegin();
1793   const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep());
1794   if (D != B)
1795     return std::nullopt;
1796 
1797   assert(!D->getValue()->isZero() && "Recurrence with zero step?");
1798   unsigned BitWidth = RCType->getBitWidth();
1799   const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
1800   const SCEV *SIntMin = SE.getConstant(APInt::getSignedMinValue(BitWidth));
1801 
1802   // Subtract Y from X so that it does not go through border of the IV
1803   // iteration space. Mathematically, it is equivalent to:
1804   //
1805   //    ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX).        [1]
1806   //
1807   // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to
1808   // any width of bit grid). But after we take min/max, the result is
1809   // guaranteed to be within [INT_MIN, INT_MAX].
1810   //
1811   // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min
1812   // values, depending on type of latch condition that defines IV iteration
1813   // space.
1814   auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) {
1815     // FIXME: The current implementation assumes that X is in [0, SINT_MAX].
1816     // This is required to ensure that SINT_MAX - X does not overflow signed and
1817     // that X - Y does not overflow unsigned if Y is negative. Can we lift this
1818     // restriction and make it work for negative X either?
1819     if (IsLatchSigned) {
1820       // X is a number from signed range, Y is interpreted as signed.
1821       // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only
1822       // thing we should care about is that we didn't cross SINT_MAX.
1823       // So, if Y is positive, we subtract Y safely.
1824       //   Rule 1: Y > 0 ---> Y.
1825       // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely.
1826       //   Rule 2: Y >=s (X - SINT_MAX) ---> Y.
1827       // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX).
1828       //   Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX).
1829       // It gives us smax(Y, X - SINT_MAX) to subtract in all cases.
1830       const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax);
1831       return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax),
1832                              SCEV::FlagNSW);
1833     } else
1834       // X is a number from unsigned range, Y is interpreted as signed.
1835       // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only
1836       // thing we should care about is that we didn't cross zero.
1837       // So, if Y is negative, we subtract Y safely.
1838       //   Rule 1: Y <s 0 ---> Y.
1839       // If 0 <= Y <= X, we subtract Y safely.
1840       //   Rule 2: Y <=s X ---> Y.
1841       // If 0 <= X < Y, we should stop at 0 and can only subtract X.
1842       //   Rule 3: Y >s X ---> X.
1843       // It gives us smin(X, Y) to subtract in all cases.
1844       return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW);
1845   };
1846   const SCEV *M = SE.getMinusSCEV(C, A);
1847   const SCEV *Zero = SE.getZero(M->getType());
1848 
1849   // This function returns SCEV equal to 1 if X is non-negative 0 otherwise.
1850   auto SCEVCheckNonNegative = [&](const SCEV *X) {
1851     const Loop *L = IndVar->getLoop();
1852     const SCEV *Zero = SE.getZero(X->getType());
1853     const SCEV *One = SE.getOne(X->getType());
1854     // Can we trivially prove that X is a non-negative or negative value?
1855     if (isKnownNonNegativeInLoop(X, L, SE))
1856       return One;
1857     else if (isKnownNegativeInLoop(X, L, SE))
1858       return Zero;
1859     // If not, we will have to figure it out during the execution.
1860     // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0.
1861     const SCEV *NegOne = SE.getNegativeSCEV(One);
1862     return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One);
1863   };
1864 
1865   // This function returns SCEV equal to 1 if X will not overflow in terms of
1866   // range check type, 0 otherwise.
1867   auto SCEVCheckWillNotOverflow = [&](const SCEV *X) {
1868     // X doesn't overflow if SINT_MAX >= X.
1869     // Then if (SINT_MAX - X) >= 0, X doesn't overflow
1870     const SCEV *SIntMaxExt = SE.getSignExtendExpr(SIntMax, X->getType());
1871     const SCEV *OverflowCheck =
1872         SCEVCheckNonNegative(SE.getMinusSCEV(SIntMaxExt, X));
1873 
1874     // X doesn't underflow if X >= SINT_MIN.
1875     // Then if (X - SINT_MIN) >= 0, X doesn't underflow
1876     const SCEV *SIntMinExt = SE.getSignExtendExpr(SIntMin, X->getType());
1877     const SCEV *UnderflowCheck =
1878         SCEVCheckNonNegative(SE.getMinusSCEV(X, SIntMinExt));
1879 
1880     return SE.getMulExpr(OverflowCheck, UnderflowCheck);
1881   };
1882 
1883   // FIXME: Current implementation of ClampedSubtract implicitly assumes that
1884   // X is non-negative (in sense of a signed value). We need to re-implement
1885   // this function in a way that it will correctly handle negative X as well.
1886   // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can
1887   // end up with a negative X and produce wrong results. So currently we ensure
1888   // that if getEnd() is negative then both ends of the safe range are zero.
1889   // Note that this may pessimize elimination of unsigned range checks against
1890   // negative values.
1891   const SCEV *REnd = getEnd();
1892   const SCEV *EndWillNotOverflow = SE.getOne(RCType);
1893 
1894   auto PrintRangeCheck = [&](raw_ostream &OS) {
1895     auto L = IndVar->getLoop();
1896     OS << "irce: in function ";
1897     OS << L->getHeader()->getParent()->getName();
1898     OS << ", in ";
1899     L->print(OS);
1900     OS << "there is range check with scaled boundary:\n";
1901     print(OS);
1902   };
1903 
1904   if (EndType->getBitWidth() > RCType->getBitWidth()) {
1905     assert(EndType->getBitWidth() == RCType->getBitWidth() * 2);
1906     if (PrintScaledBoundaryRangeChecks)
1907       PrintRangeCheck(errs());
1908     // End is computed with extended type but will be truncated to a narrow one
1909     // type of range check. Therefore we need a check that the result will not
1910     // overflow in terms of narrow type.
1911     EndWillNotOverflow =
1912         SE.getTruncateExpr(SCEVCheckWillNotOverflow(REnd), RCType);
1913     REnd = SE.getTruncateExpr(REnd, RCType);
1914   }
1915 
1916   const SCEV *RuntimeChecks =
1917       SE.getMulExpr(SCEVCheckNonNegative(REnd), EndWillNotOverflow);
1918   const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), RuntimeChecks);
1919   const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), RuntimeChecks);
1920 
1921   return InductiveRangeCheck::Range(Begin, End);
1922 }
1923 
1924 static std::optional<InductiveRangeCheck::Range>
1925 IntersectSignedRange(ScalarEvolution &SE,
1926                      const std::optional<InductiveRangeCheck::Range> &R1,
1927                      const InductiveRangeCheck::Range &R2) {
1928   if (R2.isEmpty(SE, /* IsSigned */ true))
1929     return std::nullopt;
1930   if (!R1)
1931     return R2;
1932   auto &R1Value = *R1;
1933   // We never return empty ranges from this function, and R1 is supposed to be
1934   // a result of intersection. Thus, R1 is never empty.
1935   assert(!R1Value.isEmpty(SE, /* IsSigned */ true) &&
1936          "We should never have empty R1!");
1937 
1938   // TODO: we could widen the smaller range and have this work; but for now we
1939   // bail out to keep things simple.
1940   if (R1Value.getType() != R2.getType())
1941     return std::nullopt;
1942 
1943   const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin());
1944   const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd());
1945 
1946   // If the resulting range is empty, just return std::nullopt.
1947   auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd);
1948   if (Ret.isEmpty(SE, /* IsSigned */ true))
1949     return std::nullopt;
1950   return Ret;
1951 }
1952 
1953 static std::optional<InductiveRangeCheck::Range>
1954 IntersectUnsignedRange(ScalarEvolution &SE,
1955                        const std::optional<InductiveRangeCheck::Range> &R1,
1956                        const InductiveRangeCheck::Range &R2) {
1957   if (R2.isEmpty(SE, /* IsSigned */ false))
1958     return std::nullopt;
1959   if (!R1)
1960     return R2;
1961   auto &R1Value = *R1;
1962   // We never return empty ranges from this function, and R1 is supposed to be
1963   // a result of intersection. Thus, R1 is never empty.
1964   assert(!R1Value.isEmpty(SE, /* IsSigned */ false) &&
1965          "We should never have empty R1!");
1966 
1967   // TODO: we could widen the smaller range and have this work; but for now we
1968   // bail out to keep things simple.
1969   if (R1Value.getType() != R2.getType())
1970     return std::nullopt;
1971 
1972   const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin());
1973   const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd());
1974 
1975   // If the resulting range is empty, just return std::nullopt.
1976   auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd);
1977   if (Ret.isEmpty(SE, /* IsSigned */ false))
1978     return std::nullopt;
1979   return Ret;
1980 }
1981 
1982 PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) {
1983   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
1984   LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
1985   // There are no loops in the function. Return before computing other expensive
1986   // analyses.
1987   if (LI.empty())
1988     return PreservedAnalyses::all();
1989   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
1990   auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F);
1991 
1992   // Get BFI analysis result on demand. Please note that modification of
1993   // CFG invalidates this analysis and we should handle it.
1994   auto getBFI = [&F, &AM ]()->BlockFrequencyInfo & {
1995     return AM.getResult<BlockFrequencyAnalysis>(F);
1996   };
1997   InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI, { getBFI });
1998 
1999   bool Changed = false;
2000   {
2001     bool CFGChanged = false;
2002     for (const auto &L : LI) {
2003       CFGChanged |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr,
2004                                  /*PreserveLCSSA=*/false);
2005       Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
2006     }
2007     Changed |= CFGChanged;
2008 
2009     if (CFGChanged && !SkipProfitabilityChecks) {
2010       PreservedAnalyses PA = PreservedAnalyses::all();
2011       PA.abandon<BlockFrequencyAnalysis>();
2012       AM.invalidate(F, PA);
2013     }
2014   }
2015 
2016   SmallPriorityWorklist<Loop *, 4> Worklist;
2017   appendLoopsToWorklist(LI, Worklist);
2018   auto LPMAddNewLoop = [&Worklist](Loop *NL, bool IsSubloop) {
2019     if (!IsSubloop)
2020       appendLoopsToWorklist(*NL, Worklist);
2021   };
2022 
2023   while (!Worklist.empty()) {
2024     Loop *L = Worklist.pop_back_val();
2025     if (IRCE.run(L, LPMAddNewLoop)) {
2026       Changed = true;
2027       if (!SkipProfitabilityChecks) {
2028         PreservedAnalyses PA = PreservedAnalyses::all();
2029         PA.abandon<BlockFrequencyAnalysis>();
2030         AM.invalidate(F, PA);
2031       }
2032     }
2033   }
2034 
2035   if (!Changed)
2036     return PreservedAnalyses::all();
2037   return getLoopPassPreservedAnalyses();
2038 }
2039 
2040 bool
2041 InductiveRangeCheckElimination::isProfitableToTransform(const Loop &L,
2042                                                         LoopStructure &LS) {
2043   if (SkipProfitabilityChecks)
2044     return true;
2045   if (GetBFI) {
2046     BlockFrequencyInfo &BFI = (*GetBFI)();
2047     uint64_t hFreq = BFI.getBlockFreq(LS.Header).getFrequency();
2048     uint64_t phFreq = BFI.getBlockFreq(L.getLoopPreheader()).getFrequency();
2049     if (phFreq != 0 && hFreq != 0 && (hFreq / phFreq < MinRuntimeIterations)) {
2050       LLVM_DEBUG(dbgs() << "irce: could not prove profitability: "
2051                         << "the estimated number of iterations basing on "
2052                            "frequency info is " << (hFreq / phFreq) << "\n";);
2053       return false;
2054     }
2055     return true;
2056   }
2057 
2058   if (!BPI)
2059     return true;
2060   BranchProbability ExitProbability =
2061       BPI->getEdgeProbability(LS.Latch, LS.LatchBrExitIdx);
2062   if (ExitProbability > BranchProbability(1, MinRuntimeIterations)) {
2063     LLVM_DEBUG(dbgs() << "irce: could not prove profitability: "
2064                       << "the exit probability is too big " << ExitProbability
2065                       << "\n";);
2066     return false;
2067   }
2068   return true;
2069 }
2070 
2071 bool InductiveRangeCheckElimination::run(
2072     Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) {
2073   if (L->getBlocks().size() >= LoopSizeCutoff) {
2074     LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n");
2075     return false;
2076   }
2077 
2078   BasicBlock *Preheader = L->getLoopPreheader();
2079   if (!Preheader) {
2080     LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n");
2081     return false;
2082   }
2083 
2084   LLVMContext &Context = Preheader->getContext();
2085   SmallVector<InductiveRangeCheck, 16> RangeChecks;
2086   bool Changed = false;
2087 
2088   for (auto *BBI : L->getBlocks())
2089     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
2090       InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI,
2091                                                         RangeChecks, Changed);
2092 
2093   if (RangeChecks.empty())
2094     return Changed;
2095 
2096   auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
2097     OS << "irce: looking at loop "; L->print(OS);
2098     OS << "irce: loop has " << RangeChecks.size()
2099        << " inductive range checks: \n";
2100     for (InductiveRangeCheck &IRC : RangeChecks)
2101       IRC.print(OS);
2102   };
2103 
2104   LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs()));
2105 
2106   if (PrintRangeChecks)
2107     PrintRecognizedRangeChecks(errs());
2108 
2109   const char *FailureReason = nullptr;
2110   std::optional<LoopStructure> MaybeLoopStructure =
2111       LoopStructure::parseLoopStructure(SE, *L, FailureReason);
2112   if (!MaybeLoopStructure) {
2113     LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: "
2114                       << FailureReason << "\n";);
2115     return Changed;
2116   }
2117   LoopStructure LS = *MaybeLoopStructure;
2118   if (!isProfitableToTransform(*L, LS))
2119     return Changed;
2120   const SCEVAddRecExpr *IndVar =
2121       cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep)));
2122 
2123   std::optional<InductiveRangeCheck::Range> SafeIterRange;
2124 
2125   SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate;
2126   // Basing on the type of latch predicate, we interpret the IV iteration range
2127   // as signed or unsigned range. We use different min/max functions (signed or
2128   // unsigned) when intersecting this range with safe iteration ranges implied
2129   // by range checks.
2130   auto IntersectRange =
2131       LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange;
2132 
2133   for (InductiveRangeCheck &IRC : RangeChecks) {
2134     auto Result = IRC.computeSafeIterationSpace(SE, IndVar,
2135                                                 LS.IsSignedPredicate);
2136     if (Result) {
2137       auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, *Result);
2138       if (MaybeSafeIterRange) {
2139         assert(!MaybeSafeIterRange->isEmpty(SE, LS.IsSignedPredicate) &&
2140                "We should never return empty ranges!");
2141         RangeChecksToEliminate.push_back(IRC);
2142         SafeIterRange = *MaybeSafeIterRange;
2143       }
2144     }
2145   }
2146 
2147   if (!SafeIterRange)
2148     return Changed;
2149 
2150   LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange);
2151 
2152   if (LC.run()) {
2153     Changed = true;
2154 
2155     auto PrintConstrainedLoopInfo = [L]() {
2156       dbgs() << "irce: in function ";
2157       dbgs() << L->getHeader()->getParent()->getName() << ": ";
2158       dbgs() << "constrained ";
2159       L->print(dbgs());
2160     };
2161 
2162     LLVM_DEBUG(PrintConstrainedLoopInfo());
2163 
2164     if (PrintChangedLoops)
2165       PrintConstrainedLoopInfo();
2166 
2167     // Optimize away the now-redundant range checks.
2168 
2169     for (InductiveRangeCheck &IRC : RangeChecksToEliminate) {
2170       ConstantInt *FoldedRangeCheck = IRC.getPassingDirection()
2171                                           ? ConstantInt::getTrue(Context)
2172                                           : ConstantInt::getFalse(Context);
2173       IRC.getCheckUse()->set(FoldedRangeCheck);
2174     }
2175   }
2176 
2177   return Changed;
2178 }
2179