1 //===- InstCombineShifts.cpp ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the visitShl, visitLShr, and visitAShr functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "InstCombineInternal.h" 14 #include "llvm/Analysis/ConstantFolding.h" 15 #include "llvm/Analysis/InstructionSimplify.h" 16 #include "llvm/IR/IntrinsicInst.h" 17 #include "llvm/IR/PatternMatch.h" 18 using namespace llvm; 19 using namespace PatternMatch; 20 21 #define DEBUG_TYPE "instcombine" 22 23 // Given pattern: 24 // (x shiftopcode Q) shiftopcode K 25 // we should rewrite it as 26 // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) 27 // This is valid for any shift, but they must be identical. 28 static Instruction * 29 reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, 30 const SimplifyQuery &SQ) { 31 // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1 32 Value *X, *ShAmt1, *ShAmt0; 33 Instruction *Sh1; 34 if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)), 35 m_Instruction(Sh1)), 36 m_Value(ShAmt0)))) 37 return nullptr; 38 39 // The shift opcodes must be identical. 40 Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); 41 if (ShiftOpcode != Sh1->getOpcode()) 42 return nullptr; 43 // Can we fold (ShAmt0+ShAmt1) ? 44 Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1, 45 SQ.getWithInstruction(Sh0)); 46 if (!NewShAmt) 47 return nullptr; // Did not simplify. 48 // Is the new shift amount smaller than the bit width? 49 // FIXME: could also rely on ConstantRange. 50 unsigned BitWidth = X->getType()->getScalarSizeInBits(); 51 if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, 52 APInt(BitWidth, BitWidth)))) 53 return nullptr; 54 // All good, we can do this fold. 55 BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); 56 // If both of the original shifts had the same flag set, preserve the flag. 57 if (ShiftOpcode == Instruction::BinaryOps::Shl) { 58 NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && 59 Sh1->hasNoUnsignedWrap()); 60 NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && 61 Sh1->hasNoSignedWrap()); 62 } else { 63 NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); 64 } 65 return NewShift; 66 } 67 68 Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { 69 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 70 assert(Op0->getType() == Op1->getType()); 71 72 // See if we can fold away this shift. 73 if (SimplifyDemandedInstructionBits(I)) 74 return &I; 75 76 // Try to fold constant and into select arguments. 77 if (isa<Constant>(Op0)) 78 if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) 79 if (Instruction *R = FoldOpIntoSelect(I, SI)) 80 return R; 81 82 if (Constant *CUI = dyn_cast<Constant>(Op1)) 83 if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) 84 return Res; 85 86 if (Instruction *NewShift = 87 reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)) 88 return NewShift; 89 90 // (C1 shift (A add C2)) -> (C1 shift C2) shift A) 91 // iff A and C2 are both positive. 92 Value *A; 93 Constant *C; 94 if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) 95 if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) && 96 isKnownNonNegative(C, DL, 0, &AC, &I, &DT)) 97 return BinaryOperator::Create( 98 I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A); 99 100 // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. 101 // Because shifts by negative values (which could occur if A were negative) 102 // are undefined. 103 const APInt *B; 104 if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { 105 // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't 106 // demand the sign bit (and many others) here?? 107 Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1), 108 Op1->getName()); 109 I.setOperand(1, Rem); 110 return &I; 111 } 112 113 return nullptr; 114 } 115 116 /// Return true if we can simplify two logical (either left or right) shifts 117 /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. 118 static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, 119 Instruction *InnerShift, InstCombiner &IC, 120 Instruction *CxtI) { 121 assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); 122 123 // We need constant scalar or constant splat shifts. 124 const APInt *InnerShiftConst; 125 if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst))) 126 return false; 127 128 // Two logical shifts in the same direction: 129 // shl (shl X, C1), C2 --> shl X, C1 + C2 130 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 131 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; 132 if (IsInnerShl == IsOuterShl) 133 return true; 134 135 // Equal shift amounts in opposite directions become bitwise 'and': 136 // lshr (shl X, C), C --> and X, C' 137 // shl (lshr X, C), C --> and X, C' 138 if (*InnerShiftConst == OuterShAmt) 139 return true; 140 141 // If the 2nd shift is bigger than the 1st, we can fold: 142 // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3 143 // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3 144 // but it isn't profitable unless we know the and'd out bits are already zero. 145 // Also, check that the inner shift is valid (less than the type width) or 146 // we'll crash trying to produce the bit mask for the 'and'. 147 unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); 148 if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) { 149 unsigned InnerShAmt = InnerShiftConst->getZExtValue(); 150 unsigned MaskShift = 151 IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; 152 APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; 153 if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI)) 154 return true; 155 } 156 157 return false; 158 } 159 160 /// See if we can compute the specified value, but shifted logically to the left 161 /// or right by some number of bits. This should return true if the expression 162 /// can be computed for the same cost as the current expression tree. This is 163 /// used to eliminate extraneous shifting from things like: 164 /// %C = shl i128 %A, 64 165 /// %D = shl i128 %B, 96 166 /// %E = or i128 %C, %D 167 /// %F = lshr i128 %E, 64 168 /// where the client will ask if E can be computed shifted right by 64-bits. If 169 /// this succeeds, getShiftedValue() will be called to produce the value. 170 static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, 171 InstCombiner &IC, Instruction *CxtI) { 172 // We can always evaluate constants shifted. 173 if (isa<Constant>(V)) 174 return true; 175 176 Instruction *I = dyn_cast<Instruction>(V); 177 if (!I) return false; 178 179 // If this is the opposite shift, we can directly reuse the input of the shift 180 // if the needed bits are already zero in the input. This allows us to reuse 181 // the value which means that we don't care if the shift has multiple uses. 182 // TODO: Handle opposite shift by exact value. 183 ConstantInt *CI = nullptr; 184 if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || 185 (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { 186 if (CI->getValue() == NumBits) { 187 // TODO: Check that the input bits are already zero with MaskedValueIsZero 188 #if 0 189 // If this is a truncate of a logical shr, we can truncate it to a smaller 190 // lshr iff we know that the bits we would otherwise be shifting in are 191 // already zeros. 192 uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); 193 uint32_t BitWidth = Ty->getScalarSizeInBits(); 194 if (MaskedValueIsZero(I->getOperand(0), 195 APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && 196 CI->getLimitedValue(BitWidth) < BitWidth) { 197 return CanEvaluateTruncated(I->getOperand(0), Ty); 198 } 199 #endif 200 201 } 202 } 203 204 // We can't mutate something that has multiple uses: doing so would 205 // require duplicating the instruction in general, which isn't profitable. 206 if (!I->hasOneUse()) return false; 207 208 switch (I->getOpcode()) { 209 default: return false; 210 case Instruction::And: 211 case Instruction::Or: 212 case Instruction::Xor: 213 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. 214 return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && 215 canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); 216 217 case Instruction::Shl: 218 case Instruction::LShr: 219 return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); 220 221 case Instruction::Select: { 222 SelectInst *SI = cast<SelectInst>(I); 223 Value *TrueVal = SI->getTrueValue(); 224 Value *FalseVal = SI->getFalseValue(); 225 return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && 226 canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); 227 } 228 case Instruction::PHI: { 229 // We can change a phi if we can change all operands. Note that we never 230 // get into trouble with cyclic PHIs here because we only consider 231 // instructions with a single use. 232 PHINode *PN = cast<PHINode>(I); 233 for (Value *IncValue : PN->incoming_values()) 234 if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) 235 return false; 236 return true; 237 } 238 } 239 } 240 241 /// Fold OuterShift (InnerShift X, C1), C2. 242 /// See canEvaluateShiftedShift() for the constraints on these instructions. 243 static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, 244 bool IsOuterShl, 245 InstCombiner::BuilderTy &Builder) { 246 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; 247 Type *ShType = InnerShift->getType(); 248 unsigned TypeWidth = ShType->getScalarSizeInBits(); 249 250 // We only accept shifts-by-a-constant in canEvaluateShifted(). 251 const APInt *C1; 252 match(InnerShift->getOperand(1), m_APInt(C1)); 253 unsigned InnerShAmt = C1->getZExtValue(); 254 255 // Change the shift amount and clear the appropriate IR flags. 256 auto NewInnerShift = [&](unsigned ShAmt) { 257 InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); 258 if (IsInnerShl) { 259 InnerShift->setHasNoUnsignedWrap(false); 260 InnerShift->setHasNoSignedWrap(false); 261 } else { 262 InnerShift->setIsExact(false); 263 } 264 return InnerShift; 265 }; 266 267 // Two logical shifts in the same direction: 268 // shl (shl X, C1), C2 --> shl X, C1 + C2 269 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 270 if (IsInnerShl == IsOuterShl) { 271 // If this is an oversized composite shift, then unsigned shifts get 0. 272 if (InnerShAmt + OuterShAmt >= TypeWidth) 273 return Constant::getNullValue(ShType); 274 275 return NewInnerShift(InnerShAmt + OuterShAmt); 276 } 277 278 // Equal shift amounts in opposite directions become bitwise 'and': 279 // lshr (shl X, C), C --> and X, C' 280 // shl (lshr X, C), C --> and X, C' 281 if (InnerShAmt == OuterShAmt) { 282 APInt Mask = IsInnerShl 283 ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) 284 : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); 285 Value *And = Builder.CreateAnd(InnerShift->getOperand(0), 286 ConstantInt::get(ShType, Mask)); 287 if (auto *AndI = dyn_cast<Instruction>(And)) { 288 AndI->moveBefore(InnerShift); 289 AndI->takeName(InnerShift); 290 } 291 return And; 292 } 293 294 assert(InnerShAmt > OuterShAmt && 295 "Unexpected opposite direction logical shift pair"); 296 297 // In general, we would need an 'and' for this transform, but 298 // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. 299 // lshr (shl X, C1), C2 --> shl X, C1 - C2 300 // shl (lshr X, C1), C2 --> lshr X, C1 - C2 301 return NewInnerShift(InnerShAmt - OuterShAmt); 302 } 303 304 /// When canEvaluateShifted() returns true for an expression, this function 305 /// inserts the new computation that produces the shifted value. 306 static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, 307 InstCombiner &IC, const DataLayout &DL) { 308 // We can always evaluate constants shifted. 309 if (Constant *C = dyn_cast<Constant>(V)) { 310 if (isLeftShift) 311 V = IC.Builder.CreateShl(C, NumBits); 312 else 313 V = IC.Builder.CreateLShr(C, NumBits); 314 // If we got a constantexpr back, try to simplify it with TD info. 315 if (auto *C = dyn_cast<Constant>(V)) 316 if (auto *FoldedC = 317 ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) 318 V = FoldedC; 319 return V; 320 } 321 322 Instruction *I = cast<Instruction>(V); 323 IC.Worklist.Add(I); 324 325 switch (I->getOpcode()) { 326 default: llvm_unreachable("Inconsistency with CanEvaluateShifted"); 327 case Instruction::And: 328 case Instruction::Or: 329 case Instruction::Xor: 330 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. 331 I->setOperand( 332 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); 333 I->setOperand( 334 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); 335 return I; 336 337 case Instruction::Shl: 338 case Instruction::LShr: 339 return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift, 340 IC.Builder); 341 342 case Instruction::Select: 343 I->setOperand( 344 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); 345 I->setOperand( 346 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); 347 return I; 348 case Instruction::PHI: { 349 // We can change a phi if we can change all operands. Note that we never 350 // get into trouble with cyclic PHIs here because we only consider 351 // instructions with a single use. 352 PHINode *PN = cast<PHINode>(I); 353 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 354 PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits, 355 isLeftShift, IC, DL)); 356 return PN; 357 } 358 } 359 } 360 361 // If this is a bitwise operator or add with a constant RHS we might be able 362 // to pull it through a shift. 363 static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, 364 BinaryOperator *BO) { 365 switch (BO->getOpcode()) { 366 default: 367 return false; // Do not perform transform! 368 case Instruction::Add: 369 return Shift.getOpcode() == Instruction::Shl; 370 case Instruction::Or: 371 case Instruction::Xor: 372 case Instruction::And: 373 return true; 374 } 375 } 376 377 Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, 378 BinaryOperator &I) { 379 bool isLeftShift = I.getOpcode() == Instruction::Shl; 380 381 const APInt *Op1C; 382 if (!match(Op1, m_APInt(Op1C))) 383 return nullptr; 384 385 // See if we can propagate this shift into the input, this covers the trivial 386 // cast of lshr(shl(x,c1),c2) as well as other more complex cases. 387 if (I.getOpcode() != Instruction::AShr && 388 canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { 389 LLVM_DEBUG( 390 dbgs() << "ICE: GetShiftedValue propagating shift through expression" 391 " to eliminate shift:\n IN: " 392 << *Op0 << "\n SH: " << I << "\n"); 393 394 return replaceInstUsesWith( 395 I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); 396 } 397 398 // See if we can simplify any instructions used by the instruction whose sole 399 // purpose is to compute bits we don't care about. 400 unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); 401 402 assert(!Op1C->uge(TypeBits) && 403 "Shift over the type width should have been removed already"); 404 405 if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) 406 return FoldedShift; 407 408 // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) 409 if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) { 410 Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0)); 411 // If 'shift2' is an ashr, we would have to get the sign bit into a funny 412 // place. Don't try to do this transformation in this case. Also, we 413 // require that the input operand is a shift-by-constant so that we have 414 // confidence that the shifts will get folded together. We could do this 415 // xform in more cases, but it is unlikely to be profitable. 416 if (TrOp && I.isLogicalShift() && TrOp->isShift() && 417 isa<ConstantInt>(TrOp->getOperand(1))) { 418 // Okay, we'll do this xform. Make the shift of shift. 419 Constant *ShAmt = 420 ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType()); 421 // (shift2 (shift1 & 0x00FF), c2) 422 Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); 423 424 // For logical shifts, the truncation has the effect of making the high 425 // part of the register be zeros. Emulate this by inserting an AND to 426 // clear the top bits as needed. This 'and' will usually be zapped by 427 // other xforms later if dead. 428 unsigned SrcSize = TrOp->getType()->getScalarSizeInBits(); 429 unsigned DstSize = TI->getType()->getScalarSizeInBits(); 430 APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize)); 431 432 // The mask we constructed says what the trunc would do if occurring 433 // between the shifts. We want to know the effect *after* the second 434 // shift. We know that it is a logical shift by a constant, so adjust the 435 // mask as appropriate. 436 if (I.getOpcode() == Instruction::Shl) 437 MaskV <<= Op1C->getZExtValue(); 438 else { 439 assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); 440 MaskV.lshrInPlace(Op1C->getZExtValue()); 441 } 442 443 // shift1 & 0x00FF 444 Value *And = Builder.CreateAnd(NSh, 445 ConstantInt::get(I.getContext(), MaskV), 446 TI->getName()); 447 448 // Return the value truncated to the interesting size. 449 return new TruncInst(And, I.getType()); 450 } 451 } 452 453 if (Op0->hasOneUse()) { 454 if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { 455 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) 456 Value *V1, *V2; 457 ConstantInt *CC; 458 switch (Op0BO->getOpcode()) { 459 default: break; 460 case Instruction::Add: 461 case Instruction::And: 462 case Instruction::Or: 463 case Instruction::Xor: { 464 // These operators commute. 465 // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) 466 if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && 467 match(Op0BO->getOperand(1), m_Shr(m_Value(V1), 468 m_Specific(Op1)))) { 469 Value *YS = // (Y << C) 470 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); 471 // (X + (Y << C)) 472 Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, 473 Op0BO->getOperand(1)->getName()); 474 unsigned Op1Val = Op1C->getLimitedValue(TypeBits); 475 476 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); 477 Constant *Mask = ConstantInt::get(I.getContext(), Bits); 478 if (VectorType *VT = dyn_cast<VectorType>(X->getType())) 479 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); 480 return BinaryOperator::CreateAnd(X, Mask); 481 } 482 483 // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) 484 Value *Op0BOOp1 = Op0BO->getOperand(1); 485 if (isLeftShift && Op0BOOp1->hasOneUse() && 486 match(Op0BOOp1, 487 m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), 488 m_ConstantInt(CC)))) { 489 Value *YS = // (Y << C) 490 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); 491 // X & (CC << C) 492 Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), 493 V1->getName()+".mask"); 494 return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); 495 } 496 LLVM_FALLTHROUGH; 497 } 498 499 case Instruction::Sub: { 500 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) 501 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && 502 match(Op0BO->getOperand(0), m_Shr(m_Value(V1), 503 m_Specific(Op1)))) { 504 Value *YS = // (Y << C) 505 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); 506 // (X + (Y << C)) 507 Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, 508 Op0BO->getOperand(0)->getName()); 509 unsigned Op1Val = Op1C->getLimitedValue(TypeBits); 510 511 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); 512 Constant *Mask = ConstantInt::get(I.getContext(), Bits); 513 if (VectorType *VT = dyn_cast<VectorType>(X->getType())) 514 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); 515 return BinaryOperator::CreateAnd(X, Mask); 516 } 517 518 // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) 519 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && 520 match(Op0BO->getOperand(0), 521 m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))), 522 m_ConstantInt(CC))) && V2 == Op1) { 523 Value *YS = // (Y << C) 524 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); 525 // X & (CC << C) 526 Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), 527 V1->getName()+".mask"); 528 529 return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); 530 } 531 532 break; 533 } 534 } 535 536 537 // If the operand is a bitwise operator with a constant RHS, and the 538 // shift is the only use, we can pull it out of the shift. 539 const APInt *Op0C; 540 if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { 541 if (canShiftBinOpWithConstantRHS(I, Op0BO)) { 542 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 543 cast<Constant>(Op0BO->getOperand(1)), Op1); 544 545 Value *NewShift = 546 Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); 547 NewShift->takeName(Op0BO); 548 549 return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, 550 NewRHS); 551 } 552 } 553 554 // If the operand is a subtract with a constant LHS, and the shift 555 // is the only use, we can pull it out of the shift. 556 // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) 557 if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && 558 match(Op0BO->getOperand(0), m_APInt(Op0C))) { 559 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 560 cast<Constant>(Op0BO->getOperand(0)), Op1); 561 562 Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); 563 NewShift->takeName(Op0BO); 564 565 return BinaryOperator::CreateSub(NewRHS, NewShift); 566 } 567 } 568 569 // If we have a select that conditionally executes some binary operator, 570 // see if we can pull it the select and operator through the shift. 571 // 572 // For example, turning: 573 // shl (select C, (add X, C1), X), C2 574 // Into: 575 // Y = shl X, C2 576 // select C, (add Y, C1 << C2), Y 577 Value *Cond; 578 BinaryOperator *TBO; 579 Value *FalseVal; 580 if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), 581 m_Value(FalseVal)))) { 582 const APInt *C; 583 if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && 584 match(TBO->getOperand(1), m_APInt(C)) && 585 canShiftBinOpWithConstantRHS(I, TBO)) { 586 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 587 cast<Constant>(TBO->getOperand(1)), Op1); 588 589 Value *NewShift = 590 Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); 591 Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, 592 NewRHS); 593 return SelectInst::Create(Cond, NewOp, NewShift); 594 } 595 } 596 597 BinaryOperator *FBO; 598 Value *TrueVal; 599 if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), 600 m_OneUse(m_BinOp(FBO))))) { 601 const APInt *C; 602 if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && 603 match(FBO->getOperand(1), m_APInt(C)) && 604 canShiftBinOpWithConstantRHS(I, FBO)) { 605 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), 606 cast<Constant>(FBO->getOperand(1)), Op1); 607 608 Value *NewShift = 609 Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); 610 Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, 611 NewRHS); 612 return SelectInst::Create(Cond, NewShift, NewOp); 613 } 614 } 615 } 616 617 return nullptr; 618 } 619 620 Instruction *InstCombiner::visitShl(BinaryOperator &I) { 621 if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), 622 I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), 623 SQ.getWithInstruction(&I))) 624 return replaceInstUsesWith(I, V); 625 626 if (Instruction *X = foldVectorBinop(I)) 627 return X; 628 629 if (Instruction *V = commonShiftTransforms(I)) 630 return V; 631 632 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 633 Type *Ty = I.getType(); 634 unsigned BitWidth = Ty->getScalarSizeInBits(); 635 636 const APInt *ShAmtAPInt; 637 if (match(Op1, m_APInt(ShAmtAPInt))) { 638 unsigned ShAmt = ShAmtAPInt->getZExtValue(); 639 unsigned BitWidth = Ty->getScalarSizeInBits(); 640 641 // shl (zext X), ShAmt --> zext (shl X, ShAmt) 642 // This is only valid if X would have zeros shifted out. 643 Value *X; 644 if (match(Op0, m_ZExt(m_Value(X)))) { 645 unsigned SrcWidth = X->getType()->getScalarSizeInBits(); 646 if (ShAmt < SrcWidth && 647 MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) 648 return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); 649 } 650 651 // (X >> C) << C --> X & (-1 << C) 652 if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { 653 APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); 654 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); 655 } 656 657 // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine) 658 // needs a few fixes for the rotate pattern recognition first. 659 const APInt *ShOp1; 660 if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { 661 unsigned ShrAmt = ShOp1->getZExtValue(); 662 if (ShrAmt < ShAmt) { 663 // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) 664 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); 665 auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); 666 NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); 667 NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); 668 return NewShl; 669 } 670 if (ShrAmt > ShAmt) { 671 // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) 672 Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); 673 auto *NewShr = BinaryOperator::Create( 674 cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); 675 NewShr->setIsExact(true); 676 return NewShr; 677 } 678 } 679 680 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { 681 unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); 682 // Oversized shifts are simplified to zero in InstSimplify. 683 if (AmtSum < BitWidth) 684 // (X << C1) << C2 --> X << (C1 + C2) 685 return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); 686 } 687 688 // If the shifted-out value is known-zero, then this is a NUW shift. 689 if (!I.hasNoUnsignedWrap() && 690 MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { 691 I.setHasNoUnsignedWrap(); 692 return &I; 693 } 694 695 // If the shifted-out value is all signbits, then this is a NSW shift. 696 if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { 697 I.setHasNoSignedWrap(); 698 return &I; 699 } 700 } 701 702 // Transform (x >> y) << y to x & (-1 << y) 703 // Valid for any type of right-shift. 704 Value *X; 705 if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { 706 Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); 707 Value *Mask = Builder.CreateShl(AllOnes, Op1); 708 return BinaryOperator::CreateAnd(Mask, X); 709 } 710 711 Constant *C1; 712 if (match(Op1, m_Constant(C1))) { 713 Constant *C2; 714 Value *X; 715 // (C2 << X) << C1 --> (C2 << C1) << X 716 if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X))))) 717 return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X); 718 719 // (X * C2) << C1 --> X * (C2 << C1) 720 if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) 721 return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); 722 } 723 724 // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 725 if (match(Op0, m_One()) && 726 match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) 727 return BinaryOperator::CreateLShr( 728 ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); 729 730 return nullptr; 731 } 732 733 Instruction *InstCombiner::visitLShr(BinaryOperator &I) { 734 if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), 735 SQ.getWithInstruction(&I))) 736 return replaceInstUsesWith(I, V); 737 738 if (Instruction *X = foldVectorBinop(I)) 739 return X; 740 741 if (Instruction *R = commonShiftTransforms(I)) 742 return R; 743 744 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 745 Type *Ty = I.getType(); 746 const APInt *ShAmtAPInt; 747 if (match(Op1, m_APInt(ShAmtAPInt))) { 748 unsigned ShAmt = ShAmtAPInt->getZExtValue(); 749 unsigned BitWidth = Ty->getScalarSizeInBits(); 750 auto *II = dyn_cast<IntrinsicInst>(Op0); 751 if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && 752 (II->getIntrinsicID() == Intrinsic::ctlz || 753 II->getIntrinsicID() == Intrinsic::cttz || 754 II->getIntrinsicID() == Intrinsic::ctpop)) { 755 // ctlz.i32(x)>>5 --> zext(x == 0) 756 // cttz.i32(x)>>5 --> zext(x == 0) 757 // ctpop.i32(x)>>5 --> zext(x == -1) 758 bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop; 759 Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0); 760 Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS); 761 return new ZExtInst(Cmp, Ty); 762 } 763 764 Value *X; 765 const APInt *ShOp1; 766 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { 767 if (ShOp1->ult(ShAmt)) { 768 unsigned ShlAmt = ShOp1->getZExtValue(); 769 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); 770 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { 771 // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) 772 auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); 773 NewLShr->setIsExact(I.isExact()); 774 return NewLShr; 775 } 776 // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) 777 Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); 778 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); 779 return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); 780 } 781 if (ShOp1->ugt(ShAmt)) { 782 unsigned ShlAmt = ShOp1->getZExtValue(); 783 Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); 784 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { 785 // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) 786 auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); 787 NewShl->setHasNoUnsignedWrap(true); 788 return NewShl; 789 } 790 // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) 791 Value *NewShl = Builder.CreateShl(X, ShiftDiff); 792 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); 793 return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); 794 } 795 assert(*ShOp1 == ShAmt); 796 // (X << C) >>u C --> X & (-1 >>u C) 797 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); 798 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); 799 } 800 801 if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && 802 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { 803 assert(ShAmt < X->getType()->getScalarSizeInBits() && 804 "Big shift not simplified to zero?"); 805 // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN 806 Value *NewLShr = Builder.CreateLShr(X, ShAmt); 807 return new ZExtInst(NewLShr, Ty); 808 } 809 810 if (match(Op0, m_SExt(m_Value(X))) && 811 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { 812 // Are we moving the sign bit to the low bit and widening with high zeros? 813 unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); 814 if (ShAmt == BitWidth - 1) { 815 // lshr (sext i1 X to iN), N-1 --> zext X to iN 816 if (SrcTyBitWidth == 1) 817 return new ZExtInst(X, Ty); 818 819 // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN 820 if (Op0->hasOneUse()) { 821 Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); 822 return new ZExtInst(NewLShr, Ty); 823 } 824 } 825 826 // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN 827 if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) { 828 // The new shift amount can't be more than the narrow source type. 829 unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1); 830 Value *AShr = Builder.CreateAShr(X, NewShAmt); 831 return new ZExtInst(AShr, Ty); 832 } 833 } 834 835 if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { 836 unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); 837 // Oversized shifts are simplified to zero in InstSimplify. 838 if (AmtSum < BitWidth) 839 // (X >>u C1) >>u C2 --> X >>u (C1 + C2) 840 return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); 841 } 842 843 // If the shifted-out value is known-zero, then this is an exact shift. 844 if (!I.isExact() && 845 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { 846 I.setIsExact(); 847 return &I; 848 } 849 } 850 851 // Transform (x << y) >> y to x & (-1 >> y) 852 Value *X; 853 if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { 854 Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); 855 Value *Mask = Builder.CreateLShr(AllOnes, Op1); 856 return BinaryOperator::CreateAnd(Mask, X); 857 } 858 859 return nullptr; 860 } 861 862 Instruction *InstCombiner::visitAShr(BinaryOperator &I) { 863 if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), 864 SQ.getWithInstruction(&I))) 865 return replaceInstUsesWith(I, V); 866 867 if (Instruction *X = foldVectorBinop(I)) 868 return X; 869 870 if (Instruction *R = commonShiftTransforms(I)) 871 return R; 872 873 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 874 Type *Ty = I.getType(); 875 unsigned BitWidth = Ty->getScalarSizeInBits(); 876 const APInt *ShAmtAPInt; 877 if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) { 878 unsigned ShAmt = ShAmtAPInt->getZExtValue(); 879 880 // If the shift amount equals the difference in width of the destination 881 // and source scalar types: 882 // ashr (shl (zext X), C), C --> sext X 883 Value *X; 884 if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) && 885 ShAmt == BitWidth - X->getType()->getScalarSizeInBits()) 886 return new SExtInst(X, Ty); 887 888 // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, 889 // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. 890 const APInt *ShOp1; 891 if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) && 892 ShOp1->ult(BitWidth)) { 893 unsigned ShlAmt = ShOp1->getZExtValue(); 894 if (ShlAmt < ShAmt) { 895 // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) 896 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); 897 auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff); 898 NewAShr->setIsExact(I.isExact()); 899 return NewAShr; 900 } 901 if (ShlAmt > ShAmt) { 902 // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) 903 Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); 904 auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff); 905 NewShl->setHasNoSignedWrap(true); 906 return NewShl; 907 } 908 } 909 910 if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) && 911 ShOp1->ult(BitWidth)) { 912 unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); 913 // Oversized arithmetic shifts replicate the sign bit. 914 AmtSum = std::min(AmtSum, BitWidth - 1); 915 // (X >>s C1) >>s C2 --> X >>s (C1 + C2) 916 return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); 917 } 918 919 if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) && 920 (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) { 921 // ashr (sext X), C --> sext (ashr X, C') 922 Type *SrcTy = X->getType(); 923 ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1); 924 Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt)); 925 return new SExtInst(NewSh, Ty); 926 } 927 928 // If the shifted-out value is known-zero, then this is an exact shift. 929 if (!I.isExact() && 930 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { 931 I.setIsExact(); 932 return &I; 933 } 934 } 935 936 // See if we can turn a signed shr into an unsigned shr. 937 if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) 938 return BinaryOperator::CreateLShr(Op0, Op1); 939 940 return nullptr; 941 } 942