xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp (revision ee0fe82ee2892f5ece189db0eab38913aaab5f0f)
1 //===- InductiveRangeCheckElimination.cpp - -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The InductiveRangeCheckElimination pass splits a loop's iteration space into
10 // three disjoint ranges.  It does that in a way such that the loop running in
11 // the middle loop provably does not need range checks. As an example, it will
12 // convert
13 //
14 //   len = < known positive >
15 //   for (i = 0; i < n; i++) {
16 //     if (0 <= i && i < len) {
17 //       do_something();
18 //     } else {
19 //       throw_out_of_bounds();
20 //     }
21 //   }
22 //
23 // to
24 //
25 //   len = < known positive >
26 //   limit = smin(n, len)
27 //   // no first segment
28 //   for (i = 0; i < limit; i++) {
29 //     if (0 <= i && i < len) { // this check is fully redundant
30 //       do_something();
31 //     } else {
32 //       throw_out_of_bounds();
33 //     }
34 //   }
35 //   for (i = limit; i < n; i++) {
36 //     if (0 <= i && i < len) {
37 //       do_something();
38 //     } else {
39 //       throw_out_of_bounds();
40 //     }
41 //   }
42 //
43 //===----------------------------------------------------------------------===//
44 
45 #include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h"
46 #include "llvm/ADT/APInt.h"
47 #include "llvm/ADT/ArrayRef.h"
48 #include "llvm/ADT/None.h"
49 #include "llvm/ADT/Optional.h"
50 #include "llvm/ADT/SmallPtrSet.h"
51 #include "llvm/ADT/SmallVector.h"
52 #include "llvm/ADT/StringRef.h"
53 #include "llvm/ADT/Twine.h"
54 #include "llvm/Analysis/BranchProbabilityInfo.h"
55 #include "llvm/Analysis/LoopAnalysisManager.h"
56 #include "llvm/Analysis/LoopInfo.h"
57 #include "llvm/Analysis/LoopPass.h"
58 #include "llvm/Analysis/ScalarEvolution.h"
59 #include "llvm/Analysis/ScalarEvolutionExpander.h"
60 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
61 #include "llvm/IR/BasicBlock.h"
62 #include "llvm/IR/CFG.h"
63 #include "llvm/IR/Constants.h"
64 #include "llvm/IR/DerivedTypes.h"
65 #include "llvm/IR/Dominators.h"
66 #include "llvm/IR/Function.h"
67 #include "llvm/IR/IRBuilder.h"
68 #include "llvm/IR/InstrTypes.h"
69 #include "llvm/IR/Instructions.h"
70 #include "llvm/IR/Metadata.h"
71 #include "llvm/IR/Module.h"
72 #include "llvm/IR/PatternMatch.h"
73 #include "llvm/IR/Type.h"
74 #include "llvm/IR/Use.h"
75 #include "llvm/IR/User.h"
76 #include "llvm/IR/Value.h"
77 #include "llvm/Pass.h"
78 #include "llvm/Support/BranchProbability.h"
79 #include "llvm/Support/Casting.h"
80 #include "llvm/Support/CommandLine.h"
81 #include "llvm/Support/Compiler.h"
82 #include "llvm/Support/Debug.h"
83 #include "llvm/Support/ErrorHandling.h"
84 #include "llvm/Support/raw_ostream.h"
85 #include "llvm/Transforms/Scalar.h"
86 #include "llvm/Transforms/Utils/Cloning.h"
87 #include "llvm/Transforms/Utils/LoopSimplify.h"
88 #include "llvm/Transforms/Utils/LoopUtils.h"
89 #include "llvm/Transforms/Utils/ValueMapper.h"
90 #include <algorithm>
91 #include <cassert>
92 #include <iterator>
93 #include <limits>
94 #include <utility>
95 #include <vector>
96 
97 using namespace llvm;
98 using namespace llvm::PatternMatch;
99 
100 static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
101                                         cl::init(64));
102 
103 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
104                                        cl::init(false));
105 
106 static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
107                                       cl::init(false));
108 
109 static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal",
110                                           cl::Hidden, cl::init(10));
111 
112 static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks",
113                                              cl::Hidden, cl::init(false));
114 
115 static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch",
116                                                  cl::Hidden, cl::init(true));
117 
118 static cl::opt<bool> AllowNarrowLatchCondition(
119     "irce-allow-narrow-latch", cl::Hidden, cl::init(true),
120     cl::desc("If set to true, IRCE may eliminate wide range checks in loops "
121              "with narrow latch condition."));
122 
123 static const char *ClonedLoopTag = "irce.loop.clone";
124 
125 #define DEBUG_TYPE "irce"
126 
127 namespace {
128 
129 /// An inductive range check is conditional branch in a loop with
130 ///
131 ///  1. a very cold successor (i.e. the branch jumps to that successor very
132 ///     rarely)
133 ///
134 ///  and
135 ///
136 ///  2. a condition that is provably true for some contiguous range of values
137 ///     taken by the containing loop's induction variable.
138 ///
139 class InductiveRangeCheck {
140 
141   const SCEV *Begin = nullptr;
142   const SCEV *Step = nullptr;
143   const SCEV *End = nullptr;
144   Use *CheckUse = nullptr;
145   bool IsSigned = true;
146 
147   static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE,
148                                   Value *&Index, Value *&Length,
149                                   bool &IsSigned);
150 
151   static void
152   extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse,
153                              SmallVectorImpl<InductiveRangeCheck> &Checks,
154                              SmallPtrSetImpl<Value *> &Visited);
155 
156 public:
157   const SCEV *getBegin() const { return Begin; }
158   const SCEV *getStep() const { return Step; }
159   const SCEV *getEnd() const { return End; }
160   bool isSigned() const { return IsSigned; }
161 
162   void print(raw_ostream &OS) const {
163     OS << "InductiveRangeCheck:\n";
164     OS << "  Begin: ";
165     Begin->print(OS);
166     OS << "  Step: ";
167     Step->print(OS);
168     OS << "  End: ";
169     End->print(OS);
170     OS << "\n  CheckUse: ";
171     getCheckUse()->getUser()->print(OS);
172     OS << " Operand: " << getCheckUse()->getOperandNo() << "\n";
173   }
174 
175   LLVM_DUMP_METHOD
176   void dump() {
177     print(dbgs());
178   }
179 
180   Use *getCheckUse() const { return CheckUse; }
181 
182   /// Represents an signed integer range [Range.getBegin(), Range.getEnd()).  If
183   /// R.getEnd() le R.getBegin(), then R denotes the empty range.
184 
185   class Range {
186     const SCEV *Begin;
187     const SCEV *End;
188 
189   public:
190     Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) {
191       assert(Begin->getType() == End->getType() && "ill-typed range!");
192     }
193 
194     Type *getType() const { return Begin->getType(); }
195     const SCEV *getBegin() const { return Begin; }
196     const SCEV *getEnd() const { return End; }
197     bool isEmpty(ScalarEvolution &SE, bool IsSigned) const {
198       if (Begin == End)
199         return true;
200       if (IsSigned)
201         return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End);
202       else
203         return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End);
204     }
205   };
206 
207   /// This is the value the condition of the branch needs to evaluate to for the
208   /// branch to take the hot successor (see (1) above).
209   bool getPassingDirection() { return true; }
210 
211   /// Computes a range for the induction variable (IndVar) in which the range
212   /// check is redundant and can be constant-folded away.  The induction
213   /// variable is not required to be the canonical {0,+,1} induction variable.
214   Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE,
215                                             const SCEVAddRecExpr *IndVar,
216                                             bool IsLatchSigned) const;
217 
218   /// Parse out a set of inductive range checks from \p BI and append them to \p
219   /// Checks.
220   ///
221   /// NB! There may be conditions feeding into \p BI that aren't inductive range
222   /// checks, and hence don't end up in \p Checks.
223   static void
224   extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE,
225                                BranchProbabilityInfo *BPI,
226                                SmallVectorImpl<InductiveRangeCheck> &Checks);
227 };
228 
229 class InductiveRangeCheckElimination {
230   ScalarEvolution &SE;
231   BranchProbabilityInfo *BPI;
232   DominatorTree &DT;
233   LoopInfo &LI;
234 
235 public:
236   InductiveRangeCheckElimination(ScalarEvolution &SE,
237                                  BranchProbabilityInfo *BPI, DominatorTree &DT,
238                                  LoopInfo &LI)
239       : SE(SE), BPI(BPI), DT(DT), LI(LI) {}
240 
241   bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop);
242 };
243 
244 class IRCELegacyPass : public LoopPass {
245 public:
246   static char ID;
247 
248   IRCELegacyPass() : LoopPass(ID) {
249     initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry());
250   }
251 
252   void getAnalysisUsage(AnalysisUsage &AU) const override {
253     AU.addRequired<BranchProbabilityInfoWrapperPass>();
254     getLoopAnalysisUsage(AU);
255   }
256 
257   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
258 };
259 
260 } // end anonymous namespace
261 
262 char IRCELegacyPass::ID = 0;
263 
264 INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce",
265                       "Inductive range check elimination", false, false)
266 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
267 INITIALIZE_PASS_DEPENDENCY(LoopPass)
268 INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination",
269                     false, false)
270 
271 /// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI` cannot
272 /// be interpreted as a range check, return false and set `Index` and `Length`
273 /// to `nullptr`.  Otherwise set `Index` to the value being range checked, and
274 /// set `Length` to the upper limit `Index` is being range checked.
275 bool
276 InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
277                                          ScalarEvolution &SE, Value *&Index,
278                                          Value *&Length, bool &IsSigned) {
279   auto IsLoopInvariant = [&SE, L](Value *V) {
280     return SE.isLoopInvariant(SE.getSCEV(V), L);
281   };
282 
283   ICmpInst::Predicate Pred = ICI->getPredicate();
284   Value *LHS = ICI->getOperand(0);
285   Value *RHS = ICI->getOperand(1);
286 
287   switch (Pred) {
288   default:
289     return false;
290 
291   case ICmpInst::ICMP_SLE:
292     std::swap(LHS, RHS);
293     LLVM_FALLTHROUGH;
294   case ICmpInst::ICMP_SGE:
295     IsSigned = true;
296     if (match(RHS, m_ConstantInt<0>())) {
297       Index = LHS;
298       return true; // Lower.
299     }
300     return false;
301 
302   case ICmpInst::ICMP_SLT:
303     std::swap(LHS, RHS);
304     LLVM_FALLTHROUGH;
305   case ICmpInst::ICMP_SGT:
306     IsSigned = true;
307     if (match(RHS, m_ConstantInt<-1>())) {
308       Index = LHS;
309       return true; // Lower.
310     }
311 
312     if (IsLoopInvariant(LHS)) {
313       Index = RHS;
314       Length = LHS;
315       return true; // Upper.
316     }
317     return false;
318 
319   case ICmpInst::ICMP_ULT:
320     std::swap(LHS, RHS);
321     LLVM_FALLTHROUGH;
322   case ICmpInst::ICMP_UGT:
323     IsSigned = false;
324     if (IsLoopInvariant(LHS)) {
325       Index = RHS;
326       Length = LHS;
327       return true; // Both lower and upper.
328     }
329     return false;
330   }
331 
332   llvm_unreachable("default clause returns!");
333 }
334 
335 void InductiveRangeCheck::extractRangeChecksFromCond(
336     Loop *L, ScalarEvolution &SE, Use &ConditionUse,
337     SmallVectorImpl<InductiveRangeCheck> &Checks,
338     SmallPtrSetImpl<Value *> &Visited) {
339   Value *Condition = ConditionUse.get();
340   if (!Visited.insert(Condition).second)
341     return;
342 
343   // TODO: Do the same for OR, XOR, NOT etc?
344   if (match(Condition, m_And(m_Value(), m_Value()))) {
345     extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0),
346                                Checks, Visited);
347     extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1),
348                                Checks, Visited);
349     return;
350   }
351 
352   ICmpInst *ICI = dyn_cast<ICmpInst>(Condition);
353   if (!ICI)
354     return;
355 
356   Value *Length = nullptr, *Index;
357   bool IsSigned;
358   if (!parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned))
359     return;
360 
361   const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index));
362   bool IsAffineIndex =
363       IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine();
364 
365   if (!IsAffineIndex)
366     return;
367 
368   const SCEV *End = nullptr;
369   // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
370   // We can potentially do much better here.
371   if (Length)
372     End = SE.getSCEV(Length);
373   else {
374     // So far we can only reach this point for Signed range check. This may
375     // change in future. In this case we will need to pick Unsigned max for the
376     // unsigned range check.
377     unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth();
378     const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
379     End = SIntMax;
380   }
381 
382   InductiveRangeCheck IRC;
383   IRC.End = End;
384   IRC.Begin = IndexAddRec->getStart();
385   IRC.Step = IndexAddRec->getStepRecurrence(SE);
386   IRC.CheckUse = &ConditionUse;
387   IRC.IsSigned = IsSigned;
388   Checks.push_back(IRC);
389 }
390 
391 void InductiveRangeCheck::extractRangeChecksFromBranch(
392     BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI,
393     SmallVectorImpl<InductiveRangeCheck> &Checks) {
394   if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch())
395     return;
396 
397   BranchProbability LikelyTaken(15, 16);
398 
399   if (!SkipProfitabilityChecks && BPI &&
400       BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken)
401     return;
402 
403   SmallPtrSet<Value *, 8> Visited;
404   InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0),
405                                                   Checks, Visited);
406 }
407 
408 // Add metadata to the loop L to disable loop optimizations. Callers need to
409 // confirm that optimizing loop L is not beneficial.
410 static void DisableAllLoopOptsOnLoop(Loop &L) {
411   // We do not care about any existing loopID related metadata for L, since we
412   // are setting all loop metadata to false.
413   LLVMContext &Context = L.getHeader()->getContext();
414   // Reserve first location for self reference to the LoopID metadata node.
415   MDNode *Dummy = MDNode::get(Context, {});
416   MDNode *DisableUnroll = MDNode::get(
417       Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
418   Metadata *FalseVal =
419       ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
420   MDNode *DisableVectorize = MDNode::get(
421       Context,
422       {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
423   MDNode *DisableLICMVersioning = MDNode::get(
424       Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
425   MDNode *DisableDistribution= MDNode::get(
426       Context,
427       {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
428   MDNode *NewLoopID =
429       MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
430                             DisableLICMVersioning, DisableDistribution});
431   // Set operand 0 to refer to the loop id itself.
432   NewLoopID->replaceOperandWith(0, NewLoopID);
433   L.setLoopID(NewLoopID);
434 }
435 
436 namespace {
437 
438 // Keeps track of the structure of a loop.  This is similar to llvm::Loop,
439 // except that it is more lightweight and can track the state of a loop through
440 // changing and potentially invalid IR.  This structure also formalizes the
441 // kinds of loops we can deal with -- ones that have a single latch that is also
442 // an exiting block *and* have a canonical induction variable.
443 struct LoopStructure {
444   const char *Tag = "";
445 
446   BasicBlock *Header = nullptr;
447   BasicBlock *Latch = nullptr;
448 
449   // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
450   // successor is `LatchExit', the exit block of the loop.
451   BranchInst *LatchBr = nullptr;
452   BasicBlock *LatchExit = nullptr;
453   unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max();
454 
455   // The loop represented by this instance of LoopStructure is semantically
456   // equivalent to:
457   //
458   // intN_ty inc = IndVarIncreasing ? 1 : -1;
459   // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT;
460   //
461   // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase)
462   //   ... body ...
463 
464   Value *IndVarBase = nullptr;
465   Value *IndVarStart = nullptr;
466   Value *IndVarStep = nullptr;
467   Value *LoopExitAt = nullptr;
468   bool IndVarIncreasing = false;
469   bool IsSignedPredicate = true;
470 
471   LoopStructure() = default;
472 
473   template <typename M> LoopStructure map(M Map) const {
474     LoopStructure Result;
475     Result.Tag = Tag;
476     Result.Header = cast<BasicBlock>(Map(Header));
477     Result.Latch = cast<BasicBlock>(Map(Latch));
478     Result.LatchBr = cast<BranchInst>(Map(LatchBr));
479     Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
480     Result.LatchBrExitIdx = LatchBrExitIdx;
481     Result.IndVarBase = Map(IndVarBase);
482     Result.IndVarStart = Map(IndVarStart);
483     Result.IndVarStep = Map(IndVarStep);
484     Result.LoopExitAt = Map(LoopExitAt);
485     Result.IndVarIncreasing = IndVarIncreasing;
486     Result.IsSignedPredicate = IsSignedPredicate;
487     return Result;
488   }
489 
490   static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
491                                                     BranchProbabilityInfo *BPI,
492                                                     Loop &, const char *&);
493 };
494 
495 /// This class is used to constrain loops to run within a given iteration space.
496 /// The algorithm this class implements is given a Loop and a range [Begin,
497 /// End).  The algorithm then tries to break out a "main loop" out of the loop
498 /// it is given in a way that the "main loop" runs with the induction variable
499 /// in a subset of [Begin, End).  The algorithm emits appropriate pre and post
500 /// loops to run any remaining iterations.  The pre loop runs any iterations in
501 /// which the induction variable is < Begin, and the post loop runs any
502 /// iterations in which the induction variable is >= End.
503 class LoopConstrainer {
504   // The representation of a clone of the original loop we started out with.
505   struct ClonedLoop {
506     // The cloned blocks
507     std::vector<BasicBlock *> Blocks;
508 
509     // `Map` maps values in the clonee into values in the cloned version
510     ValueToValueMapTy Map;
511 
512     // An instance of `LoopStructure` for the cloned loop
513     LoopStructure Structure;
514   };
515 
516   // Result of rewriting the range of a loop.  See changeIterationSpaceEnd for
517   // more details on what these fields mean.
518   struct RewrittenRangeInfo {
519     BasicBlock *PseudoExit = nullptr;
520     BasicBlock *ExitSelector = nullptr;
521     std::vector<PHINode *> PHIValuesAtPseudoExit;
522     PHINode *IndVarEnd = nullptr;
523 
524     RewrittenRangeInfo() = default;
525   };
526 
527   // Calculated subranges we restrict the iteration space of the main loop to.
528   // See the implementation of `calculateSubRanges' for more details on how
529   // these fields are computed.  `LowLimit` is None if there is no restriction
530   // on low end of the restricted iteration space of the main loop.  `HighLimit`
531   // is None if there is no restriction on high end of the restricted iteration
532   // space of the main loop.
533 
534   struct SubRanges {
535     Optional<const SCEV *> LowLimit;
536     Optional<const SCEV *> HighLimit;
537   };
538 
539   // Compute a safe set of limits for the main loop to run in -- effectively the
540   // intersection of `Range' and the iteration space of the original loop.
541   // Return None if unable to compute the set of subranges.
542   Optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const;
543 
544   // Clone `OriginalLoop' and return the result in CLResult.  The IR after
545   // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
546   // the PHI nodes say that there is an incoming edge from `OriginalPreheader`
547   // but there is no such edge.
548   void cloneLoop(ClonedLoop &CLResult, const char *Tag) const;
549 
550   // Create the appropriate loop structure needed to describe a cloned copy of
551   // `Original`.  The clone is described by `VM`.
552   Loop *createClonedLoopStructure(Loop *Original, Loop *Parent,
553                                   ValueToValueMapTy &VM, bool IsSubloop);
554 
555   // Rewrite the iteration space of the loop denoted by (LS, Preheader). The
556   // iteration space of the rewritten loop ends at ExitLoopAt.  The start of the
557   // iteration space is not changed.  `ExitLoopAt' is assumed to be slt
558   // `OriginalHeaderCount'.
559   //
560   // If there are iterations left to execute, control is made to jump to
561   // `ContinuationBlock', otherwise they take the normal loop exit.  The
562   // returned `RewrittenRangeInfo' object is populated as follows:
563   //
564   //  .PseudoExit is a basic block that unconditionally branches to
565   //      `ContinuationBlock'.
566   //
567   //  .ExitSelector is a basic block that decides, on exit from the loop,
568   //      whether to branch to the "true" exit or to `PseudoExit'.
569   //
570   //  .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value
571   //      for each PHINode in the loop header on taking the pseudo exit.
572   //
573   // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate
574   // preheader because it is made to branch to the loop header only
575   // conditionally.
576   RewrittenRangeInfo
577   changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader,
578                           Value *ExitLoopAt,
579                           BasicBlock *ContinuationBlock) const;
580 
581   // The loop denoted by `LS' has `OldPreheader' as its preheader.  This
582   // function creates a new preheader for `LS' and returns it.
583   BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
584                               const char *Tag) const;
585 
586   // `ContinuationBlockAndPreheader' was the continuation block for some call to
587   // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
588   // This function rewrites the PHI nodes in `LS.Header' to start with the
589   // correct value.
590   void rewriteIncomingValuesForPHIs(
591       LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
592       const LoopConstrainer::RewrittenRangeInfo &RRI) const;
593 
594   // Even though we do not preserve any passes at this time, we at least need to
595   // keep the parent loop structure consistent.  The `LPPassManager' seems to
596   // verify this after running a loop pass.  This function adds the list of
597   // blocks denoted by BBs to this loops parent loop if required.
598   void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs);
599 
600   // Some global state.
601   Function &F;
602   LLVMContext &Ctx;
603   ScalarEvolution &SE;
604   DominatorTree &DT;
605   LoopInfo &LI;
606   function_ref<void(Loop *, bool)> LPMAddNewLoop;
607 
608   // Information about the original loop we started out with.
609   Loop &OriginalLoop;
610 
611   const SCEV *LatchTakenCount = nullptr;
612   BasicBlock *OriginalPreheader = nullptr;
613 
614   // The preheader of the main loop.  This may or may not be different from
615   // `OriginalPreheader'.
616   BasicBlock *MainLoopPreheader = nullptr;
617 
618   // The range we need to run the main loop in.
619   InductiveRangeCheck::Range Range;
620 
621   // The structure of the main loop (see comment at the beginning of this class
622   // for a definition)
623   LoopStructure MainLoopStructure;
624 
625 public:
626   LoopConstrainer(Loop &L, LoopInfo &LI,
627                   function_ref<void(Loop *, bool)> LPMAddNewLoop,
628                   const LoopStructure &LS, ScalarEvolution &SE,
629                   DominatorTree &DT, InductiveRangeCheck::Range R)
630       : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
631         SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L),
632         Range(R), MainLoopStructure(LS) {}
633 
634   // Entry point for the algorithm.  Returns true on success.
635   bool run();
636 };
637 
638 } // end anonymous namespace
639 
640 /// Given a loop with an deccreasing induction variable, is it possible to
641 /// safely calculate the bounds of a new loop using the given Predicate.
642 static bool isSafeDecreasingBound(const SCEV *Start,
643                                   const SCEV *BoundSCEV, const SCEV *Step,
644                                   ICmpInst::Predicate Pred,
645                                   unsigned LatchBrExitIdx,
646                                   Loop *L, ScalarEvolution &SE) {
647   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
648       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
649     return false;
650 
651   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
652     return false;
653 
654   assert(SE.isKnownNegative(Step) && "expecting negative step");
655 
656   LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n");
657   LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
658   LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
659   LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
660   LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred)
661                     << "\n");
662   LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
663 
664   bool IsSigned = ICmpInst::isSigned(Pred);
665   // The predicate that we need to check that the induction variable lies
666   // within bounds.
667   ICmpInst::Predicate BoundPred =
668     IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
669 
670   if (LatchBrExitIdx == 1)
671     return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
672 
673   assert(LatchBrExitIdx == 0 &&
674          "LatchBrExitIdx should be either 0 or 1");
675 
676   const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
677   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
678   APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) :
679     APInt::getMinValue(BitWidth);
680   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
681 
682   const SCEV *MinusOne =
683     SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
684 
685   return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
686          SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
687 
688 }
689 
690 /// Given a loop with an increasing induction variable, is it possible to
691 /// safely calculate the bounds of a new loop using the given Predicate.
692 static bool isSafeIncreasingBound(const SCEV *Start,
693                                   const SCEV *BoundSCEV, const SCEV *Step,
694                                   ICmpInst::Predicate Pred,
695                                   unsigned LatchBrExitIdx,
696                                   Loop *L, ScalarEvolution &SE) {
697   if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
698       Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
699     return false;
700 
701   if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
702     return false;
703 
704   LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n");
705   LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
706   LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
707   LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
708   LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred)
709                     << "\n");
710   LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
711 
712   bool IsSigned = ICmpInst::isSigned(Pred);
713   // The predicate that we need to check that the induction variable lies
714   // within bounds.
715   ICmpInst::Predicate BoundPred =
716       IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
717 
718   if (LatchBrExitIdx == 1)
719     return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
720 
721   assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
722 
723   const SCEV *StepMinusOne =
724     SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
725   unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
726   APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) :
727     APInt::getMaxValue(BitWidth);
728   const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
729 
730   return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
731                                       SE.getAddExpr(BoundSCEV, Step)) &&
732           SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
733 }
734 
735 Optional<LoopStructure>
736 LoopStructure::parseLoopStructure(ScalarEvolution &SE,
737                                   BranchProbabilityInfo *BPI, Loop &L,
738                                   const char *&FailureReason) {
739   if (!L.isLoopSimplifyForm()) {
740     FailureReason = "loop not in LoopSimplify form";
741     return None;
742   }
743 
744   BasicBlock *Latch = L.getLoopLatch();
745   assert(Latch && "Simplified loops only have one latch!");
746 
747   if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
748     FailureReason = "loop has already been cloned";
749     return None;
750   }
751 
752   if (!L.isLoopExiting(Latch)) {
753     FailureReason = "no loop latch";
754     return None;
755   }
756 
757   BasicBlock *Header = L.getHeader();
758   BasicBlock *Preheader = L.getLoopPreheader();
759   if (!Preheader) {
760     FailureReason = "no preheader";
761     return None;
762   }
763 
764   BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
765   if (!LatchBr || LatchBr->isUnconditional()) {
766     FailureReason = "latch terminator not conditional branch";
767     return None;
768   }
769 
770   unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
771 
772   BranchProbability ExitProbability =
773       BPI ? BPI->getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx)
774           : BranchProbability::getZero();
775 
776   if (!SkipProfitabilityChecks &&
777       ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) {
778     FailureReason = "short running loop, not profitable";
779     return None;
780   }
781 
782   ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
783   if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
784     FailureReason = "latch terminator branch not conditional on integral icmp";
785     return None;
786   }
787 
788   const SCEV *LatchCount = SE.getExitCount(&L, Latch);
789   if (isa<SCEVCouldNotCompute>(LatchCount)) {
790     FailureReason = "could not compute latch count";
791     return None;
792   }
793 
794   ICmpInst::Predicate Pred = ICI->getPredicate();
795   Value *LeftValue = ICI->getOperand(0);
796   const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
797   IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
798 
799   Value *RightValue = ICI->getOperand(1);
800   const SCEV *RightSCEV = SE.getSCEV(RightValue);
801 
802   // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
803   if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
804     if (isa<SCEVAddRecExpr>(RightSCEV)) {
805       std::swap(LeftSCEV, RightSCEV);
806       std::swap(LeftValue, RightValue);
807       Pred = ICmpInst::getSwappedPredicate(Pred);
808     } else {
809       FailureReason = "no add recurrences in the icmp";
810       return None;
811     }
812   }
813 
814   auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
815     if (AR->getNoWrapFlags(SCEV::FlagNSW))
816       return true;
817 
818     IntegerType *Ty = cast<IntegerType>(AR->getType());
819     IntegerType *WideTy =
820         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
821 
822     const SCEVAddRecExpr *ExtendAfterOp =
823         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
824     if (ExtendAfterOp) {
825       const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
826       const SCEV *ExtendedStep =
827           SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
828 
829       bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
830                           ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
831 
832       if (NoSignedWrap)
833         return true;
834     }
835 
836     // We may have proved this when computing the sign extension above.
837     return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
838   };
839 
840   // `ICI` is interpreted as taking the backedge if the *next* value of the
841   // induction variable satisfies some constraint.
842 
843   const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
844   if (!IndVarBase->isAffine()) {
845     FailureReason = "LHS in icmp not induction variable";
846     return None;
847   }
848   const SCEV* StepRec = IndVarBase->getStepRecurrence(SE);
849   if (!isa<SCEVConstant>(StepRec)) {
850     FailureReason = "LHS in icmp not induction variable";
851     return None;
852   }
853   ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
854 
855   if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
856     FailureReason = "LHS in icmp needs nsw for equality predicates";
857     return None;
858   }
859 
860   assert(!StepCI->isZero() && "Zero step?");
861   bool IsIncreasing = !StepCI->isNegative();
862   bool IsSignedPredicate;
863   const SCEV *StartNext = IndVarBase->getStart();
864   const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
865   const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
866   const SCEV *Step = SE.getSCEV(StepCI);
867 
868   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
869   if (IsIncreasing) {
870     bool DecreasedRightValueByOne = false;
871     if (StepCI->isOne()) {
872       // Try to turn eq/ne predicates to those we can work with.
873       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
874         // while (++i != len) {         while (++i < len) {
875         //   ...                 --->     ...
876         // }                            }
877         // If both parts are known non-negative, it is profitable to use
878         // unsigned comparison in increasing loop. This allows us to make the
879         // comparison check against "RightSCEV + 1" more optimistic.
880         if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
881             isKnownNonNegativeInLoop(RightSCEV, &L, SE))
882           Pred = ICmpInst::ICMP_ULT;
883         else
884           Pred = ICmpInst::ICMP_SLT;
885       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
886         // while (true) {               while (true) {
887         //   if (++i == len)     --->     if (++i > len - 1)
888         //     break;                       break;
889         //   ...                          ...
890         // }                            }
891         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
892             cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) {
893           Pred = ICmpInst::ICMP_UGT;
894           RightSCEV = SE.getMinusSCEV(RightSCEV,
895                                       SE.getOne(RightSCEV->getType()));
896           DecreasedRightValueByOne = true;
897         } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) {
898           Pred = ICmpInst::ICMP_SGT;
899           RightSCEV = SE.getMinusSCEV(RightSCEV,
900                                       SE.getOne(RightSCEV->getType()));
901           DecreasedRightValueByOne = true;
902         }
903       }
904     }
905 
906     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
907     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
908     bool FoundExpectedPred =
909         (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
910 
911     if (!FoundExpectedPred) {
912       FailureReason = "expected icmp slt semantically, found something else";
913       return None;
914     }
915 
916     IsSignedPredicate = ICmpInst::isSigned(Pred);
917     if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
918       FailureReason = "unsigned latch conditions are explicitly prohibited";
919       return None;
920     }
921 
922     if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
923                                LatchBrExitIdx, &L, SE)) {
924       FailureReason = "Unsafe loop bounds";
925       return None;
926     }
927     if (LatchBrExitIdx == 0) {
928       // We need to increase the right value unless we have already decreased
929       // it virtually when we replaced EQ with SGT.
930       if (!DecreasedRightValueByOne) {
931         IRBuilder<> B(Preheader->getTerminator());
932         RightValue = B.CreateAdd(RightValue, One);
933       }
934     } else {
935       assert(!DecreasedRightValueByOne &&
936              "Right value can be decreased only for LatchBrExitIdx == 0!");
937     }
938   } else {
939     bool IncreasedRightValueByOne = false;
940     if (StepCI->isMinusOne()) {
941       // Try to turn eq/ne predicates to those we can work with.
942       if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
943         // while (--i != len) {         while (--i > len) {
944         //   ...                 --->     ...
945         // }                            }
946         // We intentionally don't turn the predicate into UGT even if we know
947         // that both operands are non-negative, because it will only pessimize
948         // our check against "RightSCEV - 1".
949         Pred = ICmpInst::ICMP_SGT;
950       else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
951         // while (true) {               while (true) {
952         //   if (--i == len)     --->     if (--i < len + 1)
953         //     break;                       break;
954         //   ...                          ...
955         // }                            }
956         if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
957             cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
958           Pred = ICmpInst::ICMP_ULT;
959           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
960           IncreasedRightValueByOne = true;
961         } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
962           Pred = ICmpInst::ICMP_SLT;
963           RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
964           IncreasedRightValueByOne = true;
965         }
966       }
967     }
968 
969     bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
970     bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
971 
972     bool FoundExpectedPred =
973         (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
974 
975     if (!FoundExpectedPred) {
976       FailureReason = "expected icmp sgt semantically, found something else";
977       return None;
978     }
979 
980     IsSignedPredicate =
981         Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
982 
983     if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
984       FailureReason = "unsigned latch conditions are explicitly prohibited";
985       return None;
986     }
987 
988     if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
989                                LatchBrExitIdx, &L, SE)) {
990       FailureReason = "Unsafe bounds";
991       return None;
992     }
993 
994     if (LatchBrExitIdx == 0) {
995       // We need to decrease the right value unless we have already increased
996       // it virtually when we replaced EQ with SLT.
997       if (!IncreasedRightValueByOne) {
998         IRBuilder<> B(Preheader->getTerminator());
999         RightValue = B.CreateSub(RightValue, One);
1000       }
1001     } else {
1002       assert(!IncreasedRightValueByOne &&
1003              "Right value can be increased only for LatchBrExitIdx == 0!");
1004     }
1005   }
1006   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
1007 
1008   assert(SE.getLoopDisposition(LatchCount, &L) ==
1009              ScalarEvolution::LoopInvariant &&
1010          "loop variant exit count doesn't make sense!");
1011 
1012   assert(!L.contains(LatchExit) && "expected an exit block!");
1013   const DataLayout &DL = Preheader->getModule()->getDataLayout();
1014   Value *IndVarStartV =
1015       SCEVExpander(SE, DL, "irce")
1016           .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator());
1017   IndVarStartV->setName("indvar.start");
1018 
1019   LoopStructure Result;
1020 
1021   Result.Tag = "main";
1022   Result.Header = Header;
1023   Result.Latch = Latch;
1024   Result.LatchBr = LatchBr;
1025   Result.LatchExit = LatchExit;
1026   Result.LatchBrExitIdx = LatchBrExitIdx;
1027   Result.IndVarStart = IndVarStartV;
1028   Result.IndVarStep = StepCI;
1029   Result.IndVarBase = LeftValue;
1030   Result.IndVarIncreasing = IsIncreasing;
1031   Result.LoopExitAt = RightValue;
1032   Result.IsSignedPredicate = IsSignedPredicate;
1033 
1034   FailureReason = nullptr;
1035 
1036   return Result;
1037 }
1038 
1039 /// If the type of \p S matches with \p Ty, return \p S. Otherwise, return
1040 /// signed or unsigned extension of \p S to type \p Ty.
1041 static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE,
1042                                 bool Signed) {
1043   return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty);
1044 }
1045 
1046 Optional<LoopConstrainer::SubRanges>
1047 LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
1048   IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType());
1049 
1050   auto *RTy = cast<IntegerType>(Range.getType());
1051 
1052   // We only support wide range checks and narrow latches.
1053   if (!AllowNarrowLatchCondition && RTy != Ty)
1054     return None;
1055   if (RTy->getBitWidth() < Ty->getBitWidth())
1056     return None;
1057 
1058   LoopConstrainer::SubRanges Result;
1059 
1060   // I think we can be more aggressive here and make this nuw / nsw if the
1061   // addition that feeds into the icmp for the latch's terminating branch is nuw
1062   // / nsw.  In any case, a wrapping 2's complement addition is safe.
1063   const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart),
1064                                    RTy, SE, IsSignedPredicate);
1065   const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy,
1066                                  SE, IsSignedPredicate);
1067 
1068   bool Increasing = MainLoopStructure.IndVarIncreasing;
1069 
1070   // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or
1071   // [Smallest, GreatestSeen] is the range of values the induction variable
1072   // takes.
1073 
1074   const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr;
1075 
1076   const SCEV *One = SE.getOne(RTy);
1077   if (Increasing) {
1078     Smallest = Start;
1079     Greatest = End;
1080     // No overflow, because the range [Smallest, GreatestSeen] is not empty.
1081     GreatestSeen = SE.getMinusSCEV(End, One);
1082   } else {
1083     // These two computations may sign-overflow.  Here is why that is okay:
1084     //
1085     // We know that the induction variable does not sign-overflow on any
1086     // iteration except the last one, and it starts at `Start` and ends at
1087     // `End`, decrementing by one every time.
1088     //
1089     //  * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
1090     //    induction variable is decreasing we know that that the smallest value
1091     //    the loop body is actually executed with is `INT_SMIN` == `Smallest`.
1092     //
1093     //  * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`.  In
1094     //    that case, `Clamp` will always return `Smallest` and
1095     //    [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`)
1096     //    will be an empty range.  Returning an empty range is always safe.
1097 
1098     Smallest = SE.getAddExpr(End, One);
1099     Greatest = SE.getAddExpr(Start, One);
1100     GreatestSeen = Start;
1101   }
1102 
1103   auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) {
1104     return IsSignedPredicate
1105                ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S))
1106                : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S));
1107   };
1108 
1109   // In some cases we can prove that we don't need a pre or post loop.
1110   ICmpInst::Predicate PredLE =
1111       IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
1112   ICmpInst::Predicate PredLT =
1113       IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1114 
1115   bool ProvablyNoPreloop =
1116       SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest);
1117   if (!ProvablyNoPreloop)
1118     Result.LowLimit = Clamp(Range.getBegin());
1119 
1120   bool ProvablyNoPostLoop =
1121       SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd());
1122   if (!ProvablyNoPostLoop)
1123     Result.HighLimit = Clamp(Range.getEnd());
1124 
1125   return Result;
1126 }
1127 
1128 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
1129                                 const char *Tag) const {
1130   for (BasicBlock *BB : OriginalLoop.getBlocks()) {
1131     BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
1132     Result.Blocks.push_back(Clone);
1133     Result.Map[BB] = Clone;
1134   }
1135 
1136   auto GetClonedValue = [&Result](Value *V) {
1137     assert(V && "null values not in domain!");
1138     auto It = Result.Map.find(V);
1139     if (It == Result.Map.end())
1140       return V;
1141     return static_cast<Value *>(It->second);
1142   };
1143 
1144   auto *ClonedLatch =
1145       cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
1146   ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
1147                                             MDNode::get(Ctx, {}));
1148 
1149   Result.Structure = MainLoopStructure.map(GetClonedValue);
1150   Result.Structure.Tag = Tag;
1151 
1152   for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
1153     BasicBlock *ClonedBB = Result.Blocks[i];
1154     BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
1155 
1156     assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
1157 
1158     for (Instruction &I : *ClonedBB)
1159       RemapInstruction(&I, Result.Map,
1160                        RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
1161 
1162     // Exit blocks will now have one more predecessor and their PHI nodes need
1163     // to be edited to reflect that.  No phi nodes need to be introduced because
1164     // the loop is in LCSSA.
1165 
1166     for (auto *SBB : successors(OriginalBB)) {
1167       if (OriginalLoop.contains(SBB))
1168         continue; // not an exit block
1169 
1170       for (PHINode &PN : SBB->phis()) {
1171         Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
1172         PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
1173       }
1174     }
1175   }
1176 }
1177 
1178 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
1179     const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
1180     BasicBlock *ContinuationBlock) const {
1181   // We start with a loop with a single latch:
1182   //
1183   //    +--------------------+
1184   //    |                    |
1185   //    |     preheader      |
1186   //    |                    |
1187   //    +--------+-----------+
1188   //             |      ----------------\
1189   //             |     /                |
1190   //    +--------v----v------+          |
1191   //    |                    |          |
1192   //    |      header        |          |
1193   //    |                    |          |
1194   //    +--------------------+          |
1195   //                                    |
1196   //            .....                   |
1197   //                                    |
1198   //    +--------------------+          |
1199   //    |                    |          |
1200   //    |       latch        >----------/
1201   //    |                    |
1202   //    +-------v------------+
1203   //            |
1204   //            |
1205   //            |   +--------------------+
1206   //            |   |                    |
1207   //            +--->   original exit    |
1208   //                |                    |
1209   //                +--------------------+
1210   //
1211   // We change the control flow to look like
1212   //
1213   //
1214   //    +--------------------+
1215   //    |                    |
1216   //    |     preheader      >-------------------------+
1217   //    |                    |                         |
1218   //    +--------v-----------+                         |
1219   //             |    /-------------+                  |
1220   //             |   /              |                  |
1221   //    +--------v--v--------+      |                  |
1222   //    |                    |      |                  |
1223   //    |      header        |      |   +--------+     |
1224   //    |                    |      |   |        |     |
1225   //    +--------------------+      |   |  +-----v-----v-----------+
1226   //                                |   |  |                       |
1227   //                                |   |  |     .pseudo.exit      |
1228   //                                |   |  |                       |
1229   //                                |   |  +-----------v-----------+
1230   //                                |   |              |
1231   //            .....               |   |              |
1232   //                                |   |     +--------v-------------+
1233   //    +--------------------+      |   |     |                      |
1234   //    |                    |      |   |     |   ContinuationBlock  |
1235   //    |       latch        >------+   |     |                      |
1236   //    |                    |          |     +----------------------+
1237   //    +---------v----------+          |
1238   //              |                     |
1239   //              |                     |
1240   //              |     +---------------^-----+
1241   //              |     |                     |
1242   //              +----->    .exit.selector   |
1243   //                    |                     |
1244   //                    +----------v----------+
1245   //                               |
1246   //     +--------------------+    |
1247   //     |                    |    |
1248   //     |   original exit    <----+
1249   //     |                    |
1250   //     +--------------------+
1251 
1252   RewrittenRangeInfo RRI;
1253 
1254   BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
1255   RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
1256                                         &F, BBInsertLocation);
1257   RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
1258                                       BBInsertLocation);
1259 
1260   BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
1261   bool Increasing = LS.IndVarIncreasing;
1262   bool IsSignedPredicate = LS.IsSignedPredicate;
1263 
1264   IRBuilder<> B(PreheaderJump);
1265   auto *RangeTy = Range.getBegin()->getType();
1266   auto NoopOrExt = [&](Value *V) {
1267     if (V->getType() == RangeTy)
1268       return V;
1269     return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
1270                              : B.CreateZExt(V, RangeTy, "wide." + V->getName());
1271   };
1272 
1273   // EnterLoopCond - is it okay to start executing this `LS'?
1274   Value *EnterLoopCond = nullptr;
1275   auto Pred =
1276       Increasing
1277           ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
1278           : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
1279   Value *IndVarStart = NoopOrExt(LS.IndVarStart);
1280   EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
1281 
1282   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
1283   PreheaderJump->eraseFromParent();
1284 
1285   LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
1286   B.SetInsertPoint(LS.LatchBr);
1287   Value *IndVarBase = NoopOrExt(LS.IndVarBase);
1288   Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
1289 
1290   Value *CondForBranch = LS.LatchBrExitIdx == 1
1291                              ? TakeBackedgeLoopCond
1292                              : B.CreateNot(TakeBackedgeLoopCond);
1293 
1294   LS.LatchBr->setCondition(CondForBranch);
1295 
1296   B.SetInsertPoint(RRI.ExitSelector);
1297 
1298   // IterationsLeft - are there any more iterations left, given the original
1299   // upper bound on the induction variable?  If not, we branch to the "real"
1300   // exit.
1301   Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
1302   Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
1303   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
1304 
1305   BranchInst *BranchToContinuation =
1306       BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
1307 
1308   // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
1309   // each of the PHI nodes in the loop header.  This feeds into the initial
1310   // value of the same PHI nodes if/when we continue execution.
1311   for (PHINode &PN : LS.Header->phis()) {
1312     PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
1313                                       BranchToContinuation);
1314 
1315     NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
1316     NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
1317                         RRI.ExitSelector);
1318     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
1319   }
1320 
1321   RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
1322                                   BranchToContinuation);
1323   RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
1324   RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
1325 
1326   // The latch exit now has a branch from `RRI.ExitSelector' instead of
1327   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
1328   LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
1329 
1330   return RRI;
1331 }
1332 
1333 void LoopConstrainer::rewriteIncomingValuesForPHIs(
1334     LoopStructure &LS, BasicBlock *ContinuationBlock,
1335     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
1336   unsigned PHIIndex = 0;
1337   for (PHINode &PN : LS.Header->phis())
1338     PN.setIncomingValueForBlock(ContinuationBlock,
1339                                 RRI.PHIValuesAtPseudoExit[PHIIndex++]);
1340 
1341   LS.IndVarStart = RRI.IndVarEnd;
1342 }
1343 
1344 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
1345                                              BasicBlock *OldPreheader,
1346                                              const char *Tag) const {
1347   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
1348   BranchInst::Create(LS.Header, Preheader);
1349 
1350   LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
1351 
1352   return Preheader;
1353 }
1354 
1355 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
1356   Loop *ParentLoop = OriginalLoop.getParentLoop();
1357   if (!ParentLoop)
1358     return;
1359 
1360   for (BasicBlock *BB : BBs)
1361     ParentLoop->addBasicBlockToLoop(BB, LI);
1362 }
1363 
1364 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
1365                                                  ValueToValueMapTy &VM,
1366                                                  bool IsSubloop) {
1367   Loop &New = *LI.AllocateLoop();
1368   if (Parent)
1369     Parent->addChildLoop(&New);
1370   else
1371     LI.addTopLevelLoop(&New);
1372   LPMAddNewLoop(&New, IsSubloop);
1373 
1374   // Add all of the blocks in Original to the new loop.
1375   for (auto *BB : Original->blocks())
1376     if (LI.getLoopFor(BB) == Original)
1377       New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
1378 
1379   // Add all of the subloops to the new loop.
1380   for (Loop *SubLoop : *Original)
1381     createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
1382 
1383   return &New;
1384 }
1385 
1386 bool LoopConstrainer::run() {
1387   BasicBlock *Preheader = nullptr;
1388   LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch);
1389   Preheader = OriginalLoop.getLoopPreheader();
1390   assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr &&
1391          "preconditions!");
1392 
1393   OriginalPreheader = Preheader;
1394   MainLoopPreheader = Preheader;
1395 
1396   bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
1397   Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate);
1398   if (!MaybeSR.hasValue()) {
1399     LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n");
1400     return false;
1401   }
1402 
1403   SubRanges SR = MaybeSR.getValue();
1404   bool Increasing = MainLoopStructure.IndVarIncreasing;
1405   IntegerType *IVTy =
1406       cast<IntegerType>(Range.getBegin()->getType());
1407 
1408   SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
1409   Instruction *InsertPt = OriginalPreheader->getTerminator();
1410 
1411   // It would have been better to make `PreLoop' and `PostLoop'
1412   // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
1413   // constructor.
1414   ClonedLoop PreLoop, PostLoop;
1415   bool NeedsPreLoop =
1416       Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue();
1417   bool NeedsPostLoop =
1418       Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue();
1419 
1420   Value *ExitPreLoopAt = nullptr;
1421   Value *ExitMainLoopAt = nullptr;
1422   const SCEVConstant *MinusOneS =
1423       cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
1424 
1425   if (NeedsPreLoop) {
1426     const SCEV *ExitPreLoopAtSCEV = nullptr;
1427 
1428     if (Increasing)
1429       ExitPreLoopAtSCEV = *SR.LowLimit;
1430     else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
1431                                IsSignedPredicate))
1432       ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
1433     else {
1434       LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1435                         << "preloop exit limit.  HighLimit = "
1436                         << *(*SR.HighLimit) << "\n");
1437       return false;
1438     }
1439 
1440     if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) {
1441       LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
1442                         << " preloop exit limit " << *ExitPreLoopAtSCEV
1443                         << " at block " << InsertPt->getParent()->getName()
1444                         << "\n");
1445       return false;
1446     }
1447 
1448     ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
1449     ExitPreLoopAt->setName("exit.preloop.at");
1450   }
1451 
1452   if (NeedsPostLoop) {
1453     const SCEV *ExitMainLoopAtSCEV = nullptr;
1454 
1455     if (Increasing)
1456       ExitMainLoopAtSCEV = *SR.HighLimit;
1457     else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
1458                                IsSignedPredicate))
1459       ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
1460     else {
1461       LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1462                         << "mainloop exit limit.  LowLimit = "
1463                         << *(*SR.LowLimit) << "\n");
1464       return false;
1465     }
1466 
1467     if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) {
1468       LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
1469                         << " main loop exit limit " << *ExitMainLoopAtSCEV
1470                         << " at block " << InsertPt->getParent()->getName()
1471                         << "\n");
1472       return false;
1473     }
1474 
1475     ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
1476     ExitMainLoopAt->setName("exit.mainloop.at");
1477   }
1478 
1479   // We clone these ahead of time so that we don't have to deal with changing
1480   // and temporarily invalid IR as we transform the loops.
1481   if (NeedsPreLoop)
1482     cloneLoop(PreLoop, "preloop");
1483   if (NeedsPostLoop)
1484     cloneLoop(PostLoop, "postloop");
1485 
1486   RewrittenRangeInfo PreLoopRRI;
1487 
1488   if (NeedsPreLoop) {
1489     Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
1490                                                   PreLoop.Structure.Header);
1491 
1492     MainLoopPreheader =
1493         createPreheader(MainLoopStructure, Preheader, "mainloop");
1494     PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
1495                                          ExitPreLoopAt, MainLoopPreheader);
1496     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
1497                                  PreLoopRRI);
1498   }
1499 
1500   BasicBlock *PostLoopPreheader = nullptr;
1501   RewrittenRangeInfo PostLoopRRI;
1502 
1503   if (NeedsPostLoop) {
1504     PostLoopPreheader =
1505         createPreheader(PostLoop.Structure, Preheader, "postloop");
1506     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
1507                                           ExitMainLoopAt, PostLoopPreheader);
1508     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
1509                                  PostLoopRRI);
1510   }
1511 
1512   BasicBlock *NewMainLoopPreheader =
1513       MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
1514   BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
1515                              PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
1516                              PostLoopRRI.ExitSelector, NewMainLoopPreheader};
1517 
1518   // Some of the above may be nullptr, filter them out before passing to
1519   // addToParentLoopIfNeeded.
1520   auto NewBlocksEnd =
1521       std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
1522 
1523   addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd));
1524 
1525   DT.recalculate(F);
1526 
1527   // We need to first add all the pre and post loop blocks into the loop
1528   // structures (as part of createClonedLoopStructure), and then update the
1529   // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
1530   // LI when LoopSimplifyForm is generated.
1531   Loop *PreL = nullptr, *PostL = nullptr;
1532   if (!PreLoop.Blocks.empty()) {
1533     PreL = createClonedLoopStructure(&OriginalLoop,
1534                                      OriginalLoop.getParentLoop(), PreLoop.Map,
1535                                      /* IsSubLoop */ false);
1536   }
1537 
1538   if (!PostLoop.Blocks.empty()) {
1539     PostL =
1540         createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
1541                                   PostLoop.Map, /* IsSubLoop */ false);
1542   }
1543 
1544   // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
1545   auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) {
1546     formLCSSARecursively(*L, DT, &LI, &SE);
1547     simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
1548     // Pre/post loops are slow paths, we do not need to perform any loop
1549     // optimizations on them.
1550     if (!IsOriginalLoop)
1551       DisableAllLoopOptsOnLoop(*L);
1552   };
1553   if (PreL)
1554     CanonicalizeLoop(PreL, false);
1555   if (PostL)
1556     CanonicalizeLoop(PostL, false);
1557   CanonicalizeLoop(&OriginalLoop, true);
1558 
1559   return true;
1560 }
1561 
1562 /// Computes and returns a range of values for the induction variable (IndVar)
1563 /// in which the range check can be safely elided.  If it cannot compute such a
1564 /// range, returns None.
1565 Optional<InductiveRangeCheck::Range>
1566 InductiveRangeCheck::computeSafeIterationSpace(
1567     ScalarEvolution &SE, const SCEVAddRecExpr *IndVar,
1568     bool IsLatchSigned) const {
1569   // We can deal when types of latch check and range checks don't match in case
1570   // if latch check is more narrow.
1571   auto *IVType = cast<IntegerType>(IndVar->getType());
1572   auto *RCType = cast<IntegerType>(getBegin()->getType());
1573   if (IVType->getBitWidth() > RCType->getBitWidth())
1574     return None;
1575   // IndVar is of the form "A + B * I" (where "I" is the canonical induction
1576   // variable, that may or may not exist as a real llvm::Value in the loop) and
1577   // this inductive range check is a range check on the "C + D * I" ("C" is
1578   // getBegin() and "D" is getStep()).  We rewrite the value being range
1579   // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA".
1580   //
1581   // The actual inequalities we solve are of the form
1582   //
1583   //   0 <= M + 1 * IndVar < L given L >= 0  (i.e. N == 1)
1584   //
1585   // Here L stands for upper limit of the safe iteration space.
1586   // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid
1587   // overflows when calculating (0 - M) and (L - M) we, depending on type of
1588   // IV's iteration space, limit the calculations by borders of the iteration
1589   // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0.
1590   // If we figured out that "anything greater than (-M) is safe", we strengthen
1591   // this to "everything greater than 0 is safe", assuming that values between
1592   // -M and 0 just do not exist in unsigned iteration space, and we don't want
1593   // to deal with overflown values.
1594 
1595   if (!IndVar->isAffine())
1596     return None;
1597 
1598   const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned);
1599   const SCEVConstant *B = dyn_cast<SCEVConstant>(
1600       NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned));
1601   if (!B)
1602     return None;
1603   assert(!B->isZero() && "Recurrence with zero step?");
1604 
1605   const SCEV *C = getBegin();
1606   const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep());
1607   if (D != B)
1608     return None;
1609 
1610   assert(!D->getValue()->isZero() && "Recurrence with zero step?");
1611   unsigned BitWidth = RCType->getBitWidth();
1612   const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
1613 
1614   // Subtract Y from X so that it does not go through border of the IV
1615   // iteration space. Mathematically, it is equivalent to:
1616   //
1617   //    ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX).        [1]
1618   //
1619   // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to
1620   // any width of bit grid). But after we take min/max, the result is
1621   // guaranteed to be within [INT_MIN, INT_MAX].
1622   //
1623   // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min
1624   // values, depending on type of latch condition that defines IV iteration
1625   // space.
1626   auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) {
1627     // FIXME: The current implementation assumes that X is in [0, SINT_MAX].
1628     // This is required to ensure that SINT_MAX - X does not overflow signed and
1629     // that X - Y does not overflow unsigned if Y is negative. Can we lift this
1630     // restriction and make it work for negative X either?
1631     if (IsLatchSigned) {
1632       // X is a number from signed range, Y is interpreted as signed.
1633       // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only
1634       // thing we should care about is that we didn't cross SINT_MAX.
1635       // So, if Y is positive, we subtract Y safely.
1636       //   Rule 1: Y > 0 ---> Y.
1637       // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely.
1638       //   Rule 2: Y >=s (X - SINT_MAX) ---> Y.
1639       // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX).
1640       //   Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX).
1641       // It gives us smax(Y, X - SINT_MAX) to subtract in all cases.
1642       const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax);
1643       return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax),
1644                              SCEV::FlagNSW);
1645     } else
1646       // X is a number from unsigned range, Y is interpreted as signed.
1647       // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only
1648       // thing we should care about is that we didn't cross zero.
1649       // So, if Y is negative, we subtract Y safely.
1650       //   Rule 1: Y <s 0 ---> Y.
1651       // If 0 <= Y <= X, we subtract Y safely.
1652       //   Rule 2: Y <=s X ---> Y.
1653       // If 0 <= X < Y, we should stop at 0 and can only subtract X.
1654       //   Rule 3: Y >s X ---> X.
1655       // It gives us smin(X, Y) to subtract in all cases.
1656       return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW);
1657   };
1658   const SCEV *M = SE.getMinusSCEV(C, A);
1659   const SCEV *Zero = SE.getZero(M->getType());
1660 
1661   // This function returns SCEV equal to 1 if X is non-negative 0 otherwise.
1662   auto SCEVCheckNonNegative = [&](const SCEV *X) {
1663     const Loop *L = IndVar->getLoop();
1664     const SCEV *One = SE.getOne(X->getType());
1665     // Can we trivially prove that X is a non-negative or negative value?
1666     if (isKnownNonNegativeInLoop(X, L, SE))
1667       return One;
1668     else if (isKnownNegativeInLoop(X, L, SE))
1669       return Zero;
1670     // If not, we will have to figure it out during the execution.
1671     // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0.
1672     const SCEV *NegOne = SE.getNegativeSCEV(One);
1673     return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One);
1674   };
1675   // FIXME: Current implementation of ClampedSubtract implicitly assumes that
1676   // X is non-negative (in sense of a signed value). We need to re-implement
1677   // this function in a way that it will correctly handle negative X as well.
1678   // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can
1679   // end up with a negative X and produce wrong results. So currently we ensure
1680   // that if getEnd() is negative then both ends of the safe range are zero.
1681   // Note that this may pessimize elimination of unsigned range checks against
1682   // negative values.
1683   const SCEV *REnd = getEnd();
1684   const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd);
1685 
1686   const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative);
1687   const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative);
1688   return InductiveRangeCheck::Range(Begin, End);
1689 }
1690 
1691 static Optional<InductiveRangeCheck::Range>
1692 IntersectSignedRange(ScalarEvolution &SE,
1693                      const Optional<InductiveRangeCheck::Range> &R1,
1694                      const InductiveRangeCheck::Range &R2) {
1695   if (R2.isEmpty(SE, /* IsSigned */ true))
1696     return None;
1697   if (!R1.hasValue())
1698     return R2;
1699   auto &R1Value = R1.getValue();
1700   // We never return empty ranges from this function, and R1 is supposed to be
1701   // a result of intersection. Thus, R1 is never empty.
1702   assert(!R1Value.isEmpty(SE, /* IsSigned */ true) &&
1703          "We should never have empty R1!");
1704 
1705   // TODO: we could widen the smaller range and have this work; but for now we
1706   // bail out to keep things simple.
1707   if (R1Value.getType() != R2.getType())
1708     return None;
1709 
1710   const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin());
1711   const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd());
1712 
1713   // If the resulting range is empty, just return None.
1714   auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd);
1715   if (Ret.isEmpty(SE, /* IsSigned */ true))
1716     return None;
1717   return Ret;
1718 }
1719 
1720 static Optional<InductiveRangeCheck::Range>
1721 IntersectUnsignedRange(ScalarEvolution &SE,
1722                        const Optional<InductiveRangeCheck::Range> &R1,
1723                        const InductiveRangeCheck::Range &R2) {
1724   if (R2.isEmpty(SE, /* IsSigned */ false))
1725     return None;
1726   if (!R1.hasValue())
1727     return R2;
1728   auto &R1Value = R1.getValue();
1729   // We never return empty ranges from this function, and R1 is supposed to be
1730   // a result of intersection. Thus, R1 is never empty.
1731   assert(!R1Value.isEmpty(SE, /* IsSigned */ false) &&
1732          "We should never have empty R1!");
1733 
1734   // TODO: we could widen the smaller range and have this work; but for now we
1735   // bail out to keep things simple.
1736   if (R1Value.getType() != R2.getType())
1737     return None;
1738 
1739   const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin());
1740   const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd());
1741 
1742   // If the resulting range is empty, just return None.
1743   auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd);
1744   if (Ret.isEmpty(SE, /* IsSigned */ false))
1745     return None;
1746   return Ret;
1747 }
1748 
1749 PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM,
1750                                 LoopStandardAnalysisResults &AR,
1751                                 LPMUpdater &U) {
1752   Function *F = L.getHeader()->getParent();
1753   const auto &FAM =
1754       AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
1755   auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
1756   InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI);
1757   auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) {
1758     if (!IsSubloop)
1759       U.addSiblingLoops(NL);
1760   };
1761   bool Changed = IRCE.run(&L, LPMAddNewLoop);
1762   if (!Changed)
1763     return PreservedAnalyses::all();
1764 
1765   return getLoopPassPreservedAnalyses();
1766 }
1767 
1768 bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {
1769   if (skipLoop(L))
1770     return false;
1771 
1772   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1773   BranchProbabilityInfo &BPI =
1774       getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
1775   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1776   auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1777   InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI);
1778   auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) {
1779     LPM.addLoop(*NL);
1780   };
1781   return IRCE.run(L, LPMAddNewLoop);
1782 }
1783 
1784 bool InductiveRangeCheckElimination::run(
1785     Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) {
1786   if (L->getBlocks().size() >= LoopSizeCutoff) {
1787     LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n");
1788     return false;
1789   }
1790 
1791   BasicBlock *Preheader = L->getLoopPreheader();
1792   if (!Preheader) {
1793     LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n");
1794     return false;
1795   }
1796 
1797   LLVMContext &Context = Preheader->getContext();
1798   SmallVector<InductiveRangeCheck, 16> RangeChecks;
1799 
1800   for (auto BBI : L->getBlocks())
1801     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
1802       InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI,
1803                                                         RangeChecks);
1804 
1805   if (RangeChecks.empty())
1806     return false;
1807 
1808   auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
1809     OS << "irce: looking at loop "; L->print(OS);
1810     OS << "irce: loop has " << RangeChecks.size()
1811        << " inductive range checks: \n";
1812     for (InductiveRangeCheck &IRC : RangeChecks)
1813       IRC.print(OS);
1814   };
1815 
1816   LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs()));
1817 
1818   if (PrintRangeChecks)
1819     PrintRecognizedRangeChecks(errs());
1820 
1821   const char *FailureReason = nullptr;
1822   Optional<LoopStructure> MaybeLoopStructure =
1823       LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason);
1824   if (!MaybeLoopStructure.hasValue()) {
1825     LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: "
1826                       << FailureReason << "\n";);
1827     return false;
1828   }
1829   LoopStructure LS = MaybeLoopStructure.getValue();
1830   const SCEVAddRecExpr *IndVar =
1831       cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep)));
1832 
1833   Optional<InductiveRangeCheck::Range> SafeIterRange;
1834   Instruction *ExprInsertPt = Preheader->getTerminator();
1835 
1836   SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate;
1837   // Basing on the type of latch predicate, we interpret the IV iteration range
1838   // as signed or unsigned range. We use different min/max functions (signed or
1839   // unsigned) when intersecting this range with safe iteration ranges implied
1840   // by range checks.
1841   auto IntersectRange =
1842       LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange;
1843 
1844   IRBuilder<> B(ExprInsertPt);
1845   for (InductiveRangeCheck &IRC : RangeChecks) {
1846     auto Result = IRC.computeSafeIterationSpace(SE, IndVar,
1847                                                 LS.IsSignedPredicate);
1848     if (Result.hasValue()) {
1849       auto MaybeSafeIterRange =
1850           IntersectRange(SE, SafeIterRange, Result.getValue());
1851       if (MaybeSafeIterRange.hasValue()) {
1852         assert(
1853             !MaybeSafeIterRange.getValue().isEmpty(SE, LS.IsSignedPredicate) &&
1854             "We should never return empty ranges!");
1855         RangeChecksToEliminate.push_back(IRC);
1856         SafeIterRange = MaybeSafeIterRange.getValue();
1857       }
1858     }
1859   }
1860 
1861   if (!SafeIterRange.hasValue())
1862     return false;
1863 
1864   LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT,
1865                      SafeIterRange.getValue());
1866   bool Changed = LC.run();
1867 
1868   if (Changed) {
1869     auto PrintConstrainedLoopInfo = [L]() {
1870       dbgs() << "irce: in function ";
1871       dbgs() << L->getHeader()->getParent()->getName() << ": ";
1872       dbgs() << "constrained ";
1873       L->print(dbgs());
1874     };
1875 
1876     LLVM_DEBUG(PrintConstrainedLoopInfo());
1877 
1878     if (PrintChangedLoops)
1879       PrintConstrainedLoopInfo();
1880 
1881     // Optimize away the now-redundant range checks.
1882 
1883     for (InductiveRangeCheck &IRC : RangeChecksToEliminate) {
1884       ConstantInt *FoldedRangeCheck = IRC.getPassingDirection()
1885                                           ? ConstantInt::getTrue(Context)
1886                                           : ConstantInt::getFalse(Context);
1887       IRC.getCheckUse()->set(FoldedRangeCheck);
1888     }
1889   }
1890 
1891   return Changed;
1892 }
1893 
1894 Pass *llvm::createInductiveRangeCheckEliminationPass() {
1895   return new IRCELegacyPass();
1896 }
1897