xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===- InstCombineCasts.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the visit functions for cast operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "InstCombineInternal.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/Analysis/ConstantFolding.h"
16 #include "llvm/IR/DataLayout.h"
17 #include "llvm/IR/DebugInfo.h"
18 #include "llvm/IR/PatternMatch.h"
19 #include "llvm/Support/KnownBits.h"
20 #include "llvm/Transforms/InstCombine/InstCombiner.h"
21 #include <optional>
22 
23 using namespace llvm;
24 using namespace PatternMatch;
25 
26 #define DEBUG_TYPE "instcombine"
27 
28 /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
29 /// true for, actually insert the code to evaluate the expression.
30 Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
31                                                  bool isSigned) {
32   if (Constant *C = dyn_cast<Constant>(V))
33     return ConstantFoldIntegerCast(C, Ty, isSigned, DL);
34 
35   // Otherwise, it must be an instruction.
36   Instruction *I = cast<Instruction>(V);
37   Instruction *Res = nullptr;
38   unsigned Opc = I->getOpcode();
39   switch (Opc) {
40   case Instruction::Add:
41   case Instruction::Sub:
42   case Instruction::Mul:
43   case Instruction::And:
44   case Instruction::Or:
45   case Instruction::Xor:
46   case Instruction::AShr:
47   case Instruction::LShr:
48   case Instruction::Shl:
49   case Instruction::UDiv:
50   case Instruction::URem: {
51     Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
52     Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
53     Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
54     break;
55   }
56   case Instruction::Trunc:
57   case Instruction::ZExt:
58   case Instruction::SExt:
59     // If the source type of the cast is the type we're trying for then we can
60     // just return the source.  There's no need to insert it because it is not
61     // new.
62     if (I->getOperand(0)->getType() == Ty)
63       return I->getOperand(0);
64 
65     // Otherwise, must be the same type of cast, so just reinsert a new one.
66     // This also handles the case of zext(trunc(x)) -> zext(x).
67     Res = CastInst::CreateIntegerCast(I->getOperand(0), Ty,
68                                       Opc == Instruction::SExt);
69     break;
70   case Instruction::Select: {
71     Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
72     Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned);
73     Res = SelectInst::Create(I->getOperand(0), True, False);
74     break;
75   }
76   case Instruction::PHI: {
77     PHINode *OPN = cast<PHINode>(I);
78     PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues());
79     for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) {
80       Value *V =
81           EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned);
82       NPN->addIncoming(V, OPN->getIncomingBlock(i));
83     }
84     Res = NPN;
85     break;
86   }
87   case Instruction::FPToUI:
88   case Instruction::FPToSI:
89     Res = CastInst::Create(
90       static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
91     break;
92   case Instruction::Call:
93     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
94       switch (II->getIntrinsicID()) {
95       default:
96         llvm_unreachable("Unsupported call!");
97       case Intrinsic::vscale: {
98         Function *Fn =
99             Intrinsic::getDeclaration(I->getModule(), Intrinsic::vscale, {Ty});
100         Res = CallInst::Create(Fn->getFunctionType(), Fn);
101         break;
102       }
103       }
104     }
105     break;
106   default:
107     // TODO: Can handle more cases here.
108     llvm_unreachable("Unreachable!");
109   }
110 
111   Res->takeName(I);
112   return InsertNewInstWith(Res, I->getIterator());
113 }
114 
115 Instruction::CastOps
116 InstCombinerImpl::isEliminableCastPair(const CastInst *CI1,
117                                        const CastInst *CI2) {
118   Type *SrcTy = CI1->getSrcTy();
119   Type *MidTy = CI1->getDestTy();
120   Type *DstTy = CI2->getDestTy();
121 
122   Instruction::CastOps firstOp = CI1->getOpcode();
123   Instruction::CastOps secondOp = CI2->getOpcode();
124   Type *SrcIntPtrTy =
125       SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr;
126   Type *MidIntPtrTy =
127       MidTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(MidTy) : nullptr;
128   Type *DstIntPtrTy =
129       DstTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(DstTy) : nullptr;
130   unsigned Res = CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy,
131                                                 DstTy, SrcIntPtrTy, MidIntPtrTy,
132                                                 DstIntPtrTy);
133 
134   // We don't want to form an inttoptr or ptrtoint that converts to an integer
135   // type that differs from the pointer size.
136   if ((Res == Instruction::IntToPtr && SrcTy != DstIntPtrTy) ||
137       (Res == Instruction::PtrToInt && DstTy != SrcIntPtrTy))
138     Res = 0;
139 
140   return Instruction::CastOps(Res);
141 }
142 
143 /// Implement the transforms common to all CastInst visitors.
144 Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
145   Value *Src = CI.getOperand(0);
146   Type *Ty = CI.getType();
147 
148   if (auto *SrcC = dyn_cast<Constant>(Src))
149     if (Constant *Res = ConstantFoldCastOperand(CI.getOpcode(), SrcC, Ty, DL))
150       return replaceInstUsesWith(CI, Res);
151 
152   // Try to eliminate a cast of a cast.
153   if (auto *CSrc = dyn_cast<CastInst>(Src)) {   // A->B->C cast
154     if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) {
155       // The first cast (CSrc) is eliminable so we need to fix up or replace
156       // the second cast (CI). CSrc will then have a good chance of being dead.
157       auto *Res = CastInst::Create(NewOpc, CSrc->getOperand(0), Ty);
158       // Point debug users of the dying cast to the new one.
159       if (CSrc->hasOneUse())
160         replaceAllDbgUsesWith(*CSrc, *Res, CI, DT);
161       return Res;
162     }
163   }
164 
165   if (auto *Sel = dyn_cast<SelectInst>(Src)) {
166     // We are casting a select. Try to fold the cast into the select if the
167     // select does not have a compare instruction with matching operand types
168     // or the select is likely better done in a narrow type.
169     // Creating a select with operands that are different sizes than its
170     // condition may inhibit other folds and lead to worse codegen.
171     auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition());
172     if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() ||
173         (CI.getOpcode() == Instruction::Trunc &&
174          shouldChangeType(CI.getSrcTy(), CI.getType()))) {
175       if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
176         replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
177         return NV;
178       }
179     }
180   }
181 
182   // If we are casting a PHI, then fold the cast into the PHI.
183   if (auto *PN = dyn_cast<PHINode>(Src)) {
184     // Don't do this if it would create a PHI node with an illegal type from a
185     // legal type.
186     if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() ||
187         shouldChangeType(CI.getSrcTy(), CI.getType()))
188       if (Instruction *NV = foldOpIntoPhi(CI, PN))
189         return NV;
190   }
191 
192   // Canonicalize a unary shuffle after the cast if neither operation changes
193   // the size or element size of the input vector.
194   // TODO: We could allow size-changing ops if that doesn't harm codegen.
195   // cast (shuffle X, Mask) --> shuffle (cast X), Mask
196   Value *X;
197   ArrayRef<int> Mask;
198   if (match(Src, m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(Mask))))) {
199     // TODO: Allow scalable vectors?
200     auto *SrcTy = dyn_cast<FixedVectorType>(X->getType());
201     auto *DestTy = dyn_cast<FixedVectorType>(Ty);
202     if (SrcTy && DestTy &&
203         SrcTy->getNumElements() == DestTy->getNumElements() &&
204         SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) {
205       Value *CastX = Builder.CreateCast(CI.getOpcode(), X, DestTy);
206       return new ShuffleVectorInst(CastX, Mask);
207     }
208   }
209 
210   return nullptr;
211 }
212 
213 /// Constants and extensions/truncates from the destination type are always
214 /// free to be evaluated in that type. This is a helper for canEvaluate*.
215 static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
216   if (isa<Constant>(V))
217     return match(V, m_ImmConstant());
218 
219   Value *X;
220   if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) &&
221       X->getType() == Ty)
222     return true;
223 
224   return false;
225 }
226 
227 /// Filter out values that we can not evaluate in the destination type for free.
228 /// This is a helper for canEvaluate*.
229 static bool canNotEvaluateInType(Value *V, Type *Ty) {
230   if (!isa<Instruction>(V))
231     return true;
232   // We don't extend or shrink something that has multiple uses --  doing so
233   // would require duplicating the instruction which isn't profitable.
234   if (!V->hasOneUse())
235     return true;
236 
237   return false;
238 }
239 
240 /// Return true if we can evaluate the specified expression tree as type Ty
241 /// instead of its larger type, and arrive with the same value.
242 /// This is used by code that tries to eliminate truncates.
243 ///
244 /// Ty will always be a type smaller than V.  We should return true if trunc(V)
245 /// can be computed by computing V in the smaller type.  If V is an instruction,
246 /// then trunc(inst(x,y)) can be computed as inst(trunc(x),trunc(y)), which only
247 /// makes sense if x and y can be efficiently truncated.
248 ///
249 /// This function works on both vectors and scalars.
250 ///
251 static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
252                                  Instruction *CxtI) {
253   if (canAlwaysEvaluateInType(V, Ty))
254     return true;
255   if (canNotEvaluateInType(V, Ty))
256     return false;
257 
258   auto *I = cast<Instruction>(V);
259   Type *OrigTy = V->getType();
260   switch (I->getOpcode()) {
261   case Instruction::Add:
262   case Instruction::Sub:
263   case Instruction::Mul:
264   case Instruction::And:
265   case Instruction::Or:
266   case Instruction::Xor:
267     // These operators can all arbitrarily be extended or truncated.
268     return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
269            canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
270 
271   case Instruction::UDiv:
272   case Instruction::URem: {
273     // UDiv and URem can be truncated if all the truncated bits are zero.
274     uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
275     uint32_t BitWidth = Ty->getScalarSizeInBits();
276     assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!");
277     APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
278     if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) &&
279         IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) {
280       return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
281              canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
282     }
283     break;
284   }
285   case Instruction::Shl: {
286     // If we are truncating the result of this SHL, and if it's a shift of an
287     // inrange amount, we can always perform a SHL in a smaller type.
288     uint32_t BitWidth = Ty->getScalarSizeInBits();
289     KnownBits AmtKnownBits =
290         llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
291     if (AmtKnownBits.getMaxValue().ult(BitWidth))
292       return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
293              canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
294     break;
295   }
296   case Instruction::LShr: {
297     // If this is a truncate of a logical shr, we can truncate it to a smaller
298     // lshr iff we know that the bits we would otherwise be shifting in are
299     // already zeros.
300     // TODO: It is enough to check that the bits we would be shifting in are
301     //       zero - use AmtKnownBits.getMaxValue().
302     uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
303     uint32_t BitWidth = Ty->getScalarSizeInBits();
304     KnownBits AmtKnownBits =
305         llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
306     APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
307     if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
308         IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
309       return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
310              canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
311     }
312     break;
313   }
314   case Instruction::AShr: {
315     // If this is a truncate of an arithmetic shr, we can truncate it to a
316     // smaller ashr iff we know that all the bits from the sign bit of the
317     // original type and the sign bit of the truncate type are similar.
318     // TODO: It is enough to check that the bits we would be shifting in are
319     //       similar to sign bit of the truncate type.
320     uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
321     uint32_t BitWidth = Ty->getScalarSizeInBits();
322     KnownBits AmtKnownBits =
323         llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
324     unsigned ShiftedBits = OrigBitWidth - BitWidth;
325     if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
326         ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
327       return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
328              canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
329     break;
330   }
331   case Instruction::Trunc:
332     // trunc(trunc(x)) -> trunc(x)
333     return true;
334   case Instruction::ZExt:
335   case Instruction::SExt:
336     // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
337     // trunc(ext(x)) -> trunc(x) if the source type is larger than the new dest
338     return true;
339   case Instruction::Select: {
340     SelectInst *SI = cast<SelectInst>(I);
341     return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) &&
342            canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI);
343   }
344   case Instruction::PHI: {
345     // We can change a phi if we can change all operands.  Note that we never
346     // get into trouble with cyclic PHIs here because we only consider
347     // instructions with a single use.
348     PHINode *PN = cast<PHINode>(I);
349     for (Value *IncValue : PN->incoming_values())
350       if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI))
351         return false;
352     return true;
353   }
354   case Instruction::FPToUI:
355   case Instruction::FPToSI: {
356     // If the integer type can hold the max FP value, it is safe to cast
357     // directly to that type. Otherwise, we may create poison via overflow
358     // that did not exist in the original code.
359     Type *InputTy = I->getOperand(0)->getType()->getScalarType();
360     const fltSemantics &Semantics = InputTy->getFltSemantics();
361     uint32_t MinBitWidth =
362       APFloatBase::semanticsIntSizeInBits(Semantics,
363           I->getOpcode() == Instruction::FPToSI);
364     return Ty->getScalarSizeInBits() >= MinBitWidth;
365   }
366   default:
367     // TODO: Can handle more cases here.
368     break;
369   }
370 
371   return false;
372 }
373 
374 /// Given a vector that is bitcast to an integer, optionally logically
375 /// right-shifted, and truncated, convert it to an extractelement.
376 /// Example (big endian):
377 ///   trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32
378 ///   --->
379 ///   extractelement <4 x i32> %X, 1
380 static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
381                                          InstCombinerImpl &IC) {
382   Value *TruncOp = Trunc.getOperand(0);
383   Type *DestType = Trunc.getType();
384   if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType))
385     return nullptr;
386 
387   Value *VecInput = nullptr;
388   ConstantInt *ShiftVal = nullptr;
389   if (!match(TruncOp, m_CombineOr(m_BitCast(m_Value(VecInput)),
390                                   m_LShr(m_BitCast(m_Value(VecInput)),
391                                          m_ConstantInt(ShiftVal)))) ||
392       !isa<VectorType>(VecInput->getType()))
393     return nullptr;
394 
395   VectorType *VecType = cast<VectorType>(VecInput->getType());
396   unsigned VecWidth = VecType->getPrimitiveSizeInBits();
397   unsigned DestWidth = DestType->getPrimitiveSizeInBits();
398   unsigned ShiftAmount = ShiftVal ? ShiftVal->getZExtValue() : 0;
399 
400   if ((VecWidth % DestWidth != 0) || (ShiftAmount % DestWidth != 0))
401     return nullptr;
402 
403   // If the element type of the vector doesn't match the result type,
404   // bitcast it to a vector type that we can extract from.
405   unsigned NumVecElts = VecWidth / DestWidth;
406   if (VecType->getElementType() != DestType) {
407     VecType = FixedVectorType::get(DestType, NumVecElts);
408     VecInput = IC.Builder.CreateBitCast(VecInput, VecType, "bc");
409   }
410 
411   unsigned Elt = ShiftAmount / DestWidth;
412   if (IC.getDataLayout().isBigEndian())
413     Elt = NumVecElts - 1 - Elt;
414 
415   return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
416 }
417 
418 /// Funnel/Rotate left/right may occur in a wider type than necessary because of
419 /// type promotion rules. Try to narrow the inputs and convert to funnel shift.
420 Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
421   assert((isa<VectorType>(Trunc.getSrcTy()) ||
422           shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) &&
423          "Don't narrow to an illegal scalar type");
424 
425   // Bail out on strange types. It is possible to handle some of these patterns
426   // even with non-power-of-2 sizes, but it is not a likely scenario.
427   Type *DestTy = Trunc.getType();
428   unsigned NarrowWidth = DestTy->getScalarSizeInBits();
429   unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits();
430   if (!isPowerOf2_32(NarrowWidth))
431     return nullptr;
432 
433   // First, find an or'd pair of opposite shifts:
434   // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1))
435   BinaryOperator *Or0, *Or1;
436   if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1)))))
437     return nullptr;
438 
439   Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
440   if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
441       !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
442       Or0->getOpcode() == Or1->getOpcode())
443     return nullptr;
444 
445   // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
446   if (Or0->getOpcode() == BinaryOperator::LShr) {
447     std::swap(Or0, Or1);
448     std::swap(ShVal0, ShVal1);
449     std::swap(ShAmt0, ShAmt1);
450   }
451   assert(Or0->getOpcode() == BinaryOperator::Shl &&
452          Or1->getOpcode() == BinaryOperator::LShr &&
453          "Illegal or(shift,shift) pair");
454 
455   // Match the shift amount operands for a funnel/rotate pattern. This always
456   // matches a subtraction on the R operand.
457   auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
458     // The shift amounts may add up to the narrow bit width:
459     // (shl ShVal0, L) | (lshr ShVal1, Width - L)
460     // If this is a funnel shift (different operands are shifted), then the
461     // shift amount can not over-shift (create poison) in the narrow type.
462     unsigned MaxShiftAmountWidth = Log2_32(NarrowWidth);
463     APInt HiBitMask = ~APInt::getLowBitsSet(WideWidth, MaxShiftAmountWidth);
464     if (ShVal0 == ShVal1 || MaskedValueIsZero(L, HiBitMask))
465       if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L)))))
466         return L;
467 
468     // The following patterns currently only work for rotation patterns.
469     // TODO: Add more general funnel-shift compatible patterns.
470     if (ShVal0 != ShVal1)
471       return nullptr;
472 
473     // The shift amount may be masked with negation:
474     // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1)))
475     Value *X;
476     unsigned Mask = Width - 1;
477     if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
478         match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
479       return X;
480 
481     // Same as above, but the shift amount may be extended after masking:
482     if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
483         match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
484       return X;
485 
486     return nullptr;
487   };
488 
489   Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth);
490   bool IsFshl = true; // Sub on LSHR.
491   if (!ShAmt) {
492     ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth);
493     IsFshl = false; // Sub on SHL.
494   }
495   if (!ShAmt)
496     return nullptr;
497 
498   // The right-shifted value must have high zeros in the wide type (for example
499   // from 'zext', 'and' or 'shift'). High bits of the left-shifted value are
500   // truncated, so those do not matter.
501   APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth);
502   if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc))
503     return nullptr;
504 
505   // Adjust the width of ShAmt for narrowed funnel shift operation:
506   // - Zero-extend if ShAmt is narrower than the destination type.
507   // - Truncate if ShAmt is wider, discarding non-significant high-order bits.
508   // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal),
509   // zext/trunc(ShAmt)).
510   Value *NarrowShAmt = Builder.CreateZExtOrTrunc(ShAmt, DestTy);
511 
512   Value *X, *Y;
513   X = Y = Builder.CreateTrunc(ShVal0, DestTy);
514   if (ShVal0 != ShVal1)
515     Y = Builder.CreateTrunc(ShVal1, DestTy);
516   Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
517   Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy);
518   return CallInst::Create(F, {X, Y, NarrowShAmt});
519 }
520 
521 /// Try to narrow the width of math or bitwise logic instructions by pulling a
522 /// truncate ahead of binary operators.
523 Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) {
524   Type *SrcTy = Trunc.getSrcTy();
525   Type *DestTy = Trunc.getType();
526   unsigned SrcWidth = SrcTy->getScalarSizeInBits();
527   unsigned DestWidth = DestTy->getScalarSizeInBits();
528 
529   if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
530     return nullptr;
531 
532   BinaryOperator *BinOp;
533   if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(BinOp))))
534     return nullptr;
535 
536   Value *BinOp0 = BinOp->getOperand(0);
537   Value *BinOp1 = BinOp->getOperand(1);
538   switch (BinOp->getOpcode()) {
539   case Instruction::And:
540   case Instruction::Or:
541   case Instruction::Xor:
542   case Instruction::Add:
543   case Instruction::Sub:
544   case Instruction::Mul: {
545     Constant *C;
546     if (match(BinOp0, m_Constant(C))) {
547       // trunc (binop C, X) --> binop (trunc C', X)
548       Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy);
549       Value *TruncX = Builder.CreateTrunc(BinOp1, DestTy);
550       return BinaryOperator::Create(BinOp->getOpcode(), NarrowC, TruncX);
551     }
552     if (match(BinOp1, m_Constant(C))) {
553       // trunc (binop X, C) --> binop (trunc X, C')
554       Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy);
555       Value *TruncX = Builder.CreateTrunc(BinOp0, DestTy);
556       return BinaryOperator::Create(BinOp->getOpcode(), TruncX, NarrowC);
557     }
558     Value *X;
559     if (match(BinOp0, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) {
560       // trunc (binop (ext X), Y) --> binop X, (trunc Y)
561       Value *NarrowOp1 = Builder.CreateTrunc(BinOp1, DestTy);
562       return BinaryOperator::Create(BinOp->getOpcode(), X, NarrowOp1);
563     }
564     if (match(BinOp1, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) {
565       // trunc (binop Y, (ext X)) --> binop (trunc Y), X
566       Value *NarrowOp0 = Builder.CreateTrunc(BinOp0, DestTy);
567       return BinaryOperator::Create(BinOp->getOpcode(), NarrowOp0, X);
568     }
569     break;
570   }
571   case Instruction::LShr:
572   case Instruction::AShr: {
573     // trunc (*shr (trunc A), C) --> trunc(*shr A, C)
574     Value *A;
575     Constant *C;
576     if (match(BinOp0, m_Trunc(m_Value(A))) && match(BinOp1, m_Constant(C))) {
577       unsigned MaxShiftAmt = SrcWidth - DestWidth;
578       // If the shift is small enough, all zero/sign bits created by the shift
579       // are removed by the trunc.
580       if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
581                                       APInt(SrcWidth, MaxShiftAmt)))) {
582         auto *OldShift = cast<Instruction>(Trunc.getOperand(0));
583         bool IsExact = OldShift->isExact();
584         if (Constant *ShAmt = ConstantFoldIntegerCast(C, A->getType(),
585                                                       /*IsSigned*/ true, DL)) {
586           ShAmt = Constant::mergeUndefsWith(ShAmt, C);
587           Value *Shift =
588               OldShift->getOpcode() == Instruction::AShr
589                   ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)
590                   : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact);
591           return CastInst::CreateTruncOrBitCast(Shift, DestTy);
592         }
593       }
594     }
595     break;
596   }
597   default: break;
598   }
599 
600   if (Instruction *NarrowOr = narrowFunnelShift(Trunc))
601     return NarrowOr;
602 
603   return nullptr;
604 }
605 
606 /// Try to narrow the width of a splat shuffle. This could be generalized to any
607 /// shuffle with a constant operand, but we limit the transform to avoid
608 /// creating a shuffle type that targets may not be able to lower effectively.
609 static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
610                                        InstCombiner::BuilderTy &Builder) {
611   auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
612   if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) &&
613       all_equal(Shuf->getShuffleMask()) &&
614       Shuf->getType() == Shuf->getOperand(0)->getType()) {
615     // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask
616     // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask
617     Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType());
618     return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask());
619   }
620 
621   return nullptr;
622 }
623 
624 /// Try to narrow the width of an insert element. This could be generalized for
625 /// any vector constant, but we limit the transform to insertion into undef to
626 /// avoid potential backend problems from unsupported insertion widths. This
627 /// could also be extended to handle the case of inserting a scalar constant
628 /// into a vector variable.
629 static Instruction *shrinkInsertElt(CastInst &Trunc,
630                                     InstCombiner::BuilderTy &Builder) {
631   Instruction::CastOps Opcode = Trunc.getOpcode();
632   assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
633          "Unexpected instruction for shrinking");
634 
635   auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0));
636   if (!InsElt || !InsElt->hasOneUse())
637     return nullptr;
638 
639   Type *DestTy = Trunc.getType();
640   Type *DestScalarTy = DestTy->getScalarType();
641   Value *VecOp = InsElt->getOperand(0);
642   Value *ScalarOp = InsElt->getOperand(1);
643   Value *Index = InsElt->getOperand(2);
644 
645   if (match(VecOp, m_Undef())) {
646     // trunc   (inselt undef, X, Index) --> inselt undef,   (trunc X), Index
647     // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index
648     UndefValue *NarrowUndef = UndefValue::get(DestTy);
649     Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy);
650     return InsertElementInst::Create(NarrowUndef, NarrowOp, Index);
651   }
652 
653   return nullptr;
654 }
655 
656 Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
657   if (Instruction *Result = commonCastTransforms(Trunc))
658     return Result;
659 
660   Value *Src = Trunc.getOperand(0);
661   Type *DestTy = Trunc.getType(), *SrcTy = Src->getType();
662   unsigned DestWidth = DestTy->getScalarSizeInBits();
663   unsigned SrcWidth = SrcTy->getScalarSizeInBits();
664 
665   // Attempt to truncate the entire input expression tree to the destination
666   // type.   Only do this if the dest type is a simple type, don't convert the
667   // expression tree to something weird like i93 unless the source is also
668   // strange.
669   if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
670       canEvaluateTruncated(Src, DestTy, *this, &Trunc)) {
671 
672     // If this cast is a truncate, evaluting in a different type always
673     // eliminates the cast, so it is always a win.
674     LLVM_DEBUG(
675         dbgs() << "ICE: EvaluateInDifferentType converting expression type"
676                   " to avoid cast: "
677                << Trunc << '\n');
678     Value *Res = EvaluateInDifferentType(Src, DestTy, false);
679     assert(Res->getType() == DestTy);
680     return replaceInstUsesWith(Trunc, Res);
681   }
682 
683   // For integer types, check if we can shorten the entire input expression to
684   // DestWidth * 2, which won't allow removing the truncate, but reducing the
685   // width may enable further optimizations, e.g. allowing for larger
686   // vectorization factors.
687   if (auto *DestITy = dyn_cast<IntegerType>(DestTy)) {
688     if (DestWidth * 2 < SrcWidth) {
689       auto *NewDestTy = DestITy->getExtendedType();
690       if (shouldChangeType(SrcTy, NewDestTy) &&
691           canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) {
692         LLVM_DEBUG(
693             dbgs() << "ICE: EvaluateInDifferentType converting expression type"
694                       " to reduce the width of operand of"
695                    << Trunc << '\n');
696         Value *Res = EvaluateInDifferentType(Src, NewDestTy, false);
697         return new TruncInst(Res, DestTy);
698       }
699     }
700   }
701 
702   // Test if the trunc is the user of a select which is part of a
703   // minimum or maximum operation. If so, don't do any more simplification.
704   // Even simplifying demanded bits can break the canonical form of a
705   // min/max.
706   Value *LHS, *RHS;
707   if (SelectInst *Sel = dyn_cast<SelectInst>(Src))
708     if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN)
709       return nullptr;
710 
711   // See if we can simplify any instructions used by the input whose sole
712   // purpose is to compute bits we don't care about.
713   if (SimplifyDemandedInstructionBits(Trunc))
714     return &Trunc;
715 
716   if (DestWidth == 1) {
717     Value *Zero = Constant::getNullValue(SrcTy);
718     if (DestTy->isIntegerTy()) {
719       // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only).
720       // TODO: We canonicalize to more instructions here because we are probably
721       // lacking equivalent analysis for trunc relative to icmp. There may also
722       // be codegen concerns. If those trunc limitations were removed, we could
723       // remove this transform.
724       Value *And = Builder.CreateAnd(Src, ConstantInt::get(SrcTy, 1));
725       return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
726     }
727 
728     // For vectors, we do not canonicalize all truncs to icmp, so optimize
729     // patterns that would be covered within visitICmpInst.
730     Value *X;
731     Constant *C;
732     if (match(Src, m_OneUse(m_LShr(m_Value(X), m_Constant(C))))) {
733       // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0
734       Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1));
735       Constant *MaskC = ConstantExpr::getShl(One, C);
736       Value *And = Builder.CreateAnd(X, MaskC);
737       return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
738     }
739     if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_ImmConstant(C)),
740                                    m_Deferred(X))))) {
741       // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0
742       Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1));
743       Constant *MaskC = ConstantExpr::getShl(One, C);
744       Value *And = Builder.CreateAnd(X, Builder.CreateOr(MaskC, One));
745       return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
746     }
747   }
748 
749   Value *A, *B;
750   Constant *C;
751   if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) {
752     unsigned AWidth = A->getType()->getScalarSizeInBits();
753     unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth);
754     auto *OldSh = cast<Instruction>(Src);
755     bool IsExact = OldSh->isExact();
756 
757     // If the shift is small enough, all zero bits created by the shift are
758     // removed by the trunc.
759     if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
760                                     APInt(SrcWidth, MaxShiftAmt)))) {
761       auto GetNewShAmt = [&](unsigned Width) {
762         Constant *MaxAmt = ConstantInt::get(SrcTy, Width - 1, false);
763         Constant *Cmp =
764             ConstantFoldCompareInstOperands(ICmpInst::ICMP_ULT, C, MaxAmt, DL);
765         Constant *ShAmt = ConstantFoldSelectInstruction(Cmp, C, MaxAmt);
766         return ConstantFoldCastOperand(Instruction::Trunc, ShAmt, A->getType(),
767                                        DL);
768       };
769 
770       // trunc (lshr (sext A), C) --> ashr A, C
771       if (A->getType() == DestTy) {
772         Constant *ShAmt = GetNewShAmt(DestWidth);
773         ShAmt = Constant::mergeUndefsWith(ShAmt, C);
774         return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt)
775                        : BinaryOperator::CreateAShr(A, ShAmt);
776       }
777       // The types are mismatched, so create a cast after shifting:
778       // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C)
779       if (Src->hasOneUse()) {
780         Constant *ShAmt = GetNewShAmt(AWidth);
781         Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact);
782         return CastInst::CreateIntegerCast(Shift, DestTy, true);
783       }
784     }
785     // TODO: Mask high bits with 'and'.
786   }
787 
788   if (Instruction *I = narrowBinOp(Trunc))
789     return I;
790 
791   if (Instruction *I = shrinkSplatShuffle(Trunc, Builder))
792     return I;
793 
794   if (Instruction *I = shrinkInsertElt(Trunc, Builder))
795     return I;
796 
797   if (Src->hasOneUse() &&
798       (isa<VectorType>(SrcTy) || shouldChangeType(SrcTy, DestTy))) {
799     // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the
800     // dest type is native and cst < dest size.
801     if (match(Src, m_Shl(m_Value(A), m_Constant(C))) &&
802         !match(A, m_Shr(m_Value(), m_Constant()))) {
803       // Skip shifts of shift by constants. It undoes a combine in
804       // FoldShiftByConstant and is the extend in reg pattern.
805       APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth);
806       if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold))) {
807         Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr");
808         return BinaryOperator::Create(Instruction::Shl, NewTrunc,
809                                       ConstantExpr::getTrunc(C, DestTy));
810       }
811     }
812   }
813 
814   if (Instruction *I = foldVecTruncToExtElt(Trunc, *this))
815     return I;
816 
817   // Whenever an element is extracted from a vector, and then truncated,
818   // canonicalize by converting it to a bitcast followed by an
819   // extractelement.
820   //
821   // Example (little endian):
822   //   trunc (extractelement <4 x i64> %X, 0) to i32
823   //   --->
824   //   extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
825   Value *VecOp;
826   ConstantInt *Cst;
827   if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
828     auto *VecOpTy = cast<VectorType>(VecOp->getType());
829     auto VecElts = VecOpTy->getElementCount();
830 
831     // A badly fit destination size would result in an invalid cast.
832     if (SrcWidth % DestWidth == 0) {
833       uint64_t TruncRatio = SrcWidth / DestWidth;
834       uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
835       uint64_t VecOpIdx = Cst->getZExtValue();
836       uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
837                                          : VecOpIdx * TruncRatio;
838       assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
839              "overflow 32-bits");
840 
841       auto *BitCastTo =
842           VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
843       Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
844       return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
845     }
846   }
847 
848   // trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C)
849   if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)),
850                                                        m_Value(B))))) {
851     unsigned AWidth = A->getType()->getScalarSizeInBits();
852     if (AWidth == DestWidth && AWidth > Log2_32(SrcWidth)) {
853       Value *WidthDiff = ConstantInt::get(A->getType(), SrcWidth - AWidth);
854       Value *NarrowCtlz =
855           Builder.CreateIntrinsic(Intrinsic::ctlz, {Trunc.getType()}, {A, B});
856       return BinaryOperator::CreateAdd(NarrowCtlz, WidthDiff);
857     }
858   }
859 
860   if (match(Src, m_VScale())) {
861     if (Trunc.getFunction() &&
862         Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
863       Attribute Attr =
864           Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange);
865       if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
866         if (Log2_32(*MaxVScale) < DestWidth) {
867           Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
868           return replaceInstUsesWith(Trunc, VScale);
869         }
870       }
871     }
872   }
873 
874   return nullptr;
875 }
876 
877 Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp,
878                                                  ZExtInst &Zext) {
879   // If we are just checking for a icmp eq of a single bit and zext'ing it
880   // to an integer, then shift the bit to the appropriate place and then
881   // cast to integer to avoid the comparison.
882 
883   // FIXME: This set of transforms does not check for extra uses and/or creates
884   //        an extra instruction (an optional final cast is not included
885   //        in the transform comments). We may also want to favor icmp over
886   //        shifts in cases of equal instructions because icmp has better
887   //        analysis in general (invert the transform).
888 
889   const APInt *Op1CV;
890   if (match(Cmp->getOperand(1), m_APInt(Op1CV))) {
891 
892     // zext (x <s  0) to i32 --> x>>u31      true if signbit set.
893     if (Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) {
894       Value *In = Cmp->getOperand(0);
895       Value *Sh = ConstantInt::get(In->getType(),
896                                    In->getType()->getScalarSizeInBits() - 1);
897       In = Builder.CreateLShr(In, Sh, In->getName() + ".lobit");
898       if (In->getType() != Zext.getType())
899         In = Builder.CreateIntCast(In, Zext.getType(), false /*ZExt*/);
900 
901       return replaceInstUsesWith(Zext, In);
902     }
903 
904     // zext (X == 0) to i32 --> X^1      iff X has only the low bit set.
905     // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
906     // zext (X != 0) to i32 --> X        iff X has only the low bit set.
907     // zext (X != 0) to i32 --> X>>1     iff X has only the 2nd bit set.
908 
909     if (Op1CV->isZero() && Cmp->isEquality()) {
910       // Exactly 1 possible 1? But not the high-bit because that is
911       // canonicalized to this form.
912       KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
913       APInt KnownZeroMask(~Known.Zero);
914       uint32_t ShAmt = KnownZeroMask.logBase2();
915       bool IsExpectShAmt = KnownZeroMask.isPowerOf2() &&
916                            (Zext.getType()->getScalarSizeInBits() != ShAmt + 1);
917       if (IsExpectShAmt &&
918           (Cmp->getOperand(0)->getType() == Zext.getType() ||
919            Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) {
920         Value *In = Cmp->getOperand(0);
921         if (ShAmt) {
922           // Perform a logical shr by shiftamt.
923           // Insert the shift to put the result in the low bit.
924           In = Builder.CreateLShr(In, ConstantInt::get(In->getType(), ShAmt),
925                                   In->getName() + ".lobit");
926         }
927 
928         // Toggle the low bit for "X == 0".
929         if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
930           In = Builder.CreateXor(In, ConstantInt::get(In->getType(), 1));
931 
932         if (Zext.getType() == In->getType())
933           return replaceInstUsesWith(Zext, In);
934 
935         Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false);
936         return replaceInstUsesWith(Zext, IntCast);
937       }
938     }
939   }
940 
941   if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(0)->getType()) {
942     // Test if a bit is clear/set using a shifted-one mask:
943     // zext (icmp eq (and X, (1 << ShAmt)), 0) --> and (lshr (not X), ShAmt), 1
944     // zext (icmp ne (and X, (1 << ShAmt)), 0) --> and (lshr X, ShAmt), 1
945     Value *X, *ShAmt;
946     if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) &&
947         match(Cmp->getOperand(0),
948               m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) {
949       if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
950         X = Builder.CreateNot(X);
951       Value *Lshr = Builder.CreateLShr(X, ShAmt);
952       Value *And1 = Builder.CreateAnd(Lshr, ConstantInt::get(X->getType(), 1));
953       return replaceInstUsesWith(Zext, And1);
954     }
955   }
956 
957   return nullptr;
958 }
959 
960 /// Determine if the specified value can be computed in the specified wider type
961 /// and produce the same low bits. If not, return false.
962 ///
963 /// If this function returns true, it can also return a non-zero number of bits
964 /// (in BitsToClear) which indicates that the value it computes is correct for
965 /// the zero extend, but that the additional BitsToClear bits need to be zero'd
966 /// out.  For example, to promote something like:
967 ///
968 ///   %B = trunc i64 %A to i32
969 ///   %C = lshr i32 %B, 8
970 ///   %E = zext i32 %C to i64
971 ///
972 /// CanEvaluateZExtd for the 'lshr' will return true, and BitsToClear will be
973 /// set to 8 to indicate that the promoted value needs to have bits 24-31
974 /// cleared in addition to bits 32-63.  Since an 'and' will be generated to
975 /// clear the top bits anyway, doing this has no extra cost.
976 ///
977 /// This function works on both vectors and scalars.
978 static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
979                              InstCombinerImpl &IC, Instruction *CxtI) {
980   BitsToClear = 0;
981   if (canAlwaysEvaluateInType(V, Ty))
982     return true;
983   if (canNotEvaluateInType(V, Ty))
984     return false;
985 
986   auto *I = cast<Instruction>(V);
987   unsigned Tmp;
988   switch (I->getOpcode()) {
989   case Instruction::ZExt:  // zext(zext(x)) -> zext(x).
990   case Instruction::SExt:  // zext(sext(x)) -> sext(x).
991   case Instruction::Trunc: // zext(trunc(x)) -> trunc(x) or zext(x)
992     return true;
993   case Instruction::And:
994   case Instruction::Or:
995   case Instruction::Xor:
996   case Instruction::Add:
997   case Instruction::Sub:
998   case Instruction::Mul:
999     if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) ||
1000         !canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI))
1001       return false;
1002     // These can all be promoted if neither operand has 'bits to clear'.
1003     if (BitsToClear == 0 && Tmp == 0)
1004       return true;
1005 
1006     // If the operation is an AND/OR/XOR and the bits to clear are zero in the
1007     // other side, BitsToClear is ok.
1008     if (Tmp == 0 && I->isBitwiseLogicOp()) {
1009       // We use MaskedValueIsZero here for generality, but the case we care
1010       // about the most is constant RHS.
1011       unsigned VSize = V->getType()->getScalarSizeInBits();
1012       if (IC.MaskedValueIsZero(I->getOperand(1),
1013                                APInt::getHighBitsSet(VSize, BitsToClear),
1014                                0, CxtI)) {
1015         // If this is an And instruction and all of the BitsToClear are
1016         // known to be zero we can reset BitsToClear.
1017         if (I->getOpcode() == Instruction::And)
1018           BitsToClear = 0;
1019         return true;
1020       }
1021     }
1022 
1023     // Otherwise, we don't know how to analyze this BitsToClear case yet.
1024     return false;
1025 
1026   case Instruction::Shl: {
1027     // We can promote shl(x, cst) if we can promote x.  Since shl overwrites the
1028     // upper bits we can reduce BitsToClear by the shift amount.
1029     const APInt *Amt;
1030     if (match(I->getOperand(1), m_APInt(Amt))) {
1031       if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
1032         return false;
1033       uint64_t ShiftAmt = Amt->getZExtValue();
1034       BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
1035       return true;
1036     }
1037     return false;
1038   }
1039   case Instruction::LShr: {
1040     // We can promote lshr(x, cst) if we can promote x.  This requires the
1041     // ultimate 'and' to clear out the high zero bits we're clearing out though.
1042     const APInt *Amt;
1043     if (match(I->getOperand(1), m_APInt(Amt))) {
1044       if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
1045         return false;
1046       BitsToClear += Amt->getZExtValue();
1047       if (BitsToClear > V->getType()->getScalarSizeInBits())
1048         BitsToClear = V->getType()->getScalarSizeInBits();
1049       return true;
1050     }
1051     // Cannot promote variable LSHR.
1052     return false;
1053   }
1054   case Instruction::Select:
1055     if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) ||
1056         !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) ||
1057         // TODO: If important, we could handle the case when the BitsToClear are
1058         // known zero in the disagreeing side.
1059         Tmp != BitsToClear)
1060       return false;
1061     return true;
1062 
1063   case Instruction::PHI: {
1064     // We can change a phi if we can change all operands.  Note that we never
1065     // get into trouble with cyclic PHIs here because we only consider
1066     // instructions with a single use.
1067     PHINode *PN = cast<PHINode>(I);
1068     if (!canEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI))
1069       return false;
1070     for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
1071       if (!canEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) ||
1072           // TODO: If important, we could handle the case when the BitsToClear
1073           // are known zero in the disagreeing input.
1074           Tmp != BitsToClear)
1075         return false;
1076     return true;
1077   }
1078   case Instruction::Call:
1079     // llvm.vscale() can always be executed in larger type, because the
1080     // value is automatically zero-extended.
1081     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
1082       if (II->getIntrinsicID() == Intrinsic::vscale)
1083         return true;
1084     return false;
1085   default:
1086     // TODO: Can handle more cases here.
1087     return false;
1088   }
1089 }
1090 
1091 Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
1092   // If this zero extend is only used by a truncate, let the truncate be
1093   // eliminated before we try to optimize this zext.
1094   if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back()) &&
1095       !isa<Constant>(Zext.getOperand(0)))
1096     return nullptr;
1097 
1098   // If one of the common conversion will work, do it.
1099   if (Instruction *Result = commonCastTransforms(Zext))
1100     return Result;
1101 
1102   Value *Src = Zext.getOperand(0);
1103   Type *SrcTy = Src->getType(), *DestTy = Zext.getType();
1104 
1105   // Try to extend the entire expression tree to the wide destination type.
1106   unsigned BitsToClear;
1107   if (shouldChangeType(SrcTy, DestTy) &&
1108       canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) {
1109     assert(BitsToClear <= SrcTy->getScalarSizeInBits() &&
1110            "Can't clear more bits than in SrcTy");
1111 
1112     // Okay, we can transform this!  Insert the new expression now.
1113     LLVM_DEBUG(
1114         dbgs() << "ICE: EvaluateInDifferentType converting expression type"
1115                   " to avoid zero extend: "
1116                << Zext << '\n');
1117     Value *Res = EvaluateInDifferentType(Src, DestTy, false);
1118     assert(Res->getType() == DestTy);
1119 
1120     // Preserve debug values referring to Src if the zext is its last use.
1121     if (auto *SrcOp = dyn_cast<Instruction>(Src))
1122       if (SrcOp->hasOneUse())
1123         replaceAllDbgUsesWith(*SrcOp, *Res, Zext, DT);
1124 
1125     uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear;
1126     uint32_t DestBitSize = DestTy->getScalarSizeInBits();
1127 
1128     // If the high bits are already filled with zeros, just replace this
1129     // cast with the result.
1130     if (MaskedValueIsZero(Res,
1131                           APInt::getHighBitsSet(DestBitSize,
1132                                                 DestBitSize - SrcBitsKept),
1133                              0, &Zext))
1134       return replaceInstUsesWith(Zext, Res);
1135 
1136     // We need to emit an AND to clear the high bits.
1137     Constant *C = ConstantInt::get(Res->getType(),
1138                                APInt::getLowBitsSet(DestBitSize, SrcBitsKept));
1139     return BinaryOperator::CreateAnd(Res, C);
1140   }
1141 
1142   // If this is a TRUNC followed by a ZEXT then we are dealing with integral
1143   // types and if the sizes are just right we can convert this into a logical
1144   // 'and' which will be much cheaper than the pair of casts.
1145   if (auto *CSrc = dyn_cast<TruncInst>(Src)) {   // A->B->C cast
1146     // TODO: Subsume this into EvaluateInDifferentType.
1147 
1148     // Get the sizes of the types involved.  We know that the intermediate type
1149     // will be smaller than A or C, but don't know the relation between A and C.
1150     Value *A = CSrc->getOperand(0);
1151     unsigned SrcSize = A->getType()->getScalarSizeInBits();
1152     unsigned MidSize = CSrc->getType()->getScalarSizeInBits();
1153     unsigned DstSize = DestTy->getScalarSizeInBits();
1154     // If we're actually extending zero bits, then if
1155     // SrcSize <  DstSize: zext(a & mask)
1156     // SrcSize == DstSize: a & mask
1157     // SrcSize  > DstSize: trunc(a) & mask
1158     if (SrcSize < DstSize) {
1159       APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize));
1160       Constant *AndConst = ConstantInt::get(A->getType(), AndValue);
1161       Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask");
1162       return new ZExtInst(And, DestTy);
1163     }
1164 
1165     if (SrcSize == DstSize) {
1166       APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize));
1167       return BinaryOperator::CreateAnd(A, ConstantInt::get(A->getType(),
1168                                                            AndValue));
1169     }
1170     if (SrcSize > DstSize) {
1171       Value *Trunc = Builder.CreateTrunc(A, DestTy);
1172       APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize));
1173       return BinaryOperator::CreateAnd(Trunc,
1174                                        ConstantInt::get(Trunc->getType(),
1175                                                         AndValue));
1176     }
1177   }
1178 
1179   if (auto *Cmp = dyn_cast<ICmpInst>(Src))
1180     return transformZExtICmp(Cmp, Zext);
1181 
1182   // zext(trunc(X) & C) -> (X & zext(C)).
1183   Constant *C;
1184   Value *X;
1185   if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) &&
1186       X->getType() == DestTy)
1187     return BinaryOperator::CreateAnd(X, Builder.CreateZExt(C, DestTy));
1188 
1189   // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)).
1190   Value *And;
1191   if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) &&
1192       match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) &&
1193       X->getType() == DestTy) {
1194     Value *ZC = Builder.CreateZExt(C, DestTy);
1195     return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC);
1196   }
1197 
1198   // If we are truncating, masking, and then zexting back to the original type,
1199   // that's just a mask. This is not handled by canEvaluateZextd if the
1200   // intermediate values have extra uses. This could be generalized further for
1201   // a non-constant mask operand.
1202   // zext (and (trunc X), C) --> and X, (zext C)
1203   if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) &&
1204       X->getType() == DestTy) {
1205     Value *ZextC = Builder.CreateZExt(C, DestTy);
1206     return BinaryOperator::CreateAnd(X, ZextC);
1207   }
1208 
1209   if (match(Src, m_VScale())) {
1210     if (Zext.getFunction() &&
1211         Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
1212       Attribute Attr =
1213           Zext.getFunction()->getFnAttribute(Attribute::VScaleRange);
1214       if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
1215         unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
1216         if (Log2_32(*MaxVScale) < TypeWidth) {
1217           Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
1218           return replaceInstUsesWith(Zext, VScale);
1219         }
1220       }
1221     }
1222   }
1223 
1224   if (!Zext.hasNonNeg()) {
1225     // If this zero extend is only used by a shift, add nneg flag.
1226     if (Zext.hasOneUse() &&
1227         SrcTy->getScalarSizeInBits() >
1228             Log2_64_Ceil(DestTy->getScalarSizeInBits()) &&
1229         match(Zext.user_back(), m_Shift(m_Value(), m_Specific(&Zext)))) {
1230       Zext.setNonNeg();
1231       return &Zext;
1232     }
1233 
1234     if (isKnownNonNegative(Src, SQ.getWithInstruction(&Zext))) {
1235       Zext.setNonNeg();
1236       return &Zext;
1237     }
1238   }
1239 
1240   return nullptr;
1241 }
1242 
1243 /// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp.
1244 Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp,
1245                                                  SExtInst &Sext) {
1246   Value *Op0 = Cmp->getOperand(0), *Op1 = Cmp->getOperand(1);
1247   ICmpInst::Predicate Pred = Cmp->getPredicate();
1248 
1249   // Don't bother if Op1 isn't of vector or integer type.
1250   if (!Op1->getType()->isIntOrIntVectorTy())
1251     return nullptr;
1252 
1253   if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) {
1254     // sext (x <s 0) --> ashr x, 31 (all ones if negative)
1255     Value *Sh = ConstantInt::get(Op0->getType(),
1256                                  Op0->getType()->getScalarSizeInBits() - 1);
1257     Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit");
1258     if (In->getType() != Sext.getType())
1259       In = Builder.CreateIntCast(In, Sext.getType(), true /*SExt*/);
1260 
1261     return replaceInstUsesWith(Sext, In);
1262   }
1263 
1264   if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
1265     // If we know that only one bit of the LHS of the icmp can be set and we
1266     // have an equality comparison with zero or a power of 2, we can transform
1267     // the icmp and sext into bitwise/integer operations.
1268     if (Cmp->hasOneUse() &&
1269         Cmp->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){
1270       KnownBits Known = computeKnownBits(Op0, 0, &Sext);
1271 
1272       APInt KnownZeroMask(~Known.Zero);
1273       if (KnownZeroMask.isPowerOf2()) {
1274         Value *In = Cmp->getOperand(0);
1275 
1276         // If the icmp tests for a known zero bit we can constant fold it.
1277         if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) {
1278           Value *V = Pred == ICmpInst::ICMP_NE ?
1279                        ConstantInt::getAllOnesValue(Sext.getType()) :
1280                        ConstantInt::getNullValue(Sext.getType());
1281           return replaceInstUsesWith(Sext, V);
1282         }
1283 
1284         if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) {
1285           // sext ((x & 2^n) == 0)   -> (x >> n) - 1
1286           // sext ((x & 2^n) != 2^n) -> (x >> n) - 1
1287           unsigned ShiftAmt = KnownZeroMask.countr_zero();
1288           // Perform a right shift to place the desired bit in the LSB.
1289           if (ShiftAmt)
1290             In = Builder.CreateLShr(In,
1291                                     ConstantInt::get(In->getType(), ShiftAmt));
1292 
1293           // At this point "In" is either 1 or 0. Subtract 1 to turn
1294           // {1, 0} -> {0, -1}.
1295           In = Builder.CreateAdd(In,
1296                                  ConstantInt::getAllOnesValue(In->getType()),
1297                                  "sext");
1298         } else {
1299           // sext ((x & 2^n) != 0)   -> (x << bitwidth-n) a>> bitwidth-1
1300           // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1
1301           unsigned ShiftAmt = KnownZeroMask.countl_zero();
1302           // Perform a left shift to place the desired bit in the MSB.
1303           if (ShiftAmt)
1304             In = Builder.CreateShl(In,
1305                                    ConstantInt::get(In->getType(), ShiftAmt));
1306 
1307           // Distribute the bit over the whole bit width.
1308           In = Builder.CreateAShr(In, ConstantInt::get(In->getType(),
1309                                   KnownZeroMask.getBitWidth() - 1), "sext");
1310         }
1311 
1312         if (Sext.getType() == In->getType())
1313           return replaceInstUsesWith(Sext, In);
1314         return CastInst::CreateIntegerCast(In, Sext.getType(), true/*SExt*/);
1315       }
1316     }
1317   }
1318 
1319   return nullptr;
1320 }
1321 
1322 /// Return true if we can take the specified value and return it as type Ty
1323 /// without inserting any new casts and without changing the value of the common
1324 /// low bits.  This is used by code that tries to promote integer operations to
1325 /// a wider types will allow us to eliminate the extension.
1326 ///
1327 /// This function works on both vectors and scalars.
1328 ///
1329 static bool canEvaluateSExtd(Value *V, Type *Ty) {
1330   assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() &&
1331          "Can't sign extend type to a smaller type");
1332   if (canAlwaysEvaluateInType(V, Ty))
1333     return true;
1334   if (canNotEvaluateInType(V, Ty))
1335     return false;
1336 
1337   auto *I = cast<Instruction>(V);
1338   switch (I->getOpcode()) {
1339   case Instruction::SExt:  // sext(sext(x)) -> sext(x)
1340   case Instruction::ZExt:  // sext(zext(x)) -> zext(x)
1341   case Instruction::Trunc: // sext(trunc(x)) -> trunc(x) or sext(x)
1342     return true;
1343   case Instruction::And:
1344   case Instruction::Or:
1345   case Instruction::Xor:
1346   case Instruction::Add:
1347   case Instruction::Sub:
1348   case Instruction::Mul:
1349     // These operators can all arbitrarily be extended if their inputs can.
1350     return canEvaluateSExtd(I->getOperand(0), Ty) &&
1351            canEvaluateSExtd(I->getOperand(1), Ty);
1352 
1353   //case Instruction::Shl:   TODO
1354   //case Instruction::LShr:  TODO
1355 
1356   case Instruction::Select:
1357     return canEvaluateSExtd(I->getOperand(1), Ty) &&
1358            canEvaluateSExtd(I->getOperand(2), Ty);
1359 
1360   case Instruction::PHI: {
1361     // We can change a phi if we can change all operands.  Note that we never
1362     // get into trouble with cyclic PHIs here because we only consider
1363     // instructions with a single use.
1364     PHINode *PN = cast<PHINode>(I);
1365     for (Value *IncValue : PN->incoming_values())
1366       if (!canEvaluateSExtd(IncValue, Ty)) return false;
1367     return true;
1368   }
1369   default:
1370     // TODO: Can handle more cases here.
1371     break;
1372   }
1373 
1374   return false;
1375 }
1376 
1377 Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
1378   // If this sign extend is only used by a truncate, let the truncate be
1379   // eliminated before we try to optimize this sext.
1380   if (Sext.hasOneUse() && isa<TruncInst>(Sext.user_back()))
1381     return nullptr;
1382 
1383   if (Instruction *I = commonCastTransforms(Sext))
1384     return I;
1385 
1386   Value *Src = Sext.getOperand(0);
1387   Type *SrcTy = Src->getType(), *DestTy = Sext.getType();
1388   unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
1389   unsigned DestBitSize = DestTy->getScalarSizeInBits();
1390 
1391   // If the value being extended is zero or positive, use a zext instead.
1392   if (isKnownNonNegative(Src, SQ.getWithInstruction(&Sext))) {
1393     auto CI = CastInst::Create(Instruction::ZExt, Src, DestTy);
1394     CI->setNonNeg(true);
1395     return CI;
1396   }
1397 
1398   // Try to extend the entire expression tree to the wide destination type.
1399   if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) {
1400     // Okay, we can transform this!  Insert the new expression now.
1401     LLVM_DEBUG(
1402         dbgs() << "ICE: EvaluateInDifferentType converting expression type"
1403                   " to avoid sign extend: "
1404                << Sext << '\n');
1405     Value *Res = EvaluateInDifferentType(Src, DestTy, true);
1406     assert(Res->getType() == DestTy);
1407 
1408     // If the high bits are already filled with sign bit, just replace this
1409     // cast with the result.
1410     if (ComputeNumSignBits(Res, 0, &Sext) > DestBitSize - SrcBitSize)
1411       return replaceInstUsesWith(Sext, Res);
1412 
1413     // We need to emit a shl + ashr to do the sign extend.
1414     Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize);
1415     return BinaryOperator::CreateAShr(Builder.CreateShl(Res, ShAmt, "sext"),
1416                                       ShAmt);
1417   }
1418 
1419   Value *X;
1420   if (match(Src, m_Trunc(m_Value(X)))) {
1421     // If the input has more sign bits than bits truncated, then convert
1422     // directly to final type.
1423     unsigned XBitSize = X->getType()->getScalarSizeInBits();
1424     if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
1425       return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
1426 
1427     // If input is a trunc from the destination type, then convert into shifts.
1428     if (Src->hasOneUse() && X->getType() == DestTy) {
1429       // sext (trunc X) --> ashr (shl X, C), C
1430       Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
1431       return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt);
1432     }
1433 
1434     // If we are replacing shifted-in high zero bits with sign bits, convert
1435     // the logic shift to arithmetic shift and eliminate the cast to
1436     // intermediate type:
1437     // sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C)
1438     Value *Y;
1439     if (Src->hasOneUse() &&
1440         match(X, m_LShr(m_Value(Y),
1441                         m_SpecificIntAllowUndef(XBitSize - SrcBitSize)))) {
1442       Value *Ashr = Builder.CreateAShr(Y, XBitSize - SrcBitSize);
1443       return CastInst::CreateIntegerCast(Ashr, DestTy, /* isSigned */ true);
1444     }
1445   }
1446 
1447   if (auto *Cmp = dyn_cast<ICmpInst>(Src))
1448     return transformSExtICmp(Cmp, Sext);
1449 
1450   // If the input is a shl/ashr pair of a same constant, then this is a sign
1451   // extension from a smaller value.  If we could trust arbitrary bitwidth
1452   // integers, we could turn this into a truncate to the smaller bit and then
1453   // use a sext for the whole extension.  Since we don't, look deeper and check
1454   // for a truncate.  If the source and dest are the same type, eliminate the
1455   // trunc and extend and just do shifts.  For example, turn:
1456   //   %a = trunc i32 %i to i8
1457   //   %b = shl i8 %a, C
1458   //   %c = ashr i8 %b, C
1459   //   %d = sext i8 %c to i32
1460   // into:
1461   //   %a = shl i32 %i, 32-(8-C)
1462   //   %d = ashr i32 %a, 32-(8-C)
1463   Value *A = nullptr;
1464   // TODO: Eventually this could be subsumed by EvaluateInDifferentType.
1465   Constant *BA = nullptr, *CA = nullptr;
1466   if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)),
1467                         m_ImmConstant(CA))) &&
1468       BA->isElementWiseEqual(CA) && A->getType() == DestTy) {
1469     Constant *WideCurrShAmt =
1470         ConstantFoldCastOperand(Instruction::SExt, CA, DestTy, DL);
1471     assert(WideCurrShAmt && "Constant folding of ImmConstant cannot fail");
1472     Constant *NumLowbitsLeft = ConstantExpr::getSub(
1473         ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt);
1474     Constant *NewShAmt = ConstantExpr::getSub(
1475         ConstantInt::get(DestTy, DestTy->getScalarSizeInBits()),
1476         NumLowbitsLeft);
1477     NewShAmt =
1478         Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA);
1479     A = Builder.CreateShl(A, NewShAmt, Sext.getName());
1480     return BinaryOperator::CreateAShr(A, NewShAmt);
1481   }
1482 
1483   // Splatting a bit of constant-index across a value:
1484   // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1
1485   // If the dest type is different, use a cast (adjust use check).
1486   if (match(Src, m_OneUse(m_AShr(m_Trunc(m_Value(X)),
1487                                  m_SpecificInt(SrcBitSize - 1))))) {
1488     Type *XTy = X->getType();
1489     unsigned XBitSize = XTy->getScalarSizeInBits();
1490     Constant *ShlAmtC = ConstantInt::get(XTy, XBitSize - SrcBitSize);
1491     Constant *AshrAmtC = ConstantInt::get(XTy, XBitSize - 1);
1492     if (XTy == DestTy)
1493       return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShlAmtC),
1494                                         AshrAmtC);
1495     if (cast<BinaryOperator>(Src)->getOperand(0)->hasOneUse()) {
1496       Value *Ashr = Builder.CreateAShr(Builder.CreateShl(X, ShlAmtC), AshrAmtC);
1497       return CastInst::CreateIntegerCast(Ashr, DestTy, /* isSigned */ true);
1498     }
1499   }
1500 
1501   if (match(Src, m_VScale())) {
1502     if (Sext.getFunction() &&
1503         Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
1504       Attribute Attr =
1505           Sext.getFunction()->getFnAttribute(Attribute::VScaleRange);
1506       if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
1507         if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) {
1508           Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
1509           return replaceInstUsesWith(Sext, VScale);
1510         }
1511       }
1512     }
1513   }
1514 
1515   return nullptr;
1516 }
1517 
1518 /// Return a Constant* for the specified floating-point constant if it fits
1519 /// in the specified FP type without changing its value.
1520 static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
1521   bool losesInfo;
1522   APFloat F = CFP->getValueAPF();
1523   (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
1524   return !losesInfo;
1525 }
1526 
1527 static Type *shrinkFPConstant(ConstantFP *CFP) {
1528   if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
1529     return nullptr;  // No constant folding of this.
1530   // See if the value can be truncated to half and then reextended.
1531   if (fitsInFPType(CFP, APFloat::IEEEhalf()))
1532     return Type::getHalfTy(CFP->getContext());
1533   // See if the value can be truncated to float and then reextended.
1534   if (fitsInFPType(CFP, APFloat::IEEEsingle()))
1535     return Type::getFloatTy(CFP->getContext());
1536   if (CFP->getType()->isDoubleTy())
1537     return nullptr;  // Won't shrink.
1538   if (fitsInFPType(CFP, APFloat::IEEEdouble()))
1539     return Type::getDoubleTy(CFP->getContext());
1540   // Don't try to shrink to various long double types.
1541   return nullptr;
1542 }
1543 
1544 // Determine if this is a vector of ConstantFPs and if so, return the minimal
1545 // type we can safely truncate all elements to.
1546 static Type *shrinkFPConstantVector(Value *V) {
1547   auto *CV = dyn_cast<Constant>(V);
1548   auto *CVVTy = dyn_cast<FixedVectorType>(V->getType());
1549   if (!CV || !CVVTy)
1550     return nullptr;
1551 
1552   Type *MinType = nullptr;
1553 
1554   unsigned NumElts = CVVTy->getNumElements();
1555 
1556   // For fixed-width vectors we find the minimal type by looking
1557   // through the constant values of the vector.
1558   for (unsigned i = 0; i != NumElts; ++i) {
1559     if (isa<UndefValue>(CV->getAggregateElement(i)))
1560       continue;
1561 
1562     auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
1563     if (!CFP)
1564       return nullptr;
1565 
1566     Type *T = shrinkFPConstant(CFP);
1567     if (!T)
1568       return nullptr;
1569 
1570     // If we haven't found a type yet or this type has a larger mantissa than
1571     // our previous type, this is our new minimal type.
1572     if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth())
1573       MinType = T;
1574   }
1575 
1576   // Make a vector type from the minimal type.
1577   return MinType ? FixedVectorType::get(MinType, NumElts) : nullptr;
1578 }
1579 
1580 /// Find the minimum FP type we can safely truncate to.
1581 static Type *getMinimumFPType(Value *V) {
1582   if (auto *FPExt = dyn_cast<FPExtInst>(V))
1583     return FPExt->getOperand(0)->getType();
1584 
1585   // If this value is a constant, return the constant in the smallest FP type
1586   // that can accurately represent it.  This allows us to turn
1587   // (float)((double)X+2.0) into x+2.0f.
1588   if (auto *CFP = dyn_cast<ConstantFP>(V))
1589     if (Type *T = shrinkFPConstant(CFP))
1590       return T;
1591 
1592   // We can only correctly find a minimum type for a scalable vector when it is
1593   // a splat. For splats of constant values the fpext is wrapped up as a
1594   // ConstantExpr.
1595   if (auto *FPCExt = dyn_cast<ConstantExpr>(V))
1596     if (FPCExt->getOpcode() == Instruction::FPExt)
1597       return FPCExt->getOperand(0)->getType();
1598 
1599   // Try to shrink a vector of FP constants. This returns nullptr on scalable
1600   // vectors
1601   if (Type *T = shrinkFPConstantVector(V))
1602     return T;
1603 
1604   return V->getType();
1605 }
1606 
1607 /// Return true if the cast from integer to FP can be proven to be exact for all
1608 /// possible inputs (the conversion does not lose any precision).
1609 static bool isKnownExactCastIntToFP(CastInst &I, InstCombinerImpl &IC) {
1610   CastInst::CastOps Opcode = I.getOpcode();
1611   assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) &&
1612          "Unexpected cast");
1613   Value *Src = I.getOperand(0);
1614   Type *SrcTy = Src->getType();
1615   Type *FPTy = I.getType();
1616   bool IsSigned = Opcode == Instruction::SIToFP;
1617   int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned;
1618 
1619   // Easy case - if the source integer type has less bits than the FP mantissa,
1620   // then the cast must be exact.
1621   int DestNumSigBits = FPTy->getFPMantissaWidth();
1622   if (SrcSize <= DestNumSigBits)
1623     return true;
1624 
1625   // Cast from FP to integer and back to FP is independent of the intermediate
1626   // integer width because of poison on overflow.
1627   Value *F;
1628   if (match(Src, m_FPToSI(m_Value(F))) || match(Src, m_FPToUI(m_Value(F)))) {
1629     // If this is uitofp (fptosi F), the source needs an extra bit to avoid
1630     // potential rounding of negative FP input values.
1631     int SrcNumSigBits = F->getType()->getFPMantissaWidth();
1632     if (!IsSigned && match(Src, m_FPToSI(m_Value())))
1633       SrcNumSigBits++;
1634 
1635     // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal
1636     // significant bits than the destination (and make sure neither type is
1637     // weird -- ppc_fp128).
1638     if (SrcNumSigBits > 0 && DestNumSigBits > 0 &&
1639         SrcNumSigBits <= DestNumSigBits)
1640       return true;
1641   }
1642 
1643   // TODO:
1644   // Try harder to find if the source integer type has less significant bits.
1645   // For example, compute number of sign bits.
1646   KnownBits SrcKnown = IC.computeKnownBits(Src, 0, &I);
1647   int SigBits = (int)SrcTy->getScalarSizeInBits() -
1648                 SrcKnown.countMinLeadingZeros() -
1649                 SrcKnown.countMinTrailingZeros();
1650   if (SigBits <= DestNumSigBits)
1651     return true;
1652 
1653   return false;
1654 }
1655 
1656 Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
1657   if (Instruction *I = commonCastTransforms(FPT))
1658     return I;
1659 
1660   // If we have fptrunc(OpI (fpextend x), (fpextend y)), we would like to
1661   // simplify this expression to avoid one or more of the trunc/extend
1662   // operations if we can do so without changing the numerical results.
1663   //
1664   // The exact manner in which the widths of the operands interact to limit
1665   // what we can and cannot do safely varies from operation to operation, and
1666   // is explained below in the various case statements.
1667   Type *Ty = FPT.getType();
1668   auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
1669   if (BO && BO->hasOneUse()) {
1670     Type *LHSMinType = getMinimumFPType(BO->getOperand(0));
1671     Type *RHSMinType = getMinimumFPType(BO->getOperand(1));
1672     unsigned OpWidth = BO->getType()->getFPMantissaWidth();
1673     unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
1674     unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
1675     unsigned SrcWidth = std::max(LHSWidth, RHSWidth);
1676     unsigned DstWidth = Ty->getFPMantissaWidth();
1677     switch (BO->getOpcode()) {
1678       default: break;
1679       case Instruction::FAdd:
1680       case Instruction::FSub:
1681         // For addition and subtraction, the infinitely precise result can
1682         // essentially be arbitrarily wide; proving that double rounding
1683         // will not occur because the result of OpI is exact (as we will for
1684         // FMul, for example) is hopeless.  However, we *can* nonetheless
1685         // frequently know that double rounding cannot occur (or that it is
1686         // innocuous) by taking advantage of the specific structure of
1687         // infinitely-precise results that admit double rounding.
1688         //
1689         // Specifically, if OpWidth >= 2*DstWdith+1 and DstWidth is sufficient
1690         // to represent both sources, we can guarantee that the double
1691         // rounding is innocuous (See p50 of Figueroa's 2000 PhD thesis,
1692         // "A Rigorous Framework for Fully Supporting the IEEE Standard ..."
1693         // for proof of this fact).
1694         //
1695         // Note: Figueroa does not consider the case where DstFormat !=
1696         // SrcFormat.  It's possible (likely even!) that this analysis
1697         // could be tightened for those cases, but they are rare (the main
1698         // case of interest here is (float)((double)float + float)).
1699         if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) {
1700           Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
1701           Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
1702           Instruction *RI = BinaryOperator::Create(BO->getOpcode(), LHS, RHS);
1703           RI->copyFastMathFlags(BO);
1704           return RI;
1705         }
1706         break;
1707       case Instruction::FMul:
1708         // For multiplication, the infinitely precise result has at most
1709         // LHSWidth + RHSWidth significant bits; if OpWidth is sufficient
1710         // that such a value can be exactly represented, then no double
1711         // rounding can possibly occur; we can safely perform the operation
1712         // in the destination format if it can represent both sources.
1713         if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) {
1714           Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
1715           Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
1716           return BinaryOperator::CreateFMulFMF(LHS, RHS, BO);
1717         }
1718         break;
1719       case Instruction::FDiv:
1720         // For division, we use again use the bound from Figueroa's
1721         // dissertation.  I am entirely certain that this bound can be
1722         // tightened in the unbalanced operand case by an analysis based on
1723         // the diophantine rational approximation bound, but the well-known
1724         // condition used here is a good conservative first pass.
1725         // TODO: Tighten bound via rigorous analysis of the unbalanced case.
1726         if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) {
1727           Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
1728           Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
1729           return BinaryOperator::CreateFDivFMF(LHS, RHS, BO);
1730         }
1731         break;
1732       case Instruction::FRem: {
1733         // Remainder is straightforward.  Remainder is always exact, so the
1734         // type of OpI doesn't enter into things at all.  We simply evaluate
1735         // in whichever source type is larger, then convert to the
1736         // destination type.
1737         if (SrcWidth == OpWidth)
1738           break;
1739         Value *LHS, *RHS;
1740         if (LHSWidth == SrcWidth) {
1741            LHS = Builder.CreateFPTrunc(BO->getOperand(0), LHSMinType);
1742            RHS = Builder.CreateFPTrunc(BO->getOperand(1), LHSMinType);
1743         } else {
1744            LHS = Builder.CreateFPTrunc(BO->getOperand(0), RHSMinType);
1745            RHS = Builder.CreateFPTrunc(BO->getOperand(1), RHSMinType);
1746         }
1747 
1748         Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, BO);
1749         return CastInst::CreateFPCast(ExactResult, Ty);
1750       }
1751     }
1752   }
1753 
1754   // (fptrunc (fneg x)) -> (fneg (fptrunc x))
1755   Value *X;
1756   Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0));
1757   if (Op && Op->hasOneUse()) {
1758     // FIXME: The FMF should propagate from the fptrunc, not the source op.
1759     IRBuilder<>::FastMathFlagGuard FMFG(Builder);
1760     if (isa<FPMathOperator>(Op))
1761       Builder.setFastMathFlags(Op->getFastMathFlags());
1762 
1763     if (match(Op, m_FNeg(m_Value(X)))) {
1764       Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty);
1765 
1766       return UnaryOperator::CreateFNegFMF(InnerTrunc, Op);
1767     }
1768 
1769     // If we are truncating a select that has an extended operand, we can
1770     // narrow the other operand and do the select as a narrow op.
1771     Value *Cond, *X, *Y;
1772     if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) &&
1773         X->getType() == Ty) {
1774       // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y)
1775       Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
1776       Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op);
1777       return replaceInstUsesWith(FPT, Sel);
1778     }
1779     if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) &&
1780         X->getType() == Ty) {
1781       // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X
1782       Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
1783       Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op);
1784       return replaceInstUsesWith(FPT, Sel);
1785     }
1786   }
1787 
1788   if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) {
1789     switch (II->getIntrinsicID()) {
1790     default: break;
1791     case Intrinsic::ceil:
1792     case Intrinsic::fabs:
1793     case Intrinsic::floor:
1794     case Intrinsic::nearbyint:
1795     case Intrinsic::rint:
1796     case Intrinsic::round:
1797     case Intrinsic::roundeven:
1798     case Intrinsic::trunc: {
1799       Value *Src = II->getArgOperand(0);
1800       if (!Src->hasOneUse())
1801         break;
1802 
1803       // Except for fabs, this transformation requires the input of the unary FP
1804       // operation to be itself an fpext from the type to which we're
1805       // truncating.
1806       if (II->getIntrinsicID() != Intrinsic::fabs) {
1807         FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src);
1808         if (!FPExtSrc || FPExtSrc->getSrcTy() != Ty)
1809           break;
1810       }
1811 
1812       // Do unary FP operation on smaller type.
1813       // (fptrunc (fabs x)) -> (fabs (fptrunc x))
1814       Value *InnerTrunc = Builder.CreateFPTrunc(Src, Ty);
1815       Function *Overload = Intrinsic::getDeclaration(FPT.getModule(),
1816                                                      II->getIntrinsicID(), Ty);
1817       SmallVector<OperandBundleDef, 1> OpBundles;
1818       II->getOperandBundlesAsDefs(OpBundles);
1819       CallInst *NewCI =
1820           CallInst::Create(Overload, {InnerTrunc}, OpBundles, II->getName());
1821       NewCI->copyFastMathFlags(II);
1822       return NewCI;
1823     }
1824     }
1825   }
1826 
1827   if (Instruction *I = shrinkInsertElt(FPT, Builder))
1828     return I;
1829 
1830   Value *Src = FPT.getOperand(0);
1831   if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) {
1832     auto *FPCast = cast<CastInst>(Src);
1833     if (isKnownExactCastIntToFP(*FPCast, *this))
1834       return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
1835   }
1836 
1837   return nullptr;
1838 }
1839 
1840 Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
1841   // If the source operand is a cast from integer to FP and known exact, then
1842   // cast the integer operand directly to the destination type.
1843   Type *Ty = FPExt.getType();
1844   Value *Src = FPExt.getOperand(0);
1845   if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) {
1846     auto *FPCast = cast<CastInst>(Src);
1847     if (isKnownExactCastIntToFP(*FPCast, *this))
1848       return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
1849   }
1850 
1851   return commonCastTransforms(FPExt);
1852 }
1853 
1854 /// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X)
1855 /// This is safe if the intermediate type has enough bits in its mantissa to
1856 /// accurately represent all values of X.  For example, this won't work with
1857 /// i64 -> float -> i64.
1858 Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) {
1859   if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0)))
1860     return nullptr;
1861 
1862   auto *OpI = cast<CastInst>(FI.getOperand(0));
1863   Value *X = OpI->getOperand(0);
1864   Type *XType = X->getType();
1865   Type *DestType = FI.getType();
1866   bool IsOutputSigned = isa<FPToSIInst>(FI);
1867 
1868   // Since we can assume the conversion won't overflow, our decision as to
1869   // whether the input will fit in the float should depend on the minimum
1870   // of the input range and output range.
1871 
1872   // This means this is also safe for a signed input and unsigned output, since
1873   // a negative input would lead to undefined behavior.
1874   if (!isKnownExactCastIntToFP(*OpI, *this)) {
1875     // The first cast may not round exactly based on the source integer width
1876     // and FP width, but the overflow UB rules can still allow this to fold.
1877     // If the destination type is narrow, that means the intermediate FP value
1878     // must be large enough to hold the source value exactly.
1879     // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior.
1880     int OutputSize = (int)DestType->getScalarSizeInBits();
1881     if (OutputSize > OpI->getType()->getFPMantissaWidth())
1882       return nullptr;
1883   }
1884 
1885   if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) {
1886     bool IsInputSigned = isa<SIToFPInst>(OpI);
1887     if (IsInputSigned && IsOutputSigned)
1888       return new SExtInst(X, DestType);
1889     return new ZExtInst(X, DestType);
1890   }
1891   if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits())
1892     return new TruncInst(X, DestType);
1893 
1894   assert(XType == DestType && "Unexpected types for int to FP to int casts");
1895   return replaceInstUsesWith(FI, X);
1896 }
1897 
1898 Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) {
1899   if (Instruction *I = foldItoFPtoI(FI))
1900     return I;
1901 
1902   return commonCastTransforms(FI);
1903 }
1904 
1905 Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) {
1906   if (Instruction *I = foldItoFPtoI(FI))
1907     return I;
1908 
1909   return commonCastTransforms(FI);
1910 }
1911 
1912 Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) {
1913   return commonCastTransforms(CI);
1914 }
1915 
1916 Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) {
1917   return commonCastTransforms(CI);
1918 }
1919 
1920 Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) {
1921   // If the source integer type is not the intptr_t type for this target, do a
1922   // trunc or zext to the intptr_t type, then inttoptr of it.  This allows the
1923   // cast to be exposed to other transforms.
1924   unsigned AS = CI.getAddressSpace();
1925   if (CI.getOperand(0)->getType()->getScalarSizeInBits() !=
1926       DL.getPointerSizeInBits(AS)) {
1927     Type *Ty = CI.getOperand(0)->getType()->getWithNewType(
1928         DL.getIntPtrType(CI.getContext(), AS));
1929     Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty);
1930     return new IntToPtrInst(P, CI.getType());
1931   }
1932 
1933   if (Instruction *I = commonCastTransforms(CI))
1934     return I;
1935 
1936   return nullptr;
1937 }
1938 
1939 Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
1940   // If the destination integer type is not the intptr_t type for this target,
1941   // do a ptrtoint to intptr_t then do a trunc or zext.  This allows the cast
1942   // to be exposed to other transforms.
1943   Value *SrcOp = CI.getPointerOperand();
1944   Type *SrcTy = SrcOp->getType();
1945   Type *Ty = CI.getType();
1946   unsigned AS = CI.getPointerAddressSpace();
1947   unsigned TySize = Ty->getScalarSizeInBits();
1948   unsigned PtrSize = DL.getPointerSizeInBits(AS);
1949   if (TySize != PtrSize) {
1950     Type *IntPtrTy =
1951         SrcTy->getWithNewType(DL.getIntPtrType(CI.getContext(), AS));
1952     Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy);
1953     return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
1954   }
1955 
1956   // (ptrtoint (ptrmask P, M))
1957   //    -> (and (ptrtoint P), M)
1958   // This is generally beneficial as `and` is better supported than `ptrmask`.
1959   Value *Ptr, *Mask;
1960   if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr),
1961                                                             m_Value(Mask)))) &&
1962       Mask->getType() == Ty)
1963     return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask);
1964 
1965   if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) {
1966     // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use.
1967     // While this can increase the number of instructions it doesn't actually
1968     // increase the overall complexity since the arithmetic is just part of
1969     // the GEP otherwise.
1970     if (GEP->hasOneUse() &&
1971         isa<ConstantPointerNull>(GEP->getPointerOperand())) {
1972       return replaceInstUsesWith(CI,
1973                                  Builder.CreateIntCast(EmitGEPOffset(GEP), Ty,
1974                                                        /*isSigned=*/false));
1975     }
1976   }
1977 
1978   Value *Vec, *Scalar, *Index;
1979   if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)),
1980                                         m_Value(Scalar), m_Value(Index)))) &&
1981       Vec->getType() == Ty) {
1982     assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type");
1983     // Convert the scalar to int followed by insert to eliminate one cast:
1984     // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index
1985     Value *NewCast = Builder.CreatePtrToInt(Scalar, Ty->getScalarType());
1986     return InsertElementInst::Create(Vec, NewCast, Index);
1987   }
1988 
1989   return commonCastTransforms(CI);
1990 }
1991 
1992 /// This input value (which is known to have vector type) is being zero extended
1993 /// or truncated to the specified vector type. Since the zext/trunc is done
1994 /// using an integer type, we have a (bitcast(cast(bitcast))) pattern,
1995 /// endianness will impact which end of the vector that is extended or
1996 /// truncated.
1997 ///
1998 /// A vector is always stored with index 0 at the lowest address, which
1999 /// corresponds to the most significant bits for a big endian stored integer and
2000 /// the least significant bits for little endian. A trunc/zext of an integer
2001 /// impacts the big end of the integer. Thus, we need to add/remove elements at
2002 /// the front of the vector for big endian targets, and the back of the vector
2003 /// for little endian targets.
2004 ///
2005 /// Try to replace it with a shuffle (and vector/vector bitcast) if possible.
2006 ///
2007 /// The source and destination vector types may have different element types.
2008 static Instruction *
2009 optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy,
2010                                         InstCombinerImpl &IC) {
2011   // We can only do this optimization if the output is a multiple of the input
2012   // element size, or the input is a multiple of the output element size.
2013   // Convert the input type to have the same element type as the output.
2014   VectorType *SrcTy = cast<VectorType>(InVal->getType());
2015 
2016   if (SrcTy->getElementType() != DestTy->getElementType()) {
2017     // The input types don't need to be identical, but for now they must be the
2018     // same size.  There is no specific reason we couldn't handle things like
2019     // <4 x i16> -> <4 x i32> by bitcasting to <2 x i32> but haven't gotten
2020     // there yet.
2021     if (SrcTy->getElementType()->getPrimitiveSizeInBits() !=
2022         DestTy->getElementType()->getPrimitiveSizeInBits())
2023       return nullptr;
2024 
2025     SrcTy =
2026         FixedVectorType::get(DestTy->getElementType(),
2027                              cast<FixedVectorType>(SrcTy)->getNumElements());
2028     InVal = IC.Builder.CreateBitCast(InVal, SrcTy);
2029   }
2030 
2031   bool IsBigEndian = IC.getDataLayout().isBigEndian();
2032   unsigned SrcElts = cast<FixedVectorType>(SrcTy)->getNumElements();
2033   unsigned DestElts = cast<FixedVectorType>(DestTy)->getNumElements();
2034 
2035   assert(SrcElts != DestElts && "Element counts should be different.");
2036 
2037   // Now that the element types match, get the shuffle mask and RHS of the
2038   // shuffle to use, which depends on whether we're increasing or decreasing the
2039   // size of the input.
2040   auto ShuffleMaskStorage = llvm::to_vector<16>(llvm::seq<int>(0, SrcElts));
2041   ArrayRef<int> ShuffleMask;
2042   Value *V2;
2043 
2044   if (SrcElts > DestElts) {
2045     // If we're shrinking the number of elements (rewriting an integer
2046     // truncate), just shuffle in the elements corresponding to the least
2047     // significant bits from the input and use poison as the second shuffle
2048     // input.
2049     V2 = PoisonValue::get(SrcTy);
2050     // Make sure the shuffle mask selects the "least significant bits" by
2051     // keeping elements from back of the src vector for big endian, and from the
2052     // front for little endian.
2053     ShuffleMask = ShuffleMaskStorage;
2054     if (IsBigEndian)
2055       ShuffleMask = ShuffleMask.take_back(DestElts);
2056     else
2057       ShuffleMask = ShuffleMask.take_front(DestElts);
2058   } else {
2059     // If we're increasing the number of elements (rewriting an integer zext),
2060     // shuffle in all of the elements from InVal. Fill the rest of the result
2061     // elements with zeros from a constant zero.
2062     V2 = Constant::getNullValue(SrcTy);
2063     // Use first elt from V2 when indicating zero in the shuffle mask.
2064     uint32_t NullElt = SrcElts;
2065     // Extend with null values in the "most significant bits" by adding elements
2066     // in front of the src vector for big endian, and at the back for little
2067     // endian.
2068     unsigned DeltaElts = DestElts - SrcElts;
2069     if (IsBigEndian)
2070       ShuffleMaskStorage.insert(ShuffleMaskStorage.begin(), DeltaElts, NullElt);
2071     else
2072       ShuffleMaskStorage.append(DeltaElts, NullElt);
2073     ShuffleMask = ShuffleMaskStorage;
2074   }
2075 
2076   return new ShuffleVectorInst(InVal, V2, ShuffleMask);
2077 }
2078 
2079 static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) {
2080   return Value % Ty->getPrimitiveSizeInBits() == 0;
2081 }
2082 
2083 static unsigned getTypeSizeIndex(unsigned Value, Type *Ty) {
2084   return Value / Ty->getPrimitiveSizeInBits();
2085 }
2086 
2087 /// V is a value which is inserted into a vector of VecEltTy.
2088 /// Look through the value to see if we can decompose it into
2089 /// insertions into the vector.  See the example in the comment for
2090 /// OptimizeIntegerToVectorInsertions for the pattern this handles.
2091 /// The type of V is always a non-zero multiple of VecEltTy's size.
2092 /// Shift is the number of bits between the lsb of V and the lsb of
2093 /// the vector.
2094 ///
2095 /// This returns false if the pattern can't be matched or true if it can,
2096 /// filling in Elements with the elements found here.
2097 static bool collectInsertionElements(Value *V, unsigned Shift,
2098                                      SmallVectorImpl<Value *> &Elements,
2099                                      Type *VecEltTy, bool isBigEndian) {
2100   assert(isMultipleOfTypeSize(Shift, VecEltTy) &&
2101          "Shift should be a multiple of the element type size");
2102 
2103   // Undef values never contribute useful bits to the result.
2104   if (isa<UndefValue>(V)) return true;
2105 
2106   // If we got down to a value of the right type, we win, try inserting into the
2107   // right element.
2108   if (V->getType() == VecEltTy) {
2109     // Inserting null doesn't actually insert any elements.
2110     if (Constant *C = dyn_cast<Constant>(V))
2111       if (C->isNullValue())
2112         return true;
2113 
2114     unsigned ElementIndex = getTypeSizeIndex(Shift, VecEltTy);
2115     if (isBigEndian)
2116       ElementIndex = Elements.size() - ElementIndex - 1;
2117 
2118     // Fail if multiple elements are inserted into this slot.
2119     if (Elements[ElementIndex])
2120       return false;
2121 
2122     Elements[ElementIndex] = V;
2123     return true;
2124   }
2125 
2126   if (Constant *C = dyn_cast<Constant>(V)) {
2127     // Figure out the # elements this provides, and bitcast it or slice it up
2128     // as required.
2129     unsigned NumElts = getTypeSizeIndex(C->getType()->getPrimitiveSizeInBits(),
2130                                         VecEltTy);
2131     // If the constant is the size of a vector element, we just need to bitcast
2132     // it to the right type so it gets properly inserted.
2133     if (NumElts == 1)
2134       return collectInsertionElements(ConstantExpr::getBitCast(C, VecEltTy),
2135                                       Shift, Elements, VecEltTy, isBigEndian);
2136 
2137     // Okay, this is a constant that covers multiple elements.  Slice it up into
2138     // pieces and insert each element-sized piece into the vector.
2139     if (!isa<IntegerType>(C->getType()))
2140       C = ConstantExpr::getBitCast(C, IntegerType::get(V->getContext(),
2141                                        C->getType()->getPrimitiveSizeInBits()));
2142     unsigned ElementSize = VecEltTy->getPrimitiveSizeInBits();
2143     Type *ElementIntTy = IntegerType::get(C->getContext(), ElementSize);
2144 
2145     for (unsigned i = 0; i != NumElts; ++i) {
2146       unsigned ShiftI = Shift + i * ElementSize;
2147       Constant *Piece = ConstantFoldBinaryInstruction(
2148           Instruction::LShr, C, ConstantInt::get(C->getType(), ShiftI));
2149       if (!Piece)
2150         return false;
2151 
2152       Piece = ConstantExpr::getTrunc(Piece, ElementIntTy);
2153       if (!collectInsertionElements(Piece, ShiftI, Elements, VecEltTy,
2154                                     isBigEndian))
2155         return false;
2156     }
2157     return true;
2158   }
2159 
2160   if (!V->hasOneUse()) return false;
2161 
2162   Instruction *I = dyn_cast<Instruction>(V);
2163   if (!I) return false;
2164   switch (I->getOpcode()) {
2165   default: return false; // Unhandled case.
2166   case Instruction::BitCast:
2167     if (I->getOperand(0)->getType()->isVectorTy())
2168       return false;
2169     return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy,
2170                                     isBigEndian);
2171   case Instruction::ZExt:
2172     if (!isMultipleOfTypeSize(
2173                           I->getOperand(0)->getType()->getPrimitiveSizeInBits(),
2174                               VecEltTy))
2175       return false;
2176     return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy,
2177                                     isBigEndian);
2178   case Instruction::Or:
2179     return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy,
2180                                     isBigEndian) &&
2181            collectInsertionElements(I->getOperand(1), Shift, Elements, VecEltTy,
2182                                     isBigEndian);
2183   case Instruction::Shl: {
2184     // Must be shifting by a constant that is a multiple of the element size.
2185     ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
2186     if (!CI) return false;
2187     Shift += CI->getZExtValue();
2188     if (!isMultipleOfTypeSize(Shift, VecEltTy)) return false;
2189     return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy,
2190                                     isBigEndian);
2191   }
2192 
2193   }
2194 }
2195 
2196 
2197 /// If the input is an 'or' instruction, we may be doing shifts and ors to
2198 /// assemble the elements of the vector manually.
2199 /// Try to rip the code out and replace it with insertelements.  This is to
2200 /// optimize code like this:
2201 ///
2202 ///    %tmp37 = bitcast float %inc to i32
2203 ///    %tmp38 = zext i32 %tmp37 to i64
2204 ///    %tmp31 = bitcast float %inc5 to i32
2205 ///    %tmp32 = zext i32 %tmp31 to i64
2206 ///    %tmp33 = shl i64 %tmp32, 32
2207 ///    %ins35 = or i64 %tmp33, %tmp38
2208 ///    %tmp43 = bitcast i64 %ins35 to <2 x float>
2209 ///
2210 /// Into two insertelements that do "buildvector{%inc, %inc5}".
2211 static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
2212                                                 InstCombinerImpl &IC) {
2213   auto *DestVecTy = cast<FixedVectorType>(CI.getType());
2214   Value *IntInput = CI.getOperand(0);
2215 
2216   SmallVector<Value*, 8> Elements(DestVecTy->getNumElements());
2217   if (!collectInsertionElements(IntInput, 0, Elements,
2218                                 DestVecTy->getElementType(),
2219                                 IC.getDataLayout().isBigEndian()))
2220     return nullptr;
2221 
2222   // If we succeeded, we know that all of the element are specified by Elements
2223   // or are zero if Elements has a null entry.  Recast this as a set of
2224   // insertions.
2225   Value *Result = Constant::getNullValue(CI.getType());
2226   for (unsigned i = 0, e = Elements.size(); i != e; ++i) {
2227     if (!Elements[i]) continue;  // Unset element.
2228 
2229     Result = IC.Builder.CreateInsertElement(Result, Elements[i],
2230                                             IC.Builder.getInt32(i));
2231   }
2232 
2233   return Result;
2234 }
2235 
2236 /// Canonicalize scalar bitcasts of extracted elements into a bitcast of the
2237 /// vector followed by extract element. The backend tends to handle bitcasts of
2238 /// vectors better than bitcasts of scalars because vector registers are
2239 /// usually not type-specific like scalar integer or scalar floating-point.
2240 static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast,
2241                                               InstCombinerImpl &IC) {
2242   Value *VecOp, *Index;
2243   if (!match(BitCast.getOperand(0),
2244              m_OneUse(m_ExtractElt(m_Value(VecOp), m_Value(Index)))))
2245     return nullptr;
2246 
2247   // The bitcast must be to a vectorizable type, otherwise we can't make a new
2248   // type to extract from.
2249   Type *DestType = BitCast.getType();
2250   VectorType *VecType = cast<VectorType>(VecOp->getType());
2251   if (VectorType::isValidElementType(DestType)) {
2252     auto *NewVecType = VectorType::get(DestType, VecType);
2253     auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc");
2254     return ExtractElementInst::Create(NewBC, Index);
2255   }
2256 
2257   // Only solve DestType is vector to avoid inverse transform in visitBitCast.
2258   // bitcast (extractelement <1 x elt>, dest) -> bitcast(<1 x elt>, dest)
2259   auto *FixedVType = dyn_cast<FixedVectorType>(VecType);
2260   if (DestType->isVectorTy() && FixedVType && FixedVType->getNumElements() == 1)
2261     return CastInst::Create(Instruction::BitCast, VecOp, DestType);
2262 
2263   return nullptr;
2264 }
2265 
2266 /// Change the type of a bitwise logic operation if we can eliminate a bitcast.
2267 static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast,
2268                                             InstCombiner::BuilderTy &Builder) {
2269   Type *DestTy = BitCast.getType();
2270   BinaryOperator *BO;
2271 
2272   if (!match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) ||
2273       !BO->isBitwiseLogicOp())
2274     return nullptr;
2275 
2276   // FIXME: This transform is restricted to vector types to avoid backend
2277   // problems caused by creating potentially illegal operations. If a fix-up is
2278   // added to handle that situation, we can remove this check.
2279   if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy())
2280     return nullptr;
2281 
2282   if (DestTy->isFPOrFPVectorTy()) {
2283     Value *X, *Y;
2284     // bitcast(logic(bitcast(X), bitcast(Y))) -> bitcast'(logic(bitcast'(X), Y))
2285     if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) &&
2286         match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(Y))))) {
2287       if (X->getType()->isFPOrFPVectorTy() &&
2288           Y->getType()->isIntOrIntVectorTy()) {
2289         Value *CastedOp =
2290             Builder.CreateBitCast(BO->getOperand(0), Y->getType());
2291         Value *NewBO = Builder.CreateBinOp(BO->getOpcode(), CastedOp, Y);
2292         return CastInst::CreateBitOrPointerCast(NewBO, DestTy);
2293       }
2294       if (X->getType()->isIntOrIntVectorTy() &&
2295           Y->getType()->isFPOrFPVectorTy()) {
2296         Value *CastedOp =
2297             Builder.CreateBitCast(BO->getOperand(1), X->getType());
2298         Value *NewBO = Builder.CreateBinOp(BO->getOpcode(), CastedOp, X);
2299         return CastInst::CreateBitOrPointerCast(NewBO, DestTy);
2300       }
2301     }
2302     return nullptr;
2303   }
2304 
2305   if (!DestTy->isIntOrIntVectorTy())
2306     return nullptr;
2307 
2308   Value *X;
2309   if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) &&
2310       X->getType() == DestTy && !isa<Constant>(X)) {
2311     // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y))
2312     Value *CastedOp1 = Builder.CreateBitCast(BO->getOperand(1), DestTy);
2313     return BinaryOperator::Create(BO->getOpcode(), X, CastedOp1);
2314   }
2315 
2316   if (match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(X)))) &&
2317       X->getType() == DestTy && !isa<Constant>(X)) {
2318     // bitcast(logic(Y, bitcast(X))) --> logic'(bitcast(Y), X)
2319     Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy);
2320     return BinaryOperator::Create(BO->getOpcode(), CastedOp0, X);
2321   }
2322 
2323   // Canonicalize vector bitcasts to come before vector bitwise logic with a
2324   // constant. This eases recognition of special constants for later ops.
2325   // Example:
2326   // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b
2327   Constant *C;
2328   if (match(BO->getOperand(1), m_Constant(C))) {
2329     // bitcast (logic X, C) --> logic (bitcast X, C')
2330     Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy);
2331     Value *CastedC = Builder.CreateBitCast(C, DestTy);
2332     return BinaryOperator::Create(BO->getOpcode(), CastedOp0, CastedC);
2333   }
2334 
2335   return nullptr;
2336 }
2337 
2338 /// Change the type of a select if we can eliminate a bitcast.
2339 static Instruction *foldBitCastSelect(BitCastInst &BitCast,
2340                                       InstCombiner::BuilderTy &Builder) {
2341   Value *Cond, *TVal, *FVal;
2342   if (!match(BitCast.getOperand(0),
2343              m_OneUse(m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal)))))
2344     return nullptr;
2345 
2346   // A vector select must maintain the same number of elements in its operands.
2347   Type *CondTy = Cond->getType();
2348   Type *DestTy = BitCast.getType();
2349   if (auto *CondVTy = dyn_cast<VectorType>(CondTy))
2350     if (!DestTy->isVectorTy() ||
2351         CondVTy->getElementCount() !=
2352             cast<VectorType>(DestTy)->getElementCount())
2353       return nullptr;
2354 
2355   // FIXME: This transform is restricted from changing the select between
2356   // scalars and vectors to avoid backend problems caused by creating
2357   // potentially illegal operations. If a fix-up is added to handle that
2358   // situation, we can remove this check.
2359   if (DestTy->isVectorTy() != TVal->getType()->isVectorTy())
2360     return nullptr;
2361 
2362   auto *Sel = cast<Instruction>(BitCast.getOperand(0));
2363   Value *X;
2364   if (match(TVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy &&
2365       !isa<Constant>(X)) {
2366     // bitcast(select(Cond, bitcast(X), Y)) --> select'(Cond, X, bitcast(Y))
2367     Value *CastedVal = Builder.CreateBitCast(FVal, DestTy);
2368     return SelectInst::Create(Cond, X, CastedVal, "", nullptr, Sel);
2369   }
2370 
2371   if (match(FVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy &&
2372       !isa<Constant>(X)) {
2373     // bitcast(select(Cond, Y, bitcast(X))) --> select'(Cond, bitcast(Y), X)
2374     Value *CastedVal = Builder.CreateBitCast(TVal, DestTy);
2375     return SelectInst::Create(Cond, CastedVal, X, "", nullptr, Sel);
2376   }
2377 
2378   return nullptr;
2379 }
2380 
2381 /// Check if all users of CI are StoreInsts.
2382 static bool hasStoreUsersOnly(CastInst &CI) {
2383   for (User *U : CI.users()) {
2384     if (!isa<StoreInst>(U))
2385       return false;
2386   }
2387   return true;
2388 }
2389 
2390 /// This function handles following case
2391 ///
2392 ///     A  ->  B    cast
2393 ///     PHI
2394 ///     B  ->  A    cast
2395 ///
2396 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
2397 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
2398 Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI,
2399                                                       PHINode *PN) {
2400   // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp.
2401   if (hasStoreUsersOnly(CI))
2402     return nullptr;
2403 
2404   Value *Src = CI.getOperand(0);
2405   Type *SrcTy = Src->getType();         // Type B
2406   Type *DestTy = CI.getType();          // Type A
2407 
2408   SmallVector<PHINode *, 4> PhiWorklist;
2409   SmallSetVector<PHINode *, 4> OldPhiNodes;
2410 
2411   // Find all of the A->B casts and PHI nodes.
2412   // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
2413   // OldPhiNodes is used to track all known PHI nodes, before adding a new
2414   // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
2415   PhiWorklist.push_back(PN);
2416   OldPhiNodes.insert(PN);
2417   while (!PhiWorklist.empty()) {
2418     auto *OldPN = PhiWorklist.pop_back_val();
2419     for (Value *IncValue : OldPN->incoming_values()) {
2420       if (isa<Constant>(IncValue))
2421         continue;
2422 
2423       if (auto *LI = dyn_cast<LoadInst>(IncValue)) {
2424         // If there is a sequence of one or more load instructions, each loaded
2425         // value is used as address of later load instruction, bitcast is
2426         // necessary to change the value type, don't optimize it. For
2427         // simplicity we give up if the load address comes from another load.
2428         Value *Addr = LI->getOperand(0);
2429         if (Addr == &CI || isa<LoadInst>(Addr))
2430           return nullptr;
2431         // Don't tranform "load <256 x i32>, <256 x i32>*" to
2432         // "load x86_amx, x86_amx*", because x86_amx* is invalid.
2433         // TODO: Remove this check when bitcast between vector and x86_amx
2434         // is replaced with a specific intrinsic.
2435         if (DestTy->isX86_AMXTy())
2436           return nullptr;
2437         if (LI->hasOneUse() && LI->isSimple())
2438           continue;
2439         // If a LoadInst has more than one use, changing the type of loaded
2440         // value may create another bitcast.
2441         return nullptr;
2442       }
2443 
2444       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
2445         if (OldPhiNodes.insert(PNode))
2446           PhiWorklist.push_back(PNode);
2447         continue;
2448       }
2449 
2450       auto *BCI = dyn_cast<BitCastInst>(IncValue);
2451       // We can't handle other instructions.
2452       if (!BCI)
2453         return nullptr;
2454 
2455       // Verify it's a A->B cast.
2456       Type *TyA = BCI->getOperand(0)->getType();
2457       Type *TyB = BCI->getType();
2458       if (TyA != DestTy || TyB != SrcTy)
2459         return nullptr;
2460     }
2461   }
2462 
2463   // Check that each user of each old PHI node is something that we can
2464   // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
2465   for (auto *OldPN : OldPhiNodes) {
2466     for (User *V : OldPN->users()) {
2467       if (auto *SI = dyn_cast<StoreInst>(V)) {
2468         if (!SI->isSimple() || SI->getOperand(0) != OldPN)
2469           return nullptr;
2470       } else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
2471         // Verify it's a B->A cast.
2472         Type *TyB = BCI->getOperand(0)->getType();
2473         Type *TyA = BCI->getType();
2474         if (TyA != DestTy || TyB != SrcTy)
2475           return nullptr;
2476       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
2477         // As long as the user is another old PHI node, then even if we don't
2478         // rewrite it, the PHI web we're considering won't have any users
2479         // outside itself, so it'll be dead.
2480         if (!OldPhiNodes.contains(PHI))
2481           return nullptr;
2482       } else {
2483         return nullptr;
2484       }
2485     }
2486   }
2487 
2488   // For each old PHI node, create a corresponding new PHI node with a type A.
2489   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
2490   for (auto *OldPN : OldPhiNodes) {
2491     Builder.SetInsertPoint(OldPN);
2492     PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
2493     NewPNodes[OldPN] = NewPN;
2494   }
2495 
2496   // Fill in the operands of new PHI nodes.
2497   for (auto *OldPN : OldPhiNodes) {
2498     PHINode *NewPN = NewPNodes[OldPN];
2499     for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
2500       Value *V = OldPN->getOperand(j);
2501       Value *NewV = nullptr;
2502       if (auto *C = dyn_cast<Constant>(V)) {
2503         NewV = ConstantExpr::getBitCast(C, DestTy);
2504       } else if (auto *LI = dyn_cast<LoadInst>(V)) {
2505         // Explicitly perform load combine to make sure no opposing transform
2506         // can remove the bitcast in the meantime and trigger an infinite loop.
2507         Builder.SetInsertPoint(LI);
2508         NewV = combineLoadToNewType(*LI, DestTy);
2509         // Remove the old load and its use in the old phi, which itself becomes
2510         // dead once the whole transform finishes.
2511         replaceInstUsesWith(*LI, PoisonValue::get(LI->getType()));
2512         eraseInstFromFunction(*LI);
2513       } else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
2514         NewV = BCI->getOperand(0);
2515       } else if (auto *PrevPN = dyn_cast<PHINode>(V)) {
2516         NewV = NewPNodes[PrevPN];
2517       }
2518       assert(NewV);
2519       NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
2520     }
2521   }
2522 
2523   // Traverse all accumulated PHI nodes and process its users,
2524   // which are Stores and BitcCasts. Without this processing
2525   // NewPHI nodes could be replicated and could lead to extra
2526   // moves generated after DeSSA.
2527   // If there is a store with type B, change it to type A.
2528 
2529 
2530   // Replace users of BitCast B->A with NewPHI. These will help
2531   // later to get rid off a closure formed by OldPHI nodes.
2532   Instruction *RetVal = nullptr;
2533   for (auto *OldPN : OldPhiNodes) {
2534     PHINode *NewPN = NewPNodes[OldPN];
2535     for (User *V : make_early_inc_range(OldPN->users())) {
2536       if (auto *SI = dyn_cast<StoreInst>(V)) {
2537         assert(SI->isSimple() && SI->getOperand(0) == OldPN);
2538         Builder.SetInsertPoint(SI);
2539         auto *NewBC =
2540           cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
2541         SI->setOperand(0, NewBC);
2542         Worklist.push(SI);
2543         assert(hasStoreUsersOnly(*NewBC));
2544       }
2545       else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
2546         Type *TyB = BCI->getOperand(0)->getType();
2547         Type *TyA = BCI->getType();
2548         assert(TyA == DestTy && TyB == SrcTy);
2549         (void) TyA;
2550         (void) TyB;
2551         Instruction *I = replaceInstUsesWith(*BCI, NewPN);
2552         if (BCI == &CI)
2553           RetVal = I;
2554       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
2555         assert(OldPhiNodes.contains(PHI));
2556         (void) PHI;
2557       } else {
2558         llvm_unreachable("all uses should be handled");
2559       }
2560     }
2561   }
2562 
2563   return RetVal;
2564 }
2565 
2566 Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
2567   // If the operands are integer typed then apply the integer transforms,
2568   // otherwise just apply the common ones.
2569   Value *Src = CI.getOperand(0);
2570   Type *SrcTy = Src->getType();
2571   Type *DestTy = CI.getType();
2572 
2573   // Get rid of casts from one type to the same type. These are useless and can
2574   // be replaced by the operand.
2575   if (DestTy == Src->getType())
2576     return replaceInstUsesWith(CI, Src);
2577 
2578   if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) {
2579     // Beware: messing with this target-specific oddity may cause trouble.
2580     if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) {
2581       Value *Elem = Builder.CreateBitCast(Src, DestVTy->getElementType());
2582       return InsertElementInst::Create(PoisonValue::get(DestTy), Elem,
2583                      Constant::getNullValue(Type::getInt32Ty(CI.getContext())));
2584     }
2585 
2586     if (isa<IntegerType>(SrcTy)) {
2587       // If this is a cast from an integer to vector, check to see if the input
2588       // is a trunc or zext of a bitcast from vector.  If so, we can replace all
2589       // the casts with a shuffle and (potentially) a bitcast.
2590       if (isa<TruncInst>(Src) || isa<ZExtInst>(Src)) {
2591         CastInst *SrcCast = cast<CastInst>(Src);
2592         if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0)))
2593           if (isa<VectorType>(BCIn->getOperand(0)->getType()))
2594             if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts(
2595                     BCIn->getOperand(0), cast<VectorType>(DestTy), *this))
2596               return I;
2597       }
2598 
2599       // If the input is an 'or' instruction, we may be doing shifts and ors to
2600       // assemble the elements of the vector manually.  Try to rip the code out
2601       // and replace it with insertelements.
2602       if (Value *V = optimizeIntegerToVectorInsertions(CI, *this))
2603         return replaceInstUsesWith(CI, V);
2604     }
2605   }
2606 
2607   if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(SrcTy)) {
2608     if (SrcVTy->getNumElements() == 1) {
2609       // If our destination is not a vector, then make this a straight
2610       // scalar-scalar cast.
2611       if (!DestTy->isVectorTy()) {
2612         Value *Elem =
2613           Builder.CreateExtractElement(Src,
2614                      Constant::getNullValue(Type::getInt32Ty(CI.getContext())));
2615         return CastInst::Create(Instruction::BitCast, Elem, DestTy);
2616       }
2617 
2618       // Otherwise, see if our source is an insert. If so, then use the scalar
2619       // component directly:
2620       // bitcast (inselt <1 x elt> V, X, 0) to <n x m> --> bitcast X to <n x m>
2621       if (auto *InsElt = dyn_cast<InsertElementInst>(Src))
2622         return new BitCastInst(InsElt->getOperand(1), DestTy);
2623     }
2624 
2625     // Convert an artificial vector insert into more analyzable bitwise logic.
2626     unsigned BitWidth = DestTy->getScalarSizeInBits();
2627     Value *X, *Y;
2628     uint64_t IndexC;
2629     if (match(Src, m_OneUse(m_InsertElt(m_OneUse(m_BitCast(m_Value(X))),
2630                                         m_Value(Y), m_ConstantInt(IndexC)))) &&
2631         DestTy->isIntegerTy() && X->getType() == DestTy &&
2632         Y->getType()->isIntegerTy() && isDesirableIntType(BitWidth)) {
2633       // Adjust for big endian - the LSBs are at the high index.
2634       if (DL.isBigEndian())
2635         IndexC = SrcVTy->getNumElements() - 1 - IndexC;
2636 
2637       // We only handle (endian-normalized) insert to index 0. Any other insert
2638       // would require a left-shift, so that is an extra instruction.
2639       if (IndexC == 0) {
2640         // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y)
2641         unsigned EltWidth = Y->getType()->getScalarSizeInBits();
2642         APInt MaskC = APInt::getHighBitsSet(BitWidth, BitWidth - EltWidth);
2643         Value *AndX = Builder.CreateAnd(X, MaskC);
2644         Value *ZextY = Builder.CreateZExt(Y, DestTy);
2645         return BinaryOperator::CreateOr(AndX, ZextY);
2646       }
2647     }
2648   }
2649 
2650   if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) {
2651     // Okay, we have (bitcast (shuffle ..)).  Check to see if this is
2652     // a bitcast to a vector with the same # elts.
2653     Value *ShufOp0 = Shuf->getOperand(0);
2654     Value *ShufOp1 = Shuf->getOperand(1);
2655     auto ShufElts = cast<VectorType>(Shuf->getType())->getElementCount();
2656     auto SrcVecElts = cast<VectorType>(ShufOp0->getType())->getElementCount();
2657     if (Shuf->hasOneUse() && DestTy->isVectorTy() &&
2658         cast<VectorType>(DestTy)->getElementCount() == ShufElts &&
2659         ShufElts == SrcVecElts) {
2660       BitCastInst *Tmp;
2661       // If either of the operands is a cast from CI.getType(), then
2662       // evaluating the shuffle in the casted destination's type will allow
2663       // us to eliminate at least one cast.
2664       if (((Tmp = dyn_cast<BitCastInst>(ShufOp0)) &&
2665            Tmp->getOperand(0)->getType() == DestTy) ||
2666           ((Tmp = dyn_cast<BitCastInst>(ShufOp1)) &&
2667            Tmp->getOperand(0)->getType() == DestTy)) {
2668         Value *LHS = Builder.CreateBitCast(ShufOp0, DestTy);
2669         Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy);
2670         // Return a new shuffle vector.  Use the same element ID's, as we
2671         // know the vector types match #elts.
2672         return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask());
2673       }
2674     }
2675 
2676     // A bitcasted-to-scalar and byte/bit reversing shuffle is better recognized
2677     // as a byte/bit swap:
2678     // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) -> bswap (bitcast X)
2679     // bitcast <N x i1> (shuf X, undef, <N, N-1,...0>) -> bitreverse (bitcast X)
2680     if (DestTy->isIntegerTy() && ShufElts.getKnownMinValue() % 2 == 0 &&
2681         Shuf->hasOneUse() && Shuf->isReverse()) {
2682       unsigned IntrinsicNum = 0;
2683       if (DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
2684           SrcTy->getScalarSizeInBits() == 8) {
2685         IntrinsicNum = Intrinsic::bswap;
2686       } else if (SrcTy->getScalarSizeInBits() == 1) {
2687         IntrinsicNum = Intrinsic::bitreverse;
2688       }
2689       if (IntrinsicNum != 0) {
2690         assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
2691         assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op");
2692         Function *BswapOrBitreverse =
2693             Intrinsic::getDeclaration(CI.getModule(), IntrinsicNum, DestTy);
2694         Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy);
2695         return CallInst::Create(BswapOrBitreverse, {ScalarX});
2696       }
2697     }
2698   }
2699 
2700   // Handle the A->B->A cast, and there is an intervening PHI node.
2701   if (PHINode *PN = dyn_cast<PHINode>(Src))
2702     if (Instruction *I = optimizeBitCastFromPhi(CI, PN))
2703       return I;
2704 
2705   if (Instruction *I = canonicalizeBitCastExtElt(CI, *this))
2706     return I;
2707 
2708   if (Instruction *I = foldBitCastBitwiseLogic(CI, Builder))
2709     return I;
2710 
2711   if (Instruction *I = foldBitCastSelect(CI, Builder))
2712     return I;
2713 
2714   return commonCastTransforms(CI);
2715 }
2716 
2717 Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
2718   return commonCastTransforms(CI);
2719 }
2720