xref: /freebsd/contrib/llvm-project/llvm/lib/IR/ConstantRange.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- ConstantRange.cpp - ConstantRange implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Represent a range of possible values that may occur when the program is run
10 // for an integral value.  This keeps track of a lower and upper bound for the
11 // constant, which MAY wrap around the end of the numeric range.  To do this, it
12 // keeps track of a [lower, upper) bound, which specifies an interval just like
13 // STL iterators.  When used with boolean values, the following are important
14 // ranges (other integral ranges use min/max values for special range values):
15 //
16 //  [F, F) = {}     = Empty set
17 //  [T, F) = {T}
18 //  [F, T) = {F}
19 //  [T, T) = {F, T} = Full set
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "llvm/IR/ConstantRange.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/Config/llvm-config.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/InstrTypes.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Metadata.h"
32 #include "llvm/IR/Operator.h"
33 #include "llvm/Support/Compiler.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/KnownBits.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <algorithm>
39 #include <cassert>
40 #include <cstdint>
41 #include <optional>
42 
43 using namespace llvm;
44 
ConstantRange(uint32_t BitWidth,bool Full)45 ConstantRange::ConstantRange(uint32_t BitWidth, bool Full)
46     : Lower(Full ? APInt::getMaxValue(BitWidth) : APInt::getMinValue(BitWidth)),
47       Upper(Lower) {}
48 
ConstantRange(APInt V)49 ConstantRange::ConstantRange(APInt V)
50     : Lower(std::move(V)), Upper(Lower + 1) {}
51 
ConstantRange(APInt L,APInt U)52 ConstantRange::ConstantRange(APInt L, APInt U)
53     : Lower(std::move(L)), Upper(std::move(U)) {
54   assert(Lower.getBitWidth() == Upper.getBitWidth() &&
55          "ConstantRange with unequal bit widths");
56   assert((Lower != Upper || (Lower.isMaxValue() || Lower.isMinValue())) &&
57          "Lower == Upper, but they aren't min or max value!");
58 }
59 
fromKnownBits(const KnownBits & Known,bool IsSigned)60 ConstantRange ConstantRange::fromKnownBits(const KnownBits &Known,
61                                            bool IsSigned) {
62   if (Known.hasConflict())
63     return getEmpty(Known.getBitWidth());
64   if (Known.isUnknown())
65     return getFull(Known.getBitWidth());
66 
67   // For unsigned ranges, or signed ranges with known sign bit, create a simple
68   // range between the smallest and largest possible value.
69   if (!IsSigned || Known.isNegative() || Known.isNonNegative())
70     return ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1);
71 
72   // If we don't know the sign bit, pick the lower bound as a negative number
73   // and the upper bound as a non-negative one.
74   APInt Lower = Known.getMinValue(), Upper = Known.getMaxValue();
75   Lower.setSignBit();
76   Upper.clearSignBit();
77   return ConstantRange(Lower, Upper + 1);
78 }
79 
toKnownBits() const80 KnownBits ConstantRange::toKnownBits() const {
81   // TODO: We could return conflicting known bits here, but consumers are
82   // likely not prepared for that.
83   if (isEmptySet())
84     return KnownBits(getBitWidth());
85 
86   // We can only retain the top bits that are the same between min and max.
87   APInt Min = getUnsignedMin();
88   APInt Max = getUnsignedMax();
89   KnownBits Known = KnownBits::makeConstant(Min);
90   if (std::optional<unsigned> DifferentBit =
91           APIntOps::GetMostSignificantDifferentBit(Min, Max)) {
92     Known.Zero.clearLowBits(*DifferentBit + 1);
93     Known.One.clearLowBits(*DifferentBit + 1);
94   }
95   return Known;
96 }
97 
splitPosNeg() const98 std::pair<ConstantRange, ConstantRange> ConstantRange::splitPosNeg() const {
99   uint32_t BW = getBitWidth();
100   APInt Zero = APInt::getZero(BW), One = APInt(BW, 1);
101   APInt SignedMin = APInt::getSignedMinValue(BW);
102   // There are no positive 1-bit values. The 1 would get interpreted as -1.
103   ConstantRange PosFilter =
104       BW == 1 ? getEmpty() : ConstantRange(One, SignedMin);
105   ConstantRange NegFilter(SignedMin, Zero);
106   return {intersectWith(PosFilter), intersectWith(NegFilter)};
107 }
108 
makeAllowedICmpRegion(CmpInst::Predicate Pred,const ConstantRange & CR)109 ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
110                                                    const ConstantRange &CR) {
111   if (CR.isEmptySet())
112     return CR;
113 
114   uint32_t W = CR.getBitWidth();
115   switch (Pred) {
116   default:
117     llvm_unreachable("Invalid ICmp predicate to makeAllowedICmpRegion()");
118   case CmpInst::ICMP_EQ:
119     return CR;
120   case CmpInst::ICMP_NE:
121     if (CR.isSingleElement())
122       return ConstantRange(CR.getUpper(), CR.getLower());
123     return getFull(W);
124   case CmpInst::ICMP_ULT: {
125     APInt UMax(CR.getUnsignedMax());
126     if (UMax.isMinValue())
127       return getEmpty(W);
128     return ConstantRange(APInt::getMinValue(W), std::move(UMax));
129   }
130   case CmpInst::ICMP_SLT: {
131     APInt SMax(CR.getSignedMax());
132     if (SMax.isMinSignedValue())
133       return getEmpty(W);
134     return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax));
135   }
136   case CmpInst::ICMP_ULE:
137     return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
138   case CmpInst::ICMP_SLE:
139     return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
140   case CmpInst::ICMP_UGT: {
141     APInt UMin(CR.getUnsignedMin());
142     if (UMin.isMaxValue())
143       return getEmpty(W);
144     return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
145   }
146   case CmpInst::ICMP_SGT: {
147     APInt SMin(CR.getSignedMin());
148     if (SMin.isMaxSignedValue())
149       return getEmpty(W);
150     return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W));
151   }
152   case CmpInst::ICMP_UGE:
153     return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
154   case CmpInst::ICMP_SGE:
155     return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
156   }
157 }
158 
makeSatisfyingICmpRegion(CmpInst::Predicate Pred,const ConstantRange & CR)159 ConstantRange ConstantRange::makeSatisfyingICmpRegion(CmpInst::Predicate Pred,
160                                                       const ConstantRange &CR) {
161   // Follows from De-Morgan's laws:
162   //
163   // ~(~A union ~B) == A intersect B.
164   //
165   return makeAllowedICmpRegion(CmpInst::getInversePredicate(Pred), CR)
166       .inverse();
167 }
168 
makeExactICmpRegion(CmpInst::Predicate Pred,const APInt & C)169 ConstantRange ConstantRange::makeExactICmpRegion(CmpInst::Predicate Pred,
170                                                  const APInt &C) {
171   // Computes the exact range that is equal to both the constant ranges returned
172   // by makeAllowedICmpRegion and makeSatisfyingICmpRegion. This is always true
173   // when RHS is a singleton such as an APInt. However for non-singleton RHS,
174   // for example ult [2,5) makeAllowedICmpRegion returns [0,4) but
175   // makeSatisfyICmpRegion returns [0,2).
176   //
177   return makeAllowedICmpRegion(Pred, C);
178 }
179 
areInsensitiveToSignednessOfICmpPredicate(const ConstantRange & CR1,const ConstantRange & CR2)180 bool ConstantRange::areInsensitiveToSignednessOfICmpPredicate(
181     const ConstantRange &CR1, const ConstantRange &CR2) {
182   if (CR1.isEmptySet() || CR2.isEmptySet())
183     return true;
184 
185   return (CR1.isAllNonNegative() && CR2.isAllNonNegative()) ||
186          (CR1.isAllNegative() && CR2.isAllNegative());
187 }
188 
areInsensitiveToSignednessOfInvertedICmpPredicate(const ConstantRange & CR1,const ConstantRange & CR2)189 bool ConstantRange::areInsensitiveToSignednessOfInvertedICmpPredicate(
190     const ConstantRange &CR1, const ConstantRange &CR2) {
191   if (CR1.isEmptySet() || CR2.isEmptySet())
192     return true;
193 
194   return (CR1.isAllNonNegative() && CR2.isAllNegative()) ||
195          (CR1.isAllNegative() && CR2.isAllNonNegative());
196 }
197 
getEquivalentPredWithFlippedSignedness(CmpInst::Predicate Pred,const ConstantRange & CR1,const ConstantRange & CR2)198 CmpInst::Predicate ConstantRange::getEquivalentPredWithFlippedSignedness(
199     CmpInst::Predicate Pred, const ConstantRange &CR1,
200     const ConstantRange &CR2) {
201   assert(CmpInst::isIntPredicate(Pred) && CmpInst::isRelational(Pred) &&
202          "Only for relational integer predicates!");
203 
204   CmpInst::Predicate FlippedSignednessPred =
205       ICmpInst::getFlippedSignednessPredicate(Pred);
206 
207   if (areInsensitiveToSignednessOfICmpPredicate(CR1, CR2))
208     return FlippedSignednessPred;
209 
210   if (areInsensitiveToSignednessOfInvertedICmpPredicate(CR1, CR2))
211     return CmpInst::getInversePredicate(FlippedSignednessPred);
212 
213   return CmpInst::Predicate::BAD_ICMP_PREDICATE;
214 }
215 
getEquivalentICmp(CmpInst::Predicate & Pred,APInt & RHS,APInt & Offset) const216 void ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred,
217                                       APInt &RHS, APInt &Offset) const {
218   Offset = APInt(getBitWidth(), 0);
219   if (isFullSet() || isEmptySet()) {
220     Pred = isEmptySet() ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
221     RHS = APInt(getBitWidth(), 0);
222   } else if (auto *OnlyElt = getSingleElement()) {
223     Pred = CmpInst::ICMP_EQ;
224     RHS = *OnlyElt;
225   } else if (auto *OnlyMissingElt = getSingleMissingElement()) {
226     Pred = CmpInst::ICMP_NE;
227     RHS = *OnlyMissingElt;
228   } else if (getLower().isMinSignedValue() || getLower().isMinValue()) {
229     Pred =
230         getLower().isMinSignedValue() ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
231     RHS = getUpper();
232   } else if (getUpper().isMinSignedValue() || getUpper().isMinValue()) {
233     Pred =
234         getUpper().isMinSignedValue() ? CmpInst::ICMP_SGE : CmpInst::ICMP_UGE;
235     RHS = getLower();
236   } else {
237     Pred = CmpInst::ICMP_ULT;
238     RHS = getUpper() - getLower();
239     Offset = -getLower();
240   }
241 
242   assert(ConstantRange::makeExactICmpRegion(Pred, RHS) == add(Offset) &&
243          "Bad result!");
244 }
245 
getEquivalentICmp(CmpInst::Predicate & Pred,APInt & RHS) const246 bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred,
247                                       APInt &RHS) const {
248   APInt Offset;
249   getEquivalentICmp(Pred, RHS, Offset);
250   return Offset.isZero();
251 }
252 
icmp(CmpInst::Predicate Pred,const ConstantRange & Other) const253 bool ConstantRange::icmp(CmpInst::Predicate Pred,
254                          const ConstantRange &Other) const {
255   if (isEmptySet() || Other.isEmptySet())
256     return true;
257 
258   switch (Pred) {
259   case CmpInst::ICMP_EQ:
260     if (const APInt *L = getSingleElement())
261       if (const APInt *R = Other.getSingleElement())
262         return *L == *R;
263     return false;
264   case CmpInst::ICMP_NE:
265     return inverse().contains(Other);
266   case CmpInst::ICMP_ULT:
267     return getUnsignedMax().ult(Other.getUnsignedMin());
268   case CmpInst::ICMP_ULE:
269     return getUnsignedMax().ule(Other.getUnsignedMin());
270   case CmpInst::ICMP_UGT:
271     return getUnsignedMin().ugt(Other.getUnsignedMax());
272   case CmpInst::ICMP_UGE:
273     return getUnsignedMin().uge(Other.getUnsignedMax());
274   case CmpInst::ICMP_SLT:
275     return getSignedMax().slt(Other.getSignedMin());
276   case CmpInst::ICMP_SLE:
277     return getSignedMax().sle(Other.getSignedMin());
278   case CmpInst::ICMP_SGT:
279     return getSignedMin().sgt(Other.getSignedMax());
280   case CmpInst::ICMP_SGE:
281     return getSignedMin().sge(Other.getSignedMax());
282   default:
283     llvm_unreachable("Invalid ICmp predicate");
284   }
285 }
286 
287 /// Exact mul nuw region for single element RHS.
makeExactMulNUWRegion(const APInt & V)288 static ConstantRange makeExactMulNUWRegion(const APInt &V) {
289   unsigned BitWidth = V.getBitWidth();
290   if (V == 0)
291     return ConstantRange::getFull(V.getBitWidth());
292 
293   return ConstantRange::getNonEmpty(
294       APIntOps::RoundingUDiv(APInt::getMinValue(BitWidth), V,
295                              APInt::Rounding::UP),
296       APIntOps::RoundingUDiv(APInt::getMaxValue(BitWidth), V,
297                              APInt::Rounding::DOWN) + 1);
298 }
299 
300 /// Exact mul nsw region for single element RHS.
makeExactMulNSWRegion(const APInt & V)301 static ConstantRange makeExactMulNSWRegion(const APInt &V) {
302   // Handle 0 and -1 separately to avoid division by zero or overflow.
303   unsigned BitWidth = V.getBitWidth();
304   if (V == 0)
305     return ConstantRange::getFull(BitWidth);
306 
307   APInt MinValue = APInt::getSignedMinValue(BitWidth);
308   APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
309   // e.g. Returning [-127, 127], represented as [-127, -128).
310   if (V.isAllOnes())
311     return ConstantRange(-MaxValue, MinValue);
312 
313   APInt Lower, Upper;
314   if (V.isNegative()) {
315     Lower = APIntOps::RoundingSDiv(MaxValue, V, APInt::Rounding::UP);
316     Upper = APIntOps::RoundingSDiv(MinValue, V, APInt::Rounding::DOWN);
317   } else {
318     Lower = APIntOps::RoundingSDiv(MinValue, V, APInt::Rounding::UP);
319     Upper = APIntOps::RoundingSDiv(MaxValue, V, APInt::Rounding::DOWN);
320   }
321   return ConstantRange::getNonEmpty(Lower, Upper + 1);
322 }
323 
324 ConstantRange
makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp,const ConstantRange & Other,unsigned NoWrapKind)325 ConstantRange::makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp,
326                                           const ConstantRange &Other,
327                                           unsigned NoWrapKind) {
328   using OBO = OverflowingBinaryOperator;
329 
330   assert(Instruction::isBinaryOp(BinOp) && "Binary operators only!");
331 
332   assert((NoWrapKind == OBO::NoSignedWrap ||
333           NoWrapKind == OBO::NoUnsignedWrap) &&
334          "NoWrapKind invalid!");
335 
336   bool Unsigned = NoWrapKind == OBO::NoUnsignedWrap;
337   unsigned BitWidth = Other.getBitWidth();
338 
339   switch (BinOp) {
340   default:
341     llvm_unreachable("Unsupported binary op");
342 
343   case Instruction::Add: {
344     if (Unsigned)
345       return getNonEmpty(APInt::getZero(BitWidth), -Other.getUnsignedMax());
346 
347     APInt SignedMinVal = APInt::getSignedMinValue(BitWidth);
348     APInt SMin = Other.getSignedMin(), SMax = Other.getSignedMax();
349     return getNonEmpty(
350         SMin.isNegative() ? SignedMinVal - SMin : SignedMinVal,
351         SMax.isStrictlyPositive() ? SignedMinVal - SMax : SignedMinVal);
352   }
353 
354   case Instruction::Sub: {
355     if (Unsigned)
356       return getNonEmpty(Other.getUnsignedMax(), APInt::getMinValue(BitWidth));
357 
358     APInt SignedMinVal = APInt::getSignedMinValue(BitWidth);
359     APInt SMin = Other.getSignedMin(), SMax = Other.getSignedMax();
360     return getNonEmpty(
361         SMax.isStrictlyPositive() ? SignedMinVal + SMax : SignedMinVal,
362         SMin.isNegative() ? SignedMinVal + SMin : SignedMinVal);
363   }
364 
365   case Instruction::Mul:
366     if (Unsigned)
367       return makeExactMulNUWRegion(Other.getUnsignedMax());
368 
369     // Avoid one makeExactMulNSWRegion() call for the common case of constants.
370     if (const APInt *C = Other.getSingleElement())
371       return makeExactMulNSWRegion(*C);
372 
373     return makeExactMulNSWRegion(Other.getSignedMin())
374         .intersectWith(makeExactMulNSWRegion(Other.getSignedMax()));
375 
376   case Instruction::Shl: {
377     // For given range of shift amounts, if we ignore all illegal shift amounts
378     // (that always produce poison), what shift amount range is left?
379     ConstantRange ShAmt = Other.intersectWith(
380         ConstantRange(APInt(BitWidth, 0), APInt(BitWidth, (BitWidth - 1) + 1)));
381     if (ShAmt.isEmptySet()) {
382       // If the entire range of shift amounts is already poison-producing,
383       // then we can freely add more poison-producing flags ontop of that.
384       return getFull(BitWidth);
385     }
386     // There are some legal shift amounts, we can compute conservatively-correct
387     // range of no-wrap inputs. Note that by now we have clamped the ShAmtUMax
388     // to be at most bitwidth-1, which results in most conservative range.
389     APInt ShAmtUMax = ShAmt.getUnsignedMax();
390     if (Unsigned)
391       return getNonEmpty(APInt::getZero(BitWidth),
392                          APInt::getMaxValue(BitWidth).lshr(ShAmtUMax) + 1);
393     return getNonEmpty(APInt::getSignedMinValue(BitWidth).ashr(ShAmtUMax),
394                        APInt::getSignedMaxValue(BitWidth).ashr(ShAmtUMax) + 1);
395   }
396   }
397 }
398 
makeExactNoWrapRegion(Instruction::BinaryOps BinOp,const APInt & Other,unsigned NoWrapKind)399 ConstantRange ConstantRange::makeExactNoWrapRegion(Instruction::BinaryOps BinOp,
400                                                    const APInt &Other,
401                                                    unsigned NoWrapKind) {
402   // makeGuaranteedNoWrapRegion() is exact for single-element ranges, as
403   // "for all" and "for any" coincide in this case.
404   return makeGuaranteedNoWrapRegion(BinOp, ConstantRange(Other), NoWrapKind);
405 }
406 
makeMaskNotEqualRange(const APInt & Mask,const APInt & C)407 ConstantRange ConstantRange::makeMaskNotEqualRange(const APInt &Mask,
408                                                    const APInt &C) {
409   unsigned BitWidth = Mask.getBitWidth();
410 
411   if ((Mask & C) != C)
412     return getFull(BitWidth);
413 
414   if (Mask.isZero())
415     return getEmpty(BitWidth);
416 
417   // If (Val & Mask) != C, constrained to the non-equality being
418   // satisfiable, then the value must be larger than the lowest set bit of
419   // Mask, offset by constant C.
420   return ConstantRange::getNonEmpty(
421       APInt::getOneBitSet(BitWidth, Mask.countr_zero()) + C, C);
422 }
423 
isFullSet() const424 bool ConstantRange::isFullSet() const {
425   return Lower == Upper && Lower.isMaxValue();
426 }
427 
isEmptySet() const428 bool ConstantRange::isEmptySet() const {
429   return Lower == Upper && Lower.isMinValue();
430 }
431 
isWrappedSet() const432 bool ConstantRange::isWrappedSet() const {
433   return Lower.ugt(Upper) && !Upper.isZero();
434 }
435 
isUpperWrapped() const436 bool ConstantRange::isUpperWrapped() const {
437   return Lower.ugt(Upper);
438 }
439 
isSignWrappedSet() const440 bool ConstantRange::isSignWrappedSet() const {
441   return Lower.sgt(Upper) && !Upper.isMinSignedValue();
442 }
443 
isUpperSignWrapped() const444 bool ConstantRange::isUpperSignWrapped() const {
445   return Lower.sgt(Upper);
446 }
447 
448 bool
isSizeStrictlySmallerThan(const ConstantRange & Other) const449 ConstantRange::isSizeStrictlySmallerThan(const ConstantRange &Other) const {
450   assert(getBitWidth() == Other.getBitWidth());
451   if (isFullSet())
452     return false;
453   if (Other.isFullSet())
454     return true;
455   return (Upper - Lower).ult(Other.Upper - Other.Lower);
456 }
457 
458 bool
isSizeLargerThan(uint64_t MaxSize) const459 ConstantRange::isSizeLargerThan(uint64_t MaxSize) const {
460   // If this a full set, we need special handling to avoid needing an extra bit
461   // to represent the size.
462   if (isFullSet())
463     return MaxSize == 0 || APInt::getMaxValue(getBitWidth()).ugt(MaxSize - 1);
464 
465   return (Upper - Lower).ugt(MaxSize);
466 }
467 
isAllNegative() const468 bool ConstantRange::isAllNegative() const {
469   // Empty set is all negative, full set is not.
470   if (isEmptySet())
471     return true;
472   if (isFullSet())
473     return false;
474 
475   return !isUpperSignWrapped() && !Upper.isStrictlyPositive();
476 }
477 
isAllNonNegative() const478 bool ConstantRange::isAllNonNegative() const {
479   // Empty and full set are automatically treated correctly.
480   return !isSignWrappedSet() && Lower.isNonNegative();
481 }
482 
isAllPositive() const483 bool ConstantRange::isAllPositive() const {
484   // Empty set is all positive, full set is not.
485   if (isEmptySet())
486     return true;
487   if (isFullSet())
488     return false;
489 
490   return !isSignWrappedSet() && Lower.isStrictlyPositive();
491 }
492 
getUnsignedMax() const493 APInt ConstantRange::getUnsignedMax() const {
494   if (isFullSet() || isUpperWrapped())
495     return APInt::getMaxValue(getBitWidth());
496   return getUpper() - 1;
497 }
498 
getUnsignedMin() const499 APInt ConstantRange::getUnsignedMin() const {
500   if (isFullSet() || isWrappedSet())
501     return APInt::getMinValue(getBitWidth());
502   return getLower();
503 }
504 
getSignedMax() const505 APInt ConstantRange::getSignedMax() const {
506   if (isFullSet() || isUpperSignWrapped())
507     return APInt::getSignedMaxValue(getBitWidth());
508   return getUpper() - 1;
509 }
510 
getSignedMin() const511 APInt ConstantRange::getSignedMin() const {
512   if (isFullSet() || isSignWrappedSet())
513     return APInt::getSignedMinValue(getBitWidth());
514   return getLower();
515 }
516 
contains(const APInt & V) const517 bool ConstantRange::contains(const APInt &V) const {
518   if (Lower == Upper)
519     return isFullSet();
520 
521   if (!isUpperWrapped())
522     return Lower.ule(V) && V.ult(Upper);
523   return Lower.ule(V) || V.ult(Upper);
524 }
525 
contains(const ConstantRange & Other) const526 bool ConstantRange::contains(const ConstantRange &Other) const {
527   if (isFullSet() || Other.isEmptySet()) return true;
528   if (isEmptySet() || Other.isFullSet()) return false;
529 
530   if (!isUpperWrapped()) {
531     if (Other.isUpperWrapped())
532       return false;
533 
534     return Lower.ule(Other.getLower()) && Other.getUpper().ule(Upper);
535   }
536 
537   if (!Other.isUpperWrapped())
538     return Other.getUpper().ule(Upper) ||
539            Lower.ule(Other.getLower());
540 
541   return Other.getUpper().ule(Upper) && Lower.ule(Other.getLower());
542 }
543 
getActiveBits() const544 unsigned ConstantRange::getActiveBits() const {
545   if (isEmptySet())
546     return 0;
547 
548   return getUnsignedMax().getActiveBits();
549 }
550 
getMinSignedBits() const551 unsigned ConstantRange::getMinSignedBits() const {
552   if (isEmptySet())
553     return 0;
554 
555   return std::max(getSignedMin().getSignificantBits(),
556                   getSignedMax().getSignificantBits());
557 }
558 
subtract(const APInt & Val) const559 ConstantRange ConstantRange::subtract(const APInt &Val) const {
560   assert(Val.getBitWidth() == getBitWidth() && "Wrong bit width");
561   // If the set is empty or full, don't modify the endpoints.
562   if (Lower == Upper)
563     return *this;
564   return ConstantRange(Lower - Val, Upper - Val);
565 }
566 
difference(const ConstantRange & CR) const567 ConstantRange ConstantRange::difference(const ConstantRange &CR) const {
568   return intersectWith(CR.inverse());
569 }
570 
getPreferredRange(const ConstantRange & CR1,const ConstantRange & CR2,ConstantRange::PreferredRangeType Type)571 static ConstantRange getPreferredRange(
572     const ConstantRange &CR1, const ConstantRange &CR2,
573     ConstantRange::PreferredRangeType Type) {
574   if (Type == ConstantRange::Unsigned) {
575     if (!CR1.isWrappedSet() && CR2.isWrappedSet())
576       return CR1;
577     if (CR1.isWrappedSet() && !CR2.isWrappedSet())
578       return CR2;
579   } else if (Type == ConstantRange::Signed) {
580     if (!CR1.isSignWrappedSet() && CR2.isSignWrappedSet())
581       return CR1;
582     if (CR1.isSignWrappedSet() && !CR2.isSignWrappedSet())
583       return CR2;
584   }
585 
586   if (CR1.isSizeStrictlySmallerThan(CR2))
587     return CR1;
588   return CR2;
589 }
590 
intersectWith(const ConstantRange & CR,PreferredRangeType Type) const591 ConstantRange ConstantRange::intersectWith(const ConstantRange &CR,
592                                            PreferredRangeType Type) const {
593   assert(getBitWidth() == CR.getBitWidth() &&
594          "ConstantRange types don't agree!");
595 
596   // Handle common cases.
597   if (   isEmptySet() || CR.isFullSet()) return *this;
598   if (CR.isEmptySet() ||    isFullSet()) return CR;
599 
600   if (!isUpperWrapped() && CR.isUpperWrapped())
601     return CR.intersectWith(*this, Type);
602 
603   if (!isUpperWrapped() && !CR.isUpperWrapped()) {
604     if (Lower.ult(CR.Lower)) {
605       // L---U       : this
606       //       L---U : CR
607       if (Upper.ule(CR.Lower))
608         return getEmpty();
609 
610       // L---U       : this
611       //   L---U     : CR
612       if (Upper.ult(CR.Upper))
613         return ConstantRange(CR.Lower, Upper);
614 
615       // L-------U   : this
616       //   L---U     : CR
617       return CR;
618     }
619     //   L---U     : this
620     // L-------U   : CR
621     if (Upper.ult(CR.Upper))
622       return *this;
623 
624     //   L-----U   : this
625     // L-----U     : CR
626     if (Lower.ult(CR.Upper))
627       return ConstantRange(Lower, CR.Upper);
628 
629     //       L---U : this
630     // L---U       : CR
631     return getEmpty();
632   }
633 
634   if (isUpperWrapped() && !CR.isUpperWrapped()) {
635     if (CR.Lower.ult(Upper)) {
636       // ------U   L--- : this
637       //  L--U          : CR
638       if (CR.Upper.ult(Upper))
639         return CR;
640 
641       // ------U   L--- : this
642       //  L------U      : CR
643       if (CR.Upper.ule(Lower))
644         return ConstantRange(CR.Lower, Upper);
645 
646       // ------U   L--- : this
647       //  L----------U  : CR
648       return getPreferredRange(*this, CR, Type);
649     }
650     if (CR.Lower.ult(Lower)) {
651       // --U      L---- : this
652       //     L--U       : CR
653       if (CR.Upper.ule(Lower))
654         return getEmpty();
655 
656       // --U      L---- : this
657       //     L------U   : CR
658       return ConstantRange(Lower, CR.Upper);
659     }
660 
661     // --U  L------ : this
662     //        L--U  : CR
663     return CR;
664   }
665 
666   if (CR.Upper.ult(Upper)) {
667     // ------U L-- : this
668     // --U L------ : CR
669     if (CR.Lower.ult(Upper))
670       return getPreferredRange(*this, CR, Type);
671 
672     // ----U   L-- : this
673     // --U   L---- : CR
674     if (CR.Lower.ult(Lower))
675       return ConstantRange(Lower, CR.Upper);
676 
677     // ----U L---- : this
678     // --U     L-- : CR
679     return CR;
680   }
681   if (CR.Upper.ule(Lower)) {
682     // --U     L-- : this
683     // ----U L---- : CR
684     if (CR.Lower.ult(Lower))
685       return *this;
686 
687     // --U   L---- : this
688     // ----U   L-- : CR
689     return ConstantRange(CR.Lower, Upper);
690   }
691 
692   // --U L------ : this
693   // ------U L-- : CR
694   return getPreferredRange(*this, CR, Type);
695 }
696 
unionWith(const ConstantRange & CR,PreferredRangeType Type) const697 ConstantRange ConstantRange::unionWith(const ConstantRange &CR,
698                                        PreferredRangeType Type) const {
699   assert(getBitWidth() == CR.getBitWidth() &&
700          "ConstantRange types don't agree!");
701 
702   if (   isFullSet() || CR.isEmptySet()) return *this;
703   if (CR.isFullSet() ||    isEmptySet()) return CR;
704 
705   if (!isUpperWrapped() && CR.isUpperWrapped())
706     return CR.unionWith(*this, Type);
707 
708   if (!isUpperWrapped() && !CR.isUpperWrapped()) {
709     //        L---U  and  L---U        : this
710     //  L---U                   L---U  : CR
711     // result in one of
712     //  L---------U
713     // -----U L-----
714     if (CR.Upper.ult(Lower) || Upper.ult(CR.Lower))
715       return getPreferredRange(
716           ConstantRange(Lower, CR.Upper), ConstantRange(CR.Lower, Upper), Type);
717 
718     APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower;
719     APInt U = (CR.Upper - 1).ugt(Upper - 1) ? CR.Upper : Upper;
720 
721     if (L.isZero() && U.isZero())
722       return getFull();
723 
724     return ConstantRange(std::move(L), std::move(U));
725   }
726 
727   if (!CR.isUpperWrapped()) {
728     // ------U   L-----  and  ------U   L----- : this
729     //   L--U                            L--U  : CR
730     if (CR.Upper.ule(Upper) || CR.Lower.uge(Lower))
731       return *this;
732 
733     // ------U   L----- : this
734     //    L---------U   : CR
735     if (CR.Lower.ule(Upper) && Lower.ule(CR.Upper))
736       return getFull();
737 
738     // ----U       L---- : this
739     //       L---U       : CR
740     // results in one of
741     // ----------U L----
742     // ----U L----------
743     if (Upper.ult(CR.Lower) && CR.Upper.ult(Lower))
744       return getPreferredRange(
745           ConstantRange(Lower, CR.Upper), ConstantRange(CR.Lower, Upper), Type);
746 
747     // ----U     L----- : this
748     //        L----U    : CR
749     if (Upper.ult(CR.Lower) && Lower.ule(CR.Upper))
750       return ConstantRange(CR.Lower, Upper);
751 
752     // ------U    L---- : this
753     //    L-----U       : CR
754     assert(CR.Lower.ule(Upper) && CR.Upper.ult(Lower) &&
755            "ConstantRange::unionWith missed a case with one range wrapped");
756     return ConstantRange(Lower, CR.Upper);
757   }
758 
759   // ------U    L----  and  ------U    L---- : this
760   // -U  L-----------  and  ------------U  L : CR
761   if (CR.Lower.ule(Upper) || Lower.ule(CR.Upper))
762     return getFull();
763 
764   APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower;
765   APInt U = CR.Upper.ugt(Upper) ? CR.Upper : Upper;
766 
767   return ConstantRange(std::move(L), std::move(U));
768 }
769 
770 std::optional<ConstantRange>
exactIntersectWith(const ConstantRange & CR) const771 ConstantRange::exactIntersectWith(const ConstantRange &CR) const {
772   // TODO: This can be implemented more efficiently.
773   ConstantRange Result = intersectWith(CR);
774   if (Result == inverse().unionWith(CR.inverse()).inverse())
775     return Result;
776   return std::nullopt;
777 }
778 
779 std::optional<ConstantRange>
exactUnionWith(const ConstantRange & CR) const780 ConstantRange::exactUnionWith(const ConstantRange &CR) const {
781   // TODO: This can be implemented more efficiently.
782   ConstantRange Result = unionWith(CR);
783   if (Result == inverse().intersectWith(CR.inverse()).inverse())
784     return Result;
785   return std::nullopt;
786 }
787 
castOp(Instruction::CastOps CastOp,uint32_t ResultBitWidth) const788 ConstantRange ConstantRange::castOp(Instruction::CastOps CastOp,
789                                     uint32_t ResultBitWidth) const {
790   switch (CastOp) {
791   default:
792     llvm_unreachable("unsupported cast type");
793   case Instruction::Trunc:
794     return truncate(ResultBitWidth);
795   case Instruction::SExt:
796     return signExtend(ResultBitWidth);
797   case Instruction::ZExt:
798     return zeroExtend(ResultBitWidth);
799   case Instruction::BitCast:
800     return *this;
801   case Instruction::FPToUI:
802   case Instruction::FPToSI:
803     if (getBitWidth() == ResultBitWidth)
804       return *this;
805     else
806       return getFull(ResultBitWidth);
807   case Instruction::UIToFP: {
808     // TODO: use input range if available
809     auto BW = getBitWidth();
810     APInt Min = APInt::getMinValue(BW);
811     APInt Max = APInt::getMaxValue(BW);
812     if (ResultBitWidth > BW) {
813       Min = Min.zext(ResultBitWidth);
814       Max = Max.zext(ResultBitWidth);
815     }
816     return getNonEmpty(std::move(Min), std::move(Max) + 1);
817   }
818   case Instruction::SIToFP: {
819     // TODO: use input range if available
820     auto BW = getBitWidth();
821     APInt SMin = APInt::getSignedMinValue(BW);
822     APInt SMax = APInt::getSignedMaxValue(BW);
823     if (ResultBitWidth > BW) {
824       SMin = SMin.sext(ResultBitWidth);
825       SMax = SMax.sext(ResultBitWidth);
826     }
827     return getNonEmpty(std::move(SMin), std::move(SMax) + 1);
828   }
829   case Instruction::FPTrunc:
830   case Instruction::FPExt:
831   case Instruction::IntToPtr:
832   case Instruction::PtrToInt:
833   case Instruction::AddrSpaceCast:
834     // Conservatively return getFull set.
835     return getFull(ResultBitWidth);
836   };
837 }
838 
zeroExtend(uint32_t DstTySize) const839 ConstantRange ConstantRange::zeroExtend(uint32_t DstTySize) const {
840   if (isEmptySet()) return getEmpty(DstTySize);
841 
842   unsigned SrcTySize = getBitWidth();
843   assert(SrcTySize < DstTySize && "Not a value extension");
844   if (isFullSet() || isUpperWrapped()) {
845     // Change into [0, 1 << src bit width)
846     APInt LowerExt(DstTySize, 0);
847     if (!Upper) // special case: [X, 0) -- not really wrapping around
848       LowerExt = Lower.zext(DstTySize);
849     return ConstantRange(std::move(LowerExt),
850                          APInt::getOneBitSet(DstTySize, SrcTySize));
851   }
852 
853   return ConstantRange(Lower.zext(DstTySize), Upper.zext(DstTySize));
854 }
855 
signExtend(uint32_t DstTySize) const856 ConstantRange ConstantRange::signExtend(uint32_t DstTySize) const {
857   if (isEmptySet()) return getEmpty(DstTySize);
858 
859   unsigned SrcTySize = getBitWidth();
860   assert(SrcTySize < DstTySize && "Not a value extension");
861 
862   // special case: [X, INT_MIN) -- not really wrapping around
863   if (Upper.isMinSignedValue())
864     return ConstantRange(Lower.sext(DstTySize), Upper.zext(DstTySize));
865 
866   if (isFullSet() || isSignWrappedSet()) {
867     return ConstantRange(APInt::getHighBitsSet(DstTySize,DstTySize-SrcTySize+1),
868                          APInt::getLowBitsSet(DstTySize, SrcTySize-1) + 1);
869   }
870 
871   return ConstantRange(Lower.sext(DstTySize), Upper.sext(DstTySize));
872 }
873 
truncate(uint32_t DstTySize) const874 ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
875   assert(getBitWidth() > DstTySize && "Not a value truncation");
876   if (isEmptySet())
877     return getEmpty(DstTySize);
878   if (isFullSet())
879     return getFull(DstTySize);
880 
881   APInt LowerDiv(Lower), UpperDiv(Upper);
882   ConstantRange Union(DstTySize, /*isFullSet=*/false);
883 
884   // Analyze wrapped sets in their two parts: [0, Upper) \/ [Lower, MaxValue]
885   // We use the non-wrapped set code to analyze the [Lower, MaxValue) part, and
886   // then we do the union with [MaxValue, Upper)
887   if (isUpperWrapped()) {
888     // If Upper is greater than or equal to MaxValue(DstTy), it covers the whole
889     // truncated range.
890     if (Upper.getActiveBits() > DstTySize || Upper.countr_one() == DstTySize)
891       return getFull(DstTySize);
892 
893     Union = ConstantRange(APInt::getMaxValue(DstTySize),Upper.trunc(DstTySize));
894     UpperDiv.setAllBits();
895 
896     // Union covers the MaxValue case, so return if the remaining range is just
897     // MaxValue(DstTy).
898     if (LowerDiv == UpperDiv)
899       return Union;
900   }
901 
902   // Chop off the most significant bits that are past the destination bitwidth.
903   if (LowerDiv.getActiveBits() > DstTySize) {
904     // Mask to just the signficant bits and subtract from LowerDiv/UpperDiv.
905     APInt Adjust = LowerDiv & APInt::getBitsSetFrom(getBitWidth(), DstTySize);
906     LowerDiv -= Adjust;
907     UpperDiv -= Adjust;
908   }
909 
910   unsigned UpperDivWidth = UpperDiv.getActiveBits();
911   if (UpperDivWidth <= DstTySize)
912     return ConstantRange(LowerDiv.trunc(DstTySize),
913                          UpperDiv.trunc(DstTySize)).unionWith(Union);
914 
915   // The truncated value wraps around. Check if we can do better than fullset.
916   if (UpperDivWidth == DstTySize + 1) {
917     // Clear the MSB so that UpperDiv wraps around.
918     UpperDiv.clearBit(DstTySize);
919     if (UpperDiv.ult(LowerDiv))
920       return ConstantRange(LowerDiv.trunc(DstTySize),
921                            UpperDiv.trunc(DstTySize)).unionWith(Union);
922   }
923 
924   return getFull(DstTySize);
925 }
926 
zextOrTrunc(uint32_t DstTySize) const927 ConstantRange ConstantRange::zextOrTrunc(uint32_t DstTySize) const {
928   unsigned SrcTySize = getBitWidth();
929   if (SrcTySize > DstTySize)
930     return truncate(DstTySize);
931   if (SrcTySize < DstTySize)
932     return zeroExtend(DstTySize);
933   return *this;
934 }
935 
sextOrTrunc(uint32_t DstTySize) const936 ConstantRange ConstantRange::sextOrTrunc(uint32_t DstTySize) const {
937   unsigned SrcTySize = getBitWidth();
938   if (SrcTySize > DstTySize)
939     return truncate(DstTySize);
940   if (SrcTySize < DstTySize)
941     return signExtend(DstTySize);
942   return *this;
943 }
944 
binaryOp(Instruction::BinaryOps BinOp,const ConstantRange & Other) const945 ConstantRange ConstantRange::binaryOp(Instruction::BinaryOps BinOp,
946                                       const ConstantRange &Other) const {
947   assert(Instruction::isBinaryOp(BinOp) && "Binary operators only!");
948 
949   switch (BinOp) {
950   case Instruction::Add:
951     return add(Other);
952   case Instruction::Sub:
953     return sub(Other);
954   case Instruction::Mul:
955     return multiply(Other);
956   case Instruction::UDiv:
957     return udiv(Other);
958   case Instruction::SDiv:
959     return sdiv(Other);
960   case Instruction::URem:
961     return urem(Other);
962   case Instruction::SRem:
963     return srem(Other);
964   case Instruction::Shl:
965     return shl(Other);
966   case Instruction::LShr:
967     return lshr(Other);
968   case Instruction::AShr:
969     return ashr(Other);
970   case Instruction::And:
971     return binaryAnd(Other);
972   case Instruction::Or:
973     return binaryOr(Other);
974   case Instruction::Xor:
975     return binaryXor(Other);
976   // Note: floating point operations applied to abstract ranges are just
977   // ideal integer operations with a lossy representation
978   case Instruction::FAdd:
979     return add(Other);
980   case Instruction::FSub:
981     return sub(Other);
982   case Instruction::FMul:
983     return multiply(Other);
984   default:
985     // Conservatively return getFull set.
986     return getFull();
987   }
988 }
989 
overflowingBinaryOp(Instruction::BinaryOps BinOp,const ConstantRange & Other,unsigned NoWrapKind) const990 ConstantRange ConstantRange::overflowingBinaryOp(Instruction::BinaryOps BinOp,
991                                                  const ConstantRange &Other,
992                                                  unsigned NoWrapKind) const {
993   assert(Instruction::isBinaryOp(BinOp) && "Binary operators only!");
994 
995   switch (BinOp) {
996   case Instruction::Add:
997     return addWithNoWrap(Other, NoWrapKind);
998   case Instruction::Sub:
999     return subWithNoWrap(Other, NoWrapKind);
1000   case Instruction::Mul:
1001     return multiplyWithNoWrap(Other, NoWrapKind);
1002   case Instruction::Shl:
1003     return shlWithNoWrap(Other, NoWrapKind);
1004   default:
1005     // Don't know about this Overflowing Binary Operation.
1006     // Conservatively fallback to plain binop handling.
1007     return binaryOp(BinOp, Other);
1008   }
1009 }
1010 
isIntrinsicSupported(Intrinsic::ID IntrinsicID)1011 bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
1012   switch (IntrinsicID) {
1013   case Intrinsic::uadd_sat:
1014   case Intrinsic::usub_sat:
1015   case Intrinsic::sadd_sat:
1016   case Intrinsic::ssub_sat:
1017   case Intrinsic::umin:
1018   case Intrinsic::umax:
1019   case Intrinsic::smin:
1020   case Intrinsic::smax:
1021   case Intrinsic::abs:
1022   case Intrinsic::ctlz:
1023   case Intrinsic::cttz:
1024   case Intrinsic::ctpop:
1025     return true;
1026   default:
1027     return false;
1028   }
1029 }
1030 
intrinsic(Intrinsic::ID IntrinsicID,ArrayRef<ConstantRange> Ops)1031 ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID,
1032                                        ArrayRef<ConstantRange> Ops) {
1033   switch (IntrinsicID) {
1034   case Intrinsic::uadd_sat:
1035     return Ops[0].uadd_sat(Ops[1]);
1036   case Intrinsic::usub_sat:
1037     return Ops[0].usub_sat(Ops[1]);
1038   case Intrinsic::sadd_sat:
1039     return Ops[0].sadd_sat(Ops[1]);
1040   case Intrinsic::ssub_sat:
1041     return Ops[0].ssub_sat(Ops[1]);
1042   case Intrinsic::umin:
1043     return Ops[0].umin(Ops[1]);
1044   case Intrinsic::umax:
1045     return Ops[0].umax(Ops[1]);
1046   case Intrinsic::smin:
1047     return Ops[0].smin(Ops[1]);
1048   case Intrinsic::smax:
1049     return Ops[0].smax(Ops[1]);
1050   case Intrinsic::abs: {
1051     const APInt *IntMinIsPoison = Ops[1].getSingleElement();
1052     assert(IntMinIsPoison && "Must be known (immarg)");
1053     assert(IntMinIsPoison->getBitWidth() == 1 && "Must be boolean");
1054     return Ops[0].abs(IntMinIsPoison->getBoolValue());
1055   }
1056   case Intrinsic::ctlz: {
1057     const APInt *ZeroIsPoison = Ops[1].getSingleElement();
1058     assert(ZeroIsPoison && "Must be known (immarg)");
1059     assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
1060     return Ops[0].ctlz(ZeroIsPoison->getBoolValue());
1061   }
1062   case Intrinsic::cttz: {
1063     const APInt *ZeroIsPoison = Ops[1].getSingleElement();
1064     assert(ZeroIsPoison && "Must be known (immarg)");
1065     assert(ZeroIsPoison->getBitWidth() == 1 && "Must be boolean");
1066     return Ops[0].cttz(ZeroIsPoison->getBoolValue());
1067   }
1068   case Intrinsic::ctpop:
1069     return Ops[0].ctpop();
1070   default:
1071     assert(!isIntrinsicSupported(IntrinsicID) && "Shouldn't be supported");
1072     llvm_unreachable("Unsupported intrinsic");
1073   }
1074 }
1075 
1076 ConstantRange
add(const ConstantRange & Other) const1077 ConstantRange::add(const ConstantRange &Other) const {
1078   if (isEmptySet() || Other.isEmptySet())
1079     return getEmpty();
1080   if (isFullSet() || Other.isFullSet())
1081     return getFull();
1082 
1083   APInt NewLower = getLower() + Other.getLower();
1084   APInt NewUpper = getUpper() + Other.getUpper() - 1;
1085   if (NewLower == NewUpper)
1086     return getFull();
1087 
1088   ConstantRange X = ConstantRange(std::move(NewLower), std::move(NewUpper));
1089   if (X.isSizeStrictlySmallerThan(*this) ||
1090       X.isSizeStrictlySmallerThan(Other))
1091     // We've wrapped, therefore, full set.
1092     return getFull();
1093   return X;
1094 }
1095 
addWithNoWrap(const ConstantRange & Other,unsigned NoWrapKind,PreferredRangeType RangeType) const1096 ConstantRange ConstantRange::addWithNoWrap(const ConstantRange &Other,
1097                                            unsigned NoWrapKind,
1098                                            PreferredRangeType RangeType) const {
1099   // Calculate the range for "X + Y" which is guaranteed not to wrap(overflow).
1100   // (X is from this, and Y is from Other)
1101   if (isEmptySet() || Other.isEmptySet())
1102     return getEmpty();
1103   if (isFullSet() && Other.isFullSet())
1104     return getFull();
1105 
1106   using OBO = OverflowingBinaryOperator;
1107   ConstantRange Result = add(Other);
1108 
1109   // If an overflow happens for every value pair in these two constant ranges,
1110   // we must return Empty set. In this case, we get that for free, because we
1111   // get lucky that intersection of add() with uadd_sat()/sadd_sat() results
1112   // in an empty set.
1113 
1114   if (NoWrapKind & OBO::NoSignedWrap)
1115     Result = Result.intersectWith(sadd_sat(Other), RangeType);
1116 
1117   if (NoWrapKind & OBO::NoUnsignedWrap)
1118     Result = Result.intersectWith(uadd_sat(Other), RangeType);
1119 
1120   return Result;
1121 }
1122 
1123 ConstantRange
sub(const ConstantRange & Other) const1124 ConstantRange::sub(const ConstantRange &Other) const {
1125   if (isEmptySet() || Other.isEmptySet())
1126     return getEmpty();
1127   if (isFullSet() || Other.isFullSet())
1128     return getFull();
1129 
1130   APInt NewLower = getLower() - Other.getUpper() + 1;
1131   APInt NewUpper = getUpper() - Other.getLower();
1132   if (NewLower == NewUpper)
1133     return getFull();
1134 
1135   ConstantRange X = ConstantRange(std::move(NewLower), std::move(NewUpper));
1136   if (X.isSizeStrictlySmallerThan(*this) ||
1137       X.isSizeStrictlySmallerThan(Other))
1138     // We've wrapped, therefore, full set.
1139     return getFull();
1140   return X;
1141 }
1142 
subWithNoWrap(const ConstantRange & Other,unsigned NoWrapKind,PreferredRangeType RangeType) const1143 ConstantRange ConstantRange::subWithNoWrap(const ConstantRange &Other,
1144                                            unsigned NoWrapKind,
1145                                            PreferredRangeType RangeType) const {
1146   // Calculate the range for "X - Y" which is guaranteed not to wrap(overflow).
1147   // (X is from this, and Y is from Other)
1148   if (isEmptySet() || Other.isEmptySet())
1149     return getEmpty();
1150   if (isFullSet() && Other.isFullSet())
1151     return getFull();
1152 
1153   using OBO = OverflowingBinaryOperator;
1154   ConstantRange Result = sub(Other);
1155 
1156   // If an overflow happens for every value pair in these two constant ranges,
1157   // we must return Empty set. In signed case, we get that for free, because we
1158   // get lucky that intersection of sub() with ssub_sat() results in an
1159   // empty set. But for unsigned we must perform the overflow check manually.
1160 
1161   if (NoWrapKind & OBO::NoSignedWrap)
1162     Result = Result.intersectWith(ssub_sat(Other), RangeType);
1163 
1164   if (NoWrapKind & OBO::NoUnsignedWrap) {
1165     if (getUnsignedMax().ult(Other.getUnsignedMin()))
1166       return getEmpty(); // Always overflows.
1167     Result = Result.intersectWith(usub_sat(Other), RangeType);
1168   }
1169 
1170   return Result;
1171 }
1172 
1173 ConstantRange
multiply(const ConstantRange & Other) const1174 ConstantRange::multiply(const ConstantRange &Other) const {
1175   // TODO: If either operand is a single element and the multiply is known to
1176   // be non-wrapping, round the result min and max value to the appropriate
1177   // multiple of that element. If wrapping is possible, at least adjust the
1178   // range according to the greatest power-of-two factor of the single element.
1179 
1180   if (isEmptySet() || Other.isEmptySet())
1181     return getEmpty();
1182 
1183   if (const APInt *C = getSingleElement()) {
1184     if (C->isOne())
1185       return Other;
1186     if (C->isAllOnes())
1187       return ConstantRange(APInt::getZero(getBitWidth())).sub(Other);
1188   }
1189 
1190   if (const APInt *C = Other.getSingleElement()) {
1191     if (C->isOne())
1192       return *this;
1193     if (C->isAllOnes())
1194       return ConstantRange(APInt::getZero(getBitWidth())).sub(*this);
1195   }
1196 
1197   // Multiplication is signedness-independent. However different ranges can be
1198   // obtained depending on how the input ranges are treated. These different
1199   // ranges are all conservatively correct, but one might be better than the
1200   // other. We calculate two ranges; one treating the inputs as unsigned
1201   // and the other signed, then return the smallest of these ranges.
1202 
1203   // Unsigned range first.
1204   APInt this_min = getUnsignedMin().zext(getBitWidth() * 2);
1205   APInt this_max = getUnsignedMax().zext(getBitWidth() * 2);
1206   APInt Other_min = Other.getUnsignedMin().zext(getBitWidth() * 2);
1207   APInt Other_max = Other.getUnsignedMax().zext(getBitWidth() * 2);
1208 
1209   ConstantRange Result_zext = ConstantRange(this_min * Other_min,
1210                                             this_max * Other_max + 1);
1211   ConstantRange UR = Result_zext.truncate(getBitWidth());
1212 
1213   // If the unsigned range doesn't wrap, and isn't negative then it's a range
1214   // from one positive number to another which is as good as we can generate.
1215   // In this case, skip the extra work of generating signed ranges which aren't
1216   // going to be better than this range.
1217   if (!UR.isUpperWrapped() &&
1218       (UR.getUpper().isNonNegative() || UR.getUpper().isMinSignedValue()))
1219     return UR;
1220 
1221   // Now the signed range. Because we could be dealing with negative numbers
1222   // here, the lower bound is the smallest of the cartesian product of the
1223   // lower and upper ranges; for example:
1224   //   [-1,4) * [-2,3) = min(-1*-2, -1*2, 3*-2, 3*2) = -6.
1225   // Similarly for the upper bound, swapping min for max.
1226 
1227   this_min = getSignedMin().sext(getBitWidth() * 2);
1228   this_max = getSignedMax().sext(getBitWidth() * 2);
1229   Other_min = Other.getSignedMin().sext(getBitWidth() * 2);
1230   Other_max = Other.getSignedMax().sext(getBitWidth() * 2);
1231 
1232   auto L = {this_min * Other_min, this_min * Other_max,
1233             this_max * Other_min, this_max * Other_max};
1234   auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); };
1235   ConstantRange Result_sext(std::min(L, Compare), std::max(L, Compare) + 1);
1236   ConstantRange SR = Result_sext.truncate(getBitWidth());
1237 
1238   return UR.isSizeStrictlySmallerThan(SR) ? UR : SR;
1239 }
1240 
1241 ConstantRange
multiplyWithNoWrap(const ConstantRange & Other,unsigned NoWrapKind,PreferredRangeType RangeType) const1242 ConstantRange::multiplyWithNoWrap(const ConstantRange &Other,
1243                                   unsigned NoWrapKind,
1244                                   PreferredRangeType RangeType) const {
1245   if (isEmptySet() || Other.isEmptySet())
1246     return getEmpty();
1247   if (isFullSet() && Other.isFullSet())
1248     return getFull();
1249 
1250   ConstantRange Result = multiply(Other);
1251 
1252   if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
1253     Result = Result.intersectWith(smul_sat(Other), RangeType);
1254 
1255   if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
1256     Result = Result.intersectWith(umul_sat(Other), RangeType);
1257 
1258   // mul nsw nuw X, Y s>= 0 if X s> 1 or Y s> 1
1259   if ((NoWrapKind == (OverflowingBinaryOperator::NoSignedWrap |
1260                       OverflowingBinaryOperator::NoUnsignedWrap)) &&
1261       !Result.isAllNonNegative()) {
1262     if (getSignedMin().sgt(1) || Other.getSignedMin().sgt(1))
1263       Result = Result.intersectWith(
1264           getNonEmpty(APInt::getZero(getBitWidth()),
1265                       APInt::getSignedMinValue(getBitWidth())),
1266           RangeType);
1267   }
1268 
1269   return Result;
1270 }
1271 
smul_fast(const ConstantRange & Other) const1272 ConstantRange ConstantRange::smul_fast(const ConstantRange &Other) const {
1273   if (isEmptySet() || Other.isEmptySet())
1274     return getEmpty();
1275 
1276   APInt Min = getSignedMin();
1277   APInt Max = getSignedMax();
1278   APInt OtherMin = Other.getSignedMin();
1279   APInt OtherMax = Other.getSignedMax();
1280 
1281   bool O1, O2, O3, O4;
1282   auto Muls = {Min.smul_ov(OtherMin, O1), Min.smul_ov(OtherMax, O2),
1283                Max.smul_ov(OtherMin, O3), Max.smul_ov(OtherMax, O4)};
1284   if (O1 || O2 || O3 || O4)
1285     return getFull();
1286 
1287   auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); };
1288   return getNonEmpty(std::min(Muls, Compare), std::max(Muls, Compare) + 1);
1289 }
1290 
1291 ConstantRange
smax(const ConstantRange & Other) const1292 ConstantRange::smax(const ConstantRange &Other) const {
1293   // X smax Y is: range(smax(X_smin, Y_smin),
1294   //                    smax(X_smax, Y_smax))
1295   if (isEmptySet() || Other.isEmptySet())
1296     return getEmpty();
1297   APInt NewL = APIntOps::smax(getSignedMin(), Other.getSignedMin());
1298   APInt NewU = APIntOps::smax(getSignedMax(), Other.getSignedMax()) + 1;
1299   ConstantRange Res = getNonEmpty(std::move(NewL), std::move(NewU));
1300   if (isSignWrappedSet() || Other.isSignWrappedSet())
1301     return Res.intersectWith(unionWith(Other, Signed), Signed);
1302   return Res;
1303 }
1304 
1305 ConstantRange
umax(const ConstantRange & Other) const1306 ConstantRange::umax(const ConstantRange &Other) const {
1307   // X umax Y is: range(umax(X_umin, Y_umin),
1308   //                    umax(X_umax, Y_umax))
1309   if (isEmptySet() || Other.isEmptySet())
1310     return getEmpty();
1311   APInt NewL = APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin());
1312   APInt NewU = APIntOps::umax(getUnsignedMax(), Other.getUnsignedMax()) + 1;
1313   ConstantRange Res = getNonEmpty(std::move(NewL), std::move(NewU));
1314   if (isWrappedSet() || Other.isWrappedSet())
1315     return Res.intersectWith(unionWith(Other, Unsigned), Unsigned);
1316   return Res;
1317 }
1318 
1319 ConstantRange
smin(const ConstantRange & Other) const1320 ConstantRange::smin(const ConstantRange &Other) const {
1321   // X smin Y is: range(smin(X_smin, Y_smin),
1322   //                    smin(X_smax, Y_smax))
1323   if (isEmptySet() || Other.isEmptySet())
1324     return getEmpty();
1325   APInt NewL = APIntOps::smin(getSignedMin(), Other.getSignedMin());
1326   APInt NewU = APIntOps::smin(getSignedMax(), Other.getSignedMax()) + 1;
1327   ConstantRange Res = getNonEmpty(std::move(NewL), std::move(NewU));
1328   if (isSignWrappedSet() || Other.isSignWrappedSet())
1329     return Res.intersectWith(unionWith(Other, Signed), Signed);
1330   return Res;
1331 }
1332 
1333 ConstantRange
umin(const ConstantRange & Other) const1334 ConstantRange::umin(const ConstantRange &Other) const {
1335   // X umin Y is: range(umin(X_umin, Y_umin),
1336   //                    umin(X_umax, Y_umax))
1337   if (isEmptySet() || Other.isEmptySet())
1338     return getEmpty();
1339   APInt NewL = APIntOps::umin(getUnsignedMin(), Other.getUnsignedMin());
1340   APInt NewU = APIntOps::umin(getUnsignedMax(), Other.getUnsignedMax()) + 1;
1341   ConstantRange Res = getNonEmpty(std::move(NewL), std::move(NewU));
1342   if (isWrappedSet() || Other.isWrappedSet())
1343     return Res.intersectWith(unionWith(Other, Unsigned), Unsigned);
1344   return Res;
1345 }
1346 
1347 ConstantRange
udiv(const ConstantRange & RHS) const1348 ConstantRange::udiv(const ConstantRange &RHS) const {
1349   if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isZero())
1350     return getEmpty();
1351 
1352   APInt Lower = getUnsignedMin().udiv(RHS.getUnsignedMax());
1353 
1354   APInt RHS_umin = RHS.getUnsignedMin();
1355   if (RHS_umin.isZero()) {
1356     // We want the lowest value in RHS excluding zero. Usually that would be 1
1357     // except for a range in the form of [X, 1) in which case it would be X.
1358     if (RHS.getUpper() == 1)
1359       RHS_umin = RHS.getLower();
1360     else
1361       RHS_umin = 1;
1362   }
1363 
1364   APInt Upper = getUnsignedMax().udiv(RHS_umin) + 1;
1365   return getNonEmpty(std::move(Lower), std::move(Upper));
1366 }
1367 
sdiv(const ConstantRange & RHS) const1368 ConstantRange ConstantRange::sdiv(const ConstantRange &RHS) const {
1369   APInt Zero = APInt::getZero(getBitWidth());
1370   APInt SignedMin = APInt::getSignedMinValue(getBitWidth());
1371 
1372   // We split up the LHS and RHS into positive and negative components
1373   // and then also compute the positive and negative components of the result
1374   // separately by combining division results with the appropriate signs.
1375   auto [PosL, NegL] = splitPosNeg();
1376   auto [PosR, NegR] = RHS.splitPosNeg();
1377 
1378   ConstantRange PosRes = getEmpty();
1379   if (!PosL.isEmptySet() && !PosR.isEmptySet())
1380     // pos / pos = pos.
1381     PosRes = ConstantRange(PosL.Lower.sdiv(PosR.Upper - 1),
1382                            (PosL.Upper - 1).sdiv(PosR.Lower) + 1);
1383 
1384   if (!NegL.isEmptySet() && !NegR.isEmptySet()) {
1385     // neg / neg = pos.
1386     //
1387     // We need to deal with one tricky case here: SignedMin / -1 is UB on the
1388     // IR level, so we'll want to exclude this case when calculating bounds.
1389     // (For APInts the operation is well-defined and yields SignedMin.) We
1390     // handle this by dropping either SignedMin from the LHS or -1 from the RHS.
1391     APInt Lo = (NegL.Upper - 1).sdiv(NegR.Lower);
1392     if (NegL.Lower.isMinSignedValue() && NegR.Upper.isZero()) {
1393       // Remove -1 from the LHS. Skip if it's the only element, as this would
1394       // leave us with an empty set.
1395       if (!NegR.Lower.isAllOnes()) {
1396         APInt AdjNegRUpper;
1397         if (RHS.Lower.isAllOnes())
1398           // Negative part of [-1, X] without -1 is [SignedMin, X].
1399           AdjNegRUpper = RHS.Upper;
1400         else
1401           // [X, -1] without -1 is [X, -2].
1402           AdjNegRUpper = NegR.Upper - 1;
1403 
1404         PosRes = PosRes.unionWith(
1405             ConstantRange(Lo, NegL.Lower.sdiv(AdjNegRUpper - 1) + 1));
1406       }
1407 
1408       // Remove SignedMin from the RHS. Skip if it's the only element, as this
1409       // would leave us with an empty set.
1410       if (NegL.Upper != SignedMin + 1) {
1411         APInt AdjNegLLower;
1412         if (Upper == SignedMin + 1)
1413           // Negative part of [X, SignedMin] without SignedMin is [X, -1].
1414           AdjNegLLower = Lower;
1415         else
1416           // [SignedMin, X] without SignedMin is [SignedMin + 1, X].
1417           AdjNegLLower = NegL.Lower + 1;
1418 
1419         PosRes = PosRes.unionWith(
1420             ConstantRange(std::move(Lo),
1421                           AdjNegLLower.sdiv(NegR.Upper - 1) + 1));
1422       }
1423     } else {
1424       PosRes = PosRes.unionWith(
1425           ConstantRange(std::move(Lo), NegL.Lower.sdiv(NegR.Upper - 1) + 1));
1426     }
1427   }
1428 
1429   ConstantRange NegRes = getEmpty();
1430   if (!PosL.isEmptySet() && !NegR.isEmptySet())
1431     // pos / neg = neg.
1432     NegRes = ConstantRange((PosL.Upper - 1).sdiv(NegR.Upper - 1),
1433                            PosL.Lower.sdiv(NegR.Lower) + 1);
1434 
1435   if (!NegL.isEmptySet() && !PosR.isEmptySet())
1436     // neg / pos = neg.
1437     NegRes = NegRes.unionWith(
1438         ConstantRange(NegL.Lower.sdiv(PosR.Lower),
1439                       (NegL.Upper - 1).sdiv(PosR.Upper - 1) + 1));
1440 
1441   // Prefer a non-wrapping signed range here.
1442   ConstantRange Res = NegRes.unionWith(PosRes, PreferredRangeType::Signed);
1443 
1444   // Preserve the zero that we dropped when splitting the LHS by sign.
1445   if (contains(Zero) && (!PosR.isEmptySet() || !NegR.isEmptySet()))
1446     Res = Res.unionWith(ConstantRange(Zero));
1447   return Res;
1448 }
1449 
urem(const ConstantRange & RHS) const1450 ConstantRange ConstantRange::urem(const ConstantRange &RHS) const {
1451   if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isZero())
1452     return getEmpty();
1453 
1454   if (const APInt *RHSInt = RHS.getSingleElement()) {
1455     // UREM by null is UB.
1456     if (RHSInt->isZero())
1457       return getEmpty();
1458     // Use APInt's implementation of UREM for single element ranges.
1459     if (const APInt *LHSInt = getSingleElement())
1460       return {LHSInt->urem(*RHSInt)};
1461   }
1462 
1463   // L % R for L < R is L.
1464   if (getUnsignedMax().ult(RHS.getUnsignedMin()))
1465     return *this;
1466 
1467   // L % R is <= L and < R.
1468   APInt Upper = APIntOps::umin(getUnsignedMax(), RHS.getUnsignedMax() - 1) + 1;
1469   return getNonEmpty(APInt::getZero(getBitWidth()), std::move(Upper));
1470 }
1471 
srem(const ConstantRange & RHS) const1472 ConstantRange ConstantRange::srem(const ConstantRange &RHS) const {
1473   if (isEmptySet() || RHS.isEmptySet())
1474     return getEmpty();
1475 
1476   if (const APInt *RHSInt = RHS.getSingleElement()) {
1477     // SREM by null is UB.
1478     if (RHSInt->isZero())
1479       return getEmpty();
1480     // Use APInt's implementation of SREM for single element ranges.
1481     if (const APInt *LHSInt = getSingleElement())
1482       return {LHSInt->srem(*RHSInt)};
1483   }
1484 
1485   ConstantRange AbsRHS = RHS.abs();
1486   APInt MinAbsRHS = AbsRHS.getUnsignedMin();
1487   APInt MaxAbsRHS = AbsRHS.getUnsignedMax();
1488 
1489   // Modulus by zero is UB.
1490   if (MaxAbsRHS.isZero())
1491     return getEmpty();
1492 
1493   if (MinAbsRHS.isZero())
1494     ++MinAbsRHS;
1495 
1496   APInt MinLHS = getSignedMin(), MaxLHS = getSignedMax();
1497 
1498   if (MinLHS.isNonNegative()) {
1499     // L % R for L < R is L.
1500     if (MaxLHS.ult(MinAbsRHS))
1501       return *this;
1502 
1503     // L % R is <= L and < R.
1504     APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1;
1505     return ConstantRange(APInt::getZero(getBitWidth()), std::move(Upper));
1506   }
1507 
1508   // Same basic logic as above, but the result is negative.
1509   if (MaxLHS.isNegative()) {
1510     if (MinLHS.ugt(-MinAbsRHS))
1511       return *this;
1512 
1513     APInt Lower = APIntOps::umax(MinLHS, -MaxAbsRHS + 1);
1514     return ConstantRange(std::move(Lower), APInt(getBitWidth(), 1));
1515   }
1516 
1517   // LHS range crosses zero.
1518   APInt Lower = APIntOps::umax(MinLHS, -MaxAbsRHS + 1);
1519   APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1;
1520   return ConstantRange(std::move(Lower), std::move(Upper));
1521 }
1522 
binaryNot() const1523 ConstantRange ConstantRange::binaryNot() const {
1524   return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this);
1525 }
1526 
1527 /// Estimate the 'bit-masked AND' operation's lower bound.
1528 ///
1529 /// E.g., given two ranges as follows (single quotes are separators and
1530 /// have no meaning here),
1531 ///
1532 ///   LHS = [10'00101'1,  ; LLo
1533 ///          10'10000'0]  ; LHi
1534 ///   RHS = [10'11111'0,  ; RLo
1535 ///          10'11111'1]  ; RHi
1536 ///
1537 /// we know that the higher 2 bits of the result is always 10; and we also
1538 /// notice that RHS[1:6] are always 1, so the result[1:6] cannot be less than
1539 /// LHS[1:6] (i.e., 00101). Thus, the lower bound is 10'00101'0.
1540 ///
1541 /// The algorithm is as follows,
1542 /// 1. we first calculate a mask to find the higher common bits by
1543 ///       Mask = ~((LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo));
1544 ///       Mask = clear all non-leading-ones bits in Mask;
1545 ///    in the example, the Mask is set to 11'00000'0;
1546 /// 2. calculate a new mask by setting all common leading bits to 1 in RHS, and
1547 ///    keeping the longest leading ones (i.e., 11'11111'0 in the example);
1548 /// 3. return (LLo & new mask) as the lower bound;
1549 /// 4. repeat the step 2 and 3 with LHS and RHS swapped, and update the lower
1550 ///    bound with the larger one.
estimateBitMaskedAndLowerBound(const ConstantRange & LHS,const ConstantRange & RHS)1551 static APInt estimateBitMaskedAndLowerBound(const ConstantRange &LHS,
1552                                             const ConstantRange &RHS) {
1553   auto BitWidth = LHS.getBitWidth();
1554   // If either is full set or unsigned wrapped, then the range must contain '0'
1555   // which leads the lower bound to 0.
1556   if ((LHS.isFullSet() || RHS.isFullSet()) ||
1557       (LHS.isWrappedSet() || RHS.isWrappedSet()))
1558     return APInt::getZero(BitWidth);
1559 
1560   auto LLo = LHS.getLower();
1561   auto LHi = LHS.getUpper() - 1;
1562   auto RLo = RHS.getLower();
1563   auto RHi = RHS.getUpper() - 1;
1564 
1565   // Calculate the mask for the higher common bits.
1566   auto Mask = ~((LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo));
1567   unsigned LeadingOnes = Mask.countLeadingOnes();
1568   Mask.clearLowBits(BitWidth - LeadingOnes);
1569 
1570   auto estimateBound = [BitWidth, &Mask](APInt ALo, const APInt &BLo,
1571                                          const APInt &BHi) {
1572     unsigned LeadingOnes = ((BLo & BHi) | Mask).countLeadingOnes();
1573     unsigned StartBit = BitWidth - LeadingOnes;
1574     ALo.clearLowBits(StartBit);
1575     return ALo;
1576   };
1577 
1578   auto LowerBoundByLHS = estimateBound(LLo, RLo, RHi);
1579   auto LowerBoundByRHS = estimateBound(RLo, LLo, LHi);
1580 
1581   return APIntOps::umax(LowerBoundByLHS, LowerBoundByRHS);
1582 }
1583 
binaryAnd(const ConstantRange & Other) const1584 ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const {
1585   if (isEmptySet() || Other.isEmptySet())
1586     return getEmpty();
1587 
1588   ConstantRange KnownBitsRange =
1589       fromKnownBits(toKnownBits() & Other.toKnownBits(), false);
1590   auto LowerBound = estimateBitMaskedAndLowerBound(*this, Other);
1591   ConstantRange UMinUMaxRange = getNonEmpty(
1592       LowerBound, APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
1593   return KnownBitsRange.intersectWith(UMinUMaxRange);
1594 }
1595 
binaryOr(const ConstantRange & Other) const1596 ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const {
1597   if (isEmptySet() || Other.isEmptySet())
1598     return getEmpty();
1599 
1600   ConstantRange KnownBitsRange =
1601       fromKnownBits(toKnownBits() | Other.toKnownBits(), false);
1602 
1603   //      ~a & ~b    >= x
1604   // <=>  ~(~a & ~b) <= ~x
1605   // <=>  a | b      <= ~x
1606   // <=>  a | b      <  ~x + 1 = -x
1607   // thus, UpperBound(a | b) == -LowerBound(~a & ~b)
1608   auto UpperBound =
1609       -estimateBitMaskedAndLowerBound(binaryNot(), Other.binaryNot());
1610   // Upper wrapped range.
1611   ConstantRange UMaxUMinRange = getNonEmpty(
1612       APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), UpperBound);
1613   return KnownBitsRange.intersectWith(UMaxUMinRange);
1614 }
1615 
binaryXor(const ConstantRange & Other) const1616 ConstantRange ConstantRange::binaryXor(const ConstantRange &Other) const {
1617   if (isEmptySet() || Other.isEmptySet())
1618     return getEmpty();
1619 
1620   // Use APInt's implementation of XOR for single element ranges.
1621   if (isSingleElement() && Other.isSingleElement())
1622     return {*getSingleElement() ^ *Other.getSingleElement()};
1623 
1624   // Special-case binary complement, since we can give a precise answer.
1625   if (Other.isSingleElement() && Other.getSingleElement()->isAllOnes())
1626     return binaryNot();
1627   if (isSingleElement() && getSingleElement()->isAllOnes())
1628     return Other.binaryNot();
1629 
1630   KnownBits LHSKnown = toKnownBits();
1631   KnownBits RHSKnown = Other.toKnownBits();
1632   KnownBits Known = LHSKnown ^ RHSKnown;
1633   ConstantRange CR = fromKnownBits(Known, /*IsSigned*/ false);
1634   // Typically the following code doesn't improve the result if BW = 1.
1635   if (getBitWidth() == 1)
1636     return CR;
1637 
1638   // If LHS is known to be the subset of RHS, treat LHS ^ RHS as RHS -nuw/nsw
1639   // LHS. If RHS is known to be the subset of LHS, treat LHS ^ RHS as LHS
1640   // -nuw/nsw RHS.
1641   if ((~LHSKnown.Zero).isSubsetOf(RHSKnown.One))
1642     CR = CR.intersectWith(Other.sub(*this), PreferredRangeType::Unsigned);
1643   else if ((~RHSKnown.Zero).isSubsetOf(LHSKnown.One))
1644     CR = CR.intersectWith(this->sub(Other), PreferredRangeType::Unsigned);
1645   return CR;
1646 }
1647 
1648 ConstantRange
shl(const ConstantRange & Other) const1649 ConstantRange::shl(const ConstantRange &Other) const {
1650   if (isEmptySet() || Other.isEmptySet())
1651     return getEmpty();
1652 
1653   APInt Min = getUnsignedMin();
1654   APInt Max = getUnsignedMax();
1655   if (const APInt *RHS = Other.getSingleElement()) {
1656     unsigned BW = getBitWidth();
1657     if (RHS->uge(BW))
1658       return getEmpty();
1659 
1660     unsigned EqualLeadingBits = (Min ^ Max).countl_zero();
1661     if (RHS->ule(EqualLeadingBits))
1662       return getNonEmpty(Min << *RHS, (Max << *RHS) + 1);
1663 
1664     return getNonEmpty(APInt::getZero(BW),
1665                        APInt::getBitsSetFrom(BW, RHS->getZExtValue()) + 1);
1666   }
1667 
1668   APInt OtherMax = Other.getUnsignedMax();
1669   if (isAllNegative() && OtherMax.ule(Min.countl_one())) {
1670     // For negative numbers, if the shift does not overflow in a signed sense,
1671     // a larger shift will make the number smaller.
1672     Max <<= Other.getUnsignedMin();
1673     Min <<= OtherMax;
1674     return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
1675   }
1676 
1677   // There's overflow!
1678   if (OtherMax.ugt(Max.countl_zero()))
1679     return getFull();
1680 
1681   // FIXME: implement the other tricky cases
1682 
1683   Min <<= Other.getUnsignedMin();
1684   Max <<= OtherMax;
1685 
1686   return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
1687 }
1688 
computeShlNUW(const ConstantRange & LHS,const ConstantRange & RHS)1689 static ConstantRange computeShlNUW(const ConstantRange &LHS,
1690                                    const ConstantRange &RHS) {
1691   unsigned BitWidth = LHS.getBitWidth();
1692   bool Overflow;
1693   APInt LHSMin = LHS.getUnsignedMin();
1694   unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
1695   APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
1696   if (Overflow)
1697     return ConstantRange::getEmpty(BitWidth);
1698   APInt LHSMax = LHS.getUnsignedMax();
1699   unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
1700   APInt MaxShl = MinShl;
1701   unsigned MaxShAmt = LHSMax.countLeadingZeros();
1702   if (RHSMin <= MaxShAmt)
1703     MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
1704   RHSMin = std::max(RHSMin, MaxShAmt + 1);
1705   RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
1706   if (RHSMin <= RHSMax)
1707     MaxShl = APIntOps::umax(MaxShl,
1708                             APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin));
1709   return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
1710 }
1711 
computeShlNSWWithNNegLHS(const APInt & LHSMin,const APInt & LHSMax,unsigned RHSMin,unsigned RHSMax)1712 static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin,
1713                                               const APInt &LHSMax,
1714                                               unsigned RHSMin,
1715                                               unsigned RHSMax) {
1716   unsigned BitWidth = LHSMin.getBitWidth();
1717   bool Overflow;
1718   APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow);
1719   if (Overflow)
1720     return ConstantRange::getEmpty(BitWidth);
1721   APInt MaxShl = MinShl;
1722   unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1;
1723   if (RHSMin <= MaxShAmt)
1724     MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
1725   RHSMin = std::max(RHSMin, MaxShAmt + 1);
1726   RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1);
1727   if (RHSMin <= RHSMax)
1728     MaxShl = APIntOps::umax(MaxShl,
1729                             APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1));
1730   return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
1731 }
1732 
computeShlNSWWithNegLHS(const APInt & LHSMin,const APInt & LHSMax,unsigned RHSMin,unsigned RHSMax)1733 static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin,
1734                                              const APInt &LHSMax,
1735                                              unsigned RHSMin, unsigned RHSMax) {
1736   unsigned BitWidth = LHSMin.getBitWidth();
1737   bool Overflow;
1738   APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow);
1739   if (Overflow)
1740     return ConstantRange::getEmpty(BitWidth);
1741   APInt MinShl = MaxShl;
1742   unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1;
1743   if (RHSMin <= MaxShAmt)
1744     MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt));
1745   RHSMin = std::max(RHSMin, MaxShAmt + 1);
1746   RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1);
1747   if (RHSMin <= RHSMax)
1748     MinShl = APInt::getSignMask(BitWidth);
1749   return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
1750 }
1751 
computeShlNSW(const ConstantRange & LHS,const ConstantRange & RHS)1752 static ConstantRange computeShlNSW(const ConstantRange &LHS,
1753                                    const ConstantRange &RHS) {
1754   unsigned BitWidth = LHS.getBitWidth();
1755   unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
1756   unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
1757   APInt LHSMin = LHS.getSignedMin();
1758   APInt LHSMax = LHS.getSignedMax();
1759   if (LHSMin.isNonNegative())
1760     return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
1761   else if (LHSMax.isNegative())
1762     return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
1763   return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin,
1764                                   RHSMax)
1765       .unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth),
1766                                          RHSMin, RHSMax),
1767                  ConstantRange::Signed);
1768 }
1769 
shlWithNoWrap(const ConstantRange & Other,unsigned NoWrapKind,PreferredRangeType RangeType) const1770 ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
1771                                            unsigned NoWrapKind,
1772                                            PreferredRangeType RangeType) const {
1773   if (isEmptySet() || Other.isEmptySet())
1774     return getEmpty();
1775 
1776   switch (NoWrapKind) {
1777   case 0:
1778     return shl(Other);
1779   case OverflowingBinaryOperator::NoSignedWrap:
1780     return computeShlNSW(*this, Other);
1781   case OverflowingBinaryOperator::NoUnsignedWrap:
1782     return computeShlNUW(*this, Other);
1783   case OverflowingBinaryOperator::NoSignedWrap |
1784       OverflowingBinaryOperator::NoUnsignedWrap:
1785     return computeShlNSW(*this, Other)
1786         .intersectWith(computeShlNUW(*this, Other), RangeType);
1787   default:
1788     llvm_unreachable("Invalid NoWrapKind");
1789   }
1790 }
1791 
1792 ConstantRange
lshr(const ConstantRange & Other) const1793 ConstantRange::lshr(const ConstantRange &Other) const {
1794   if (isEmptySet() || Other.isEmptySet())
1795     return getEmpty();
1796 
1797   APInt max = getUnsignedMax().lshr(Other.getUnsignedMin()) + 1;
1798   APInt min = getUnsignedMin().lshr(Other.getUnsignedMax());
1799   return getNonEmpty(std::move(min), std::move(max));
1800 }
1801 
1802 ConstantRange
ashr(const ConstantRange & Other) const1803 ConstantRange::ashr(const ConstantRange &Other) const {
1804   if (isEmptySet() || Other.isEmptySet())
1805     return getEmpty();
1806 
1807   // May straddle zero, so handle both positive and negative cases.
1808   // 'PosMax' is the upper bound of the result of the ashr
1809   // operation, when Upper of the LHS of ashr is a non-negative.
1810   // number. Since ashr of a non-negative number will result in a
1811   // smaller number, the Upper value of LHS is shifted right with
1812   // the minimum value of 'Other' instead of the maximum value.
1813   APInt PosMax = getSignedMax().ashr(Other.getUnsignedMin()) + 1;
1814 
1815   // 'PosMin' is the lower bound of the result of the ashr
1816   // operation, when Lower of the LHS is a non-negative number.
1817   // Since ashr of a non-negative number will result in a smaller
1818   // number, the Lower value of LHS is shifted right with the
1819   // maximum value of 'Other'.
1820   APInt PosMin = getSignedMin().ashr(Other.getUnsignedMax());
1821 
1822   // 'NegMax' is the upper bound of the result of the ashr
1823   // operation, when Upper of the LHS of ashr is a negative number.
1824   // Since 'ashr' of a negative number will result in a bigger
1825   // number, the Upper value of LHS is shifted right with the
1826   // maximum value of 'Other'.
1827   APInt NegMax = getSignedMax().ashr(Other.getUnsignedMax()) + 1;
1828 
1829   // 'NegMin' is the lower bound of the result of the ashr
1830   // operation, when Lower of the LHS of ashr is a negative number.
1831   // Since 'ashr' of a negative number will result in a bigger
1832   // number, the Lower value of LHS is shifted right with the
1833   // minimum value of 'Other'.
1834   APInt NegMin = getSignedMin().ashr(Other.getUnsignedMin());
1835 
1836   APInt max, min;
1837   if (getSignedMin().isNonNegative()) {
1838     // Upper and Lower of LHS are non-negative.
1839     min = PosMin;
1840     max = PosMax;
1841   } else if (getSignedMax().isNegative()) {
1842     // Upper and Lower of LHS are negative.
1843     min = NegMin;
1844     max = NegMax;
1845   } else {
1846     // Upper is non-negative and Lower is negative.
1847     min = NegMin;
1848     max = PosMax;
1849   }
1850   return getNonEmpty(std::move(min), std::move(max));
1851 }
1852 
uadd_sat(const ConstantRange & Other) const1853 ConstantRange ConstantRange::uadd_sat(const ConstantRange &Other) const {
1854   if (isEmptySet() || Other.isEmptySet())
1855     return getEmpty();
1856 
1857   APInt NewL = getUnsignedMin().uadd_sat(Other.getUnsignedMin());
1858   APInt NewU = getUnsignedMax().uadd_sat(Other.getUnsignedMax()) + 1;
1859   return getNonEmpty(std::move(NewL), std::move(NewU));
1860 }
1861 
sadd_sat(const ConstantRange & Other) const1862 ConstantRange ConstantRange::sadd_sat(const ConstantRange &Other) const {
1863   if (isEmptySet() || Other.isEmptySet())
1864     return getEmpty();
1865 
1866   APInt NewL = getSignedMin().sadd_sat(Other.getSignedMin());
1867   APInt NewU = getSignedMax().sadd_sat(Other.getSignedMax()) + 1;
1868   return getNonEmpty(std::move(NewL), std::move(NewU));
1869 }
1870 
usub_sat(const ConstantRange & Other) const1871 ConstantRange ConstantRange::usub_sat(const ConstantRange &Other) const {
1872   if (isEmptySet() || Other.isEmptySet())
1873     return getEmpty();
1874 
1875   APInt NewL = getUnsignedMin().usub_sat(Other.getUnsignedMax());
1876   APInt NewU = getUnsignedMax().usub_sat(Other.getUnsignedMin()) + 1;
1877   return getNonEmpty(std::move(NewL), std::move(NewU));
1878 }
1879 
ssub_sat(const ConstantRange & Other) const1880 ConstantRange ConstantRange::ssub_sat(const ConstantRange &Other) const {
1881   if (isEmptySet() || Other.isEmptySet())
1882     return getEmpty();
1883 
1884   APInt NewL = getSignedMin().ssub_sat(Other.getSignedMax());
1885   APInt NewU = getSignedMax().ssub_sat(Other.getSignedMin()) + 1;
1886   return getNonEmpty(std::move(NewL), std::move(NewU));
1887 }
1888 
umul_sat(const ConstantRange & Other) const1889 ConstantRange ConstantRange::umul_sat(const ConstantRange &Other) const {
1890   if (isEmptySet() || Other.isEmptySet())
1891     return getEmpty();
1892 
1893   APInt NewL = getUnsignedMin().umul_sat(Other.getUnsignedMin());
1894   APInt NewU = getUnsignedMax().umul_sat(Other.getUnsignedMax()) + 1;
1895   return getNonEmpty(std::move(NewL), std::move(NewU));
1896 }
1897 
smul_sat(const ConstantRange & Other) const1898 ConstantRange ConstantRange::smul_sat(const ConstantRange &Other) const {
1899   if (isEmptySet() || Other.isEmptySet())
1900     return getEmpty();
1901 
1902   // Because we could be dealing with negative numbers here, the lower bound is
1903   // the smallest of the cartesian product of the lower and upper ranges;
1904   // for example:
1905   //   [-1,4) * [-2,3) = min(-1*-2, -1*2, 3*-2, 3*2) = -6.
1906   // Similarly for the upper bound, swapping min for max.
1907 
1908   APInt Min = getSignedMin();
1909   APInt Max = getSignedMax();
1910   APInt OtherMin = Other.getSignedMin();
1911   APInt OtherMax = Other.getSignedMax();
1912 
1913   auto L = {Min.smul_sat(OtherMin), Min.smul_sat(OtherMax),
1914             Max.smul_sat(OtherMin), Max.smul_sat(OtherMax)};
1915   auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); };
1916   return getNonEmpty(std::min(L, Compare), std::max(L, Compare) + 1);
1917 }
1918 
ushl_sat(const ConstantRange & Other) const1919 ConstantRange ConstantRange::ushl_sat(const ConstantRange &Other) const {
1920   if (isEmptySet() || Other.isEmptySet())
1921     return getEmpty();
1922 
1923   APInt NewL = getUnsignedMin().ushl_sat(Other.getUnsignedMin());
1924   APInt NewU = getUnsignedMax().ushl_sat(Other.getUnsignedMax()) + 1;
1925   return getNonEmpty(std::move(NewL), std::move(NewU));
1926 }
1927 
sshl_sat(const ConstantRange & Other) const1928 ConstantRange ConstantRange::sshl_sat(const ConstantRange &Other) const {
1929   if (isEmptySet() || Other.isEmptySet())
1930     return getEmpty();
1931 
1932   APInt Min = getSignedMin(), Max = getSignedMax();
1933   APInt ShAmtMin = Other.getUnsignedMin(), ShAmtMax = Other.getUnsignedMax();
1934   APInt NewL = Min.sshl_sat(Min.isNonNegative() ? ShAmtMin : ShAmtMax);
1935   APInt NewU = Max.sshl_sat(Max.isNegative() ? ShAmtMin : ShAmtMax) + 1;
1936   return getNonEmpty(std::move(NewL), std::move(NewU));
1937 }
1938 
inverse() const1939 ConstantRange ConstantRange::inverse() const {
1940   if (isFullSet())
1941     return getEmpty();
1942   if (isEmptySet())
1943     return getFull();
1944   return ConstantRange(Upper, Lower);
1945 }
1946 
abs(bool IntMinIsPoison) const1947 ConstantRange ConstantRange::abs(bool IntMinIsPoison) const {
1948   if (isEmptySet())
1949     return getEmpty();
1950 
1951   if (isSignWrappedSet()) {
1952     APInt Lo;
1953     // Check whether the range crosses zero.
1954     if (Upper.isStrictlyPositive() || !Lower.isStrictlyPositive())
1955       Lo = APInt::getZero(getBitWidth());
1956     else
1957       Lo = APIntOps::umin(Lower, -Upper + 1);
1958 
1959     // If SignedMin is not poison, then it is included in the result range.
1960     if (IntMinIsPoison)
1961       return ConstantRange(Lo, APInt::getSignedMinValue(getBitWidth()));
1962     else
1963       return ConstantRange(Lo, APInt::getSignedMinValue(getBitWidth()) + 1);
1964   }
1965 
1966   APInt SMin = getSignedMin(), SMax = getSignedMax();
1967 
1968   // Skip SignedMin if it is poison.
1969   if (IntMinIsPoison && SMin.isMinSignedValue()) {
1970     // The range may become empty if it *only* contains SignedMin.
1971     if (SMax.isMinSignedValue())
1972       return getEmpty();
1973     ++SMin;
1974   }
1975 
1976   // All non-negative.
1977   if (SMin.isNonNegative())
1978     return ConstantRange(SMin, SMax + 1);
1979 
1980   // All negative.
1981   if (SMax.isNegative())
1982     return ConstantRange(-SMax, -SMin + 1);
1983 
1984   // Range crosses zero.
1985   return ConstantRange::getNonEmpty(APInt::getZero(getBitWidth()),
1986                                     APIntOps::umax(-SMin, SMax) + 1);
1987 }
1988 
ctlz(bool ZeroIsPoison) const1989 ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
1990   if (isEmptySet())
1991     return getEmpty();
1992 
1993   APInt Zero = APInt::getZero(getBitWidth());
1994   if (ZeroIsPoison && contains(Zero)) {
1995     // ZeroIsPoison is set, and zero is contained. We discern three cases, in
1996     // which a zero can appear:
1997     // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
1998     // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
1999     // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
2000 
2001     if (getLower().isZero()) {
2002       if ((getUpper() - 1).isZero()) {
2003         // We have in input interval of kind [0, 1). In this case we cannot
2004         // really help but return empty-set.
2005         return getEmpty();
2006       }
2007 
2008       // Compute the resulting range by excluding zero from Lower.
2009       return ConstantRange(
2010           APInt(getBitWidth(), (getUpper() - 1).countl_zero()),
2011           APInt(getBitWidth(), (getLower() + 1).countl_zero() + 1));
2012     } else if ((getUpper() - 1).isZero()) {
2013       // Compute the resulting range by excluding zero from Upper.
2014       return ConstantRange(Zero,
2015                            APInt(getBitWidth(), getLower().countl_zero() + 1));
2016     } else {
2017       return ConstantRange(Zero, APInt(getBitWidth(), getBitWidth()));
2018     }
2019   }
2020 
2021   // Zero is either safe or not in the range. The output range is composed by
2022   // the result of countLeadingZero of the two extremes.
2023   return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
2024                      APInt(getBitWidth(), getUnsignedMin().countl_zero()) + 1);
2025 }
2026 
getUnsignedCountTrailingZerosRange(const APInt & Lower,const APInt & Upper)2027 static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
2028                                                         const APInt &Upper) {
2029   assert(!ConstantRange(Lower, Upper).isWrappedSet() &&
2030          "Unexpected wrapped set.");
2031   assert(Lower != Upper && "Unexpected empty set.");
2032   unsigned BitWidth = Lower.getBitWidth();
2033   if (Lower + 1 == Upper)
2034     return ConstantRange(APInt(BitWidth, Lower.countr_zero()));
2035   if (Lower.isZero())
2036     return ConstantRange(APInt::getZero(BitWidth),
2037                          APInt(BitWidth, BitWidth + 1));
2038 
2039   // Calculate longest common prefix.
2040   unsigned LCPLength = (Lower ^ (Upper - 1)).countl_zero();
2041   // If Lower is {LCP, 000...}, the maximum is Lower.countr_zero().
2042   // Otherwise, the maximum is BitWidth - LCPLength - 1 ({LCP, 100...}).
2043   return ConstantRange(
2044       APInt::getZero(BitWidth),
2045       APInt(BitWidth,
2046             std::max(BitWidth - LCPLength - 1, Lower.countr_zero()) + 1));
2047 }
2048 
cttz(bool ZeroIsPoison) const2049 ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
2050   if (isEmptySet())
2051     return getEmpty();
2052 
2053   unsigned BitWidth = getBitWidth();
2054   APInt Zero = APInt::getZero(BitWidth);
2055   if (ZeroIsPoison && contains(Zero)) {
2056     // ZeroIsPoison is set, and zero is contained. We discern three cases, in
2057     // which a zero can appear:
2058     // 1) Lower is zero, handling cases of kind [0, 1), [0, 2), etc.
2059     // 2) Upper is zero, wrapped set, handling cases of kind [3, 0], etc.
2060     // 3) Zero contained in a wrapped set, e.g., [3, 2), [3, 1), etc.
2061 
2062     if (Lower.isZero()) {
2063       if (Upper == 1) {
2064         // We have in input interval of kind [0, 1). In this case we cannot
2065         // really help but return empty-set.
2066         return getEmpty();
2067       }
2068 
2069       // Compute the resulting range by excluding zero from Lower.
2070       return getUnsignedCountTrailingZerosRange(APInt(BitWidth, 1), Upper);
2071     } else if (Upper == 1) {
2072       // Compute the resulting range by excluding zero from Upper.
2073       return getUnsignedCountTrailingZerosRange(Lower, Zero);
2074     } else {
2075       ConstantRange CR1 = getUnsignedCountTrailingZerosRange(Lower, Zero);
2076       ConstantRange CR2 =
2077           getUnsignedCountTrailingZerosRange(APInt(BitWidth, 1), Upper);
2078       return CR1.unionWith(CR2);
2079     }
2080   }
2081 
2082   if (isFullSet())
2083     return getNonEmpty(Zero, APInt(BitWidth, BitWidth) + 1);
2084   if (!isWrappedSet())
2085     return getUnsignedCountTrailingZerosRange(Lower, Upper);
2086   // The range is wrapped. We decompose it into two ranges, [0, Upper) and
2087   // [Lower, 0).
2088   // Handle [Lower, 0)
2089   ConstantRange CR1 = getUnsignedCountTrailingZerosRange(Lower, Zero);
2090   // Handle [0, Upper)
2091   ConstantRange CR2 = getUnsignedCountTrailingZerosRange(Zero, Upper);
2092   return CR1.unionWith(CR2);
2093 }
2094 
getUnsignedPopCountRange(const APInt & Lower,const APInt & Upper)2095 static ConstantRange getUnsignedPopCountRange(const APInt &Lower,
2096                                               const APInt &Upper) {
2097   assert(!ConstantRange(Lower, Upper).isWrappedSet() &&
2098          "Unexpected wrapped set.");
2099   assert(Lower != Upper && "Unexpected empty set.");
2100   unsigned BitWidth = Lower.getBitWidth();
2101   if (Lower + 1 == Upper)
2102     return ConstantRange(APInt(BitWidth, Lower.popcount()));
2103 
2104   APInt Max = Upper - 1;
2105   // Calculate longest common prefix.
2106   unsigned LCPLength = (Lower ^ Max).countl_zero();
2107   unsigned LCPPopCount = Lower.getHiBits(LCPLength).popcount();
2108   // If Lower is {LCP, 000...}, the minimum is the popcount of LCP.
2109   // Otherwise, the minimum is the popcount of LCP + 1.
2110   unsigned MinBits =
2111       LCPPopCount + (Lower.countr_zero() < BitWidth - LCPLength ? 1 : 0);
2112   // If Max is {LCP, 111...}, the maximum is the popcount of LCP + (BitWidth -
2113   // length of LCP).
2114   // Otherwise, the minimum is the popcount of LCP + (BitWidth -
2115   // length of LCP - 1).
2116   unsigned MaxBits = LCPPopCount + (BitWidth - LCPLength) -
2117                      (Max.countr_one() < BitWidth - LCPLength ? 1 : 0);
2118   return ConstantRange(APInt(BitWidth, MinBits), APInt(BitWidth, MaxBits + 1));
2119 }
2120 
ctpop() const2121 ConstantRange ConstantRange::ctpop() const {
2122   if (isEmptySet())
2123     return getEmpty();
2124 
2125   unsigned BitWidth = getBitWidth();
2126   APInt Zero = APInt::getZero(BitWidth);
2127   if (isFullSet())
2128     return getNonEmpty(Zero, APInt(BitWidth, BitWidth) + 1);
2129   if (!isWrappedSet())
2130     return getUnsignedPopCountRange(Lower, Upper);
2131   // The range is wrapped. We decompose it into two ranges, [0, Upper) and
2132   // [Lower, 0).
2133   // Handle [Lower, 0) == [Lower, Max]
2134   ConstantRange CR1 = ConstantRange(APInt(BitWidth, Lower.countl_one()),
2135                                     APInt(BitWidth, BitWidth + 1));
2136   // Handle [0, Upper)
2137   ConstantRange CR2 = getUnsignedPopCountRange(Zero, Upper);
2138   return CR1.unionWith(CR2);
2139 }
2140 
unsignedAddMayOverflow(const ConstantRange & Other) const2141 ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
2142     const ConstantRange &Other) const {
2143   if (isEmptySet() || Other.isEmptySet())
2144     return OverflowResult::MayOverflow;
2145 
2146   APInt Min = getUnsignedMin(), Max = getUnsignedMax();
2147   APInt OtherMin = Other.getUnsignedMin(), OtherMax = Other.getUnsignedMax();
2148 
2149   // a u+ b overflows high iff a u> ~b.
2150   if (Min.ugt(~OtherMin))
2151     return OverflowResult::AlwaysOverflowsHigh;
2152   if (Max.ugt(~OtherMax))
2153     return OverflowResult::MayOverflow;
2154   return OverflowResult::NeverOverflows;
2155 }
2156 
signedAddMayOverflow(const ConstantRange & Other) const2157 ConstantRange::OverflowResult ConstantRange::signedAddMayOverflow(
2158     const ConstantRange &Other) const {
2159   if (isEmptySet() || Other.isEmptySet())
2160     return OverflowResult::MayOverflow;
2161 
2162   APInt Min = getSignedMin(), Max = getSignedMax();
2163   APInt OtherMin = Other.getSignedMin(), OtherMax = Other.getSignedMax();
2164 
2165   APInt SignedMin = APInt::getSignedMinValue(getBitWidth());
2166   APInt SignedMax = APInt::getSignedMaxValue(getBitWidth());
2167 
2168   // a s+ b overflows high iff a s>=0 && b s>= 0 && a s> smax - b.
2169   // a s+ b overflows low iff a s< 0 && b s< 0 && a s< smin - b.
2170   if (Min.isNonNegative() && OtherMin.isNonNegative() &&
2171       Min.sgt(SignedMax - OtherMin))
2172     return OverflowResult::AlwaysOverflowsHigh;
2173   if (Max.isNegative() && OtherMax.isNegative() &&
2174       Max.slt(SignedMin - OtherMax))
2175     return OverflowResult::AlwaysOverflowsLow;
2176 
2177   if (Max.isNonNegative() && OtherMax.isNonNegative() &&
2178       Max.sgt(SignedMax - OtherMax))
2179     return OverflowResult::MayOverflow;
2180   if (Min.isNegative() && OtherMin.isNegative() &&
2181       Min.slt(SignedMin - OtherMin))
2182     return OverflowResult::MayOverflow;
2183 
2184   return OverflowResult::NeverOverflows;
2185 }
2186 
unsignedSubMayOverflow(const ConstantRange & Other) const2187 ConstantRange::OverflowResult ConstantRange::unsignedSubMayOverflow(
2188     const ConstantRange &Other) const {
2189   if (isEmptySet() || Other.isEmptySet())
2190     return OverflowResult::MayOverflow;
2191 
2192   APInt Min = getUnsignedMin(), Max = getUnsignedMax();
2193   APInt OtherMin = Other.getUnsignedMin(), OtherMax = Other.getUnsignedMax();
2194 
2195   // a u- b overflows low iff a u< b.
2196   if (Max.ult(OtherMin))
2197     return OverflowResult::AlwaysOverflowsLow;
2198   if (Min.ult(OtherMax))
2199     return OverflowResult::MayOverflow;
2200   return OverflowResult::NeverOverflows;
2201 }
2202 
signedSubMayOverflow(const ConstantRange & Other) const2203 ConstantRange::OverflowResult ConstantRange::signedSubMayOverflow(
2204     const ConstantRange &Other) const {
2205   if (isEmptySet() || Other.isEmptySet())
2206     return OverflowResult::MayOverflow;
2207 
2208   APInt Min = getSignedMin(), Max = getSignedMax();
2209   APInt OtherMin = Other.getSignedMin(), OtherMax = Other.getSignedMax();
2210 
2211   APInt SignedMin = APInt::getSignedMinValue(getBitWidth());
2212   APInt SignedMax = APInt::getSignedMaxValue(getBitWidth());
2213 
2214   // a s- b overflows high iff a s>=0 && b s< 0 && a s> smax + b.
2215   // a s- b overflows low iff a s< 0 && b s>= 0 && a s< smin + b.
2216   if (Min.isNonNegative() && OtherMax.isNegative() &&
2217       Min.sgt(SignedMax + OtherMax))
2218     return OverflowResult::AlwaysOverflowsHigh;
2219   if (Max.isNegative() && OtherMin.isNonNegative() &&
2220       Max.slt(SignedMin + OtherMin))
2221     return OverflowResult::AlwaysOverflowsLow;
2222 
2223   if (Max.isNonNegative() && OtherMin.isNegative() &&
2224       Max.sgt(SignedMax + OtherMin))
2225     return OverflowResult::MayOverflow;
2226   if (Min.isNegative() && OtherMax.isNonNegative() &&
2227       Min.slt(SignedMin + OtherMax))
2228     return OverflowResult::MayOverflow;
2229 
2230   return OverflowResult::NeverOverflows;
2231 }
2232 
unsignedMulMayOverflow(const ConstantRange & Other) const2233 ConstantRange::OverflowResult ConstantRange::unsignedMulMayOverflow(
2234     const ConstantRange &Other) const {
2235   if (isEmptySet() || Other.isEmptySet())
2236     return OverflowResult::MayOverflow;
2237 
2238   APInt Min = getUnsignedMin(), Max = getUnsignedMax();
2239   APInt OtherMin = Other.getUnsignedMin(), OtherMax = Other.getUnsignedMax();
2240   bool Overflow;
2241 
2242   (void) Min.umul_ov(OtherMin, Overflow);
2243   if (Overflow)
2244     return OverflowResult::AlwaysOverflowsHigh;
2245 
2246   (void) Max.umul_ov(OtherMax, Overflow);
2247   if (Overflow)
2248     return OverflowResult::MayOverflow;
2249 
2250   return OverflowResult::NeverOverflows;
2251 }
2252 
print(raw_ostream & OS) const2253 void ConstantRange::print(raw_ostream &OS) const {
2254   if (isFullSet())
2255     OS << "full-set";
2256   else if (isEmptySet())
2257     OS << "empty-set";
2258   else
2259     OS << "[" << Lower << "," << Upper << ")";
2260 }
2261 
2262 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump() const2263 LLVM_DUMP_METHOD void ConstantRange::dump() const {
2264   print(dbgs());
2265 }
2266 #endif
2267 
getConstantRangeFromMetadata(const MDNode & Ranges)2268 ConstantRange llvm::getConstantRangeFromMetadata(const MDNode &Ranges) {
2269   const unsigned NumRanges = Ranges.getNumOperands() / 2;
2270   assert(NumRanges >= 1 && "Must have at least one range!");
2271   assert(Ranges.getNumOperands() % 2 == 0 && "Must be a sequence of pairs");
2272 
2273   auto *FirstLow = mdconst::extract<ConstantInt>(Ranges.getOperand(0));
2274   auto *FirstHigh = mdconst::extract<ConstantInt>(Ranges.getOperand(1));
2275 
2276   ConstantRange CR(FirstLow->getValue(), FirstHigh->getValue());
2277 
2278   for (unsigned i = 1; i < NumRanges; ++i) {
2279     auto *Low = mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
2280     auto *High = mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
2281 
2282     // Note: unionWith will potentially create a range that contains values not
2283     // contained in any of the original N ranges.
2284     CR = CR.unionWith(ConstantRange(Low->getValue(), High->getValue()));
2285   }
2286 
2287   return CR;
2288 }
2289