xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- InstCombineSimplifyDemanded.cpp ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains logic for simplifying instructions based on information
10 // about how they are used.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "InstCombineInternal.h"
15 #include "llvm/Analysis/ValueTracking.h"
16 #include "llvm/IR/GetElementPtrTypeIterator.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/PatternMatch.h"
19 #include "llvm/Support/KnownBits.h"
20 #include "llvm/Transforms/InstCombine/InstCombiner.h"
21 
22 using namespace llvm;
23 using namespace llvm::PatternMatch;
24 
25 #define DEBUG_TYPE "instcombine"
26 
27 static cl::opt<bool>
28     VerifyKnownBits("instcombine-verify-known-bits",
29                     cl::desc("Verify that computeKnownBits() and "
30                              "SimplifyDemandedBits() are consistent"),
31                     cl::Hidden, cl::init(false));
32 
33 /// Check to see if the specified operand of the specified instruction is a
34 /// constant integer. If so, check to see if there are any bits set in the
35 /// constant that are not demanded. If so, shrink the constant and return true.
36 static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
37                                    const APInt &Demanded) {
38   assert(I && "No instruction?");
39   assert(OpNo < I->getNumOperands() && "Operand index too large");
40 
41   // The operand must be a constant integer or splat integer.
42   Value *Op = I->getOperand(OpNo);
43   const APInt *C;
44   if (!match(Op, m_APInt(C)))
45     return false;
46 
47   // If there are no bits set that aren't demanded, nothing to do.
48   if (C->isSubsetOf(Demanded))
49     return false;
50 
51   // This instruction is producing bits that are not demanded. Shrink the RHS.
52   I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded));
53 
54   return true;
55 }
56 
57 /// Returns the bitwidth of the given scalar or pointer type. For vector types,
58 /// returns the element type's bitwidth.
59 static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
60   if (unsigned BitWidth = Ty->getScalarSizeInBits())
61     return BitWidth;
62 
63   return DL.getPointerTypeSizeInBits(Ty);
64 }
65 
66 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
67 /// the instruction has any properties that allow us to simplify its operands.
68 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
69                                                        KnownBits &Known) {
70   APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
71   Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
72                                      0, SQ.getWithInstruction(&Inst));
73   if (!V) return false;
74   if (V == &Inst) return true;
75   replaceInstUsesWith(Inst, V);
76   return true;
77 }
78 
79 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
80 /// the instruction has any properties that allow us to simplify its operands.
81 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
82   KnownBits Known(getBitWidth(Inst.getType(), DL));
83   return SimplifyDemandedInstructionBits(Inst, Known);
84 }
85 
86 /// This form of SimplifyDemandedBits simplifies the specified instruction
87 /// operand if possible, updating it in place. It returns true if it made any
88 /// change and false otherwise.
89 bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
90                                             const APInt &DemandedMask,
91                                             KnownBits &Known, unsigned Depth,
92                                             const SimplifyQuery &Q) {
93   Use &U = I->getOperandUse(OpNo);
94   Value *V = U.get();
95   if (isa<Constant>(V)) {
96     llvm::computeKnownBits(V, Known, Depth, Q);
97     return false;
98   }
99 
100   Known.resetAll();
101   if (DemandedMask.isZero()) {
102     // Not demanding any bits from V.
103     replaceUse(U, UndefValue::get(V->getType()));
104     return true;
105   }
106 
107   if (Depth == MaxAnalysisRecursionDepth)
108     return false;
109 
110   Instruction *VInst = dyn_cast<Instruction>(V);
111   if (!VInst) {
112     llvm::computeKnownBits(V, Known, Depth, Q);
113     return false;
114   }
115 
116   Value *NewVal;
117   if (VInst->hasOneUse()) {
118     // If the instruction has one use, we can directly simplify it.
119     NewVal = SimplifyDemandedUseBits(VInst, DemandedMask, Known, Depth, Q);
120   } else {
121     // If there are multiple uses of this instruction, then we can simplify
122     // VInst to some other value, but not modify the instruction.
123     NewVal =
124         SimplifyMultipleUseDemandedBits(VInst, DemandedMask, Known, Depth, Q);
125   }
126   if (!NewVal) return false;
127   if (Instruction* OpInst = dyn_cast<Instruction>(U))
128     salvageDebugInfo(*OpInst);
129 
130   replaceUse(U, NewVal);
131   return true;
132 }
133 
134 /// This function attempts to replace V with a simpler value based on the
135 /// demanded bits. When this function is called, it is known that only the bits
136 /// set in DemandedMask of the result of V are ever used downstream.
137 /// Consequently, depending on the mask and V, it may be possible to replace V
138 /// with a constant or one of its operands. In such cases, this function does
139 /// the replacement and returns true. In all other cases, it returns false after
140 /// analyzing the expression and setting KnownOne and known to be one in the
141 /// expression. Known.Zero contains all the bits that are known to be zero in
142 /// the expression. These are provided to potentially allow the caller (which
143 /// might recursively be SimplifyDemandedBits itself) to simplify the
144 /// expression.
145 /// Known.One and Known.Zero always follow the invariant that:
146 ///   Known.One & Known.Zero == 0.
147 /// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero
148 /// are accurate even for bits not in DemandedMask. Note
149 /// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all
150 /// be the same.
151 ///
152 /// This returns null if it did not change anything and it permits no
153 /// simplification.  This returns V itself if it did some simplification of V's
154 /// operands based on the information about what bits are demanded. This returns
155 /// some other non-null value if it found out that V is equal to another value
156 /// in the context where the specified bits are demanded, but not for all users.
157 Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
158                                                  const APInt &DemandedMask,
159                                                  KnownBits &Known,
160                                                  unsigned Depth,
161                                                  const SimplifyQuery &Q) {
162   assert(I != nullptr && "Null pointer of Value???");
163   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
164   uint32_t BitWidth = DemandedMask.getBitWidth();
165   Type *VTy = I->getType();
166   assert(
167       (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) &&
168       Known.getBitWidth() == BitWidth &&
169       "Value *V, DemandedMask and Known must have same BitWidth");
170 
171   KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
172 
173   // Update flags after simplifying an operand based on the fact that some high
174   // order bits are not demanded.
175   auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I,
176                                                   unsigned NLZ) {
177     if (NLZ > 0) {
178       // Disable the nsw and nuw flags here: We can no longer guarantee that
179       // we won't wrap after simplification. Removing the nsw/nuw flags is
180       // legal here because the top bit is not demanded.
181       I->setHasNoSignedWrap(false);
182       I->setHasNoUnsignedWrap(false);
183     }
184     return I;
185   };
186 
187   // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
188   // about the high bits of the operands.
189   auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
190     unsigned NLZ = DemandedMask.countl_zero();
191     // Right fill the mask of bits for the operands to demand the most
192     // significant bit and all those below it.
193     DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
194     if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
195         SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1, Q) ||
196         ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
197         SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) {
198       disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
199       return true;
200     }
201     return false;
202   };
203 
204   switch (I->getOpcode()) {
205   default:
206     llvm::computeKnownBits(I, Known, Depth, Q);
207     break;
208   case Instruction::And: {
209     // If either the LHS or the RHS are Zero, the result is zero.
210     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
211         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown,
212                              Depth + 1, Q))
213       return I;
214 
215     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
216                                          Depth, Q);
217 
218     // If the client is only demanding bits that we know, return the known
219     // constant.
220     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
221       return Constant::getIntegerValue(VTy, Known.One);
222 
223     // If all of the demanded bits are known 1 on one side, return the other.
224     // These bits cannot contribute to the result of the 'and'.
225     if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
226       return I->getOperand(0);
227     if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
228       return I->getOperand(1);
229 
230     // If the RHS is a constant, see if we can simplify it.
231     if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero))
232       return I;
233 
234     break;
235   }
236   case Instruction::Or: {
237     // If either the LHS or the RHS are One, the result is One.
238     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
239         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
240                              Depth + 1, Q)) {
241       // Disjoint flag may not longer hold.
242       I->dropPoisonGeneratingFlags();
243       return I;
244     }
245 
246     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
247                                          Depth, Q);
248 
249     // If the client is only demanding bits that we know, return the known
250     // constant.
251     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
252       return Constant::getIntegerValue(VTy, Known.One);
253 
254     // If all of the demanded bits are known zero on one side, return the other.
255     // These bits cannot contribute to the result of the 'or'.
256     if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
257       return I->getOperand(0);
258     if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
259       return I->getOperand(1);
260 
261     // If the RHS is a constant, see if we can simplify it.
262     if (ShrinkDemandedConstant(I, 1, DemandedMask))
263       return I;
264 
265     // Infer disjoint flag if no common bits are set.
266     if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) {
267       WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown),
268           RHSCache(I->getOperand(1), RHSKnown);
269       if (haveNoCommonBitsSet(LHSCache, RHSCache, Q)) {
270         cast<PossiblyDisjointInst>(I)->setIsDisjoint(true);
271         return I;
272       }
273     }
274 
275     break;
276   }
277   case Instruction::Xor: {
278     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
279         SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q))
280       return I;
281     Value *LHS, *RHS;
282     if (DemandedMask == 1 &&
283         match(I->getOperand(0), m_Intrinsic<Intrinsic::ctpop>(m_Value(LHS))) &&
284         match(I->getOperand(1), m_Intrinsic<Intrinsic::ctpop>(m_Value(RHS)))) {
285       // (ctpop(X) ^ ctpop(Y)) & 1 --> ctpop(X^Y) & 1
286       IRBuilderBase::InsertPointGuard Guard(Builder);
287       Builder.SetInsertPoint(I);
288       auto *Xor = Builder.CreateXor(LHS, RHS);
289       return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor);
290     }
291 
292     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
293                                          Depth, Q);
294 
295     // If the client is only demanding bits that we know, return the known
296     // constant.
297     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
298       return Constant::getIntegerValue(VTy, Known.One);
299 
300     // If all of the demanded bits are known zero on one side, return the other.
301     // These bits cannot contribute to the result of the 'xor'.
302     if (DemandedMask.isSubsetOf(RHSKnown.Zero))
303       return I->getOperand(0);
304     if (DemandedMask.isSubsetOf(LHSKnown.Zero))
305       return I->getOperand(1);
306 
307     // If all of the demanded bits are known to be zero on one side or the
308     // other, turn this into an *inclusive* or.
309     //    e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
310     if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) {
311       Instruction *Or =
312           BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1));
313       if (DemandedMask.isAllOnes())
314         cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true);
315       Or->takeName(I);
316       return InsertNewInstWith(Or, I->getIterator());
317     }
318 
319     // If all of the demanded bits on one side are known, and all of the set
320     // bits on that side are also known to be set on the other side, turn this
321     // into an AND, as we know the bits will be cleared.
322     //    e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
323     if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) &&
324         RHSKnown.One.isSubsetOf(LHSKnown.One)) {
325       Constant *AndC = Constant::getIntegerValue(VTy,
326                                                  ~RHSKnown.One & DemandedMask);
327       Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
328       return InsertNewInstWith(And, I->getIterator());
329     }
330 
331     // If the RHS is a constant, see if we can change it. Don't alter a -1
332     // constant because that's a canonical 'not' op, and that is better for
333     // combining, SCEV, and codegen.
334     const APInt *C;
335     if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) {
336       if ((*C | ~DemandedMask).isAllOnes()) {
337         // Force bits to 1 to create a 'not' op.
338         I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
339         return I;
340       }
341       // If we can't turn this into a 'not', try to shrink the constant.
342       if (ShrinkDemandedConstant(I, 1, DemandedMask))
343         return I;
344     }
345 
346     // If our LHS is an 'and' and if it has one use, and if any of the bits we
347     // are flipping are known to be set, then the xor is just resetting those
348     // bits to zero.  We can just knock out bits from the 'and' and the 'xor',
349     // simplifying both of them.
350     if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) {
351       ConstantInt *AndRHS, *XorRHS;
352       if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
353           match(I->getOperand(1), m_ConstantInt(XorRHS)) &&
354           match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) &&
355           (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
356         APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
357 
358         Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue());
359         Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
360         InsertNewInstWith(NewAnd, I->getIterator());
361 
362         Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue());
363         Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
364         return InsertNewInstWith(NewXor, I->getIterator());
365       }
366     }
367     break;
368   }
369   case Instruction::Select: {
370     if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1, Q) ||
371         SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1, Q))
372       return I;
373 
374     // If the operands are constants, see if we can simplify them.
375     // This is similar to ShrinkDemandedConstant, but for a select we want to
376     // try to keep the selected constants the same as icmp value constants, if
377     // we can. This helps not break apart (or helps put back together)
378     // canonical patterns like min and max.
379     auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
380                                          const APInt &DemandedMask) {
381       const APInt *SelC;
382       if (!match(I->getOperand(OpNo), m_APInt(SelC)))
383         return false;
384 
385       // Get the constant out of the ICmp, if there is one.
386       // Only try this when exactly 1 operand is a constant (if both operands
387       // are constant, the icmp should eventually simplify). Otherwise, we may
388       // invert the transform that reduces set bits and infinite-loop.
389       Value *X;
390       const APInt *CmpC;
391       ICmpInst::Predicate Pred;
392       if (!match(I->getOperand(0), m_ICmp(Pred, m_Value(X), m_APInt(CmpC))) ||
393           isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth())
394         return ShrinkDemandedConstant(I, OpNo, DemandedMask);
395 
396       // If the constant is already the same as the ICmp, leave it as-is.
397       if (*CmpC == *SelC)
398         return false;
399       // If the constants are not already the same, but can be with the demand
400       // mask, use the constant value from the ICmp.
401       if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) {
402         I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC));
403         return true;
404       }
405       return ShrinkDemandedConstant(I, OpNo, DemandedMask);
406     };
407     if (CanonicalizeSelectConstant(I, 1, DemandedMask) ||
408         CanonicalizeSelectConstant(I, 2, DemandedMask))
409       return I;
410 
411     // Only known if known in both the LHS and RHS.
412     adjustKnownBitsForSelectArm(LHSKnown, I->getOperand(0), I->getOperand(1),
413                                 /*Invert=*/false, Depth, Q);
414     adjustKnownBitsForSelectArm(RHSKnown, I->getOperand(0), I->getOperand(2),
415                                 /*Invert=*/true, Depth, Q);
416     Known = LHSKnown.intersectWith(RHSKnown);
417     break;
418   }
419   case Instruction::Trunc: {
420     // If we do not demand the high bits of a right-shifted and truncated value,
421     // then we may be able to truncate it before the shift.
422     Value *X;
423     const APInt *C;
424     if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
425       // The shift amount must be valid (not poison) in the narrow type, and
426       // it must not be greater than the high bits demanded of the result.
427       if (C->ult(VTy->getScalarSizeInBits()) &&
428           C->ule(DemandedMask.countl_zero())) {
429         // trunc (lshr X, C) --> lshr (trunc X), C
430         IRBuilderBase::InsertPointGuard Guard(Builder);
431         Builder.SetInsertPoint(I);
432         Value *Trunc = Builder.CreateTrunc(X, VTy);
433         return Builder.CreateLShr(Trunc, C->getZExtValue());
434       }
435     }
436   }
437     [[fallthrough]];
438   case Instruction::ZExt: {
439     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
440 
441     APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
442     KnownBits InputKnown(SrcBitWidth);
443     if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1,
444                              Q)) {
445       // For zext nneg, we may have dropped the instruction which made the
446       // input non-negative.
447       I->dropPoisonGeneratingFlags();
448       return I;
449     }
450     assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?");
451     if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() &&
452         !InputKnown.isNegative())
453       InputKnown.makeNonNegative();
454     Known = InputKnown.zextOrTrunc(BitWidth);
455 
456     break;
457   }
458   case Instruction::SExt: {
459     // Compute the bits in the result that are not present in the input.
460     unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
461 
462     APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth);
463 
464     // If any of the sign extended bits are demanded, we know that the sign
465     // bit is demanded.
466     if (DemandedMask.getActiveBits() > SrcBitWidth)
467       InputDemandedBits.setBit(SrcBitWidth-1);
468 
469     KnownBits InputKnown(SrcBitWidth);
470     if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1, Q))
471       return I;
472 
473     // If the input sign bit is known zero, or if the NewBits are not demanded
474     // convert this into a zero extension.
475     if (InputKnown.isNonNegative() ||
476         DemandedMask.getActiveBits() <= SrcBitWidth) {
477       // Convert to ZExt cast.
478       CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy);
479       NewCast->takeName(I);
480       return InsertNewInstWith(NewCast, I->getIterator());
481     }
482 
483     // If the sign bit of the input is known set or clear, then we know the
484     // top bits of the result.
485     Known = InputKnown.sext(BitWidth);
486     break;
487   }
488   case Instruction::Add: {
489     if ((DemandedMask & 1) == 0) {
490       // If we do not need the low bit, try to convert bool math to logic:
491       // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN
492       Value *X, *Y;
493       if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))),
494                            m_OneUse(m_SExt(m_Value(Y))))) &&
495           X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) {
496         // Truth table for inputs and output signbits:
497         //       X:0 | X:1
498         //      ----------
499         // Y:0  |  0 | 0 |
500         // Y:1  | -1 | 0 |
501         //      ----------
502         IRBuilderBase::InsertPointGuard Guard(Builder);
503         Builder.SetInsertPoint(I);
504         Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y);
505         return Builder.CreateSExt(AndNot, VTy);
506       }
507 
508       // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN
509       if (match(I, m_Add(m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
510           X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() &&
511           (I->getOperand(0)->hasOneUse() || I->getOperand(1)->hasOneUse())) {
512 
513         // Truth table for inputs and output signbits:
514         //       X:0 | X:1
515         //      -----------
516         // Y:0  | -1 | -1 |
517         // Y:1  | -1 |  0 |
518         //      -----------
519         IRBuilderBase::InsertPointGuard Guard(Builder);
520         Builder.SetInsertPoint(I);
521         Value *Or = Builder.CreateOr(X, Y);
522         return Builder.CreateSExt(Or, VTy);
523       }
524     }
525 
526     // Right fill the mask of bits for the operands to demand the most
527     // significant bit and all those below it.
528     unsigned NLZ = DemandedMask.countl_zero();
529     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
530     if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
531         SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q))
532       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
533 
534     // If low order bits are not demanded and known to be zero in one operand,
535     // then we don't need to demand them from the other operand, since they
536     // can't cause overflow into any bits that are demanded in the result.
537     unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
538     APInt DemandedFromLHS = DemandedFromOps;
539     DemandedFromLHS.clearLowBits(NTZ);
540     if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
541         SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q))
542       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
543 
544     // If we are known to be adding zeros to every bit below
545     // the highest demanded bit, we just return the other side.
546     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
547       return I->getOperand(0);
548     if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
549       return I->getOperand(1);
550 
551     // (add X, C) --> (xor X, C) IFF C is equal to the top bit of the DemandMask
552     {
553       const APInt *C;
554       if (match(I->getOperand(1), m_APInt(C)) &&
555           C->isOneBitSet(DemandedMask.getActiveBits() - 1)) {
556         IRBuilderBase::InsertPointGuard Guard(Builder);
557         Builder.SetInsertPoint(I);
558         return Builder.CreateXor(I->getOperand(0), ConstantInt::get(VTy, *C));
559       }
560     }
561 
562     // Otherwise just compute the known bits of the result.
563     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
564     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
565     Known = KnownBits::computeForAddSub(true, NSW, NUW, LHSKnown, RHSKnown);
566     break;
567   }
568   case Instruction::Sub: {
569     // Right fill the mask of bits for the operands to demand the most
570     // significant bit and all those below it.
571     unsigned NLZ = DemandedMask.countl_zero();
572     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
573     if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
574         SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q))
575       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
576 
577     // If low order bits are not demanded and are known to be zero in RHS,
578     // then we don't need to demand them from LHS, since they can't cause a
579     // borrow from any bits that are demanded in the result.
580     unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
581     APInt DemandedFromLHS = DemandedFromOps;
582     DemandedFromLHS.clearLowBits(NTZ);
583     if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
584         SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q))
585       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
586 
587     // If we are known to be subtracting zeros from every bit below
588     // the highest demanded bit, we just return the other side.
589     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
590       return I->getOperand(0);
591     // We can't do this with the LHS for subtraction, unless we are only
592     // demanding the LSB.
593     if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero))
594       return I->getOperand(1);
595 
596     // Otherwise just compute the known bits of the result.
597     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
598     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
599     Known = KnownBits::computeForAddSub(false, NSW, NUW, LHSKnown, RHSKnown);
600     break;
601   }
602   case Instruction::Mul: {
603     APInt DemandedFromOps;
604     if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
605       return I;
606 
607     if (DemandedMask.isPowerOf2()) {
608       // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
609       // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
610       // odd (has LSB set), then the left-shifted low bit of X is the answer.
611       unsigned CTZ = DemandedMask.countr_zero();
612       const APInt *C;
613       if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) {
614         Constant *ShiftC = ConstantInt::get(VTy, CTZ);
615         Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
616         return InsertNewInstWith(Shl, I->getIterator());
617       }
618     }
619     // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
620     // X * X is odd iff X is odd.
621     // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
622     if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) {
623       Constant *One = ConstantInt::get(VTy, 1);
624       Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One);
625       return InsertNewInstWith(And1, I->getIterator());
626     }
627 
628     llvm::computeKnownBits(I, Known, Depth, Q);
629     break;
630   }
631   case Instruction::Shl: {
632     const APInt *SA;
633     if (match(I->getOperand(1), m_APInt(SA))) {
634       const APInt *ShrAmt;
635       if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt))))
636         if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0)))
637           if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA,
638                                                     DemandedMask, Known))
639             return R;
640 
641       // Do not simplify if shl is part of funnel-shift pattern
642       if (I->hasOneUse()) {
643         auto *Inst = dyn_cast<Instruction>(I->user_back());
644         if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
645           if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) {
646             auto [IID, FShiftArgs] = *Opt;
647             if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
648                 FShiftArgs[0] == FShiftArgs[1]) {
649               llvm::computeKnownBits(I, Known, Depth, Q);
650               break;
651             }
652           }
653         }
654       }
655 
656       // We only want bits that already match the signbit then we don't
657       // need to shift.
658       uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1);
659       if (DemandedMask.countr_zero() >= ShiftAmt) {
660         if (I->hasNoSignedWrap()) {
661           unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
662           unsigned SignBits =
663               ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
664           if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits)
665             return I->getOperand(0);
666         }
667 
668         // If we can pre-shift a right-shifted constant to the left without
669         // losing any high bits and we don't demand the low bits, then eliminate
670         // the left-shift:
671         // (C >> X) << LeftShiftAmtC --> (C << LeftShiftAmtC) >> X
672         Value *X;
673         Constant *C;
674         if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
675           Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
676           Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
677                                                         LeftShiftAmtC, DL);
678           if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC,
679                                            LeftShiftAmtC, DL) == C) {
680             Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
681             return InsertNewInstWith(Lshr, I->getIterator());
682           }
683         }
684       }
685 
686       APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
687 
688       // If the shift is NUW/NSW, then it does demand the high bits.
689       ShlOperator *IOp = cast<ShlOperator>(I);
690       if (IOp->hasNoSignedWrap())
691         DemandedMaskIn.setHighBits(ShiftAmt+1);
692       else if (IOp->hasNoUnsignedWrap())
693         DemandedMaskIn.setHighBits(ShiftAmt);
694 
695       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q))
696         return I;
697 
698       Known = KnownBits::shl(Known,
699                              KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)),
700                              /* NUW */ IOp->hasNoUnsignedWrap(),
701                              /* NSW */ IOp->hasNoSignedWrap());
702     } else {
703       // This is a variable shift, so we can't shift the demand mask by a known
704       // amount. But if we are not demanding high bits, then we are not
705       // demanding those bits from the pre-shifted operand either.
706       if (unsigned CTLZ = DemandedMask.countl_zero()) {
707         APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
708         if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1, Q)) {
709           // We can't guarantee that nsw/nuw hold after simplifying the operand.
710           I->dropPoisonGeneratingFlags();
711           return I;
712         }
713       }
714       llvm::computeKnownBits(I, Known, Depth, Q);
715     }
716     break;
717   }
718   case Instruction::LShr: {
719     const APInt *SA;
720     if (match(I->getOperand(1), m_APInt(SA))) {
721       uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
722 
723       // Do not simplify if lshr is part of funnel-shift pattern
724       if (I->hasOneUse()) {
725         auto *Inst = dyn_cast<Instruction>(I->user_back());
726         if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
727           if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) {
728             auto [IID, FShiftArgs] = *Opt;
729             if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
730                 FShiftArgs[0] == FShiftArgs[1]) {
731               llvm::computeKnownBits(I, Known, Depth, Q);
732               break;
733             }
734           }
735         }
736       }
737 
738       // If we are just demanding the shifted sign bit and below, then this can
739       // be treated as an ASHR in disguise.
740       if (DemandedMask.countl_zero() >= ShiftAmt) {
741         // If we only want bits that already match the signbit then we don't
742         // need to shift.
743         unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
744         unsigned SignBits =
745             ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
746         if (SignBits >= NumHiDemandedBits)
747           return I->getOperand(0);
748 
749         // If we can pre-shift a left-shifted constant to the right without
750         // losing any low bits (we already know we don't demand the high bits),
751         // then eliminate the right-shift:
752         // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X
753         Value *X;
754         Constant *C;
755         if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) {
756           Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
757           Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C,
758                                                         RightShiftAmtC, DL);
759           if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC,
760                                            RightShiftAmtC, DL) == C) {
761             Instruction *Shl = BinaryOperator::CreateShl(NewC, X);
762             return InsertNewInstWith(Shl, I->getIterator());
763           }
764         }
765       }
766 
767       // Unsigned shift right.
768       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
769       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) {
770         // exact flag may not longer hold.
771         I->dropPoisonGeneratingFlags();
772         return I;
773       }
774       Known.Zero.lshrInPlace(ShiftAmt);
775       Known.One.lshrInPlace(ShiftAmt);
776       if (ShiftAmt)
777         Known.Zero.setHighBits(ShiftAmt);  // high bits known zero.
778     } else {
779       llvm::computeKnownBits(I, Known, Depth, Q);
780     }
781     break;
782   }
783   case Instruction::AShr: {
784     unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
785 
786     // If we only want bits that already match the signbit then we don't need
787     // to shift.
788     unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
789     if (SignBits >= NumHiDemandedBits)
790       return I->getOperand(0);
791 
792     // If this is an arithmetic shift right and only the low-bit is set, we can
793     // always convert this into a logical shr, even if the shift amount is
794     // variable.  The low bit of the shift cannot be an input sign bit unless
795     // the shift amount is >= the size of the datatype, which is undefined.
796     if (DemandedMask.isOne()) {
797       // Perform the logical shift right.
798       Instruction *NewVal = BinaryOperator::CreateLShr(
799                         I->getOperand(0), I->getOperand(1), I->getName());
800       return InsertNewInstWith(NewVal, I->getIterator());
801     }
802 
803     const APInt *SA;
804     if (match(I->getOperand(1), m_APInt(SA))) {
805       uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
806 
807       // Signed shift right.
808       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
809       // If any of the bits being shifted in are demanded, then we should set
810       // the sign bit as demanded.
811       bool ShiftedInBitsDemanded = DemandedMask.countl_zero() < ShiftAmt;
812       if (ShiftedInBitsDemanded)
813         DemandedMaskIn.setSignBit();
814       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) {
815         // exact flag may not longer hold.
816         I->dropPoisonGeneratingFlags();
817         return I;
818       }
819 
820       // If the input sign bit is known to be zero, or if none of the shifted in
821       // bits are demanded, turn this into an unsigned shift right.
822       if (Known.Zero[BitWidth - 1] || !ShiftedInBitsDemanded) {
823         BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0),
824                                                           I->getOperand(1));
825         LShr->setIsExact(cast<BinaryOperator>(I)->isExact());
826         LShr->takeName(I);
827         return InsertNewInstWith(LShr, I->getIterator());
828       }
829 
830       Known = KnownBits::ashr(
831           Known, KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)),
832           ShiftAmt != 0, I->isExact());
833     } else {
834       llvm::computeKnownBits(I, Known, Depth, Q);
835     }
836     break;
837   }
838   case Instruction::UDiv: {
839     // UDiv doesn't demand low bits that are zero in the divisor.
840     const APInt *SA;
841     if (match(I->getOperand(1), m_APInt(SA))) {
842       // TODO: Take the demanded mask of the result into account.
843       unsigned RHSTrailingZeros = SA->countr_zero();
844       APInt DemandedMaskIn =
845           APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
846       if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1, Q)) {
847         // We can't guarantee that "exact" is still true after changing the
848         // the dividend.
849         I->dropPoisonGeneratingFlags();
850         return I;
851       }
852 
853       Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA),
854                               cast<BinaryOperator>(I)->isExact());
855     } else {
856       llvm::computeKnownBits(I, Known, Depth, Q);
857     }
858     break;
859   }
860   case Instruction::SRem: {
861     const APInt *Rem;
862     if (match(I->getOperand(1), m_APInt(Rem))) {
863       // X % -1 demands all the bits because we don't want to introduce
864       // INT_MIN % -1 (== undef) by accident.
865       if (Rem->isAllOnes())
866         break;
867       APInt RA = Rem->abs();
868       if (RA.isPowerOf2()) {
869         if (DemandedMask.ult(RA))    // srem won't affect demanded bits
870           return I->getOperand(0);
871 
872         APInt LowBits = RA - 1;
873         APInt Mask2 = LowBits | APInt::getSignMask(BitWidth);
874         if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1, Q))
875           return I;
876 
877         // The low bits of LHS are unchanged by the srem.
878         Known.Zero = LHSKnown.Zero & LowBits;
879         Known.One = LHSKnown.One & LowBits;
880 
881         // If LHS is non-negative or has all low bits zero, then the upper bits
882         // are all zero.
883         if (LHSKnown.isNonNegative() || LowBits.isSubsetOf(LHSKnown.Zero))
884           Known.Zero |= ~LowBits;
885 
886         // If LHS is negative and not all low bits are zero, then the upper bits
887         // are all one.
888         if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One))
889           Known.One |= ~LowBits;
890 
891         break;
892       }
893     }
894 
895     llvm::computeKnownBits(I, Known, Depth, Q);
896     break;
897   }
898   case Instruction::Call: {
899     bool KnownBitsComputed = false;
900     if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
901       switch (II->getIntrinsicID()) {
902       case Intrinsic::abs: {
903         if (DemandedMask == 1)
904           return II->getArgOperand(0);
905         break;
906       }
907       case Intrinsic::ctpop: {
908         // Checking if the number of clear bits is odd (parity)? If the type has
909         // an even number of bits, that's the same as checking if the number of
910         // set bits is odd, so we can eliminate the 'not' op.
911         Value *X;
912         if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 &&
913             match(II->getArgOperand(0), m_Not(m_Value(X)))) {
914           Function *Ctpop = Intrinsic::getDeclaration(
915               II->getModule(), Intrinsic::ctpop, VTy);
916           return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator());
917         }
918         break;
919       }
920       case Intrinsic::bswap: {
921         // If the only bits demanded come from one byte of the bswap result,
922         // just shift the input byte into position to eliminate the bswap.
923         unsigned NLZ = DemandedMask.countl_zero();
924         unsigned NTZ = DemandedMask.countr_zero();
925 
926         // Round NTZ down to the next byte.  If we have 11 trailing zeros, then
927         // we need all the bits down to bit 8.  Likewise, round NLZ.  If we
928         // have 14 leading zeros, round to 8.
929         NLZ = alignDown(NLZ, 8);
930         NTZ = alignDown(NTZ, 8);
931         // If we need exactly one byte, we can do this transformation.
932         if (BitWidth - NLZ - NTZ == 8) {
933           // Replace this with either a left or right shift to get the byte into
934           // the right place.
935           Instruction *NewVal;
936           if (NLZ > NTZ)
937             NewVal = BinaryOperator::CreateLShr(
938                 II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ));
939           else
940             NewVal = BinaryOperator::CreateShl(
941                 II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ));
942           NewVal->takeName(I);
943           return InsertNewInstWith(NewVal, I->getIterator());
944         }
945         break;
946       }
947       case Intrinsic::ptrmask: {
948         unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
949         RHSKnown = KnownBits(MaskWidth);
950         // If either the LHS or the RHS are Zero, the result is zero.
951         if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q) ||
952             SimplifyDemandedBits(
953                 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
954                 RHSKnown, Depth + 1, Q))
955           return I;
956 
957         // TODO: Should be 1-extend
958         RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
959 
960         Known = LHSKnown & RHSKnown;
961         KnownBitsComputed = true;
962 
963         // If the client is only demanding bits we know to be zero, return
964         // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
965         // provenance, but making the mask zero will be easily optimizable in
966         // the backend.
967         if (DemandedMask.isSubsetOf(Known.Zero) &&
968             !match(I->getOperand(1), m_Zero()))
969           return replaceOperand(
970               *I, 1, Constant::getNullValue(I->getOperand(1)->getType()));
971 
972         // Mask in demanded space does nothing.
973         // NOTE: We may have attributes associated with the return value of the
974         // llvm.ptrmask intrinsic that will be lost when we just return the
975         // operand. We should try to preserve them.
976         if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
977           return I->getOperand(0);
978 
979         // If the RHS is a constant, see if we can simplify it.
980         if (ShrinkDemandedConstant(
981                 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth)))
982           return I;
983 
984         // Combine:
985         // (ptrmask (getelementptr i8, ptr p, imm i), imm mask)
986         //   -> (ptrmask (getelementptr i8, ptr p, imm (i & mask)), imm mask)
987         // where only the low bits known to be zero in the pointer are changed
988         Value *InnerPtr;
989         uint64_t GEPIndex;
990         uint64_t PtrMaskImmediate;
991         if (match(I, m_Intrinsic<Intrinsic::ptrmask>(
992                          m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)),
993                          m_ConstantInt(PtrMaskImmediate)))) {
994 
995           LHSKnown = computeKnownBits(InnerPtr, Depth + 1, I);
996           if (!LHSKnown.isZero()) {
997             const unsigned trailingZeros = LHSKnown.countMinTrailingZeros();
998             uint64_t PointerAlignBits = (uint64_t(1) << trailingZeros) - 1;
999 
1000             uint64_t HighBitsGEPIndex = GEPIndex & ~PointerAlignBits;
1001             uint64_t MaskedLowBitsGEPIndex =
1002                 GEPIndex & PointerAlignBits & PtrMaskImmediate;
1003 
1004             uint64_t MaskedGEPIndex = HighBitsGEPIndex | MaskedLowBitsGEPIndex;
1005 
1006             if (MaskedGEPIndex != GEPIndex) {
1007               auto *GEP = cast<GetElementPtrInst>(II->getArgOperand(0));
1008               Builder.SetInsertPoint(I);
1009               Type *GEPIndexType =
1010                   DL.getIndexType(GEP->getPointerOperand()->getType());
1011               Value *MaskedGEP = Builder.CreateGEP(
1012                   GEP->getSourceElementType(), InnerPtr,
1013                   ConstantInt::get(GEPIndexType, MaskedGEPIndex),
1014                   GEP->getName(), GEP->isInBounds());
1015 
1016               replaceOperand(*I, 0, MaskedGEP);
1017               return I;
1018             }
1019           }
1020         }
1021 
1022         break;
1023       }
1024 
1025       case Intrinsic::fshr:
1026       case Intrinsic::fshl: {
1027         const APInt *SA;
1028         if (!match(I->getOperand(2), m_APInt(SA)))
1029           break;
1030 
1031         // Normalize to funnel shift left. APInt shifts of BitWidth are well-
1032         // defined, so no need to special-case zero shifts here.
1033         uint64_t ShiftAmt = SA->urem(BitWidth);
1034         if (II->getIntrinsicID() == Intrinsic::fshr)
1035           ShiftAmt = BitWidth - ShiftAmt;
1036 
1037         APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt));
1038         APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
1039         if (I->getOperand(0) != I->getOperand(1)) {
1040           if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown,
1041                                    Depth + 1, Q) ||
1042               SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1,
1043                                    Q))
1044             return I;
1045         } else { // fshl is a rotate
1046           // Avoid converting rotate into funnel shift.
1047           // Only simplify if one operand is constant.
1048           LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I);
1049           if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) &&
1050               !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) {
1051             replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One));
1052             return I;
1053           }
1054 
1055           RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I);
1056           if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) &&
1057               !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) {
1058             replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One));
1059             return I;
1060           }
1061         }
1062 
1063         Known.Zero = LHSKnown.Zero.shl(ShiftAmt) |
1064                      RHSKnown.Zero.lshr(BitWidth - ShiftAmt);
1065         Known.One = LHSKnown.One.shl(ShiftAmt) |
1066                     RHSKnown.One.lshr(BitWidth - ShiftAmt);
1067         KnownBitsComputed = true;
1068         break;
1069       }
1070       case Intrinsic::umax: {
1071         // UMax(A, C) == A if ...
1072         // The lowest non-zero bit of DemandMask is higher than the highest
1073         // non-zero bit of C.
1074         const APInt *C;
1075         unsigned CTZ = DemandedMask.countr_zero();
1076         if (match(II->getArgOperand(1), m_APInt(C)) &&
1077             CTZ >= C->getActiveBits())
1078           return II->getArgOperand(0);
1079         break;
1080       }
1081       case Intrinsic::umin: {
1082         // UMin(A, C) == A if ...
1083         // The lowest non-zero bit of DemandMask is higher than the highest
1084         // non-one bit of C.
1085         // This comes from using DeMorgans on the above umax example.
1086         const APInt *C;
1087         unsigned CTZ = DemandedMask.countr_zero();
1088         if (match(II->getArgOperand(1), m_APInt(C)) &&
1089             CTZ >= C->getBitWidth() - C->countl_one())
1090           return II->getArgOperand(0);
1091         break;
1092       }
1093       default: {
1094         // Handle target specific intrinsics
1095         std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
1096             *II, DemandedMask, Known, KnownBitsComputed);
1097         if (V)
1098           return *V;
1099         break;
1100       }
1101       }
1102     }
1103 
1104     if (!KnownBitsComputed)
1105       llvm::computeKnownBits(I, Known, Depth, Q);
1106     break;
1107   }
1108   }
1109 
1110   if (I->getType()->isPointerTy()) {
1111     Align Alignment = I->getPointerAlignment(DL);
1112     Known.Zero.setLowBits(Log2(Alignment));
1113   }
1114 
1115   // If the client is only demanding bits that we know, return the known
1116   // constant. We can't directly simplify pointers as a constant because of
1117   // pointer provenance.
1118   // TODO: We could return `(inttoptr const)` for pointers.
1119   if (!I->getType()->isPointerTy() &&
1120       DemandedMask.isSubsetOf(Known.Zero | Known.One))
1121     return Constant::getIntegerValue(VTy, Known.One);
1122 
1123   if (VerifyKnownBits) {
1124     KnownBits ReferenceKnown = llvm::computeKnownBits(I, Depth, Q);
1125     if (Known != ReferenceKnown) {
1126       errs() << "Mismatched known bits for " << *I << " in "
1127              << I->getFunction()->getName() << "\n";
1128       errs() << "computeKnownBits(): " << ReferenceKnown << "\n";
1129       errs() << "SimplifyDemandedBits(): " << Known << "\n";
1130       std::abort();
1131     }
1132   }
1133 
1134   return nullptr;
1135 }
1136 
1137 /// Helper routine of SimplifyDemandedUseBits. It computes Known
1138 /// bits. It also tries to handle simplifications that can be done based on
1139 /// DemandedMask, but without modifying the Instruction.
1140 Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
1141     Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth,
1142     const SimplifyQuery &Q) {
1143   unsigned BitWidth = DemandedMask.getBitWidth();
1144   Type *ITy = I->getType();
1145 
1146   KnownBits LHSKnown(BitWidth);
1147   KnownBits RHSKnown(BitWidth);
1148 
1149   // Despite the fact that we can't simplify this instruction in all User's
1150   // context, we can at least compute the known bits, and we can
1151   // do simplifications that apply to *just* the one user if we know that
1152   // this instruction has a simpler value in that context.
1153   switch (I->getOpcode()) {
1154   case Instruction::And: {
1155     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1156     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1157     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
1158                                          Depth, Q);
1159     computeKnownBitsFromContext(I, Known, Depth, Q);
1160 
1161     // If the client is only demanding bits that we know, return the known
1162     // constant.
1163     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1164       return Constant::getIntegerValue(ITy, Known.One);
1165 
1166     // If all of the demanded bits are known 1 on one side, return the other.
1167     // These bits cannot contribute to the result of the 'and' in this context.
1168     if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
1169       return I->getOperand(0);
1170     if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
1171       return I->getOperand(1);
1172 
1173     break;
1174   }
1175   case Instruction::Or: {
1176     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1177     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1178     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
1179                                          Depth, Q);
1180     computeKnownBitsFromContext(I, Known, Depth, Q);
1181 
1182     // If the client is only demanding bits that we know, return the known
1183     // constant.
1184     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1185       return Constant::getIntegerValue(ITy, Known.One);
1186 
1187     // We can simplify (X|Y) -> X or Y in the user's context if we know that
1188     // only bits from X or Y are demanded.
1189     // If all of the demanded bits are known zero on one side, return the other.
1190     // These bits cannot contribute to the result of the 'or' in this context.
1191     if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
1192       return I->getOperand(0);
1193     if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
1194       return I->getOperand(1);
1195 
1196     break;
1197   }
1198   case Instruction::Xor: {
1199     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1200     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1201     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
1202                                          Depth, Q);
1203     computeKnownBitsFromContext(I, Known, Depth, Q);
1204 
1205     // If the client is only demanding bits that we know, return the known
1206     // constant.
1207     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1208       return Constant::getIntegerValue(ITy, Known.One);
1209 
1210     // We can simplify (X^Y) -> X or Y in the user's context if we know that
1211     // only bits from X or Y are demanded.
1212     // If all of the demanded bits are known zero on one side, return the other.
1213     if (DemandedMask.isSubsetOf(RHSKnown.Zero))
1214       return I->getOperand(0);
1215     if (DemandedMask.isSubsetOf(LHSKnown.Zero))
1216       return I->getOperand(1);
1217 
1218     break;
1219   }
1220   case Instruction::Add: {
1221     unsigned NLZ = DemandedMask.countl_zero();
1222     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
1223 
1224     // If an operand adds zeros to every bit below the highest demanded bit,
1225     // that operand doesn't change the result. Return the other side.
1226     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1227     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
1228       return I->getOperand(0);
1229 
1230     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1231     if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
1232       return I->getOperand(1);
1233 
1234     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
1235     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
1236     Known =
1237         KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, LHSKnown, RHSKnown);
1238     computeKnownBitsFromContext(I, Known, Depth, Q);
1239     break;
1240   }
1241   case Instruction::Sub: {
1242     unsigned NLZ = DemandedMask.countl_zero();
1243     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
1244 
1245     // If an operand subtracts zeros from every bit below the highest demanded
1246     // bit, that operand doesn't change the result. Return the other side.
1247     llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
1248     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
1249       return I->getOperand(0);
1250 
1251     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
1252     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
1253     llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
1254     Known = KnownBits::computeForAddSub(/*Add=*/false, NSW, NUW, LHSKnown,
1255                                         RHSKnown);
1256     computeKnownBitsFromContext(I, Known, Depth, Q);
1257     break;
1258   }
1259   case Instruction::AShr: {
1260     // Compute the Known bits to simplify things downstream.
1261     llvm::computeKnownBits(I, Known, Depth, Q);
1262 
1263     // If this user is only demanding bits that we know, return the known
1264     // constant.
1265     if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
1266       return Constant::getIntegerValue(ITy, Known.One);
1267 
1268     // If the right shift operand 0 is a result of a left shift by the same
1269     // amount, this is probably a zero/sign extension, which may be unnecessary,
1270     // if we do not demand any of the new sign bits. So, return the original
1271     // operand instead.
1272     const APInt *ShiftRC;
1273     const APInt *ShiftLC;
1274     Value *X;
1275     unsigned BitWidth = DemandedMask.getBitWidth();
1276     if (match(I,
1277               m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) &&
1278         ShiftLC == ShiftRC && ShiftLC->ult(BitWidth) &&
1279         DemandedMask.isSubsetOf(APInt::getLowBitsSet(
1280             BitWidth, BitWidth - ShiftRC->getZExtValue()))) {
1281       return X;
1282     }
1283 
1284     break;
1285   }
1286   default:
1287     // Compute the Known bits to simplify things downstream.
1288     llvm::computeKnownBits(I, Known, Depth, Q);
1289 
1290     // If this user is only demanding bits that we know, return the known
1291     // constant.
1292     if (DemandedMask.isSubsetOf(Known.Zero|Known.One))
1293       return Constant::getIntegerValue(ITy, Known.One);
1294 
1295     break;
1296   }
1297 
1298   return nullptr;
1299 }
1300 
1301 /// Helper routine of SimplifyDemandedUseBits. It tries to simplify
1302 /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
1303 /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
1304 /// of "C2-C1".
1305 ///
1306 /// Suppose E1 and E2 are generally different in bits S={bm, bm+1,
1307 /// ..., bn}, without considering the specific value X is holding.
1308 /// This transformation is legal iff one of following conditions is hold:
1309 ///  1) All the bit in S are 0, in this case E1 == E2.
1310 ///  2) We don't care those bits in S, per the input DemandedMask.
1311 ///  3) Combination of 1) and 2). Some bits in S are 0, and we don't care the
1312 ///     rest bits.
1313 ///
1314 /// Currently we only test condition 2).
1315 ///
1316 /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was
1317 /// not successful.
1318 Value *InstCombinerImpl::simplifyShrShlDemandedBits(
1319     Instruction *Shr, const APInt &ShrOp1, Instruction *Shl,
1320     const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) {
1321   if (!ShlOp1 || !ShrOp1)
1322     return nullptr; // No-op.
1323 
1324   Value *VarX = Shr->getOperand(0);
1325   Type *Ty = VarX->getType();
1326   unsigned BitWidth = Ty->getScalarSizeInBits();
1327   if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth))
1328     return nullptr; // Undef.
1329 
1330   unsigned ShlAmt = ShlOp1.getZExtValue();
1331   unsigned ShrAmt = ShrOp1.getZExtValue();
1332 
1333   Known.One.clearAllBits();
1334   Known.Zero.setLowBits(ShlAmt - 1);
1335   Known.Zero &= DemandedMask;
1336 
1337   APInt BitMask1(APInt::getAllOnes(BitWidth));
1338   APInt BitMask2(APInt::getAllOnes(BitWidth));
1339 
1340   bool isLshr = (Shr->getOpcode() == Instruction::LShr);
1341   BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) :
1342                       (BitMask1.ashr(ShrAmt) << ShlAmt);
1343 
1344   if (ShrAmt <= ShlAmt) {
1345     BitMask2 <<= (ShlAmt - ShrAmt);
1346   } else {
1347     BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt):
1348                         BitMask2.ashr(ShrAmt - ShlAmt);
1349   }
1350 
1351   // Check if condition-2 (see the comment to this function) is satified.
1352   if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) {
1353     if (ShrAmt == ShlAmt)
1354       return VarX;
1355 
1356     if (!Shr->hasOneUse())
1357       return nullptr;
1358 
1359     BinaryOperator *New;
1360     if (ShrAmt < ShlAmt) {
1361       Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt);
1362       New = BinaryOperator::CreateShl(VarX, Amt);
1363       BinaryOperator *Orig = cast<BinaryOperator>(Shl);
1364       New->setHasNoSignedWrap(Orig->hasNoSignedWrap());
1365       New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap());
1366     } else {
1367       Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt);
1368       New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) :
1369                      BinaryOperator::CreateAShr(VarX, Amt);
1370       if (cast<BinaryOperator>(Shr)->isExact())
1371         New->setIsExact(true);
1372     }
1373 
1374     return InsertNewInstWith(New, Shl->getIterator());
1375   }
1376 
1377   return nullptr;
1378 }
1379 
1380 /// The specified value produces a vector with any number of elements.
1381 /// This method analyzes which elements of the operand are poison and
1382 /// returns that information in PoisonElts.
1383 ///
1384 /// DemandedElts contains the set of elements that are actually used by the
1385 /// caller, and by default (AllowMultipleUsers equals false) the value is
1386 /// simplified only if it has a single caller. If AllowMultipleUsers is set
1387 /// to true, DemandedElts refers to the union of sets of elements that are
1388 /// used by all callers.
1389 ///
1390 /// If the information about demanded elements can be used to simplify the
1391 /// operation, the operation is simplified, then the resultant value is
1392 /// returned.  This returns null if no change was made.
1393 Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
1394                                                     APInt DemandedElts,
1395                                                     APInt &PoisonElts,
1396                                                     unsigned Depth,
1397                                                     bool AllowMultipleUsers) {
1398   // Cannot analyze scalable type. The number of vector elements is not a
1399   // compile-time constant.
1400   if (isa<ScalableVectorType>(V->getType()))
1401     return nullptr;
1402 
1403   unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
1404   APInt EltMask(APInt::getAllOnes(VWidth));
1405   assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
1406 
1407   if (match(V, m_Poison())) {
1408     // If the entire vector is poison, just return this info.
1409     PoisonElts = EltMask;
1410     return nullptr;
1411   }
1412 
1413   if (DemandedElts.isZero()) { // If nothing is demanded, provide poison.
1414     PoisonElts = EltMask;
1415     return PoisonValue::get(V->getType());
1416   }
1417 
1418   PoisonElts = 0;
1419 
1420   if (auto *C = dyn_cast<Constant>(V)) {
1421     // Check if this is identity. If so, return 0 since we are not simplifying
1422     // anything.
1423     if (DemandedElts.isAllOnes())
1424       return nullptr;
1425 
1426     Type *EltTy = cast<VectorType>(V->getType())->getElementType();
1427     Constant *Poison = PoisonValue::get(EltTy);
1428     SmallVector<Constant*, 16> Elts;
1429     for (unsigned i = 0; i != VWidth; ++i) {
1430       if (!DemandedElts[i]) {   // If not demanded, set to poison.
1431         Elts.push_back(Poison);
1432         PoisonElts.setBit(i);
1433         continue;
1434       }
1435 
1436       Constant *Elt = C->getAggregateElement(i);
1437       if (!Elt) return nullptr;
1438 
1439       Elts.push_back(Elt);
1440       if (isa<PoisonValue>(Elt)) // Already poison.
1441         PoisonElts.setBit(i);
1442     }
1443 
1444     // If we changed the constant, return it.
1445     Constant *NewCV = ConstantVector::get(Elts);
1446     return NewCV != C ? NewCV : nullptr;
1447   }
1448 
1449   // Limit search depth.
1450   if (Depth == 10)
1451     return nullptr;
1452 
1453   if (!AllowMultipleUsers) {
1454     // If multiple users are using the root value, proceed with
1455     // simplification conservatively assuming that all elements
1456     // are needed.
1457     if (!V->hasOneUse()) {
1458       // Quit if we find multiple users of a non-root value though.
1459       // They'll be handled when it's their turn to be visited by
1460       // the main instcombine process.
1461       if (Depth != 0)
1462         // TODO: Just compute the PoisonElts information recursively.
1463         return nullptr;
1464 
1465       // Conservatively assume that all elements are needed.
1466       DemandedElts = EltMask;
1467     }
1468   }
1469 
1470   Instruction *I = dyn_cast<Instruction>(V);
1471   if (!I) return nullptr;        // Only analyze instructions.
1472 
1473   bool MadeChange = false;
1474   auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum,
1475                               APInt Demanded, APInt &Undef) {
1476     auto *II = dyn_cast<IntrinsicInst>(Inst);
1477     Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum);
1478     if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) {
1479       replaceOperand(*Inst, OpNum, V);
1480       MadeChange = true;
1481     }
1482   };
1483 
1484   APInt PoisonElts2(VWidth, 0);
1485   APInt PoisonElts3(VWidth, 0);
1486   switch (I->getOpcode()) {
1487   default: break;
1488 
1489   case Instruction::GetElementPtr: {
1490     // The LangRef requires that struct geps have all constant indices.  As
1491     // such, we can't convert any operand to partial undef.
1492     auto mayIndexStructType = [](GetElementPtrInst &GEP) {
1493       for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP);
1494            I != E; I++)
1495         if (I.isStruct())
1496           return true;
1497       return false;
1498     };
1499     if (mayIndexStructType(cast<GetElementPtrInst>(*I)))
1500       break;
1501 
1502     // Conservatively track the demanded elements back through any vector
1503     // operands we may have.  We know there must be at least one, or we
1504     // wouldn't have a vector result to get here. Note that we intentionally
1505     // merge the undef bits here since gepping with either an poison base or
1506     // index results in poison.
1507     for (unsigned i = 0; i < I->getNumOperands(); i++) {
1508       if (i == 0 ? match(I->getOperand(i), m_Undef())
1509                  : match(I->getOperand(i), m_Poison())) {
1510         // If the entire vector is undefined, just return this info.
1511         PoisonElts = EltMask;
1512         return nullptr;
1513       }
1514       if (I->getOperand(i)->getType()->isVectorTy()) {
1515         APInt PoisonEltsOp(VWidth, 0);
1516         simplifyAndSetOp(I, i, DemandedElts, PoisonEltsOp);
1517         // gep(x, undef) is not undef, so skip considering idx ops here
1518         // Note that we could propagate poison, but we can't distinguish between
1519         // undef & poison bits ATM
1520         if (i == 0)
1521           PoisonElts |= PoisonEltsOp;
1522       }
1523     }
1524 
1525     break;
1526   }
1527   case Instruction::InsertElement: {
1528     // If this is a variable index, we don't know which element it overwrites.
1529     // demand exactly the same input as we produce.
1530     ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2));
1531     if (!Idx) {
1532       // Note that we can't propagate undef elt info, because we don't know
1533       // which elt is getting updated.
1534       simplifyAndSetOp(I, 0, DemandedElts, PoisonElts2);
1535       break;
1536     }
1537 
1538     // The element inserted overwrites whatever was there, so the input demanded
1539     // set is simpler than the output set.
1540     unsigned IdxNo = Idx->getZExtValue();
1541     APInt PreInsertDemandedElts = DemandedElts;
1542     if (IdxNo < VWidth)
1543       PreInsertDemandedElts.clearBit(IdxNo);
1544 
1545     // If we only demand the element that is being inserted and that element
1546     // was extracted from the same index in another vector with the same type,
1547     // replace this insert with that other vector.
1548     // Note: This is attempted before the call to simplifyAndSetOp because that
1549     //       may change PoisonElts to a value that does not match with Vec.
1550     Value *Vec;
1551     if (PreInsertDemandedElts == 0 &&
1552         match(I->getOperand(1),
1553               m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) &&
1554         Vec->getType() == I->getType()) {
1555       return Vec;
1556     }
1557 
1558     simplifyAndSetOp(I, 0, PreInsertDemandedElts, PoisonElts);
1559 
1560     // If this is inserting an element that isn't demanded, remove this
1561     // insertelement.
1562     if (IdxNo >= VWidth || !DemandedElts[IdxNo]) {
1563       Worklist.push(I);
1564       return I->getOperand(0);
1565     }
1566 
1567     // The inserted element is defined.
1568     PoisonElts.clearBit(IdxNo);
1569     break;
1570   }
1571   case Instruction::ShuffleVector: {
1572     auto *Shuffle = cast<ShuffleVectorInst>(I);
1573     assert(Shuffle->getOperand(0)->getType() ==
1574            Shuffle->getOperand(1)->getType() &&
1575            "Expected shuffle operands to have same type");
1576     unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType())
1577                            ->getNumElements();
1578     // Handle trivial case of a splat. Only check the first element of LHS
1579     // operand.
1580     if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) &&
1581         DemandedElts.isAllOnes()) {
1582       if (!isa<PoisonValue>(I->getOperand(1))) {
1583         I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType()));
1584         MadeChange = true;
1585       }
1586       APInt LeftDemanded(OpWidth, 1);
1587       APInt LHSPoisonElts(OpWidth, 0);
1588       simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts);
1589       if (LHSPoisonElts[0])
1590         PoisonElts = EltMask;
1591       else
1592         PoisonElts.clearAllBits();
1593       break;
1594     }
1595 
1596     APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
1597     for (unsigned i = 0; i < VWidth; i++) {
1598       if (DemandedElts[i]) {
1599         unsigned MaskVal = Shuffle->getMaskValue(i);
1600         if (MaskVal != -1u) {
1601           assert(MaskVal < OpWidth * 2 &&
1602                  "shufflevector mask index out of range!");
1603           if (MaskVal < OpWidth)
1604             LeftDemanded.setBit(MaskVal);
1605           else
1606             RightDemanded.setBit(MaskVal - OpWidth);
1607         }
1608       }
1609     }
1610 
1611     APInt LHSPoisonElts(OpWidth, 0);
1612     simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts);
1613 
1614     APInt RHSPoisonElts(OpWidth, 0);
1615     simplifyAndSetOp(I, 1, RightDemanded, RHSPoisonElts);
1616 
1617     // If this shuffle does not change the vector length and the elements
1618     // demanded by this shuffle are an identity mask, then this shuffle is
1619     // unnecessary.
1620     //
1621     // We are assuming canonical form for the mask, so the source vector is
1622     // operand 0 and operand 1 is not used.
1623     //
1624     // Note that if an element is demanded and this shuffle mask is undefined
1625     // for that element, then the shuffle is not considered an identity
1626     // operation. The shuffle prevents poison from the operand vector from
1627     // leaking to the result by replacing poison with an undefined value.
1628     if (VWidth == OpWidth) {
1629       bool IsIdentityShuffle = true;
1630       for (unsigned i = 0; i < VWidth; i++) {
1631         unsigned MaskVal = Shuffle->getMaskValue(i);
1632         if (DemandedElts[i] && i != MaskVal) {
1633           IsIdentityShuffle = false;
1634           break;
1635         }
1636       }
1637       if (IsIdentityShuffle)
1638         return Shuffle->getOperand(0);
1639     }
1640 
1641     bool NewPoisonElts = false;
1642     unsigned LHSIdx = -1u, LHSValIdx = -1u;
1643     unsigned RHSIdx = -1u, RHSValIdx = -1u;
1644     bool LHSUniform = true;
1645     bool RHSUniform = true;
1646     for (unsigned i = 0; i < VWidth; i++) {
1647       unsigned MaskVal = Shuffle->getMaskValue(i);
1648       if (MaskVal == -1u) {
1649         PoisonElts.setBit(i);
1650       } else if (!DemandedElts[i]) {
1651         NewPoisonElts = true;
1652         PoisonElts.setBit(i);
1653       } else if (MaskVal < OpWidth) {
1654         if (LHSPoisonElts[MaskVal]) {
1655           NewPoisonElts = true;
1656           PoisonElts.setBit(i);
1657         } else {
1658           LHSIdx = LHSIdx == -1u ? i : OpWidth;
1659           LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
1660           LHSUniform = LHSUniform && (MaskVal == i);
1661         }
1662       } else {
1663         if (RHSPoisonElts[MaskVal - OpWidth]) {
1664           NewPoisonElts = true;
1665           PoisonElts.setBit(i);
1666         } else {
1667           RHSIdx = RHSIdx == -1u ? i : OpWidth;
1668           RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
1669           RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
1670         }
1671       }
1672     }
1673 
1674     // Try to transform shuffle with constant vector and single element from
1675     // this constant vector to single insertelement instruction.
1676     // shufflevector V, C, <v1, v2, .., ci, .., vm> ->
1677     // insertelement V, C[ci], ci-n
1678     if (OpWidth ==
1679         cast<FixedVectorType>(Shuffle->getType())->getNumElements()) {
1680       Value *Op = nullptr;
1681       Constant *Value = nullptr;
1682       unsigned Idx = -1u;
1683 
1684       // Find constant vector with the single element in shuffle (LHS or RHS).
1685       if (LHSIdx < OpWidth && RHSUniform) {
1686         if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) {
1687           Op = Shuffle->getOperand(1);
1688           Value = CV->getOperand(LHSValIdx);
1689           Idx = LHSIdx;
1690         }
1691       }
1692       if (RHSIdx < OpWidth && LHSUniform) {
1693         if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) {
1694           Op = Shuffle->getOperand(0);
1695           Value = CV->getOperand(RHSValIdx);
1696           Idx = RHSIdx;
1697         }
1698       }
1699       // Found constant vector with single element - convert to insertelement.
1700       if (Op && Value) {
1701         Instruction *New = InsertElementInst::Create(
1702             Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx),
1703             Shuffle->getName());
1704         InsertNewInstWith(New, Shuffle->getIterator());
1705         return New;
1706       }
1707     }
1708     if (NewPoisonElts) {
1709       // Add additional discovered undefs.
1710       SmallVector<int, 16> Elts;
1711       for (unsigned i = 0; i < VWidth; ++i) {
1712         if (PoisonElts[i])
1713           Elts.push_back(PoisonMaskElem);
1714         else
1715           Elts.push_back(Shuffle->getMaskValue(i));
1716       }
1717       Shuffle->setShuffleMask(Elts);
1718       MadeChange = true;
1719     }
1720     break;
1721   }
1722   case Instruction::Select: {
1723     // If this is a vector select, try to transform the select condition based
1724     // on the current demanded elements.
1725     SelectInst *Sel = cast<SelectInst>(I);
1726     if (Sel->getCondition()->getType()->isVectorTy()) {
1727       // TODO: We are not doing anything with PoisonElts based on this call.
1728       // It is overwritten below based on the other select operands. If an
1729       // element of the select condition is known undef, then we are free to
1730       // choose the output value from either arm of the select. If we know that
1731       // one of those values is undef, then the output can be undef.
1732       simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1733     }
1734 
1735     // Next, see if we can transform the arms of the select.
1736     APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts);
1737     if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) {
1738       for (unsigned i = 0; i < VWidth; i++) {
1739         // isNullValue() always returns false when called on a ConstantExpr.
1740         // Skip constant expressions to avoid propagating incorrect information.
1741         Constant *CElt = CV->getAggregateElement(i);
1742         if (isa<ConstantExpr>(CElt))
1743           continue;
1744         // TODO: If a select condition element is undef, we can demand from
1745         // either side. If one side is known undef, choosing that side would
1746         // propagate undef.
1747         if (CElt->isNullValue())
1748           DemandedLHS.clearBit(i);
1749         else
1750           DemandedRHS.clearBit(i);
1751       }
1752     }
1753 
1754     simplifyAndSetOp(I, 1, DemandedLHS, PoisonElts2);
1755     simplifyAndSetOp(I, 2, DemandedRHS, PoisonElts3);
1756 
1757     // Output elements are undefined if the element from each arm is undefined.
1758     // TODO: This can be improved. See comment in select condition handling.
1759     PoisonElts = PoisonElts2 & PoisonElts3;
1760     break;
1761   }
1762   case Instruction::BitCast: {
1763     // Vector->vector casts only.
1764     VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType());
1765     if (!VTy) break;
1766     unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements();
1767     APInt InputDemandedElts(InVWidth, 0);
1768     PoisonElts2 = APInt(InVWidth, 0);
1769     unsigned Ratio;
1770 
1771     if (VWidth == InVWidth) {
1772       // If we are converting from <4 x i32> -> <4 x f32>, we demand the same
1773       // elements as are demanded of us.
1774       Ratio = 1;
1775       InputDemandedElts = DemandedElts;
1776     } else if ((VWidth % InVWidth) == 0) {
1777       // If the number of elements in the output is a multiple of the number of
1778       // elements in the input then an input element is live if any of the
1779       // corresponding output elements are live.
1780       Ratio = VWidth / InVWidth;
1781       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1782         if (DemandedElts[OutIdx])
1783           InputDemandedElts.setBit(OutIdx / Ratio);
1784     } else if ((InVWidth % VWidth) == 0) {
1785       // If the number of elements in the input is a multiple of the number of
1786       // elements in the output then an input element is live if the
1787       // corresponding output element is live.
1788       Ratio = InVWidth / VWidth;
1789       for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
1790         if (DemandedElts[InIdx / Ratio])
1791           InputDemandedElts.setBit(InIdx);
1792     } else {
1793       // Unsupported so far.
1794       break;
1795     }
1796 
1797     simplifyAndSetOp(I, 0, InputDemandedElts, PoisonElts2);
1798 
1799     if (VWidth == InVWidth) {
1800       PoisonElts = PoisonElts2;
1801     } else if ((VWidth % InVWidth) == 0) {
1802       // If the number of elements in the output is a multiple of the number of
1803       // elements in the input then an output element is undef if the
1804       // corresponding input element is undef.
1805       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1806         if (PoisonElts2[OutIdx / Ratio])
1807           PoisonElts.setBit(OutIdx);
1808     } else if ((InVWidth % VWidth) == 0) {
1809       // If the number of elements in the input is a multiple of the number of
1810       // elements in the output then an output element is undef if all of the
1811       // corresponding input elements are undef.
1812       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
1813         APInt SubUndef = PoisonElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
1814         if (SubUndef.popcount() == Ratio)
1815           PoisonElts.setBit(OutIdx);
1816       }
1817     } else {
1818       llvm_unreachable("Unimp");
1819     }
1820     break;
1821   }
1822   case Instruction::FPTrunc:
1823   case Instruction::FPExt:
1824     simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1825     break;
1826 
1827   case Instruction::Call: {
1828     IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
1829     if (!II) break;
1830     switch (II->getIntrinsicID()) {
1831     case Intrinsic::masked_gather: // fallthrough
1832     case Intrinsic::masked_load: {
1833       // Subtlety: If we load from a pointer, the pointer must be valid
1834       // regardless of whether the element is demanded.  Doing otherwise risks
1835       // segfaults which didn't exist in the original program.
1836       APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
1837           DemandedPassThrough(DemandedElts);
1838       if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2)))
1839         for (unsigned i = 0; i < VWidth; i++) {
1840           Constant *CElt = CV->getAggregateElement(i);
1841           if (CElt->isNullValue())
1842             DemandedPtrs.clearBit(i);
1843           else if (CElt->isAllOnesValue())
1844             DemandedPassThrough.clearBit(i);
1845         }
1846       if (II->getIntrinsicID() == Intrinsic::masked_gather)
1847         simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2);
1848       simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3);
1849 
1850       // Output elements are undefined if the element from both sources are.
1851       // TODO: can strengthen via mask as well.
1852       PoisonElts = PoisonElts2 & PoisonElts3;
1853       break;
1854     }
1855     default: {
1856       // Handle target specific intrinsics
1857       std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
1858           *II, DemandedElts, PoisonElts, PoisonElts2, PoisonElts3,
1859           simplifyAndSetOp);
1860       if (V)
1861         return *V;
1862       break;
1863     }
1864     } // switch on IntrinsicID
1865     break;
1866   } // case Call
1867   } // switch on Opcode
1868 
1869   // TODO: We bail completely on integer div/rem and shifts because they have
1870   // UB/poison potential, but that should be refined.
1871   BinaryOperator *BO;
1872   if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) {
1873     Value *X = BO->getOperand(0);
1874     Value *Y = BO->getOperand(1);
1875 
1876     // Look for an equivalent binop except that one operand has been shuffled.
1877     // If the demand for this binop only includes elements that are the same as
1878     // the other binop, then we may be able to replace this binop with a use of
1879     // the earlier one.
1880     //
1881     // Example:
1882     // %other_bo = bo (shuf X, {0}), Y
1883     // %this_extracted_bo = extelt (bo X, Y), 0
1884     // -->
1885     // %other_bo = bo (shuf X, {0}), Y
1886     // %this_extracted_bo = extelt %other_bo, 0
1887     //
1888     // TODO: Handle demand of an arbitrary single element or more than one
1889     //       element instead of just element 0.
1890     // TODO: Unlike general demanded elements transforms, this should be safe
1891     //       for any (div/rem/shift) opcode too.
1892     if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() &&
1893         BO->hasOneUse() ) {
1894 
1895       auto findShufBO = [&](bool MatchShufAsOp0) -> User * {
1896         // Try to use shuffle-of-operand in place of an operand:
1897         // bo X, Y --> bo (shuf X), Y
1898         // bo X, Y --> bo X, (shuf Y)
1899         BinaryOperator::BinaryOps Opcode = BO->getOpcode();
1900         Value *ShufOp = MatchShufAsOp0 ? X : Y;
1901         Value *OtherOp = MatchShufAsOp0 ? Y : X;
1902         for (User *U : OtherOp->users()) {
1903           ArrayRef<int> Mask;
1904           auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_Mask(Mask));
1905           if (BO->isCommutative()
1906                   ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
1907                   : MatchShufAsOp0
1908                         ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp)))
1909                         : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf)))
1910             if (match(Mask, m_ZeroMask()) && Mask[0] != PoisonMaskElem)
1911               if (DT.dominates(U, I))
1912                 return U;
1913         }
1914         return nullptr;
1915       };
1916 
1917       if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true))
1918         return ShufBO;
1919       if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false))
1920         return ShufBO;
1921     }
1922 
1923     simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1924     simplifyAndSetOp(I, 1, DemandedElts, PoisonElts2);
1925 
1926     // Output elements are undefined if both are undefined. Consider things
1927     // like undef & 0. The result is known zero, not undef.
1928     PoisonElts &= PoisonElts2;
1929   }
1930 
1931   // If we've proven all of the lanes poison, return a poison value.
1932   // TODO: Intersect w/demanded lanes
1933   if (PoisonElts.isAllOnes())
1934     return PoisonValue::get(I->getType());
1935 
1936   return MadeChange ? I : nullptr;
1937 }
1938 
1939 /// For floating-point classes that resolve to a single bit pattern, return that
1940 /// value.
1941 static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask) {
1942   switch (Mask) {
1943   case fcPosZero:
1944     return ConstantFP::getZero(Ty);
1945   case fcNegZero:
1946     return ConstantFP::getZero(Ty, true);
1947   case fcPosInf:
1948     return ConstantFP::getInfinity(Ty);
1949   case fcNegInf:
1950     return ConstantFP::getInfinity(Ty, true);
1951   case fcNone:
1952     return PoisonValue::get(Ty);
1953   default:
1954     return nullptr;
1955   }
1956 }
1957 
1958 Value *InstCombinerImpl::SimplifyDemandedUseFPClass(
1959     Value *V, const FPClassTest DemandedMask, KnownFPClass &Known,
1960     unsigned Depth, Instruction *CxtI) {
1961   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
1962   Type *VTy = V->getType();
1963 
1964   assert(Known == KnownFPClass() && "expected uninitialized state");
1965 
1966   if (DemandedMask == fcNone)
1967     return isa<UndefValue>(V) ? nullptr : PoisonValue::get(VTy);
1968 
1969   if (Depth == MaxAnalysisRecursionDepth)
1970     return nullptr;
1971 
1972   Instruction *I = dyn_cast<Instruction>(V);
1973   if (!I) {
1974     // Handle constants and arguments
1975     Known = computeKnownFPClass(V, fcAllFlags, CxtI, Depth + 1);
1976     Value *FoldedToConst =
1977         getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses);
1978     return FoldedToConst == V ? nullptr : FoldedToConst;
1979   }
1980 
1981   if (!I->hasOneUse())
1982     return nullptr;
1983 
1984   // TODO: Should account for nofpclass/FastMathFlags on current instruction
1985   switch (I->getOpcode()) {
1986   case Instruction::FNeg: {
1987     if (SimplifyDemandedFPClass(I, 0, llvm::fneg(DemandedMask), Known,
1988                                 Depth + 1))
1989       return I;
1990     Known.fneg();
1991     break;
1992   }
1993   case Instruction::Call: {
1994     CallInst *CI = cast<CallInst>(I);
1995     switch (CI->getIntrinsicID()) {
1996     case Intrinsic::fabs:
1997       if (SimplifyDemandedFPClass(I, 0, llvm::inverse_fabs(DemandedMask), Known,
1998                                   Depth + 1))
1999         return I;
2000       Known.fabs();
2001       break;
2002     case Intrinsic::arithmetic_fence:
2003       if (SimplifyDemandedFPClass(I, 0, DemandedMask, Known, Depth + 1))
2004         return I;
2005       break;
2006     case Intrinsic::copysign: {
2007       // Flip on more potentially demanded classes
2008       const FPClassTest DemandedMaskAnySign = llvm::unknown_sign(DemandedMask);
2009       if (SimplifyDemandedFPClass(I, 0, DemandedMaskAnySign, Known, Depth + 1))
2010         return I;
2011 
2012       if ((DemandedMask & fcPositive) == fcNone) {
2013         // Roundabout way of replacing with fneg(fabs)
2014         I->setOperand(1, ConstantFP::get(VTy, -1.0));
2015         return I;
2016       }
2017 
2018       if ((DemandedMask & fcNegative) == fcNone) {
2019         // Roundabout way of replacing with fabs
2020         I->setOperand(1, ConstantFP::getZero(VTy));
2021         return I;
2022       }
2023 
2024       KnownFPClass KnownSign =
2025           computeKnownFPClass(I->getOperand(1), fcAllFlags, CxtI, Depth + 1);
2026       Known.copysign(KnownSign);
2027       break;
2028     }
2029     default:
2030       Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1);
2031       break;
2032     }
2033 
2034     break;
2035   }
2036   case Instruction::Select: {
2037     KnownFPClass KnownLHS, KnownRHS;
2038     if (SimplifyDemandedFPClass(I, 2, DemandedMask, KnownRHS, Depth + 1) ||
2039         SimplifyDemandedFPClass(I, 1, DemandedMask, KnownLHS, Depth + 1))
2040       return I;
2041 
2042     if (KnownLHS.isKnownNever(DemandedMask))
2043       return I->getOperand(2);
2044     if (KnownRHS.isKnownNever(DemandedMask))
2045       return I->getOperand(1);
2046 
2047     // TODO: Recognize clamping patterns
2048     Known = KnownLHS | KnownRHS;
2049     break;
2050   }
2051   default:
2052     Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1);
2053     break;
2054   }
2055 
2056   return getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses);
2057 }
2058 
2059 bool InstCombinerImpl::SimplifyDemandedFPClass(Instruction *I, unsigned OpNo,
2060                                                FPClassTest DemandedMask,
2061                                                KnownFPClass &Known,
2062                                                unsigned Depth) {
2063   Use &U = I->getOperandUse(OpNo);
2064   Value *NewVal =
2065       SimplifyDemandedUseFPClass(U.get(), DemandedMask, Known, Depth, I);
2066   if (!NewVal)
2067     return false;
2068   if (Instruction *OpInst = dyn_cast<Instruction>(U))
2069     salvageDebugInfo(*OpInst);
2070 
2071   replaceUse(U, NewVal);
2072   return true;
2073 }
2074