xref: /freebsd/contrib/llvm-project/llvm/lib/Support/KnownBits.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18 
19 using namespace llvm;
20 
computeForAddCarry(const KnownBits & LHS,const KnownBits & RHS,bool CarryZero,bool CarryOne)21 static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
22                                     bool CarryZero, bool CarryOne) {
23 
24   APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
25   APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
26 
27   // Compute known bits of the carry.
28   APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
29   APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
30 
31   // Compute set of known bits (where all three relevant bits are known).
32   APInt LHSKnownUnion = LHS.Zero | LHS.One;
33   APInt RHSKnownUnion = RHS.Zero | RHS.One;
34   APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
35   APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
36 
37   // Compute known bits of the result.
38   KnownBits KnownOut;
39   KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
40   KnownOut.One = std::move(PossibleSumOne) & Known;
41   return KnownOut;
42 }
43 
computeForAddCarry(const KnownBits & LHS,const KnownBits & RHS,const KnownBits & Carry)44 KnownBits KnownBits::computeForAddCarry(
45     const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
46   assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
47   return ::computeForAddCarry(
48       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
49 }
50 
computeForAddSub(bool Add,bool NSW,bool NUW,const KnownBits & LHS,const KnownBits & RHS)51 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
52                                       const KnownBits &LHS,
53                                       const KnownBits &RHS) {
54   unsigned BitWidth = LHS.getBitWidth();
55   KnownBits KnownOut(BitWidth);
56   // This can be a relatively expensive helper, so optimistically save some
57   // work.
58   if (LHS.isUnknown() && RHS.isUnknown())
59     return KnownOut;
60 
61   if (!LHS.isUnknown() && !RHS.isUnknown()) {
62     if (Add) {
63       // Sum = LHS + RHS + 0
64       KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
65                                       /*CarryOne=*/false);
66     } else {
67       // Sum = LHS + ~RHS + 1
68       KnownBits NotRHS = RHS;
69       std::swap(NotRHS.Zero, NotRHS.One);
70       KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
71                                       /*CarryOne=*/true);
72     }
73   }
74 
75   // Handle add/sub given nsw and/or nuw.
76   if (NUW) {
77     if (Add) {
78       // (add nuw X, Y)
79       APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
80       // None of the adds can end up overflowing, so min consecutive highbits
81       // in minimum possible of X + Y must all remain set.
82       if (NSW) {
83         unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
84         // If we have NSW as well, we also know we can't overflow the signbit so
85         // can start counting from 1 bit back.
86         KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
87       }
88       KnownOut.One.setHighBits(MinVal.countl_one());
89     } else {
90       // (sub nuw X, Y)
91       APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue());
92       // None of the subs can overflow at any point, so any common high bits
93       // will subtract away and result in zeros.
94       if (NSW) {
95         // If we have NSW as well, we also know we can't overflow the signbit so
96         // can start counting from 1 bit back.
97         unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
98         KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
99       }
100       KnownOut.Zero.setHighBits(MaxVal.countl_zero());
101     }
102   }
103 
104   if (NSW) {
105     APInt MinVal;
106     APInt MaxVal;
107     if (Add) {
108       // (add nsw X, Y)
109       MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue());
110       MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue());
111     } else {
112       // (sub nsw X, Y)
113       MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue());
114       MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue());
115     }
116     if (MinVal.isNonNegative()) {
117       // If min is non-negative, result will always be non-neg (can't overflow
118       // around).
119       unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
120       KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
121       KnownOut.Zero.setSignBit();
122     }
123     if (MaxVal.isNegative()) {
124       // If max is negative, result will always be neg (can't overflow around).
125       unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
126       KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
127       KnownOut.One.setSignBit();
128     }
129   }
130 
131   // Just return 0 if the nsw/nuw is violated and we have poison.
132   if (KnownOut.hasConflict())
133     KnownOut.setAllZero();
134   return KnownOut;
135 }
136 
computeForSubBorrow(const KnownBits & LHS,KnownBits RHS,const KnownBits & Borrow)137 KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
138                                          const KnownBits &Borrow) {
139   assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
140 
141   // LHS - RHS = LHS + ~RHS + 1
142   // Carry 1 - Borrow in ::computeForAddCarry
143   std::swap(RHS.Zero, RHS.One);
144   return ::computeForAddCarry(LHS, RHS,
145                               /*CarryZero=*/Borrow.One.getBoolValue(),
146                               /*CarryOne=*/Borrow.Zero.getBoolValue());
147 }
148 
sextInReg(unsigned SrcBitWidth) const149 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
150   unsigned BitWidth = getBitWidth();
151   assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
152          "Illegal sext-in-register");
153 
154   if (SrcBitWidth == BitWidth)
155     return *this;
156 
157   unsigned ExtBits = BitWidth - SrcBitWidth;
158   KnownBits Result;
159   Result.One = One << ExtBits;
160   Result.Zero = Zero << ExtBits;
161   Result.One.ashrInPlace(ExtBits);
162   Result.Zero.ashrInPlace(ExtBits);
163   return Result;
164 }
165 
makeGE(const APInt & Val) const166 KnownBits KnownBits::makeGE(const APInt &Val) const {
167   // Count the number of leading bit positions where our underlying value is
168   // known to be less than or equal to Val.
169   unsigned N = (Zero | Val).countl_one();
170 
171   // For each of those bit positions, if Val has a 1 in that bit then our
172   // underlying value must also have a 1.
173   APInt MaskedVal(Val);
174   MaskedVal.clearLowBits(getBitWidth() - N);
175   return KnownBits(Zero, One | MaskedVal);
176 }
177 
umax(const KnownBits & LHS,const KnownBits & RHS)178 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
179   // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
180   // RHS. Ideally our caller would already have spotted these cases and
181   // optimized away the umax operation, but we handle them here for
182   // completeness.
183   if (LHS.getMinValue().uge(RHS.getMaxValue()))
184     return LHS;
185   if (RHS.getMinValue().uge(LHS.getMaxValue()))
186     return RHS;
187 
188   // If the result of the umax is LHS then it must be greater than or equal to
189   // the minimum possible value of RHS. Likewise for RHS. Any known bits that
190   // are common to these two values are also known in the result.
191   KnownBits L = LHS.makeGE(RHS.getMinValue());
192   KnownBits R = RHS.makeGE(LHS.getMinValue());
193   return L.intersectWith(R);
194 }
195 
umin(const KnownBits & LHS,const KnownBits & RHS)196 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
197   // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
198   auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
199   return Flip(umax(Flip(LHS), Flip(RHS)));
200 }
201 
smax(const KnownBits & LHS,const KnownBits & RHS)202 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
203   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
204   auto Flip = [](const KnownBits &Val) {
205     unsigned SignBitPosition = Val.getBitWidth() - 1;
206     APInt Zero = Val.Zero;
207     APInt One = Val.One;
208     Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
209     One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
210     return KnownBits(Zero, One);
211   };
212   return Flip(umax(Flip(LHS), Flip(RHS)));
213 }
214 
smin(const KnownBits & LHS,const KnownBits & RHS)215 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
216   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
217   auto Flip = [](const KnownBits &Val) {
218     unsigned SignBitPosition = Val.getBitWidth() - 1;
219     APInt Zero = Val.One;
220     APInt One = Val.Zero;
221     Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
222     One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
223     return KnownBits(Zero, One);
224   };
225   return Flip(umax(Flip(LHS), Flip(RHS)));
226 }
227 
abdu(const KnownBits & LHS,const KnownBits & RHS)228 KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
229   // If we know which argument is larger, return (sub LHS, RHS) or
230   // (sub RHS, LHS) directly.
231   if (LHS.getMinValue().uge(RHS.getMaxValue()))
232     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
233                             RHS);
234   if (RHS.getMinValue().uge(LHS.getMaxValue()))
235     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
236                             LHS);
237 
238   // By construction, the subtraction in abdu never has unsigned overflow.
239   // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
240   KnownBits Diff0 =
241       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
242   KnownBits Diff1 =
243       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
244   return Diff0.intersectWith(Diff1);
245 }
246 
abds(KnownBits LHS,KnownBits RHS)247 KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
248   // If we know which argument is larger, return (sub LHS, RHS) or
249   // (sub RHS, LHS) directly.
250   if (LHS.getSignedMinValue().sge(RHS.getSignedMaxValue()))
251     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
252                             RHS);
253   if (RHS.getSignedMinValue().sge(LHS.getSignedMaxValue()))
254     return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
255                             LHS);
256 
257   // Shift both arguments from the signed range to the unsigned range, e.g. from
258   // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
259   // abdu does.
260   // Note that we can't just use "sub nsw" instead because abds has signed
261   // inputs but an unsigned result, which makes the overflow conditions
262   // different.
263   unsigned SignBitPosition = LHS.getBitWidth() - 1;
264   for (auto Arg : {&LHS, &RHS}) {
265     bool Tmp = Arg->Zero[SignBitPosition];
266     Arg->Zero.setBitVal(SignBitPosition, Arg->One[SignBitPosition]);
267     Arg->One.setBitVal(SignBitPosition, Tmp);
268   }
269 
270   // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
271   KnownBits Diff0 =
272       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
273   KnownBits Diff1 =
274       computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
275   return Diff0.intersectWith(Diff1);
276 }
277 
getMaxShiftAmount(const APInt & MaxValue,unsigned BitWidth)278 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
279   if (isPowerOf2_32(BitWidth))
280     return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
281   // This is only an approximate upper bound.
282   return MaxValue.getLimitedValue(BitWidth - 1);
283 }
284 
shl(const KnownBits & LHS,const KnownBits & RHS,bool NUW,bool NSW,bool ShAmtNonZero)285 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
286                          bool NSW, bool ShAmtNonZero) {
287   unsigned BitWidth = LHS.getBitWidth();
288   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
289     KnownBits Known;
290     bool ShiftedOutZero, ShiftedOutOne;
291     Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
292     Known.Zero.setLowBits(ShiftAmt);
293     Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
294 
295     // All cases returning poison have been handled by MaxShiftAmount already.
296     if (NSW) {
297       if (NUW && ShiftAmt != 0)
298         // NUW means we can assume anything shifted out was a zero.
299         ShiftedOutZero = true;
300 
301       if (ShiftedOutZero)
302         Known.makeNonNegative();
303       else if (ShiftedOutOne)
304         Known.makeNegative();
305     }
306     return Known;
307   };
308 
309   // Fast path for a common case when LHS is completely unknown.
310   KnownBits Known(BitWidth);
311   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
312   if (MinShiftAmount == 0 && ShAmtNonZero)
313     MinShiftAmount = 1;
314   if (LHS.isUnknown()) {
315     Known.Zero.setLowBits(MinShiftAmount);
316     if (NUW && NSW && MinShiftAmount != 0)
317       Known.makeNonNegative();
318     return Known;
319   }
320 
321   // Determine maximum shift amount, taking NUW/NSW flags into account.
322   APInt MaxValue = RHS.getMaxValue();
323   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
324   if (NUW && NSW)
325     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
326   if (NUW)
327     MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
328   if (NSW)
329     MaxShiftAmount = std::min(
330         MaxShiftAmount,
331         std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
332 
333   // Fast path for common case where the shift amount is unknown.
334   if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
335       isPowerOf2_32(BitWidth)) {
336     Known.Zero.setLowBits(LHS.countMinTrailingZeros());
337     if (LHS.isAllOnes())
338       Known.One.setSignBit();
339     if (NSW) {
340       if (LHS.isNonNegative())
341         Known.makeNonNegative();
342       if (LHS.isNegative())
343         Known.makeNegative();
344     }
345     return Known;
346   }
347 
348   // Find the common bits from all possible shifts.
349   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
350   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
351   Known.Zero.setAllBits();
352   Known.One.setAllBits();
353   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
354        ++ShiftAmt) {
355     // Skip if the shift amount is impossible.
356     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
357         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
358       continue;
359     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
360     if (Known.isUnknown())
361       break;
362   }
363 
364   // All shift amounts may result in poison.
365   if (Known.hasConflict())
366     Known.setAllZero();
367   return Known;
368 }
369 
lshr(const KnownBits & LHS,const KnownBits & RHS,bool ShAmtNonZero,bool Exact)370 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
371                           bool ShAmtNonZero, bool Exact) {
372   unsigned BitWidth = LHS.getBitWidth();
373   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
374     KnownBits Known = LHS;
375     Known.Zero.lshrInPlace(ShiftAmt);
376     Known.One.lshrInPlace(ShiftAmt);
377     // High bits are known zero.
378     Known.Zero.setHighBits(ShiftAmt);
379     return Known;
380   };
381 
382   // Fast path for a common case when LHS is completely unknown.
383   KnownBits Known(BitWidth);
384   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
385   if (MinShiftAmount == 0 && ShAmtNonZero)
386     MinShiftAmount = 1;
387   if (LHS.isUnknown()) {
388     Known.Zero.setHighBits(MinShiftAmount);
389     return Known;
390   }
391 
392   // Find the common bits from all possible shifts.
393   APInt MaxValue = RHS.getMaxValue();
394   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
395 
396   // If exact, bound MaxShiftAmount to first known 1 in LHS.
397   if (Exact) {
398     unsigned FirstOne = LHS.countMaxTrailingZeros();
399     if (FirstOne < MinShiftAmount) {
400       // Always poison. Return zero because we don't like returning conflict.
401       Known.setAllZero();
402       return Known;
403     }
404     MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
405   }
406 
407   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
408   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
409   Known.Zero.setAllBits();
410   Known.One.setAllBits();
411   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
412        ++ShiftAmt) {
413     // Skip if the shift amount is impossible.
414     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
415         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
416       continue;
417     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
418     if (Known.isUnknown())
419       break;
420   }
421 
422   // All shift amounts may result in poison.
423   if (Known.hasConflict())
424     Known.setAllZero();
425   return Known;
426 }
427 
ashr(const KnownBits & LHS,const KnownBits & RHS,bool ShAmtNonZero,bool Exact)428 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
429                           bool ShAmtNonZero, bool Exact) {
430   unsigned BitWidth = LHS.getBitWidth();
431   auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
432     KnownBits Known = LHS;
433     Known.Zero.ashrInPlace(ShiftAmt);
434     Known.One.ashrInPlace(ShiftAmt);
435     return Known;
436   };
437 
438   // Fast path for a common case when LHS is completely unknown.
439   KnownBits Known(BitWidth);
440   unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
441   if (MinShiftAmount == 0 && ShAmtNonZero)
442     MinShiftAmount = 1;
443   if (LHS.isUnknown()) {
444     if (MinShiftAmount == BitWidth) {
445       // Always poison. Return zero because we don't like returning conflict.
446       Known.setAllZero();
447       return Known;
448     }
449     return Known;
450   }
451 
452   // Find the common bits from all possible shifts.
453   APInt MaxValue = RHS.getMaxValue();
454   unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
455 
456   // If exact, bound MaxShiftAmount to first known 1 in LHS.
457   if (Exact) {
458     unsigned FirstOne = LHS.countMaxTrailingZeros();
459     if (FirstOne < MinShiftAmount) {
460       // Always poison. Return zero because we don't like returning conflict.
461       Known.setAllZero();
462       return Known;
463     }
464     MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
465   }
466 
467   unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
468   unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
469   Known.Zero.setAllBits();
470   Known.One.setAllBits();
471   for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
472       ++ShiftAmt) {
473     // Skip if the shift amount is impossible.
474     if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
475         (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
476       continue;
477     Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
478     if (Known.isUnknown())
479       break;
480   }
481 
482   // All shift amounts may result in poison.
483   if (Known.hasConflict())
484     Known.setAllZero();
485   return Known;
486 }
487 
eq(const KnownBits & LHS,const KnownBits & RHS)488 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
489   if (LHS.isConstant() && RHS.isConstant())
490     return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
491   if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
492     return std::optional<bool>(false);
493   return std::nullopt;
494 }
495 
ne(const KnownBits & LHS,const KnownBits & RHS)496 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
497   if (std::optional<bool> KnownEQ = eq(LHS, RHS))
498     return std::optional<bool>(!*KnownEQ);
499   return std::nullopt;
500 }
501 
ugt(const KnownBits & LHS,const KnownBits & RHS)502 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
503   // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
504   if (LHS.getMaxValue().ule(RHS.getMinValue()))
505     return std::optional<bool>(false);
506   // LHS >u RHS -> true if umin(LHS) > umax(RHS)
507   if (LHS.getMinValue().ugt(RHS.getMaxValue()))
508     return std::optional<bool>(true);
509   return std::nullopt;
510 }
511 
uge(const KnownBits & LHS,const KnownBits & RHS)512 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
513   if (std::optional<bool> IsUGT = ugt(RHS, LHS))
514     return std::optional<bool>(!*IsUGT);
515   return std::nullopt;
516 }
517 
ult(const KnownBits & LHS,const KnownBits & RHS)518 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
519   return ugt(RHS, LHS);
520 }
521 
ule(const KnownBits & LHS,const KnownBits & RHS)522 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
523   return uge(RHS, LHS);
524 }
525 
sgt(const KnownBits & LHS,const KnownBits & RHS)526 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
527   // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
528   if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
529     return std::optional<bool>(false);
530   // LHS >s RHS -> true if smin(LHS) > smax(RHS)
531   if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
532     return std::optional<bool>(true);
533   return std::nullopt;
534 }
535 
sge(const KnownBits & LHS,const KnownBits & RHS)536 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
537   if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
538     return std::optional<bool>(!*KnownSGT);
539   return std::nullopt;
540 }
541 
slt(const KnownBits & LHS,const KnownBits & RHS)542 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
543   return sgt(RHS, LHS);
544 }
545 
sle(const KnownBits & LHS,const KnownBits & RHS)546 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
547   return sge(RHS, LHS);
548 }
549 
abs(bool IntMinIsPoison) const550 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
551   // If the source's MSB is zero then we know the rest of the bits already.
552   if (isNonNegative())
553     return *this;
554 
555   // Absolute value preserves trailing zero count.
556   KnownBits KnownAbs(getBitWidth());
557 
558   // If the input is negative, then abs(x) == -x.
559   if (isNegative()) {
560     KnownBits Tmp = *this;
561     // Special case for IntMinIsPoison. We know the sign bit is set and we know
562     // all the rest of the bits except one to be zero. Since we have
563     // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
564     // INT_MIN.
565     if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
566       Tmp.One.setBit(countMinTrailingZeros());
567 
568     KnownAbs = computeForAddSub(
569         /*Add*/ false, IntMinIsPoison, /*NUW=*/false,
570         KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
571 
572     // One more special case for IntMinIsPoison. If we don't know any ones other
573     // than the signbit, we know for certain that all the unknowns can't be
574     // zero. So if we know high zero bits, but have unknown low bits, we know
575     // for certain those high-zero bits will end up as one. This is because,
576     // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
577     // to the high bits. If we know a known INT_MIN input skip this. The result
578     // is poison anyways.
579     if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
580         Tmp.countMaxPopulation() != 1) {
581       Tmp.One.clearSignBit();
582       Tmp.Zero.setSignBit();
583       KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(),
584                            getBitWidth() - 1);
585     }
586 
587   } else {
588     unsigned MaxTZ = countMaxTrailingZeros();
589     unsigned MinTZ = countMinTrailingZeros();
590 
591     KnownAbs.Zero.setLowBits(MinTZ);
592     // If we know the lowest set 1, then preserve it.
593     if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
594       KnownAbs.One.setBit(MaxTZ);
595 
596     // We only know that the absolute values's MSB will be zero if INT_MIN is
597     // poison, or there is a set bit that isn't the sign bit (otherwise it could
598     // be INT_MIN).
599     if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
600       KnownAbs.One.clearSignBit();
601       KnownAbs.Zero.setSignBit();
602     }
603   }
604 
605   return KnownAbs;
606 }
607 
computeForSatAddSub(bool Add,bool Signed,const KnownBits & LHS,const KnownBits & RHS)608 static KnownBits computeForSatAddSub(bool Add, bool Signed,
609                                      const KnownBits &LHS,
610                                      const KnownBits &RHS) {
611   // We don't see NSW even for sadd/ssub as we want to check if the result has
612   // signed overflow.
613   KnownBits Res =
614       KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
615   unsigned BitWidth = Res.getBitWidth();
616   auto SignBitKnown = [&](const KnownBits &K) {
617     return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
618   };
619   std::optional<bool> Overflow;
620 
621   if (Signed) {
622     // If we can actually detect overflow do so. Otherwise leave Overflow as
623     // nullopt (we assume it may have happened).
624     if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
625       if (Add) {
626         // sadd.sat
627         Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
628                     Res.isNonNegative() != LHS.isNonNegative());
629       } else {
630         // ssub.sat
631         Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
632                     Res.isNonNegative() != LHS.isNonNegative());
633       }
634     }
635   } else if (Add) {
636     // uadd.sat
637     bool Of;
638     (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
639     if (!Of) {
640       Overflow = false;
641     } else {
642       (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
643       if (Of)
644         Overflow = true;
645     }
646   } else {
647     // usub.sat
648     bool Of;
649     (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
650     if (!Of) {
651       Overflow = false;
652     } else {
653       (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
654       if (Of)
655         Overflow = true;
656     }
657   }
658 
659   if (Signed) {
660     if (Add) {
661       if (LHS.isNonNegative() && RHS.isNonNegative()) {
662         // Pos + Pos -> Pos
663         Res.One.clearSignBit();
664         Res.Zero.setSignBit();
665       }
666       if (LHS.isNegative() && RHS.isNegative()) {
667         // Neg + Neg -> Neg
668         Res.One.setSignBit();
669         Res.Zero.clearSignBit();
670       }
671     } else {
672       if (LHS.isNegative() && RHS.isNonNegative()) {
673         // Neg - Pos -> Neg
674         Res.One.setSignBit();
675         Res.Zero.clearSignBit();
676       } else if (LHS.isNonNegative() && RHS.isNegative()) {
677         // Pos - Neg -> Pos
678         Res.One.clearSignBit();
679         Res.Zero.setSignBit();
680       }
681     }
682   } else {
683     // Add: Leading ones of either operand are preserved.
684     // Sub: Leading zeros of LHS and leading ones of RHS are preserved
685     // as leading zeros in the result.
686     unsigned LeadingKnown;
687     if (Add)
688       LeadingKnown =
689           std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
690     else
691       LeadingKnown =
692           std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
693 
694     // We select between the operation result and all-ones/zero
695     // respectively, so we can preserve known ones/zeros.
696     APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
697     if (Add) {
698       Res.One |= Mask;
699       Res.Zero &= ~Mask;
700     } else {
701       Res.Zero |= Mask;
702       Res.One &= ~Mask;
703     }
704   }
705 
706   if (Overflow) {
707     // We know whether or not we overflowed.
708     if (!(*Overflow)) {
709       // No overflow.
710       return Res;
711     }
712 
713     // We overflowed
714     APInt C;
715     if (Signed) {
716       // sadd.sat / ssub.sat
717       assert(SignBitKnown(LHS) &&
718              "We somehow know overflow without knowing input sign");
719       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
720                            : APInt::getSignedMaxValue(BitWidth);
721     } else if (Add) {
722       // uadd.sat
723       C = APInt::getMaxValue(BitWidth);
724     } else {
725       // uadd.sat
726       C = APInt::getMinValue(BitWidth);
727     }
728 
729     Res.One = C;
730     Res.Zero = ~C;
731     return Res;
732   }
733 
734   // We don't know if we overflowed.
735   if (Signed) {
736     // sadd.sat/ssub.sat
737     // We can keep our information about the sign bits.
738     Res.Zero.clearLowBits(BitWidth - 1);
739     Res.One.clearLowBits(BitWidth - 1);
740   } else if (Add) {
741     // uadd.sat
742     // We need to clear all the known zeros as we can only use the leading ones.
743     Res.Zero.clearAllBits();
744   } else {
745     // usub.sat
746     // We need to clear all the known ones as we can only use the leading zero.
747     Res.One.clearAllBits();
748   }
749 
750   return Res;
751 }
752 
sadd_sat(const KnownBits & LHS,const KnownBits & RHS)753 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
754   return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
755 }
ssub_sat(const KnownBits & LHS,const KnownBits & RHS)756 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
757   return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
758 }
uadd_sat(const KnownBits & LHS,const KnownBits & RHS)759 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
760   return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
761 }
usub_sat(const KnownBits & LHS,const KnownBits & RHS)762 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
763   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
764 }
765 
avgCompute(KnownBits LHS,KnownBits RHS,bool IsCeil,bool IsSigned)766 static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
767                             bool IsSigned) {
768   unsigned BitWidth = LHS.getBitWidth();
769   LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
770   RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
771   LHS =
772       computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
773   LHS = LHS.extractBits(BitWidth, 1);
774   return LHS;
775 }
776 
avgFloorS(const KnownBits & LHS,const KnownBits & RHS)777 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
778   return avgCompute(LHS, RHS, /* IsCeil */ false,
779                     /* IsSigned */ true);
780 }
781 
avgFloorU(const KnownBits & LHS,const KnownBits & RHS)782 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
783   return avgCompute(LHS, RHS, /* IsCeil */ false,
784                     /* IsSigned */ false);
785 }
786 
avgCeilS(const KnownBits & LHS,const KnownBits & RHS)787 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
788   return avgCompute(LHS, RHS, /* IsCeil */ true,
789                     /* IsSigned */ true);
790 }
791 
avgCeilU(const KnownBits & LHS,const KnownBits & RHS)792 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
793   return avgCompute(LHS, RHS, /* IsCeil */ true,
794                     /* IsSigned */ false);
795 }
796 
mul(const KnownBits & LHS,const KnownBits & RHS,bool NoUndefSelfMultiply)797 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
798                          bool NoUndefSelfMultiply) {
799   unsigned BitWidth = LHS.getBitWidth();
800   assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
801   assert((!NoUndefSelfMultiply || LHS == RHS) &&
802          "Self multiplication knownbits mismatch");
803 
804   // Compute the high known-0 bits by multiplying the unsigned max of each side.
805   // Conservatively, M active bits * N active bits results in M + N bits in the
806   // result. But if we know a value is a power-of-2 for example, then this
807   // computes one more leading zero.
808   // TODO: This could be generalized to number of sign bits (negative numbers).
809   APInt UMaxLHS = LHS.getMaxValue();
810   APInt UMaxRHS = RHS.getMaxValue();
811 
812   // For leading zeros in the result to be valid, the unsigned max product must
813   // fit in the bitwidth (it must not overflow).
814   bool HasOverflow;
815   APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
816   unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
817 
818   // The result of the bottom bits of an integer multiply can be
819   // inferred by looking at the bottom bits of both operands and
820   // multiplying them together.
821   // We can infer at least the minimum number of known trailing bits
822   // of both operands. Depending on number of trailing zeros, we can
823   // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
824   // a and b are divisible by m and n respectively.
825   // We then calculate how many of those bits are inferrable and set
826   // the output. For example, the i8 mul:
827   //  a = XXXX1100 (12)
828   //  b = XXXX1110 (14)
829   // We know the bottom 3 bits are zero since the first can be divided by
830   // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
831   // Applying the multiplication to the trimmed arguments gets:
832   //    XX11 (3)
833   //    X111 (7)
834   // -------
835   //    XX11
836   //   XX11
837   //  XX11
838   // XX11
839   // -------
840   // XXXXX01
841   // Which allows us to infer the 2 LSBs. Since we're multiplying the result
842   // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
843   // The proof for this can be described as:
844   // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
845   //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
846   //                    umin(countTrailingZeros(C2), C6) +
847   //                    umin(C5 - umin(countTrailingZeros(C1), C5),
848   //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
849   // %aa = shl i8 %a, C5
850   // %bb = shl i8 %b, C6
851   // %aaa = or i8 %aa, C1
852   // %bbb = or i8 %bb, C2
853   // %mul = mul i8 %aaa, %bbb
854   // %mask = and i8 %mul, C7
855   //   =>
856   // %mask = i8 ((C1*C2)&C7)
857   // Where C5, C6 describe the known bits of %a, %b
858   // C1, C2 describe the known bottom bits of %a, %b.
859   // C7 describes the mask of the known bits of the result.
860   const APInt &Bottom0 = LHS.One;
861   const APInt &Bottom1 = RHS.One;
862 
863   // How many times we'd be able to divide each argument by 2 (shr by 1).
864   // This gives us the number of trailing zeros on the multiplication result.
865   unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
866   unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
867   unsigned TrailZero0 = LHS.countMinTrailingZeros();
868   unsigned TrailZero1 = RHS.countMinTrailingZeros();
869   unsigned TrailZ = TrailZero0 + TrailZero1;
870 
871   // Figure out the fewest known-bits operand.
872   unsigned SmallestOperand =
873       std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
874   unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
875 
876   APInt BottomKnown =
877       Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
878 
879   KnownBits Res(BitWidth);
880   Res.Zero.setHighBits(LeadZ);
881   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
882   Res.One = BottomKnown.getLoBits(ResultBitsKnown);
883 
884   // If we're self-multiplying then bit[1] is guaranteed to be zero.
885   if (NoUndefSelfMultiply && BitWidth > 1) {
886     assert(Res.One[1] == 0 &&
887            "Self-multiplication failed Quadratic Reciprocity!");
888     Res.Zero.setBit(1);
889   }
890 
891   return Res;
892 }
893 
mulhs(const KnownBits & LHS,const KnownBits & RHS)894 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
895   unsigned BitWidth = LHS.getBitWidth();
896   assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
897   KnownBits WideLHS = LHS.sext(2 * BitWidth);
898   KnownBits WideRHS = RHS.sext(2 * BitWidth);
899   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
900 }
901 
mulhu(const KnownBits & LHS,const KnownBits & RHS)902 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
903   unsigned BitWidth = LHS.getBitWidth();
904   assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
905   KnownBits WideLHS = LHS.zext(2 * BitWidth);
906   KnownBits WideRHS = RHS.zext(2 * BitWidth);
907   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
908 }
909 
divComputeLowBit(KnownBits Known,const KnownBits & LHS,const KnownBits & RHS,bool Exact)910 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
911                                   const KnownBits &RHS, bool Exact) {
912 
913   if (!Exact)
914     return Known;
915 
916   // If LHS is Odd, the result is Odd no matter what.
917   // Odd / Odd -> Odd
918   // Odd / Even -> Impossible (because its exact division)
919   if (LHS.One[0])
920     Known.One.setBit(0);
921 
922   int MinTZ =
923       (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
924   int MaxTZ =
925       (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
926   if (MinTZ >= 0) {
927     // Result has at least MinTZ trailing zeros.
928     Known.Zero.setLowBits(MinTZ);
929     if (MinTZ == MaxTZ) {
930       // Result has exactly MinTZ trailing zeros.
931       Known.One.setBit(MinTZ);
932     }
933   } else if (MaxTZ < 0) {
934     // Poison Result
935     Known.setAllZero();
936   }
937 
938   // In the KnownBits exhaustive tests, we have poison inputs for exact values
939   // a LOT. If we have a conflict, just return all zeros.
940   if (Known.hasConflict())
941     Known.setAllZero();
942 
943   return Known;
944 }
945 
sdiv(const KnownBits & LHS,const KnownBits & RHS,bool Exact)946 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
947                           bool Exact) {
948   // Equivalent of `udiv`. We must have caught this before it was folded.
949   if (LHS.isNonNegative() && RHS.isNonNegative())
950     return udiv(LHS, RHS, Exact);
951 
952   unsigned BitWidth = LHS.getBitWidth();
953   KnownBits Known(BitWidth);
954 
955   if (LHS.isZero() || RHS.isZero()) {
956     // Result is either known Zero or UB. Return Zero either way.
957     // Checking this earlier saves us a lot of special cases later on.
958     Known.setAllZero();
959     return Known;
960   }
961 
962   std::optional<APInt> Res;
963   if (LHS.isNegative() && RHS.isNegative()) {
964     // Result non-negative.
965     APInt Denom = RHS.getSignedMaxValue();
966     APInt Num = LHS.getSignedMinValue();
967     // INT_MIN/-1 would be a poison result (impossible). Estimate the division
968     // as signed max (we will only set sign bit in the result).
969     Res = (Num.isMinSignedValue() && Denom.isAllOnes())
970               ? APInt::getSignedMaxValue(BitWidth)
971               : Num.sdiv(Denom);
972   } else if (LHS.isNegative() && RHS.isNonNegative()) {
973     // Result is negative if Exact OR -LHS u>= RHS.
974     if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
975       APInt Denom = RHS.getSignedMinValue();
976       APInt Num = LHS.getSignedMinValue();
977       Res = Denom.isZero() ? Num : Num.sdiv(Denom);
978     }
979   } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
980     // Result is negative if Exact OR LHS u>= -RHS.
981     if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
982       APInt Denom = RHS.getSignedMaxValue();
983       APInt Num = LHS.getSignedMaxValue();
984       Res = Num.sdiv(Denom);
985     }
986   }
987 
988   if (Res) {
989     if (Res->isNonNegative()) {
990       unsigned LeadZ = Res->countLeadingZeros();
991       Known.Zero.setHighBits(LeadZ);
992     } else {
993       unsigned LeadO = Res->countLeadingOnes();
994       Known.One.setHighBits(LeadO);
995     }
996   }
997 
998   Known = divComputeLowBit(Known, LHS, RHS, Exact);
999   return Known;
1000 }
1001 
udiv(const KnownBits & LHS,const KnownBits & RHS,bool Exact)1002 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
1003                           bool Exact) {
1004   unsigned BitWidth = LHS.getBitWidth();
1005   KnownBits Known(BitWidth);
1006 
1007   if (LHS.isZero() || RHS.isZero()) {
1008     // Result is either known Zero or UB. Return Zero either way.
1009     // Checking this earlier saves us a lot of special cases later on.
1010     Known.setAllZero();
1011     return Known;
1012   }
1013 
1014   // We can figure out the minimum number of upper zero bits by doing
1015   // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
1016   // gets larger, the number of upper zero bits increases.
1017   APInt MinDenom = RHS.getMinValue();
1018   APInt MaxNum = LHS.getMaxValue();
1019   APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom);
1020 
1021   unsigned LeadZ = MaxRes.countLeadingZeros();
1022 
1023   Known.Zero.setHighBits(LeadZ);
1024   Known = divComputeLowBit(Known, LHS, RHS, Exact);
1025 
1026   return Known;
1027 }
1028 
remGetLowBits(const KnownBits & LHS,const KnownBits & RHS)1029 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
1030   unsigned BitWidth = LHS.getBitWidth();
1031   if (!RHS.isZero() && RHS.Zero[0]) {
1032     // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
1033     unsigned RHSZeros = RHS.countMinTrailingZeros();
1034     APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros);
1035     APInt OnesMask = LHS.One & Mask;
1036     APInt ZerosMask = LHS.Zero & Mask;
1037     return KnownBits(ZerosMask, OnesMask);
1038   }
1039   return KnownBits(BitWidth);
1040 }
1041 
urem(const KnownBits & LHS,const KnownBits & RHS)1042 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
1043   KnownBits Known = remGetLowBits(LHS, RHS);
1044   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1045     // NB: Low bits set in `remGetLowBits`.
1046     APInt HighBits = ~(RHS.getConstant() - 1);
1047     Known.Zero |= HighBits;
1048     return Known;
1049   }
1050 
1051   // Since the result is less than or equal to either operand, any leading
1052   // zero bits in either operand must also exist in the result.
1053   uint32_t Leaders =
1054       std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
1055   Known.Zero.setHighBits(Leaders);
1056   return Known;
1057 }
1058 
srem(const KnownBits & LHS,const KnownBits & RHS)1059 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
1060   KnownBits Known = remGetLowBits(LHS, RHS);
1061   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1062     // NB: Low bits are set in `remGetLowBits`.
1063     APInt LowBits = RHS.getConstant() - 1;
1064     // If the first operand is non-negative or has all low bits zero, then
1065     // the upper bits are all zero.
1066     if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
1067       Known.Zero |= ~LowBits;
1068 
1069     // If the first operand is negative and not all low bits are zero, then
1070     // the upper bits are all one.
1071     if (LHS.isNegative() && LowBits.intersects(LHS.One))
1072       Known.One |= ~LowBits;
1073     return Known;
1074   }
1075 
1076   // The sign bit is the LHS's sign bit, except when the result of the
1077   // remainder is zero. The magnitude of the result should be less than or
1078   // equal to the magnitude of the LHS. Therefore any leading zeros that exist
1079   // in the left hand side must also exist in the result.
1080   Known.Zero.setHighBits(LHS.countMinLeadingZeros());
1081   return Known;
1082 }
1083 
operator &=(const KnownBits & RHS)1084 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
1085   // Result bit is 0 if either operand bit is 0.
1086   Zero |= RHS.Zero;
1087   // Result bit is 1 if both operand bits are 1.
1088   One &= RHS.One;
1089   return *this;
1090 }
1091 
operator |=(const KnownBits & RHS)1092 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
1093   // Result bit is 0 if both operand bits are 0.
1094   Zero &= RHS.Zero;
1095   // Result bit is 1 if either operand bit is 1.
1096   One |= RHS.One;
1097   return *this;
1098 }
1099 
operator ^=(const KnownBits & RHS)1100 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
1101   // Result bit is 0 if both operand bits are 0 or both are 1.
1102   APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
1103   // Result bit is 1 if one operand bit is 0 and the other is 1.
1104   One = (Zero & RHS.One) | (One & RHS.Zero);
1105   Zero = std::move(Z);
1106   return *this;
1107 }
1108 
blsi() const1109 KnownBits KnownBits::blsi() const {
1110   unsigned BitWidth = getBitWidth();
1111   KnownBits Known(Zero, APInt(BitWidth, 0));
1112   unsigned Max = countMaxTrailingZeros();
1113   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
1114   unsigned Min = countMinTrailingZeros();
1115   if (Max == Min && Max < BitWidth)
1116     Known.One.setBit(Max);
1117   return Known;
1118 }
1119 
blsmsk() const1120 KnownBits KnownBits::blsmsk() const {
1121   unsigned BitWidth = getBitWidth();
1122   KnownBits Known(BitWidth);
1123   unsigned Max = countMaxTrailingZeros();
1124   Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
1125   unsigned Min = countMinTrailingZeros();
1126   Known.One.setLowBits(std::min(Min + 1, BitWidth));
1127   return Known;
1128 }
1129 
print(raw_ostream & OS) const1130 void KnownBits::print(raw_ostream &OS) const {
1131   unsigned BitWidth = getBitWidth();
1132   for (unsigned I = 0; I < BitWidth; ++I) {
1133     unsigned N = BitWidth - I - 1;
1134     if (Zero[N] && One[N])
1135       OS << "!";
1136     else if (Zero[N])
1137       OS << "0";
1138     else if (One[N])
1139       OS << "1";
1140     else
1141       OS << "?";
1142   }
1143 }
dump() const1144 void KnownBits::dump() const {
1145   print(dbgs());
1146   dbgs() << "\n";
1147 }
1148