1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a class for representing known zeros and ones used by 10 // computeKnownBits. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Support/KnownBits.h" 15 #include "llvm/Support/Debug.h" 16 #include "llvm/Support/raw_ostream.h" 17 #include <cassert> 18 19 using namespace llvm; 20 21 static KnownBits computeForAddCarry( 22 const KnownBits &LHS, const KnownBits &RHS, 23 bool CarryZero, bool CarryOne) { 24 assert(!(CarryZero && CarryOne) && 25 "Carry can't be zero and one at the same time"); 26 27 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero; 28 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne; 29 30 // Compute known bits of the carry. 31 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); 32 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; 33 34 // Compute set of known bits (where all three relevant bits are known). 35 APInt LHSKnownUnion = LHS.Zero | LHS.One; 36 APInt RHSKnownUnion = RHS.Zero | RHS.One; 37 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne; 38 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion; 39 40 assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && 41 "known bits of sum differ"); 42 43 // Compute known bits of the result. 44 KnownBits KnownOut; 45 KnownOut.Zero = ~std::move(PossibleSumZero) & Known; 46 KnownOut.One = std::move(PossibleSumOne) & Known; 47 return KnownOut; 48 } 49 50 KnownBits KnownBits::computeForAddCarry( 51 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { 52 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit"); 53 return ::computeForAddCarry( 54 LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); 55 } 56 57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, 58 const KnownBits &LHS, KnownBits RHS) { 59 KnownBits KnownOut; 60 if (Add) { 61 // Sum = LHS + RHS + 0 62 KnownOut = ::computeForAddCarry( 63 LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); 64 } else { 65 // Sum = LHS + ~RHS + 1 66 std::swap(RHS.Zero, RHS.One); 67 KnownOut = ::computeForAddCarry( 68 LHS, RHS, /*CarryZero*/false, /*CarryOne*/true); 69 } 70 71 // Are we still trying to solve for the sign bit? 72 if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) { 73 if (NSW) { 74 // Adding two non-negative numbers, or subtracting a negative number from 75 // a non-negative one, can't wrap into negative. 76 if (LHS.isNonNegative() && RHS.isNonNegative()) 77 KnownOut.makeNonNegative(); 78 // Adding two negative numbers, or subtracting a non-negative number from 79 // a negative one, can't wrap into non-negative. 80 else if (LHS.isNegative() && RHS.isNegative()) 81 KnownOut.makeNegative(); 82 } 83 } 84 85 return KnownOut; 86 } 87 88 KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS, 89 const KnownBits &Borrow) { 90 assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit"); 91 92 // LHS - RHS = LHS + ~RHS + 1 93 // Carry 1 - Borrow in ::computeForAddCarry 94 std::swap(RHS.Zero, RHS.One); 95 return ::computeForAddCarry(LHS, RHS, 96 /*CarryZero=*/Borrow.One.getBoolValue(), 97 /*CarryOne=*/Borrow.Zero.getBoolValue()); 98 } 99 100 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const { 101 unsigned BitWidth = getBitWidth(); 102 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth && 103 "Illegal sext-in-register"); 104 105 if (SrcBitWidth == BitWidth) 106 return *this; 107 108 unsigned ExtBits = BitWidth - SrcBitWidth; 109 KnownBits Result; 110 Result.One = One << ExtBits; 111 Result.Zero = Zero << ExtBits; 112 Result.One.ashrInPlace(ExtBits); 113 Result.Zero.ashrInPlace(ExtBits); 114 return Result; 115 } 116 117 KnownBits KnownBits::makeGE(const APInt &Val) const { 118 // Count the number of leading bit positions where our underlying value is 119 // known to be less than or equal to Val. 120 unsigned N = (Zero | Val).countl_one(); 121 122 // For each of those bit positions, if Val has a 1 in that bit then our 123 // underlying value must also have a 1. 124 APInt MaskedVal(Val); 125 MaskedVal.clearLowBits(getBitWidth() - N); 126 return KnownBits(Zero, One | MaskedVal); 127 } 128 129 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { 130 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for 131 // RHS. Ideally our caller would already have spotted these cases and 132 // optimized away the umax operation, but we handle them here for 133 // completeness. 134 if (LHS.getMinValue().uge(RHS.getMaxValue())) 135 return LHS; 136 if (RHS.getMinValue().uge(LHS.getMaxValue())) 137 return RHS; 138 139 // If the result of the umax is LHS then it must be greater than or equal to 140 // the minimum possible value of RHS. Likewise for RHS. Any known bits that 141 // are common to these two values are also known in the result. 142 KnownBits L = LHS.makeGE(RHS.getMinValue()); 143 KnownBits R = RHS.makeGE(LHS.getMinValue()); 144 return L.intersectWith(R); 145 } 146 147 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { 148 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] 149 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); }; 150 return Flip(umax(Flip(LHS), Flip(RHS))); 151 } 152 153 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { 154 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] 155 auto Flip = [](const KnownBits &Val) { 156 unsigned SignBitPosition = Val.getBitWidth() - 1; 157 APInt Zero = Val.Zero; 158 APInt One = Val.One; 159 Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 160 One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 161 return KnownBits(Zero, One); 162 }; 163 return Flip(umax(Flip(LHS), Flip(RHS))); 164 } 165 166 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { 167 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] 168 auto Flip = [](const KnownBits &Val) { 169 unsigned SignBitPosition = Val.getBitWidth() - 1; 170 APInt Zero = Val.One; 171 APInt One = Val.Zero; 172 Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 173 One.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 174 return KnownBits(Zero, One); 175 }; 176 return Flip(umax(Flip(LHS), Flip(RHS))); 177 } 178 179 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) { 180 if (isPowerOf2_32(BitWidth)) 181 return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0); 182 // This is only an approximate upper bound. 183 return MaxValue.getLimitedValue(BitWidth - 1); 184 } 185 186 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, 187 bool NSW, bool ShAmtNonZero) { 188 unsigned BitWidth = LHS.getBitWidth(); 189 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { 190 KnownBits Known; 191 bool ShiftedOutZero, ShiftedOutOne; 192 Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero); 193 Known.Zero.setLowBits(ShiftAmt); 194 Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne); 195 196 // All cases returning poison have been handled by MaxShiftAmount already. 197 if (NSW) { 198 if (NUW && ShiftAmt != 0) 199 // NUW means we can assume anything shifted out was a zero. 200 ShiftedOutZero = true; 201 202 if (ShiftedOutZero) 203 Known.makeNonNegative(); 204 else if (ShiftedOutOne) 205 Known.makeNegative(); 206 } 207 return Known; 208 }; 209 210 // Fast path for a common case when LHS is completely unknown. 211 KnownBits Known(BitWidth); 212 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); 213 if (MinShiftAmount == 0 && ShAmtNonZero) 214 MinShiftAmount = 1; 215 if (LHS.isUnknown()) { 216 Known.Zero.setLowBits(MinShiftAmount); 217 if (NUW && NSW && MinShiftAmount != 0) 218 Known.makeNonNegative(); 219 return Known; 220 } 221 222 // Determine maximum shift amount, taking NUW/NSW flags into account. 223 APInt MaxValue = RHS.getMaxValue(); 224 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); 225 if (NUW && NSW) 226 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1); 227 if (NUW) 228 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros()); 229 if (NSW) 230 MaxShiftAmount = std::min( 231 MaxShiftAmount, 232 std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1); 233 234 // Fast path for common case where the shift amount is unknown. 235 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 && 236 isPowerOf2_32(BitWidth)) { 237 Known.Zero.setLowBits(LHS.countMinTrailingZeros()); 238 if (LHS.isAllOnes()) 239 Known.One.setSignBit(); 240 if (NSW) { 241 if (LHS.isNonNegative()) 242 Known.makeNonNegative(); 243 if (LHS.isNegative()) 244 Known.makeNegative(); 245 } 246 return Known; 247 } 248 249 // Find the common bits from all possible shifts. 250 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); 251 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); 252 Known.Zero.setAllBits(); 253 Known.One.setAllBits(); 254 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; 255 ++ShiftAmt) { 256 // Skip if the shift amount is impossible. 257 if ((ShiftAmtZeroMask & ShiftAmt) != 0 || 258 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) 259 continue; 260 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); 261 if (Known.isUnknown()) 262 break; 263 } 264 265 // All shift amounts may result in poison. 266 if (Known.hasConflict()) 267 Known.setAllZero(); 268 return Known; 269 } 270 271 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS, 272 bool ShAmtNonZero) { 273 unsigned BitWidth = LHS.getBitWidth(); 274 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { 275 KnownBits Known = LHS; 276 Known.Zero.lshrInPlace(ShiftAmt); 277 Known.One.lshrInPlace(ShiftAmt); 278 // High bits are known zero. 279 Known.Zero.setHighBits(ShiftAmt); 280 return Known; 281 }; 282 283 // Fast path for a common case when LHS is completely unknown. 284 KnownBits Known(BitWidth); 285 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); 286 if (MinShiftAmount == 0 && ShAmtNonZero) 287 MinShiftAmount = 1; 288 if (LHS.isUnknown()) { 289 Known.Zero.setHighBits(MinShiftAmount); 290 return Known; 291 } 292 293 // Find the common bits from all possible shifts. 294 APInt MaxValue = RHS.getMaxValue(); 295 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); 296 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); 297 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); 298 Known.Zero.setAllBits(); 299 Known.One.setAllBits(); 300 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; 301 ++ShiftAmt) { 302 // Skip if the shift amount is impossible. 303 if ((ShiftAmtZeroMask & ShiftAmt) != 0 || 304 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) 305 continue; 306 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); 307 if (Known.isUnknown()) 308 break; 309 } 310 311 // All shift amounts may result in poison. 312 if (Known.hasConflict()) 313 Known.setAllZero(); 314 return Known; 315 } 316 317 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS, 318 bool ShAmtNonZero) { 319 unsigned BitWidth = LHS.getBitWidth(); 320 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { 321 KnownBits Known = LHS; 322 Known.Zero.ashrInPlace(ShiftAmt); 323 Known.One.ashrInPlace(ShiftAmt); 324 return Known; 325 }; 326 327 // Fast path for a common case when LHS is completely unknown. 328 KnownBits Known(BitWidth); 329 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); 330 if (MinShiftAmount == 0 && ShAmtNonZero) 331 MinShiftAmount = 1; 332 if (LHS.isUnknown()) { 333 if (MinShiftAmount == BitWidth) { 334 // Always poison. Return zero because we don't like returning conflict. 335 Known.setAllZero(); 336 return Known; 337 } 338 return Known; 339 } 340 341 // Find the common bits from all possible shifts. 342 APInt MaxValue = RHS.getMaxValue(); 343 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); 344 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); 345 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); 346 Known.Zero.setAllBits(); 347 Known.One.setAllBits(); 348 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; 349 ++ShiftAmt) { 350 // Skip if the shift amount is impossible. 351 if ((ShiftAmtZeroMask & ShiftAmt) != 0 || 352 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) 353 continue; 354 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); 355 if (Known.isUnknown()) 356 break; 357 } 358 359 // All shift amounts may result in poison. 360 if (Known.hasConflict()) 361 Known.setAllZero(); 362 return Known; 363 } 364 365 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) { 366 if (LHS.isConstant() && RHS.isConstant()) 367 return std::optional<bool>(LHS.getConstant() == RHS.getConstant()); 368 if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero)) 369 return std::optional<bool>(false); 370 return std::nullopt; 371 } 372 373 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) { 374 if (std::optional<bool> KnownEQ = eq(LHS, RHS)) 375 return std::optional<bool>(!*KnownEQ); 376 return std::nullopt; 377 } 378 379 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) { 380 // LHS >u RHS -> false if umax(LHS) <= umax(RHS) 381 if (LHS.getMaxValue().ule(RHS.getMinValue())) 382 return std::optional<bool>(false); 383 // LHS >u RHS -> true if umin(LHS) > umax(RHS) 384 if (LHS.getMinValue().ugt(RHS.getMaxValue())) 385 return std::optional<bool>(true); 386 return std::nullopt; 387 } 388 389 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) { 390 if (std::optional<bool> IsUGT = ugt(RHS, LHS)) 391 return std::optional<bool>(!*IsUGT); 392 return std::nullopt; 393 } 394 395 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) { 396 return ugt(RHS, LHS); 397 } 398 399 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) { 400 return uge(RHS, LHS); 401 } 402 403 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) { 404 // LHS >s RHS -> false if smax(LHS) <= smax(RHS) 405 if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue())) 406 return std::optional<bool>(false); 407 // LHS >s RHS -> true if smin(LHS) > smax(RHS) 408 if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue())) 409 return std::optional<bool>(true); 410 return std::nullopt; 411 } 412 413 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) { 414 if (std::optional<bool> KnownSGT = sgt(RHS, LHS)) 415 return std::optional<bool>(!*KnownSGT); 416 return std::nullopt; 417 } 418 419 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) { 420 return sgt(RHS, LHS); 421 } 422 423 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) { 424 return sge(RHS, LHS); 425 } 426 427 KnownBits KnownBits::abs(bool IntMinIsPoison) const { 428 // If the source's MSB is zero then we know the rest of the bits already. 429 if (isNonNegative()) 430 return *this; 431 432 // Absolute value preserves trailing zero count. 433 KnownBits KnownAbs(getBitWidth()); 434 435 // If the input is negative, then abs(x) == -x. 436 if (isNegative()) { 437 KnownBits Tmp = *this; 438 // Special case for IntMinIsPoison. We know the sign bit is set and we know 439 // all the rest of the bits except one to be zero. Since we have 440 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is 441 // INT_MIN. 442 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth()) 443 Tmp.One.setBit(countMinTrailingZeros()); 444 445 KnownAbs = computeForAddSub( 446 /*Add*/ false, IntMinIsPoison, 447 KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp); 448 449 // One more special case for IntMinIsPoison. If we don't know any ones other 450 // than the signbit, we know for certain that all the unknowns can't be 451 // zero. So if we know high zero bits, but have unknown low bits, we know 452 // for certain those high-zero bits will end up as one. This is because, 453 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up 454 // to the high bits. If we know a known INT_MIN input skip this. The result 455 // is poison anyways. 456 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 && 457 Tmp.countMaxPopulation() != 1) { 458 Tmp.One.clearSignBit(); 459 Tmp.Zero.setSignBit(); 460 KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(), 461 getBitWidth() - 1); 462 } 463 464 } else { 465 unsigned MaxTZ = countMaxTrailingZeros(); 466 unsigned MinTZ = countMinTrailingZeros(); 467 468 KnownAbs.Zero.setLowBits(MinTZ); 469 // If we know the lowest set 1, then preserve it. 470 if (MaxTZ == MinTZ && MaxTZ < getBitWidth()) 471 KnownAbs.One.setBit(MaxTZ); 472 473 // We only know that the absolute values's MSB will be zero if INT_MIN is 474 // poison, or there is a set bit that isn't the sign bit (otherwise it could 475 // be INT_MIN). 476 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) { 477 KnownAbs.One.clearSignBit(); 478 KnownAbs.Zero.setSignBit(); 479 } 480 } 481 482 assert(!KnownAbs.hasConflict() && "Bad Output"); 483 return KnownAbs; 484 } 485 486 static KnownBits computeForSatAddSub(bool Add, bool Signed, 487 const KnownBits &LHS, 488 const KnownBits &RHS) { 489 assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); 490 // We don't see NSW even for sadd/ssub as we want to check if the result has 491 // signed overflow. 492 KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS); 493 unsigned BitWidth = Res.getBitWidth(); 494 auto SignBitKnown = [&](const KnownBits &K) { 495 return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; 496 }; 497 std::optional<bool> Overflow; 498 499 if (Signed) { 500 // If we can actually detect overflow do so. Otherwise leave Overflow as 501 // nullopt (we assume it may have happened). 502 if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) { 503 if (Add) { 504 // sadd.sat 505 Overflow = (LHS.isNonNegative() == RHS.isNonNegative() && 506 Res.isNonNegative() != LHS.isNonNegative()); 507 } else { 508 // ssub.sat 509 Overflow = (LHS.isNonNegative() != RHS.isNonNegative() && 510 Res.isNonNegative() != LHS.isNonNegative()); 511 } 512 } 513 } else if (Add) { 514 // uadd.sat 515 bool Of; 516 (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of); 517 if (!Of) { 518 Overflow = false; 519 } else { 520 (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of); 521 if (Of) 522 Overflow = true; 523 } 524 } else { 525 // usub.sat 526 bool Of; 527 (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of); 528 if (!Of) { 529 Overflow = false; 530 } else { 531 (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of); 532 if (Of) 533 Overflow = true; 534 } 535 } 536 537 if (Signed) { 538 if (Add) { 539 if (LHS.isNonNegative() && RHS.isNonNegative()) { 540 // Pos + Pos -> Pos 541 Res.One.clearSignBit(); 542 Res.Zero.setSignBit(); 543 } 544 if (LHS.isNegative() && RHS.isNegative()) { 545 // Neg + Neg -> Neg 546 Res.One.setSignBit(); 547 Res.Zero.clearSignBit(); 548 } 549 } else { 550 if (LHS.isNegative() && RHS.isNonNegative()) { 551 // Neg - Pos -> Neg 552 Res.One.setSignBit(); 553 Res.Zero.clearSignBit(); 554 } else if (LHS.isNonNegative() && RHS.isNegative()) { 555 // Pos - Neg -> Pos 556 Res.One.clearSignBit(); 557 Res.Zero.setSignBit(); 558 } 559 } 560 } else { 561 // Add: Leading ones of either operand are preserved. 562 // Sub: Leading zeros of LHS and leading ones of RHS are preserved 563 // as leading zeros in the result. 564 unsigned LeadingKnown; 565 if (Add) 566 LeadingKnown = 567 std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes()); 568 else 569 LeadingKnown = 570 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes()); 571 572 // We select between the operation result and all-ones/zero 573 // respectively, so we can preserve known ones/zeros. 574 APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown); 575 if (Add) { 576 Res.One |= Mask; 577 Res.Zero &= ~Mask; 578 } else { 579 Res.Zero |= Mask; 580 Res.One &= ~Mask; 581 } 582 } 583 584 if (Overflow) { 585 // We know whether or not we overflowed. 586 if (!(*Overflow)) { 587 // No overflow. 588 assert(!Res.hasConflict() && "Bad Output"); 589 return Res; 590 } 591 592 // We overflowed 593 APInt C; 594 if (Signed) { 595 // sadd.sat / ssub.sat 596 assert(SignBitKnown(LHS) && 597 "We somehow know overflow without knowing input sign"); 598 C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth) 599 : APInt::getSignedMaxValue(BitWidth); 600 } else if (Add) { 601 // uadd.sat 602 C = APInt::getMaxValue(BitWidth); 603 } else { 604 // uadd.sat 605 C = APInt::getMinValue(BitWidth); 606 } 607 608 Res.One = C; 609 Res.Zero = ~C; 610 assert(!Res.hasConflict() && "Bad Output"); 611 return Res; 612 } 613 614 // We don't know if we overflowed. 615 if (Signed) { 616 // sadd.sat/ssub.sat 617 // We can keep our information about the sign bits. 618 Res.Zero.clearLowBits(BitWidth - 1); 619 Res.One.clearLowBits(BitWidth - 1); 620 } else if (Add) { 621 // uadd.sat 622 // We need to clear all the known zeros as we can only use the leading ones. 623 Res.Zero.clearAllBits(); 624 } else { 625 // usub.sat 626 // We need to clear all the known ones as we can only use the leading zero. 627 Res.One.clearAllBits(); 628 } 629 630 assert(!Res.hasConflict() && "Bad Output"); 631 return Res; 632 } 633 634 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) { 635 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS); 636 } 637 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) { 638 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS); 639 } 640 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) { 641 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS); 642 } 643 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { 644 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); 645 } 646 647 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, 648 bool NoUndefSelfMultiply) { 649 unsigned BitWidth = LHS.getBitWidth(); 650 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && 651 !RHS.hasConflict() && "Operand mismatch"); 652 assert((!NoUndefSelfMultiply || LHS == RHS) && 653 "Self multiplication knownbits mismatch"); 654 655 // Compute the high known-0 bits by multiplying the unsigned max of each side. 656 // Conservatively, M active bits * N active bits results in M + N bits in the 657 // result. But if we know a value is a power-of-2 for example, then this 658 // computes one more leading zero. 659 // TODO: This could be generalized to number of sign bits (negative numbers). 660 APInt UMaxLHS = LHS.getMaxValue(); 661 APInt UMaxRHS = RHS.getMaxValue(); 662 663 // For leading zeros in the result to be valid, the unsigned max product must 664 // fit in the bitwidth (it must not overflow). 665 bool HasOverflow; 666 APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow); 667 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero(); 668 669 // The result of the bottom bits of an integer multiply can be 670 // inferred by looking at the bottom bits of both operands and 671 // multiplying them together. 672 // We can infer at least the minimum number of known trailing bits 673 // of both operands. Depending on number of trailing zeros, we can 674 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming 675 // a and b are divisible by m and n respectively. 676 // We then calculate how many of those bits are inferrable and set 677 // the output. For example, the i8 mul: 678 // a = XXXX1100 (12) 679 // b = XXXX1110 (14) 680 // We know the bottom 3 bits are zero since the first can be divided by 681 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). 682 // Applying the multiplication to the trimmed arguments gets: 683 // XX11 (3) 684 // X111 (7) 685 // ------- 686 // XX11 687 // XX11 688 // XX11 689 // XX11 690 // ------- 691 // XXXXX01 692 // Which allows us to infer the 2 LSBs. Since we're multiplying the result 693 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. 694 // The proof for this can be described as: 695 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && 696 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + 697 // umin(countTrailingZeros(C2), C6) + 698 // umin(C5 - umin(countTrailingZeros(C1), C5), 699 // C6 - umin(countTrailingZeros(C2), C6)))) - 1) 700 // %aa = shl i8 %a, C5 701 // %bb = shl i8 %b, C6 702 // %aaa = or i8 %aa, C1 703 // %bbb = or i8 %bb, C2 704 // %mul = mul i8 %aaa, %bbb 705 // %mask = and i8 %mul, C7 706 // => 707 // %mask = i8 ((C1*C2)&C7) 708 // Where C5, C6 describe the known bits of %a, %b 709 // C1, C2 describe the known bottom bits of %a, %b. 710 // C7 describes the mask of the known bits of the result. 711 const APInt &Bottom0 = LHS.One; 712 const APInt &Bottom1 = RHS.One; 713 714 // How many times we'd be able to divide each argument by 2 (shr by 1). 715 // This gives us the number of trailing zeros on the multiplication result. 716 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one(); 717 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one(); 718 unsigned TrailZero0 = LHS.countMinTrailingZeros(); 719 unsigned TrailZero1 = RHS.countMinTrailingZeros(); 720 unsigned TrailZ = TrailZero0 + TrailZero1; 721 722 // Figure out the fewest known-bits operand. 723 unsigned SmallestOperand = 724 std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1); 725 unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); 726 727 APInt BottomKnown = 728 Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1); 729 730 KnownBits Res(BitWidth); 731 Res.Zero.setHighBits(LeadZ); 732 Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); 733 Res.One = BottomKnown.getLoBits(ResultBitsKnown); 734 735 // If we're self-multiplying then bit[1] is guaranteed to be zero. 736 if (NoUndefSelfMultiply && BitWidth > 1) { 737 assert(Res.One[1] == 0 && 738 "Self-multiplication failed Quadratic Reciprocity!"); 739 Res.Zero.setBit(1); 740 } 741 742 return Res; 743 } 744 745 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) { 746 unsigned BitWidth = LHS.getBitWidth(); 747 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && 748 !RHS.hasConflict() && "Operand mismatch"); 749 KnownBits WideLHS = LHS.sext(2 * BitWidth); 750 KnownBits WideRHS = RHS.sext(2 * BitWidth); 751 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); 752 } 753 754 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) { 755 unsigned BitWidth = LHS.getBitWidth(); 756 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && 757 !RHS.hasConflict() && "Operand mismatch"); 758 KnownBits WideLHS = LHS.zext(2 * BitWidth); 759 KnownBits WideRHS = RHS.zext(2 * BitWidth); 760 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); 761 } 762 763 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS, 764 const KnownBits &RHS, bool Exact) { 765 766 if (!Exact) 767 return Known; 768 769 // If LHS is Odd, the result is Odd no matter what. 770 // Odd / Odd -> Odd 771 // Odd / Even -> Impossible (because its exact division) 772 if (LHS.One[0]) 773 Known.One.setBit(0); 774 775 int MinTZ = 776 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros(); 777 int MaxTZ = 778 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros(); 779 if (MinTZ >= 0) { 780 // Result has at least MinTZ trailing zeros. 781 Known.Zero.setLowBits(MinTZ); 782 if (MinTZ == MaxTZ) { 783 // Result has exactly MinTZ trailing zeros. 784 Known.One.setBit(MinTZ); 785 } 786 } else if (MaxTZ < 0) { 787 // Poison Result 788 Known.setAllZero(); 789 } 790 791 // In the KnownBits exhaustive tests, we have poison inputs for exact values 792 // a LOT. If we have a conflict, just return all zeros. 793 if (Known.hasConflict()) 794 Known.setAllZero(); 795 796 return Known; 797 } 798 799 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, 800 bool Exact) { 801 // Equivalent of `udiv`. We must have caught this before it was folded. 802 if (LHS.isNonNegative() && RHS.isNonNegative()) 803 return udiv(LHS, RHS, Exact); 804 805 unsigned BitWidth = LHS.getBitWidth(); 806 assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); 807 KnownBits Known(BitWidth); 808 809 if (LHS.isZero() || RHS.isZero()) { 810 // Result is either known Zero or UB. Return Zero either way. 811 // Checking this earlier saves us a lot of special cases later on. 812 Known.setAllZero(); 813 return Known; 814 } 815 816 std::optional<APInt> Res; 817 if (LHS.isNegative() && RHS.isNegative()) { 818 // Result non-negative. 819 APInt Denom = RHS.getSignedMaxValue(); 820 APInt Num = LHS.getSignedMinValue(); 821 // INT_MIN/-1 would be a poison result (impossible). Estimate the division 822 // as signed max (we will only set sign bit in the result). 823 Res = (Num.isMinSignedValue() && Denom.isAllOnes()) 824 ? APInt::getSignedMaxValue(BitWidth) 825 : Num.sdiv(Denom); 826 } else if (LHS.isNegative() && RHS.isNonNegative()) { 827 // Result is negative if Exact OR -LHS u>= RHS. 828 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) { 829 APInt Denom = RHS.getSignedMinValue(); 830 APInt Num = LHS.getSignedMinValue(); 831 Res = Denom.isZero() ? Num : Num.sdiv(Denom); 832 } 833 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { 834 // Result is negative if Exact OR LHS u>= -RHS. 835 if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) { 836 APInt Denom = RHS.getSignedMaxValue(); 837 APInt Num = LHS.getSignedMaxValue(); 838 Res = Num.sdiv(Denom); 839 } 840 } 841 842 if (Res) { 843 if (Res->isNonNegative()) { 844 unsigned LeadZ = Res->countLeadingZeros(); 845 Known.Zero.setHighBits(LeadZ); 846 } else { 847 unsigned LeadO = Res->countLeadingOnes(); 848 Known.One.setHighBits(LeadO); 849 } 850 } 851 852 Known = divComputeLowBit(Known, LHS, RHS, Exact); 853 854 assert(!Known.hasConflict() && "Bad Output"); 855 return Known; 856 } 857 858 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS, 859 bool Exact) { 860 unsigned BitWidth = LHS.getBitWidth(); 861 assert(!LHS.hasConflict() && !RHS.hasConflict()); 862 KnownBits Known(BitWidth); 863 864 if (LHS.isZero() || RHS.isZero()) { 865 // Result is either known Zero or UB. Return Zero either way. 866 // Checking this earlier saves us a lot of special cases later on. 867 Known.setAllZero(); 868 return Known; 869 } 870 871 // We can figure out the minimum number of upper zero bits by doing 872 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator 873 // gets larger, the number of upper zero bits increases. 874 APInt MinDenom = RHS.getMinValue(); 875 APInt MaxNum = LHS.getMaxValue(); 876 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom); 877 878 unsigned LeadZ = MaxRes.countLeadingZeros(); 879 880 Known.Zero.setHighBits(LeadZ); 881 Known = divComputeLowBit(Known, LHS, RHS, Exact); 882 883 assert(!Known.hasConflict() && "Bad Output"); 884 return Known; 885 } 886 887 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) { 888 unsigned BitWidth = LHS.getBitWidth(); 889 if (!RHS.isZero() && RHS.Zero[0]) { 890 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result. 891 unsigned RHSZeros = RHS.countMinTrailingZeros(); 892 APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros); 893 APInt OnesMask = LHS.One & Mask; 894 APInt ZerosMask = LHS.Zero & Mask; 895 return KnownBits(ZerosMask, OnesMask); 896 } 897 return KnownBits(BitWidth); 898 } 899 900 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) { 901 assert(!LHS.hasConflict() && !RHS.hasConflict()); 902 903 KnownBits Known = remGetLowBits(LHS, RHS); 904 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 905 // NB: Low bits set in `remGetLowBits`. 906 APInt HighBits = ~(RHS.getConstant() - 1); 907 Known.Zero |= HighBits; 908 return Known; 909 } 910 911 // Since the result is less than or equal to either operand, any leading 912 // zero bits in either operand must also exist in the result. 913 uint32_t Leaders = 914 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros()); 915 Known.Zero.setHighBits(Leaders); 916 return Known; 917 } 918 919 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) { 920 assert(!LHS.hasConflict() && !RHS.hasConflict()); 921 922 KnownBits Known = remGetLowBits(LHS, RHS); 923 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 924 // NB: Low bits are set in `remGetLowBits`. 925 APInt LowBits = RHS.getConstant() - 1; 926 // If the first operand is non-negative or has all low bits zero, then 927 // the upper bits are all zero. 928 if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero)) 929 Known.Zero |= ~LowBits; 930 931 // If the first operand is negative and not all low bits are zero, then 932 // the upper bits are all one. 933 if (LHS.isNegative() && LowBits.intersects(LHS.One)) 934 Known.One |= ~LowBits; 935 return Known; 936 } 937 938 // The sign bit is the LHS's sign bit, except when the result of the 939 // remainder is zero. The magnitude of the result should be less than or 940 // equal to the magnitude of the LHS. Therefore any leading zeros that exist 941 // in the left hand side must also exist in the result. 942 Known.Zero.setHighBits(LHS.countMinLeadingZeros()); 943 return Known; 944 } 945 946 KnownBits &KnownBits::operator&=(const KnownBits &RHS) { 947 // Result bit is 0 if either operand bit is 0. 948 Zero |= RHS.Zero; 949 // Result bit is 1 if both operand bits are 1. 950 One &= RHS.One; 951 return *this; 952 } 953 954 KnownBits &KnownBits::operator|=(const KnownBits &RHS) { 955 // Result bit is 0 if both operand bits are 0. 956 Zero &= RHS.Zero; 957 // Result bit is 1 if either operand bit is 1. 958 One |= RHS.One; 959 return *this; 960 } 961 962 KnownBits &KnownBits::operator^=(const KnownBits &RHS) { 963 // Result bit is 0 if both operand bits are 0 or both are 1. 964 APInt Z = (Zero & RHS.Zero) | (One & RHS.One); 965 // Result bit is 1 if one operand bit is 0 and the other is 1. 966 One = (Zero & RHS.One) | (One & RHS.Zero); 967 Zero = std::move(Z); 968 return *this; 969 } 970 971 KnownBits KnownBits::blsi() const { 972 unsigned BitWidth = getBitWidth(); 973 KnownBits Known(Zero, APInt(BitWidth, 0)); 974 unsigned Max = countMaxTrailingZeros(); 975 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth)); 976 unsigned Min = countMinTrailingZeros(); 977 if (Max == Min && Max < BitWidth) 978 Known.One.setBit(Max); 979 return Known; 980 } 981 982 KnownBits KnownBits::blsmsk() const { 983 unsigned BitWidth = getBitWidth(); 984 KnownBits Known(BitWidth); 985 unsigned Max = countMaxTrailingZeros(); 986 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth)); 987 unsigned Min = countMinTrailingZeros(); 988 Known.One.setLowBits(std::min(Min + 1, BitWidth)); 989 return Known; 990 } 991 992 void KnownBits::print(raw_ostream &OS) const { 993 unsigned BitWidth = getBitWidth(); 994 for (unsigned I = 0; I < BitWidth; ++I) { 995 unsigned N = BitWidth - I - 1; 996 if (Zero[N] && One[N]) 997 OS << "!"; 998 else if (Zero[N]) 999 OS << "0"; 1000 else if (One[N]) 1001 OS << "1"; 1002 else 1003 OS << "?"; 1004 } 1005 } 1006 void KnownBits::dump() const { 1007 print(dbgs()); 1008 dbgs() << "\n"; 1009 } 1010