xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Support/KnownBits.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- llvm/Support/KnownBits.h - Stores known zeros/ones -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a class for representing known zeros and ones used by
10 // computeKnownBits.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_SUPPORT_KNOWNBITS_H
15 #define LLVM_SUPPORT_KNOWNBITS_H
16 
17 #include "llvm/ADT/APInt.h"
18 #include <optional>
19 
20 namespace llvm {
21 
22 // Struct for tracking the known zeros and ones of a value.
23 struct KnownBits {
24   APInt Zero;
25   APInt One;
26 
27 private:
28   // Internal constructor for creating a KnownBits from two APInts.
KnownBitsKnownBits29   KnownBits(APInt Zero, APInt One)
30       : Zero(std::move(Zero)), One(std::move(One)) {}
31 
32 public:
33   // Default construct Zero and One.
34   KnownBits() = default;
35 
36   /// Create a known bits object of BitWidth bits initialized to unknown.
KnownBitsKnownBits37   KnownBits(unsigned BitWidth) : Zero(BitWidth, 0), One(BitWidth, 0) {}
38 
39   /// Get the bit width of this value.
getBitWidthKnownBits40   unsigned getBitWidth() const {
41     assert(Zero.getBitWidth() == One.getBitWidth() &&
42            "Zero and One should have the same width!");
43     return Zero.getBitWidth();
44   }
45 
46   /// Returns true if there is conflicting information.
hasConflictKnownBits47   bool hasConflict() const { return Zero.intersects(One); }
48 
49   /// Returns true if we know the value of all bits.
isConstantKnownBits50   bool isConstant() const {
51     return Zero.popcount() + One.popcount() == getBitWidth();
52   }
53 
54   /// Returns the value when all bits have a known value. This just returns One
55   /// with a protective assertion.
getConstantKnownBits56   const APInt &getConstant() const {
57     assert(isConstant() && "Can only get value when all bits are known");
58     return One;
59   }
60 
61   /// Returns true if we don't know any bits.
isUnknownKnownBits62   bool isUnknown() const { return Zero.isZero() && One.isZero(); }
63 
64   /// Returns true if we don't know the sign bit.
isSignUnknownKnownBits65   bool isSignUnknown() const {
66     return !Zero.isSignBitSet() && !One.isSignBitSet();
67   }
68 
69   /// Resets the known state of all bits.
resetAllKnownBits70   void resetAll() {
71     Zero.clearAllBits();
72     One.clearAllBits();
73   }
74 
75   /// Returns true if value is all zero.
isZeroKnownBits76   bool isZero() const { return Zero.isAllOnes(); }
77 
78   /// Returns true if value is all one bits.
isAllOnesKnownBits79   bool isAllOnes() const { return One.isAllOnes(); }
80 
81   /// Make all bits known to be zero and discard any previous information.
setAllZeroKnownBits82   void setAllZero() {
83     Zero.setAllBits();
84     One.clearAllBits();
85   }
86 
87   /// Make all bits known to be one and discard any previous information.
setAllOnesKnownBits88   void setAllOnes() {
89     Zero.clearAllBits();
90     One.setAllBits();
91   }
92 
93   /// Returns true if this value is known to be negative.
isNegativeKnownBits94   bool isNegative() const { return One.isSignBitSet(); }
95 
96   /// Returns true if this value is known to be non-negative.
isNonNegativeKnownBits97   bool isNonNegative() const { return Zero.isSignBitSet(); }
98 
99   /// Returns true if this value is known to be non-zero.
isNonZeroKnownBits100   bool isNonZero() const { return !One.isZero(); }
101 
102   /// Returns true if this value is known to be positive.
isStrictlyPositiveKnownBits103   bool isStrictlyPositive() const {
104     return Zero.isSignBitSet() && !One.isZero();
105   }
106 
107   /// Make this value negative.
makeNegativeKnownBits108   void makeNegative() {
109     One.setSignBit();
110   }
111 
112   /// Make this value non-negative.
makeNonNegativeKnownBits113   void makeNonNegative() {
114     Zero.setSignBit();
115   }
116 
117   /// Return the minimal unsigned value possible given these KnownBits.
getMinValueKnownBits118   APInt getMinValue() const {
119     // Assume that all bits that aren't known-ones are zeros.
120     return One;
121   }
122 
123   /// Return the minimal signed value possible given these KnownBits.
getSignedMinValueKnownBits124   APInt getSignedMinValue() const {
125     // Assume that all bits that aren't known-ones are zeros.
126     APInt Min = One;
127     // Sign bit is unknown.
128     if (Zero.isSignBitClear())
129       Min.setSignBit();
130     return Min;
131   }
132 
133   /// Return the maximal unsigned value possible given these KnownBits.
getMaxValueKnownBits134   APInt getMaxValue() const {
135     // Assume that all bits that aren't known-zeros are ones.
136     return ~Zero;
137   }
138 
139   /// Return the maximal signed value possible given these KnownBits.
getSignedMaxValueKnownBits140   APInt getSignedMaxValue() const {
141     // Assume that all bits that aren't known-zeros are ones.
142     APInt Max = ~Zero;
143     // Sign bit is unknown.
144     if (One.isSignBitClear())
145       Max.clearSignBit();
146     return Max;
147   }
148 
149   /// Return known bits for a truncation of the value we're tracking.
truncKnownBits150   KnownBits trunc(unsigned BitWidth) const {
151     return KnownBits(Zero.trunc(BitWidth), One.trunc(BitWidth));
152   }
153 
154   /// Return known bits for an "any" extension of the value we're tracking,
155   /// where we don't know anything about the extended bits.
anyextKnownBits156   KnownBits anyext(unsigned BitWidth) const {
157     return KnownBits(Zero.zext(BitWidth), One.zext(BitWidth));
158   }
159 
160   /// Return known bits for a zero extension of the value we're tracking.
zextKnownBits161   KnownBits zext(unsigned BitWidth) const {
162     unsigned OldBitWidth = getBitWidth();
163     APInt NewZero = Zero.zext(BitWidth);
164     NewZero.setBitsFrom(OldBitWidth);
165     return KnownBits(NewZero, One.zext(BitWidth));
166   }
167 
168   /// Return known bits for a sign extension of the value we're tracking.
sextKnownBits169   KnownBits sext(unsigned BitWidth) const {
170     return KnownBits(Zero.sext(BitWidth), One.sext(BitWidth));
171   }
172 
173   /// Return known bits for an "any" extension or truncation of the value we're
174   /// tracking.
anyextOrTruncKnownBits175   KnownBits anyextOrTrunc(unsigned BitWidth) const {
176     if (BitWidth > getBitWidth())
177       return anyext(BitWidth);
178     if (BitWidth < getBitWidth())
179       return trunc(BitWidth);
180     return *this;
181   }
182 
183   /// Return known bits for a zero extension or truncation of the value we're
184   /// tracking.
zextOrTruncKnownBits185   KnownBits zextOrTrunc(unsigned BitWidth) const {
186     if (BitWidth > getBitWidth())
187       return zext(BitWidth);
188     if (BitWidth < getBitWidth())
189       return trunc(BitWidth);
190     return *this;
191   }
192 
193   /// Return known bits for a sign extension or truncation of the value we're
194   /// tracking.
sextOrTruncKnownBits195   KnownBits sextOrTrunc(unsigned BitWidth) const {
196     if (BitWidth > getBitWidth())
197       return sext(BitWidth);
198     if (BitWidth < getBitWidth())
199       return trunc(BitWidth);
200     return *this;
201   }
202 
203   /// Return known bits for a in-register sign extension of the value we're
204   /// tracking.
205   KnownBits sextInReg(unsigned SrcBitWidth) const;
206 
207   /// Insert the bits from a smaller known bits starting at bitPosition.
insertBitsKnownBits208   void insertBits(const KnownBits &SubBits, unsigned BitPosition) {
209     Zero.insertBits(SubBits.Zero, BitPosition);
210     One.insertBits(SubBits.One, BitPosition);
211   }
212 
213   /// Return a subset of the known bits from [bitPosition,bitPosition+numBits).
extractBitsKnownBits214   KnownBits extractBits(unsigned NumBits, unsigned BitPosition) const {
215     return KnownBits(Zero.extractBits(NumBits, BitPosition),
216                      One.extractBits(NumBits, BitPosition));
217   }
218 
219   /// Concatenate the bits from \p Lo onto the bottom of *this.  This is
220   /// equivalent to:
221   ///   (this->zext(NewWidth) << Lo.getBitWidth()) | Lo.zext(NewWidth)
concatKnownBits222   KnownBits concat(const KnownBits &Lo) const {
223     return KnownBits(Zero.concat(Lo.Zero), One.concat(Lo.One));
224   }
225 
226   /// Return KnownBits based on this, but updated given that the underlying
227   /// value is known to be greater than or equal to Val.
228   KnownBits makeGE(const APInt &Val) const;
229 
230   /// Returns the minimum number of trailing zero bits.
countMinTrailingZerosKnownBits231   unsigned countMinTrailingZeros() const { return Zero.countr_one(); }
232 
233   /// Returns the minimum number of trailing one bits.
countMinTrailingOnesKnownBits234   unsigned countMinTrailingOnes() const { return One.countr_one(); }
235 
236   /// Returns the minimum number of leading zero bits.
countMinLeadingZerosKnownBits237   unsigned countMinLeadingZeros() const { return Zero.countl_one(); }
238 
239   /// Returns the minimum number of leading one bits.
countMinLeadingOnesKnownBits240   unsigned countMinLeadingOnes() const { return One.countl_one(); }
241 
242   /// Returns the number of times the sign bit is replicated into the other
243   /// bits.
countMinSignBitsKnownBits244   unsigned countMinSignBits() const {
245     if (isNonNegative())
246       return countMinLeadingZeros();
247     if (isNegative())
248       return countMinLeadingOnes();
249     // Every value has at least 1 sign bit.
250     return 1;
251   }
252 
253   /// Returns the maximum number of bits needed to represent all possible
254   /// signed values with these known bits. This is the inverse of the minimum
255   /// number of known sign bits. Examples for bitwidth 5:
256   /// 110?? --> 4
257   /// 0000? --> 2
countMaxSignificantBitsKnownBits258   unsigned countMaxSignificantBits() const {
259     return getBitWidth() - countMinSignBits() + 1;
260   }
261 
262   /// Returns the maximum number of trailing zero bits possible.
countMaxTrailingZerosKnownBits263   unsigned countMaxTrailingZeros() const { return One.countr_zero(); }
264 
265   /// Returns the maximum number of trailing one bits possible.
countMaxTrailingOnesKnownBits266   unsigned countMaxTrailingOnes() const { return Zero.countr_zero(); }
267 
268   /// Returns the maximum number of leading zero bits possible.
countMaxLeadingZerosKnownBits269   unsigned countMaxLeadingZeros() const { return One.countl_zero(); }
270 
271   /// Returns the maximum number of leading one bits possible.
countMaxLeadingOnesKnownBits272   unsigned countMaxLeadingOnes() const { return Zero.countl_zero(); }
273 
274   /// Returns the number of bits known to be one.
countMinPopulationKnownBits275   unsigned countMinPopulation() const { return One.popcount(); }
276 
277   /// Returns the maximum number of bits that could be one.
countMaxPopulationKnownBits278   unsigned countMaxPopulation() const {
279     return getBitWidth() - Zero.popcount();
280   }
281 
282   /// Returns the maximum number of bits needed to represent all possible
283   /// unsigned values with these known bits. This is the inverse of the
284   /// minimum number of leading zeros.
countMaxActiveBitsKnownBits285   unsigned countMaxActiveBits() const {
286     return getBitWidth() - countMinLeadingZeros();
287   }
288 
289   /// Create known bits from a known constant.
makeConstantKnownBits290   static KnownBits makeConstant(const APInt &C) {
291     return KnownBits(~C, C);
292   }
293 
294   /// Returns KnownBits information that is known to be true for both this and
295   /// RHS.
296   ///
297   /// When an operation is known to return one of its operands, this can be used
298   /// to combine information about the known bits of the operands to get the
299   /// information that must be true about the result.
intersectWithKnownBits300   KnownBits intersectWith(const KnownBits &RHS) const {
301     return KnownBits(Zero & RHS.Zero, One & RHS.One);
302   }
303 
304   /// Returns KnownBits information that is known to be true for either this or
305   /// RHS or both.
306   ///
307   /// This can be used to combine different sources of information about the
308   /// known bits of a single value, e.g. information about the low bits and the
309   /// high bits of the result of a multiplication.
unionWithKnownBits310   KnownBits unionWith(const KnownBits &RHS) const {
311     return KnownBits(Zero | RHS.Zero, One | RHS.One);
312   }
313 
314   /// Return true if LHS and RHS have no common bits set.
haveNoCommonBitsSetKnownBits315   static bool haveNoCommonBitsSet(const KnownBits &LHS, const KnownBits &RHS) {
316     return (LHS.Zero | RHS.Zero).isAllOnes();
317   }
318 
319   /// Compute known bits resulting from adding LHS, RHS and a 1-bit Carry.
320   static KnownBits computeForAddCarry(
321       const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
322 
323   /// Compute known bits resulting from adding LHS and RHS.
324   static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
325                                     const KnownBits &LHS, const KnownBits &RHS);
326 
327   /// Compute known bits results from subtracting RHS from LHS with 1-bit
328   /// Borrow.
329   static KnownBits computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
330                                        const KnownBits &Borrow);
331 
332   /// Compute knownbits resulting from llvm.sadd.sat(LHS, RHS)
333   static KnownBits sadd_sat(const KnownBits &LHS, const KnownBits &RHS);
334 
335   /// Compute knownbits resulting from llvm.uadd.sat(LHS, RHS)
336   static KnownBits uadd_sat(const KnownBits &LHS, const KnownBits &RHS);
337 
338   /// Compute knownbits resulting from llvm.ssub.sat(LHS, RHS)
339   static KnownBits ssub_sat(const KnownBits &LHS, const KnownBits &RHS);
340 
341   /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS)
342   static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS);
343 
344   /// Compute knownbits resulting from APIntOps::avgFloorS
345   static KnownBits avgFloorS(const KnownBits &LHS, const KnownBits &RHS);
346 
347   /// Compute knownbits resulting from APIntOps::avgFloorU
348   static KnownBits avgFloorU(const KnownBits &LHS, const KnownBits &RHS);
349 
350   /// Compute knownbits resulting from APIntOps::avgCeilS
351   static KnownBits avgCeilS(const KnownBits &LHS, const KnownBits &RHS);
352 
353   /// Compute knownbits resulting from APIntOps::avgCeilU
354   static KnownBits avgCeilU(const KnownBits &LHS, const KnownBits &RHS);
355 
356   /// Compute known bits resulting from multiplying LHS and RHS.
357   static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS,
358                        bool NoUndefSelfMultiply = false);
359 
360   /// Compute known bits from sign-extended multiply-hi.
361   static KnownBits mulhs(const KnownBits &LHS, const KnownBits &RHS);
362 
363   /// Compute known bits from zero-extended multiply-hi.
364   static KnownBits mulhu(const KnownBits &LHS, const KnownBits &RHS);
365 
366   /// Compute known bits for sdiv(LHS, RHS).
367   static KnownBits sdiv(const KnownBits &LHS, const KnownBits &RHS,
368                         bool Exact = false);
369 
370   /// Compute known bits for udiv(LHS, RHS).
371   static KnownBits udiv(const KnownBits &LHS, const KnownBits &RHS,
372                         bool Exact = false);
373 
374   /// Compute known bits for urem(LHS, RHS).
375   static KnownBits urem(const KnownBits &LHS, const KnownBits &RHS);
376 
377   /// Compute known bits for srem(LHS, RHS).
378   static KnownBits srem(const KnownBits &LHS, const KnownBits &RHS);
379 
380   /// Compute known bits for umax(LHS, RHS).
381   static KnownBits umax(const KnownBits &LHS, const KnownBits &RHS);
382 
383   /// Compute known bits for umin(LHS, RHS).
384   static KnownBits umin(const KnownBits &LHS, const KnownBits &RHS);
385 
386   /// Compute known bits for smax(LHS, RHS).
387   static KnownBits smax(const KnownBits &LHS, const KnownBits &RHS);
388 
389   /// Compute known bits for smin(LHS, RHS).
390   static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
391 
392   /// Compute known bits for abdu(LHS, RHS).
393   static KnownBits abdu(const KnownBits &LHS, const KnownBits &RHS);
394 
395   /// Compute known bits for abds(LHS, RHS).
396   static KnownBits abds(KnownBits LHS, KnownBits RHS);
397 
398   /// Compute known bits for shl(LHS, RHS).
399   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
400   static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
401                        bool NUW = false, bool NSW = false,
402                        bool ShAmtNonZero = false);
403 
404   /// Compute known bits for lshr(LHS, RHS).
405   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
406   static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS,
407                         bool ShAmtNonZero = false, bool Exact = false);
408 
409   /// Compute known bits for ashr(LHS, RHS).
410   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
411   static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS,
412                         bool ShAmtNonZero = false, bool Exact = false);
413 
414   /// Determine if these known bits always give the same ICMP_EQ result.
415   static std::optional<bool> eq(const KnownBits &LHS, const KnownBits &RHS);
416 
417   /// Determine if these known bits always give the same ICMP_NE result.
418   static std::optional<bool> ne(const KnownBits &LHS, const KnownBits &RHS);
419 
420   /// Determine if these known bits always give the same ICMP_UGT result.
421   static std::optional<bool> ugt(const KnownBits &LHS, const KnownBits &RHS);
422 
423   /// Determine if these known bits always give the same ICMP_UGE result.
424   static std::optional<bool> uge(const KnownBits &LHS, const KnownBits &RHS);
425 
426   /// Determine if these known bits always give the same ICMP_ULT result.
427   static std::optional<bool> ult(const KnownBits &LHS, const KnownBits &RHS);
428 
429   /// Determine if these known bits always give the same ICMP_ULE result.
430   static std::optional<bool> ule(const KnownBits &LHS, const KnownBits &RHS);
431 
432   /// Determine if these known bits always give the same ICMP_SGT result.
433   static std::optional<bool> sgt(const KnownBits &LHS, const KnownBits &RHS);
434 
435   /// Determine if these known bits always give the same ICMP_SGE result.
436   static std::optional<bool> sge(const KnownBits &LHS, const KnownBits &RHS);
437 
438   /// Determine if these known bits always give the same ICMP_SLT result.
439   static std::optional<bool> slt(const KnownBits &LHS, const KnownBits &RHS);
440 
441   /// Determine if these known bits always give the same ICMP_SLE result.
442   static std::optional<bool> sle(const KnownBits &LHS, const KnownBits &RHS);
443 
444   /// Update known bits based on ANDing with RHS.
445   KnownBits &operator&=(const KnownBits &RHS);
446 
447   /// Update known bits based on ORing with RHS.
448   KnownBits &operator|=(const KnownBits &RHS);
449 
450   /// Update known bits based on XORing with RHS.
451   KnownBits &operator^=(const KnownBits &RHS);
452 
453   /// Compute known bits for the absolute value.
454   KnownBits abs(bool IntMinIsPoison = false) const;
455 
byteSwapKnownBits456   KnownBits byteSwap() const {
457     return KnownBits(Zero.byteSwap(), One.byteSwap());
458   }
459 
reverseBitsKnownBits460   KnownBits reverseBits() const {
461     return KnownBits(Zero.reverseBits(), One.reverseBits());
462   }
463 
464   /// Compute known bits for X & -X, which has only the lowest bit set of X set.
465   /// The name comes from the X86 BMI instruction
466   KnownBits blsi() const;
467 
468   /// Compute known bits for X ^ (X - 1), which has all bits up to and including
469   /// the lowest set bit of X set. The name comes from the X86 BMI instruction.
470   KnownBits blsmsk() const;
471 
472   bool operator==(const KnownBits &Other) const {
473     return Zero == Other.Zero && One == Other.One;
474   }
475 
476   bool operator!=(const KnownBits &Other) const { return !(*this == Other); }
477 
478   void print(raw_ostream &OS) const;
479   void dump() const;
480 
481 private:
482   // Internal helper for getting the initial KnownBits for an `srem` or `urem`
483   // operation with the low-bits set.
484   static KnownBits remGetLowBits(const KnownBits &LHS, const KnownBits &RHS);
485 };
486 
487 inline KnownBits operator&(KnownBits LHS, const KnownBits &RHS) {
488   LHS &= RHS;
489   return LHS;
490 }
491 
492 inline KnownBits operator&(const KnownBits &LHS, KnownBits &&RHS) {
493   RHS &= LHS;
494   return std::move(RHS);
495 }
496 
497 inline KnownBits operator|(KnownBits LHS, const KnownBits &RHS) {
498   LHS |= RHS;
499   return LHS;
500 }
501 
502 inline KnownBits operator|(const KnownBits &LHS, KnownBits &&RHS) {
503   RHS |= LHS;
504   return std::move(RHS);
505 }
506 
507 inline KnownBits operator^(KnownBits LHS, const KnownBits &RHS) {
508   LHS ^= RHS;
509   return LHS;
510 }
511 
512 inline KnownBits operator^(const KnownBits &LHS, KnownBits &&RHS) {
513   RHS ^= LHS;
514   return std::move(RHS);
515 }
516 
517 inline raw_ostream &operator<<(raw_ostream &OS, const KnownBits &Known) {
518   Known.print(OS);
519   return OS;
520 }
521 
522 } // end namespace llvm
523 
524 #endif
525