xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/ValueTracking.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- ValueTracking.cpp - Walk computations to compute properties --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains routines that help analyze properties that chains of
10 // computations have.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/ValueTracking.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/FloatingPointMode.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/iterator_range.h"
26 #include "llvm/Analysis/AliasAnalysis.h"
27 #include "llvm/Analysis/AssumeBundleQueries.h"
28 #include "llvm/Analysis/AssumptionCache.h"
29 #include "llvm/Analysis/ConstantFolding.h"
30 #include "llvm/Analysis/DomConditionCache.h"
31 #include "llvm/Analysis/FloatingPointPredicateUtils.h"
32 #include "llvm/Analysis/GuardUtils.h"
33 #include "llvm/Analysis/InstructionSimplify.h"
34 #include "llvm/Analysis/Loads.h"
35 #include "llvm/Analysis/LoopInfo.h"
36 #include "llvm/Analysis/TargetLibraryInfo.h"
37 #include "llvm/Analysis/VectorUtils.h"
38 #include "llvm/Analysis/WithCache.h"
39 #include "llvm/IR/Argument.h"
40 #include "llvm/IR/Attributes.h"
41 #include "llvm/IR/BasicBlock.h"
42 #include "llvm/IR/Constant.h"
43 #include "llvm/IR/ConstantRange.h"
44 #include "llvm/IR/Constants.h"
45 #include "llvm/IR/DerivedTypes.h"
46 #include "llvm/IR/DiagnosticInfo.h"
47 #include "llvm/IR/Dominators.h"
48 #include "llvm/IR/EHPersonalities.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/GetElementPtrTypeIterator.h"
51 #include "llvm/IR/GlobalAlias.h"
52 #include "llvm/IR/GlobalValue.h"
53 #include "llvm/IR/GlobalVariable.h"
54 #include "llvm/IR/InstrTypes.h"
55 #include "llvm/IR/Instruction.h"
56 #include "llvm/IR/Instructions.h"
57 #include "llvm/IR/IntrinsicInst.h"
58 #include "llvm/IR/Intrinsics.h"
59 #include "llvm/IR/IntrinsicsAArch64.h"
60 #include "llvm/IR/IntrinsicsAMDGPU.h"
61 #include "llvm/IR/IntrinsicsRISCV.h"
62 #include "llvm/IR/IntrinsicsX86.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Metadata.h"
65 #include "llvm/IR/Module.h"
66 #include "llvm/IR/Operator.h"
67 #include "llvm/IR/PatternMatch.h"
68 #include "llvm/IR/Type.h"
69 #include "llvm/IR/User.h"
70 #include "llvm/IR/Value.h"
71 #include "llvm/Support/Casting.h"
72 #include "llvm/Support/CommandLine.h"
73 #include "llvm/Support/Compiler.h"
74 #include "llvm/Support/ErrorHandling.h"
75 #include "llvm/Support/KnownBits.h"
76 #include "llvm/Support/KnownFPClass.h"
77 #include "llvm/Support/MathExtras.h"
78 #include "llvm/TargetParser/RISCVTargetParser.h"
79 #include <algorithm>
80 #include <cassert>
81 #include <cstdint>
82 #include <optional>
83 #include <utility>
84 
85 using namespace llvm;
86 using namespace llvm::PatternMatch;
87 
88 // Controls the number of uses of the value searched for possible
89 // dominating comparisons.
90 static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
91                                               cl::Hidden, cl::init(20));
92 
93 
94 /// Returns the bitwidth of the given scalar or pointer type. For vector types,
95 /// returns the element type's bitwidth.
96 static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
97   if (unsigned BitWidth = Ty->getScalarSizeInBits())
98     return BitWidth;
99 
100   return DL.getPointerTypeSizeInBits(Ty);
101 }
102 
103 // Given the provided Value and, potentially, a context instruction, return
104 // the preferred context instruction (if any).
105 static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
106   // If we've been provided with a context instruction, then use that (provided
107   // it has been inserted).
108   if (CxtI && CxtI->getParent())
109     return CxtI;
110 
111   // If the value is really an already-inserted instruction, then use that.
112   CxtI = dyn_cast<Instruction>(V);
113   if (CxtI && CxtI->getParent())
114     return CxtI;
115 
116   return nullptr;
117 }
118 
119 static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
120                                    const APInt &DemandedElts,
121                                    APInt &DemandedLHS, APInt &DemandedRHS) {
122   if (isa<ScalableVectorType>(Shuf->getType())) {
123     assert(DemandedElts == APInt(1,1));
124     DemandedLHS = DemandedRHS = DemandedElts;
125     return true;
126   }
127 
128   int NumElts =
129       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
130   return llvm::getShuffleDemandedElts(NumElts, Shuf->getShuffleMask(),
131                                       DemandedElts, DemandedLHS, DemandedRHS);
132 }
133 
134 static void computeKnownBits(const Value *V, const APInt &DemandedElts,
135                              KnownBits &Known, const SimplifyQuery &Q,
136                              unsigned Depth);
137 
138 void llvm::computeKnownBits(const Value *V, KnownBits &Known,
139                             const SimplifyQuery &Q, unsigned Depth) {
140   // Since the number of lanes in a scalable vector is unknown at compile time,
141   // we track one bit which is implicitly broadcast to all lanes.  This means
142   // that all lanes in a scalable vector are considered demanded.
143   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
144   APInt DemandedElts =
145       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
146   ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
147 }
148 
149 void llvm::computeKnownBits(const Value *V, KnownBits &Known,
150                             const DataLayout &DL, AssumptionCache *AC,
151                             const Instruction *CxtI, const DominatorTree *DT,
152                             bool UseInstrInfo, unsigned Depth) {
153   computeKnownBits(V, Known,
154                    SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
155                    Depth);
156 }
157 
158 KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
159                                  AssumptionCache *AC, const Instruction *CxtI,
160                                  const DominatorTree *DT, bool UseInstrInfo,
161                                  unsigned Depth) {
162   return computeKnownBits(
163       V, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
164 }
165 
166 KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
167                                  const DataLayout &DL, AssumptionCache *AC,
168                                  const Instruction *CxtI,
169                                  const DominatorTree *DT, bool UseInstrInfo,
170                                  unsigned Depth) {
171   return computeKnownBits(
172       V, DemandedElts,
173       SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
174 }
175 
176 static bool haveNoCommonBitsSetSpecialCases(const Value *LHS, const Value *RHS,
177                                             const SimplifyQuery &SQ) {
178   // Look for an inverted mask: (X & ~M) op (Y & M).
179   {
180     Value *M;
181     if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) &&
182         match(RHS, m_c_And(m_Specific(M), m_Value())) &&
183         isGuaranteedNotToBeUndef(M, SQ.AC, SQ.CxtI, SQ.DT))
184       return true;
185   }
186 
187   // X op (Y & ~X)
188   if (match(RHS, m_c_And(m_Not(m_Specific(LHS)), m_Value())) &&
189       isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT))
190     return true;
191 
192   // X op ((X & Y) ^ Y) -- this is the canonical form of the previous pattern
193   // for constant Y.
194   Value *Y;
195   if (match(RHS,
196             m_c_Xor(m_c_And(m_Specific(LHS), m_Value(Y)), m_Deferred(Y))) &&
197       isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT) &&
198       isGuaranteedNotToBeUndef(Y, SQ.AC, SQ.CxtI, SQ.DT))
199     return true;
200 
201   // Peek through extends to find a 'not' of the other side:
202   // (ext Y) op ext(~Y)
203   if (match(LHS, m_ZExtOrSExt(m_Value(Y))) &&
204       match(RHS, m_ZExtOrSExt(m_Not(m_Specific(Y)))) &&
205       isGuaranteedNotToBeUndef(Y, SQ.AC, SQ.CxtI, SQ.DT))
206     return true;
207 
208   // Look for: (A & B) op ~(A | B)
209   {
210     Value *A, *B;
211     if (match(LHS, m_And(m_Value(A), m_Value(B))) &&
212         match(RHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))) &&
213         isGuaranteedNotToBeUndef(A, SQ.AC, SQ.CxtI, SQ.DT) &&
214         isGuaranteedNotToBeUndef(B, SQ.AC, SQ.CxtI, SQ.DT))
215       return true;
216   }
217 
218   // Look for: (X << V) op (Y >> (BitWidth - V))
219   // or        (X >> V) op (Y << (BitWidth - V))
220   {
221     const Value *V;
222     const APInt *R;
223     if (((match(RHS, m_Shl(m_Value(), m_Sub(m_APInt(R), m_Value(V)))) &&
224           match(LHS, m_LShr(m_Value(), m_Specific(V)))) ||
225          (match(RHS, m_LShr(m_Value(), m_Sub(m_APInt(R), m_Value(V)))) &&
226           match(LHS, m_Shl(m_Value(), m_Specific(V))))) &&
227         R->uge(LHS->getType()->getScalarSizeInBits()))
228       return true;
229   }
230 
231   return false;
232 }
233 
234 bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
235                                const WithCache<const Value *> &RHSCache,
236                                const SimplifyQuery &SQ) {
237   const Value *LHS = LHSCache.getValue();
238   const Value *RHS = RHSCache.getValue();
239 
240   assert(LHS->getType() == RHS->getType() &&
241          "LHS and RHS should have the same type");
242   assert(LHS->getType()->isIntOrIntVectorTy() &&
243          "LHS and RHS should be integers");
244 
245   if (haveNoCommonBitsSetSpecialCases(LHS, RHS, SQ) ||
246       haveNoCommonBitsSetSpecialCases(RHS, LHS, SQ))
247     return true;
248 
249   return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ),
250                                         RHSCache.getKnownBits(SQ));
251 }
252 
253 bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
254   return !I->user_empty() && all_of(I->users(), [](const User *U) {
255     return match(U, m_ICmp(m_Value(), m_Zero()));
256   });
257 }
258 
259 bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
260   return !I->user_empty() && all_of(I->users(), [](const User *U) {
261     CmpPredicate P;
262     return match(U, m_ICmp(P, m_Value(), m_Zero())) && ICmpInst::isEquality(P);
263   });
264 }
265 
266 bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
267                                   bool OrZero, AssumptionCache *AC,
268                                   const Instruction *CxtI,
269                                   const DominatorTree *DT, bool UseInstrInfo,
270                                   unsigned Depth) {
271   return ::isKnownToBeAPowerOfTwo(
272       V, OrZero, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
273       Depth);
274 }
275 
276 static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
277                            const SimplifyQuery &Q, unsigned Depth);
278 
279 bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
280                               unsigned Depth) {
281   return computeKnownBits(V, SQ, Depth).isNonNegative();
282 }
283 
284 bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
285                            unsigned Depth) {
286   if (auto *CI = dyn_cast<ConstantInt>(V))
287     return CI->getValue().isStrictlyPositive();
288 
289   // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
290   // this updated.
291   KnownBits Known = computeKnownBits(V, SQ, Depth);
292   return Known.isNonNegative() &&
293          (Known.isNonZero() || isKnownNonZero(V, SQ, Depth));
294 }
295 
296 bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
297                            unsigned Depth) {
298   return computeKnownBits(V, SQ, Depth).isNegative();
299 }
300 
301 static bool isKnownNonEqual(const Value *V1, const Value *V2,
302                             const APInt &DemandedElts, const SimplifyQuery &Q,
303                             unsigned Depth);
304 
305 bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
306                            const SimplifyQuery &Q, unsigned Depth) {
307   // We don't support looking through casts.
308   if (V1 == V2 || V1->getType() != V2->getType())
309     return false;
310   auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
311   APInt DemandedElts =
312       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
313   return ::isKnownNonEqual(V1, V2, DemandedElts, Q, Depth);
314 }
315 
316 bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
317                              const SimplifyQuery &SQ, unsigned Depth) {
318   KnownBits Known(Mask.getBitWidth());
319   computeKnownBits(V, Known, SQ, Depth);
320   return Mask.isSubsetOf(Known.Zero);
321 }
322 
323 static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
324                                    const SimplifyQuery &Q, unsigned Depth);
325 
326 static unsigned ComputeNumSignBits(const Value *V, const SimplifyQuery &Q,
327                                    unsigned Depth = 0) {
328   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
329   APInt DemandedElts =
330       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
331   return ComputeNumSignBits(V, DemandedElts, Q, Depth);
332 }
333 
334 unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
335                                   AssumptionCache *AC, const Instruction *CxtI,
336                                   const DominatorTree *DT, bool UseInstrInfo,
337                                   unsigned Depth) {
338   return ::ComputeNumSignBits(
339       V, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
340 }
341 
342 unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
343                                          AssumptionCache *AC,
344                                          const Instruction *CxtI,
345                                          const DominatorTree *DT,
346                                          unsigned Depth) {
347   unsigned SignBits = ComputeNumSignBits(V, DL, AC, CxtI, DT, Depth);
348   return V->getType()->getScalarSizeInBits() - SignBits + 1;
349 }
350 
351 static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
352                                    bool NSW, bool NUW,
353                                    const APInt &DemandedElts,
354                                    KnownBits &KnownOut, KnownBits &Known2,
355                                    const SimplifyQuery &Q, unsigned Depth) {
356   computeKnownBits(Op1, DemandedElts, KnownOut, Q, Depth + 1);
357 
358   // If one operand is unknown and we have no nowrap information,
359   // the result will be unknown independently of the second operand.
360   if (KnownOut.isUnknown() && !NSW && !NUW)
361     return;
362 
363   computeKnownBits(Op0, DemandedElts, Known2, Q, Depth + 1);
364   KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
365 
366   if (!Add && NSW && !KnownOut.isNonNegative() &&
367       isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
368           .value_or(false))
369     KnownOut.makeNonNegative();
370 }
371 
372 static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
373                                 bool NUW, const APInt &DemandedElts,
374                                 KnownBits &Known, KnownBits &Known2,
375                                 const SimplifyQuery &Q, unsigned Depth) {
376   computeKnownBits(Op1, DemandedElts, Known, Q, Depth + 1);
377   computeKnownBits(Op0, DemandedElts, Known2, Q, Depth + 1);
378 
379   bool isKnownNegative = false;
380   bool isKnownNonNegative = false;
381   // If the multiplication is known not to overflow, compute the sign bit.
382   if (NSW) {
383     if (Op0 == Op1) {
384       // The product of a number with itself is non-negative.
385       isKnownNonNegative = true;
386     } else {
387       bool isKnownNonNegativeOp1 = Known.isNonNegative();
388       bool isKnownNonNegativeOp0 = Known2.isNonNegative();
389       bool isKnownNegativeOp1 = Known.isNegative();
390       bool isKnownNegativeOp0 = Known2.isNegative();
391       // The product of two numbers with the same sign is non-negative.
392       isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
393                            (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
394       if (!isKnownNonNegative && NUW) {
395         // mul nuw nsw with a factor > 1 is non-negative.
396         KnownBits One = KnownBits::makeConstant(APInt(Known.getBitWidth(), 1));
397         isKnownNonNegative = KnownBits::sgt(Known, One).value_or(false) ||
398                              KnownBits::sgt(Known2, One).value_or(false);
399       }
400 
401       // The product of a negative number and a non-negative number is either
402       // negative or zero.
403       if (!isKnownNonNegative)
404         isKnownNegative =
405             (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
406              Known2.isNonZero()) ||
407             (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero());
408     }
409   }
410 
411   bool SelfMultiply = Op0 == Op1;
412   if (SelfMultiply)
413     SelfMultiply &=
414         isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1);
415   Known = KnownBits::mul(Known, Known2, SelfMultiply);
416 
417   // Only make use of no-wrap flags if we failed to compute the sign bit
418   // directly.  This matters if the multiplication always overflows, in
419   // which case we prefer to follow the result of the direct computation,
420   // though as the program is invoking undefined behaviour we can choose
421   // whatever we like here.
422   if (isKnownNonNegative && !Known.isNegative())
423     Known.makeNonNegative();
424   else if (isKnownNegative && !Known.isNonNegative())
425     Known.makeNegative();
426 }
427 
428 void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
429                                              KnownBits &Known) {
430   unsigned BitWidth = Known.getBitWidth();
431   unsigned NumRanges = Ranges.getNumOperands() / 2;
432   assert(NumRanges >= 1);
433 
434   Known.Zero.setAllBits();
435   Known.One.setAllBits();
436 
437   for (unsigned i = 0; i < NumRanges; ++i) {
438     ConstantInt *Lower =
439         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
440     ConstantInt *Upper =
441         mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
442     ConstantRange Range(Lower->getValue(), Upper->getValue());
443     // BitWidth must equal the Ranges BitWidth for the correct number of high
444     // bits to be set.
445     assert(BitWidth == Range.getBitWidth() &&
446            "Known bit width must match range bit width!");
447 
448     // The first CommonPrefixBits of all values in Range are equal.
449     unsigned CommonPrefixBits =
450         (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
451     APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
452     APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(BitWidth);
453     Known.One &= UnsignedMax & Mask;
454     Known.Zero &= ~UnsignedMax & Mask;
455   }
456 }
457 
458 static bool isEphemeralValueOf(const Instruction *I, const Value *E) {
459   SmallVector<const Instruction *, 16> WorkSet(1, I);
460   SmallPtrSet<const Instruction *, 32> Visited;
461   SmallPtrSet<const Instruction *, 16> EphValues;
462 
463   // The instruction defining an assumption's condition itself is always
464   // considered ephemeral to that assumption (even if it has other
465   // non-ephemeral users). See r246696's test case for an example.
466   if (is_contained(I->operands(), E))
467     return true;
468 
469   while (!WorkSet.empty()) {
470     const Instruction *V = WorkSet.pop_back_val();
471     if (!Visited.insert(V).second)
472       continue;
473 
474     // If all uses of this value are ephemeral, then so is this value.
475     if (all_of(V->users(), [&](const User *U) {
476           return EphValues.count(cast<Instruction>(U));
477         })) {
478       if (V == E)
479         return true;
480 
481       if (V == I || (!V->mayHaveSideEffects() && !V->isTerminator())) {
482         EphValues.insert(V);
483 
484         if (const User *U = dyn_cast<User>(V)) {
485           for (const Use &U : U->operands()) {
486             if (const auto *I = dyn_cast<Instruction>(U.get()))
487               WorkSet.push_back(I);
488           }
489         }
490       }
491     }
492   }
493 
494   return false;
495 }
496 
497 // Is this an intrinsic that cannot be speculated but also cannot trap?
498 bool llvm::isAssumeLikeIntrinsic(const Instruction *I) {
499   if (const IntrinsicInst *CI = dyn_cast<IntrinsicInst>(I))
500     return CI->isAssumeLikeIntrinsic();
501 
502   return false;
503 }
504 
505 bool llvm::isValidAssumeForContext(const Instruction *Inv,
506                                    const Instruction *CxtI,
507                                    const DominatorTree *DT,
508                                    bool AllowEphemerals) {
509   // There are two restrictions on the use of an assume:
510   //  1. The assume must dominate the context (or the control flow must
511   //     reach the assume whenever it reaches the context).
512   //  2. The context must not be in the assume's set of ephemeral values
513   //     (otherwise we will use the assume to prove that the condition
514   //     feeding the assume is trivially true, thus causing the removal of
515   //     the assume).
516 
517   if (Inv->getParent() == CxtI->getParent()) {
518     // If Inv and CtxI are in the same block, check if the assume (Inv) is first
519     // in the BB.
520     if (Inv->comesBefore(CxtI))
521       return true;
522 
523     // Don't let an assume affect itself - this would cause the problems
524     // `isEphemeralValueOf` is trying to prevent, and it would also make
525     // the loop below go out of bounds.
526     if (!AllowEphemerals && Inv == CxtI)
527       return false;
528 
529     // The context comes first, but they're both in the same block.
530     // Make sure there is nothing in between that might interrupt
531     // the control flow, not even CxtI itself.
532     // We limit the scan distance between the assume and its context instruction
533     // to avoid a compile-time explosion. This limit is chosen arbitrarily, so
534     // it can be adjusted if needed (could be turned into a cl::opt).
535     auto Range = make_range(CxtI->getIterator(), Inv->getIterator());
536     if (!isGuaranteedToTransferExecutionToSuccessor(Range, 15))
537       return false;
538 
539     return AllowEphemerals || !isEphemeralValueOf(Inv, CxtI);
540   }
541 
542   // Inv and CxtI are in different blocks.
543   if (DT) {
544     if (DT->dominates(Inv, CxtI))
545       return true;
546   } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor() ||
547              Inv->getParent()->isEntryBlock()) {
548     // We don't have a DT, but this trivially dominates.
549     return true;
550   }
551 
552   return false;
553 }
554 
555 // TODO: cmpExcludesZero misses many cases where `RHS` is non-constant but
556 // we still have enough information about `RHS` to conclude non-zero. For
557 // example Pred=EQ, RHS=isKnownNonZero. cmpExcludesZero is called in loops
558 // so the extra compile time may not be worth it, but possibly a second API
559 // should be created for use outside of loops.
560 static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) {
561   // v u> y implies v != 0.
562   if (Pred == ICmpInst::ICMP_UGT)
563     return true;
564 
565   // Special-case v != 0 to also handle v != null.
566   if (Pred == ICmpInst::ICMP_NE)
567     return match(RHS, m_Zero());
568 
569   // All other predicates - rely on generic ConstantRange handling.
570   const APInt *C;
571   auto Zero = APInt::getZero(RHS->getType()->getScalarSizeInBits());
572   if (match(RHS, m_APInt(C))) {
573     ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(Pred, *C);
574     return !TrueValues.contains(Zero);
575   }
576 
577   auto *VC = dyn_cast<ConstantDataVector>(RHS);
578   if (VC == nullptr)
579     return false;
580 
581   for (unsigned ElemIdx = 0, NElem = VC->getNumElements(); ElemIdx < NElem;
582        ++ElemIdx) {
583     ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(
584         Pred, VC->getElementAsAPInt(ElemIdx));
585     if (TrueValues.contains(Zero))
586       return false;
587   }
588   return true;
589 }
590 
591 static void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
592                                   Value *&ValOut, Instruction *&CtxIOut,
593                                   const PHINode **PhiOut = nullptr) {
594   ValOut = U->get();
595   if (ValOut == PHI)
596     return;
597   CtxIOut = PHI->getIncomingBlock(*U)->getTerminator();
598   if (PhiOut)
599     *PhiOut = PHI;
600   Value *V;
601   // If the Use is a select of this phi, compute analysis on other arm to break
602   // recursion.
603   // TODO: Min/Max
604   if (match(ValOut, m_Select(m_Value(), m_Specific(PHI), m_Value(V))) ||
605       match(ValOut, m_Select(m_Value(), m_Value(V), m_Specific(PHI))))
606     ValOut = V;
607 
608   // Same for select, if this phi is 2-operand phi, compute analysis on other
609   // incoming value to break recursion.
610   // TODO: We could handle any number of incoming edges as long as we only have
611   // two unique values.
612   if (auto *IncPhi = dyn_cast<PHINode>(ValOut);
613       IncPhi && IncPhi->getNumIncomingValues() == 2) {
614     for (int Idx = 0; Idx < 2; ++Idx) {
615       if (IncPhi->getIncomingValue(Idx) == PHI) {
616         ValOut = IncPhi->getIncomingValue(1 - Idx);
617         if (PhiOut)
618           *PhiOut = IncPhi;
619         CtxIOut = IncPhi->getIncomingBlock(1 - Idx)->getTerminator();
620         break;
621       }
622     }
623   }
624 }
625 
626 static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
627   // Use of assumptions is context-sensitive. If we don't have a context, we
628   // cannot use them!
629   if (!Q.AC || !Q.CxtI)
630     return false;
631 
632   for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
633     if (!Elem.Assume)
634       continue;
635 
636     AssumeInst *I = cast<AssumeInst>(Elem.Assume);
637     assert(I->getFunction() == Q.CxtI->getFunction() &&
638            "Got assumption for the wrong function!");
639 
640     if (Elem.Index != AssumptionCache::ExprResultIdx) {
641       if (!V->getType()->isPointerTy())
642         continue;
643       if (RetainedKnowledge RK = getKnowledgeFromBundle(
644               *I, I->bundle_op_info_begin()[Elem.Index])) {
645         if (RK.WasOn == V &&
646             (RK.AttrKind == Attribute::NonNull ||
647              (RK.AttrKind == Attribute::Dereferenceable &&
648               !NullPointerIsDefined(Q.CxtI->getFunction(),
649                                     V->getType()->getPointerAddressSpace()))) &&
650             isValidAssumeForContext(I, Q.CxtI, Q.DT))
651           return true;
652       }
653       continue;
654     }
655 
656     // Warning: This loop can end up being somewhat performance sensitive.
657     // We're running this loop for once for each value queried resulting in a
658     // runtime of ~O(#assumes * #values).
659 
660     Value *RHS;
661     CmpPredicate Pred;
662     auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
663     if (!match(I->getArgOperand(0), m_c_ICmp(Pred, m_V, m_Value(RHS))))
664       continue;
665 
666     if (cmpExcludesZero(Pred, RHS) && isValidAssumeForContext(I, Q.CxtI, Q.DT))
667       return true;
668   }
669 
670   return false;
671 }
672 
673 static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
674                                     Value *LHS, Value *RHS, KnownBits &Known,
675                                     const SimplifyQuery &Q) {
676   if (RHS->getType()->isPointerTy()) {
677     // Handle comparison of pointer to null explicitly, as it will not be
678     // covered by the m_APInt() logic below.
679     if (LHS == V && match(RHS, m_Zero())) {
680       switch (Pred) {
681       case ICmpInst::ICMP_EQ:
682         Known.setAllZero();
683         break;
684       case ICmpInst::ICMP_SGE:
685       case ICmpInst::ICMP_SGT:
686         Known.makeNonNegative();
687         break;
688       case ICmpInst::ICMP_SLT:
689         Known.makeNegative();
690         break;
691       default:
692         break;
693       }
694     }
695     return;
696   }
697 
698   unsigned BitWidth = Known.getBitWidth();
699   auto m_V =
700       m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
701 
702   Value *Y;
703   const APInt *Mask, *C;
704   if (!match(RHS, m_APInt(C)))
705     return;
706 
707   uint64_t ShAmt;
708   switch (Pred) {
709   case ICmpInst::ICMP_EQ:
710     // assume(V = C)
711     if (match(LHS, m_V)) {
712       Known = Known.unionWith(KnownBits::makeConstant(*C));
713       // assume(V & Mask = C)
714     } else if (match(LHS, m_c_And(m_V, m_Value(Y)))) {
715       // For one bits in Mask, we can propagate bits from C to V.
716       Known.One |= *C;
717       if (match(Y, m_APInt(Mask)))
718         Known.Zero |= ~*C & *Mask;
719       // assume(V | Mask = C)
720     } else if (match(LHS, m_c_Or(m_V, m_Value(Y)))) {
721       // For zero bits in Mask, we can propagate bits from C to V.
722       Known.Zero |= ~*C;
723       if (match(Y, m_APInt(Mask)))
724         Known.One |= *C & ~*Mask;
725       // assume(V << ShAmt = C)
726     } else if (match(LHS, m_Shl(m_V, m_ConstantInt(ShAmt))) &&
727                ShAmt < BitWidth) {
728       // For those bits in C that are known, we can propagate them to known
729       // bits in V shifted to the right by ShAmt.
730       KnownBits RHSKnown = KnownBits::makeConstant(*C);
731       RHSKnown.Zero.lshrInPlace(ShAmt);
732       RHSKnown.One.lshrInPlace(ShAmt);
733       Known = Known.unionWith(RHSKnown);
734       // assume(V >> ShAmt = C)
735     } else if (match(LHS, m_Shr(m_V, m_ConstantInt(ShAmt))) &&
736                ShAmt < BitWidth) {
737       KnownBits RHSKnown = KnownBits::makeConstant(*C);
738       // For those bits in RHS that are known, we can propagate them to known
739       // bits in V shifted to the right by C.
740       Known.Zero |= RHSKnown.Zero << ShAmt;
741       Known.One |= RHSKnown.One << ShAmt;
742     }
743     break;
744   case ICmpInst::ICMP_NE: {
745     // assume (V & B != 0) where B is a power of 2
746     const APInt *BPow2;
747     if (C->isZero() && match(LHS, m_And(m_V, m_Power2(BPow2))))
748       Known.One |= *BPow2;
749     break;
750   }
751   default: {
752     const APInt *Offset = nullptr;
753     if (match(LHS, m_CombineOr(m_V, m_AddLike(m_V, m_APInt(Offset))))) {
754       ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, *C);
755       if (Offset)
756         LHSRange = LHSRange.sub(*Offset);
757       Known = Known.unionWith(LHSRange.toKnownBits());
758     }
759     if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
760       // X & Y u> C     -> X u> C && Y u> C
761       // X nuw- Y u> C  -> X u> C
762       if (match(LHS, m_c_And(m_V, m_Value())) ||
763           match(LHS, m_NUWSub(m_V, m_Value())))
764         Known.One.setHighBits(
765             (*C + (Pred == ICmpInst::ICMP_UGT)).countLeadingOnes());
766     }
767     if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
768       // X | Y u< C    -> X u< C && Y u< C
769       // X nuw+ Y u< C -> X u< C && Y u< C
770       if (match(LHS, m_c_Or(m_V, m_Value())) ||
771           match(LHS, m_c_NUWAdd(m_V, m_Value()))) {
772         Known.Zero.setHighBits(
773             (*C - (Pred == ICmpInst::ICMP_ULT)).countLeadingZeros());
774       }
775     }
776   } break;
777   }
778 }
779 
780 static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
781                                          KnownBits &Known,
782                                          const SimplifyQuery &SQ, bool Invert) {
783   ICmpInst::Predicate Pred =
784       Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
785   Value *LHS = Cmp->getOperand(0);
786   Value *RHS = Cmp->getOperand(1);
787 
788   // Handle icmp pred (trunc V), C
789   if (match(LHS, m_Trunc(m_Specific(V)))) {
790     KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
791     computeKnownBitsFromCmp(LHS, Pred, LHS, RHS, DstKnown, SQ);
792     if (cast<TruncInst>(LHS)->hasNoUnsignedWrap())
793       Known = Known.unionWith(DstKnown.zext(Known.getBitWidth()));
794     else
795       Known = Known.unionWith(DstKnown.anyext(Known.getBitWidth()));
796     return;
797   }
798 
799   computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
800 }
801 
802 static void computeKnownBitsFromCond(const Value *V, Value *Cond,
803                                      KnownBits &Known, const SimplifyQuery &SQ,
804                                      bool Invert, unsigned Depth) {
805   Value *A, *B;
806   if (Depth < MaxAnalysisRecursionDepth &&
807       match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) {
808     KnownBits Known2(Known.getBitWidth());
809     KnownBits Known3(Known.getBitWidth());
810     computeKnownBitsFromCond(V, A, Known2, SQ, Invert, Depth + 1);
811     computeKnownBitsFromCond(V, B, Known3, SQ, Invert, Depth + 1);
812     if (Invert ? match(Cond, m_LogicalOr(m_Value(), m_Value()))
813                : match(Cond, m_LogicalAnd(m_Value(), m_Value())))
814       Known2 = Known2.unionWith(Known3);
815     else
816       Known2 = Known2.intersectWith(Known3);
817     Known = Known.unionWith(Known2);
818     return;
819   }
820 
821   if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
822     computeKnownBitsFromICmpCond(V, Cmp, Known, SQ, Invert);
823     return;
824   }
825 
826   if (match(Cond, m_Trunc(m_Specific(V)))) {
827     KnownBits DstKnown(1);
828     if (Invert) {
829       DstKnown.setAllZero();
830     } else {
831       DstKnown.setAllOnes();
832     }
833     if (cast<TruncInst>(Cond)->hasNoUnsignedWrap()) {
834       Known = Known.unionWith(DstKnown.zext(Known.getBitWidth()));
835       return;
836     }
837     Known = Known.unionWith(DstKnown.anyext(Known.getBitWidth()));
838     return;
839   }
840 
841   if (Depth < MaxAnalysisRecursionDepth && match(Cond, m_Not(m_Value(A))))
842     computeKnownBitsFromCond(V, A, Known, SQ, !Invert, Depth + 1);
843 }
844 
845 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
846                                        const SimplifyQuery &Q, unsigned Depth) {
847   // Handle injected condition.
848   if (Q.CC && Q.CC->AffectedValues.contains(V))
849     computeKnownBitsFromCond(V, Q.CC->Cond, Known, Q, Q.CC->Invert, Depth);
850 
851   if (!Q.CxtI)
852     return;
853 
854   if (Q.DC && Q.DT) {
855     // Handle dominating conditions.
856     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
857       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
858       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
859         computeKnownBitsFromCond(V, BI->getCondition(), Known, Q,
860                                  /*Invert*/ false, Depth);
861 
862       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
863       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
864         computeKnownBitsFromCond(V, BI->getCondition(), Known, Q,
865                                  /*Invert*/ true, Depth);
866     }
867 
868     if (Known.hasConflict())
869       Known.resetAll();
870   }
871 
872   if (!Q.AC)
873     return;
874 
875   unsigned BitWidth = Known.getBitWidth();
876 
877   // Note that the patterns below need to be kept in sync with the code
878   // in AssumptionCache::updateAffectedValues.
879 
880   for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
881     if (!Elem.Assume)
882       continue;
883 
884     AssumeInst *I = cast<AssumeInst>(Elem.Assume);
885     assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
886            "Got assumption for the wrong function!");
887 
888     if (Elem.Index != AssumptionCache::ExprResultIdx) {
889       if (!V->getType()->isPointerTy())
890         continue;
891       if (RetainedKnowledge RK = getKnowledgeFromBundle(
892               *I, I->bundle_op_info_begin()[Elem.Index])) {
893         // Allow AllowEphemerals in isValidAssumeForContext, as the CxtI might
894         // be the producer of the pointer in the bundle. At the moment, align
895         // assumptions aren't optimized away.
896         if (RK.WasOn == V && RK.AttrKind == Attribute::Alignment &&
897             isPowerOf2_64(RK.ArgValue) &&
898             isValidAssumeForContext(I, Q.CxtI, Q.DT, /*AllowEphemerals*/ true))
899           Known.Zero.setLowBits(Log2_64(RK.ArgValue));
900       }
901       continue;
902     }
903 
904     // Warning: This loop can end up being somewhat performance sensitive.
905     // We're running this loop for once for each value queried resulting in a
906     // runtime of ~O(#assumes * #values).
907 
908     Value *Arg = I->getArgOperand(0);
909 
910     if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
911       assert(BitWidth == 1 && "assume operand is not i1?");
912       (void)BitWidth;
913       Known.setAllOnes();
914       return;
915     }
916     if (match(Arg, m_Not(m_Specific(V))) &&
917         isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
918       assert(BitWidth == 1 && "assume operand is not i1?");
919       (void)BitWidth;
920       Known.setAllZero();
921       return;
922     }
923     auto *Trunc = dyn_cast<TruncInst>(Arg);
924     if (Trunc && Trunc->getOperand(0) == V &&
925         isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
926       if (Trunc->hasNoUnsignedWrap()) {
927         Known = KnownBits::makeConstant(APInt(BitWidth, 1));
928         return;
929       }
930       Known.One.setBit(0);
931       return;
932     }
933 
934     // The remaining tests are all recursive, so bail out if we hit the limit.
935     if (Depth == MaxAnalysisRecursionDepth)
936       continue;
937 
938     ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
939     if (!Cmp)
940       continue;
941 
942     if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
943       continue;
944 
945     computeKnownBitsFromICmpCond(V, Cmp, Known, Q, /*Invert=*/false);
946   }
947 
948   // Conflicting assumption: Undefined behavior will occur on this execution
949   // path.
950   if (Known.hasConflict())
951     Known.resetAll();
952 }
953 
954 /// Compute known bits from a shift operator, including those with a
955 /// non-constant shift amount. Known is the output of this function. Known2 is a
956 /// pre-allocated temporary with the same bit width as Known and on return
957 /// contains the known bit of the shift value source. KF is an
958 /// operator-specific function that, given the known-bits and a shift amount,
959 /// compute the implied known-bits of the shift operator's result respectively
960 /// for that shift amount. The results from calling KF are conservatively
961 /// combined for all permitted shift amounts.
962 static void computeKnownBitsFromShiftOperator(
963     const Operator *I, const APInt &DemandedElts, KnownBits &Known,
964     KnownBits &Known2, const SimplifyQuery &Q, unsigned Depth,
965     function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
966   computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
967   computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
968   // To limit compile-time impact, only query isKnownNonZero() if we know at
969   // least something about the shift amount.
970   bool ShAmtNonZero =
971       Known.isNonZero() ||
972       (Known.getMaxValue().ult(Known.getBitWidth()) &&
973        isKnownNonZero(I->getOperand(1), DemandedElts, Q, Depth + 1));
974   Known = KF(Known2, Known, ShAmtNonZero);
975 }
976 
977 static KnownBits
978 getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
979                          const KnownBits &KnownLHS, const KnownBits &KnownRHS,
980                          const SimplifyQuery &Q, unsigned Depth) {
981   unsigned BitWidth = KnownLHS.getBitWidth();
982   KnownBits KnownOut(BitWidth);
983   bool IsAnd = false;
984   bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero();
985   Value *X = nullptr, *Y = nullptr;
986 
987   switch (I->getOpcode()) {
988   case Instruction::And:
989     KnownOut = KnownLHS & KnownRHS;
990     IsAnd = true;
991     // and(x, -x) is common idioms that will clear all but lowest set
992     // bit. If we have a single known bit in x, we can clear all bits
993     // above it.
994     // TODO: instcombine often reassociates independent `and` which can hide
995     // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x).
996     if (HasKnownOne && match(I, m_c_And(m_Value(X), m_Neg(m_Deferred(X))))) {
997       // -(-x) == x so using whichever (LHS/RHS) gets us a better result.
998       if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros())
999         KnownOut = KnownLHS.blsi();
1000       else
1001         KnownOut = KnownRHS.blsi();
1002     }
1003     break;
1004   case Instruction::Or:
1005     KnownOut = KnownLHS | KnownRHS;
1006     break;
1007   case Instruction::Xor:
1008     KnownOut = KnownLHS ^ KnownRHS;
1009     // xor(x, x-1) is common idioms that will clear all but lowest set
1010     // bit. If we have a single known bit in x, we can clear all bits
1011     // above it.
1012     // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C !=
1013     // -1 but for the purpose of demanded bits (xor(x, x-C) &
1014     // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
1015     // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
1016     if (HasKnownOne &&
1017         match(I, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())))) {
1018       const KnownBits &XBits = I->getOperand(0) == X ? KnownLHS : KnownRHS;
1019       KnownOut = XBits.blsmsk();
1020     }
1021     break;
1022   default:
1023     llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'");
1024   }
1025 
1026   // and(x, add (x, -1)) is a common idiom that always clears the low bit;
1027   // xor/or(x, add (x, -1)) is an idiom that will always set the low bit.
1028   // here we handle the more general case of adding any odd number by
1029   // matching the form and/xor/or(x, add(x, y)) where y is odd.
1030   // TODO: This could be generalized to clearing any bit set in y where the
1031   // following bit is known to be unset in y.
1032   if (!KnownOut.Zero[0] && !KnownOut.One[0] &&
1033       (match(I, m_c_BinOp(m_Value(X), m_c_Add(m_Deferred(X), m_Value(Y)))) ||
1034        match(I, m_c_BinOp(m_Value(X), m_Sub(m_Deferred(X), m_Value(Y)))) ||
1035        match(I, m_c_BinOp(m_Value(X), m_Sub(m_Value(Y), m_Deferred(X)))))) {
1036     KnownBits KnownY(BitWidth);
1037     computeKnownBits(Y, DemandedElts, KnownY, Q, Depth + 1);
1038     if (KnownY.countMinTrailingOnes() > 0) {
1039       if (IsAnd)
1040         KnownOut.Zero.setBit(0);
1041       else
1042         KnownOut.One.setBit(0);
1043     }
1044   }
1045   return KnownOut;
1046 }
1047 
1048 static KnownBits computeKnownBitsForHorizontalOperation(
1049     const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
1050     unsigned Depth,
1051     const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
1052         KnownBitsFunc) {
1053   APInt DemandedEltsLHS, DemandedEltsRHS;
1054   getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
1055                                       DemandedElts, DemandedEltsLHS,
1056                                       DemandedEltsRHS);
1057 
1058   const auto ComputeForSingleOpFunc =
1059       [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
1060         return KnownBitsFunc(
1061             computeKnownBits(Op, DemandedEltsOp, Q, Depth + 1),
1062             computeKnownBits(Op, DemandedEltsOp << 1, Q, Depth + 1));
1063       };
1064 
1065   if (DemandedEltsRHS.isZero())
1066     return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
1067   if (DemandedEltsLHS.isZero())
1068     return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);
1069 
1070   return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
1071       .intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
1072 }
1073 
1074 // Public so this can be used in `SimplifyDemandedUseBits`.
1075 KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
1076                                              const KnownBits &KnownLHS,
1077                                              const KnownBits &KnownRHS,
1078                                              const SimplifyQuery &SQ,
1079                                              unsigned Depth) {
1080   auto *FVTy = dyn_cast<FixedVectorType>(I->getType());
1081   APInt DemandedElts =
1082       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
1083 
1084   return getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS, KnownRHS, SQ,
1085                                   Depth);
1086 }
1087 
1088 ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) {
1089   Attribute Attr = F->getFnAttribute(Attribute::VScaleRange);
1090   // Without vscale_range, we only know that vscale is non-zero.
1091   if (!Attr.isValid())
1092     return ConstantRange(APInt(BitWidth, 1), APInt::getZero(BitWidth));
1093 
1094   unsigned AttrMin = Attr.getVScaleRangeMin();
1095   // Minimum is larger than vscale width, result is always poison.
1096   if ((unsigned)llvm::bit_width(AttrMin) > BitWidth)
1097     return ConstantRange::getEmpty(BitWidth);
1098 
1099   APInt Min(BitWidth, AttrMin);
1100   std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax();
1101   if (!AttrMax || (unsigned)llvm::bit_width(*AttrMax) > BitWidth)
1102     return ConstantRange(Min, APInt::getZero(BitWidth));
1103 
1104   return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1);
1105 }
1106 
1107 void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
1108                                        Value *Arm, bool Invert,
1109                                        const SimplifyQuery &Q, unsigned Depth) {
1110   // If we have a constant arm, we are done.
1111   if (Known.isConstant())
1112     return;
1113 
1114   // See what condition implies about the bits of the select arm.
1115   KnownBits CondRes(Known.getBitWidth());
1116   computeKnownBitsFromCond(Arm, Cond, CondRes, Q, Invert, Depth + 1);
1117   // If we don't get any information from the condition, no reason to
1118   // proceed.
1119   if (CondRes.isUnknown())
1120     return;
1121 
1122   // We can have conflict if the condition is dead. I.e if we have
1123   // (x | 64) < 32 ? (x | 64) : y
1124   // we will have conflict at bit 6 from the condition/the `or`.
1125   // In that case just return. Its not particularly important
1126   // what we do, as this select is going to be simplified soon.
1127   CondRes = CondRes.unionWith(Known);
1128   if (CondRes.hasConflict())
1129     return;
1130 
1131   // Finally make sure the information we found is valid. This is relatively
1132   // expensive so it's left for the very end.
1133   if (!isGuaranteedNotToBeUndef(Arm, Q.AC, Q.CxtI, Q.DT, Depth + 1))
1134     return;
1135 
1136   // Finally, we know we get information from the condition and its valid,
1137   // so return it.
1138   Known = CondRes;
1139 }
1140 
1141 // Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
1142 // Returns the input and lower/upper bounds.
1143 static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
1144                                 const APInt *&CLow, const APInt *&CHigh) {
1145   assert(isa<Operator>(Select) &&
1146          cast<Operator>(Select)->getOpcode() == Instruction::Select &&
1147          "Input should be a Select!");
1148 
1149   const Value *LHS = nullptr, *RHS = nullptr;
1150   SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
1151   if (SPF != SPF_SMAX && SPF != SPF_SMIN)
1152     return false;
1153 
1154   if (!match(RHS, m_APInt(CLow)))
1155     return false;
1156 
1157   const Value *LHS2 = nullptr, *RHS2 = nullptr;
1158   SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
1159   if (getInverseMinMaxFlavor(SPF) != SPF2)
1160     return false;
1161 
1162   if (!match(RHS2, m_APInt(CHigh)))
1163     return false;
1164 
1165   if (SPF == SPF_SMIN)
1166     std::swap(CLow, CHigh);
1167 
1168   In = LHS2;
1169   return CLow->sle(*CHigh);
1170 }
1171 
1172 static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
1173                                          const APInt *&CLow,
1174                                          const APInt *&CHigh) {
1175   assert((II->getIntrinsicID() == Intrinsic::smin ||
1176           II->getIntrinsicID() == Intrinsic::smax) &&
1177          "Must be smin/smax");
1178 
1179   Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
1180   auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
1181   if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
1182       !match(II->getArgOperand(1), m_APInt(CLow)) ||
1183       !match(InnerII->getArgOperand(1), m_APInt(CHigh)))
1184     return false;
1185 
1186   if (II->getIntrinsicID() == Intrinsic::smin)
1187     std::swap(CLow, CHigh);
1188   return CLow->sle(*CHigh);
1189 }
1190 
1191 static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
1192                                           KnownBits &Known) {
1193   const APInt *CLow, *CHigh;
1194   if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
1195     Known = Known.unionWith(
1196         ConstantRange::getNonEmpty(*CLow, *CHigh + 1).toKnownBits());
1197 }
1198 
1199 static void computeKnownBitsFromOperator(const Operator *I,
1200                                          const APInt &DemandedElts,
1201                                          KnownBits &Known,
1202                                          const SimplifyQuery &Q,
1203                                          unsigned Depth) {
1204   unsigned BitWidth = Known.getBitWidth();
1205 
1206   KnownBits Known2(BitWidth);
1207   switch (I->getOpcode()) {
1208   default: break;
1209   case Instruction::Load:
1210     if (MDNode *MD =
1211             Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range))
1212       computeKnownBitsFromRangeMetadata(*MD, Known);
1213     break;
1214   case Instruction::And:
1215     computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
1216     computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1217 
1218     Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
1219     break;
1220   case Instruction::Or:
1221     computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
1222     computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1223 
1224     Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
1225     break;
1226   case Instruction::Xor:
1227     computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
1228     computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1229 
1230     Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
1231     break;
1232   case Instruction::Mul: {
1233     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1234     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1235     computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, NUW,
1236                         DemandedElts, Known, Known2, Q, Depth);
1237     break;
1238   }
1239   case Instruction::UDiv: {
1240     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1241     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1242     Known =
1243         KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
1244     break;
1245   }
1246   case Instruction::SDiv: {
1247     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1248     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1249     Known =
1250         KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
1251     break;
1252   }
1253   case Instruction::Select: {
1254     auto ComputeForArm = [&](Value *Arm, bool Invert) {
1255       KnownBits Res(Known.getBitWidth());
1256       computeKnownBits(Arm, DemandedElts, Res, Q, Depth + 1);
1257       adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Q, Depth);
1258       return Res;
1259     };
1260     // Only known if known in both the LHS and RHS.
1261     Known =
1262         ComputeForArm(I->getOperand(1), /*Invert=*/false)
1263             .intersectWith(ComputeForArm(I->getOperand(2), /*Invert=*/true));
1264     break;
1265   }
1266   case Instruction::FPTrunc:
1267   case Instruction::FPExt:
1268   case Instruction::FPToUI:
1269   case Instruction::FPToSI:
1270   case Instruction::SIToFP:
1271   case Instruction::UIToFP:
1272     break; // Can't work with floating point.
1273   case Instruction::PtrToInt:
1274   case Instruction::IntToPtr:
1275     // Fall through and handle them the same as zext/trunc.
1276     [[fallthrough]];
1277   case Instruction::ZExt:
1278   case Instruction::Trunc: {
1279     Type *SrcTy = I->getOperand(0)->getType();
1280 
1281     unsigned SrcBitWidth;
1282     // Note that we handle pointer operands here because of inttoptr/ptrtoint
1283     // which fall through here.
1284     Type *ScalarTy = SrcTy->getScalarType();
1285     SrcBitWidth = ScalarTy->isPointerTy() ?
1286       Q.DL.getPointerTypeSizeInBits(ScalarTy) :
1287       Q.DL.getTypeSizeInBits(ScalarTy);
1288 
1289     assert(SrcBitWidth && "SrcBitWidth can't be zero");
1290     Known = Known.anyextOrTrunc(SrcBitWidth);
1291     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1292     if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
1293         Inst && Inst->hasNonNeg() && !Known.isNegative())
1294       Known.makeNonNegative();
1295     Known = Known.zextOrTrunc(BitWidth);
1296     break;
1297   }
1298   case Instruction::BitCast: {
1299     Type *SrcTy = I->getOperand(0)->getType();
1300     if (SrcTy->isIntOrPtrTy() &&
1301         // TODO: For now, not handling conversions like:
1302         // (bitcast i64 %x to <2 x i32>)
1303         !I->getType()->isVectorTy()) {
1304       computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
1305       break;
1306     }
1307 
1308     const Value *V;
1309     // Handle bitcast from floating point to integer.
1310     if (match(I, m_ElementWiseBitCast(m_Value(V))) &&
1311         V->getType()->isFPOrFPVectorTy()) {
1312       Type *FPType = V->getType()->getScalarType();
1313       KnownFPClass Result =
1314           computeKnownFPClass(V, DemandedElts, fcAllFlags, Q, Depth + 1);
1315       FPClassTest FPClasses = Result.KnownFPClasses;
1316 
1317       // TODO: Treat it as zero/poison if the use of I is unreachable.
1318       if (FPClasses == fcNone)
1319         break;
1320 
1321       if (Result.isKnownNever(fcNormal | fcSubnormal | fcNan)) {
1322         Known.Zero.setAllBits();
1323         Known.One.setAllBits();
1324 
1325         if (FPClasses & fcInf)
1326           Known = Known.intersectWith(KnownBits::makeConstant(
1327               APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt()));
1328 
1329         if (FPClasses & fcZero)
1330           Known = Known.intersectWith(KnownBits::makeConstant(
1331               APInt::getZero(FPType->getScalarSizeInBits())));
1332 
1333         Known.Zero.clearSignBit();
1334         Known.One.clearSignBit();
1335       }
1336 
1337       if (Result.SignBit) {
1338         if (*Result.SignBit)
1339           Known.makeNegative();
1340         else
1341           Known.makeNonNegative();
1342       }
1343 
1344       break;
1345     }
1346 
1347     // Handle cast from vector integer type to scalar or vector integer.
1348     auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcTy);
1349     if (!SrcVecTy || !SrcVecTy->getElementType()->isIntegerTy() ||
1350         !I->getType()->isIntOrIntVectorTy() ||
1351         isa<ScalableVectorType>(I->getType()))
1352       break;
1353 
1354     // Look through a cast from narrow vector elements to wider type.
1355     // Examples: v4i32 -> v2i64, v3i8 -> v24
1356     unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
1357     if (BitWidth % SubBitWidth == 0) {
1358       // Known bits are automatically intersected across demanded elements of a
1359       // vector. So for example, if a bit is computed as known zero, it must be
1360       // zero across all demanded elements of the vector.
1361       //
1362       // For this bitcast, each demanded element of the output is sub-divided
1363       // across a set of smaller vector elements in the source vector. To get
1364       // the known bits for an entire element of the output, compute the known
1365       // bits for each sub-element sequentially. This is done by shifting the
1366       // one-set-bit demanded elements parameter across the sub-elements for
1367       // consecutive calls to computeKnownBits. We are using the demanded
1368       // elements parameter as a mask operator.
1369       //
1370       // The known bits of each sub-element are then inserted into place
1371       // (dependent on endian) to form the full result of known bits.
1372       unsigned NumElts = DemandedElts.getBitWidth();
1373       unsigned SubScale = BitWidth / SubBitWidth;
1374       APInt SubDemandedElts = APInt::getZero(NumElts * SubScale);
1375       for (unsigned i = 0; i != NumElts; ++i) {
1376         if (DemandedElts[i])
1377           SubDemandedElts.setBit(i * SubScale);
1378       }
1379 
1380       KnownBits KnownSrc(SubBitWidth);
1381       for (unsigned i = 0; i != SubScale; ++i) {
1382         computeKnownBits(I->getOperand(0), SubDemandedElts.shl(i), KnownSrc, Q,
1383                          Depth + 1);
1384         unsigned ShiftElt = Q.DL.isLittleEndian() ? i : SubScale - 1 - i;
1385         Known.insertBits(KnownSrc, ShiftElt * SubBitWidth);
1386       }
1387     }
1388     break;
1389   }
1390   case Instruction::SExt: {
1391     // Compute the bits in the result that are not present in the input.
1392     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
1393 
1394     Known = Known.trunc(SrcBitWidth);
1395     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1396     // If the sign bit of the input is known set or clear, then we know the
1397     // top bits of the result.
1398     Known = Known.sext(BitWidth);
1399     break;
1400   }
1401   case Instruction::Shl: {
1402     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1403     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1404     auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1405                          bool ShAmtNonZero) {
1406       return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW, ShAmtNonZero);
1407     };
1408     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1409                                       KF);
1410     // Trailing zeros of a right-shifted constant never decrease.
1411     const APInt *C;
1412     if (match(I->getOperand(0), m_APInt(C)))
1413       Known.Zero.setLowBits(C->countr_zero());
1414     break;
1415   }
1416   case Instruction::LShr: {
1417     bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
1418     auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1419                       bool ShAmtNonZero) {
1420       return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
1421     };
1422     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1423                                       KF);
1424     // Leading zeros of a left-shifted constant never decrease.
1425     const APInt *C;
1426     if (match(I->getOperand(0), m_APInt(C)))
1427       Known.Zero.setHighBits(C->countl_zero());
1428     break;
1429   }
1430   case Instruction::AShr: {
1431     bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
1432     auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1433                       bool ShAmtNonZero) {
1434       return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
1435     };
1436     computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1437                                       KF);
1438     break;
1439   }
1440   case Instruction::Sub: {
1441     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1442     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1443     computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
1444                            DemandedElts, Known, Known2, Q, Depth);
1445     break;
1446   }
1447   case Instruction::Add: {
1448     bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1449     bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
1450     computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
1451                            DemandedElts, Known, Known2, Q, Depth);
1452     break;
1453   }
1454   case Instruction::SRem:
1455     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1456     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1457     Known = KnownBits::srem(Known, Known2);
1458     break;
1459 
1460   case Instruction::URem:
1461     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1462     computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1463     Known = KnownBits::urem(Known, Known2);
1464     break;
1465   case Instruction::Alloca:
1466     Known.Zero.setLowBits(Log2(cast<AllocaInst>(I)->getAlign()));
1467     break;
1468   case Instruction::GetElementPtr: {
1469     // Analyze all of the subscripts of this getelementptr instruction
1470     // to determine if we can prove known low zero bits.
1471     computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
1472     // Accumulate the constant indices in a separate variable
1473     // to minimize the number of calls to computeForAddSub.
1474     unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(I->getType());
1475     APInt AccConstIndices(IndexWidth, 0);
1476 
1477     auto AddIndexToKnown = [&](KnownBits IndexBits) {
1478       if (IndexWidth == BitWidth) {
1479         // Note that inbounds does *not* guarantee nsw for the addition, as only
1480         // the offset is signed, while the base address is unsigned.
1481         Known = KnownBits::add(Known, IndexBits);
1482       } else {
1483         // If the index width is smaller than the pointer width, only add the
1484         // value to the low bits.
1485         assert(IndexWidth < BitWidth &&
1486                "Index width can't be larger than pointer width");
1487         Known.insertBits(KnownBits::add(Known.trunc(IndexWidth), IndexBits), 0);
1488       }
1489     };
1490 
1491     gep_type_iterator GTI = gep_type_begin(I);
1492     for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
1493       // TrailZ can only become smaller, short-circuit if we hit zero.
1494       if (Known.isUnknown())
1495         break;
1496 
1497       Value *Index = I->getOperand(i);
1498 
1499       // Handle case when index is zero.
1500       Constant *CIndex = dyn_cast<Constant>(Index);
1501       if (CIndex && CIndex->isZeroValue())
1502         continue;
1503 
1504       if (StructType *STy = GTI.getStructTypeOrNull()) {
1505         // Handle struct member offset arithmetic.
1506 
1507         assert(CIndex &&
1508                "Access to structure field must be known at compile time");
1509 
1510         if (CIndex->getType()->isVectorTy())
1511           Index = CIndex->getSplatValue();
1512 
1513         unsigned Idx = cast<ConstantInt>(Index)->getZExtValue();
1514         const StructLayout *SL = Q.DL.getStructLayout(STy);
1515         uint64_t Offset = SL->getElementOffset(Idx);
1516         AccConstIndices += Offset;
1517         continue;
1518       }
1519 
1520       // Handle array index arithmetic.
1521       Type *IndexedTy = GTI.getIndexedType();
1522       if (!IndexedTy->isSized()) {
1523         Known.resetAll();
1524         break;
1525       }
1526 
1527       TypeSize Stride = GTI.getSequentialElementStride(Q.DL);
1528       uint64_t StrideInBytes = Stride.getKnownMinValue();
1529       if (!Stride.isScalable()) {
1530         // Fast path for constant offset.
1531         if (auto *CI = dyn_cast<ConstantInt>(Index)) {
1532           AccConstIndices +=
1533               CI->getValue().sextOrTrunc(IndexWidth) * StrideInBytes;
1534           continue;
1535         }
1536       }
1537 
1538       KnownBits IndexBits =
1539           computeKnownBits(Index, Q, Depth + 1).sextOrTrunc(IndexWidth);
1540       KnownBits ScalingFactor(IndexWidth);
1541       // Multiply by current sizeof type.
1542       // &A[i] == A + i * sizeof(*A[i]).
1543       if (Stride.isScalable()) {
1544         // For scalable types the only thing we know about sizeof is
1545         // that this is a multiple of the minimum size.
1546         ScalingFactor.Zero.setLowBits(llvm::countr_zero(StrideInBytes));
1547       } else {
1548         ScalingFactor =
1549             KnownBits::makeConstant(APInt(IndexWidth, StrideInBytes));
1550       }
1551       AddIndexToKnown(KnownBits::mul(IndexBits, ScalingFactor));
1552     }
1553     if (!Known.isUnknown() && !AccConstIndices.isZero())
1554       AddIndexToKnown(KnownBits::makeConstant(AccConstIndices));
1555     break;
1556   }
1557   case Instruction::PHI: {
1558     const PHINode *P = cast<PHINode>(I);
1559     BinaryOperator *BO = nullptr;
1560     Value *R = nullptr, *L = nullptr;
1561     if (matchSimpleRecurrence(P, BO, R, L)) {
1562       // Handle the case of a simple two-predecessor recurrence PHI.
1563       // There's a lot more that could theoretically be done here, but
1564       // this is sufficient to catch some interesting cases.
1565       unsigned Opcode = BO->getOpcode();
1566 
1567       switch (Opcode) {
1568       // If this is a shift recurrence, we know the bits being shifted in. We
1569       // can combine that with information about the start value of the
1570       // recurrence to conclude facts about the result. If this is a udiv
1571       // recurrence, we know that the result can never exceed either the
1572       // numerator or the start value, whichever is greater.
1573       case Instruction::LShr:
1574       case Instruction::AShr:
1575       case Instruction::Shl:
1576       case Instruction::UDiv:
1577         if (BO->getOperand(0) != I)
1578           break;
1579         [[fallthrough]];
1580 
1581       // For a urem recurrence, the result can never exceed the start value. The
1582       // phi could either be the numerator or the denominator.
1583       case Instruction::URem: {
1584         // We have matched a recurrence of the form:
1585         // %iv = [R, %entry], [%iv.next, %backedge]
1586         // %iv.next = shift_op %iv, L
1587 
1588         // Recurse with the phi context to avoid concern about whether facts
1589         // inferred hold at original context instruction.  TODO: It may be
1590         // correct to use the original context.  IF warranted, explore and
1591         // add sufficient tests to cover.
1592         SimplifyQuery RecQ = Q.getWithoutCondContext();
1593         RecQ.CxtI = P;
1594         computeKnownBits(R, DemandedElts, Known2, RecQ, Depth + 1);
1595         switch (Opcode) {
1596         case Instruction::Shl:
1597           // A shl recurrence will only increase the tailing zeros
1598           Known.Zero.setLowBits(Known2.countMinTrailingZeros());
1599           break;
1600         case Instruction::LShr:
1601         case Instruction::UDiv:
1602         case Instruction::URem:
1603           // lshr, udiv, and urem recurrences will preserve the leading zeros of
1604           // the start value.
1605           Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1606           break;
1607         case Instruction::AShr:
1608           // An ashr recurrence will extend the initial sign bit
1609           Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1610           Known.One.setHighBits(Known2.countMinLeadingOnes());
1611           break;
1612         }
1613         break;
1614       }
1615 
1616       // Check for operations that have the property that if
1617       // both their operands have low zero bits, the result
1618       // will have low zero bits.
1619       case Instruction::Add:
1620       case Instruction::Sub:
1621       case Instruction::And:
1622       case Instruction::Or:
1623       case Instruction::Mul: {
1624         // Change the context instruction to the "edge" that flows into the
1625         // phi. This is important because that is where the value is actually
1626         // "evaluated" even though it is used later somewhere else. (see also
1627         // D69571).
1628         SimplifyQuery RecQ = Q.getWithoutCondContext();
1629 
1630         unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
1631         Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
1632         Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
1633 
1634         // Ok, we have a PHI of the form L op= R. Check for low
1635         // zero bits.
1636         RecQ.CxtI = RInst;
1637         computeKnownBits(R, DemandedElts, Known2, RecQ, Depth + 1);
1638 
1639         // We need to take the minimum number of known bits
1640         KnownBits Known3(BitWidth);
1641         RecQ.CxtI = LInst;
1642         computeKnownBits(L, DemandedElts, Known3, RecQ, Depth + 1);
1643 
1644         Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
1645                                        Known3.countMinTrailingZeros()));
1646 
1647         auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(BO);
1648         if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(OverflowOp))
1649           break;
1650 
1651         switch (Opcode) {
1652         // If initial value of recurrence is nonnegative, and we are adding
1653         // a nonnegative number with nsw, the result can only be nonnegative
1654         // or poison value regardless of the number of times we execute the
1655         // add in phi recurrence. If initial value is negative and we are
1656         // adding a negative number with nsw, the result can only be
1657         // negative or poison value. Similar arguments apply to sub and mul.
1658         //
1659         // (add non-negative, non-negative) --> non-negative
1660         // (add negative, negative) --> negative
1661         case Instruction::Add: {
1662           if (Known2.isNonNegative() && Known3.isNonNegative())
1663             Known.makeNonNegative();
1664           else if (Known2.isNegative() && Known3.isNegative())
1665             Known.makeNegative();
1666           break;
1667         }
1668 
1669         // (sub nsw non-negative, negative) --> non-negative
1670         // (sub nsw negative, non-negative) --> negative
1671         case Instruction::Sub: {
1672           if (BO->getOperand(0) != I)
1673             break;
1674           if (Known2.isNonNegative() && Known3.isNegative())
1675             Known.makeNonNegative();
1676           else if (Known2.isNegative() && Known3.isNonNegative())
1677             Known.makeNegative();
1678           break;
1679         }
1680 
1681         // (mul nsw non-negative, non-negative) --> non-negative
1682         case Instruction::Mul:
1683           if (Known2.isNonNegative() && Known3.isNonNegative())
1684             Known.makeNonNegative();
1685           break;
1686 
1687         default:
1688           break;
1689         }
1690         break;
1691       }
1692 
1693       default:
1694         break;
1695       }
1696     }
1697 
1698     // Unreachable blocks may have zero-operand PHI nodes.
1699     if (P->getNumIncomingValues() == 0)
1700       break;
1701 
1702     // Otherwise take the unions of the known bit sets of the operands,
1703     // taking conservative care to avoid excessive recursion.
1704     if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
1705       // Skip if every incoming value references to ourself.
1706       if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
1707         break;
1708 
1709       Known.Zero.setAllBits();
1710       Known.One.setAllBits();
1711       for (const Use &U : P->operands()) {
1712         Value *IncValue;
1713         const PHINode *CxtPhi;
1714         Instruction *CxtI;
1715         breakSelfRecursivePHI(&U, P, IncValue, CxtI, &CxtPhi);
1716         // Skip direct self references.
1717         if (IncValue == P)
1718           continue;
1719 
1720         // Change the context instruction to the "edge" that flows into the
1721         // phi. This is important because that is where the value is actually
1722         // "evaluated" even though it is used later somewhere else. (see also
1723         // D69571).
1724         SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(CxtI);
1725 
1726         Known2 = KnownBits(BitWidth);
1727 
1728         // Recurse, but cap the recursion to one level, because we don't
1729         // want to waste time spinning around in loops.
1730         // TODO: See if we can base recursion limiter on number of incoming phi
1731         // edges so we don't overly clamp analysis.
1732         computeKnownBits(IncValue, DemandedElts, Known2, RecQ,
1733                          MaxAnalysisRecursionDepth - 1);
1734 
1735         // See if we can further use a conditional branch into the phi
1736         // to help us determine the range of the value.
1737         if (!Known2.isConstant()) {
1738           CmpPredicate Pred;
1739           const APInt *RHSC;
1740           BasicBlock *TrueSucc, *FalseSucc;
1741           // TODO: Use RHS Value and compute range from its known bits.
1742           if (match(RecQ.CxtI,
1743                     m_Br(m_c_ICmp(Pred, m_Specific(IncValue), m_APInt(RHSC)),
1744                          m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
1745             // Check for cases of duplicate successors.
1746             if ((TrueSucc == CxtPhi->getParent()) !=
1747                 (FalseSucc == CxtPhi->getParent())) {
1748               // If we're using the false successor, invert the predicate.
1749               if (FalseSucc == CxtPhi->getParent())
1750                 Pred = CmpInst::getInversePredicate(Pred);
1751               // Get the knownbits implied by the incoming phi condition.
1752               auto CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
1753               KnownBits KnownUnion = Known2.unionWith(CR.toKnownBits());
1754               // We can have conflicts here if we are analyzing deadcode (its
1755               // impossible for us reach this BB based the icmp).
1756               if (KnownUnion.hasConflict()) {
1757                 // No reason to continue analyzing in a known dead region, so
1758                 // just resetAll and break. This will cause us to also exit the
1759                 // outer loop.
1760                 Known.resetAll();
1761                 break;
1762               }
1763               Known2 = KnownUnion;
1764             }
1765           }
1766         }
1767 
1768         Known = Known.intersectWith(Known2);
1769         // If all bits have been ruled out, there's no need to check
1770         // more operands.
1771         if (Known.isUnknown())
1772           break;
1773       }
1774     }
1775     break;
1776   }
1777   case Instruction::Call:
1778   case Instruction::Invoke: {
1779     // If range metadata is attached to this call, set known bits from that,
1780     // and then intersect with known bits based on other properties of the
1781     // function.
1782     if (MDNode *MD =
1783             Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
1784       computeKnownBitsFromRangeMetadata(*MD, Known);
1785 
1786     const auto *CB = cast<CallBase>(I);
1787 
1788     if (std::optional<ConstantRange> Range = CB->getRange())
1789       Known = Known.unionWith(Range->toKnownBits());
1790 
1791     if (const Value *RV = CB->getReturnedArgOperand()) {
1792       if (RV->getType() == I->getType()) {
1793         computeKnownBits(RV, Known2, Q, Depth + 1);
1794         Known = Known.unionWith(Known2);
1795         // If the function doesn't return properly for all input values
1796         // (e.g. unreachable exits) then there might be conflicts between the
1797         // argument value and the range metadata. Simply discard the known bits
1798         // in case of conflicts.
1799         if (Known.hasConflict())
1800           Known.resetAll();
1801       }
1802     }
1803     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
1804       switch (II->getIntrinsicID()) {
1805       default:
1806         break;
1807       case Intrinsic::abs: {
1808         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1809         bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
1810         Known = Known2.abs(IntMinIsPoison);
1811         break;
1812       }
1813       case Intrinsic::bitreverse:
1814         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1815         Known.Zero |= Known2.Zero.reverseBits();
1816         Known.One |= Known2.One.reverseBits();
1817         break;
1818       case Intrinsic::bswap:
1819         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1820         Known.Zero |= Known2.Zero.byteSwap();
1821         Known.One |= Known2.One.byteSwap();
1822         break;
1823       case Intrinsic::ctlz: {
1824         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1825         // If we have a known 1, its position is our upper bound.
1826         unsigned PossibleLZ = Known2.countMaxLeadingZeros();
1827         // If this call is poison for 0 input, the result will be less than 2^n.
1828         if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
1829           PossibleLZ = std::min(PossibleLZ, BitWidth - 1);
1830         unsigned LowBits = llvm::bit_width(PossibleLZ);
1831         Known.Zero.setBitsFrom(LowBits);
1832         break;
1833       }
1834       case Intrinsic::cttz: {
1835         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1836         // If we have a known 1, its position is our upper bound.
1837         unsigned PossibleTZ = Known2.countMaxTrailingZeros();
1838         // If this call is poison for 0 input, the result will be less than 2^n.
1839         if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
1840           PossibleTZ = std::min(PossibleTZ, BitWidth - 1);
1841         unsigned LowBits = llvm::bit_width(PossibleTZ);
1842         Known.Zero.setBitsFrom(LowBits);
1843         break;
1844       }
1845       case Intrinsic::ctpop: {
1846         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1847         // We can bound the space the count needs.  Also, bits known to be zero
1848         // can't contribute to the population.
1849         unsigned BitsPossiblySet = Known2.countMaxPopulation();
1850         unsigned LowBits = llvm::bit_width(BitsPossiblySet);
1851         Known.Zero.setBitsFrom(LowBits);
1852         // TODO: we could bound KnownOne using the lower bound on the number
1853         // of bits which might be set provided by popcnt KnownOne2.
1854         break;
1855       }
1856       case Intrinsic::fshr:
1857       case Intrinsic::fshl: {
1858         const APInt *SA;
1859         if (!match(I->getOperand(2), m_APInt(SA)))
1860           break;
1861 
1862         // Normalize to funnel shift left.
1863         uint64_t ShiftAmt = SA->urem(BitWidth);
1864         if (II->getIntrinsicID() == Intrinsic::fshr)
1865           ShiftAmt = BitWidth - ShiftAmt;
1866 
1867         KnownBits Known3(BitWidth);
1868         computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
1869         computeKnownBits(I->getOperand(1), DemandedElts, Known3, Q, Depth + 1);
1870 
1871         Known.Zero =
1872             Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt);
1873         Known.One =
1874             Known2.One.shl(ShiftAmt) | Known3.One.lshr(BitWidth - ShiftAmt);
1875         break;
1876       }
1877       case Intrinsic::uadd_sat:
1878         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1879         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1880         Known = KnownBits::uadd_sat(Known, Known2);
1881         break;
1882       case Intrinsic::usub_sat:
1883         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1884         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1885         Known = KnownBits::usub_sat(Known, Known2);
1886         break;
1887       case Intrinsic::sadd_sat:
1888         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1889         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1890         Known = KnownBits::sadd_sat(Known, Known2);
1891         break;
1892       case Intrinsic::ssub_sat:
1893         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1894         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1895         Known = KnownBits::ssub_sat(Known, Known2);
1896         break;
1897         // Vec reverse preserves bits from input vec.
1898       case Intrinsic::vector_reverse:
1899         computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known, Q,
1900                          Depth + 1);
1901         break;
1902         // for min/max/and/or reduce, any bit common to each element in the
1903         // input vec is set in the output.
1904       case Intrinsic::vector_reduce_and:
1905       case Intrinsic::vector_reduce_or:
1906       case Intrinsic::vector_reduce_umax:
1907       case Intrinsic::vector_reduce_umin:
1908       case Intrinsic::vector_reduce_smax:
1909       case Intrinsic::vector_reduce_smin:
1910         computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
1911         break;
1912       case Intrinsic::vector_reduce_xor: {
1913         computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
1914         // The zeros common to all vecs are zero in the output.
1915         // If the number of elements is odd, then the common ones remain. If the
1916         // number of elements is even, then the common ones becomes zeros.
1917         auto *VecTy = cast<VectorType>(I->getOperand(0)->getType());
1918         // Even, so the ones become zeros.
1919         bool EvenCnt = VecTy->getElementCount().isKnownEven();
1920         if (EvenCnt)
1921           Known.Zero |= Known.One;
1922         // Maybe even element count so need to clear ones.
1923         if (VecTy->isScalableTy() || EvenCnt)
1924           Known.One.clearAllBits();
1925         break;
1926       }
1927       case Intrinsic::umin:
1928         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1929         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1930         Known = KnownBits::umin(Known, Known2);
1931         break;
1932       case Intrinsic::umax:
1933         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1934         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1935         Known = KnownBits::umax(Known, Known2);
1936         break;
1937       case Intrinsic::smin:
1938         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1939         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1940         Known = KnownBits::smin(Known, Known2);
1941         unionWithMinMaxIntrinsicClamp(II, Known);
1942         break;
1943       case Intrinsic::smax:
1944         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1945         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1946         Known = KnownBits::smax(Known, Known2);
1947         unionWithMinMaxIntrinsicClamp(II, Known);
1948         break;
1949       case Intrinsic::ptrmask: {
1950         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1951 
1952         const Value *Mask = I->getOperand(1);
1953         Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
1954         computeKnownBits(Mask, DemandedElts, Known2, Q, Depth + 1);
1955         // TODO: 1-extend would be more precise.
1956         Known &= Known2.anyextOrTrunc(BitWidth);
1957         break;
1958       }
1959       case Intrinsic::x86_sse2_pmulh_w:
1960       case Intrinsic::x86_avx2_pmulh_w:
1961       case Intrinsic::x86_avx512_pmulh_w_512:
1962         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1963         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1964         Known = KnownBits::mulhs(Known, Known2);
1965         break;
1966       case Intrinsic::x86_sse2_pmulhu_w:
1967       case Intrinsic::x86_avx2_pmulhu_w:
1968       case Intrinsic::x86_avx512_pmulhu_w_512:
1969         computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
1970         computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
1971         Known = KnownBits::mulhu(Known, Known2);
1972         break;
1973       case Intrinsic::x86_sse42_crc32_64_64:
1974         Known.Zero.setBitsFrom(32);
1975         break;
1976       case Intrinsic::x86_ssse3_phadd_d_128:
1977       case Intrinsic::x86_ssse3_phadd_w_128:
1978       case Intrinsic::x86_avx2_phadd_d:
1979       case Intrinsic::x86_avx2_phadd_w: {
1980         Known = computeKnownBitsForHorizontalOperation(
1981             I, DemandedElts, Q, Depth,
1982             [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
1983               return KnownBits::add(KnownLHS, KnownRHS);
1984             });
1985         break;
1986       }
1987       case Intrinsic::x86_ssse3_phadd_sw_128:
1988       case Intrinsic::x86_avx2_phadd_sw: {
1989         Known = computeKnownBitsForHorizontalOperation(
1990             I, DemandedElts, Q, Depth, KnownBits::sadd_sat);
1991         break;
1992       }
1993       case Intrinsic::x86_ssse3_phsub_d_128:
1994       case Intrinsic::x86_ssse3_phsub_w_128:
1995       case Intrinsic::x86_avx2_phsub_d:
1996       case Intrinsic::x86_avx2_phsub_w: {
1997         Known = computeKnownBitsForHorizontalOperation(
1998             I, DemandedElts, Q, Depth,
1999             [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2000               return KnownBits::sub(KnownLHS, KnownRHS);
2001             });
2002         break;
2003       }
2004       case Intrinsic::x86_ssse3_phsub_sw_128:
2005       case Intrinsic::x86_avx2_phsub_sw: {
2006         Known = computeKnownBitsForHorizontalOperation(
2007             I, DemandedElts, Q, Depth, KnownBits::ssub_sat);
2008         break;
2009       }
2010       case Intrinsic::riscv_vsetvli:
2011       case Intrinsic::riscv_vsetvlimax: {
2012         bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
2013         const ConstantRange Range = getVScaleRange(II->getFunction(), BitWidth);
2014         uint64_t SEW = RISCVVType::decodeVSEW(
2015             cast<ConstantInt>(II->getArgOperand(HasAVL))->getZExtValue());
2016         RISCVVType::VLMUL VLMUL = static_cast<RISCVVType::VLMUL>(
2017             cast<ConstantInt>(II->getArgOperand(1 + HasAVL))->getZExtValue());
2018         uint64_t MaxVLEN =
2019             Range.getUnsignedMax().getZExtValue() * RISCV::RVVBitsPerBlock;
2020         uint64_t MaxVL = MaxVLEN / RISCVVType::getSEWLMULRatio(SEW, VLMUL);
2021 
2022         // Result of vsetvli must be not larger than AVL.
2023         if (HasAVL)
2024           if (auto *CI = dyn_cast<ConstantInt>(II->getArgOperand(0)))
2025             MaxVL = std::min(MaxVL, CI->getZExtValue());
2026 
2027         unsigned KnownZeroFirstBit = Log2_32(MaxVL) + 1;
2028         if (BitWidth > KnownZeroFirstBit)
2029           Known.Zero.setBitsFrom(KnownZeroFirstBit);
2030         break;
2031       }
2032       case Intrinsic::vscale: {
2033         if (!II->getParent() || !II->getFunction())
2034           break;
2035 
2036         Known = getVScaleRange(II->getFunction(), BitWidth).toKnownBits();
2037         break;
2038       }
2039       }
2040     }
2041     break;
2042   }
2043   case Instruction::ShuffleVector: {
2044     auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
2045     // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
2046     if (!Shuf) {
2047       Known.resetAll();
2048       return;
2049     }
2050     // For undef elements, we don't know anything about the common state of
2051     // the shuffle result.
2052     APInt DemandedLHS, DemandedRHS;
2053     if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
2054       Known.resetAll();
2055       return;
2056     }
2057     Known.One.setAllBits();
2058     Known.Zero.setAllBits();
2059     if (!!DemandedLHS) {
2060       const Value *LHS = Shuf->getOperand(0);
2061       computeKnownBits(LHS, DemandedLHS, Known, Q, Depth + 1);
2062       // If we don't know any bits, early out.
2063       if (Known.isUnknown())
2064         break;
2065     }
2066     if (!!DemandedRHS) {
2067       const Value *RHS = Shuf->getOperand(1);
2068       computeKnownBits(RHS, DemandedRHS, Known2, Q, Depth + 1);
2069       Known = Known.intersectWith(Known2);
2070     }
2071     break;
2072   }
2073   case Instruction::InsertElement: {
2074     if (isa<ScalableVectorType>(I->getType())) {
2075       Known.resetAll();
2076       return;
2077     }
2078     const Value *Vec = I->getOperand(0);
2079     const Value *Elt = I->getOperand(1);
2080     auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2));
2081     unsigned NumElts = DemandedElts.getBitWidth();
2082     APInt DemandedVecElts = DemandedElts;
2083     bool NeedsElt = true;
2084     // If we know the index we are inserting too, clear it from Vec check.
2085     if (CIdx && CIdx->getValue().ult(NumElts)) {
2086       DemandedVecElts.clearBit(CIdx->getZExtValue());
2087       NeedsElt = DemandedElts[CIdx->getZExtValue()];
2088     }
2089 
2090     Known.One.setAllBits();
2091     Known.Zero.setAllBits();
2092     if (NeedsElt) {
2093       computeKnownBits(Elt, Known, Q, Depth + 1);
2094       // If we don't know any bits, early out.
2095       if (Known.isUnknown())
2096         break;
2097     }
2098 
2099     if (!DemandedVecElts.isZero()) {
2100       computeKnownBits(Vec, DemandedVecElts, Known2, Q, Depth + 1);
2101       Known = Known.intersectWith(Known2);
2102     }
2103     break;
2104   }
2105   case Instruction::ExtractElement: {
2106     // Look through extract element. If the index is non-constant or
2107     // out-of-range demand all elements, otherwise just the extracted element.
2108     const Value *Vec = I->getOperand(0);
2109     const Value *Idx = I->getOperand(1);
2110     auto *CIdx = dyn_cast<ConstantInt>(Idx);
2111     if (isa<ScalableVectorType>(Vec->getType())) {
2112       // FIXME: there's probably *something* we can do with scalable vectors
2113       Known.resetAll();
2114       break;
2115     }
2116     unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
2117     APInt DemandedVecElts = APInt::getAllOnes(NumElts);
2118     if (CIdx && CIdx->getValue().ult(NumElts))
2119       DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
2120     computeKnownBits(Vec, DemandedVecElts, Known, Q, Depth + 1);
2121     break;
2122   }
2123   case Instruction::ExtractValue:
2124     if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) {
2125       const ExtractValueInst *EVI = cast<ExtractValueInst>(I);
2126       if (EVI->getNumIndices() != 1) break;
2127       if (EVI->getIndices()[0] == 0) {
2128         switch (II->getIntrinsicID()) {
2129         default: break;
2130         case Intrinsic::uadd_with_overflow:
2131         case Intrinsic::sadd_with_overflow:
2132           computeKnownBitsAddSub(
2133               true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
2134               /* NUW=*/false, DemandedElts, Known, Known2, Q, Depth);
2135           break;
2136         case Intrinsic::usub_with_overflow:
2137         case Intrinsic::ssub_with_overflow:
2138           computeKnownBitsAddSub(
2139               false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
2140               /* NUW=*/false, DemandedElts, Known, Known2, Q, Depth);
2141           break;
2142         case Intrinsic::umul_with_overflow:
2143         case Intrinsic::smul_with_overflow:
2144           computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false,
2145                               false, DemandedElts, Known, Known2, Q, Depth);
2146           break;
2147         }
2148       }
2149     }
2150     break;
2151   case Instruction::Freeze:
2152     if (isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT,
2153                                   Depth + 1))
2154       computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
2155     break;
2156   }
2157 }
2158 
2159 /// Determine which bits of V are known to be either zero or one and return
2160 /// them.
2161 KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
2162                                  const SimplifyQuery &Q, unsigned Depth) {
2163   KnownBits Known(getBitWidth(V->getType(), Q.DL));
2164   ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
2165   return Known;
2166 }
2167 
2168 /// Determine which bits of V are known to be either zero or one and return
2169 /// them.
2170 KnownBits llvm::computeKnownBits(const Value *V, const SimplifyQuery &Q,
2171                                  unsigned Depth) {
2172   KnownBits Known(getBitWidth(V->getType(), Q.DL));
2173   computeKnownBits(V, Known, Q, Depth);
2174   return Known;
2175 }
2176 
2177 /// Determine which bits of V are known to be either zero or one and return
2178 /// them in the Known bit set.
2179 ///
2180 /// NOTE: we cannot consider 'undef' to be "IsZero" here.  The problem is that
2181 /// we cannot optimize based on the assumption that it is zero without changing
2182 /// it to be an explicit zero.  If we don't change it to zero, other code could
2183 /// optimized based on the contradictory assumption that it is non-zero.
2184 /// Because instcombine aggressively folds operations with undef args anyway,
2185 /// this won't lose us code quality.
2186 ///
2187 /// This function is defined on values with integer type, values with pointer
2188 /// type, and vectors of integers.  In the case
2189 /// where V is a vector, known zero, and known one values are the
2190 /// same width as the vector element, and the bit is set only if it is true
2191 /// for all of the demanded elements in the vector specified by DemandedElts.
2192 void computeKnownBits(const Value *V, const APInt &DemandedElts,
2193                       KnownBits &Known, const SimplifyQuery &Q,
2194                       unsigned Depth) {
2195   if (!DemandedElts) {
2196     // No demanded elts, better to assume we don't know anything.
2197     Known.resetAll();
2198     return;
2199   }
2200 
2201   assert(V && "No Value?");
2202   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2203 
2204 #ifndef NDEBUG
2205   Type *Ty = V->getType();
2206   unsigned BitWidth = Known.getBitWidth();
2207 
2208   assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
2209          "Not integer or pointer type!");
2210 
2211   if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
2212     assert(
2213         FVTy->getNumElements() == DemandedElts.getBitWidth() &&
2214         "DemandedElt width should equal the fixed vector number of elements");
2215   } else {
2216     assert(DemandedElts == APInt(1, 1) &&
2217            "DemandedElt width should be 1 for scalars or scalable vectors");
2218   }
2219 
2220   Type *ScalarTy = Ty->getScalarType();
2221   if (ScalarTy->isPointerTy()) {
2222     assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
2223            "V and Known should have same BitWidth");
2224   } else {
2225     assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
2226            "V and Known should have same BitWidth");
2227   }
2228 #endif
2229 
2230   const APInt *C;
2231   if (match(V, m_APInt(C))) {
2232     // We know all of the bits for a scalar constant or a splat vector constant!
2233     Known = KnownBits::makeConstant(*C);
2234     return;
2235   }
2236   // Null and aggregate-zero are all-zeros.
2237   if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) {
2238     Known.setAllZero();
2239     return;
2240   }
2241   // Handle a constant vector by taking the intersection of the known bits of
2242   // each element.
2243   if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
2244     assert(!isa<ScalableVectorType>(V->getType()));
2245     // We know that CDV must be a vector of integers. Take the intersection of
2246     // each element.
2247     Known.Zero.setAllBits(); Known.One.setAllBits();
2248     for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
2249       if (!DemandedElts[i])
2250         continue;
2251       APInt Elt = CDV->getElementAsAPInt(i);
2252       Known.Zero &= ~Elt;
2253       Known.One &= Elt;
2254     }
2255     if (Known.hasConflict())
2256       Known.resetAll();
2257     return;
2258   }
2259 
2260   if (const auto *CV = dyn_cast<ConstantVector>(V)) {
2261     assert(!isa<ScalableVectorType>(V->getType()));
2262     // We know that CV must be a vector of integers. Take the intersection of
2263     // each element.
2264     Known.Zero.setAllBits(); Known.One.setAllBits();
2265     for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
2266       if (!DemandedElts[i])
2267         continue;
2268       Constant *Element = CV->getAggregateElement(i);
2269       if (isa<PoisonValue>(Element))
2270         continue;
2271       auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element);
2272       if (!ElementCI) {
2273         Known.resetAll();
2274         return;
2275       }
2276       const APInt &Elt = ElementCI->getValue();
2277       Known.Zero &= ~Elt;
2278       Known.One &= Elt;
2279     }
2280     if (Known.hasConflict())
2281       Known.resetAll();
2282     return;
2283   }
2284 
2285   // Start out not knowing anything.
2286   Known.resetAll();
2287 
2288   // We can't imply anything about undefs.
2289   if (isa<UndefValue>(V))
2290     return;
2291 
2292   // There's no point in looking through other users of ConstantData for
2293   // assumptions.  Confirm that we've handled them all.
2294   assert(!isa<ConstantData>(V) && "Unhandled constant data!");
2295 
2296   if (const auto *A = dyn_cast<Argument>(V))
2297     if (std::optional<ConstantRange> Range = A->getRange())
2298       Known = Range->toKnownBits();
2299 
2300   // All recursive calls that increase depth must come after this.
2301   if (Depth == MaxAnalysisRecursionDepth)
2302     return;
2303 
2304   // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
2305   // the bits of its aliasee.
2306   if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
2307     if (!GA->isInterposable())
2308       computeKnownBits(GA->getAliasee(), Known, Q, Depth + 1);
2309     return;
2310   }
2311 
2312   if (const Operator *I = dyn_cast<Operator>(V))
2313     computeKnownBitsFromOperator(I, DemandedElts, Known, Q, Depth);
2314   else if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
2315     if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
2316       Known = CR->toKnownBits();
2317   }
2318 
2319   // Aligned pointers have trailing zeros - refine Known.Zero set
2320   if (isa<PointerType>(V->getType())) {
2321     Align Alignment = V->getPointerAlignment(Q.DL);
2322     Known.Zero.setLowBits(Log2(Alignment));
2323   }
2324 
2325   // computeKnownBitsFromContext strictly refines Known.
2326   // Therefore, we run them after computeKnownBitsFromOperator.
2327 
2328   // Check whether we can determine known bits from context such as assumes.
2329   computeKnownBitsFromContext(V, Known, Q, Depth);
2330 }
2331 
2332 /// Try to detect a recurrence that the value of the induction variable is
2333 /// always a power of two (or zero).
2334 static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
2335                                    SimplifyQuery &Q, unsigned Depth) {
2336   BinaryOperator *BO = nullptr;
2337   Value *Start = nullptr, *Step = nullptr;
2338   if (!matchSimpleRecurrence(PN, BO, Start, Step))
2339     return false;
2340 
2341   // Initial value must be a power of two.
2342   for (const Use &U : PN->operands()) {
2343     if (U.get() == Start) {
2344       // Initial value comes from a different BB, need to adjust context
2345       // instruction for analysis.
2346       Q.CxtI = PN->getIncomingBlock(U)->getTerminator();
2347       if (!isKnownToBeAPowerOfTwo(Start, OrZero, Q, Depth))
2348         return false;
2349     }
2350   }
2351 
2352   // Except for Mul, the induction variable must be on the left side of the
2353   // increment expression, otherwise its value can be arbitrary.
2354   if (BO->getOpcode() != Instruction::Mul && BO->getOperand(1) != Step)
2355     return false;
2356 
2357   Q.CxtI = BO->getParent()->getTerminator();
2358   switch (BO->getOpcode()) {
2359   case Instruction::Mul:
2360     // Power of two is closed under multiplication.
2361     return (OrZero || Q.IIQ.hasNoUnsignedWrap(BO) ||
2362             Q.IIQ.hasNoSignedWrap(BO)) &&
2363            isKnownToBeAPowerOfTwo(Step, OrZero, Q, Depth);
2364   case Instruction::SDiv:
2365     // Start value must not be signmask for signed division, so simply being a
2366     // power of two is not sufficient, and it has to be a constant.
2367     if (!match(Start, m_Power2()) || match(Start, m_SignMask()))
2368       return false;
2369     [[fallthrough]];
2370   case Instruction::UDiv:
2371     // Divisor must be a power of two.
2372     // If OrZero is false, cannot guarantee induction variable is non-zero after
2373     // division, same for Shr, unless it is exact division.
2374     return (OrZero || Q.IIQ.isExact(BO)) &&
2375            isKnownToBeAPowerOfTwo(Step, false, Q, Depth);
2376   case Instruction::Shl:
2377     return OrZero || Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO);
2378   case Instruction::AShr:
2379     if (!match(Start, m_Power2()) || match(Start, m_SignMask()))
2380       return false;
2381     [[fallthrough]];
2382   case Instruction::LShr:
2383     return OrZero || Q.IIQ.isExact(BO);
2384   default:
2385     return false;
2386   }
2387 }
2388 
2389 /// Return true if we can infer that \p V is known to be a power of 2 from
2390 /// dominating condition \p Cond (e.g., ctpop(V) == 1).
2391 static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
2392                                              const Value *Cond,
2393                                              bool CondIsTrue) {
2394   CmpPredicate Pred;
2395   const APInt *RHSC;
2396   if (!match(Cond, m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Specific(V)),
2397                           m_APInt(RHSC))))
2398     return false;
2399   if (!CondIsTrue)
2400     Pred = ICmpInst::getInversePredicate(Pred);
2401   // ctpop(V) u< 2
2402   if (OrZero && Pred == ICmpInst::ICMP_ULT && *RHSC == 2)
2403     return true;
2404   // ctpop(V) == 1
2405   return Pred == ICmpInst::ICMP_EQ && *RHSC == 1;
2406 }
2407 
2408 /// Return true if the given value is known to have exactly one
2409 /// bit set when defined. For vectors return true if every element is known to
2410 /// be a power of two when defined. Supports values with integer or pointer
2411 /// types and vectors of integers.
2412 bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero,
2413                                   const SimplifyQuery &Q, unsigned Depth) {
2414   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2415 
2416   if (isa<Constant>(V))
2417     return OrZero ? match(V, m_Power2OrZero()) : match(V, m_Power2());
2418 
2419   // i1 is by definition a power of 2 or zero.
2420   if (OrZero && V->getType()->getScalarSizeInBits() == 1)
2421     return true;
2422 
2423   // Try to infer from assumptions.
2424   if (Q.AC && Q.CxtI) {
2425     for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
2426       if (!AssumeVH)
2427         continue;
2428       CallInst *I = cast<CallInst>(AssumeVH);
2429       if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, I->getArgOperand(0),
2430                                            /*CondIsTrue=*/true) &&
2431           isValidAssumeForContext(I, Q.CxtI, Q.DT))
2432         return true;
2433     }
2434   }
2435 
2436   // Handle dominating conditions.
2437   if (Q.DC && Q.CxtI && Q.DT) {
2438     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
2439       Value *Cond = BI->getCondition();
2440 
2441       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
2442       if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2443                                            /*CondIsTrue=*/true) &&
2444           Q.DT->dominates(Edge0, Q.CxtI->getParent()))
2445         return true;
2446 
2447       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
2448       if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2449                                            /*CondIsTrue=*/false) &&
2450           Q.DT->dominates(Edge1, Q.CxtI->getParent()))
2451         return true;
2452     }
2453   }
2454 
2455   auto *I = dyn_cast<Instruction>(V);
2456   if (!I)
2457     return false;
2458 
2459   if (Q.CxtI && match(V, m_VScale())) {
2460     const Function *F = Q.CxtI->getFunction();
2461     // The vscale_range indicates vscale is a power-of-two.
2462     return F->hasFnAttribute(Attribute::VScaleRange);
2463   }
2464 
2465   // 1 << X is clearly a power of two if the one is not shifted off the end.  If
2466   // it is shifted off the end then the result is undefined.
2467   if (match(I, m_Shl(m_One(), m_Value())))
2468     return true;
2469 
2470   // (signmask) >>l X is clearly a power of two if the one is not shifted off
2471   // the bottom.  If it is shifted off the bottom then the result is undefined.
2472   if (match(I, m_LShr(m_SignMask(), m_Value())))
2473     return true;
2474 
2475   // The remaining tests are all recursive, so bail out if we hit the limit.
2476   if (Depth++ == MaxAnalysisRecursionDepth)
2477     return false;
2478 
2479   switch (I->getOpcode()) {
2480   case Instruction::ZExt:
2481     return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth);
2482   case Instruction::Trunc:
2483     return OrZero && isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth);
2484   case Instruction::Shl:
2485     if (OrZero || Q.IIQ.hasNoUnsignedWrap(I) || Q.IIQ.hasNoSignedWrap(I))
2486       return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth);
2487     return false;
2488   case Instruction::LShr:
2489     if (OrZero || Q.IIQ.isExact(cast<BinaryOperator>(I)))
2490       return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth);
2491     return false;
2492   case Instruction::UDiv:
2493     if (Q.IIQ.isExact(cast<BinaryOperator>(I)))
2494       return isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth);
2495     return false;
2496   case Instruction::Mul:
2497     return isKnownToBeAPowerOfTwo(I->getOperand(1), OrZero, Q, Depth) &&
2498            isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth) &&
2499            (OrZero || isKnownNonZero(I, Q, Depth));
2500   case Instruction::And:
2501     // A power of two and'd with anything is a power of two or zero.
2502     if (OrZero &&
2503         (isKnownToBeAPowerOfTwo(I->getOperand(1), /*OrZero*/ true, Q, Depth) ||
2504          isKnownToBeAPowerOfTwo(I->getOperand(0), /*OrZero*/ true, Q, Depth)))
2505       return true;
2506     // X & (-X) is always a power of two or zero.
2507     if (match(I->getOperand(0), m_Neg(m_Specific(I->getOperand(1)))) ||
2508         match(I->getOperand(1), m_Neg(m_Specific(I->getOperand(0)))))
2509       return OrZero || isKnownNonZero(I->getOperand(0), Q, Depth);
2510     return false;
2511   case Instruction::Add: {
2512     // Adding a power-of-two or zero to the same power-of-two or zero yields
2513     // either the original power-of-two, a larger power-of-two or zero.
2514     const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V);
2515     if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) ||
2516         Q.IIQ.hasNoSignedWrap(VOBO)) {
2517       if (match(I->getOperand(0),
2518                 m_c_And(m_Specific(I->getOperand(1)), m_Value())) &&
2519           isKnownToBeAPowerOfTwo(I->getOperand(1), OrZero, Q, Depth))
2520         return true;
2521       if (match(I->getOperand(1),
2522                 m_c_And(m_Specific(I->getOperand(0)), m_Value())) &&
2523           isKnownToBeAPowerOfTwo(I->getOperand(0), OrZero, Q, Depth))
2524         return true;
2525 
2526       unsigned BitWidth = V->getType()->getScalarSizeInBits();
2527       KnownBits LHSBits(BitWidth);
2528       computeKnownBits(I->getOperand(0), LHSBits, Q, Depth);
2529 
2530       KnownBits RHSBits(BitWidth);
2531       computeKnownBits(I->getOperand(1), RHSBits, Q, Depth);
2532       // If i8 V is a power of two or zero:
2533       //  ZeroBits: 1 1 1 0 1 1 1 1
2534       // ~ZeroBits: 0 0 0 1 0 0 0 0
2535       if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2())
2536         // If OrZero isn't set, we cannot give back a zero result.
2537         // Make sure either the LHS or RHS has a bit set.
2538         if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue())
2539           return true;
2540     }
2541 
2542     // LShr(UINT_MAX, Y) + 1 is a power of two (if add is nuw) or zero.
2543     if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO))
2544       if (match(I, m_Add(m_LShr(m_AllOnes(), m_Value()), m_One())))
2545         return true;
2546     return false;
2547   }
2548   case Instruction::Select:
2549     return isKnownToBeAPowerOfTwo(I->getOperand(1), OrZero, Q, Depth) &&
2550            isKnownToBeAPowerOfTwo(I->getOperand(2), OrZero, Q, Depth);
2551   case Instruction::PHI: {
2552     // A PHI node is power of two if all incoming values are power of two, or if
2553     // it is an induction variable where in each step its value is a power of
2554     // two.
2555     auto *PN = cast<PHINode>(I);
2556     SimplifyQuery RecQ = Q.getWithoutCondContext();
2557 
2558     // Check if it is an induction variable and always power of two.
2559     if (isPowerOfTwoRecurrence(PN, OrZero, RecQ, Depth))
2560       return true;
2561 
2562     // Recursively check all incoming values. Limit recursion to 2 levels, so
2563     // that search complexity is limited to number of operands^2.
2564     unsigned NewDepth = std::max(Depth, MaxAnalysisRecursionDepth - 1);
2565     return llvm::all_of(PN->operands(), [&](const Use &U) {
2566       // Value is power of 2 if it is coming from PHI node itself by induction.
2567       if (U.get() == PN)
2568         return true;
2569 
2570       // Change the context instruction to the incoming block where it is
2571       // evaluated.
2572       RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
2573       return isKnownToBeAPowerOfTwo(U.get(), OrZero, RecQ, NewDepth);
2574     });
2575   }
2576   case Instruction::Invoke:
2577   case Instruction::Call: {
2578     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
2579       switch (II->getIntrinsicID()) {
2580       case Intrinsic::umax:
2581       case Intrinsic::smax:
2582       case Intrinsic::umin:
2583       case Intrinsic::smin:
2584         return isKnownToBeAPowerOfTwo(II->getArgOperand(1), OrZero, Q, Depth) &&
2585                isKnownToBeAPowerOfTwo(II->getArgOperand(0), OrZero, Q, Depth);
2586       // bswap/bitreverse just move around bits, but don't change any 1s/0s
2587       // thus dont change pow2/non-pow2 status.
2588       case Intrinsic::bitreverse:
2589       case Intrinsic::bswap:
2590         return isKnownToBeAPowerOfTwo(II->getArgOperand(0), OrZero, Q, Depth);
2591       case Intrinsic::fshr:
2592       case Intrinsic::fshl:
2593         // If Op0 == Op1, this is a rotate. is_pow2(rotate(x, y)) == is_pow2(x)
2594         if (II->getArgOperand(0) == II->getArgOperand(1))
2595           return isKnownToBeAPowerOfTwo(II->getArgOperand(0), OrZero, Q, Depth);
2596         break;
2597       default:
2598         break;
2599       }
2600     }
2601     return false;
2602   }
2603   default:
2604     return false;
2605   }
2606 }
2607 
2608 /// Test whether a GEP's result is known to be non-null.
2609 ///
2610 /// Uses properties inherent in a GEP to try to determine whether it is known
2611 /// to be non-null.
2612 ///
2613 /// Currently this routine does not support vector GEPs.
2614 static bool isGEPKnownNonNull(const GEPOperator *GEP, const SimplifyQuery &Q,
2615                               unsigned Depth) {
2616   const Function *F = nullptr;
2617   if (const Instruction *I = dyn_cast<Instruction>(GEP))
2618     F = I->getFunction();
2619 
2620   // If the gep is nuw or inbounds with invalid null pointer, then the GEP
2621   // may be null iff the base pointer is null and the offset is zero.
2622   if (!GEP->hasNoUnsignedWrap() &&
2623       !(GEP->isInBounds() &&
2624         !NullPointerIsDefined(F, GEP->getPointerAddressSpace())))
2625     return false;
2626 
2627   // FIXME: Support vector-GEPs.
2628   assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP");
2629 
2630   // If the base pointer is non-null, we cannot walk to a null address with an
2631   // inbounds GEP in address space zero.
2632   if (isKnownNonZero(GEP->getPointerOperand(), Q, Depth))
2633     return true;
2634 
2635   // Walk the GEP operands and see if any operand introduces a non-zero offset.
2636   // If so, then the GEP cannot produce a null pointer, as doing so would
2637   // inherently violate the inbounds contract within address space zero.
2638   for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
2639        GTI != GTE; ++GTI) {
2640     // Struct types are easy -- they must always be indexed by a constant.
2641     if (StructType *STy = GTI.getStructTypeOrNull()) {
2642       ConstantInt *OpC = cast<ConstantInt>(GTI.getOperand());
2643       unsigned ElementIdx = OpC->getZExtValue();
2644       const StructLayout *SL = Q.DL.getStructLayout(STy);
2645       uint64_t ElementOffset = SL->getElementOffset(ElementIdx);
2646       if (ElementOffset > 0)
2647         return true;
2648       continue;
2649     }
2650 
2651     // If we have a zero-sized type, the index doesn't matter. Keep looping.
2652     if (GTI.getSequentialElementStride(Q.DL).isZero())
2653       continue;
2654 
2655     // Fast path the constant operand case both for efficiency and so we don't
2656     // increment Depth when just zipping down an all-constant GEP.
2657     if (ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand())) {
2658       if (!OpC->isZero())
2659         return true;
2660       continue;
2661     }
2662 
2663     // We post-increment Depth here because while isKnownNonZero increments it
2664     // as well, when we pop back up that increment won't persist. We don't want
2665     // to recurse 10k times just because we have 10k GEP operands. We don't
2666     // bail completely out because we want to handle constant GEPs regardless
2667     // of depth.
2668     if (Depth++ >= MaxAnalysisRecursionDepth)
2669       continue;
2670 
2671     if (isKnownNonZero(GTI.getOperand(), Q, Depth))
2672       return true;
2673   }
2674 
2675   return false;
2676 }
2677 
2678 static bool isKnownNonNullFromDominatingCondition(const Value *V,
2679                                                   const Instruction *CtxI,
2680                                                   const DominatorTree *DT) {
2681   assert(!isa<Constant>(V) && "Called for constant?");
2682 
2683   if (!CtxI || !DT)
2684     return false;
2685 
2686   unsigned NumUsesExplored = 0;
2687   for (auto &U : V->uses()) {
2688     // Avoid massive lists
2689     if (NumUsesExplored >= DomConditionsMaxUses)
2690       break;
2691     NumUsesExplored++;
2692 
2693     const Instruction *UI = cast<Instruction>(U.getUser());
2694     // If the value is used as an argument to a call or invoke, then argument
2695     // attributes may provide an answer about null-ness.
2696     if (V->getType()->isPointerTy()) {
2697       if (const auto *CB = dyn_cast<CallBase>(UI)) {
2698         if (CB->isArgOperand(&U) &&
2699             CB->paramHasNonNullAttr(CB->getArgOperandNo(&U),
2700                                     /*AllowUndefOrPoison=*/false) &&
2701             DT->dominates(CB, CtxI))
2702           return true;
2703       }
2704     }
2705 
2706     // If the value is used as a load/store, then the pointer must be non null.
2707     if (V == getLoadStorePointerOperand(UI)) {
2708       if (!NullPointerIsDefined(UI->getFunction(),
2709                                 V->getType()->getPointerAddressSpace()) &&
2710           DT->dominates(UI, CtxI))
2711         return true;
2712     }
2713 
2714     if ((match(UI, m_IDiv(m_Value(), m_Specific(V))) ||
2715          match(UI, m_IRem(m_Value(), m_Specific(V)))) &&
2716         isValidAssumeForContext(UI, CtxI, DT))
2717       return true;
2718 
2719     // Consider only compare instructions uniquely controlling a branch
2720     Value *RHS;
2721     CmpPredicate Pred;
2722     if (!match(UI, m_c_ICmp(Pred, m_Specific(V), m_Value(RHS))))
2723       continue;
2724 
2725     bool NonNullIfTrue;
2726     if (cmpExcludesZero(Pred, RHS))
2727       NonNullIfTrue = true;
2728     else if (cmpExcludesZero(CmpInst::getInversePredicate(Pred), RHS))
2729       NonNullIfTrue = false;
2730     else
2731       continue;
2732 
2733     SmallVector<const User *, 4> WorkList;
2734     SmallPtrSet<const User *, 4> Visited;
2735     for (const auto *CmpU : UI->users()) {
2736       assert(WorkList.empty() && "Should be!");
2737       if (Visited.insert(CmpU).second)
2738         WorkList.push_back(CmpU);
2739 
2740       while (!WorkList.empty()) {
2741         auto *Curr = WorkList.pop_back_val();
2742 
2743         // If a user is an AND, add all its users to the work list. We only
2744         // propagate "pred != null" condition through AND because it is only
2745         // correct to assume that all conditions of AND are met in true branch.
2746         // TODO: Support similar logic of OR and EQ predicate?
2747         if (NonNullIfTrue)
2748           if (match(Curr, m_LogicalAnd(m_Value(), m_Value()))) {
2749             for (const auto *CurrU : Curr->users())
2750               if (Visited.insert(CurrU).second)
2751                 WorkList.push_back(CurrU);
2752             continue;
2753           }
2754 
2755         if (const BranchInst *BI = dyn_cast<BranchInst>(Curr)) {
2756           assert(BI->isConditional() && "uses a comparison!");
2757 
2758           BasicBlock *NonNullSuccessor =
2759               BI->getSuccessor(NonNullIfTrue ? 0 : 1);
2760           BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor);
2761           if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent()))
2762             return true;
2763         } else if (NonNullIfTrue && isGuard(Curr) &&
2764                    DT->dominates(cast<Instruction>(Curr), CtxI)) {
2765           return true;
2766         }
2767       }
2768     }
2769   }
2770 
2771   return false;
2772 }
2773 
2774 /// Does the 'Range' metadata (which must be a valid MD_range operand list)
2775 /// ensure that the value it's attached to is never Value?  'RangeType' is
2776 /// is the type of the value described by the range.
2777 static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) {
2778   const unsigned NumRanges = Ranges->getNumOperands() / 2;
2779   assert(NumRanges >= 1);
2780   for (unsigned i = 0; i < NumRanges; ++i) {
2781     ConstantInt *Lower =
2782         mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0));
2783     ConstantInt *Upper =
2784         mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1));
2785     ConstantRange Range(Lower->getValue(), Upper->getValue());
2786     if (Range.contains(Value))
2787       return false;
2788   }
2789   return true;
2790 }
2791 
2792 /// Try to detect a recurrence that monotonically increases/decreases from a
2793 /// non-zero starting value. These are common as induction variables.
2794 static bool isNonZeroRecurrence(const PHINode *PN) {
2795   BinaryOperator *BO = nullptr;
2796   Value *Start = nullptr, *Step = nullptr;
2797   const APInt *StartC, *StepC;
2798   if (!matchSimpleRecurrence(PN, BO, Start, Step) ||
2799       !match(Start, m_APInt(StartC)) || StartC->isZero())
2800     return false;
2801 
2802   switch (BO->getOpcode()) {
2803   case Instruction::Add:
2804     // Starting from non-zero and stepping away from zero can never wrap back
2805     // to zero.
2806     return BO->hasNoUnsignedWrap() ||
2807            (BO->hasNoSignedWrap() && match(Step, m_APInt(StepC)) &&
2808             StartC->isNegative() == StepC->isNegative());
2809   case Instruction::Mul:
2810     return (BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) &&
2811            match(Step, m_APInt(StepC)) && !StepC->isZero();
2812   case Instruction::Shl:
2813     return BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap();
2814   case Instruction::AShr:
2815   case Instruction::LShr:
2816     return BO->isExact();
2817   default:
2818     return false;
2819   }
2820 }
2821 
2822 static bool matchOpWithOpEqZero(Value *Op0, Value *Op1) {
2823   return match(Op0, m_ZExtOrSExt(m_SpecificICmp(ICmpInst::ICMP_EQ,
2824                                                 m_Specific(Op1), m_Zero()))) ||
2825          match(Op1, m_ZExtOrSExt(m_SpecificICmp(ICmpInst::ICMP_EQ,
2826                                                 m_Specific(Op0), m_Zero())));
2827 }
2828 
2829 static bool isNonZeroAdd(const APInt &DemandedElts, const SimplifyQuery &Q,
2830                          unsigned BitWidth, Value *X, Value *Y, bool NSW,
2831                          bool NUW, unsigned Depth) {
2832   // (X + (X != 0)) is non zero
2833   if (matchOpWithOpEqZero(X, Y))
2834     return true;
2835 
2836   if (NUW)
2837     return isKnownNonZero(Y, DemandedElts, Q, Depth) ||
2838            isKnownNonZero(X, DemandedElts, Q, Depth);
2839 
2840   KnownBits XKnown = computeKnownBits(X, DemandedElts, Q, Depth);
2841   KnownBits YKnown = computeKnownBits(Y, DemandedElts, Q, Depth);
2842 
2843   // If X and Y are both non-negative (as signed values) then their sum is not
2844   // zero unless both X and Y are zero.
2845   if (XKnown.isNonNegative() && YKnown.isNonNegative())
2846     if (isKnownNonZero(Y, DemandedElts, Q, Depth) ||
2847         isKnownNonZero(X, DemandedElts, Q, Depth))
2848       return true;
2849 
2850   // If X and Y are both negative (as signed values) then their sum is not
2851   // zero unless both X and Y equal INT_MIN.
2852   if (XKnown.isNegative() && YKnown.isNegative()) {
2853     APInt Mask = APInt::getSignedMaxValue(BitWidth);
2854     // The sign bit of X is set.  If some other bit is set then X is not equal
2855     // to INT_MIN.
2856     if (XKnown.One.intersects(Mask))
2857       return true;
2858     // The sign bit of Y is set.  If some other bit is set then Y is not equal
2859     // to INT_MIN.
2860     if (YKnown.One.intersects(Mask))
2861       return true;
2862   }
2863 
2864   // The sum of a non-negative number and a power of two is not zero.
2865   if (XKnown.isNonNegative() &&
2866       isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Q, Depth))
2867     return true;
2868   if (YKnown.isNonNegative() &&
2869       isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Q, Depth))
2870     return true;
2871 
2872   return KnownBits::add(XKnown, YKnown, NSW, NUW).isNonZero();
2873 }
2874 
2875 static bool isNonZeroSub(const APInt &DemandedElts, const SimplifyQuery &Q,
2876                          unsigned BitWidth, Value *X, Value *Y,
2877                          unsigned Depth) {
2878   // (X - (X != 0)) is non zero
2879   // ((X != 0) - X) is non zero
2880   if (matchOpWithOpEqZero(X, Y))
2881     return true;
2882 
2883   // TODO: Move this case into isKnownNonEqual().
2884   if (auto *C = dyn_cast<Constant>(X))
2885     if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
2886       return true;
2887 
2888   return ::isKnownNonEqual(X, Y, DemandedElts, Q, Depth);
2889 }
2890 
2891 static bool isNonZeroMul(const APInt &DemandedElts, const SimplifyQuery &Q,
2892                          unsigned BitWidth, Value *X, Value *Y, bool NSW,
2893                          bool NUW, unsigned Depth) {
2894   // If X and Y are non-zero then so is X * Y as long as the multiplication
2895   // does not overflow.
2896   if (NSW || NUW)
2897     return isKnownNonZero(X, DemandedElts, Q, Depth) &&
2898            isKnownNonZero(Y, DemandedElts, Q, Depth);
2899 
2900   // If either X or Y is odd, then if the other is non-zero the result can't
2901   // be zero.
2902   KnownBits XKnown = computeKnownBits(X, DemandedElts, Q, Depth);
2903   if (XKnown.One[0])
2904     return isKnownNonZero(Y, DemandedElts, Q, Depth);
2905 
2906   KnownBits YKnown = computeKnownBits(Y, DemandedElts, Q, Depth);
2907   if (YKnown.One[0])
2908     return XKnown.isNonZero() || isKnownNonZero(X, DemandedElts, Q, Depth);
2909 
2910   // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
2911   // non-zero, then X * Y is non-zero. We can find sX and sY by just taking
2912   // the lowest known One of X and Y. If they are non-zero, the result
2913   // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
2914   // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
2915   return (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros()) <
2916          BitWidth;
2917 }
2918 
2919 static bool isNonZeroShift(const Operator *I, const APInt &DemandedElts,
2920                            const SimplifyQuery &Q, const KnownBits &KnownVal,
2921                            unsigned Depth) {
2922   auto ShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
2923     switch (I->getOpcode()) {
2924     case Instruction::Shl:
2925       return Lhs.shl(Rhs);
2926     case Instruction::LShr:
2927       return Lhs.lshr(Rhs);
2928     case Instruction::AShr:
2929       return Lhs.ashr(Rhs);
2930     default:
2931       llvm_unreachable("Unknown Shift Opcode");
2932     }
2933   };
2934 
2935   auto InvShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
2936     switch (I->getOpcode()) {
2937     case Instruction::Shl:
2938       return Lhs.lshr(Rhs);
2939     case Instruction::LShr:
2940     case Instruction::AShr:
2941       return Lhs.shl(Rhs);
2942     default:
2943       llvm_unreachable("Unknown Shift Opcode");
2944     }
2945   };
2946 
2947   if (KnownVal.isUnknown())
2948     return false;
2949 
2950   KnownBits KnownCnt =
2951       computeKnownBits(I->getOperand(1), DemandedElts, Q, Depth);
2952   APInt MaxShift = KnownCnt.getMaxValue();
2953   unsigned NumBits = KnownVal.getBitWidth();
2954   if (MaxShift.uge(NumBits))
2955     return false;
2956 
2957   if (!ShiftOp(KnownVal.One, MaxShift).isZero())
2958     return true;
2959 
2960   // If all of the bits shifted out are known to be zero, and Val is known
2961   // non-zero then at least one non-zero bit must remain.
2962   if (InvShiftOp(KnownVal.Zero, NumBits - MaxShift)
2963           .eq(InvShiftOp(APInt::getAllOnes(NumBits), NumBits - MaxShift)) &&
2964       isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth))
2965     return true;
2966 
2967   return false;
2968 }
2969 
2970 static bool isKnownNonZeroFromOperator(const Operator *I,
2971                                        const APInt &DemandedElts,
2972                                        const SimplifyQuery &Q, unsigned Depth) {
2973   unsigned BitWidth = getBitWidth(I->getType()->getScalarType(), Q.DL);
2974   switch (I->getOpcode()) {
2975   case Instruction::Alloca:
2976     // Alloca never returns null, malloc might.
2977     return I->getType()->getPointerAddressSpace() == 0;
2978   case Instruction::GetElementPtr:
2979     if (I->getType()->isPointerTy())
2980       return isGEPKnownNonNull(cast<GEPOperator>(I), Q, Depth);
2981     break;
2982   case Instruction::BitCast: {
2983     // We need to be a bit careful here. We can only peek through the bitcast
2984     // if the scalar size of elements in the operand are smaller than and a
2985     // multiple of the size they are casting too. Take three cases:
2986     //
2987     // 1) Unsafe:
2988     //        bitcast <2 x i16> %NonZero to <4 x i8>
2989     //
2990     //    %NonZero can have 2 non-zero i16 elements, but isKnownNonZero on a
2991     //    <4 x i8> requires that all 4 i8 elements be non-zero which isn't
2992     //    guranteed (imagine just sign bit set in the 2 i16 elements).
2993     //
2994     // 2) Unsafe:
2995     //        bitcast <4 x i3> %NonZero to <3 x i4>
2996     //
2997     //    Even though the scalar size of the src (`i3`) is smaller than the
2998     //    scalar size of the dst `i4`, because `i3` is not a multiple of `i4`
2999     //    its possible for the `3 x i4` elements to be zero because there are
3000     //    some elements in the destination that don't contain any full src
3001     //    element.
3002     //
3003     // 3) Safe:
3004     //        bitcast <4 x i8> %NonZero to <2 x i16>
3005     //
3006     //    This is always safe as non-zero in the 4 i8 elements implies
3007     //    non-zero in the combination of any two adjacent ones. Since i8 is a
3008     //    multiple of i16, each i16 is guranteed to have 2 full i8 elements.
3009     //    This all implies the 2 i16 elements are non-zero.
3010     Type *FromTy = I->getOperand(0)->getType();
3011     if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
3012         (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
3013       return isKnownNonZero(I->getOperand(0), Q, Depth);
3014   } break;
3015   case Instruction::IntToPtr:
3016     // Note that we have to take special care to avoid looking through
3017     // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
3018     // as casts that can alter the value, e.g., AddrSpaceCasts.
3019     if (!isa<ScalableVectorType>(I->getType()) &&
3020         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
3021             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
3022       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3023     break;
3024   case Instruction::PtrToInt:
3025     // Similar to int2ptr above, we can look through ptr2int here if the cast
3026     // is a no-op or an extend and not a truncate.
3027     if (!isa<ScalableVectorType>(I->getType()) &&
3028         Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
3029             Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
3030       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3031     break;
3032   case Instruction::Trunc:
3033     // nuw/nsw trunc preserves zero/non-zero status of input.
3034     if (auto *TI = dyn_cast<TruncInst>(I))
3035       if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
3036         return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
3037     break;
3038 
3039   // Iff x - y != 0, then x ^ y != 0
3040   // Therefore we can do the same exact checks
3041   case Instruction::Xor:
3042   case Instruction::Sub:
3043     return isNonZeroSub(DemandedElts, Q, BitWidth, I->getOperand(0),
3044                         I->getOperand(1), Depth);
3045   case Instruction::Or:
3046     // (X | (X != 0)) is non zero
3047     if (matchOpWithOpEqZero(I->getOperand(0), I->getOperand(1)))
3048       return true;
3049     // X | Y != 0 if X != Y.
3050     if (isKnownNonEqual(I->getOperand(0), I->getOperand(1), DemandedElts, Q,
3051                         Depth))
3052       return true;
3053     // X | Y != 0 if X != 0 or Y != 0.
3054     return isKnownNonZero(I->getOperand(1), DemandedElts, Q, Depth) ||
3055            isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3056   case Instruction::SExt:
3057   case Instruction::ZExt:
3058     // ext X != 0 if X != 0.
3059     return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3060 
3061   case Instruction::Shl: {
3062     // shl nsw/nuw can't remove any non-zero bits.
3063     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
3064     if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
3065       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3066 
3067     // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
3068     // if the lowest bit is shifted off the end.
3069     KnownBits Known(BitWidth);
3070     computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth);
3071     if (Known.One[0])
3072       return true;
3073 
3074     return isNonZeroShift(I, DemandedElts, Q, Known, Depth);
3075   }
3076   case Instruction::LShr:
3077   case Instruction::AShr: {
3078     // shr exact can only shift out zero bits.
3079     const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
3080     if (BO->isExact())
3081       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3082 
3083     // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
3084     // defined if the sign bit is shifted off the end.
3085     KnownBits Known =
3086         computeKnownBits(I->getOperand(0), DemandedElts, Q, Depth);
3087     if (Known.isNegative())
3088       return true;
3089 
3090     return isNonZeroShift(I, DemandedElts, Q, Known, Depth);
3091   }
3092   case Instruction::UDiv:
3093   case Instruction::SDiv: {
3094     // X / Y
3095     // div exact can only produce a zero if the dividend is zero.
3096     if (cast<PossiblyExactOperator>(I)->isExact())
3097       return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
3098 
3099     KnownBits XKnown =
3100         computeKnownBits(I->getOperand(0), DemandedElts, Q, Depth);
3101     // If X is fully unknown we won't be able to figure anything out so don't
3102     // both computing knownbits for Y.
3103     if (XKnown.isUnknown())
3104       return false;
3105 
3106     KnownBits YKnown =
3107         computeKnownBits(I->getOperand(1), DemandedElts, Q, Depth);
3108     if (I->getOpcode() == Instruction::SDiv) {
3109       // For signed division need to compare abs value of the operands.
3110       XKnown = XKnown.abs(/*IntMinIsPoison*/ false);
3111       YKnown = YKnown.abs(/*IntMinIsPoison*/ false);
3112     }
3113     // If X u>= Y then div is non zero (0/0 is UB).
3114     std::optional<bool> XUgeY = KnownBits::uge(XKnown, YKnown);
3115     // If X is total unknown or X u< Y we won't be able to prove non-zero
3116     // with compute known bits so just return early.
3117     return XUgeY && *XUgeY;
3118   }
3119   case Instruction::Add: {
3120     // X + Y.
3121 
3122     // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
3123     // non-zero.
3124     auto *BO = cast<OverflowingBinaryOperator>(I);
3125     return isNonZeroAdd(DemandedElts, Q, BitWidth, I->getOperand(0),
3126                         I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
3127                         Q.IIQ.hasNoUnsignedWrap(BO), Depth);
3128   }
3129   case Instruction::Mul: {
3130     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
3131     return isNonZeroMul(DemandedElts, Q, BitWidth, I->getOperand(0),
3132                         I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
3133                         Q.IIQ.hasNoUnsignedWrap(BO), Depth);
3134   }
3135   case Instruction::Select: {
3136     // (C ? X : Y) != 0 if X != 0 and Y != 0.
3137 
3138     // First check if the arm is non-zero using `isKnownNonZero`. If that fails,
3139     // then see if the select condition implies the arm is non-zero. For example
3140     // (X != 0 ? X : Y), we know the true arm is non-zero as the `X` "return" is
3141     // dominated by `X != 0`.
3142     auto SelectArmIsNonZero = [&](bool IsTrueArm) {
3143       Value *Op;
3144       Op = IsTrueArm ? I->getOperand(1) : I->getOperand(2);
3145       // Op is trivially non-zero.
3146       if (isKnownNonZero(Op, DemandedElts, Q, Depth))
3147         return true;
3148 
3149       // The condition of the select dominates the true/false arm. Check if the
3150       // condition implies that a given arm is non-zero.
3151       Value *X;
3152       CmpPredicate Pred;
3153       if (!match(I->getOperand(0), m_c_ICmp(Pred, m_Specific(Op), m_Value(X))))
3154         return false;
3155 
3156       if (!IsTrueArm)
3157         Pred = ICmpInst::getInversePredicate(Pred);
3158 
3159       return cmpExcludesZero(Pred, X);
3160     };
3161 
3162     if (SelectArmIsNonZero(/* IsTrueArm */ true) &&
3163         SelectArmIsNonZero(/* IsTrueArm */ false))
3164       return true;
3165     break;
3166   }
3167   case Instruction::PHI: {
3168     auto *PN = cast<PHINode>(I);
3169     if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN))
3170       return true;
3171 
3172     // Check if all incoming values are non-zero using recursion.
3173     SimplifyQuery RecQ = Q.getWithoutCondContext();
3174     unsigned NewDepth = std::max(Depth, MaxAnalysisRecursionDepth - 1);
3175     return llvm::all_of(PN->operands(), [&](const Use &U) {
3176       if (U.get() == PN)
3177         return true;
3178       RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
3179       // Check if the branch on the phi excludes zero.
3180       CmpPredicate Pred;
3181       Value *X;
3182       BasicBlock *TrueSucc, *FalseSucc;
3183       if (match(RecQ.CxtI,
3184                 m_Br(m_c_ICmp(Pred, m_Specific(U.get()), m_Value(X)),
3185                      m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
3186         // Check for cases of duplicate successors.
3187         if ((TrueSucc == PN->getParent()) != (FalseSucc == PN->getParent())) {
3188           // If we're using the false successor, invert the predicate.
3189           if (FalseSucc == PN->getParent())
3190             Pred = CmpInst::getInversePredicate(Pred);
3191           if (cmpExcludesZero(Pred, X))
3192             return true;
3193         }
3194       }
3195       // Finally recurse on the edge and check it directly.
3196       return isKnownNonZero(U.get(), DemandedElts, RecQ, NewDepth);
3197     });
3198   }
3199   case Instruction::InsertElement: {
3200     if (isa<ScalableVectorType>(I->getType()))
3201       break;
3202 
3203     const Value *Vec = I->getOperand(0);
3204     const Value *Elt = I->getOperand(1);
3205     auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2));
3206 
3207     unsigned NumElts = DemandedElts.getBitWidth();
3208     APInt DemandedVecElts = DemandedElts;
3209     bool SkipElt = false;
3210     // If we know the index we are inserting too, clear it from Vec check.
3211     if (CIdx && CIdx->getValue().ult(NumElts)) {
3212       DemandedVecElts.clearBit(CIdx->getZExtValue());
3213       SkipElt = !DemandedElts[CIdx->getZExtValue()];
3214     }
3215 
3216     // Result is zero if Elt is non-zero and rest of the demanded elts in Vec
3217     // are non-zero.
3218     return (SkipElt || isKnownNonZero(Elt, Q, Depth)) &&
3219            (DemandedVecElts.isZero() ||
3220             isKnownNonZero(Vec, DemandedVecElts, Q, Depth));
3221   }
3222   case Instruction::ExtractElement:
3223     if (const auto *EEI = dyn_cast<ExtractElementInst>(I)) {
3224       const Value *Vec = EEI->getVectorOperand();
3225       const Value *Idx = EEI->getIndexOperand();
3226       auto *CIdx = dyn_cast<ConstantInt>(Idx);
3227       if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) {
3228         unsigned NumElts = VecTy->getNumElements();
3229         APInt DemandedVecElts = APInt::getAllOnes(NumElts);
3230         if (CIdx && CIdx->getValue().ult(NumElts))
3231           DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
3232         return isKnownNonZero(Vec, DemandedVecElts, Q, Depth);
3233       }
3234     }
3235     break;
3236   case Instruction::ShuffleVector: {
3237     auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
3238     if (!Shuf)
3239       break;
3240     APInt DemandedLHS, DemandedRHS;
3241     // For undef elements, we don't know anything about the common state of
3242     // the shuffle result.
3243     if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
3244       break;
3245     // If demanded elements for both vecs are non-zero, the shuffle is non-zero.
3246     return (DemandedRHS.isZero() ||
3247             isKnownNonZero(Shuf->getOperand(1), DemandedRHS, Q, Depth)) &&
3248            (DemandedLHS.isZero() ||
3249             isKnownNonZero(Shuf->getOperand(0), DemandedLHS, Q, Depth));
3250   }
3251   case Instruction::Freeze:
3252     return isKnownNonZero(I->getOperand(0), Q, Depth) &&
3253            isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT,
3254                                      Depth);
3255   case Instruction::Load: {
3256     auto *LI = cast<LoadInst>(I);
3257     // A Load tagged with nonnull or dereferenceable with null pointer undefined
3258     // is never null.
3259     if (auto *PtrT = dyn_cast<PointerType>(I->getType())) {
3260       if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull) ||
3261           (Q.IIQ.getMetadata(LI, LLVMContext::MD_dereferenceable) &&
3262            !NullPointerIsDefined(LI->getFunction(), PtrT->getAddressSpace())))
3263         return true;
3264     } else if (MDNode *Ranges = Q.IIQ.getMetadata(LI, LLVMContext::MD_range)) {
3265       return rangeMetadataExcludesValue(Ranges, APInt::getZero(BitWidth));
3266     }
3267 
3268     // No need to fall through to computeKnownBits as range metadata is already
3269     // handled in isKnownNonZero.
3270     return false;
3271   }
3272   case Instruction::ExtractValue: {
3273     const WithOverflowInst *WO;
3274     if (match(I, m_ExtractValue<0>(m_WithOverflowInst(WO)))) {
3275       switch (WO->getBinaryOp()) {
3276       default:
3277         break;
3278       case Instruction::Add:
3279         return isNonZeroAdd(DemandedElts, Q, BitWidth, WO->getArgOperand(0),
3280                             WO->getArgOperand(1),
3281                             /*NSW=*/false,
3282                             /*NUW=*/false, Depth);
3283       case Instruction::Sub:
3284         return isNonZeroSub(DemandedElts, Q, BitWidth, WO->getArgOperand(0),
3285                             WO->getArgOperand(1), Depth);
3286       case Instruction::Mul:
3287         return isNonZeroMul(DemandedElts, Q, BitWidth, WO->getArgOperand(0),
3288                             WO->getArgOperand(1),
3289                             /*NSW=*/false, /*NUW=*/false, Depth);
3290         break;
3291       }
3292     }
3293     break;
3294   }
3295   case Instruction::Call:
3296   case Instruction::Invoke: {
3297     const auto *Call = cast<CallBase>(I);
3298     if (I->getType()->isPointerTy()) {
3299       if (Call->isReturnNonNull())
3300         return true;
3301       if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, true))
3302         return isKnownNonZero(RP, Q, Depth);
3303     } else {
3304       if (MDNode *Ranges = Q.IIQ.getMetadata(Call, LLVMContext::MD_range))
3305         return rangeMetadataExcludesValue(Ranges, APInt::getZero(BitWidth));
3306       if (std::optional<ConstantRange> Range = Call->getRange()) {
3307         const APInt ZeroValue(Range->getBitWidth(), 0);
3308         if (!Range->contains(ZeroValue))
3309           return true;
3310       }
3311       if (const Value *RV = Call->getReturnedArgOperand())
3312         if (RV->getType() == I->getType() && isKnownNonZero(RV, Q, Depth))
3313           return true;
3314     }
3315 
3316     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
3317       switch (II->getIntrinsicID()) {
3318       case Intrinsic::sshl_sat:
3319       case Intrinsic::ushl_sat:
3320       case Intrinsic::abs:
3321       case Intrinsic::bitreverse:
3322       case Intrinsic::bswap:
3323       case Intrinsic::ctpop:
3324         return isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth);
3325         // NB: We don't do usub_sat here as in any case we can prove its
3326         // non-zero, we will fold it to `sub nuw` in InstCombine.
3327       case Intrinsic::ssub_sat:
3328         return isNonZeroSub(DemandedElts, Q, BitWidth, II->getArgOperand(0),
3329                             II->getArgOperand(1), Depth);
3330       case Intrinsic::sadd_sat:
3331         return isNonZeroAdd(DemandedElts, Q, BitWidth, II->getArgOperand(0),
3332                             II->getArgOperand(1),
3333                             /*NSW=*/true, /* NUW=*/false, Depth);
3334         // Vec reverse preserves zero/non-zero status from input vec.
3335       case Intrinsic::vector_reverse:
3336         return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
3337                               Q, Depth);
3338         // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
3339       case Intrinsic::vector_reduce_or:
3340       case Intrinsic::vector_reduce_umax:
3341       case Intrinsic::vector_reduce_umin:
3342       case Intrinsic::vector_reduce_smax:
3343       case Intrinsic::vector_reduce_smin:
3344         return isKnownNonZero(II->getArgOperand(0), Q, Depth);
3345       case Intrinsic::umax:
3346       case Intrinsic::uadd_sat:
3347         // umax(X, (X != 0)) is non zero
3348         // X +usat (X != 0) is non zero
3349         if (matchOpWithOpEqZero(II->getArgOperand(0), II->getArgOperand(1)))
3350           return true;
3351 
3352         return isKnownNonZero(II->getArgOperand(1), DemandedElts, Q, Depth) ||
3353                isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth);
3354       case Intrinsic::smax: {
3355         // If either arg is strictly positive the result is non-zero. Otherwise
3356         // the result is non-zero if both ops are non-zero.
3357         auto IsNonZero = [&](Value *Op, std::optional<bool> &OpNonZero,
3358                              const KnownBits &OpKnown) {
3359           if (!OpNonZero.has_value())
3360             OpNonZero = OpKnown.isNonZero() ||
3361                         isKnownNonZero(Op, DemandedElts, Q, Depth);
3362           return *OpNonZero;
3363         };
3364         // Avoid re-computing isKnownNonZero.
3365         std::optional<bool> Op0NonZero, Op1NonZero;
3366         KnownBits Op1Known =
3367             computeKnownBits(II->getArgOperand(1), DemandedElts, Q, Depth);
3368         if (Op1Known.isNonNegative() &&
3369             IsNonZero(II->getArgOperand(1), Op1NonZero, Op1Known))
3370           return true;
3371         KnownBits Op0Known =
3372             computeKnownBits(II->getArgOperand(0), DemandedElts, Q, Depth);
3373         if (Op0Known.isNonNegative() &&
3374             IsNonZero(II->getArgOperand(0), Op0NonZero, Op0Known))
3375           return true;
3376         return IsNonZero(II->getArgOperand(1), Op1NonZero, Op1Known) &&
3377                IsNonZero(II->getArgOperand(0), Op0NonZero, Op0Known);
3378       }
3379       case Intrinsic::smin: {
3380         // If either arg is negative the result is non-zero. Otherwise
3381         // the result is non-zero if both ops are non-zero.
3382         KnownBits Op1Known =
3383             computeKnownBits(II->getArgOperand(1), DemandedElts, Q, Depth);
3384         if (Op1Known.isNegative())
3385           return true;
3386         KnownBits Op0Known =
3387             computeKnownBits(II->getArgOperand(0), DemandedElts, Q, Depth);
3388         if (Op0Known.isNegative())
3389           return true;
3390 
3391         if (Op1Known.isNonZero() && Op0Known.isNonZero())
3392           return true;
3393       }
3394         [[fallthrough]];
3395       case Intrinsic::umin:
3396         return isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth) &&
3397                isKnownNonZero(II->getArgOperand(1), DemandedElts, Q, Depth);
3398       case Intrinsic::cttz:
3399         return computeKnownBits(II->getArgOperand(0), DemandedElts, Q, Depth)
3400             .Zero[0];
3401       case Intrinsic::ctlz:
3402         return computeKnownBits(II->getArgOperand(0), DemandedElts, Q, Depth)
3403             .isNonNegative();
3404       case Intrinsic::fshr:
3405       case Intrinsic::fshl:
3406         // If Op0 == Op1, this is a rotate. rotate(x, y) != 0 iff x != 0.
3407         if (II->getArgOperand(0) == II->getArgOperand(1))
3408           return isKnownNonZero(II->getArgOperand(0), DemandedElts, Q, Depth);
3409         break;
3410       case Intrinsic::vscale:
3411         return true;
3412       case Intrinsic::experimental_get_vector_length:
3413         return isKnownNonZero(I->getOperand(0), Q, Depth);
3414       default:
3415         break;
3416       }
3417       break;
3418     }
3419 
3420     return false;
3421   }
3422   }
3423 
3424   KnownBits Known(BitWidth);
3425   computeKnownBits(I, DemandedElts, Known, Q, Depth);
3426   return Known.One != 0;
3427 }
3428 
3429 /// Return true if the given value is known to be non-zero when defined. For
3430 /// vectors, return true if every demanded element is known to be non-zero when
3431 /// defined. For pointers, if the context instruction and dominator tree are
3432 /// specified, perform context-sensitive analysis and return true if the
3433 /// pointer couldn't possibly be null at the specified instruction.
3434 /// Supports values with integer or pointer type and vectors of integers.
3435 bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
3436                     const SimplifyQuery &Q, unsigned Depth) {
3437   Type *Ty = V->getType();
3438 
3439 #ifndef NDEBUG
3440   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
3441 
3442   if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
3443     assert(
3444         FVTy->getNumElements() == DemandedElts.getBitWidth() &&
3445         "DemandedElt width should equal the fixed vector number of elements");
3446   } else {
3447     assert(DemandedElts == APInt(1, 1) &&
3448            "DemandedElt width should be 1 for scalars");
3449   }
3450 #endif
3451 
3452   if (auto *C = dyn_cast<Constant>(V)) {
3453     if (C->isNullValue())
3454       return false;
3455     if (isa<ConstantInt>(C))
3456       // Must be non-zero due to null test above.
3457       return true;
3458 
3459     // For constant vectors, check that all elements are poison or known
3460     // non-zero to determine that the whole vector is known non-zero.
3461     if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
3462       for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
3463         if (!DemandedElts[i])
3464           continue;
3465         Constant *Elt = C->getAggregateElement(i);
3466         if (!Elt || Elt->isNullValue())
3467           return false;
3468         if (!isa<PoisonValue>(Elt) && !isa<ConstantInt>(Elt))
3469           return false;
3470       }
3471       return true;
3472     }
3473 
3474     // Constant ptrauth can be null, iff the base pointer can be.
3475     if (auto *CPA = dyn_cast<ConstantPtrAuth>(V))
3476       return isKnownNonZero(CPA->getPointer(), DemandedElts, Q, Depth);
3477 
3478     // A global variable in address space 0 is non null unless extern weak
3479     // or an absolute symbol reference. Other address spaces may have null as a
3480     // valid address for a global, so we can't assume anything.
3481     if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
3482       if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() &&
3483           GV->getType()->getAddressSpace() == 0)
3484         return true;
3485     }
3486 
3487     // For constant expressions, fall through to the Operator code below.
3488     if (!isa<ConstantExpr>(V))
3489       return false;
3490   }
3491 
3492   if (const auto *A = dyn_cast<Argument>(V))
3493     if (std::optional<ConstantRange> Range = A->getRange()) {
3494       const APInt ZeroValue(Range->getBitWidth(), 0);
3495       if (!Range->contains(ZeroValue))
3496         return true;
3497     }
3498 
3499   if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q))
3500     return true;
3501 
3502   // Some of the tests below are recursive, so bail out if we hit the limit.
3503   if (Depth++ >= MaxAnalysisRecursionDepth)
3504     return false;
3505 
3506   // Check for pointer simplifications.
3507 
3508   if (PointerType *PtrTy = dyn_cast<PointerType>(Ty)) {
3509     // A byval, inalloca may not be null in a non-default addres space. A
3510     // nonnull argument is assumed never 0.
3511     if (const Argument *A = dyn_cast<Argument>(V)) {
3512       if (((A->hasPassPointeeByValueCopyAttr() &&
3513             !NullPointerIsDefined(A->getParent(), PtrTy->getAddressSpace())) ||
3514            A->hasNonNullAttr()))
3515         return true;
3516     }
3517   }
3518 
3519   if (const auto *I = dyn_cast<Operator>(V))
3520     if (isKnownNonZeroFromOperator(I, DemandedElts, Q, Depth))
3521       return true;
3522 
3523   if (!isa<Constant>(V) &&
3524       isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT))
3525     return true;
3526 
3527   if (const Value *Stripped = stripNullTest(V))
3528     return isKnownNonZero(Stripped, DemandedElts, Q, Depth);
3529 
3530   return false;
3531 }
3532 
3533 bool llvm::isKnownNonZero(const Value *V, const SimplifyQuery &Q,
3534                           unsigned Depth) {
3535   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
3536   APInt DemandedElts =
3537       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
3538   return ::isKnownNonZero(V, DemandedElts, Q, Depth);
3539 }
3540 
3541 /// If the pair of operators are the same invertible function, return the
3542 /// the operands of the function corresponding to each input. Otherwise,
3543 /// return std::nullopt.  An invertible function is one that is 1-to-1 and maps
3544 /// every input value to exactly one output value.  This is equivalent to
3545 /// saying that Op1 and Op2 are equal exactly when the specified pair of
3546 /// operands are equal, (except that Op1 and Op2 may be poison more often.)
3547 static std::optional<std::pair<Value*, Value*>>
3548 getInvertibleOperands(const Operator *Op1,
3549                       const Operator *Op2) {
3550   if (Op1->getOpcode() != Op2->getOpcode())
3551     return std::nullopt;
3552 
3553   auto getOperands = [&](unsigned OpNum) -> auto {
3554     return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum));
3555   };
3556 
3557   switch (Op1->getOpcode()) {
3558   default:
3559     break;
3560   case Instruction::Or:
3561     if (!cast<PossiblyDisjointInst>(Op1)->isDisjoint() ||
3562         !cast<PossiblyDisjointInst>(Op2)->isDisjoint())
3563       break;
3564     [[fallthrough]];
3565   case Instruction::Xor:
3566   case Instruction::Add: {
3567     Value *Other;
3568     if (match(Op2, m_c_BinOp(m_Specific(Op1->getOperand(0)), m_Value(Other))))
3569       return std::make_pair(Op1->getOperand(1), Other);
3570     if (match(Op2, m_c_BinOp(m_Specific(Op1->getOperand(1)), m_Value(Other))))
3571       return std::make_pair(Op1->getOperand(0), Other);
3572     break;
3573   }
3574   case Instruction::Sub:
3575     if (Op1->getOperand(0) == Op2->getOperand(0))
3576       return getOperands(1);
3577     if (Op1->getOperand(1) == Op2->getOperand(1))
3578       return getOperands(0);
3579     break;
3580   case Instruction::Mul: {
3581     // invertible if A * B == (A * B) mod 2^N where A, and B are integers
3582     // and N is the bitwdith.  The nsw case is non-obvious, but proven by
3583     // alive2: https://alive2.llvm.org/ce/z/Z6D5qK
3584     auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
3585     auto *OBO2 = cast<OverflowingBinaryOperator>(Op2);
3586     if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3587         (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3588       break;
3589 
3590     // Assume operand order has been canonicalized
3591     if (Op1->getOperand(1) == Op2->getOperand(1) &&
3592         isa<ConstantInt>(Op1->getOperand(1)) &&
3593         !cast<ConstantInt>(Op1->getOperand(1))->isZero())
3594       return getOperands(0);
3595     break;
3596   }
3597   case Instruction::Shl: {
3598     // Same as multiplies, with the difference that we don't need to check
3599     // for a non-zero multiply. Shifts always multiply by non-zero.
3600     auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
3601     auto *OBO2 = cast<OverflowingBinaryOperator>(Op2);
3602     if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3603         (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3604       break;
3605 
3606     if (Op1->getOperand(1) == Op2->getOperand(1))
3607       return getOperands(0);
3608     break;
3609   }
3610   case Instruction::AShr:
3611   case Instruction::LShr: {
3612     auto *PEO1 = cast<PossiblyExactOperator>(Op1);
3613     auto *PEO2 = cast<PossiblyExactOperator>(Op2);
3614     if (!PEO1->isExact() || !PEO2->isExact())
3615       break;
3616 
3617     if (Op1->getOperand(1) == Op2->getOperand(1))
3618       return getOperands(0);
3619     break;
3620   }
3621   case Instruction::SExt:
3622   case Instruction::ZExt:
3623     if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType())
3624       return getOperands(0);
3625     break;
3626   case Instruction::PHI: {
3627     const PHINode *PN1 = cast<PHINode>(Op1);
3628     const PHINode *PN2 = cast<PHINode>(Op2);
3629 
3630     // If PN1 and PN2 are both recurrences, can we prove the entire recurrences
3631     // are a single invertible function of the start values? Note that repeated
3632     // application of an invertible function is also invertible
3633     BinaryOperator *BO1 = nullptr;
3634     Value *Start1 = nullptr, *Step1 = nullptr;
3635     BinaryOperator *BO2 = nullptr;
3636     Value *Start2 = nullptr, *Step2 = nullptr;
3637     if (PN1->getParent() != PN2->getParent() ||
3638         !matchSimpleRecurrence(PN1, BO1, Start1, Step1) ||
3639         !matchSimpleRecurrence(PN2, BO2, Start2, Step2))
3640       break;
3641 
3642     auto Values = getInvertibleOperands(cast<Operator>(BO1),
3643                                         cast<Operator>(BO2));
3644     if (!Values)
3645        break;
3646 
3647     // We have to be careful of mutually defined recurrences here.  Ex:
3648     // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V
3649     // * X_i = Y_i = X_(i-1) OP Y_(i-1)
3650     // The invertibility of these is complicated, and not worth reasoning
3651     // about (yet?).
3652     if (Values->first != PN1 || Values->second != PN2)
3653       break;
3654 
3655     return std::make_pair(Start1, Start2);
3656   }
3657   }
3658   return std::nullopt;
3659 }
3660 
3661 /// Return true if V1 == (binop V2, X), where X is known non-zero.
3662 /// Only handle a small subset of binops where (binop V2, X) with non-zero X
3663 /// implies V2 != V1.
3664 static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
3665                                       const APInt &DemandedElts,
3666                                       const SimplifyQuery &Q, unsigned Depth) {
3667   const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
3668   if (!BO)
3669     return false;
3670   switch (BO->getOpcode()) {
3671   default:
3672     break;
3673   case Instruction::Or:
3674     if (!cast<PossiblyDisjointInst>(V1)->isDisjoint())
3675       break;
3676     [[fallthrough]];
3677   case Instruction::Xor:
3678   case Instruction::Add:
3679     Value *Op = nullptr;
3680     if (V2 == BO->getOperand(0))
3681       Op = BO->getOperand(1);
3682     else if (V2 == BO->getOperand(1))
3683       Op = BO->getOperand(0);
3684     else
3685       return false;
3686     return isKnownNonZero(Op, DemandedElts, Q, Depth + 1);
3687   }
3688   return false;
3689 }
3690 
3691 /// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
3692 /// the multiplication is nuw or nsw.
3693 static bool isNonEqualMul(const Value *V1, const Value *V2,
3694                           const APInt &DemandedElts, const SimplifyQuery &Q,
3695                           unsigned Depth) {
3696   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
3697     const APInt *C;
3698     return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
3699            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3700            !C->isZero() && !C->isOne() &&
3701            isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
3702   }
3703   return false;
3704 }
3705 
3706 /// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
3707 /// the shift is nuw or nsw.
3708 static bool isNonEqualShl(const Value *V1, const Value *V2,
3709                           const APInt &DemandedElts, const SimplifyQuery &Q,
3710                           unsigned Depth) {
3711   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
3712     const APInt *C;
3713     return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
3714            (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3715            !C->isZero() && isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
3716   }
3717   return false;
3718 }
3719 
3720 static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
3721                            const APInt &DemandedElts, const SimplifyQuery &Q,
3722                            unsigned Depth) {
3723   // Check two PHIs are in same block.
3724   if (PN1->getParent() != PN2->getParent())
3725     return false;
3726 
3727   SmallPtrSet<const BasicBlock *, 8> VisitedBBs;
3728   bool UsedFullRecursion = false;
3729   for (const BasicBlock *IncomBB : PN1->blocks()) {
3730     if (!VisitedBBs.insert(IncomBB).second)
3731       continue; // Don't reprocess blocks that we have dealt with already.
3732     const Value *IV1 = PN1->getIncomingValueForBlock(IncomBB);
3733     const Value *IV2 = PN2->getIncomingValueForBlock(IncomBB);
3734     const APInt *C1, *C2;
3735     if (match(IV1, m_APInt(C1)) && match(IV2, m_APInt(C2)) && *C1 != *C2)
3736       continue;
3737 
3738     // Only one pair of phi operands is allowed for full recursion.
3739     if (UsedFullRecursion)
3740       return false;
3741 
3742     SimplifyQuery RecQ = Q.getWithoutCondContext();
3743     RecQ.CxtI = IncomBB->getTerminator();
3744     if (!isKnownNonEqual(IV1, IV2, DemandedElts, RecQ, Depth + 1))
3745       return false;
3746     UsedFullRecursion = true;
3747   }
3748   return true;
3749 }
3750 
3751 static bool isNonEqualSelect(const Value *V1, const Value *V2,
3752                              const APInt &DemandedElts, const SimplifyQuery &Q,
3753                              unsigned Depth) {
3754   const SelectInst *SI1 = dyn_cast<SelectInst>(V1);
3755   if (!SI1)
3756     return false;
3757 
3758   if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2)) {
3759     const Value *Cond1 = SI1->getCondition();
3760     const Value *Cond2 = SI2->getCondition();
3761     if (Cond1 == Cond2)
3762       return isKnownNonEqual(SI1->getTrueValue(), SI2->getTrueValue(),
3763                              DemandedElts, Q, Depth + 1) &&
3764              isKnownNonEqual(SI1->getFalseValue(), SI2->getFalseValue(),
3765                              DemandedElts, Q, Depth + 1);
3766   }
3767   return isKnownNonEqual(SI1->getTrueValue(), V2, DemandedElts, Q, Depth + 1) &&
3768          isKnownNonEqual(SI1->getFalseValue(), V2, DemandedElts, Q, Depth + 1);
3769 }
3770 
3771 // Check to see if A is both a GEP and is the incoming value for a PHI in the
3772 // loop, and B is either a ptr or another GEP. If the PHI has 2 incoming values,
3773 // one of them being the recursive GEP A and the other a ptr at same base and at
3774 // the same/higher offset than B we are only incrementing the pointer further in
3775 // loop if offset of recursive GEP is greater than 0.
3776 static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
3777                                                const SimplifyQuery &Q) {
3778   if (!A->getType()->isPointerTy() || !B->getType()->isPointerTy())
3779     return false;
3780 
3781   auto *GEPA = dyn_cast<GEPOperator>(A);
3782   if (!GEPA || GEPA->getNumIndices() != 1 || !isa<Constant>(GEPA->idx_begin()))
3783     return false;
3784 
3785   // Handle 2 incoming PHI values with one being a recursive GEP.
3786   auto *PN = dyn_cast<PHINode>(GEPA->getPointerOperand());
3787   if (!PN || PN->getNumIncomingValues() != 2)
3788     return false;
3789 
3790   // Search for the recursive GEP as an incoming operand, and record that as
3791   // Step.
3792   Value *Start = nullptr;
3793   Value *Step = const_cast<Value *>(A);
3794   if (PN->getIncomingValue(0) == Step)
3795     Start = PN->getIncomingValue(1);
3796   else if (PN->getIncomingValue(1) == Step)
3797     Start = PN->getIncomingValue(0);
3798   else
3799     return false;
3800 
3801   // Other incoming node base should match the B base.
3802   // StartOffset >= OffsetB && StepOffset > 0?
3803   // StartOffset <= OffsetB && StepOffset < 0?
3804   // Is non-equal if above are true.
3805   // We use stripAndAccumulateInBoundsConstantOffsets to restrict the
3806   // optimisation to inbounds GEPs only.
3807   unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Start->getType());
3808   APInt StartOffset(IndexWidth, 0);
3809   Start = Start->stripAndAccumulateInBoundsConstantOffsets(Q.DL, StartOffset);
3810   APInt StepOffset(IndexWidth, 0);
3811   Step = Step->stripAndAccumulateInBoundsConstantOffsets(Q.DL, StepOffset);
3812 
3813   // Check if Base Pointer of Step matches the PHI.
3814   if (Step != PN)
3815     return false;
3816   APInt OffsetB(IndexWidth, 0);
3817   B = B->stripAndAccumulateInBoundsConstantOffsets(Q.DL, OffsetB);
3818   return Start == B &&
3819          ((StartOffset.sge(OffsetB) && StepOffset.isStrictlyPositive()) ||
3820           (StartOffset.sle(OffsetB) && StepOffset.isNegative()));
3821 }
3822 
3823 static bool isKnownNonEqualFromContext(const Value *V1, const Value *V2,
3824                                        const SimplifyQuery &Q, unsigned Depth) {
3825   if (!Q.CxtI)
3826     return false;
3827 
3828   // Try to infer NonEqual based on information from dominating conditions.
3829   if (Q.DC && Q.DT) {
3830     auto IsKnownNonEqualFromDominatingCondition = [&](const Value *V) {
3831       for (BranchInst *BI : Q.DC->conditionsFor(V)) {
3832         Value *Cond = BI->getCondition();
3833         BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
3834         if (Q.DT->dominates(Edge0, Q.CxtI->getParent()) &&
3835             isImpliedCondition(Cond, ICmpInst::ICMP_NE, V1, V2, Q.DL,
3836                                /*LHSIsTrue=*/true, Depth)
3837                 .value_or(false))
3838           return true;
3839 
3840         BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
3841         if (Q.DT->dominates(Edge1, Q.CxtI->getParent()) &&
3842             isImpliedCondition(Cond, ICmpInst::ICMP_NE, V1, V2, Q.DL,
3843                                /*LHSIsTrue=*/false, Depth)
3844                 .value_or(false))
3845           return true;
3846       }
3847 
3848       return false;
3849     };
3850 
3851     if (IsKnownNonEqualFromDominatingCondition(V1) ||
3852         IsKnownNonEqualFromDominatingCondition(V2))
3853       return true;
3854   }
3855 
3856   if (!Q.AC)
3857     return false;
3858 
3859   // Try to infer NonEqual based on information from assumptions.
3860   for (auto &AssumeVH : Q.AC->assumptionsFor(V1)) {
3861     if (!AssumeVH)
3862       continue;
3863     CallInst *I = cast<CallInst>(AssumeVH);
3864 
3865     assert(I->getFunction() == Q.CxtI->getFunction() &&
3866            "Got assumption for the wrong function!");
3867     assert(I->getIntrinsicID() == Intrinsic::assume &&
3868            "must be an assume intrinsic");
3869 
3870     if (isImpliedCondition(I->getArgOperand(0), ICmpInst::ICMP_NE, V1, V2, Q.DL,
3871                            /*LHSIsTrue=*/true, Depth)
3872             .value_or(false) &&
3873         isValidAssumeForContext(I, Q.CxtI, Q.DT))
3874       return true;
3875   }
3876 
3877   return false;
3878 }
3879 
3880 /// Return true if it is known that V1 != V2.
3881 static bool isKnownNonEqual(const Value *V1, const Value *V2,
3882                             const APInt &DemandedElts, const SimplifyQuery &Q,
3883                             unsigned Depth) {
3884   if (V1 == V2)
3885     return false;
3886   if (V1->getType() != V2->getType())
3887     // We can't look through casts yet.
3888     return false;
3889 
3890   if (Depth >= MaxAnalysisRecursionDepth)
3891     return false;
3892 
3893   // See if we can recurse through (exactly one of) our operands.  This
3894   // requires our operation be 1-to-1 and map every input value to exactly
3895   // one output value.  Such an operation is invertible.
3896   auto *O1 = dyn_cast<Operator>(V1);
3897   auto *O2 = dyn_cast<Operator>(V2);
3898   if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
3899     if (auto Values = getInvertibleOperands(O1, O2))
3900       return isKnownNonEqual(Values->first, Values->second, DemandedElts, Q,
3901                              Depth + 1);
3902 
3903     if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
3904       const PHINode *PN2 = cast<PHINode>(V2);
3905       // FIXME: This is missing a generalization to handle the case where one is
3906       // a PHI and another one isn't.
3907       if (isNonEqualPHIs(PN1, PN2, DemandedElts, Q, Depth))
3908         return true;
3909     };
3910   }
3911 
3912   if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Q, Depth) ||
3913       isModifyingBinopOfNonZero(V2, V1, DemandedElts, Q, Depth))
3914     return true;
3915 
3916   if (isNonEqualMul(V1, V2, DemandedElts, Q, Depth) ||
3917       isNonEqualMul(V2, V1, DemandedElts, Q, Depth))
3918     return true;
3919 
3920   if (isNonEqualShl(V1, V2, DemandedElts, Q, Depth) ||
3921       isNonEqualShl(V2, V1, DemandedElts, Q, Depth))
3922     return true;
3923 
3924   if (V1->getType()->isIntOrIntVectorTy()) {
3925     // Are any known bits in V1 contradictory to known bits in V2? If V1
3926     // has a known zero where V2 has a known one, they must not be equal.
3927     KnownBits Known1 = computeKnownBits(V1, DemandedElts, Q, Depth);
3928     if (!Known1.isUnknown()) {
3929       KnownBits Known2 = computeKnownBits(V2, DemandedElts, Q, Depth);
3930       if (Known1.Zero.intersects(Known2.One) ||
3931           Known2.Zero.intersects(Known1.One))
3932         return true;
3933     }
3934   }
3935 
3936   if (isNonEqualSelect(V1, V2, DemandedElts, Q, Depth) ||
3937       isNonEqualSelect(V2, V1, DemandedElts, Q, Depth))
3938     return true;
3939 
3940   if (isNonEqualPointersWithRecursiveGEP(V1, V2, Q) ||
3941       isNonEqualPointersWithRecursiveGEP(V2, V1, Q))
3942     return true;
3943 
3944   Value *A, *B;
3945   // PtrToInts are NonEqual if their Ptrs are NonEqual.
3946   // Check PtrToInt type matches the pointer size.
3947   if (match(V1, m_PtrToIntSameSize(Q.DL, m_Value(A))) &&
3948       match(V2, m_PtrToIntSameSize(Q.DL, m_Value(B))))
3949     return isKnownNonEqual(A, B, DemandedElts, Q, Depth + 1);
3950 
3951   if (isKnownNonEqualFromContext(V1, V2, Q, Depth))
3952     return true;
3953 
3954   return false;
3955 }
3956 
3957 /// For vector constants, loop over the elements and find the constant with the
3958 /// minimum number of sign bits. Return 0 if the value is not a vector constant
3959 /// or if any element was not analyzed; otherwise, return the count for the
3960 /// element with the minimum number of sign bits.
3961 static unsigned computeNumSignBitsVectorConstant(const Value *V,
3962                                                  const APInt &DemandedElts,
3963                                                  unsigned TyBits) {
3964   const auto *CV = dyn_cast<Constant>(V);
3965   if (!CV || !isa<FixedVectorType>(CV->getType()))
3966     return 0;
3967 
3968   unsigned MinSignBits = TyBits;
3969   unsigned NumElts = cast<FixedVectorType>(CV->getType())->getNumElements();
3970   for (unsigned i = 0; i != NumElts; ++i) {
3971     if (!DemandedElts[i])
3972       continue;
3973     // If we find a non-ConstantInt, bail out.
3974     auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i));
3975     if (!Elt)
3976       return 0;
3977 
3978     MinSignBits = std::min(MinSignBits, Elt->getValue().getNumSignBits());
3979   }
3980 
3981   return MinSignBits;
3982 }
3983 
3984 static unsigned ComputeNumSignBitsImpl(const Value *V,
3985                                        const APInt &DemandedElts,
3986                                        const SimplifyQuery &Q, unsigned Depth);
3987 
3988 static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
3989                                    const SimplifyQuery &Q, unsigned Depth) {
3990   unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Q, Depth);
3991   assert(Result > 0 && "At least one sign bit needs to be present!");
3992   return Result;
3993 }
3994 
3995 /// Return the number of times the sign bit of the register is replicated into
3996 /// the other bits. We know that at least 1 bit is always equal to the sign bit
3997 /// (itself), but other cases can give us information. For example, immediately
3998 /// after an "ashr X, 2", we know that the top 3 bits are all equal to each
3999 /// other, so we return 3. For vectors, return the number of sign bits for the
4000 /// vector element with the minimum number of known sign bits of the demanded
4001 /// elements in the vector specified by DemandedElts.
4002 static unsigned ComputeNumSignBitsImpl(const Value *V,
4003                                        const APInt &DemandedElts,
4004                                        const SimplifyQuery &Q, unsigned Depth) {
4005   Type *Ty = V->getType();
4006 #ifndef NDEBUG
4007   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4008 
4009   if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
4010     assert(
4011         FVTy->getNumElements() == DemandedElts.getBitWidth() &&
4012         "DemandedElt width should equal the fixed vector number of elements");
4013   } else {
4014     assert(DemandedElts == APInt(1, 1) &&
4015            "DemandedElt width should be 1 for scalars");
4016   }
4017 #endif
4018 
4019   // We return the minimum number of sign bits that are guaranteed to be present
4020   // in V, so for undef we have to conservatively return 1.  We don't have the
4021   // same behavior for poison though -- that's a FIXME today.
4022 
4023   Type *ScalarTy = Ty->getScalarType();
4024   unsigned TyBits = ScalarTy->isPointerTy() ?
4025     Q.DL.getPointerTypeSizeInBits(ScalarTy) :
4026     Q.DL.getTypeSizeInBits(ScalarTy);
4027 
4028   unsigned Tmp, Tmp2;
4029   unsigned FirstAnswer = 1;
4030 
4031   // Note that ConstantInt is handled by the general computeKnownBits case
4032   // below.
4033 
4034   if (Depth == MaxAnalysisRecursionDepth)
4035     return 1;
4036 
4037   if (auto *U = dyn_cast<Operator>(V)) {
4038     switch (Operator::getOpcode(V)) {
4039     default: break;
4040     case Instruction::BitCast: {
4041       Value *Src = U->getOperand(0);
4042       Type *SrcTy = Src->getType();
4043 
4044       // Skip if the source type is not an integer or integer vector type
4045       // This ensures we only process integer-like types
4046       if (!SrcTy->isIntOrIntVectorTy())
4047         break;
4048 
4049       unsigned SrcBits = SrcTy->getScalarSizeInBits();
4050 
4051       // Bitcast 'large element' scalar/vector to 'small element' vector.
4052       if ((SrcBits % TyBits) != 0)
4053         break;
4054 
4055       // Only proceed if the destination type is a fixed-size vector
4056       if (isa<FixedVectorType>(Ty)) {
4057         // Fast case - sign splat can be simply split across the small elements.
4058         // This works for both vector and scalar sources
4059         Tmp = ComputeNumSignBits(Src, Q, Depth + 1);
4060         if (Tmp == SrcBits)
4061           return TyBits;
4062       }
4063       break;
4064     }
4065     case Instruction::SExt:
4066       Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
4067       return ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1) +
4068              Tmp;
4069 
4070     case Instruction::SDiv: {
4071       const APInt *Denominator;
4072       // sdiv X, C -> adds log(C) sign bits.
4073       if (match(U->getOperand(1), m_APInt(Denominator))) {
4074 
4075         // Ignore non-positive denominator.
4076         if (!Denominator->isStrictlyPositive())
4077           break;
4078 
4079         // Calculate the incoming numerator bits.
4080         unsigned NumBits =
4081             ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4082 
4083         // Add floor(log(C)) bits to the numerator bits.
4084         return std::min(TyBits, NumBits + Denominator->logBase2());
4085       }
4086       break;
4087     }
4088 
4089     case Instruction::SRem: {
4090       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4091 
4092       const APInt *Denominator;
4093       // srem X, C -> we know that the result is within [-C+1,C) when C is a
4094       // positive constant.  This let us put a lower bound on the number of sign
4095       // bits.
4096       if (match(U->getOperand(1), m_APInt(Denominator))) {
4097 
4098         // Ignore non-positive denominator.
4099         if (Denominator->isStrictlyPositive()) {
4100           // Calculate the leading sign bit constraints by examining the
4101           // denominator.  Given that the denominator is positive, there are two
4102           // cases:
4103           //
4104           //  1. The numerator is positive. The result range is [0,C) and
4105           //     [0,C) u< (1 << ceilLogBase2(C)).
4106           //
4107           //  2. The numerator is negative. Then the result range is (-C,0] and
4108           //     integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
4109           //
4110           // Thus a lower bound on the number of sign bits is `TyBits -
4111           // ceilLogBase2(C)`.
4112 
4113           unsigned ResBits = TyBits - Denominator->ceilLogBase2();
4114           Tmp = std::max(Tmp, ResBits);
4115         }
4116       }
4117       return Tmp;
4118     }
4119 
4120     case Instruction::AShr: {
4121       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4122       // ashr X, C   -> adds C sign bits.  Vectors too.
4123       const APInt *ShAmt;
4124       if (match(U->getOperand(1), m_APInt(ShAmt))) {
4125         if (ShAmt->uge(TyBits))
4126           break; // Bad shift.
4127         unsigned ShAmtLimited = ShAmt->getZExtValue();
4128         Tmp += ShAmtLimited;
4129         if (Tmp > TyBits) Tmp = TyBits;
4130       }
4131       return Tmp;
4132     }
4133     case Instruction::Shl: {
4134       const APInt *ShAmt;
4135       Value *X = nullptr;
4136       if (match(U->getOperand(1), m_APInt(ShAmt))) {
4137         // shl destroys sign bits.
4138         if (ShAmt->uge(TyBits))
4139           break; // Bad shift.
4140         // We can look through a zext (more or less treating it as a sext) if
4141         // all extended bits are shifted out.
4142         if (match(U->getOperand(0), m_ZExt(m_Value(X))) &&
4143             ShAmt->uge(TyBits - X->getType()->getScalarSizeInBits())) {
4144           Tmp = ComputeNumSignBits(X, DemandedElts, Q, Depth + 1);
4145           Tmp += TyBits - X->getType()->getScalarSizeInBits();
4146         } else
4147           Tmp =
4148               ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4149         if (ShAmt->uge(Tmp))
4150           break; // Shifted all sign bits out.
4151         Tmp2 = ShAmt->getZExtValue();
4152         return Tmp - Tmp2;
4153       }
4154       break;
4155     }
4156     case Instruction::And:
4157     case Instruction::Or:
4158     case Instruction::Xor: // NOT is handled here.
4159       // Logical binary ops preserve the number of sign bits at the worst.
4160       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4161       if (Tmp != 1) {
4162         Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Q, Depth + 1);
4163         FirstAnswer = std::min(Tmp, Tmp2);
4164         // We computed what we know about the sign bits as our first
4165         // answer. Now proceed to the generic code that uses
4166         // computeKnownBits, and pick whichever answer is better.
4167       }
4168       break;
4169 
4170     case Instruction::Select: {
4171       // If we have a clamp pattern, we know that the number of sign bits will
4172       // be the minimum of the clamp min/max range.
4173       const Value *X;
4174       const APInt *CLow, *CHigh;
4175       if (isSignedMinMaxClamp(U, X, CLow, CHigh))
4176         return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
4177 
4178       Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Q, Depth + 1);
4179       if (Tmp == 1)
4180         break;
4181       Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Q, Depth + 1);
4182       return std::min(Tmp, Tmp2);
4183     }
4184 
4185     case Instruction::Add:
4186       // Add can have at most one carry bit.  Thus we know that the output
4187       // is, at worst, one more bit than the inputs.
4188       Tmp = ComputeNumSignBits(U->getOperand(0), Q, Depth + 1);
4189       if (Tmp == 1) break;
4190 
4191       // Special case decrementing a value (ADD X, -1):
4192       if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
4193         if (CRHS->isAllOnesValue()) {
4194           KnownBits Known(TyBits);
4195           computeKnownBits(U->getOperand(0), DemandedElts, Known, Q, Depth + 1);
4196 
4197           // If the input is known to be 0 or 1, the output is 0/-1, which is
4198           // all sign bits set.
4199           if ((Known.Zero | 1).isAllOnes())
4200             return TyBits;
4201 
4202           // If we are subtracting one from a positive number, there is no carry
4203           // out of the result.
4204           if (Known.isNonNegative())
4205             return Tmp;
4206         }
4207 
4208       Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Q, Depth + 1);
4209       if (Tmp2 == 1)
4210         break;
4211       return std::min(Tmp, Tmp2) - 1;
4212 
4213     case Instruction::Sub:
4214       Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Q, Depth + 1);
4215       if (Tmp2 == 1)
4216         break;
4217 
4218       // Handle NEG.
4219       if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
4220         if (CLHS->isNullValue()) {
4221           KnownBits Known(TyBits);
4222           computeKnownBits(U->getOperand(1), DemandedElts, Known, Q, Depth + 1);
4223           // If the input is known to be 0 or 1, the output is 0/-1, which is
4224           // all sign bits set.
4225           if ((Known.Zero | 1).isAllOnes())
4226             return TyBits;
4227 
4228           // If the input is known to be positive (the sign bit is known clear),
4229           // the output of the NEG has the same number of sign bits as the
4230           // input.
4231           if (Known.isNonNegative())
4232             return Tmp2;
4233 
4234           // Otherwise, we treat this like a SUB.
4235         }
4236 
4237       // Sub can have at most one carry bit.  Thus we know that the output
4238       // is, at worst, one more bit than the inputs.
4239       Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4240       if (Tmp == 1)
4241         break;
4242       return std::min(Tmp, Tmp2) - 1;
4243 
4244     case Instruction::Mul: {
4245       // The output of the Mul can be at most twice the valid bits in the
4246       // inputs.
4247       unsigned SignBitsOp0 =
4248           ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4249       if (SignBitsOp0 == 1)
4250         break;
4251       unsigned SignBitsOp1 =
4252           ComputeNumSignBits(U->getOperand(1), DemandedElts, Q, Depth + 1);
4253       if (SignBitsOp1 == 1)
4254         break;
4255       unsigned OutValidBits =
4256           (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
4257       return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
4258     }
4259 
4260     case Instruction::PHI: {
4261       const PHINode *PN = cast<PHINode>(U);
4262       unsigned NumIncomingValues = PN->getNumIncomingValues();
4263       // Don't analyze large in-degree PHIs.
4264       if (NumIncomingValues > 4) break;
4265       // Unreachable blocks may have zero-operand PHI nodes.
4266       if (NumIncomingValues == 0) break;
4267 
4268       // Take the minimum of all incoming values.  This can't infinitely loop
4269       // because of our depth threshold.
4270       SimplifyQuery RecQ = Q.getWithoutCondContext();
4271       Tmp = TyBits;
4272       for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
4273         if (Tmp == 1) return Tmp;
4274         RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
4275         Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i),
4276                                                DemandedElts, RecQ, Depth + 1));
4277       }
4278       return Tmp;
4279     }
4280 
4281     case Instruction::Trunc: {
4282       // If the input contained enough sign bits that some remain after the
4283       // truncation, then we can make use of that. Otherwise we don't know
4284       // anything.
4285       Tmp = ComputeNumSignBits(U->getOperand(0), Q, Depth + 1);
4286       unsigned OperandTyBits = U->getOperand(0)->getType()->getScalarSizeInBits();
4287       if (Tmp > (OperandTyBits - TyBits))
4288         return Tmp - (OperandTyBits - TyBits);
4289 
4290       return 1;
4291     }
4292 
4293     case Instruction::ExtractElement:
4294       // Look through extract element. At the moment we keep this simple and
4295       // skip tracking the specific element. But at least we might find
4296       // information valid for all elements of the vector (for example if vector
4297       // is sign extended, shifted, etc).
4298       return ComputeNumSignBits(U->getOperand(0), Q, Depth + 1);
4299 
4300     case Instruction::ShuffleVector: {
4301       // Collect the minimum number of sign bits that are shared by every vector
4302       // element referenced by the shuffle.
4303       auto *Shuf = dyn_cast<ShuffleVectorInst>(U);
4304       if (!Shuf) {
4305         // FIXME: Add support for shufflevector constant expressions.
4306         return 1;
4307       }
4308       APInt DemandedLHS, DemandedRHS;
4309       // For undef elements, we don't know anything about the common state of
4310       // the shuffle result.
4311       if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
4312         return 1;
4313       Tmp = std::numeric_limits<unsigned>::max();
4314       if (!!DemandedLHS) {
4315         const Value *LHS = Shuf->getOperand(0);
4316         Tmp = ComputeNumSignBits(LHS, DemandedLHS, Q, Depth + 1);
4317       }
4318       // If we don't know anything, early out and try computeKnownBits
4319       // fall-back.
4320       if (Tmp == 1)
4321         break;
4322       if (!!DemandedRHS) {
4323         const Value *RHS = Shuf->getOperand(1);
4324         Tmp2 = ComputeNumSignBits(RHS, DemandedRHS, Q, Depth + 1);
4325         Tmp = std::min(Tmp, Tmp2);
4326       }
4327       // If we don't know anything, early out and try computeKnownBits
4328       // fall-back.
4329       if (Tmp == 1)
4330         break;
4331       assert(Tmp <= TyBits && "Failed to determine minimum sign bits");
4332       return Tmp;
4333     }
4334     case Instruction::Call: {
4335       if (const auto *II = dyn_cast<IntrinsicInst>(U)) {
4336         switch (II->getIntrinsicID()) {
4337         default:
4338           break;
4339         case Intrinsic::abs:
4340           Tmp =
4341               ComputeNumSignBits(U->getOperand(0), DemandedElts, Q, Depth + 1);
4342           if (Tmp == 1)
4343             break;
4344 
4345           // Absolute value reduces number of sign bits by at most 1.
4346           return Tmp - 1;
4347         case Intrinsic::smin:
4348         case Intrinsic::smax: {
4349           const APInt *CLow, *CHigh;
4350           if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
4351             return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());
4352         }
4353         }
4354       }
4355     }
4356     }
4357   }
4358 
4359   // Finally, if we can prove that the top bits of the result are 0's or 1's,
4360   // use this information.
4361 
4362   // If we can examine all elements of a vector constant successfully, we're
4363   // done (we can't do any better than that). If not, keep trying.
4364   if (unsigned VecSignBits =
4365           computeNumSignBitsVectorConstant(V, DemandedElts, TyBits))
4366     return VecSignBits;
4367 
4368   KnownBits Known(TyBits);
4369   computeKnownBits(V, DemandedElts, Known, Q, Depth);
4370 
4371   // If we know that the sign bit is either zero or one, determine the number of
4372   // identical bits in the top of the input value.
4373   return std::max(FirstAnswer, Known.countMinSignBits());
4374 }
4375 
4376 Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB,
4377                                             const TargetLibraryInfo *TLI) {
4378   const Function *F = CB.getCalledFunction();
4379   if (!F)
4380     return Intrinsic::not_intrinsic;
4381 
4382   if (F->isIntrinsic())
4383     return F->getIntrinsicID();
4384 
4385   // We are going to infer semantics of a library function based on mapping it
4386   // to an LLVM intrinsic. Check that the library function is available from
4387   // this callbase and in this environment.
4388   LibFunc Func;
4389   if (F->hasLocalLinkage() || !TLI || !TLI->getLibFunc(CB, Func) ||
4390       !CB.onlyReadsMemory())
4391     return Intrinsic::not_intrinsic;
4392 
4393   switch (Func) {
4394   default:
4395     break;
4396   case LibFunc_sin:
4397   case LibFunc_sinf:
4398   case LibFunc_sinl:
4399     return Intrinsic::sin;
4400   case LibFunc_cos:
4401   case LibFunc_cosf:
4402   case LibFunc_cosl:
4403     return Intrinsic::cos;
4404   case LibFunc_tan:
4405   case LibFunc_tanf:
4406   case LibFunc_tanl:
4407     return Intrinsic::tan;
4408   case LibFunc_asin:
4409   case LibFunc_asinf:
4410   case LibFunc_asinl:
4411     return Intrinsic::asin;
4412   case LibFunc_acos:
4413   case LibFunc_acosf:
4414   case LibFunc_acosl:
4415     return Intrinsic::acos;
4416   case LibFunc_atan:
4417   case LibFunc_atanf:
4418   case LibFunc_atanl:
4419     return Intrinsic::atan;
4420   case LibFunc_atan2:
4421   case LibFunc_atan2f:
4422   case LibFunc_atan2l:
4423     return Intrinsic::atan2;
4424   case LibFunc_sinh:
4425   case LibFunc_sinhf:
4426   case LibFunc_sinhl:
4427     return Intrinsic::sinh;
4428   case LibFunc_cosh:
4429   case LibFunc_coshf:
4430   case LibFunc_coshl:
4431     return Intrinsic::cosh;
4432   case LibFunc_tanh:
4433   case LibFunc_tanhf:
4434   case LibFunc_tanhl:
4435     return Intrinsic::tanh;
4436   case LibFunc_exp:
4437   case LibFunc_expf:
4438   case LibFunc_expl:
4439     return Intrinsic::exp;
4440   case LibFunc_exp2:
4441   case LibFunc_exp2f:
4442   case LibFunc_exp2l:
4443     return Intrinsic::exp2;
4444   case LibFunc_exp10:
4445   case LibFunc_exp10f:
4446   case LibFunc_exp10l:
4447     return Intrinsic::exp10;
4448   case LibFunc_log:
4449   case LibFunc_logf:
4450   case LibFunc_logl:
4451     return Intrinsic::log;
4452   case LibFunc_log10:
4453   case LibFunc_log10f:
4454   case LibFunc_log10l:
4455     return Intrinsic::log10;
4456   case LibFunc_log2:
4457   case LibFunc_log2f:
4458   case LibFunc_log2l:
4459     return Intrinsic::log2;
4460   case LibFunc_fabs:
4461   case LibFunc_fabsf:
4462   case LibFunc_fabsl:
4463     return Intrinsic::fabs;
4464   case LibFunc_fmin:
4465   case LibFunc_fminf:
4466   case LibFunc_fminl:
4467     return Intrinsic::minnum;
4468   case LibFunc_fmax:
4469   case LibFunc_fmaxf:
4470   case LibFunc_fmaxl:
4471     return Intrinsic::maxnum;
4472   case LibFunc_copysign:
4473   case LibFunc_copysignf:
4474   case LibFunc_copysignl:
4475     return Intrinsic::copysign;
4476   case LibFunc_floor:
4477   case LibFunc_floorf:
4478   case LibFunc_floorl:
4479     return Intrinsic::floor;
4480   case LibFunc_ceil:
4481   case LibFunc_ceilf:
4482   case LibFunc_ceill:
4483     return Intrinsic::ceil;
4484   case LibFunc_trunc:
4485   case LibFunc_truncf:
4486   case LibFunc_truncl:
4487     return Intrinsic::trunc;
4488   case LibFunc_rint:
4489   case LibFunc_rintf:
4490   case LibFunc_rintl:
4491     return Intrinsic::rint;
4492   case LibFunc_nearbyint:
4493   case LibFunc_nearbyintf:
4494   case LibFunc_nearbyintl:
4495     return Intrinsic::nearbyint;
4496   case LibFunc_round:
4497   case LibFunc_roundf:
4498   case LibFunc_roundl:
4499     return Intrinsic::round;
4500   case LibFunc_roundeven:
4501   case LibFunc_roundevenf:
4502   case LibFunc_roundevenl:
4503     return Intrinsic::roundeven;
4504   case LibFunc_pow:
4505   case LibFunc_powf:
4506   case LibFunc_powl:
4507     return Intrinsic::pow;
4508   case LibFunc_sqrt:
4509   case LibFunc_sqrtf:
4510   case LibFunc_sqrtl:
4511     return Intrinsic::sqrt;
4512   }
4513 
4514   return Intrinsic::not_intrinsic;
4515 }
4516 
4517 static bool outputDenormalIsIEEEOrPosZero(const Function &F, const Type *Ty) {
4518   Ty = Ty->getScalarType();
4519   DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
4520   return Mode.Output == DenormalMode::IEEE ||
4521          Mode.Output == DenormalMode::PositiveZero;
4522 }
4523 /// Given an exploded icmp instruction, return true if the comparison only
4524 /// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if
4525 /// the result of the comparison is true when the input value is signed.
4526 bool llvm::isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
4527                           bool &TrueIfSigned) {
4528   switch (Pred) {
4529   case ICmpInst::ICMP_SLT: // True if LHS s< 0
4530     TrueIfSigned = true;
4531     return RHS.isZero();
4532   case ICmpInst::ICMP_SLE: // True if LHS s<= -1
4533     TrueIfSigned = true;
4534     return RHS.isAllOnes();
4535   case ICmpInst::ICMP_SGT: // True if LHS s> -1
4536     TrueIfSigned = false;
4537     return RHS.isAllOnes();
4538   case ICmpInst::ICMP_SGE: // True if LHS s>= 0
4539     TrueIfSigned = false;
4540     return RHS.isZero();
4541   case ICmpInst::ICMP_UGT:
4542     // True if LHS u> RHS and RHS == sign-bit-mask - 1
4543     TrueIfSigned = true;
4544     return RHS.isMaxSignedValue();
4545   case ICmpInst::ICMP_UGE:
4546     // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4547     TrueIfSigned = true;
4548     return RHS.isMinSignedValue();
4549   case ICmpInst::ICMP_ULT:
4550     // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4551     TrueIfSigned = false;
4552     return RHS.isMinSignedValue();
4553   case ICmpInst::ICMP_ULE:
4554     // True if LHS u<= RHS and RHS == sign-bit-mask - 1
4555     TrueIfSigned = false;
4556     return RHS.isMaxSignedValue();
4557   default:
4558     return false;
4559   }
4560 }
4561 
4562 static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
4563                                         bool CondIsTrue,
4564                                         const Instruction *CxtI,
4565                                         KnownFPClass &KnownFromContext,
4566                                         unsigned Depth = 0) {
4567   Value *A, *B;
4568   if (Depth < MaxAnalysisRecursionDepth &&
4569       (CondIsTrue ? match(Cond, m_LogicalAnd(m_Value(A), m_Value(B)))
4570                   : match(Cond, m_LogicalOr(m_Value(A), m_Value(B))))) {
4571     computeKnownFPClassFromCond(V, A, CondIsTrue, CxtI, KnownFromContext,
4572                                 Depth + 1);
4573     computeKnownFPClassFromCond(V, B, CondIsTrue, CxtI, KnownFromContext,
4574                                 Depth + 1);
4575     return;
4576   }
4577   if (Depth < MaxAnalysisRecursionDepth && match(Cond, m_Not(m_Value(A)))) {
4578     computeKnownFPClassFromCond(V, A, !CondIsTrue, CxtI, KnownFromContext,
4579                                 Depth + 1);
4580     return;
4581   }
4582   CmpPredicate Pred;
4583   Value *LHS;
4584   uint64_t ClassVal = 0;
4585   const APFloat *CRHS;
4586   const APInt *RHS;
4587   if (match(Cond, m_FCmp(Pred, m_Value(LHS), m_APFloat(CRHS)))) {
4588     auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
4589         Pred, *CxtI->getParent()->getParent(), LHS, *CRHS, LHS != V);
4590     if (CmpVal == V)
4591       KnownFromContext.knownNot(~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
4592   } else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(
4593                              m_Specific(V), m_ConstantInt(ClassVal)))) {
4594     FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
4595     KnownFromContext.knownNot(CondIsTrue ? ~Mask : Mask);
4596   } else if (match(Cond, m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(V)),
4597                                 m_APInt(RHS)))) {
4598     bool TrueIfSigned;
4599     if (!isSignBitCheck(Pred, *RHS, TrueIfSigned))
4600       return;
4601     if (TrueIfSigned == CondIsTrue)
4602       KnownFromContext.signBitMustBeOne();
4603     else
4604       KnownFromContext.signBitMustBeZero();
4605   }
4606 }
4607 
4608 static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4609                                                    const SimplifyQuery &Q) {
4610   KnownFPClass KnownFromContext;
4611 
4612   if (Q.CC && Q.CC->AffectedValues.contains(V))
4613     computeKnownFPClassFromCond(V, Q.CC->Cond, !Q.CC->Invert, Q.CxtI,
4614                                 KnownFromContext);
4615 
4616   if (!Q.CxtI)
4617     return KnownFromContext;
4618 
4619   if (Q.DC && Q.DT) {
4620     // Handle dominating conditions.
4621     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4622       Value *Cond = BI->getCondition();
4623 
4624       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
4625       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
4626         computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, Q.CxtI,
4627                                     KnownFromContext);
4628 
4629       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
4630       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
4631         computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, Q.CxtI,
4632                                     KnownFromContext);
4633     }
4634   }
4635 
4636   if (!Q.AC)
4637     return KnownFromContext;
4638 
4639   // Try to restrict the floating-point classes based on information from
4640   // assumptions.
4641   for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
4642     if (!AssumeVH)
4643       continue;
4644     CallInst *I = cast<CallInst>(AssumeVH);
4645 
4646     assert(I->getFunction() == Q.CxtI->getParent()->getParent() &&
4647            "Got assumption for the wrong function!");
4648     assert(I->getIntrinsicID() == Intrinsic::assume &&
4649            "must be an assume intrinsic");
4650 
4651     if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
4652       continue;
4653 
4654     computeKnownFPClassFromCond(V, I->getArgOperand(0),
4655                                 /*CondIsTrue=*/true, Q.CxtI, KnownFromContext);
4656   }
4657 
4658   return KnownFromContext;
4659 }
4660 
4661 void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4662                          FPClassTest InterestedClasses, KnownFPClass &Known,
4663                          const SimplifyQuery &Q, unsigned Depth);
4664 
4665 static void computeKnownFPClass(const Value *V, KnownFPClass &Known,
4666                                 FPClassTest InterestedClasses,
4667                                 const SimplifyQuery &Q, unsigned Depth) {
4668   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
4669   APInt DemandedElts =
4670       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
4671   computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Q, Depth);
4672 }
4673 
4674 static void computeKnownFPClassForFPTrunc(const Operator *Op,
4675                                           const APInt &DemandedElts,
4676                                           FPClassTest InterestedClasses,
4677                                           KnownFPClass &Known,
4678                                           const SimplifyQuery &Q,
4679                                           unsigned Depth) {
4680   if ((InterestedClasses &
4681        (KnownFPClass::OrderedLessThanZeroMask | fcNan)) == fcNone)
4682     return;
4683 
4684   KnownFPClass KnownSrc;
4685   computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
4686                       KnownSrc, Q, Depth + 1);
4687 
4688   // Sign should be preserved
4689   // TODO: Handle cannot be ordered greater than zero
4690   if (KnownSrc.cannotBeOrderedLessThanZero())
4691     Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
4692 
4693   Known.propagateNaN(KnownSrc, true);
4694 
4695   // Infinity needs a range check.
4696 }
4697 
4698 void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4699                          FPClassTest InterestedClasses, KnownFPClass &Known,
4700                          const SimplifyQuery &Q, unsigned Depth) {
4701   assert(Known.isUnknown() && "should not be called with known information");
4702 
4703   if (!DemandedElts) {
4704     // No demanded elts, better to assume we don't know anything.
4705     Known.resetAll();
4706     return;
4707   }
4708 
4709   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4710 
4711   if (auto *CFP = dyn_cast<ConstantFP>(V)) {
4712     Known.KnownFPClasses = CFP->getValueAPF().classify();
4713     Known.SignBit = CFP->isNegative();
4714     return;
4715   }
4716 
4717   if (isa<ConstantAggregateZero>(V)) {
4718     Known.KnownFPClasses = fcPosZero;
4719     Known.SignBit = false;
4720     return;
4721   }
4722 
4723   if (isa<PoisonValue>(V)) {
4724     Known.KnownFPClasses = fcNone;
4725     Known.SignBit = false;
4726     return;
4727   }
4728 
4729   // Try to handle fixed width vector constants
4730   auto *VFVTy = dyn_cast<FixedVectorType>(V->getType());
4731   const Constant *CV = dyn_cast<Constant>(V);
4732   if (VFVTy && CV) {
4733     Known.KnownFPClasses = fcNone;
4734     bool SignBitAllZero = true;
4735     bool SignBitAllOne = true;
4736 
4737     // For vectors, verify that each element is not NaN.
4738     unsigned NumElts = VFVTy->getNumElements();
4739     for (unsigned i = 0; i != NumElts; ++i) {
4740       if (!DemandedElts[i])
4741         continue;
4742 
4743       Constant *Elt = CV->getAggregateElement(i);
4744       if (!Elt) {
4745         Known = KnownFPClass();
4746         return;
4747       }
4748       if (isa<PoisonValue>(Elt))
4749         continue;
4750       auto *CElt = dyn_cast<ConstantFP>(Elt);
4751       if (!CElt) {
4752         Known = KnownFPClass();
4753         return;
4754       }
4755 
4756       const APFloat &C = CElt->getValueAPF();
4757       Known.KnownFPClasses |= C.classify();
4758       if (C.isNegative())
4759         SignBitAllZero = false;
4760       else
4761         SignBitAllOne = false;
4762     }
4763     if (SignBitAllOne != SignBitAllZero)
4764       Known.SignBit = SignBitAllOne;
4765     return;
4766   }
4767 
4768   FPClassTest KnownNotFromFlags = fcNone;
4769   if (const auto *CB = dyn_cast<CallBase>(V))
4770     KnownNotFromFlags |= CB->getRetNoFPClass();
4771   else if (const auto *Arg = dyn_cast<Argument>(V))
4772     KnownNotFromFlags |= Arg->getNoFPClass();
4773 
4774   const Operator *Op = dyn_cast<Operator>(V);
4775   if (const FPMathOperator *FPOp = dyn_cast_or_null<FPMathOperator>(Op)) {
4776     if (FPOp->hasNoNaNs())
4777       KnownNotFromFlags |= fcNan;
4778     if (FPOp->hasNoInfs())
4779       KnownNotFromFlags |= fcInf;
4780   }
4781 
4782   KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
4783   KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
4784 
4785   // We no longer need to find out about these bits from inputs if we can
4786   // assume this from flags/attributes.
4787   InterestedClasses &= ~KnownNotFromFlags;
4788 
4789   auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
4790     Known.knownNot(KnownNotFromFlags);
4791     if (!Known.SignBit && AssumedClasses.SignBit) {
4792       if (*AssumedClasses.SignBit)
4793         Known.signBitMustBeOne();
4794       else
4795         Known.signBitMustBeZero();
4796     }
4797   });
4798 
4799   if (!Op)
4800     return;
4801 
4802   // All recursive calls that increase depth must come after this.
4803   if (Depth == MaxAnalysisRecursionDepth)
4804     return;
4805 
4806   const unsigned Opc = Op->getOpcode();
4807   switch (Opc) {
4808   case Instruction::FNeg: {
4809     computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
4810                         Known, Q, Depth + 1);
4811     Known.fneg();
4812     break;
4813   }
4814   case Instruction::Select: {
4815     Value *Cond = Op->getOperand(0);
4816     Value *LHS = Op->getOperand(1);
4817     Value *RHS = Op->getOperand(2);
4818 
4819     FPClassTest FilterLHS = fcAllFlags;
4820     FPClassTest FilterRHS = fcAllFlags;
4821 
4822     Value *TestedValue = nullptr;
4823     FPClassTest MaskIfTrue = fcAllFlags;
4824     FPClassTest MaskIfFalse = fcAllFlags;
4825     uint64_t ClassVal = 0;
4826     const Function *F = cast<Instruction>(Op)->getFunction();
4827     CmpPredicate Pred;
4828     Value *CmpLHS, *CmpRHS;
4829     if (F && match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
4830       // If the select filters out a value based on the class, it no longer
4831       // participates in the class of the result
4832 
4833       // TODO: In some degenerate cases we can infer something if we try again
4834       // without looking through sign operations.
4835       bool LookThroughFAbsFNeg = CmpLHS != LHS && CmpLHS != RHS;
4836       std::tie(TestedValue, MaskIfTrue, MaskIfFalse) =
4837           fcmpImpliesClass(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
4838     } else if (match(Cond,
4839                      m_Intrinsic<Intrinsic::is_fpclass>(
4840                          m_Value(TestedValue), m_ConstantInt(ClassVal)))) {
4841       FPClassTest TestedMask = static_cast<FPClassTest>(ClassVal);
4842       MaskIfTrue = TestedMask;
4843       MaskIfFalse = ~TestedMask;
4844     }
4845 
4846     if (TestedValue == LHS) {
4847       // match !isnan(x) ? x : y
4848       FilterLHS = MaskIfTrue;
4849     } else if (TestedValue == RHS) { // && IsExactClass
4850       // match !isnan(x) ? y : x
4851       FilterRHS = MaskIfFalse;
4852     }
4853 
4854     KnownFPClass Known2;
4855     computeKnownFPClass(LHS, DemandedElts, InterestedClasses & FilterLHS, Known,
4856                         Q, Depth + 1);
4857     Known.KnownFPClasses &= FilterLHS;
4858 
4859     computeKnownFPClass(RHS, DemandedElts, InterestedClasses & FilterRHS,
4860                         Known2, Q, Depth + 1);
4861     Known2.KnownFPClasses &= FilterRHS;
4862 
4863     Known |= Known2;
4864     break;
4865   }
4866   case Instruction::Call: {
4867     const CallInst *II = cast<CallInst>(Op);
4868     const Intrinsic::ID IID = II->getIntrinsicID();
4869     switch (IID) {
4870     case Intrinsic::fabs: {
4871       if ((InterestedClasses & (fcNan | fcPositive)) != fcNone) {
4872         // If we only care about the sign bit we don't need to inspect the
4873         // operand.
4874         computeKnownFPClass(II->getArgOperand(0), DemandedElts,
4875                             InterestedClasses, Known, Q, Depth + 1);
4876       }
4877 
4878       Known.fabs();
4879       break;
4880     }
4881     case Intrinsic::copysign: {
4882       KnownFPClass KnownSign;
4883 
4884       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
4885                           Known, Q, Depth + 1);
4886       computeKnownFPClass(II->getArgOperand(1), DemandedElts, InterestedClasses,
4887                           KnownSign, Q, Depth + 1);
4888       Known.copysign(KnownSign);
4889       break;
4890     }
4891     case Intrinsic::fma:
4892     case Intrinsic::fmuladd: {
4893       if ((InterestedClasses & fcNegative) == fcNone)
4894         break;
4895 
4896       if (II->getArgOperand(0) != II->getArgOperand(1))
4897         break;
4898 
4899       // The multiply cannot be -0 and therefore the add can't be -0
4900       Known.knownNot(fcNegZero);
4901 
4902       // x * x + y is non-negative if y is non-negative.
4903       KnownFPClass KnownAddend;
4904       computeKnownFPClass(II->getArgOperand(2), DemandedElts, InterestedClasses,
4905                           KnownAddend, Q, Depth + 1);
4906 
4907       if (KnownAddend.cannotBeOrderedLessThanZero())
4908         Known.knownNot(fcNegative);
4909       break;
4910     }
4911     case Intrinsic::sqrt:
4912     case Intrinsic::experimental_constrained_sqrt: {
4913       KnownFPClass KnownSrc;
4914       FPClassTest InterestedSrcs = InterestedClasses;
4915       if (InterestedClasses & fcNan)
4916         InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
4917 
4918       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
4919                           KnownSrc, Q, Depth + 1);
4920 
4921       if (KnownSrc.isKnownNeverPosInfinity())
4922         Known.knownNot(fcPosInf);
4923       if (KnownSrc.isKnownNever(fcSNan))
4924         Known.knownNot(fcSNan);
4925 
4926       // Any negative value besides -0 returns a nan.
4927       if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
4928         Known.knownNot(fcNan);
4929 
4930       // The only negative value that can be returned is -0 for -0 inputs.
4931       Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
4932 
4933       // If the input denormal mode could be PreserveSign, a negative
4934       // subnormal input could produce a negative zero output.
4935       const Function *F = II->getFunction();
4936       const fltSemantics &FltSem =
4937           II->getType()->getScalarType()->getFltSemantics();
4938 
4939       if (Q.IIQ.hasNoSignedZeros(II) ||
4940           (F &&
4941            KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem))))
4942         Known.knownNot(fcNegZero);
4943 
4944       break;
4945     }
4946     case Intrinsic::sin:
4947     case Intrinsic::cos: {
4948       // Return NaN on infinite inputs.
4949       KnownFPClass KnownSrc;
4950       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
4951                           KnownSrc, Q, Depth + 1);
4952       Known.knownNot(fcInf);
4953       if (KnownSrc.isKnownNeverNaN() && KnownSrc.isKnownNeverInfinity())
4954         Known.knownNot(fcNan);
4955       break;
4956     }
4957     case Intrinsic::maxnum:
4958     case Intrinsic::minnum:
4959     case Intrinsic::minimum:
4960     case Intrinsic::maximum:
4961     case Intrinsic::minimumnum:
4962     case Intrinsic::maximumnum: {
4963       KnownFPClass KnownLHS, KnownRHS;
4964       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
4965                           KnownLHS, Q, Depth + 1);
4966       computeKnownFPClass(II->getArgOperand(1), DemandedElts, InterestedClasses,
4967                           KnownRHS, Q, Depth + 1);
4968 
4969       bool NeverNaN = KnownLHS.isKnownNeverNaN() || KnownRHS.isKnownNeverNaN();
4970       Known = KnownLHS | KnownRHS;
4971 
4972       // If either operand is not NaN, the result is not NaN.
4973       if (NeverNaN &&
4974           (IID == Intrinsic::minnum || IID == Intrinsic::maxnum ||
4975            IID == Intrinsic::minimumnum || IID == Intrinsic::maximumnum))
4976         Known.knownNot(fcNan);
4977 
4978       if (IID == Intrinsic::maxnum || IID == Intrinsic::maximumnum) {
4979         // If at least one operand is known to be positive, the result must be
4980         // positive.
4981         if ((KnownLHS.cannotBeOrderedLessThanZero() &&
4982              KnownLHS.isKnownNeverNaN()) ||
4983             (KnownRHS.cannotBeOrderedLessThanZero() &&
4984              KnownRHS.isKnownNeverNaN()))
4985           Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
4986       } else if (IID == Intrinsic::maximum) {
4987         // If at least one operand is known to be positive, the result must be
4988         // positive.
4989         if (KnownLHS.cannotBeOrderedLessThanZero() ||
4990             KnownRHS.cannotBeOrderedLessThanZero())
4991           Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
4992       } else if (IID == Intrinsic::minnum || IID == Intrinsic::minimumnum) {
4993         // If at least one operand is known to be negative, the result must be
4994         // negative.
4995         if ((KnownLHS.cannotBeOrderedGreaterThanZero() &&
4996              KnownLHS.isKnownNeverNaN()) ||
4997             (KnownRHS.cannotBeOrderedGreaterThanZero() &&
4998              KnownRHS.isKnownNeverNaN()))
4999           Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5000       } else if (IID == Intrinsic::minimum) {
5001         // If at least one operand is known to be negative, the result must be
5002         // negative.
5003         if (KnownLHS.cannotBeOrderedGreaterThanZero() ||
5004             KnownRHS.cannotBeOrderedGreaterThanZero())
5005           Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5006       } else
5007         llvm_unreachable("unhandled intrinsic");
5008 
5009       // Fixup zero handling if denormals could be returned as a zero.
5010       //
5011       // As there's no spec for denormal flushing, be conservative with the
5012       // treatment of denormals that could be flushed to zero. For older
5013       // subtargets on AMDGPU the min/max instructions would not flush the
5014       // output and return the original value.
5015       //
5016       if ((Known.KnownFPClasses & fcZero) != fcNone &&
5017           !Known.isKnownNeverSubnormal()) {
5018         const Function *Parent = II->getFunction();
5019         if (!Parent)
5020           break;
5021 
5022         DenormalMode Mode = Parent->getDenormalMode(
5023             II->getType()->getScalarType()->getFltSemantics());
5024         if (Mode != DenormalMode::getIEEE())
5025           Known.KnownFPClasses |= fcZero;
5026       }
5027 
5028       if (Known.isKnownNeverNaN()) {
5029         if (KnownLHS.SignBit && KnownRHS.SignBit &&
5030             *KnownLHS.SignBit == *KnownRHS.SignBit) {
5031           if (*KnownLHS.SignBit)
5032             Known.signBitMustBeOne();
5033           else
5034             Known.signBitMustBeZero();
5035         } else if ((IID == Intrinsic::maximum || IID == Intrinsic::minimum ||
5036                     IID == Intrinsic::maximumnum ||
5037                     IID == Intrinsic::minimumnum) ||
5038                    // FIXME: Should be using logical zero versions
5039                    ((KnownLHS.isKnownNeverNegZero() ||
5040                      KnownRHS.isKnownNeverPosZero()) &&
5041                     (KnownLHS.isKnownNeverPosZero() ||
5042                      KnownRHS.isKnownNeverNegZero()))) {
5043           if ((IID == Intrinsic::maximum || IID == Intrinsic::maximumnum ||
5044                IID == Intrinsic::maxnum) &&
5045               (KnownLHS.SignBit == false || KnownRHS.SignBit == false))
5046             Known.signBitMustBeZero();
5047           else if ((IID == Intrinsic::minimum || IID == Intrinsic::minimumnum ||
5048                     IID == Intrinsic::minnum) &&
5049                    (KnownLHS.SignBit == true || KnownRHS.SignBit == true))
5050             Known.signBitMustBeOne();
5051         }
5052       }
5053       break;
5054     }
5055     case Intrinsic::canonicalize: {
5056       KnownFPClass KnownSrc;
5057       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5058                           KnownSrc, Q, Depth + 1);
5059 
5060       // This is essentially a stronger form of
5061       // propagateCanonicalizingSrc. Other "canonicalizing" operations don't
5062       // actually have an IR canonicalization guarantee.
5063 
5064       // Canonicalize may flush denormals to zero, so we have to consider the
5065       // denormal mode to preserve known-not-0 knowledge.
5066       Known.KnownFPClasses = KnownSrc.KnownFPClasses | fcZero | fcQNan;
5067 
5068       // Stronger version of propagateNaN
5069       // Canonicalize is guaranteed to quiet signaling nans.
5070       if (KnownSrc.isKnownNeverNaN())
5071         Known.knownNot(fcNan);
5072       else
5073         Known.knownNot(fcSNan);
5074 
5075       const Function *F = II->getFunction();
5076       if (!F)
5077         break;
5078 
5079       // If the parent function flushes denormals, the canonical output cannot
5080       // be a denormal.
5081       const fltSemantics &FPType =
5082           II->getType()->getScalarType()->getFltSemantics();
5083       DenormalMode DenormMode = F->getDenormalMode(FPType);
5084       if (DenormMode == DenormalMode::getIEEE()) {
5085         if (KnownSrc.isKnownNever(fcPosZero))
5086           Known.knownNot(fcPosZero);
5087         if (KnownSrc.isKnownNever(fcNegZero))
5088           Known.knownNot(fcNegZero);
5089         break;
5090       }
5091 
5092       if (DenormMode.inputsAreZero() || DenormMode.outputsAreZero())
5093         Known.knownNot(fcSubnormal);
5094 
5095       if (DenormMode.Input == DenormalMode::PositiveZero ||
5096           (DenormMode.Output == DenormalMode::PositiveZero &&
5097            DenormMode.Input == DenormalMode::IEEE))
5098         Known.knownNot(fcNegZero);
5099 
5100       break;
5101     }
5102     case Intrinsic::vector_reduce_fmax:
5103     case Intrinsic::vector_reduce_fmin:
5104     case Intrinsic::vector_reduce_fmaximum:
5105     case Intrinsic::vector_reduce_fminimum: {
5106       // reduce min/max will choose an element from one of the vector elements,
5107       // so we can infer and class information that is common to all elements.
5108       Known = computeKnownFPClass(II->getArgOperand(0), II->getFastMathFlags(),
5109                                   InterestedClasses, Q, Depth + 1);
5110       // Can only propagate sign if output is never NaN.
5111       if (!Known.isKnownNeverNaN())
5112         Known.SignBit.reset();
5113       break;
5114     }
5115       // reverse preserves all characteristics of the input vec's element.
5116     case Intrinsic::vector_reverse:
5117       Known = computeKnownFPClass(
5118           II->getArgOperand(0), DemandedElts.reverseBits(),
5119           II->getFastMathFlags(), InterestedClasses, Q, Depth + 1);
5120       break;
5121     case Intrinsic::trunc:
5122     case Intrinsic::floor:
5123     case Intrinsic::ceil:
5124     case Intrinsic::rint:
5125     case Intrinsic::nearbyint:
5126     case Intrinsic::round:
5127     case Intrinsic::roundeven: {
5128       KnownFPClass KnownSrc;
5129       FPClassTest InterestedSrcs = InterestedClasses;
5130       if (InterestedSrcs & fcPosFinite)
5131         InterestedSrcs |= fcPosFinite;
5132       if (InterestedSrcs & fcNegFinite)
5133         InterestedSrcs |= fcNegFinite;
5134       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
5135                           KnownSrc, Q, Depth + 1);
5136 
5137       // Integer results cannot be subnormal.
5138       Known.knownNot(fcSubnormal);
5139 
5140       Known.propagateNaN(KnownSrc, true);
5141 
5142       // Pass through infinities, except PPC_FP128 is a special case for
5143       // intrinsics other than trunc.
5144       if (IID == Intrinsic::trunc || !V->getType()->isMultiUnitFPType()) {
5145         if (KnownSrc.isKnownNeverPosInfinity())
5146           Known.knownNot(fcPosInf);
5147         if (KnownSrc.isKnownNeverNegInfinity())
5148           Known.knownNot(fcNegInf);
5149       }
5150 
5151       // Negative round ups to 0 produce -0
5152       if (KnownSrc.isKnownNever(fcPosFinite))
5153         Known.knownNot(fcPosFinite);
5154       if (KnownSrc.isKnownNever(fcNegFinite))
5155         Known.knownNot(fcNegFinite);
5156 
5157       break;
5158     }
5159     case Intrinsic::exp:
5160     case Intrinsic::exp2:
5161     case Intrinsic::exp10: {
5162       Known.knownNot(fcNegative);
5163       if ((InterestedClasses & fcNan) == fcNone)
5164         break;
5165 
5166       KnownFPClass KnownSrc;
5167       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5168                           KnownSrc, Q, Depth + 1);
5169       if (KnownSrc.isKnownNeverNaN()) {
5170         Known.knownNot(fcNan);
5171         Known.signBitMustBeZero();
5172       }
5173 
5174       break;
5175     }
5176     case Intrinsic::fptrunc_round: {
5177       computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known,
5178                                     Q, Depth);
5179       break;
5180     }
5181     case Intrinsic::log:
5182     case Intrinsic::log10:
5183     case Intrinsic::log2:
5184     case Intrinsic::experimental_constrained_log:
5185     case Intrinsic::experimental_constrained_log10:
5186     case Intrinsic::experimental_constrained_log2: {
5187       // log(+inf) -> +inf
5188       // log([+-]0.0) -> -inf
5189       // log(-inf) -> nan
5190       // log(-x) -> nan
5191       if ((InterestedClasses & (fcNan | fcInf)) == fcNone)
5192         break;
5193 
5194       FPClassTest InterestedSrcs = InterestedClasses;
5195       if ((InterestedClasses & fcNegInf) != fcNone)
5196         InterestedSrcs |= fcZero | fcSubnormal;
5197       if ((InterestedClasses & fcNan) != fcNone)
5198         InterestedSrcs |= fcNan | (fcNegative & ~fcNan);
5199 
5200       KnownFPClass KnownSrc;
5201       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
5202                           KnownSrc, Q, Depth + 1);
5203 
5204       if (KnownSrc.isKnownNeverPosInfinity())
5205         Known.knownNot(fcPosInf);
5206 
5207       if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
5208         Known.knownNot(fcNan);
5209 
5210       const Function *F = II->getFunction();
5211 
5212       if (!F)
5213         break;
5214 
5215       const fltSemantics &FltSem =
5216           II->getType()->getScalarType()->getFltSemantics();
5217       DenormalMode Mode = F->getDenormalMode(FltSem);
5218 
5219       if (KnownSrc.isKnownNeverLogicalZero(Mode))
5220         Known.knownNot(fcNegInf);
5221 
5222       break;
5223     }
5224     case Intrinsic::powi: {
5225       if ((InterestedClasses & fcNegative) == fcNone)
5226         break;
5227 
5228       const Value *Exp = II->getArgOperand(1);
5229       Type *ExpTy = Exp->getType();
5230       unsigned BitWidth = ExpTy->getScalarType()->getIntegerBitWidth();
5231       KnownBits ExponentKnownBits(BitWidth);
5232       computeKnownBits(Exp, isa<VectorType>(ExpTy) ? DemandedElts : APInt(1, 1),
5233                        ExponentKnownBits, Q, Depth + 1);
5234 
5235       if (ExponentKnownBits.Zero[0]) { // Is even
5236         Known.knownNot(fcNegative);
5237         break;
5238       }
5239 
5240       // Given that exp is an integer, here are the
5241       // ways that pow can return a negative value:
5242       //
5243       //   pow(-x, exp)   --> negative if exp is odd and x is negative.
5244       //   pow(-0, exp)   --> -inf if exp is negative odd.
5245       //   pow(-0, exp)   --> -0 if exp is positive odd.
5246       //   pow(-inf, exp) --> -0 if exp is negative odd.
5247       //   pow(-inf, exp) --> -inf if exp is positive odd.
5248       KnownFPClass KnownSrc;
5249       computeKnownFPClass(II->getArgOperand(0), DemandedElts, fcNegative,
5250                           KnownSrc, Q, Depth + 1);
5251       if (KnownSrc.isKnownNever(fcNegative))
5252         Known.knownNot(fcNegative);
5253       break;
5254     }
5255     case Intrinsic::ldexp: {
5256       KnownFPClass KnownSrc;
5257       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5258                           KnownSrc, Q, Depth + 1);
5259       Known.propagateNaN(KnownSrc, /*PropagateSign=*/true);
5260 
5261       // Sign is preserved, but underflows may produce zeroes.
5262       if (KnownSrc.isKnownNever(fcNegative))
5263         Known.knownNot(fcNegative);
5264       else if (KnownSrc.cannotBeOrderedLessThanZero())
5265         Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5266 
5267       if (KnownSrc.isKnownNever(fcPositive))
5268         Known.knownNot(fcPositive);
5269       else if (KnownSrc.cannotBeOrderedGreaterThanZero())
5270         Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5271 
5272       // Can refine inf/zero handling based on the exponent operand.
5273       const FPClassTest ExpInfoMask = fcZero | fcSubnormal | fcInf;
5274       if ((InterestedClasses & ExpInfoMask) == fcNone)
5275         break;
5276       if ((KnownSrc.KnownFPClasses & ExpInfoMask) == fcNone)
5277         break;
5278 
5279       const fltSemantics &Flt =
5280           II->getType()->getScalarType()->getFltSemantics();
5281       unsigned Precision = APFloat::semanticsPrecision(Flt);
5282       const Value *ExpArg = II->getArgOperand(1);
5283       ConstantRange ExpRange = computeConstantRange(
5284           ExpArg, true, Q.IIQ.UseInstrInfo, Q.AC, Q.CxtI, Q.DT, Depth + 1);
5285 
5286       const int MantissaBits = Precision - 1;
5287       if (ExpRange.getSignedMin().sge(static_cast<int64_t>(MantissaBits)))
5288         Known.knownNot(fcSubnormal);
5289 
5290       const Function *F = II->getFunction();
5291       const APInt *ConstVal = ExpRange.getSingleElement();
5292       const fltSemantics &FltSem =
5293           II->getType()->getScalarType()->getFltSemantics();
5294       if (ConstVal && ConstVal->isZero()) {
5295         // ldexp(x, 0) -> x, so propagate everything.
5296         Known.propagateCanonicalizingSrc(KnownSrc, F->getDenormalMode(FltSem));
5297       } else if (ExpRange.isAllNegative()) {
5298         // If we know the power is <= 0, can't introduce inf
5299         if (KnownSrc.isKnownNeverPosInfinity())
5300           Known.knownNot(fcPosInf);
5301         if (KnownSrc.isKnownNeverNegInfinity())
5302           Known.knownNot(fcNegInf);
5303       } else if (ExpRange.isAllNonNegative()) {
5304         // If we know the power is >= 0, can't introduce subnormal or zero
5305         if (KnownSrc.isKnownNeverPosSubnormal())
5306           Known.knownNot(fcPosSubnormal);
5307         if (KnownSrc.isKnownNeverNegSubnormal())
5308           Known.knownNot(fcNegSubnormal);
5309         if (F &&
5310             KnownSrc.isKnownNeverLogicalPosZero(F->getDenormalMode(FltSem)))
5311           Known.knownNot(fcPosZero);
5312         if (F &&
5313             KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem)))
5314           Known.knownNot(fcNegZero);
5315       }
5316 
5317       break;
5318     }
5319     case Intrinsic::arithmetic_fence: {
5320       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedClasses,
5321                           Known, Q, Depth + 1);
5322       break;
5323     }
5324     case Intrinsic::experimental_constrained_sitofp:
5325     case Intrinsic::experimental_constrained_uitofp:
5326       // Cannot produce nan
5327       Known.knownNot(fcNan);
5328 
5329       // sitofp and uitofp turn into +0.0 for zero.
5330       Known.knownNot(fcNegZero);
5331 
5332       // Integers cannot be subnormal
5333       Known.knownNot(fcSubnormal);
5334 
5335       if (IID == Intrinsic::experimental_constrained_uitofp)
5336         Known.signBitMustBeZero();
5337 
5338       // TODO: Copy inf handling from instructions
5339       break;
5340     default:
5341       break;
5342     }
5343 
5344     break;
5345   }
5346   case Instruction::FAdd:
5347   case Instruction::FSub: {
5348     KnownFPClass KnownLHS, KnownRHS;
5349     bool WantNegative =
5350         Op->getOpcode() == Instruction::FAdd &&
5351         (InterestedClasses & KnownFPClass::OrderedLessThanZeroMask) != fcNone;
5352     bool WantNaN = (InterestedClasses & fcNan) != fcNone;
5353     bool WantNegZero = (InterestedClasses & fcNegZero) != fcNone;
5354 
5355     if (!WantNaN && !WantNegative && !WantNegZero)
5356       break;
5357 
5358     FPClassTest InterestedSrcs = InterestedClasses;
5359     if (WantNegative)
5360       InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5361     if (InterestedClasses & fcNan)
5362       InterestedSrcs |= fcInf;
5363     computeKnownFPClass(Op->getOperand(1), DemandedElts, InterestedSrcs,
5364                         KnownRHS, Q, Depth + 1);
5365 
5366     if ((WantNaN && KnownRHS.isKnownNeverNaN()) ||
5367         (WantNegative && KnownRHS.cannotBeOrderedLessThanZero()) ||
5368         WantNegZero || Opc == Instruction::FSub) {
5369 
5370       // RHS is canonically cheaper to compute. Skip inspecting the LHS if
5371       // there's no point.
5372       computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedSrcs,
5373                           KnownLHS, Q, Depth + 1);
5374       // Adding positive and negative infinity produces NaN.
5375       // TODO: Check sign of infinities.
5376       if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5377           (KnownLHS.isKnownNeverInfinity() || KnownRHS.isKnownNeverInfinity()))
5378         Known.knownNot(fcNan);
5379 
5380       // FIXME: Context function should always be passed in separately
5381       const Function *F = cast<Instruction>(Op)->getFunction();
5382 
5383       if (Op->getOpcode() == Instruction::FAdd) {
5384         if (KnownLHS.cannotBeOrderedLessThanZero() &&
5385             KnownRHS.cannotBeOrderedLessThanZero())
5386           Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5387         if (!F)
5388           break;
5389 
5390         const fltSemantics &FltSem =
5391             Op->getType()->getScalarType()->getFltSemantics();
5392         DenormalMode Mode = F->getDenormalMode(FltSem);
5393 
5394         // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0.
5395         if ((KnownLHS.isKnownNeverLogicalNegZero(Mode) ||
5396              KnownRHS.isKnownNeverLogicalNegZero(Mode)) &&
5397             // Make sure output negative denormal can't flush to -0
5398             outputDenormalIsIEEEOrPosZero(*F, Op->getType()))
5399           Known.knownNot(fcNegZero);
5400       } else {
5401         if (!F)
5402           break;
5403 
5404         const fltSemantics &FltSem =
5405             Op->getType()->getScalarType()->getFltSemantics();
5406         DenormalMode Mode = F->getDenormalMode(FltSem);
5407 
5408         // Only fsub -0, +0 can return -0
5409         if ((KnownLHS.isKnownNeverLogicalNegZero(Mode) ||
5410              KnownRHS.isKnownNeverLogicalPosZero(Mode)) &&
5411             // Make sure output negative denormal can't flush to -0
5412             outputDenormalIsIEEEOrPosZero(*F, Op->getType()))
5413           Known.knownNot(fcNegZero);
5414       }
5415     }
5416 
5417     break;
5418   }
5419   case Instruction::FMul: {
5420     // X * X is always non-negative or a NaN.
5421     if (Op->getOperand(0) == Op->getOperand(1))
5422       Known.knownNot(fcNegative);
5423 
5424     if ((InterestedClasses & fcNan) != fcNan)
5425       break;
5426 
5427     // fcSubnormal is only needed in case of DAZ.
5428     const FPClassTest NeedForNan = fcNan | fcInf | fcZero | fcSubnormal;
5429 
5430     KnownFPClass KnownLHS, KnownRHS;
5431     computeKnownFPClass(Op->getOperand(1), DemandedElts, NeedForNan, KnownRHS,
5432                         Q, Depth + 1);
5433     if (!KnownRHS.isKnownNeverNaN())
5434       break;
5435 
5436     computeKnownFPClass(Op->getOperand(0), DemandedElts, NeedForNan, KnownLHS,
5437                         Q, Depth + 1);
5438     if (!KnownLHS.isKnownNeverNaN())
5439       break;
5440 
5441     if (KnownLHS.SignBit && KnownRHS.SignBit) {
5442       if (*KnownLHS.SignBit == *KnownRHS.SignBit)
5443         Known.signBitMustBeZero();
5444       else
5445         Known.signBitMustBeOne();
5446     }
5447 
5448     // If 0 * +/-inf produces NaN.
5449     if (KnownLHS.isKnownNeverInfinity() && KnownRHS.isKnownNeverInfinity()) {
5450       Known.knownNot(fcNan);
5451       break;
5452     }
5453 
5454     const Function *F = cast<Instruction>(Op)->getFunction();
5455     if (!F)
5456       break;
5457 
5458     Type *OpTy = Op->getType()->getScalarType();
5459     const fltSemantics &FltSem = OpTy->getFltSemantics();
5460     DenormalMode Mode = F->getDenormalMode(FltSem);
5461 
5462     if ((KnownRHS.isKnownNeverInfinity() ||
5463          KnownLHS.isKnownNeverLogicalZero(Mode)) &&
5464         (KnownLHS.isKnownNeverInfinity() ||
5465          KnownRHS.isKnownNeverLogicalZero(Mode)))
5466       Known.knownNot(fcNan);
5467 
5468     break;
5469   }
5470   case Instruction::FDiv:
5471   case Instruction::FRem: {
5472     if (Op->getOperand(0) == Op->getOperand(1)) {
5473       // TODO: Could filter out snan if we inspect the operand
5474       if (Op->getOpcode() == Instruction::FDiv) {
5475         // X / X is always exactly 1.0 or a NaN.
5476         Known.KnownFPClasses = fcNan | fcPosNormal;
5477       } else {
5478         // X % X is always exactly [+-]0.0 or a NaN.
5479         Known.KnownFPClasses = fcNan | fcZero;
5480       }
5481 
5482       break;
5483     }
5484 
5485     const bool WantNan = (InterestedClasses & fcNan) != fcNone;
5486     const bool WantNegative = (InterestedClasses & fcNegative) != fcNone;
5487     const bool WantPositive =
5488         Opc == Instruction::FRem && (InterestedClasses & fcPositive) != fcNone;
5489     if (!WantNan && !WantNegative && !WantPositive)
5490       break;
5491 
5492     KnownFPClass KnownLHS, KnownRHS;
5493 
5494     computeKnownFPClass(Op->getOperand(1), DemandedElts,
5495                         fcNan | fcInf | fcZero | fcNegative, KnownRHS, Q,
5496                         Depth + 1);
5497 
5498     bool KnowSomethingUseful =
5499         KnownRHS.isKnownNeverNaN() || KnownRHS.isKnownNever(fcNegative);
5500 
5501     if (KnowSomethingUseful || WantPositive) {
5502       const FPClassTest InterestedLHS =
5503           WantPositive ? fcAllFlags
5504                        : fcNan | fcInf | fcZero | fcSubnormal | fcNegative;
5505 
5506       computeKnownFPClass(Op->getOperand(0), DemandedElts,
5507                           InterestedClasses & InterestedLHS, KnownLHS, Q,
5508                           Depth + 1);
5509     }
5510 
5511     const Function *F = cast<Instruction>(Op)->getFunction();
5512     const fltSemantics &FltSem =
5513         Op->getType()->getScalarType()->getFltSemantics();
5514 
5515     if (Op->getOpcode() == Instruction::FDiv) {
5516       // Only 0/0, Inf/Inf produce NaN.
5517       if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5518           (KnownLHS.isKnownNeverInfinity() ||
5519            KnownRHS.isKnownNeverInfinity()) &&
5520           ((F &&
5521             KnownLHS.isKnownNeverLogicalZero(F->getDenormalMode(FltSem))) ||
5522            (F &&
5523             KnownRHS.isKnownNeverLogicalZero(F->getDenormalMode(FltSem))))) {
5524         Known.knownNot(fcNan);
5525       }
5526 
5527       // X / -0.0 is -Inf (or NaN).
5528       // +X / +X is +X
5529       if (KnownLHS.isKnownNever(fcNegative) && KnownRHS.isKnownNever(fcNegative))
5530         Known.knownNot(fcNegative);
5531     } else {
5532       // Inf REM x and x REM 0 produce NaN.
5533       if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5534           KnownLHS.isKnownNeverInfinity() && F &&
5535           KnownRHS.isKnownNeverLogicalZero(F->getDenormalMode(FltSem))) {
5536         Known.knownNot(fcNan);
5537       }
5538 
5539       // The sign for frem is the same as the first operand.
5540       if (KnownLHS.cannotBeOrderedLessThanZero())
5541         Known.knownNot(KnownFPClass::OrderedLessThanZeroMask);
5542       if (KnownLHS.cannotBeOrderedGreaterThanZero())
5543         Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask);
5544 
5545       // See if we can be more aggressive about the sign of 0.
5546       if (KnownLHS.isKnownNever(fcNegative))
5547         Known.knownNot(fcNegative);
5548       if (KnownLHS.isKnownNever(fcPositive))
5549         Known.knownNot(fcPositive);
5550     }
5551 
5552     break;
5553   }
5554   case Instruction::FPExt: {
5555     // Infinity, nan and zero propagate from source.
5556     computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses,
5557                         Known, Q, Depth + 1);
5558 
5559     const fltSemantics &DstTy =
5560         Op->getType()->getScalarType()->getFltSemantics();
5561     const fltSemantics &SrcTy =
5562         Op->getOperand(0)->getType()->getScalarType()->getFltSemantics();
5563 
5564     // All subnormal inputs should be in the normal range in the result type.
5565     if (APFloat::isRepresentableAsNormalIn(SrcTy, DstTy)) {
5566       if (Known.KnownFPClasses & fcPosSubnormal)
5567         Known.KnownFPClasses |= fcPosNormal;
5568       if (Known.KnownFPClasses & fcNegSubnormal)
5569         Known.KnownFPClasses |= fcNegNormal;
5570       Known.knownNot(fcSubnormal);
5571     }
5572 
5573     // Sign bit of a nan isn't guaranteed.
5574     if (!Known.isKnownNeverNaN())
5575       Known.SignBit = std::nullopt;
5576     break;
5577   }
5578   case Instruction::FPTrunc: {
5579     computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known, Q,
5580                                   Depth);
5581     break;
5582   }
5583   case Instruction::SIToFP:
5584   case Instruction::UIToFP: {
5585     // Cannot produce nan
5586     Known.knownNot(fcNan);
5587 
5588     // Integers cannot be subnormal
5589     Known.knownNot(fcSubnormal);
5590 
5591     // sitofp and uitofp turn into +0.0 for zero.
5592     Known.knownNot(fcNegZero);
5593     if (Op->getOpcode() == Instruction::UIToFP)
5594       Known.signBitMustBeZero();
5595 
5596     if (InterestedClasses & fcInf) {
5597       // Get width of largest magnitude integer (remove a bit if signed).
5598       // This still works for a signed minimum value because the largest FP
5599       // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx).
5600       int IntSize = Op->getOperand(0)->getType()->getScalarSizeInBits();
5601       if (Op->getOpcode() == Instruction::SIToFP)
5602         --IntSize;
5603 
5604       // If the exponent of the largest finite FP value can hold the largest
5605       // integer, the result of the cast must be finite.
5606       Type *FPTy = Op->getType()->getScalarType();
5607       if (ilogb(APFloat::getLargest(FPTy->getFltSemantics())) >= IntSize)
5608         Known.knownNot(fcInf);
5609     }
5610 
5611     break;
5612   }
5613   case Instruction::ExtractElement: {
5614     // Look through extract element. If the index is non-constant or
5615     // out-of-range demand all elements, otherwise just the extracted element.
5616     const Value *Vec = Op->getOperand(0);
5617 
5618     APInt DemandedVecElts;
5619     if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) {
5620       unsigned NumElts = VecTy->getNumElements();
5621       DemandedVecElts = APInt::getAllOnes(NumElts);
5622       auto *CIdx = dyn_cast<ConstantInt>(Op->getOperand(1));
5623       if (CIdx && CIdx->getValue().ult(NumElts))
5624         DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
5625     } else {
5626       DemandedVecElts = APInt(1, 1);
5627     }
5628 
5629     return computeKnownFPClass(Vec, DemandedVecElts, InterestedClasses, Known,
5630                                Q, Depth + 1);
5631   }
5632   case Instruction::InsertElement: {
5633     if (isa<ScalableVectorType>(Op->getType()))
5634       return;
5635 
5636     const Value *Vec = Op->getOperand(0);
5637     const Value *Elt = Op->getOperand(1);
5638     auto *CIdx = dyn_cast<ConstantInt>(Op->getOperand(2));
5639     unsigned NumElts = DemandedElts.getBitWidth();
5640     APInt DemandedVecElts = DemandedElts;
5641     bool NeedsElt = true;
5642     // If we know the index we are inserting to, clear it from Vec check.
5643     if (CIdx && CIdx->getValue().ult(NumElts)) {
5644       DemandedVecElts.clearBit(CIdx->getZExtValue());
5645       NeedsElt = DemandedElts[CIdx->getZExtValue()];
5646     }
5647 
5648     // Do we demand the inserted element?
5649     if (NeedsElt) {
5650       computeKnownFPClass(Elt, Known, InterestedClasses, Q, Depth + 1);
5651       // If we don't know any bits, early out.
5652       if (Known.isUnknown())
5653         break;
5654     } else {
5655       Known.KnownFPClasses = fcNone;
5656     }
5657 
5658     // Do we need anymore elements from Vec?
5659     if (!DemandedVecElts.isZero()) {
5660       KnownFPClass Known2;
5661       computeKnownFPClass(Vec, DemandedVecElts, InterestedClasses, Known2, Q,
5662                           Depth + 1);
5663       Known |= Known2;
5664     }
5665 
5666     break;
5667   }
5668   case Instruction::ShuffleVector: {
5669     // For undef elements, we don't know anything about the common state of
5670     // the shuffle result.
5671     APInt DemandedLHS, DemandedRHS;
5672     auto *Shuf = dyn_cast<ShuffleVectorInst>(Op);
5673     if (!Shuf || !getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
5674       return;
5675 
5676     if (!!DemandedLHS) {
5677       const Value *LHS = Shuf->getOperand(0);
5678       computeKnownFPClass(LHS, DemandedLHS, InterestedClasses, Known, Q,
5679                           Depth + 1);
5680 
5681       // If we don't know any bits, early out.
5682       if (Known.isUnknown())
5683         break;
5684     } else {
5685       Known.KnownFPClasses = fcNone;
5686     }
5687 
5688     if (!!DemandedRHS) {
5689       KnownFPClass Known2;
5690       const Value *RHS = Shuf->getOperand(1);
5691       computeKnownFPClass(RHS, DemandedRHS, InterestedClasses, Known2, Q,
5692                           Depth + 1);
5693       Known |= Known2;
5694     }
5695 
5696     break;
5697   }
5698   case Instruction::ExtractValue: {
5699     const ExtractValueInst *Extract = cast<ExtractValueInst>(Op);
5700     ArrayRef<unsigned> Indices = Extract->getIndices();
5701     const Value *Src = Extract->getAggregateOperand();
5702     if (isa<StructType>(Src->getType()) && Indices.size() == 1 &&
5703         Indices[0] == 0) {
5704       if (const auto *II = dyn_cast<IntrinsicInst>(Src)) {
5705         switch (II->getIntrinsicID()) {
5706         case Intrinsic::frexp: {
5707           Known.knownNot(fcSubnormal);
5708 
5709           KnownFPClass KnownSrc;
5710           computeKnownFPClass(II->getArgOperand(0), DemandedElts,
5711                               InterestedClasses, KnownSrc, Q, Depth + 1);
5712 
5713           const Function *F = cast<Instruction>(Op)->getFunction();
5714           const fltSemantics &FltSem =
5715               Op->getType()->getScalarType()->getFltSemantics();
5716 
5717           if (KnownSrc.isKnownNever(fcNegative))
5718             Known.knownNot(fcNegative);
5719           else {
5720             if (F &&
5721                 KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem)))
5722               Known.knownNot(fcNegZero);
5723             if (KnownSrc.isKnownNever(fcNegInf))
5724               Known.knownNot(fcNegInf);
5725           }
5726 
5727           if (KnownSrc.isKnownNever(fcPositive))
5728             Known.knownNot(fcPositive);
5729           else {
5730             if (F &&
5731                 KnownSrc.isKnownNeverLogicalPosZero(F->getDenormalMode(FltSem)))
5732               Known.knownNot(fcPosZero);
5733             if (KnownSrc.isKnownNever(fcPosInf))
5734               Known.knownNot(fcPosInf);
5735           }
5736 
5737           Known.propagateNaN(KnownSrc);
5738           return;
5739         }
5740         default:
5741           break;
5742         }
5743       }
5744     }
5745 
5746     computeKnownFPClass(Src, DemandedElts, InterestedClasses, Known, Q,
5747                         Depth + 1);
5748     break;
5749   }
5750   case Instruction::PHI: {
5751     const PHINode *P = cast<PHINode>(Op);
5752     // Unreachable blocks may have zero-operand PHI nodes.
5753     if (P->getNumIncomingValues() == 0)
5754       break;
5755 
5756     // Otherwise take the unions of the known bit sets of the operands,
5757     // taking conservative care to avoid excessive recursion.
5758     const unsigned PhiRecursionLimit = MaxAnalysisRecursionDepth - 2;
5759 
5760     if (Depth < PhiRecursionLimit) {
5761       // Skip if every incoming value references to ourself.
5762       if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
5763         break;
5764 
5765       bool First = true;
5766 
5767       for (const Use &U : P->operands()) {
5768         Value *IncValue;
5769         Instruction *CxtI;
5770         breakSelfRecursivePHI(&U, P, IncValue, CxtI);
5771         // Skip direct self references.
5772         if (IncValue == P)
5773           continue;
5774 
5775         KnownFPClass KnownSrc;
5776         // Recurse, but cap the recursion to two levels, because we don't want
5777         // to waste time spinning around in loops. We need at least depth 2 to
5778         // detect known sign bits.
5779         computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
5780                             Q.getWithoutCondContext().getWithInstruction(CxtI),
5781                             PhiRecursionLimit);
5782 
5783         if (First) {
5784           Known = KnownSrc;
5785           First = false;
5786         } else {
5787           Known |= KnownSrc;
5788         }
5789 
5790         if (Known.KnownFPClasses == fcAllFlags)
5791           break;
5792       }
5793     }
5794 
5795     break;
5796   }
5797   case Instruction::BitCast: {
5798     const Value *Src;
5799     if (!match(Op, m_ElementWiseBitCast(m_Value(Src))) ||
5800         !Src->getType()->isIntOrIntVectorTy())
5801       break;
5802 
5803     const Type *Ty = Op->getType()->getScalarType();
5804     KnownBits Bits(Ty->getScalarSizeInBits());
5805     computeKnownBits(Src, DemandedElts, Bits, Q, Depth + 1);
5806 
5807     // Transfer information from the sign bit.
5808     if (Bits.isNonNegative())
5809       Known.signBitMustBeZero();
5810     else if (Bits.isNegative())
5811       Known.signBitMustBeOne();
5812 
5813     if (Ty->isIEEELikeFPTy()) {
5814       // IEEE floats are NaN when all bits of the exponent plus at least one of
5815       // the fraction bits are 1. This means:
5816       //   - If we assume unknown bits are 0 and the value is NaN, it will
5817       //     always be NaN
5818       //   - If we assume unknown bits are 1 and the value is not NaN, it can
5819       //     never be NaN
5820       // Note: They do not hold for x86_fp80 format.
5821       if (APFloat(Ty->getFltSemantics(), Bits.One).isNaN())
5822         Known.KnownFPClasses = fcNan;
5823       else if (!APFloat(Ty->getFltSemantics(), ~Bits.Zero).isNaN())
5824         Known.knownNot(fcNan);
5825 
5826       // Build KnownBits representing Inf and check if it must be equal or
5827       // unequal to this value.
5828       auto InfKB = KnownBits::makeConstant(
5829           APFloat::getInf(Ty->getFltSemantics()).bitcastToAPInt());
5830       InfKB.Zero.clearSignBit();
5831       if (const auto InfResult = KnownBits::eq(Bits, InfKB)) {
5832         assert(!InfResult.value());
5833         Known.knownNot(fcInf);
5834       } else if (Bits == InfKB) {
5835         Known.KnownFPClasses = fcInf;
5836       }
5837 
5838       // Build KnownBits representing Zero and check if it must be equal or
5839       // unequal to this value.
5840       auto ZeroKB = KnownBits::makeConstant(
5841           APFloat::getZero(Ty->getFltSemantics()).bitcastToAPInt());
5842       ZeroKB.Zero.clearSignBit();
5843       if (const auto ZeroResult = KnownBits::eq(Bits, ZeroKB)) {
5844         assert(!ZeroResult.value());
5845         Known.knownNot(fcZero);
5846       } else if (Bits == ZeroKB) {
5847         Known.KnownFPClasses = fcZero;
5848       }
5849     }
5850 
5851     break;
5852   }
5853   default:
5854     break;
5855   }
5856 }
5857 
5858 KnownFPClass llvm::computeKnownFPClass(const Value *V,
5859                                        const APInt &DemandedElts,
5860                                        FPClassTest InterestedClasses,
5861                                        const SimplifyQuery &SQ,
5862                                        unsigned Depth) {
5863   KnownFPClass KnownClasses;
5864   ::computeKnownFPClass(V, DemandedElts, InterestedClasses, KnownClasses, SQ,
5865                         Depth);
5866   return KnownClasses;
5867 }
5868 
5869 KnownFPClass llvm::computeKnownFPClass(const Value *V,
5870                                        FPClassTest InterestedClasses,
5871                                        const SimplifyQuery &SQ,
5872                                        unsigned Depth) {
5873   KnownFPClass Known;
5874   ::computeKnownFPClass(V, Known, InterestedClasses, SQ, Depth);
5875   return Known;
5876 }
5877 
5878 KnownFPClass llvm::computeKnownFPClass(
5879     const Value *V, const DataLayout &DL, FPClassTest InterestedClasses,
5880     const TargetLibraryInfo *TLI, AssumptionCache *AC, const Instruction *CxtI,
5881     const DominatorTree *DT, bool UseInstrInfo, unsigned Depth) {
5882   return computeKnownFPClass(V, InterestedClasses,
5883                              SimplifyQuery(DL, TLI, DT, AC, CxtI, UseInstrInfo),
5884                              Depth);
5885 }
5886 
5887 KnownFPClass
5888 llvm::computeKnownFPClass(const Value *V, const APInt &DemandedElts,
5889                           FastMathFlags FMF, FPClassTest InterestedClasses,
5890                           const SimplifyQuery &SQ, unsigned Depth) {
5891   if (FMF.noNaNs())
5892     InterestedClasses &= ~fcNan;
5893   if (FMF.noInfs())
5894     InterestedClasses &= ~fcInf;
5895 
5896   KnownFPClass Result =
5897       computeKnownFPClass(V, DemandedElts, InterestedClasses, SQ, Depth);
5898 
5899   if (FMF.noNaNs())
5900     Result.KnownFPClasses &= ~fcNan;
5901   if (FMF.noInfs())
5902     Result.KnownFPClasses &= ~fcInf;
5903   return Result;
5904 }
5905 
5906 KnownFPClass llvm::computeKnownFPClass(const Value *V, FastMathFlags FMF,
5907                                        FPClassTest InterestedClasses,
5908                                        const SimplifyQuery &SQ,
5909                                        unsigned Depth) {
5910   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
5911   APInt DemandedElts =
5912       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
5913   return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, SQ,
5914                              Depth);
5915 }
5916 
5917 bool llvm::cannotBeNegativeZero(const Value *V, const SimplifyQuery &SQ,
5918                                 unsigned Depth) {
5919   KnownFPClass Known = computeKnownFPClass(V, fcNegZero, SQ, Depth);
5920   return Known.isKnownNeverNegZero();
5921 }
5922 
5923 bool llvm::cannotBeOrderedLessThanZero(const Value *V, const SimplifyQuery &SQ,
5924                                        unsigned Depth) {
5925   KnownFPClass Known =
5926       computeKnownFPClass(V, KnownFPClass::OrderedLessThanZeroMask, SQ, Depth);
5927   return Known.cannotBeOrderedLessThanZero();
5928 }
5929 
5930 bool llvm::isKnownNeverInfinity(const Value *V, const SimplifyQuery &SQ,
5931                                 unsigned Depth) {
5932   KnownFPClass Known = computeKnownFPClass(V, fcInf, SQ, Depth);
5933   return Known.isKnownNeverInfinity();
5934 }
5935 
5936 /// Return true if the floating-point value can never contain a NaN or infinity.
5937 bool llvm::isKnownNeverInfOrNaN(const Value *V, const SimplifyQuery &SQ,
5938                                 unsigned Depth) {
5939   KnownFPClass Known = computeKnownFPClass(V, fcInf | fcNan, SQ, Depth);
5940   return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity();
5941 }
5942 
5943 /// Return true if the floating-point scalar value is not a NaN or if the
5944 /// floating-point vector value has no NaN elements. Return false if a value
5945 /// could ever be NaN.
5946 bool llvm::isKnownNeverNaN(const Value *V, const SimplifyQuery &SQ,
5947                            unsigned Depth) {
5948   KnownFPClass Known = computeKnownFPClass(V, fcNan, SQ, Depth);
5949   return Known.isKnownNeverNaN();
5950 }
5951 
5952 /// Return false if we can prove that the specified FP value's sign bit is 0.
5953 /// Return true if we can prove that the specified FP value's sign bit is 1.
5954 /// Otherwise return std::nullopt.
5955 std::optional<bool> llvm::computeKnownFPSignBit(const Value *V,
5956                                                 const SimplifyQuery &SQ,
5957                                                 unsigned Depth) {
5958   KnownFPClass Known = computeKnownFPClass(V, fcAllFlags, SQ, Depth);
5959   return Known.SignBit;
5960 }
5961 
5962 bool llvm::canIgnoreSignBitOfZero(const Use &U) {
5963   auto *User = cast<Instruction>(U.getUser());
5964   if (auto *FPOp = dyn_cast<FPMathOperator>(User)) {
5965     if (FPOp->hasNoSignedZeros())
5966       return true;
5967   }
5968 
5969   switch (User->getOpcode()) {
5970   case Instruction::FPToSI:
5971   case Instruction::FPToUI:
5972     return true;
5973   case Instruction::FCmp:
5974     // fcmp treats both positive and negative zero as equal.
5975     return true;
5976   case Instruction::Call:
5977     if (auto *II = dyn_cast<IntrinsicInst>(User)) {
5978       switch (II->getIntrinsicID()) {
5979       case Intrinsic::fabs:
5980         return true;
5981       case Intrinsic::copysign:
5982         return U.getOperandNo() == 0;
5983       case Intrinsic::is_fpclass:
5984       case Intrinsic::vp_is_fpclass: {
5985         auto Test =
5986             static_cast<FPClassTest>(
5987                 cast<ConstantInt>(II->getArgOperand(1))->getZExtValue()) &
5988             FPClassTest::fcZero;
5989         return Test == FPClassTest::fcZero || Test == FPClassTest::fcNone;
5990       }
5991       default:
5992         return false;
5993       }
5994     }
5995     return false;
5996   default:
5997     return false;
5998   }
5999 }
6000 
6001 bool llvm::canIgnoreSignBitOfNaN(const Use &U) {
6002   auto *User = cast<Instruction>(U.getUser());
6003   if (auto *FPOp = dyn_cast<FPMathOperator>(User)) {
6004     if (FPOp->hasNoNaNs())
6005       return true;
6006   }
6007 
6008   switch (User->getOpcode()) {
6009   case Instruction::FPToSI:
6010   case Instruction::FPToUI:
6011     return true;
6012   // Proper FP math operations ignore the sign bit of NaN.
6013   case Instruction::FAdd:
6014   case Instruction::FSub:
6015   case Instruction::FMul:
6016   case Instruction::FDiv:
6017   case Instruction::FRem:
6018   case Instruction::FPTrunc:
6019   case Instruction::FPExt:
6020   case Instruction::FCmp:
6021     return true;
6022   // Bitwise FP operations should preserve the sign bit of NaN.
6023   case Instruction::FNeg:
6024   case Instruction::Select:
6025   case Instruction::PHI:
6026     return false;
6027   case Instruction::Ret:
6028     return User->getFunction()->getAttributes().getRetNoFPClass() &
6029            FPClassTest::fcNan;
6030   case Instruction::Call:
6031   case Instruction::Invoke: {
6032     if (auto *II = dyn_cast<IntrinsicInst>(User)) {
6033       switch (II->getIntrinsicID()) {
6034       case Intrinsic::fabs:
6035         return true;
6036       case Intrinsic::copysign:
6037         return U.getOperandNo() == 0;
6038       // Other proper FP math intrinsics ignore the sign bit of NaN.
6039       case Intrinsic::maxnum:
6040       case Intrinsic::minnum:
6041       case Intrinsic::maximum:
6042       case Intrinsic::minimum:
6043       case Intrinsic::maximumnum:
6044       case Intrinsic::minimumnum:
6045       case Intrinsic::canonicalize:
6046       case Intrinsic::fma:
6047       case Intrinsic::fmuladd:
6048       case Intrinsic::sqrt:
6049       case Intrinsic::pow:
6050       case Intrinsic::powi:
6051       case Intrinsic::fptoui_sat:
6052       case Intrinsic::fptosi_sat:
6053       case Intrinsic::is_fpclass:
6054       case Intrinsic::vp_is_fpclass:
6055         return true;
6056       default:
6057         return false;
6058       }
6059     }
6060 
6061     FPClassTest NoFPClass =
6062         cast<CallBase>(User)->getParamNoFPClass(U.getOperandNo());
6063     return NoFPClass & FPClassTest::fcNan;
6064   }
6065   default:
6066     return false;
6067   }
6068 }
6069 
6070 Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
6071 
6072   // All byte-wide stores are splatable, even of arbitrary variables.
6073   if (V->getType()->isIntegerTy(8))
6074     return V;
6075 
6076   LLVMContext &Ctx = V->getContext();
6077 
6078   // Undef don't care.
6079   auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx));
6080   if (isa<UndefValue>(V))
6081     return UndefInt8;
6082 
6083   // Return poison for zero-sized type.
6084   if (DL.getTypeStoreSize(V->getType()).isZero())
6085     return PoisonValue::get(Type::getInt8Ty(Ctx));
6086 
6087   Constant *C = dyn_cast<Constant>(V);
6088   if (!C) {
6089     // Conceptually, we could handle things like:
6090     //   %a = zext i8 %X to i16
6091     //   %b = shl i16 %a, 8
6092     //   %c = or i16 %a, %b
6093     // but until there is an example that actually needs this, it doesn't seem
6094     // worth worrying about.
6095     return nullptr;
6096   }
6097 
6098   // Handle 'null' ConstantArrayZero etc.
6099   if (C->isNullValue())
6100     return Constant::getNullValue(Type::getInt8Ty(Ctx));
6101 
6102   // Constant floating-point values can be handled as integer values if the
6103   // corresponding integer value is "byteable".  An important case is 0.0.
6104   if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
6105     Type *Ty = nullptr;
6106     if (CFP->getType()->isHalfTy())
6107       Ty = Type::getInt16Ty(Ctx);
6108     else if (CFP->getType()->isFloatTy())
6109       Ty = Type::getInt32Ty(Ctx);
6110     else if (CFP->getType()->isDoubleTy())
6111       Ty = Type::getInt64Ty(Ctx);
6112     // Don't handle long double formats, which have strange constraints.
6113     return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty), DL)
6114               : nullptr;
6115   }
6116 
6117   // We can handle constant integers that are multiple of 8 bits.
6118   if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
6119     if (CI->getBitWidth() % 8 == 0) {
6120       assert(CI->getBitWidth() > 8 && "8 bits should be handled above!");
6121       if (!CI->getValue().isSplat(8))
6122         return nullptr;
6123       return ConstantInt::get(Ctx, CI->getValue().trunc(8));
6124     }
6125   }
6126 
6127   if (auto *CE = dyn_cast<ConstantExpr>(C)) {
6128     if (CE->getOpcode() == Instruction::IntToPtr) {
6129       if (auto *PtrTy = dyn_cast<PointerType>(CE->getType())) {
6130         unsigned BitWidth = DL.getPointerSizeInBits(PtrTy->getAddressSpace());
6131         if (Constant *Op = ConstantFoldIntegerCast(
6132                 CE->getOperand(0), Type::getIntNTy(Ctx, BitWidth), false, DL))
6133           return isBytewiseValue(Op, DL);
6134       }
6135     }
6136   }
6137 
6138   auto Merge = [&](Value *LHS, Value *RHS) -> Value * {
6139     if (LHS == RHS)
6140       return LHS;
6141     if (!LHS || !RHS)
6142       return nullptr;
6143     if (LHS == UndefInt8)
6144       return RHS;
6145     if (RHS == UndefInt8)
6146       return LHS;
6147     return nullptr;
6148   };
6149 
6150   if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(C)) {
6151     Value *Val = UndefInt8;
6152     for (uint64_t I = 0, E = CA->getNumElements(); I != E; ++I)
6153       if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I), DL))))
6154         return nullptr;
6155     return Val;
6156   }
6157 
6158   if (isa<ConstantAggregate>(C)) {
6159     Value *Val = UndefInt8;
6160     for (Value *Op : C->operands())
6161       if (!(Val = Merge(Val, isBytewiseValue(Op, DL))))
6162         return nullptr;
6163     return Val;
6164   }
6165 
6166   // Don't try to handle the handful of other constants.
6167   return nullptr;
6168 }
6169 
6170 // This is the recursive version of BuildSubAggregate. It takes a few different
6171 // arguments. Idxs is the index within the nested struct From that we are
6172 // looking at now (which is of type IndexedType). IdxSkip is the number of
6173 // indices from Idxs that should be left out when inserting into the resulting
6174 // struct. To is the result struct built so far, new insertvalue instructions
6175 // build on that.
6176 static Value *BuildSubAggregate(Value *From, Value *To, Type *IndexedType,
6177                                 SmallVectorImpl<unsigned> &Idxs,
6178                                 unsigned IdxSkip,
6179                                 BasicBlock::iterator InsertBefore) {
6180   StructType *STy = dyn_cast<StructType>(IndexedType);
6181   if (STy) {
6182     // Save the original To argument so we can modify it
6183     Value *OrigTo = To;
6184     // General case, the type indexed by Idxs is a struct
6185     for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
6186       // Process each struct element recursively
6187       Idxs.push_back(i);
6188       Value *PrevTo = To;
6189       To = BuildSubAggregate(From, To, STy->getElementType(i), Idxs, IdxSkip,
6190                              InsertBefore);
6191       Idxs.pop_back();
6192       if (!To) {
6193         // Couldn't find any inserted value for this index? Cleanup
6194         while (PrevTo != OrigTo) {
6195           InsertValueInst* Del = cast<InsertValueInst>(PrevTo);
6196           PrevTo = Del->getAggregateOperand();
6197           Del->eraseFromParent();
6198         }
6199         // Stop processing elements
6200         break;
6201       }
6202     }
6203     // If we successfully found a value for each of our subaggregates
6204     if (To)
6205       return To;
6206   }
6207   // Base case, the type indexed by SourceIdxs is not a struct, or not all of
6208   // the struct's elements had a value that was inserted directly. In the latter
6209   // case, perhaps we can't determine each of the subelements individually, but
6210   // we might be able to find the complete struct somewhere.
6211 
6212   // Find the value that is at that particular spot
6213   Value *V = FindInsertedValue(From, Idxs);
6214 
6215   if (!V)
6216     return nullptr;
6217 
6218   // Insert the value in the new (sub) aggregate
6219   return InsertValueInst::Create(To, V, ArrayRef(Idxs).slice(IdxSkip), "tmp",
6220                                  InsertBefore);
6221 }
6222 
6223 // This helper takes a nested struct and extracts a part of it (which is again a
6224 // struct) into a new value. For example, given the struct:
6225 // { a, { b, { c, d }, e } }
6226 // and the indices "1, 1" this returns
6227 // { c, d }.
6228 //
6229 // It does this by inserting an insertvalue for each element in the resulting
6230 // struct, as opposed to just inserting a single struct. This will only work if
6231 // each of the elements of the substruct are known (ie, inserted into From by an
6232 // insertvalue instruction somewhere).
6233 //
6234 // All inserted insertvalue instructions are inserted before InsertBefore
6235 static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range,
6236                                 BasicBlock::iterator InsertBefore) {
6237   Type *IndexedType = ExtractValueInst::getIndexedType(From->getType(),
6238                                                              idx_range);
6239   Value *To = PoisonValue::get(IndexedType);
6240   SmallVector<unsigned, 10> Idxs(idx_range);
6241   unsigned IdxSkip = Idxs.size();
6242 
6243   return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore);
6244 }
6245 
6246 /// Given an aggregate and a sequence of indices, see if the scalar value
6247 /// indexed is already around as a register, for example if it was inserted
6248 /// directly into the aggregate.
6249 ///
6250 /// If InsertBefore is not null, this function will duplicate (modified)
6251 /// insertvalues when a part of a nested struct is extracted.
6252 Value *
6253 llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range,
6254                         std::optional<BasicBlock::iterator> InsertBefore) {
6255   // Nothing to index? Just return V then (this is useful at the end of our
6256   // recursion).
6257   if (idx_range.empty())
6258     return V;
6259   // We have indices, so V should have an indexable type.
6260   assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) &&
6261          "Not looking at a struct or array?");
6262   assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) &&
6263          "Invalid indices for type?");
6264 
6265   if (Constant *C = dyn_cast<Constant>(V)) {
6266     C = C->getAggregateElement(idx_range[0]);
6267     if (!C) return nullptr;
6268     return FindInsertedValue(C, idx_range.slice(1), InsertBefore);
6269   }
6270 
6271   if (InsertValueInst *I = dyn_cast<InsertValueInst>(V)) {
6272     // Loop the indices for the insertvalue instruction in parallel with the
6273     // requested indices
6274     const unsigned *req_idx = idx_range.begin();
6275     for (const unsigned *i = I->idx_begin(), *e = I->idx_end();
6276          i != e; ++i, ++req_idx) {
6277       if (req_idx == idx_range.end()) {
6278         // We can't handle this without inserting insertvalues
6279         if (!InsertBefore)
6280           return nullptr;
6281 
6282         // The requested index identifies a part of a nested aggregate. Handle
6283         // this specially. For example,
6284         // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0
6285         // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1
6286         // %C = extractvalue {i32, { i32, i32 } } %B, 1
6287         // This can be changed into
6288         // %A = insertvalue {i32, i32 } undef, i32 10, 0
6289         // %C = insertvalue {i32, i32 } %A, i32 11, 1
6290         // which allows the unused 0,0 element from the nested struct to be
6291         // removed.
6292         return BuildSubAggregate(V, ArrayRef(idx_range.begin(), req_idx),
6293                                  *InsertBefore);
6294       }
6295 
6296       // This insert value inserts something else than what we are looking for.
6297       // See if the (aggregate) value inserted into has the value we are
6298       // looking for, then.
6299       if (*req_idx != *i)
6300         return FindInsertedValue(I->getAggregateOperand(), idx_range,
6301                                  InsertBefore);
6302     }
6303     // If we end up here, the indices of the insertvalue match with those
6304     // requested (though possibly only partially). Now we recursively look at
6305     // the inserted value, passing any remaining indices.
6306     return FindInsertedValue(I->getInsertedValueOperand(),
6307                              ArrayRef(req_idx, idx_range.end()), InsertBefore);
6308   }
6309 
6310   if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(V)) {
6311     // If we're extracting a value from an aggregate that was extracted from
6312     // something else, we can extract from that something else directly instead.
6313     // However, we will need to chain I's indices with the requested indices.
6314 
6315     // Calculate the number of indices required
6316     unsigned size = I->getNumIndices() + idx_range.size();
6317     // Allocate some space to put the new indices in
6318     SmallVector<unsigned, 5> Idxs;
6319     Idxs.reserve(size);
6320     // Add indices from the extract value instruction
6321     Idxs.append(I->idx_begin(), I->idx_end());
6322 
6323     // Add requested indices
6324     Idxs.append(idx_range.begin(), idx_range.end());
6325 
6326     assert(Idxs.size() == size
6327            && "Number of indices added not correct?");
6328 
6329     return FindInsertedValue(I->getAggregateOperand(), Idxs, InsertBefore);
6330   }
6331   // Otherwise, we don't know (such as, extracting from a function return value
6332   // or load instruction)
6333   return nullptr;
6334 }
6335 
6336 bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP,
6337                                        unsigned CharSize) {
6338   // Make sure the GEP has exactly three arguments.
6339   if (GEP->getNumOperands() != 3)
6340     return false;
6341 
6342   // Make sure the index-ee is a pointer to array of \p CharSize integers.
6343   // CharSize.
6344   ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType());
6345   if (!AT || !AT->getElementType()->isIntegerTy(CharSize))
6346     return false;
6347 
6348   // Check to make sure that the first operand of the GEP is an integer and
6349   // has value 0 so that we are sure we're indexing into the initializer.
6350   const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1));
6351   if (!FirstIdx || !FirstIdx->isZero())
6352     return false;
6353 
6354   return true;
6355 }
6356 
6357 // If V refers to an initialized global constant, set Slice either to
6358 // its initializer if the size of its elements equals ElementSize, or,
6359 // for ElementSize == 8, to its representation as an array of unsiged
6360 // char. Return true on success.
6361 // Offset is in the unit "nr of ElementSize sized elements".
6362 bool llvm::getConstantDataArrayInfo(const Value *V,
6363                                     ConstantDataArraySlice &Slice,
6364                                     unsigned ElementSize, uint64_t Offset) {
6365   assert(V && "V should not be null.");
6366   assert((ElementSize % 8) == 0 &&
6367          "ElementSize expected to be a multiple of the size of a byte.");
6368   unsigned ElementSizeInBytes = ElementSize / 8;
6369 
6370   // Drill down into the pointer expression V, ignoring any intervening
6371   // casts, and determine the identity of the object it references along
6372   // with the cumulative byte offset into it.
6373   const GlobalVariable *GV =
6374     dyn_cast<GlobalVariable>(getUnderlyingObject(V));
6375   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
6376     // Fail if V is not based on constant global object.
6377     return false;
6378 
6379   const DataLayout &DL = GV->getDataLayout();
6380   APInt Off(DL.getIndexTypeSizeInBits(V->getType()), 0);
6381 
6382   if (GV != V->stripAndAccumulateConstantOffsets(DL, Off,
6383                                                  /*AllowNonInbounds*/ true))
6384     // Fail if a constant offset could not be determined.
6385     return false;
6386 
6387   uint64_t StartIdx = Off.getLimitedValue();
6388   if (StartIdx == UINT64_MAX)
6389     // Fail if the constant offset is excessive.
6390     return false;
6391 
6392   // Off/StartIdx is in the unit of bytes. So we need to convert to number of
6393   // elements. Simply bail out if that isn't possible.
6394   if ((StartIdx % ElementSizeInBytes) != 0)
6395     return false;
6396 
6397   Offset += StartIdx / ElementSizeInBytes;
6398   ConstantDataArray *Array = nullptr;
6399   ArrayType *ArrayTy = nullptr;
6400 
6401   if (GV->getInitializer()->isNullValue()) {
6402     Type *GVTy = GV->getValueType();
6403     uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy).getFixedValue();
6404     uint64_t Length = SizeInBytes / ElementSizeInBytes;
6405 
6406     Slice.Array = nullptr;
6407     Slice.Offset = 0;
6408     // Return an empty Slice for undersized constants to let callers
6409     // transform even undefined library calls into simpler, well-defined
6410     // expressions.  This is preferable to making the calls although it
6411     // prevents sanitizers from detecting such calls.
6412     Slice.Length = Length < Offset ? 0 : Length - Offset;
6413     return true;
6414   }
6415 
6416   auto *Init = const_cast<Constant *>(GV->getInitializer());
6417   if (auto *ArrayInit = dyn_cast<ConstantDataArray>(Init)) {
6418     Type *InitElTy = ArrayInit->getElementType();
6419     if (InitElTy->isIntegerTy(ElementSize)) {
6420       // If Init is an initializer for an array of the expected type
6421       // and size, use it as is.
6422       Array = ArrayInit;
6423       ArrayTy = ArrayInit->getType();
6424     }
6425   }
6426 
6427   if (!Array) {
6428     if (ElementSize != 8)
6429       // TODO: Handle conversions to larger integral types.
6430       return false;
6431 
6432     // Otherwise extract the portion of the initializer starting
6433     // at Offset as an array of bytes, and reset Offset.
6434     Init = ReadByteArrayFromGlobal(GV, Offset);
6435     if (!Init)
6436       return false;
6437 
6438     Offset = 0;
6439     Array = dyn_cast<ConstantDataArray>(Init);
6440     ArrayTy = dyn_cast<ArrayType>(Init->getType());
6441   }
6442 
6443   uint64_t NumElts = ArrayTy->getArrayNumElements();
6444   if (Offset > NumElts)
6445     return false;
6446 
6447   Slice.Array = Array;
6448   Slice.Offset = Offset;
6449   Slice.Length = NumElts - Offset;
6450   return true;
6451 }
6452 
6453 /// Extract bytes from the initializer of the constant array V, which need
6454 /// not be a nul-terminated string.  On success, store the bytes in Str and
6455 /// return true.  When TrimAtNul is set, Str will contain only the bytes up
6456 /// to but not including the first nul.  Return false on failure.
6457 bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
6458                                  bool TrimAtNul) {
6459   ConstantDataArraySlice Slice;
6460   if (!getConstantDataArrayInfo(V, Slice, 8))
6461     return false;
6462 
6463   if (Slice.Array == nullptr) {
6464     if (TrimAtNul) {
6465       // Return a nul-terminated string even for an empty Slice.  This is
6466       // safe because all existing SimplifyLibcalls callers require string
6467       // arguments and the behavior of the functions they fold is undefined
6468       // otherwise.  Folding the calls this way is preferable to making
6469       // the undefined library calls, even though it prevents sanitizers
6470       // from reporting such calls.
6471       Str = StringRef();
6472       return true;
6473     }
6474     if (Slice.Length == 1) {
6475       Str = StringRef("", 1);
6476       return true;
6477     }
6478     // We cannot instantiate a StringRef as we do not have an appropriate string
6479     // of 0s at hand.
6480     return false;
6481   }
6482 
6483   // Start out with the entire array in the StringRef.
6484   Str = Slice.Array->getAsString();
6485   // Skip over 'offset' bytes.
6486   Str = Str.substr(Slice.Offset);
6487 
6488   if (TrimAtNul) {
6489     // Trim off the \0 and anything after it.  If the array is not nul
6490     // terminated, we just return the whole end of string.  The client may know
6491     // some other way that the string is length-bound.
6492     Str = Str.substr(0, Str.find('\0'));
6493   }
6494   return true;
6495 }
6496 
6497 // These next two are very similar to the above, but also look through PHI
6498 // nodes.
6499 // TODO: See if we can integrate these two together.
6500 
6501 /// If we can compute the length of the string pointed to by
6502 /// the specified pointer, return 'len+1'.  If we can't, return 0.
6503 static uint64_t GetStringLengthH(const Value *V,
6504                                  SmallPtrSetImpl<const PHINode*> &PHIs,
6505                                  unsigned CharSize) {
6506   // Look through noop bitcast instructions.
6507   V = V->stripPointerCasts();
6508 
6509   // If this is a PHI node, there are two cases: either we have already seen it
6510   // or we haven't.
6511   if (const PHINode *PN = dyn_cast<PHINode>(V)) {
6512     if (!PHIs.insert(PN).second)
6513       return ~0ULL;  // already in the set.
6514 
6515     // If it was new, see if all the input strings are the same length.
6516     uint64_t LenSoFar = ~0ULL;
6517     for (Value *IncValue : PN->incoming_values()) {
6518       uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize);
6519       if (Len == 0) return 0; // Unknown length -> unknown.
6520 
6521       if (Len == ~0ULL) continue;
6522 
6523       if (Len != LenSoFar && LenSoFar != ~0ULL)
6524         return 0;    // Disagree -> unknown.
6525       LenSoFar = Len;
6526     }
6527 
6528     // Success, all agree.
6529     return LenSoFar;
6530   }
6531 
6532   // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
6533   if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
6534     uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize);
6535     if (Len1 == 0) return 0;
6536     uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize);
6537     if (Len2 == 0) return 0;
6538     if (Len1 == ~0ULL) return Len2;
6539     if (Len2 == ~0ULL) return Len1;
6540     if (Len1 != Len2) return 0;
6541     return Len1;
6542   }
6543 
6544   // Otherwise, see if we can read the string.
6545   ConstantDataArraySlice Slice;
6546   if (!getConstantDataArrayInfo(V, Slice, CharSize))
6547     return 0;
6548 
6549   if (Slice.Array == nullptr)
6550     // Zeroinitializer (including an empty one).
6551     return 1;
6552 
6553   // Search for the first nul character.  Return a conservative result even
6554   // when there is no nul.  This is safe since otherwise the string function
6555   // being folded such as strlen is undefined, and can be preferable to
6556   // making the undefined library call.
6557   unsigned NullIndex = 0;
6558   for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
6559     if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0)
6560       break;
6561   }
6562 
6563   return NullIndex + 1;
6564 }
6565 
6566 /// If we can compute the length of the string pointed to by
6567 /// the specified pointer, return 'len+1'.  If we can't, return 0.
6568 uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
6569   if (!V->getType()->isPointerTy())
6570     return 0;
6571 
6572   SmallPtrSet<const PHINode*, 32> PHIs;
6573   uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
6574   // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
6575   // an empty string as a length.
6576   return Len == ~0ULL ? 1 : Len;
6577 }
6578 
6579 const Value *
6580 llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call,
6581                                            bool MustPreserveNullness) {
6582   assert(Call &&
6583          "getArgumentAliasingToReturnedPointer only works on nonnull calls");
6584   if (const Value *RV = Call->getReturnedArgOperand())
6585     return RV;
6586   // This can be used only as a aliasing property.
6587   if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6588           Call, MustPreserveNullness))
6589     return Call->getArgOperand(0);
6590   return nullptr;
6591 }
6592 
6593 bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6594     const CallBase *Call, bool MustPreserveNullness) {
6595   switch (Call->getIntrinsicID()) {
6596   case Intrinsic::launder_invariant_group:
6597   case Intrinsic::strip_invariant_group:
6598   case Intrinsic::aarch64_irg:
6599   case Intrinsic::aarch64_tagp:
6600   // The amdgcn_make_buffer_rsrc function does not alter the address of the
6601   // input pointer (and thus preserve null-ness for the purposes of escape
6602   // analysis, which is where the MustPreserveNullness flag comes in to play).
6603   // However, it will not necessarily map ptr addrspace(N) null to ptr
6604   // addrspace(8) null, aka the "null descriptor", which has "all loads return
6605   // 0, all stores are dropped" semantics. Given the context of this intrinsic
6606   // list, no one should be relying on such a strict interpretation of
6607   // MustPreserveNullness (and, at time of writing, they are not), but we
6608   // document this fact out of an abundance of caution.
6609   case Intrinsic::amdgcn_make_buffer_rsrc:
6610     return true;
6611   case Intrinsic::ptrmask:
6612     return !MustPreserveNullness;
6613   case Intrinsic::threadlocal_address:
6614     // The underlying variable changes with thread ID. The Thread ID may change
6615     // at coroutine suspend points.
6616     return !Call->getParent()->getParent()->isPresplitCoroutine();
6617   default:
6618     return false;
6619   }
6620 }
6621 
6622 /// \p PN defines a loop-variant pointer to an object.  Check if the
6623 /// previous iteration of the loop was referring to the same object as \p PN.
6624 static bool isSameUnderlyingObjectInLoop(const PHINode *PN,
6625                                          const LoopInfo *LI) {
6626   // Find the loop-defined value.
6627   Loop *L = LI->getLoopFor(PN->getParent());
6628   if (PN->getNumIncomingValues() != 2)
6629     return true;
6630 
6631   // Find the value from previous iteration.
6632   auto *PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(0));
6633   if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L)
6634     PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(1));
6635   if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L)
6636     return true;
6637 
6638   // If a new pointer is loaded in the loop, the pointer references a different
6639   // object in every iteration.  E.g.:
6640   //    for (i)
6641   //       int *p = a[i];
6642   //       ...
6643   if (auto *Load = dyn_cast<LoadInst>(PrevValue))
6644     if (!L->isLoopInvariant(Load->getPointerOperand()))
6645       return false;
6646   return true;
6647 }
6648 
6649 const Value *llvm::getUnderlyingObject(const Value *V, unsigned MaxLookup) {
6650   for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
6651     if (auto *GEP = dyn_cast<GEPOperator>(V)) {
6652       const Value *PtrOp = GEP->getPointerOperand();
6653       if (!PtrOp->getType()->isPointerTy()) // Only handle scalar pointer base.
6654         return V;
6655       V = PtrOp;
6656     } else if (Operator::getOpcode(V) == Instruction::BitCast ||
6657                Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
6658       Value *NewV = cast<Operator>(V)->getOperand(0);
6659       if (!NewV->getType()->isPointerTy())
6660         return V;
6661       V = NewV;
6662     } else if (auto *GA = dyn_cast<GlobalAlias>(V)) {
6663       if (GA->isInterposable())
6664         return V;
6665       V = GA->getAliasee();
6666     } else {
6667       if (auto *PHI = dyn_cast<PHINode>(V)) {
6668         // Look through single-arg phi nodes created by LCSSA.
6669         if (PHI->getNumIncomingValues() == 1) {
6670           V = PHI->getIncomingValue(0);
6671           continue;
6672         }
6673       } else if (auto *Call = dyn_cast<CallBase>(V)) {
6674         // CaptureTracking can know about special capturing properties of some
6675         // intrinsics like launder.invariant.group, that can't be expressed with
6676         // the attributes, but have properties like returning aliasing pointer.
6677         // Because some analysis may assume that nocaptured pointer is not
6678         // returned from some special intrinsic (because function would have to
6679         // be marked with returns attribute), it is crucial to use this function
6680         // because it should be in sync with CaptureTracking. Not using it may
6681         // cause weird miscompilations where 2 aliasing pointers are assumed to
6682         // noalias.
6683         if (auto *RP = getArgumentAliasingToReturnedPointer(Call, false)) {
6684           V = RP;
6685           continue;
6686         }
6687       }
6688 
6689       return V;
6690     }
6691     assert(V->getType()->isPointerTy() && "Unexpected operand type!");
6692   }
6693   return V;
6694 }
6695 
6696 void llvm::getUnderlyingObjects(const Value *V,
6697                                 SmallVectorImpl<const Value *> &Objects,
6698                                 const LoopInfo *LI, unsigned MaxLookup) {
6699   SmallPtrSet<const Value *, 4> Visited;
6700   SmallVector<const Value *, 4> Worklist;
6701   Worklist.push_back(V);
6702   do {
6703     const Value *P = Worklist.pop_back_val();
6704     P = getUnderlyingObject(P, MaxLookup);
6705 
6706     if (!Visited.insert(P).second)
6707       continue;
6708 
6709     if (auto *SI = dyn_cast<SelectInst>(P)) {
6710       Worklist.push_back(SI->getTrueValue());
6711       Worklist.push_back(SI->getFalseValue());
6712       continue;
6713     }
6714 
6715     if (auto *PN = dyn_cast<PHINode>(P)) {
6716       // If this PHI changes the underlying object in every iteration of the
6717       // loop, don't look through it.  Consider:
6718       //   int **A;
6719       //   for (i) {
6720       //     Prev = Curr;     // Prev = PHI (Prev_0, Curr)
6721       //     Curr = A[i];
6722       //     *Prev, *Curr;
6723       //
6724       // Prev is tracking Curr one iteration behind so they refer to different
6725       // underlying objects.
6726       if (!LI || !LI->isLoopHeader(PN->getParent()) ||
6727           isSameUnderlyingObjectInLoop(PN, LI))
6728         append_range(Worklist, PN->incoming_values());
6729       else
6730         Objects.push_back(P);
6731       continue;
6732     }
6733 
6734     Objects.push_back(P);
6735   } while (!Worklist.empty());
6736 }
6737 
6738 const Value *llvm::getUnderlyingObjectAggressive(const Value *V) {
6739   const unsigned MaxVisited = 8;
6740 
6741   SmallPtrSet<const Value *, 8> Visited;
6742   SmallVector<const Value *, 8> Worklist;
6743   Worklist.push_back(V);
6744   const Value *Object = nullptr;
6745   // Used as fallback if we can't find a common underlying object through
6746   // recursion.
6747   bool First = true;
6748   const Value *FirstObject = getUnderlyingObject(V);
6749   do {
6750     const Value *P = Worklist.pop_back_val();
6751     P = First ? FirstObject : getUnderlyingObject(P);
6752     First = false;
6753 
6754     if (!Visited.insert(P).second)
6755       continue;
6756 
6757     if (Visited.size() == MaxVisited)
6758       return FirstObject;
6759 
6760     if (auto *SI = dyn_cast<SelectInst>(P)) {
6761       Worklist.push_back(SI->getTrueValue());
6762       Worklist.push_back(SI->getFalseValue());
6763       continue;
6764     }
6765 
6766     if (auto *PN = dyn_cast<PHINode>(P)) {
6767       append_range(Worklist, PN->incoming_values());
6768       continue;
6769     }
6770 
6771     if (!Object)
6772       Object = P;
6773     else if (Object != P)
6774       return FirstObject;
6775   } while (!Worklist.empty());
6776 
6777   return Object ? Object : FirstObject;
6778 }
6779 
6780 /// This is the function that does the work of looking through basic
6781 /// ptrtoint+arithmetic+inttoptr sequences.
6782 static const Value *getUnderlyingObjectFromInt(const Value *V) {
6783   do {
6784     if (const Operator *U = dyn_cast<Operator>(V)) {
6785       // If we find a ptrtoint, we can transfer control back to the
6786       // regular getUnderlyingObjectFromInt.
6787       if (U->getOpcode() == Instruction::PtrToInt)
6788         return U->getOperand(0);
6789       // If we find an add of a constant, a multiplied value, or a phi, it's
6790       // likely that the other operand will lead us to the base
6791       // object. We don't have to worry about the case where the
6792       // object address is somehow being computed by the multiply,
6793       // because our callers only care when the result is an
6794       // identifiable object.
6795       if (U->getOpcode() != Instruction::Add ||
6796           (!isa<ConstantInt>(U->getOperand(1)) &&
6797            Operator::getOpcode(U->getOperand(1)) != Instruction::Mul &&
6798            !isa<PHINode>(U->getOperand(1))))
6799         return V;
6800       V = U->getOperand(0);
6801     } else {
6802       return V;
6803     }
6804     assert(V->getType()->isIntegerTy() && "Unexpected operand type!");
6805   } while (true);
6806 }
6807 
6808 /// This is a wrapper around getUnderlyingObjects and adds support for basic
6809 /// ptrtoint+arithmetic+inttoptr sequences.
6810 /// It returns false if unidentified object is found in getUnderlyingObjects.
6811 bool llvm::getUnderlyingObjectsForCodeGen(const Value *V,
6812                                           SmallVectorImpl<Value *> &Objects) {
6813   SmallPtrSet<const Value *, 16> Visited;
6814   SmallVector<const Value *, 4> Working(1, V);
6815   do {
6816     V = Working.pop_back_val();
6817 
6818     SmallVector<const Value *, 4> Objs;
6819     getUnderlyingObjects(V, Objs);
6820 
6821     for (const Value *V : Objs) {
6822       if (!Visited.insert(V).second)
6823         continue;
6824       if (Operator::getOpcode(V) == Instruction::IntToPtr) {
6825         const Value *O =
6826           getUnderlyingObjectFromInt(cast<User>(V)->getOperand(0));
6827         if (O->getType()->isPointerTy()) {
6828           Working.push_back(O);
6829           continue;
6830         }
6831       }
6832       // If getUnderlyingObjects fails to find an identifiable object,
6833       // getUnderlyingObjectsForCodeGen also fails for safety.
6834       if (!isIdentifiedObject(V)) {
6835         Objects.clear();
6836         return false;
6837       }
6838       Objects.push_back(const_cast<Value *>(V));
6839     }
6840   } while (!Working.empty());
6841   return true;
6842 }
6843 
6844 AllocaInst *llvm::findAllocaForValue(Value *V, bool OffsetZero) {
6845   AllocaInst *Result = nullptr;
6846   SmallPtrSet<Value *, 4> Visited;
6847   SmallVector<Value *, 4> Worklist;
6848 
6849   auto AddWork = [&](Value *V) {
6850     if (Visited.insert(V).second)
6851       Worklist.push_back(V);
6852   };
6853 
6854   AddWork(V);
6855   do {
6856     V = Worklist.pop_back_val();
6857     assert(Visited.count(V));
6858 
6859     if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) {
6860       if (Result && Result != AI)
6861         return nullptr;
6862       Result = AI;
6863     } else if (CastInst *CI = dyn_cast<CastInst>(V)) {
6864       AddWork(CI->getOperand(0));
6865     } else if (PHINode *PN = dyn_cast<PHINode>(V)) {
6866       for (Value *IncValue : PN->incoming_values())
6867         AddWork(IncValue);
6868     } else if (auto *SI = dyn_cast<SelectInst>(V)) {
6869       AddWork(SI->getTrueValue());
6870       AddWork(SI->getFalseValue());
6871     } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
6872       if (OffsetZero && !GEP->hasAllZeroIndices())
6873         return nullptr;
6874       AddWork(GEP->getPointerOperand());
6875     } else if (CallBase *CB = dyn_cast<CallBase>(V)) {
6876       Value *Returned = CB->getReturnedArgOperand();
6877       if (Returned)
6878         AddWork(Returned);
6879       else
6880         return nullptr;
6881     } else {
6882       return nullptr;
6883     }
6884   } while (!Worklist.empty());
6885 
6886   return Result;
6887 }
6888 
6889 static bool onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
6890     const Value *V, bool AllowLifetime, bool AllowDroppable) {
6891   for (const User *U : V->users()) {
6892     const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
6893     if (!II)
6894       return false;
6895 
6896     if (AllowLifetime && II->isLifetimeStartOrEnd())
6897       continue;
6898 
6899     if (AllowDroppable && II->isDroppable())
6900       continue;
6901 
6902     return false;
6903   }
6904   return true;
6905 }
6906 
6907 bool llvm::onlyUsedByLifetimeMarkers(const Value *V) {
6908   return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
6909       V, /* AllowLifetime */ true, /* AllowDroppable */ false);
6910 }
6911 bool llvm::onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V) {
6912   return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
6913       V, /* AllowLifetime */ true, /* AllowDroppable */ true);
6914 }
6915 
6916 bool llvm::isNotCrossLaneOperation(const Instruction *I) {
6917   if (auto *II = dyn_cast<IntrinsicInst>(I))
6918     return isTriviallyVectorizable(II->getIntrinsicID());
6919   auto *Shuffle = dyn_cast<ShuffleVectorInst>(I);
6920   return (!Shuffle || Shuffle->isSelect()) &&
6921          !isa<CallBase, BitCastInst, ExtractElementInst>(I);
6922 }
6923 
6924 bool llvm::isSafeToSpeculativelyExecute(
6925     const Instruction *Inst, const Instruction *CtxI, AssumptionCache *AC,
6926     const DominatorTree *DT, const TargetLibraryInfo *TLI, bool UseVariableInfo,
6927     bool IgnoreUBImplyingAttrs) {
6928   return isSafeToSpeculativelyExecuteWithOpcode(Inst->getOpcode(), Inst, CtxI,
6929                                                 AC, DT, TLI, UseVariableInfo,
6930                                                 IgnoreUBImplyingAttrs);
6931 }
6932 
6933 bool llvm::isSafeToSpeculativelyExecuteWithOpcode(
6934     unsigned Opcode, const Instruction *Inst, const Instruction *CtxI,
6935     AssumptionCache *AC, const DominatorTree *DT, const TargetLibraryInfo *TLI,
6936     bool UseVariableInfo, bool IgnoreUBImplyingAttrs) {
6937 #ifndef NDEBUG
6938   if (Inst->getOpcode() != Opcode) {
6939     // Check that the operands are actually compatible with the Opcode override.
6940     auto hasEqualReturnAndLeadingOperandTypes =
6941         [](const Instruction *Inst, unsigned NumLeadingOperands) {
6942           if (Inst->getNumOperands() < NumLeadingOperands)
6943             return false;
6944           const Type *ExpectedType = Inst->getType();
6945           for (unsigned ItOp = 0; ItOp < NumLeadingOperands; ++ItOp)
6946             if (Inst->getOperand(ItOp)->getType() != ExpectedType)
6947               return false;
6948           return true;
6949         };
6950     assert(!Instruction::isBinaryOp(Opcode) ||
6951            hasEqualReturnAndLeadingOperandTypes(Inst, 2));
6952     assert(!Instruction::isUnaryOp(Opcode) ||
6953            hasEqualReturnAndLeadingOperandTypes(Inst, 1));
6954   }
6955 #endif
6956 
6957   switch (Opcode) {
6958   default:
6959     return true;
6960   case Instruction::UDiv:
6961   case Instruction::URem: {
6962     // x / y is undefined if y == 0.
6963     const APInt *V;
6964     if (match(Inst->getOperand(1), m_APInt(V)))
6965       return *V != 0;
6966     return false;
6967   }
6968   case Instruction::SDiv:
6969   case Instruction::SRem: {
6970     // x / y is undefined if y == 0 or x == INT_MIN and y == -1
6971     const APInt *Numerator, *Denominator;
6972     if (!match(Inst->getOperand(1), m_APInt(Denominator)))
6973       return false;
6974     // We cannot hoist this division if the denominator is 0.
6975     if (*Denominator == 0)
6976       return false;
6977     // It's safe to hoist if the denominator is not 0 or -1.
6978     if (!Denominator->isAllOnes())
6979       return true;
6980     // At this point we know that the denominator is -1.  It is safe to hoist as
6981     // long we know that the numerator is not INT_MIN.
6982     if (match(Inst->getOperand(0), m_APInt(Numerator)))
6983       return !Numerator->isMinSignedValue();
6984     // The numerator *might* be MinSignedValue.
6985     return false;
6986   }
6987   case Instruction::Load: {
6988     if (!UseVariableInfo)
6989       return false;
6990 
6991     const LoadInst *LI = dyn_cast<LoadInst>(Inst);
6992     if (!LI)
6993       return false;
6994     if (mustSuppressSpeculation(*LI))
6995       return false;
6996     const DataLayout &DL = LI->getDataLayout();
6997     return isDereferenceableAndAlignedPointer(LI->getPointerOperand(),
6998                                               LI->getType(), LI->getAlign(), DL,
6999                                               CtxI, AC, DT, TLI);
7000   }
7001   case Instruction::Call: {
7002     auto *CI = dyn_cast<const CallInst>(Inst);
7003     if (!CI)
7004       return false;
7005     const Function *Callee = CI->getCalledFunction();
7006 
7007     // The called function could have undefined behavior or side-effects, even
7008     // if marked readnone nounwind.
7009     if (!Callee || !Callee->isSpeculatable())
7010       return false;
7011     // Since the operands may be changed after hoisting, undefined behavior may
7012     // be triggered by some UB-implying attributes.
7013     return IgnoreUBImplyingAttrs || !CI->hasUBImplyingAttrs();
7014   }
7015   case Instruction::VAArg:
7016   case Instruction::Alloca:
7017   case Instruction::Invoke:
7018   case Instruction::CallBr:
7019   case Instruction::PHI:
7020   case Instruction::Store:
7021   case Instruction::Ret:
7022   case Instruction::Br:
7023   case Instruction::IndirectBr:
7024   case Instruction::Switch:
7025   case Instruction::Unreachable:
7026   case Instruction::Fence:
7027   case Instruction::AtomicRMW:
7028   case Instruction::AtomicCmpXchg:
7029   case Instruction::LandingPad:
7030   case Instruction::Resume:
7031   case Instruction::CatchSwitch:
7032   case Instruction::CatchPad:
7033   case Instruction::CatchRet:
7034   case Instruction::CleanupPad:
7035   case Instruction::CleanupRet:
7036     return false; // Misc instructions which have effects
7037   }
7038 }
7039 
7040 bool llvm::mayHaveNonDefUseDependency(const Instruction &I) {
7041   if (I.mayReadOrWriteMemory())
7042     // Memory dependency possible
7043     return true;
7044   if (!isSafeToSpeculativelyExecute(&I))
7045     // Can't move above a maythrow call or infinite loop.  Or if an
7046     // inalloca alloca, above a stacksave call.
7047     return true;
7048   if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7049     // 1) Can't reorder two inf-loop calls, even if readonly
7050     // 2) Also can't reorder an inf-loop call below a instruction which isn't
7051     //    safe to speculative execute.  (Inverse of above)
7052     return true;
7053   return false;
7054 }
7055 
7056 /// Convert ConstantRange OverflowResult into ValueTracking OverflowResult.
7057 static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
7058   switch (OR) {
7059     case ConstantRange::OverflowResult::MayOverflow:
7060       return OverflowResult::MayOverflow;
7061     case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7062       return OverflowResult::AlwaysOverflowsLow;
7063     case ConstantRange::OverflowResult::AlwaysOverflowsHigh:
7064       return OverflowResult::AlwaysOverflowsHigh;
7065     case ConstantRange::OverflowResult::NeverOverflows:
7066       return OverflowResult::NeverOverflows;
7067   }
7068   llvm_unreachable("Unknown OverflowResult");
7069 }
7070 
7071 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
7072 ConstantRange
7073 llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
7074                                              bool ForSigned,
7075                                              const SimplifyQuery &SQ) {
7076   ConstantRange CR1 =
7077       ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
7078   ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
7079   ConstantRange::PreferredRangeType RangeType =
7080       ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
7081   return CR1.intersectWith(CR2, RangeType);
7082 }
7083 
7084 OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
7085                                                    const Value *RHS,
7086                                                    const SimplifyQuery &SQ,
7087                                                    bool IsNSW) {
7088   KnownBits LHSKnown = computeKnownBits(LHS, SQ);
7089   KnownBits RHSKnown = computeKnownBits(RHS, SQ);
7090 
7091   // mul nsw of two non-negative numbers is also nuw.
7092   if (IsNSW && LHSKnown.isNonNegative() && RHSKnown.isNonNegative())
7093     return OverflowResult::NeverOverflows;
7094 
7095   ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
7096   ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
7097   return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
7098 }
7099 
7100 OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
7101                                                  const Value *RHS,
7102                                                  const SimplifyQuery &SQ) {
7103   // Multiplying n * m significant bits yields a result of n + m significant
7104   // bits. If the total number of significant bits does not exceed the
7105   // result bit width (minus 1), there is no overflow.
7106   // This means if we have enough leading sign bits in the operands
7107   // we can guarantee that the result does not overflow.
7108   // Ref: "Hacker's Delight" by Henry Warren
7109   unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
7110 
7111   // Note that underestimating the number of sign bits gives a more
7112   // conservative answer.
7113   unsigned SignBits =
7114       ::ComputeNumSignBits(LHS, SQ) + ::ComputeNumSignBits(RHS, SQ);
7115 
7116   // First handle the easy case: if we have enough sign bits there's
7117   // definitely no overflow.
7118   if (SignBits > BitWidth + 1)
7119     return OverflowResult::NeverOverflows;
7120 
7121   // There are two ambiguous cases where there can be no overflow:
7122   //   SignBits == BitWidth + 1    and
7123   //   SignBits == BitWidth
7124   // The second case is difficult to check, therefore we only handle the
7125   // first case.
7126   if (SignBits == BitWidth + 1) {
7127     // It overflows only when both arguments are negative and the true
7128     // product is exactly the minimum negative number.
7129     // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
7130     // For simplicity we just check if at least one side is not negative.
7131     KnownBits LHSKnown = computeKnownBits(LHS, SQ);
7132     KnownBits RHSKnown = computeKnownBits(RHS, SQ);
7133     if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
7134       return OverflowResult::NeverOverflows;
7135   }
7136   return OverflowResult::MayOverflow;
7137 }
7138 
7139 OverflowResult
7140 llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
7141                                     const WithCache<const Value *> &RHS,
7142                                     const SimplifyQuery &SQ) {
7143   ConstantRange LHSRange =
7144       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
7145   ConstantRange RHSRange =
7146       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
7147   return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
7148 }
7149 
7150 static OverflowResult
7151 computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7152                             const WithCache<const Value *> &RHS,
7153                             const AddOperator *Add, const SimplifyQuery &SQ) {
7154   if (Add && Add->hasNoSignedWrap()) {
7155     return OverflowResult::NeverOverflows;
7156   }
7157 
7158   // If LHS and RHS each have at least two sign bits, the addition will look
7159   // like
7160   //
7161   // XX..... +
7162   // YY.....
7163   //
7164   // If the carry into the most significant position is 0, X and Y can't both
7165   // be 1 and therefore the carry out of the addition is also 0.
7166   //
7167   // If the carry into the most significant position is 1, X and Y can't both
7168   // be 0 and therefore the carry out of the addition is also 1.
7169   //
7170   // Since the carry into the most significant position is always equal to
7171   // the carry out of the addition, there is no signed overflow.
7172   if (::ComputeNumSignBits(LHS, SQ) > 1 && ::ComputeNumSignBits(RHS, SQ) > 1)
7173     return OverflowResult::NeverOverflows;
7174 
7175   ConstantRange LHSRange =
7176       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
7177   ConstantRange RHSRange =
7178       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
7179   OverflowResult OR =
7180       mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange));
7181   if (OR != OverflowResult::MayOverflow)
7182     return OR;
7183 
7184   // The remaining code needs Add to be available. Early returns if not so.
7185   if (!Add)
7186     return OverflowResult::MayOverflow;
7187 
7188   // If the sign of Add is the same as at least one of the operands, this add
7189   // CANNOT overflow. If this can be determined from the known bits of the
7190   // operands the above signedAddMayOverflow() check will have already done so.
7191   // The only other way to improve on the known bits is from an assumption, so
7192   // call computeKnownBitsFromContext() directly.
7193   bool LHSOrRHSKnownNonNegative =
7194       (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative());
7195   bool LHSOrRHSKnownNegative =
7196       (LHSRange.isAllNegative() || RHSRange.isAllNegative());
7197   if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
7198     KnownBits AddKnown(LHSRange.getBitWidth());
7199     computeKnownBitsFromContext(Add, AddKnown, SQ);
7200     if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
7201         (AddKnown.isNegative() && LHSOrRHSKnownNegative))
7202       return OverflowResult::NeverOverflows;
7203   }
7204 
7205   return OverflowResult::MayOverflow;
7206 }
7207 
7208 OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
7209                                                    const Value *RHS,
7210                                                    const SimplifyQuery &SQ) {
7211   // X - (X % ?)
7212   // The remainder of a value can't have greater magnitude than itself,
7213   // so the subtraction can't overflow.
7214 
7215   // X - (X -nuw ?)
7216   // In the minimal case, this would simplify to "?", so there's no subtract
7217   // at all. But if this analysis is used to peek through casts, for example,
7218   // then determining no-overflow may allow other transforms.
7219 
7220   // TODO: There are other patterns like this.
7221   //       See simplifyICmpWithBinOpOnLHS() for candidates.
7222   if (match(RHS, m_URem(m_Specific(LHS), m_Value())) ||
7223       match(RHS, m_NUWSub(m_Specific(LHS), m_Value())))
7224     if (isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT))
7225       return OverflowResult::NeverOverflows;
7226 
7227   if (auto C = isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, SQ.CxtI,
7228                                        SQ.DL)) {
7229     if (*C)
7230       return OverflowResult::NeverOverflows;
7231     return OverflowResult::AlwaysOverflowsLow;
7232   }
7233 
7234   ConstantRange LHSRange =
7235       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
7236   ConstantRange RHSRange =
7237       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
7238   return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange));
7239 }
7240 
7241 OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
7242                                                  const Value *RHS,
7243                                                  const SimplifyQuery &SQ) {
7244   // X - (X % ?)
7245   // The remainder of a value can't have greater magnitude than itself,
7246   // so the subtraction can't overflow.
7247 
7248   // X - (X -nsw ?)
7249   // In the minimal case, this would simplify to "?", so there's no subtract
7250   // at all. But if this analysis is used to peek through casts, for example,
7251   // then determining no-overflow may allow other transforms.
7252   if (match(RHS, m_SRem(m_Specific(LHS), m_Value())) ||
7253       match(RHS, m_NSWSub(m_Specific(LHS), m_Value())))
7254     if (isGuaranteedNotToBeUndef(LHS, SQ.AC, SQ.CxtI, SQ.DT))
7255       return OverflowResult::NeverOverflows;
7256 
7257   // If LHS and RHS each have at least two sign bits, the subtraction
7258   // cannot overflow.
7259   if (::ComputeNumSignBits(LHS, SQ) > 1 && ::ComputeNumSignBits(RHS, SQ) > 1)
7260     return OverflowResult::NeverOverflows;
7261 
7262   ConstantRange LHSRange =
7263       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
7264   ConstantRange RHSRange =
7265       computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
7266   return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange));
7267 }
7268 
7269 bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
7270                                      const DominatorTree &DT) {
7271   SmallVector<const BranchInst *, 2> GuardingBranches;
7272   SmallVector<const ExtractValueInst *, 2> Results;
7273 
7274   for (const User *U : WO->users()) {
7275     if (const auto *EVI = dyn_cast<ExtractValueInst>(U)) {
7276       assert(EVI->getNumIndices() == 1 && "Obvious from CI's type");
7277 
7278       if (EVI->getIndices()[0] == 0)
7279         Results.push_back(EVI);
7280       else {
7281         assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type");
7282 
7283         for (const auto *U : EVI->users())
7284           if (const auto *B = dyn_cast<BranchInst>(U)) {
7285             assert(B->isConditional() && "How else is it using an i1?");
7286             GuardingBranches.push_back(B);
7287           }
7288       }
7289     } else {
7290       // We are using the aggregate directly in a way we don't want to analyze
7291       // here (storing it to a global, say).
7292       return false;
7293     }
7294   }
7295 
7296   auto AllUsesGuardedByBranch = [&](const BranchInst *BI) {
7297     BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(1));
7298     if (!NoWrapEdge.isSingleEdge())
7299       return false;
7300 
7301     // Check if all users of the add are provably no-wrap.
7302     for (const auto *Result : Results) {
7303       // If the extractvalue itself is not executed on overflow, the we don't
7304       // need to check each use separately, since domination is transitive.
7305       if (DT.dominates(NoWrapEdge, Result->getParent()))
7306         continue;
7307 
7308       for (const auto &RU : Result->uses())
7309         if (!DT.dominates(NoWrapEdge, RU))
7310           return false;
7311     }
7312 
7313     return true;
7314   };
7315 
7316   return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch);
7317 }
7318 
7319 /// Shifts return poison if shiftwidth is larger than the bitwidth.
7320 static bool shiftAmountKnownInRange(const Value *ShiftAmount) {
7321   auto *C = dyn_cast<Constant>(ShiftAmount);
7322   if (!C)
7323     return false;
7324 
7325   // Shifts return poison if shiftwidth is larger than the bitwidth.
7326   SmallVector<const Constant *, 4> ShiftAmounts;
7327   if (auto *FVTy = dyn_cast<FixedVectorType>(C->getType())) {
7328     unsigned NumElts = FVTy->getNumElements();
7329     for (unsigned i = 0; i < NumElts; ++i)
7330       ShiftAmounts.push_back(C->getAggregateElement(i));
7331   } else if (isa<ScalableVectorType>(C->getType()))
7332     return false; // Can't tell, just return false to be safe
7333   else
7334     ShiftAmounts.push_back(C);
7335 
7336   bool Safe = llvm::all_of(ShiftAmounts, [](const Constant *C) {
7337     auto *CI = dyn_cast_or_null<ConstantInt>(C);
7338     return CI && CI->getValue().ult(C->getType()->getIntegerBitWidth());
7339   });
7340 
7341   return Safe;
7342 }
7343 
7344 enum class UndefPoisonKind {
7345   PoisonOnly = (1 << 0),
7346   UndefOnly = (1 << 1),
7347   UndefOrPoison = PoisonOnly | UndefOnly,
7348 };
7349 
7350 static bool includesPoison(UndefPoisonKind Kind) {
7351   return (unsigned(Kind) & unsigned(UndefPoisonKind::PoisonOnly)) != 0;
7352 }
7353 
7354 static bool includesUndef(UndefPoisonKind Kind) {
7355   return (unsigned(Kind) & unsigned(UndefPoisonKind::UndefOnly)) != 0;
7356 }
7357 
7358 static bool canCreateUndefOrPoison(const Operator *Op, UndefPoisonKind Kind,
7359                                    bool ConsiderFlagsAndMetadata) {
7360 
7361   if (ConsiderFlagsAndMetadata && includesPoison(Kind) &&
7362       Op->hasPoisonGeneratingAnnotations())
7363     return true;
7364 
7365   unsigned Opcode = Op->getOpcode();
7366 
7367   // Check whether opcode is a poison/undef-generating operation
7368   switch (Opcode) {
7369   case Instruction::Shl:
7370   case Instruction::AShr:
7371   case Instruction::LShr:
7372     return includesPoison(Kind) && !shiftAmountKnownInRange(Op->getOperand(1));
7373   case Instruction::FPToSI:
7374   case Instruction::FPToUI:
7375     // fptosi/ui yields poison if the resulting value does not fit in the
7376     // destination type.
7377     return true;
7378   case Instruction::Call:
7379     if (auto *II = dyn_cast<IntrinsicInst>(Op)) {
7380       switch (II->getIntrinsicID()) {
7381       // TODO: Add more intrinsics.
7382       case Intrinsic::ctlz:
7383       case Intrinsic::cttz:
7384       case Intrinsic::abs:
7385         if (cast<ConstantInt>(II->getArgOperand(1))->isNullValue())
7386           return false;
7387         break;
7388       case Intrinsic::ctpop:
7389       case Intrinsic::bswap:
7390       case Intrinsic::bitreverse:
7391       case Intrinsic::fshl:
7392       case Intrinsic::fshr:
7393       case Intrinsic::smax:
7394       case Intrinsic::smin:
7395       case Intrinsic::umax:
7396       case Intrinsic::umin:
7397       case Intrinsic::ptrmask:
7398       case Intrinsic::fptoui_sat:
7399       case Intrinsic::fptosi_sat:
7400       case Intrinsic::sadd_with_overflow:
7401       case Intrinsic::ssub_with_overflow:
7402       case Intrinsic::smul_with_overflow:
7403       case Intrinsic::uadd_with_overflow:
7404       case Intrinsic::usub_with_overflow:
7405       case Intrinsic::umul_with_overflow:
7406       case Intrinsic::sadd_sat:
7407       case Intrinsic::uadd_sat:
7408       case Intrinsic::ssub_sat:
7409       case Intrinsic::usub_sat:
7410         return false;
7411       case Intrinsic::sshl_sat:
7412       case Intrinsic::ushl_sat:
7413         return includesPoison(Kind) &&
7414                !shiftAmountKnownInRange(II->getArgOperand(1));
7415       case Intrinsic::fma:
7416       case Intrinsic::fmuladd:
7417       case Intrinsic::sqrt:
7418       case Intrinsic::powi:
7419       case Intrinsic::sin:
7420       case Intrinsic::cos:
7421       case Intrinsic::pow:
7422       case Intrinsic::log:
7423       case Intrinsic::log10:
7424       case Intrinsic::log2:
7425       case Intrinsic::exp:
7426       case Intrinsic::exp2:
7427       case Intrinsic::exp10:
7428       case Intrinsic::fabs:
7429       case Intrinsic::copysign:
7430       case Intrinsic::floor:
7431       case Intrinsic::ceil:
7432       case Intrinsic::trunc:
7433       case Intrinsic::rint:
7434       case Intrinsic::nearbyint:
7435       case Intrinsic::round:
7436       case Intrinsic::roundeven:
7437       case Intrinsic::fptrunc_round:
7438       case Intrinsic::canonicalize:
7439       case Intrinsic::arithmetic_fence:
7440       case Intrinsic::minnum:
7441       case Intrinsic::maxnum:
7442       case Intrinsic::minimum:
7443       case Intrinsic::maximum:
7444       case Intrinsic::minimumnum:
7445       case Intrinsic::maximumnum:
7446       case Intrinsic::is_fpclass:
7447       case Intrinsic::ldexp:
7448       case Intrinsic::frexp:
7449         return false;
7450       case Intrinsic::lround:
7451       case Intrinsic::llround:
7452       case Intrinsic::lrint:
7453       case Intrinsic::llrint:
7454         // If the value doesn't fit an unspecified value is returned (but this
7455         // is not poison).
7456         return false;
7457       }
7458     }
7459     [[fallthrough]];
7460   case Instruction::CallBr:
7461   case Instruction::Invoke: {
7462     const auto *CB = cast<CallBase>(Op);
7463     return !CB->hasRetAttr(Attribute::NoUndef);
7464   }
7465   case Instruction::InsertElement:
7466   case Instruction::ExtractElement: {
7467     // If index exceeds the length of the vector, it returns poison
7468     auto *VTy = cast<VectorType>(Op->getOperand(0)->getType());
7469     unsigned IdxOp = Op->getOpcode() == Instruction::InsertElement ? 2 : 1;
7470     auto *Idx = dyn_cast<ConstantInt>(Op->getOperand(IdxOp));
7471     if (includesPoison(Kind))
7472       return !Idx ||
7473              Idx->getValue().uge(VTy->getElementCount().getKnownMinValue());
7474     return false;
7475   }
7476   case Instruction::ShuffleVector: {
7477     ArrayRef<int> Mask = isa<ConstantExpr>(Op)
7478                              ? cast<ConstantExpr>(Op)->getShuffleMask()
7479                              : cast<ShuffleVectorInst>(Op)->getShuffleMask();
7480     return includesPoison(Kind) && is_contained(Mask, PoisonMaskElem);
7481   }
7482   case Instruction::FNeg:
7483   case Instruction::PHI:
7484   case Instruction::Select:
7485   case Instruction::ExtractValue:
7486   case Instruction::InsertValue:
7487   case Instruction::Freeze:
7488   case Instruction::ICmp:
7489   case Instruction::FCmp:
7490   case Instruction::GetElementPtr:
7491     return false;
7492   case Instruction::AddrSpaceCast:
7493     return true;
7494   default: {
7495     const auto *CE = dyn_cast<ConstantExpr>(Op);
7496     if (isa<CastInst>(Op) || (CE && CE->isCast()))
7497       return false;
7498     else if (Instruction::isBinaryOp(Opcode))
7499       return false;
7500     // Be conservative and return true.
7501     return true;
7502   }
7503   }
7504 }
7505 
7506 bool llvm::canCreateUndefOrPoison(const Operator *Op,
7507                                   bool ConsiderFlagsAndMetadata) {
7508   return ::canCreateUndefOrPoison(Op, UndefPoisonKind::UndefOrPoison,
7509                                   ConsiderFlagsAndMetadata);
7510 }
7511 
7512 bool llvm::canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata) {
7513   return ::canCreateUndefOrPoison(Op, UndefPoisonKind::PoisonOnly,
7514                                   ConsiderFlagsAndMetadata);
7515 }
7516 
7517 static bool directlyImpliesPoison(const Value *ValAssumedPoison, const Value *V,
7518                                   unsigned Depth) {
7519   if (ValAssumedPoison == V)
7520     return true;
7521 
7522   const unsigned MaxDepth = 2;
7523   if (Depth >= MaxDepth)
7524     return false;
7525 
7526   if (const auto *I = dyn_cast<Instruction>(V)) {
7527     if (any_of(I->operands(), [=](const Use &Op) {
7528           return propagatesPoison(Op) &&
7529                  directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1);
7530         }))
7531       return true;
7532 
7533     // V  = extractvalue V0, idx
7534     // V2 = extractvalue V0, idx2
7535     // V0's elements are all poison or not. (e.g., add_with_overflow)
7536     const WithOverflowInst *II;
7537     if (match(I, m_ExtractValue(m_WithOverflowInst(II))) &&
7538         (match(ValAssumedPoison, m_ExtractValue(m_Specific(II))) ||
7539          llvm::is_contained(II->args(), ValAssumedPoison)))
7540       return true;
7541   }
7542   return false;
7543 }
7544 
7545 static bool impliesPoison(const Value *ValAssumedPoison, const Value *V,
7546                           unsigned Depth) {
7547   if (isGuaranteedNotToBePoison(ValAssumedPoison))
7548     return true;
7549 
7550   if (directlyImpliesPoison(ValAssumedPoison, V, /* Depth */ 0))
7551     return true;
7552 
7553   const unsigned MaxDepth = 2;
7554   if (Depth >= MaxDepth)
7555     return false;
7556 
7557   const auto *I = dyn_cast<Instruction>(ValAssumedPoison);
7558   if (I && !canCreatePoison(cast<Operator>(I))) {
7559     return all_of(I->operands(), [=](const Value *Op) {
7560       return impliesPoison(Op, V, Depth + 1);
7561     });
7562   }
7563   return false;
7564 }
7565 
7566 bool llvm::impliesPoison(const Value *ValAssumedPoison, const Value *V) {
7567   return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0);
7568 }
7569 
7570 static bool programUndefinedIfUndefOrPoison(const Value *V, bool PoisonOnly);
7571 
7572 static bool isGuaranteedNotToBeUndefOrPoison(
7573     const Value *V, AssumptionCache *AC, const Instruction *CtxI,
7574     const DominatorTree *DT, unsigned Depth, UndefPoisonKind Kind) {
7575   if (Depth >= MaxAnalysisRecursionDepth)
7576     return false;
7577 
7578   if (isa<MetadataAsValue>(V))
7579     return false;
7580 
7581   if (const auto *A = dyn_cast<Argument>(V)) {
7582     if (A->hasAttribute(Attribute::NoUndef) ||
7583         A->hasAttribute(Attribute::Dereferenceable) ||
7584         A->hasAttribute(Attribute::DereferenceableOrNull))
7585       return true;
7586   }
7587 
7588   if (auto *C = dyn_cast<Constant>(V)) {
7589     if (isa<PoisonValue>(C))
7590       return !includesPoison(Kind);
7591 
7592     if (isa<UndefValue>(C))
7593       return !includesUndef(Kind);
7594 
7595     if (isa<ConstantInt>(C) || isa<GlobalVariable>(C) || isa<ConstantFP>(C) ||
7596         isa<ConstantPointerNull>(C) || isa<Function>(C))
7597       return true;
7598 
7599     if (C->getType()->isVectorTy()) {
7600       if (isa<ConstantExpr>(C)) {
7601         // Scalable vectors can use a ConstantExpr to build a splat.
7602         if (Constant *SplatC = C->getSplatValue())
7603           if (isa<ConstantInt>(SplatC) || isa<ConstantFP>(SplatC))
7604             return true;
7605       } else {
7606         if (includesUndef(Kind) && C->containsUndefElement())
7607           return false;
7608         if (includesPoison(Kind) && C->containsPoisonElement())
7609           return false;
7610         return !C->containsConstantExpression();
7611       }
7612     }
7613   }
7614 
7615   // Strip cast operations from a pointer value.
7616   // Note that stripPointerCastsSameRepresentation can strip off getelementptr
7617   // inbounds with zero offset. To guarantee that the result isn't poison, the
7618   // stripped pointer is checked as it has to be pointing into an allocated
7619   // object or be null `null` to ensure `inbounds` getelement pointers with a
7620   // zero offset could not produce poison.
7621   // It can strip off addrspacecast that do not change bit representation as
7622   // well. We believe that such addrspacecast is equivalent to no-op.
7623   auto *StrippedV = V->stripPointerCastsSameRepresentation();
7624   if (isa<AllocaInst>(StrippedV) || isa<GlobalVariable>(StrippedV) ||
7625       isa<Function>(StrippedV) || isa<ConstantPointerNull>(StrippedV))
7626     return true;
7627 
7628   auto OpCheck = [&](const Value *V) {
7629     return isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth + 1, Kind);
7630   };
7631 
7632   if (auto *Opr = dyn_cast<Operator>(V)) {
7633     // If the value is a freeze instruction, then it can never
7634     // be undef or poison.
7635     if (isa<FreezeInst>(V))
7636       return true;
7637 
7638     if (const auto *CB = dyn_cast<CallBase>(V)) {
7639       if (CB->hasRetAttr(Attribute::NoUndef) ||
7640           CB->hasRetAttr(Attribute::Dereferenceable) ||
7641           CB->hasRetAttr(Attribute::DereferenceableOrNull))
7642         return true;
7643     }
7644 
7645     if (const auto *PN = dyn_cast<PHINode>(V)) {
7646       unsigned Num = PN->getNumIncomingValues();
7647       bool IsWellDefined = true;
7648       for (unsigned i = 0; i < Num; ++i) {
7649         if (PN == PN->getIncomingValue(i))
7650           continue;
7651         auto *TI = PN->getIncomingBlock(i)->getTerminator();
7652         if (!isGuaranteedNotToBeUndefOrPoison(PN->getIncomingValue(i), AC, TI,
7653                                               DT, Depth + 1, Kind)) {
7654           IsWellDefined = false;
7655           break;
7656         }
7657       }
7658       if (IsWellDefined)
7659         return true;
7660     } else if (!::canCreateUndefOrPoison(Opr, Kind,
7661                                          /*ConsiderFlagsAndMetadata*/ true) &&
7662                all_of(Opr->operands(), OpCheck))
7663       return true;
7664   }
7665 
7666   if (auto *I = dyn_cast<LoadInst>(V))
7667     if (I->hasMetadata(LLVMContext::MD_noundef) ||
7668         I->hasMetadata(LLVMContext::MD_dereferenceable) ||
7669         I->hasMetadata(LLVMContext::MD_dereferenceable_or_null))
7670       return true;
7671 
7672   if (programUndefinedIfUndefOrPoison(V, !includesUndef(Kind)))
7673     return true;
7674 
7675   // CxtI may be null or a cloned instruction.
7676   if (!CtxI || !CtxI->getParent() || !DT)
7677     return false;
7678 
7679   auto *DNode = DT->getNode(CtxI->getParent());
7680   if (!DNode)
7681     // Unreachable block
7682     return false;
7683 
7684   // If V is used as a branch condition before reaching CtxI, V cannot be
7685   // undef or poison.
7686   //   br V, BB1, BB2
7687   // BB1:
7688   //   CtxI ; V cannot be undef or poison here
7689   auto *Dominator = DNode->getIDom();
7690   // This check is purely for compile time reasons: we can skip the IDom walk
7691   // if what we are checking for includes undef and the value is not an integer.
7692   if (!includesUndef(Kind) || V->getType()->isIntegerTy())
7693     while (Dominator) {
7694       auto *TI = Dominator->getBlock()->getTerminator();
7695 
7696       Value *Cond = nullptr;
7697       if (auto BI = dyn_cast_or_null<BranchInst>(TI)) {
7698         if (BI->isConditional())
7699           Cond = BI->getCondition();
7700       } else if (auto SI = dyn_cast_or_null<SwitchInst>(TI)) {
7701         Cond = SI->getCondition();
7702       }
7703 
7704       if (Cond) {
7705         if (Cond == V)
7706           return true;
7707         else if (!includesUndef(Kind) && isa<Operator>(Cond)) {
7708           // For poison, we can analyze further
7709           auto *Opr = cast<Operator>(Cond);
7710           if (any_of(Opr->operands(), [V](const Use &U) {
7711                 return V == U && propagatesPoison(U);
7712               }))
7713             return true;
7714         }
7715       }
7716 
7717       Dominator = Dominator->getIDom();
7718     }
7719 
7720   if (AC && getKnowledgeValidInContext(V, {Attribute::NoUndef}, *AC, CtxI, DT))
7721     return true;
7722 
7723   return false;
7724 }
7725 
7726 bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC,
7727                                             const Instruction *CtxI,
7728                                             const DominatorTree *DT,
7729                                             unsigned Depth) {
7730   return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7731                                             UndefPoisonKind::UndefOrPoison);
7732 }
7733 
7734 bool llvm::isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
7735                                      const Instruction *CtxI,
7736                                      const DominatorTree *DT, unsigned Depth) {
7737   return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7738                                             UndefPoisonKind::PoisonOnly);
7739 }
7740 
7741 bool llvm::isGuaranteedNotToBeUndef(const Value *V, AssumptionCache *AC,
7742                                     const Instruction *CtxI,
7743                                     const DominatorTree *DT, unsigned Depth) {
7744   return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7745                                             UndefPoisonKind::UndefOnly);
7746 }
7747 
7748 /// Return true if undefined behavior would provably be executed on the path to
7749 /// OnPathTo if Root produced a posion result.  Note that this doesn't say
7750 /// anything about whether OnPathTo is actually executed or whether Root is
7751 /// actually poison.  This can be used to assess whether a new use of Root can
7752 /// be added at a location which is control equivalent with OnPathTo (such as
7753 /// immediately before it) without introducing UB which didn't previously
7754 /// exist.  Note that a false result conveys no information.
7755 bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
7756                                          Instruction *OnPathTo,
7757                                          DominatorTree *DT) {
7758   // Basic approach is to assume Root is poison, propagate poison forward
7759   // through all users we can easily track, and then check whether any of those
7760   // users are provable UB and must execute before out exiting block might
7761   // exit.
7762 
7763   // The set of all recursive users we've visited (which are assumed to all be
7764   // poison because of said visit)
7765   SmallSet<const Value *, 16> KnownPoison;
7766   SmallVector<const Instruction*, 16> Worklist;
7767   Worklist.push_back(Root);
7768   while (!Worklist.empty()) {
7769     const Instruction *I = Worklist.pop_back_val();
7770 
7771     // If we know this must trigger UB on a path leading our target.
7772     if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo))
7773       return true;
7774 
7775     // If we can't analyze propagation through this instruction, just skip it
7776     // and transitive users.  Safe as false is a conservative result.
7777     if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) {
7778           return KnownPoison.contains(U) && propagatesPoison(U);
7779         }))
7780       continue;
7781 
7782     if (KnownPoison.insert(I).second)
7783       for (const User *User : I->users())
7784         Worklist.push_back(cast<Instruction>(User));
7785   }
7786 
7787   // Might be non-UB, or might have a path we couldn't prove must execute on
7788   // way to exiting bb.
7789   return false;
7790 }
7791 
7792 OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
7793                                                  const SimplifyQuery &SQ) {
7794   return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1),
7795                                        Add, SQ);
7796 }
7797 
7798 OverflowResult
7799 llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7800                                   const WithCache<const Value *> &RHS,
7801                                   const SimplifyQuery &SQ) {
7802   return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
7803 }
7804 
7805 bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
7806   // Note: An atomic operation isn't guaranteed to return in a reasonable amount
7807   // of time because it's possible for another thread to interfere with it for an
7808   // arbitrary length of time, but programs aren't allowed to rely on that.
7809 
7810   // If there is no successor, then execution can't transfer to it.
7811   if (isa<ReturnInst>(I))
7812     return false;
7813   if (isa<UnreachableInst>(I))
7814     return false;
7815 
7816   // Note: Do not add new checks here; instead, change Instruction::mayThrow or
7817   // Instruction::willReturn.
7818   //
7819   // FIXME: Move this check into Instruction::willReturn.
7820   if (isa<CatchPadInst>(I)) {
7821     switch (classifyEHPersonality(I->getFunction()->getPersonalityFn())) {
7822     default:
7823       // A catchpad may invoke exception object constructors and such, which
7824       // in some languages can be arbitrary code, so be conservative by default.
7825       return false;
7826     case EHPersonality::CoreCLR:
7827       // For CoreCLR, it just involves a type test.
7828       return true;
7829     }
7830   }
7831 
7832   // An instruction that returns without throwing must transfer control flow
7833   // to a successor.
7834   return !I->mayThrow() && I->willReturn();
7835 }
7836 
7837 bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) {
7838   // TODO: This is slightly conservative for invoke instruction since exiting
7839   // via an exception *is* normal control for them.
7840   for (const Instruction &I : *BB)
7841     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7842       return false;
7843   return true;
7844 }
7845 
7846 bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7847    BasicBlock::const_iterator Begin, BasicBlock::const_iterator End,
7848    unsigned ScanLimit) {
7849   return isGuaranteedToTransferExecutionToSuccessor(make_range(Begin, End),
7850                                                     ScanLimit);
7851 }
7852 
7853 bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7854    iterator_range<BasicBlock::const_iterator> Range, unsigned ScanLimit) {
7855   assert(ScanLimit && "scan limit must be non-zero");
7856   for (const Instruction &I : Range) {
7857     if (--ScanLimit == 0)
7858       return false;
7859     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7860       return false;
7861   }
7862   return true;
7863 }
7864 
7865 bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I,
7866                                                   const Loop *L) {
7867   // The loop header is guaranteed to be executed for every iteration.
7868   //
7869   // FIXME: Relax this constraint to cover all basic blocks that are
7870   // guaranteed to be executed at every iteration.
7871   if (I->getParent() != L->getHeader()) return false;
7872 
7873   for (const Instruction &LI : *L->getHeader()) {
7874     if (&LI == I) return true;
7875     if (!isGuaranteedToTransferExecutionToSuccessor(&LI)) return false;
7876   }
7877   llvm_unreachable("Instruction not contained in its own parent basic block.");
7878 }
7879 
7880 bool llvm::intrinsicPropagatesPoison(Intrinsic::ID IID) {
7881   switch (IID) {
7882   // TODO: Add more intrinsics.
7883   case Intrinsic::sadd_with_overflow:
7884   case Intrinsic::ssub_with_overflow:
7885   case Intrinsic::smul_with_overflow:
7886   case Intrinsic::uadd_with_overflow:
7887   case Intrinsic::usub_with_overflow:
7888   case Intrinsic::umul_with_overflow:
7889     // If an input is a vector containing a poison element, the
7890     // two output vectors (calculated results, overflow bits)'
7891     // corresponding lanes are poison.
7892     return true;
7893   case Intrinsic::ctpop:
7894   case Intrinsic::ctlz:
7895   case Intrinsic::cttz:
7896   case Intrinsic::abs:
7897   case Intrinsic::smax:
7898   case Intrinsic::smin:
7899   case Intrinsic::umax:
7900   case Intrinsic::umin:
7901   case Intrinsic::scmp:
7902   case Intrinsic::is_fpclass:
7903   case Intrinsic::ptrmask:
7904   case Intrinsic::ucmp:
7905   case Intrinsic::bitreverse:
7906   case Intrinsic::bswap:
7907   case Intrinsic::sadd_sat:
7908   case Intrinsic::ssub_sat:
7909   case Intrinsic::sshl_sat:
7910   case Intrinsic::uadd_sat:
7911   case Intrinsic::usub_sat:
7912   case Intrinsic::ushl_sat:
7913   case Intrinsic::smul_fix:
7914   case Intrinsic::smul_fix_sat:
7915   case Intrinsic::pow:
7916   case Intrinsic::powi:
7917   case Intrinsic::canonicalize:
7918   case Intrinsic::sqrt:
7919     return true;
7920   default:
7921     return false;
7922   }
7923 }
7924 
7925 bool llvm::propagatesPoison(const Use &PoisonOp) {
7926   const Operator *I = cast<Operator>(PoisonOp.getUser());
7927   switch (I->getOpcode()) {
7928   case Instruction::Freeze:
7929   case Instruction::PHI:
7930   case Instruction::Invoke:
7931     return false;
7932   case Instruction::Select:
7933     return PoisonOp.getOperandNo() == 0;
7934   case Instruction::Call:
7935     if (auto *II = dyn_cast<IntrinsicInst>(I))
7936       return intrinsicPropagatesPoison(II->getIntrinsicID());
7937     return false;
7938   case Instruction::ICmp:
7939   case Instruction::FCmp:
7940   case Instruction::GetElementPtr:
7941     return true;
7942   default:
7943     if (isa<BinaryOperator>(I) || isa<UnaryOperator>(I) || isa<CastInst>(I))
7944       return true;
7945 
7946     // Be conservative and return false.
7947     return false;
7948   }
7949 }
7950 
7951 /// Enumerates all operands of \p I that are guaranteed to not be undef or
7952 /// poison. If the callback \p Handle returns true, stop processing and return
7953 /// true. Otherwise, return false.
7954 template <typename CallableT>
7955 static bool handleGuaranteedWellDefinedOps(const Instruction *I,
7956                                            const CallableT &Handle) {
7957   switch (I->getOpcode()) {
7958     case Instruction::Store:
7959       if (Handle(cast<StoreInst>(I)->getPointerOperand()))
7960         return true;
7961       break;
7962 
7963     case Instruction::Load:
7964       if (Handle(cast<LoadInst>(I)->getPointerOperand()))
7965         return true;
7966       break;
7967 
7968     // Since dereferenceable attribute imply noundef, atomic operations
7969     // also implicitly have noundef pointers too
7970     case Instruction::AtomicCmpXchg:
7971       if (Handle(cast<AtomicCmpXchgInst>(I)->getPointerOperand()))
7972         return true;
7973       break;
7974 
7975     case Instruction::AtomicRMW:
7976       if (Handle(cast<AtomicRMWInst>(I)->getPointerOperand()))
7977         return true;
7978       break;
7979 
7980     case Instruction::Call:
7981     case Instruction::Invoke: {
7982       const CallBase *CB = cast<CallBase>(I);
7983       if (CB->isIndirectCall() && Handle(CB->getCalledOperand()))
7984         return true;
7985       for (unsigned i = 0; i < CB->arg_size(); ++i)
7986         if ((CB->paramHasAttr(i, Attribute::NoUndef) ||
7987              CB->paramHasAttr(i, Attribute::Dereferenceable) ||
7988              CB->paramHasAttr(i, Attribute::DereferenceableOrNull)) &&
7989             Handle(CB->getArgOperand(i)))
7990           return true;
7991       break;
7992     }
7993     case Instruction::Ret:
7994       if (I->getFunction()->hasRetAttribute(Attribute::NoUndef) &&
7995           Handle(I->getOperand(0)))
7996         return true;
7997       break;
7998     case Instruction::Switch:
7999       if (Handle(cast<SwitchInst>(I)->getCondition()))
8000         return true;
8001       break;
8002     case Instruction::Br: {
8003       auto *BR = cast<BranchInst>(I);
8004       if (BR->isConditional() && Handle(BR->getCondition()))
8005         return true;
8006       break;
8007     }
8008     default:
8009       break;
8010   }
8011 
8012   return false;
8013 }
8014 
8015 /// Enumerates all operands of \p I that are guaranteed to not be poison.
8016 template <typename CallableT>
8017 static bool handleGuaranteedNonPoisonOps(const Instruction *I,
8018                                          const CallableT &Handle) {
8019   if (handleGuaranteedWellDefinedOps(I, Handle))
8020     return true;
8021   switch (I->getOpcode()) {
8022   // Divisors of these operations are allowed to be partially undef.
8023   case Instruction::UDiv:
8024   case Instruction::SDiv:
8025   case Instruction::URem:
8026   case Instruction::SRem:
8027     return Handle(I->getOperand(1));
8028   default:
8029     return false;
8030   }
8031 }
8032 
8033 bool llvm::mustTriggerUB(const Instruction *I,
8034                          const SmallPtrSetImpl<const Value *> &KnownPoison) {
8035   return handleGuaranteedNonPoisonOps(
8036       I, [&](const Value *V) { return KnownPoison.count(V); });
8037 }
8038 
8039 static bool programUndefinedIfUndefOrPoison(const Value *V,
8040                                             bool PoisonOnly) {
8041   // We currently only look for uses of values within the same basic
8042   // block, as that makes it easier to guarantee that the uses will be
8043   // executed given that Inst is executed.
8044   //
8045   // FIXME: Expand this to consider uses beyond the same basic block. To do
8046   // this, look out for the distinction between post-dominance and strong
8047   // post-dominance.
8048   const BasicBlock *BB = nullptr;
8049   BasicBlock::const_iterator Begin;
8050   if (const auto *Inst = dyn_cast<Instruction>(V)) {
8051     BB = Inst->getParent();
8052     Begin = Inst->getIterator();
8053     Begin++;
8054   } else if (const auto *Arg = dyn_cast<Argument>(V)) {
8055     if (Arg->getParent()->isDeclaration())
8056       return false;
8057     BB = &Arg->getParent()->getEntryBlock();
8058     Begin = BB->begin();
8059   } else {
8060     return false;
8061   }
8062 
8063   // Limit number of instructions we look at, to avoid scanning through large
8064   // blocks. The current limit is chosen arbitrarily.
8065   unsigned ScanLimit = 32;
8066   BasicBlock::const_iterator End = BB->end();
8067 
8068   if (!PoisonOnly) {
8069     // Since undef does not propagate eagerly, be conservative & just check
8070     // whether a value is directly passed to an instruction that must take
8071     // well-defined operands.
8072 
8073     for (const auto &I : make_range(Begin, End)) {
8074       if (--ScanLimit == 0)
8075         break;
8076 
8077       if (handleGuaranteedWellDefinedOps(&I, [V](const Value *WellDefinedOp) {
8078             return WellDefinedOp == V;
8079           }))
8080         return true;
8081 
8082       if (!isGuaranteedToTransferExecutionToSuccessor(&I))
8083         break;
8084     }
8085     return false;
8086   }
8087 
8088   // Set of instructions that we have proved will yield poison if Inst
8089   // does.
8090   SmallSet<const Value *, 16> YieldsPoison;
8091   SmallSet<const BasicBlock *, 4> Visited;
8092 
8093   YieldsPoison.insert(V);
8094   Visited.insert(BB);
8095 
8096   while (true) {
8097     for (const auto &I : make_range(Begin, End)) {
8098       if (--ScanLimit == 0)
8099         return false;
8100       if (mustTriggerUB(&I, YieldsPoison))
8101         return true;
8102       if (!isGuaranteedToTransferExecutionToSuccessor(&I))
8103         return false;
8104 
8105       // If an operand is poison and propagates it, mark I as yielding poison.
8106       for (const Use &Op : I.operands()) {
8107         if (YieldsPoison.count(Op) && propagatesPoison(Op)) {
8108           YieldsPoison.insert(&I);
8109           break;
8110         }
8111       }
8112 
8113       // Special handling for select, which returns poison if its operand 0 is
8114       // poison (handled in the loop above) *or* if both its true/false operands
8115       // are poison (handled here).
8116       if (I.getOpcode() == Instruction::Select &&
8117           YieldsPoison.count(I.getOperand(1)) &&
8118           YieldsPoison.count(I.getOperand(2))) {
8119         YieldsPoison.insert(&I);
8120       }
8121     }
8122 
8123     BB = BB->getSingleSuccessor();
8124     if (!BB || !Visited.insert(BB).second)
8125       break;
8126 
8127     Begin = BB->getFirstNonPHIIt();
8128     End = BB->end();
8129   }
8130   return false;
8131 }
8132 
8133 bool llvm::programUndefinedIfUndefOrPoison(const Instruction *Inst) {
8134   return ::programUndefinedIfUndefOrPoison(Inst, false);
8135 }
8136 
8137 bool llvm::programUndefinedIfPoison(const Instruction *Inst) {
8138   return ::programUndefinedIfUndefOrPoison(Inst, true);
8139 }
8140 
8141 static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) {
8142   if (FMF.noNaNs())
8143     return true;
8144 
8145   if (auto *C = dyn_cast<ConstantFP>(V))
8146     return !C->isNaN();
8147 
8148   if (auto *C = dyn_cast<ConstantDataVector>(V)) {
8149     if (!C->getElementType()->isFloatingPointTy())
8150       return false;
8151     for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8152       if (C->getElementAsAPFloat(I).isNaN())
8153         return false;
8154     }
8155     return true;
8156   }
8157 
8158   if (isa<ConstantAggregateZero>(V))
8159     return true;
8160 
8161   return false;
8162 }
8163 
8164 static bool isKnownNonZero(const Value *V) {
8165   if (auto *C = dyn_cast<ConstantFP>(V))
8166     return !C->isZero();
8167 
8168   if (auto *C = dyn_cast<ConstantDataVector>(V)) {
8169     if (!C->getElementType()->isFloatingPointTy())
8170       return false;
8171     for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8172       if (C->getElementAsAPFloat(I).isZero())
8173         return false;
8174     }
8175     return true;
8176   }
8177 
8178   return false;
8179 }
8180 
8181 /// Match clamp pattern for float types without care about NaNs or signed zeros.
8182 /// Given non-min/max outer cmp/select from the clamp pattern this
8183 /// function recognizes if it can be substitued by a "canonical" min/max
8184 /// pattern.
8185 static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred,
8186                                                Value *CmpLHS, Value *CmpRHS,
8187                                                Value *TrueVal, Value *FalseVal,
8188                                                Value *&LHS, Value *&RHS) {
8189   // Try to match
8190   //   X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2))
8191   //   X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2))
8192   // and return description of the outer Max/Min.
8193 
8194   // First, check if select has inverse order:
8195   if (CmpRHS == FalseVal) {
8196     std::swap(TrueVal, FalseVal);
8197     Pred = CmpInst::getInversePredicate(Pred);
8198   }
8199 
8200   // Assume success now. If there's no match, callers should not use these anyway.
8201   LHS = TrueVal;
8202   RHS = FalseVal;
8203 
8204   const APFloat *FC1;
8205   if (CmpRHS != TrueVal || !match(CmpRHS, m_APFloat(FC1)) || !FC1->isFinite())
8206     return {SPF_UNKNOWN, SPNB_NA, false};
8207 
8208   const APFloat *FC2;
8209   switch (Pred) {
8210   case CmpInst::FCMP_OLT:
8211   case CmpInst::FCMP_OLE:
8212   case CmpInst::FCMP_ULT:
8213   case CmpInst::FCMP_ULE:
8214     if (match(FalseVal, m_OrdOrUnordFMin(m_Specific(CmpLHS), m_APFloat(FC2))) &&
8215         *FC1 < *FC2)
8216       return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false};
8217     break;
8218   case CmpInst::FCMP_OGT:
8219   case CmpInst::FCMP_OGE:
8220   case CmpInst::FCMP_UGT:
8221   case CmpInst::FCMP_UGE:
8222     if (match(FalseVal, m_OrdOrUnordFMax(m_Specific(CmpLHS), m_APFloat(FC2))) &&
8223         *FC1 > *FC2)
8224       return {SPF_FMINNUM, SPNB_RETURNS_ANY, false};
8225     break;
8226   default:
8227     break;
8228   }
8229 
8230   return {SPF_UNKNOWN, SPNB_NA, false};
8231 }
8232 
8233 /// Recognize variations of:
8234 ///   CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v)))
8235 static SelectPatternResult matchClamp(CmpInst::Predicate Pred,
8236                                       Value *CmpLHS, Value *CmpRHS,
8237                                       Value *TrueVal, Value *FalseVal) {
8238   // Swap the select operands and predicate to match the patterns below.
8239   if (CmpRHS != TrueVal) {
8240     Pred = ICmpInst::getSwappedPredicate(Pred);
8241     std::swap(TrueVal, FalseVal);
8242   }
8243   const APInt *C1;
8244   if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) {
8245     const APInt *C2;
8246     // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1)
8247     if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) &&
8248         C1->slt(*C2) && Pred == CmpInst::ICMP_SLT)
8249       return {SPF_SMAX, SPNB_NA, false};
8250 
8251     // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1)
8252     if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) &&
8253         C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT)
8254       return {SPF_SMIN, SPNB_NA, false};
8255 
8256     // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1)
8257     if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) &&
8258         C1->ult(*C2) && Pred == CmpInst::ICMP_ULT)
8259       return {SPF_UMAX, SPNB_NA, false};
8260 
8261     // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1)
8262     if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) &&
8263         C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT)
8264       return {SPF_UMIN, SPNB_NA, false};
8265   }
8266   return {SPF_UNKNOWN, SPNB_NA, false};
8267 }
8268 
8269 /// Recognize variations of:
8270 ///   a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c))
8271 static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred,
8272                                                Value *CmpLHS, Value *CmpRHS,
8273                                                Value *TVal, Value *FVal,
8274                                                unsigned Depth) {
8275   // TODO: Allow FP min/max with nnan/nsz.
8276   assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison");
8277 
8278   Value *A = nullptr, *B = nullptr;
8279   SelectPatternResult L = matchSelectPattern(TVal, A, B, nullptr, Depth + 1);
8280   if (!SelectPatternResult::isMinOrMax(L.Flavor))
8281     return {SPF_UNKNOWN, SPNB_NA, false};
8282 
8283   Value *C = nullptr, *D = nullptr;
8284   SelectPatternResult R = matchSelectPattern(FVal, C, D, nullptr, Depth + 1);
8285   if (L.Flavor != R.Flavor)
8286     return {SPF_UNKNOWN, SPNB_NA, false};
8287 
8288   // We have something like: x Pred y ? min(a, b) : min(c, d).
8289   // Try to match the compare to the min/max operations of the select operands.
8290   // First, make sure we have the right compare predicate.
8291   switch (L.Flavor) {
8292   case SPF_SMIN:
8293     if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
8294       Pred = ICmpInst::getSwappedPredicate(Pred);
8295       std::swap(CmpLHS, CmpRHS);
8296     }
8297     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
8298       break;
8299     return {SPF_UNKNOWN, SPNB_NA, false};
8300   case SPF_SMAX:
8301     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
8302       Pred = ICmpInst::getSwappedPredicate(Pred);
8303       std::swap(CmpLHS, CmpRHS);
8304     }
8305     if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
8306       break;
8307     return {SPF_UNKNOWN, SPNB_NA, false};
8308   case SPF_UMIN:
8309     if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
8310       Pred = ICmpInst::getSwappedPredicate(Pred);
8311       std::swap(CmpLHS, CmpRHS);
8312     }
8313     if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
8314       break;
8315     return {SPF_UNKNOWN, SPNB_NA, false};
8316   case SPF_UMAX:
8317     if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
8318       Pred = ICmpInst::getSwappedPredicate(Pred);
8319       std::swap(CmpLHS, CmpRHS);
8320     }
8321     if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
8322       break;
8323     return {SPF_UNKNOWN, SPNB_NA, false};
8324   default:
8325     return {SPF_UNKNOWN, SPNB_NA, false};
8326   }
8327 
8328   // If there is a common operand in the already matched min/max and the other
8329   // min/max operands match the compare operands (either directly or inverted),
8330   // then this is min/max of the same flavor.
8331 
8332   // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8333   // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8334   if (D == B) {
8335     if ((CmpLHS == A && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) &&
8336                                          match(A, m_Not(m_Specific(CmpRHS)))))
8337       return {L.Flavor, SPNB_NA, false};
8338   }
8339   // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8340   // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8341   if (C == B) {
8342     if ((CmpLHS == A && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) &&
8343                                          match(A, m_Not(m_Specific(CmpRHS)))))
8344       return {L.Flavor, SPNB_NA, false};
8345   }
8346   // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8347   // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8348   if (D == A) {
8349     if ((CmpLHS == B && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) &&
8350                                          match(B, m_Not(m_Specific(CmpRHS)))))
8351       return {L.Flavor, SPNB_NA, false};
8352   }
8353   // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8354   // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8355   if (C == A) {
8356     if ((CmpLHS == B && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) &&
8357                                          match(B, m_Not(m_Specific(CmpRHS)))))
8358       return {L.Flavor, SPNB_NA, false};
8359   }
8360 
8361   return {SPF_UNKNOWN, SPNB_NA, false};
8362 }
8363 
8364 /// If the input value is the result of a 'not' op, constant integer, or vector
8365 /// splat of a constant integer, return the bitwise-not source value.
8366 /// TODO: This could be extended to handle non-splat vector integer constants.
8367 static Value *getNotValue(Value *V) {
8368   Value *NotV;
8369   if (match(V, m_Not(m_Value(NotV))))
8370     return NotV;
8371 
8372   const APInt *C;
8373   if (match(V, m_APInt(C)))
8374     return ConstantInt::get(V->getType(), ~(*C));
8375 
8376   return nullptr;
8377 }
8378 
8379 /// Match non-obvious integer minimum and maximum sequences.
8380 static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
8381                                        Value *CmpLHS, Value *CmpRHS,
8382                                        Value *TrueVal, Value *FalseVal,
8383                                        Value *&LHS, Value *&RHS,
8384                                        unsigned Depth) {
8385   // Assume success. If there's no match, callers should not use these anyway.
8386   LHS = TrueVal;
8387   RHS = FalseVal;
8388 
8389   SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal);
8390   if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8391     return SPR;
8392 
8393   SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, Depth);
8394   if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8395     return SPR;
8396 
8397   // Look through 'not' ops to find disguised min/max.
8398   // (X > Y) ? ~X : ~Y ==> (~X < ~Y) ? ~X : ~Y ==> MIN(~X, ~Y)
8399   // (X < Y) ? ~X : ~Y ==> (~X > ~Y) ? ~X : ~Y ==> MAX(~X, ~Y)
8400   if (CmpLHS == getNotValue(TrueVal) && CmpRHS == getNotValue(FalseVal)) {
8401     switch (Pred) {
8402     case CmpInst::ICMP_SGT: return {SPF_SMIN, SPNB_NA, false};
8403     case CmpInst::ICMP_SLT: return {SPF_SMAX, SPNB_NA, false};
8404     case CmpInst::ICMP_UGT: return {SPF_UMIN, SPNB_NA, false};
8405     case CmpInst::ICMP_ULT: return {SPF_UMAX, SPNB_NA, false};
8406     default: break;
8407     }
8408   }
8409 
8410   // (X > Y) ? ~Y : ~X ==> (~X < ~Y) ? ~Y : ~X ==> MAX(~Y, ~X)
8411   // (X < Y) ? ~Y : ~X ==> (~X > ~Y) ? ~Y : ~X ==> MIN(~Y, ~X)
8412   if (CmpLHS == getNotValue(FalseVal) && CmpRHS == getNotValue(TrueVal)) {
8413     switch (Pred) {
8414     case CmpInst::ICMP_SGT: return {SPF_SMAX, SPNB_NA, false};
8415     case CmpInst::ICMP_SLT: return {SPF_SMIN, SPNB_NA, false};
8416     case CmpInst::ICMP_UGT: return {SPF_UMAX, SPNB_NA, false};
8417     case CmpInst::ICMP_ULT: return {SPF_UMIN, SPNB_NA, false};
8418     default: break;
8419     }
8420   }
8421 
8422   if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
8423     return {SPF_UNKNOWN, SPNB_NA, false};
8424 
8425   const APInt *C1;
8426   if (!match(CmpRHS, m_APInt(C1)))
8427     return {SPF_UNKNOWN, SPNB_NA, false};
8428 
8429   // An unsigned min/max can be written with a signed compare.
8430   const APInt *C2;
8431   if ((CmpLHS == TrueVal && match(FalseVal, m_APInt(C2))) ||
8432       (CmpLHS == FalseVal && match(TrueVal, m_APInt(C2)))) {
8433     // Is the sign bit set?
8434     // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX
8435     // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN
8436     if (Pred == CmpInst::ICMP_SLT && C1->isZero() && C2->isMaxSignedValue())
8437       return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false};
8438 
8439     // Is the sign bit clear?
8440     // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX
8441     // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN
8442     if (Pred == CmpInst::ICMP_SGT && C1->isAllOnes() && C2->isMinSignedValue())
8443       return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false};
8444   }
8445 
8446   return {SPF_UNKNOWN, SPNB_NA, false};
8447 }
8448 
8449 bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
8450                            bool AllowPoison) {
8451   assert(X && Y && "Invalid operand");
8452 
8453   auto IsNegationOf = [&](const Value *X, const Value *Y) {
8454     if (!match(X, m_Neg(m_Specific(Y))))
8455       return false;
8456 
8457     auto *BO = cast<BinaryOperator>(X);
8458     if (NeedNSW && !BO->hasNoSignedWrap())
8459       return false;
8460 
8461     auto *Zero = cast<Constant>(BO->getOperand(0));
8462     if (!AllowPoison && !Zero->isNullValue())
8463       return false;
8464 
8465     return true;
8466   };
8467 
8468   // X = -Y or Y = -X
8469   if (IsNegationOf(X, Y) || IsNegationOf(Y, X))
8470     return true;
8471 
8472   // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A)
8473   Value *A, *B;
8474   return (!NeedNSW && (match(X, m_Sub(m_Value(A), m_Value(B))) &&
8475                         match(Y, m_Sub(m_Specific(B), m_Specific(A))))) ||
8476          (NeedNSW && (match(X, m_NSWSub(m_Value(A), m_Value(B))) &&
8477                        match(Y, m_NSWSub(m_Specific(B), m_Specific(A)))));
8478 }
8479 
8480 bool llvm::isKnownInversion(const Value *X, const Value *Y) {
8481   // Handle X = icmp pred A, B, Y = icmp pred A, C.
8482   Value *A, *B, *C;
8483   CmpPredicate Pred1, Pred2;
8484   if (!match(X, m_ICmp(Pred1, m_Value(A), m_Value(B))) ||
8485       !match(Y, m_c_ICmp(Pred2, m_Specific(A), m_Value(C))))
8486     return false;
8487 
8488   // They must both have samesign flag or not.
8489   if (Pred1.hasSameSign() != Pred2.hasSameSign())
8490     return false;
8491 
8492   if (B == C)
8493     return Pred1 == ICmpInst::getInversePredicate(Pred2);
8494 
8495   // Try to infer the relationship from constant ranges.
8496   const APInt *RHSC1, *RHSC2;
8497   if (!match(B, m_APInt(RHSC1)) || !match(C, m_APInt(RHSC2)))
8498     return false;
8499 
8500   // Sign bits of two RHSCs should match.
8501   if (Pred1.hasSameSign() && RHSC1->isNonNegative() != RHSC2->isNonNegative())
8502     return false;
8503 
8504   const auto CR1 = ConstantRange::makeExactICmpRegion(Pred1, *RHSC1);
8505   const auto CR2 = ConstantRange::makeExactICmpRegion(Pred2, *RHSC2);
8506 
8507   return CR1.inverse() == CR2;
8508 }
8509 
8510 SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
8511                                            SelectPatternNaNBehavior NaNBehavior,
8512                                            bool Ordered) {
8513   switch (Pred) {
8514   default:
8515     return {SPF_UNKNOWN, SPNB_NA, false}; // Equality.
8516   case ICmpInst::ICMP_UGT:
8517   case ICmpInst::ICMP_UGE:
8518     return {SPF_UMAX, SPNB_NA, false};
8519   case ICmpInst::ICMP_SGT:
8520   case ICmpInst::ICMP_SGE:
8521     return {SPF_SMAX, SPNB_NA, false};
8522   case ICmpInst::ICMP_ULT:
8523   case ICmpInst::ICMP_ULE:
8524     return {SPF_UMIN, SPNB_NA, false};
8525   case ICmpInst::ICMP_SLT:
8526   case ICmpInst::ICMP_SLE:
8527     return {SPF_SMIN, SPNB_NA, false};
8528   case FCmpInst::FCMP_UGT:
8529   case FCmpInst::FCMP_UGE:
8530   case FCmpInst::FCMP_OGT:
8531   case FCmpInst::FCMP_OGE:
8532     return {SPF_FMAXNUM, NaNBehavior, Ordered};
8533   case FCmpInst::FCMP_ULT:
8534   case FCmpInst::FCMP_ULE:
8535   case FCmpInst::FCMP_OLT:
8536   case FCmpInst::FCMP_OLE:
8537     return {SPF_FMINNUM, NaNBehavior, Ordered};
8538   }
8539 }
8540 
8541 std::optional<std::pair<CmpPredicate, Constant *>>
8542 llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
8543   assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
8544          "Only for relational integer predicates.");
8545   if (isa<UndefValue>(C))
8546     return std::nullopt;
8547 
8548   Type *Type = C->getType();
8549   bool IsSigned = ICmpInst::isSigned(Pred);
8550 
8551   CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
8552   bool WillIncrement =
8553       UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
8554 
8555   // Check if the constant operand can be safely incremented/decremented
8556   // without overflowing/underflowing.
8557   auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
8558     return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
8559   };
8560 
8561   Constant *SafeReplacementConstant = nullptr;
8562   if (auto *CI = dyn_cast<ConstantInt>(C)) {
8563     // Bail out if the constant can't be safely incremented/decremented.
8564     if (!ConstantIsOk(CI))
8565       return std::nullopt;
8566   } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
8567     unsigned NumElts = FVTy->getNumElements();
8568     for (unsigned i = 0; i != NumElts; ++i) {
8569       Constant *Elt = C->getAggregateElement(i);
8570       if (!Elt)
8571         return std::nullopt;
8572 
8573       if (isa<UndefValue>(Elt))
8574         continue;
8575 
8576       // Bail out if we can't determine if this constant is min/max or if we
8577       // know that this constant is min/max.
8578       auto *CI = dyn_cast<ConstantInt>(Elt);
8579       if (!CI || !ConstantIsOk(CI))
8580         return std::nullopt;
8581 
8582       if (!SafeReplacementConstant)
8583         SafeReplacementConstant = CI;
8584     }
8585   } else if (isa<VectorType>(C->getType())) {
8586     // Handle scalable splat
8587     Value *SplatC = C->getSplatValue();
8588     auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
8589     // Bail out if the constant can't be safely incremented/decremented.
8590     if (!CI || !ConstantIsOk(CI))
8591       return std::nullopt;
8592   } else {
8593     // ConstantExpr?
8594     return std::nullopt;
8595   }
8596 
8597   // It may not be safe to change a compare predicate in the presence of
8598   // undefined elements, so replace those elements with the first safe constant
8599   // that we found.
8600   // TODO: in case of poison, it is safe; let's replace undefs only.
8601   if (C->containsUndefOrPoisonElement()) {
8602     assert(SafeReplacementConstant && "Replacement constant not set");
8603     C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
8604   }
8605 
8606   CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
8607 
8608   // Increment or decrement the constant.
8609   Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
8610   Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
8611 
8612   return std::make_pair(NewPred, NewC);
8613 }
8614 
8615 static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
8616                                               FastMathFlags FMF,
8617                                               Value *CmpLHS, Value *CmpRHS,
8618                                               Value *TrueVal, Value *FalseVal,
8619                                               Value *&LHS, Value *&RHS,
8620                                               unsigned Depth) {
8621   bool HasMismatchedZeros = false;
8622   if (CmpInst::isFPPredicate(Pred)) {
8623     // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one
8624     // 0.0 operand, set the compare's 0.0 operands to that same value for the
8625     // purpose of identifying min/max. Disregard vector constants with undefined
8626     // elements because those can not be back-propagated for analysis.
8627     Value *OutputZeroVal = nullptr;
8628     if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) &&
8629         !cast<Constant>(TrueVal)->containsUndefOrPoisonElement())
8630       OutputZeroVal = TrueVal;
8631     else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) &&
8632              !cast<Constant>(FalseVal)->containsUndefOrPoisonElement())
8633       OutputZeroVal = FalseVal;
8634 
8635     if (OutputZeroVal) {
8636       if (match(CmpLHS, m_AnyZeroFP()) && CmpLHS != OutputZeroVal) {
8637         HasMismatchedZeros = true;
8638         CmpLHS = OutputZeroVal;
8639       }
8640       if (match(CmpRHS, m_AnyZeroFP()) && CmpRHS != OutputZeroVal) {
8641         HasMismatchedZeros = true;
8642         CmpRHS = OutputZeroVal;
8643       }
8644     }
8645   }
8646 
8647   LHS = CmpLHS;
8648   RHS = CmpRHS;
8649 
8650   // Signed zero may return inconsistent results between implementations.
8651   //  (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0
8652   //  minNum(0.0, -0.0)          // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1)
8653   // Therefore, we behave conservatively and only proceed if at least one of the
8654   // operands is known to not be zero or if we don't care about signed zero.
8655   switch (Pred) {
8656   default: break;
8657   case CmpInst::FCMP_OGT: case CmpInst::FCMP_OLT:
8658   case CmpInst::FCMP_UGT: case CmpInst::FCMP_ULT:
8659     if (!HasMismatchedZeros)
8660       break;
8661     [[fallthrough]];
8662   case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE:
8663   case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE:
8664     if (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) &&
8665         !isKnownNonZero(CmpRHS))
8666       return {SPF_UNKNOWN, SPNB_NA, false};
8667   }
8668 
8669   SelectPatternNaNBehavior NaNBehavior = SPNB_NA;
8670   bool Ordered = false;
8671 
8672   // When given one NaN and one non-NaN input:
8673   //   - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input.
8674   //   - A simple C99 (a < b ? a : b) construction will return 'b' (as the
8675   //     ordered comparison fails), which could be NaN or non-NaN.
8676   // so here we discover exactly what NaN behavior is required/accepted.
8677   if (CmpInst::isFPPredicate(Pred)) {
8678     bool LHSSafe = isKnownNonNaN(CmpLHS, FMF);
8679     bool RHSSafe = isKnownNonNaN(CmpRHS, FMF);
8680 
8681     if (LHSSafe && RHSSafe) {
8682       // Both operands are known non-NaN.
8683       NaNBehavior = SPNB_RETURNS_ANY;
8684       Ordered = CmpInst::isOrdered(Pred);
8685     } else if (CmpInst::isOrdered(Pred)) {
8686       // An ordered comparison will return false when given a NaN, so it
8687       // returns the RHS.
8688       Ordered = true;
8689       if (LHSSafe)
8690         // LHS is non-NaN, so if RHS is NaN then NaN will be returned.
8691         NaNBehavior = SPNB_RETURNS_NAN;
8692       else if (RHSSafe)
8693         NaNBehavior = SPNB_RETURNS_OTHER;
8694       else
8695         // Completely unsafe.
8696         return {SPF_UNKNOWN, SPNB_NA, false};
8697     } else {
8698       Ordered = false;
8699       // An unordered comparison will return true when given a NaN, so it
8700       // returns the LHS.
8701       if (LHSSafe)
8702         // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned.
8703         NaNBehavior = SPNB_RETURNS_OTHER;
8704       else if (RHSSafe)
8705         NaNBehavior = SPNB_RETURNS_NAN;
8706       else
8707         // Completely unsafe.
8708         return {SPF_UNKNOWN, SPNB_NA, false};
8709     }
8710   }
8711 
8712   if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
8713     std::swap(CmpLHS, CmpRHS);
8714     Pred = CmpInst::getSwappedPredicate(Pred);
8715     if (NaNBehavior == SPNB_RETURNS_NAN)
8716       NaNBehavior = SPNB_RETURNS_OTHER;
8717     else if (NaNBehavior == SPNB_RETURNS_OTHER)
8718       NaNBehavior = SPNB_RETURNS_NAN;
8719     Ordered = !Ordered;
8720   }
8721 
8722   // ([if]cmp X, Y) ? X : Y
8723   if (TrueVal == CmpLHS && FalseVal == CmpRHS)
8724     return getSelectPattern(Pred, NaNBehavior, Ordered);
8725 
8726   if (isKnownNegation(TrueVal, FalseVal)) {
8727     // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
8728     // match against either LHS or sext(LHS).
8729     auto MaybeSExtCmpLHS =
8730         m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS)));
8731     auto ZeroOrAllOnes = m_CombineOr(m_ZeroInt(), m_AllOnes());
8732     auto ZeroOrOne = m_CombineOr(m_ZeroInt(), m_One());
8733     if (match(TrueVal, MaybeSExtCmpLHS)) {
8734       // Set the return values. If the compare uses the negated value (-X >s 0),
8735       // swap the return values because the negated value is always 'RHS'.
8736       LHS = TrueVal;
8737       RHS = FalseVal;
8738       if (match(CmpLHS, m_Neg(m_Specific(FalseVal))))
8739         std::swap(LHS, RHS);
8740 
8741       // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X)
8742       // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X)
8743       if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes))
8744         return {SPF_ABS, SPNB_NA, false};
8745 
8746       // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X)
8747       if (Pred == ICmpInst::ICMP_SGE && match(CmpRHS, ZeroOrOne))
8748         return {SPF_ABS, SPNB_NA, false};
8749 
8750       // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X)
8751       // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X)
8752       if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne))
8753         return {SPF_NABS, SPNB_NA, false};
8754     }
8755     else if (match(FalseVal, MaybeSExtCmpLHS)) {
8756       // Set the return values. If the compare uses the negated value (-X >s 0),
8757       // swap the return values because the negated value is always 'RHS'.
8758       LHS = FalseVal;
8759       RHS = TrueVal;
8760       if (match(CmpLHS, m_Neg(m_Specific(TrueVal))))
8761         std::swap(LHS, RHS);
8762 
8763       // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X)
8764       // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X)
8765       if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes))
8766         return {SPF_NABS, SPNB_NA, false};
8767 
8768       // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X)
8769       // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X)
8770       if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne))
8771         return {SPF_ABS, SPNB_NA, false};
8772     }
8773   }
8774 
8775   if (CmpInst::isIntPredicate(Pred))
8776     return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth);
8777 
8778   // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar
8779   // may return either -0.0 or 0.0, so fcmp/select pair has stricter
8780   // semantics than minNum. Be conservative in such case.
8781   if (NaNBehavior != SPNB_RETURNS_ANY ||
8782       (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) &&
8783        !isKnownNonZero(CmpRHS)))
8784     return {SPF_UNKNOWN, SPNB_NA, false};
8785 
8786   return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
8787 }
8788 
8789 static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
8790                                    Instruction::CastOps *CastOp) {
8791   const DataLayout &DL = CmpI->getDataLayout();
8792 
8793   Constant *CastedTo = nullptr;
8794   switch (*CastOp) {
8795   case Instruction::ZExt:
8796     if (CmpI->isUnsigned())
8797       CastedTo = ConstantExpr::getTrunc(C, SrcTy);
8798     break;
8799   case Instruction::SExt:
8800     if (CmpI->isSigned())
8801       CastedTo = ConstantExpr::getTrunc(C, SrcTy, true);
8802     break;
8803   case Instruction::Trunc:
8804     Constant *CmpConst;
8805     if (match(CmpI->getOperand(1), m_Constant(CmpConst)) &&
8806         CmpConst->getType() == SrcTy) {
8807       // Here we have the following case:
8808       //
8809       //   %cond = cmp iN %x, CmpConst
8810       //   %tr = trunc iN %x to iK
8811       //   %narrowsel = select i1 %cond, iK %t, iK C
8812       //
8813       // We can always move trunc after select operation:
8814       //
8815       //   %cond = cmp iN %x, CmpConst
8816       //   %widesel = select i1 %cond, iN %x, iN CmpConst
8817       //   %tr = trunc iN %widesel to iK
8818       //
8819       // Note that C could be extended in any way because we don't care about
8820       // upper bits after truncation. It can't be abs pattern, because it would
8821       // look like:
8822       //
8823       //   select i1 %cond, x, -x.
8824       //
8825       // So only min/max pattern could be matched. Such match requires widened C
8826       // == CmpConst. That is why set widened C = CmpConst, condition trunc
8827       // CmpConst == C is checked below.
8828       CastedTo = CmpConst;
8829     } else {
8830       unsigned ExtOp = CmpI->isSigned() ? Instruction::SExt : Instruction::ZExt;
8831       CastedTo = ConstantFoldCastOperand(ExtOp, C, SrcTy, DL);
8832     }
8833     break;
8834   case Instruction::FPTrunc:
8835     CastedTo = ConstantFoldCastOperand(Instruction::FPExt, C, SrcTy, DL);
8836     break;
8837   case Instruction::FPExt:
8838     CastedTo = ConstantFoldCastOperand(Instruction::FPTrunc, C, SrcTy, DL);
8839     break;
8840   case Instruction::FPToUI:
8841     CastedTo = ConstantFoldCastOperand(Instruction::UIToFP, C, SrcTy, DL);
8842     break;
8843   case Instruction::FPToSI:
8844     CastedTo = ConstantFoldCastOperand(Instruction::SIToFP, C, SrcTy, DL);
8845     break;
8846   case Instruction::UIToFP:
8847     CastedTo = ConstantFoldCastOperand(Instruction::FPToUI, C, SrcTy, DL);
8848     break;
8849   case Instruction::SIToFP:
8850     CastedTo = ConstantFoldCastOperand(Instruction::FPToSI, C, SrcTy, DL);
8851     break;
8852   default:
8853     break;
8854   }
8855 
8856   if (!CastedTo)
8857     return nullptr;
8858 
8859   // Make sure the cast doesn't lose any information.
8860   Constant *CastedBack =
8861       ConstantFoldCastOperand(*CastOp, CastedTo, C->getType(), DL);
8862   if (CastedBack && CastedBack != C)
8863     return nullptr;
8864 
8865   return CastedTo;
8866 }
8867 
8868 /// Helps to match a select pattern in case of a type mismatch.
8869 ///
8870 /// The function processes the case when type of true and false values of a
8871 /// select instruction differs from type of the cmp instruction operands because
8872 /// of a cast instruction. The function checks if it is legal to move the cast
8873 /// operation after "select". If yes, it returns the new second value of
8874 /// "select" (with the assumption that cast is moved):
8875 /// 1. As operand of cast instruction when both values of "select" are same cast
8876 /// instructions.
8877 /// 2. As restored constant (by applying reverse cast operation) when the first
8878 /// value of the "select" is a cast operation and the second value is a
8879 /// constant. It is implemented in lookThroughCastConst().
8880 /// 3. As one operand is cast instruction and the other is not. The operands in
8881 /// sel(cmp) are in different type integer.
8882 /// NOTE: We return only the new second value because the first value could be
8883 /// accessed as operand of cast instruction.
8884 static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
8885                               Instruction::CastOps *CastOp) {
8886   auto *Cast1 = dyn_cast<CastInst>(V1);
8887   if (!Cast1)
8888     return nullptr;
8889 
8890   *CastOp = Cast1->getOpcode();
8891   Type *SrcTy = Cast1->getSrcTy();
8892   if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
8893     // If V1 and V2 are both the same cast from the same type, look through V1.
8894     if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
8895       return Cast2->getOperand(0);
8896     return nullptr;
8897   }
8898 
8899   auto *C = dyn_cast<Constant>(V2);
8900   if (C)
8901     return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
8902 
8903   Value *CastedTo = nullptr;
8904   if (*CastOp == Instruction::Trunc) {
8905     if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2)))) {
8906       // Here we have the following case:
8907       //   %y_ext = sext iK %y to iN
8908       //   %cond = cmp iN %x, %y_ext
8909       //   %tr = trunc iN %x to iK
8910       //   %narrowsel = select i1 %cond, iK %tr, iK %y
8911       //
8912       // We can always move trunc after select operation:
8913       //   %y_ext = sext iK %y to iN
8914       //   %cond = cmp iN %x, %y_ext
8915       //   %widesel = select i1 %cond, iN %x, iN %y_ext
8916       //   %tr = trunc iN %widesel to iK
8917       assert(V2->getType() == Cast1->getType() &&
8918              "V2 and Cast1 should be the same type.");
8919       CastedTo = CmpI->getOperand(1);
8920     }
8921   }
8922 
8923   return CastedTo;
8924 }
8925 SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
8926                                              Instruction::CastOps *CastOp,
8927                                              unsigned Depth) {
8928   if (Depth >= MaxAnalysisRecursionDepth)
8929     return {SPF_UNKNOWN, SPNB_NA, false};
8930 
8931   SelectInst *SI = dyn_cast<SelectInst>(V);
8932   if (!SI) return {SPF_UNKNOWN, SPNB_NA, false};
8933 
8934   CmpInst *CmpI = dyn_cast<CmpInst>(SI->getCondition());
8935   if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false};
8936 
8937   Value *TrueVal = SI->getTrueValue();
8938   Value *FalseVal = SI->getFalseValue();
8939 
8940   return llvm::matchDecomposedSelectPattern(
8941       CmpI, TrueVal, FalseVal, LHS, RHS,
8942       isa<FPMathOperator>(SI) ? SI->getFastMathFlags() : FastMathFlags(),
8943       CastOp, Depth);
8944 }
8945 
8946 SelectPatternResult llvm::matchDecomposedSelectPattern(
8947     CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
8948     FastMathFlags FMF, Instruction::CastOps *CastOp, unsigned Depth) {
8949   CmpInst::Predicate Pred = CmpI->getPredicate();
8950   Value *CmpLHS = CmpI->getOperand(0);
8951   Value *CmpRHS = CmpI->getOperand(1);
8952   if (isa<FPMathOperator>(CmpI) && CmpI->hasNoNaNs())
8953     FMF.setNoNaNs();
8954 
8955   // Bail out early.
8956   if (CmpI->isEquality())
8957     return {SPF_UNKNOWN, SPNB_NA, false};
8958 
8959   // Deal with type mismatches.
8960   if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
8961     if (Value *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp)) {
8962       // If this is a potential fmin/fmax with a cast to integer, then ignore
8963       // -0.0 because there is no corresponding integer value.
8964       if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
8965         FMF.setNoSignedZeros();
8966       return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
8967                                   cast<CastInst>(TrueVal)->getOperand(0), C,
8968                                   LHS, RHS, Depth);
8969     }
8970     if (Value *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) {
8971       // If this is a potential fmin/fmax with a cast to integer, then ignore
8972       // -0.0 because there is no corresponding integer value.
8973       if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
8974         FMF.setNoSignedZeros();
8975       return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
8976                                   C, cast<CastInst>(FalseVal)->getOperand(0),
8977                                   LHS, RHS, Depth);
8978     }
8979   }
8980   return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal,
8981                               LHS, RHS, Depth);
8982 }
8983 
8984 CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
8985   if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT;
8986   if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT;
8987   if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT;
8988   if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT;
8989   if (SPF == SPF_FMINNUM)
8990     return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
8991   if (SPF == SPF_FMAXNUM)
8992     return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
8993   llvm_unreachable("unhandled!");
8994 }
8995 
8996 Intrinsic::ID llvm::getMinMaxIntrinsic(SelectPatternFlavor SPF) {
8997   switch (SPF) {
8998   case SelectPatternFlavor::SPF_UMIN:
8999     return Intrinsic::umin;
9000   case SelectPatternFlavor::SPF_UMAX:
9001     return Intrinsic::umax;
9002   case SelectPatternFlavor::SPF_SMIN:
9003     return Intrinsic::smin;
9004   case SelectPatternFlavor::SPF_SMAX:
9005     return Intrinsic::smax;
9006   default:
9007     llvm_unreachable("Unexpected SPF");
9008   }
9009 }
9010 
9011 SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
9012   if (SPF == SPF_SMIN) return SPF_SMAX;
9013   if (SPF == SPF_UMIN) return SPF_UMAX;
9014   if (SPF == SPF_SMAX) return SPF_SMIN;
9015   if (SPF == SPF_UMAX) return SPF_UMIN;
9016   llvm_unreachable("unhandled!");
9017 }
9018 
9019 Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
9020   switch (MinMaxID) {
9021   case Intrinsic::smax: return Intrinsic::smin;
9022   case Intrinsic::smin: return Intrinsic::smax;
9023   case Intrinsic::umax: return Intrinsic::umin;
9024   case Intrinsic::umin: return Intrinsic::umax;
9025   // Please note that next four intrinsics may produce the same result for
9026   // original and inverted case even if X != Y due to NaN is handled specially.
9027   case Intrinsic::maximum: return Intrinsic::minimum;
9028   case Intrinsic::minimum: return Intrinsic::maximum;
9029   case Intrinsic::maxnum: return Intrinsic::minnum;
9030   case Intrinsic::minnum: return Intrinsic::maxnum;
9031   default: llvm_unreachable("Unexpected intrinsic");
9032   }
9033 }
9034 
9035 APInt llvm::getMinMaxLimit(SelectPatternFlavor SPF, unsigned BitWidth) {
9036   switch (SPF) {
9037   case SPF_SMAX: return APInt::getSignedMaxValue(BitWidth);
9038   case SPF_SMIN: return APInt::getSignedMinValue(BitWidth);
9039   case SPF_UMAX: return APInt::getMaxValue(BitWidth);
9040   case SPF_UMIN: return APInt::getMinValue(BitWidth);
9041   default: llvm_unreachable("Unexpected flavor");
9042   }
9043 }
9044 
9045 std::pair<Intrinsic::ID, bool>
9046 llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
9047   // Check if VL contains select instructions that can be folded into a min/max
9048   // vector intrinsic and return the intrinsic if it is possible.
9049   // TODO: Support floating point min/max.
9050   bool AllCmpSingleUse = true;
9051   SelectPatternResult SelectPattern;
9052   SelectPattern.Flavor = SPF_UNKNOWN;
9053   if (all_of(VL, [&SelectPattern, &AllCmpSingleUse](Value *I) {
9054         Value *LHS, *RHS;
9055         auto CurrentPattern = matchSelectPattern(I, LHS, RHS);
9056         if (!SelectPatternResult::isMinOrMax(CurrentPattern.Flavor))
9057           return false;
9058         if (SelectPattern.Flavor != SPF_UNKNOWN &&
9059             SelectPattern.Flavor != CurrentPattern.Flavor)
9060           return false;
9061         SelectPattern = CurrentPattern;
9062         AllCmpSingleUse &=
9063             match(I, m_Select(m_OneUse(m_Value()), m_Value(), m_Value()));
9064         return true;
9065       })) {
9066     switch (SelectPattern.Flavor) {
9067     case SPF_SMIN:
9068       return {Intrinsic::smin, AllCmpSingleUse};
9069     case SPF_UMIN:
9070       return {Intrinsic::umin, AllCmpSingleUse};
9071     case SPF_SMAX:
9072       return {Intrinsic::smax, AllCmpSingleUse};
9073     case SPF_UMAX:
9074       return {Intrinsic::umax, AllCmpSingleUse};
9075     case SPF_FMAXNUM:
9076       return {Intrinsic::maxnum, AllCmpSingleUse};
9077     case SPF_FMINNUM:
9078       return {Intrinsic::minnum, AllCmpSingleUse};
9079     default:
9080       llvm_unreachable("unexpected select pattern flavor");
9081     }
9082   }
9083   return {Intrinsic::not_intrinsic, false};
9084 }
9085 
9086 template <typename InstTy>
9087 static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
9088                                     Value *&Init, Value *&OtherOp) {
9089   // Handle the case of a simple two-predecessor recurrence PHI.
9090   // There's a lot more that could theoretically be done here, but
9091   // this is sufficient to catch some interesting cases.
9092   // TODO: Expand list -- gep, uadd.sat etc.
9093   if (PN->getNumIncomingValues() != 2)
9094     return false;
9095 
9096   for (unsigned I = 0; I != 2; ++I) {
9097     if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(I))) {
9098       Value *LHS = Operation->getOperand(0);
9099       Value *RHS = Operation->getOperand(1);
9100       if (LHS != PN && RHS != PN)
9101         continue;
9102 
9103       Inst = Operation;
9104       Init = PN->getIncomingValue(!I);
9105       OtherOp = (LHS == PN) ? RHS : LHS;
9106       return true;
9107     }
9108   }
9109   return false;
9110 }
9111 
9112 bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9113                                  Value *&Start, Value *&Step) {
9114   // We try to match a recurrence of the form:
9115   //   %iv = [Start, %entry], [%iv.next, %backedge]
9116   //   %iv.next = binop %iv, Step
9117   // Or:
9118   //   %iv = [Start, %entry], [%iv.next, %backedge]
9119   //   %iv.next = binop Step, %iv
9120   return matchTwoInputRecurrence(P, BO, Start, Step);
9121 }
9122 
9123 bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
9124                                  Value *&Start, Value *&Step) {
9125   BinaryOperator *BO = nullptr;
9126   P = dyn_cast<PHINode>(I->getOperand(0));
9127   if (!P)
9128     P = dyn_cast<PHINode>(I->getOperand(1));
9129   return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
9130 }
9131 
9132 bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
9133                                                 PHINode *&P, Value *&Init,
9134                                                 Value *&OtherOp) {
9135   // Binary intrinsics only supported for now.
9136   if (I->arg_size() != 2 || I->getType() != I->getArgOperand(0)->getType() ||
9137       I->getType() != I->getArgOperand(1)->getType())
9138     return false;
9139 
9140   IntrinsicInst *II = nullptr;
9141   P = dyn_cast<PHINode>(I->getArgOperand(0));
9142   if (!P)
9143     P = dyn_cast<PHINode>(I->getArgOperand(1));
9144 
9145   return P && matchTwoInputRecurrence(P, II, Init, OtherOp) && II == I;
9146 }
9147 
9148 /// Return true if "icmp Pred LHS RHS" is always true.
9149 static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
9150                             const Value *RHS) {
9151   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
9152     return true;
9153 
9154   switch (Pred) {
9155   default:
9156     return false;
9157 
9158   case CmpInst::ICMP_SLE: {
9159     const APInt *C;
9160 
9161     // LHS s<= LHS +_{nsw} C   if C >= 0
9162     // LHS s<= LHS | C         if C >= 0
9163     if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))) ||
9164         match(RHS, m_Or(m_Specific(LHS), m_APInt(C))))
9165       return !C->isNegative();
9166 
9167     // LHS s<= smax(LHS, V) for any V
9168     if (match(RHS, m_c_SMax(m_Specific(LHS), m_Value())))
9169       return true;
9170 
9171     // smin(RHS, V) s<= RHS for any V
9172     if (match(LHS, m_c_SMin(m_Specific(RHS), m_Value())))
9173       return true;
9174 
9175     // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
9176     const Value *X;
9177     const APInt *CLHS, *CRHS;
9178     if (match(LHS, m_NSWAddLike(m_Value(X), m_APInt(CLHS))) &&
9179         match(RHS, m_NSWAddLike(m_Specific(X), m_APInt(CRHS))))
9180       return CLHS->sle(*CRHS);
9181 
9182     return false;
9183   }
9184 
9185   case CmpInst::ICMP_ULE: {
9186     // LHS u<= LHS +_{nuw} V for any V
9187     if (match(RHS, m_c_Add(m_Specific(LHS), m_Value())) &&
9188         cast<OverflowingBinaryOperator>(RHS)->hasNoUnsignedWrap())
9189       return true;
9190 
9191     // LHS u<= LHS | V for any V
9192     if (match(RHS, m_c_Or(m_Specific(LHS), m_Value())))
9193       return true;
9194 
9195     // LHS u<= umax(LHS, V) for any V
9196     if (match(RHS, m_c_UMax(m_Specific(LHS), m_Value())))
9197       return true;
9198 
9199     // RHS >> V u<= RHS for any V
9200     if (match(LHS, m_LShr(m_Specific(RHS), m_Value())))
9201       return true;
9202 
9203     // RHS u/ C_ugt_1 u<= RHS
9204     const APInt *C;
9205     if (match(LHS, m_UDiv(m_Specific(RHS), m_APInt(C))) && C->ugt(1))
9206       return true;
9207 
9208     // RHS & V u<= RHS for any V
9209     if (match(LHS, m_c_And(m_Specific(RHS), m_Value())))
9210       return true;
9211 
9212     // umin(RHS, V) u<= RHS for any V
9213     if (match(LHS, m_c_UMin(m_Specific(RHS), m_Value())))
9214       return true;
9215 
9216     // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
9217     const Value *X;
9218     const APInt *CLHS, *CRHS;
9219     if (match(LHS, m_NUWAddLike(m_Value(X), m_APInt(CLHS))) &&
9220         match(RHS, m_NUWAddLike(m_Specific(X), m_APInt(CRHS))))
9221       return CLHS->ule(*CRHS);
9222 
9223     return false;
9224   }
9225   }
9226 }
9227 
9228 /// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
9229 /// ALHS ARHS" is true.  Otherwise, return std::nullopt.
9230 static std::optional<bool>
9231 isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
9232                       const Value *ARHS, const Value *BLHS, const Value *BRHS) {
9233   switch (Pred) {
9234   default:
9235     return std::nullopt;
9236 
9237   case CmpInst::ICMP_SLT:
9238   case CmpInst::ICMP_SLE:
9239     if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS) &&
9240         isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS))
9241       return true;
9242     return std::nullopt;
9243 
9244   case CmpInst::ICMP_SGT:
9245   case CmpInst::ICMP_SGE:
9246     if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS) &&
9247         isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS))
9248       return true;
9249     return std::nullopt;
9250 
9251   case CmpInst::ICMP_ULT:
9252   case CmpInst::ICMP_ULE:
9253     if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS) &&
9254         isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS))
9255       return true;
9256     return std::nullopt;
9257 
9258   case CmpInst::ICMP_UGT:
9259   case CmpInst::ICMP_UGE:
9260     if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS) &&
9261         isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS))
9262       return true;
9263     return std::nullopt;
9264   }
9265 }
9266 
9267 /// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
9268 /// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
9269 /// Otherwise, return std::nullopt if we can't infer anything.
9270 static std::optional<bool>
9271 isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR,
9272                                  CmpPredicate RPred, const ConstantRange &RCR) {
9273   auto CRImpliesPred = [&](ConstantRange CR,
9274                            CmpInst::Predicate Pred) -> std::optional<bool> {
9275     // If all true values for lhs and true for rhs, lhs implies rhs
9276     if (CR.icmp(Pred, RCR))
9277       return true;
9278 
9279     // If there is no overlap, lhs implies not rhs
9280     if (CR.icmp(CmpInst::getInversePredicate(Pred), RCR))
9281       return false;
9282 
9283     return std::nullopt;
9284   };
9285   if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
9286                                RPred))
9287     return Res;
9288   if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
9289     LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(LPred)
9290                                 : LPred.dropSameSign();
9291     RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(RPred)
9292                                 : RPred.dropSameSign();
9293     return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
9294                          RPred);
9295   }
9296   return std::nullopt;
9297 }
9298 
9299 /// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9300 /// is true.  Return false if LHS implies RHS is false. Otherwise, return
9301 /// std::nullopt if we can't infer anything.
9302 static std::optional<bool>
9303 isImpliedCondICmps(CmpPredicate LPred, const Value *L0, const Value *L1,
9304                    CmpPredicate RPred, const Value *R0, const Value *R1,
9305                    const DataLayout &DL, bool LHSIsTrue) {
9306   // The rest of the logic assumes the LHS condition is true.  If that's not the
9307   // case, invert the predicate to make it so.
9308   if (!LHSIsTrue)
9309     LPred = ICmpInst::getInverseCmpPredicate(LPred);
9310 
9311   // We can have non-canonical operands, so try to normalize any common operand
9312   // to L0/R0.
9313   if (L0 == R1) {
9314     std::swap(R0, R1);
9315     RPred = ICmpInst::getSwappedCmpPredicate(RPred);
9316   }
9317   if (R0 == L1) {
9318     std::swap(L0, L1);
9319     LPred = ICmpInst::getSwappedCmpPredicate(LPred);
9320   }
9321   if (L1 == R1) {
9322     // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9323     if (L0 != R0 || match(L0, m_ImmConstant())) {
9324       std::swap(L0, L1);
9325       LPred = ICmpInst::getSwappedCmpPredicate(LPred);
9326       std::swap(R0, R1);
9327       RPred = ICmpInst::getSwappedCmpPredicate(RPred);
9328     }
9329   }
9330 
9331   // See if we can infer anything if operand-0 matches and we have at least one
9332   // constant.
9333   const APInt *Unused;
9334   if (L0 == R0 && (match(L1, m_APInt(Unused)) || match(R1, m_APInt(Unused)))) {
9335     // Potential TODO: We could also further use the constant range of L0/R0 to
9336     // further constraint the constant ranges. At the moment this leads to
9337     // several regressions related to not transforming `multi_use(A + C0) eq/ne
9338     // C1` (see discussion: D58633).
9339     ConstantRange LCR = computeConstantRange(
9340         L1, ICmpInst::isSigned(LPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9341         /*CxtI=*/nullptr, /*DT=*/nullptr, MaxAnalysisRecursionDepth - 1);
9342     ConstantRange RCR = computeConstantRange(
9343         R1, ICmpInst::isSigned(RPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9344         /*CxtI=*/nullptr, /*DT=*/nullptr, MaxAnalysisRecursionDepth - 1);
9345     // Even if L1/R1 are not both constant, we can still sometimes deduce
9346     // relationship from a single constant. For example X u> Y implies X != 0.
9347     if (auto R = isImpliedCondCommonOperandWithCR(LPred, LCR, RPred, RCR))
9348       return R;
9349     // If both L1/R1 were exact constant ranges and we didn't get anything
9350     // here, we won't be able to deduce this.
9351     if (match(L1, m_APInt(Unused)) && match(R1, m_APInt(Unused)))
9352       return std::nullopt;
9353   }
9354 
9355   // Can we infer anything when the two compares have matching operands?
9356   if (L0 == R0 && L1 == R1)
9357     return ICmpInst::isImpliedByMatchingCmp(LPred, RPred);
9358 
9359   // It only really makes sense in the context of signed comparison for "X - Y
9360   // must be positive if X >= Y and no overflow".
9361   // Take SGT as an example:  L0:x > L1:y and C >= 0
9362   //                      ==> R0:(x -nsw y) < R1:(-C) is false
9363   CmpInst::Predicate SignedLPred = LPred.getPreferredSignedPredicate();
9364   if ((SignedLPred == ICmpInst::ICMP_SGT ||
9365        SignedLPred == ICmpInst::ICMP_SGE) &&
9366       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
9367     if (match(R1, m_NonPositive()) &&
9368         ICmpInst::isImpliedByMatchingCmp(SignedLPred, RPred) == false)
9369       return false;
9370   }
9371 
9372   // Take SLT as an example:  L0:x < L1:y and C <= 0
9373   //                      ==> R0:(x -nsw y) < R1:(-C) is true
9374   if ((SignedLPred == ICmpInst::ICMP_SLT ||
9375        SignedLPred == ICmpInst::ICMP_SLE) &&
9376       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
9377     if (match(R1, m_NonNegative()) &&
9378         ICmpInst::isImpliedByMatchingCmp(SignedLPred, RPred) == true)
9379       return true;
9380   }
9381 
9382   // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0 <u L1 implies R0 <u R1
9383   if (L0 == R0 &&
9384       (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
9385       (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
9386       match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
9387     return CmpPredicate::getMatching(LPred, RPred).has_value();
9388 
9389   if (auto P = CmpPredicate::getMatching(LPred, RPred))
9390     return isImpliedCondOperands(*P, L0, L1, R0, R1);
9391 
9392   return std::nullopt;
9393 }
9394 
9395 /// Return true if LHS implies RHS is true.  Return false if LHS implies RHS is
9396 /// false.  Otherwise, return std::nullopt if we can't infer anything.  We
9397 /// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select'
9398 /// instruction.
9399 static std::optional<bool>
9400 isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred,
9401                    const Value *RHSOp0, const Value *RHSOp1,
9402                    const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9403   // The LHS must be an 'or', 'and', or a 'select' instruction.
9404   assert((LHS->getOpcode() == Instruction::And ||
9405           LHS->getOpcode() == Instruction::Or ||
9406           LHS->getOpcode() == Instruction::Select) &&
9407          "Expected LHS to be 'and', 'or', or 'select'.");
9408 
9409   assert(Depth <= MaxAnalysisRecursionDepth && "Hit recursion limit");
9410 
9411   // If the result of an 'or' is false, then we know both legs of the 'or' are
9412   // false.  Similarly, if the result of an 'and' is true, then we know both
9413   // legs of the 'and' are true.
9414   const Value *ALHS, *ARHS;
9415   if ((!LHSIsTrue && match(LHS, m_LogicalOr(m_Value(ALHS), m_Value(ARHS)))) ||
9416       (LHSIsTrue && match(LHS, m_LogicalAnd(m_Value(ALHS), m_Value(ARHS))))) {
9417     // FIXME: Make this non-recursion.
9418     if (std::optional<bool> Implication = isImpliedCondition(
9419             ALHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth + 1))
9420       return Implication;
9421     if (std::optional<bool> Implication = isImpliedCondition(
9422             ARHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth + 1))
9423       return Implication;
9424     return std::nullopt;
9425   }
9426   return std::nullopt;
9427 }
9428 
9429 std::optional<bool>
9430 llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
9431                          const Value *RHSOp0, const Value *RHSOp1,
9432                          const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9433   // Bail out when we hit the limit.
9434   if (Depth == MaxAnalysisRecursionDepth)
9435     return std::nullopt;
9436 
9437   // A mismatch occurs when we compare a scalar cmp to a vector cmp, for
9438   // example.
9439   if (RHSOp0->getType()->isVectorTy() != LHS->getType()->isVectorTy())
9440     return std::nullopt;
9441 
9442   assert(LHS->getType()->isIntOrIntVectorTy(1) &&
9443          "Expected integer type only!");
9444 
9445   // Match not
9446   if (match(LHS, m_Not(m_Value(LHS))))
9447     LHSIsTrue = !LHSIsTrue;
9448 
9449   // Both LHS and RHS are icmps.
9450   if (const auto *LHSCmp = dyn_cast<ICmpInst>(LHS))
9451     return isImpliedCondICmps(LHSCmp->getCmpPredicate(), LHSCmp->getOperand(0),
9452                               LHSCmp->getOperand(1), RHSPred, RHSOp0, RHSOp1,
9453                               DL, LHSIsTrue);
9454   const Value *V;
9455   if (match(LHS, m_NUWTrunc(m_Value(V))))
9456     return isImpliedCondICmps(CmpInst::ICMP_NE, V,
9457                               ConstantInt::get(V->getType(), 0), RHSPred,
9458                               RHSOp0, RHSOp1, DL, LHSIsTrue);
9459 
9460   /// The LHS should be an 'or', 'and', or a 'select' instruction.  We expect
9461   /// the RHS to be an icmp.
9462   /// FIXME: Add support for and/or/select on the RHS.
9463   if (const Instruction *LHSI = dyn_cast<Instruction>(LHS)) {
9464     if ((LHSI->getOpcode() == Instruction::And ||
9465          LHSI->getOpcode() == Instruction::Or ||
9466          LHSI->getOpcode() == Instruction::Select))
9467       return isImpliedCondAndOr(LHSI, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
9468                                 Depth);
9469   }
9470   return std::nullopt;
9471 }
9472 
9473 std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
9474                                              const DataLayout &DL,
9475                                              bool LHSIsTrue, unsigned Depth) {
9476   // LHS ==> RHS by definition
9477   if (LHS == RHS)
9478     return LHSIsTrue;
9479 
9480   // Match not
9481   bool InvertRHS = false;
9482   if (match(RHS, m_Not(m_Value(RHS)))) {
9483     if (LHS == RHS)
9484       return !LHSIsTrue;
9485     InvertRHS = true;
9486   }
9487 
9488   if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS)) {
9489     if (auto Implied = isImpliedCondition(
9490             LHS, RHSCmp->getCmpPredicate(), RHSCmp->getOperand(0),
9491             RHSCmp->getOperand(1), DL, LHSIsTrue, Depth))
9492       return InvertRHS ? !*Implied : *Implied;
9493     return std::nullopt;
9494   }
9495 
9496   const Value *V;
9497   if (match(RHS, m_NUWTrunc(m_Value(V)))) {
9498     if (auto Implied = isImpliedCondition(LHS, CmpInst::ICMP_NE, V,
9499                                           ConstantInt::get(V->getType(), 0), DL,
9500                                           LHSIsTrue, Depth))
9501       return InvertRHS ? !*Implied : *Implied;
9502     return std::nullopt;
9503   }
9504 
9505   if (Depth == MaxAnalysisRecursionDepth)
9506     return std::nullopt;
9507 
9508   // LHS ==> (RHS1 || RHS2) if LHS ==> RHS1 or LHS ==> RHS2
9509   // LHS ==> !(RHS1 && RHS2) if LHS ==> !RHS1 or LHS ==> !RHS2
9510   const Value *RHS1, *RHS2;
9511   if (match(RHS, m_LogicalOr(m_Value(RHS1), m_Value(RHS2)))) {
9512     if (std::optional<bool> Imp =
9513             isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1))
9514       if (*Imp == true)
9515         return !InvertRHS;
9516     if (std::optional<bool> Imp =
9517             isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1))
9518       if (*Imp == true)
9519         return !InvertRHS;
9520   }
9521   if (match(RHS, m_LogicalAnd(m_Value(RHS1), m_Value(RHS2)))) {
9522     if (std::optional<bool> Imp =
9523             isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1))
9524       if (*Imp == false)
9525         return InvertRHS;
9526     if (std::optional<bool> Imp =
9527             isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1))
9528       if (*Imp == false)
9529         return InvertRHS;
9530   }
9531 
9532   return std::nullopt;
9533 }
9534 
9535 // Returns a pair (Condition, ConditionIsTrue), where Condition is a branch
9536 // condition dominating ContextI or nullptr, if no condition is found.
9537 static std::pair<Value *, bool>
9538 getDomPredecessorCondition(const Instruction *ContextI) {
9539   if (!ContextI || !ContextI->getParent())
9540     return {nullptr, false};
9541 
9542   // TODO: This is a poor/cheap way to determine dominance. Should we use a
9543   // dominator tree (eg, from a SimplifyQuery) instead?
9544   const BasicBlock *ContextBB = ContextI->getParent();
9545   const BasicBlock *PredBB = ContextBB->getSinglePredecessor();
9546   if (!PredBB)
9547     return {nullptr, false};
9548 
9549   // We need a conditional branch in the predecessor.
9550   Value *PredCond;
9551   BasicBlock *TrueBB, *FalseBB;
9552   if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB)))
9553     return {nullptr, false};
9554 
9555   // The branch should get simplified. Don't bother simplifying this condition.
9556   if (TrueBB == FalseBB)
9557     return {nullptr, false};
9558 
9559   assert((TrueBB == ContextBB || FalseBB == ContextBB) &&
9560          "Predecessor block does not point to successor?");
9561 
9562   // Is this condition implied by the predecessor condition?
9563   return {PredCond, TrueBB == ContextBB};
9564 }
9565 
9566 std::optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
9567                                                   const Instruction *ContextI,
9568                                                   const DataLayout &DL) {
9569   assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool");
9570   auto PredCond = getDomPredecessorCondition(ContextI);
9571   if (PredCond.first)
9572     return isImpliedCondition(PredCond.first, Cond, DL, PredCond.second);
9573   return std::nullopt;
9574 }
9575 
9576 std::optional<bool> llvm::isImpliedByDomCondition(CmpPredicate Pred,
9577                                                   const Value *LHS,
9578                                                   const Value *RHS,
9579                                                   const Instruction *ContextI,
9580                                                   const DataLayout &DL) {
9581   auto PredCond = getDomPredecessorCondition(ContextI);
9582   if (PredCond.first)
9583     return isImpliedCondition(PredCond.first, Pred, LHS, RHS, DL,
9584                               PredCond.second);
9585   return std::nullopt;
9586 }
9587 
9588 static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
9589                               APInt &Upper, const InstrInfoQuery &IIQ,
9590                               bool PreferSignedRange) {
9591   unsigned Width = Lower.getBitWidth();
9592   const APInt *C;
9593   switch (BO.getOpcode()) {
9594   case Instruction::Sub:
9595     if (match(BO.getOperand(0), m_APInt(C))) {
9596       bool HasNSW = IIQ.hasNoSignedWrap(&BO);
9597       bool HasNUW = IIQ.hasNoUnsignedWrap(&BO);
9598 
9599       // If the caller expects a signed compare, then try to use a signed range.
9600       // Otherwise if both no-wraps are set, use the unsigned range because it
9601       // is never larger than the signed range. Example:
9602       // "sub nuw nsw i8 -2, x" is unsigned [0, 254] vs. signed [-128, 126].
9603       // "sub nuw nsw i8 2, x" is unsigned [0, 2] vs. signed [-125, 127].
9604       if (PreferSignedRange && HasNSW && HasNUW)
9605         HasNUW = false;
9606 
9607       if (HasNUW) {
9608         // 'sub nuw c, x' produces [0, C].
9609         Upper = *C + 1;
9610       } else if (HasNSW) {
9611         if (C->isNegative()) {
9612           // 'sub nsw -C, x' produces [SINT_MIN, -C - SINT_MIN].
9613           Lower = APInt::getSignedMinValue(Width);
9614           Upper = *C - APInt::getSignedMaxValue(Width);
9615         } else {
9616           // Note that sub 0, INT_MIN is not NSW. It techically is a signed wrap
9617           // 'sub nsw C, x' produces [C - SINT_MAX, SINT_MAX].
9618           Lower = *C - APInt::getSignedMaxValue(Width);
9619           Upper = APInt::getSignedMinValue(Width);
9620         }
9621       }
9622     }
9623     break;
9624   case Instruction::Add:
9625     if (match(BO.getOperand(1), m_APInt(C)) && !C->isZero()) {
9626       bool HasNSW = IIQ.hasNoSignedWrap(&BO);
9627       bool HasNUW = IIQ.hasNoUnsignedWrap(&BO);
9628 
9629       // If the caller expects a signed compare, then try to use a signed
9630       // range. Otherwise if both no-wraps are set, use the unsigned range
9631       // because it is never larger than the signed range. Example: "add nuw
9632       // nsw i8 X, -2" is unsigned [254,255] vs. signed [-128, 125].
9633       if (PreferSignedRange && HasNSW && HasNUW)
9634         HasNUW = false;
9635 
9636       if (HasNUW) {
9637         // 'add nuw x, C' produces [C, UINT_MAX].
9638         Lower = *C;
9639       } else if (HasNSW) {
9640         if (C->isNegative()) {
9641           // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
9642           Lower = APInt::getSignedMinValue(Width);
9643           Upper = APInt::getSignedMaxValue(Width) + *C + 1;
9644         } else {
9645           // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX].
9646           Lower = APInt::getSignedMinValue(Width) + *C;
9647           Upper = APInt::getSignedMaxValue(Width) + 1;
9648         }
9649       }
9650     }
9651     break;
9652 
9653   case Instruction::And:
9654     if (match(BO.getOperand(1), m_APInt(C)))
9655       // 'and x, C' produces [0, C].
9656       Upper = *C + 1;
9657     // X & -X is a power of two or zero. So we can cap the value at max power of
9658     // two.
9659     if (match(BO.getOperand(0), m_Neg(m_Specific(BO.getOperand(1)))) ||
9660         match(BO.getOperand(1), m_Neg(m_Specific(BO.getOperand(0)))))
9661       Upper = APInt::getSignedMinValue(Width) + 1;
9662     break;
9663 
9664   case Instruction::Or:
9665     if (match(BO.getOperand(1), m_APInt(C)))
9666       // 'or x, C' produces [C, UINT_MAX].
9667       Lower = *C;
9668     break;
9669 
9670   case Instruction::AShr:
9671     if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
9672       // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
9673       Lower = APInt::getSignedMinValue(Width).ashr(*C);
9674       Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1;
9675     } else if (match(BO.getOperand(0), m_APInt(C))) {
9676       unsigned ShiftAmount = Width - 1;
9677       if (!C->isZero() && IIQ.isExact(&BO))
9678         ShiftAmount = C->countr_zero();
9679       if (C->isNegative()) {
9680         // 'ashr C, x' produces [C, C >> (Width-1)]
9681         Lower = *C;
9682         Upper = C->ashr(ShiftAmount) + 1;
9683       } else {
9684         // 'ashr C, x' produces [C >> (Width-1), C]
9685         Lower = C->ashr(ShiftAmount);
9686         Upper = *C + 1;
9687       }
9688     }
9689     break;
9690 
9691   case Instruction::LShr:
9692     if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
9693       // 'lshr x, C' produces [0, UINT_MAX >> C].
9694       Upper = APInt::getAllOnes(Width).lshr(*C) + 1;
9695     } else if (match(BO.getOperand(0), m_APInt(C))) {
9696       // 'lshr C, x' produces [C >> (Width-1), C].
9697       unsigned ShiftAmount = Width - 1;
9698       if (!C->isZero() && IIQ.isExact(&BO))
9699         ShiftAmount = C->countr_zero();
9700       Lower = C->lshr(ShiftAmount);
9701       Upper = *C + 1;
9702     }
9703     break;
9704 
9705   case Instruction::Shl:
9706     if (match(BO.getOperand(0), m_APInt(C))) {
9707       if (IIQ.hasNoUnsignedWrap(&BO)) {
9708         // 'shl nuw C, x' produces [C, C << CLZ(C)]
9709         Lower = *C;
9710         Upper = Lower.shl(Lower.countl_zero()) + 1;
9711       } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
9712         if (C->isNegative()) {
9713           // 'shl nsw C, x' produces [C << CLO(C)-1, C]
9714           unsigned ShiftAmount = C->countl_one() - 1;
9715           Lower = C->shl(ShiftAmount);
9716           Upper = *C + 1;
9717         } else {
9718           // 'shl nsw C, x' produces [C, C << CLZ(C)-1]
9719           unsigned ShiftAmount = C->countl_zero() - 1;
9720           Lower = *C;
9721           Upper = C->shl(ShiftAmount) + 1;
9722         }
9723       } else {
9724         // If lowbit is set, value can never be zero.
9725         if ((*C)[0])
9726           Lower = APInt::getOneBitSet(Width, 0);
9727         // If we are shifting a constant the largest it can be is if the longest
9728         // sequence of consecutive ones is shifted to the highbits (breaking
9729         // ties for which sequence is higher). At the moment we take a liberal
9730         // upper bound on this by just popcounting the constant.
9731         // TODO: There may be a bitwise trick for it longest/highest
9732         // consecutative sequence of ones (naive method is O(Width) loop).
9733         Upper = APInt::getHighBitsSet(Width, C->popcount()) + 1;
9734       }
9735     } else if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) {
9736       Upper = APInt::getBitsSetFrom(Width, C->getZExtValue()) + 1;
9737     }
9738     break;
9739 
9740   case Instruction::SDiv:
9741     if (match(BO.getOperand(1), m_APInt(C))) {
9742       APInt IntMin = APInt::getSignedMinValue(Width);
9743       APInt IntMax = APInt::getSignedMaxValue(Width);
9744       if (C->isAllOnes()) {
9745         // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
9746         //    where C != -1 and C != 0 and C != 1
9747         Lower = IntMin + 1;
9748         Upper = IntMax + 1;
9749       } else if (C->countl_zero() < Width - 1) {
9750         // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
9751         //    where C != -1 and C != 0 and C != 1
9752         Lower = IntMin.sdiv(*C);
9753         Upper = IntMax.sdiv(*C);
9754         if (Lower.sgt(Upper))
9755           std::swap(Lower, Upper);
9756         Upper = Upper + 1;
9757         assert(Upper != Lower && "Upper part of range has wrapped!");
9758       }
9759     } else if (match(BO.getOperand(0), m_APInt(C))) {
9760       if (C->isMinSignedValue()) {
9761         // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
9762         Lower = *C;
9763         Upper = Lower.lshr(1) + 1;
9764       } else {
9765         // 'sdiv C, x' produces [-|C|, |C|].
9766         Upper = C->abs() + 1;
9767         Lower = (-Upper) + 1;
9768       }
9769     }
9770     break;
9771 
9772   case Instruction::UDiv:
9773     if (match(BO.getOperand(1), m_APInt(C)) && !C->isZero()) {
9774       // 'udiv x, C' produces [0, UINT_MAX / C].
9775       Upper = APInt::getMaxValue(Width).udiv(*C) + 1;
9776     } else if (match(BO.getOperand(0), m_APInt(C))) {
9777       // 'udiv C, x' produces [0, C].
9778       Upper = *C + 1;
9779     }
9780     break;
9781 
9782   case Instruction::SRem:
9783     if (match(BO.getOperand(1), m_APInt(C))) {
9784       // 'srem x, C' produces (-|C|, |C|).
9785       Upper = C->abs();
9786       Lower = (-Upper) + 1;
9787     } else if (match(BO.getOperand(0), m_APInt(C))) {
9788       if (C->isNegative()) {
9789         // 'srem -|C|, x' produces [-|C|, 0].
9790         Upper = 1;
9791         Lower = *C;
9792       } else {
9793         // 'srem |C|, x' produces [0, |C|].
9794         Upper = *C + 1;
9795       }
9796     }
9797     break;
9798 
9799   case Instruction::URem:
9800     if (match(BO.getOperand(1), m_APInt(C)))
9801       // 'urem x, C' produces [0, C).
9802       Upper = *C;
9803     else if (match(BO.getOperand(0), m_APInt(C)))
9804       // 'urem C, x' produces [0, C].
9805       Upper = *C + 1;
9806     break;
9807 
9808   default:
9809     break;
9810   }
9811 }
9812 
9813 static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II,
9814                                           bool UseInstrInfo) {
9815   unsigned Width = II.getType()->getScalarSizeInBits();
9816   const APInt *C;
9817   switch (II.getIntrinsicID()) {
9818   case Intrinsic::ctlz:
9819   case Intrinsic::cttz: {
9820     APInt Upper(Width, Width);
9821     if (!UseInstrInfo || !match(II.getArgOperand(1), m_One()))
9822       Upper += 1;
9823     // Maximum of set/clear bits is the bit width.
9824     return ConstantRange::getNonEmpty(APInt::getZero(Width), Upper);
9825   }
9826   case Intrinsic::ctpop:
9827     // Maximum of set/clear bits is the bit width.
9828     return ConstantRange::getNonEmpty(APInt::getZero(Width),
9829                                       APInt(Width, Width) + 1);
9830   case Intrinsic::uadd_sat:
9831     // uadd.sat(x, C) produces [C, UINT_MAX].
9832     if (match(II.getOperand(0), m_APInt(C)) ||
9833         match(II.getOperand(1), m_APInt(C)))
9834       return ConstantRange::getNonEmpty(*C, APInt::getZero(Width));
9835     break;
9836   case Intrinsic::sadd_sat:
9837     if (match(II.getOperand(0), m_APInt(C)) ||
9838         match(II.getOperand(1), m_APInt(C))) {
9839       if (C->isNegative())
9840         // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)].
9841         return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9842                                           APInt::getSignedMaxValue(Width) + *C +
9843                                               1);
9844 
9845       // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX].
9846       return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width) + *C,
9847                                         APInt::getSignedMaxValue(Width) + 1);
9848     }
9849     break;
9850   case Intrinsic::usub_sat:
9851     // usub.sat(C, x) produces [0, C].
9852     if (match(II.getOperand(0), m_APInt(C)))
9853       return ConstantRange::getNonEmpty(APInt::getZero(Width), *C + 1);
9854 
9855     // usub.sat(x, C) produces [0, UINT_MAX - C].
9856     if (match(II.getOperand(1), m_APInt(C)))
9857       return ConstantRange::getNonEmpty(APInt::getZero(Width),
9858                                         APInt::getMaxValue(Width) - *C + 1);
9859     break;
9860   case Intrinsic::ssub_sat:
9861     if (match(II.getOperand(0), m_APInt(C))) {
9862       if (C->isNegative())
9863         // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)].
9864         return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9865                                           *C - APInt::getSignedMinValue(Width) +
9866                                               1);
9867 
9868       // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX].
9869       return ConstantRange::getNonEmpty(*C - APInt::getSignedMaxValue(Width),
9870                                         APInt::getSignedMaxValue(Width) + 1);
9871     } else if (match(II.getOperand(1), m_APInt(C))) {
9872       if (C->isNegative())
9873         // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]:
9874         return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width) - *C,
9875                                           APInt::getSignedMaxValue(Width) + 1);
9876 
9877       // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C].
9878       return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9879                                         APInt::getSignedMaxValue(Width) - *C +
9880                                             1);
9881     }
9882     break;
9883   case Intrinsic::umin:
9884   case Intrinsic::umax:
9885   case Intrinsic::smin:
9886   case Intrinsic::smax:
9887     if (!match(II.getOperand(0), m_APInt(C)) &&
9888         !match(II.getOperand(1), m_APInt(C)))
9889       break;
9890 
9891     switch (II.getIntrinsicID()) {
9892     case Intrinsic::umin:
9893       return ConstantRange::getNonEmpty(APInt::getZero(Width), *C + 1);
9894     case Intrinsic::umax:
9895       return ConstantRange::getNonEmpty(*C, APInt::getZero(Width));
9896     case Intrinsic::smin:
9897       return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width),
9898                                         *C + 1);
9899     case Intrinsic::smax:
9900       return ConstantRange::getNonEmpty(*C,
9901                                         APInt::getSignedMaxValue(Width) + 1);
9902     default:
9903       llvm_unreachable("Must be min/max intrinsic");
9904     }
9905     break;
9906   case Intrinsic::abs:
9907     // If abs of SIGNED_MIN is poison, then the result is [0..SIGNED_MAX],
9908     // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
9909     if (match(II.getOperand(1), m_One()))
9910       return ConstantRange::getNonEmpty(APInt::getZero(Width),
9911                                         APInt::getSignedMaxValue(Width) + 1);
9912 
9913     return ConstantRange::getNonEmpty(APInt::getZero(Width),
9914                                       APInt::getSignedMinValue(Width) + 1);
9915   case Intrinsic::vscale:
9916     if (!II.getParent() || !II.getFunction())
9917       break;
9918     return getVScaleRange(II.getFunction(), Width);
9919   default:
9920     break;
9921   }
9922 
9923   return ConstantRange::getFull(Width);
9924 }
9925 
9926 static ConstantRange getRangeForSelectPattern(const SelectInst &SI,
9927                                               const InstrInfoQuery &IIQ) {
9928   unsigned BitWidth = SI.getType()->getScalarSizeInBits();
9929   const Value *LHS = nullptr, *RHS = nullptr;
9930   SelectPatternResult R = matchSelectPattern(&SI, LHS, RHS);
9931   if (R.Flavor == SPF_UNKNOWN)
9932     return ConstantRange::getFull(BitWidth);
9933 
9934   if (R.Flavor == SelectPatternFlavor::SPF_ABS) {
9935     // If the negation part of the abs (in RHS) has the NSW flag,
9936     // then the result of abs(X) is [0..SIGNED_MAX],
9937     // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
9938     if (match(RHS, m_Neg(m_Specific(LHS))) &&
9939         IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
9940       return ConstantRange::getNonEmpty(APInt::getZero(BitWidth),
9941                                         APInt::getSignedMaxValue(BitWidth) + 1);
9942 
9943     return ConstantRange::getNonEmpty(APInt::getZero(BitWidth),
9944                                       APInt::getSignedMinValue(BitWidth) + 1);
9945   }
9946 
9947   if (R.Flavor == SelectPatternFlavor::SPF_NABS) {
9948     // The result of -abs(X) is <= 0.
9949     return ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth),
9950                                       APInt(BitWidth, 1));
9951   }
9952 
9953   const APInt *C;
9954   if (!match(LHS, m_APInt(C)) && !match(RHS, m_APInt(C)))
9955     return ConstantRange::getFull(BitWidth);
9956 
9957   switch (R.Flavor) {
9958   case SPF_UMIN:
9959     return ConstantRange::getNonEmpty(APInt::getZero(BitWidth), *C + 1);
9960   case SPF_UMAX:
9961     return ConstantRange::getNonEmpty(*C, APInt::getZero(BitWidth));
9962   case SPF_SMIN:
9963     return ConstantRange::getNonEmpty(APInt::getSignedMinValue(BitWidth),
9964                                       *C + 1);
9965   case SPF_SMAX:
9966     return ConstantRange::getNonEmpty(*C,
9967                                       APInt::getSignedMaxValue(BitWidth) + 1);
9968   default:
9969     return ConstantRange::getFull(BitWidth);
9970   }
9971 }
9972 
9973 static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
9974   // The maximum representable value of a half is 65504. For floats the maximum
9975   // value is 3.4e38 which requires roughly 129 bits.
9976   unsigned BitWidth = I->getType()->getScalarSizeInBits();
9977   if (!I->getOperand(0)->getType()->getScalarType()->isHalfTy())
9978     return;
9979   if (isa<FPToSIInst>(I) && BitWidth >= 17) {
9980     Lower = APInt(BitWidth, -65504, true);
9981     Upper = APInt(BitWidth, 65505);
9982   }
9983 
9984   if (isa<FPToUIInst>(I) && BitWidth >= 16) {
9985     // For a fptoui the lower limit is left as 0.
9986     Upper = APInt(BitWidth, 65505);
9987   }
9988 }
9989 
9990 ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
9991                                          bool UseInstrInfo, AssumptionCache *AC,
9992                                          const Instruction *CtxI,
9993                                          const DominatorTree *DT,
9994                                          unsigned Depth) {
9995   assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
9996 
9997   if (Depth == MaxAnalysisRecursionDepth)
9998     return ConstantRange::getFull(V->getType()->getScalarSizeInBits());
9999 
10000   if (auto *C = dyn_cast<Constant>(V))
10001     return C->toConstantRange();
10002 
10003   unsigned BitWidth = V->getType()->getScalarSizeInBits();
10004   InstrInfoQuery IIQ(UseInstrInfo);
10005   ConstantRange CR = ConstantRange::getFull(BitWidth);
10006   if (auto *BO = dyn_cast<BinaryOperator>(V)) {
10007     APInt Lower = APInt(BitWidth, 0);
10008     APInt Upper = APInt(BitWidth, 0);
10009     // TODO: Return ConstantRange.
10010     setLimitsForBinOp(*BO, Lower, Upper, IIQ, ForSigned);
10011     CR = ConstantRange::getNonEmpty(Lower, Upper);
10012   } else if (auto *II = dyn_cast<IntrinsicInst>(V))
10013     CR = getRangeForIntrinsic(*II, UseInstrInfo);
10014   else if (auto *SI = dyn_cast<SelectInst>(V)) {
10015     ConstantRange CRTrue = computeConstantRange(
10016         SI->getTrueValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth + 1);
10017     ConstantRange CRFalse = computeConstantRange(
10018         SI->getFalseValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth + 1);
10019     CR = CRTrue.unionWith(CRFalse);
10020     CR = CR.intersectWith(getRangeForSelectPattern(*SI, IIQ));
10021   } else if (isa<FPToUIInst>(V) || isa<FPToSIInst>(V)) {
10022     APInt Lower = APInt(BitWidth, 0);
10023     APInt Upper = APInt(BitWidth, 0);
10024     // TODO: Return ConstantRange.
10025     setLimitForFPToI(cast<Instruction>(V), Lower, Upper);
10026     CR = ConstantRange::getNonEmpty(Lower, Upper);
10027   } else if (const auto *A = dyn_cast<Argument>(V))
10028     if (std::optional<ConstantRange> Range = A->getRange())
10029       CR = *Range;
10030 
10031   if (auto *I = dyn_cast<Instruction>(V)) {
10032     if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
10033       CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
10034 
10035     if (const auto *CB = dyn_cast<CallBase>(V))
10036       if (std::optional<ConstantRange> Range = CB->getRange())
10037         CR = CR.intersectWith(*Range);
10038   }
10039 
10040   if (CtxI && AC) {
10041     // Try to restrict the range based on information from assumptions.
10042     for (auto &AssumeVH : AC->assumptionsFor(V)) {
10043       if (!AssumeVH)
10044         continue;
10045       CallInst *I = cast<CallInst>(AssumeVH);
10046       assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
10047              "Got assumption for the wrong function!");
10048       assert(I->getIntrinsicID() == Intrinsic::assume &&
10049              "must be an assume intrinsic");
10050 
10051       if (!isValidAssumeForContext(I, CtxI, DT))
10052         continue;
10053       Value *Arg = I->getArgOperand(0);
10054       ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
10055       // Currently we just use information from comparisons.
10056       if (!Cmp || Cmp->getOperand(0) != V)
10057         continue;
10058       // TODO: Set "ForSigned" parameter via Cmp->isSigned()?
10059       ConstantRange RHS =
10060           computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false,
10061                                UseInstrInfo, AC, I, DT, Depth + 1);
10062       CR = CR.intersectWith(
10063           ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS));
10064     }
10065   }
10066 
10067   return CR;
10068 }
10069 
10070 static void
10071 addValueAffectedByCondition(Value *V,
10072                             function_ref<void(Value *)> InsertAffected) {
10073   assert(V != nullptr);
10074   if (isa<Argument>(V) || isa<GlobalValue>(V)) {
10075     InsertAffected(V);
10076   } else if (auto *I = dyn_cast<Instruction>(V)) {
10077     InsertAffected(V);
10078 
10079     // Peek through unary operators to find the source of the condition.
10080     Value *Op;
10081     if (match(I, m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
10082       if (isa<Instruction>(Op) || isa<Argument>(Op))
10083         InsertAffected(Op);
10084     }
10085   }
10086 }
10087 
10088 void llvm::findValuesAffectedByCondition(
10089     Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
10090   auto AddAffected = [&InsertAffected](Value *V) {
10091     addValueAffectedByCondition(V, InsertAffected);
10092   };
10093 
10094   auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
10095     if (IsAssume) {
10096       AddAffected(LHS);
10097       AddAffected(RHS);
10098     } else if (match(RHS, m_Constant()))
10099       AddAffected(LHS);
10100   };
10101 
10102   SmallVector<Value *, 8> Worklist;
10103   SmallPtrSet<Value *, 8> Visited;
10104   Worklist.push_back(Cond);
10105   while (!Worklist.empty()) {
10106     Value *V = Worklist.pop_back_val();
10107     if (!Visited.insert(V).second)
10108       continue;
10109 
10110     CmpPredicate Pred;
10111     Value *A, *B, *X;
10112 
10113     if (IsAssume) {
10114       AddAffected(V);
10115       if (match(V, m_Not(m_Value(X))))
10116         AddAffected(X);
10117     }
10118 
10119     if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
10120       // assume(A && B) is split to -> assume(A); assume(B);
10121       // assume(!(A || B)) is split to -> assume(!A); assume(!B);
10122       // Finally, assume(A || B) / assume(!(A && B)) generally don't provide
10123       // enough information to be worth handling (intersection of information as
10124       // opposed to union).
10125       if (!IsAssume) {
10126         Worklist.push_back(A);
10127         Worklist.push_back(B);
10128       }
10129     } else if (match(V, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
10130       bool HasRHSC = match(B, m_ConstantInt());
10131       if (ICmpInst::isEquality(Pred)) {
10132         AddAffected(A);
10133         if (IsAssume)
10134           AddAffected(B);
10135         if (HasRHSC) {
10136           Value *Y;
10137           // (X & C) or (X | C).
10138           // (X << C) or (X >>_s C) or (X >>_u C).
10139           if (match(A, m_Shift(m_Value(X), m_ConstantInt())))
10140             AddAffected(X);
10141           else if (match(A, m_And(m_Value(X), m_Value(Y))) ||
10142                    match(A, m_Or(m_Value(X), m_Value(Y)))) {
10143             AddAffected(X);
10144             AddAffected(Y);
10145           }
10146         }
10147       } else {
10148         AddCmpOperands(A, B);
10149         if (HasRHSC) {
10150           // Handle (A + C1) u< C2, which is the canonical form of
10151           // A > C3 && A < C4.
10152           if (match(A, m_AddLike(m_Value(X), m_ConstantInt())))
10153             AddAffected(X);
10154 
10155           if (ICmpInst::isUnsigned(Pred)) {
10156             Value *Y;
10157             // X & Y u> C    -> X >u C && Y >u C
10158             // X | Y u< C    -> X u< C && Y u< C
10159             // X nuw+ Y u< C -> X u< C && Y u< C
10160             if (match(A, m_And(m_Value(X), m_Value(Y))) ||
10161                 match(A, m_Or(m_Value(X), m_Value(Y))) ||
10162                 match(A, m_NUWAdd(m_Value(X), m_Value(Y)))) {
10163               AddAffected(X);
10164               AddAffected(Y);
10165             }
10166             // X nuw- Y u> C -> X u> C
10167             if (match(A, m_NUWSub(m_Value(X), m_Value())))
10168               AddAffected(X);
10169           }
10170         }
10171 
10172         // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
10173         // by computeKnownFPClass().
10174         if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
10175           if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
10176             InsertAffected(X);
10177           else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
10178             InsertAffected(X);
10179         }
10180       }
10181 
10182       if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
10183         AddAffected(X);
10184     } else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
10185       AddCmpOperands(A, B);
10186 
10187       // fcmp fneg(x), y
10188       // fcmp fabs(x), y
10189       // fcmp fneg(fabs(x)), y
10190       if (match(A, m_FNeg(m_Value(A))))
10191         AddAffected(A);
10192       if (match(A, m_FAbs(m_Value(A))))
10193         AddAffected(A);
10194 
10195     } else if (match(V, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
10196                                                            m_Value()))) {
10197       // Handle patterns that computeKnownFPClass() support.
10198       AddAffected(A);
10199     } else if (!IsAssume && match(V, m_Trunc(m_Value(X)))) {
10200       // Assume is checked here as X is already added above for assumes in
10201       // addValueAffectedByCondition
10202       AddAffected(X);
10203     } else if (!IsAssume && match(V, m_Not(m_Value(X)))) {
10204       // Assume is checked here to avoid issues with ephemeral values
10205       Worklist.push_back(X);
10206     }
10207   }
10208 }
10209 
10210 const Value *llvm::stripNullTest(const Value *V) {
10211   // (X >> C) or/add (X & mask(C) != 0)
10212   if (const auto *BO = dyn_cast<BinaryOperator>(V)) {
10213     if (BO->getOpcode() == Instruction::Add ||
10214         BO->getOpcode() == Instruction::Or) {
10215       const Value *X;
10216       const APInt *C1, *C2;
10217       if (match(BO, m_c_BinOp(m_LShr(m_Value(X), m_APInt(C1)),
10218                               m_ZExt(m_SpecificICmp(
10219                                   ICmpInst::ICMP_NE,
10220                                   m_And(m_Deferred(X), m_LowBitMask(C2)),
10221                                   m_Zero())))) &&
10222           C2->popcount() == C1->getZExtValue())
10223         return X;
10224     }
10225   }
10226   return nullptr;
10227 }
10228 
10229 Value *llvm::stripNullTest(Value *V) {
10230   return const_cast<Value *>(stripNullTest(const_cast<const Value *>(V)));
10231 }
10232