1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Support/KnownBits.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/raw_ostream.h"
17 #include <cassert>
18
19 using namespace llvm;
20
computeForAddCarry(const KnownBits & LHS,const KnownBits & RHS,bool CarryZero,bool CarryOne)21 static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
22 bool CarryZero, bool CarryOne) {
23
24 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
25 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
26
27 // Compute known bits of the carry.
28 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
29 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
30
31 // Compute set of known bits (where all three relevant bits are known).
32 APInt LHSKnownUnion = LHS.Zero | LHS.One;
33 APInt RHSKnownUnion = RHS.Zero | RHS.One;
34 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
35 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
36
37 // Compute known bits of the result.
38 KnownBits KnownOut;
39 KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
40 KnownOut.One = std::move(PossibleSumOne) & Known;
41 return KnownOut;
42 }
43
computeForAddCarry(const KnownBits & LHS,const KnownBits & RHS,const KnownBits & Carry)44 KnownBits KnownBits::computeForAddCarry(
45 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
46 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
47 return ::computeForAddCarry(
48 LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
49 }
50
computeForAddSub(bool Add,bool NSW,bool NUW,const KnownBits & LHS,const KnownBits & RHS)51 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
52 const KnownBits &LHS,
53 const KnownBits &RHS) {
54 unsigned BitWidth = LHS.getBitWidth();
55 KnownBits KnownOut(BitWidth);
56 // This can be a relatively expensive helper, so optimistically save some
57 // work.
58 if (LHS.isUnknown() && RHS.isUnknown())
59 return KnownOut;
60
61 if (!LHS.isUnknown() && !RHS.isUnknown()) {
62 if (Add) {
63 // Sum = LHS + RHS + 0
64 KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
65 /*CarryOne=*/false);
66 } else {
67 // Sum = LHS + ~RHS + 1
68 KnownBits NotRHS = RHS;
69 std::swap(NotRHS.Zero, NotRHS.One);
70 KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
71 /*CarryOne=*/true);
72 }
73 }
74
75 // Handle add/sub given nsw and/or nuw.
76 if (NUW) {
77 if (Add) {
78 // (add nuw X, Y)
79 APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
80 // None of the adds can end up overflowing, so min consecutive highbits
81 // in minimum possible of X + Y must all remain set.
82 if (NSW) {
83 unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
84 // If we have NSW as well, we also know we can't overflow the signbit so
85 // can start counting from 1 bit back.
86 KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
87 }
88 KnownOut.One.setHighBits(MinVal.countl_one());
89 } else {
90 // (sub nuw X, Y)
91 APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue());
92 // None of the subs can overflow at any point, so any common high bits
93 // will subtract away and result in zeros.
94 if (NSW) {
95 // If we have NSW as well, we also know we can't overflow the signbit so
96 // can start counting from 1 bit back.
97 unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
98 KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
99 }
100 KnownOut.Zero.setHighBits(MaxVal.countl_zero());
101 }
102 }
103
104 if (NSW) {
105 APInt MinVal;
106 APInt MaxVal;
107 if (Add) {
108 // (add nsw X, Y)
109 MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue());
110 MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue());
111 } else {
112 // (sub nsw X, Y)
113 MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue());
114 MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue());
115 }
116 if (MinVal.isNonNegative()) {
117 // If min is non-negative, result will always be non-neg (can't overflow
118 // around).
119 unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
120 KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
121 KnownOut.Zero.setSignBit();
122 }
123 if (MaxVal.isNegative()) {
124 // If max is negative, result will always be neg (can't overflow around).
125 unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
126 KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
127 KnownOut.One.setSignBit();
128 }
129 }
130
131 // Just return 0 if the nsw/nuw is violated and we have poison.
132 if (KnownOut.hasConflict())
133 KnownOut.setAllZero();
134 return KnownOut;
135 }
136
computeForSubBorrow(const KnownBits & LHS,KnownBits RHS,const KnownBits & Borrow)137 KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
138 const KnownBits &Borrow) {
139 assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
140
141 // LHS - RHS = LHS + ~RHS + 1
142 // Carry 1 - Borrow in ::computeForAddCarry
143 std::swap(RHS.Zero, RHS.One);
144 return ::computeForAddCarry(LHS, RHS,
145 /*CarryZero=*/Borrow.One.getBoolValue(),
146 /*CarryOne=*/Borrow.Zero.getBoolValue());
147 }
148
sextInReg(unsigned SrcBitWidth) const149 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
150 unsigned BitWidth = getBitWidth();
151 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
152 "Illegal sext-in-register");
153
154 if (SrcBitWidth == BitWidth)
155 return *this;
156
157 unsigned ExtBits = BitWidth - SrcBitWidth;
158 KnownBits Result;
159 Result.One = One << ExtBits;
160 Result.Zero = Zero << ExtBits;
161 Result.One.ashrInPlace(ExtBits);
162 Result.Zero.ashrInPlace(ExtBits);
163 return Result;
164 }
165
makeGE(const APInt & Val) const166 KnownBits KnownBits::makeGE(const APInt &Val) const {
167 // Count the number of leading bit positions where our underlying value is
168 // known to be less than or equal to Val.
169 unsigned N = (Zero | Val).countl_one();
170
171 // For each of those bit positions, if Val has a 1 in that bit then our
172 // underlying value must also have a 1.
173 APInt MaskedVal(Val);
174 MaskedVal.clearLowBits(getBitWidth() - N);
175 return KnownBits(Zero, One | MaskedVal);
176 }
177
umax(const KnownBits & LHS,const KnownBits & RHS)178 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
179 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
180 // RHS. Ideally our caller would already have spotted these cases and
181 // optimized away the umax operation, but we handle them here for
182 // completeness.
183 if (LHS.getMinValue().uge(RHS.getMaxValue()))
184 return LHS;
185 if (RHS.getMinValue().uge(LHS.getMaxValue()))
186 return RHS;
187
188 // If the result of the umax is LHS then it must be greater than or equal to
189 // the minimum possible value of RHS. Likewise for RHS. Any known bits that
190 // are common to these two values are also known in the result.
191 KnownBits L = LHS.makeGE(RHS.getMinValue());
192 KnownBits R = RHS.makeGE(LHS.getMinValue());
193 return L.intersectWith(R);
194 }
195
umin(const KnownBits & LHS,const KnownBits & RHS)196 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
197 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
198 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
199 return Flip(umax(Flip(LHS), Flip(RHS)));
200 }
201
smax(const KnownBits & LHS,const KnownBits & RHS)202 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
203 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
204 auto Flip = [](const KnownBits &Val) {
205 unsigned SignBitPosition = Val.getBitWidth() - 1;
206 APInt Zero = Val.Zero;
207 APInt One = Val.One;
208 Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
209 One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
210 return KnownBits(Zero, One);
211 };
212 return Flip(umax(Flip(LHS), Flip(RHS)));
213 }
214
smin(const KnownBits & LHS,const KnownBits & RHS)215 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
216 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
217 auto Flip = [](const KnownBits &Val) {
218 unsigned SignBitPosition = Val.getBitWidth() - 1;
219 APInt Zero = Val.One;
220 APInt One = Val.Zero;
221 Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
222 One.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
223 return KnownBits(Zero, One);
224 };
225 return Flip(umax(Flip(LHS), Flip(RHS)));
226 }
227
abdu(const KnownBits & LHS,const KnownBits & RHS)228 KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
229 // If we know which argument is larger, return (sub LHS, RHS) or
230 // (sub RHS, LHS) directly.
231 if (LHS.getMinValue().uge(RHS.getMaxValue()))
232 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
233 RHS);
234 if (RHS.getMinValue().uge(LHS.getMaxValue()))
235 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
236 LHS);
237
238 // By construction, the subtraction in abdu never has unsigned overflow.
239 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
240 KnownBits Diff0 =
241 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
242 KnownBits Diff1 =
243 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
244 return Diff0.intersectWith(Diff1);
245 }
246
abds(KnownBits LHS,KnownBits RHS)247 KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
248 // If we know which argument is larger, return (sub LHS, RHS) or
249 // (sub RHS, LHS) directly.
250 if (LHS.getSignedMinValue().sge(RHS.getSignedMaxValue()))
251 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
252 RHS);
253 if (RHS.getSignedMinValue().sge(LHS.getSignedMaxValue()))
254 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS,
255 LHS);
256
257 // Shift both arguments from the signed range to the unsigned range, e.g. from
258 // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
259 // abdu does.
260 // Note that we can't just use "sub nsw" instead because abds has signed
261 // inputs but an unsigned result, which makes the overflow conditions
262 // different.
263 unsigned SignBitPosition = LHS.getBitWidth() - 1;
264 for (auto Arg : {&LHS, &RHS}) {
265 bool Tmp = Arg->Zero[SignBitPosition];
266 Arg->Zero.setBitVal(SignBitPosition, Arg->One[SignBitPosition]);
267 Arg->One.setBitVal(SignBitPosition, Tmp);
268 }
269
270 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
271 KnownBits Diff0 =
272 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
273 KnownBits Diff1 =
274 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, RHS, LHS);
275 return Diff0.intersectWith(Diff1);
276 }
277
getMaxShiftAmount(const APInt & MaxValue,unsigned BitWidth)278 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
279 if (isPowerOf2_32(BitWidth))
280 return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
281 // This is only an approximate upper bound.
282 return MaxValue.getLimitedValue(BitWidth - 1);
283 }
284
shl(const KnownBits & LHS,const KnownBits & RHS,bool NUW,bool NSW,bool ShAmtNonZero)285 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
286 bool NSW, bool ShAmtNonZero) {
287 unsigned BitWidth = LHS.getBitWidth();
288 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
289 KnownBits Known;
290 bool ShiftedOutZero, ShiftedOutOne;
291 Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero);
292 Known.Zero.setLowBits(ShiftAmt);
293 Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne);
294
295 // All cases returning poison have been handled by MaxShiftAmount already.
296 if (NSW) {
297 if (NUW && ShiftAmt != 0)
298 // NUW means we can assume anything shifted out was a zero.
299 ShiftedOutZero = true;
300
301 if (ShiftedOutZero)
302 Known.makeNonNegative();
303 else if (ShiftedOutOne)
304 Known.makeNegative();
305 }
306 return Known;
307 };
308
309 // Fast path for a common case when LHS is completely unknown.
310 KnownBits Known(BitWidth);
311 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
312 if (MinShiftAmount == 0 && ShAmtNonZero)
313 MinShiftAmount = 1;
314 if (LHS.isUnknown()) {
315 Known.Zero.setLowBits(MinShiftAmount);
316 if (NUW && NSW && MinShiftAmount != 0)
317 Known.makeNonNegative();
318 return Known;
319 }
320
321 // Determine maximum shift amount, taking NUW/NSW flags into account.
322 APInt MaxValue = RHS.getMaxValue();
323 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
324 if (NUW && NSW)
325 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1);
326 if (NUW)
327 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros());
328 if (NSW)
329 MaxShiftAmount = std::min(
330 MaxShiftAmount,
331 std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1);
332
333 // Fast path for common case where the shift amount is unknown.
334 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
335 isPowerOf2_32(BitWidth)) {
336 Known.Zero.setLowBits(LHS.countMinTrailingZeros());
337 if (LHS.isAllOnes())
338 Known.One.setSignBit();
339 if (NSW) {
340 if (LHS.isNonNegative())
341 Known.makeNonNegative();
342 if (LHS.isNegative())
343 Known.makeNegative();
344 }
345 return Known;
346 }
347
348 // Find the common bits from all possible shifts.
349 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
350 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
351 Known.Zero.setAllBits();
352 Known.One.setAllBits();
353 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
354 ++ShiftAmt) {
355 // Skip if the shift amount is impossible.
356 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
357 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
358 continue;
359 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
360 if (Known.isUnknown())
361 break;
362 }
363
364 // All shift amounts may result in poison.
365 if (Known.hasConflict())
366 Known.setAllZero();
367 return Known;
368 }
369
lshr(const KnownBits & LHS,const KnownBits & RHS,bool ShAmtNonZero,bool Exact)370 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
371 bool ShAmtNonZero, bool Exact) {
372 unsigned BitWidth = LHS.getBitWidth();
373 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
374 KnownBits Known = LHS;
375 Known.Zero.lshrInPlace(ShiftAmt);
376 Known.One.lshrInPlace(ShiftAmt);
377 // High bits are known zero.
378 Known.Zero.setHighBits(ShiftAmt);
379 return Known;
380 };
381
382 // Fast path for a common case when LHS is completely unknown.
383 KnownBits Known(BitWidth);
384 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
385 if (MinShiftAmount == 0 && ShAmtNonZero)
386 MinShiftAmount = 1;
387 if (LHS.isUnknown()) {
388 Known.Zero.setHighBits(MinShiftAmount);
389 return Known;
390 }
391
392 // Find the common bits from all possible shifts.
393 APInt MaxValue = RHS.getMaxValue();
394 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
395
396 // If exact, bound MaxShiftAmount to first known 1 in LHS.
397 if (Exact) {
398 unsigned FirstOne = LHS.countMaxTrailingZeros();
399 if (FirstOne < MinShiftAmount) {
400 // Always poison. Return zero because we don't like returning conflict.
401 Known.setAllZero();
402 return Known;
403 }
404 MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
405 }
406
407 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
408 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
409 Known.Zero.setAllBits();
410 Known.One.setAllBits();
411 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
412 ++ShiftAmt) {
413 // Skip if the shift amount is impossible.
414 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
415 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
416 continue;
417 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
418 if (Known.isUnknown())
419 break;
420 }
421
422 // All shift amounts may result in poison.
423 if (Known.hasConflict())
424 Known.setAllZero();
425 return Known;
426 }
427
ashr(const KnownBits & LHS,const KnownBits & RHS,bool ShAmtNonZero,bool Exact)428 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
429 bool ShAmtNonZero, bool Exact) {
430 unsigned BitWidth = LHS.getBitWidth();
431 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
432 KnownBits Known = LHS;
433 Known.Zero.ashrInPlace(ShiftAmt);
434 Known.One.ashrInPlace(ShiftAmt);
435 return Known;
436 };
437
438 // Fast path for a common case when LHS is completely unknown.
439 KnownBits Known(BitWidth);
440 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth);
441 if (MinShiftAmount == 0 && ShAmtNonZero)
442 MinShiftAmount = 1;
443 if (LHS.isUnknown()) {
444 if (MinShiftAmount == BitWidth) {
445 // Always poison. Return zero because we don't like returning conflict.
446 Known.setAllZero();
447 return Known;
448 }
449 return Known;
450 }
451
452 // Find the common bits from all possible shifts.
453 APInt MaxValue = RHS.getMaxValue();
454 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
455
456 // If exact, bound MaxShiftAmount to first known 1 in LHS.
457 if (Exact) {
458 unsigned FirstOne = LHS.countMaxTrailingZeros();
459 if (FirstOne < MinShiftAmount) {
460 // Always poison. Return zero because we don't like returning conflict.
461 Known.setAllZero();
462 return Known;
463 }
464 MaxShiftAmount = std::min(MaxShiftAmount, FirstOne);
465 }
466
467 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue();
468 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue();
469 Known.Zero.setAllBits();
470 Known.One.setAllBits();
471 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
472 ++ShiftAmt) {
473 // Skip if the shift amount is impossible.
474 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
475 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
476 continue;
477 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt));
478 if (Known.isUnknown())
479 break;
480 }
481
482 // All shift amounts may result in poison.
483 if (Known.hasConflict())
484 Known.setAllZero();
485 return Known;
486 }
487
eq(const KnownBits & LHS,const KnownBits & RHS)488 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
489 if (LHS.isConstant() && RHS.isConstant())
490 return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
491 if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero))
492 return std::optional<bool>(false);
493 return std::nullopt;
494 }
495
ne(const KnownBits & LHS,const KnownBits & RHS)496 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
497 if (std::optional<bool> KnownEQ = eq(LHS, RHS))
498 return std::optional<bool>(!*KnownEQ);
499 return std::nullopt;
500 }
501
ugt(const KnownBits & LHS,const KnownBits & RHS)502 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
503 // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
504 if (LHS.getMaxValue().ule(RHS.getMinValue()))
505 return std::optional<bool>(false);
506 // LHS >u RHS -> true if umin(LHS) > umax(RHS)
507 if (LHS.getMinValue().ugt(RHS.getMaxValue()))
508 return std::optional<bool>(true);
509 return std::nullopt;
510 }
511
uge(const KnownBits & LHS,const KnownBits & RHS)512 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
513 if (std::optional<bool> IsUGT = ugt(RHS, LHS))
514 return std::optional<bool>(!*IsUGT);
515 return std::nullopt;
516 }
517
ult(const KnownBits & LHS,const KnownBits & RHS)518 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
519 return ugt(RHS, LHS);
520 }
521
ule(const KnownBits & LHS,const KnownBits & RHS)522 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
523 return uge(RHS, LHS);
524 }
525
sgt(const KnownBits & LHS,const KnownBits & RHS)526 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
527 // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
528 if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue()))
529 return std::optional<bool>(false);
530 // LHS >s RHS -> true if smin(LHS) > smax(RHS)
531 if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue()))
532 return std::optional<bool>(true);
533 return std::nullopt;
534 }
535
sge(const KnownBits & LHS,const KnownBits & RHS)536 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
537 if (std::optional<bool> KnownSGT = sgt(RHS, LHS))
538 return std::optional<bool>(!*KnownSGT);
539 return std::nullopt;
540 }
541
slt(const KnownBits & LHS,const KnownBits & RHS)542 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
543 return sgt(RHS, LHS);
544 }
545
sle(const KnownBits & LHS,const KnownBits & RHS)546 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
547 return sge(RHS, LHS);
548 }
549
abs(bool IntMinIsPoison) const550 KnownBits KnownBits::abs(bool IntMinIsPoison) const {
551 // If the source's MSB is zero then we know the rest of the bits already.
552 if (isNonNegative())
553 return *this;
554
555 // Absolute value preserves trailing zero count.
556 KnownBits KnownAbs(getBitWidth());
557
558 // If the input is negative, then abs(x) == -x.
559 if (isNegative()) {
560 KnownBits Tmp = *this;
561 // Special case for IntMinIsPoison. We know the sign bit is set and we know
562 // all the rest of the bits except one to be zero. Since we have
563 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
564 // INT_MIN.
565 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
566 Tmp.One.setBit(countMinTrailingZeros());
567
568 KnownAbs = computeForAddSub(
569 /*Add*/ false, IntMinIsPoison, /*NUW=*/false,
570 KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
571
572 // One more special case for IntMinIsPoison. If we don't know any ones other
573 // than the signbit, we know for certain that all the unknowns can't be
574 // zero. So if we know high zero bits, but have unknown low bits, we know
575 // for certain those high-zero bits will end up as one. This is because,
576 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
577 // to the high bits. If we know a known INT_MIN input skip this. The result
578 // is poison anyways.
579 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
580 Tmp.countMaxPopulation() != 1) {
581 Tmp.One.clearSignBit();
582 Tmp.Zero.setSignBit();
583 KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(),
584 getBitWidth() - 1);
585 }
586
587 } else {
588 unsigned MaxTZ = countMaxTrailingZeros();
589 unsigned MinTZ = countMinTrailingZeros();
590
591 KnownAbs.Zero.setLowBits(MinTZ);
592 // If we know the lowest set 1, then preserve it.
593 if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
594 KnownAbs.One.setBit(MaxTZ);
595
596 // We only know that the absolute values's MSB will be zero if INT_MIN is
597 // poison, or there is a set bit that isn't the sign bit (otherwise it could
598 // be INT_MIN).
599 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
600 KnownAbs.One.clearSignBit();
601 KnownAbs.Zero.setSignBit();
602 }
603 }
604
605 return KnownAbs;
606 }
607
computeForSatAddSub(bool Add,bool Signed,const KnownBits & LHS,const KnownBits & RHS)608 static KnownBits computeForSatAddSub(bool Add, bool Signed,
609 const KnownBits &LHS,
610 const KnownBits &RHS) {
611 // We don't see NSW even for sadd/ssub as we want to check if the result has
612 // signed overflow.
613 KnownBits Res =
614 KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
615 unsigned BitWidth = Res.getBitWidth();
616 auto SignBitKnown = [&](const KnownBits &K) {
617 return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
618 };
619 std::optional<bool> Overflow;
620
621 if (Signed) {
622 // If we can actually detect overflow do so. Otherwise leave Overflow as
623 // nullopt (we assume it may have happened).
624 if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
625 if (Add) {
626 // sadd.sat
627 Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
628 Res.isNonNegative() != LHS.isNonNegative());
629 } else {
630 // ssub.sat
631 Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
632 Res.isNonNegative() != LHS.isNonNegative());
633 }
634 }
635 } else if (Add) {
636 // uadd.sat
637 bool Of;
638 (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of);
639 if (!Of) {
640 Overflow = false;
641 } else {
642 (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of);
643 if (Of)
644 Overflow = true;
645 }
646 } else {
647 // usub.sat
648 bool Of;
649 (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of);
650 if (!Of) {
651 Overflow = false;
652 } else {
653 (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of);
654 if (Of)
655 Overflow = true;
656 }
657 }
658
659 if (Signed) {
660 if (Add) {
661 if (LHS.isNonNegative() && RHS.isNonNegative()) {
662 // Pos + Pos -> Pos
663 Res.One.clearSignBit();
664 Res.Zero.setSignBit();
665 }
666 if (LHS.isNegative() && RHS.isNegative()) {
667 // Neg + Neg -> Neg
668 Res.One.setSignBit();
669 Res.Zero.clearSignBit();
670 }
671 } else {
672 if (LHS.isNegative() && RHS.isNonNegative()) {
673 // Neg - Pos -> Neg
674 Res.One.setSignBit();
675 Res.Zero.clearSignBit();
676 } else if (LHS.isNonNegative() && RHS.isNegative()) {
677 // Pos - Neg -> Pos
678 Res.One.clearSignBit();
679 Res.Zero.setSignBit();
680 }
681 }
682 } else {
683 // Add: Leading ones of either operand are preserved.
684 // Sub: Leading zeros of LHS and leading ones of RHS are preserved
685 // as leading zeros in the result.
686 unsigned LeadingKnown;
687 if (Add)
688 LeadingKnown =
689 std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
690 else
691 LeadingKnown =
692 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
693
694 // We select between the operation result and all-ones/zero
695 // respectively, so we can preserve known ones/zeros.
696 APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
697 if (Add) {
698 Res.One |= Mask;
699 Res.Zero &= ~Mask;
700 } else {
701 Res.Zero |= Mask;
702 Res.One &= ~Mask;
703 }
704 }
705
706 if (Overflow) {
707 // We know whether or not we overflowed.
708 if (!(*Overflow)) {
709 // No overflow.
710 return Res;
711 }
712
713 // We overflowed
714 APInt C;
715 if (Signed) {
716 // sadd.sat / ssub.sat
717 assert(SignBitKnown(LHS) &&
718 "We somehow know overflow without knowing input sign");
719 C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
720 : APInt::getSignedMaxValue(BitWidth);
721 } else if (Add) {
722 // uadd.sat
723 C = APInt::getMaxValue(BitWidth);
724 } else {
725 // uadd.sat
726 C = APInt::getMinValue(BitWidth);
727 }
728
729 Res.One = C;
730 Res.Zero = ~C;
731 return Res;
732 }
733
734 // We don't know if we overflowed.
735 if (Signed) {
736 // sadd.sat/ssub.sat
737 // We can keep our information about the sign bits.
738 Res.Zero.clearLowBits(BitWidth - 1);
739 Res.One.clearLowBits(BitWidth - 1);
740 } else if (Add) {
741 // uadd.sat
742 // We need to clear all the known zeros as we can only use the leading ones.
743 Res.Zero.clearAllBits();
744 } else {
745 // usub.sat
746 // We need to clear all the known ones as we can only use the leading zero.
747 Res.One.clearAllBits();
748 }
749
750 return Res;
751 }
752
sadd_sat(const KnownBits & LHS,const KnownBits & RHS)753 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
754 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
755 }
ssub_sat(const KnownBits & LHS,const KnownBits & RHS)756 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
757 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
758 }
uadd_sat(const KnownBits & LHS,const KnownBits & RHS)759 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
760 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
761 }
usub_sat(const KnownBits & LHS,const KnownBits & RHS)762 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
763 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
764 }
765
avgCompute(KnownBits LHS,KnownBits RHS,bool IsCeil,bool IsSigned)766 static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
767 bool IsSigned) {
768 unsigned BitWidth = LHS.getBitWidth();
769 LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
770 RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
771 LHS =
772 computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
773 LHS = LHS.extractBits(BitWidth, 1);
774 return LHS;
775 }
776
avgFloorS(const KnownBits & LHS,const KnownBits & RHS)777 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
778 return avgCompute(LHS, RHS, /* IsCeil */ false,
779 /* IsSigned */ true);
780 }
781
avgFloorU(const KnownBits & LHS,const KnownBits & RHS)782 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
783 return avgCompute(LHS, RHS, /* IsCeil */ false,
784 /* IsSigned */ false);
785 }
786
avgCeilS(const KnownBits & LHS,const KnownBits & RHS)787 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
788 return avgCompute(LHS, RHS, /* IsCeil */ true,
789 /* IsSigned */ true);
790 }
791
avgCeilU(const KnownBits & LHS,const KnownBits & RHS)792 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
793 return avgCompute(LHS, RHS, /* IsCeil */ true,
794 /* IsSigned */ false);
795 }
796
mul(const KnownBits & LHS,const KnownBits & RHS,bool NoUndefSelfMultiply)797 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
798 bool NoUndefSelfMultiply) {
799 unsigned BitWidth = LHS.getBitWidth();
800 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
801 assert((!NoUndefSelfMultiply || LHS == RHS) &&
802 "Self multiplication knownbits mismatch");
803
804 // Compute the high known-0 bits by multiplying the unsigned max of each side.
805 // Conservatively, M active bits * N active bits results in M + N bits in the
806 // result. But if we know a value is a power-of-2 for example, then this
807 // computes one more leading zero.
808 // TODO: This could be generalized to number of sign bits (negative numbers).
809 APInt UMaxLHS = LHS.getMaxValue();
810 APInt UMaxRHS = RHS.getMaxValue();
811
812 // For leading zeros in the result to be valid, the unsigned max product must
813 // fit in the bitwidth (it must not overflow).
814 bool HasOverflow;
815 APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow);
816 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
817
818 // The result of the bottom bits of an integer multiply can be
819 // inferred by looking at the bottom bits of both operands and
820 // multiplying them together.
821 // We can infer at least the minimum number of known trailing bits
822 // of both operands. Depending on number of trailing zeros, we can
823 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
824 // a and b are divisible by m and n respectively.
825 // We then calculate how many of those bits are inferrable and set
826 // the output. For example, the i8 mul:
827 // a = XXXX1100 (12)
828 // b = XXXX1110 (14)
829 // We know the bottom 3 bits are zero since the first can be divided by
830 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
831 // Applying the multiplication to the trimmed arguments gets:
832 // XX11 (3)
833 // X111 (7)
834 // -------
835 // XX11
836 // XX11
837 // XX11
838 // XX11
839 // -------
840 // XXXXX01
841 // Which allows us to infer the 2 LSBs. Since we're multiplying the result
842 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
843 // The proof for this can be described as:
844 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
845 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
846 // umin(countTrailingZeros(C2), C6) +
847 // umin(C5 - umin(countTrailingZeros(C1), C5),
848 // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
849 // %aa = shl i8 %a, C5
850 // %bb = shl i8 %b, C6
851 // %aaa = or i8 %aa, C1
852 // %bbb = or i8 %bb, C2
853 // %mul = mul i8 %aaa, %bbb
854 // %mask = and i8 %mul, C7
855 // =>
856 // %mask = i8 ((C1*C2)&C7)
857 // Where C5, C6 describe the known bits of %a, %b
858 // C1, C2 describe the known bottom bits of %a, %b.
859 // C7 describes the mask of the known bits of the result.
860 const APInt &Bottom0 = LHS.One;
861 const APInt &Bottom1 = RHS.One;
862
863 // How many times we'd be able to divide each argument by 2 (shr by 1).
864 // This gives us the number of trailing zeros on the multiplication result.
865 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
866 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
867 unsigned TrailZero0 = LHS.countMinTrailingZeros();
868 unsigned TrailZero1 = RHS.countMinTrailingZeros();
869 unsigned TrailZ = TrailZero0 + TrailZero1;
870
871 // Figure out the fewest known-bits operand.
872 unsigned SmallestOperand =
873 std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1);
874 unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
875
876 APInt BottomKnown =
877 Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1);
878
879 KnownBits Res(BitWidth);
880 Res.Zero.setHighBits(LeadZ);
881 Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
882 Res.One = BottomKnown.getLoBits(ResultBitsKnown);
883
884 // If we're self-multiplying then bit[1] is guaranteed to be zero.
885 if (NoUndefSelfMultiply && BitWidth > 1) {
886 assert(Res.One[1] == 0 &&
887 "Self-multiplication failed Quadratic Reciprocity!");
888 Res.Zero.setBit(1);
889 }
890
891 return Res;
892 }
893
mulhs(const KnownBits & LHS,const KnownBits & RHS)894 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
895 unsigned BitWidth = LHS.getBitWidth();
896 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
897 KnownBits WideLHS = LHS.sext(2 * BitWidth);
898 KnownBits WideRHS = RHS.sext(2 * BitWidth);
899 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
900 }
901
mulhu(const KnownBits & LHS,const KnownBits & RHS)902 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
903 unsigned BitWidth = LHS.getBitWidth();
904 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
905 KnownBits WideLHS = LHS.zext(2 * BitWidth);
906 KnownBits WideRHS = RHS.zext(2 * BitWidth);
907 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth);
908 }
909
divComputeLowBit(KnownBits Known,const KnownBits & LHS,const KnownBits & RHS,bool Exact)910 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
911 const KnownBits &RHS, bool Exact) {
912
913 if (!Exact)
914 return Known;
915
916 // If LHS is Odd, the result is Odd no matter what.
917 // Odd / Odd -> Odd
918 // Odd / Even -> Impossible (because its exact division)
919 if (LHS.One[0])
920 Known.One.setBit(0);
921
922 int MinTZ =
923 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
924 int MaxTZ =
925 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
926 if (MinTZ >= 0) {
927 // Result has at least MinTZ trailing zeros.
928 Known.Zero.setLowBits(MinTZ);
929 if (MinTZ == MaxTZ) {
930 // Result has exactly MinTZ trailing zeros.
931 Known.One.setBit(MinTZ);
932 }
933 } else if (MaxTZ < 0) {
934 // Poison Result
935 Known.setAllZero();
936 }
937
938 // In the KnownBits exhaustive tests, we have poison inputs for exact values
939 // a LOT. If we have a conflict, just return all zeros.
940 if (Known.hasConflict())
941 Known.setAllZero();
942
943 return Known;
944 }
945
sdiv(const KnownBits & LHS,const KnownBits & RHS,bool Exact)946 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
947 bool Exact) {
948 // Equivalent of `udiv`. We must have caught this before it was folded.
949 if (LHS.isNonNegative() && RHS.isNonNegative())
950 return udiv(LHS, RHS, Exact);
951
952 unsigned BitWidth = LHS.getBitWidth();
953 KnownBits Known(BitWidth);
954
955 if (LHS.isZero() || RHS.isZero()) {
956 // Result is either known Zero or UB. Return Zero either way.
957 // Checking this earlier saves us a lot of special cases later on.
958 Known.setAllZero();
959 return Known;
960 }
961
962 std::optional<APInt> Res;
963 if (LHS.isNegative() && RHS.isNegative()) {
964 // Result non-negative.
965 APInt Denom = RHS.getSignedMaxValue();
966 APInt Num = LHS.getSignedMinValue();
967 // INT_MIN/-1 would be a poison result (impossible). Estimate the division
968 // as signed max (we will only set sign bit in the result).
969 Res = (Num.isMinSignedValue() && Denom.isAllOnes())
970 ? APInt::getSignedMaxValue(BitWidth)
971 : Num.sdiv(Denom);
972 } else if (LHS.isNegative() && RHS.isNonNegative()) {
973 // Result is negative if Exact OR -LHS u>= RHS.
974 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
975 APInt Denom = RHS.getSignedMinValue();
976 APInt Num = LHS.getSignedMinValue();
977 Res = Denom.isZero() ? Num : Num.sdiv(Denom);
978 }
979 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
980 // Result is negative if Exact OR LHS u>= -RHS.
981 if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) {
982 APInt Denom = RHS.getSignedMaxValue();
983 APInt Num = LHS.getSignedMaxValue();
984 Res = Num.sdiv(Denom);
985 }
986 }
987
988 if (Res) {
989 if (Res->isNonNegative()) {
990 unsigned LeadZ = Res->countLeadingZeros();
991 Known.Zero.setHighBits(LeadZ);
992 } else {
993 unsigned LeadO = Res->countLeadingOnes();
994 Known.One.setHighBits(LeadO);
995 }
996 }
997
998 Known = divComputeLowBit(Known, LHS, RHS, Exact);
999 return Known;
1000 }
1001
udiv(const KnownBits & LHS,const KnownBits & RHS,bool Exact)1002 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
1003 bool Exact) {
1004 unsigned BitWidth = LHS.getBitWidth();
1005 KnownBits Known(BitWidth);
1006
1007 if (LHS.isZero() || RHS.isZero()) {
1008 // Result is either known Zero or UB. Return Zero either way.
1009 // Checking this earlier saves us a lot of special cases later on.
1010 Known.setAllZero();
1011 return Known;
1012 }
1013
1014 // We can figure out the minimum number of upper zero bits by doing
1015 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
1016 // gets larger, the number of upper zero bits increases.
1017 APInt MinDenom = RHS.getMinValue();
1018 APInt MaxNum = LHS.getMaxValue();
1019 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom);
1020
1021 unsigned LeadZ = MaxRes.countLeadingZeros();
1022
1023 Known.Zero.setHighBits(LeadZ);
1024 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1025
1026 return Known;
1027 }
1028
remGetLowBits(const KnownBits & LHS,const KnownBits & RHS)1029 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
1030 unsigned BitWidth = LHS.getBitWidth();
1031 if (!RHS.isZero() && RHS.Zero[0]) {
1032 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
1033 unsigned RHSZeros = RHS.countMinTrailingZeros();
1034 APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros);
1035 APInt OnesMask = LHS.One & Mask;
1036 APInt ZerosMask = LHS.Zero & Mask;
1037 return KnownBits(ZerosMask, OnesMask);
1038 }
1039 return KnownBits(BitWidth);
1040 }
1041
urem(const KnownBits & LHS,const KnownBits & RHS)1042 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
1043 KnownBits Known = remGetLowBits(LHS, RHS);
1044 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1045 // NB: Low bits set in `remGetLowBits`.
1046 APInt HighBits = ~(RHS.getConstant() - 1);
1047 Known.Zero |= HighBits;
1048 return Known;
1049 }
1050
1051 // Since the result is less than or equal to either operand, any leading
1052 // zero bits in either operand must also exist in the result.
1053 uint32_t Leaders =
1054 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros());
1055 Known.Zero.setHighBits(Leaders);
1056 return Known;
1057 }
1058
srem(const KnownBits & LHS,const KnownBits & RHS)1059 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
1060 KnownBits Known = remGetLowBits(LHS, RHS);
1061 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1062 // NB: Low bits are set in `remGetLowBits`.
1063 APInt LowBits = RHS.getConstant() - 1;
1064 // If the first operand is non-negative or has all low bits zero, then
1065 // the upper bits are all zero.
1066 if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero))
1067 Known.Zero |= ~LowBits;
1068
1069 // If the first operand is negative and not all low bits are zero, then
1070 // the upper bits are all one.
1071 if (LHS.isNegative() && LowBits.intersects(LHS.One))
1072 Known.One |= ~LowBits;
1073 return Known;
1074 }
1075
1076 // The sign bit is the LHS's sign bit, except when the result of the
1077 // remainder is zero. The magnitude of the result should be less than or
1078 // equal to the magnitude of the LHS. Therefore any leading zeros that exist
1079 // in the left hand side must also exist in the result.
1080 Known.Zero.setHighBits(LHS.countMinLeadingZeros());
1081 return Known;
1082 }
1083
operator &=(const KnownBits & RHS)1084 KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
1085 // Result bit is 0 if either operand bit is 0.
1086 Zero |= RHS.Zero;
1087 // Result bit is 1 if both operand bits are 1.
1088 One &= RHS.One;
1089 return *this;
1090 }
1091
operator |=(const KnownBits & RHS)1092 KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
1093 // Result bit is 0 if both operand bits are 0.
1094 Zero &= RHS.Zero;
1095 // Result bit is 1 if either operand bit is 1.
1096 One |= RHS.One;
1097 return *this;
1098 }
1099
operator ^=(const KnownBits & RHS)1100 KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
1101 // Result bit is 0 if both operand bits are 0 or both are 1.
1102 APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
1103 // Result bit is 1 if one operand bit is 0 and the other is 1.
1104 One = (Zero & RHS.One) | (One & RHS.Zero);
1105 Zero = std::move(Z);
1106 return *this;
1107 }
1108
blsi() const1109 KnownBits KnownBits::blsi() const {
1110 unsigned BitWidth = getBitWidth();
1111 KnownBits Known(Zero, APInt(BitWidth, 0));
1112 unsigned Max = countMaxTrailingZeros();
1113 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
1114 unsigned Min = countMinTrailingZeros();
1115 if (Max == Min && Max < BitWidth)
1116 Known.One.setBit(Max);
1117 return Known;
1118 }
1119
blsmsk() const1120 KnownBits KnownBits::blsmsk() const {
1121 unsigned BitWidth = getBitWidth();
1122 KnownBits Known(BitWidth);
1123 unsigned Max = countMaxTrailingZeros();
1124 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth));
1125 unsigned Min = countMinTrailingZeros();
1126 Known.One.setLowBits(std::min(Min + 1, BitWidth));
1127 return Known;
1128 }
1129
print(raw_ostream & OS) const1130 void KnownBits::print(raw_ostream &OS) const {
1131 unsigned BitWidth = getBitWidth();
1132 for (unsigned I = 0; I < BitWidth; ++I) {
1133 unsigned N = BitWidth - I - 1;
1134 if (Zero[N] && One[N])
1135 OS << "!";
1136 else if (Zero[N])
1137 OS << "0";
1138 else if (One[N])
1139 OS << "1";
1140 else
1141 OS << "?";
1142 }
1143 }
dump() const1144 void KnownBits::dump() const {
1145 print(dbgs());
1146 dbgs() << "\n";
1147 }
1148