xref: /freebsd/contrib/llvm-project/llvm/lib/Support/KnownBits.cpp (revision 43e29d03f416d7dda52112a29600a7c82ee1a91e)
1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18 
19 using namespace llvm;
20 
21 static KnownBits computeForAddCarry(
22     const KnownBits &LHS, const KnownBits &RHS,
23     bool CarryZero, bool CarryOne) {
24   assert(!(CarryZero && CarryOne) &&
25          "Carry can't be zero and one at the same time");
26 
27   APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
28   APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
29 
30   // Compute known bits of the carry.
31   APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
32   APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
33 
34   // Compute set of known bits (where all three relevant bits are known).
35   APInt LHSKnownUnion = LHS.Zero | LHS.One;
36   APInt RHSKnownUnion = RHS.Zero | RHS.One;
37   APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
38   APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
39 
40   assert((PossibleSumZero & Known) == (PossibleSumOne & Known) &&
41          "known bits of sum differ");
42 
43   // Compute known bits of the result.
44   KnownBits KnownOut;
45   KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
46   KnownOut.One = std::move(PossibleSumOne) & Known;
47   return KnownOut;
48 }
49 
50 KnownBits KnownBits::computeForAddCarry(
51     const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
52   assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
53   return ::computeForAddCarry(
54       LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
55 }
56 
57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
58                                       const KnownBits &LHS, KnownBits RHS) {
59   KnownBits KnownOut;
60   if (Add) {
61     // Sum = LHS + RHS + 0
62     KnownOut = ::computeForAddCarry(
63         LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
64   } else {
65     // Sum = LHS + ~RHS + 1
66     std::swap(RHS.Zero, RHS.One);
67     KnownOut = ::computeForAddCarry(
68         LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
69   }
70 
71   // Are we still trying to solve for the sign bit?
72   if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
73     if (NSW) {
74       // Adding two non-negative numbers, or subtracting a negative number from
75       // a non-negative one, can't wrap into negative.
76       if (LHS.isNonNegative() && RHS.isNonNegative())
77         KnownOut.makeNonNegative();
78       // Adding two negative numbers, or subtracting a non-negative number from
79       // a negative one, can't wrap into non-negative.
80       else if (LHS.isNegative() && RHS.isNegative())
81         KnownOut.makeNegative();
82     }
83   }
84 
85   return KnownOut;
86 }
87 
88 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
89   unsigned BitWidth = getBitWidth();
90   assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
91          "Illegal sext-in-register");
92 
93   if (SrcBitWidth == BitWidth)
94     return *this;
95 
96   unsigned ExtBits = BitWidth - SrcBitWidth;
97   KnownBits Result;
98   Result.One = One << ExtBits;
99   Result.Zero = Zero << ExtBits;
100   Result.One.ashrInPlace(ExtBits);
101   Result.Zero.ashrInPlace(ExtBits);
102   return Result;
103 }
104 
105 KnownBits KnownBits::makeGE(const APInt &Val) const {
106   // Count the number of leading bit positions where our underlying value is
107   // known to be less than or equal to Val.
108   unsigned N = (Zero | Val).countLeadingOnes();
109 
110   // For each of those bit positions, if Val has a 1 in that bit then our
111   // underlying value must also have a 1.
112   APInt MaskedVal(Val);
113   MaskedVal.clearLowBits(getBitWidth() - N);
114   return KnownBits(Zero, One | MaskedVal);
115 }
116 
117 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
118   // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
119   // RHS. Ideally our caller would already have spotted these cases and
120   // optimized away the umax operation, but we handle them here for
121   // completeness.
122   if (LHS.getMinValue().uge(RHS.getMaxValue()))
123     return LHS;
124   if (RHS.getMinValue().uge(LHS.getMaxValue()))
125     return RHS;
126 
127   // If the result of the umax is LHS then it must be greater than or equal to
128   // the minimum possible value of RHS. Likewise for RHS. Any known bits that
129   // are common to these two values are also known in the result.
130   KnownBits L = LHS.makeGE(RHS.getMinValue());
131   KnownBits R = RHS.makeGE(LHS.getMinValue());
132   return KnownBits::commonBits(L, R);
133 }
134 
135 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
136   // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
137   auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
138   return Flip(umax(Flip(LHS), Flip(RHS)));
139 }
140 
141 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
142   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
143   auto Flip = [](const KnownBits &Val) {
144     unsigned SignBitPosition = Val.getBitWidth() - 1;
145     APInt Zero = Val.Zero;
146     APInt One = Val.One;
147     Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
148     One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
149     return KnownBits(Zero, One);
150   };
151   return Flip(umax(Flip(LHS), Flip(RHS)));
152 }
153 
154 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
155   // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
156   auto Flip = [](const KnownBits &Val) {
157     unsigned SignBitPosition = Val.getBitWidth() - 1;
158     APInt Zero = Val.One;
159     APInt One = Val.Zero;
160     Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
161     One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
162     return KnownBits(Zero, One);
163   };
164   return Flip(umax(Flip(LHS), Flip(RHS)));
165 }
166 
167 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
168   unsigned BitWidth = LHS.getBitWidth();
169   KnownBits Known(BitWidth);
170 
171   // If the shift amount is a valid constant then transform LHS directly.
172   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
173     unsigned Shift = RHS.getConstant().getZExtValue();
174     Known = LHS;
175     Known.Zero <<= Shift;
176     Known.One <<= Shift;
177     // Low bits are known zero.
178     Known.Zero.setLowBits(Shift);
179     return Known;
180   }
181 
182   // No matter the shift amount, the trailing zeros will stay zero.
183   unsigned MinTrailingZeros = LHS.countMinTrailingZeros();
184 
185   // Minimum shift amount low bits are known zero.
186   APInt MinShiftAmount = RHS.getMinValue();
187   if (MinShiftAmount.ult(BitWidth)) {
188     MinTrailingZeros += MinShiftAmount.getZExtValue();
189     MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
190   }
191 
192   // If the maximum shift is in range, then find the common bits from all
193   // possible shifts.
194   APInt MaxShiftAmount = RHS.getMaxValue();
195   if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
196     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
197     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
198     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
199     Known.Zero.setAllBits();
200     Known.One.setAllBits();
201     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
202                   MaxShiftAmt = MaxShiftAmount.getZExtValue();
203          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
204       // Skip if the shift amount is impossible.
205       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
206           (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
207         continue;
208       KnownBits SpecificShift;
209       SpecificShift.Zero = LHS.Zero << ShiftAmt;
210       SpecificShift.One = LHS.One << ShiftAmt;
211       Known = KnownBits::commonBits(Known, SpecificShift);
212       if (Known.isUnknown())
213         break;
214     }
215   }
216 
217   Known.Zero.setLowBits(MinTrailingZeros);
218   return Known;
219 }
220 
221 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) {
222   unsigned BitWidth = LHS.getBitWidth();
223   KnownBits Known(BitWidth);
224 
225   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
226     unsigned Shift = RHS.getConstant().getZExtValue();
227     Known = LHS;
228     Known.Zero.lshrInPlace(Shift);
229     Known.One.lshrInPlace(Shift);
230     // High bits are known zero.
231     Known.Zero.setHighBits(Shift);
232     return Known;
233   }
234 
235   // No matter the shift amount, the leading zeros will stay zero.
236   unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
237 
238   // Minimum shift amount high bits are known zero.
239   APInt MinShiftAmount = RHS.getMinValue();
240   if (MinShiftAmount.ult(BitWidth)) {
241     MinLeadingZeros += MinShiftAmount.getZExtValue();
242     MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
243   }
244 
245   // If the maximum shift is in range, then find the common bits from all
246   // possible shifts.
247   APInt MaxShiftAmount = RHS.getMaxValue();
248   if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
249     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
250     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
251     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
252     Known.Zero.setAllBits();
253     Known.One.setAllBits();
254     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
255                   MaxShiftAmt = MaxShiftAmount.getZExtValue();
256          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
257       // Skip if the shift amount is impossible.
258       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
259           (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
260         continue;
261       KnownBits SpecificShift = LHS;
262       SpecificShift.Zero.lshrInPlace(ShiftAmt);
263       SpecificShift.One.lshrInPlace(ShiftAmt);
264       Known = KnownBits::commonBits(Known, SpecificShift);
265       if (Known.isUnknown())
266         break;
267     }
268   }
269 
270   Known.Zero.setHighBits(MinLeadingZeros);
271   return Known;
272 }
273 
274 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) {
275   unsigned BitWidth = LHS.getBitWidth();
276   KnownBits Known(BitWidth);
277 
278   if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
279     unsigned Shift = RHS.getConstant().getZExtValue();
280     Known = LHS;
281     Known.Zero.ashrInPlace(Shift);
282     Known.One.ashrInPlace(Shift);
283     return Known;
284   }
285 
286   // No matter the shift amount, the leading sign bits will stay.
287   unsigned MinLeadingZeros = LHS.countMinLeadingZeros();
288   unsigned MinLeadingOnes = LHS.countMinLeadingOnes();
289 
290   // Minimum shift amount high bits are known sign bits.
291   APInt MinShiftAmount = RHS.getMinValue();
292   if (MinShiftAmount.ult(BitWidth)) {
293     if (MinLeadingZeros) {
294       MinLeadingZeros += MinShiftAmount.getZExtValue();
295       MinLeadingZeros = std::min(MinLeadingZeros, BitWidth);
296     }
297     if (MinLeadingOnes) {
298       MinLeadingOnes += MinShiftAmount.getZExtValue();
299       MinLeadingOnes = std::min(MinLeadingOnes, BitWidth);
300     }
301   }
302 
303   // If the maximum shift is in range, then find the common bits from all
304   // possible shifts.
305   APInt MaxShiftAmount = RHS.getMaxValue();
306   if (MaxShiftAmount.ult(BitWidth) && !LHS.isUnknown()) {
307     uint64_t ShiftAmtZeroMask = (~RHS.Zero).getZExtValue();
308     uint64_t ShiftAmtOneMask = RHS.One.getZExtValue();
309     assert(MinShiftAmount.ult(MaxShiftAmount) && "Illegal shift range");
310     Known.Zero.setAllBits();
311     Known.One.setAllBits();
312     for (uint64_t ShiftAmt = MinShiftAmount.getZExtValue(),
313                   MaxShiftAmt = MaxShiftAmount.getZExtValue();
314          ShiftAmt <= MaxShiftAmt; ++ShiftAmt) {
315       // Skip if the shift amount is impossible.
316       if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
317           (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
318         continue;
319       KnownBits SpecificShift = LHS;
320       SpecificShift.Zero.ashrInPlace(ShiftAmt);
321       SpecificShift.One.ashrInPlace(ShiftAmt);
322       Known = KnownBits::commonBits(Known, SpecificShift);
323       if (Known.isUnknown())
324         break;
325     }
326   }
327 
328   Known.Zero.setHighBits(MinLeadingZeros);
329   Known.One.setHighBits(MinLeadingOnes);
330   return Known;
331 }
332 
333 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
334   if (LHS.isConstant() && RHS.isConstant())
335     return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
336   if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
337     return std::optional<bool>(false);
338   return std::nullopt;
339 }
340 
341 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
342   if (std::optional<bool> KnownEQ = eq(LHS, RHS))
343     return std::optional<bool>(!*KnownEQ);
344   return std::nullopt;
345 }
346 
347 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
348   // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
349   if (LHS.getMaxValue().ule(RHS.getMinValue()))
350     return std::optional<bool>(false);
351   // LHS >u RHS -> true if umin(LHS) > umax(RHS)
352   if (LHS.getMinValue().ugt(RHS.getMaxValue()))
353     return std::optional<bool>(true);
354   return std::nullopt;
355 }
356 
357 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
358   if (std::optional<bool> IsUGT = ugt(RHS, LHS))
359     return std::optional<bool>(!*IsUGT);
360   return std::nullopt;
361 }
362 
363 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
364   return ugt(RHS, LHS);
365 }
366 
367 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
368   return uge(RHS, LHS);
369 }
370 
371 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
372   // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
373   if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
374     return std::optional<bool>(false);
375   // LHS >s RHS -> true if smin(LHS) > smax(RHS)
376   if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
377     return std::optional<bool>(true);
378   return std::nullopt;
379 }
380 
381 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
382   if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
383     return std::optional<bool>(!*KnownSGT);
384   return std::nullopt;
385 }
386 
387 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
388   return sgt(RHS, LHS);
389 }
390 
391 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
392   return sge(RHS, LHS);
393 }
394 
395 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
396   // If the source's MSB is zero then we know the rest of the bits already.
397   if (isNonNegative())
398     return *this;
399 
400   // Absolute value preserves trailing zero count.
401   KnownBits KnownAbs(getBitWidth());
402   KnownAbs.Zero.setLowBits(countMinTrailingZeros());
403 
404   // We only know that the absolute values's MSB will be zero if INT_MIN is
405   // poison, or there is a set bit that isn't the sign bit (otherwise it could
406   // be INT_MIN).
407   if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue()))
408     KnownAbs.Zero.setSignBit();
409 
410   // FIXME: Handle known negative input?
411   // FIXME: Calculate the negated Known bits and combine them?
412   return KnownAbs;
413 }
414 
415 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
416                          bool NoUndefSelfMultiply) {
417   unsigned BitWidth = LHS.getBitWidth();
418   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
419          !RHS.hasConflict() && "Operand mismatch");
420   assert((!NoUndefSelfMultiply || LHS == RHS) &&
421          "Self multiplication knownbits mismatch");
422 
423   // Compute the high known-0 bits by multiplying the unsigned max of each side.
424   // Conservatively, M active bits * N active bits results in M + N bits in the
425   // result. But if we know a value is a power-of-2 for example, then this
426   // computes one more leading zero.
427   // TODO: This could be generalized to number of sign bits (negative numbers).
428   APInt UMaxLHS = LHS.getMaxValue();
429   APInt UMaxRHS = RHS.getMaxValue();
430 
431   // For leading zeros in the result to be valid, the unsigned max product must
432   // fit in the bitwidth (it must not overflow).
433   bool HasOverflow;
434   APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
435   unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countLeadingZeros();
436 
437   // The result of the bottom bits of an integer multiply can be
438   // inferred by looking at the bottom bits of both operands and
439   // multiplying them together.
440   // We can infer at least the minimum number of known trailing bits
441   // of both operands. Depending on number of trailing zeros, we can
442   // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
443   // a and b are divisible by m and n respectively.
444   // We then calculate how many of those bits are inferrable and set
445   // the output. For example, the i8 mul:
446   //  a = XXXX1100 (12)
447   //  b = XXXX1110 (14)
448   // We know the bottom 3 bits are zero since the first can be divided by
449   // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
450   // Applying the multiplication to the trimmed arguments gets:
451   //    XX11 (3)
452   //    X111 (7)
453   // -------
454   //    XX11
455   //   XX11
456   //  XX11
457   // XX11
458   // -------
459   // XXXXX01
460   // Which allows us to infer the 2 LSBs. Since we're multiplying the result
461   // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
462   // The proof for this can be described as:
463   // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
464   //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
465   //                    umin(countTrailingZeros(C2), C6) +
466   //                    umin(C5 - umin(countTrailingZeros(C1), C5),
467   //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
468   // %aa = shl i8 %a, C5
469   // %bb = shl i8 %b, C6
470   // %aaa = or i8 %aa, C1
471   // %bbb = or i8 %bb, C2
472   // %mul = mul i8 %aaa, %bbb
473   // %mask = and i8 %mul, C7
474   //   =>
475   // %mask = i8 ((C1*C2)&C7)
476   // Where C5, C6 describe the known bits of %a, %b
477   // C1, C2 describe the known bottom bits of %a, %b.
478   // C7 describes the mask of the known bits of the result.
479   const APInt &Bottom0 = LHS.One;
480   const APInt &Bottom1 = RHS.One;
481 
482   // How many times we'd be able to divide each argument by 2 (shr by 1).
483   // This gives us the number of trailing zeros on the multiplication result.
484   unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countTrailingOnes();
485   unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countTrailingOnes();
486   unsigned TrailZero0 = LHS.countMinTrailingZeros();
487   unsigned TrailZero1 = RHS.countMinTrailingZeros();
488   unsigned TrailZ = TrailZero0 + TrailZero1;
489 
490   // Figure out the fewest known-bits operand.
491   unsigned SmallestOperand =
492       std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
493   unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
494 
495   APInt BottomKnown =
496       Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
497 
498   KnownBits Res(BitWidth);
499   Res.Zero.setHighBits(LeadZ);
500   Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
501   Res.One = BottomKnown.getLoBits(ResultBitsKnown);
502 
503   // If we're self-multiplying then bit[1] is guaranteed to be zero.
504   if (NoUndefSelfMultiply && BitWidth > 1) {
505     assert(Res.One[1] == 0 &&
506            "Self-multiplication failed Quadratic Reciprocity!");
507     Res.Zero.setBit(1);
508   }
509 
510   return Res;
511 }
512 
513 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
514   unsigned BitWidth = LHS.getBitWidth();
515   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
516          !RHS.hasConflict() && "Operand mismatch");
517   KnownBits WideLHS = LHS.sext(2 * BitWidth);
518   KnownBits WideRHS = RHS.sext(2 * BitWidth);
519   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
520 }
521 
522 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
523   unsigned BitWidth = LHS.getBitWidth();
524   assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() &&
525          !RHS.hasConflict() && "Operand mismatch");
526   KnownBits WideLHS = LHS.zext(2 * BitWidth);
527   KnownBits WideRHS = RHS.zext(2 * BitWidth);
528   return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
529 }
530 
531 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS) {
532   unsigned BitWidth = LHS.getBitWidth();
533   assert(!LHS.hasConflict() && !RHS.hasConflict());
534   KnownBits Known(BitWidth);
535 
536   // For the purposes of computing leading zeros we can conservatively
537   // treat a udiv as a logical right shift by the power of 2 known to
538   // be less than the denominator.
539   unsigned LeadZ = LHS.countMinLeadingZeros();
540   unsigned RHSMaxLeadingZeros = RHS.countMaxLeadingZeros();
541 
542   if (RHSMaxLeadingZeros != BitWidth)
543     LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1);
544 
545   Known.Zero.setHighBits(LeadZ);
546   return Known;
547 }
548 
549 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
550   unsigned BitWidth = LHS.getBitWidth();
551   assert(!LHS.hasConflict() && !RHS.hasConflict());
552   KnownBits Known(BitWidth);
553 
554   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
555     // The upper bits are all zero, the lower ones are unchanged.
556     APInt LowBits = RHS.getConstant() - 1;
557     Known.Zero = LHS.Zero | ~LowBits;
558     Known.One = LHS.One & LowBits;
559     return Known;
560   }
561 
562   // Since the result is less than or equal to either operand, any leading
563   // zero bits in either operand must also exist in the result.
564   uint32_t Leaders =
565       std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
566   Known.Zero.setHighBits(Leaders);
567   return Known;
568 }
569 
570 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
571   unsigned BitWidth = LHS.getBitWidth();
572   assert(!LHS.hasConflict() && !RHS.hasConflict());
573   KnownBits Known(BitWidth);
574 
575   if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
576     // The low bits of the first operand are unchanged by the srem.
577     APInt LowBits = RHS.getConstant() - 1;
578     Known.Zero = LHS.Zero & LowBits;
579     Known.One = LHS.One & LowBits;
580 
581     // If the first operand is non-negative or has all low bits zero, then
582     // the upper bits are all zero.
583     if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
584       Known.Zero |= ~LowBits;
585 
586     // If the first operand is negative and not all low bits are zero, then
587     // the upper bits are all one.
588     if (LHS.isNegative() && LowBits.intersects(LHS.One))
589       Known.One |= ~LowBits;
590     return Known;
591   }
592 
593   // The sign bit is the LHS's sign bit, except when the result of the
594   // remainder is zero. The magnitude of the result should be less than or
595   // equal to the magnitude of the LHS. Therefore any leading zeros that exist
596   // in the left hand side must also exist in the result.
597   Known.Zero.setHighBits(LHS.countMinLeadingZeros());
598   return Known;
599 }
600 
601 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
602   // Result bit is 0 if either operand bit is 0.
603   Zero |= RHS.Zero;
604   // Result bit is 1 if both operand bits are 1.
605   One &= RHS.One;
606   return *this;
607 }
608 
609 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
610   // Result bit is 0 if both operand bits are 0.
611   Zero &= RHS.Zero;
612   // Result bit is 1 if either operand bit is 1.
613   One |= RHS.One;
614   return *this;
615 }
616 
617 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
618   // Result bit is 0 if both operand bits are 0 or both are 1.
619   APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
620   // Result bit is 1 if one operand bit is 0 and the other is 1.
621   One = (Zero & RHS.One) | (One & RHS.Zero);
622   Zero = std::move(Z);
623   return *this;
624 }
625 
626 void KnownBits::print(raw_ostream &OS) const {
627   OS << "{Zero=" << Zero << ", One=" << One << "}";
628 }
629 void KnownBits::dump() const {
630   print(dbgs());
631   dbgs() << "\n";
632 }
633