1 //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a class for representing known zeros and ones used by 10 // computeKnownBits. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Support/KnownBits.h" 15 #include "llvm/Support/Debug.h" 16 #include "llvm/Support/raw_ostream.h" 17 #include <cassert> 18 19 using namespace llvm; 20 21 static KnownBits computeForAddCarry( 22 const KnownBits &LHS, const KnownBits &RHS, 23 bool CarryZero, bool CarryOne) { 24 assert(!(CarryZero && CarryOne) && 25 "Carry can't be zero and one at the same time"); 26 27 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero; 28 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne; 29 30 // Compute known bits of the carry. 31 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); 32 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; 33 34 // Compute set of known bits (where all three relevant bits are known). 35 APInt LHSKnownUnion = LHS.Zero | LHS.One; 36 APInt RHSKnownUnion = RHS.Zero | RHS.One; 37 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne; 38 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion; 39 40 assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && 41 "known bits of sum differ"); 42 43 // Compute known bits of the result. 44 KnownBits KnownOut; 45 KnownOut.Zero = ~std::move(PossibleSumZero) & Known; 46 KnownOut.One = std::move(PossibleSumOne) & Known; 47 return KnownOut; 48 } 49 50 KnownBits KnownBits::computeForAddCarry( 51 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { 52 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit"); 53 return ::computeForAddCarry( 54 LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); 55 } 56 57 KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, 58 const KnownBits &LHS, KnownBits RHS) { 59 KnownBits KnownOut; 60 if (Add) { 61 // Sum = LHS + RHS + 0 62 KnownOut = ::computeForAddCarry( 63 LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); 64 } else { 65 // Sum = LHS + ~RHS + 1 66 std::swap(RHS.Zero, RHS.One); 67 KnownOut = ::computeForAddCarry( 68 LHS, RHS, /*CarryZero*/false, /*CarryOne*/true); 69 } 70 71 // Are we still trying to solve for the sign bit? 72 if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) { 73 if (NSW) { 74 // Adding two non-negative numbers, or subtracting a negative number from 75 // a non-negative one, can't wrap into negative. 76 if (LHS.isNonNegative() && RHS.isNonNegative()) 77 KnownOut.makeNonNegative(); 78 // Adding two negative numbers, or subtracting a non-negative number from 79 // a negative one, can't wrap into non-negative. 80 else if (LHS.isNegative() && RHS.isNegative()) 81 KnownOut.makeNegative(); 82 } 83 } 84 85 return KnownOut; 86 } 87 88 KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const { 89 unsigned BitWidth = getBitWidth(); 90 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth && 91 "Illegal sext-in-register"); 92 93 if (SrcBitWidth == BitWidth) 94 return *this; 95 96 unsigned ExtBits = BitWidth - SrcBitWidth; 97 KnownBits Result; 98 Result.One = One << ExtBits; 99 Result.Zero = Zero << ExtBits; 100 Result.One.ashrInPlace(ExtBits); 101 Result.Zero.ashrInPlace(ExtBits); 102 return Result; 103 } 104 105 KnownBits KnownBits::makeGE(const APInt &Val) const { 106 // Count the number of leading bit positions where our underlying value is 107 // known to be less than or equal to Val. 108 unsigned N = (Zero | Val).countl_one(); 109 110 // For each of those bit positions, if Val has a 1 in that bit then our 111 // underlying value must also have a 1. 112 APInt MaskedVal(Val); 113 MaskedVal.clearLowBits(getBitWidth() - N); 114 return KnownBits(Zero, One | MaskedVal); 115 } 116 117 KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { 118 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for 119 // RHS. Ideally our caller would already have spotted these cases and 120 // optimized away the umax operation, but we handle them here for 121 // completeness. 122 if (LHS.getMinValue().uge(RHS.getMaxValue())) 123 return LHS; 124 if (RHS.getMinValue().uge(LHS.getMaxValue())) 125 return RHS; 126 127 // If the result of the umax is LHS then it must be greater than or equal to 128 // the minimum possible value of RHS. Likewise for RHS. Any known bits that 129 // are common to these two values are also known in the result. 130 KnownBits L = LHS.makeGE(RHS.getMinValue()); 131 KnownBits R = RHS.makeGE(LHS.getMinValue()); 132 return L.intersectWith(R); 133 } 134 135 KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { 136 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] 137 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); }; 138 return Flip(umax(Flip(LHS), Flip(RHS))); 139 } 140 141 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { 142 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] 143 auto Flip = [](const KnownBits &Val) { 144 unsigned SignBitPosition = Val.getBitWidth() - 1; 145 APInt Zero = Val.Zero; 146 APInt One = Val.One; 147 Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 148 One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 149 return KnownBits(Zero, One); 150 }; 151 return Flip(umax(Flip(LHS), Flip(RHS))); 152 } 153 154 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { 155 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] 156 auto Flip = [](const KnownBits &Val) { 157 unsigned SignBitPosition = Val.getBitWidth() - 1; 158 APInt Zero = Val.One; 159 APInt One = Val.Zero; 160 Zero.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]); 161 One.setBitVal(SignBitPosition, Val.One[SignBitPosition]); 162 return KnownBits(Zero, One); 163 }; 164 return Flip(umax(Flip(LHS), Flip(RHS))); 165 } 166 167 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) { 168 if (isPowerOf2_32(BitWidth)) 169 return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0); 170 // This is only an approximate upper bound. 171 return MaxValue.getLimitedValue(BitWidth - 1); 172 } 173 174 KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, 175 bool NSW, bool ShAmtNonZero) { 176 unsigned BitWidth = LHS.getBitWidth(); 177 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { 178 KnownBits Known; 179 bool ShiftedOutZero, ShiftedOutOne; 180 Known.Zero = LHS.Zero.ushl_ov(ShiftAmt, ShiftedOutZero); 181 Known.Zero.setLowBits(ShiftAmt); 182 Known.One = LHS.One.ushl_ov(ShiftAmt, ShiftedOutOne); 183 184 // All cases returning poison have been handled by MaxShiftAmount already. 185 if (NSW) { 186 if (NUW && ShiftAmt != 0) 187 // NUW means we can assume anything shifted out was a zero. 188 ShiftedOutZero = true; 189 190 if (ShiftedOutZero) 191 Known.makeNonNegative(); 192 else if (ShiftedOutOne) 193 Known.makeNegative(); 194 } 195 return Known; 196 }; 197 198 // Fast path for a common case when LHS is completely unknown. 199 KnownBits Known(BitWidth); 200 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); 201 if (MinShiftAmount == 0 && ShAmtNonZero) 202 MinShiftAmount = 1; 203 if (LHS.isUnknown()) { 204 Known.Zero.setLowBits(MinShiftAmount); 205 if (NUW && NSW && MinShiftAmount != 0) 206 Known.makeNonNegative(); 207 return Known; 208 } 209 210 // Determine maximum shift amount, taking NUW/NSW flags into account. 211 APInt MaxValue = RHS.getMaxValue(); 212 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); 213 if (NUW && NSW) 214 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros() - 1); 215 if (NUW) 216 MaxShiftAmount = std::min(MaxShiftAmount, LHS.countMaxLeadingZeros()); 217 if (NSW) 218 MaxShiftAmount = std::min( 219 MaxShiftAmount, 220 std::max(LHS.countMaxLeadingZeros(), LHS.countMaxLeadingOnes()) - 1); 221 222 // Fast path for common case where the shift amount is unknown. 223 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 && 224 isPowerOf2_32(BitWidth)) { 225 Known.Zero.setLowBits(LHS.countMinTrailingZeros()); 226 if (LHS.isAllOnes()) 227 Known.One.setSignBit(); 228 if (NSW) { 229 if (LHS.isNonNegative()) 230 Known.makeNonNegative(); 231 if (LHS.isNegative()) 232 Known.makeNegative(); 233 } 234 return Known; 235 } 236 237 // Find the common bits from all possible shifts. 238 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); 239 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); 240 Known.Zero.setAllBits(); 241 Known.One.setAllBits(); 242 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; 243 ++ShiftAmt) { 244 // Skip if the shift amount is impossible. 245 if ((ShiftAmtZeroMask & ShiftAmt) != 0 || 246 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) 247 continue; 248 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); 249 if (Known.isUnknown()) 250 break; 251 } 252 253 // All shift amounts may result in poison. 254 if (Known.hasConflict()) 255 Known.setAllZero(); 256 return Known; 257 } 258 259 KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS, 260 bool ShAmtNonZero) { 261 unsigned BitWidth = LHS.getBitWidth(); 262 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { 263 KnownBits Known = LHS; 264 Known.Zero.lshrInPlace(ShiftAmt); 265 Known.One.lshrInPlace(ShiftAmt); 266 // High bits are known zero. 267 Known.Zero.setHighBits(ShiftAmt); 268 return Known; 269 }; 270 271 // Fast path for a common case when LHS is completely unknown. 272 KnownBits Known(BitWidth); 273 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); 274 if (MinShiftAmount == 0 && ShAmtNonZero) 275 MinShiftAmount = 1; 276 if (LHS.isUnknown()) { 277 Known.Zero.setHighBits(MinShiftAmount); 278 return Known; 279 } 280 281 // Find the common bits from all possible shifts. 282 APInt MaxValue = RHS.getMaxValue(); 283 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); 284 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); 285 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); 286 Known.Zero.setAllBits(); 287 Known.One.setAllBits(); 288 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; 289 ++ShiftAmt) { 290 // Skip if the shift amount is impossible. 291 if ((ShiftAmtZeroMask & ShiftAmt) != 0 || 292 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) 293 continue; 294 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); 295 if (Known.isUnknown()) 296 break; 297 } 298 299 // All shift amounts may result in poison. 300 if (Known.hasConflict()) 301 Known.setAllZero(); 302 return Known; 303 } 304 305 KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS, 306 bool ShAmtNonZero) { 307 unsigned BitWidth = LHS.getBitWidth(); 308 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { 309 KnownBits Known = LHS; 310 Known.Zero.ashrInPlace(ShiftAmt); 311 Known.One.ashrInPlace(ShiftAmt); 312 return Known; 313 }; 314 315 // Fast path for a common case when LHS is completely unknown. 316 KnownBits Known(BitWidth); 317 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(BitWidth); 318 if (MinShiftAmount == 0 && ShAmtNonZero) 319 MinShiftAmount = 1; 320 if (LHS.isUnknown()) { 321 if (MinShiftAmount == BitWidth) { 322 // Always poison. Return zero because we don't like returning conflict. 323 Known.setAllZero(); 324 return Known; 325 } 326 return Known; 327 } 328 329 // Find the common bits from all possible shifts. 330 APInt MaxValue = RHS.getMaxValue(); 331 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); 332 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(32).getZExtValue(); 333 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(32).getZExtValue(); 334 Known.Zero.setAllBits(); 335 Known.One.setAllBits(); 336 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; 337 ++ShiftAmt) { 338 // Skip if the shift amount is impossible. 339 if ((ShiftAmtZeroMask & ShiftAmt) != 0 || 340 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) 341 continue; 342 Known = Known.intersectWith(ShiftByConst(LHS, ShiftAmt)); 343 if (Known.isUnknown()) 344 break; 345 } 346 347 // All shift amounts may result in poison. 348 if (Known.hasConflict()) 349 Known.setAllZero(); 350 return Known; 351 } 352 353 std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) { 354 if (LHS.isConstant() && RHS.isConstant()) 355 return std::optional<bool>(LHS.getConstant() == RHS.getConstant()); 356 if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero)) 357 return std::optional<bool>(false); 358 return std::nullopt; 359 } 360 361 std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) { 362 if (std::optional<bool> KnownEQ = eq(LHS, RHS)) 363 return std::optional<bool>(!*KnownEQ); 364 return std::nullopt; 365 } 366 367 std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) { 368 // LHS >u RHS -> false if umax(LHS) <= umax(RHS) 369 if (LHS.getMaxValue().ule(RHS.getMinValue())) 370 return std::optional<bool>(false); 371 // LHS >u RHS -> true if umin(LHS) > umax(RHS) 372 if (LHS.getMinValue().ugt(RHS.getMaxValue())) 373 return std::optional<bool>(true); 374 return std::nullopt; 375 } 376 377 std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) { 378 if (std::optional<bool> IsUGT = ugt(RHS, LHS)) 379 return std::optional<bool>(!*IsUGT); 380 return std::nullopt; 381 } 382 383 std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) { 384 return ugt(RHS, LHS); 385 } 386 387 std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) { 388 return uge(RHS, LHS); 389 } 390 391 std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) { 392 // LHS >s RHS -> false if smax(LHS) <= smax(RHS) 393 if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue())) 394 return std::optional<bool>(false); 395 // LHS >s RHS -> true if smin(LHS) > smax(RHS) 396 if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue())) 397 return std::optional<bool>(true); 398 return std::nullopt; 399 } 400 401 std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) { 402 if (std::optional<bool> KnownSGT = sgt(RHS, LHS)) 403 return std::optional<bool>(!*KnownSGT); 404 return std::nullopt; 405 } 406 407 std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) { 408 return sgt(RHS, LHS); 409 } 410 411 std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) { 412 return sge(RHS, LHS); 413 } 414 415 KnownBits KnownBits::abs(bool IntMinIsPoison) const { 416 // If the source's MSB is zero then we know the rest of the bits already. 417 if (isNonNegative()) 418 return *this; 419 420 // Absolute value preserves trailing zero count. 421 KnownBits KnownAbs(getBitWidth()); 422 423 // If the input is negative, then abs(x) == -x. 424 if (isNegative()) { 425 KnownBits Tmp = *this; 426 // Special case for IntMinIsPoison. We know the sign bit is set and we know 427 // all the rest of the bits except one to be zero. Since we have 428 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is 429 // INT_MIN. 430 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth()) 431 Tmp.One.setBit(countMinTrailingZeros()); 432 433 KnownAbs = computeForAddSub( 434 /*Add*/ false, IntMinIsPoison, 435 KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp); 436 437 // One more special case for IntMinIsPoison. If we don't know any ones other 438 // than the signbit, we know for certain that all the unknowns can't be 439 // zero. So if we know high zero bits, but have unknown low bits, we know 440 // for certain those high-zero bits will end up as one. This is because, 441 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up 442 // to the high bits. If we know a known INT_MIN input skip this. The result 443 // is poison anyways. 444 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 && 445 Tmp.countMaxPopulation() != 1) { 446 Tmp.One.clearSignBit(); 447 Tmp.Zero.setSignBit(); 448 KnownAbs.One.setBits(getBitWidth() - Tmp.countMinLeadingZeros(), 449 getBitWidth() - 1); 450 } 451 452 } else { 453 unsigned MaxTZ = countMaxTrailingZeros(); 454 unsigned MinTZ = countMinTrailingZeros(); 455 456 KnownAbs.Zero.setLowBits(MinTZ); 457 // If we know the lowest set 1, then preserve it. 458 if (MaxTZ == MinTZ && MaxTZ < getBitWidth()) 459 KnownAbs.One.setBit(MaxTZ); 460 461 // We only know that the absolute values's MSB will be zero if INT_MIN is 462 // poison, or there is a set bit that isn't the sign bit (otherwise it could 463 // be INT_MIN). 464 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) { 465 KnownAbs.One.clearSignBit(); 466 KnownAbs.Zero.setSignBit(); 467 } 468 } 469 470 assert(!KnownAbs.hasConflict() && "Bad Output"); 471 return KnownAbs; 472 } 473 474 static KnownBits computeForSatAddSub(bool Add, bool Signed, 475 const KnownBits &LHS, 476 const KnownBits &RHS) { 477 assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); 478 // We don't see NSW even for sadd/ssub as we want to check if the result has 479 // signed overflow. 480 KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS); 481 unsigned BitWidth = Res.getBitWidth(); 482 auto SignBitKnown = [&](const KnownBits &K) { 483 return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; 484 }; 485 std::optional<bool> Overflow; 486 487 if (Signed) { 488 // If we can actually detect overflow do so. Otherwise leave Overflow as 489 // nullopt (we assume it may have happened). 490 if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) { 491 if (Add) { 492 // sadd.sat 493 Overflow = (LHS.isNonNegative() == RHS.isNonNegative() && 494 Res.isNonNegative() != LHS.isNonNegative()); 495 } else { 496 // ssub.sat 497 Overflow = (LHS.isNonNegative() != RHS.isNonNegative() && 498 Res.isNonNegative() != LHS.isNonNegative()); 499 } 500 } 501 } else if (Add) { 502 // uadd.sat 503 bool Of; 504 (void)LHS.getMaxValue().uadd_ov(RHS.getMaxValue(), Of); 505 if (!Of) { 506 Overflow = false; 507 } else { 508 (void)LHS.getMinValue().uadd_ov(RHS.getMinValue(), Of); 509 if (Of) 510 Overflow = true; 511 } 512 } else { 513 // usub.sat 514 bool Of; 515 (void)LHS.getMinValue().usub_ov(RHS.getMaxValue(), Of); 516 if (!Of) { 517 Overflow = false; 518 } else { 519 (void)LHS.getMaxValue().usub_ov(RHS.getMinValue(), Of); 520 if (Of) 521 Overflow = true; 522 } 523 } 524 525 if (Signed) { 526 if (Add) { 527 if (LHS.isNonNegative() && RHS.isNonNegative()) { 528 // Pos + Pos -> Pos 529 Res.One.clearSignBit(); 530 Res.Zero.setSignBit(); 531 } 532 if (LHS.isNegative() && RHS.isNegative()) { 533 // Neg + Neg -> Neg 534 Res.One.setSignBit(); 535 Res.Zero.clearSignBit(); 536 } 537 } else { 538 if (LHS.isNegative() && RHS.isNonNegative()) { 539 // Neg - Pos -> Neg 540 Res.One.setSignBit(); 541 Res.Zero.clearSignBit(); 542 } else if (LHS.isNonNegative() && RHS.isNegative()) { 543 // Pos - Neg -> Pos 544 Res.One.clearSignBit(); 545 Res.Zero.setSignBit(); 546 } 547 } 548 } else { 549 // Add: Leading ones of either operand are preserved. 550 // Sub: Leading zeros of LHS and leading ones of RHS are preserved 551 // as leading zeros in the result. 552 unsigned LeadingKnown; 553 if (Add) 554 LeadingKnown = 555 std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes()); 556 else 557 LeadingKnown = 558 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes()); 559 560 // We select between the operation result and all-ones/zero 561 // respectively, so we can preserve known ones/zeros. 562 APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown); 563 if (Add) { 564 Res.One |= Mask; 565 Res.Zero &= ~Mask; 566 } else { 567 Res.Zero |= Mask; 568 Res.One &= ~Mask; 569 } 570 } 571 572 if (Overflow) { 573 // We know whether or not we overflowed. 574 if (!(*Overflow)) { 575 // No overflow. 576 assert(!Res.hasConflict() && "Bad Output"); 577 return Res; 578 } 579 580 // We overflowed 581 APInt C; 582 if (Signed) { 583 // sadd.sat / ssub.sat 584 assert(SignBitKnown(LHS) && 585 "We somehow know overflow without knowing input sign"); 586 C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth) 587 : APInt::getSignedMaxValue(BitWidth); 588 } else if (Add) { 589 // uadd.sat 590 C = APInt::getMaxValue(BitWidth); 591 } else { 592 // uadd.sat 593 C = APInt::getMinValue(BitWidth); 594 } 595 596 Res.One = C; 597 Res.Zero = ~C; 598 assert(!Res.hasConflict() && "Bad Output"); 599 return Res; 600 } 601 602 // We don't know if we overflowed. 603 if (Signed) { 604 // sadd.sat/ssub.sat 605 // We can keep our information about the sign bits. 606 Res.Zero.clearLowBits(BitWidth - 1); 607 Res.One.clearLowBits(BitWidth - 1); 608 } else if (Add) { 609 // uadd.sat 610 // We need to clear all the known zeros as we can only use the leading ones. 611 Res.Zero.clearAllBits(); 612 } else { 613 // usub.sat 614 // We need to clear all the known ones as we can only use the leading zero. 615 Res.One.clearAllBits(); 616 } 617 618 assert(!Res.hasConflict() && "Bad Output"); 619 return Res; 620 } 621 622 KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) { 623 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS); 624 } 625 KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) { 626 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS); 627 } 628 KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) { 629 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS); 630 } 631 KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { 632 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); 633 } 634 635 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, 636 bool NoUndefSelfMultiply) { 637 unsigned BitWidth = LHS.getBitWidth(); 638 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && 639 !RHS.hasConflict() && "Operand mismatch"); 640 assert((!NoUndefSelfMultiply || LHS == RHS) && 641 "Self multiplication knownbits mismatch"); 642 643 // Compute the high known-0 bits by multiplying the unsigned max of each side. 644 // Conservatively, M active bits * N active bits results in M + N bits in the 645 // result. But if we know a value is a power-of-2 for example, then this 646 // computes one more leading zero. 647 // TODO: This could be generalized to number of sign bits (negative numbers). 648 APInt UMaxLHS = LHS.getMaxValue(); 649 APInt UMaxRHS = RHS.getMaxValue(); 650 651 // For leading zeros in the result to be valid, the unsigned max product must 652 // fit in the bitwidth (it must not overflow). 653 bool HasOverflow; 654 APInt UMaxResult = UMaxLHS.umul_ov(UMaxRHS, HasOverflow); 655 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero(); 656 657 // The result of the bottom bits of an integer multiply can be 658 // inferred by looking at the bottom bits of both operands and 659 // multiplying them together. 660 // We can infer at least the minimum number of known trailing bits 661 // of both operands. Depending on number of trailing zeros, we can 662 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming 663 // a and b are divisible by m and n respectively. 664 // We then calculate how many of those bits are inferrable and set 665 // the output. For example, the i8 mul: 666 // a = XXXX1100 (12) 667 // b = XXXX1110 (14) 668 // We know the bottom 3 bits are zero since the first can be divided by 669 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). 670 // Applying the multiplication to the trimmed arguments gets: 671 // XX11 (3) 672 // X111 (7) 673 // ------- 674 // XX11 675 // XX11 676 // XX11 677 // XX11 678 // ------- 679 // XXXXX01 680 // Which allows us to infer the 2 LSBs. Since we're multiplying the result 681 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. 682 // The proof for this can be described as: 683 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && 684 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + 685 // umin(countTrailingZeros(C2), C6) + 686 // umin(C5 - umin(countTrailingZeros(C1), C5), 687 // C6 - umin(countTrailingZeros(C2), C6)))) - 1) 688 // %aa = shl i8 %a, C5 689 // %bb = shl i8 %b, C6 690 // %aaa = or i8 %aa, C1 691 // %bbb = or i8 %bb, C2 692 // %mul = mul i8 %aaa, %bbb 693 // %mask = and i8 %mul, C7 694 // => 695 // %mask = i8 ((C1*C2)&C7) 696 // Where C5, C6 describe the known bits of %a, %b 697 // C1, C2 describe the known bottom bits of %a, %b. 698 // C7 describes the mask of the known bits of the result. 699 const APInt &Bottom0 = LHS.One; 700 const APInt &Bottom1 = RHS.One; 701 702 // How many times we'd be able to divide each argument by 2 (shr by 1). 703 // This gives us the number of trailing zeros on the multiplication result. 704 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one(); 705 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one(); 706 unsigned TrailZero0 = LHS.countMinTrailingZeros(); 707 unsigned TrailZero1 = RHS.countMinTrailingZeros(); 708 unsigned TrailZ = TrailZero0 + TrailZero1; 709 710 // Figure out the fewest known-bits operand. 711 unsigned SmallestOperand = 712 std::min(TrailBitsKnown0 - TrailZero0, TrailBitsKnown1 - TrailZero1); 713 unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); 714 715 APInt BottomKnown = 716 Bottom0.getLoBits(TrailBitsKnown0) * Bottom1.getLoBits(TrailBitsKnown1); 717 718 KnownBits Res(BitWidth); 719 Res.Zero.setHighBits(LeadZ); 720 Res.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); 721 Res.One = BottomKnown.getLoBits(ResultBitsKnown); 722 723 // If we're self-multiplying then bit[1] is guaranteed to be zero. 724 if (NoUndefSelfMultiply && BitWidth > 1) { 725 assert(Res.One[1] == 0 && 726 "Self-multiplication failed Quadratic Reciprocity!"); 727 Res.Zero.setBit(1); 728 } 729 730 return Res; 731 } 732 733 KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) { 734 unsigned BitWidth = LHS.getBitWidth(); 735 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && 736 !RHS.hasConflict() && "Operand mismatch"); 737 KnownBits WideLHS = LHS.sext(2 * BitWidth); 738 KnownBits WideRHS = RHS.sext(2 * BitWidth); 739 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); 740 } 741 742 KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) { 743 unsigned BitWidth = LHS.getBitWidth(); 744 assert(BitWidth == RHS.getBitWidth() && !LHS.hasConflict() && 745 !RHS.hasConflict() && "Operand mismatch"); 746 KnownBits WideLHS = LHS.zext(2 * BitWidth); 747 KnownBits WideRHS = RHS.zext(2 * BitWidth); 748 return mul(WideLHS, WideRHS).extractBits(BitWidth, BitWidth); 749 } 750 751 static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS, 752 const KnownBits &RHS, bool Exact) { 753 754 if (!Exact) 755 return Known; 756 757 // If LHS is Odd, the result is Odd no matter what. 758 // Odd / Odd -> Odd 759 // Odd / Even -> Impossible (because its exact division) 760 if (LHS.One[0]) 761 Known.One.setBit(0); 762 763 int MinTZ = 764 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros(); 765 int MaxTZ = 766 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros(); 767 if (MinTZ >= 0) { 768 // Result has at least MinTZ trailing zeros. 769 Known.Zero.setLowBits(MinTZ); 770 if (MinTZ == MaxTZ) { 771 // Result has exactly MinTZ trailing zeros. 772 Known.One.setBit(MinTZ); 773 } 774 } else if (MaxTZ < 0) { 775 // Poison Result 776 Known.setAllZero(); 777 } 778 779 // In the KnownBits exhaustive tests, we have poison inputs for exact values 780 // a LOT. If we have a conflict, just return all zeros. 781 if (Known.hasConflict()) 782 Known.setAllZero(); 783 784 return Known; 785 } 786 787 KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, 788 bool Exact) { 789 // Equivalent of `udiv`. We must have caught this before it was folded. 790 if (LHS.isNonNegative() && RHS.isNonNegative()) 791 return udiv(LHS, RHS, Exact); 792 793 unsigned BitWidth = LHS.getBitWidth(); 794 assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); 795 KnownBits Known(BitWidth); 796 797 if (LHS.isZero() || RHS.isZero()) { 798 // Result is either known Zero or UB. Return Zero either way. 799 // Checking this earlier saves us a lot of special cases later on. 800 Known.setAllZero(); 801 return Known; 802 } 803 804 std::optional<APInt> Res; 805 if (LHS.isNegative() && RHS.isNegative()) { 806 // Result non-negative. 807 APInt Denom = RHS.getSignedMaxValue(); 808 APInt Num = LHS.getSignedMinValue(); 809 // INT_MIN/-1 would be a poison result (impossible). Estimate the division 810 // as signed max (we will only set sign bit in the result). 811 Res = (Num.isMinSignedValue() && Denom.isAllOnes()) 812 ? APInt::getSignedMaxValue(BitWidth) 813 : Num.sdiv(Denom); 814 } else if (LHS.isNegative() && RHS.isNonNegative()) { 815 // Result is negative if Exact OR -LHS u>= RHS. 816 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) { 817 APInt Denom = RHS.getSignedMinValue(); 818 APInt Num = LHS.getSignedMinValue(); 819 Res = Denom.isZero() ? Num : Num.sdiv(Denom); 820 } 821 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { 822 // Result is negative if Exact OR LHS u>= -RHS. 823 if (Exact || LHS.getSignedMinValue().uge(-RHS.getSignedMinValue())) { 824 APInt Denom = RHS.getSignedMaxValue(); 825 APInt Num = LHS.getSignedMaxValue(); 826 Res = Num.sdiv(Denom); 827 } 828 } 829 830 if (Res) { 831 if (Res->isNonNegative()) { 832 unsigned LeadZ = Res->countLeadingZeros(); 833 Known.Zero.setHighBits(LeadZ); 834 } else { 835 unsigned LeadO = Res->countLeadingOnes(); 836 Known.One.setHighBits(LeadO); 837 } 838 } 839 840 Known = divComputeLowBit(Known, LHS, RHS, Exact); 841 842 assert(!Known.hasConflict() && "Bad Output"); 843 return Known; 844 } 845 846 KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS, 847 bool Exact) { 848 unsigned BitWidth = LHS.getBitWidth(); 849 assert(!LHS.hasConflict() && !RHS.hasConflict()); 850 KnownBits Known(BitWidth); 851 852 if (LHS.isZero() || RHS.isZero()) { 853 // Result is either known Zero or UB. Return Zero either way. 854 // Checking this earlier saves us a lot of special cases later on. 855 Known.setAllZero(); 856 return Known; 857 } 858 859 // We can figure out the minimum number of upper zero bits by doing 860 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator 861 // gets larger, the number of upper zero bits increases. 862 APInt MinDenom = RHS.getMinValue(); 863 APInt MaxNum = LHS.getMaxValue(); 864 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(MinDenom); 865 866 unsigned LeadZ = MaxRes.countLeadingZeros(); 867 868 Known.Zero.setHighBits(LeadZ); 869 Known = divComputeLowBit(Known, LHS, RHS, Exact); 870 871 assert(!Known.hasConflict() && "Bad Output"); 872 return Known; 873 } 874 875 KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) { 876 unsigned BitWidth = LHS.getBitWidth(); 877 if (!RHS.isZero() && RHS.Zero[0]) { 878 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result. 879 unsigned RHSZeros = RHS.countMinTrailingZeros(); 880 APInt Mask = APInt::getLowBitsSet(BitWidth, RHSZeros); 881 APInt OnesMask = LHS.One & Mask; 882 APInt ZerosMask = LHS.Zero & Mask; 883 return KnownBits(ZerosMask, OnesMask); 884 } 885 return KnownBits(BitWidth); 886 } 887 888 KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) { 889 assert(!LHS.hasConflict() && !RHS.hasConflict()); 890 891 KnownBits Known = remGetLowBits(LHS, RHS); 892 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 893 // NB: Low bits set in `remGetLowBits`. 894 APInt HighBits = ~(RHS.getConstant() - 1); 895 Known.Zero |= HighBits; 896 return Known; 897 } 898 899 // Since the result is less than or equal to either operand, any leading 900 // zero bits in either operand must also exist in the result. 901 uint32_t Leaders = 902 std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingZeros()); 903 Known.Zero.setHighBits(Leaders); 904 return Known; 905 } 906 907 KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) { 908 assert(!LHS.hasConflict() && !RHS.hasConflict()); 909 910 KnownBits Known = remGetLowBits(LHS, RHS); 911 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { 912 // NB: Low bits are set in `remGetLowBits`. 913 APInt LowBits = RHS.getConstant() - 1; 914 // If the first operand is non-negative or has all low bits zero, then 915 // the upper bits are all zero. 916 if (LHS.isNonNegative() || LowBits.isSubsetOf(LHS.Zero)) 917 Known.Zero |= ~LowBits; 918 919 // If the first operand is negative and not all low bits are zero, then 920 // the upper bits are all one. 921 if (LHS.isNegative() && LowBits.intersects(LHS.One)) 922 Known.One |= ~LowBits; 923 return Known; 924 } 925 926 // The sign bit is the LHS's sign bit, except when the result of the 927 // remainder is zero. The magnitude of the result should be less than or 928 // equal to the magnitude of the LHS. Therefore any leading zeros that exist 929 // in the left hand side must also exist in the result. 930 Known.Zero.setHighBits(LHS.countMinLeadingZeros()); 931 return Known; 932 } 933 934 KnownBits &KnownBits::operator&=(const KnownBits &RHS) { 935 // Result bit is 0 if either operand bit is 0. 936 Zero |= RHS.Zero; 937 // Result bit is 1 if both operand bits are 1. 938 One &= RHS.One; 939 return *this; 940 } 941 942 KnownBits &KnownBits::operator|=(const KnownBits &RHS) { 943 // Result bit is 0 if both operand bits are 0. 944 Zero &= RHS.Zero; 945 // Result bit is 1 if either operand bit is 1. 946 One |= RHS.One; 947 return *this; 948 } 949 950 KnownBits &KnownBits::operator^=(const KnownBits &RHS) { 951 // Result bit is 0 if both operand bits are 0 or both are 1. 952 APInt Z = (Zero & RHS.Zero) | (One & RHS.One); 953 // Result bit is 1 if one operand bit is 0 and the other is 1. 954 One = (Zero & RHS.One) | (One & RHS.Zero); 955 Zero = std::move(Z); 956 return *this; 957 } 958 959 KnownBits KnownBits::blsi() const { 960 unsigned BitWidth = getBitWidth(); 961 KnownBits Known(Zero, APInt(BitWidth, 0)); 962 unsigned Max = countMaxTrailingZeros(); 963 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth)); 964 unsigned Min = countMinTrailingZeros(); 965 if (Max == Min && Max < BitWidth) 966 Known.One.setBit(Max); 967 return Known; 968 } 969 970 KnownBits KnownBits::blsmsk() const { 971 unsigned BitWidth = getBitWidth(); 972 KnownBits Known(BitWidth); 973 unsigned Max = countMaxTrailingZeros(); 974 Known.Zero.setBitsFrom(std::min(Max + 1, BitWidth)); 975 unsigned Min = countMinTrailingZeros(); 976 Known.One.setLowBits(std::min(Min + 1, BitWidth)); 977 return Known; 978 } 979 980 void KnownBits::print(raw_ostream &OS) const { 981 unsigned BitWidth = getBitWidth(); 982 for (unsigned I = 0; I < BitWidth; ++I) { 983 unsigned N = BitWidth - I - 1; 984 if (Zero[N] && One[N]) 985 OS << "!"; 986 else if (Zero[N]) 987 OS << "0"; 988 else if (One[N]) 989 OS << "1"; 990 else 991 OS << "?"; 992 } 993 } 994 void KnownBits::dump() const { 995 print(dbgs()); 996 dbgs() << "\n"; 997 } 998