1 //===- InstCombineSimplifyDemanded.cpp ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains logic for simplifying instructions based on information 10 // about how they are used. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "InstCombineInternal.h" 15 #include "llvm/Analysis/ValueTracking.h" 16 #include "llvm/IR/GetElementPtrTypeIterator.h" 17 #include "llvm/IR/IntrinsicInst.h" 18 #include "llvm/IR/PatternMatch.h" 19 #include "llvm/Support/KnownBits.h" 20 #include "llvm/Transforms/InstCombine/InstCombiner.h" 21 22 using namespace llvm; 23 using namespace llvm::PatternMatch; 24 25 #define DEBUG_TYPE "instcombine" 26 27 static cl::opt<bool> 28 VerifyKnownBits("instcombine-verify-known-bits", 29 cl::desc("Verify that computeKnownBits() and " 30 "SimplifyDemandedBits() are consistent"), 31 cl::Hidden, cl::init(false)); 32 33 /// Check to see if the specified operand of the specified instruction is a 34 /// constant integer. If so, check to see if there are any bits set in the 35 /// constant that are not demanded. If so, shrink the constant and return true. 36 static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, 37 const APInt &Demanded) { 38 assert(I && "No instruction?"); 39 assert(OpNo < I->getNumOperands() && "Operand index too large"); 40 41 // The operand must be a constant integer or splat integer. 42 Value *Op = I->getOperand(OpNo); 43 const APInt *C; 44 if (!match(Op, m_APInt(C))) 45 return false; 46 47 // If there are no bits set that aren't demanded, nothing to do. 48 if (C->isSubsetOf(Demanded)) 49 return false; 50 51 // This instruction is producing bits that are not demanded. Shrink the RHS. 52 I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded)); 53 54 return true; 55 } 56 57 /// Returns the bitwidth of the given scalar or pointer type. For vector types, 58 /// returns the element type's bitwidth. 59 static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { 60 if (unsigned BitWidth = Ty->getScalarSizeInBits()) 61 return BitWidth; 62 63 return DL.getPointerTypeSizeInBits(Ty); 64 } 65 66 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if 67 /// the instruction has any properties that allow us to simplify its operands. 68 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst, 69 KnownBits &Known) { 70 APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth())); 71 Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 72 0, SQ.getWithInstruction(&Inst)); 73 if (!V) return false; 74 if (V == &Inst) return true; 75 replaceInstUsesWith(Inst, V); 76 return true; 77 } 78 79 /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if 80 /// the instruction has any properties that allow us to simplify its operands. 81 bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { 82 KnownBits Known(getBitWidth(Inst.getType(), DL)); 83 return SimplifyDemandedInstructionBits(Inst, Known); 84 } 85 86 /// This form of SimplifyDemandedBits simplifies the specified instruction 87 /// operand if possible, updating it in place. It returns true if it made any 88 /// change and false otherwise. 89 bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, 90 const APInt &DemandedMask, 91 KnownBits &Known, unsigned Depth, 92 const SimplifyQuery &Q) { 93 Use &U = I->getOperandUse(OpNo); 94 Value *V = U.get(); 95 if (isa<Constant>(V)) { 96 llvm::computeKnownBits(V, Known, Depth, Q); 97 return false; 98 } 99 100 Known.resetAll(); 101 if (DemandedMask.isZero()) { 102 // Not demanding any bits from V. 103 replaceUse(U, UndefValue::get(V->getType())); 104 return true; 105 } 106 107 if (Depth == MaxAnalysisRecursionDepth) 108 return false; 109 110 Instruction *VInst = dyn_cast<Instruction>(V); 111 if (!VInst) { 112 llvm::computeKnownBits(V, Known, Depth, Q); 113 return false; 114 } 115 116 Value *NewVal; 117 if (VInst->hasOneUse()) { 118 // If the instruction has one use, we can directly simplify it. 119 NewVal = SimplifyDemandedUseBits(VInst, DemandedMask, Known, Depth, Q); 120 } else { 121 // If there are multiple uses of this instruction, then we can simplify 122 // VInst to some other value, but not modify the instruction. 123 NewVal = 124 SimplifyMultipleUseDemandedBits(VInst, DemandedMask, Known, Depth, Q); 125 } 126 if (!NewVal) return false; 127 if (Instruction* OpInst = dyn_cast<Instruction>(U)) 128 salvageDebugInfo(*OpInst); 129 130 replaceUse(U, NewVal); 131 return true; 132 } 133 134 /// This function attempts to replace V with a simpler value based on the 135 /// demanded bits. When this function is called, it is known that only the bits 136 /// set in DemandedMask of the result of V are ever used downstream. 137 /// Consequently, depending on the mask and V, it may be possible to replace V 138 /// with a constant or one of its operands. In such cases, this function does 139 /// the replacement and returns true. In all other cases, it returns false after 140 /// analyzing the expression and setting KnownOne and known to be one in the 141 /// expression. Known.Zero contains all the bits that are known to be zero in 142 /// the expression. These are provided to potentially allow the caller (which 143 /// might recursively be SimplifyDemandedBits itself) to simplify the 144 /// expression. 145 /// Known.One and Known.Zero always follow the invariant that: 146 /// Known.One & Known.Zero == 0. 147 /// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero 148 /// are accurate even for bits not in DemandedMask. Note 149 /// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all 150 /// be the same. 151 /// 152 /// This returns null if it did not change anything and it permits no 153 /// simplification. This returns V itself if it did some simplification of V's 154 /// operands based on the information about what bits are demanded. This returns 155 /// some other non-null value if it found out that V is equal to another value 156 /// in the context where the specified bits are demanded, but not for all users. 157 Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, 158 const APInt &DemandedMask, 159 KnownBits &Known, 160 unsigned Depth, 161 const SimplifyQuery &Q) { 162 assert(I != nullptr && "Null pointer of Value???"); 163 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); 164 uint32_t BitWidth = DemandedMask.getBitWidth(); 165 Type *VTy = I->getType(); 166 assert( 167 (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) && 168 Known.getBitWidth() == BitWidth && 169 "Value *V, DemandedMask and Known must have same BitWidth"); 170 171 KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); 172 173 // Update flags after simplifying an operand based on the fact that some high 174 // order bits are not demanded. 175 auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I, 176 unsigned NLZ) { 177 if (NLZ > 0) { 178 // Disable the nsw and nuw flags here: We can no longer guarantee that 179 // we won't wrap after simplification. Removing the nsw/nuw flags is 180 // legal here because the top bit is not demanded. 181 I->setHasNoSignedWrap(false); 182 I->setHasNoUnsignedWrap(false); 183 } 184 return I; 185 }; 186 187 // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care 188 // about the high bits of the operands. 189 auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { 190 unsigned NLZ = DemandedMask.countl_zero(); 191 // Right fill the mask of bits for the operands to demand the most 192 // significant bit and all those below it. 193 DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 194 if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || 195 SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1, Q) || 196 ShrinkDemandedConstant(I, 1, DemandedFromOps) || 197 SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) { 198 disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 199 return true; 200 } 201 return false; 202 }; 203 204 switch (I->getOpcode()) { 205 default: 206 llvm::computeKnownBits(I, Known, Depth, Q); 207 break; 208 case Instruction::And: { 209 // If either the LHS or the RHS are Zero, the result is zero. 210 if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) || 211 SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown, 212 Depth + 1, Q)) 213 return I; 214 215 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 216 Depth, Q); 217 218 // If the client is only demanding bits that we know, return the known 219 // constant. 220 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 221 return Constant::getIntegerValue(VTy, Known.One); 222 223 // If all of the demanded bits are known 1 on one side, return the other. 224 // These bits cannot contribute to the result of the 'and'. 225 if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) 226 return I->getOperand(0); 227 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) 228 return I->getOperand(1); 229 230 // If the RHS is a constant, see if we can simplify it. 231 if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero)) 232 return I; 233 234 break; 235 } 236 case Instruction::Or: { 237 // If either the LHS or the RHS are One, the result is One. 238 if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) || 239 SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, 240 Depth + 1, Q)) { 241 // Disjoint flag may not longer hold. 242 I->dropPoisonGeneratingFlags(); 243 return I; 244 } 245 246 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 247 Depth, Q); 248 249 // If the client is only demanding bits that we know, return the known 250 // constant. 251 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 252 return Constant::getIntegerValue(VTy, Known.One); 253 254 // If all of the demanded bits are known zero on one side, return the other. 255 // These bits cannot contribute to the result of the 'or'. 256 if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) 257 return I->getOperand(0); 258 if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) 259 return I->getOperand(1); 260 261 // If the RHS is a constant, see if we can simplify it. 262 if (ShrinkDemandedConstant(I, 1, DemandedMask)) 263 return I; 264 265 // Infer disjoint flag if no common bits are set. 266 if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) { 267 WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown), 268 RHSCache(I->getOperand(1), RHSKnown); 269 if (haveNoCommonBitsSet(LHSCache, RHSCache, Q)) { 270 cast<PossiblyDisjointInst>(I)->setIsDisjoint(true); 271 return I; 272 } 273 } 274 275 break; 276 } 277 case Instruction::Xor: { 278 if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) || 279 SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q)) 280 return I; 281 Value *LHS, *RHS; 282 if (DemandedMask == 1 && 283 match(I->getOperand(0), m_Intrinsic<Intrinsic::ctpop>(m_Value(LHS))) && 284 match(I->getOperand(1), m_Intrinsic<Intrinsic::ctpop>(m_Value(RHS)))) { 285 // (ctpop(X) ^ ctpop(Y)) & 1 --> ctpop(X^Y) & 1 286 IRBuilderBase::InsertPointGuard Guard(Builder); 287 Builder.SetInsertPoint(I); 288 auto *Xor = Builder.CreateXor(LHS, RHS); 289 return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor); 290 } 291 292 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 293 Depth, Q); 294 295 // If the client is only demanding bits that we know, return the known 296 // constant. 297 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 298 return Constant::getIntegerValue(VTy, Known.One); 299 300 // If all of the demanded bits are known zero on one side, return the other. 301 // These bits cannot contribute to the result of the 'xor'. 302 if (DemandedMask.isSubsetOf(RHSKnown.Zero)) 303 return I->getOperand(0); 304 if (DemandedMask.isSubsetOf(LHSKnown.Zero)) 305 return I->getOperand(1); 306 307 // If all of the demanded bits are known to be zero on one side or the 308 // other, turn this into an *inclusive* or. 309 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 310 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) { 311 Instruction *Or = 312 BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1)); 313 if (DemandedMask.isAllOnes()) 314 cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true); 315 Or->takeName(I); 316 return InsertNewInstWith(Or, I->getIterator()); 317 } 318 319 // If all of the demanded bits on one side are known, and all of the set 320 // bits on that side are also known to be set on the other side, turn this 321 // into an AND, as we know the bits will be cleared. 322 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 323 if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) && 324 RHSKnown.One.isSubsetOf(LHSKnown.One)) { 325 Constant *AndC = Constant::getIntegerValue(VTy, 326 ~RHSKnown.One & DemandedMask); 327 Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); 328 return InsertNewInstWith(And, I->getIterator()); 329 } 330 331 // If the RHS is a constant, see if we can change it. Don't alter a -1 332 // constant because that's a canonical 'not' op, and that is better for 333 // combining, SCEV, and codegen. 334 const APInt *C; 335 if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { 336 if ((*C | ~DemandedMask).isAllOnes()) { 337 // Force bits to 1 to create a 'not' op. 338 I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); 339 return I; 340 } 341 // If we can't turn this into a 'not', try to shrink the constant. 342 if (ShrinkDemandedConstant(I, 1, DemandedMask)) 343 return I; 344 } 345 346 // If our LHS is an 'and' and if it has one use, and if any of the bits we 347 // are flipping are known to be set, then the xor is just resetting those 348 // bits to zero. We can just knock out bits from the 'and' and the 'xor', 349 // simplifying both of them. 350 if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) { 351 ConstantInt *AndRHS, *XorRHS; 352 if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() && 353 match(I->getOperand(1), m_ConstantInt(XorRHS)) && 354 match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) && 355 (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { 356 APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); 357 358 Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); 359 Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); 360 InsertNewInstWith(NewAnd, I->getIterator()); 361 362 Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); 363 Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); 364 return InsertNewInstWith(NewXor, I->getIterator()); 365 } 366 } 367 break; 368 } 369 case Instruction::Select: { 370 if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1, Q) || 371 SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1, Q)) 372 return I; 373 374 // If the operands are constants, see if we can simplify them. 375 // This is similar to ShrinkDemandedConstant, but for a select we want to 376 // try to keep the selected constants the same as icmp value constants, if 377 // we can. This helps not break apart (or helps put back together) 378 // canonical patterns like min and max. 379 auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo, 380 const APInt &DemandedMask) { 381 const APInt *SelC; 382 if (!match(I->getOperand(OpNo), m_APInt(SelC))) 383 return false; 384 385 // Get the constant out of the ICmp, if there is one. 386 // Only try this when exactly 1 operand is a constant (if both operands 387 // are constant, the icmp should eventually simplify). Otherwise, we may 388 // invert the transform that reduces set bits and infinite-loop. 389 Value *X; 390 const APInt *CmpC; 391 ICmpInst::Predicate Pred; 392 if (!match(I->getOperand(0), m_ICmp(Pred, m_Value(X), m_APInt(CmpC))) || 393 isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth()) 394 return ShrinkDemandedConstant(I, OpNo, DemandedMask); 395 396 // If the constant is already the same as the ICmp, leave it as-is. 397 if (*CmpC == *SelC) 398 return false; 399 // If the constants are not already the same, but can be with the demand 400 // mask, use the constant value from the ICmp. 401 if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) { 402 I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC)); 403 return true; 404 } 405 return ShrinkDemandedConstant(I, OpNo, DemandedMask); 406 }; 407 if (CanonicalizeSelectConstant(I, 1, DemandedMask) || 408 CanonicalizeSelectConstant(I, 2, DemandedMask)) 409 return I; 410 411 // Only known if known in both the LHS and RHS. 412 adjustKnownBitsForSelectArm(LHSKnown, I->getOperand(0), I->getOperand(1), 413 /*Invert=*/false, Depth, Q); 414 adjustKnownBitsForSelectArm(RHSKnown, I->getOperand(0), I->getOperand(2), 415 /*Invert=*/true, Depth, Q); 416 Known = LHSKnown.intersectWith(RHSKnown); 417 break; 418 } 419 case Instruction::Trunc: { 420 // If we do not demand the high bits of a right-shifted and truncated value, 421 // then we may be able to truncate it before the shift. 422 Value *X; 423 const APInt *C; 424 if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { 425 // The shift amount must be valid (not poison) in the narrow type, and 426 // it must not be greater than the high bits demanded of the result. 427 if (C->ult(VTy->getScalarSizeInBits()) && 428 C->ule(DemandedMask.countl_zero())) { 429 // trunc (lshr X, C) --> lshr (trunc X), C 430 IRBuilderBase::InsertPointGuard Guard(Builder); 431 Builder.SetInsertPoint(I); 432 Value *Trunc = Builder.CreateTrunc(X, VTy); 433 return Builder.CreateLShr(Trunc, C->getZExtValue()); 434 } 435 } 436 } 437 [[fallthrough]]; 438 case Instruction::ZExt: { 439 unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); 440 441 APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); 442 KnownBits InputKnown(SrcBitWidth); 443 if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1, 444 Q)) { 445 // For zext nneg, we may have dropped the instruction which made the 446 // input non-negative. 447 I->dropPoisonGeneratingFlags(); 448 return I; 449 } 450 assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); 451 if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() && 452 !InputKnown.isNegative()) 453 InputKnown.makeNonNegative(); 454 Known = InputKnown.zextOrTrunc(BitWidth); 455 456 break; 457 } 458 case Instruction::SExt: { 459 // Compute the bits in the result that are not present in the input. 460 unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); 461 462 APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth); 463 464 // If any of the sign extended bits are demanded, we know that the sign 465 // bit is demanded. 466 if (DemandedMask.getActiveBits() > SrcBitWidth) 467 InputDemandedBits.setBit(SrcBitWidth-1); 468 469 KnownBits InputKnown(SrcBitWidth); 470 if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1, Q)) 471 return I; 472 473 // If the input sign bit is known zero, or if the NewBits are not demanded 474 // convert this into a zero extension. 475 if (InputKnown.isNonNegative() || 476 DemandedMask.getActiveBits() <= SrcBitWidth) { 477 // Convert to ZExt cast. 478 CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy); 479 NewCast->takeName(I); 480 return InsertNewInstWith(NewCast, I->getIterator()); 481 } 482 483 // If the sign bit of the input is known set or clear, then we know the 484 // top bits of the result. 485 Known = InputKnown.sext(BitWidth); 486 break; 487 } 488 case Instruction::Add: { 489 if ((DemandedMask & 1) == 0) { 490 // If we do not need the low bit, try to convert bool math to logic: 491 // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN 492 Value *X, *Y; 493 if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))), 494 m_OneUse(m_SExt(m_Value(Y))))) && 495 X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { 496 // Truth table for inputs and output signbits: 497 // X:0 | X:1 498 // ---------- 499 // Y:0 | 0 | 0 | 500 // Y:1 | -1 | 0 | 501 // ---------- 502 IRBuilderBase::InsertPointGuard Guard(Builder); 503 Builder.SetInsertPoint(I); 504 Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y); 505 return Builder.CreateSExt(AndNot, VTy); 506 } 507 508 // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN 509 if (match(I, m_Add(m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) && 510 X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && 511 (I->getOperand(0)->hasOneUse() || I->getOperand(1)->hasOneUse())) { 512 513 // Truth table for inputs and output signbits: 514 // X:0 | X:1 515 // ----------- 516 // Y:0 | -1 | -1 | 517 // Y:1 | -1 | 0 | 518 // ----------- 519 IRBuilderBase::InsertPointGuard Guard(Builder); 520 Builder.SetInsertPoint(I); 521 Value *Or = Builder.CreateOr(X, Y); 522 return Builder.CreateSExt(Or, VTy); 523 } 524 } 525 526 // Right fill the mask of bits for the operands to demand the most 527 // significant bit and all those below it. 528 unsigned NLZ = DemandedMask.countl_zero(); 529 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 530 if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || 531 SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) 532 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 533 534 // If low order bits are not demanded and known to be zero in one operand, 535 // then we don't need to demand them from the other operand, since they 536 // can't cause overflow into any bits that are demanded in the result. 537 unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); 538 APInt DemandedFromLHS = DemandedFromOps; 539 DemandedFromLHS.clearLowBits(NTZ); 540 if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || 541 SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q)) 542 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 543 544 // If we are known to be adding zeros to every bit below 545 // the highest demanded bit, we just return the other side. 546 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 547 return I->getOperand(0); 548 if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) 549 return I->getOperand(1); 550 551 // (add X, C) --> (xor X, C) IFF C is equal to the top bit of the DemandMask 552 { 553 const APInt *C; 554 if (match(I->getOperand(1), m_APInt(C)) && 555 C->isOneBitSet(DemandedMask.getActiveBits() - 1)) { 556 IRBuilderBase::InsertPointGuard Guard(Builder); 557 Builder.SetInsertPoint(I); 558 return Builder.CreateXor(I->getOperand(0), ConstantInt::get(VTy, *C)); 559 } 560 } 561 562 // Otherwise just compute the known bits of the result. 563 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 564 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 565 Known = KnownBits::computeForAddSub(true, NSW, NUW, LHSKnown, RHSKnown); 566 break; 567 } 568 case Instruction::Sub: { 569 // Right fill the mask of bits for the operands to demand the most 570 // significant bit and all those below it. 571 unsigned NLZ = DemandedMask.countl_zero(); 572 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 573 if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || 574 SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) 575 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 576 577 // If low order bits are not demanded and are known to be zero in RHS, 578 // then we don't need to demand them from LHS, since they can't cause a 579 // borrow from any bits that are demanded in the result. 580 unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); 581 APInt DemandedFromLHS = DemandedFromOps; 582 DemandedFromLHS.clearLowBits(NTZ); 583 if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || 584 SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q)) 585 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); 586 587 // If we are known to be subtracting zeros from every bit below 588 // the highest demanded bit, we just return the other side. 589 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 590 return I->getOperand(0); 591 // We can't do this with the LHS for subtraction, unless we are only 592 // demanding the LSB. 593 if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) 594 return I->getOperand(1); 595 596 // Otherwise just compute the known bits of the result. 597 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 598 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 599 Known = KnownBits::computeForAddSub(false, NSW, NUW, LHSKnown, RHSKnown); 600 break; 601 } 602 case Instruction::Mul: { 603 APInt DemandedFromOps; 604 if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) 605 return I; 606 607 if (DemandedMask.isPowerOf2()) { 608 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. 609 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is 610 // odd (has LSB set), then the left-shifted low bit of X is the answer. 611 unsigned CTZ = DemandedMask.countr_zero(); 612 const APInt *C; 613 if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) { 614 Constant *ShiftC = ConstantInt::get(VTy, CTZ); 615 Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); 616 return InsertNewInstWith(Shl, I->getIterator()); 617 } 618 } 619 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: 620 // X * X is odd iff X is odd. 621 // 'Quadratic Reciprocity': X * X -> 0 for bit[1] 622 if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { 623 Constant *One = ConstantInt::get(VTy, 1); 624 Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); 625 return InsertNewInstWith(And1, I->getIterator()); 626 } 627 628 llvm::computeKnownBits(I, Known, Depth, Q); 629 break; 630 } 631 case Instruction::Shl: { 632 const APInt *SA; 633 if (match(I->getOperand(1), m_APInt(SA))) { 634 const APInt *ShrAmt; 635 if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) 636 if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0))) 637 if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA, 638 DemandedMask, Known)) 639 return R; 640 641 // Do not simplify if shl is part of funnel-shift pattern 642 if (I->hasOneUse()) { 643 auto *Inst = dyn_cast<Instruction>(I->user_back()); 644 if (Inst && Inst->getOpcode() == BinaryOperator::Or) { 645 if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) { 646 auto [IID, FShiftArgs] = *Opt; 647 if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 648 FShiftArgs[0] == FShiftArgs[1]) { 649 llvm::computeKnownBits(I, Known, Depth, Q); 650 break; 651 } 652 } 653 } 654 } 655 656 // We only want bits that already match the signbit then we don't 657 // need to shift. 658 uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1); 659 if (DemandedMask.countr_zero() >= ShiftAmt) { 660 if (I->hasNoSignedWrap()) { 661 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); 662 unsigned SignBits = 663 ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI); 664 if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits) 665 return I->getOperand(0); 666 } 667 668 // If we can pre-shift a right-shifted constant to the left without 669 // losing any high bits and we don't demand the low bits, then eliminate 670 // the left-shift: 671 // (C >> X) << LeftShiftAmtC --> (C << LeftShiftAmtC) >> X 672 Value *X; 673 Constant *C; 674 if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { 675 Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); 676 Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C, 677 LeftShiftAmtC, DL); 678 if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, 679 LeftShiftAmtC, DL) == C) { 680 Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); 681 return InsertNewInstWith(Lshr, I->getIterator()); 682 } 683 } 684 } 685 686 APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); 687 688 // If the shift is NUW/NSW, then it does demand the high bits. 689 ShlOperator *IOp = cast<ShlOperator>(I); 690 if (IOp->hasNoSignedWrap()) 691 DemandedMaskIn.setHighBits(ShiftAmt+1); 692 else if (IOp->hasNoUnsignedWrap()) 693 DemandedMaskIn.setHighBits(ShiftAmt); 694 695 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) 696 return I; 697 698 Known = KnownBits::shl(Known, 699 KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), 700 /* NUW */ IOp->hasNoUnsignedWrap(), 701 /* NSW */ IOp->hasNoSignedWrap()); 702 } else { 703 // This is a variable shift, so we can't shift the demand mask by a known 704 // amount. But if we are not demanding high bits, then we are not 705 // demanding those bits from the pre-shifted operand either. 706 if (unsigned CTLZ = DemandedMask.countl_zero()) { 707 APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ)); 708 if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1, Q)) { 709 // We can't guarantee that nsw/nuw hold after simplifying the operand. 710 I->dropPoisonGeneratingFlags(); 711 return I; 712 } 713 } 714 llvm::computeKnownBits(I, Known, Depth, Q); 715 } 716 break; 717 } 718 case Instruction::LShr: { 719 const APInt *SA; 720 if (match(I->getOperand(1), m_APInt(SA))) { 721 uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); 722 723 // Do not simplify if lshr is part of funnel-shift pattern 724 if (I->hasOneUse()) { 725 auto *Inst = dyn_cast<Instruction>(I->user_back()); 726 if (Inst && Inst->getOpcode() == BinaryOperator::Or) { 727 if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) { 728 auto [IID, FShiftArgs] = *Opt; 729 if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 730 FShiftArgs[0] == FShiftArgs[1]) { 731 llvm::computeKnownBits(I, Known, Depth, Q); 732 break; 733 } 734 } 735 } 736 } 737 738 // If we are just demanding the shifted sign bit and below, then this can 739 // be treated as an ASHR in disguise. 740 if (DemandedMask.countl_zero() >= ShiftAmt) { 741 // If we only want bits that already match the signbit then we don't 742 // need to shift. 743 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); 744 unsigned SignBits = 745 ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI); 746 if (SignBits >= NumHiDemandedBits) 747 return I->getOperand(0); 748 749 // If we can pre-shift a left-shifted constant to the right without 750 // losing any low bits (we already know we don't demand the high bits), 751 // then eliminate the right-shift: 752 // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X 753 Value *X; 754 Constant *C; 755 if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { 756 Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); 757 Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C, 758 RightShiftAmtC, DL); 759 if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC, 760 RightShiftAmtC, DL) == C) { 761 Instruction *Shl = BinaryOperator::CreateShl(NewC, X); 762 return InsertNewInstWith(Shl, I->getIterator()); 763 } 764 } 765 } 766 767 // Unsigned shift right. 768 APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); 769 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) { 770 // exact flag may not longer hold. 771 I->dropPoisonGeneratingFlags(); 772 return I; 773 } 774 Known.Zero.lshrInPlace(ShiftAmt); 775 Known.One.lshrInPlace(ShiftAmt); 776 if (ShiftAmt) 777 Known.Zero.setHighBits(ShiftAmt); // high bits known zero. 778 } else { 779 llvm::computeKnownBits(I, Known, Depth, Q); 780 } 781 break; 782 } 783 case Instruction::AShr: { 784 unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI); 785 786 // If we only want bits that already match the signbit then we don't need 787 // to shift. 788 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); 789 if (SignBits >= NumHiDemandedBits) 790 return I->getOperand(0); 791 792 // If this is an arithmetic shift right and only the low-bit is set, we can 793 // always convert this into a logical shr, even if the shift amount is 794 // variable. The low bit of the shift cannot be an input sign bit unless 795 // the shift amount is >= the size of the datatype, which is undefined. 796 if (DemandedMask.isOne()) { 797 // Perform the logical shift right. 798 Instruction *NewVal = BinaryOperator::CreateLShr( 799 I->getOperand(0), I->getOperand(1), I->getName()); 800 return InsertNewInstWith(NewVal, I->getIterator()); 801 } 802 803 const APInt *SA; 804 if (match(I->getOperand(1), m_APInt(SA))) { 805 uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); 806 807 // Signed shift right. 808 APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); 809 // If any of the bits being shifted in are demanded, then we should set 810 // the sign bit as demanded. 811 bool ShiftedInBitsDemanded = DemandedMask.countl_zero() < ShiftAmt; 812 if (ShiftedInBitsDemanded) 813 DemandedMaskIn.setSignBit(); 814 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) { 815 // exact flag may not longer hold. 816 I->dropPoisonGeneratingFlags(); 817 return I; 818 } 819 820 // If the input sign bit is known to be zero, or if none of the shifted in 821 // bits are demanded, turn this into an unsigned shift right. 822 if (Known.Zero[BitWidth - 1] || !ShiftedInBitsDemanded) { 823 BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), 824 I->getOperand(1)); 825 LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); 826 LShr->takeName(I); 827 return InsertNewInstWith(LShr, I->getIterator()); 828 } 829 830 Known = KnownBits::ashr( 831 Known, KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), 832 ShiftAmt != 0, I->isExact()); 833 } else { 834 llvm::computeKnownBits(I, Known, Depth, Q); 835 } 836 break; 837 } 838 case Instruction::UDiv: { 839 // UDiv doesn't demand low bits that are zero in the divisor. 840 const APInt *SA; 841 if (match(I->getOperand(1), m_APInt(SA))) { 842 // TODO: Take the demanded mask of the result into account. 843 unsigned RHSTrailingZeros = SA->countr_zero(); 844 APInt DemandedMaskIn = 845 APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); 846 if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1, Q)) { 847 // We can't guarantee that "exact" is still true after changing the 848 // the dividend. 849 I->dropPoisonGeneratingFlags(); 850 return I; 851 } 852 853 Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA), 854 cast<BinaryOperator>(I)->isExact()); 855 } else { 856 llvm::computeKnownBits(I, Known, Depth, Q); 857 } 858 break; 859 } 860 case Instruction::SRem: { 861 const APInt *Rem; 862 if (match(I->getOperand(1), m_APInt(Rem))) { 863 // X % -1 demands all the bits because we don't want to introduce 864 // INT_MIN % -1 (== undef) by accident. 865 if (Rem->isAllOnes()) 866 break; 867 APInt RA = Rem->abs(); 868 if (RA.isPowerOf2()) { 869 if (DemandedMask.ult(RA)) // srem won't affect demanded bits 870 return I->getOperand(0); 871 872 APInt LowBits = RA - 1; 873 APInt Mask2 = LowBits | APInt::getSignMask(BitWidth); 874 if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1, Q)) 875 return I; 876 877 // The low bits of LHS are unchanged by the srem. 878 Known.Zero = LHSKnown.Zero & LowBits; 879 Known.One = LHSKnown.One & LowBits; 880 881 // If LHS is non-negative or has all low bits zero, then the upper bits 882 // are all zero. 883 if (LHSKnown.isNonNegative() || LowBits.isSubsetOf(LHSKnown.Zero)) 884 Known.Zero |= ~LowBits; 885 886 // If LHS is negative and not all low bits are zero, then the upper bits 887 // are all one. 888 if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One)) 889 Known.One |= ~LowBits; 890 891 break; 892 } 893 } 894 895 llvm::computeKnownBits(I, Known, Depth, Q); 896 break; 897 } 898 case Instruction::Call: { 899 bool KnownBitsComputed = false; 900 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { 901 switch (II->getIntrinsicID()) { 902 case Intrinsic::abs: { 903 if (DemandedMask == 1) 904 return II->getArgOperand(0); 905 break; 906 } 907 case Intrinsic::ctpop: { 908 // Checking if the number of clear bits is odd (parity)? If the type has 909 // an even number of bits, that's the same as checking if the number of 910 // set bits is odd, so we can eliminate the 'not' op. 911 Value *X; 912 if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 && 913 match(II->getArgOperand(0), m_Not(m_Value(X)))) { 914 Function *Ctpop = Intrinsic::getDeclaration( 915 II->getModule(), Intrinsic::ctpop, VTy); 916 return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator()); 917 } 918 break; 919 } 920 case Intrinsic::bswap: { 921 // If the only bits demanded come from one byte of the bswap result, 922 // just shift the input byte into position to eliminate the bswap. 923 unsigned NLZ = DemandedMask.countl_zero(); 924 unsigned NTZ = DemandedMask.countr_zero(); 925 926 // Round NTZ down to the next byte. If we have 11 trailing zeros, then 927 // we need all the bits down to bit 8. Likewise, round NLZ. If we 928 // have 14 leading zeros, round to 8. 929 NLZ = alignDown(NLZ, 8); 930 NTZ = alignDown(NTZ, 8); 931 // If we need exactly one byte, we can do this transformation. 932 if (BitWidth - NLZ - NTZ == 8) { 933 // Replace this with either a left or right shift to get the byte into 934 // the right place. 935 Instruction *NewVal; 936 if (NLZ > NTZ) 937 NewVal = BinaryOperator::CreateLShr( 938 II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ)); 939 else 940 NewVal = BinaryOperator::CreateShl( 941 II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); 942 NewVal->takeName(I); 943 return InsertNewInstWith(NewVal, I->getIterator()); 944 } 945 break; 946 } 947 case Intrinsic::ptrmask: { 948 unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits(); 949 RHSKnown = KnownBits(MaskWidth); 950 // If either the LHS or the RHS are Zero, the result is zero. 951 if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q) || 952 SimplifyDemandedBits( 953 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth), 954 RHSKnown, Depth + 1, Q)) 955 return I; 956 957 // TODO: Should be 1-extend 958 RHSKnown = RHSKnown.anyextOrTrunc(BitWidth); 959 960 Known = LHSKnown & RHSKnown; 961 KnownBitsComputed = true; 962 963 // If the client is only demanding bits we know to be zero, return 964 // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer 965 // provenance, but making the mask zero will be easily optimizable in 966 // the backend. 967 if (DemandedMask.isSubsetOf(Known.Zero) && 968 !match(I->getOperand(1), m_Zero())) 969 return replaceOperand( 970 *I, 1, Constant::getNullValue(I->getOperand(1)->getType())); 971 972 // Mask in demanded space does nothing. 973 // NOTE: We may have attributes associated with the return value of the 974 // llvm.ptrmask intrinsic that will be lost when we just return the 975 // operand. We should try to preserve them. 976 if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) 977 return I->getOperand(0); 978 979 // If the RHS is a constant, see if we can simplify it. 980 if (ShrinkDemandedConstant( 981 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth))) 982 return I; 983 984 // Combine: 985 // (ptrmask (getelementptr i8, ptr p, imm i), imm mask) 986 // -> (ptrmask (getelementptr i8, ptr p, imm (i & mask)), imm mask) 987 // where only the low bits known to be zero in the pointer are changed 988 Value *InnerPtr; 989 uint64_t GEPIndex; 990 uint64_t PtrMaskImmediate; 991 if (match(I, m_Intrinsic<Intrinsic::ptrmask>( 992 m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)), 993 m_ConstantInt(PtrMaskImmediate)))) { 994 995 LHSKnown = computeKnownBits(InnerPtr, Depth + 1, I); 996 if (!LHSKnown.isZero()) { 997 const unsigned trailingZeros = LHSKnown.countMinTrailingZeros(); 998 uint64_t PointerAlignBits = (uint64_t(1) << trailingZeros) - 1; 999 1000 uint64_t HighBitsGEPIndex = GEPIndex & ~PointerAlignBits; 1001 uint64_t MaskedLowBitsGEPIndex = 1002 GEPIndex & PointerAlignBits & PtrMaskImmediate; 1003 1004 uint64_t MaskedGEPIndex = HighBitsGEPIndex | MaskedLowBitsGEPIndex; 1005 1006 if (MaskedGEPIndex != GEPIndex) { 1007 auto *GEP = cast<GetElementPtrInst>(II->getArgOperand(0)); 1008 Builder.SetInsertPoint(I); 1009 Type *GEPIndexType = 1010 DL.getIndexType(GEP->getPointerOperand()->getType()); 1011 Value *MaskedGEP = Builder.CreateGEP( 1012 GEP->getSourceElementType(), InnerPtr, 1013 ConstantInt::get(GEPIndexType, MaskedGEPIndex), 1014 GEP->getName(), GEP->isInBounds()); 1015 1016 replaceOperand(*I, 0, MaskedGEP); 1017 return I; 1018 } 1019 } 1020 } 1021 1022 break; 1023 } 1024 1025 case Intrinsic::fshr: 1026 case Intrinsic::fshl: { 1027 const APInt *SA; 1028 if (!match(I->getOperand(2), m_APInt(SA))) 1029 break; 1030 1031 // Normalize to funnel shift left. APInt shifts of BitWidth are well- 1032 // defined, so no need to special-case zero shifts here. 1033 uint64_t ShiftAmt = SA->urem(BitWidth); 1034 if (II->getIntrinsicID() == Intrinsic::fshr) 1035 ShiftAmt = BitWidth - ShiftAmt; 1036 1037 APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt)); 1038 APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); 1039 if (I->getOperand(0) != I->getOperand(1)) { 1040 if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, 1041 Depth + 1, Q) || 1042 SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1, 1043 Q)) 1044 return I; 1045 } else { // fshl is a rotate 1046 // Avoid converting rotate into funnel shift. 1047 // Only simplify if one operand is constant. 1048 LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I); 1049 if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) && 1050 !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) { 1051 replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One)); 1052 return I; 1053 } 1054 1055 RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I); 1056 if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) && 1057 !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) { 1058 replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One)); 1059 return I; 1060 } 1061 } 1062 1063 Known.Zero = LHSKnown.Zero.shl(ShiftAmt) | 1064 RHSKnown.Zero.lshr(BitWidth - ShiftAmt); 1065 Known.One = LHSKnown.One.shl(ShiftAmt) | 1066 RHSKnown.One.lshr(BitWidth - ShiftAmt); 1067 KnownBitsComputed = true; 1068 break; 1069 } 1070 case Intrinsic::umax: { 1071 // UMax(A, C) == A if ... 1072 // The lowest non-zero bit of DemandMask is higher than the highest 1073 // non-zero bit of C. 1074 const APInt *C; 1075 unsigned CTZ = DemandedMask.countr_zero(); 1076 if (match(II->getArgOperand(1), m_APInt(C)) && 1077 CTZ >= C->getActiveBits()) 1078 return II->getArgOperand(0); 1079 break; 1080 } 1081 case Intrinsic::umin: { 1082 // UMin(A, C) == A if ... 1083 // The lowest non-zero bit of DemandMask is higher than the highest 1084 // non-one bit of C. 1085 // This comes from using DeMorgans on the above umax example. 1086 const APInt *C; 1087 unsigned CTZ = DemandedMask.countr_zero(); 1088 if (match(II->getArgOperand(1), m_APInt(C)) && 1089 CTZ >= C->getBitWidth() - C->countl_one()) 1090 return II->getArgOperand(0); 1091 break; 1092 } 1093 default: { 1094 // Handle target specific intrinsics 1095 std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( 1096 *II, DemandedMask, Known, KnownBitsComputed); 1097 if (V) 1098 return *V; 1099 break; 1100 } 1101 } 1102 } 1103 1104 if (!KnownBitsComputed) 1105 llvm::computeKnownBits(I, Known, Depth, Q); 1106 break; 1107 } 1108 } 1109 1110 if (I->getType()->isPointerTy()) { 1111 Align Alignment = I->getPointerAlignment(DL); 1112 Known.Zero.setLowBits(Log2(Alignment)); 1113 } 1114 1115 // If the client is only demanding bits that we know, return the known 1116 // constant. We can't directly simplify pointers as a constant because of 1117 // pointer provenance. 1118 // TODO: We could return `(inttoptr const)` for pointers. 1119 if (!I->getType()->isPointerTy() && 1120 DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1121 return Constant::getIntegerValue(VTy, Known.One); 1122 1123 if (VerifyKnownBits) { 1124 KnownBits ReferenceKnown = llvm::computeKnownBits(I, Depth, Q); 1125 if (Known != ReferenceKnown) { 1126 errs() << "Mismatched known bits for " << *I << " in " 1127 << I->getFunction()->getName() << "\n"; 1128 errs() << "computeKnownBits(): " << ReferenceKnown << "\n"; 1129 errs() << "SimplifyDemandedBits(): " << Known << "\n"; 1130 std::abort(); 1131 } 1132 } 1133 1134 return nullptr; 1135 } 1136 1137 /// Helper routine of SimplifyDemandedUseBits. It computes Known 1138 /// bits. It also tries to handle simplifications that can be done based on 1139 /// DemandedMask, but without modifying the Instruction. 1140 Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( 1141 Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth, 1142 const SimplifyQuery &Q) { 1143 unsigned BitWidth = DemandedMask.getBitWidth(); 1144 Type *ITy = I->getType(); 1145 1146 KnownBits LHSKnown(BitWidth); 1147 KnownBits RHSKnown(BitWidth); 1148 1149 // Despite the fact that we can't simplify this instruction in all User's 1150 // context, we can at least compute the known bits, and we can 1151 // do simplifications that apply to *just* the one user if we know that 1152 // this instruction has a simpler value in that context. 1153 switch (I->getOpcode()) { 1154 case Instruction::And: { 1155 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); 1156 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); 1157 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 1158 Depth, Q); 1159 computeKnownBitsFromContext(I, Known, Depth, Q); 1160 1161 // If the client is only demanding bits that we know, return the known 1162 // constant. 1163 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1164 return Constant::getIntegerValue(ITy, Known.One); 1165 1166 // If all of the demanded bits are known 1 on one side, return the other. 1167 // These bits cannot contribute to the result of the 'and' in this context. 1168 if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) 1169 return I->getOperand(0); 1170 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) 1171 return I->getOperand(1); 1172 1173 break; 1174 } 1175 case Instruction::Or: { 1176 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); 1177 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); 1178 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 1179 Depth, Q); 1180 computeKnownBitsFromContext(I, Known, Depth, Q); 1181 1182 // If the client is only demanding bits that we know, return the known 1183 // constant. 1184 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1185 return Constant::getIntegerValue(ITy, Known.One); 1186 1187 // We can simplify (X|Y) -> X or Y in the user's context if we know that 1188 // only bits from X or Y are demanded. 1189 // If all of the demanded bits are known zero on one side, return the other. 1190 // These bits cannot contribute to the result of the 'or' in this context. 1191 if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) 1192 return I->getOperand(0); 1193 if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) 1194 return I->getOperand(1); 1195 1196 break; 1197 } 1198 case Instruction::Xor: { 1199 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); 1200 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); 1201 Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, 1202 Depth, Q); 1203 computeKnownBitsFromContext(I, Known, Depth, Q); 1204 1205 // If the client is only demanding bits that we know, return the known 1206 // constant. 1207 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1208 return Constant::getIntegerValue(ITy, Known.One); 1209 1210 // We can simplify (X^Y) -> X or Y in the user's context if we know that 1211 // only bits from X or Y are demanded. 1212 // If all of the demanded bits are known zero on one side, return the other. 1213 if (DemandedMask.isSubsetOf(RHSKnown.Zero)) 1214 return I->getOperand(0); 1215 if (DemandedMask.isSubsetOf(LHSKnown.Zero)) 1216 return I->getOperand(1); 1217 1218 break; 1219 } 1220 case Instruction::Add: { 1221 unsigned NLZ = DemandedMask.countl_zero(); 1222 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 1223 1224 // If an operand adds zeros to every bit below the highest demanded bit, 1225 // that operand doesn't change the result. Return the other side. 1226 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); 1227 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 1228 return I->getOperand(0); 1229 1230 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); 1231 if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) 1232 return I->getOperand(1); 1233 1234 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 1235 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 1236 Known = 1237 KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, LHSKnown, RHSKnown); 1238 computeKnownBitsFromContext(I, Known, Depth, Q); 1239 break; 1240 } 1241 case Instruction::Sub: { 1242 unsigned NLZ = DemandedMask.countl_zero(); 1243 APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); 1244 1245 // If an operand subtracts zeros from every bit below the highest demanded 1246 // bit, that operand doesn't change the result. Return the other side. 1247 llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); 1248 if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) 1249 return I->getOperand(0); 1250 1251 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); 1252 bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); 1253 llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); 1254 Known = KnownBits::computeForAddSub(/*Add=*/false, NSW, NUW, LHSKnown, 1255 RHSKnown); 1256 computeKnownBitsFromContext(I, Known, Depth, Q); 1257 break; 1258 } 1259 case Instruction::AShr: { 1260 // Compute the Known bits to simplify things downstream. 1261 llvm::computeKnownBits(I, Known, Depth, Q); 1262 1263 // If this user is only demanding bits that we know, return the known 1264 // constant. 1265 if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) 1266 return Constant::getIntegerValue(ITy, Known.One); 1267 1268 // If the right shift operand 0 is a result of a left shift by the same 1269 // amount, this is probably a zero/sign extension, which may be unnecessary, 1270 // if we do not demand any of the new sign bits. So, return the original 1271 // operand instead. 1272 const APInt *ShiftRC; 1273 const APInt *ShiftLC; 1274 Value *X; 1275 unsigned BitWidth = DemandedMask.getBitWidth(); 1276 if (match(I, 1277 m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) && 1278 ShiftLC == ShiftRC && ShiftLC->ult(BitWidth) && 1279 DemandedMask.isSubsetOf(APInt::getLowBitsSet( 1280 BitWidth, BitWidth - ShiftRC->getZExtValue()))) { 1281 return X; 1282 } 1283 1284 break; 1285 } 1286 default: 1287 // Compute the Known bits to simplify things downstream. 1288 llvm::computeKnownBits(I, Known, Depth, Q); 1289 1290 // If this user is only demanding bits that we know, return the known 1291 // constant. 1292 if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) 1293 return Constant::getIntegerValue(ITy, Known.One); 1294 1295 break; 1296 } 1297 1298 return nullptr; 1299 } 1300 1301 /// Helper routine of SimplifyDemandedUseBits. It tries to simplify 1302 /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into 1303 /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign 1304 /// of "C2-C1". 1305 /// 1306 /// Suppose E1 and E2 are generally different in bits S={bm, bm+1, 1307 /// ..., bn}, without considering the specific value X is holding. 1308 /// This transformation is legal iff one of following conditions is hold: 1309 /// 1) All the bit in S are 0, in this case E1 == E2. 1310 /// 2) We don't care those bits in S, per the input DemandedMask. 1311 /// 3) Combination of 1) and 2). Some bits in S are 0, and we don't care the 1312 /// rest bits. 1313 /// 1314 /// Currently we only test condition 2). 1315 /// 1316 /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was 1317 /// not successful. 1318 Value *InstCombinerImpl::simplifyShrShlDemandedBits( 1319 Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, 1320 const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) { 1321 if (!ShlOp1 || !ShrOp1) 1322 return nullptr; // No-op. 1323 1324 Value *VarX = Shr->getOperand(0); 1325 Type *Ty = VarX->getType(); 1326 unsigned BitWidth = Ty->getScalarSizeInBits(); 1327 if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth)) 1328 return nullptr; // Undef. 1329 1330 unsigned ShlAmt = ShlOp1.getZExtValue(); 1331 unsigned ShrAmt = ShrOp1.getZExtValue(); 1332 1333 Known.One.clearAllBits(); 1334 Known.Zero.setLowBits(ShlAmt - 1); 1335 Known.Zero &= DemandedMask; 1336 1337 APInt BitMask1(APInt::getAllOnes(BitWidth)); 1338 APInt BitMask2(APInt::getAllOnes(BitWidth)); 1339 1340 bool isLshr = (Shr->getOpcode() == Instruction::LShr); 1341 BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : 1342 (BitMask1.ashr(ShrAmt) << ShlAmt); 1343 1344 if (ShrAmt <= ShlAmt) { 1345 BitMask2 <<= (ShlAmt - ShrAmt); 1346 } else { 1347 BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt): 1348 BitMask2.ashr(ShrAmt - ShlAmt); 1349 } 1350 1351 // Check if condition-2 (see the comment to this function) is satified. 1352 if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) { 1353 if (ShrAmt == ShlAmt) 1354 return VarX; 1355 1356 if (!Shr->hasOneUse()) 1357 return nullptr; 1358 1359 BinaryOperator *New; 1360 if (ShrAmt < ShlAmt) { 1361 Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt); 1362 New = BinaryOperator::CreateShl(VarX, Amt); 1363 BinaryOperator *Orig = cast<BinaryOperator>(Shl); 1364 New->setHasNoSignedWrap(Orig->hasNoSignedWrap()); 1365 New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap()); 1366 } else { 1367 Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt); 1368 New = isLshr ? BinaryOperator::CreateLShr(VarX, Amt) : 1369 BinaryOperator::CreateAShr(VarX, Amt); 1370 if (cast<BinaryOperator>(Shr)->isExact()) 1371 New->setIsExact(true); 1372 } 1373 1374 return InsertNewInstWith(New, Shl->getIterator()); 1375 } 1376 1377 return nullptr; 1378 } 1379 1380 /// The specified value produces a vector with any number of elements. 1381 /// This method analyzes which elements of the operand are poison and 1382 /// returns that information in PoisonElts. 1383 /// 1384 /// DemandedElts contains the set of elements that are actually used by the 1385 /// caller, and by default (AllowMultipleUsers equals false) the value is 1386 /// simplified only if it has a single caller. If AllowMultipleUsers is set 1387 /// to true, DemandedElts refers to the union of sets of elements that are 1388 /// used by all callers. 1389 /// 1390 /// If the information about demanded elements can be used to simplify the 1391 /// operation, the operation is simplified, then the resultant value is 1392 /// returned. This returns null if no change was made. 1393 Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, 1394 APInt DemandedElts, 1395 APInt &PoisonElts, 1396 unsigned Depth, 1397 bool AllowMultipleUsers) { 1398 // Cannot analyze scalable type. The number of vector elements is not a 1399 // compile-time constant. 1400 if (isa<ScalableVectorType>(V->getType())) 1401 return nullptr; 1402 1403 unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); 1404 APInt EltMask(APInt::getAllOnes(VWidth)); 1405 assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); 1406 1407 if (match(V, m_Poison())) { 1408 // If the entire vector is poison, just return this info. 1409 PoisonElts = EltMask; 1410 return nullptr; 1411 } 1412 1413 if (DemandedElts.isZero()) { // If nothing is demanded, provide poison. 1414 PoisonElts = EltMask; 1415 return PoisonValue::get(V->getType()); 1416 } 1417 1418 PoisonElts = 0; 1419 1420 if (auto *C = dyn_cast<Constant>(V)) { 1421 // Check if this is identity. If so, return 0 since we are not simplifying 1422 // anything. 1423 if (DemandedElts.isAllOnes()) 1424 return nullptr; 1425 1426 Type *EltTy = cast<VectorType>(V->getType())->getElementType(); 1427 Constant *Poison = PoisonValue::get(EltTy); 1428 SmallVector<Constant*, 16> Elts; 1429 for (unsigned i = 0; i != VWidth; ++i) { 1430 if (!DemandedElts[i]) { // If not demanded, set to poison. 1431 Elts.push_back(Poison); 1432 PoisonElts.setBit(i); 1433 continue; 1434 } 1435 1436 Constant *Elt = C->getAggregateElement(i); 1437 if (!Elt) return nullptr; 1438 1439 Elts.push_back(Elt); 1440 if (isa<PoisonValue>(Elt)) // Already poison. 1441 PoisonElts.setBit(i); 1442 } 1443 1444 // If we changed the constant, return it. 1445 Constant *NewCV = ConstantVector::get(Elts); 1446 return NewCV != C ? NewCV : nullptr; 1447 } 1448 1449 // Limit search depth. 1450 if (Depth == 10) 1451 return nullptr; 1452 1453 if (!AllowMultipleUsers) { 1454 // If multiple users are using the root value, proceed with 1455 // simplification conservatively assuming that all elements 1456 // are needed. 1457 if (!V->hasOneUse()) { 1458 // Quit if we find multiple users of a non-root value though. 1459 // They'll be handled when it's their turn to be visited by 1460 // the main instcombine process. 1461 if (Depth != 0) 1462 // TODO: Just compute the PoisonElts information recursively. 1463 return nullptr; 1464 1465 // Conservatively assume that all elements are needed. 1466 DemandedElts = EltMask; 1467 } 1468 } 1469 1470 Instruction *I = dyn_cast<Instruction>(V); 1471 if (!I) return nullptr; // Only analyze instructions. 1472 1473 bool MadeChange = false; 1474 auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum, 1475 APInt Demanded, APInt &Undef) { 1476 auto *II = dyn_cast<IntrinsicInst>(Inst); 1477 Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum); 1478 if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) { 1479 replaceOperand(*Inst, OpNum, V); 1480 MadeChange = true; 1481 } 1482 }; 1483 1484 APInt PoisonElts2(VWidth, 0); 1485 APInt PoisonElts3(VWidth, 0); 1486 switch (I->getOpcode()) { 1487 default: break; 1488 1489 case Instruction::GetElementPtr: { 1490 // The LangRef requires that struct geps have all constant indices. As 1491 // such, we can't convert any operand to partial undef. 1492 auto mayIndexStructType = [](GetElementPtrInst &GEP) { 1493 for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP); 1494 I != E; I++) 1495 if (I.isStruct()) 1496 return true; 1497 return false; 1498 }; 1499 if (mayIndexStructType(cast<GetElementPtrInst>(*I))) 1500 break; 1501 1502 // Conservatively track the demanded elements back through any vector 1503 // operands we may have. We know there must be at least one, or we 1504 // wouldn't have a vector result to get here. Note that we intentionally 1505 // merge the undef bits here since gepping with either an poison base or 1506 // index results in poison. 1507 for (unsigned i = 0; i < I->getNumOperands(); i++) { 1508 if (i == 0 ? match(I->getOperand(i), m_Undef()) 1509 : match(I->getOperand(i), m_Poison())) { 1510 // If the entire vector is undefined, just return this info. 1511 PoisonElts = EltMask; 1512 return nullptr; 1513 } 1514 if (I->getOperand(i)->getType()->isVectorTy()) { 1515 APInt PoisonEltsOp(VWidth, 0); 1516 simplifyAndSetOp(I, i, DemandedElts, PoisonEltsOp); 1517 // gep(x, undef) is not undef, so skip considering idx ops here 1518 // Note that we could propagate poison, but we can't distinguish between 1519 // undef & poison bits ATM 1520 if (i == 0) 1521 PoisonElts |= PoisonEltsOp; 1522 } 1523 } 1524 1525 break; 1526 } 1527 case Instruction::InsertElement: { 1528 // If this is a variable index, we don't know which element it overwrites. 1529 // demand exactly the same input as we produce. 1530 ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2)); 1531 if (!Idx) { 1532 // Note that we can't propagate undef elt info, because we don't know 1533 // which elt is getting updated. 1534 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts2); 1535 break; 1536 } 1537 1538 // The element inserted overwrites whatever was there, so the input demanded 1539 // set is simpler than the output set. 1540 unsigned IdxNo = Idx->getZExtValue(); 1541 APInt PreInsertDemandedElts = DemandedElts; 1542 if (IdxNo < VWidth) 1543 PreInsertDemandedElts.clearBit(IdxNo); 1544 1545 // If we only demand the element that is being inserted and that element 1546 // was extracted from the same index in another vector with the same type, 1547 // replace this insert with that other vector. 1548 // Note: This is attempted before the call to simplifyAndSetOp because that 1549 // may change PoisonElts to a value that does not match with Vec. 1550 Value *Vec; 1551 if (PreInsertDemandedElts == 0 && 1552 match(I->getOperand(1), 1553 m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) && 1554 Vec->getType() == I->getType()) { 1555 return Vec; 1556 } 1557 1558 simplifyAndSetOp(I, 0, PreInsertDemandedElts, PoisonElts); 1559 1560 // If this is inserting an element that isn't demanded, remove this 1561 // insertelement. 1562 if (IdxNo >= VWidth || !DemandedElts[IdxNo]) { 1563 Worklist.push(I); 1564 return I->getOperand(0); 1565 } 1566 1567 // The inserted element is defined. 1568 PoisonElts.clearBit(IdxNo); 1569 break; 1570 } 1571 case Instruction::ShuffleVector: { 1572 auto *Shuffle = cast<ShuffleVectorInst>(I); 1573 assert(Shuffle->getOperand(0)->getType() == 1574 Shuffle->getOperand(1)->getType() && 1575 "Expected shuffle operands to have same type"); 1576 unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType()) 1577 ->getNumElements(); 1578 // Handle trivial case of a splat. Only check the first element of LHS 1579 // operand. 1580 if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && 1581 DemandedElts.isAllOnes()) { 1582 if (!isa<PoisonValue>(I->getOperand(1))) { 1583 I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType())); 1584 MadeChange = true; 1585 } 1586 APInt LeftDemanded(OpWidth, 1); 1587 APInt LHSPoisonElts(OpWidth, 0); 1588 simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts); 1589 if (LHSPoisonElts[0]) 1590 PoisonElts = EltMask; 1591 else 1592 PoisonElts.clearAllBits(); 1593 break; 1594 } 1595 1596 APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); 1597 for (unsigned i = 0; i < VWidth; i++) { 1598 if (DemandedElts[i]) { 1599 unsigned MaskVal = Shuffle->getMaskValue(i); 1600 if (MaskVal != -1u) { 1601 assert(MaskVal < OpWidth * 2 && 1602 "shufflevector mask index out of range!"); 1603 if (MaskVal < OpWidth) 1604 LeftDemanded.setBit(MaskVal); 1605 else 1606 RightDemanded.setBit(MaskVal - OpWidth); 1607 } 1608 } 1609 } 1610 1611 APInt LHSPoisonElts(OpWidth, 0); 1612 simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts); 1613 1614 APInt RHSPoisonElts(OpWidth, 0); 1615 simplifyAndSetOp(I, 1, RightDemanded, RHSPoisonElts); 1616 1617 // If this shuffle does not change the vector length and the elements 1618 // demanded by this shuffle are an identity mask, then this shuffle is 1619 // unnecessary. 1620 // 1621 // We are assuming canonical form for the mask, so the source vector is 1622 // operand 0 and operand 1 is not used. 1623 // 1624 // Note that if an element is demanded and this shuffle mask is undefined 1625 // for that element, then the shuffle is not considered an identity 1626 // operation. The shuffle prevents poison from the operand vector from 1627 // leaking to the result by replacing poison with an undefined value. 1628 if (VWidth == OpWidth) { 1629 bool IsIdentityShuffle = true; 1630 for (unsigned i = 0; i < VWidth; i++) { 1631 unsigned MaskVal = Shuffle->getMaskValue(i); 1632 if (DemandedElts[i] && i != MaskVal) { 1633 IsIdentityShuffle = false; 1634 break; 1635 } 1636 } 1637 if (IsIdentityShuffle) 1638 return Shuffle->getOperand(0); 1639 } 1640 1641 bool NewPoisonElts = false; 1642 unsigned LHSIdx = -1u, LHSValIdx = -1u; 1643 unsigned RHSIdx = -1u, RHSValIdx = -1u; 1644 bool LHSUniform = true; 1645 bool RHSUniform = true; 1646 for (unsigned i = 0; i < VWidth; i++) { 1647 unsigned MaskVal = Shuffle->getMaskValue(i); 1648 if (MaskVal == -1u) { 1649 PoisonElts.setBit(i); 1650 } else if (!DemandedElts[i]) { 1651 NewPoisonElts = true; 1652 PoisonElts.setBit(i); 1653 } else if (MaskVal < OpWidth) { 1654 if (LHSPoisonElts[MaskVal]) { 1655 NewPoisonElts = true; 1656 PoisonElts.setBit(i); 1657 } else { 1658 LHSIdx = LHSIdx == -1u ? i : OpWidth; 1659 LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth; 1660 LHSUniform = LHSUniform && (MaskVal == i); 1661 } 1662 } else { 1663 if (RHSPoisonElts[MaskVal - OpWidth]) { 1664 NewPoisonElts = true; 1665 PoisonElts.setBit(i); 1666 } else { 1667 RHSIdx = RHSIdx == -1u ? i : OpWidth; 1668 RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth; 1669 RHSUniform = RHSUniform && (MaskVal - OpWidth == i); 1670 } 1671 } 1672 } 1673 1674 // Try to transform shuffle with constant vector and single element from 1675 // this constant vector to single insertelement instruction. 1676 // shufflevector V, C, <v1, v2, .., ci, .., vm> -> 1677 // insertelement V, C[ci], ci-n 1678 if (OpWidth == 1679 cast<FixedVectorType>(Shuffle->getType())->getNumElements()) { 1680 Value *Op = nullptr; 1681 Constant *Value = nullptr; 1682 unsigned Idx = -1u; 1683 1684 // Find constant vector with the single element in shuffle (LHS or RHS). 1685 if (LHSIdx < OpWidth && RHSUniform) { 1686 if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) { 1687 Op = Shuffle->getOperand(1); 1688 Value = CV->getOperand(LHSValIdx); 1689 Idx = LHSIdx; 1690 } 1691 } 1692 if (RHSIdx < OpWidth && LHSUniform) { 1693 if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) { 1694 Op = Shuffle->getOperand(0); 1695 Value = CV->getOperand(RHSValIdx); 1696 Idx = RHSIdx; 1697 } 1698 } 1699 // Found constant vector with single element - convert to insertelement. 1700 if (Op && Value) { 1701 Instruction *New = InsertElementInst::Create( 1702 Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx), 1703 Shuffle->getName()); 1704 InsertNewInstWith(New, Shuffle->getIterator()); 1705 return New; 1706 } 1707 } 1708 if (NewPoisonElts) { 1709 // Add additional discovered undefs. 1710 SmallVector<int, 16> Elts; 1711 for (unsigned i = 0; i < VWidth; ++i) { 1712 if (PoisonElts[i]) 1713 Elts.push_back(PoisonMaskElem); 1714 else 1715 Elts.push_back(Shuffle->getMaskValue(i)); 1716 } 1717 Shuffle->setShuffleMask(Elts); 1718 MadeChange = true; 1719 } 1720 break; 1721 } 1722 case Instruction::Select: { 1723 // If this is a vector select, try to transform the select condition based 1724 // on the current demanded elements. 1725 SelectInst *Sel = cast<SelectInst>(I); 1726 if (Sel->getCondition()->getType()->isVectorTy()) { 1727 // TODO: We are not doing anything with PoisonElts based on this call. 1728 // It is overwritten below based on the other select operands. If an 1729 // element of the select condition is known undef, then we are free to 1730 // choose the output value from either arm of the select. If we know that 1731 // one of those values is undef, then the output can be undef. 1732 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts); 1733 } 1734 1735 // Next, see if we can transform the arms of the select. 1736 APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts); 1737 if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) { 1738 for (unsigned i = 0; i < VWidth; i++) { 1739 // isNullValue() always returns false when called on a ConstantExpr. 1740 // Skip constant expressions to avoid propagating incorrect information. 1741 Constant *CElt = CV->getAggregateElement(i); 1742 if (isa<ConstantExpr>(CElt)) 1743 continue; 1744 // TODO: If a select condition element is undef, we can demand from 1745 // either side. If one side is known undef, choosing that side would 1746 // propagate undef. 1747 if (CElt->isNullValue()) 1748 DemandedLHS.clearBit(i); 1749 else 1750 DemandedRHS.clearBit(i); 1751 } 1752 } 1753 1754 simplifyAndSetOp(I, 1, DemandedLHS, PoisonElts2); 1755 simplifyAndSetOp(I, 2, DemandedRHS, PoisonElts3); 1756 1757 // Output elements are undefined if the element from each arm is undefined. 1758 // TODO: This can be improved. See comment in select condition handling. 1759 PoisonElts = PoisonElts2 & PoisonElts3; 1760 break; 1761 } 1762 case Instruction::BitCast: { 1763 // Vector->vector casts only. 1764 VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType()); 1765 if (!VTy) break; 1766 unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements(); 1767 APInt InputDemandedElts(InVWidth, 0); 1768 PoisonElts2 = APInt(InVWidth, 0); 1769 unsigned Ratio; 1770 1771 if (VWidth == InVWidth) { 1772 // If we are converting from <4 x i32> -> <4 x f32>, we demand the same 1773 // elements as are demanded of us. 1774 Ratio = 1; 1775 InputDemandedElts = DemandedElts; 1776 } else if ((VWidth % InVWidth) == 0) { 1777 // If the number of elements in the output is a multiple of the number of 1778 // elements in the input then an input element is live if any of the 1779 // corresponding output elements are live. 1780 Ratio = VWidth / InVWidth; 1781 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) 1782 if (DemandedElts[OutIdx]) 1783 InputDemandedElts.setBit(OutIdx / Ratio); 1784 } else if ((InVWidth % VWidth) == 0) { 1785 // If the number of elements in the input is a multiple of the number of 1786 // elements in the output then an input element is live if the 1787 // corresponding output element is live. 1788 Ratio = InVWidth / VWidth; 1789 for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) 1790 if (DemandedElts[InIdx / Ratio]) 1791 InputDemandedElts.setBit(InIdx); 1792 } else { 1793 // Unsupported so far. 1794 break; 1795 } 1796 1797 simplifyAndSetOp(I, 0, InputDemandedElts, PoisonElts2); 1798 1799 if (VWidth == InVWidth) { 1800 PoisonElts = PoisonElts2; 1801 } else if ((VWidth % InVWidth) == 0) { 1802 // If the number of elements in the output is a multiple of the number of 1803 // elements in the input then an output element is undef if the 1804 // corresponding input element is undef. 1805 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) 1806 if (PoisonElts2[OutIdx / Ratio]) 1807 PoisonElts.setBit(OutIdx); 1808 } else if ((InVWidth % VWidth) == 0) { 1809 // If the number of elements in the input is a multiple of the number of 1810 // elements in the output then an output element is undef if all of the 1811 // corresponding input elements are undef. 1812 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { 1813 APInt SubUndef = PoisonElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio); 1814 if (SubUndef.popcount() == Ratio) 1815 PoisonElts.setBit(OutIdx); 1816 } 1817 } else { 1818 llvm_unreachable("Unimp"); 1819 } 1820 break; 1821 } 1822 case Instruction::FPTrunc: 1823 case Instruction::FPExt: 1824 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts); 1825 break; 1826 1827 case Instruction::Call: { 1828 IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); 1829 if (!II) break; 1830 switch (II->getIntrinsicID()) { 1831 case Intrinsic::masked_gather: // fallthrough 1832 case Intrinsic::masked_load: { 1833 // Subtlety: If we load from a pointer, the pointer must be valid 1834 // regardless of whether the element is demanded. Doing otherwise risks 1835 // segfaults which didn't exist in the original program. 1836 APInt DemandedPtrs(APInt::getAllOnes(VWidth)), 1837 DemandedPassThrough(DemandedElts); 1838 if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) 1839 for (unsigned i = 0; i < VWidth; i++) { 1840 Constant *CElt = CV->getAggregateElement(i); 1841 if (CElt->isNullValue()) 1842 DemandedPtrs.clearBit(i); 1843 else if (CElt->isAllOnesValue()) 1844 DemandedPassThrough.clearBit(i); 1845 } 1846 if (II->getIntrinsicID() == Intrinsic::masked_gather) 1847 simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2); 1848 simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3); 1849 1850 // Output elements are undefined if the element from both sources are. 1851 // TODO: can strengthen via mask as well. 1852 PoisonElts = PoisonElts2 & PoisonElts3; 1853 break; 1854 } 1855 default: { 1856 // Handle target specific intrinsics 1857 std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( 1858 *II, DemandedElts, PoisonElts, PoisonElts2, PoisonElts3, 1859 simplifyAndSetOp); 1860 if (V) 1861 return *V; 1862 break; 1863 } 1864 } // switch on IntrinsicID 1865 break; 1866 } // case Call 1867 } // switch on Opcode 1868 1869 // TODO: We bail completely on integer div/rem and shifts because they have 1870 // UB/poison potential, but that should be refined. 1871 BinaryOperator *BO; 1872 if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) { 1873 Value *X = BO->getOperand(0); 1874 Value *Y = BO->getOperand(1); 1875 1876 // Look for an equivalent binop except that one operand has been shuffled. 1877 // If the demand for this binop only includes elements that are the same as 1878 // the other binop, then we may be able to replace this binop with a use of 1879 // the earlier one. 1880 // 1881 // Example: 1882 // %other_bo = bo (shuf X, {0}), Y 1883 // %this_extracted_bo = extelt (bo X, Y), 0 1884 // --> 1885 // %other_bo = bo (shuf X, {0}), Y 1886 // %this_extracted_bo = extelt %other_bo, 0 1887 // 1888 // TODO: Handle demand of an arbitrary single element or more than one 1889 // element instead of just element 0. 1890 // TODO: Unlike general demanded elements transforms, this should be safe 1891 // for any (div/rem/shift) opcode too. 1892 if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() && 1893 BO->hasOneUse() ) { 1894 1895 auto findShufBO = [&](bool MatchShufAsOp0) -> User * { 1896 // Try to use shuffle-of-operand in place of an operand: 1897 // bo X, Y --> bo (shuf X), Y 1898 // bo X, Y --> bo X, (shuf Y) 1899 BinaryOperator::BinaryOps Opcode = BO->getOpcode(); 1900 Value *ShufOp = MatchShufAsOp0 ? X : Y; 1901 Value *OtherOp = MatchShufAsOp0 ? Y : X; 1902 for (User *U : OtherOp->users()) { 1903 ArrayRef<int> Mask; 1904 auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_Mask(Mask)); 1905 if (BO->isCommutative() 1906 ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp))) 1907 : MatchShufAsOp0 1908 ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp))) 1909 : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf))) 1910 if (match(Mask, m_ZeroMask()) && Mask[0] != PoisonMaskElem) 1911 if (DT.dominates(U, I)) 1912 return U; 1913 } 1914 return nullptr; 1915 }; 1916 1917 if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true)) 1918 return ShufBO; 1919 if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false)) 1920 return ShufBO; 1921 } 1922 1923 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts); 1924 simplifyAndSetOp(I, 1, DemandedElts, PoisonElts2); 1925 1926 // Output elements are undefined if both are undefined. Consider things 1927 // like undef & 0. The result is known zero, not undef. 1928 PoisonElts &= PoisonElts2; 1929 } 1930 1931 // If we've proven all of the lanes poison, return a poison value. 1932 // TODO: Intersect w/demanded lanes 1933 if (PoisonElts.isAllOnes()) 1934 return PoisonValue::get(I->getType()); 1935 1936 return MadeChange ? I : nullptr; 1937 } 1938 1939 /// For floating-point classes that resolve to a single bit pattern, return that 1940 /// value. 1941 static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask) { 1942 switch (Mask) { 1943 case fcPosZero: 1944 return ConstantFP::getZero(Ty); 1945 case fcNegZero: 1946 return ConstantFP::getZero(Ty, true); 1947 case fcPosInf: 1948 return ConstantFP::getInfinity(Ty); 1949 case fcNegInf: 1950 return ConstantFP::getInfinity(Ty, true); 1951 case fcNone: 1952 return PoisonValue::get(Ty); 1953 default: 1954 return nullptr; 1955 } 1956 } 1957 1958 Value *InstCombinerImpl::SimplifyDemandedUseFPClass( 1959 Value *V, const FPClassTest DemandedMask, KnownFPClass &Known, 1960 unsigned Depth, Instruction *CxtI) { 1961 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); 1962 Type *VTy = V->getType(); 1963 1964 assert(Known == KnownFPClass() && "expected uninitialized state"); 1965 1966 if (DemandedMask == fcNone) 1967 return isa<UndefValue>(V) ? nullptr : PoisonValue::get(VTy); 1968 1969 if (Depth == MaxAnalysisRecursionDepth) 1970 return nullptr; 1971 1972 Instruction *I = dyn_cast<Instruction>(V); 1973 if (!I) { 1974 // Handle constants and arguments 1975 Known = computeKnownFPClass(V, fcAllFlags, CxtI, Depth + 1); 1976 Value *FoldedToConst = 1977 getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses); 1978 return FoldedToConst == V ? nullptr : FoldedToConst; 1979 } 1980 1981 if (!I->hasOneUse()) 1982 return nullptr; 1983 1984 // TODO: Should account for nofpclass/FastMathFlags on current instruction 1985 switch (I->getOpcode()) { 1986 case Instruction::FNeg: { 1987 if (SimplifyDemandedFPClass(I, 0, llvm::fneg(DemandedMask), Known, 1988 Depth + 1)) 1989 return I; 1990 Known.fneg(); 1991 break; 1992 } 1993 case Instruction::Call: { 1994 CallInst *CI = cast<CallInst>(I); 1995 switch (CI->getIntrinsicID()) { 1996 case Intrinsic::fabs: 1997 if (SimplifyDemandedFPClass(I, 0, llvm::inverse_fabs(DemandedMask), Known, 1998 Depth + 1)) 1999 return I; 2000 Known.fabs(); 2001 break; 2002 case Intrinsic::arithmetic_fence: 2003 if (SimplifyDemandedFPClass(I, 0, DemandedMask, Known, Depth + 1)) 2004 return I; 2005 break; 2006 case Intrinsic::copysign: { 2007 // Flip on more potentially demanded classes 2008 const FPClassTest DemandedMaskAnySign = llvm::unknown_sign(DemandedMask); 2009 if (SimplifyDemandedFPClass(I, 0, DemandedMaskAnySign, Known, Depth + 1)) 2010 return I; 2011 2012 if ((DemandedMask & fcPositive) == fcNone) { 2013 // Roundabout way of replacing with fneg(fabs) 2014 I->setOperand(1, ConstantFP::get(VTy, -1.0)); 2015 return I; 2016 } 2017 2018 if ((DemandedMask & fcNegative) == fcNone) { 2019 // Roundabout way of replacing with fabs 2020 I->setOperand(1, ConstantFP::getZero(VTy)); 2021 return I; 2022 } 2023 2024 KnownFPClass KnownSign = 2025 computeKnownFPClass(I->getOperand(1), fcAllFlags, CxtI, Depth + 1); 2026 Known.copysign(KnownSign); 2027 break; 2028 } 2029 default: 2030 Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1); 2031 break; 2032 } 2033 2034 break; 2035 } 2036 case Instruction::Select: { 2037 KnownFPClass KnownLHS, KnownRHS; 2038 if (SimplifyDemandedFPClass(I, 2, DemandedMask, KnownRHS, Depth + 1) || 2039 SimplifyDemandedFPClass(I, 1, DemandedMask, KnownLHS, Depth + 1)) 2040 return I; 2041 2042 if (KnownLHS.isKnownNever(DemandedMask)) 2043 return I->getOperand(2); 2044 if (KnownRHS.isKnownNever(DemandedMask)) 2045 return I->getOperand(1); 2046 2047 // TODO: Recognize clamping patterns 2048 Known = KnownLHS | KnownRHS; 2049 break; 2050 } 2051 default: 2052 Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1); 2053 break; 2054 } 2055 2056 return getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses); 2057 } 2058 2059 bool InstCombinerImpl::SimplifyDemandedFPClass(Instruction *I, unsigned OpNo, 2060 FPClassTest DemandedMask, 2061 KnownFPClass &Known, 2062 unsigned Depth) { 2063 Use &U = I->getOperandUse(OpNo); 2064 Value *NewVal = 2065 SimplifyDemandedUseFPClass(U.get(), DemandedMask, Known, Depth, I); 2066 if (!NewVal) 2067 return false; 2068 if (Instruction *OpInst = dyn_cast<Instruction>(U)) 2069 salvageDebugInfo(*OpInst); 2070 2071 replaceUse(U, NewVal); 2072 return true; 2073 } 2074