xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (revision 86461b646df5da8117ddf051d212bcd13b5593f8)
1 //===- InstCombineShifts.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the visitShl, visitLShr, and visitAShr functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "InstCombineInternal.h"
14 #include "llvm/Analysis/ConstantFolding.h"
15 #include "llvm/Analysis/InstructionSimplify.h"
16 #include "llvm/IR/IntrinsicInst.h"
17 #include "llvm/IR/PatternMatch.h"
18 #include "llvm/Transforms/InstCombine/InstCombiner.h"
19 using namespace llvm;
20 using namespace PatternMatch;
21 
22 #define DEBUG_TYPE "instcombine"
23 
24 // Given pattern:
25 //   (x shiftopcode Q) shiftopcode K
26 // we should rewrite it as
27 //   x shiftopcode (Q+K)  iff (Q+K) u< bitwidth(x) and
28 //
29 // This is valid for any shift, but they must be identical, and we must be
30 // careful in case we have (zext(Q)+zext(K)) and look past extensions,
31 // (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
32 //
33 // AnalyzeForSignBitExtraction indicates that we will only analyze whether this
34 // pattern has any 2 right-shifts that sum to 1 less than original bit width.
35 Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
36     BinaryOperator *Sh0, const SimplifyQuery &SQ,
37     bool AnalyzeForSignBitExtraction) {
38   // Look for a shift of some instruction, ignore zext of shift amount if any.
39   Instruction *Sh0Op0;
40   Value *ShAmt0;
41   if (!match(Sh0,
42              m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0)))))
43     return nullptr;
44 
45   // If there is a truncation between the two shifts, we must make note of it
46   // and look through it. The truncation imposes additional constraints on the
47   // transform.
48   Instruction *Sh1;
49   Value *Trunc = nullptr;
50   match(Sh0Op0,
51         m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)),
52                     m_Instruction(Sh1)));
53 
54   // Inner shift: (x shiftopcode ShAmt1)
55   // Like with other shift, ignore zext of shift amount if any.
56   Value *X, *ShAmt1;
57   if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
58     return nullptr;
59 
60   // We have two shift amounts from two different shifts. The types of those
61   // shift amounts may not match. If that's the case let's bailout now..
62   if (ShAmt0->getType() != ShAmt1->getType())
63     return nullptr;
64 
65   // As input, we have the following pattern:
66   //   Sh0 (Sh1 X, Q), K
67   // We want to rewrite that as:
68   //   Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
69   // While we know that originally (Q+K) would not overflow
70   // (because  2 * (N-1) u<= iN -1), we have looked past extensions of
71   // shift amounts. so it may now overflow in smaller bitwidth.
72   // To ensure that does not happen, we need to ensure that the total maximal
73   // shift amount is still representable in that smaller bit width.
74   unsigned MaximalPossibleTotalShiftAmount =
75       (Sh0->getType()->getScalarSizeInBits() - 1) +
76       (Sh1->getType()->getScalarSizeInBits() - 1);
77   APInt MaximalRepresentableShiftAmount =
78       APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
79   if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
80     return nullptr;
81 
82   // We are only looking for signbit extraction if we have two right shifts.
83   bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
84                            match(Sh1, m_Shr(m_Value(), m_Value()));
85   // ... and if it's not two right-shifts, we know the answer already.
86   if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
87     return nullptr;
88 
89   // The shift opcodes must be identical, unless we are just checking whether
90   // this pattern can be interpreted as a sign-bit-extraction.
91   Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
92   bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
93   if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
94     return nullptr;
95 
96   // If we saw truncation, we'll need to produce extra instruction,
97   // and for that one of the operands of the shift must be one-use,
98   // unless of course we don't actually plan to produce any instructions here.
99   if (Trunc && !AnalyzeForSignBitExtraction &&
100       !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
101     return nullptr;
102 
103   // Can we fold (ShAmt0+ShAmt1) ?
104   auto *NewShAmt = dyn_cast_or_null<Constant>(
105       SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false,
106                       SQ.getWithInstruction(Sh0)));
107   if (!NewShAmt)
108     return nullptr; // Did not simplify.
109   unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits();
110   unsigned XBitWidth = X->getType()->getScalarSizeInBits();
111   // Is the new shift amount smaller than the bit width of inner/new shift?
112   if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
113                                           APInt(NewShAmtBitWidth, XBitWidth))))
114     return nullptr; // FIXME: could perform constant-folding.
115 
116   // If there was a truncation, and we have a right-shift, we can only fold if
117   // we are left with the original sign bit. Likewise, if we were just checking
118   // that this is a sighbit extraction, this is the place to check it.
119   // FIXME: zero shift amount is also legal here, but we can't *easily* check
120   // more than one predicate so it's not really worth it.
121   if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
122     // If it's not a sign bit extraction, then we're done.
123     if (!match(NewShAmt,
124                m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
125                                   APInt(NewShAmtBitWidth, XBitWidth - 1))))
126       return nullptr;
127     // If it is, and that was the question, return the base value.
128     if (AnalyzeForSignBitExtraction)
129       return X;
130   }
131 
132   assert(IdenticalShOpcodes && "Should not get here with different shifts.");
133 
134   // All good, we can do this fold.
135   NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
136 
137   BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
138 
139   // The flags can only be propagated if there wasn't a trunc.
140   if (!Trunc) {
141     // If the pattern did not involve trunc, and both of the original shifts
142     // had the same flag set, preserve the flag.
143     if (ShiftOpcode == Instruction::BinaryOps::Shl) {
144       NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
145                                      Sh1->hasNoUnsignedWrap());
146       NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
147                                    Sh1->hasNoSignedWrap());
148     } else {
149       NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
150     }
151   }
152 
153   Instruction *Ret = NewShift;
154   if (Trunc) {
155     Builder.Insert(NewShift);
156     Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType());
157   }
158 
159   return Ret;
160 }
161 
162 // If we have some pattern that leaves only some low bits set, and then performs
163 // left-shift of those bits, if none of the bits that are left after the final
164 // shift are modified by the mask, we can omit the mask.
165 //
166 // There are many variants to this pattern:
167 //   a)  (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
168 //   b)  (x & (~(-1 << MaskShAmt))) << ShiftShAmt
169 //   c)  (x & (-1 >> MaskShAmt)) << ShiftShAmt
170 //   d)  (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt
171 //   e)  ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt
172 //   f)  ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt
173 // All these patterns can be simplified to just:
174 //   x << ShiftShAmt
175 // iff:
176 //   a,b)     (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
177 //   c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
178 static Instruction *
179 dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
180                                      const SimplifyQuery &Q,
181                                      InstCombiner::BuilderTy &Builder) {
182   assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
183          "The input must be 'shl'!");
184 
185   Value *Masked, *ShiftShAmt;
186   match(OuterShift,
187         m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt))));
188 
189   // *If* there is a truncation between an outer shift and a possibly-mask,
190   // then said truncation *must* be one-use, else we can't perform the fold.
191   Value *Trunc;
192   if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) &&
193       !Trunc->hasOneUse())
194     return nullptr;
195 
196   Type *NarrowestTy = OuterShift->getType();
197   Type *WidestTy = Masked->getType();
198   bool HadTrunc = WidestTy != NarrowestTy;
199 
200   // The mask must be computed in a type twice as wide to ensure
201   // that no bits are lost if the sum-of-shifts is wider than the base type.
202   Type *ExtendedTy = WidestTy->getExtendedType();
203 
204   Value *MaskShAmt;
205 
206   // ((1 << MaskShAmt) - 1)
207   auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
208   // (~(-1 << maskNbits))
209   auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
210   // (-1 >> MaskShAmt)
211   auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
212   // ((-1 << MaskShAmt) >> MaskShAmt)
213   auto MaskD =
214       m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt));
215 
216   Value *X;
217   Constant *NewMask;
218 
219   if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
220     // Peek through an optional zext of the shift amount.
221     match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
222 
223     // We have two shift amounts from two different shifts. The types of those
224     // shift amounts may not match. If that's the case let's bailout now.
225     if (MaskShAmt->getType() != ShiftShAmt->getType())
226       return nullptr;
227 
228     // Can we simplify (MaskShAmt+ShiftShAmt) ?
229     auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst(
230         MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
231     if (!SumOfShAmts)
232       return nullptr; // Did not simplify.
233     // In this pattern SumOfShAmts correlates with the number of low bits
234     // that shall remain in the root value (OuterShift).
235 
236     // An extend of an undef value becomes zero because the high bits are never
237     // completely unknown. Replace the the `undef` shift amounts with final
238     // shift bitwidth to ensure that the value remains undef when creating the
239     // subsequent shift op.
240     SumOfShAmts = Constant::replaceUndefsWith(
241         SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
242                                       ExtendedTy->getScalarSizeInBits()));
243     auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
244     // And compute the mask as usual: ~(-1 << (SumOfShAmts))
245     auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
246     auto *ExtendedInvertedMask =
247         ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
248     NewMask = ConstantExpr::getNot(ExtendedInvertedMask);
249   } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
250              match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
251                                  m_Deferred(MaskShAmt)))) {
252     // Peek through an optional zext of the shift amount.
253     match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
254 
255     // We have two shift amounts from two different shifts. The types of those
256     // shift amounts may not match. If that's the case let's bailout now.
257     if (MaskShAmt->getType() != ShiftShAmt->getType())
258       return nullptr;
259 
260     // Can we simplify (ShiftShAmt-MaskShAmt) ?
261     auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst(
262         ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
263     if (!ShAmtsDiff)
264       return nullptr; // Did not simplify.
265     // In this pattern ShAmtsDiff correlates with the number of high bits that
266     // shall be unset in the root value (OuterShift).
267 
268     // An extend of an undef value becomes zero because the high bits are never
269     // completely unknown. Replace the the `undef` shift amounts with negated
270     // bitwidth of innermost shift to ensure that the value remains undef when
271     // creating the subsequent shift op.
272     unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits();
273     ShAmtsDiff = Constant::replaceUndefsWith(
274         ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
275                                      -WidestTyBitWidth));
276     auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
277         ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
278                                               WidestTyBitWidth,
279                                               /*isSigned=*/false),
280                              ShAmtsDiff),
281         ExtendedTy);
282     // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
283     auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
284     NewMask =
285         ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
286   } else
287     return nullptr; // Don't know anything about this pattern.
288 
289   NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy);
290 
291   // Does this mask has any unset bits? If not then we can just not apply it.
292   bool NeedMask = !match(NewMask, m_AllOnes());
293 
294   // If we need to apply a mask, there are several more restrictions we have.
295   if (NeedMask) {
296     // The old masking instruction must go away.
297     if (!Masked->hasOneUse())
298       return nullptr;
299     // The original "masking" instruction must not have been`ashr`.
300     if (match(Masked, m_AShr(m_Value(), m_Value())))
301       return nullptr;
302   }
303 
304   // If we need to apply truncation, let's do it first, since we can.
305   // We have already ensured that the old truncation will go away.
306   if (HadTrunc)
307     X = Builder.CreateTrunc(X, NarrowestTy);
308 
309   // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
310   // We didn't change the Type of this outermost shift, so we can just do it.
311   auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X,
312                                           OuterShift->getOperand(1));
313   if (!NeedMask)
314     return NewShift;
315 
316   Builder.Insert(NewShift);
317   return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
318 }
319 
320 /// If we have a shift-by-constant of a bitwise logic op that itself has a
321 /// shift-by-constant operand with identical opcode, we may be able to convert
322 /// that into 2 independent shifts followed by the logic op. This eliminates a
323 /// a use of an intermediate value (reduces dependency chain).
324 static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
325                                             InstCombiner::BuilderTy &Builder) {
326   assert(I.isShift() && "Expected a shift as input");
327   auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0));
328   if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse())
329     return nullptr;
330 
331   Constant *C0, *C1;
332   if (!match(I.getOperand(1), m_Constant(C1)))
333     return nullptr;
334 
335   Instruction::BinaryOps ShiftOpcode = I.getOpcode();
336   Type *Ty = I.getType();
337 
338   // Find a matching one-use shift by constant. The fold is not valid if the sum
339   // of the shift values equals or exceeds bitwidth.
340   // TODO: Remove the one-use check if the other logic operand (Y) is constant.
341   Value *X, *Y;
342   auto matchFirstShift = [&](Value *V) {
343     BinaryOperator *BO;
344     APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits());
345     return match(V, m_BinOp(BO)) && BO->getOpcode() == ShiftOpcode &&
346            match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) &&
347            match(ConstantExpr::getAdd(C0, C1),
348                  m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
349   };
350 
351   // Logic ops are commutative, so check each operand for a match.
352   if (matchFirstShift(LogicInst->getOperand(0)))
353     Y = LogicInst->getOperand(1);
354   else if (matchFirstShift(LogicInst->getOperand(1)))
355     Y = LogicInst->getOperand(0);
356   else
357     return nullptr;
358 
359   // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
360   Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
361   Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
362   Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1));
363   return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
364 }
365 
366 Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
367   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
368   assert(Op0->getType() == Op1->getType());
369 
370   // If the shift amount is a one-use `sext`, we can demote it to `zext`.
371   Value *Y;
372   if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) {
373     Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName());
374     return BinaryOperator::Create(I.getOpcode(), Op0, NewExt);
375   }
376 
377   // See if we can fold away this shift.
378   if (SimplifyDemandedInstructionBits(I))
379     return &I;
380 
381   // Try to fold constant and into select arguments.
382   if (isa<Constant>(Op0))
383     if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
384       if (Instruction *R = FoldOpIntoSelect(I, SI))
385         return R;
386 
387   if (Constant *CUI = dyn_cast<Constant>(Op1))
388     if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
389       return Res;
390 
391   if (auto *NewShift = cast_or_null<Instruction>(
392           reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
393     return NewShift;
394 
395   // (C1 shift (A add C2)) -> (C1 shift C2) shift A)
396   // iff A and C2 are both positive.
397   Value *A;
398   Constant *C;
399   if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C))))
400     if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) &&
401         isKnownNonNegative(C, DL, 0, &AC, &I, &DT))
402       return BinaryOperator::Create(
403           I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A);
404 
405   // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2.
406   // Because shifts by negative values (which could occur if A were negative)
407   // are undefined.
408   if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Constant(C))) &&
409       match(C, m_Power2())) {
410     // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
411     // demand the sign bit (and many others) here??
412     Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(I.getType(), 1));
413     Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName());
414     return replaceOperand(I, 1, Rem);
415   }
416 
417   if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder))
418     return Logic;
419 
420   return nullptr;
421 }
422 
423 /// Return true if we can simplify two logical (either left or right) shifts
424 /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
425 static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
426                                     Instruction *InnerShift,
427                                     InstCombinerImpl &IC, Instruction *CxtI) {
428   assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
429 
430   // We need constant scalar or constant splat shifts.
431   const APInt *InnerShiftConst;
432   if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
433     return false;
434 
435   // Two logical shifts in the same direction:
436   // shl (shl X, C1), C2 -->  shl X, C1 + C2
437   // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
438   bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
439   if (IsInnerShl == IsOuterShl)
440     return true;
441 
442   // Equal shift amounts in opposite directions become bitwise 'and':
443   // lshr (shl X, C), C --> and X, C'
444   // shl (lshr X, C), C --> and X, C'
445   if (*InnerShiftConst == OuterShAmt)
446     return true;
447 
448   // If the 2nd shift is bigger than the 1st, we can fold:
449   // lshr (shl X, C1), C2 -->  and (shl X, C1 - C2), C3
450   // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
451   // but it isn't profitable unless we know the and'd out bits are already zero.
452   // Also, check that the inner shift is valid (less than the type width) or
453   // we'll crash trying to produce the bit mask for the 'and'.
454   unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
455   if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
456     unsigned InnerShAmt = InnerShiftConst->getZExtValue();
457     unsigned MaskShift =
458         IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
459     APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
460     if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
461       return true;
462   }
463 
464   return false;
465 }
466 
467 /// See if we can compute the specified value, but shifted logically to the left
468 /// or right by some number of bits. This should return true if the expression
469 /// can be computed for the same cost as the current expression tree. This is
470 /// used to eliminate extraneous shifting from things like:
471 ///      %C = shl i128 %A, 64
472 ///      %D = shl i128 %B, 96
473 ///      %E = or i128 %C, %D
474 ///      %F = lshr i128 %E, 64
475 /// where the client will ask if E can be computed shifted right by 64-bits. If
476 /// this succeeds, getShiftedValue() will be called to produce the value.
477 static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
478                                InstCombinerImpl &IC, Instruction *CxtI) {
479   // We can always evaluate constants shifted.
480   if (isa<Constant>(V))
481     return true;
482 
483   Instruction *I = dyn_cast<Instruction>(V);
484   if (!I) return false;
485 
486   // We can't mutate something that has multiple uses: doing so would
487   // require duplicating the instruction in general, which isn't profitable.
488   if (!I->hasOneUse()) return false;
489 
490   switch (I->getOpcode()) {
491   default: return false;
492   case Instruction::And:
493   case Instruction::Or:
494   case Instruction::Xor:
495     // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
496     return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
497            canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
498 
499   case Instruction::Shl:
500   case Instruction::LShr:
501     return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
502 
503   case Instruction::Select: {
504     SelectInst *SI = cast<SelectInst>(I);
505     Value *TrueVal = SI->getTrueValue();
506     Value *FalseVal = SI->getFalseValue();
507     return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
508            canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
509   }
510   case Instruction::PHI: {
511     // We can change a phi if we can change all operands.  Note that we never
512     // get into trouble with cyclic PHIs here because we only consider
513     // instructions with a single use.
514     PHINode *PN = cast<PHINode>(I);
515     for (Value *IncValue : PN->incoming_values())
516       if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
517         return false;
518     return true;
519   }
520   }
521 }
522 
523 /// Fold OuterShift (InnerShift X, C1), C2.
524 /// See canEvaluateShiftedShift() for the constraints on these instructions.
525 static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
526                                bool IsOuterShl,
527                                InstCombiner::BuilderTy &Builder) {
528   bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
529   Type *ShType = InnerShift->getType();
530   unsigned TypeWidth = ShType->getScalarSizeInBits();
531 
532   // We only accept shifts-by-a-constant in canEvaluateShifted().
533   const APInt *C1;
534   match(InnerShift->getOperand(1), m_APInt(C1));
535   unsigned InnerShAmt = C1->getZExtValue();
536 
537   // Change the shift amount and clear the appropriate IR flags.
538   auto NewInnerShift = [&](unsigned ShAmt) {
539     InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
540     if (IsInnerShl) {
541       InnerShift->setHasNoUnsignedWrap(false);
542       InnerShift->setHasNoSignedWrap(false);
543     } else {
544       InnerShift->setIsExact(false);
545     }
546     return InnerShift;
547   };
548 
549   // Two logical shifts in the same direction:
550   // shl (shl X, C1), C2 -->  shl X, C1 + C2
551   // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
552   if (IsInnerShl == IsOuterShl) {
553     // If this is an oversized composite shift, then unsigned shifts get 0.
554     if (InnerShAmt + OuterShAmt >= TypeWidth)
555       return Constant::getNullValue(ShType);
556 
557     return NewInnerShift(InnerShAmt + OuterShAmt);
558   }
559 
560   // Equal shift amounts in opposite directions become bitwise 'and':
561   // lshr (shl X, C), C --> and X, C'
562   // shl (lshr X, C), C --> and X, C'
563   if (InnerShAmt == OuterShAmt) {
564     APInt Mask = IsInnerShl
565                      ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
566                      : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
567     Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
568                                    ConstantInt::get(ShType, Mask));
569     if (auto *AndI = dyn_cast<Instruction>(And)) {
570       AndI->moveBefore(InnerShift);
571       AndI->takeName(InnerShift);
572     }
573     return And;
574   }
575 
576   assert(InnerShAmt > OuterShAmt &&
577          "Unexpected opposite direction logical shift pair");
578 
579   // In general, we would need an 'and' for this transform, but
580   // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
581   // lshr (shl X, C1), C2 -->  shl X, C1 - C2
582   // shl (lshr X, C1), C2 --> lshr X, C1 - C2
583   return NewInnerShift(InnerShAmt - OuterShAmt);
584 }
585 
586 /// When canEvaluateShifted() returns true for an expression, this function
587 /// inserts the new computation that produces the shifted value.
588 static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
589                               InstCombinerImpl &IC, const DataLayout &DL) {
590   // We can always evaluate constants shifted.
591   if (Constant *C = dyn_cast<Constant>(V)) {
592     if (isLeftShift)
593       return IC.Builder.CreateShl(C, NumBits);
594     else
595       return IC.Builder.CreateLShr(C, NumBits);
596   }
597 
598   Instruction *I = cast<Instruction>(V);
599   IC.addToWorklist(I);
600 
601   switch (I->getOpcode()) {
602   default: llvm_unreachable("Inconsistency with CanEvaluateShifted");
603   case Instruction::And:
604   case Instruction::Or:
605   case Instruction::Xor:
606     // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
607     I->setOperand(
608         0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
609     I->setOperand(
610         1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
611     return I;
612 
613   case Instruction::Shl:
614   case Instruction::LShr:
615     return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
616                             IC.Builder);
617 
618   case Instruction::Select:
619     I->setOperand(
620         1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
621     I->setOperand(
622         2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
623     return I;
624   case Instruction::PHI: {
625     // We can change a phi if we can change all operands.  Note that we never
626     // get into trouble with cyclic PHIs here because we only consider
627     // instructions with a single use.
628     PHINode *PN = cast<PHINode>(I);
629     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
630       PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
631                                               isLeftShift, IC, DL));
632     return PN;
633   }
634   }
635 }
636 
637 // If this is a bitwise operator or add with a constant RHS we might be able
638 // to pull it through a shift.
639 static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
640                                          BinaryOperator *BO) {
641   switch (BO->getOpcode()) {
642   default:
643     return false; // Do not perform transform!
644   case Instruction::Add:
645     return Shift.getOpcode() == Instruction::Shl;
646   case Instruction::Or:
647   case Instruction::And:
648     return true;
649   case Instruction::Xor:
650     // Do not change a 'not' of logical shift because that would create a normal
651     // 'xor'. The 'not' is likely better for analysis, SCEV, and codegen.
652     return !(Shift.isLogicalShift() && match(BO, m_Not(m_Value())));
653   }
654 }
655 
656 Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
657                                                    BinaryOperator &I) {
658   bool isLeftShift = I.getOpcode() == Instruction::Shl;
659 
660   const APInt *Op1C;
661   if (!match(Op1, m_APInt(Op1C)))
662     return nullptr;
663 
664   // See if we can propagate this shift into the input, this covers the trivial
665   // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
666   if (I.getOpcode() != Instruction::AShr &&
667       canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
668     LLVM_DEBUG(
669         dbgs() << "ICE: GetShiftedValue propagating shift through expression"
670                   " to eliminate shift:\n  IN: "
671                << *Op0 << "\n  SH: " << I << "\n");
672 
673     return replaceInstUsesWith(
674         I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
675   }
676 
677   // See if we can simplify any instructions used by the instruction whose sole
678   // purpose is to compute bits we don't care about.
679   Type *Ty = I.getType();
680   unsigned TypeBits = Ty->getScalarSizeInBits();
681   assert(!Op1C->uge(TypeBits) &&
682          "Shift over the type width should have been removed already");
683 
684   if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
685     return FoldedShift;
686 
687   // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
688   if (auto *TI = dyn_cast<TruncInst>(Op0)) {
689     // If 'shift2' is an ashr, we would have to get the sign bit into a funny
690     // place.  Don't try to do this transformation in this case.  Also, we
691     // require that the input operand is a shift-by-constant so that we have
692     // confidence that the shifts will get folded together.  We could do this
693     // xform in more cases, but it is unlikely to be profitable.
694     const APInt *TrShiftAmt;
695     if (I.isLogicalShift() &&
696         match(TI->getOperand(0), m_Shift(m_Value(), m_APInt(TrShiftAmt)))) {
697       auto *TrOp = cast<Instruction>(TI->getOperand(0));
698       Type *SrcTy = TrOp->getType();
699 
700       // Okay, we'll do this xform.  Make the shift of shift.
701       Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy);
702       // (shift2 (shift1 & 0x00FF), c2)
703       Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
704 
705       // For logical shifts, the truncation has the effect of making the high
706       // part of the register be zeros.  Emulate this by inserting an AND to
707       // clear the top bits as needed.  This 'and' will usually be zapped by
708       // other xforms later if dead.
709       unsigned SrcSize = SrcTy->getScalarSizeInBits();
710       Constant *MaskV =
711           ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits));
712 
713       // The mask we constructed says what the trunc would do if occurring
714       // between the shifts.  We want to know the effect *after* the second
715       // shift.  We know that it is a logical shift by a constant, so adjust the
716       // mask as appropriate.
717       MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt);
718       // shift1 & 0x00FF
719       Value *And = Builder.CreateAnd(NSh, MaskV, TI->getName());
720       // Return the value truncated to the interesting size.
721       return new TruncInst(And, Ty);
722     }
723   }
724 
725   if (Op0->hasOneUse()) {
726     if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) {
727       // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
728       Value *V1;
729       const APInt *CC;
730       switch (Op0BO->getOpcode()) {
731       default: break;
732       case Instruction::Add:
733       case Instruction::And:
734       case Instruction::Or:
735       case Instruction::Xor: {
736         // These operators commute.
737         // Turn (Y + (X >> C)) << C  ->  (X + (Y << C)) & (~0 << C)
738         if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
739             match(Op0BO->getOperand(1), m_Shr(m_Value(V1),
740                   m_Specific(Op1)))) {
741           Value *YS =         // (Y << C)
742             Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
743           // (X + (Y << C))
744           Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
745                                          Op0BO->getOperand(1)->getName());
746           unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
747           APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
748           Constant *Mask = ConstantInt::get(Ty, Bits);
749           return BinaryOperator::CreateAnd(X, Mask);
750         }
751 
752         // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
753         Value *Op0BOOp1 = Op0BO->getOperand(1);
754         if (isLeftShift && Op0BOOp1->hasOneUse() &&
755             match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
756                                   m_APInt(CC)))) {
757           Value *YS = // (Y << C)
758               Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
759           // X & (CC << C)
760           Value *XM = Builder.CreateAnd(
761               V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1),
762               V1->getName() + ".mask");
763           return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
764         }
765         LLVM_FALLTHROUGH;
766       }
767 
768       case Instruction::Sub: {
769         // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
770         if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
771             match(Op0BO->getOperand(0), m_Shr(m_Value(V1),
772                   m_Specific(Op1)))) {
773           Value *YS =  // (Y << C)
774             Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
775           // (X + (Y << C))
776           Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
777                                          Op0BO->getOperand(0)->getName());
778           unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
779           APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
780           Constant *Mask = ConstantInt::get(Ty, Bits);
781           return BinaryOperator::CreateAnd(X, Mask);
782         }
783 
784         // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
785         if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
786             match(Op0BO->getOperand(0),
787                   m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
788                         m_APInt(CC)))) {
789           Value *YS = // (Y << C)
790               Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
791           // X & (CC << C)
792           Value *XM = Builder.CreateAnd(
793               V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1),
794               V1->getName() + ".mask");
795           return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
796         }
797 
798         break;
799       }
800       }
801 
802       // If the operand is a bitwise operator with a constant RHS, and the
803       // shift is the only use, we can pull it out of the shift.
804       const APInt *Op0C;
805       if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
806         if (canShiftBinOpWithConstantRHS(I, Op0BO)) {
807           Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
808                                      cast<Constant>(Op0BO->getOperand(1)), Op1);
809 
810           Value *NewShift =
811             Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1);
812           NewShift->takeName(Op0BO);
813 
814           return BinaryOperator::Create(Op0BO->getOpcode(), NewShift,
815                                         NewRHS);
816         }
817       }
818 
819       // If the operand is a subtract with a constant LHS, and the shift
820       // is the only use, we can pull it out of the shift.
821       // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
822       if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
823           match(Op0BO->getOperand(0), m_APInt(Op0C))) {
824         Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
825                                    cast<Constant>(Op0BO->getOperand(0)), Op1);
826 
827         Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1);
828         NewShift->takeName(Op0BO);
829 
830         return BinaryOperator::CreateSub(NewRHS, NewShift);
831       }
832     }
833 
834     // If we have a select that conditionally executes some binary operator,
835     // see if we can pull it the select and operator through the shift.
836     //
837     // For example, turning:
838     //   shl (select C, (add X, C1), X), C2
839     // Into:
840     //   Y = shl X, C2
841     //   select C, (add Y, C1 << C2), Y
842     Value *Cond;
843     BinaryOperator *TBO;
844     Value *FalseVal;
845     if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)),
846                             m_Value(FalseVal)))) {
847       const APInt *C;
848       if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal &&
849           match(TBO->getOperand(1), m_APInt(C)) &&
850           canShiftBinOpWithConstantRHS(I, TBO)) {
851         Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
852                                        cast<Constant>(TBO->getOperand(1)), Op1);
853 
854         Value *NewShift =
855           Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1);
856         Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift,
857                                            NewRHS);
858         return SelectInst::Create(Cond, NewOp, NewShift);
859       }
860     }
861 
862     BinaryOperator *FBO;
863     Value *TrueVal;
864     if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal),
865                             m_OneUse(m_BinOp(FBO))))) {
866       const APInt *C;
867       if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal &&
868           match(FBO->getOperand(1), m_APInt(C)) &&
869           canShiftBinOpWithConstantRHS(I, FBO)) {
870         Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
871                                        cast<Constant>(FBO->getOperand(1)), Op1);
872 
873         Value *NewShift =
874           Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1);
875         Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift,
876                                            NewRHS);
877         return SelectInst::Create(Cond, NewShift, NewOp);
878       }
879     }
880   }
881 
882   return nullptr;
883 }
884 
885 Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
886   const SimplifyQuery Q = SQ.getWithInstruction(&I);
887 
888   if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
889                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q))
890     return replaceInstUsesWith(I, V);
891 
892   if (Instruction *X = foldVectorBinop(I))
893     return X;
894 
895   if (Instruction *V = commonShiftTransforms(I))
896     return V;
897 
898   if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
899     return V;
900 
901   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
902   Type *Ty = I.getType();
903   unsigned BitWidth = Ty->getScalarSizeInBits();
904 
905   const APInt *ShAmtAPInt;
906   if (match(Op1, m_APInt(ShAmtAPInt))) {
907     unsigned ShAmt = ShAmtAPInt->getZExtValue();
908 
909     // shl (zext X), ShAmt --> zext (shl X, ShAmt)
910     // This is only valid if X would have zeros shifted out.
911     Value *X;
912     if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
913       unsigned SrcWidth = X->getType()->getScalarSizeInBits();
914       if (ShAmt < SrcWidth &&
915           MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
916         return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty);
917     }
918 
919     // (X >> C) << C --> X & (-1 << C)
920     if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) {
921       APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
922       return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
923     }
924 
925     const APInt *ShOp1;
926     if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1)))) &&
927         ShOp1->ult(BitWidth)) {
928       unsigned ShrAmt = ShOp1->getZExtValue();
929       if (ShrAmt < ShAmt) {
930         // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
931         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
932         auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
933         NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
934         NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
935         return NewShl;
936       }
937       if (ShrAmt > ShAmt) {
938         // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
939         Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
940         auto *NewShr = BinaryOperator::Create(
941             cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
942         NewShr->setIsExact(true);
943         return NewShr;
944       }
945     }
946 
947     if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(ShOp1)))) &&
948         ShOp1->ult(BitWidth)) {
949       unsigned ShrAmt = ShOp1->getZExtValue();
950       if (ShrAmt < ShAmt) {
951         // If C1 < C2: (X >>? C1) << C2 --> X << (C2 - C1) & (-1 << C2)
952         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
953         auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
954         NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
955         NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
956         Builder.Insert(NewShl);
957         APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
958         return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
959       }
960       if (ShrAmt > ShAmt) {
961         // If C1 > C2: (X >>? C1) << C2 --> X >>? (C1 - C2) & (-1 << C2)
962         Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
963         auto *OldShr = cast<BinaryOperator>(Op0);
964         auto *NewShr =
965             BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff);
966         NewShr->setIsExact(OldShr->isExact());
967         Builder.Insert(NewShr);
968         APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
969         return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask));
970       }
971     }
972 
973     if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) {
974       unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
975       // Oversized shifts are simplified to zero in InstSimplify.
976       if (AmtSum < BitWidth)
977         // (X << C1) << C2 --> X << (C1 + C2)
978         return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
979     }
980 
981     // If the shifted-out value is known-zero, then this is a NUW shift.
982     if (!I.hasNoUnsignedWrap() &&
983         MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
984       I.setHasNoUnsignedWrap();
985       return &I;
986     }
987 
988     // If the shifted-out value is all signbits, then this is a NSW shift.
989     if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
990       I.setHasNoSignedWrap();
991       return &I;
992     }
993   }
994 
995   // Transform  (x >> y) << y  to  x & (-1 << y)
996   // Valid for any type of right-shift.
997   Value *X;
998   if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
999     Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
1000     Value *Mask = Builder.CreateShl(AllOnes, Op1);
1001     return BinaryOperator::CreateAnd(Mask, X);
1002   }
1003 
1004   Constant *C1;
1005   if (match(Op1, m_Constant(C1))) {
1006     Constant *C2;
1007     Value *X;
1008     // (C2 << X) << C1 --> (C2 << C1) << X
1009     if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
1010       return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
1011 
1012     // (X * C2) << C1 --> X * (C2 << C1)
1013     if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
1014       return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
1015 
1016     // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
1017     if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1018       auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1);
1019       return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
1020     }
1021   }
1022 
1023   // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
1024   if (match(Op0, m_One()) &&
1025       match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
1026     return BinaryOperator::CreateLShr(
1027         ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
1028 
1029   return nullptr;
1030 }
1031 
1032 Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
1033   if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
1034                                   SQ.getWithInstruction(&I)))
1035     return replaceInstUsesWith(I, V);
1036 
1037   if (Instruction *X = foldVectorBinop(I))
1038     return X;
1039 
1040   if (Instruction *R = commonShiftTransforms(I))
1041     return R;
1042 
1043   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1044   Type *Ty = I.getType();
1045   const APInt *ShAmtAPInt;
1046   if (match(Op1, m_APInt(ShAmtAPInt))) {
1047     unsigned ShAmt = ShAmtAPInt->getZExtValue();
1048     unsigned BitWidth = Ty->getScalarSizeInBits();
1049     auto *II = dyn_cast<IntrinsicInst>(Op0);
1050     if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
1051         (II->getIntrinsicID() == Intrinsic::ctlz ||
1052          II->getIntrinsicID() == Intrinsic::cttz ||
1053          II->getIntrinsicID() == Intrinsic::ctpop)) {
1054       // ctlz.i32(x)>>5  --> zext(x == 0)
1055       // cttz.i32(x)>>5  --> zext(x == 0)
1056       // ctpop.i32(x)>>5 --> zext(x == -1)
1057       bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
1058       Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
1059       Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS);
1060       return new ZExtInst(Cmp, Ty);
1061     }
1062 
1063     Value *X;
1064     const APInt *ShOp1;
1065     if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) {
1066       if (ShOp1->ult(ShAmt)) {
1067         unsigned ShlAmt = ShOp1->getZExtValue();
1068         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
1069         if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
1070           // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
1071           auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
1072           NewLShr->setIsExact(I.isExact());
1073           return NewLShr;
1074         }
1075         // (X << C1) >>u C2  --> (X >>u (C2 - C1)) & (-1 >> C2)
1076         Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact());
1077         APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
1078         return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
1079       }
1080       if (ShOp1->ugt(ShAmt)) {
1081         unsigned ShlAmt = ShOp1->getZExtValue();
1082         Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
1083         if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
1084           // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
1085           auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
1086           NewShl->setHasNoUnsignedWrap(true);
1087           return NewShl;
1088         }
1089         // (X << C1) >>u C2  --> X << (C1 - C2) & (-1 >> C2)
1090         Value *NewShl = Builder.CreateShl(X, ShiftDiff);
1091         APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
1092         return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
1093       }
1094       assert(*ShOp1 == ShAmt);
1095       // (X << C) >>u C --> X & (-1 >>u C)
1096       APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
1097       return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
1098     }
1099 
1100     if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) &&
1101         (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
1102       assert(ShAmt < X->getType()->getScalarSizeInBits() &&
1103              "Big shift not simplified to zero?");
1104       // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
1105       Value *NewLShr = Builder.CreateLShr(X, ShAmt);
1106       return new ZExtInst(NewLShr, Ty);
1107     }
1108 
1109     if (match(Op0, m_SExt(m_Value(X))) &&
1110         (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
1111       // Are we moving the sign bit to the low bit and widening with high zeros?
1112       unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
1113       if (ShAmt == BitWidth - 1) {
1114         // lshr (sext i1 X to iN), N-1 --> zext X to iN
1115         if (SrcTyBitWidth == 1)
1116           return new ZExtInst(X, Ty);
1117 
1118         // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
1119         if (Op0->hasOneUse()) {
1120           Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
1121           return new ZExtInst(NewLShr, Ty);
1122         }
1123       }
1124 
1125       // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
1126       if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) {
1127         // The new shift amount can't be more than the narrow source type.
1128         unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1);
1129         Value *AShr = Builder.CreateAShr(X, NewShAmt);
1130         return new ZExtInst(AShr, Ty);
1131       }
1132     }
1133 
1134     // lshr i32 (X -nsw Y), 31 --> zext (X < Y)
1135     Value *Y;
1136     if (ShAmt == BitWidth - 1 &&
1137         match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
1138       return new ZExtInst(Builder.CreateICmpSLT(X, Y), Ty);
1139 
1140     if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
1141       unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
1142       // Oversized shifts are simplified to zero in InstSimplify.
1143       if (AmtSum < BitWidth)
1144         // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
1145         return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
1146     }
1147 
1148     // If the shifted-out value is known-zero, then this is an exact shift.
1149     if (!I.isExact() &&
1150         MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
1151       I.setIsExact();
1152       return &I;
1153     }
1154   }
1155 
1156   // Transform  (x << y) >> y  to  x & (-1 >> y)
1157   Value *X;
1158   if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
1159     Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
1160     Value *Mask = Builder.CreateLShr(AllOnes, Op1);
1161     return BinaryOperator::CreateAnd(Mask, X);
1162   }
1163 
1164   return nullptr;
1165 }
1166 
1167 Instruction *
1168 InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract(
1169     BinaryOperator &OldAShr) {
1170   assert(OldAShr.getOpcode() == Instruction::AShr &&
1171          "Must be called with arithmetic right-shift instruction only.");
1172 
1173   // Check that constant C is a splat of the element-wise bitwidth of V.
1174   auto BitWidthSplat = [](Constant *C, Value *V) {
1175     return match(
1176         C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
1177                               APInt(C->getType()->getScalarSizeInBits(),
1178                                     V->getType()->getScalarSizeInBits())));
1179   };
1180 
1181   // It should look like variable-length sign-extension on the outside:
1182   //   (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits)
1183   Value *NBits;
1184   Instruction *MaybeTrunc;
1185   Constant *C1, *C2;
1186   if (!match(&OldAShr,
1187              m_AShr(m_Shl(m_Instruction(MaybeTrunc),
1188                           m_ZExtOrSelf(m_Sub(m_Constant(C1),
1189                                              m_ZExtOrSelf(m_Value(NBits))))),
1190                     m_ZExtOrSelf(m_Sub(m_Constant(C2),
1191                                        m_ZExtOrSelf(m_Deferred(NBits)))))) ||
1192       !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr))
1193     return nullptr;
1194 
1195   // There may or may not be a truncation after outer two shifts.
1196   Instruction *HighBitExtract;
1197   match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract)));
1198   bool HadTrunc = MaybeTrunc != HighBitExtract;
1199 
1200   // And finally, the innermost part of the pattern must be a right-shift.
1201   Value *X, *NumLowBitsToSkip;
1202   if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip))))
1203     return nullptr;
1204 
1205   // Said right-shift must extract high NBits bits - C0 must be it's bitwidth.
1206   Constant *C0;
1207   if (!match(NumLowBitsToSkip,
1208              m_ZExtOrSelf(
1209                  m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) ||
1210       !BitWidthSplat(C0, HighBitExtract))
1211     return nullptr;
1212 
1213   // Since the NBits is identical for all shifts, if the outermost and
1214   // innermost shifts are identical, then outermost shifts are redundant.
1215   // If we had truncation, do keep it though.
1216   if (HighBitExtract->getOpcode() == OldAShr.getOpcode())
1217     return replaceInstUsesWith(OldAShr, MaybeTrunc);
1218 
1219   // Else, if there was a truncation, then we need to ensure that one
1220   // instruction will go away.
1221   if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
1222     return nullptr;
1223 
1224   // Finally, bypass two innermost shifts, and perform the outermost shift on
1225   // the operands of the innermost shift.
1226   Instruction *NewAShr =
1227       BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip);
1228   NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness.
1229   if (!HadTrunc)
1230     return NewAShr;
1231 
1232   Builder.Insert(NewAShr);
1233   return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
1234 }
1235 
1236 Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
1237   if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
1238                                   SQ.getWithInstruction(&I)))
1239     return replaceInstUsesWith(I, V);
1240 
1241   if (Instruction *X = foldVectorBinop(I))
1242     return X;
1243 
1244   if (Instruction *R = commonShiftTransforms(I))
1245     return R;
1246 
1247   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1248   Type *Ty = I.getType();
1249   unsigned BitWidth = Ty->getScalarSizeInBits();
1250   const APInt *ShAmtAPInt;
1251   if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
1252     unsigned ShAmt = ShAmtAPInt->getZExtValue();
1253 
1254     // If the shift amount equals the difference in width of the destination
1255     // and source scalar types:
1256     // ashr (shl (zext X), C), C --> sext X
1257     Value *X;
1258     if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
1259         ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
1260       return new SExtInst(X, Ty);
1261 
1262     // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
1263     // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
1264     const APInt *ShOp1;
1265     if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
1266         ShOp1->ult(BitWidth)) {
1267       unsigned ShlAmt = ShOp1->getZExtValue();
1268       if (ShlAmt < ShAmt) {
1269         // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
1270         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
1271         auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
1272         NewAShr->setIsExact(I.isExact());
1273         return NewAShr;
1274       }
1275       if (ShlAmt > ShAmt) {
1276         // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
1277         Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
1278         auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
1279         NewShl->setHasNoSignedWrap(true);
1280         return NewShl;
1281       }
1282     }
1283 
1284     if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
1285         ShOp1->ult(BitWidth)) {
1286       unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
1287       // Oversized arithmetic shifts replicate the sign bit.
1288       AmtSum = std::min(AmtSum, BitWidth - 1);
1289       // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
1290       return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
1291     }
1292 
1293     if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
1294         (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
1295       // ashr (sext X), C --> sext (ashr X, C')
1296       Type *SrcTy = X->getType();
1297       ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1);
1298       Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt));
1299       return new SExtInst(NewSh, Ty);
1300     }
1301 
1302     // ashr i32 (X -nsw Y), 31 --> sext (X < Y)
1303     Value *Y;
1304     if (ShAmt == BitWidth - 1 &&
1305         match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
1306       return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
1307 
1308     // If the shifted-out value is known-zero, then this is an exact shift.
1309     if (!I.isExact() &&
1310         MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
1311       I.setIsExact();
1312       return &I;
1313     }
1314   }
1315 
1316   if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I))
1317     return R;
1318 
1319   // See if we can turn a signed shr into an unsigned shr.
1320   if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
1321     return BinaryOperator::CreateLShr(Op0, Op1);
1322 
1323   return nullptr;
1324 }
1325