1 //===- InstCombineSimplifyDemanded.cpp ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains logic for simplifying instructions based on information 10 // about how they are used. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "InstCombineInternal.h" 15 #include "llvm/Analysis/ValueTracking.h" 16 #include "llvm/IR/GetElementPtrTypeIterator.h" 17 #include "llvm/IR/IntrinsicInst.h" 18 #include "llvm/IR/PatternMatch.h" 19 #include "llvm/Support/KnownBits.h" 20 #include "llvm/Transforms/InstCombine/InstCombiner.h" 21 22 using namespace llvm; 23 using namespace llvm::PatternMatch; 24 25 #define DEBUG_TYPE "instcombine" 26 27 static cl::opt<bool> 28 VerifyKnownBits("instcombine-verify-known-bits", 29 cl::desc("Verify that computeKnownBits() and " 30 "SimplifyDemandedBits() are consistent"), 31 cl::Hidden, cl::init(false)); 32 33 static cl::opt<unsigned> SimplifyDemandedVectorEltsDepthLimit( 34 "instcombine-simplify-vector-elts-depth", 35 cl::desc( 36 "Depth limit when simplifying vector instructions and their operands"), 37 cl::Hidden, cl::init(10)); 38 39 /// Check to see if the specified operand of the specified instruction is a 40 /// constant integer. If so, check to see if there are any bits set in the 41 /// constant that are not demanded. If so, shrink the constant and return true. 42 static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, 43 const APInt &Demanded) { 44 assert(I && "No instruction?"); 45 assert(OpNo < I->getNumOperands() && "Operand index too large"); 46 47 // The operand must be a constant integer or splat integer. 48 Value *Op = I->getOperand(OpNo); 49 const APInt *C; 50 if (!match(Op, m_APInt(C))) 51 return false; 52 53 // If there are no bits set that aren't demanded, nothing to do. 54 if (C->isSubsetOf(Demanded)) 55 return false; 56 57 // This instruction is producing bits that are not demanded. Shrink the RHS. 58 I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded)); 59 60 return true; 61 } 62 63 /// Returns the bitwidth of the given scalar or pointer type. For vector types, 64 /// returns the element type's bitwidth. 65 static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { 66 if (unsigned BitWidth = Ty->getScalarSizeInBits()) 67 return BitWidth; 68 69 return DL.getPointerTypeSizeInBits(Ty); 70 } 71 72 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if 73 /// the instruction has any properties that allow us to simplify its operands. 74 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst, 75 KnownBits &Known) { 76 APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth())); 77 Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 78 SQ.getWithInstruction(&Inst)); 79 if (!V) return false; 80 if (V == &Inst) return true; 81 replaceInstUsesWith(Inst, V); 82 return true; 83 } 84 85 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if 86 /// the instruction has any properties that allow us to simplify its operands. 87 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { 88 KnownBits Known(getBitWidth(Inst.getType(), DL)); 89 return SimplifyDemandedInstructionBits(Inst, Known); 90 } 91 92 /// This form of SimplifyDemandedBits simplifies the specified instruction 93 /// operand if possible, updating it in place. It returns true if it made any 94 /// change and false otherwise. 95 bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, 96 const APInt &DemandedMask, 97 KnownBits &Known, 98 const SimplifyQuery &Q, 99 unsigned Depth) { 100 Use &U = I->getOperandUse(OpNo); 101 Value *V = U.get(); 102 if (isa<Constant>(V)) { 103 llvm::computeKnownBits(V, Known, Q, Depth); 104 return false; 105 } 106 107 Known.resetAll(); 108 if (DemandedMask.isZero()) { 109 // Not demanding any bits from V. 110 replaceUse(U, UndefValue::get(V->getType())); 111 return true; 112 } 113 114 Instruction *VInst = dyn_cast<Instruction>(V); 115 if (!VInst) { 116 llvm::computeKnownBits(V, Known, Q, Depth); 117 return false; 118 } 119 120 if (Depth == MaxAnalysisRecursionDepth) 121 return false; 122 123 Value *NewVal; 124 if (VInst->hasOneUse()) { 125 // If the instruction has one use, we can directly simplify it. 126 NewVal = SimplifyDemandedUseBits(VInst, DemandedMask, Known, Q, Depth); 127 } else { 128 // If there are multiple uses of this instruction, then we can simplify 129 // VInst to some other value, but not modify the instruction. 130 NewVal = 131 SimplifyMultipleUseDemandedBits(VInst, DemandedMask, Known, Q, Depth); 132 } 133 if (!NewVal) return false; 134 if (Instruction* OpInst = dyn_cast<Instruction>(U)) 135 salvageDebugInfo(*OpInst); 136 137 replaceUse(U, NewVal); 138 return true; 139 } 140 141 /// This function attempts to replace V with a simpler value based on the 142 /// demanded bits. When this function is called, it is known that only the bits 143 /// set in DemandedMask of the result of V are ever used downstream. 144 /// Consequently, depending on the mask and V, it may be possible to replace V 145 /// with a constant or one of its operands. In such cases, this function does 146 /// the replacement and returns true. In all other cases, it returns false after 147 /// analyzing the expression and setting KnownOne and known to be one in the 148 /// expression. Known.Zero contains all the bits that are known to be zero in 149 /// the expression. These are provided to potentially allow the caller (which 150 /// might recursively be SimplifyDemandedBits itself) to simplify the 151 /// expression. 152 /// Known.One and Known.Zero always follow the invariant that: 153 /// Known.One & Known.Zero == 0. 154 /// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero 155 /// are accurate even for bits not in DemandedMask. Note 156 /// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all 157 /// be the same. 158 /// 159 /// This returns null if it did not change anything and it permits no 160 /// simplification. This returns V itself if it did some simplification of V's 161 /// operands based on the information about what bits are demanded. This returns 162 /// some other non-null value if it found out that V is equal to another value 163 /// in the context where the specified bits are demanded, but not for all users. 164 Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, 165 const APInt &DemandedMask, 166 KnownBits &Known, 167 const SimplifyQuery &Q, 168 unsigned Depth) { 169 assert(I != nullptr && "Null pointer of Value???"); 170 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); 171 uint32_t BitWidth = DemandedMask.getBitWidth(); 172 Type *VTy = I->getType(); 173 assert( 174 (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) && 175 Known.getBitWidth() == BitWidth && 176 "Value *V, DemandedMask and Known must have same BitWidth"); 177 178 KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); 179 180 // Update flags after simplifying an operand based on the fact that some high 181 // order bits are not demanded. 182 auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I, 183 unsigned NLZ) { 184 if (NLZ > 0) { 185 // Disable the nsw and nuw flags here: We can no longer guarantee that 186 // we won't wrap after simplification. Removing the nsw/nuw flags is 187 // legal here because the top bit is not demanded. 188 I->setHasNoSignedWrap(false); 189 I->setHasNoUnsignedWrap(false); 190 } 191 return I; 192 }; 193 194 // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care 195 // about the high bits of the operands. 196 auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { 197 unsigned NLZ = DemandedMask.countl_zero(); 198 // Right fill the mask of bits for the operands to demand the most 199 // significant bit and all those below it. 200 DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 201 if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || 202 SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Q, Depth + 1) || 203 ShrinkDemandedConstant(I, 1, DemandedFromOps) || 204 SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Q, Depth + 1)) { 205 disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 206 return true; 207 } 208 return false; 209 }; 210 211 switch (I->getOpcode()) { 212 default: 213 llvm::computeKnownBits(I, Known, Q, Depth); 214 break; 215 case Instruction::And: { 216 // If either the LHS or the RHS are Zero, the result is zero. 217 if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Q, Depth + 1) || 218 SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown, Q, 219 Depth + 1)) 220 return I; 221 222 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 223 Q, Depth); 224 225 // If the client is only demanding bits that we know, return the known 226 // constant. 227 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 228 return Constant::getIntegerValue(VTy, Known.One); 229 230 // If all of the demanded bits are known 1 on one side, return the other. 231 // These bits cannot contribute to the result of the 'and'. 232 if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) 233 return I->getOperand(0); 234 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) 235 return I->getOperand(1); 236 237 // If the RHS is a constant, see if we can simplify it. 238 if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero)) 239 return I; 240 241 break; 242 } 243 case Instruction::Or: { 244 // If either the LHS or the RHS are One, the result is One. 245 if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Q, Depth + 1) || 246 SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, Q, 247 Depth + 1)) { 248 // Disjoint flag may not longer hold. 249 I->dropPoisonGeneratingFlags(); 250 return I; 251 } 252 253 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 254 Q, Depth); 255 256 // If the client is only demanding bits that we know, return the known 257 // constant. 258 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 259 return Constant::getIntegerValue(VTy, Known.One); 260 261 // If all of the demanded bits are known zero on one side, return the other. 262 // These bits cannot contribute to the result of the 'or'. 263 if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) 264 return I->getOperand(0); 265 if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) 266 return I->getOperand(1); 267 268 // If the RHS is a constant, see if we can simplify it. 269 if (ShrinkDemandedConstant(I, 1, DemandedMask)) 270 return I; 271 272 // Infer disjoint flag if no common bits are set. 273 if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) { 274 WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown), 275 RHSCache(I->getOperand(1), RHSKnown); 276 if (haveNoCommonBitsSet(LHSCache, RHSCache, Q)) { 277 cast<PossiblyDisjointInst>(I)->setIsDisjoint(true); 278 return I; 279 } 280 } 281 282 break; 283 } 284 case Instruction::Xor: { 285 if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Q, Depth + 1) || 286 SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Q, Depth + 1)) 287 return I; 288 Value *LHS, *RHS; 289 if (DemandedMask == 1 && 290 match(I->getOperand(0), m_Intrinsic<Intrinsic::ctpop>(m_Value(LHS))) && 291 match(I->getOperand(1), m_Intrinsic<Intrinsic::ctpop>(m_Value(RHS)))) { 292 // (ctpop(X) ^ ctpop(Y)) & 1 --> ctpop(X^Y) & 1 293 IRBuilderBase::InsertPointGuard Guard(Builder); 294 Builder.SetInsertPoint(I); 295 auto *Xor = Builder.CreateXor(LHS, RHS); 296 return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor); 297 } 298 299 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 300 Q, Depth); 301 302 // If the client is only demanding bits that we know, return the known 303 // constant. 304 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 305 return Constant::getIntegerValue(VTy, Known.One); 306 307 // If all of the demanded bits are known zero on one side, return the other. 308 // These bits cannot contribute to the result of the 'xor'. 309 if (DemandedMask.isSubsetOf(RHSKnown.Zero)) 310 return I->getOperand(0); 311 if (DemandedMask.isSubsetOf(LHSKnown.Zero)) 312 return I->getOperand(1); 313 314 // If all of the demanded bits are known to be zero on one side or the 315 // other, turn this into an *inclusive* or. 316 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 317 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) { 318 Instruction *Or = 319 BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1)); 320 if (DemandedMask.isAllOnes()) 321 cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true); 322 Or->takeName(I); 323 return InsertNewInstWith(Or, I->getIterator()); 324 } 325 326 // If all of the demanded bits on one side are known, and all of the set 327 // bits on that side are also known to be set on the other side, turn this 328 // into an AND, as we know the bits will be cleared. 329 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 330 if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) && 331 RHSKnown.One.isSubsetOf(LHSKnown.One)) { 332 Constant *AndC = Constant::getIntegerValue(VTy, 333 ~RHSKnown.One & DemandedMask); 334 Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); 335 return InsertNewInstWith(And, I->getIterator()); 336 } 337 338 // If the RHS is a constant, see if we can change it. Don't alter a -1 339 // constant because that's a canonical 'not' op, and that is better for 340 // combining, SCEV, and codegen. 341 const APInt *C; 342 if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { 343 if ((*C | ~DemandedMask).isAllOnes()) { 344 // Force bits to 1 to create a 'not' op. 345 I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); 346 return I; 347 } 348 // If we can't turn this into a 'not', try to shrink the constant. 349 if (ShrinkDemandedConstant(I, 1, DemandedMask)) 350 return I; 351 } 352 353 // If our LHS is an 'and' and if it has one use, and if any of the bits we 354 // are flipping are known to be set, then the xor is just resetting those 355 // bits to zero. We can just knock out bits from the 'and' and the 'xor', 356 // simplifying both of them. 357 if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) { 358 ConstantInt *AndRHS, *XorRHS; 359 if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() && 360 match(I->getOperand(1), m_ConstantInt(XorRHS)) && 361 match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) && 362 (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { 363 APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); 364 365 Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); 366 Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); 367 InsertNewInstWith(NewAnd, I->getIterator()); 368 369 Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); 370 Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); 371 return InsertNewInstWith(NewXor, I->getIterator()); 372 } 373 } 374 break; 375 } 376 case Instruction::Select: { 377 if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Q, Depth + 1) || 378 SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Q, Depth + 1)) 379 return I; 380 381 // If the operands are constants, see if we can simplify them. 382 // This is similar to ShrinkDemandedConstant, but for a select we want to 383 // try to keep the selected constants the same as icmp value constants, if 384 // we can. This helps not break apart (or helps put back together) 385 // canonical patterns like min and max. 386 auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo, 387 const APInt &DemandedMask) { 388 const APInt *SelC; 389 if (!match(I->getOperand(OpNo), m_APInt(SelC))) 390 return false; 391 392 // Get the constant out of the ICmp, if there is one. 393 // Only try this when exactly 1 operand is a constant (if both operands 394 // are constant, the icmp should eventually simplify). Otherwise, we may 395 // invert the transform that reduces set bits and infinite-loop. 396 Value *X; 397 const APInt *CmpC; 398 if (!match(I->getOperand(0), m_ICmp(m_Value(X), m_APInt(CmpC))) || 399 isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth()) 400 return ShrinkDemandedConstant(I, OpNo, DemandedMask); 401 402 // If the constant is already the same as the ICmp, leave it as-is. 403 if (*CmpC == *SelC) 404 return false; 405 // If the constants are not already the same, but can be with the demand 406 // mask, use the constant value from the ICmp. 407 if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) { 408 I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC)); 409 return true; 410 } 411 return ShrinkDemandedConstant(I, OpNo, DemandedMask); 412 }; 413 if (CanonicalizeSelectConstant(I, 1, DemandedMask) || 414 CanonicalizeSelectConstant(I, 2, DemandedMask)) 415 return I; 416 417 // Only known if known in both the LHS and RHS. 418 adjustKnownBitsForSelectArm(LHSKnown, I->getOperand(0), I->getOperand(1), 419 /*Invert=*/false, Q, Depth); 420 adjustKnownBitsForSelectArm(RHSKnown, I->getOperand(0), I->getOperand(2), 421 /*Invert=*/true, Q, Depth); 422 Known = LHSKnown.intersectWith(RHSKnown); 423 break; 424 } 425 case Instruction::Trunc: { 426 // If we do not demand the high bits of a right-shifted and truncated value, 427 // then we may be able to truncate it before the shift. 428 Value *X; 429 const APInt *C; 430 if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { 431 // The shift amount must be valid (not poison) in the narrow type, and 432 // it must not be greater than the high bits demanded of the result. 433 if (C->ult(VTy->getScalarSizeInBits()) && 434 C->ule(DemandedMask.countl_zero())) { 435 // trunc (lshr X, C) --> lshr (trunc X), C 436 IRBuilderBase::InsertPointGuard Guard(Builder); 437 Builder.SetInsertPoint(I); 438 Value *Trunc = Builder.CreateTrunc(X, VTy); 439 return Builder.CreateLShr(Trunc, C->getZExtValue()); 440 } 441 } 442 } 443 [[fallthrough]]; 444 case Instruction::ZExt: { 445 unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); 446 447 APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); 448 KnownBits InputKnown(SrcBitWidth); 449 if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Q, 450 Depth + 1)) { 451 // For zext nneg, we may have dropped the instruction which made the 452 // input non-negative. 453 I->dropPoisonGeneratingFlags(); 454 return I; 455 } 456 assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); 457 if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() && 458 !InputKnown.isNegative()) 459 InputKnown.makeNonNegative(); 460 Known = InputKnown.zextOrTrunc(BitWidth); 461 462 break; 463 } 464 case Instruction::SExt: { 465 // Compute the bits in the result that are not present in the input. 466 unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); 467 468 APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth); 469 470 // If any of the sign extended bits are demanded, we know that the sign 471 // bit is demanded. 472 if (DemandedMask.getActiveBits() > SrcBitWidth) 473 InputDemandedBits.setBit(SrcBitWidth-1); 474 475 KnownBits InputKnown(SrcBitWidth); 476 if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Q, Depth + 1)) 477 return I; 478 479 // If the input sign bit is known zero, or if the NewBits are not demanded 480 // convert this into a zero extension. 481 if (InputKnown.isNonNegative() || 482 DemandedMask.getActiveBits() <= SrcBitWidth) { 483 // Convert to ZExt cast. 484 CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy); 485 NewCast->takeName(I); 486 return InsertNewInstWith(NewCast, I->getIterator()); 487 } 488 489 // If the sign bit of the input is known set or clear, then we know the 490 // top bits of the result. 491 Known = InputKnown.sext(BitWidth); 492 break; 493 } 494 case Instruction::Add: { 495 if ((DemandedMask & 1) == 0) { 496 // If we do not need the low bit, try to convert bool math to logic: 497 // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN 498 Value *X, *Y; 499 if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))), 500 m_OneUse(m_SExt(m_Value(Y))))) && 501 X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { 502 // Truth table for inputs and output signbits: 503 // X:0 | X:1 504 // ---------- 505 // Y:0 | 0 | 0 | 506 // Y:1 | -1 | 0 | 507 // ---------- 508 IRBuilderBase::InsertPointGuard Guard(Builder); 509 Builder.SetInsertPoint(I); 510 Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y); 511 return Builder.CreateSExt(AndNot, VTy); 512 } 513 514 // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN 515 if (match(I, m_Add(m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) && 516 X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && 517 (I->getOperand(0)->hasOneUse() || I->getOperand(1)->hasOneUse())) { 518 519 // Truth table for inputs and output signbits: 520 // X:0 | X:1 521 // ----------- 522 // Y:0 | -1 | -1 | 523 // Y:1 | -1 | 0 | 524 // ----------- 525 IRBuilderBase::InsertPointGuard Guard(Builder); 526 Builder.SetInsertPoint(I); 527 Value *Or = Builder.CreateOr(X, Y); 528 return Builder.CreateSExt(Or, VTy); 529 } 530 } 531 532 // Right fill the mask of bits for the operands to demand the most 533 // significant bit and all those below it. 534 unsigned NLZ = DemandedMask.countl_zero(); 535 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 536 if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || 537 SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Q, Depth + 1)) 538 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 539 540 // If low order bits are not demanded and known to be zero in one operand, 541 // then we don't need to demand them from the other operand, since they 542 // can't cause overflow into any bits that are demanded in the result. 543 unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); 544 APInt DemandedFromLHS = DemandedFromOps; 545 DemandedFromLHS.clearLowBits(NTZ); 546 if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || 547 SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Q, Depth + 1)) 548 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 549 550 // If we are known to be adding zeros to every bit below 551 // the highest demanded bit, we just return the other side. 552 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 553 return I->getOperand(0); 554 if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) 555 return I->getOperand(1); 556 557 // (add X, C) --> (xor X, C) IFF C is equal to the top bit of the DemandMask 558 { 559 const APInt *C; 560 if (match(I->getOperand(1), m_APInt(C)) && 561 C->isOneBitSet(DemandedMask.getActiveBits() - 1)) { 562 IRBuilderBase::InsertPointGuard Guard(Builder); 563 Builder.SetInsertPoint(I); 564 return Builder.CreateXor(I->getOperand(0), ConstantInt::get(VTy, *C)); 565 } 566 } 567 568 // Otherwise just compute the known bits of the result. 569 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 570 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 571 Known = KnownBits::add(LHSKnown, RHSKnown, NSW, NUW); 572 break; 573 } 574 case Instruction::Sub: { 575 // Right fill the mask of bits for the operands to demand the most 576 // significant bit and all those below it. 577 unsigned NLZ = DemandedMask.countl_zero(); 578 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 579 if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || 580 SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Q, Depth + 1)) 581 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 582 583 // If low order bits are not demanded and are known to be zero in RHS, 584 // then we don't need to demand them from LHS, since they can't cause a 585 // borrow from any bits that are demanded in the result. 586 unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); 587 APInt DemandedFromLHS = DemandedFromOps; 588 DemandedFromLHS.clearLowBits(NTZ); 589 if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || 590 SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Q, Depth + 1)) 591 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 592 593 // If we are known to be subtracting zeros from every bit below 594 // the highest demanded bit, we just return the other side. 595 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 596 return I->getOperand(0); 597 // We can't do this with the LHS for subtraction, unless we are only 598 // demanding the LSB. 599 if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) 600 return I->getOperand(1); 601 602 // Canonicalize sub mask, X -> ~X 603 const APInt *LHSC; 604 if (match(I->getOperand(0), m_LowBitMask(LHSC)) && 605 DemandedFromOps.isSubsetOf(*LHSC)) { 606 IRBuilderBase::InsertPointGuard Guard(Builder); 607 Builder.SetInsertPoint(I); 608 return Builder.CreateNot(I->getOperand(1)); 609 } 610 611 // Otherwise just compute the known bits of the result. 612 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 613 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 614 Known = KnownBits::sub(LHSKnown, RHSKnown, NSW, NUW); 615 break; 616 } 617 case Instruction::Mul: { 618 APInt DemandedFromOps; 619 if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) 620 return I; 621 622 if (DemandedMask.isPowerOf2()) { 623 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. 624 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is 625 // odd (has LSB set), then the left-shifted low bit of X is the answer. 626 unsigned CTZ = DemandedMask.countr_zero(); 627 const APInt *C; 628 if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) { 629 Constant *ShiftC = ConstantInt::get(VTy, CTZ); 630 Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); 631 return InsertNewInstWith(Shl, I->getIterator()); 632 } 633 } 634 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: 635 // X * X is odd iff X is odd. 636 // 'Quadratic Reciprocity': X * X -> 0 for bit[1] 637 if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { 638 Constant *One = ConstantInt::get(VTy, 1); 639 Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); 640 return InsertNewInstWith(And1, I->getIterator()); 641 } 642 643 llvm::computeKnownBits(I, Known, Q, Depth); 644 break; 645 } 646 case Instruction::Shl: { 647 const APInt *SA; 648 if (match(I->getOperand(1), m_APInt(SA))) { 649 const APInt *ShrAmt; 650 if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) 651 if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0))) 652 if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA, 653 DemandedMask, Known)) 654 return R; 655 656 // Do not simplify if shl is part of funnel-shift pattern 657 if (I->hasOneUse()) { 658 auto *Inst = dyn_cast<Instruction>(I->user_back()); 659 if (Inst && Inst->getOpcode() == BinaryOperator::Or) { 660 if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) { 661 auto [IID, FShiftArgs] = *Opt; 662 if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 663 FShiftArgs[0] == FShiftArgs[1]) { 664 llvm::computeKnownBits(I, Known, Q, Depth); 665 break; 666 } 667 } 668 } 669 } 670 671 // We only want bits that already match the signbit then we don't 672 // need to shift. 673 uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1); 674 if (DemandedMask.countr_zero() >= ShiftAmt) { 675 if (I->hasNoSignedWrap()) { 676 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); 677 unsigned SignBits = 678 ComputeNumSignBits(I->getOperand(0), Q.CxtI, Depth + 1); 679 if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits) 680 return I->getOperand(0); 681 } 682 683 // If we can pre-shift a right-shifted constant to the left without 684 // losing any high bits and we don't demand the low bits, then eliminate 685 // the left-shift: 686 // (C >> X) << LeftShiftAmtC --> (C << LeftShiftAmtC) >> X 687 Value *X; 688 Constant *C; 689 if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { 690 Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); 691 Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C, 692 LeftShiftAmtC, DL); 693 if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, 694 LeftShiftAmtC, DL) == C) { 695 Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); 696 return InsertNewInstWith(Lshr, I->getIterator()); 697 } 698 } 699 } 700 701 APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); 702 703 // If the shift is NUW/NSW, then it does demand the high bits. 704 ShlOperator *IOp = cast<ShlOperator>(I); 705 if (IOp->hasNoSignedWrap()) 706 DemandedMaskIn.setHighBits(ShiftAmt+1); 707 else if (IOp->hasNoUnsignedWrap()) 708 DemandedMaskIn.setHighBits(ShiftAmt); 709 710 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Q, Depth + 1)) 711 return I; 712 713 Known = KnownBits::shl(Known, 714 KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), 715 /* NUW */ IOp->hasNoUnsignedWrap(), 716 /* NSW */ IOp->hasNoSignedWrap()); 717 } else { 718 // This is a variable shift, so we can't shift the demand mask by a known 719 // amount. But if we are not demanding high bits, then we are not 720 // demanding those bits from the pre-shifted operand either. 721 if (unsigned CTLZ = DemandedMask.countl_zero()) { 722 APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ)); 723 if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Q, Depth + 1)) { 724 // We can't guarantee that nsw/nuw hold after simplifying the operand. 725 I->dropPoisonGeneratingFlags(); 726 return I; 727 } 728 } 729 llvm::computeKnownBits(I, Known, Q, Depth); 730 } 731 break; 732 } 733 case Instruction::LShr: { 734 const APInt *SA; 735 if (match(I->getOperand(1), m_APInt(SA))) { 736 uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); 737 738 // Do not simplify if lshr is part of funnel-shift pattern 739 if (I->hasOneUse()) { 740 auto *Inst = dyn_cast<Instruction>(I->user_back()); 741 if (Inst && Inst->getOpcode() == BinaryOperator::Or) { 742 if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) { 743 auto [IID, FShiftArgs] = *Opt; 744 if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 745 FShiftArgs[0] == FShiftArgs[1]) { 746 llvm::computeKnownBits(I, Known, Q, Depth); 747 break; 748 } 749 } 750 } 751 } 752 753 // If we are just demanding the shifted sign bit and below, then this can 754 // be treated as an ASHR in disguise. 755 if (DemandedMask.countl_zero() >= ShiftAmt) { 756 // If we only want bits that already match the signbit then we don't 757 // need to shift. 758 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); 759 unsigned SignBits = 760 ComputeNumSignBits(I->getOperand(0), Q.CxtI, Depth + 1); 761 if (SignBits >= NumHiDemandedBits) 762 return I->getOperand(0); 763 764 // If we can pre-shift a left-shifted constant to the right without 765 // losing any low bits (we already know we don't demand the high bits), 766 // then eliminate the right-shift: 767 // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X 768 Value *X; 769 Constant *C; 770 if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { 771 Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); 772 Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C, 773 RightShiftAmtC, DL); 774 if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC, 775 RightShiftAmtC, DL) == C) { 776 Instruction *Shl = BinaryOperator::CreateShl(NewC, X); 777 return InsertNewInstWith(Shl, I->getIterator()); 778 } 779 } 780 781 const APInt *Factor; 782 if (match(I->getOperand(0), 783 m_OneUse(m_Mul(m_Value(X), m_APInt(Factor)))) && 784 Factor->countr_zero() >= ShiftAmt) { 785 BinaryOperator *Mul = BinaryOperator::CreateMul( 786 X, ConstantInt::get(X->getType(), Factor->lshr(ShiftAmt))); 787 return InsertNewInstWith(Mul, I->getIterator()); 788 } 789 } 790 791 // Unsigned shift right. 792 APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); 793 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Q, Depth + 1)) { 794 // exact flag may not longer hold. 795 I->dropPoisonGeneratingFlags(); 796 return I; 797 } 798 Known.Zero.lshrInPlace(ShiftAmt); 799 Known.One.lshrInPlace(ShiftAmt); 800 if (ShiftAmt) 801 Known.Zero.setHighBits(ShiftAmt); // high bits known zero. 802 } else { 803 llvm::computeKnownBits(I, Known, Q, Depth); 804 } 805 break; 806 } 807 case Instruction::AShr: { 808 unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Q.CxtI, Depth + 1); 809 810 // If we only want bits that already match the signbit then we don't need 811 // to shift. 812 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); 813 if (SignBits >= NumHiDemandedBits) 814 return I->getOperand(0); 815 816 // If this is an arithmetic shift right and only the low-bit is set, we can 817 // always convert this into a logical shr, even if the shift amount is 818 // variable. The low bit of the shift cannot be an input sign bit unless 819 // the shift amount is >= the size of the datatype, which is undefined. 820 if (DemandedMask.isOne()) { 821 // Perform the logical shift right. 822 Instruction *NewVal = BinaryOperator::CreateLShr( 823 I->getOperand(0), I->getOperand(1), I->getName()); 824 return InsertNewInstWith(NewVal, I->getIterator()); 825 } 826 827 const APInt *SA; 828 if (match(I->getOperand(1), m_APInt(SA))) { 829 uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); 830 831 // Signed shift right. 832 APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); 833 // If any of the bits being shifted in are demanded, then we should set 834 // the sign bit as demanded. 835 bool ShiftedInBitsDemanded = DemandedMask.countl_zero() < ShiftAmt; 836 if (ShiftedInBitsDemanded) 837 DemandedMaskIn.setSignBit(); 838 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Q, Depth + 1)) { 839 // exact flag may not longer hold. 840 I->dropPoisonGeneratingFlags(); 841 return I; 842 } 843 844 // If the input sign bit is known to be zero, or if none of the shifted in 845 // bits are demanded, turn this into an unsigned shift right. 846 if (Known.Zero[BitWidth - 1] || !ShiftedInBitsDemanded) { 847 BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), 848 I->getOperand(1)); 849 LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); 850 LShr->takeName(I); 851 return InsertNewInstWith(LShr, I->getIterator()); 852 } 853 854 Known = KnownBits::ashr( 855 Known, KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), 856 ShiftAmt != 0, I->isExact()); 857 } else { 858 llvm::computeKnownBits(I, Known, Q, Depth); 859 } 860 break; 861 } 862 case Instruction::UDiv: { 863 // UDiv doesn't demand low bits that are zero in the divisor. 864 const APInt *SA; 865 if (match(I->getOperand(1), m_APInt(SA))) { 866 // TODO: Take the demanded mask of the result into account. 867 unsigned RHSTrailingZeros = SA->countr_zero(); 868 APInt DemandedMaskIn = 869 APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); 870 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Q, Depth + 1)) { 871 // We can't guarantee that "exact" is still true after changing the 872 // the dividend. 873 I->dropPoisonGeneratingFlags(); 874 return I; 875 } 876 877 Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA), 878 cast<BinaryOperator>(I)->isExact()); 879 } else { 880 llvm::computeKnownBits(I, Known, Q, Depth); 881 } 882 break; 883 } 884 case Instruction::SRem: { 885 const APInt *Rem; 886 if (match(I->getOperand(1), m_APInt(Rem)) && Rem->isPowerOf2()) { 887 if (DemandedMask.ult(*Rem)) // srem won't affect demanded bits 888 return I->getOperand(0); 889 890 APInt LowBits = *Rem - 1; 891 APInt Mask2 = LowBits | APInt::getSignMask(BitWidth); 892 if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Q, Depth + 1)) 893 return I; 894 Known = KnownBits::srem(LHSKnown, KnownBits::makeConstant(*Rem)); 895 break; 896 } 897 898 llvm::computeKnownBits(I, Known, Q, Depth); 899 break; 900 } 901 case Instruction::Call: { 902 bool KnownBitsComputed = false; 903 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { 904 switch (II->getIntrinsicID()) { 905 case Intrinsic::abs: { 906 if (DemandedMask == 1) 907 return II->getArgOperand(0); 908 break; 909 } 910 case Intrinsic::ctpop: { 911 // Checking if the number of clear bits is odd (parity)? If the type has 912 // an even number of bits, that's the same as checking if the number of 913 // set bits is odd, so we can eliminate the 'not' op. 914 Value *X; 915 if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 && 916 match(II->getArgOperand(0), m_Not(m_Value(X)))) { 917 Function *Ctpop = Intrinsic::getOrInsertDeclaration( 918 II->getModule(), Intrinsic::ctpop, VTy); 919 return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator()); 920 } 921 break; 922 } 923 case Intrinsic::bswap: { 924 // If the only bits demanded come from one byte of the bswap result, 925 // just shift the input byte into position to eliminate the bswap. 926 unsigned NLZ = DemandedMask.countl_zero(); 927 unsigned NTZ = DemandedMask.countr_zero(); 928 929 // Round NTZ down to the next byte. If we have 11 trailing zeros, then 930 // we need all the bits down to bit 8. Likewise, round NLZ. If we 931 // have 14 leading zeros, round to 8. 932 NLZ = alignDown(NLZ, 8); 933 NTZ = alignDown(NTZ, 8); 934 // If we need exactly one byte, we can do this transformation. 935 if (BitWidth - NLZ - NTZ == 8) { 936 // Replace this with either a left or right shift to get the byte into 937 // the right place. 938 Instruction *NewVal; 939 if (NLZ > NTZ) 940 NewVal = BinaryOperator::CreateLShr( 941 II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ)); 942 else 943 NewVal = BinaryOperator::CreateShl( 944 II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); 945 NewVal->takeName(I); 946 return InsertNewInstWith(NewVal, I->getIterator()); 947 } 948 break; 949 } 950 case Intrinsic::ptrmask: { 951 unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits(); 952 RHSKnown = KnownBits(MaskWidth); 953 // If either the LHS or the RHS are Zero, the result is zero. 954 if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Q, Depth + 1) || 955 SimplifyDemandedBits( 956 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth), 957 RHSKnown, Q, Depth + 1)) 958 return I; 959 960 // TODO: Should be 1-extend 961 RHSKnown = RHSKnown.anyextOrTrunc(BitWidth); 962 963 Known = LHSKnown & RHSKnown; 964 KnownBitsComputed = true; 965 966 // If the client is only demanding bits we know to be zero, return 967 // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer 968 // provenance, but making the mask zero will be easily optimizable in 969 // the backend. 970 if (DemandedMask.isSubsetOf(Known.Zero) && 971 !match(I->getOperand(1), m_Zero())) 972 return replaceOperand( 973 *I, 1, Constant::getNullValue(I->getOperand(1)->getType())); 974 975 // Mask in demanded space does nothing. 976 // NOTE: We may have attributes associated with the return value of the 977 // llvm.ptrmask intrinsic that will be lost when we just return the 978 // operand. We should try to preserve them. 979 if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) 980 return I->getOperand(0); 981 982 // If the RHS is a constant, see if we can simplify it. 983 if (ShrinkDemandedConstant( 984 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth))) 985 return I; 986 987 // Combine: 988 // (ptrmask (getelementptr i8, ptr p, imm i), imm mask) 989 // -> (ptrmask (getelementptr i8, ptr p, imm (i & mask)), imm mask) 990 // where only the low bits known to be zero in the pointer are changed 991 Value *InnerPtr; 992 uint64_t GEPIndex; 993 uint64_t PtrMaskImmediate; 994 if (match(I, m_Intrinsic<Intrinsic::ptrmask>( 995 m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)), 996 m_ConstantInt(PtrMaskImmediate)))) { 997 998 LHSKnown = computeKnownBits(InnerPtr, I, Depth + 1); 999 if (!LHSKnown.isZero()) { 1000 const unsigned trailingZeros = LHSKnown.countMinTrailingZeros(); 1001 uint64_t PointerAlignBits = (uint64_t(1) << trailingZeros) - 1; 1002 1003 uint64_t HighBitsGEPIndex = GEPIndex & ~PointerAlignBits; 1004 uint64_t MaskedLowBitsGEPIndex = 1005 GEPIndex & PointerAlignBits & PtrMaskImmediate; 1006 1007 uint64_t MaskedGEPIndex = HighBitsGEPIndex | MaskedLowBitsGEPIndex; 1008 1009 if (MaskedGEPIndex != GEPIndex) { 1010 auto *GEP = cast<GEPOperator>(II->getArgOperand(0)); 1011 Builder.SetInsertPoint(I); 1012 Type *GEPIndexType = 1013 DL.getIndexType(GEP->getPointerOperand()->getType()); 1014 Value *MaskedGEP = Builder.CreateGEP( 1015 GEP->getSourceElementType(), InnerPtr, 1016 ConstantInt::get(GEPIndexType, MaskedGEPIndex), 1017 GEP->getName(), GEP->isInBounds()); 1018 1019 replaceOperand(*I, 0, MaskedGEP); 1020 return I; 1021 } 1022 } 1023 } 1024 1025 break; 1026 } 1027 1028 case Intrinsic::fshr: 1029 case Intrinsic::fshl: { 1030 const APInt *SA; 1031 if (!match(I->getOperand(2), m_APInt(SA))) 1032 break; 1033 1034 // Normalize to funnel shift left. APInt shifts of BitWidth are well- 1035 // defined, so no need to special-case zero shifts here. 1036 uint64_t ShiftAmt = SA->urem(BitWidth); 1037 if (II->getIntrinsicID() == Intrinsic::fshr) 1038 ShiftAmt = BitWidth - ShiftAmt; 1039 1040 APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt)); 1041 APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); 1042 if (I->getOperand(0) != I->getOperand(1)) { 1043 if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Q, 1044 Depth + 1) || 1045 SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Q, 1046 Depth + 1)) { 1047 // Range attribute may no longer hold. 1048 I->dropPoisonGeneratingReturnAttributes(); 1049 return I; 1050 } 1051 } else { // fshl is a rotate 1052 // Avoid converting rotate into funnel shift. 1053 // Only simplify if one operand is constant. 1054 LHSKnown = computeKnownBits(I->getOperand(0), I, Depth + 1); 1055 if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) && 1056 !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) { 1057 replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One)); 1058 return I; 1059 } 1060 1061 RHSKnown = computeKnownBits(I->getOperand(1), I, Depth + 1); 1062 if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) && 1063 !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) { 1064 replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One)); 1065 return I; 1066 } 1067 } 1068 1069 Known.Zero = LHSKnown.Zero.shl(ShiftAmt) | 1070 RHSKnown.Zero.lshr(BitWidth - ShiftAmt); 1071 Known.One = LHSKnown.One.shl(ShiftAmt) | 1072 RHSKnown.One.lshr(BitWidth - ShiftAmt); 1073 KnownBitsComputed = true; 1074 break; 1075 } 1076 case Intrinsic::umax: { 1077 // UMax(A, C) == A if ... 1078 // The lowest non-zero bit of DemandMask is higher than the highest 1079 // non-zero bit of C. 1080 const APInt *C; 1081 unsigned CTZ = DemandedMask.countr_zero(); 1082 if (match(II->getArgOperand(1), m_APInt(C)) && 1083 CTZ >= C->getActiveBits()) 1084 return II->getArgOperand(0); 1085 break; 1086 } 1087 case Intrinsic::umin: { 1088 // UMin(A, C) == A if ... 1089 // The lowest non-zero bit of DemandMask is higher than the highest 1090 // non-one bit of C. 1091 // This comes from using DeMorgans on the above umax example. 1092 const APInt *C; 1093 unsigned CTZ = DemandedMask.countr_zero(); 1094 if (match(II->getArgOperand(1), m_APInt(C)) && 1095 CTZ >= C->getBitWidth() - C->countl_one()) 1096 return II->getArgOperand(0); 1097 break; 1098 } 1099 default: { 1100 // Handle target specific intrinsics 1101 std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( 1102 *II, DemandedMask, Known, KnownBitsComputed); 1103 if (V) 1104 return *V; 1105 break; 1106 } 1107 } 1108 } 1109 1110 if (!KnownBitsComputed) 1111 llvm::computeKnownBits(I, Known, Q, Depth); 1112 break; 1113 } 1114 } 1115 1116 if (I->getType()->isPointerTy()) { 1117 Align Alignment = I->getPointerAlignment(DL); 1118 Known.Zero.setLowBits(Log2(Alignment)); 1119 } 1120 1121 // If the client is only demanding bits that we know, return the known 1122 // constant. We can't directly simplify pointers as a constant because of 1123 // pointer provenance. 1124 // TODO: We could return `(inttoptr const)` for pointers. 1125 if (!I->getType()->isPointerTy() && 1126 DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1127 return Constant::getIntegerValue(VTy, Known.One); 1128 1129 if (VerifyKnownBits) { 1130 KnownBits ReferenceKnown = llvm::computeKnownBits(I, Q, Depth); 1131 if (Known != ReferenceKnown) { 1132 errs() << "Mismatched known bits for " << *I << " in " 1133 << I->getFunction()->getName() << "\n"; 1134 errs() << "computeKnownBits(): " << ReferenceKnown << "\n"; 1135 errs() << "SimplifyDemandedBits(): " << Known << "\n"; 1136 std::abort(); 1137 } 1138 } 1139 1140 return nullptr; 1141 } 1142 1143 /// Helper routine of SimplifyDemandedUseBits. It computes Known 1144 /// bits. It also tries to handle simplifications that can be done based on 1145 /// DemandedMask, but without modifying the Instruction. 1146 Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( 1147 Instruction *I, const APInt &DemandedMask, KnownBits &Known, 1148 const SimplifyQuery &Q, unsigned Depth) { 1149 unsigned BitWidth = DemandedMask.getBitWidth(); 1150 Type *ITy = I->getType(); 1151 1152 KnownBits LHSKnown(BitWidth); 1153 KnownBits RHSKnown(BitWidth); 1154 1155 // Despite the fact that we can't simplify this instruction in all User's 1156 // context, we can at least compute the known bits, and we can 1157 // do simplifications that apply to *just* the one user if we know that 1158 // this instruction has a simpler value in that context. 1159 switch (I->getOpcode()) { 1160 case Instruction::And: { 1161 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Q, Depth + 1); 1162 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Q, Depth + 1); 1163 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 1164 Q, Depth); 1165 computeKnownBitsFromContext(I, Known, Q, Depth); 1166 1167 // If the client is only demanding bits that we know, return the known 1168 // constant. 1169 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1170 return Constant::getIntegerValue(ITy, Known.One); 1171 1172 // If all of the demanded bits are known 1 on one side, return the other. 1173 // These bits cannot contribute to the result of the 'and' in this context. 1174 if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) 1175 return I->getOperand(0); 1176 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) 1177 return I->getOperand(1); 1178 1179 break; 1180 } 1181 case Instruction::Or: { 1182 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Q, Depth + 1); 1183 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Q, Depth + 1); 1184 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 1185 Q, Depth); 1186 computeKnownBitsFromContext(I, Known, Q, Depth); 1187 1188 // If the client is only demanding bits that we know, return the known 1189 // constant. 1190 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1191 return Constant::getIntegerValue(ITy, Known.One); 1192 1193 // We can simplify (X|Y) -> X or Y in the user's context if we know that 1194 // only bits from X or Y are demanded. 1195 // If all of the demanded bits are known zero on one side, return the other. 1196 // These bits cannot contribute to the result of the 'or' in this context. 1197 if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) 1198 return I->getOperand(0); 1199 if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) 1200 return I->getOperand(1); 1201 1202 break; 1203 } 1204 case Instruction::Xor: { 1205 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Q, Depth + 1); 1206 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Q, Depth + 1); 1207 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 1208 Q, Depth); 1209 computeKnownBitsFromContext(I, Known, Q, Depth); 1210 1211 // If the client is only demanding bits that we know, return the known 1212 // constant. 1213 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1214 return Constant::getIntegerValue(ITy, Known.One); 1215 1216 // We can simplify (X^Y) -> X or Y in the user's context if we know that 1217 // only bits from X or Y are demanded. 1218 // If all of the demanded bits are known zero on one side, return the other. 1219 if (DemandedMask.isSubsetOf(RHSKnown.Zero)) 1220 return I->getOperand(0); 1221 if (DemandedMask.isSubsetOf(LHSKnown.Zero)) 1222 return I->getOperand(1); 1223 1224 break; 1225 } 1226 case Instruction::Add: { 1227 unsigned NLZ = DemandedMask.countl_zero(); 1228 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 1229 1230 // If an operand adds zeros to every bit below the highest demanded bit, 1231 // that operand doesn't change the result. Return the other side. 1232 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Q, Depth + 1); 1233 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 1234 return I->getOperand(0); 1235 1236 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Q, Depth + 1); 1237 if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) 1238 return I->getOperand(1); 1239 1240 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 1241 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 1242 Known = KnownBits::add(LHSKnown, RHSKnown, NSW, NUW); 1243 computeKnownBitsFromContext(I, Known, Q, Depth); 1244 break; 1245 } 1246 case Instruction::Sub: { 1247 unsigned NLZ = DemandedMask.countl_zero(); 1248 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 1249 1250 // If an operand subtracts zeros from every bit below the highest demanded 1251 // bit, that operand doesn't change the result. Return the other side. 1252 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Q, Depth + 1); 1253 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 1254 return I->getOperand(0); 1255 1256 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 1257 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 1258 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Q, Depth + 1); 1259 Known = KnownBits::sub(LHSKnown, RHSKnown, NSW, NUW); 1260 computeKnownBitsFromContext(I, Known, Q, Depth); 1261 break; 1262 } 1263 case Instruction::AShr: { 1264 // Compute the Known bits to simplify things downstream. 1265 llvm::computeKnownBits(I, Known, Q, Depth); 1266 1267 // If this user is only demanding bits that we know, return the known 1268 // constant. 1269 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1270 return Constant::getIntegerValue(ITy, Known.One); 1271 1272 // If the right shift operand 0 is a result of a left shift by the same 1273 // amount, this is probably a zero/sign extension, which may be unnecessary, 1274 // if we do not demand any of the new sign bits. So, return the original 1275 // operand instead. 1276 const APInt *ShiftRC; 1277 const APInt *ShiftLC; 1278 Value *X; 1279 unsigned BitWidth = DemandedMask.getBitWidth(); 1280 if (match(I, 1281 m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) && 1282 ShiftLC == ShiftRC && ShiftLC->ult(BitWidth) && 1283 DemandedMask.isSubsetOf(APInt::getLowBitsSet( 1284 BitWidth, BitWidth - ShiftRC->getZExtValue()))) { 1285 return X; 1286 } 1287 1288 break; 1289 } 1290 default: 1291 // Compute the Known bits to simplify things downstream. 1292 llvm::computeKnownBits(I, Known, Q, Depth); 1293 1294 // If this user is only demanding bits that we know, return the known 1295 // constant. 1296 if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) 1297 return Constant::getIntegerValue(ITy, Known.One); 1298 1299 break; 1300 } 1301 1302 return nullptr; 1303 } 1304 1305 /// Helper routine of SimplifyDemandedUseBits. It tries to simplify 1306 /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into 1307 /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign 1308 /// of "C2-C1". 1309 /// 1310 /// Suppose E1 and E2 are generally different in bits S={bm, bm+1, 1311 /// ..., bn}, without considering the specific value X is holding. 1312 /// This transformation is legal iff one of following conditions is hold: 1313 /// 1) All the bit in S are 0, in this case E1 == E2. 1314 /// 2) We don't care those bits in S, per the input DemandedMask. 1315 /// 3) Combination of 1) and 2). Some bits in S are 0, and we don't care the 1316 /// rest bits. 1317 /// 1318 /// Currently we only test condition 2). 1319 /// 1320 /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was 1321 /// not successful. 1322 Value *InstCombinerImpl::simplifyShrShlDemandedBits( 1323 Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, 1324 const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) { 1325 if (!ShlOp1 || !ShrOp1) 1326 return nullptr; // No-op. 1327 1328 Value *VarX = Shr->getOperand(0); 1329 Type *Ty = VarX->getType(); 1330 unsigned BitWidth = Ty->getScalarSizeInBits(); 1331 if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth)) 1332 return nullptr; // Undef. 1333 1334 unsigned ShlAmt = ShlOp1.getZExtValue(); 1335 unsigned ShrAmt = ShrOp1.getZExtValue(); 1336 1337 Known.One.clearAllBits(); 1338 Known.Zero.setLowBits(ShlAmt - 1); 1339 Known.Zero &= DemandedMask; 1340 1341 APInt BitMask1(APInt::getAllOnes(BitWidth)); 1342 APInt BitMask2(APInt::getAllOnes(BitWidth)); 1343 1344 bool isLshr = (Shr->getOpcode() == Instruction::LShr); 1345 BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : 1346 (BitMask1.ashr(ShrAmt) << ShlAmt); 1347 1348 if (ShrAmt <= ShlAmt) { 1349 BitMask2 <<= (ShlAmt - ShrAmt); 1350 } else { 1351 BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt): 1352 BitMask2.ashr(ShrAmt - ShlAmt); 1353 } 1354 1355 // Check if condition-2 (see the comment to this function) is satified. 1356 if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) { 1357 if (ShrAmt == ShlAmt) 1358 return VarX; 1359 1360 if (!Shr->hasOneUse()) 1361 return nullptr; 1362 1363 BinaryOperator *New; 1364 if (ShrAmt < ShlAmt) { 1365 Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt); 1366 New = BinaryOperator::CreateShl(VarX, Amt); 1367 BinaryOperator *Orig = cast<BinaryOperator>(Shl); 1368 New->setHasNoSignedWrap(Orig->hasNoSignedWrap()); 1369 New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap()); 1370 } else { 1371 Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt); 1372 New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) : 1373 BinaryOperator::CreateAShr(VarX, Amt); 1374 if (cast<BinaryOperator>(Shr)->isExact()) 1375 New->setIsExact(true); 1376 } 1377 1378 return InsertNewInstWith(New, Shl->getIterator()); 1379 } 1380 1381 return nullptr; 1382 } 1383 1384 /// The specified value produces a vector with any number of elements. 1385 /// This method analyzes which elements of the operand are poison and 1386 /// returns that information in PoisonElts. 1387 /// 1388 /// DemandedElts contains the set of elements that are actually used by the 1389 /// caller, and by default (AllowMultipleUsers equals false) the value is 1390 /// simplified only if it has a single caller. If AllowMultipleUsers is set 1391 /// to true, DemandedElts refers to the union of sets of elements that are 1392 /// used by all callers. 1393 /// 1394 /// If the information about demanded elements can be used to simplify the 1395 /// operation, the operation is simplified, then the resultant value is 1396 /// returned. This returns null if no change was made. 1397 Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, 1398 APInt DemandedElts, 1399 APInt &PoisonElts, 1400 unsigned Depth, 1401 bool AllowMultipleUsers) { 1402 // Cannot analyze scalable type. The number of vector elements is not a 1403 // compile-time constant. 1404 if (isa<ScalableVectorType>(V->getType())) 1405 return nullptr; 1406 1407 unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); 1408 APInt EltMask(APInt::getAllOnes(VWidth)); 1409 assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); 1410 1411 if (match(V, m_Poison())) { 1412 // If the entire vector is poison, just return this info. 1413 PoisonElts = EltMask; 1414 return nullptr; 1415 } 1416 1417 if (DemandedElts.isZero()) { // If nothing is demanded, provide poison. 1418 PoisonElts = EltMask; 1419 return PoisonValue::get(V->getType()); 1420 } 1421 1422 PoisonElts = 0; 1423 1424 if (auto *C = dyn_cast<Constant>(V)) { 1425 // Check if this is identity. If so, return 0 since we are not simplifying 1426 // anything. 1427 if (DemandedElts.isAllOnes()) 1428 return nullptr; 1429 1430 Type *EltTy = cast<VectorType>(V->getType())->getElementType(); 1431 Constant *Poison = PoisonValue::get(EltTy); 1432 SmallVector<Constant*, 16> Elts; 1433 for (unsigned i = 0; i != VWidth; ++i) { 1434 if (!DemandedElts[i]) { // If not demanded, set to poison. 1435 Elts.push_back(Poison); 1436 PoisonElts.setBit(i); 1437 continue; 1438 } 1439 1440 Constant *Elt = C->getAggregateElement(i); 1441 if (!Elt) return nullptr; 1442 1443 Elts.push_back(Elt); 1444 if (isa<PoisonValue>(Elt)) // Already poison. 1445 PoisonElts.setBit(i); 1446 } 1447 1448 // If we changed the constant, return it. 1449 Constant *NewCV = ConstantVector::get(Elts); 1450 return NewCV != C ? NewCV : nullptr; 1451 } 1452 1453 // Limit search depth. 1454 if (Depth == SimplifyDemandedVectorEltsDepthLimit) 1455 return nullptr; 1456 1457 if (!AllowMultipleUsers) { 1458 // If multiple users are using the root value, proceed with 1459 // simplification conservatively assuming that all elements 1460 // are needed. 1461 if (!V->hasOneUse()) { 1462 // Quit if we find multiple users of a non-root value though. 1463 // They'll be handled when it's their turn to be visited by 1464 // the main instcombine process. 1465 if (Depth != 0) 1466 // TODO: Just compute the PoisonElts information recursively. 1467 return nullptr; 1468 1469 // Conservatively assume that all elements are needed. 1470 DemandedElts = EltMask; 1471 } 1472 } 1473 1474 Instruction *I = dyn_cast<Instruction>(V); 1475 if (!I) return nullptr; // Only analyze instructions. 1476 1477 bool MadeChange = false; 1478 auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum, 1479 APInt Demanded, APInt &Undef) { 1480 auto *II = dyn_cast<IntrinsicInst>(Inst); 1481 Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum); 1482 if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) { 1483 replaceOperand(*Inst, OpNum, V); 1484 MadeChange = true; 1485 } 1486 }; 1487 1488 APInt PoisonElts2(VWidth, 0); 1489 APInt PoisonElts3(VWidth, 0); 1490 switch (I->getOpcode()) { 1491 default: break; 1492 1493 case Instruction::GetElementPtr: { 1494 // The LangRef requires that struct geps have all constant indices. As 1495 // such, we can't convert any operand to partial undef. 1496 auto mayIndexStructType = [](GetElementPtrInst &GEP) { 1497 for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP); 1498 I != E; I++) 1499 if (I.isStruct()) 1500 return true; 1501 return false; 1502 }; 1503 if (mayIndexStructType(cast<GetElementPtrInst>(*I))) 1504 break; 1505 1506 // Conservatively track the demanded elements back through any vector 1507 // operands we may have. We know there must be at least one, or we 1508 // wouldn't have a vector result to get here. Note that we intentionally 1509 // merge the undef bits here since gepping with either an poison base or 1510 // index results in poison. 1511 for (unsigned i = 0; i < I->getNumOperands(); i++) { 1512 if (i == 0 ? match(I->getOperand(i), m_Undef()) 1513 : match(I->getOperand(i), m_Poison())) { 1514 // If the entire vector is undefined, just return this info. 1515 PoisonElts = EltMask; 1516 return nullptr; 1517 } 1518 if (I->getOperand(i)->getType()->isVectorTy()) { 1519 APInt PoisonEltsOp(VWidth, 0); 1520 simplifyAndSetOp(I, i, DemandedElts, PoisonEltsOp); 1521 // gep(x, undef) is not undef, so skip considering idx ops here 1522 // Note that we could propagate poison, but we can't distinguish between 1523 // undef & poison bits ATM 1524 if (i == 0) 1525 PoisonElts |= PoisonEltsOp; 1526 } 1527 } 1528 1529 break; 1530 } 1531 case Instruction::InsertElement: { 1532 // If this is a variable index, we don't know which element it overwrites. 1533 // demand exactly the same input as we produce. 1534 ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2)); 1535 if (!Idx) { 1536 // Note that we can't propagate undef elt info, because we don't know 1537 // which elt is getting updated. 1538 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts2); 1539 break; 1540 } 1541 1542 // The element inserted overwrites whatever was there, so the input demanded 1543 // set is simpler than the output set. 1544 unsigned IdxNo = Idx->getZExtValue(); 1545 APInt PreInsertDemandedElts = DemandedElts; 1546 if (IdxNo < VWidth) 1547 PreInsertDemandedElts.clearBit(IdxNo); 1548 1549 // If we only demand the element that is being inserted and that element 1550 // was extracted from the same index in another vector with the same type, 1551 // replace this insert with that other vector. 1552 // Note: This is attempted before the call to simplifyAndSetOp because that 1553 // may change PoisonElts to a value that does not match with Vec. 1554 Value *Vec; 1555 if (PreInsertDemandedElts == 0 && 1556 match(I->getOperand(1), 1557 m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) && 1558 Vec->getType() == I->getType()) { 1559 return Vec; 1560 } 1561 1562 simplifyAndSetOp(I, 0, PreInsertDemandedElts, PoisonElts); 1563 1564 // If this is inserting an element that isn't demanded, remove this 1565 // insertelement. 1566 if (IdxNo >= VWidth || !DemandedElts[IdxNo]) { 1567 Worklist.push(I); 1568 return I->getOperand(0); 1569 } 1570 1571 // The inserted element is defined. 1572 PoisonElts.clearBit(IdxNo); 1573 break; 1574 } 1575 case Instruction::ShuffleVector: { 1576 auto *Shuffle = cast<ShuffleVectorInst>(I); 1577 assert(Shuffle->getOperand(0)->getType() == 1578 Shuffle->getOperand(1)->getType() && 1579 "Expected shuffle operands to have same type"); 1580 unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType()) 1581 ->getNumElements(); 1582 // Handle trivial case of a splat. Only check the first element of LHS 1583 // operand. 1584 if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && 1585 DemandedElts.isAllOnes()) { 1586 if (!isa<PoisonValue>(I->getOperand(1))) { 1587 I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType())); 1588 MadeChange = true; 1589 } 1590 APInt LeftDemanded(OpWidth, 1); 1591 APInt LHSPoisonElts(OpWidth, 0); 1592 simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts); 1593 if (LHSPoisonElts[0]) 1594 PoisonElts = EltMask; 1595 else 1596 PoisonElts.clearAllBits(); 1597 break; 1598 } 1599 1600 APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); 1601 for (unsigned i = 0; i < VWidth; i++) { 1602 if (DemandedElts[i]) { 1603 unsigned MaskVal = Shuffle->getMaskValue(i); 1604 if (MaskVal != -1u) { 1605 assert(MaskVal < OpWidth * 2 && 1606 "shufflevector mask index out of range!"); 1607 if (MaskVal < OpWidth) 1608 LeftDemanded.setBit(MaskVal); 1609 else 1610 RightDemanded.setBit(MaskVal - OpWidth); 1611 } 1612 } 1613 } 1614 1615 APInt LHSPoisonElts(OpWidth, 0); 1616 simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts); 1617 1618 APInt RHSPoisonElts(OpWidth, 0); 1619 simplifyAndSetOp(I, 1, RightDemanded, RHSPoisonElts); 1620 1621 // If this shuffle does not change the vector length and the elements 1622 // demanded by this shuffle are an identity mask, then this shuffle is 1623 // unnecessary. 1624 // 1625 // We are assuming canonical form for the mask, so the source vector is 1626 // operand 0 and operand 1 is not used. 1627 // 1628 // Note that if an element is demanded and this shuffle mask is undefined 1629 // for that element, then the shuffle is not considered an identity 1630 // operation. The shuffle prevents poison from the operand vector from 1631 // leaking to the result by replacing poison with an undefined value. 1632 if (VWidth == OpWidth) { 1633 bool IsIdentityShuffle = true; 1634 for (unsigned i = 0; i < VWidth; i++) { 1635 unsigned MaskVal = Shuffle->getMaskValue(i); 1636 if (DemandedElts[i] && i != MaskVal) { 1637 IsIdentityShuffle = false; 1638 break; 1639 } 1640 } 1641 if (IsIdentityShuffle) 1642 return Shuffle->getOperand(0); 1643 } 1644 1645 bool NewPoisonElts = false; 1646 unsigned LHSIdx = -1u, LHSValIdx = -1u; 1647 unsigned RHSIdx = -1u, RHSValIdx = -1u; 1648 bool LHSUniform = true; 1649 bool RHSUniform = true; 1650 for (unsigned i = 0; i < VWidth; i++) { 1651 unsigned MaskVal = Shuffle->getMaskValue(i); 1652 if (MaskVal == -1u) { 1653 PoisonElts.setBit(i); 1654 } else if (!DemandedElts[i]) { 1655 NewPoisonElts = true; 1656 PoisonElts.setBit(i); 1657 } else if (MaskVal < OpWidth) { 1658 if (LHSPoisonElts[MaskVal]) { 1659 NewPoisonElts = true; 1660 PoisonElts.setBit(i); 1661 } else { 1662 LHSIdx = LHSIdx == -1u ? i : OpWidth; 1663 LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth; 1664 LHSUniform = LHSUniform && (MaskVal == i); 1665 } 1666 } else { 1667 if (RHSPoisonElts[MaskVal - OpWidth]) { 1668 NewPoisonElts = true; 1669 PoisonElts.setBit(i); 1670 } else { 1671 RHSIdx = RHSIdx == -1u ? i : OpWidth; 1672 RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth; 1673 RHSUniform = RHSUniform && (MaskVal - OpWidth == i); 1674 } 1675 } 1676 } 1677 1678 // Try to transform shuffle with constant vector and single element from 1679 // this constant vector to single insertelement instruction. 1680 // shufflevector V, C, <v1, v2, .., ci, .., vm> -> 1681 // insertelement V, C[ci], ci-n 1682 if (OpWidth == 1683 cast<FixedVectorType>(Shuffle->getType())->getNumElements()) { 1684 Value *Op = nullptr; 1685 Constant *Value = nullptr; 1686 unsigned Idx = -1u; 1687 1688 // Find constant vector with the single element in shuffle (LHS or RHS). 1689 if (LHSIdx < OpWidth && RHSUniform) { 1690 if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) { 1691 Op = Shuffle->getOperand(1); 1692 Value = CV->getOperand(LHSValIdx); 1693 Idx = LHSIdx; 1694 } 1695 } 1696 if (RHSIdx < OpWidth && LHSUniform) { 1697 if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) { 1698 Op = Shuffle->getOperand(0); 1699 Value = CV->getOperand(RHSValIdx); 1700 Idx = RHSIdx; 1701 } 1702 } 1703 // Found constant vector with single element - convert to insertelement. 1704 if (Op && Value) { 1705 Instruction *New = InsertElementInst::Create( 1706 Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx), 1707 Shuffle->getName()); 1708 InsertNewInstWith(New, Shuffle->getIterator()); 1709 return New; 1710 } 1711 } 1712 if (NewPoisonElts) { 1713 // Add additional discovered undefs. 1714 SmallVector<int, 16> Elts; 1715 for (unsigned i = 0; i < VWidth; ++i) { 1716 if (PoisonElts[i]) 1717 Elts.push_back(PoisonMaskElem); 1718 else 1719 Elts.push_back(Shuffle->getMaskValue(i)); 1720 } 1721 Shuffle->setShuffleMask(Elts); 1722 MadeChange = true; 1723 } 1724 break; 1725 } 1726 case Instruction::Select: { 1727 // If this is a vector select, try to transform the select condition based 1728 // on the current demanded elements. 1729 SelectInst *Sel = cast<SelectInst>(I); 1730 if (Sel->getCondition()->getType()->isVectorTy()) { 1731 // TODO: We are not doing anything with PoisonElts based on this call. 1732 // It is overwritten below based on the other select operands. If an 1733 // element of the select condition is known undef, then we are free to 1734 // choose the output value from either arm of the select. If we know that 1735 // one of those values is undef, then the output can be undef. 1736 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts); 1737 } 1738 1739 // Next, see if we can transform the arms of the select. 1740 APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts); 1741 if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) { 1742 for (unsigned i = 0; i < VWidth; i++) { 1743 Constant *CElt = CV->getAggregateElement(i); 1744 1745 // isNullValue() always returns false when called on a ConstantExpr. 1746 if (CElt->isNullValue()) 1747 DemandedLHS.clearBit(i); 1748 else if (CElt->isOneValue()) 1749 DemandedRHS.clearBit(i); 1750 } 1751 } 1752 1753 simplifyAndSetOp(I, 1, DemandedLHS, PoisonElts2); 1754 simplifyAndSetOp(I, 2, DemandedRHS, PoisonElts3); 1755 1756 // Output elements are undefined if the element from each arm is undefined. 1757 // TODO: This can be improved. See comment in select condition handling. 1758 PoisonElts = PoisonElts2 & PoisonElts3; 1759 break; 1760 } 1761 case Instruction::BitCast: { 1762 // Vector->vector casts only. 1763 VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType()); 1764 if (!VTy) break; 1765 unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements(); 1766 APInt InputDemandedElts(InVWidth, 0); 1767 PoisonElts2 = APInt(InVWidth, 0); 1768 unsigned Ratio; 1769 1770 if (VWidth == InVWidth) { 1771 // If we are converting from <4 x i32> -> <4 x f32>, we demand the same 1772 // elements as are demanded of us. 1773 Ratio = 1; 1774 InputDemandedElts = DemandedElts; 1775 } else if ((VWidth % InVWidth) == 0) { 1776 // If the number of elements in the output is a multiple of the number of 1777 // elements in the input then an input element is live if any of the 1778 // corresponding output elements are live. 1779 Ratio = VWidth / InVWidth; 1780 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) 1781 if (DemandedElts[OutIdx]) 1782 InputDemandedElts.setBit(OutIdx / Ratio); 1783 } else if ((InVWidth % VWidth) == 0) { 1784 // If the number of elements in the input is a multiple of the number of 1785 // elements in the output then an input element is live if the 1786 // corresponding output element is live. 1787 Ratio = InVWidth / VWidth; 1788 for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) 1789 if (DemandedElts[InIdx / Ratio]) 1790 InputDemandedElts.setBit(InIdx); 1791 } else { 1792 // Unsupported so far. 1793 break; 1794 } 1795 1796 simplifyAndSetOp(I, 0, InputDemandedElts, PoisonElts2); 1797 1798 if (VWidth == InVWidth) { 1799 PoisonElts = PoisonElts2; 1800 } else if ((VWidth % InVWidth) == 0) { 1801 // If the number of elements in the output is a multiple of the number of 1802 // elements in the input then an output element is undef if the 1803 // corresponding input element is undef. 1804 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) 1805 if (PoisonElts2[OutIdx / Ratio]) 1806 PoisonElts.setBit(OutIdx); 1807 } else if ((InVWidth % VWidth) == 0) { 1808 // If the number of elements in the input is a multiple of the number of 1809 // elements in the output then an output element is undef if all of the 1810 // corresponding input elements are undef. 1811 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { 1812 APInt SubUndef = PoisonElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio); 1813 if (SubUndef.popcount() == Ratio) 1814 PoisonElts.setBit(OutIdx); 1815 } 1816 } else { 1817 llvm_unreachable("Unimp"); 1818 } 1819 break; 1820 } 1821 case Instruction::FPTrunc: 1822 case Instruction::FPExt: 1823 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts); 1824 break; 1825 1826 case Instruction::Call: { 1827 IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); 1828 if (!II) break; 1829 switch (II->getIntrinsicID()) { 1830 case Intrinsic::masked_gather: // fallthrough 1831 case Intrinsic::masked_load: { 1832 // Subtlety: If we load from a pointer, the pointer must be valid 1833 // regardless of whether the element is demanded. Doing otherwise risks 1834 // segfaults which didn't exist in the original program. 1835 APInt DemandedPtrs(APInt::getAllOnes(VWidth)), 1836 DemandedPassThrough(DemandedElts); 1837 if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) 1838 for (unsigned i = 0; i < VWidth; i++) { 1839 Constant *CElt = CV->getAggregateElement(i); 1840 if (CElt->isNullValue()) 1841 DemandedPtrs.clearBit(i); 1842 else if (CElt->isAllOnesValue()) 1843 DemandedPassThrough.clearBit(i); 1844 } 1845 if (II->getIntrinsicID() == Intrinsic::masked_gather) 1846 simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2); 1847 simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3); 1848 1849 // Output elements are undefined if the element from both sources are. 1850 // TODO: can strengthen via mask as well. 1851 PoisonElts = PoisonElts2 & PoisonElts3; 1852 break; 1853 } 1854 default: { 1855 // Handle target specific intrinsics 1856 std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( 1857 *II, DemandedElts, PoisonElts, PoisonElts2, PoisonElts3, 1858 simplifyAndSetOp); 1859 if (V) 1860 return *V; 1861 break; 1862 } 1863 } // switch on IntrinsicID 1864 break; 1865 } // case Call 1866 } // switch on Opcode 1867 1868 // TODO: We bail completely on integer div/rem and shifts because they have 1869 // UB/poison potential, but that should be refined. 1870 BinaryOperator *BO; 1871 if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) { 1872 Value *X = BO->getOperand(0); 1873 Value *Y = BO->getOperand(1); 1874 1875 // Look for an equivalent binop except that one operand has been shuffled. 1876 // If the demand for this binop only includes elements that are the same as 1877 // the other binop, then we may be able to replace this binop with a use of 1878 // the earlier one. 1879 // 1880 // Example: 1881 // %other_bo = bo (shuf X, {0}), Y 1882 // %this_extracted_bo = extelt (bo X, Y), 0 1883 // --> 1884 // %other_bo = bo (shuf X, {0}), Y 1885 // %this_extracted_bo = extelt %other_bo, 0 1886 // 1887 // TODO: Handle demand of an arbitrary single element or more than one 1888 // element instead of just element 0. 1889 // TODO: Unlike general demanded elements transforms, this should be safe 1890 // for any (div/rem/shift) opcode too. 1891 if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() && 1892 BO->hasOneUse() ) { 1893 1894 auto findShufBO = [&](bool MatchShufAsOp0) -> User * { 1895 // Try to use shuffle-of-operand in place of an operand: 1896 // bo X, Y --> bo (shuf X), Y 1897 // bo X, Y --> bo X, (shuf Y) 1898 1899 Value *OtherOp = MatchShufAsOp0 ? Y : X; 1900 if (!OtherOp->hasUseList()) 1901 return nullptr; 1902 1903 BinaryOperator::BinaryOps Opcode = BO->getOpcode(); 1904 Value *ShufOp = MatchShufAsOp0 ? X : Y; 1905 1906 for (User *U : OtherOp->users()) { 1907 ArrayRef<int> Mask; 1908 auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_Mask(Mask)); 1909 if (BO->isCommutative() 1910 ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp))) 1911 : MatchShufAsOp0 1912 ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp))) 1913 : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf))) 1914 if (match(Mask, m_ZeroMask()) && Mask[0] != PoisonMaskElem) 1915 if (DT.dominates(U, I)) 1916 return U; 1917 } 1918 return nullptr; 1919 }; 1920 1921 if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true)) 1922 return ShufBO; 1923 if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false)) 1924 return ShufBO; 1925 } 1926 1927 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts); 1928 simplifyAndSetOp(I, 1, DemandedElts, PoisonElts2); 1929 1930 // Output elements are undefined if both are undefined. Consider things 1931 // like undef & 0. The result is known zero, not undef. 1932 PoisonElts &= PoisonElts2; 1933 } 1934 1935 // If we've proven all of the lanes poison, return a poison value. 1936 // TODO: Intersect w/demanded lanes 1937 if (PoisonElts.isAllOnes()) 1938 return PoisonValue::get(I->getType()); 1939 1940 return MadeChange ? I : nullptr; 1941 } 1942 1943 /// For floating-point classes that resolve to a single bit pattern, return that 1944 /// value. 1945 static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask) { 1946 if (Mask == fcNone) 1947 return PoisonValue::get(Ty); 1948 1949 if (Mask == fcPosZero) 1950 return Constant::getNullValue(Ty); 1951 1952 // TODO: Support aggregate types that are allowed by FPMathOperator. 1953 if (Ty->isAggregateType()) 1954 return nullptr; 1955 1956 switch (Mask) { 1957 case fcNegZero: 1958 return ConstantFP::getZero(Ty, true); 1959 case fcPosInf: 1960 return ConstantFP::getInfinity(Ty); 1961 case fcNegInf: 1962 return ConstantFP::getInfinity(Ty, true); 1963 default: 1964 return nullptr; 1965 } 1966 } 1967 1968 Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Value *V, 1969 FPClassTest DemandedMask, 1970 KnownFPClass &Known, 1971 Instruction *CxtI, 1972 unsigned Depth) { 1973 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); 1974 Type *VTy = V->getType(); 1975 1976 assert(Known == KnownFPClass() && "expected uninitialized state"); 1977 1978 if (DemandedMask == fcNone) 1979 return isa<UndefValue>(V) ? nullptr : PoisonValue::get(VTy); 1980 1981 if (Depth == MaxAnalysisRecursionDepth) 1982 return nullptr; 1983 1984 Instruction *I = dyn_cast<Instruction>(V); 1985 if (!I) { 1986 // Handle constants and arguments 1987 Known = computeKnownFPClass(V, fcAllFlags, CxtI, Depth + 1); 1988 Value *FoldedToConst = 1989 getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses); 1990 return FoldedToConst == V ? nullptr : FoldedToConst; 1991 } 1992 1993 if (!I->hasOneUse()) 1994 return nullptr; 1995 1996 if (auto *FPOp = dyn_cast<FPMathOperator>(I)) { 1997 if (FPOp->hasNoNaNs()) 1998 DemandedMask &= ~fcNan; 1999 if (FPOp->hasNoInfs()) 2000 DemandedMask &= ~fcInf; 2001 } 2002 switch (I->getOpcode()) { 2003 case Instruction::FNeg: { 2004 if (SimplifyDemandedFPClass(I, 0, llvm::fneg(DemandedMask), Known, 2005 Depth + 1)) 2006 return I; 2007 Known.fneg(); 2008 break; 2009 } 2010 case Instruction::Call: { 2011 CallInst *CI = cast<CallInst>(I); 2012 switch (CI->getIntrinsicID()) { 2013 case Intrinsic::fabs: 2014 if (SimplifyDemandedFPClass(I, 0, llvm::inverse_fabs(DemandedMask), Known, 2015 Depth + 1)) 2016 return I; 2017 Known.fabs(); 2018 break; 2019 case Intrinsic::arithmetic_fence: 2020 if (SimplifyDemandedFPClass(I, 0, DemandedMask, Known, Depth + 1)) 2021 return I; 2022 break; 2023 case Intrinsic::copysign: { 2024 // Flip on more potentially demanded classes 2025 const FPClassTest DemandedMaskAnySign = llvm::unknown_sign(DemandedMask); 2026 if (SimplifyDemandedFPClass(I, 0, DemandedMaskAnySign, Known, Depth + 1)) 2027 return I; 2028 2029 if ((DemandedMask & fcNegative) == DemandedMask) { 2030 // Roundabout way of replacing with fneg(fabs) 2031 I->setOperand(1, ConstantFP::get(VTy, -1.0)); 2032 return I; 2033 } 2034 2035 if ((DemandedMask & fcPositive) == DemandedMask) { 2036 // Roundabout way of replacing with fabs 2037 I->setOperand(1, ConstantFP::getZero(VTy)); 2038 return I; 2039 } 2040 2041 KnownFPClass KnownSign = 2042 computeKnownFPClass(I->getOperand(1), fcAllFlags, CxtI, Depth + 1); 2043 Known.copysign(KnownSign); 2044 break; 2045 } 2046 default: 2047 Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1); 2048 break; 2049 } 2050 2051 break; 2052 } 2053 case Instruction::Select: { 2054 KnownFPClass KnownLHS, KnownRHS; 2055 if (SimplifyDemandedFPClass(I, 2, DemandedMask, KnownRHS, Depth + 1) || 2056 SimplifyDemandedFPClass(I, 1, DemandedMask, KnownLHS, Depth + 1)) 2057 return I; 2058 2059 if (KnownLHS.isKnownNever(DemandedMask)) 2060 return I->getOperand(2); 2061 if (KnownRHS.isKnownNever(DemandedMask)) 2062 return I->getOperand(1); 2063 2064 // TODO: Recognize clamping patterns 2065 Known = KnownLHS | KnownRHS; 2066 break; 2067 } 2068 default: 2069 Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1); 2070 break; 2071 } 2072 2073 return getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses); 2074 } 2075 2076 bool InstCombinerImpl::SimplifyDemandedFPClass(Instruction *I, unsigned OpNo, 2077 FPClassTest DemandedMask, 2078 KnownFPClass &Known, 2079 unsigned Depth) { 2080 Use &U = I->getOperandUse(OpNo); 2081 Value *NewVal = 2082 SimplifyDemandedUseFPClass(U.get(), DemandedMask, Known, I, Depth); 2083 if (!NewVal) 2084 return false; 2085 if (Instruction *OpInst = dyn_cast<Instruction>(U)) 2086 salvageDebugInfo(*OpInst); 2087 2088 replaceUse(U, NewVal); 2089 return true; 2090 } 2091