1 //===- InstCombineCompares.cpp --------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the visitICmp and visitFCmp functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "InstCombineInternal.h" 14 #include "llvm/ADT/APSInt.h" 15 #include "llvm/ADT/ScopeExit.h" 16 #include "llvm/ADT/SetVector.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/CaptureTracking.h" 19 #include "llvm/Analysis/CmpInstAnalysis.h" 20 #include "llvm/Analysis/ConstantFolding.h" 21 #include "llvm/Analysis/InstructionSimplify.h" 22 #include "llvm/Analysis/Utils/Local.h" 23 #include "llvm/Analysis/VectorUtils.h" 24 #include "llvm/IR/ConstantRange.h" 25 #include "llvm/IR/DataLayout.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/IR/PatternMatch.h" 28 #include "llvm/Support/KnownBits.h" 29 #include "llvm/Transforms/InstCombine/InstCombiner.h" 30 #include <bitset> 31 32 using namespace llvm; 33 using namespace PatternMatch; 34 35 #define DEBUG_TYPE "instcombine" 36 37 // How many times is a select replaced by one of its operands? 38 STATISTIC(NumSel, "Number of select opts"); 39 40 41 /// Compute Result = In1+In2, returning true if the result overflowed for this 42 /// type. 43 static bool addWithOverflow(APInt &Result, const APInt &In1, 44 const APInt &In2, bool IsSigned = false) { 45 bool Overflow; 46 if (IsSigned) 47 Result = In1.sadd_ov(In2, Overflow); 48 else 49 Result = In1.uadd_ov(In2, Overflow); 50 51 return Overflow; 52 } 53 54 /// Compute Result = In1-In2, returning true if the result overflowed for this 55 /// type. 56 static bool subWithOverflow(APInt &Result, const APInt &In1, 57 const APInt &In2, bool IsSigned = false) { 58 bool Overflow; 59 if (IsSigned) 60 Result = In1.ssub_ov(In2, Overflow); 61 else 62 Result = In1.usub_ov(In2, Overflow); 63 64 return Overflow; 65 } 66 67 /// Given an icmp instruction, return true if any use of this comparison is a 68 /// branch on sign bit comparison. 69 static bool hasBranchUse(ICmpInst &I) { 70 for (auto *U : I.users()) 71 if (isa<BranchInst>(U)) 72 return true; 73 return false; 74 } 75 76 /// Returns true if the exploded icmp can be expressed as a signed comparison 77 /// to zero and updates the predicate accordingly. 78 /// The signedness of the comparison is preserved. 79 /// TODO: Refactor with decomposeBitTestICmp()? 80 static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { 81 if (!ICmpInst::isSigned(Pred)) 82 return false; 83 84 if (C.isZero()) 85 return ICmpInst::isRelational(Pred); 86 87 if (C.isOne()) { 88 if (Pred == ICmpInst::ICMP_SLT) { 89 Pred = ICmpInst::ICMP_SLE; 90 return true; 91 } 92 } else if (C.isAllOnes()) { 93 if (Pred == ICmpInst::ICMP_SGT) { 94 Pred = ICmpInst::ICMP_SGE; 95 return true; 96 } 97 } 98 99 return false; 100 } 101 102 /// This is called when we see this pattern: 103 /// cmp pred (load (gep GV, ...)), cmpcst 104 /// where GV is a global variable with a constant initializer. Try to simplify 105 /// this into some simple computation that does not need the load. For example 106 /// we can optimize "icmp eq (load (gep "foo", 0, i)), 0" into "icmp eq i, 3". 107 /// 108 /// If AndCst is non-null, then the loaded value is masked with that constant 109 /// before doing the comparison. This handles cases like "A[i]&4 == 0". 110 Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( 111 LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, 112 ConstantInt *AndCst) { 113 if (LI->isVolatile() || LI->getType() != GEP->getResultElementType() || 114 GV->getValueType() != GEP->getSourceElementType() || !GV->isConstant() || 115 !GV->hasDefinitiveInitializer()) 116 return nullptr; 117 118 Constant *Init = GV->getInitializer(); 119 if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) 120 return nullptr; 121 122 uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); 123 // Don't blow up on huge arrays. 124 if (ArrayElementCount > MaxArraySizeForCombine) 125 return nullptr; 126 127 // There are many forms of this optimization we can handle, for now, just do 128 // the simple index into a single-dimensional array. 129 // 130 // Require: GEP GV, 0, i {{, constant indices}} 131 if (GEP->getNumOperands() < 3 || !isa<ConstantInt>(GEP->getOperand(1)) || 132 !cast<ConstantInt>(GEP->getOperand(1))->isZero() || 133 isa<Constant>(GEP->getOperand(2))) 134 return nullptr; 135 136 // Check that indices after the variable are constants and in-range for the 137 // type they index. Collect the indices. This is typically for arrays of 138 // structs. 139 SmallVector<unsigned, 4> LaterIndices; 140 141 Type *EltTy = Init->getType()->getArrayElementType(); 142 for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) { 143 ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(i)); 144 if (!Idx) 145 return nullptr; // Variable index. 146 147 uint64_t IdxVal = Idx->getZExtValue(); 148 if ((unsigned)IdxVal != IdxVal) 149 return nullptr; // Too large array index. 150 151 if (StructType *STy = dyn_cast<StructType>(EltTy)) 152 EltTy = STy->getElementType(IdxVal); 153 else if (ArrayType *ATy = dyn_cast<ArrayType>(EltTy)) { 154 if (IdxVal >= ATy->getNumElements()) 155 return nullptr; 156 EltTy = ATy->getElementType(); 157 } else { 158 return nullptr; // Unknown type. 159 } 160 161 LaterIndices.push_back(IdxVal); 162 } 163 164 enum { Overdefined = -3, Undefined = -2 }; 165 166 // Variables for our state machines. 167 168 // FirstTrueElement/SecondTrueElement - Used to emit a comparison of the form 169 // "i == 47 | i == 87", where 47 is the first index the condition is true for, 170 // and 87 is the second (and last) index. FirstTrueElement is -2 when 171 // undefined, otherwise set to the first true element. SecondTrueElement is 172 // -2 when undefined, -3 when overdefined and >= 0 when that index is true. 173 int FirstTrueElement = Undefined, SecondTrueElement = Undefined; 174 175 // FirstFalseElement/SecondFalseElement - Used to emit a comparison of the 176 // form "i != 47 & i != 87". Same state transitions as for true elements. 177 int FirstFalseElement = Undefined, SecondFalseElement = Undefined; 178 179 /// TrueRangeEnd/FalseRangeEnd - In conjunction with First*Element, these 180 /// define a state machine that triggers for ranges of values that the index 181 /// is true or false for. This triggers on things like "abbbbc"[i] == 'b'. 182 /// This is -2 when undefined, -3 when overdefined, and otherwise the last 183 /// index in the range (inclusive). We use -2 for undefined here because we 184 /// use relative comparisons and don't want 0-1 to match -1. 185 int TrueRangeEnd = Undefined, FalseRangeEnd = Undefined; 186 187 // MagicBitvector - This is a magic bitvector where we set a bit if the 188 // comparison is true for element 'i'. If there are 64 elements or less in 189 // the array, this will fully represent all the comparison results. 190 uint64_t MagicBitvector = 0; 191 192 // Scan the array and see if one of our patterns matches. 193 Constant *CompareRHS = cast<Constant>(ICI.getOperand(1)); 194 for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { 195 Constant *Elt = Init->getAggregateElement(i); 196 if (!Elt) 197 return nullptr; 198 199 // If this is indexing an array of structures, get the structure element. 200 if (!LaterIndices.empty()) { 201 Elt = ConstantFoldExtractValueInstruction(Elt, LaterIndices); 202 if (!Elt) 203 return nullptr; 204 } 205 206 // If the element is masked, handle it. 207 if (AndCst) { 208 Elt = ConstantFoldBinaryOpOperands(Instruction::And, Elt, AndCst, DL); 209 if (!Elt) 210 return nullptr; 211 } 212 213 // Find out if the comparison would be true or false for the i'th element. 214 Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, 215 CompareRHS, DL, &TLI); 216 // If the result is undef for this element, ignore it. 217 if (isa<UndefValue>(C)) { 218 // Extend range state machines to cover this element in case there is an 219 // undef in the middle of the range. 220 if (TrueRangeEnd == (int)i - 1) 221 TrueRangeEnd = i; 222 if (FalseRangeEnd == (int)i - 1) 223 FalseRangeEnd = i; 224 continue; 225 } 226 227 // If we can't compute the result for any of the elements, we have to give 228 // up evaluating the entire conditional. 229 if (!isa<ConstantInt>(C)) 230 return nullptr; 231 232 // Otherwise, we know if the comparison is true or false for this element, 233 // update our state machines. 234 bool IsTrueForElt = !cast<ConstantInt>(C)->isZero(); 235 236 // State machine for single/double/range index comparison. 237 if (IsTrueForElt) { 238 // Update the TrueElement state machine. 239 if (FirstTrueElement == Undefined) 240 FirstTrueElement = TrueRangeEnd = i; // First true element. 241 else { 242 // Update double-compare state machine. 243 if (SecondTrueElement == Undefined) 244 SecondTrueElement = i; 245 else 246 SecondTrueElement = Overdefined; 247 248 // Update range state machine. 249 if (TrueRangeEnd == (int)i - 1) 250 TrueRangeEnd = i; 251 else 252 TrueRangeEnd = Overdefined; 253 } 254 } else { 255 // Update the FalseElement state machine. 256 if (FirstFalseElement == Undefined) 257 FirstFalseElement = FalseRangeEnd = i; // First false element. 258 else { 259 // Update double-compare state machine. 260 if (SecondFalseElement == Undefined) 261 SecondFalseElement = i; 262 else 263 SecondFalseElement = Overdefined; 264 265 // Update range state machine. 266 if (FalseRangeEnd == (int)i - 1) 267 FalseRangeEnd = i; 268 else 269 FalseRangeEnd = Overdefined; 270 } 271 } 272 273 // If this element is in range, update our magic bitvector. 274 if (i < 64 && IsTrueForElt) 275 MagicBitvector |= 1ULL << i; 276 277 // If all of our states become overdefined, bail out early. Since the 278 // predicate is expensive, only check it every 8 elements. This is only 279 // really useful for really huge arrays. 280 if ((i & 8) == 0 && i >= 64 && SecondTrueElement == Overdefined && 281 SecondFalseElement == Overdefined && TrueRangeEnd == Overdefined && 282 FalseRangeEnd == Overdefined) 283 return nullptr; 284 } 285 286 // Now that we've scanned the entire array, emit our new comparison(s). We 287 // order the state machines in complexity of the generated code. 288 Value *Idx = GEP->getOperand(2); 289 290 // If the index is larger than the pointer offset size of the target, truncate 291 // the index down like the GEP would do implicitly. We don't have to do this 292 // for an inbounds GEP because the index can't be out of range. 293 if (!GEP->isInBounds()) { 294 Type *PtrIdxTy = DL.getIndexType(GEP->getType()); 295 unsigned OffsetSize = PtrIdxTy->getIntegerBitWidth(); 296 if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > OffsetSize) 297 Idx = Builder.CreateTrunc(Idx, PtrIdxTy); 298 } 299 300 // If inbounds keyword is not present, Idx * ElementSize can overflow. 301 // Let's assume that ElementSize is 2 and the wanted value is at offset 0. 302 // Then, there are two possible values for Idx to match offset 0: 303 // 0x00..00, 0x80..00. 304 // Emitting 'icmp eq Idx, 0' isn't correct in this case because the 305 // comparison is false if Idx was 0x80..00. 306 // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx. 307 unsigned ElementSize = 308 DL.getTypeAllocSize(Init->getType()->getArrayElementType()); 309 auto MaskIdx = [&](Value *Idx) { 310 if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) { 311 Value *Mask = ConstantInt::get(Idx->getType(), -1); 312 Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize)); 313 Idx = Builder.CreateAnd(Idx, Mask); 314 } 315 return Idx; 316 }; 317 318 // If the comparison is only true for one or two elements, emit direct 319 // comparisons. 320 if (SecondTrueElement != Overdefined) { 321 Idx = MaskIdx(Idx); 322 // None true -> false. 323 if (FirstTrueElement == Undefined) 324 return replaceInstUsesWith(ICI, Builder.getFalse()); 325 326 Value *FirstTrueIdx = ConstantInt::get(Idx->getType(), FirstTrueElement); 327 328 // True for one element -> 'i == 47'. 329 if (SecondTrueElement == Undefined) 330 return new ICmpInst(ICmpInst::ICMP_EQ, Idx, FirstTrueIdx); 331 332 // True for two elements -> 'i == 47 | i == 72'. 333 Value *C1 = Builder.CreateICmpEQ(Idx, FirstTrueIdx); 334 Value *SecondTrueIdx = ConstantInt::get(Idx->getType(), SecondTrueElement); 335 Value *C2 = Builder.CreateICmpEQ(Idx, SecondTrueIdx); 336 return BinaryOperator::CreateOr(C1, C2); 337 } 338 339 // If the comparison is only false for one or two elements, emit direct 340 // comparisons. 341 if (SecondFalseElement != Overdefined) { 342 Idx = MaskIdx(Idx); 343 // None false -> true. 344 if (FirstFalseElement == Undefined) 345 return replaceInstUsesWith(ICI, Builder.getTrue()); 346 347 Value *FirstFalseIdx = ConstantInt::get(Idx->getType(), FirstFalseElement); 348 349 // False for one element -> 'i != 47'. 350 if (SecondFalseElement == Undefined) 351 return new ICmpInst(ICmpInst::ICMP_NE, Idx, FirstFalseIdx); 352 353 // False for two elements -> 'i != 47 & i != 72'. 354 Value *C1 = Builder.CreateICmpNE(Idx, FirstFalseIdx); 355 Value *SecondFalseIdx = 356 ConstantInt::get(Idx->getType(), SecondFalseElement); 357 Value *C2 = Builder.CreateICmpNE(Idx, SecondFalseIdx); 358 return BinaryOperator::CreateAnd(C1, C2); 359 } 360 361 // If the comparison can be replaced with a range comparison for the elements 362 // where it is true, emit the range check. 363 if (TrueRangeEnd != Overdefined) { 364 assert(TrueRangeEnd != FirstTrueElement && "Should emit single compare"); 365 Idx = MaskIdx(Idx); 366 367 // Generate (i-FirstTrue) <u (TrueRangeEnd-FirstTrue+1). 368 if (FirstTrueElement) { 369 Value *Offs = ConstantInt::get(Idx->getType(), -FirstTrueElement); 370 Idx = Builder.CreateAdd(Idx, Offs); 371 } 372 373 Value *End = 374 ConstantInt::get(Idx->getType(), TrueRangeEnd - FirstTrueElement + 1); 375 return new ICmpInst(ICmpInst::ICMP_ULT, Idx, End); 376 } 377 378 // False range check. 379 if (FalseRangeEnd != Overdefined) { 380 assert(FalseRangeEnd != FirstFalseElement && "Should emit single compare"); 381 Idx = MaskIdx(Idx); 382 // Generate (i-FirstFalse) >u (FalseRangeEnd-FirstFalse). 383 if (FirstFalseElement) { 384 Value *Offs = ConstantInt::get(Idx->getType(), -FirstFalseElement); 385 Idx = Builder.CreateAdd(Idx, Offs); 386 } 387 388 Value *End = 389 ConstantInt::get(Idx->getType(), FalseRangeEnd - FirstFalseElement); 390 return new ICmpInst(ICmpInst::ICMP_UGT, Idx, End); 391 } 392 393 // If a magic bitvector captures the entire comparison state 394 // of this load, replace it with computation that does: 395 // ((magic_cst >> i) & 1) != 0 396 { 397 Type *Ty = nullptr; 398 399 // Look for an appropriate type: 400 // - The type of Idx if the magic fits 401 // - The smallest fitting legal type 402 if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) 403 Ty = Idx->getType(); 404 else 405 Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); 406 407 if (Ty) { 408 Idx = MaskIdx(Idx); 409 Value *V = Builder.CreateIntCast(Idx, Ty, false); 410 V = Builder.CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); 411 V = Builder.CreateAnd(ConstantInt::get(Ty, 1), V); 412 return new ICmpInst(ICmpInst::ICMP_NE, V, ConstantInt::get(Ty, 0)); 413 } 414 } 415 416 return nullptr; 417 } 418 419 /// Returns true if we can rewrite Start as a GEP with pointer Base 420 /// and some integer offset. The nodes that need to be re-written 421 /// for this transformation will be added to Explored. 422 static bool canRewriteGEPAsOffset(Value *Start, Value *Base, 423 const DataLayout &DL, 424 SetVector<Value *> &Explored) { 425 SmallVector<Value *, 16> WorkList(1, Start); 426 Explored.insert(Base); 427 428 // The following traversal gives us an order which can be used 429 // when doing the final transformation. Since in the final 430 // transformation we create the PHI replacement instructions first, 431 // we don't have to get them in any particular order. 432 // 433 // However, for other instructions we will have to traverse the 434 // operands of an instruction first, which means that we have to 435 // do a post-order traversal. 436 while (!WorkList.empty()) { 437 SetVector<PHINode *> PHIs; 438 439 while (!WorkList.empty()) { 440 if (Explored.size() >= 100) 441 return false; 442 443 Value *V = WorkList.back(); 444 445 if (Explored.contains(V)) { 446 WorkList.pop_back(); 447 continue; 448 } 449 450 if (!isa<GetElementPtrInst>(V) && !isa<PHINode>(V)) 451 // We've found some value that we can't explore which is different from 452 // the base. Therefore we can't do this transformation. 453 return false; 454 455 if (auto *GEP = dyn_cast<GEPOperator>(V)) { 456 // Only allow inbounds GEPs with at most one variable offset. 457 auto IsNonConst = [](Value *V) { return !isa<ConstantInt>(V); }; 458 if (!GEP->isInBounds() || count_if(GEP->indices(), IsNonConst) > 1) 459 return false; 460 461 if (!Explored.contains(GEP->getOperand(0))) 462 WorkList.push_back(GEP->getOperand(0)); 463 } 464 465 if (WorkList.back() == V) { 466 WorkList.pop_back(); 467 // We've finished visiting this node, mark it as such. 468 Explored.insert(V); 469 } 470 471 if (auto *PN = dyn_cast<PHINode>(V)) { 472 // We cannot transform PHIs on unsplittable basic blocks. 473 if (isa<CatchSwitchInst>(PN->getParent()->getTerminator())) 474 return false; 475 Explored.insert(PN); 476 PHIs.insert(PN); 477 } 478 } 479 480 // Explore the PHI nodes further. 481 for (auto *PN : PHIs) 482 for (Value *Op : PN->incoming_values()) 483 if (!Explored.contains(Op)) 484 WorkList.push_back(Op); 485 } 486 487 // Make sure that we can do this. Since we can't insert GEPs in a basic 488 // block before a PHI node, we can't easily do this transformation if 489 // we have PHI node users of transformed instructions. 490 for (Value *Val : Explored) { 491 for (Value *Use : Val->uses()) { 492 493 auto *PHI = dyn_cast<PHINode>(Use); 494 auto *Inst = dyn_cast<Instruction>(Val); 495 496 if (Inst == Base || Inst == PHI || !Inst || !PHI || 497 !Explored.contains(PHI)) 498 continue; 499 500 if (PHI->getParent() == Inst->getParent()) 501 return false; 502 } 503 } 504 return true; 505 } 506 507 // Sets the appropriate insert point on Builder where we can add 508 // a replacement Instruction for V (if that is possible). 509 static void setInsertionPoint(IRBuilder<> &Builder, Value *V, 510 bool Before = true) { 511 if (auto *PHI = dyn_cast<PHINode>(V)) { 512 BasicBlock *Parent = PHI->getParent(); 513 Builder.SetInsertPoint(Parent, Parent->getFirstInsertionPt()); 514 return; 515 } 516 if (auto *I = dyn_cast<Instruction>(V)) { 517 if (!Before) 518 I = &*std::next(I->getIterator()); 519 Builder.SetInsertPoint(I); 520 return; 521 } 522 if (auto *A = dyn_cast<Argument>(V)) { 523 // Set the insertion point in the entry block. 524 BasicBlock &Entry = A->getParent()->getEntryBlock(); 525 Builder.SetInsertPoint(&Entry, Entry.getFirstInsertionPt()); 526 return; 527 } 528 // Otherwise, this is a constant and we don't need to set a new 529 // insertion point. 530 assert(isa<Constant>(V) && "Setting insertion point for unknown value!"); 531 } 532 533 /// Returns a re-written value of Start as an indexed GEP using Base as a 534 /// pointer. 535 static Value *rewriteGEPAsOffset(Value *Start, Value *Base, 536 const DataLayout &DL, 537 SetVector<Value *> &Explored, 538 InstCombiner &IC) { 539 // Perform all the substitutions. This is a bit tricky because we can 540 // have cycles in our use-def chains. 541 // 1. Create the PHI nodes without any incoming values. 542 // 2. Create all the other values. 543 // 3. Add the edges for the PHI nodes. 544 // 4. Emit GEPs to get the original pointers. 545 // 5. Remove the original instructions. 546 Type *IndexType = IntegerType::get( 547 Base->getContext(), DL.getIndexTypeSizeInBits(Start->getType())); 548 549 DenseMap<Value *, Value *> NewInsts; 550 NewInsts[Base] = ConstantInt::getNullValue(IndexType); 551 552 // Create the new PHI nodes, without adding any incoming values. 553 for (Value *Val : Explored) { 554 if (Val == Base) 555 continue; 556 // Create empty phi nodes. This avoids cyclic dependencies when creating 557 // the remaining instructions. 558 if (auto *PHI = dyn_cast<PHINode>(Val)) 559 NewInsts[PHI] = PHINode::Create(IndexType, PHI->getNumIncomingValues(), 560 PHI->getName() + ".idx", PHI); 561 } 562 IRBuilder<> Builder(Base->getContext()); 563 564 // Create all the other instructions. 565 for (Value *Val : Explored) { 566 if (NewInsts.contains(Val)) 567 continue; 568 569 if (auto *GEP = dyn_cast<GEPOperator>(Val)) { 570 setInsertionPoint(Builder, GEP); 571 Value *Op = NewInsts[GEP->getOperand(0)]; 572 Value *OffsetV = emitGEPOffset(&Builder, DL, GEP); 573 if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) 574 NewInsts[GEP] = OffsetV; 575 else 576 NewInsts[GEP] = Builder.CreateNSWAdd( 577 Op, OffsetV, GEP->getOperand(0)->getName() + ".add"); 578 continue; 579 } 580 if (isa<PHINode>(Val)) 581 continue; 582 583 llvm_unreachable("Unexpected instruction type"); 584 } 585 586 // Add the incoming values to the PHI nodes. 587 for (Value *Val : Explored) { 588 if (Val == Base) 589 continue; 590 // All the instructions have been created, we can now add edges to the 591 // phi nodes. 592 if (auto *PHI = dyn_cast<PHINode>(Val)) { 593 PHINode *NewPhi = static_cast<PHINode *>(NewInsts[PHI]); 594 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { 595 Value *NewIncoming = PHI->getIncomingValue(I); 596 597 if (NewInsts.contains(NewIncoming)) 598 NewIncoming = NewInsts[NewIncoming]; 599 600 NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I)); 601 } 602 } 603 } 604 605 for (Value *Val : Explored) { 606 if (Val == Base) 607 continue; 608 609 setInsertionPoint(Builder, Val, false); 610 // Create GEP for external users. 611 Value *NewVal = Builder.CreateInBoundsGEP( 612 Builder.getInt8Ty(), Base, NewInsts[Val], Val->getName() + ".ptr"); 613 IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal); 614 // Add old instruction to worklist for DCE. We don't directly remove it 615 // here because the original compare is one of the users. 616 IC.addToWorklist(cast<Instruction>(Val)); 617 } 618 619 return NewInsts[Start]; 620 } 621 622 /// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant. 623 /// We can look through PHIs, GEPs and casts in order to determine a common base 624 /// between GEPLHS and RHS. 625 static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, 626 ICmpInst::Predicate Cond, 627 const DataLayout &DL, 628 InstCombiner &IC) { 629 // FIXME: Support vector of pointers. 630 if (GEPLHS->getType()->isVectorTy()) 631 return nullptr; 632 633 if (!GEPLHS->hasAllConstantIndices()) 634 return nullptr; 635 636 APInt Offset(DL.getIndexTypeSizeInBits(GEPLHS->getType()), 0); 637 Value *PtrBase = 638 GEPLHS->stripAndAccumulateConstantOffsets(DL, Offset, 639 /*AllowNonInbounds*/ false); 640 641 // Bail if we looked through addrspacecast. 642 if (PtrBase->getType() != GEPLHS->getType()) 643 return nullptr; 644 645 // The set of nodes that will take part in this transformation. 646 SetVector<Value *> Nodes; 647 648 if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes)) 649 return nullptr; 650 651 // We know we can re-write this as 652 // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) 653 // Since we've only looked through inbouds GEPs we know that we 654 // can't have overflow on either side. We can therefore re-write 655 // this as: 656 // OFFSET1 cmp OFFSET2 657 Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes, IC); 658 659 // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written 660 // GEP having PtrBase as the pointer base, and has returned in NewRHS the 661 // offset. Since Index is the offset of LHS to the base pointer, we will now 662 // compare the offsets instead of comparing the pointers. 663 return new ICmpInst(ICmpInst::getSignedPredicate(Cond), 664 IC.Builder.getInt(Offset), NewRHS); 665 } 666 667 /// Fold comparisons between a GEP instruction and something else. At this point 668 /// we know that the GEP is on the LHS of the comparison. 669 Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, 670 ICmpInst::Predicate Cond, 671 Instruction &I) { 672 // Don't transform signed compares of GEPs into index compares. Even if the 673 // GEP is inbounds, the final add of the base pointer can have signed overflow 674 // and would change the result of the icmp. 675 // e.g. "&foo[0] <s &foo[1]" can't be folded to "true" because "foo" could be 676 // the maximum signed value for the pointer type. 677 if (ICmpInst::isSigned(Cond)) 678 return nullptr; 679 680 // Look through bitcasts and addrspacecasts. We do not however want to remove 681 // 0 GEPs. 682 if (!isa<GetElementPtrInst>(RHS)) 683 RHS = RHS->stripPointerCasts(); 684 685 Value *PtrBase = GEPLHS->getOperand(0); 686 if (PtrBase == RHS && (GEPLHS->isInBounds() || ICmpInst::isEquality(Cond))) { 687 // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). 688 Value *Offset = EmitGEPOffset(GEPLHS); 689 return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, 690 Constant::getNullValue(Offset->getType())); 691 } 692 693 if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) && 694 isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() && 695 !NullPointerIsDefined(I.getFunction(), 696 RHS->getType()->getPointerAddressSpace())) { 697 // For most address spaces, an allocation can't be placed at null, but null 698 // itself is treated as a 0 size allocation in the in bounds rules. Thus, 699 // the only valid inbounds address derived from null, is null itself. 700 // Thus, we have four cases to consider: 701 // 1) Base == nullptr, Offset == 0 -> inbounds, null 702 // 2) Base == nullptr, Offset != 0 -> poison as the result is out of bounds 703 // 3) Base != nullptr, Offset == (-base) -> poison (crossing allocations) 704 // 4) Base != nullptr, Offset != (-base) -> nonnull (and possibly poison) 705 // 706 // (Note if we're indexing a type of size 0, that simply collapses into one 707 // of the buckets above.) 708 // 709 // In general, we're allowed to make values less poison (i.e. remove 710 // sources of full UB), so in this case, we just select between the two 711 // non-poison cases (1 and 4 above). 712 // 713 // For vectors, we apply the same reasoning on a per-lane basis. 714 auto *Base = GEPLHS->getPointerOperand(); 715 if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { 716 auto EC = cast<VectorType>(GEPLHS->getType())->getElementCount(); 717 Base = Builder.CreateVectorSplat(EC, Base); 718 } 719 return new ICmpInst(Cond, Base, 720 ConstantExpr::getPointerBitCastOrAddrSpaceCast( 721 cast<Constant>(RHS), Base->getType())); 722 } else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) { 723 // If the base pointers are different, but the indices are the same, just 724 // compare the base pointer. 725 if (PtrBase != GEPRHS->getOperand(0)) { 726 bool IndicesTheSame = 727 GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && 728 GEPLHS->getPointerOperand()->getType() == 729 GEPRHS->getPointerOperand()->getType() && 730 GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType(); 731 if (IndicesTheSame) 732 for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) 733 if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { 734 IndicesTheSame = false; 735 break; 736 } 737 738 // If all indices are the same, just compare the base pointers. 739 Type *BaseType = GEPLHS->getOperand(0)->getType(); 740 if (IndicesTheSame && CmpInst::makeCmpResultType(BaseType) == I.getType()) 741 return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0)); 742 743 // If we're comparing GEPs with two base pointers that only differ in type 744 // and both GEPs have only constant indices or just one use, then fold 745 // the compare with the adjusted indices. 746 // FIXME: Support vector of pointers. 747 if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && 748 (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && 749 (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && 750 PtrBase->stripPointerCasts() == 751 GEPRHS->getOperand(0)->stripPointerCasts() && 752 !GEPLHS->getType()->isVectorTy()) { 753 Value *LOffset = EmitGEPOffset(GEPLHS); 754 Value *ROffset = EmitGEPOffset(GEPRHS); 755 756 // If we looked through an addrspacecast between different sized address 757 // spaces, the LHS and RHS pointers are different sized 758 // integers. Truncate to the smaller one. 759 Type *LHSIndexTy = LOffset->getType(); 760 Type *RHSIndexTy = ROffset->getType(); 761 if (LHSIndexTy != RHSIndexTy) { 762 if (LHSIndexTy->getPrimitiveSizeInBits().getFixedValue() < 763 RHSIndexTy->getPrimitiveSizeInBits().getFixedValue()) { 764 ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy); 765 } else 766 LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy); 767 } 768 769 Value *Cmp = Builder.CreateICmp(ICmpInst::getSignedPredicate(Cond), 770 LOffset, ROffset); 771 return replaceInstUsesWith(I, Cmp); 772 } 773 774 // Otherwise, the base pointers are different and the indices are 775 // different. Try convert this to an indexed compare by looking through 776 // PHIs/casts. 777 return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); 778 } 779 780 bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); 781 if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && 782 GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) { 783 // If the GEPs only differ by one index, compare it. 784 unsigned NumDifferences = 0; // Keep track of # differences. 785 unsigned DiffOperand = 0; // The operand that differs. 786 for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) 787 if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { 788 Type *LHSType = GEPLHS->getOperand(i)->getType(); 789 Type *RHSType = GEPRHS->getOperand(i)->getType(); 790 // FIXME: Better support for vector of pointers. 791 if (LHSType->getPrimitiveSizeInBits() != 792 RHSType->getPrimitiveSizeInBits() || 793 (GEPLHS->getType()->isVectorTy() && 794 (!LHSType->isVectorTy() || !RHSType->isVectorTy()))) { 795 // Irreconcilable differences. 796 NumDifferences = 2; 797 break; 798 } 799 800 if (NumDifferences++) break; 801 DiffOperand = i; 802 } 803 804 if (NumDifferences == 0) // SAME GEP? 805 return replaceInstUsesWith(I, // No comparison is needed here. 806 ConstantInt::get(I.getType(), ICmpInst::isTrueWhenEqual(Cond))); 807 808 else if (NumDifferences == 1 && GEPsInBounds) { 809 Value *LHSV = GEPLHS->getOperand(DiffOperand); 810 Value *RHSV = GEPRHS->getOperand(DiffOperand); 811 // Make sure we do a signed comparison here. 812 return new ICmpInst(ICmpInst::getSignedPredicate(Cond), LHSV, RHSV); 813 } 814 } 815 816 // Only lower this if the icmp is the only user of the GEP or if we expect 817 // the result to fold to a constant! 818 if ((GEPsInBounds || CmpInst::isEquality(Cond)) && 819 (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && 820 (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse())) { 821 // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) 822 Value *L = EmitGEPOffset(GEPLHS); 823 Value *R = EmitGEPOffset(GEPRHS); 824 return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); 825 } 826 } 827 828 // Try convert this to an indexed compare by looking through PHIs/casts as a 829 // last resort. 830 return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); 831 } 832 833 bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) { 834 // It would be tempting to fold away comparisons between allocas and any 835 // pointer not based on that alloca (e.g. an argument). However, even 836 // though such pointers cannot alias, they can still compare equal. 837 // 838 // But LLVM doesn't specify where allocas get their memory, so if the alloca 839 // doesn't escape we can argue that it's impossible to guess its value, and we 840 // can therefore act as if any such guesses are wrong. 841 // 842 // However, we need to ensure that this folding is consistent: We can't fold 843 // one comparison to false, and then leave a different comparison against the 844 // same value alone (as it might evaluate to true at runtime, leading to a 845 // contradiction). As such, this code ensures that all comparisons are folded 846 // at the same time, and there are no other escapes. 847 848 struct CmpCaptureTracker : public CaptureTracker { 849 AllocaInst *Alloca; 850 bool Captured = false; 851 /// The value of the map is a bit mask of which icmp operands the alloca is 852 /// used in. 853 SmallMapVector<ICmpInst *, unsigned, 4> ICmps; 854 855 CmpCaptureTracker(AllocaInst *Alloca) : Alloca(Alloca) {} 856 857 void tooManyUses() override { Captured = true; } 858 859 bool captured(const Use *U) override { 860 auto *ICmp = dyn_cast<ICmpInst>(U->getUser()); 861 // We need to check that U is based *only* on the alloca, and doesn't 862 // have other contributions from a select/phi operand. 863 // TODO: We could check whether getUnderlyingObjects() reduces to one 864 // object, which would allow looking through phi nodes. 865 if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) { 866 // Collect equality icmps of the alloca, and don't treat them as 867 // captures. 868 auto Res = ICmps.insert({ICmp, 0}); 869 Res.first->second |= 1u << U->getOperandNo(); 870 return false; 871 } 872 873 Captured = true; 874 return true; 875 } 876 }; 877 878 CmpCaptureTracker Tracker(Alloca); 879 PointerMayBeCaptured(Alloca, &Tracker); 880 if (Tracker.Captured) 881 return false; 882 883 bool Changed = false; 884 for (auto [ICmp, Operands] : Tracker.ICmps) { 885 switch (Operands) { 886 case 1: 887 case 2: { 888 // The alloca is only used in one icmp operand. Assume that the 889 // equality is false. 890 auto *Res = ConstantInt::get( 891 ICmp->getType(), ICmp->getPredicate() == ICmpInst::ICMP_NE); 892 replaceInstUsesWith(*ICmp, Res); 893 eraseInstFromFunction(*ICmp); 894 Changed = true; 895 break; 896 } 897 case 3: 898 // Both icmp operands are based on the alloca, so this is comparing 899 // pointer offsets, without leaking any information about the address 900 // of the alloca. Ignore such comparisons. 901 break; 902 default: 903 llvm_unreachable("Cannot happen"); 904 } 905 } 906 907 return Changed; 908 } 909 910 /// Fold "icmp pred (X+C), X". 911 Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C, 912 ICmpInst::Predicate Pred) { 913 // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, 914 // so the values can never be equal. Similarly for all other "or equals" 915 // operators. 916 assert(!!C && "C should not be zero!"); 917 918 // (X+1) <u X --> X >u (MAXUINT-1) --> X == 255 919 // (X+2) <u X --> X >u (MAXUINT-2) --> X > 253 920 // (X+MAXUINT) <u X --> X >u (MAXUINT-MAXUINT) --> X != 0 921 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { 922 Constant *R = ConstantInt::get(X->getType(), 923 APInt::getMaxValue(C.getBitWidth()) - C); 924 return new ICmpInst(ICmpInst::ICMP_UGT, X, R); 925 } 926 927 // (X+1) >u X --> X <u (0-1) --> X != 255 928 // (X+2) >u X --> X <u (0-2) --> X <u 254 929 // (X+MAXUINT) >u X --> X <u (0-MAXUINT) --> X <u 1 --> X == 0 930 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) 931 return new ICmpInst(ICmpInst::ICMP_ULT, X, 932 ConstantInt::get(X->getType(), -C)); 933 934 APInt SMax = APInt::getSignedMaxValue(C.getBitWidth()); 935 936 // (X+ 1) <s X --> X >s (MAXSINT-1) --> X == 127 937 // (X+ 2) <s X --> X >s (MAXSINT-2) --> X >s 125 938 // (X+MAXSINT) <s X --> X >s (MAXSINT-MAXSINT) --> X >s 0 939 // (X+MINSINT) <s X --> X >s (MAXSINT-MINSINT) --> X >s -1 940 // (X+ -2) <s X --> X >s (MAXSINT- -2) --> X >s 126 941 // (X+ -1) <s X --> X >s (MAXSINT- -1) --> X != 127 942 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) 943 return new ICmpInst(ICmpInst::ICMP_SGT, X, 944 ConstantInt::get(X->getType(), SMax - C)); 945 946 // (X+ 1) >s X --> X <s (MAXSINT-(1-1)) --> X != 127 947 // (X+ 2) >s X --> X <s (MAXSINT-(2-1)) --> X <s 126 948 // (X+MAXSINT) >s X --> X <s (MAXSINT-(MAXSINT-1)) --> X <s 1 949 // (X+MINSINT) >s X --> X <s (MAXSINT-(MINSINT-1)) --> X <s -2 950 // (X+ -2) >s X --> X <s (MAXSINT-(-2-1)) --> X <s -126 951 // (X+ -1) >s X --> X <s (MAXSINT-(-1-1)) --> X == -128 952 953 assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE); 954 return new ICmpInst(ICmpInst::ICMP_SLT, X, 955 ConstantInt::get(X->getType(), SMax - (C - 1))); 956 } 957 958 /// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> 959 /// (icmp eq/ne A, Log2(AP2/AP1)) -> 960 /// (icmp eq/ne A, Log2(AP2) - Log2(AP1)). 961 Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, 962 const APInt &AP1, 963 const APInt &AP2) { 964 assert(I.isEquality() && "Cannot fold icmp gt/lt"); 965 966 auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { 967 if (I.getPredicate() == I.ICMP_NE) 968 Pred = CmpInst::getInversePredicate(Pred); 969 return new ICmpInst(Pred, LHS, RHS); 970 }; 971 972 // Don't bother doing any work for cases which InstSimplify handles. 973 if (AP2.isZero()) 974 return nullptr; 975 976 bool IsAShr = isa<AShrOperator>(I.getOperand(0)); 977 if (IsAShr) { 978 if (AP2.isAllOnes()) 979 return nullptr; 980 if (AP2.isNegative() != AP1.isNegative()) 981 return nullptr; 982 if (AP2.sgt(AP1)) 983 return nullptr; 984 } 985 986 if (!AP1) 987 // 'A' must be large enough to shift out the highest set bit. 988 return getICmp(I.ICMP_UGT, A, 989 ConstantInt::get(A->getType(), AP2.logBase2())); 990 991 if (AP1 == AP2) 992 return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); 993 994 int Shift; 995 if (IsAShr && AP1.isNegative()) 996 Shift = AP1.countl_one() - AP2.countl_one(); 997 else 998 Shift = AP1.countl_zero() - AP2.countl_zero(); 999 1000 if (Shift > 0) { 1001 if (IsAShr && AP1 == AP2.ashr(Shift)) { 1002 // There are multiple solutions if we are comparing against -1 and the LHS 1003 // of the ashr is not a power of two. 1004 if (AP1.isAllOnes() && !AP2.isPowerOf2()) 1005 return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); 1006 return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); 1007 } else if (AP1 == AP2.lshr(Shift)) { 1008 return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); 1009 } 1010 } 1011 1012 // Shifting const2 will never be equal to const1. 1013 // FIXME: This should always be handled by InstSimplify? 1014 auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); 1015 return replaceInstUsesWith(I, TorF); 1016 } 1017 1018 /// Handle "(icmp eq/ne (shl AP2, A), AP1)" -> 1019 /// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)). 1020 Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, 1021 const APInt &AP1, 1022 const APInt &AP2) { 1023 assert(I.isEquality() && "Cannot fold icmp gt/lt"); 1024 1025 auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { 1026 if (I.getPredicate() == I.ICMP_NE) 1027 Pred = CmpInst::getInversePredicate(Pred); 1028 return new ICmpInst(Pred, LHS, RHS); 1029 }; 1030 1031 // Don't bother doing any work for cases which InstSimplify handles. 1032 if (AP2.isZero()) 1033 return nullptr; 1034 1035 unsigned AP2TrailingZeros = AP2.countr_zero(); 1036 1037 if (!AP1 && AP2TrailingZeros != 0) 1038 return getICmp( 1039 I.ICMP_UGE, A, 1040 ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); 1041 1042 if (AP1 == AP2) 1043 return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); 1044 1045 // Get the distance between the lowest bits that are set. 1046 int Shift = AP1.countr_zero() - AP2TrailingZeros; 1047 1048 if (Shift > 0 && AP2.shl(Shift) == AP1) 1049 return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); 1050 1051 // Shifting const2 will never be equal to const1. 1052 // FIXME: This should always be handled by InstSimplify? 1053 auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); 1054 return replaceInstUsesWith(I, TorF); 1055 } 1056 1057 /// The caller has matched a pattern of the form: 1058 /// I = icmp ugt (add (add A, B), CI2), CI1 1059 /// If this is of the form: 1060 /// sum = a + b 1061 /// if (sum+128 >u 255) 1062 /// Then replace it with llvm.sadd.with.overflow.i8. 1063 /// 1064 static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, 1065 ConstantInt *CI2, ConstantInt *CI1, 1066 InstCombinerImpl &IC) { 1067 // The transformation we're trying to do here is to transform this into an 1068 // llvm.sadd.with.overflow. To do this, we have to replace the original add 1069 // with a narrower add, and discard the add-with-constant that is part of the 1070 // range check (if we can't eliminate it, this isn't profitable). 1071 1072 // In order to eliminate the add-with-constant, the compare can be its only 1073 // use. 1074 Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); 1075 if (!AddWithCst->hasOneUse()) 1076 return nullptr; 1077 1078 // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. 1079 if (!CI2->getValue().isPowerOf2()) 1080 return nullptr; 1081 unsigned NewWidth = CI2->getValue().countr_zero(); 1082 if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) 1083 return nullptr; 1084 1085 // The width of the new add formed is 1 more than the bias. 1086 ++NewWidth; 1087 1088 // Check to see that CI1 is an all-ones value with NewWidth bits. 1089 if (CI1->getBitWidth() == NewWidth || 1090 CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) 1091 return nullptr; 1092 1093 // This is only really a signed overflow check if the inputs have been 1094 // sign-extended; check for that condition. For example, if CI2 is 2^31 and 1095 // the operands of the add are 64 bits wide, we need at least 33 sign bits. 1096 if (IC.ComputeMaxSignificantBits(A, 0, &I) > NewWidth || 1097 IC.ComputeMaxSignificantBits(B, 0, &I) > NewWidth) 1098 return nullptr; 1099 1100 // In order to replace the original add with a narrower 1101 // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant 1102 // and truncates that discard the high bits of the add. Verify that this is 1103 // the case. 1104 Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); 1105 for (User *U : OrigAdd->users()) { 1106 if (U == AddWithCst) 1107 continue; 1108 1109 // Only accept truncates for now. We would really like a nice recursive 1110 // predicate like SimplifyDemandedBits, but which goes downwards the use-def 1111 // chain to see which bits of a value are actually demanded. If the 1112 // original add had another add which was then immediately truncated, we 1113 // could still do the transformation. 1114 TruncInst *TI = dyn_cast<TruncInst>(U); 1115 if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) 1116 return nullptr; 1117 } 1118 1119 // If the pattern matches, truncate the inputs to the narrower type and 1120 // use the sadd_with_overflow intrinsic to efficiently compute both the 1121 // result and the overflow bit. 1122 Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); 1123 Function *F = Intrinsic::getDeclaration( 1124 I.getModule(), Intrinsic::sadd_with_overflow, NewType); 1125 1126 InstCombiner::BuilderTy &Builder = IC.Builder; 1127 1128 // Put the new code above the original add, in case there are any uses of the 1129 // add between the add and the compare. 1130 Builder.SetInsertPoint(OrigAdd); 1131 1132 Value *TruncA = Builder.CreateTrunc(A, NewType, A->getName() + ".trunc"); 1133 Value *TruncB = Builder.CreateTrunc(B, NewType, B->getName() + ".trunc"); 1134 CallInst *Call = Builder.CreateCall(F, {TruncA, TruncB}, "sadd"); 1135 Value *Add = Builder.CreateExtractValue(Call, 0, "sadd.result"); 1136 Value *ZExt = Builder.CreateZExt(Add, OrigAdd->getType()); 1137 1138 // The inner add was the result of the narrow add, zero extended to the 1139 // wider type. Replace it with the result computed by the intrinsic. 1140 IC.replaceInstUsesWith(*OrigAdd, ZExt); 1141 IC.eraseInstFromFunction(*OrigAdd); 1142 1143 // The original icmp gets replaced with the overflow value. 1144 return ExtractValueInst::Create(Call, 1, "sadd.overflow"); 1145 } 1146 1147 /// If we have: 1148 /// icmp eq/ne (urem/srem %x, %y), 0 1149 /// iff %y is a power-of-two, we can replace this with a bit test: 1150 /// icmp eq/ne (and %x, (add %y, -1)), 0 1151 Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { 1152 // This fold is only valid for equality predicates. 1153 if (!I.isEquality()) 1154 return nullptr; 1155 ICmpInst::Predicate Pred; 1156 Value *X, *Y, *Zero; 1157 if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))), 1158 m_CombineAnd(m_Zero(), m_Value(Zero))))) 1159 return nullptr; 1160 if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I)) 1161 return nullptr; 1162 // This may increase instruction count, we don't enforce that Y is a constant. 1163 Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType())); 1164 Value *Masked = Builder.CreateAnd(X, Mask); 1165 return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero); 1166 } 1167 1168 /// Fold equality-comparison between zero and any (maybe truncated) right-shift 1169 /// by one-less-than-bitwidth into a sign test on the original value. 1170 Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) { 1171 Instruction *Val; 1172 ICmpInst::Predicate Pred; 1173 if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) 1174 return nullptr; 1175 1176 Value *X; 1177 Type *XTy; 1178 1179 Constant *C; 1180 if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { 1181 XTy = X->getType(); 1182 unsigned XBitWidth = XTy->getScalarSizeInBits(); 1183 if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, 1184 APInt(XBitWidth, XBitWidth - 1)))) 1185 return nullptr; 1186 } else if (isa<BinaryOperator>(Val) && 1187 (X = reassociateShiftAmtsOfTwoSameDirectionShifts( 1188 cast<BinaryOperator>(Val), SQ.getWithInstruction(Val), 1189 /*AnalyzeForSignBitExtraction=*/true))) { 1190 XTy = X->getType(); 1191 } else 1192 return nullptr; 1193 1194 return ICmpInst::Create(Instruction::ICmp, 1195 Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE 1196 : ICmpInst::ICMP_SLT, 1197 X, ConstantInt::getNullValue(XTy)); 1198 } 1199 1200 // Handle icmp pred X, 0 1201 Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { 1202 CmpInst::Predicate Pred = Cmp.getPredicate(); 1203 if (!match(Cmp.getOperand(1), m_Zero())) 1204 return nullptr; 1205 1206 // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) 1207 if (Pred == ICmpInst::ICMP_SGT) { 1208 Value *A, *B; 1209 if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) { 1210 if (isKnownPositive(A, SQ.getWithInstruction(&Cmp))) 1211 return new ICmpInst(Pred, B, Cmp.getOperand(1)); 1212 if (isKnownPositive(B, SQ.getWithInstruction(&Cmp))) 1213 return new ICmpInst(Pred, A, Cmp.getOperand(1)); 1214 } 1215 } 1216 1217 if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp)) 1218 return New; 1219 1220 // Given: 1221 // icmp eq/ne (urem %x, %y), 0 1222 // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': 1223 // icmp eq/ne %x, 0 1224 Value *X, *Y; 1225 if (match(Cmp.getOperand(0), m_URem(m_Value(X), m_Value(Y))) && 1226 ICmpInst::isEquality(Pred)) { 1227 KnownBits XKnown = computeKnownBits(X, 0, &Cmp); 1228 KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); 1229 if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2) 1230 return new ICmpInst(Pred, X, Cmp.getOperand(1)); 1231 } 1232 1233 // (icmp eq/ne (mul X Y)) -> (icmp eq/ne X/Y) if we know about whether X/Y are 1234 // odd/non-zero/there is no overflow. 1235 if (match(Cmp.getOperand(0), m_Mul(m_Value(X), m_Value(Y))) && 1236 ICmpInst::isEquality(Pred)) { 1237 1238 KnownBits XKnown = computeKnownBits(X, 0, &Cmp); 1239 // if X % 2 != 0 1240 // (icmp eq/ne Y) 1241 if (XKnown.countMaxTrailingZeros() == 0) 1242 return new ICmpInst(Pred, Y, Cmp.getOperand(1)); 1243 1244 KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); 1245 // if Y % 2 != 0 1246 // (icmp eq/ne X) 1247 if (YKnown.countMaxTrailingZeros() == 0) 1248 return new ICmpInst(Pred, X, Cmp.getOperand(1)); 1249 1250 auto *BO0 = cast<OverflowingBinaryOperator>(Cmp.getOperand(0)); 1251 if (BO0->hasNoUnsignedWrap() || BO0->hasNoSignedWrap()) { 1252 const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); 1253 // `isKnownNonZero` does more analysis than just `!KnownBits.One.isZero()` 1254 // but to avoid unnecessary work, first just if this is an obvious case. 1255 1256 // if X non-zero and NoOverflow(X * Y) 1257 // (icmp eq/ne Y) 1258 if (!XKnown.One.isZero() || isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) 1259 return new ICmpInst(Pred, Y, Cmp.getOperand(1)); 1260 1261 // if Y non-zero and NoOverflow(X * Y) 1262 // (icmp eq/ne X) 1263 if (!YKnown.One.isZero() || isKnownNonZero(Y, DL, 0, Q.AC, Q.CxtI, Q.DT)) 1264 return new ICmpInst(Pred, X, Cmp.getOperand(1)); 1265 } 1266 // Note, we are skipping cases: 1267 // if Y % 2 != 0 AND X % 2 != 0 1268 // (false/true) 1269 // if X non-zero and Y non-zero and NoOverflow(X * Y) 1270 // (false/true) 1271 // Those can be simplified later as we would have already replaced the (icmp 1272 // eq/ne (mul X, Y)) with (icmp eq/ne X/Y) and if X/Y is known non-zero that 1273 // will fold to a constant elsewhere. 1274 } 1275 return nullptr; 1276 } 1277 1278 /// Fold icmp Pred X, C. 1279 /// TODO: This code structure does not make sense. The saturating add fold 1280 /// should be moved to some other helper and extended as noted below (it is also 1281 /// possible that code has been made unnecessary - do we canonicalize IR to 1282 /// overflow/saturating intrinsics or not?). 1283 Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { 1284 // Match the following pattern, which is a common idiom when writing 1285 // overflow-safe integer arithmetic functions. The source performs an addition 1286 // in wider type and explicitly checks for overflow using comparisons against 1287 // INT_MIN and INT_MAX. Simplify by using the sadd_with_overflow intrinsic. 1288 // 1289 // TODO: This could probably be generalized to handle other overflow-safe 1290 // operations if we worked out the formulas to compute the appropriate magic 1291 // constants. 1292 // 1293 // sum = a + b 1294 // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 1295 CmpInst::Predicate Pred = Cmp.getPredicate(); 1296 Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1); 1297 Value *A, *B; 1298 ConstantInt *CI, *CI2; // I = icmp ugt (add (add A, B), CI2), CI 1299 if (Pred == ICmpInst::ICMP_UGT && match(Op1, m_ConstantInt(CI)) && 1300 match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) 1301 if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this)) 1302 return Res; 1303 1304 // icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...). 1305 Constant *C = dyn_cast<Constant>(Op1); 1306 if (!C) 1307 return nullptr; 1308 1309 if (auto *Phi = dyn_cast<PHINode>(Op0)) 1310 if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) { 1311 SmallVector<Constant *> Ops; 1312 for (Value *V : Phi->incoming_values()) { 1313 Constant *Res = 1314 ConstantFoldCompareInstOperands(Pred, cast<Constant>(V), C, DL); 1315 if (!Res) 1316 return nullptr; 1317 Ops.push_back(Res); 1318 } 1319 Builder.SetInsertPoint(Phi); 1320 PHINode *NewPhi = Builder.CreatePHI(Cmp.getType(), Phi->getNumOperands()); 1321 for (auto [V, Pred] : zip(Ops, Phi->blocks())) 1322 NewPhi->addIncoming(V, Pred); 1323 return replaceInstUsesWith(Cmp, NewPhi); 1324 } 1325 1326 if (Instruction *R = tryFoldInstWithCtpopWithNot(&Cmp)) 1327 return R; 1328 1329 return nullptr; 1330 } 1331 1332 /// Canonicalize icmp instructions based on dominating conditions. 1333 Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { 1334 // We already checked simple implication in InstSimplify, only handle complex 1335 // cases here. 1336 Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1); 1337 ICmpInst::Predicate DomPred; 1338 const APInt *C; 1339 if (!match(Y, m_APInt(C))) 1340 return nullptr; 1341 1342 CmpInst::Predicate Pred = Cmp.getPredicate(); 1343 ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C); 1344 1345 auto handleDomCond = [&](Value *DomCond, bool CondIsTrue) -> Instruction * { 1346 const APInt *DomC; 1347 if (!match(DomCond, m_ICmp(DomPred, m_Specific(X), m_APInt(DomC)))) 1348 return nullptr; 1349 // We have 2 compares of a variable with constants. Calculate the constant 1350 // ranges of those compares to see if we can transform the 2nd compare: 1351 // DomBB: 1352 // DomCond = icmp DomPred X, DomC 1353 // br DomCond, CmpBB, FalseBB 1354 // CmpBB: 1355 // Cmp = icmp Pred X, C 1356 if (!CondIsTrue) 1357 DomPred = CmpInst::getInversePredicate(DomPred); 1358 ConstantRange DominatingCR = 1359 ConstantRange::makeExactICmpRegion(DomPred, *DomC); 1360 ConstantRange Intersection = DominatingCR.intersectWith(CR); 1361 ConstantRange Difference = DominatingCR.difference(CR); 1362 if (Intersection.isEmptySet()) 1363 return replaceInstUsesWith(Cmp, Builder.getFalse()); 1364 if (Difference.isEmptySet()) 1365 return replaceInstUsesWith(Cmp, Builder.getTrue()); 1366 1367 // Canonicalizing a sign bit comparison that gets used in a branch, 1368 // pessimizes codegen by generating branch on zero instruction instead 1369 // of a test and branch. So we avoid canonicalizing in such situations 1370 // because test and branch instruction has better branch displacement 1371 // than compare and branch instruction. 1372 bool UnusedBit; 1373 bool IsSignBit = isSignBitCheck(Pred, *C, UnusedBit); 1374 if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp))) 1375 return nullptr; 1376 1377 // Avoid an infinite loop with min/max canonicalization. 1378 // TODO: This will be unnecessary if we canonicalize to min/max intrinsics. 1379 if (Cmp.hasOneUse() && 1380 match(Cmp.user_back(), m_MaxOrMin(m_Value(), m_Value()))) 1381 return nullptr; 1382 1383 if (const APInt *EqC = Intersection.getSingleElement()) 1384 return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*EqC)); 1385 if (const APInt *NeC = Difference.getSingleElement()) 1386 return new ICmpInst(ICmpInst::ICMP_NE, X, Builder.getInt(*NeC)); 1387 return nullptr; 1388 }; 1389 1390 for (BranchInst *BI : DC.conditionsFor(X)) { 1391 auto *Cond = BI->getCondition(); 1392 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0)); 1393 if (DT.dominates(Edge0, Cmp.getParent())) { 1394 if (auto *V = handleDomCond(Cond, true)) 1395 return V; 1396 } else { 1397 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1)); 1398 if (DT.dominates(Edge1, Cmp.getParent())) 1399 if (auto *V = handleDomCond(Cond, false)) 1400 return V; 1401 } 1402 } 1403 1404 return nullptr; 1405 } 1406 1407 /// Fold icmp (trunc X), C. 1408 Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, 1409 TruncInst *Trunc, 1410 const APInt &C) { 1411 ICmpInst::Predicate Pred = Cmp.getPredicate(); 1412 Value *X = Trunc->getOperand(0); 1413 if (C.isOne() && C.getBitWidth() > 1) { 1414 // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 1415 Value *V = nullptr; 1416 if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) 1417 return new ICmpInst(ICmpInst::ICMP_SLT, V, 1418 ConstantInt::get(V->getType(), 1)); 1419 } 1420 1421 Type *SrcTy = X->getType(); 1422 unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), 1423 SrcBits = SrcTy->getScalarSizeInBits(); 1424 1425 // TODO: Handle any shifted constant by subtracting trailing zeros. 1426 // TODO: Handle non-equality predicates. 1427 Value *Y; 1428 if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) { 1429 // (trunc (1 << Y) to iN) == 0 --> Y u>= N 1430 // (trunc (1 << Y) to iN) != 0 --> Y u< N 1431 if (C.isZero()) { 1432 auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT; 1433 return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits)); 1434 } 1435 // (trunc (1 << Y) to iN) == 2**C --> Y == C 1436 // (trunc (1 << Y) to iN) != 2**C --> Y != C 1437 if (C.isPowerOf2()) 1438 return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); 1439 } 1440 1441 if (Cmp.isEquality() && Trunc->hasOneUse()) { 1442 // Canonicalize to a mask and wider compare if the wide type is suitable: 1443 // (trunc X to i8) == C --> (X & 0xff) == (zext C) 1444 if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { 1445 Constant *Mask = 1446 ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits)); 1447 Value *And = Builder.CreateAnd(X, Mask); 1448 Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits)); 1449 return new ICmpInst(Pred, And, WideC); 1450 } 1451 1452 // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all 1453 // of the high bits truncated out of x are known. 1454 KnownBits Known = computeKnownBits(X, 0, &Cmp); 1455 1456 // If all the high bits are known, we can do this xform. 1457 if ((Known.Zero | Known.One).countl_one() >= SrcBits - DstBits) { 1458 // Pull in the high bits from known-ones set. 1459 APInt NewRHS = C.zext(SrcBits); 1460 NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); 1461 return new ICmpInst(Pred, X, ConstantInt::get(SrcTy, NewRHS)); 1462 } 1463 } 1464 1465 // Look through truncated right-shift of the sign-bit for a sign-bit check: 1466 // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0 1467 // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1 1468 Value *ShOp; 1469 const APInt *ShAmtC; 1470 bool TrueIfSigned; 1471 if (isSignBitCheck(Pred, C, TrueIfSigned) && 1472 match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) && 1473 DstBits == SrcBits - ShAmtC->getZExtValue()) { 1474 return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, 1475 ConstantInt::getNullValue(SrcTy)) 1476 : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, 1477 ConstantInt::getAllOnesValue(SrcTy)); 1478 } 1479 1480 return nullptr; 1481 } 1482 1483 /// Fold icmp (trunc X), (trunc Y). 1484 /// Fold icmp (trunc X), (zext Y). 1485 Instruction * 1486 InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, 1487 const SimplifyQuery &Q) { 1488 if (Cmp.isSigned()) 1489 return nullptr; 1490 1491 Value *X, *Y; 1492 ICmpInst::Predicate Pred; 1493 bool YIsZext = false; 1494 // Try to match icmp (trunc X), (trunc Y) 1495 if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) { 1496 if (X->getType() != Y->getType() && 1497 (!Cmp.getOperand(0)->hasOneUse() || !Cmp.getOperand(1)->hasOneUse())) 1498 return nullptr; 1499 if (!isDesirableIntType(X->getType()->getScalarSizeInBits()) && 1500 isDesirableIntType(Y->getType()->getScalarSizeInBits())) { 1501 std::swap(X, Y); 1502 Pred = Cmp.getSwappedPredicate(Pred); 1503 } 1504 } 1505 // Try to match icmp (trunc X), (zext Y) 1506 else if (match(&Cmp, m_c_ICmp(Pred, m_Trunc(m_Value(X)), 1507 m_OneUse(m_ZExt(m_Value(Y)))))) 1508 1509 YIsZext = true; 1510 else 1511 return nullptr; 1512 1513 Type *TruncTy = Cmp.getOperand(0)->getType(); 1514 unsigned TruncBits = TruncTy->getScalarSizeInBits(); 1515 1516 // If this transform will end up changing from desirable types -> undesirable 1517 // types skip it. 1518 if (isDesirableIntType(TruncBits) && 1519 !isDesirableIntType(X->getType()->getScalarSizeInBits())) 1520 return nullptr; 1521 1522 // Check if the trunc is unneeded. 1523 KnownBits KnownX = llvm::computeKnownBits(X, /*Depth*/ 0, Q); 1524 if (KnownX.countMaxActiveBits() > TruncBits) 1525 return nullptr; 1526 1527 if (!YIsZext) { 1528 // If Y is also a trunc, make sure it is unneeded. 1529 KnownBits KnownY = llvm::computeKnownBits(Y, /*Depth*/ 0, Q); 1530 if (KnownY.countMaxActiveBits() > TruncBits) 1531 return nullptr; 1532 } 1533 1534 Value *NewY = Builder.CreateZExtOrTrunc(Y, X->getType()); 1535 return new ICmpInst(Pred, X, NewY); 1536 } 1537 1538 /// Fold icmp (xor X, Y), C. 1539 Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, 1540 BinaryOperator *Xor, 1541 const APInt &C) { 1542 if (Instruction *I = foldICmpXorShiftConst(Cmp, Xor, C)) 1543 return I; 1544 1545 Value *X = Xor->getOperand(0); 1546 Value *Y = Xor->getOperand(1); 1547 const APInt *XorC; 1548 if (!match(Y, m_APInt(XorC))) 1549 return nullptr; 1550 1551 // If this is a comparison that tests the signbit (X < 0) or (x > -1), 1552 // fold the xor. 1553 ICmpInst::Predicate Pred = Cmp.getPredicate(); 1554 bool TrueIfSigned = false; 1555 if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) { 1556 1557 // If the sign bit of the XorCst is not set, there is no change to 1558 // the operation, just stop using the Xor. 1559 if (!XorC->isNegative()) 1560 return replaceOperand(Cmp, 0, X); 1561 1562 // Emit the opposite comparison. 1563 if (TrueIfSigned) 1564 return new ICmpInst(ICmpInst::ICMP_SGT, X, 1565 ConstantInt::getAllOnesValue(X->getType())); 1566 else 1567 return new ICmpInst(ICmpInst::ICMP_SLT, X, 1568 ConstantInt::getNullValue(X->getType())); 1569 } 1570 1571 if (Xor->hasOneUse()) { 1572 // (icmp u/s (xor X SignMask), C) -> (icmp s/u X, (xor C SignMask)) 1573 if (!Cmp.isEquality() && XorC->isSignMask()) { 1574 Pred = Cmp.getFlippedSignednessPredicate(); 1575 return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); 1576 } 1577 1578 // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) 1579 if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { 1580 Pred = Cmp.getFlippedSignednessPredicate(); 1581 Pred = Cmp.getSwappedPredicate(Pred); 1582 return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); 1583 } 1584 } 1585 1586 // Mask constant magic can eliminate an 'xor' with unsigned compares. 1587 if (Pred == ICmpInst::ICMP_UGT) { 1588 // (xor X, ~C) >u C --> X <u ~C (when C+1 is a power of 2) 1589 if (*XorC == ~C && (C + 1).isPowerOf2()) 1590 return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); 1591 // (xor X, C) >u C --> X >u C (when C+1 is a power of 2) 1592 if (*XorC == C && (C + 1).isPowerOf2()) 1593 return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); 1594 } 1595 if (Pred == ICmpInst::ICMP_ULT) { 1596 // (xor X, -C) <u C --> X >u ~C (when C is a power of 2) 1597 if (*XorC == -C && C.isPowerOf2()) 1598 return new ICmpInst(ICmpInst::ICMP_UGT, X, 1599 ConstantInt::get(X->getType(), ~C)); 1600 // (xor X, C) <u C --> X >u ~C (when -C is a power of 2) 1601 if (*XorC == C && (-C).isPowerOf2()) 1602 return new ICmpInst(ICmpInst::ICMP_UGT, X, 1603 ConstantInt::get(X->getType(), ~C)); 1604 } 1605 return nullptr; 1606 } 1607 1608 /// For power-of-2 C: 1609 /// ((X s>> ShiftC) ^ X) u< C --> (X + C) u< (C << 1) 1610 /// ((X s>> ShiftC) ^ X) u> (C - 1) --> (X + C) u> ((C << 1) - 1) 1611 Instruction *InstCombinerImpl::foldICmpXorShiftConst(ICmpInst &Cmp, 1612 BinaryOperator *Xor, 1613 const APInt &C) { 1614 CmpInst::Predicate Pred = Cmp.getPredicate(); 1615 APInt PowerOf2; 1616 if (Pred == ICmpInst::ICMP_ULT) 1617 PowerOf2 = C; 1618 else if (Pred == ICmpInst::ICMP_UGT && !C.isMaxValue()) 1619 PowerOf2 = C + 1; 1620 else 1621 return nullptr; 1622 if (!PowerOf2.isPowerOf2()) 1623 return nullptr; 1624 Value *X; 1625 const APInt *ShiftC; 1626 if (!match(Xor, m_OneUse(m_c_Xor(m_Value(X), 1627 m_AShr(m_Deferred(X), m_APInt(ShiftC)))))) 1628 return nullptr; 1629 uint64_t Shift = ShiftC->getLimitedValue(); 1630 Type *XType = X->getType(); 1631 if (Shift == 0 || PowerOf2.isMinSignedValue()) 1632 return nullptr; 1633 Value *Add = Builder.CreateAdd(X, ConstantInt::get(XType, PowerOf2)); 1634 APInt Bound = 1635 Pred == ICmpInst::ICMP_ULT ? PowerOf2 << 1 : ((PowerOf2 << 1) - 1); 1636 return new ICmpInst(Pred, Add, ConstantInt::get(XType, Bound)); 1637 } 1638 1639 /// Fold icmp (and (sh X, Y), C2), C1. 1640 Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, 1641 BinaryOperator *And, 1642 const APInt &C1, 1643 const APInt &C2) { 1644 BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); 1645 if (!Shift || !Shift->isShift()) 1646 return nullptr; 1647 1648 // If this is: (X >> C3) & C2 != C1 (where any shift and any compare could 1649 // exist), turn it into (X & (C2 << C3)) != (C1 << C3). This happens a LOT in 1650 // code produced by the clang front-end, for bitfield access. 1651 // This seemingly simple opportunity to fold away a shift turns out to be 1652 // rather complicated. See PR17827 for details. 1653 unsigned ShiftOpcode = Shift->getOpcode(); 1654 bool IsShl = ShiftOpcode == Instruction::Shl; 1655 const APInt *C3; 1656 if (match(Shift->getOperand(1), m_APInt(C3))) { 1657 APInt NewAndCst, NewCmpCst; 1658 bool AnyCmpCstBitsShiftedOut; 1659 if (ShiftOpcode == Instruction::Shl) { 1660 // For a left shift, we can fold if the comparison is not signed. We can 1661 // also fold a signed comparison if the mask value and comparison value 1662 // are not negative. These constraints may not be obvious, but we can 1663 // prove that they are correct using an SMT solver. 1664 if (Cmp.isSigned() && (C2.isNegative() || C1.isNegative())) 1665 return nullptr; 1666 1667 NewCmpCst = C1.lshr(*C3); 1668 NewAndCst = C2.lshr(*C3); 1669 AnyCmpCstBitsShiftedOut = NewCmpCst.shl(*C3) != C1; 1670 } else if (ShiftOpcode == Instruction::LShr) { 1671 // For a logical right shift, we can fold if the comparison is not signed. 1672 // We can also fold a signed comparison if the shifted mask value and the 1673 // shifted comparison value are not negative. These constraints may not be 1674 // obvious, but we can prove that they are correct using an SMT solver. 1675 NewCmpCst = C1.shl(*C3); 1676 NewAndCst = C2.shl(*C3); 1677 AnyCmpCstBitsShiftedOut = NewCmpCst.lshr(*C3) != C1; 1678 if (Cmp.isSigned() && (NewAndCst.isNegative() || NewCmpCst.isNegative())) 1679 return nullptr; 1680 } else { 1681 // For an arithmetic shift, check that both constants don't use (in a 1682 // signed sense) the top bits being shifted out. 1683 assert(ShiftOpcode == Instruction::AShr && "Unknown shift opcode"); 1684 NewCmpCst = C1.shl(*C3); 1685 NewAndCst = C2.shl(*C3); 1686 AnyCmpCstBitsShiftedOut = NewCmpCst.ashr(*C3) != C1; 1687 if (NewAndCst.ashr(*C3) != C2) 1688 return nullptr; 1689 } 1690 1691 if (AnyCmpCstBitsShiftedOut) { 1692 // If we shifted bits out, the fold is not going to work out. As a 1693 // special case, check to see if this means that the result is always 1694 // true or false now. 1695 if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) 1696 return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); 1697 if (Cmp.getPredicate() == ICmpInst::ICMP_NE) 1698 return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); 1699 } else { 1700 Value *NewAnd = Builder.CreateAnd( 1701 Shift->getOperand(0), ConstantInt::get(And->getType(), NewAndCst)); 1702 return new ICmpInst(Cmp.getPredicate(), 1703 NewAnd, ConstantInt::get(And->getType(), NewCmpCst)); 1704 } 1705 } 1706 1707 // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is 1708 // preferable because it allows the C2 << Y expression to be hoisted out of a 1709 // loop if Y is invariant and X is not. 1710 if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() && 1711 !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { 1712 // Compute C2 << Y. 1713 Value *NewShift = 1714 IsShl ? Builder.CreateLShr(And->getOperand(1), Shift->getOperand(1)) 1715 : Builder.CreateShl(And->getOperand(1), Shift->getOperand(1)); 1716 1717 // Compute X & (C2 << Y). 1718 Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift); 1719 return replaceOperand(Cmp, 0, NewAnd); 1720 } 1721 1722 return nullptr; 1723 } 1724 1725 /// Fold icmp (and X, C2), C1. 1726 Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, 1727 BinaryOperator *And, 1728 const APInt &C1) { 1729 bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE; 1730 1731 // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 1732 // TODO: We canonicalize to the longer form for scalars because we have 1733 // better analysis/folds for icmp, and codegen may be better with icmp. 1734 if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isZero() && 1735 match(And->getOperand(1), m_One())) 1736 return new TruncInst(And->getOperand(0), Cmp.getType()); 1737 1738 const APInt *C2; 1739 Value *X; 1740 if (!match(And, m_And(m_Value(X), m_APInt(C2)))) 1741 return nullptr; 1742 1743 // Don't perform the following transforms if the AND has multiple uses 1744 if (!And->hasOneUse()) 1745 return nullptr; 1746 1747 if (Cmp.isEquality() && C1.isZero()) { 1748 // Restrict this fold to single-use 'and' (PR10267). 1749 // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 1750 if (C2->isSignMask()) { 1751 Constant *Zero = Constant::getNullValue(X->getType()); 1752 auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; 1753 return new ICmpInst(NewPred, X, Zero); 1754 } 1755 1756 APInt NewC2 = *C2; 1757 KnownBits Know = computeKnownBits(And->getOperand(0), 0, And); 1758 // Set high zeros of C2 to allow matching negated power-of-2. 1759 NewC2 = *C2 | APInt::getHighBitsSet(C2->getBitWidth(), 1760 Know.countMinLeadingZeros()); 1761 1762 // Restrict this fold only for single-use 'and' (PR10267). 1763 // ((%x & C) == 0) --> %x u< (-C) iff (-C) is power of two. 1764 if (NewC2.isNegatedPowerOf2()) { 1765 Constant *NegBOC = ConstantInt::get(And->getType(), -NewC2); 1766 auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; 1767 return new ICmpInst(NewPred, X, NegBOC); 1768 } 1769 } 1770 1771 // If the LHS is an 'and' of a truncate and we can widen the and/compare to 1772 // the input width without changing the value produced, eliminate the cast: 1773 // 1774 // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1' 1775 // 1776 // We can do this transformation if the constants do not have their sign bits 1777 // set or if it is an equality comparison. Extending a relational comparison 1778 // when we're checking the sign bit would not work. 1779 Value *W; 1780 if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) && 1781 (Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) { 1782 // TODO: Is this a good transform for vectors? Wider types may reduce 1783 // throughput. Should this transform be limited (even for scalars) by using 1784 // shouldChangeType()? 1785 if (!Cmp.getType()->isVectorTy()) { 1786 Type *WideType = W->getType(); 1787 unsigned WideScalarBits = WideType->getScalarSizeInBits(); 1788 Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits)); 1789 Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); 1790 Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName()); 1791 return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); 1792 } 1793 } 1794 1795 if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2)) 1796 return I; 1797 1798 // (icmp pred (and (or (lshr A, B), A), 1), 0) --> 1799 // (icmp pred (and A, (or (shl 1, B), 1), 0)) 1800 // 1801 // iff pred isn't signed 1802 if (!Cmp.isSigned() && C1.isZero() && And->getOperand(0)->hasOneUse() && 1803 match(And->getOperand(1), m_One())) { 1804 Constant *One = cast<Constant>(And->getOperand(1)); 1805 Value *Or = And->getOperand(0); 1806 Value *A, *B, *LShr; 1807 if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && 1808 match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { 1809 unsigned UsesRemoved = 0; 1810 if (And->hasOneUse()) 1811 ++UsesRemoved; 1812 if (Or->hasOneUse()) 1813 ++UsesRemoved; 1814 if (LShr->hasOneUse()) 1815 ++UsesRemoved; 1816 1817 // Compute A & ((1 << B) | 1) 1818 unsigned RequireUsesRemoved = match(B, m_ImmConstant()) ? 1 : 3; 1819 if (UsesRemoved >= RequireUsesRemoved) { 1820 Value *NewOr = 1821 Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), 1822 /*HasNUW=*/true), 1823 One, Or->getName()); 1824 Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); 1825 return replaceOperand(Cmp, 0, NewAnd); 1826 } 1827 } 1828 } 1829 1830 return nullptr; 1831 } 1832 1833 /// Fold icmp (and X, Y), C. 1834 Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, 1835 BinaryOperator *And, 1836 const APInt &C) { 1837 if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) 1838 return I; 1839 1840 const ICmpInst::Predicate Pred = Cmp.getPredicate(); 1841 bool TrueIfNeg; 1842 if (isSignBitCheck(Pred, C, TrueIfNeg)) { 1843 // ((X - 1) & ~X) < 0 --> X == 0 1844 // ((X - 1) & ~X) >= 0 --> X != 0 1845 Value *X; 1846 if (match(And->getOperand(0), m_Add(m_Value(X), m_AllOnes())) && 1847 match(And->getOperand(1), m_Not(m_Specific(X)))) { 1848 auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; 1849 return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType())); 1850 } 1851 // (X & X) < 0 --> X == MinSignedC 1852 // (X & X) > -1 --> X != MinSignedC 1853 if (match(And, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) { 1854 Constant *MinSignedC = ConstantInt::get( 1855 X->getType(), 1856 APInt::getSignedMinValue(X->getType()->getScalarSizeInBits())); 1857 auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; 1858 return new ICmpInst(NewPred, X, MinSignedC); 1859 } 1860 } 1861 1862 // TODO: These all require that Y is constant too, so refactor with the above. 1863 1864 // Try to optimize things like "A[i] & 42 == 0" to index computations. 1865 Value *X = And->getOperand(0); 1866 Value *Y = And->getOperand(1); 1867 if (auto *C2 = dyn_cast<ConstantInt>(Y)) 1868 if (auto *LI = dyn_cast<LoadInst>(X)) 1869 if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) 1870 if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) 1871 if (Instruction *Res = 1872 foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2)) 1873 return Res; 1874 1875 if (!Cmp.isEquality()) 1876 return nullptr; 1877 1878 // X & -C == -C -> X > u ~C 1879 // X & -C != -C -> X <= u ~C 1880 // iff C is a power of 2 1881 if (Cmp.getOperand(1) == Y && C.isNegatedPowerOf2()) { 1882 auto NewPred = 1883 Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; 1884 return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); 1885 } 1886 1887 // If we are testing the intersection of 2 select-of-nonzero-constants with no 1888 // common bits set, it's the same as checking if exactly one select condition 1889 // is set: 1890 // ((A ? TC : FC) & (B ? TC : FC)) == 0 --> xor A, B 1891 // ((A ? TC : FC) & (B ? TC : FC)) != 0 --> not(xor A, B) 1892 // TODO: Generalize for non-constant values. 1893 // TODO: Handle signed/unsigned predicates. 1894 // TODO: Handle other bitwise logic connectors. 1895 // TODO: Extend to handle a non-zero compare constant. 1896 if (C.isZero() && (Pred == CmpInst::ICMP_EQ || And->hasOneUse())) { 1897 assert(Cmp.isEquality() && "Not expecting non-equality predicates"); 1898 Value *A, *B; 1899 const APInt *TC, *FC; 1900 if (match(X, m_Select(m_Value(A), m_APInt(TC), m_APInt(FC))) && 1901 match(Y, 1902 m_Select(m_Value(B), m_SpecificInt(*TC), m_SpecificInt(*FC))) && 1903 !TC->isZero() && !FC->isZero() && !TC->intersects(*FC)) { 1904 Value *R = Builder.CreateXor(A, B); 1905 if (Pred == CmpInst::ICMP_NE) 1906 R = Builder.CreateNot(R); 1907 return replaceInstUsesWith(Cmp, R); 1908 } 1909 } 1910 1911 // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X) 1912 // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X) 1913 // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X) 1914 // ((zext i1 X) & Y) != 1 --> !((trunc Y) & X) 1915 if (match(And, m_OneUse(m_c_And(m_OneUse(m_ZExt(m_Value(X))), m_Value(Y)))) && 1916 X->getType()->isIntOrIntVectorTy(1) && (C.isZero() || C.isOne())) { 1917 Value *TruncY = Builder.CreateTrunc(Y, X->getType()); 1918 if (C.isZero() ^ (Pred == CmpInst::ICMP_NE)) { 1919 Value *And = Builder.CreateAnd(TruncY, X); 1920 return BinaryOperator::CreateNot(And); 1921 } 1922 return BinaryOperator::CreateAnd(TruncY, X); 1923 } 1924 1925 return nullptr; 1926 } 1927 1928 /// Fold icmp eq/ne (or (xor/sub (X1, X2), xor/sub (X3, X4))), 0. 1929 static Value *foldICmpOrXorSubChain(ICmpInst &Cmp, BinaryOperator *Or, 1930 InstCombiner::BuilderTy &Builder) { 1931 // Are we using xors or subs to bitwise check for a pair or pairs of 1932 // (in)equalities? Convert to a shorter form that has more potential to be 1933 // folded even further. 1934 // ((X1 ^/- X2) || (X3 ^/- X4)) == 0 --> (X1 == X2) && (X3 == X4) 1935 // ((X1 ^/- X2) || (X3 ^/- X4)) != 0 --> (X1 != X2) || (X3 != X4) 1936 // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) == 0 --> 1937 // (X1 == X2) && (X3 == X4) && (X5 == X6) 1938 // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) != 0 --> 1939 // (X1 != X2) || (X3 != X4) || (X5 != X6) 1940 SmallVector<std::pair<Value *, Value *>, 2> CmpValues; 1941 SmallVector<Value *, 16> WorkList(1, Or); 1942 1943 while (!WorkList.empty()) { 1944 auto MatchOrOperatorArgument = [&](Value *OrOperatorArgument) { 1945 Value *Lhs, *Rhs; 1946 1947 if (match(OrOperatorArgument, 1948 m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) { 1949 CmpValues.emplace_back(Lhs, Rhs); 1950 return; 1951 } 1952 1953 if (match(OrOperatorArgument, 1954 m_OneUse(m_Sub(m_Value(Lhs), m_Value(Rhs))))) { 1955 CmpValues.emplace_back(Lhs, Rhs); 1956 return; 1957 } 1958 1959 WorkList.push_back(OrOperatorArgument); 1960 }; 1961 1962 Value *CurrentValue = WorkList.pop_back_val(); 1963 Value *OrOperatorLhs, *OrOperatorRhs; 1964 1965 if (!match(CurrentValue, 1966 m_Or(m_Value(OrOperatorLhs), m_Value(OrOperatorRhs)))) { 1967 return nullptr; 1968 } 1969 1970 MatchOrOperatorArgument(OrOperatorRhs); 1971 MatchOrOperatorArgument(OrOperatorLhs); 1972 } 1973 1974 ICmpInst::Predicate Pred = Cmp.getPredicate(); 1975 auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; 1976 Value *LhsCmp = Builder.CreateICmp(Pred, CmpValues.rbegin()->first, 1977 CmpValues.rbegin()->second); 1978 1979 for (auto It = CmpValues.rbegin() + 1; It != CmpValues.rend(); ++It) { 1980 Value *RhsCmp = Builder.CreateICmp(Pred, It->first, It->second); 1981 LhsCmp = Builder.CreateBinOp(BOpc, LhsCmp, RhsCmp); 1982 } 1983 1984 return LhsCmp; 1985 } 1986 1987 /// Fold icmp (or X, Y), C. 1988 Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, 1989 BinaryOperator *Or, 1990 const APInt &C) { 1991 ICmpInst::Predicate Pred = Cmp.getPredicate(); 1992 if (C.isOne()) { 1993 // icmp slt signum(V) 1 --> icmp slt V, 1 1994 Value *V = nullptr; 1995 if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) 1996 return new ICmpInst(ICmpInst::ICMP_SLT, V, 1997 ConstantInt::get(V->getType(), 1)); 1998 } 1999 2000 Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); 2001 const APInt *MaskC; 2002 if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) { 2003 if (*MaskC == C && (C + 1).isPowerOf2()) { 2004 // X | C == C --> X <=u C 2005 // X | C != C --> X >u C 2006 // iff C+1 is a power of 2 (C is a bitmask of the low bits) 2007 Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; 2008 return new ICmpInst(Pred, OrOp0, OrOp1); 2009 } 2010 2011 // More general: canonicalize 'equality with set bits mask' to 2012 // 'equality with clear bits mask'. 2013 // (X | MaskC) == C --> (X & ~MaskC) == C ^ MaskC 2014 // (X | MaskC) != C --> (X & ~MaskC) != C ^ MaskC 2015 if (Or->hasOneUse()) { 2016 Value *And = Builder.CreateAnd(OrOp0, ~(*MaskC)); 2017 Constant *NewC = ConstantInt::get(Or->getType(), C ^ (*MaskC)); 2018 return new ICmpInst(Pred, And, NewC); 2019 } 2020 } 2021 2022 // (X | (X-1)) s< 0 --> X s< 1 2023 // (X | (X-1)) s> -1 --> X s> 0 2024 Value *X; 2025 bool TrueIfSigned; 2026 if (isSignBitCheck(Pred, C, TrueIfSigned) && 2027 match(Or, m_c_Or(m_Add(m_Value(X), m_AllOnes()), m_Deferred(X)))) { 2028 auto NewPred = TrueIfSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT; 2029 Constant *NewC = ConstantInt::get(X->getType(), TrueIfSigned ? 1 : 0); 2030 return new ICmpInst(NewPred, X, NewC); 2031 } 2032 2033 const APInt *OrC; 2034 // icmp(X | OrC, C) --> icmp(X, 0) 2035 if (C.isNonNegative() && match(Or, m_Or(m_Value(X), m_APInt(OrC)))) { 2036 switch (Pred) { 2037 // X | OrC s< C --> X s< 0 iff OrC s>= C s>= 0 2038 case ICmpInst::ICMP_SLT: 2039 // X | OrC s>= C --> X s>= 0 iff OrC s>= C s>= 0 2040 case ICmpInst::ICMP_SGE: 2041 if (OrC->sge(C)) 2042 return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); 2043 break; 2044 // X | OrC s<= C --> X s< 0 iff OrC s> C s>= 0 2045 case ICmpInst::ICMP_SLE: 2046 // X | OrC s> C --> X s>= 0 iff OrC s> C s>= 0 2047 case ICmpInst::ICMP_SGT: 2048 if (OrC->sgt(C)) 2049 return new ICmpInst(ICmpInst::getFlippedStrictnessPredicate(Pred), X, 2050 ConstantInt::getNullValue(X->getType())); 2051 break; 2052 default: 2053 break; 2054 } 2055 } 2056 2057 if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) 2058 return nullptr; 2059 2060 Value *P, *Q; 2061 if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { 2062 // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 2063 // -> and (icmp eq P, null), (icmp eq Q, null). 2064 Value *CmpP = 2065 Builder.CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); 2066 Value *CmpQ = 2067 Builder.CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); 2068 auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; 2069 return BinaryOperator::Create(BOpc, CmpP, CmpQ); 2070 } 2071 2072 if (Value *V = foldICmpOrXorSubChain(Cmp, Or, Builder)) 2073 return replaceInstUsesWith(Cmp, V); 2074 2075 return nullptr; 2076 } 2077 2078 /// Fold icmp (mul X, Y), C. 2079 Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, 2080 BinaryOperator *Mul, 2081 const APInt &C) { 2082 ICmpInst::Predicate Pred = Cmp.getPredicate(); 2083 Type *MulTy = Mul->getType(); 2084 Value *X = Mul->getOperand(0); 2085 2086 // If there's no overflow: 2087 // X * X == 0 --> X == 0 2088 // X * X != 0 --> X != 0 2089 if (Cmp.isEquality() && C.isZero() && X == Mul->getOperand(1) && 2090 (Mul->hasNoUnsignedWrap() || Mul->hasNoSignedWrap())) 2091 return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); 2092 2093 const APInt *MulC; 2094 if (!match(Mul->getOperand(1), m_APInt(MulC))) 2095 return nullptr; 2096 2097 // If this is a test of the sign bit and the multiply is sign-preserving with 2098 // a constant operand, use the multiply LHS operand instead: 2099 // (X * +MulC) < 0 --> X < 0 2100 // (X * -MulC) < 0 --> X > 0 2101 if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { 2102 if (MulC->isNegative()) 2103 Pred = ICmpInst::getSwappedPredicate(Pred); 2104 return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); 2105 } 2106 2107 if (MulC->isZero()) 2108 return nullptr; 2109 2110 // If the multiply does not wrap or the constant is odd, try to divide the 2111 // compare constant by the multiplication factor. 2112 if (Cmp.isEquality()) { 2113 // (mul nsw X, MulC) eq/ne C --> X eq/ne C /s MulC 2114 if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { 2115 Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); 2116 return new ICmpInst(Pred, X, NewC); 2117 } 2118 2119 // C % MulC == 0 is weaker than we could use if MulC is odd because it 2120 // correct to transform if MulC * N == C including overflow. I.e with i8 2121 // (icmp eq (mul X, 5), 101) -> (icmp eq X, 225) but since 101 % 5 != 0, we 2122 // miss that case. 2123 if (C.urem(*MulC).isZero()) { 2124 // (mul nuw X, MulC) eq/ne C --> X eq/ne C /u MulC 2125 // (mul X, OddC) eq/ne N * C --> X eq/ne N 2126 if ((*MulC & 1).isOne() || Mul->hasNoUnsignedWrap()) { 2127 Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); 2128 return new ICmpInst(Pred, X, NewC); 2129 } 2130 } 2131 } 2132 2133 // With a matching no-overflow guarantee, fold the constants: 2134 // (X * MulC) < C --> X < (C / MulC) 2135 // (X * MulC) > C --> X > (C / MulC) 2136 // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE? 2137 Constant *NewC = nullptr; 2138 if (Mul->hasNoSignedWrap() && ICmpInst::isSigned(Pred)) { 2139 // MININT / -1 --> overflow. 2140 if (C.isMinSignedValue() && MulC->isAllOnes()) 2141 return nullptr; 2142 if (MulC->isNegative()) 2143 Pred = ICmpInst::getSwappedPredicate(Pred); 2144 2145 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { 2146 NewC = ConstantInt::get( 2147 MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); 2148 } else { 2149 assert((Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) && 2150 "Unexpected predicate"); 2151 NewC = ConstantInt::get( 2152 MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); 2153 } 2154 } else if (Mul->hasNoUnsignedWrap() && ICmpInst::isUnsigned(Pred)) { 2155 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) { 2156 NewC = ConstantInt::get( 2157 MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); 2158 } else { 2159 assert((Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && 2160 "Unexpected predicate"); 2161 NewC = ConstantInt::get( 2162 MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); 2163 } 2164 } 2165 2166 return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; 2167 } 2168 2169 /// Fold icmp (shl 1, Y), C. 2170 static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, 2171 const APInt &C) { 2172 Value *Y; 2173 if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) 2174 return nullptr; 2175 2176 Type *ShiftType = Shl->getType(); 2177 unsigned TypeBits = C.getBitWidth(); 2178 bool CIsPowerOf2 = C.isPowerOf2(); 2179 ICmpInst::Predicate Pred = Cmp.getPredicate(); 2180 if (Cmp.isUnsigned()) { 2181 // (1 << Y) pred C -> Y pred Log2(C) 2182 if (!CIsPowerOf2) { 2183 // (1 << Y) < 30 -> Y <= 4 2184 // (1 << Y) <= 30 -> Y <= 4 2185 // (1 << Y) >= 30 -> Y > 4 2186 // (1 << Y) > 30 -> Y > 4 2187 if (Pred == ICmpInst::ICMP_ULT) 2188 Pred = ICmpInst::ICMP_ULE; 2189 else if (Pred == ICmpInst::ICMP_UGE) 2190 Pred = ICmpInst::ICMP_UGT; 2191 } 2192 2193 unsigned CLog2 = C.logBase2(); 2194 return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); 2195 } else if (Cmp.isSigned()) { 2196 Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); 2197 // (1 << Y) > 0 -> Y != 31 2198 // (1 << Y) > C -> Y != 31 if C is negative. 2199 if (Pred == ICmpInst::ICMP_SGT && C.sle(0)) 2200 return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); 2201 2202 // (1 << Y) < 0 -> Y == 31 2203 // (1 << Y) < 1 -> Y == 31 2204 // (1 << Y) < C -> Y == 31 if C is negative and not signed min. 2205 // Exclude signed min by subtracting 1 and lower the upper bound to 0. 2206 if (Pred == ICmpInst::ICMP_SLT && (C-1).sle(0)) 2207 return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); 2208 } 2209 2210 return nullptr; 2211 } 2212 2213 /// Fold icmp (shl X, Y), C. 2214 Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, 2215 BinaryOperator *Shl, 2216 const APInt &C) { 2217 const APInt *ShiftVal; 2218 if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) 2219 return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); 2220 2221 ICmpInst::Predicate Pred = Cmp.getPredicate(); 2222 // (icmp pred (shl nuw&nsw X, Y), Csle0) 2223 // -> (icmp pred X, Csle0) 2224 // 2225 // The idea is the nuw/nsw essentially freeze the sign bit for the shift op 2226 // so X's must be what is used. 2227 if (C.sle(0) && Shl->hasNoUnsignedWrap() && Shl->hasNoSignedWrap()) 2228 return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); 2229 2230 // (icmp eq/ne (shl nuw|nsw X, Y), 0) 2231 // -> (icmp eq/ne X, 0) 2232 if (ICmpInst::isEquality(Pred) && C.isZero() && 2233 (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap())) 2234 return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); 2235 2236 // (icmp slt (shl nsw X, Y), 0/1) 2237 // -> (icmp slt X, 0/1) 2238 // (icmp sgt (shl nsw X, Y), 0/-1) 2239 // -> (icmp sgt X, 0/-1) 2240 // 2241 // NB: sge/sle with a constant will canonicalize to sgt/slt. 2242 if (Shl->hasNoSignedWrap() && 2243 (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) 2244 if (C.isZero() || (Pred == ICmpInst::ICMP_SGT ? C.isAllOnes() : C.isOne())) 2245 return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); 2246 2247 const APInt *ShiftAmt; 2248 if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) 2249 return foldICmpShlOne(Cmp, Shl, C); 2250 2251 // Check that the shift amount is in range. If not, don't perform undefined 2252 // shifts. When the shift is visited, it will be simplified. 2253 unsigned TypeBits = C.getBitWidth(); 2254 if (ShiftAmt->uge(TypeBits)) 2255 return nullptr; 2256 2257 Value *X = Shl->getOperand(0); 2258 Type *ShType = Shl->getType(); 2259 2260 // NSW guarantees that we are only shifting out sign bits from the high bits, 2261 // so we can ASHR the compare constant without needing a mask and eliminate 2262 // the shift. 2263 if (Shl->hasNoSignedWrap()) { 2264 if (Pred == ICmpInst::ICMP_SGT) { 2265 // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) 2266 APInt ShiftedC = C.ashr(*ShiftAmt); 2267 return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); 2268 } 2269 if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && 2270 C.ashr(*ShiftAmt).shl(*ShiftAmt) == C) { 2271 APInt ShiftedC = C.ashr(*ShiftAmt); 2272 return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); 2273 } 2274 if (Pred == ICmpInst::ICMP_SLT) { 2275 // SLE is the same as above, but SLE is canonicalized to SLT, so convert: 2276 // (X << S) <=s C is equiv to X <=s (C >> S) for all C 2277 // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX 2278 // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN 2279 assert(!C.isMinSignedValue() && "Unexpected icmp slt"); 2280 APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; 2281 return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); 2282 } 2283 } 2284 2285 // NUW guarantees that we are only shifting out zero bits from the high bits, 2286 // so we can LSHR the compare constant without needing a mask and eliminate 2287 // the shift. 2288 if (Shl->hasNoUnsignedWrap()) { 2289 if (Pred == ICmpInst::ICMP_UGT) { 2290 // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) 2291 APInt ShiftedC = C.lshr(*ShiftAmt); 2292 return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); 2293 } 2294 if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && 2295 C.lshr(*ShiftAmt).shl(*ShiftAmt) == C) { 2296 APInt ShiftedC = C.lshr(*ShiftAmt); 2297 return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); 2298 } 2299 if (Pred == ICmpInst::ICMP_ULT) { 2300 // ULE is the same as above, but ULE is canonicalized to ULT, so convert: 2301 // (X << S) <=u C is equiv to X <=u (C >> S) for all C 2302 // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u 2303 // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 2304 assert(C.ugt(0) && "ult 0 should have been eliminated"); 2305 APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1; 2306 return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); 2307 } 2308 } 2309 2310 if (Cmp.isEquality() && Shl->hasOneUse()) { 2311 // Strength-reduce the shift into an 'and'. 2312 Constant *Mask = ConstantInt::get( 2313 ShType, 2314 APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); 2315 Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); 2316 Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt)); 2317 return new ICmpInst(Pred, And, LShrC); 2318 } 2319 2320 // Otherwise, if this is a comparison of the sign bit, simplify to and/test. 2321 bool TrueIfSigned = false; 2322 if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) { 2323 // (X << 31) <s 0 --> (X & 1) != 0 2324 Constant *Mask = ConstantInt::get( 2325 ShType, 2326 APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); 2327 Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); 2328 return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, 2329 And, Constant::getNullValue(ShType)); 2330 } 2331 2332 // Simplify 'shl' inequality test into 'and' equality test. 2333 if (Cmp.isUnsigned() && Shl->hasOneUse()) { 2334 // (X l<< C2) u<=/u> C1 iff C1+1 is power of two -> X & (~C1 l>> C2) ==/!= 0 2335 if ((C + 1).isPowerOf2() && 2336 (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)) { 2337 Value *And = Builder.CreateAnd(X, (~C).lshr(ShiftAmt->getZExtValue())); 2338 return new ICmpInst(Pred == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_EQ 2339 : ICmpInst::ICMP_NE, 2340 And, Constant::getNullValue(ShType)); 2341 } 2342 // (X l<< C2) u</u>= C1 iff C1 is power of two -> X & (-C1 l>> C2) ==/!= 0 2343 if (C.isPowerOf2() && 2344 (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { 2345 Value *And = 2346 Builder.CreateAnd(X, (~(C - 1)).lshr(ShiftAmt->getZExtValue())); 2347 return new ICmpInst(Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_EQ 2348 : ICmpInst::ICMP_NE, 2349 And, Constant::getNullValue(ShType)); 2350 } 2351 } 2352 2353 // Transform (icmp pred iM (shl iM %v, N), C) 2354 // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) 2355 // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. 2356 // This enables us to get rid of the shift in favor of a trunc that may be 2357 // free on the target. It has the additional benefit of comparing to a 2358 // smaller constant that may be more target-friendly. 2359 unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); 2360 if (Shl->hasOneUse() && Amt != 0 && C.countr_zero() >= Amt && 2361 DL.isLegalInteger(TypeBits - Amt)) { 2362 Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); 2363 if (auto *ShVTy = dyn_cast<VectorType>(ShType)) 2364 TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount()); 2365 Constant *NewC = 2366 ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); 2367 return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); 2368 } 2369 2370 return nullptr; 2371 } 2372 2373 /// Fold icmp ({al}shr X, Y), C. 2374 Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, 2375 BinaryOperator *Shr, 2376 const APInt &C) { 2377 // An exact shr only shifts out zero bits, so: 2378 // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 2379 Value *X = Shr->getOperand(0); 2380 CmpInst::Predicate Pred = Cmp.getPredicate(); 2381 if (Cmp.isEquality() && Shr->isExact() && C.isZero()) 2382 return new ICmpInst(Pred, X, Cmp.getOperand(1)); 2383 2384 bool IsAShr = Shr->getOpcode() == Instruction::AShr; 2385 const APInt *ShiftValC; 2386 if (match(X, m_APInt(ShiftValC))) { 2387 if (Cmp.isEquality()) 2388 return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC); 2389 2390 // (ShiftValC >> Y) >s -1 --> Y != 0 with ShiftValC < 0 2391 // (ShiftValC >> Y) <s 0 --> Y == 0 with ShiftValC < 0 2392 bool TrueIfSigned; 2393 if (!IsAShr && ShiftValC->isNegative() && 2394 isSignBitCheck(Pred, C, TrueIfSigned)) 2395 return new ICmpInst(TrueIfSigned ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE, 2396 Shr->getOperand(1), 2397 ConstantInt::getNullValue(X->getType())); 2398 2399 // If the shifted constant is a power-of-2, test the shift amount directly: 2400 // (ShiftValC >> Y) >u C --> X <u (LZ(C) - LZ(ShiftValC)) 2401 // (ShiftValC >> Y) <u C --> X >=u (LZ(C-1) - LZ(ShiftValC)) 2402 if (!IsAShr && ShiftValC->isPowerOf2() && 2403 (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_ULT)) { 2404 bool IsUGT = Pred == CmpInst::ICMP_UGT; 2405 assert(ShiftValC->uge(C) && "Expected simplify of compare"); 2406 assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify"); 2407 2408 unsigned CmpLZ = IsUGT ? C.countl_zero() : (C - 1).countl_zero(); 2409 unsigned ShiftLZ = ShiftValC->countl_zero(); 2410 Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ); 2411 auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; 2412 return new ICmpInst(NewPred, Shr->getOperand(1), NewC); 2413 } 2414 } 2415 2416 const APInt *ShiftAmtC; 2417 if (!match(Shr->getOperand(1), m_APInt(ShiftAmtC))) 2418 return nullptr; 2419 2420 // Check that the shift amount is in range. If not, don't perform undefined 2421 // shifts. When the shift is visited it will be simplified. 2422 unsigned TypeBits = C.getBitWidth(); 2423 unsigned ShAmtVal = ShiftAmtC->getLimitedValue(TypeBits); 2424 if (ShAmtVal >= TypeBits || ShAmtVal == 0) 2425 return nullptr; 2426 2427 bool IsExact = Shr->isExact(); 2428 Type *ShrTy = Shr->getType(); 2429 // TODO: If we could guarantee that InstSimplify would handle all of the 2430 // constant-value-based preconditions in the folds below, then we could assert 2431 // those conditions rather than checking them. This is difficult because of 2432 // undef/poison (PR34838). 2433 if (IsAShr && Shr->hasOneUse()) { 2434 if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) { 2435 // When ShAmtC can be shifted losslessly: 2436 // icmp PRED (ashr exact X, ShAmtC), C --> icmp PRED X, (C << ShAmtC) 2437 // icmp slt/ult (ashr X, ShAmtC), C --> icmp slt/ult X, (C << ShAmtC) 2438 APInt ShiftedC = C.shl(ShAmtVal); 2439 if (ShiftedC.ashr(ShAmtVal) == C) 2440 return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); 2441 } 2442 if (Pred == CmpInst::ICMP_SGT) { 2443 // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1 2444 APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; 2445 if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() && 2446 (ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) 2447 return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); 2448 } 2449 if (Pred == CmpInst::ICMP_UGT) { 2450 // icmp ugt (ashr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 2451 // 'C + 1 << ShAmtC' can overflow as a signed number, so the 2nd 2452 // clause accounts for that pattern. 2453 APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; 2454 if ((ShiftedC + 1).ashr(ShAmtVal) == (C + 1) || 2455 (C + 1).shl(ShAmtVal).isMinSignedValue()) 2456 return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); 2457 } 2458 2459 // If the compare constant has significant bits above the lowest sign-bit, 2460 // then convert an unsigned cmp to a test of the sign-bit: 2461 // (ashr X, ShiftC) u> C --> X s< 0 2462 // (ashr X, ShiftC) u< C --> X s> -1 2463 if (C.getBitWidth() > 2 && C.getNumSignBits() <= ShAmtVal) { 2464 if (Pred == CmpInst::ICMP_UGT) { 2465 return new ICmpInst(CmpInst::ICMP_SLT, X, 2466 ConstantInt::getNullValue(ShrTy)); 2467 } 2468 if (Pred == CmpInst::ICMP_ULT) { 2469 return new ICmpInst(CmpInst::ICMP_SGT, X, 2470 ConstantInt::getAllOnesValue(ShrTy)); 2471 } 2472 } 2473 } else if (!IsAShr) { 2474 if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { 2475 // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) 2476 // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) 2477 APInt ShiftedC = C.shl(ShAmtVal); 2478 if (ShiftedC.lshr(ShAmtVal) == C) 2479 return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); 2480 } 2481 if (Pred == CmpInst::ICMP_UGT) { 2482 // icmp ugt (lshr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 2483 APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; 2484 if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1)) 2485 return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); 2486 } 2487 } 2488 2489 if (!Cmp.isEquality()) 2490 return nullptr; 2491 2492 // Handle equality comparisons of shift-by-constant. 2493 2494 // If the comparison constant changes with the shift, the comparison cannot 2495 // succeed (bits of the comparison constant cannot match the shifted value). 2496 // This should be known by InstSimplify and already be folded to true/false. 2497 assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) || 2498 (!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) && 2499 "Expected icmp+shr simplify did not occur."); 2500 2501 // If the bits shifted out are known zero, compare the unshifted value: 2502 // (X & 4) >> 1 == 2 --> (X & 4) == 4. 2503 if (Shr->isExact()) 2504 return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); 2505 2506 if (C.isZero()) { 2507 // == 0 is u< 1. 2508 if (Pred == CmpInst::ICMP_EQ) 2509 return new ICmpInst(CmpInst::ICMP_ULT, X, 2510 ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal))); 2511 else 2512 return new ICmpInst(CmpInst::ICMP_UGT, X, 2513 ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1)); 2514 } 2515 2516 if (Shr->hasOneUse()) { 2517 // Canonicalize the shift into an 'and': 2518 // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) 2519 APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); 2520 Constant *Mask = ConstantInt::get(ShrTy, Val); 2521 Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask"); 2522 return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal)); 2523 } 2524 2525 return nullptr; 2526 } 2527 2528 Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, 2529 BinaryOperator *SRem, 2530 const APInt &C) { 2531 // Match an 'is positive' or 'is negative' comparison of remainder by a 2532 // constant power-of-2 value: 2533 // (X % pow2C) sgt/slt 0 2534 const ICmpInst::Predicate Pred = Cmp.getPredicate(); 2535 if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT && 2536 Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE) 2537 return nullptr; 2538 2539 // TODO: The one-use check is standard because we do not typically want to 2540 // create longer instruction sequences, but this might be a special-case 2541 // because srem is not good for analysis or codegen. 2542 if (!SRem->hasOneUse()) 2543 return nullptr; 2544 2545 const APInt *DivisorC; 2546 if (!match(SRem->getOperand(1), m_Power2(DivisorC))) 2547 return nullptr; 2548 2549 // For cmp_sgt/cmp_slt only zero valued C is handled. 2550 // For cmp_eq/cmp_ne only positive valued C is handled. 2551 if (((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT) && 2552 !C.isZero()) || 2553 ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && 2554 !C.isStrictlyPositive())) 2555 return nullptr; 2556 2557 // Mask off the sign bit and the modulo bits (low-bits). 2558 Type *Ty = SRem->getType(); 2559 APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits()); 2560 Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); 2561 Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); 2562 2563 if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) 2564 return new ICmpInst(Pred, And, ConstantInt::get(Ty, C)); 2565 2566 // For 'is positive?' check that the sign-bit is clear and at least 1 masked 2567 // bit is set. Example: 2568 // (i8 X % 32) s> 0 --> (X & 159) s> 0 2569 if (Pred == ICmpInst::ICMP_SGT) 2570 return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty)); 2571 2572 // For 'is negative?' check that the sign-bit is set and at least 1 masked 2573 // bit is set. Example: 2574 // (i16 X % 4) s< 0 --> (X & 32771) u> 32768 2575 return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask)); 2576 } 2577 2578 /// Fold icmp (udiv X, Y), C. 2579 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp, 2580 BinaryOperator *UDiv, 2581 const APInt &C) { 2582 ICmpInst::Predicate Pred = Cmp.getPredicate(); 2583 Value *X = UDiv->getOperand(0); 2584 Value *Y = UDiv->getOperand(1); 2585 Type *Ty = UDiv->getType(); 2586 2587 const APInt *C2; 2588 if (!match(X, m_APInt(C2))) 2589 return nullptr; 2590 2591 assert(*C2 != 0 && "udiv 0, X should have been simplified already."); 2592 2593 // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) 2594 if (Pred == ICmpInst::ICMP_UGT) { 2595 assert(!C.isMaxValue() && 2596 "icmp ugt X, UINT_MAX should have been simplified already."); 2597 return new ICmpInst(ICmpInst::ICMP_ULE, Y, 2598 ConstantInt::get(Ty, C2->udiv(C + 1))); 2599 } 2600 2601 // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) 2602 if (Pred == ICmpInst::ICMP_ULT) { 2603 assert(C != 0 && "icmp ult X, 0 should have been simplified already."); 2604 return new ICmpInst(ICmpInst::ICMP_UGT, Y, 2605 ConstantInt::get(Ty, C2->udiv(C))); 2606 } 2607 2608 return nullptr; 2609 } 2610 2611 /// Fold icmp ({su}div X, Y), C. 2612 Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, 2613 BinaryOperator *Div, 2614 const APInt &C) { 2615 ICmpInst::Predicate Pred = Cmp.getPredicate(); 2616 Value *X = Div->getOperand(0); 2617 Value *Y = Div->getOperand(1); 2618 Type *Ty = Div->getType(); 2619 bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; 2620 2621 // If unsigned division and the compare constant is bigger than 2622 // UMAX/2 (negative), there's only one pair of values that satisfies an 2623 // equality check, so eliminate the division: 2624 // (X u/ Y) == C --> (X == C) && (Y == 1) 2625 // (X u/ Y) != C --> (X != C) || (Y != 1) 2626 // Similarly, if signed division and the compare constant is exactly SMIN: 2627 // (X s/ Y) == SMIN --> (X == SMIN) && (Y == 1) 2628 // (X s/ Y) != SMIN --> (X != SMIN) || (Y != 1) 2629 if (Cmp.isEquality() && Div->hasOneUse() && C.isSignBitSet() && 2630 (!DivIsSigned || C.isMinSignedValue())) { 2631 Value *XBig = Builder.CreateICmp(Pred, X, ConstantInt::get(Ty, C)); 2632 Value *YOne = Builder.CreateICmp(Pred, Y, ConstantInt::get(Ty, 1)); 2633 auto Logic = Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; 2634 return BinaryOperator::Create(Logic, XBig, YOne); 2635 } 2636 2637 // Fold: icmp pred ([us]div X, C2), C -> range test 2638 // Fold this div into the comparison, producing a range check. 2639 // Determine, based on the divide type, what the range is being 2640 // checked. If there is an overflow on the low or high side, remember 2641 // it, otherwise compute the range [low, hi) bounding the new value. 2642 // See: InsertRangeTest above for the kinds of replacements possible. 2643 const APInt *C2; 2644 if (!match(Y, m_APInt(C2))) 2645 return nullptr; 2646 2647 // FIXME: If the operand types don't match the type of the divide 2648 // then don't attempt this transform. The code below doesn't have the 2649 // logic to deal with a signed divide and an unsigned compare (and 2650 // vice versa). This is because (x /s C2) <s C produces different 2651 // results than (x /s C2) <u C or (x /u C2) <s C or even 2652 // (x /u C2) <u C. Simply casting the operands and result won't 2653 // work. :( The if statement below tests that condition and bails 2654 // if it finds it. 2655 if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned()) 2656 return nullptr; 2657 2658 // The ProdOV computation fails on divide by 0 and divide by -1. Cases with 2659 // INT_MIN will also fail if the divisor is 1. Although folds of all these 2660 // division-by-constant cases should be present, we can not assert that they 2661 // have happened before we reach this icmp instruction. 2662 if (C2->isZero() || C2->isOne() || (DivIsSigned && C2->isAllOnes())) 2663 return nullptr; 2664 2665 // Compute Prod = C * C2. We are essentially solving an equation of 2666 // form X / C2 = C. We solve for X by multiplying C2 and C. 2667 // By solving for X, we can turn this into a range check instead of computing 2668 // a divide. 2669 APInt Prod = C * *C2; 2670 2671 // Determine if the product overflows by seeing if the product is not equal to 2672 // the divide. Make sure we do the same kind of divide as in the LHS 2673 // instruction that we're folding. 2674 bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; 2675 2676 // If the division is known to be exact, then there is no remainder from the 2677 // divide, so the covered range size is unit, otherwise it is the divisor. 2678 APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; 2679 2680 // Figure out the interval that is being checked. For example, a comparison 2681 // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). 2682 // Compute this interval based on the constants involved and the signedness of 2683 // the compare/divide. This computes a half-open interval, keeping track of 2684 // whether either value in the interval overflows. After analysis each 2685 // overflow variable is set to 0 if it's corresponding bound variable is valid 2686 // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. 2687 int LoOverflow = 0, HiOverflow = 0; 2688 APInt LoBound, HiBound; 2689 2690 if (!DivIsSigned) { // udiv 2691 // e.g. X/5 op 3 --> [15, 20) 2692 LoBound = Prod; 2693 HiOverflow = LoOverflow = ProdOV; 2694 if (!HiOverflow) { 2695 // If this is not an exact divide, then many values in the range collapse 2696 // to the same result value. 2697 HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); 2698 } 2699 } else if (C2->isStrictlyPositive()) { // Divisor is > 0. 2700 if (C.isZero()) { // (X / pos) op 0 2701 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) 2702 LoBound = -(RangeSize - 1); 2703 HiBound = RangeSize; 2704 } else if (C.isStrictlyPositive()) { // (X / pos) op pos 2705 LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) 2706 HiOverflow = LoOverflow = ProdOV; 2707 if (!HiOverflow) 2708 HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); 2709 } else { // (X / pos) op neg 2710 // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) 2711 HiBound = Prod + 1; 2712 LoOverflow = HiOverflow = ProdOV ? -1 : 0; 2713 if (!LoOverflow) { 2714 APInt DivNeg = -RangeSize; 2715 LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; 2716 } 2717 } 2718 } else if (C2->isNegative()) { // Divisor is < 0. 2719 if (Div->isExact()) 2720 RangeSize.negate(); 2721 if (C.isZero()) { // (X / neg) op 0 2722 // e.g. X/-5 op 0 --> [-4, 5) 2723 LoBound = RangeSize + 1; 2724 HiBound = -RangeSize; 2725 if (HiBound == *C2) { // -INTMIN = INTMIN 2726 HiOverflow = 1; // [INTMIN+1, overflow) 2727 HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN 2728 } 2729 } else if (C.isStrictlyPositive()) { // (X / neg) op pos 2730 // e.g. X/-5 op 3 --> [-19, -14) 2731 HiBound = Prod + 1; 2732 HiOverflow = LoOverflow = ProdOV ? -1 : 0; 2733 if (!LoOverflow) 2734 LoOverflow = 2735 addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1 : 0; 2736 } else { // (X / neg) op neg 2737 LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) 2738 LoOverflow = HiOverflow = ProdOV; 2739 if (!HiOverflow) 2740 HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); 2741 } 2742 2743 // Dividing by a negative swaps the condition. LT <-> GT 2744 Pred = ICmpInst::getSwappedPredicate(Pred); 2745 } 2746 2747 switch (Pred) { 2748 default: 2749 llvm_unreachable("Unhandled icmp predicate!"); 2750 case ICmpInst::ICMP_EQ: 2751 if (LoOverflow && HiOverflow) 2752 return replaceInstUsesWith(Cmp, Builder.getFalse()); 2753 if (HiOverflow) 2754 return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, 2755 X, ConstantInt::get(Ty, LoBound)); 2756 if (LoOverflow) 2757 return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, 2758 X, ConstantInt::get(Ty, HiBound)); 2759 return replaceInstUsesWith( 2760 Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); 2761 case ICmpInst::ICMP_NE: 2762 if (LoOverflow && HiOverflow) 2763 return replaceInstUsesWith(Cmp, Builder.getTrue()); 2764 if (HiOverflow) 2765 return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, 2766 X, ConstantInt::get(Ty, LoBound)); 2767 if (LoOverflow) 2768 return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, 2769 X, ConstantInt::get(Ty, HiBound)); 2770 return replaceInstUsesWith( 2771 Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); 2772 case ICmpInst::ICMP_ULT: 2773 case ICmpInst::ICMP_SLT: 2774 if (LoOverflow == +1) // Low bound is greater than input range. 2775 return replaceInstUsesWith(Cmp, Builder.getTrue()); 2776 if (LoOverflow == -1) // Low bound is less than input range. 2777 return replaceInstUsesWith(Cmp, Builder.getFalse()); 2778 return new ICmpInst(Pred, X, ConstantInt::get(Ty, LoBound)); 2779 case ICmpInst::ICMP_UGT: 2780 case ICmpInst::ICMP_SGT: 2781 if (HiOverflow == +1) // High bound greater than input range. 2782 return replaceInstUsesWith(Cmp, Builder.getFalse()); 2783 if (HiOverflow == -1) // High bound less than input range. 2784 return replaceInstUsesWith(Cmp, Builder.getTrue()); 2785 if (Pred == ICmpInst::ICMP_UGT) 2786 return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, HiBound)); 2787 return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, HiBound)); 2788 } 2789 2790 return nullptr; 2791 } 2792 2793 /// Fold icmp (sub X, Y), C. 2794 Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, 2795 BinaryOperator *Sub, 2796 const APInt &C) { 2797 Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); 2798 ICmpInst::Predicate Pred = Cmp.getPredicate(); 2799 Type *Ty = Sub->getType(); 2800 2801 // (SubC - Y) == C) --> Y == (SubC - C) 2802 // (SubC - Y) != C) --> Y != (SubC - C) 2803 Constant *SubC; 2804 if (Cmp.isEquality() && match(X, m_ImmConstant(SubC))) { 2805 return new ICmpInst(Pred, Y, 2806 ConstantExpr::getSub(SubC, ConstantInt::get(Ty, C))); 2807 } 2808 2809 // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) 2810 const APInt *C2; 2811 APInt SubResult; 2812 ICmpInst::Predicate SwappedPred = Cmp.getSwappedPredicate(); 2813 bool HasNSW = Sub->hasNoSignedWrap(); 2814 bool HasNUW = Sub->hasNoUnsignedWrap(); 2815 if (match(X, m_APInt(C2)) && 2816 ((Cmp.isUnsigned() && HasNUW) || (Cmp.isSigned() && HasNSW)) && 2817 !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) 2818 return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult)); 2819 2820 // X - Y == 0 --> X == Y. 2821 // X - Y != 0 --> X != Y. 2822 // TODO: We allow this with multiple uses as long as the other uses are not 2823 // in phis. The phi use check is guarding against a codegen regression 2824 // for a loop test. If the backend could undo this (and possibly 2825 // subsequent transforms), we would not need this hack. 2826 if (Cmp.isEquality() && C.isZero() && 2827 none_of((Sub->users()), [](const User *U) { return isa<PHINode>(U); })) 2828 return new ICmpInst(Pred, X, Y); 2829 2830 // The following transforms are only worth it if the only user of the subtract 2831 // is the icmp. 2832 // TODO: This is an artificial restriction for all of the transforms below 2833 // that only need a single replacement icmp. Can these use the phi test 2834 // like the transform above here? 2835 if (!Sub->hasOneUse()) 2836 return nullptr; 2837 2838 if (Sub->hasNoSignedWrap()) { 2839 // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) 2840 if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) 2841 return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); 2842 2843 // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) 2844 if (Pred == ICmpInst::ICMP_SGT && C.isZero()) 2845 return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); 2846 2847 // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) 2848 if (Pred == ICmpInst::ICMP_SLT && C.isZero()) 2849 return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); 2850 2851 // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) 2852 if (Pred == ICmpInst::ICMP_SLT && C.isOne()) 2853 return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); 2854 } 2855 2856 if (!match(X, m_APInt(C2))) 2857 return nullptr; 2858 2859 // C2 - Y <u C -> (Y | (C - 1)) == C2 2860 // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 2861 if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && 2862 (*C2 & (C - 1)) == (C - 1)) 2863 return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X); 2864 2865 // C2 - Y >u C -> (Y | C) != C2 2866 // iff C2 & C == C and C + 1 is a power of 2 2867 if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) 2868 return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); 2869 2870 // We have handled special cases that reduce. 2871 // Canonicalize any remaining sub to add as: 2872 // (C2 - Y) > C --> (Y + ~C2) < ~C 2873 Value *Add = Builder.CreateAdd(Y, ConstantInt::get(Ty, ~(*C2)), "notsub", 2874 HasNUW, HasNSW); 2875 return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); 2876 } 2877 2878 static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0, 2879 Value *Op1, IRBuilderBase &Builder, 2880 bool HasOneUse) { 2881 auto FoldConstant = [&](bool Val) { 2882 Constant *Res = Val ? Builder.getTrue() : Builder.getFalse(); 2883 if (Op0->getType()->isVectorTy()) 2884 Res = ConstantVector::getSplat( 2885 cast<VectorType>(Op0->getType())->getElementCount(), Res); 2886 return Res; 2887 }; 2888 2889 switch (Table.to_ulong()) { 2890 case 0: // 0 0 0 0 2891 return FoldConstant(false); 2892 case 1: // 0 0 0 1 2893 return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr; 2894 case 2: // 0 0 1 0 2895 return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr; 2896 case 3: // 0 0 1 1 2897 return Builder.CreateNot(Op0); 2898 case 4: // 0 1 0 0 2899 return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr; 2900 case 5: // 0 1 0 1 2901 return Builder.CreateNot(Op1); 2902 case 6: // 0 1 1 0 2903 return Builder.CreateXor(Op0, Op1); 2904 case 7: // 0 1 1 1 2905 return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr; 2906 case 8: // 1 0 0 0 2907 return Builder.CreateAnd(Op0, Op1); 2908 case 9: // 1 0 0 1 2909 return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr; 2910 case 10: // 1 0 1 0 2911 return Op1; 2912 case 11: // 1 0 1 1 2913 return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr; 2914 case 12: // 1 1 0 0 2915 return Op0; 2916 case 13: // 1 1 0 1 2917 return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr; 2918 case 14: // 1 1 1 0 2919 return Builder.CreateOr(Op0, Op1); 2920 case 15: // 1 1 1 1 2921 return FoldConstant(true); 2922 default: 2923 llvm_unreachable("Invalid Operation"); 2924 } 2925 return nullptr; 2926 } 2927 2928 /// Fold icmp (add X, Y), C. 2929 Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, 2930 BinaryOperator *Add, 2931 const APInt &C) { 2932 Value *Y = Add->getOperand(1); 2933 Value *X = Add->getOperand(0); 2934 2935 Value *Op0, *Op1; 2936 Instruction *Ext0, *Ext1; 2937 const CmpInst::Predicate Pred = Cmp.getPredicate(); 2938 if (match(Add, 2939 m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))), 2940 m_CombineAnd(m_Instruction(Ext1), 2941 m_ZExtOrSExt(m_Value(Op1))))) && 2942 Op0->getType()->isIntOrIntVectorTy(1) && 2943 Op1->getType()->isIntOrIntVectorTy(1)) { 2944 unsigned BW = C.getBitWidth(); 2945 std::bitset<4> Table; 2946 auto ComputeTable = [&](bool Op0Val, bool Op1Val) { 2947 int Res = 0; 2948 if (Op0Val) 2949 Res += isa<ZExtInst>(Ext0) ? 1 : -1; 2950 if (Op1Val) 2951 Res += isa<ZExtInst>(Ext1) ? 1 : -1; 2952 return ICmpInst::compare(APInt(BW, Res, true), C, Pred); 2953 }; 2954 2955 Table[0] = ComputeTable(false, false); 2956 Table[1] = ComputeTable(false, true); 2957 Table[2] = ComputeTable(true, false); 2958 Table[3] = ComputeTable(true, true); 2959 if (auto *Cond = 2960 createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse())) 2961 return replaceInstUsesWith(Cmp, Cond); 2962 } 2963 const APInt *C2; 2964 if (Cmp.isEquality() || !match(Y, m_APInt(C2))) 2965 return nullptr; 2966 2967 // Fold icmp pred (add X, C2), C. 2968 Type *Ty = Add->getType(); 2969 2970 // If the add does not wrap, we can always adjust the compare by subtracting 2971 // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE 2972 // are canonicalized to SGT/SLT/UGT/ULT. 2973 if ((Add->hasNoSignedWrap() && 2974 (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || 2975 (Add->hasNoUnsignedWrap() && 2976 (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { 2977 bool Overflow; 2978 APInt NewC = 2979 Cmp.isSigned() ? C.ssub_ov(*C2, Overflow) : C.usub_ov(*C2, Overflow); 2980 // If there is overflow, the result must be true or false. 2981 // TODO: Can we assert there is no overflow because InstSimplify always 2982 // handles those cases? 2983 if (!Overflow) 2984 // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) 2985 return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); 2986 } 2987 2988 auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2); 2989 const APInt &Upper = CR.getUpper(); 2990 const APInt &Lower = CR.getLower(); 2991 if (Cmp.isSigned()) { 2992 if (Lower.isSignMask()) 2993 return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); 2994 if (Upper.isSignMask()) 2995 return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); 2996 } else { 2997 if (Lower.isMinValue()) 2998 return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper)); 2999 if (Upper.isMinValue()) 3000 return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); 3001 } 3002 3003 // This set of folds is intentionally placed after folds that use no-wrapping 3004 // flags because those folds are likely better for later analysis/codegen. 3005 const APInt SMax = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); 3006 const APInt SMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); 3007 3008 // Fold compare with offset to opposite sign compare if it eliminates offset: 3009 // (X + C2) >u C --> X <s -C2 (if C == C2 + SMAX) 3010 if (Pred == CmpInst::ICMP_UGT && C == *C2 + SMax) 3011 return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, -(*C2))); 3012 3013 // (X + C2) <u C --> X >s ~C2 (if C == C2 + SMIN) 3014 if (Pred == CmpInst::ICMP_ULT && C == *C2 + SMin) 3015 return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::get(Ty, ~(*C2))); 3016 3017 // (X + C2) >s C --> X <u (SMAX - C) (if C == C2 - 1) 3018 if (Pred == CmpInst::ICMP_SGT && C == *C2 - 1) 3019 return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, SMax - C)); 3020 3021 // (X + C2) <s C --> X >u (C ^ SMAX) (if C == C2) 3022 if (Pred == CmpInst::ICMP_SLT && C == *C2) 3023 return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax)); 3024 3025 // (X + -1) <u C --> X <=u C (if X is never null) 3026 if (Pred == CmpInst::ICMP_ULT && C2->isAllOnes()) { 3027 const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); 3028 if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) 3029 return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); 3030 } 3031 3032 if (!Add->hasOneUse()) 3033 return nullptr; 3034 3035 // X+C <u C2 -> (X & -C2) == C 3036 // iff C & (C2-1) == 0 3037 // C2 is a power of 2 3038 if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0) 3039 return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C), 3040 ConstantExpr::getNeg(cast<Constant>(Y))); 3041 3042 // X+C >u C2 -> (X & ~C2) != C 3043 // iff C & C2 == 0 3044 // C2+1 is a power of 2 3045 if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0) 3046 return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), 3047 ConstantExpr::getNeg(cast<Constant>(Y))); 3048 3049 // The range test idiom can use either ult or ugt. Arbitrarily canonicalize 3050 // to the ult form. 3051 // X+C2 >u C -> X+(C2-C-1) <u ~C 3052 if (Pred == ICmpInst::ICMP_UGT) 3053 return new ICmpInst(ICmpInst::ICMP_ULT, 3054 Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)), 3055 ConstantInt::get(Ty, ~C)); 3056 3057 return nullptr; 3058 } 3059 3060 bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, 3061 Value *&RHS, ConstantInt *&Less, 3062 ConstantInt *&Equal, 3063 ConstantInt *&Greater) { 3064 // TODO: Generalize this to work with other comparison idioms or ensure 3065 // they get canonicalized into this form. 3066 3067 // select i1 (a == b), 3068 // i32 Equal, 3069 // i32 (select i1 (a < b), i32 Less, i32 Greater) 3070 // where Equal, Less and Greater are placeholders for any three constants. 3071 ICmpInst::Predicate PredA; 3072 if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) || 3073 !ICmpInst::isEquality(PredA)) 3074 return false; 3075 Value *EqualVal = SI->getTrueValue(); 3076 Value *UnequalVal = SI->getFalseValue(); 3077 // We still can get non-canonical predicate here, so canonicalize. 3078 if (PredA == ICmpInst::ICMP_NE) 3079 std::swap(EqualVal, UnequalVal); 3080 if (!match(EqualVal, m_ConstantInt(Equal))) 3081 return false; 3082 ICmpInst::Predicate PredB; 3083 Value *LHS2, *RHS2; 3084 if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)), 3085 m_ConstantInt(Less), m_ConstantInt(Greater)))) 3086 return false; 3087 // We can get predicate mismatch here, so canonicalize if possible: 3088 // First, ensure that 'LHS' match. 3089 if (LHS2 != LHS) { 3090 // x sgt y <--> y slt x 3091 std::swap(LHS2, RHS2); 3092 PredB = ICmpInst::getSwappedPredicate(PredB); 3093 } 3094 if (LHS2 != LHS) 3095 return false; 3096 // We also need to canonicalize 'RHS'. 3097 if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { 3098 // x sgt C-1 <--> x sge C <--> not(x slt C) 3099 auto FlippedStrictness = 3100 InstCombiner::getFlippedStrictnessPredicateAndConstant( 3101 PredB, cast<Constant>(RHS2)); 3102 if (!FlippedStrictness) 3103 return false; 3104 assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && 3105 "basic correctness failure"); 3106 RHS2 = FlippedStrictness->second; 3107 // And kind-of perform the result swap. 3108 std::swap(Less, Greater); 3109 PredB = ICmpInst::ICMP_SLT; 3110 } 3111 return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; 3112 } 3113 3114 Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, 3115 SelectInst *Select, 3116 ConstantInt *C) { 3117 3118 assert(C && "Cmp RHS should be a constant int!"); 3119 // If we're testing a constant value against the result of a three way 3120 // comparison, the result can be expressed directly in terms of the 3121 // original values being compared. Note: We could possibly be more 3122 // aggressive here and remove the hasOneUse test. The original select is 3123 // really likely to simplify or sink when we remove a test of the result. 3124 Value *OrigLHS, *OrigRHS; 3125 ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; 3126 if (Cmp.hasOneUse() && 3127 matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal, 3128 C3GreaterThan)) { 3129 assert(C1LessThan && C2Equal && C3GreaterThan); 3130 3131 bool TrueWhenLessThan = 3132 ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) 3133 ->isAllOnesValue(); 3134 bool TrueWhenEqual = 3135 ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) 3136 ->isAllOnesValue(); 3137 bool TrueWhenGreaterThan = 3138 ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) 3139 ->isAllOnesValue(); 3140 3141 // This generates the new instruction that will replace the original Cmp 3142 // Instruction. Instead of enumerating the various combinations when 3143 // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus 3144 // false, we rely on chaining of ORs and future passes of InstCombine to 3145 // simplify the OR further (i.e. a s< b || a == b becomes a s<= b). 3146 3147 // When none of the three constants satisfy the predicate for the RHS (C), 3148 // the entire original Cmp can be simplified to a false. 3149 Value *Cond = Builder.getFalse(); 3150 if (TrueWhenLessThan) 3151 Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, 3152 OrigLHS, OrigRHS)); 3153 if (TrueWhenEqual) 3154 Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, 3155 OrigLHS, OrigRHS)); 3156 if (TrueWhenGreaterThan) 3157 Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, 3158 OrigLHS, OrigRHS)); 3159 3160 return replaceInstUsesWith(Cmp, Cond); 3161 } 3162 return nullptr; 3163 } 3164 3165 Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { 3166 auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); 3167 if (!Bitcast) 3168 return nullptr; 3169 3170 ICmpInst::Predicate Pred = Cmp.getPredicate(); 3171 Value *Op1 = Cmp.getOperand(1); 3172 Value *BCSrcOp = Bitcast->getOperand(0); 3173 Type *SrcType = Bitcast->getSrcTy(); 3174 Type *DstType = Bitcast->getType(); 3175 3176 // Make sure the bitcast doesn't change between scalar and vector and 3177 // doesn't change the number of vector elements. 3178 if (SrcType->isVectorTy() == DstType->isVectorTy() && 3179 SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) { 3180 // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. 3181 Value *X; 3182 if (match(BCSrcOp, m_SIToFP(m_Value(X)))) { 3183 // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 3184 // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 3185 // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 3186 // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 3187 if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || 3188 Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && 3189 match(Op1, m_Zero())) 3190 return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); 3191 3192 // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 3193 if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) 3194 return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); 3195 3196 // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 3197 if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) 3198 return new ICmpInst(Pred, X, 3199 ConstantInt::getAllOnesValue(X->getType())); 3200 } 3201 3202 // Zero-equality checks are preserved through unsigned floating-point casts: 3203 // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 3204 // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 3205 if (match(BCSrcOp, m_UIToFP(m_Value(X)))) 3206 if (Cmp.isEquality() && match(Op1, m_Zero())) 3207 return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); 3208 3209 // If this is a sign-bit test of a bitcast of a casted FP value, eliminate 3210 // the FP extend/truncate because that cast does not change the sign-bit. 3211 // This is true for all standard IEEE-754 types and the X86 80-bit type. 3212 // The sign-bit is always the most significant bit in those types. 3213 const APInt *C; 3214 bool TrueIfSigned; 3215 if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && 3216 isSignBitCheck(Pred, *C, TrueIfSigned)) { 3217 if (match(BCSrcOp, m_FPExt(m_Value(X))) || 3218 match(BCSrcOp, m_FPTrunc(m_Value(X)))) { 3219 // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 3220 // (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1 3221 Type *XType = X->getType(); 3222 3223 // We can't currently handle Power style floating point operations here. 3224 if (!(XType->isPPC_FP128Ty() || SrcType->isPPC_FP128Ty())) { 3225 Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits()); 3226 if (auto *XVTy = dyn_cast<VectorType>(XType)) 3227 NewType = VectorType::get(NewType, XVTy->getElementCount()); 3228 Value *NewBitcast = Builder.CreateBitCast(X, NewType); 3229 if (TrueIfSigned) 3230 return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast, 3231 ConstantInt::getNullValue(NewType)); 3232 else 3233 return new ICmpInst(ICmpInst::ICMP_SGT, NewBitcast, 3234 ConstantInt::getAllOnesValue(NewType)); 3235 } 3236 } 3237 } 3238 } 3239 3240 const APInt *C; 3241 if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() || 3242 !SrcType->isIntOrIntVectorTy()) 3243 return nullptr; 3244 3245 // If this is checking if all elements of a vector compare are set or not, 3246 // invert the casted vector equality compare and test if all compare 3247 // elements are clear or not. Compare against zero is generally easier for 3248 // analysis and codegen. 3249 // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 3250 // Example: are all elements equal? --> are zero elements not equal? 3251 // TODO: Try harder to reduce compare of 2 freely invertible operands? 3252 if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse()) { 3253 if (Value *NotBCSrcOp = 3254 getFreelyInverted(BCSrcOp, BCSrcOp->hasOneUse(), &Builder)) { 3255 Value *Cast = Builder.CreateBitCast(NotBCSrcOp, DstType); 3256 return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); 3257 } 3258 } 3259 3260 // If this is checking if all elements of an extended vector are clear or not, 3261 // compare in a narrow type to eliminate the extend: 3262 // icmp eq/ne (bitcast (ext X) to iN), 0 --> icmp eq/ne (bitcast X to iM), 0 3263 Value *X; 3264 if (Cmp.isEquality() && C->isZero() && Bitcast->hasOneUse() && 3265 match(BCSrcOp, m_ZExtOrSExt(m_Value(X)))) { 3266 if (auto *VecTy = dyn_cast<FixedVectorType>(X->getType())) { 3267 Type *NewType = Builder.getIntNTy(VecTy->getPrimitiveSizeInBits()); 3268 Value *NewCast = Builder.CreateBitCast(X, NewType); 3269 return new ICmpInst(Pred, NewCast, ConstantInt::getNullValue(NewType)); 3270 } 3271 } 3272 3273 // Folding: icmp <pred> iN X, C 3274 // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN 3275 // and C is a splat of a K-bit pattern 3276 // and SC is a constant vector = <C', C', C', ..., C'> 3277 // Into: 3278 // %E = extractelement <M x iK> %vec, i32 C' 3279 // icmp <pred> iK %E, trunc(C) 3280 Value *Vec; 3281 ArrayRef<int> Mask; 3282 if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { 3283 // Check whether every element of Mask is the same constant 3284 if (all_equal(Mask)) { 3285 auto *VecTy = cast<VectorType>(SrcType); 3286 auto *EltTy = cast<IntegerType>(VecTy->getElementType()); 3287 if (C->isSplat(EltTy->getBitWidth())) { 3288 // Fold the icmp based on the value of C 3289 // If C is M copies of an iK sized bit pattern, 3290 // then: 3291 // => %E = extractelement <N x iK> %vec, i32 Elem 3292 // icmp <pred> iK %SplatVal, <pattern> 3293 Value *Elem = Builder.getInt32(Mask[0]); 3294 Value *Extract = Builder.CreateExtractElement(Vec, Elem); 3295 Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth())); 3296 return new ICmpInst(Pred, Extract, NewC); 3297 } 3298 } 3299 } 3300 return nullptr; 3301 } 3302 3303 /// Try to fold integer comparisons with a constant operand: icmp Pred X, C 3304 /// where X is some kind of instruction. 3305 Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { 3306 const APInt *C; 3307 3308 if (match(Cmp.getOperand(1), m_APInt(C))) { 3309 if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) 3310 if (Instruction *I = foldICmpBinOpWithConstant(Cmp, BO, *C)) 3311 return I; 3312 3313 if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) 3314 // For now, we only support constant integers while folding the 3315 // ICMP(SELECT)) pattern. We can extend this to support vector of integers 3316 // similar to the cases handled by binary ops above. 3317 if (auto *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) 3318 if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) 3319 return I; 3320 3321 if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) 3322 if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) 3323 return I; 3324 3325 if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) 3326 if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) 3327 return I; 3328 3329 // (extractval ([s/u]subo X, Y), 0) == 0 --> X == Y 3330 // (extractval ([s/u]subo X, Y), 0) != 0 --> X != Y 3331 // TODO: This checks one-use, but that is not strictly necessary. 3332 Value *Cmp0 = Cmp.getOperand(0); 3333 Value *X, *Y; 3334 if (C->isZero() && Cmp.isEquality() && Cmp0->hasOneUse() && 3335 (match(Cmp0, 3336 m_ExtractValue<0>(m_Intrinsic<Intrinsic::ssub_with_overflow>( 3337 m_Value(X), m_Value(Y)))) || 3338 match(Cmp0, 3339 m_ExtractValue<0>(m_Intrinsic<Intrinsic::usub_with_overflow>( 3340 m_Value(X), m_Value(Y)))))) 3341 return new ICmpInst(Cmp.getPredicate(), X, Y); 3342 } 3343 3344 if (match(Cmp.getOperand(1), m_APIntAllowUndef(C))) 3345 return foldICmpInstWithConstantAllowUndef(Cmp, *C); 3346 3347 return nullptr; 3348 } 3349 3350 /// Fold an icmp equality instruction with binary operator LHS and constant RHS: 3351 /// icmp eq/ne BO, C. 3352 Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( 3353 ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) { 3354 // TODO: Some of these folds could work with arbitrary constants, but this 3355 // function is limited to scalar and vector splat constants. 3356 if (!Cmp.isEquality()) 3357 return nullptr; 3358 3359 ICmpInst::Predicate Pred = Cmp.getPredicate(); 3360 bool isICMP_NE = Pred == ICmpInst::ICMP_NE; 3361 Constant *RHS = cast<Constant>(Cmp.getOperand(1)); 3362 Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); 3363 3364 switch (BO->getOpcode()) { 3365 case Instruction::SRem: 3366 // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. 3367 if (C.isZero() && BO->hasOneUse()) { 3368 const APInt *BOC; 3369 if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { 3370 Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); 3371 return new ICmpInst(Pred, NewRem, 3372 Constant::getNullValue(BO->getType())); 3373 } 3374 } 3375 break; 3376 case Instruction::Add: { 3377 // (A + C2) == C --> A == (C - C2) 3378 // (A + C2) != C --> A != (C - C2) 3379 // TODO: Remove the one-use limitation? See discussion in D58633. 3380 if (Constant *C2 = dyn_cast<Constant>(BOp1)) { 3381 if (BO->hasOneUse()) 3382 return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, C2)); 3383 } else if (C.isZero()) { 3384 // Replace ((add A, B) != 0) with (A != -B) if A or B is 3385 // efficiently invertible, or if the add has just this one use. 3386 if (Value *NegVal = dyn_castNegVal(BOp1)) 3387 return new ICmpInst(Pred, BOp0, NegVal); 3388 if (Value *NegVal = dyn_castNegVal(BOp0)) 3389 return new ICmpInst(Pred, NegVal, BOp1); 3390 if (BO->hasOneUse()) { 3391 Value *Neg = Builder.CreateNeg(BOp1); 3392 Neg->takeName(BO); 3393 return new ICmpInst(Pred, BOp0, Neg); 3394 } 3395 } 3396 break; 3397 } 3398 case Instruction::Xor: 3399 if (BO->hasOneUse()) { 3400 if (Constant *BOC = dyn_cast<Constant>(BOp1)) { 3401 // For the xor case, we can xor two constants together, eliminating 3402 // the explicit xor. 3403 return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); 3404 } else if (C.isZero()) { 3405 // Replace ((xor A, B) != 0) with (A != B) 3406 return new ICmpInst(Pred, BOp0, BOp1); 3407 } 3408 } 3409 break; 3410 case Instruction::Or: { 3411 const APInt *BOC; 3412 if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { 3413 // Comparing if all bits outside of a constant mask are set? 3414 // Replace (X | C) == -1 with (X & ~C) == ~C. 3415 // This removes the -1 constant. 3416 Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1)); 3417 Value *And = Builder.CreateAnd(BOp0, NotBOC); 3418 return new ICmpInst(Pred, And, NotBOC); 3419 } 3420 break; 3421 } 3422 case Instruction::UDiv: 3423 case Instruction::SDiv: 3424 if (BO->isExact()) { 3425 // div exact X, Y eq/ne 0 -> X eq/ne 0 3426 // div exact X, Y eq/ne 1 -> X eq/ne Y 3427 // div exact X, Y eq/ne C -> 3428 // if Y * C never-overflow && OneUse: 3429 // -> Y * C eq/ne X 3430 if (C.isZero()) 3431 return new ICmpInst(Pred, BOp0, Constant::getNullValue(BO->getType())); 3432 else if (C.isOne()) 3433 return new ICmpInst(Pred, BOp0, BOp1); 3434 else if (BO->hasOneUse()) { 3435 OverflowResult OR = computeOverflow( 3436 Instruction::Mul, BO->getOpcode() == Instruction::SDiv, BOp1, 3437 Cmp.getOperand(1), BO); 3438 if (OR == OverflowResult::NeverOverflows) { 3439 Value *YC = 3440 Builder.CreateMul(BOp1, ConstantInt::get(BO->getType(), C)); 3441 return new ICmpInst(Pred, YC, BOp0); 3442 } 3443 } 3444 } 3445 if (BO->getOpcode() == Instruction::UDiv && C.isZero()) { 3446 // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) 3447 auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; 3448 return new ICmpInst(NewPred, BOp1, BOp0); 3449 } 3450 break; 3451 default: 3452 break; 3453 } 3454 return nullptr; 3455 } 3456 3457 static Instruction *foldCtpopPow2Test(ICmpInst &I, IntrinsicInst *CtpopLhs, 3458 const APInt &CRhs, 3459 InstCombiner::BuilderTy &Builder, 3460 const SimplifyQuery &Q) { 3461 assert(CtpopLhs->getIntrinsicID() == Intrinsic::ctpop && 3462 "Non-ctpop intrin in ctpop fold"); 3463 if (!CtpopLhs->hasOneUse()) 3464 return nullptr; 3465 3466 // Power of 2 test: 3467 // isPow2OrZero : ctpop(X) u< 2 3468 // isPow2 : ctpop(X) == 1 3469 // NotPow2OrZero: ctpop(X) u> 1 3470 // NotPow2 : ctpop(X) != 1 3471 // If we know any bit of X can be folded to: 3472 // IsPow2 : X & (~Bit) == 0 3473 // NotPow2 : X & (~Bit) != 0 3474 const ICmpInst::Predicate Pred = I.getPredicate(); 3475 if (((I.isEquality() || Pred == ICmpInst::ICMP_UGT) && CRhs == 1) || 3476 (Pred == ICmpInst::ICMP_ULT && CRhs == 2)) { 3477 Value *Op = CtpopLhs->getArgOperand(0); 3478 KnownBits OpKnown = computeKnownBits(Op, Q.DL, 3479 /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT); 3480 // No need to check for count > 1, that should be already constant folded. 3481 if (OpKnown.countMinPopulation() == 1) { 3482 Value *And = Builder.CreateAnd( 3483 Op, Constant::getIntegerValue(Op->getType(), ~(OpKnown.One))); 3484 return new ICmpInst( 3485 (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_ULT) 3486 ? ICmpInst::ICMP_EQ 3487 : ICmpInst::ICMP_NE, 3488 And, Constant::getNullValue(Op->getType())); 3489 } 3490 } 3491 3492 return nullptr; 3493 } 3494 3495 /// Fold an equality icmp with LLVM intrinsic and constant operand. 3496 Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( 3497 ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { 3498 Type *Ty = II->getType(); 3499 unsigned BitWidth = C.getBitWidth(); 3500 const ICmpInst::Predicate Pred = Cmp.getPredicate(); 3501 3502 switch (II->getIntrinsicID()) { 3503 case Intrinsic::abs: 3504 // abs(A) == 0 -> A == 0 3505 // abs(A) == INT_MIN -> A == INT_MIN 3506 if (C.isZero() || C.isMinSignedValue()) 3507 return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C)); 3508 break; 3509 3510 case Intrinsic::bswap: 3511 // bswap(A) == C -> A == bswap(C) 3512 return new ICmpInst(Pred, II->getArgOperand(0), 3513 ConstantInt::get(Ty, C.byteSwap())); 3514 3515 case Intrinsic::bitreverse: 3516 // bitreverse(A) == C -> A == bitreverse(C) 3517 return new ICmpInst(Pred, II->getArgOperand(0), 3518 ConstantInt::get(Ty, C.reverseBits())); 3519 3520 case Intrinsic::ctlz: 3521 case Intrinsic::cttz: { 3522 // ctz(A) == bitwidth(A) -> A == 0 and likewise for != 3523 if (C == BitWidth) 3524 return new ICmpInst(Pred, II->getArgOperand(0), 3525 ConstantInt::getNullValue(Ty)); 3526 3527 // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set 3528 // and Mask1 has bits 0..C+1 set. Similar for ctl, but for high bits. 3529 // Limit to one use to ensure we don't increase instruction count. 3530 unsigned Num = C.getLimitedValue(BitWidth); 3531 if (Num != BitWidth && II->hasOneUse()) { 3532 bool IsTrailing = II->getIntrinsicID() == Intrinsic::cttz; 3533 APInt Mask1 = IsTrailing ? APInt::getLowBitsSet(BitWidth, Num + 1) 3534 : APInt::getHighBitsSet(BitWidth, Num + 1); 3535 APInt Mask2 = IsTrailing 3536 ? APInt::getOneBitSet(BitWidth, Num) 3537 : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); 3538 return new ICmpInst(Pred, Builder.CreateAnd(II->getArgOperand(0), Mask1), 3539 ConstantInt::get(Ty, Mask2)); 3540 } 3541 break; 3542 } 3543 3544 case Intrinsic::ctpop: { 3545 // popcount(A) == 0 -> A == 0 and likewise for != 3546 // popcount(A) == bitwidth(A) -> A == -1 and likewise for != 3547 bool IsZero = C.isZero(); 3548 if (IsZero || C == BitWidth) 3549 return new ICmpInst(Pred, II->getArgOperand(0), 3550 IsZero ? Constant::getNullValue(Ty) 3551 : Constant::getAllOnesValue(Ty)); 3552 3553 break; 3554 } 3555 3556 case Intrinsic::fshl: 3557 case Intrinsic::fshr: 3558 if (II->getArgOperand(0) == II->getArgOperand(1)) { 3559 const APInt *RotAmtC; 3560 // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC) 3561 // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC) 3562 if (match(II->getArgOperand(2), m_APInt(RotAmtC))) 3563 return new ICmpInst(Pred, II->getArgOperand(0), 3564 II->getIntrinsicID() == Intrinsic::fshl 3565 ? ConstantInt::get(Ty, C.rotr(*RotAmtC)) 3566 : ConstantInt::get(Ty, C.rotl(*RotAmtC))); 3567 } 3568 break; 3569 3570 case Intrinsic::umax: 3571 case Intrinsic::uadd_sat: { 3572 // uadd.sat(a, b) == 0 -> (a | b) == 0 3573 // umax(a, b) == 0 -> (a | b) == 0 3574 if (C.isZero() && II->hasOneUse()) { 3575 Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); 3576 return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); 3577 } 3578 break; 3579 } 3580 3581 case Intrinsic::ssub_sat: 3582 // ssub.sat(a, b) == 0 -> a == b 3583 if (C.isZero()) 3584 return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); 3585 break; 3586 case Intrinsic::usub_sat: { 3587 // usub.sat(a, b) == 0 -> a <= b 3588 if (C.isZero()) { 3589 ICmpInst::Predicate NewPred = 3590 Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; 3591 return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); 3592 } 3593 break; 3594 } 3595 default: 3596 break; 3597 } 3598 3599 return nullptr; 3600 } 3601 3602 /// Fold an icmp with LLVM intrinsics 3603 static Instruction * 3604 foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp, 3605 InstCombiner::BuilderTy &Builder) { 3606 assert(Cmp.isEquality()); 3607 3608 ICmpInst::Predicate Pred = Cmp.getPredicate(); 3609 Value *Op0 = Cmp.getOperand(0); 3610 Value *Op1 = Cmp.getOperand(1); 3611 const auto *IIOp0 = dyn_cast<IntrinsicInst>(Op0); 3612 const auto *IIOp1 = dyn_cast<IntrinsicInst>(Op1); 3613 if (!IIOp0 || !IIOp1 || IIOp0->getIntrinsicID() != IIOp1->getIntrinsicID()) 3614 return nullptr; 3615 3616 switch (IIOp0->getIntrinsicID()) { 3617 case Intrinsic::bswap: 3618 case Intrinsic::bitreverse: 3619 // If both operands are byte-swapped or bit-reversed, just compare the 3620 // original values. 3621 return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); 3622 case Intrinsic::fshl: 3623 case Intrinsic::fshr: { 3624 // If both operands are rotated by same amount, just compare the 3625 // original values. 3626 if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) 3627 break; 3628 if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) 3629 break; 3630 if (IIOp0->getOperand(2) == IIOp1->getOperand(2)) 3631 return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); 3632 3633 // rotate(X, AmtX) == rotate(Y, AmtY) 3634 // -> rotate(X, AmtX - AmtY) == Y 3635 // Do this if either both rotates have one use or if only one has one use 3636 // and AmtX/AmtY are constants. 3637 unsigned OneUses = IIOp0->hasOneUse() + IIOp1->hasOneUse(); 3638 if (OneUses == 2 || 3639 (OneUses == 1 && match(IIOp0->getOperand(2), m_ImmConstant()) && 3640 match(IIOp1->getOperand(2), m_ImmConstant()))) { 3641 Value *SubAmt = 3642 Builder.CreateSub(IIOp0->getOperand(2), IIOp1->getOperand(2)); 3643 Value *CombinedRotate = Builder.CreateIntrinsic( 3644 Op0->getType(), IIOp0->getIntrinsicID(), 3645 {IIOp0->getOperand(0), IIOp0->getOperand(0), SubAmt}); 3646 return new ICmpInst(Pred, IIOp1->getOperand(0), CombinedRotate); 3647 } 3648 } break; 3649 default: 3650 break; 3651 } 3652 3653 return nullptr; 3654 } 3655 3656 /// Try to fold integer comparisons with a constant operand: icmp Pred X, C 3657 /// where X is some kind of instruction and C is AllowUndef. 3658 /// TODO: Move more folds which allow undef to this function. 3659 Instruction * 3660 InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, 3661 const APInt &C) { 3662 const ICmpInst::Predicate Pred = Cmp.getPredicate(); 3663 if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) { 3664 switch (II->getIntrinsicID()) { 3665 default: 3666 break; 3667 case Intrinsic::fshl: 3668 case Intrinsic::fshr: 3669 if (Cmp.isEquality() && II->getArgOperand(0) == II->getArgOperand(1)) { 3670 // (rot X, ?) == 0/-1 --> X == 0/-1 3671 if (C.isZero() || C.isAllOnes()) 3672 return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); 3673 } 3674 break; 3675 } 3676 } 3677 3678 return nullptr; 3679 } 3680 3681 /// Fold an icmp with BinaryOp and constant operand: icmp Pred BO, C. 3682 Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, 3683 BinaryOperator *BO, 3684 const APInt &C) { 3685 switch (BO->getOpcode()) { 3686 case Instruction::Xor: 3687 if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) 3688 return I; 3689 break; 3690 case Instruction::And: 3691 if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) 3692 return I; 3693 break; 3694 case Instruction::Or: 3695 if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) 3696 return I; 3697 break; 3698 case Instruction::Mul: 3699 if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) 3700 return I; 3701 break; 3702 case Instruction::Shl: 3703 if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) 3704 return I; 3705 break; 3706 case Instruction::LShr: 3707 case Instruction::AShr: 3708 if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) 3709 return I; 3710 break; 3711 case Instruction::SRem: 3712 if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C)) 3713 return I; 3714 break; 3715 case Instruction::UDiv: 3716 if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) 3717 return I; 3718 [[fallthrough]]; 3719 case Instruction::SDiv: 3720 if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) 3721 return I; 3722 break; 3723 case Instruction::Sub: 3724 if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) 3725 return I; 3726 break; 3727 case Instruction::Add: 3728 if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) 3729 return I; 3730 break; 3731 default: 3732 break; 3733 } 3734 3735 // TODO: These folds could be refactored to be part of the above calls. 3736 return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); 3737 } 3738 3739 static Instruction * 3740 foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, 3741 SaturatingInst *II, const APInt &C, 3742 InstCombiner::BuilderTy &Builder) { 3743 // This transform may end up producing more than one instruction for the 3744 // intrinsic, so limit it to one user of the intrinsic. 3745 if (!II->hasOneUse()) 3746 return nullptr; 3747 3748 // Let Y = [add/sub]_sat(X, C) pred C2 3749 // SatVal = The saturating value for the operation 3750 // WillWrap = Whether or not the operation will underflow / overflow 3751 // => Y = (WillWrap ? SatVal : (X binop C)) pred C2 3752 // => Y = WillWrap ? (SatVal pred C2) : ((X binop C) pred C2) 3753 // 3754 // When (SatVal pred C2) is true, then 3755 // Y = WillWrap ? true : ((X binop C) pred C2) 3756 // => Y = WillWrap || ((X binop C) pred C2) 3757 // else 3758 // Y = WillWrap ? false : ((X binop C) pred C2) 3759 // => Y = !WillWrap ? ((X binop C) pred C2) : false 3760 // => Y = !WillWrap && ((X binop C) pred C2) 3761 Value *Op0 = II->getOperand(0); 3762 Value *Op1 = II->getOperand(1); 3763 3764 const APInt *COp1; 3765 // This transform only works when the intrinsic has an integral constant or 3766 // splat vector as the second operand. 3767 if (!match(Op1, m_APInt(COp1))) 3768 return nullptr; 3769 3770 APInt SatVal; 3771 switch (II->getIntrinsicID()) { 3772 default: 3773 llvm_unreachable( 3774 "This function only works with usub_sat and uadd_sat for now!"); 3775 case Intrinsic::uadd_sat: 3776 SatVal = APInt::getAllOnes(C.getBitWidth()); 3777 break; 3778 case Intrinsic::usub_sat: 3779 SatVal = APInt::getZero(C.getBitWidth()); 3780 break; 3781 } 3782 3783 // Check (SatVal pred C2) 3784 bool SatValCheck = ICmpInst::compare(SatVal, C, Pred); 3785 3786 // !WillWrap. 3787 ConstantRange C1 = ConstantRange::makeExactNoWrapRegion( 3788 II->getBinaryOp(), *COp1, II->getNoWrapKind()); 3789 3790 // WillWrap. 3791 if (SatValCheck) 3792 C1 = C1.inverse(); 3793 3794 ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); 3795 if (II->getBinaryOp() == Instruction::Add) 3796 C2 = C2.sub(*COp1); 3797 else 3798 C2 = C2.add(*COp1); 3799 3800 Instruction::BinaryOps CombiningOp = 3801 SatValCheck ? Instruction::BinaryOps::Or : Instruction::BinaryOps::And; 3802 3803 std::optional<ConstantRange> Combination; 3804 if (CombiningOp == Instruction::BinaryOps::Or) 3805 Combination = C1.exactUnionWith(C2); 3806 else /* CombiningOp == Instruction::BinaryOps::And */ 3807 Combination = C1.exactIntersectWith(C2); 3808 3809 if (!Combination) 3810 return nullptr; 3811 3812 CmpInst::Predicate EquivPred; 3813 APInt EquivInt; 3814 APInt EquivOffset; 3815 3816 Combination->getEquivalentICmp(EquivPred, EquivInt, EquivOffset); 3817 3818 return new ICmpInst( 3819 EquivPred, 3820 Builder.CreateAdd(Op0, ConstantInt::get(Op1->getType(), EquivOffset)), 3821 ConstantInt::get(Op1->getType(), EquivInt)); 3822 } 3823 3824 /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. 3825 Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, 3826 IntrinsicInst *II, 3827 const APInt &C) { 3828 ICmpInst::Predicate Pred = Cmp.getPredicate(); 3829 3830 // Handle folds that apply for any kind of icmp. 3831 switch (II->getIntrinsicID()) { 3832 default: 3833 break; 3834 case Intrinsic::uadd_sat: 3835 case Intrinsic::usub_sat: 3836 if (auto *Folded = foldICmpUSubSatOrUAddSatWithConstant( 3837 Pred, cast<SaturatingInst>(II), C, Builder)) 3838 return Folded; 3839 break; 3840 case Intrinsic::ctpop: { 3841 const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); 3842 if (Instruction *R = foldCtpopPow2Test(Cmp, II, C, Builder, Q)) 3843 return R; 3844 } break; 3845 } 3846 3847 if (Cmp.isEquality()) 3848 return foldICmpEqIntrinsicWithConstant(Cmp, II, C); 3849 3850 Type *Ty = II->getType(); 3851 unsigned BitWidth = C.getBitWidth(); 3852 switch (II->getIntrinsicID()) { 3853 case Intrinsic::ctpop: { 3854 // (ctpop X > BitWidth - 1) --> X == -1 3855 Value *X = II->getArgOperand(0); 3856 if (C == BitWidth - 1 && Pred == ICmpInst::ICMP_UGT) 3857 return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, X, 3858 ConstantInt::getAllOnesValue(Ty)); 3859 // (ctpop X < BitWidth) --> X != -1 3860 if (C == BitWidth && Pred == ICmpInst::ICMP_ULT) 3861 return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, X, 3862 ConstantInt::getAllOnesValue(Ty)); 3863 break; 3864 } 3865 case Intrinsic::ctlz: { 3866 // ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000 3867 if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { 3868 unsigned Num = C.getLimitedValue(); 3869 APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); 3870 return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, 3871 II->getArgOperand(0), ConstantInt::get(Ty, Limit)); 3872 } 3873 3874 // ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111 3875 if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) { 3876 unsigned Num = C.getLimitedValue(); 3877 APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num); 3878 return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, 3879 II->getArgOperand(0), ConstantInt::get(Ty, Limit)); 3880 } 3881 break; 3882 } 3883 case Intrinsic::cttz: { 3884 // Limit to one use to ensure we don't increase instruction count. 3885 if (!II->hasOneUse()) 3886 return nullptr; 3887 3888 // cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0 3889 if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { 3890 APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1); 3891 return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, 3892 Builder.CreateAnd(II->getArgOperand(0), Mask), 3893 ConstantInt::getNullValue(Ty)); 3894 } 3895 3896 // cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0 3897 if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) { 3898 APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue()); 3899 return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, 3900 Builder.CreateAnd(II->getArgOperand(0), Mask), 3901 ConstantInt::getNullValue(Ty)); 3902 } 3903 break; 3904 } 3905 case Intrinsic::ssub_sat: 3906 // ssub.sat(a, b) spred 0 -> a spred b 3907 if (ICmpInst::isSigned(Pred)) { 3908 if (C.isZero()) 3909 return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); 3910 // X s<= 0 is cannonicalized to X s< 1 3911 if (Pred == ICmpInst::ICMP_SLT && C.isOne()) 3912 return new ICmpInst(ICmpInst::ICMP_SLE, II->getArgOperand(0), 3913 II->getArgOperand(1)); 3914 // X s>= 0 is cannonicalized to X s> -1 3915 if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) 3916 return new ICmpInst(ICmpInst::ICMP_SGE, II->getArgOperand(0), 3917 II->getArgOperand(1)); 3918 } 3919 break; 3920 default: 3921 break; 3922 } 3923 3924 return nullptr; 3925 } 3926 3927 /// Handle icmp with constant (but not simple integer constant) RHS. 3928 Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { 3929 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 3930 Constant *RHSC = dyn_cast<Constant>(Op1); 3931 Instruction *LHSI = dyn_cast<Instruction>(Op0); 3932 if (!RHSC || !LHSI) 3933 return nullptr; 3934 3935 switch (LHSI->getOpcode()) { 3936 case Instruction::PHI: 3937 if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) 3938 return NV; 3939 break; 3940 case Instruction::IntToPtr: 3941 // icmp pred inttoptr(X), null -> icmp pred X, 0 3942 if (RHSC->isNullValue() && 3943 DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) 3944 return new ICmpInst( 3945 I.getPredicate(), LHSI->getOperand(0), 3946 Constant::getNullValue(LHSI->getOperand(0)->getType())); 3947 break; 3948 3949 case Instruction::Load: 3950 // Try to optimize things like "A[i] > 4" to index computations. 3951 if (GetElementPtrInst *GEP = 3952 dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) 3953 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) 3954 if (Instruction *Res = 3955 foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I)) 3956 return Res; 3957 break; 3958 } 3959 3960 return nullptr; 3961 } 3962 3963 Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, 3964 SelectInst *SI, Value *RHS, 3965 const ICmpInst &I) { 3966 // Try to fold the comparison into the select arms, which will cause the 3967 // select to be converted into a logical and/or. 3968 auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * { 3969 if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ)) 3970 return Res; 3971 if (std::optional<bool> Impl = isImpliedCondition( 3972 SI->getCondition(), Pred, Op, RHS, DL, SelectCondIsTrue)) 3973 return ConstantInt::get(I.getType(), *Impl); 3974 return nullptr; 3975 }; 3976 3977 ConstantInt *CI = nullptr; 3978 Value *Op1 = SimplifyOp(SI->getOperand(1), true); 3979 if (Op1) 3980 CI = dyn_cast<ConstantInt>(Op1); 3981 3982 Value *Op2 = SimplifyOp(SI->getOperand(2), false); 3983 if (Op2) 3984 CI = dyn_cast<ConstantInt>(Op2); 3985 3986 // We only want to perform this transformation if it will not lead to 3987 // additional code. This is true if either both sides of the select 3988 // fold to a constant (in which case the icmp is replaced with a select 3989 // which will usually simplify) or this is the only user of the 3990 // select (in which case we are trading a select+icmp for a simpler 3991 // select+icmp) or all uses of the select can be replaced based on 3992 // dominance information ("Global cases"). 3993 bool Transform = false; 3994 if (Op1 && Op2) 3995 Transform = true; 3996 else if (Op1 || Op2) { 3997 // Local case 3998 if (SI->hasOneUse()) 3999 Transform = true; 4000 // Global cases 4001 else if (CI && !CI->isZero()) 4002 // When Op1 is constant try replacing select with second operand. 4003 // Otherwise Op2 is constant and try replacing select with first 4004 // operand. 4005 Transform = replacedSelectWithOperand(SI, &I, Op1 ? 2 : 1); 4006 } 4007 if (Transform) { 4008 if (!Op1) 4009 Op1 = Builder.CreateICmp(Pred, SI->getOperand(1), RHS, I.getName()); 4010 if (!Op2) 4011 Op2 = Builder.CreateICmp(Pred, SI->getOperand(2), RHS, I.getName()); 4012 return SelectInst::Create(SI->getOperand(0), Op1, Op2); 4013 } 4014 4015 return nullptr; 4016 } 4017 4018 /// Some comparisons can be simplified. 4019 /// In this case, we are looking for comparisons that look like 4020 /// a check for a lossy truncation. 4021 /// Folds: 4022 /// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask 4023 /// Where Mask is some pattern that produces all-ones in low bits: 4024 /// (-1 >> y) 4025 /// ((-1 << y) >> y) <- non-canonical, has extra uses 4026 /// ~(-1 << y) 4027 /// ((1 << y) + (-1)) <- non-canonical, has extra uses 4028 /// The Mask can be a constant, too. 4029 /// For some predicates, the operands are commutative. 4030 /// For others, x can only be on a specific side. 4031 static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, 4032 InstCombiner::BuilderTy &Builder) { 4033 ICmpInst::Predicate SrcPred; 4034 Value *X, *M, *Y; 4035 auto m_VariableMask = m_CombineOr( 4036 m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())), 4037 m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())), 4038 m_CombineOr(m_LShr(m_AllOnes(), m_Value()), 4039 m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y)))); 4040 auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask()); 4041 if (!match(&I, m_c_ICmp(SrcPred, 4042 m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)), 4043 m_Deferred(X)))) 4044 return nullptr; 4045 4046 ICmpInst::Predicate DstPred; 4047 switch (SrcPred) { 4048 case ICmpInst::Predicate::ICMP_EQ: 4049 // x & (-1 >> y) == x -> x u<= (-1 >> y) 4050 DstPred = ICmpInst::Predicate::ICMP_ULE; 4051 break; 4052 case ICmpInst::Predicate::ICMP_NE: 4053 // x & (-1 >> y) != x -> x u> (-1 >> y) 4054 DstPred = ICmpInst::Predicate::ICMP_UGT; 4055 break; 4056 case ICmpInst::Predicate::ICMP_ULT: 4057 // x & (-1 >> y) u< x -> x u> (-1 >> y) 4058 // x u> x & (-1 >> y) -> x u> (-1 >> y) 4059 DstPred = ICmpInst::Predicate::ICMP_UGT; 4060 break; 4061 case ICmpInst::Predicate::ICMP_UGE: 4062 // x & (-1 >> y) u>= x -> x u<= (-1 >> y) 4063 // x u<= x & (-1 >> y) -> x u<= (-1 >> y) 4064 DstPred = ICmpInst::Predicate::ICMP_ULE; 4065 break; 4066 case ICmpInst::Predicate::ICMP_SLT: 4067 // x & (-1 >> y) s< x -> x s> (-1 >> y) 4068 // x s> x & (-1 >> y) -> x s> (-1 >> y) 4069 if (!match(M, m_Constant())) // Can not do this fold with non-constant. 4070 return nullptr; 4071 if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. 4072 return nullptr; 4073 DstPred = ICmpInst::Predicate::ICMP_SGT; 4074 break; 4075 case ICmpInst::Predicate::ICMP_SGE: 4076 // x & (-1 >> y) s>= x -> x s<= (-1 >> y) 4077 // x s<= x & (-1 >> y) -> x s<= (-1 >> y) 4078 if (!match(M, m_Constant())) // Can not do this fold with non-constant. 4079 return nullptr; 4080 if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. 4081 return nullptr; 4082 DstPred = ICmpInst::Predicate::ICMP_SLE; 4083 break; 4084 case ICmpInst::Predicate::ICMP_SGT: 4085 case ICmpInst::Predicate::ICMP_SLE: 4086 return nullptr; 4087 case ICmpInst::Predicate::ICMP_UGT: 4088 case ICmpInst::Predicate::ICMP_ULE: 4089 llvm_unreachable("Instsimplify took care of commut. variant"); 4090 break; 4091 default: 4092 llvm_unreachable("All possible folds are handled."); 4093 } 4094 4095 // The mask value may be a vector constant that has undefined elements. But it 4096 // may not be safe to propagate those undefs into the new compare, so replace 4097 // those elements by copying an existing, defined, and safe scalar constant. 4098 Type *OpTy = M->getType(); 4099 auto *VecC = dyn_cast<Constant>(M); 4100 auto *OpVTy = dyn_cast<FixedVectorType>(OpTy); 4101 if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { 4102 Constant *SafeReplacementConstant = nullptr; 4103 for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { 4104 if (!isa<UndefValue>(VecC->getAggregateElement(i))) { 4105 SafeReplacementConstant = VecC->getAggregateElement(i); 4106 break; 4107 } 4108 } 4109 assert(SafeReplacementConstant && "Failed to find undef replacement"); 4110 M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant); 4111 } 4112 4113 return Builder.CreateICmp(DstPred, X, M); 4114 } 4115 4116 /// Some comparisons can be simplified. 4117 /// In this case, we are looking for comparisons that look like 4118 /// a check for a lossy signed truncation. 4119 /// Folds: (MaskedBits is a constant.) 4120 /// ((%x << MaskedBits) a>> MaskedBits) SrcPred %x 4121 /// Into: 4122 /// (add %x, (1 << (KeptBits-1))) DstPred (1 << KeptBits) 4123 /// Where KeptBits = bitwidth(%x) - MaskedBits 4124 static Value * 4125 foldICmpWithTruncSignExtendedVal(ICmpInst &I, 4126 InstCombiner::BuilderTy &Builder) { 4127 ICmpInst::Predicate SrcPred; 4128 Value *X; 4129 const APInt *C0, *C1; // FIXME: non-splats, potentially with undef. 4130 // We are ok with 'shl' having multiple uses, but 'ashr' must be one-use. 4131 if (!match(&I, m_c_ICmp(SrcPred, 4132 m_OneUse(m_AShr(m_Shl(m_Value(X), m_APInt(C0)), 4133 m_APInt(C1))), 4134 m_Deferred(X)))) 4135 return nullptr; 4136 4137 // Potential handling of non-splats: for each element: 4138 // * if both are undef, replace with constant 0. 4139 // Because (1<<0) is OK and is 1, and ((1<<0)>>1) is also OK and is 0. 4140 // * if both are not undef, and are different, bailout. 4141 // * else, only one is undef, then pick the non-undef one. 4142 4143 // The shift amount must be equal. 4144 if (*C0 != *C1) 4145 return nullptr; 4146 const APInt &MaskedBits = *C0; 4147 assert(MaskedBits != 0 && "shift by zero should be folded away already."); 4148 4149 ICmpInst::Predicate DstPred; 4150 switch (SrcPred) { 4151 case ICmpInst::Predicate::ICMP_EQ: 4152 // ((%x << MaskedBits) a>> MaskedBits) == %x 4153 // => 4154 // (add %x, (1 << (KeptBits-1))) u< (1 << KeptBits) 4155 DstPred = ICmpInst::Predicate::ICMP_ULT; 4156 break; 4157 case ICmpInst::Predicate::ICMP_NE: 4158 // ((%x << MaskedBits) a>> MaskedBits) != %x 4159 // => 4160 // (add %x, (1 << (KeptBits-1))) u>= (1 << KeptBits) 4161 DstPred = ICmpInst::Predicate::ICMP_UGE; 4162 break; 4163 // FIXME: are more folds possible? 4164 default: 4165 return nullptr; 4166 } 4167 4168 auto *XType = X->getType(); 4169 const unsigned XBitWidth = XType->getScalarSizeInBits(); 4170 const APInt BitWidth = APInt(XBitWidth, XBitWidth); 4171 assert(BitWidth.ugt(MaskedBits) && "shifts should leave some bits untouched"); 4172 4173 // KeptBits = bitwidth(%x) - MaskedBits 4174 const APInt KeptBits = BitWidth - MaskedBits; 4175 assert(KeptBits.ugt(0) && KeptBits.ult(BitWidth) && "unreachable"); 4176 // ICmpCst = (1 << KeptBits) 4177 const APInt ICmpCst = APInt(XBitWidth, 1).shl(KeptBits); 4178 assert(ICmpCst.isPowerOf2()); 4179 // AddCst = (1 << (KeptBits-1)) 4180 const APInt AddCst = ICmpCst.lshr(1); 4181 assert(AddCst.ult(ICmpCst) && AddCst.isPowerOf2()); 4182 4183 // T0 = add %x, AddCst 4184 Value *T0 = Builder.CreateAdd(X, ConstantInt::get(XType, AddCst)); 4185 // T1 = T0 DstPred ICmpCst 4186 Value *T1 = Builder.CreateICmp(DstPred, T0, ConstantInt::get(XType, ICmpCst)); 4187 4188 return T1; 4189 } 4190 4191 // Given pattern: 4192 // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 4193 // we should move shifts to the same hand of 'and', i.e. rewrite as 4194 // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) 4195 // We are only interested in opposite logical shifts here. 4196 // One of the shifts can be truncated. 4197 // If we can, we want to end up creating 'lshr' shift. 4198 static Value * 4199 foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, 4200 InstCombiner::BuilderTy &Builder) { 4201 if (!I.isEquality() || !match(I.getOperand(1), m_Zero()) || 4202 !I.getOperand(0)->hasOneUse()) 4203 return nullptr; 4204 4205 auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); 4206 4207 // Look for an 'and' of two logical shifts, one of which may be truncated. 4208 // We use m_TruncOrSelf() on the RHS to correctly handle commutative case. 4209 Instruction *XShift, *MaybeTruncation, *YShift; 4210 if (!match( 4211 I.getOperand(0), 4212 m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), 4213 m_CombineAnd(m_TruncOrSelf(m_CombineAnd( 4214 m_AnyLogicalShift, m_Instruction(YShift))), 4215 m_Instruction(MaybeTruncation))))) 4216 return nullptr; 4217 4218 // We potentially looked past 'trunc', but only when matching YShift, 4219 // therefore YShift must have the widest type. 4220 Instruction *WidestShift = YShift; 4221 // Therefore XShift must have the shallowest type. 4222 // Or they both have identical types if there was no truncation. 4223 Instruction *NarrowestShift = XShift; 4224 4225 Type *WidestTy = WidestShift->getType(); 4226 Type *NarrowestTy = NarrowestShift->getType(); 4227 assert(NarrowestTy == I.getOperand(0)->getType() && 4228 "We did not look past any shifts while matching XShift though."); 4229 bool HadTrunc = WidestTy != I.getOperand(0)->getType(); 4230 4231 // If YShift is a 'lshr', swap the shifts around. 4232 if (match(YShift, m_LShr(m_Value(), m_Value()))) 4233 std::swap(XShift, YShift); 4234 4235 // The shifts must be in opposite directions. 4236 auto XShiftOpcode = XShift->getOpcode(); 4237 if (XShiftOpcode == YShift->getOpcode()) 4238 return nullptr; // Do not care about same-direction shifts here. 4239 4240 Value *X, *XShAmt, *Y, *YShAmt; 4241 match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt)))); 4242 match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt)))); 4243 4244 // If one of the values being shifted is a constant, then we will end with 4245 // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not, 4246 // however, we will need to ensure that we won't increase instruction count. 4247 if (!isa<Constant>(X) && !isa<Constant>(Y)) { 4248 // At least one of the hands of the 'and' should be one-use shift. 4249 if (!match(I.getOperand(0), 4250 m_c_And(m_OneUse(m_AnyLogicalShift), m_Value()))) 4251 return nullptr; 4252 if (HadTrunc) { 4253 // Due to the 'trunc', we will need to widen X. For that either the old 4254 // 'trunc' or the shift amt in the non-truncated shift should be one-use. 4255 if (!MaybeTruncation->hasOneUse() && 4256 !NarrowestShift->getOperand(1)->hasOneUse()) 4257 return nullptr; 4258 } 4259 } 4260 4261 // We have two shift amounts from two different shifts. The types of those 4262 // shift amounts may not match. If that's the case let's bailout now. 4263 if (XShAmt->getType() != YShAmt->getType()) 4264 return nullptr; 4265 4266 // As input, we have the following pattern: 4267 // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 4268 // We want to rewrite that as: 4269 // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) 4270 // While we know that originally (Q+K) would not overflow 4271 // (because 2 * (N-1) u<= iN -1), we have looked past extensions of 4272 // shift amounts. so it may now overflow in smaller bitwidth. 4273 // To ensure that does not happen, we need to ensure that the total maximal 4274 // shift amount is still representable in that smaller bit width. 4275 unsigned MaximalPossibleTotalShiftAmount = 4276 (WidestTy->getScalarSizeInBits() - 1) + 4277 (NarrowestTy->getScalarSizeInBits() - 1); 4278 APInt MaximalRepresentableShiftAmount = 4279 APInt::getAllOnes(XShAmt->getType()->getScalarSizeInBits()); 4280 if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) 4281 return nullptr; 4282 4283 // Can we fold (XShAmt+YShAmt) ? 4284 auto *NewShAmt = dyn_cast_or_null<Constant>( 4285 simplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, 4286 /*isNUW=*/false, SQ.getWithInstruction(&I))); 4287 if (!NewShAmt) 4288 return nullptr; 4289 if (NewShAmt->getType() != WidestTy) { 4290 NewShAmt = 4291 ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, WidestTy, SQ.DL); 4292 if (!NewShAmt) 4293 return nullptr; 4294 } 4295 unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); 4296 4297 // Is the new shift amount smaller than the bit width? 4298 // FIXME: could also rely on ConstantRange. 4299 if (!match(NewShAmt, 4300 m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, 4301 APInt(WidestBitWidth, WidestBitWidth)))) 4302 return nullptr; 4303 4304 // An extra legality check is needed if we had trunc-of-lshr. 4305 if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) { 4306 auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ, 4307 WidestShift]() { 4308 // It isn't obvious whether it's worth it to analyze non-constants here. 4309 // Also, let's basically give up on non-splat cases, pessimizing vectors. 4310 // If *any* of these preconditions matches we can perform the fold. 4311 Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy() 4312 ? NewShAmt->getSplatValue() 4313 : NewShAmt; 4314 // If it's edge-case shift (by 0 or by WidestBitWidth-1) we can fold. 4315 if (NewShAmtSplat && 4316 (NewShAmtSplat->isNullValue() || 4317 NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1)) 4318 return true; 4319 // We consider *min* leading zeros so a single outlier 4320 // blocks the transform as opposed to allowing it. 4321 if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) { 4322 KnownBits Known = computeKnownBits(C, SQ.DL); 4323 unsigned MinLeadZero = Known.countMinLeadingZeros(); 4324 // If the value being shifted has at most lowest bit set we can fold. 4325 unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; 4326 if (MaxActiveBits <= 1) 4327 return true; 4328 // Precondition: NewShAmt u<= countLeadingZeros(C) 4329 if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero)) 4330 return true; 4331 } 4332 if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) { 4333 KnownBits Known = computeKnownBits(C, SQ.DL); 4334 unsigned MinLeadZero = Known.countMinLeadingZeros(); 4335 // If the value being shifted has at most lowest bit set we can fold. 4336 unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; 4337 if (MaxActiveBits <= 1) 4338 return true; 4339 // Precondition: ((WidestBitWidth-1)-NewShAmt) u<= countLeadingZeros(C) 4340 if (NewShAmtSplat) { 4341 APInt AdjNewShAmt = 4342 (WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger(); 4343 if (AdjNewShAmt.ule(MinLeadZero)) 4344 return true; 4345 } 4346 } 4347 return false; // Can't tell if it's ok. 4348 }; 4349 if (!CanFold()) 4350 return nullptr; 4351 } 4352 4353 // All good, we can do this fold. 4354 X = Builder.CreateZExt(X, WidestTy); 4355 Y = Builder.CreateZExt(Y, WidestTy); 4356 // The shift is the same that was for X. 4357 Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr 4358 ? Builder.CreateLShr(X, NewShAmt) 4359 : Builder.CreateShl(X, NewShAmt); 4360 Value *T1 = Builder.CreateAnd(T0, Y); 4361 return Builder.CreateICmp(I.getPredicate(), T1, 4362 Constant::getNullValue(WidestTy)); 4363 } 4364 4365 /// Fold 4366 /// (-1 u/ x) u< y 4367 /// ((x * y) ?/ x) != y 4368 /// to 4369 /// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit 4370 /// Note that the comparison is commutative, while inverted (u>=, ==) predicate 4371 /// will mean that we are looking for the opposite answer. 4372 Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { 4373 ICmpInst::Predicate Pred; 4374 Value *X, *Y; 4375 Instruction *Mul; 4376 Instruction *Div; 4377 bool NeedNegation; 4378 // Look for: (-1 u/ x) u</u>= y 4379 if (!I.isEquality() && 4380 match(&I, m_c_ICmp(Pred, 4381 m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), 4382 m_Instruction(Div)), 4383 m_Value(Y)))) { 4384 Mul = nullptr; 4385 4386 // Are we checking that overflow does not happen, or does happen? 4387 switch (Pred) { 4388 case ICmpInst::Predicate::ICMP_ULT: 4389 NeedNegation = false; 4390 break; // OK 4391 case ICmpInst::Predicate::ICMP_UGE: 4392 NeedNegation = true; 4393 break; // OK 4394 default: 4395 return nullptr; // Wrong predicate. 4396 } 4397 } else // Look for: ((x * y) / x) !=/== y 4398 if (I.isEquality() && 4399 match(&I, 4400 m_c_ICmp(Pred, m_Value(Y), 4401 m_CombineAnd( 4402 m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), 4403 m_Value(X)), 4404 m_Instruction(Mul)), 4405 m_Deferred(X))), 4406 m_Instruction(Div))))) { 4407 NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; 4408 } else 4409 return nullptr; 4410 4411 BuilderTy::InsertPointGuard Guard(Builder); 4412 // If the pattern included (x * y), we'll want to insert new instructions 4413 // right before that original multiplication so that we can replace it. 4414 bool MulHadOtherUses = Mul && !Mul->hasOneUse(); 4415 if (MulHadOtherUses) 4416 Builder.SetInsertPoint(Mul); 4417 4418 Function *F = Intrinsic::getDeclaration(I.getModule(), 4419 Div->getOpcode() == Instruction::UDiv 4420 ? Intrinsic::umul_with_overflow 4421 : Intrinsic::smul_with_overflow, 4422 X->getType()); 4423 CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul"); 4424 4425 // If the multiplication was used elsewhere, to ensure that we don't leave 4426 // "duplicate" instructions, replace uses of that original multiplication 4427 // with the multiplication result from the with.overflow intrinsic. 4428 if (MulHadOtherUses) 4429 replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val")); 4430 4431 Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov"); 4432 if (NeedNegation) // This technically increases instruction count. 4433 Res = Builder.CreateNot(Res, "mul.not.ov"); 4434 4435 // If we replaced the mul, erase it. Do this after all uses of Builder, 4436 // as the mul is used as insertion point. 4437 if (MulHadOtherUses) 4438 eraseInstFromFunction(*Mul); 4439 4440 return Res; 4441 } 4442 4443 static Instruction *foldICmpXNegX(ICmpInst &I, 4444 InstCombiner::BuilderTy &Builder) { 4445 CmpInst::Predicate Pred; 4446 Value *X; 4447 if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) { 4448 4449 if (ICmpInst::isSigned(Pred)) 4450 Pred = ICmpInst::getSwappedPredicate(Pred); 4451 else if (ICmpInst::isUnsigned(Pred)) 4452 Pred = ICmpInst::getSignedPredicate(Pred); 4453 // else for equality-comparisons just keep the predicate. 4454 4455 return ICmpInst::Create(Instruction::ICmp, Pred, X, 4456 Constant::getNullValue(X->getType()), I.getName()); 4457 } 4458 4459 // A value is not equal to its negation unless that value is 0 or 4460 // MinSignedValue, ie: a != -a --> (a & MaxSignedVal) != 0 4461 if (match(&I, m_c_ICmp(Pred, m_OneUse(m_Neg(m_Value(X))), m_Deferred(X))) && 4462 ICmpInst::isEquality(Pred)) { 4463 Type *Ty = X->getType(); 4464 uint32_t BitWidth = Ty->getScalarSizeInBits(); 4465 Constant *MaxSignedVal = 4466 ConstantInt::get(Ty, APInt::getSignedMaxValue(BitWidth)); 4467 Value *And = Builder.CreateAnd(X, MaxSignedVal); 4468 Constant *Zero = Constant::getNullValue(Ty); 4469 return CmpInst::Create(Instruction::ICmp, Pred, And, Zero); 4470 } 4471 4472 return nullptr; 4473 } 4474 4475 static Instruction *foldICmpAndXX(ICmpInst &I, const SimplifyQuery &Q, 4476 InstCombinerImpl &IC) { 4477 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; 4478 // Normalize and operand as operand 0. 4479 CmpInst::Predicate Pred = I.getPredicate(); 4480 if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) { 4481 std::swap(Op0, Op1); 4482 Pred = ICmpInst::getSwappedPredicate(Pred); 4483 } 4484 4485 if (!match(Op0, m_c_And(m_Specific(Op1), m_Value(A)))) 4486 return nullptr; 4487 4488 // (icmp (X & Y) u< X --> (X & Y) != X 4489 if (Pred == ICmpInst::ICMP_ULT) 4490 return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); 4491 4492 // (icmp (X & Y) u>= X --> (X & Y) == X 4493 if (Pred == ICmpInst::ICMP_UGE) 4494 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); 4495 4496 return nullptr; 4497 } 4498 4499 static Instruction *foldICmpOrXX(ICmpInst &I, const SimplifyQuery &Q, 4500 InstCombinerImpl &IC) { 4501 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; 4502 4503 // Normalize or operand as operand 0. 4504 CmpInst::Predicate Pred = I.getPredicate(); 4505 if (match(Op1, m_c_Or(m_Specific(Op0), m_Value(A)))) { 4506 std::swap(Op0, Op1); 4507 Pred = ICmpInst::getSwappedPredicate(Pred); 4508 } else if (!match(Op0, m_c_Or(m_Specific(Op1), m_Value(A)))) { 4509 return nullptr; 4510 } 4511 4512 // icmp (X | Y) u<= X --> (X | Y) == X 4513 if (Pred == ICmpInst::ICMP_ULE) 4514 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); 4515 4516 // icmp (X | Y) u> X --> (X | Y) != X 4517 if (Pred == ICmpInst::ICMP_UGT) 4518 return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); 4519 4520 if (ICmpInst::isEquality(Pred) && Op0->hasOneUse()) { 4521 // icmp (X | Y) eq/ne Y --> (X & ~Y) eq/ne 0 if Y is freely invertible 4522 if (Value *NotOp1 = 4523 IC.getFreelyInverted(Op1, Op1->hasOneUse(), &IC.Builder)) 4524 return new ICmpInst(Pred, IC.Builder.CreateAnd(A, NotOp1), 4525 Constant::getNullValue(Op1->getType())); 4526 // icmp (X | Y) eq/ne Y --> (~X | Y) eq/ne -1 if X is freely invertible. 4527 if (Value *NotA = IC.getFreelyInverted(A, A->hasOneUse(), &IC.Builder)) 4528 return new ICmpInst(Pred, IC.Builder.CreateOr(Op1, NotA), 4529 Constant::getAllOnesValue(Op1->getType())); 4530 } 4531 return nullptr; 4532 } 4533 4534 static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, 4535 InstCombinerImpl &IC) { 4536 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; 4537 // Normalize xor operand as operand 0. 4538 CmpInst::Predicate Pred = I.getPredicate(); 4539 if (match(Op1, m_c_Xor(m_Specific(Op0), m_Value()))) { 4540 std::swap(Op0, Op1); 4541 Pred = ICmpInst::getSwappedPredicate(Pred); 4542 } 4543 if (!match(Op0, m_c_Xor(m_Specific(Op1), m_Value(A)))) 4544 return nullptr; 4545 4546 // icmp (X ^ Y_NonZero) u>= X --> icmp (X ^ Y_NonZero) u> X 4547 // icmp (X ^ Y_NonZero) u<= X --> icmp (X ^ Y_NonZero) u< X 4548 // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X 4549 // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X 4550 CmpInst::Predicate PredOut = CmpInst::getStrictPredicate(Pred); 4551 if (PredOut != Pred && 4552 isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) 4553 return new ICmpInst(PredOut, Op0, Op1); 4554 4555 return nullptr; 4556 } 4557 4558 /// Try to fold icmp (binop), X or icmp X, (binop). 4559 /// TODO: A large part of this logic is duplicated in InstSimplify's 4560 /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code 4561 /// duplication. 4562 Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, 4563 const SimplifyQuery &SQ) { 4564 const SimplifyQuery Q = SQ.getWithInstruction(&I); 4565 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 4566 4567 // Special logic for binary operators. 4568 BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); 4569 BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); 4570 if (!BO0 && !BO1) 4571 return nullptr; 4572 4573 if (Instruction *NewICmp = foldICmpXNegX(I, Builder)) 4574 return NewICmp; 4575 4576 const CmpInst::Predicate Pred = I.getPredicate(); 4577 Value *X; 4578 4579 // Convert add-with-unsigned-overflow comparisons into a 'not' with compare. 4580 // (Op1 + X) u</u>= Op1 --> ~Op1 u</u>= X 4581 if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) && 4582 (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) 4583 return new ICmpInst(Pred, Builder.CreateNot(Op1), X); 4584 // Op0 u>/u<= (Op0 + X) --> X u>/u<= ~Op0 4585 if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) && 4586 (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) 4587 return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); 4588 4589 { 4590 // (Op1 + X) + C u</u>= Op1 --> ~C - X u</u>= Op1 4591 Constant *C; 4592 if (match(Op0, m_OneUse(m_Add(m_c_Add(m_Specific(Op1), m_Value(X)), 4593 m_ImmConstant(C)))) && 4594 (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { 4595 Constant *C2 = ConstantExpr::getNot(C); 4596 return new ICmpInst(Pred, Builder.CreateSub(C2, X), Op1); 4597 } 4598 // Op0 u>/u<= (Op0 + X) + C --> Op0 u>/u<= ~C - X 4599 if (match(Op1, m_OneUse(m_Add(m_c_Add(m_Specific(Op0), m_Value(X)), 4600 m_ImmConstant(C)))) && 4601 (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) { 4602 Constant *C2 = ConstantExpr::getNot(C); 4603 return new ICmpInst(Pred, Op0, Builder.CreateSub(C2, X)); 4604 } 4605 } 4606 4607 { 4608 // Similar to above: an unsigned overflow comparison may use offset + mask: 4609 // ((Op1 + C) & C) u< Op1 --> Op1 != 0 4610 // ((Op1 + C) & C) u>= Op1 --> Op1 == 0 4611 // Op0 u> ((Op0 + C) & C) --> Op0 != 0 4612 // Op0 u<= ((Op0 + C) & C) --> Op0 == 0 4613 BinaryOperator *BO; 4614 const APInt *C; 4615 if ((Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) && 4616 match(Op0, m_And(m_BinOp(BO), m_LowBitMask(C))) && 4617 match(BO, m_Add(m_Specific(Op1), m_SpecificIntAllowUndef(*C)))) { 4618 CmpInst::Predicate NewPred = 4619 Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; 4620 Constant *Zero = ConstantInt::getNullValue(Op1->getType()); 4621 return new ICmpInst(NewPred, Op1, Zero); 4622 } 4623 4624 if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) && 4625 match(Op1, m_And(m_BinOp(BO), m_LowBitMask(C))) && 4626 match(BO, m_Add(m_Specific(Op0), m_SpecificIntAllowUndef(*C)))) { 4627 CmpInst::Predicate NewPred = 4628 Pred == ICmpInst::ICMP_UGT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; 4629 Constant *Zero = ConstantInt::getNullValue(Op1->getType()); 4630 return new ICmpInst(NewPred, Op0, Zero); 4631 } 4632 } 4633 4634 bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; 4635 bool Op0HasNUW = false, Op1HasNUW = false; 4636 bool Op0HasNSW = false, Op1HasNSW = false; 4637 // Analyze the case when either Op0 or Op1 is an add instruction. 4638 // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). 4639 auto hasNoWrapProblem = [](const BinaryOperator &BO, CmpInst::Predicate Pred, 4640 bool &HasNSW, bool &HasNUW) -> bool { 4641 if (isa<OverflowingBinaryOperator>(BO)) { 4642 HasNUW = BO.hasNoUnsignedWrap(); 4643 HasNSW = BO.hasNoSignedWrap(); 4644 return ICmpInst::isEquality(Pred) || 4645 (CmpInst::isUnsigned(Pred) && HasNUW) || 4646 (CmpInst::isSigned(Pred) && HasNSW); 4647 } else if (BO.getOpcode() == Instruction::Or) { 4648 HasNUW = true; 4649 HasNSW = true; 4650 return true; 4651 } else { 4652 return false; 4653 } 4654 }; 4655 Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; 4656 4657 if (BO0) { 4658 match(BO0, m_AddLike(m_Value(A), m_Value(B))); 4659 NoOp0WrapProblem = hasNoWrapProblem(*BO0, Pred, Op0HasNSW, Op0HasNUW); 4660 } 4661 if (BO1) { 4662 match(BO1, m_AddLike(m_Value(C), m_Value(D))); 4663 NoOp1WrapProblem = hasNoWrapProblem(*BO1, Pred, Op1HasNSW, Op1HasNUW); 4664 } 4665 4666 // icmp (A+B), A -> icmp B, 0 for equalities or if there is no overflow. 4667 // icmp (A+B), B -> icmp A, 0 for equalities or if there is no overflow. 4668 if ((A == Op1 || B == Op1) && NoOp0WrapProblem) 4669 return new ICmpInst(Pred, A == Op1 ? B : A, 4670 Constant::getNullValue(Op1->getType())); 4671 4672 // icmp C, (C+D) -> icmp 0, D for equalities or if there is no overflow. 4673 // icmp D, (C+D) -> icmp 0, C for equalities or if there is no overflow. 4674 if ((C == Op0 || D == Op0) && NoOp1WrapProblem) 4675 return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), 4676 C == Op0 ? D : C); 4677 4678 // icmp (A+B), (A+D) -> icmp B, D for equalities or if there is no overflow. 4679 if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && 4680 NoOp1WrapProblem) { 4681 // Determine Y and Z in the form icmp (X+Y), (X+Z). 4682 Value *Y, *Z; 4683 if (A == C) { 4684 // C + B == C + D -> B == D 4685 Y = B; 4686 Z = D; 4687 } else if (A == D) { 4688 // D + B == C + D -> B == C 4689 Y = B; 4690 Z = C; 4691 } else if (B == C) { 4692 // A + C == C + D -> A == D 4693 Y = A; 4694 Z = D; 4695 } else { 4696 assert(B == D); 4697 // A + D == C + D -> A == C 4698 Y = A; 4699 Z = C; 4700 } 4701 return new ICmpInst(Pred, Y, Z); 4702 } 4703 4704 // icmp slt (A + -1), Op1 -> icmp sle A, Op1 4705 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && 4706 match(B, m_AllOnes())) 4707 return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); 4708 4709 // icmp sge (A + -1), Op1 -> icmp sgt A, Op1 4710 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && 4711 match(B, m_AllOnes())) 4712 return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); 4713 4714 // icmp sle (A + 1), Op1 -> icmp slt A, Op1 4715 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) 4716 return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); 4717 4718 // icmp sgt (A + 1), Op1 -> icmp sge A, Op1 4719 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) 4720 return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); 4721 4722 // icmp sgt Op0, (C + -1) -> icmp sge Op0, C 4723 if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && 4724 match(D, m_AllOnes())) 4725 return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); 4726 4727 // icmp sle Op0, (C + -1) -> icmp slt Op0, C 4728 if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && 4729 match(D, m_AllOnes())) 4730 return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); 4731 4732 // icmp sge Op0, (C + 1) -> icmp sgt Op0, C 4733 if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) 4734 return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); 4735 4736 // icmp slt Op0, (C + 1) -> icmp sle Op0, C 4737 if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) 4738 return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); 4739 4740 // TODO: The subtraction-related identities shown below also hold, but 4741 // canonicalization from (X -nuw 1) to (X + -1) means that the combinations 4742 // wouldn't happen even if they were implemented. 4743 // 4744 // icmp ult (A - 1), Op1 -> icmp ule A, Op1 4745 // icmp uge (A - 1), Op1 -> icmp ugt A, Op1 4746 // icmp ugt Op0, (C - 1) -> icmp uge Op0, C 4747 // icmp ule Op0, (C - 1) -> icmp ult Op0, C 4748 4749 // icmp ule (A + 1), Op0 -> icmp ult A, Op1 4750 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) 4751 return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); 4752 4753 // icmp ugt (A + 1), Op0 -> icmp uge A, Op1 4754 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) 4755 return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); 4756 4757 // icmp uge Op0, (C + 1) -> icmp ugt Op0, C 4758 if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) 4759 return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); 4760 4761 // icmp ult Op0, (C + 1) -> icmp ule Op0, C 4762 if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) 4763 return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); 4764 4765 // if C1 has greater magnitude than C2: 4766 // icmp (A + C1), (C + C2) -> icmp (A + C3), C 4767 // s.t. C3 = C1 - C2 4768 // 4769 // if C2 has greater magnitude than C1: 4770 // icmp (A + C1), (C + C2) -> icmp A, (C + C3) 4771 // s.t. C3 = C2 - C1 4772 if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && 4773 (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) { 4774 const APInt *AP1, *AP2; 4775 // TODO: Support non-uniform vectors. 4776 // TODO: Allow undef passthrough if B AND D's element is undef. 4777 if (match(B, m_APIntAllowUndef(AP1)) && match(D, m_APIntAllowUndef(AP2)) && 4778 AP1->isNegative() == AP2->isNegative()) { 4779 APInt AP1Abs = AP1->abs(); 4780 APInt AP2Abs = AP2->abs(); 4781 if (AP1Abs.uge(AP2Abs)) { 4782 APInt Diff = *AP1 - *AP2; 4783 Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff); 4784 Value *NewAdd = Builder.CreateAdd( 4785 A, C3, "", Op0HasNUW && Diff.ule(*AP1), Op0HasNSW); 4786 return new ICmpInst(Pred, NewAdd, C); 4787 } else { 4788 APInt Diff = *AP2 - *AP1; 4789 Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff); 4790 Value *NewAdd = Builder.CreateAdd( 4791 C, C3, "", Op1HasNUW && Diff.ule(*AP2), Op1HasNSW); 4792 return new ICmpInst(Pred, A, NewAdd); 4793 } 4794 } 4795 Constant *Cst1, *Cst2; 4796 if (match(B, m_ImmConstant(Cst1)) && match(D, m_ImmConstant(Cst2)) && 4797 ICmpInst::isEquality(Pred)) { 4798 Constant *Diff = ConstantExpr::getSub(Cst2, Cst1); 4799 Value *NewAdd = Builder.CreateAdd(C, Diff); 4800 return new ICmpInst(Pred, A, NewAdd); 4801 } 4802 } 4803 4804 // Analyze the case when either Op0 or Op1 is a sub instruction. 4805 // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). 4806 A = nullptr; 4807 B = nullptr; 4808 C = nullptr; 4809 D = nullptr; 4810 if (BO0 && BO0->getOpcode() == Instruction::Sub) { 4811 A = BO0->getOperand(0); 4812 B = BO0->getOperand(1); 4813 } 4814 if (BO1 && BO1->getOpcode() == Instruction::Sub) { 4815 C = BO1->getOperand(0); 4816 D = BO1->getOperand(1); 4817 } 4818 4819 // icmp (A-B), A -> icmp 0, B for equalities or if there is no overflow. 4820 if (A == Op1 && NoOp0WrapProblem) 4821 return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); 4822 // icmp C, (C-D) -> icmp D, 0 for equalities or if there is no overflow. 4823 if (C == Op0 && NoOp1WrapProblem) 4824 return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); 4825 4826 // Convert sub-with-unsigned-overflow comparisons into a comparison of args. 4827 // (A - B) u>/u<= A --> B u>/u<= A 4828 if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) 4829 return new ICmpInst(Pred, B, A); 4830 // C u</u>= (C - D) --> C u</u>= D 4831 if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) 4832 return new ICmpInst(Pred, C, D); 4833 // (A - B) u>=/u< A --> B u>/u<= A iff B != 0 4834 if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && 4835 isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) 4836 return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A); 4837 // C u<=/u> (C - D) --> C u</u>= D iff B != 0 4838 if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && 4839 isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) 4840 return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D); 4841 4842 // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow. 4843 if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem) 4844 return new ICmpInst(Pred, A, C); 4845 4846 // icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow. 4847 if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem) 4848 return new ICmpInst(Pred, D, B); 4849 4850 // icmp (0-X) < cst --> x > -cst 4851 if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { 4852 Value *X; 4853 if (match(BO0, m_Neg(m_Value(X)))) 4854 if (Constant *RHSC = dyn_cast<Constant>(Op1)) 4855 if (RHSC->isNotMinSignedValue()) 4856 return new ICmpInst(I.getSwappedPredicate(), X, 4857 ConstantExpr::getNeg(RHSC)); 4858 } 4859 4860 if (Instruction * R = foldICmpXorXX(I, Q, *this)) 4861 return R; 4862 if (Instruction *R = foldICmpOrXX(I, Q, *this)) 4863 return R; 4864 4865 { 4866 // Try to remove shared multiplier from comparison: 4867 // X * Z u{lt/le/gt/ge}/eq/ne Y * Z 4868 Value *X, *Y, *Z; 4869 if (Pred == ICmpInst::getUnsignedPredicate(Pred) && 4870 ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) && 4871 match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) || 4872 (match(Op0, m_Mul(m_Value(Z), m_Value(X))) && 4873 match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) { 4874 bool NonZero; 4875 if (ICmpInst::isEquality(Pred)) { 4876 KnownBits ZKnown = computeKnownBits(Z, 0, &I); 4877 // if Z % 2 != 0 4878 // X * Z eq/ne Y * Z -> X eq/ne Y 4879 if (ZKnown.countMaxTrailingZeros() == 0) 4880 return new ICmpInst(Pred, X, Y); 4881 NonZero = !ZKnown.One.isZero() || 4882 isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); 4883 // if Z != 0 and nsw(X * Z) and nsw(Y * Z) 4884 // X * Z eq/ne Y * Z -> X eq/ne Y 4885 if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW) 4886 return new ICmpInst(Pred, X, Y); 4887 } else 4888 NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); 4889 4890 // If Z != 0 and nuw(X * Z) and nuw(Y * Z) 4891 // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y 4892 if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW) 4893 return new ICmpInst(Pred, X, Y); 4894 } 4895 } 4896 4897 BinaryOperator *SRem = nullptr; 4898 // icmp (srem X, Y), Y 4899 if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) 4900 SRem = BO0; 4901 // icmp Y, (srem X, Y) 4902 else if (BO1 && BO1->getOpcode() == Instruction::SRem && 4903 Op0 == BO1->getOperand(1)) 4904 SRem = BO1; 4905 if (SRem) { 4906 // We don't check hasOneUse to avoid increasing register pressure because 4907 // the value we use is the same value this instruction was already using. 4908 switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { 4909 default: 4910 break; 4911 case ICmpInst::ICMP_EQ: 4912 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 4913 case ICmpInst::ICMP_NE: 4914 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 4915 case ICmpInst::ICMP_SGT: 4916 case ICmpInst::ICMP_SGE: 4917 return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), 4918 Constant::getAllOnesValue(SRem->getType())); 4919 case ICmpInst::ICMP_SLT: 4920 case ICmpInst::ICMP_SLE: 4921 return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), 4922 Constant::getNullValue(SRem->getType())); 4923 } 4924 } 4925 4926 if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && 4927 (BO0->hasOneUse() || BO1->hasOneUse()) && 4928 BO0->getOperand(1) == BO1->getOperand(1)) { 4929 switch (BO0->getOpcode()) { 4930 default: 4931 break; 4932 case Instruction::Add: 4933 case Instruction::Sub: 4934 case Instruction::Xor: { 4935 if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b 4936 return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); 4937 4938 const APInt *C; 4939 if (match(BO0->getOperand(1), m_APInt(C))) { 4940 // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b 4941 if (C->isSignMask()) { 4942 ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate(); 4943 return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); 4944 } 4945 4946 // icmp u/s (a ^ maxsignval), (b ^ maxsignval) --> icmp s/u' a, b 4947 if (BO0->getOpcode() == Instruction::Xor && C->isMaxSignedValue()) { 4948 ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate(); 4949 NewPred = I.getSwappedPredicate(NewPred); 4950 return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); 4951 } 4952 } 4953 break; 4954 } 4955 case Instruction::Mul: { 4956 if (!I.isEquality()) 4957 break; 4958 4959 const APInt *C; 4960 if (match(BO0->getOperand(1), m_APInt(C)) && !C->isZero() && 4961 !C->isOne()) { 4962 // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) 4963 // Mask = -1 >> count-trailing-zeros(C). 4964 if (unsigned TZs = C->countr_zero()) { 4965 Constant *Mask = ConstantInt::get( 4966 BO0->getType(), 4967 APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs)); 4968 Value *And1 = Builder.CreateAnd(BO0->getOperand(0), Mask); 4969 Value *And2 = Builder.CreateAnd(BO1->getOperand(0), Mask); 4970 return new ICmpInst(Pred, And1, And2); 4971 } 4972 } 4973 break; 4974 } 4975 case Instruction::UDiv: 4976 case Instruction::LShr: 4977 if (I.isSigned() || !BO0->isExact() || !BO1->isExact()) 4978 break; 4979 return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); 4980 4981 case Instruction::SDiv: 4982 if (!(I.isEquality() || match(BO0->getOperand(1), m_NonNegative())) || 4983 !BO0->isExact() || !BO1->isExact()) 4984 break; 4985 return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); 4986 4987 case Instruction::AShr: 4988 if (!BO0->isExact() || !BO1->isExact()) 4989 break; 4990 return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); 4991 4992 case Instruction::Shl: { 4993 bool NUW = Op0HasNUW && Op1HasNUW; 4994 bool NSW = Op0HasNSW && Op1HasNSW; 4995 if (!NUW && !NSW) 4996 break; 4997 if (!NSW && I.isSigned()) 4998 break; 4999 return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); 5000 } 5001 } 5002 } 5003 5004 if (BO0) { 5005 // Transform A & (L - 1) `ult` L --> L != 0 5006 auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); 5007 auto BitwiseAnd = m_c_And(m_Value(), LSubOne); 5008 5009 if (match(BO0, BitwiseAnd) && Pred == ICmpInst::ICMP_ULT) { 5010 auto *Zero = Constant::getNullValue(BO0->getType()); 5011 return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); 5012 } 5013 } 5014 5015 // For unsigned predicates / eq / ne: 5016 // icmp pred (x << 1), x --> icmp getSignedPredicate(pred) x, 0 5017 // icmp pred x, (x << 1) --> icmp getSignedPredicate(pred) 0, x 5018 if (!ICmpInst::isSigned(Pred)) { 5019 if (match(Op0, m_Shl(m_Specific(Op1), m_One()))) 5020 return new ICmpInst(ICmpInst::getSignedPredicate(Pred), Op1, 5021 Constant::getNullValue(Op1->getType())); 5022 else if (match(Op1, m_Shl(m_Specific(Op0), m_One()))) 5023 return new ICmpInst(ICmpInst::getSignedPredicate(Pred), 5024 Constant::getNullValue(Op0->getType()), Op0); 5025 } 5026 5027 if (Value *V = foldMultiplicationOverflowCheck(I)) 5028 return replaceInstUsesWith(I, V); 5029 5030 if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) 5031 return replaceInstUsesWith(I, V); 5032 5033 if (Instruction *R = foldICmpAndXX(I, Q, *this)) 5034 return R; 5035 5036 if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) 5037 return replaceInstUsesWith(I, V); 5038 5039 if (Value *V = foldShiftIntoShiftInAnotherHandOfAndInICmp(I, SQ, Builder)) 5040 return replaceInstUsesWith(I, V); 5041 5042 return nullptr; 5043 } 5044 5045 /// Fold icmp Pred min|max(X, Y), Z. 5046 Instruction *InstCombinerImpl::foldICmpWithMinMax(Instruction &I, 5047 MinMaxIntrinsic *MinMax, 5048 Value *Z, 5049 ICmpInst::Predicate Pred) { 5050 Value *X = MinMax->getLHS(); 5051 Value *Y = MinMax->getRHS(); 5052 if (ICmpInst::isSigned(Pred) && !MinMax->isSigned()) 5053 return nullptr; 5054 if (ICmpInst::isUnsigned(Pred) && MinMax->isSigned()) { 5055 // Revert the transform signed pred -> unsigned pred 5056 // TODO: We can flip the signedness of predicate if both operands of icmp 5057 // are negative. 5058 if (isKnownNonNegative(Z, SQ.getWithInstruction(&I)) && 5059 isKnownNonNegative(MinMax, SQ.getWithInstruction(&I))) { 5060 Pred = ICmpInst::getFlippedSignednessPredicate(Pred); 5061 } else 5062 return nullptr; 5063 } 5064 SimplifyQuery Q = SQ.getWithInstruction(&I); 5065 auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> { 5066 if (!Val) 5067 return std::nullopt; 5068 if (match(Val, m_One())) 5069 return true; 5070 if (match(Val, m_Zero())) 5071 return false; 5072 return std::nullopt; 5073 }; 5074 auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(Pred, X, Z, Q)); 5075 auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(Pred, Y, Z, Q)); 5076 if (!CmpXZ.has_value() && !CmpYZ.has_value()) 5077 return nullptr; 5078 if (!CmpXZ.has_value()) { 5079 std::swap(X, Y); 5080 std::swap(CmpXZ, CmpYZ); 5081 } 5082 5083 auto FoldIntoCmpYZ = [&]() -> Instruction * { 5084 if (CmpYZ.has_value()) 5085 return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *CmpYZ)); 5086 return ICmpInst::Create(Instruction::ICmp, Pred, Y, Z); 5087 }; 5088 5089 switch (Pred) { 5090 case ICmpInst::ICMP_EQ: 5091 case ICmpInst::ICMP_NE: { 5092 // If X == Z: 5093 // Expr Result 5094 // min(X, Y) == Z X <= Y 5095 // max(X, Y) == Z X >= Y 5096 // min(X, Y) != Z X > Y 5097 // max(X, Y) != Z X < Y 5098 if ((Pred == ICmpInst::ICMP_EQ) == *CmpXZ) { 5099 ICmpInst::Predicate NewPred = 5100 ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); 5101 if (Pred == ICmpInst::ICMP_NE) 5102 NewPred = ICmpInst::getInversePredicate(NewPred); 5103 return ICmpInst::Create(Instruction::ICmp, NewPred, X, Y); 5104 } 5105 // Otherwise (X != Z): 5106 ICmpInst::Predicate NewPred = MinMax->getPredicate(); 5107 auto MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q)); 5108 if (!MinMaxCmpXZ.has_value()) { 5109 std::swap(X, Y); 5110 std::swap(CmpXZ, CmpYZ); 5111 // Re-check pre-condition X != Z 5112 if (!CmpXZ.has_value() || (Pred == ICmpInst::ICMP_EQ) == *CmpXZ) 5113 break; 5114 MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q)); 5115 } 5116 if (!MinMaxCmpXZ.has_value()) 5117 break; 5118 if (*MinMaxCmpXZ) { 5119 // Expr Fact Result 5120 // min(X, Y) == Z X < Z false 5121 // max(X, Y) == Z X > Z false 5122 // min(X, Y) != Z X < Z true 5123 // max(X, Y) != Z X > Z true 5124 return replaceInstUsesWith( 5125 I, ConstantInt::getBool(I.getType(), Pred == ICmpInst::ICMP_NE)); 5126 } else { 5127 // Expr Fact Result 5128 // min(X, Y) == Z X > Z Y == Z 5129 // max(X, Y) == Z X < Z Y == Z 5130 // min(X, Y) != Z X > Z Y != Z 5131 // max(X, Y) != Z X < Z Y != Z 5132 return FoldIntoCmpYZ(); 5133 } 5134 break; 5135 } 5136 case ICmpInst::ICMP_SLT: 5137 case ICmpInst::ICMP_ULT: 5138 case ICmpInst::ICMP_SLE: 5139 case ICmpInst::ICMP_ULE: 5140 case ICmpInst::ICMP_SGT: 5141 case ICmpInst::ICMP_UGT: 5142 case ICmpInst::ICMP_SGE: 5143 case ICmpInst::ICMP_UGE: { 5144 bool IsSame = MinMax->getPredicate() == ICmpInst::getStrictPredicate(Pred); 5145 if (*CmpXZ) { 5146 if (IsSame) { 5147 // Expr Fact Result 5148 // min(X, Y) < Z X < Z true 5149 // min(X, Y) <= Z X <= Z true 5150 // max(X, Y) > Z X > Z true 5151 // max(X, Y) >= Z X >= Z true 5152 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 5153 } else { 5154 // Expr Fact Result 5155 // max(X, Y) < Z X < Z Y < Z 5156 // max(X, Y) <= Z X <= Z Y <= Z 5157 // min(X, Y) > Z X > Z Y > Z 5158 // min(X, Y) >= Z X >= Z Y >= Z 5159 return FoldIntoCmpYZ(); 5160 } 5161 } else { 5162 if (IsSame) { 5163 // Expr Fact Result 5164 // min(X, Y) < Z X >= Z Y < Z 5165 // min(X, Y) <= Z X > Z Y <= Z 5166 // max(X, Y) > Z X <= Z Y > Z 5167 // max(X, Y) >= Z X < Z Y >= Z 5168 return FoldIntoCmpYZ(); 5169 } else { 5170 // Expr Fact Result 5171 // max(X, Y) < Z X >= Z false 5172 // max(X, Y) <= Z X > Z false 5173 // min(X, Y) > Z X <= Z false 5174 // min(X, Y) >= Z X < Z false 5175 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 5176 } 5177 } 5178 break; 5179 } 5180 default: 5181 break; 5182 } 5183 5184 return nullptr; 5185 } 5186 5187 // Canonicalize checking for a power-of-2-or-zero value: 5188 static Instruction *foldICmpPow2Test(ICmpInst &I, 5189 InstCombiner::BuilderTy &Builder) { 5190 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 5191 const CmpInst::Predicate Pred = I.getPredicate(); 5192 Value *A = nullptr; 5193 bool CheckIs; 5194 if (I.isEquality()) { 5195 // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) 5196 // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) 5197 if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), 5198 m_Deferred(A)))) || 5199 !match(Op1, m_ZeroInt())) 5200 A = nullptr; 5201 5202 // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) 5203 // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) 5204 if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) 5205 A = Op1; 5206 else if (match(Op1, 5207 m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) 5208 A = Op0; 5209 5210 CheckIs = Pred == ICmpInst::ICMP_EQ; 5211 } else if (ICmpInst::isUnsigned(Pred)) { 5212 // (A ^ (A-1)) u>= A --> ctpop(A) < 2 (two commuted variants) 5213 // ((A-1) ^ A) u< A --> ctpop(A) > 1 (two commuted variants) 5214 5215 if ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && 5216 match(Op0, m_OneUse(m_c_Xor(m_Add(m_Specific(Op1), m_AllOnes()), 5217 m_Specific(Op1))))) { 5218 A = Op1; 5219 CheckIs = Pred == ICmpInst::ICMP_UGE; 5220 } else if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) && 5221 match(Op1, m_OneUse(m_c_Xor(m_Add(m_Specific(Op0), m_AllOnes()), 5222 m_Specific(Op0))))) { 5223 A = Op0; 5224 CheckIs = Pred == ICmpInst::ICMP_ULE; 5225 } 5226 } 5227 5228 if (A) { 5229 Type *Ty = A->getType(); 5230 CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); 5231 return CheckIs ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, 5232 ConstantInt::get(Ty, 2)) 5233 : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, 5234 ConstantInt::get(Ty, 1)); 5235 } 5236 5237 return nullptr; 5238 } 5239 5240 Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { 5241 if (!I.isEquality()) 5242 return nullptr; 5243 5244 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 5245 const CmpInst::Predicate Pred = I.getPredicate(); 5246 Value *A, *B, *C, *D; 5247 if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { 5248 if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 5249 Value *OtherVal = A == Op1 ? B : A; 5250 return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); 5251 } 5252 5253 if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { 5254 // A^c1 == C^c2 --> A == C^(c1^c2) 5255 ConstantInt *C1, *C2; 5256 if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) && 5257 Op1->hasOneUse()) { 5258 Constant *NC = Builder.getInt(C1->getValue() ^ C2->getValue()); 5259 Value *Xor = Builder.CreateXor(C, NC); 5260 return new ICmpInst(Pred, A, Xor); 5261 } 5262 5263 // A^B == A^D -> B == D 5264 if (A == C) 5265 return new ICmpInst(Pred, B, D); 5266 if (A == D) 5267 return new ICmpInst(Pred, B, C); 5268 if (B == C) 5269 return new ICmpInst(Pred, A, D); 5270 if (B == D) 5271 return new ICmpInst(Pred, A, C); 5272 } 5273 } 5274 5275 // canoncalize: 5276 // (icmp eq/ne (and X, C), X) 5277 // -> (icmp eq/ne (and X, ~C), 0) 5278 { 5279 Constant *CMask; 5280 A = nullptr; 5281 if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_ImmConstant(CMask))))) 5282 A = Op1; 5283 else if (match(Op1, m_OneUse(m_And(m_Specific(Op0), m_ImmConstant(CMask))))) 5284 A = Op0; 5285 if (A) 5286 return new ICmpInst(Pred, Builder.CreateAnd(A, Builder.CreateNot(CMask)), 5287 Constant::getNullValue(A->getType())); 5288 } 5289 5290 if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { 5291 // A == (A^B) -> B == 0 5292 Value *OtherVal = A == Op0 ? B : A; 5293 return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); 5294 } 5295 5296 // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 5297 if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && 5298 match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { 5299 Value *X = nullptr, *Y = nullptr, *Z = nullptr; 5300 5301 if (A == C) { 5302 X = B; 5303 Y = D; 5304 Z = A; 5305 } else if (A == D) { 5306 X = B; 5307 Y = C; 5308 Z = A; 5309 } else if (B == C) { 5310 X = A; 5311 Y = D; 5312 Z = B; 5313 } else if (B == D) { 5314 X = A; 5315 Y = C; 5316 Z = B; 5317 } 5318 5319 if (X) { // Build (X^Y) & Z 5320 Op1 = Builder.CreateXor(X, Y); 5321 Op1 = Builder.CreateAnd(Op1, Z); 5322 return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType())); 5323 } 5324 } 5325 5326 { 5327 // Similar to above, but specialized for constant because invert is needed: 5328 // (X | C) == (Y | C) --> (X ^ Y) & ~C == 0 5329 Value *X, *Y; 5330 Constant *C; 5331 if (match(Op0, m_OneUse(m_Or(m_Value(X), m_Constant(C)))) && 5332 match(Op1, m_OneUse(m_Or(m_Value(Y), m_Specific(C))))) { 5333 Value *Xor = Builder.CreateXor(X, Y); 5334 Value *And = Builder.CreateAnd(Xor, ConstantExpr::getNot(C)); 5335 return new ICmpInst(Pred, And, Constant::getNullValue(And->getType())); 5336 } 5337 } 5338 5339 if (match(Op1, m_ZExt(m_Value(A))) && 5340 (Op0->hasOneUse() || Op1->hasOneUse())) { 5341 // (B & (Pow2C-1)) == zext A --> A == trunc B 5342 // (B & (Pow2C-1)) != zext A --> A != trunc B 5343 const APInt *MaskC; 5344 if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) && 5345 MaskC->countr_one() == A->getType()->getScalarSizeInBits()) 5346 return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); 5347 } 5348 5349 // (A >> C) == (B >> C) --> (A^B) u< (1 << C) 5350 // For lshr and ashr pairs. 5351 const APInt *AP1, *AP2; 5352 if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_APIntAllowUndef(AP1)))) && 5353 match(Op1, m_OneUse(m_LShr(m_Value(B), m_APIntAllowUndef(AP2))))) || 5354 (match(Op0, m_OneUse(m_AShr(m_Value(A), m_APIntAllowUndef(AP1)))) && 5355 match(Op1, m_OneUse(m_AShr(m_Value(B), m_APIntAllowUndef(AP2)))))) { 5356 if (AP1 != AP2) 5357 return nullptr; 5358 unsigned TypeBits = AP1->getBitWidth(); 5359 unsigned ShAmt = AP1->getLimitedValue(TypeBits); 5360 if (ShAmt < TypeBits && ShAmt != 0) { 5361 ICmpInst::Predicate NewPred = 5362 Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; 5363 Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); 5364 APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); 5365 return new ICmpInst(NewPred, Xor, ConstantInt::get(A->getType(), CmpVal)); 5366 } 5367 } 5368 5369 // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 5370 ConstantInt *Cst1; 5371 if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && 5372 match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { 5373 unsigned TypeBits = Cst1->getBitWidth(); 5374 unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); 5375 if (ShAmt < TypeBits && ShAmt != 0) { 5376 Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); 5377 APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); 5378 Value *And = Builder.CreateAnd(Xor, Builder.getInt(AndVal), 5379 I.getName() + ".mask"); 5380 return new ICmpInst(Pred, And, Constant::getNullValue(Cst1->getType())); 5381 } 5382 } 5383 5384 // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to 5385 // "icmp (and X, mask), cst" 5386 uint64_t ShAmt = 0; 5387 if (Op0->hasOneUse() && 5388 match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), m_ConstantInt(ShAmt))))) && 5389 match(Op1, m_ConstantInt(Cst1)) && 5390 // Only do this when A has multiple uses. This is most important to do 5391 // when it exposes other optimizations. 5392 !A->hasOneUse()) { 5393 unsigned ASize = cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); 5394 5395 if (ShAmt < ASize) { 5396 APInt MaskV = 5397 APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); 5398 MaskV <<= ShAmt; 5399 5400 APInt CmpV = Cst1->getValue().zext(ASize); 5401 CmpV <<= ShAmt; 5402 5403 Value *Mask = Builder.CreateAnd(A, Builder.getInt(MaskV)); 5404 return new ICmpInst(Pred, Mask, Builder.getInt(CmpV)); 5405 } 5406 } 5407 5408 if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I, Builder)) 5409 return ICmp; 5410 5411 // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the 5412 // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", 5413 // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps 5414 // of instcombine. 5415 unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); 5416 if (match(Op0, m_AShr(m_Trunc(m_Value(A)), m_SpecificInt(BitWidth - 1))) && 5417 match(Op1, m_Trunc(m_LShr(m_Specific(A), m_SpecificInt(BitWidth)))) && 5418 A->getType()->getScalarSizeInBits() == BitWidth * 2 && 5419 (I.getOperand(0)->hasOneUse() || I.getOperand(1)->hasOneUse())) { 5420 APInt C = APInt::getOneBitSet(BitWidth * 2, BitWidth - 1); 5421 Value *Add = Builder.CreateAdd(A, ConstantInt::get(A->getType(), C)); 5422 return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT 5423 : ICmpInst::ICMP_UGE, 5424 Add, ConstantInt::get(A->getType(), C.shl(1))); 5425 } 5426 5427 // Canonicalize: 5428 // Assume B_Pow2 != 0 5429 // 1. A & B_Pow2 != B_Pow2 -> A & B_Pow2 == 0 5430 // 2. A & B_Pow2 == B_Pow2 -> A & B_Pow2 != 0 5431 if (match(Op0, m_c_And(m_Specific(Op1), m_Value())) && 5432 isKnownToBeAPowerOfTwo(Op1, /* OrZero */ false, 0, &I)) 5433 return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0, 5434 ConstantInt::getNullValue(Op0->getType())); 5435 5436 if (match(Op1, m_c_And(m_Specific(Op0), m_Value())) && 5437 isKnownToBeAPowerOfTwo(Op0, /* OrZero */ false, 0, &I)) 5438 return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1, 5439 ConstantInt::getNullValue(Op1->getType())); 5440 5441 // Canonicalize: 5442 // icmp eq/ne X, OneUse(rotate-right(X)) 5443 // -> icmp eq/ne X, rotate-left(X) 5444 // We generally try to convert rotate-right -> rotate-left, this just 5445 // canonicalizes another case. 5446 CmpInst::Predicate PredUnused = Pred; 5447 if (match(&I, m_c_ICmp(PredUnused, m_Value(A), 5448 m_OneUse(m_Intrinsic<Intrinsic::fshr>( 5449 m_Deferred(A), m_Deferred(A), m_Value(B)))))) 5450 return new ICmpInst( 5451 Pred, A, 5452 Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B})); 5453 5454 // Canonicalize: 5455 // icmp eq/ne OneUse(A ^ Cst), B --> icmp eq/ne (A ^ B), Cst 5456 Constant *Cst; 5457 if (match(&I, m_c_ICmp(PredUnused, 5458 m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))), 5459 m_CombineAnd(m_Value(B), m_Unless(m_ImmConstant()))))) 5460 return new ICmpInst(Pred, Builder.CreateXor(A, B), Cst); 5461 5462 { 5463 // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2) 5464 auto m_Matcher = 5465 m_CombineOr(m_CombineOr(m_c_Add(m_Value(B), m_Deferred(A)), 5466 m_c_Xor(m_Value(B), m_Deferred(A))), 5467 m_Sub(m_Value(B), m_Deferred(A))); 5468 std::optional<bool> IsZero = std::nullopt; 5469 if (match(&I, m_c_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)), 5470 m_Deferred(A)))) 5471 IsZero = false; 5472 // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0) 5473 else if (match(&I, 5474 m_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)), 5475 m_Zero()))) 5476 IsZero = true; 5477 5478 if (IsZero && isKnownToBeAPowerOfTwo(A, /* OrZero */ true, /*Depth*/ 0, &I)) 5479 // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2) 5480 // -> (icmp eq/ne (and X, P2), 0) 5481 // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0) 5482 // -> (icmp eq/ne (and X, P2), P2) 5483 return new ICmpInst(Pred, Builder.CreateAnd(B, A), 5484 *IsZero ? A 5485 : ConstantInt::getNullValue(A->getType())); 5486 } 5487 5488 return nullptr; 5489 } 5490 5491 Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { 5492 ICmpInst::Predicate Pred = ICmp.getPredicate(); 5493 Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); 5494 5495 // Try to canonicalize trunc + compare-to-constant into a mask + cmp. 5496 // The trunc masks high bits while the compare may effectively mask low bits. 5497 Value *X; 5498 const APInt *C; 5499 if (!match(Op0, m_OneUse(m_Trunc(m_Value(X)))) || !match(Op1, m_APInt(C))) 5500 return nullptr; 5501 5502 // This matches patterns corresponding to tests of the signbit as well as: 5503 // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) 5504 // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) 5505 APInt Mask; 5506 if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true /* WithTrunc */)) { 5507 Value *And = Builder.CreateAnd(X, Mask); 5508 Constant *Zero = ConstantInt::getNullValue(X->getType()); 5509 return new ICmpInst(Pred, And, Zero); 5510 } 5511 5512 unsigned SrcBits = X->getType()->getScalarSizeInBits(); 5513 if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) { 5514 // If C is a negative power-of-2 (high-bit mask): 5515 // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) 5516 Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); 5517 Value *And = Builder.CreateAnd(X, MaskC); 5518 return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); 5519 } 5520 5521 if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) { 5522 // If C is not-of-power-of-2 (one clear bit): 5523 // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) 5524 Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); 5525 Value *And = Builder.CreateAnd(X, MaskC); 5526 return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); 5527 } 5528 5529 if (auto *II = dyn_cast<IntrinsicInst>(X)) { 5530 if (II->getIntrinsicID() == Intrinsic::cttz || 5531 II->getIntrinsicID() == Intrinsic::ctlz) { 5532 unsigned MaxRet = SrcBits; 5533 // If the "is_zero_poison" argument is set, then we know at least 5534 // one bit is set in the input, so the result is always at least one 5535 // less than the full bitwidth of that input. 5536 if (match(II->getArgOperand(1), m_One())) 5537 MaxRet--; 5538 5539 // Make sure the destination is wide enough to hold the largest output of 5540 // the intrinsic. 5541 if (llvm::Log2_32(MaxRet) + 1 <= Op0->getType()->getScalarSizeInBits()) 5542 if (Instruction *I = 5543 foldICmpIntrinsicWithConstant(ICmp, II, C->zext(SrcBits))) 5544 return I; 5545 } 5546 } 5547 5548 return nullptr; 5549 } 5550 5551 Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { 5552 assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); 5553 auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); 5554 Value *X; 5555 if (!match(CastOp0, m_ZExtOrSExt(m_Value(X)))) 5556 return nullptr; 5557 5558 bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; 5559 bool IsSignedCmp = ICmp.isSigned(); 5560 5561 // icmp Pred (ext X), (ext Y) 5562 Value *Y; 5563 if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) { 5564 bool IsZext0 = isa<ZExtInst>(ICmp.getOperand(0)); 5565 bool IsZext1 = isa<ZExtInst>(ICmp.getOperand(1)); 5566 5567 if (IsZext0 != IsZext1) { 5568 // If X and Y and both i1 5569 // (icmp eq/ne (zext X) (sext Y)) 5570 // eq -> (icmp eq (or X, Y), 0) 5571 // ne -> (icmp ne (or X, Y), 0) 5572 if (ICmp.isEquality() && X->getType()->isIntOrIntVectorTy(1) && 5573 Y->getType()->isIntOrIntVectorTy(1)) 5574 return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y), 5575 Constant::getNullValue(X->getType())); 5576 5577 // If we have mismatched casts and zext has the nneg flag, we can 5578 // treat the "zext nneg" as "sext". Otherwise, we cannot fold and quit. 5579 5580 auto *NonNegInst0 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(0)); 5581 auto *NonNegInst1 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(1)); 5582 5583 bool IsNonNeg0 = NonNegInst0 && NonNegInst0->hasNonNeg(); 5584 bool IsNonNeg1 = NonNegInst1 && NonNegInst1->hasNonNeg(); 5585 5586 if ((IsZext0 && IsNonNeg0) || (IsZext1 && IsNonNeg1)) 5587 IsSignedExt = true; 5588 else 5589 return nullptr; 5590 } 5591 5592 // Not an extension from the same type? 5593 Type *XTy = X->getType(), *YTy = Y->getType(); 5594 if (XTy != YTy) { 5595 // One of the casts must have one use because we are creating a new cast. 5596 if (!ICmp.getOperand(0)->hasOneUse() && !ICmp.getOperand(1)->hasOneUse()) 5597 return nullptr; 5598 // Extend the narrower operand to the type of the wider operand. 5599 CastInst::CastOps CastOpcode = 5600 IsSignedExt ? Instruction::SExt : Instruction::ZExt; 5601 if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) 5602 X = Builder.CreateCast(CastOpcode, X, YTy); 5603 else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) 5604 Y = Builder.CreateCast(CastOpcode, Y, XTy); 5605 else 5606 return nullptr; 5607 } 5608 5609 // (zext X) == (zext Y) --> X == Y 5610 // (sext X) == (sext Y) --> X == Y 5611 if (ICmp.isEquality()) 5612 return new ICmpInst(ICmp.getPredicate(), X, Y); 5613 5614 // A signed comparison of sign extended values simplifies into a 5615 // signed comparison. 5616 if (IsSignedCmp && IsSignedExt) 5617 return new ICmpInst(ICmp.getPredicate(), X, Y); 5618 5619 // The other three cases all fold into an unsigned comparison. 5620 return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y); 5621 } 5622 5623 // Below here, we are only folding a compare with constant. 5624 auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); 5625 if (!C) 5626 return nullptr; 5627 5628 // If a lossless truncate is possible... 5629 Type *SrcTy = CastOp0->getSrcTy(); 5630 Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode()); 5631 if (Res) { 5632 if (ICmp.isEquality()) 5633 return new ICmpInst(ICmp.getPredicate(), X, Res); 5634 5635 // A signed comparison of sign extended values simplifies into a 5636 // signed comparison. 5637 if (IsSignedExt && IsSignedCmp) 5638 return new ICmpInst(ICmp.getPredicate(), X, Res); 5639 5640 // The other three cases all fold into an unsigned comparison. 5641 return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res); 5642 } 5643 5644 // The re-extended constant changed, partly changed (in the case of a vector), 5645 // or could not be determined to be equal (in the case of a constant 5646 // expression), so the constant cannot be represented in the shorter type. 5647 // All the cases that fold to true or false will have already been handled 5648 // by simplifyICmpInst, so only deal with the tricky case. 5649 if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C)) 5650 return nullptr; 5651 5652 // Is source op positive? 5653 // icmp ult (sext X), C --> icmp sgt X, -1 5654 if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) 5655 return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy)); 5656 5657 // Is source op negative? 5658 // icmp ugt (sext X), C --> icmp slt X, 0 5659 assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); 5660 return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); 5661 } 5662 5663 /// Handle icmp (cast x), (cast or constant). 5664 Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { 5665 // If any operand of ICmp is a inttoptr roundtrip cast then remove it as 5666 // icmp compares only pointer's value. 5667 // icmp (inttoptr (ptrtoint p1)), p2 --> icmp p1, p2. 5668 Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0)); 5669 Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1)); 5670 if (SimplifiedOp0 || SimplifiedOp1) 5671 return new ICmpInst(ICmp.getPredicate(), 5672 SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0), 5673 SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1)); 5674 5675 auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); 5676 if (!CastOp0) 5677 return nullptr; 5678 if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1))) 5679 return nullptr; 5680 5681 Value *Op0Src = CastOp0->getOperand(0); 5682 Type *SrcTy = CastOp0->getSrcTy(); 5683 Type *DestTy = CastOp0->getDestTy(); 5684 5685 // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the 5686 // integer type is the same size as the pointer type. 5687 auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { 5688 if (isa<VectorType>(SrcTy)) { 5689 SrcTy = cast<VectorType>(SrcTy)->getElementType(); 5690 DestTy = cast<VectorType>(DestTy)->getElementType(); 5691 } 5692 return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); 5693 }; 5694 if (CastOp0->getOpcode() == Instruction::PtrToInt && 5695 CompatibleSizes(SrcTy, DestTy)) { 5696 Value *NewOp1 = nullptr; 5697 if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { 5698 Value *PtrSrc = PtrToIntOp1->getOperand(0); 5699 if (PtrSrc->getType() == Op0Src->getType()) 5700 NewOp1 = PtrToIntOp1->getOperand(0); 5701 } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { 5702 NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); 5703 } 5704 5705 if (NewOp1) 5706 return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); 5707 } 5708 5709 if (Instruction *R = foldICmpWithTrunc(ICmp)) 5710 return R; 5711 5712 return foldICmpWithZextOrSext(ICmp); 5713 } 5714 5715 static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS, bool IsSigned) { 5716 switch (BinaryOp) { 5717 default: 5718 llvm_unreachable("Unsupported binary op"); 5719 case Instruction::Add: 5720 case Instruction::Sub: 5721 return match(RHS, m_Zero()); 5722 case Instruction::Mul: 5723 return !(RHS->getType()->isIntOrIntVectorTy(1) && IsSigned) && 5724 match(RHS, m_One()); 5725 } 5726 } 5727 5728 OverflowResult 5729 InstCombinerImpl::computeOverflow(Instruction::BinaryOps BinaryOp, 5730 bool IsSigned, Value *LHS, Value *RHS, 5731 Instruction *CxtI) const { 5732 switch (BinaryOp) { 5733 default: 5734 llvm_unreachable("Unsupported binary op"); 5735 case Instruction::Add: 5736 if (IsSigned) 5737 return computeOverflowForSignedAdd(LHS, RHS, CxtI); 5738 else 5739 return computeOverflowForUnsignedAdd(LHS, RHS, CxtI); 5740 case Instruction::Sub: 5741 if (IsSigned) 5742 return computeOverflowForSignedSub(LHS, RHS, CxtI); 5743 else 5744 return computeOverflowForUnsignedSub(LHS, RHS, CxtI); 5745 case Instruction::Mul: 5746 if (IsSigned) 5747 return computeOverflowForSignedMul(LHS, RHS, CxtI); 5748 else 5749 return computeOverflowForUnsignedMul(LHS, RHS, CxtI); 5750 } 5751 } 5752 5753 bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, 5754 bool IsSigned, Value *LHS, 5755 Value *RHS, Instruction &OrigI, 5756 Value *&Result, 5757 Constant *&Overflow) { 5758 if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS)) 5759 std::swap(LHS, RHS); 5760 5761 // If the overflow check was an add followed by a compare, the insertion point 5762 // may be pointing to the compare. We want to insert the new instructions 5763 // before the add in case there are uses of the add between the add and the 5764 // compare. 5765 Builder.SetInsertPoint(&OrigI); 5766 5767 Type *OverflowTy = Type::getInt1Ty(LHS->getContext()); 5768 if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType())) 5769 OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount()); 5770 5771 if (isNeutralValue(BinaryOp, RHS, IsSigned)) { 5772 Result = LHS; 5773 Overflow = ConstantInt::getFalse(OverflowTy); 5774 return true; 5775 } 5776 5777 switch (computeOverflow(BinaryOp, IsSigned, LHS, RHS, &OrigI)) { 5778 case OverflowResult::MayOverflow: 5779 return false; 5780 case OverflowResult::AlwaysOverflowsLow: 5781 case OverflowResult::AlwaysOverflowsHigh: 5782 Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); 5783 Result->takeName(&OrigI); 5784 Overflow = ConstantInt::getTrue(OverflowTy); 5785 return true; 5786 case OverflowResult::NeverOverflows: 5787 Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); 5788 Result->takeName(&OrigI); 5789 Overflow = ConstantInt::getFalse(OverflowTy); 5790 if (auto *Inst = dyn_cast<Instruction>(Result)) { 5791 if (IsSigned) 5792 Inst->setHasNoSignedWrap(); 5793 else 5794 Inst->setHasNoUnsignedWrap(); 5795 } 5796 return true; 5797 } 5798 5799 llvm_unreachable("Unexpected overflow result"); 5800 } 5801 5802 /// Recognize and process idiom involving test for multiplication 5803 /// overflow. 5804 /// 5805 /// The caller has matched a pattern of the form: 5806 /// I = cmp u (mul(zext A, zext B), V 5807 /// The function checks if this is a test for overflow and if so replaces 5808 /// multiplication with call to 'mul.with.overflow' intrinsic. 5809 /// 5810 /// \param I Compare instruction. 5811 /// \param MulVal Result of 'mult' instruction. It is one of the arguments of 5812 /// the compare instruction. Must be of integer type. 5813 /// \param OtherVal The other argument of compare instruction. 5814 /// \returns Instruction which must replace the compare instruction, NULL if no 5815 /// replacement required. 5816 static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, 5817 const APInt *OtherVal, 5818 InstCombinerImpl &IC) { 5819 // Don't bother doing this transformation for pointers, don't do it for 5820 // vectors. 5821 if (!isa<IntegerType>(MulVal->getType())) 5822 return nullptr; 5823 5824 auto *MulInstr = dyn_cast<Instruction>(MulVal); 5825 if (!MulInstr) 5826 return nullptr; 5827 assert(MulInstr->getOpcode() == Instruction::Mul); 5828 5829 auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)), 5830 *RHS = cast<ZExtInst>(MulInstr->getOperand(1)); 5831 assert(LHS->getOpcode() == Instruction::ZExt); 5832 assert(RHS->getOpcode() == Instruction::ZExt); 5833 Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); 5834 5835 // Calculate type and width of the result produced by mul.with.overflow. 5836 Type *TyA = A->getType(), *TyB = B->getType(); 5837 unsigned WidthA = TyA->getPrimitiveSizeInBits(), 5838 WidthB = TyB->getPrimitiveSizeInBits(); 5839 unsigned MulWidth; 5840 Type *MulType; 5841 if (WidthB > WidthA) { 5842 MulWidth = WidthB; 5843 MulType = TyB; 5844 } else { 5845 MulWidth = WidthA; 5846 MulType = TyA; 5847 } 5848 5849 // In order to replace the original mul with a narrower mul.with.overflow, 5850 // all uses must ignore upper bits of the product. The number of used low 5851 // bits must be not greater than the width of mul.with.overflow. 5852 if (MulVal->hasNUsesOrMore(2)) 5853 for (User *U : MulVal->users()) { 5854 if (U == &I) 5855 continue; 5856 if (TruncInst *TI = dyn_cast<TruncInst>(U)) { 5857 // Check if truncation ignores bits above MulWidth. 5858 unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); 5859 if (TruncWidth > MulWidth) 5860 return nullptr; 5861 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { 5862 // Check if AND ignores bits above MulWidth. 5863 if (BO->getOpcode() != Instruction::And) 5864 return nullptr; 5865 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { 5866 const APInt &CVal = CI->getValue(); 5867 if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth) 5868 return nullptr; 5869 } else { 5870 // In this case we could have the operand of the binary operation 5871 // being defined in another block, and performing the replacement 5872 // could break the dominance relation. 5873 return nullptr; 5874 } 5875 } else { 5876 // Other uses prohibit this transformation. 5877 return nullptr; 5878 } 5879 } 5880 5881 // Recognize patterns 5882 switch (I.getPredicate()) { 5883 case ICmpInst::ICMP_UGT: { 5884 // Recognize pattern: 5885 // mulval = mul(zext A, zext B) 5886 // cmp ugt mulval, max 5887 APInt MaxVal = APInt::getMaxValue(MulWidth); 5888 MaxVal = MaxVal.zext(OtherVal->getBitWidth()); 5889 if (MaxVal.eq(*OtherVal)) 5890 break; // Recognized 5891 return nullptr; 5892 } 5893 5894 case ICmpInst::ICMP_ULT: { 5895 // Recognize pattern: 5896 // mulval = mul(zext A, zext B) 5897 // cmp ule mulval, max + 1 5898 APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth); 5899 if (MaxVal.eq(*OtherVal)) 5900 break; // Recognized 5901 return nullptr; 5902 } 5903 5904 default: 5905 return nullptr; 5906 } 5907 5908 InstCombiner::BuilderTy &Builder = IC.Builder; 5909 Builder.SetInsertPoint(MulInstr); 5910 5911 // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) 5912 Value *MulA = A, *MulB = B; 5913 if (WidthA < MulWidth) 5914 MulA = Builder.CreateZExt(A, MulType); 5915 if (WidthB < MulWidth) 5916 MulB = Builder.CreateZExt(B, MulType); 5917 Function *F = Intrinsic::getDeclaration( 5918 I.getModule(), Intrinsic::umul_with_overflow, MulType); 5919 CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); 5920 IC.addToWorklist(MulInstr); 5921 5922 // If there are uses of mul result other than the comparison, we know that 5923 // they are truncation or binary AND. Change them to use result of 5924 // mul.with.overflow and adjust properly mask/size. 5925 if (MulVal->hasNUsesOrMore(2)) { 5926 Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); 5927 for (User *U : make_early_inc_range(MulVal->users())) { 5928 if (U == &I) 5929 continue; 5930 if (TruncInst *TI = dyn_cast<TruncInst>(U)) { 5931 if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) 5932 IC.replaceInstUsesWith(*TI, Mul); 5933 else 5934 TI->setOperand(0, Mul); 5935 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { 5936 assert(BO->getOpcode() == Instruction::And); 5937 // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) 5938 ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); 5939 APInt ShortMask = CI->getValue().trunc(MulWidth); 5940 Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask); 5941 Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType()); 5942 IC.replaceInstUsesWith(*BO, Zext); 5943 } else { 5944 llvm_unreachable("Unexpected Binary operation"); 5945 } 5946 IC.addToWorklist(cast<Instruction>(U)); 5947 } 5948 } 5949 5950 // The original icmp gets replaced with the overflow value, maybe inverted 5951 // depending on predicate. 5952 if (I.getPredicate() == ICmpInst::ICMP_ULT) { 5953 Value *Res = Builder.CreateExtractValue(Call, 1); 5954 return BinaryOperator::CreateNot(Res); 5955 } 5956 5957 return ExtractValueInst::Create(Call, 1); 5958 } 5959 5960 /// When performing a comparison against a constant, it is possible that not all 5961 /// the bits in the LHS are demanded. This helper method computes the mask that 5962 /// IS demanded. 5963 static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { 5964 const APInt *RHS; 5965 if (!match(I.getOperand(1), m_APInt(RHS))) 5966 return APInt::getAllOnes(BitWidth); 5967 5968 // If this is a normal comparison, it demands all bits. If it is a sign bit 5969 // comparison, it only demands the sign bit. 5970 bool UnusedBit; 5971 if (InstCombiner::isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) 5972 return APInt::getSignMask(BitWidth); 5973 5974 switch (I.getPredicate()) { 5975 // For a UGT comparison, we don't care about any bits that 5976 // correspond to the trailing ones of the comparand. The value of these 5977 // bits doesn't impact the outcome of the comparison, because any value 5978 // greater than the RHS must differ in a bit higher than these due to carry. 5979 case ICmpInst::ICMP_UGT: 5980 return APInt::getBitsSetFrom(BitWidth, RHS->countr_one()); 5981 5982 // Similarly, for a ULT comparison, we don't care about the trailing zeros. 5983 // Any value less than the RHS must differ in a higher bit because of carries. 5984 case ICmpInst::ICMP_ULT: 5985 return APInt::getBitsSetFrom(BitWidth, RHS->countr_zero()); 5986 5987 default: 5988 return APInt::getAllOnes(BitWidth); 5989 } 5990 } 5991 5992 /// Check that one use is in the same block as the definition and all 5993 /// other uses are in blocks dominated by a given block. 5994 /// 5995 /// \param DI Definition 5996 /// \param UI Use 5997 /// \param DB Block that must dominate all uses of \p DI outside 5998 /// the parent block 5999 /// \return true when \p UI is the only use of \p DI in the parent block 6000 /// and all other uses of \p DI are in blocks dominated by \p DB. 6001 /// 6002 bool InstCombinerImpl::dominatesAllUses(const Instruction *DI, 6003 const Instruction *UI, 6004 const BasicBlock *DB) const { 6005 assert(DI && UI && "Instruction not defined\n"); 6006 // Ignore incomplete definitions. 6007 if (!DI->getParent()) 6008 return false; 6009 // DI and UI must be in the same block. 6010 if (DI->getParent() != UI->getParent()) 6011 return false; 6012 // Protect from self-referencing blocks. 6013 if (DI->getParent() == DB) 6014 return false; 6015 for (const User *U : DI->users()) { 6016 auto *Usr = cast<Instruction>(U); 6017 if (Usr != UI && !DT.dominates(DB, Usr->getParent())) 6018 return false; 6019 } 6020 return true; 6021 } 6022 6023 /// Return true when the instruction sequence within a block is select-cmp-br. 6024 static bool isChainSelectCmpBranch(const SelectInst *SI) { 6025 const BasicBlock *BB = SI->getParent(); 6026 if (!BB) 6027 return false; 6028 auto *BI = dyn_cast_or_null<BranchInst>(BB->getTerminator()); 6029 if (!BI || BI->getNumSuccessors() != 2) 6030 return false; 6031 auto *IC = dyn_cast<ICmpInst>(BI->getCondition()); 6032 if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI)) 6033 return false; 6034 return true; 6035 } 6036 6037 /// True when a select result is replaced by one of its operands 6038 /// in select-icmp sequence. This will eventually result in the elimination 6039 /// of the select. 6040 /// 6041 /// \param SI Select instruction 6042 /// \param Icmp Compare instruction 6043 /// \param SIOpd Operand that replaces the select 6044 /// 6045 /// Notes: 6046 /// - The replacement is global and requires dominator information 6047 /// - The caller is responsible for the actual replacement 6048 /// 6049 /// Example: 6050 /// 6051 /// entry: 6052 /// %4 = select i1 %3, %C* %0, %C* null 6053 /// %5 = icmp eq %C* %4, null 6054 /// br i1 %5, label %9, label %7 6055 /// ... 6056 /// ; <label>:7 ; preds = %entry 6057 /// %8 = getelementptr inbounds %C* %4, i64 0, i32 0 6058 /// ... 6059 /// 6060 /// can be transformed to 6061 /// 6062 /// %5 = icmp eq %C* %0, null 6063 /// %6 = select i1 %3, i1 %5, i1 true 6064 /// br i1 %6, label %9, label %7 6065 /// ... 6066 /// ; <label>:7 ; preds = %entry 6067 /// %8 = getelementptr inbounds %C* %0, i64 0, i32 0 // replace by %0! 6068 /// 6069 /// Similar when the first operand of the select is a constant or/and 6070 /// the compare is for not equal rather than equal. 6071 /// 6072 /// NOTE: The function is only called when the select and compare constants 6073 /// are equal, the optimization can work only for EQ predicates. This is not a 6074 /// major restriction since a NE compare should be 'normalized' to an equal 6075 /// compare, which usually happens in the combiner and test case 6076 /// select-cmp-br.ll checks for it. 6077 bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI, 6078 const ICmpInst *Icmp, 6079 const unsigned SIOpd) { 6080 assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); 6081 if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { 6082 BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); 6083 // The check for the single predecessor is not the best that can be 6084 // done. But it protects efficiently against cases like when SI's 6085 // home block has two successors, Succ and Succ1, and Succ1 predecessor 6086 // of Succ. Then SI can't be replaced by SIOpd because the use that gets 6087 // replaced can be reached on either path. So the uniqueness check 6088 // guarantees that the path all uses of SI (outside SI's parent) are on 6089 // is disjoint from all other paths out of SI. But that information 6090 // is more expensive to compute, and the trade-off here is in favor 6091 // of compile-time. It should also be noticed that we check for a single 6092 // predecessor and not only uniqueness. This to handle the situation when 6093 // Succ and Succ1 points to the same basic block. 6094 if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { 6095 NumSel++; 6096 SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent()); 6097 return true; 6098 } 6099 } 6100 return false; 6101 } 6102 6103 /// Try to fold the comparison based on range information we can get by checking 6104 /// whether bits are known to be zero or one in the inputs. 6105 Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { 6106 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 6107 Type *Ty = Op0->getType(); 6108 ICmpInst::Predicate Pred = I.getPredicate(); 6109 6110 // Get scalar or pointer size. 6111 unsigned BitWidth = Ty->isIntOrIntVectorTy() 6112 ? Ty->getScalarSizeInBits() 6113 : DL.getPointerTypeSizeInBits(Ty->getScalarType()); 6114 6115 if (!BitWidth) 6116 return nullptr; 6117 6118 KnownBits Op0Known(BitWidth); 6119 KnownBits Op1Known(BitWidth); 6120 6121 { 6122 // Don't use dominating conditions when folding icmp using known bits. This 6123 // may convert signed into unsigned predicates in ways that other passes 6124 // (especially IndVarSimplify) may not be able to reliably undo. 6125 SQ.DC = nullptr; 6126 auto _ = make_scope_exit([&]() { SQ.DC = &DC; }); 6127 if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth), 6128 Op0Known, 0)) 6129 return &I; 6130 6131 if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) 6132 return &I; 6133 } 6134 6135 // Given the known and unknown bits, compute a range that the LHS could be 6136 // in. Compute the Min, Max and RHS values based on the known bits. For the 6137 // EQ and NE we use unsigned values. 6138 APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); 6139 APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); 6140 if (I.isSigned()) { 6141 Op0Min = Op0Known.getSignedMinValue(); 6142 Op0Max = Op0Known.getSignedMaxValue(); 6143 Op1Min = Op1Known.getSignedMinValue(); 6144 Op1Max = Op1Known.getSignedMaxValue(); 6145 } else { 6146 Op0Min = Op0Known.getMinValue(); 6147 Op0Max = Op0Known.getMaxValue(); 6148 Op1Min = Op1Known.getMinValue(); 6149 Op1Max = Op1Known.getMaxValue(); 6150 } 6151 6152 // If Min and Max are known to be the same, then SimplifyDemandedBits figured 6153 // out that the LHS or RHS is a constant. Constant fold this now, so that 6154 // code below can assume that Min != Max. 6155 if (!isa<Constant>(Op0) && Op0Min == Op0Max) 6156 return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1); 6157 if (!isa<Constant>(Op1) && Op1Min == Op1Max) 6158 return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); 6159 6160 // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a 6161 // min/max canonical compare with some other compare. That could lead to 6162 // conflict with select canonicalization and infinite looping. 6163 // FIXME: This constraint may go away if min/max intrinsics are canonical. 6164 auto isMinMaxCmp = [&](Instruction &Cmp) { 6165 if (!Cmp.hasOneUse()) 6166 return false; 6167 Value *A, *B; 6168 SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor; 6169 if (!SelectPatternResult::isMinOrMax(SPF)) 6170 return false; 6171 return match(Op0, m_MaxOrMin(m_Value(), m_Value())) || 6172 match(Op1, m_MaxOrMin(m_Value(), m_Value())); 6173 }; 6174 if (!isMinMaxCmp(I)) { 6175 switch (Pred) { 6176 default: 6177 break; 6178 case ICmpInst::ICMP_ULT: { 6179 if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) 6180 return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); 6181 const APInt *CmpC; 6182 if (match(Op1, m_APInt(CmpC))) { 6183 // A <u C -> A == C-1 if min(A)+1 == C 6184 if (*CmpC == Op0Min + 1) 6185 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, 6186 ConstantInt::get(Op1->getType(), *CmpC - 1)); 6187 // X <u C --> X == 0, if the number of zero bits in the bottom of X 6188 // exceeds the log2 of C. 6189 if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) 6190 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, 6191 Constant::getNullValue(Op1->getType())); 6192 } 6193 break; 6194 } 6195 case ICmpInst::ICMP_UGT: { 6196 if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) 6197 return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); 6198 const APInt *CmpC; 6199 if (match(Op1, m_APInt(CmpC))) { 6200 // A >u C -> A == C+1 if max(a)-1 == C 6201 if (*CmpC == Op0Max - 1) 6202 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, 6203 ConstantInt::get(Op1->getType(), *CmpC + 1)); 6204 // X >u C --> X != 0, if the number of zero bits in the bottom of X 6205 // exceeds the log2 of C. 6206 if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) 6207 return new ICmpInst(ICmpInst::ICMP_NE, Op0, 6208 Constant::getNullValue(Op1->getType())); 6209 } 6210 break; 6211 } 6212 case ICmpInst::ICMP_SLT: { 6213 if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) 6214 return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); 6215 const APInt *CmpC; 6216 if (match(Op1, m_APInt(CmpC))) { 6217 if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C 6218 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, 6219 ConstantInt::get(Op1->getType(), *CmpC - 1)); 6220 } 6221 break; 6222 } 6223 case ICmpInst::ICMP_SGT: { 6224 if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) 6225 return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); 6226 const APInt *CmpC; 6227 if (match(Op1, m_APInt(CmpC))) { 6228 if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C 6229 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, 6230 ConstantInt::get(Op1->getType(), *CmpC + 1)); 6231 } 6232 break; 6233 } 6234 } 6235 } 6236 6237 // Based on the range information we know about the LHS, see if we can 6238 // simplify this comparison. For example, (x&4) < 8 is always true. 6239 switch (Pred) { 6240 default: 6241 llvm_unreachable("Unknown icmp opcode!"); 6242 case ICmpInst::ICMP_EQ: 6243 case ICmpInst::ICMP_NE: { 6244 if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) 6245 return replaceInstUsesWith( 6246 I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE)); 6247 6248 // If all bits are known zero except for one, then we know at most one bit 6249 // is set. If the comparison is against zero, then this is a check to see if 6250 // *that* bit is set. 6251 APInt Op0KnownZeroInverted = ~Op0Known.Zero; 6252 if (Op1Known.isZero()) { 6253 // If the LHS is an AND with the same constant, look through it. 6254 Value *LHS = nullptr; 6255 const APInt *LHSC; 6256 if (!match(Op0, m_And(m_Value(LHS), m_APInt(LHSC))) || 6257 *LHSC != Op0KnownZeroInverted) 6258 LHS = Op0; 6259 6260 Value *X; 6261 const APInt *C1; 6262 if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) { 6263 Type *XTy = X->getType(); 6264 unsigned Log2C1 = C1->countr_zero(); 6265 APInt C2 = Op0KnownZeroInverted; 6266 APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1; 6267 if (C2Pow2.isPowerOf2()) { 6268 // iff (C1 is pow2) & ((C2 & ~(C1-1)) + C1) is pow2): 6269 // ((C1 << X) & C2) == 0 -> X >= (Log2(C2+C1) - Log2(C1)) 6270 // ((C1 << X) & C2) != 0 -> X < (Log2(C2+C1) - Log2(C1)) 6271 unsigned Log2C2 = C2Pow2.countr_zero(); 6272 auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1); 6273 auto NewPred = 6274 Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; 6275 return new ICmpInst(NewPred, X, CmpC); 6276 } 6277 } 6278 } 6279 6280 // Op0 eq C_Pow2 -> Op0 ne 0 if Op0 is known to be C_Pow2 or zero. 6281 if (Op1Known.isConstant() && Op1Known.getConstant().isPowerOf2() && 6282 (Op0Known & Op1Known) == Op0Known) 6283 return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0, 6284 ConstantInt::getNullValue(Op1->getType())); 6285 break; 6286 } 6287 case ICmpInst::ICMP_ULT: { 6288 if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) 6289 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6290 if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) 6291 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6292 break; 6293 } 6294 case ICmpInst::ICMP_UGT: { 6295 if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) 6296 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6297 if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) 6298 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6299 break; 6300 } 6301 case ICmpInst::ICMP_SLT: { 6302 if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) 6303 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6304 if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) 6305 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6306 break; 6307 } 6308 case ICmpInst::ICMP_SGT: { 6309 if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) 6310 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6311 if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) 6312 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6313 break; 6314 } 6315 case ICmpInst::ICMP_SGE: 6316 assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); 6317 if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) 6318 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6319 if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) 6320 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6321 if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B) 6322 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); 6323 break; 6324 case ICmpInst::ICMP_SLE: 6325 assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); 6326 if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) 6327 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6328 if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) 6329 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6330 if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B) 6331 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); 6332 break; 6333 case ICmpInst::ICMP_UGE: 6334 assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); 6335 if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) 6336 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6337 if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) 6338 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6339 if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B) 6340 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); 6341 break; 6342 case ICmpInst::ICMP_ULE: 6343 assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); 6344 if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) 6345 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 6346 if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) 6347 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 6348 if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B) 6349 return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); 6350 break; 6351 } 6352 6353 // Turn a signed comparison into an unsigned one if both operands are known to 6354 // have the same sign. 6355 if (I.isSigned() && 6356 ((Op0Known.Zero.isNegative() && Op1Known.Zero.isNegative()) || 6357 (Op0Known.One.isNegative() && Op1Known.One.isNegative()))) 6358 return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); 6359 6360 return nullptr; 6361 } 6362 6363 /// If one operand of an icmp is effectively a bool (value range of {0,1}), 6364 /// then try to reduce patterns based on that limit. 6365 Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { 6366 Value *X, *Y; 6367 ICmpInst::Predicate Pred; 6368 6369 // X must be 0 and bool must be true for "ULT": 6370 // X <u (zext i1 Y) --> (X == 0) & Y 6371 if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_ZExt(m_Value(Y))))) && 6372 Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULT) 6373 return BinaryOperator::CreateAnd(Builder.CreateIsNull(X), Y); 6374 6375 // X must be 0 or bool must be true for "ULE": 6376 // X <=u (sext i1 Y) --> (X == 0) | Y 6377 if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_SExt(m_Value(Y))))) && 6378 Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) 6379 return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); 6380 6381 // icmp eq/ne X, (zext/sext (icmp eq/ne X, C)) 6382 ICmpInst::Predicate Pred1, Pred2; 6383 const APInt *C; 6384 Instruction *ExtI; 6385 if (match(&I, m_c_ICmp(Pred1, m_Value(X), 6386 m_CombineAnd(m_Instruction(ExtI), 6387 m_ZExtOrSExt(m_ICmp(Pred2, m_Deferred(X), 6388 m_APInt(C)))))) && 6389 ICmpInst::isEquality(Pred1) && ICmpInst::isEquality(Pred2)) { 6390 bool IsSExt = ExtI->getOpcode() == Instruction::SExt; 6391 bool HasOneUse = ExtI->hasOneUse() && ExtI->getOperand(0)->hasOneUse(); 6392 auto CreateRangeCheck = [&] { 6393 Value *CmpV1 = 6394 Builder.CreateICmp(Pred1, X, Constant::getNullValue(X->getType())); 6395 Value *CmpV2 = Builder.CreateICmp( 6396 Pred1, X, ConstantInt::getSigned(X->getType(), IsSExt ? -1 : 1)); 6397 return BinaryOperator::Create( 6398 Pred1 == ICmpInst::ICMP_EQ ? Instruction::Or : Instruction::And, 6399 CmpV1, CmpV2); 6400 }; 6401 if (C->isZero()) { 6402 if (Pred2 == ICmpInst::ICMP_EQ) { 6403 // icmp eq X, (zext/sext (icmp eq X, 0)) --> false 6404 // icmp ne X, (zext/sext (icmp eq X, 0)) --> true 6405 return replaceInstUsesWith( 6406 I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE)); 6407 } else if (!IsSExt || HasOneUse) { 6408 // icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1 6409 // icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1 6410 // icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1 6411 // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X == -1 6412 return CreateRangeCheck(); 6413 } 6414 } else if (IsSExt ? C->isAllOnes() : C->isOne()) { 6415 if (Pred2 == ICmpInst::ICMP_NE) { 6416 // icmp eq X, (zext (icmp ne X, 1)) --> false 6417 // icmp ne X, (zext (icmp ne X, 1)) --> true 6418 // icmp eq X, (sext (icmp ne X, -1)) --> false 6419 // icmp ne X, (sext (icmp ne X, -1)) --> true 6420 return replaceInstUsesWith( 6421 I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE)); 6422 } else if (!IsSExt || HasOneUse) { 6423 // icmp eq X, (zext (icmp eq X, 1)) --> X == 0 || X == 1 6424 // icmp ne X, (zext (icmp eq X, 1)) --> X != 0 && X != 1 6425 // icmp eq X, (sext (icmp eq X, -1)) --> X == 0 || X == -1 6426 // icmp ne X, (sext (icmp eq X, -1)) --> X != 0 && X == -1 6427 return CreateRangeCheck(); 6428 } 6429 } else { 6430 // when C != 0 && C != 1: 6431 // icmp eq X, (zext (icmp eq X, C)) --> icmp eq X, 0 6432 // icmp eq X, (zext (icmp ne X, C)) --> icmp eq X, 1 6433 // icmp ne X, (zext (icmp eq X, C)) --> icmp ne X, 0 6434 // icmp ne X, (zext (icmp ne X, C)) --> icmp ne X, 1 6435 // when C != 0 && C != -1: 6436 // icmp eq X, (sext (icmp eq X, C)) --> icmp eq X, 0 6437 // icmp eq X, (sext (icmp ne X, C)) --> icmp eq X, -1 6438 // icmp ne X, (sext (icmp eq X, C)) --> icmp ne X, 0 6439 // icmp ne X, (sext (icmp ne X, C)) --> icmp ne X, -1 6440 return ICmpInst::Create( 6441 Instruction::ICmp, Pred1, X, 6442 ConstantInt::getSigned(X->getType(), Pred2 == ICmpInst::ICMP_NE 6443 ? (IsSExt ? -1 : 1) 6444 : 0)); 6445 } 6446 } 6447 6448 return nullptr; 6449 } 6450 6451 std::optional<std::pair<CmpInst::Predicate, Constant *>> 6452 InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, 6453 Constant *C) { 6454 assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && 6455 "Only for relational integer predicates."); 6456 6457 Type *Type = C->getType(); 6458 bool IsSigned = ICmpInst::isSigned(Pred); 6459 6460 CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); 6461 bool WillIncrement = 6462 UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; 6463 6464 // Check if the constant operand can be safely incremented/decremented 6465 // without overflowing/underflowing. 6466 auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { 6467 return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); 6468 }; 6469 6470 Constant *SafeReplacementConstant = nullptr; 6471 if (auto *CI = dyn_cast<ConstantInt>(C)) { 6472 // Bail out if the constant can't be safely incremented/decremented. 6473 if (!ConstantIsOk(CI)) 6474 return std::nullopt; 6475 } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { 6476 unsigned NumElts = FVTy->getNumElements(); 6477 for (unsigned i = 0; i != NumElts; ++i) { 6478 Constant *Elt = C->getAggregateElement(i); 6479 if (!Elt) 6480 return std::nullopt; 6481 6482 if (isa<UndefValue>(Elt)) 6483 continue; 6484 6485 // Bail out if we can't determine if this constant is min/max or if we 6486 // know that this constant is min/max. 6487 auto *CI = dyn_cast<ConstantInt>(Elt); 6488 if (!CI || !ConstantIsOk(CI)) 6489 return std::nullopt; 6490 6491 if (!SafeReplacementConstant) 6492 SafeReplacementConstant = CI; 6493 } 6494 } else if (isa<VectorType>(C->getType())) { 6495 // Handle scalable splat 6496 Value *SplatC = C->getSplatValue(); 6497 auto *CI = dyn_cast_or_null<ConstantInt>(SplatC); 6498 // Bail out if the constant can't be safely incremented/decremented. 6499 if (!CI || !ConstantIsOk(CI)) 6500 return std::nullopt; 6501 } else { 6502 // ConstantExpr? 6503 return std::nullopt; 6504 } 6505 6506 // It may not be safe to change a compare predicate in the presence of 6507 // undefined elements, so replace those elements with the first safe constant 6508 // that we found. 6509 // TODO: in case of poison, it is safe; let's replace undefs only. 6510 if (C->containsUndefOrPoisonElement()) { 6511 assert(SafeReplacementConstant && "Replacement constant not set"); 6512 C = Constant::replaceUndefsWith(C, SafeReplacementConstant); 6513 } 6514 6515 CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); 6516 6517 // Increment or decrement the constant. 6518 Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); 6519 Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); 6520 6521 return std::make_pair(NewPred, NewC); 6522 } 6523 6524 /// If we have an icmp le or icmp ge instruction with a constant operand, turn 6525 /// it into the appropriate icmp lt or icmp gt instruction. This transform 6526 /// allows them to be folded in visitICmpInst. 6527 static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { 6528 ICmpInst::Predicate Pred = I.getPredicate(); 6529 if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) || 6530 InstCombiner::isCanonicalPredicate(Pred)) 6531 return nullptr; 6532 6533 Value *Op0 = I.getOperand(0); 6534 Value *Op1 = I.getOperand(1); 6535 auto *Op1C = dyn_cast<Constant>(Op1); 6536 if (!Op1C) 6537 return nullptr; 6538 6539 auto FlippedStrictness = 6540 InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C); 6541 if (!FlippedStrictness) 6542 return nullptr; 6543 6544 return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); 6545 } 6546 6547 /// If we have a comparison with a non-canonical predicate, if we can update 6548 /// all the users, invert the predicate and adjust all the users. 6549 CmpInst *InstCombinerImpl::canonicalizeICmpPredicate(CmpInst &I) { 6550 // Is the predicate already canonical? 6551 CmpInst::Predicate Pred = I.getPredicate(); 6552 if (InstCombiner::isCanonicalPredicate(Pred)) 6553 return nullptr; 6554 6555 // Can all users be adjusted to predicate inversion? 6556 if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) 6557 return nullptr; 6558 6559 // Ok, we can canonicalize comparison! 6560 // Let's first invert the comparison's predicate. 6561 I.setPredicate(CmpInst::getInversePredicate(Pred)); 6562 I.setName(I.getName() + ".not"); 6563 6564 // And, adapt users. 6565 freelyInvertAllUsersOf(&I); 6566 6567 return &I; 6568 } 6569 6570 /// Integer compare with boolean values can always be turned into bitwise ops. 6571 static Instruction *canonicalizeICmpBool(ICmpInst &I, 6572 InstCombiner::BuilderTy &Builder) { 6573 Value *A = I.getOperand(0), *B = I.getOperand(1); 6574 assert(A->getType()->isIntOrIntVectorTy(1) && "Bools only"); 6575 6576 // A boolean compared to true/false can be simplified to Op0/true/false in 6577 // 14 out of the 20 (10 predicates * 2 constants) possible combinations. 6578 // Cases not handled by InstSimplify are always 'not' of Op0. 6579 if (match(B, m_Zero())) { 6580 switch (I.getPredicate()) { 6581 case CmpInst::ICMP_EQ: // A == 0 -> !A 6582 case CmpInst::ICMP_ULE: // A <=u 0 -> !A 6583 case CmpInst::ICMP_SGE: // A >=s 0 -> !A 6584 return BinaryOperator::CreateNot(A); 6585 default: 6586 llvm_unreachable("ICmp i1 X, C not simplified as expected."); 6587 } 6588 } else if (match(B, m_One())) { 6589 switch (I.getPredicate()) { 6590 case CmpInst::ICMP_NE: // A != 1 -> !A 6591 case CmpInst::ICMP_ULT: // A <u 1 -> !A 6592 case CmpInst::ICMP_SGT: // A >s -1 -> !A 6593 return BinaryOperator::CreateNot(A); 6594 default: 6595 llvm_unreachable("ICmp i1 X, C not simplified as expected."); 6596 } 6597 } 6598 6599 switch (I.getPredicate()) { 6600 default: 6601 llvm_unreachable("Invalid icmp instruction!"); 6602 case ICmpInst::ICMP_EQ: 6603 // icmp eq i1 A, B -> ~(A ^ B) 6604 return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); 6605 6606 case ICmpInst::ICMP_NE: 6607 // icmp ne i1 A, B -> A ^ B 6608 return BinaryOperator::CreateXor(A, B); 6609 6610 case ICmpInst::ICMP_UGT: 6611 // icmp ugt -> icmp ult 6612 std::swap(A, B); 6613 [[fallthrough]]; 6614 case ICmpInst::ICMP_ULT: 6615 // icmp ult i1 A, B -> ~A & B 6616 return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); 6617 6618 case ICmpInst::ICMP_SGT: 6619 // icmp sgt -> icmp slt 6620 std::swap(A, B); 6621 [[fallthrough]]; 6622 case ICmpInst::ICMP_SLT: 6623 // icmp slt i1 A, B -> A & ~B 6624 return BinaryOperator::CreateAnd(Builder.CreateNot(B), A); 6625 6626 case ICmpInst::ICMP_UGE: 6627 // icmp uge -> icmp ule 6628 std::swap(A, B); 6629 [[fallthrough]]; 6630 case ICmpInst::ICMP_ULE: 6631 // icmp ule i1 A, B -> ~A | B 6632 return BinaryOperator::CreateOr(Builder.CreateNot(A), B); 6633 6634 case ICmpInst::ICMP_SGE: 6635 // icmp sge -> icmp sle 6636 std::swap(A, B); 6637 [[fallthrough]]; 6638 case ICmpInst::ICMP_SLE: 6639 // icmp sle i1 A, B -> A | ~B 6640 return BinaryOperator::CreateOr(Builder.CreateNot(B), A); 6641 } 6642 } 6643 6644 // Transform pattern like: 6645 // (1 << Y) u<= X or ~(-1 << Y) u< X or ((1 << Y)+(-1)) u< X 6646 // (1 << Y) u> X or ~(-1 << Y) u>= X or ((1 << Y)+(-1)) u>= X 6647 // Into: 6648 // (X l>> Y) != 0 6649 // (X l>> Y) == 0 6650 static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, 6651 InstCombiner::BuilderTy &Builder) { 6652 ICmpInst::Predicate Pred, NewPred; 6653 Value *X, *Y; 6654 if (match(&Cmp, 6655 m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) { 6656 switch (Pred) { 6657 case ICmpInst::ICMP_ULE: 6658 NewPred = ICmpInst::ICMP_NE; 6659 break; 6660 case ICmpInst::ICMP_UGT: 6661 NewPred = ICmpInst::ICMP_EQ; 6662 break; 6663 default: 6664 return nullptr; 6665 } 6666 } else if (match(&Cmp, m_c_ICmp(Pred, 6667 m_OneUse(m_CombineOr( 6668 m_Not(m_Shl(m_AllOnes(), m_Value(Y))), 6669 m_Add(m_Shl(m_One(), m_Value(Y)), 6670 m_AllOnes()))), 6671 m_Value(X)))) { 6672 // The variant with 'add' is not canonical, (the variant with 'not' is) 6673 // we only get it because it has extra uses, and can't be canonicalized, 6674 6675 switch (Pred) { 6676 case ICmpInst::ICMP_ULT: 6677 NewPred = ICmpInst::ICMP_NE; 6678 break; 6679 case ICmpInst::ICMP_UGE: 6680 NewPred = ICmpInst::ICMP_EQ; 6681 break; 6682 default: 6683 return nullptr; 6684 } 6685 } else 6686 return nullptr; 6687 6688 Value *NewX = Builder.CreateLShr(X, Y, X->getName() + ".highbits"); 6689 Constant *Zero = Constant::getNullValue(NewX->getType()); 6690 return CmpInst::Create(Instruction::ICmp, NewPred, NewX, Zero); 6691 } 6692 6693 static Instruction *foldVectorCmp(CmpInst &Cmp, 6694 InstCombiner::BuilderTy &Builder) { 6695 const CmpInst::Predicate Pred = Cmp.getPredicate(); 6696 Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); 6697 Value *V1, *V2; 6698 6699 auto createCmpReverse = [&](CmpInst::Predicate Pred, Value *X, Value *Y) { 6700 Value *V = Builder.CreateCmp(Pred, X, Y, Cmp.getName()); 6701 if (auto *I = dyn_cast<Instruction>(V)) 6702 I->copyIRFlags(&Cmp); 6703 Module *M = Cmp.getModule(); 6704 Function *F = Intrinsic::getDeclaration( 6705 M, Intrinsic::experimental_vector_reverse, V->getType()); 6706 return CallInst::Create(F, V); 6707 }; 6708 6709 if (match(LHS, m_VecReverse(m_Value(V1)))) { 6710 // cmp Pred, rev(V1), rev(V2) --> rev(cmp Pred, V1, V2) 6711 if (match(RHS, m_VecReverse(m_Value(V2))) && 6712 (LHS->hasOneUse() || RHS->hasOneUse())) 6713 return createCmpReverse(Pred, V1, V2); 6714 6715 // cmp Pred, rev(V1), RHSSplat --> rev(cmp Pred, V1, RHSSplat) 6716 if (LHS->hasOneUse() && isSplatValue(RHS)) 6717 return createCmpReverse(Pred, V1, RHS); 6718 } 6719 // cmp Pred, LHSSplat, rev(V2) --> rev(cmp Pred, LHSSplat, V2) 6720 else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2))))) 6721 return createCmpReverse(Pred, LHS, V2); 6722 6723 ArrayRef<int> M; 6724 if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M)))) 6725 return nullptr; 6726 6727 // If both arguments of the cmp are shuffles that use the same mask and 6728 // shuffle within a single vector, move the shuffle after the cmp: 6729 // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M 6730 Type *V1Ty = V1->getType(); 6731 if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && 6732 V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { 6733 Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); 6734 return new ShuffleVectorInst(NewCmp, M); 6735 } 6736 6737 // Try to canonicalize compare with splatted operand and splat constant. 6738 // TODO: We could generalize this for more than splats. See/use the code in 6739 // InstCombiner::foldVectorBinop(). 6740 Constant *C; 6741 if (!LHS->hasOneUse() || !match(RHS, m_Constant(C))) 6742 return nullptr; 6743 6744 // Length-changing splats are ok, so adjust the constants as needed: 6745 // cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M 6746 Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true); 6747 int MaskSplatIndex; 6748 if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) { 6749 // We allow undefs in matching, but this transform removes those for safety. 6750 // Demanded elements analysis should be able to recover some/all of that. 6751 C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(), 6752 ScalarC); 6753 SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); 6754 Value *NewCmp = Builder.CreateCmp(Pred, V1, C); 6755 return new ShuffleVectorInst(NewCmp, NewM); 6756 } 6757 6758 return nullptr; 6759 } 6760 6761 // extract(uadd.with.overflow(A, B), 0) ult A 6762 // -> extract(uadd.with.overflow(A, B), 1) 6763 static Instruction *foldICmpOfUAddOv(ICmpInst &I) { 6764 CmpInst::Predicate Pred = I.getPredicate(); 6765 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 6766 6767 Value *UAddOv; 6768 Value *A, *B; 6769 auto UAddOvResultPat = m_ExtractValue<0>( 6770 m_Intrinsic<Intrinsic::uadd_with_overflow>(m_Value(A), m_Value(B))); 6771 if (match(Op0, UAddOvResultPat) && 6772 ((Pred == ICmpInst::ICMP_ULT && (Op1 == A || Op1 == B)) || 6773 (Pred == ICmpInst::ICMP_EQ && match(Op1, m_ZeroInt()) && 6774 (match(A, m_One()) || match(B, m_One()))) || 6775 (Pred == ICmpInst::ICMP_NE && match(Op1, m_AllOnes()) && 6776 (match(A, m_AllOnes()) || match(B, m_AllOnes()))))) 6777 // extract(uadd.with.overflow(A, B), 0) < A 6778 // extract(uadd.with.overflow(A, 1), 0) == 0 6779 // extract(uadd.with.overflow(A, -1), 0) != -1 6780 UAddOv = cast<ExtractValueInst>(Op0)->getAggregateOperand(); 6781 else if (match(Op1, UAddOvResultPat) && 6782 Pred == ICmpInst::ICMP_UGT && (Op0 == A || Op0 == B)) 6783 // A > extract(uadd.with.overflow(A, B), 0) 6784 UAddOv = cast<ExtractValueInst>(Op1)->getAggregateOperand(); 6785 else 6786 return nullptr; 6787 6788 return ExtractValueInst::Create(UAddOv, 1); 6789 } 6790 6791 static Instruction *foldICmpInvariantGroup(ICmpInst &I) { 6792 if (!I.getOperand(0)->getType()->isPointerTy() || 6793 NullPointerIsDefined( 6794 I.getParent()->getParent(), 6795 I.getOperand(0)->getType()->getPointerAddressSpace())) { 6796 return nullptr; 6797 } 6798 Instruction *Op; 6799 if (match(I.getOperand(0), m_Instruction(Op)) && 6800 match(I.getOperand(1), m_Zero()) && 6801 Op->isLaunderOrStripInvariantGroup()) { 6802 return ICmpInst::Create(Instruction::ICmp, I.getPredicate(), 6803 Op->getOperand(0), I.getOperand(1)); 6804 } 6805 return nullptr; 6806 } 6807 6808 /// This function folds patterns produced by lowering of reduce idioms, such as 6809 /// llvm.vector.reduce.and which are lowered into instruction chains. This code 6810 /// attempts to generate fewer number of scalar comparisons instead of vector 6811 /// comparisons when possible. 6812 static Instruction *foldReductionIdiom(ICmpInst &I, 6813 InstCombiner::BuilderTy &Builder, 6814 const DataLayout &DL) { 6815 if (I.getType()->isVectorTy()) 6816 return nullptr; 6817 ICmpInst::Predicate OuterPred, InnerPred; 6818 Value *LHS, *RHS; 6819 6820 // Match lowering of @llvm.vector.reduce.and. Turn 6821 /// %vec_ne = icmp ne <8 x i8> %lhs, %rhs 6822 /// %scalar_ne = bitcast <8 x i1> %vec_ne to i8 6823 /// %res = icmp <pred> i8 %scalar_ne, 0 6824 /// 6825 /// into 6826 /// 6827 /// %lhs.scalar = bitcast <8 x i8> %lhs to i64 6828 /// %rhs.scalar = bitcast <8 x i8> %rhs to i64 6829 /// %res = icmp <pred> i64 %lhs.scalar, %rhs.scalar 6830 /// 6831 /// for <pred> in {ne, eq}. 6832 if (!match(&I, m_ICmp(OuterPred, 6833 m_OneUse(m_BitCast(m_OneUse( 6834 m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))), 6835 m_Zero()))) 6836 return nullptr; 6837 auto *LHSTy = dyn_cast<FixedVectorType>(LHS->getType()); 6838 if (!LHSTy || !LHSTy->getElementType()->isIntegerTy()) 6839 return nullptr; 6840 unsigned NumBits = 6841 LHSTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth(); 6842 // TODO: Relax this to "not wider than max legal integer type"? 6843 if (!DL.isLegalInteger(NumBits)) 6844 return nullptr; 6845 6846 if (ICmpInst::isEquality(OuterPred) && InnerPred == ICmpInst::ICMP_NE) { 6847 auto *ScalarTy = Builder.getIntNTy(NumBits); 6848 LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar"); 6849 RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar"); 6850 return ICmpInst::Create(Instruction::ICmp, OuterPred, LHS, RHS, 6851 I.getName()); 6852 } 6853 6854 return nullptr; 6855 } 6856 6857 // This helper will be called with icmp operands in both orders. 6858 Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred, 6859 Value *Op0, Value *Op1, 6860 ICmpInst &CxtI) { 6861 // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'. 6862 if (auto *GEP = dyn_cast<GEPOperator>(Op0)) 6863 if (Instruction *NI = foldGEPICmp(GEP, Op1, Pred, CxtI)) 6864 return NI; 6865 6866 if (auto *SI = dyn_cast<SelectInst>(Op0)) 6867 if (Instruction *NI = foldSelectICmp(Pred, SI, Op1, CxtI)) 6868 return NI; 6869 6870 if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op0)) 6871 if (Instruction *Res = foldICmpWithMinMax(CxtI, MinMax, Op1, Pred)) 6872 return Res; 6873 6874 { 6875 Value *X; 6876 const APInt *C; 6877 // icmp X+Cst, X 6878 if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X) 6879 return foldICmpAddOpConst(X, *C, Pred); 6880 } 6881 6882 // abs(X) >= X --> true 6883 // abs(X) u<= X --> true 6884 // abs(X) < X --> false 6885 // abs(X) u> X --> false 6886 // abs(X) u>= X --> IsIntMinPosion ? `X > -1`: `X u<= INTMIN` 6887 // abs(X) <= X --> IsIntMinPosion ? `X > -1`: `X u<= INTMIN` 6888 // abs(X) == X --> IsIntMinPosion ? `X > -1`: `X u<= INTMIN` 6889 // abs(X) u< X --> IsIntMinPosion ? `X < 0` : `X > INTMIN` 6890 // abs(X) > X --> IsIntMinPosion ? `X < 0` : `X > INTMIN` 6891 // abs(X) != X --> IsIntMinPosion ? `X < 0` : `X > INTMIN` 6892 { 6893 Value *X; 6894 Constant *C; 6895 if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X), m_Constant(C))) && 6896 match(Op1, m_Specific(X))) { 6897 Value *NullValue = Constant::getNullValue(X->getType()); 6898 Value *AllOnesValue = Constant::getAllOnesValue(X->getType()); 6899 const APInt SMin = 6900 APInt::getSignedMinValue(X->getType()->getScalarSizeInBits()); 6901 bool IsIntMinPosion = C->isAllOnesValue(); 6902 switch (Pred) { 6903 case CmpInst::ICMP_ULE: 6904 case CmpInst::ICMP_SGE: 6905 return replaceInstUsesWith(CxtI, ConstantInt::getTrue(CxtI.getType())); 6906 case CmpInst::ICMP_UGT: 6907 case CmpInst::ICMP_SLT: 6908 return replaceInstUsesWith(CxtI, ConstantInt::getFalse(CxtI.getType())); 6909 case CmpInst::ICMP_UGE: 6910 case CmpInst::ICMP_SLE: 6911 case CmpInst::ICMP_EQ: { 6912 return replaceInstUsesWith( 6913 CxtI, IsIntMinPosion 6914 ? Builder.CreateICmpSGT(X, AllOnesValue) 6915 : Builder.CreateICmpULT( 6916 X, ConstantInt::get(X->getType(), SMin + 1))); 6917 } 6918 case CmpInst::ICMP_ULT: 6919 case CmpInst::ICMP_SGT: 6920 case CmpInst::ICMP_NE: { 6921 return replaceInstUsesWith( 6922 CxtI, IsIntMinPosion 6923 ? Builder.CreateICmpSLT(X, NullValue) 6924 : Builder.CreateICmpUGT( 6925 X, ConstantInt::get(X->getType(), SMin))); 6926 } 6927 default: 6928 llvm_unreachable("Invalid predicate!"); 6929 } 6930 } 6931 } 6932 6933 return nullptr; 6934 } 6935 6936 Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { 6937 bool Changed = false; 6938 const SimplifyQuery Q = SQ.getWithInstruction(&I); 6939 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 6940 unsigned Op0Cplxity = getComplexity(Op0); 6941 unsigned Op1Cplxity = getComplexity(Op1); 6942 6943 /// Orders the operands of the compare so that they are listed from most 6944 /// complex to least complex. This puts constants before unary operators, 6945 /// before binary operators. 6946 if (Op0Cplxity < Op1Cplxity) { 6947 I.swapOperands(); 6948 std::swap(Op0, Op1); 6949 Changed = true; 6950 } 6951 6952 if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) 6953 return replaceInstUsesWith(I, V); 6954 6955 // Comparing -val or val with non-zero is the same as just comparing val 6956 // ie, abs(val) != 0 -> val != 0 6957 if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { 6958 Value *Cond, *SelectTrue, *SelectFalse; 6959 if (match(Op0, m_Select(m_Value(Cond), m_Value(SelectTrue), 6960 m_Value(SelectFalse)))) { 6961 if (Value *V = dyn_castNegVal(SelectTrue)) { 6962 if (V == SelectFalse) 6963 return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1); 6964 } 6965 else if (Value *V = dyn_castNegVal(SelectFalse)) { 6966 if (V == SelectTrue) 6967 return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1); 6968 } 6969 } 6970 } 6971 6972 if (Op0->getType()->isIntOrIntVectorTy(1)) 6973 if (Instruction *Res = canonicalizeICmpBool(I, Builder)) 6974 return Res; 6975 6976 if (Instruction *Res = canonicalizeCmpWithConstant(I)) 6977 return Res; 6978 6979 if (Instruction *Res = canonicalizeICmpPredicate(I)) 6980 return Res; 6981 6982 if (Instruction *Res = foldICmpWithConstant(I)) 6983 return Res; 6984 6985 if (Instruction *Res = foldICmpWithDominatingICmp(I)) 6986 return Res; 6987 6988 if (Instruction *Res = foldICmpUsingBoolRange(I)) 6989 return Res; 6990 6991 if (Instruction *Res = foldICmpUsingKnownBits(I)) 6992 return Res; 6993 6994 if (Instruction *Res = foldICmpTruncWithTruncOrExt(I, Q)) 6995 return Res; 6996 6997 // Test if the ICmpInst instruction is used exclusively by a select as 6998 // part of a minimum or maximum operation. If so, refrain from doing 6999 // any other folding. This helps out other analyses which understand 7000 // non-obfuscated minimum and maximum idioms, such as ScalarEvolution 7001 // and CodeGen. And in this case, at least one of the comparison 7002 // operands has at least one user besides the compare (the select), 7003 // which would often largely negate the benefit of folding anyway. 7004 // 7005 // Do the same for the other patterns recognized by matchSelectPattern. 7006 if (I.hasOneUse()) 7007 if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { 7008 Value *A, *B; 7009 SelectPatternResult SPR = matchSelectPattern(SI, A, B); 7010 if (SPR.Flavor != SPF_UNKNOWN) 7011 return nullptr; 7012 } 7013 7014 // Do this after checking for min/max to prevent infinite looping. 7015 if (Instruction *Res = foldICmpWithZero(I)) 7016 return Res; 7017 7018 // FIXME: We only do this after checking for min/max to prevent infinite 7019 // looping caused by a reverse canonicalization of these patterns for min/max. 7020 // FIXME: The organization of folds is a mess. These would naturally go into 7021 // canonicalizeCmpWithConstant(), but we can't move all of the above folds 7022 // down here after the min/max restriction. 7023 ICmpInst::Predicate Pred = I.getPredicate(); 7024 const APInt *C; 7025 if (match(Op1, m_APInt(C))) { 7026 // For i32: x >u 2147483647 -> x <s 0 -> true if sign bit set 7027 if (Pred == ICmpInst::ICMP_UGT && C->isMaxSignedValue()) { 7028 Constant *Zero = Constant::getNullValue(Op0->getType()); 7029 return new ICmpInst(ICmpInst::ICMP_SLT, Op0, Zero); 7030 } 7031 7032 // For i32: x <u 2147483648 -> x >s -1 -> true if sign bit clear 7033 if (Pred == ICmpInst::ICMP_ULT && C->isMinSignedValue()) { 7034 Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); 7035 return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); 7036 } 7037 } 7038 7039 // The folds in here may rely on wrapping flags and special constants, so 7040 // they can break up min/max idioms in some cases but not seemingly similar 7041 // patterns. 7042 // FIXME: It may be possible to enhance select folding to make this 7043 // unnecessary. It may also be moot if we canonicalize to min/max 7044 // intrinsics. 7045 if (Instruction *Res = foldICmpBinOp(I, Q)) 7046 return Res; 7047 7048 if (Instruction *Res = foldICmpInstWithConstant(I)) 7049 return Res; 7050 7051 // Try to match comparison as a sign bit test. Intentionally do this after 7052 // foldICmpInstWithConstant() to potentially let other folds to happen first. 7053 if (Instruction *New = foldSignBitTest(I)) 7054 return New; 7055 7056 if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) 7057 return Res; 7058 7059 if (Instruction *Res = foldICmpCommutative(I.getPredicate(), Op0, Op1, I)) 7060 return Res; 7061 if (Instruction *Res = 7062 foldICmpCommutative(I.getSwappedPredicate(), Op1, Op0, I)) 7063 return Res; 7064 7065 // In case of a comparison with two select instructions having the same 7066 // condition, check whether one of the resulting branches can be simplified. 7067 // If so, just compare the other branch and select the appropriate result. 7068 // For example: 7069 // %tmp1 = select i1 %cmp, i32 %y, i32 %x 7070 // %tmp2 = select i1 %cmp, i32 %z, i32 %x 7071 // %cmp2 = icmp slt i32 %tmp2, %tmp1 7072 // The icmp will result false for the false value of selects and the result 7073 // will depend upon the comparison of true values of selects if %cmp is 7074 // true. Thus, transform this into: 7075 // %cmp = icmp slt i32 %y, %z 7076 // %sel = select i1 %cond, i1 %cmp, i1 false 7077 // This handles similar cases to transform. 7078 { 7079 Value *Cond, *A, *B, *C, *D; 7080 if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) && 7081 match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) && 7082 (Op0->hasOneUse() || Op1->hasOneUse())) { 7083 // Check whether comparison of TrueValues can be simplified 7084 if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) { 7085 Value *NewICMP = Builder.CreateICmp(Pred, B, D); 7086 return SelectInst::Create(Cond, Res, NewICMP); 7087 } 7088 // Check whether comparison of FalseValues can be simplified 7089 if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) { 7090 Value *NewICMP = Builder.CreateICmp(Pred, A, C); 7091 return SelectInst::Create(Cond, NewICMP, Res); 7092 } 7093 } 7094 } 7095 7096 // Try to optimize equality comparisons against alloca-based pointers. 7097 if (Op0->getType()->isPointerTy() && I.isEquality()) { 7098 assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); 7099 if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0))) 7100 if (foldAllocaCmp(Alloca)) 7101 return nullptr; 7102 if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1))) 7103 if (foldAllocaCmp(Alloca)) 7104 return nullptr; 7105 } 7106 7107 if (Instruction *Res = foldICmpBitCast(I)) 7108 return Res; 7109 7110 // TODO: Hoist this above the min/max bailout. 7111 if (Instruction *R = foldICmpWithCastOp(I)) 7112 return R; 7113 7114 { 7115 Value *X, *Y; 7116 // Transform (X & ~Y) == 0 --> (X & Y) != 0 7117 // and (X & ~Y) != 0 --> (X & Y) == 0 7118 // if A is a power of 2. 7119 if (match(Op0, m_And(m_Value(X), m_Not(m_Value(Y)))) && 7120 match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(X, false, 0, &I) && 7121 I.isEquality()) 7122 return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(X, Y), 7123 Op1); 7124 7125 // Op0 pred Op1 -> ~Op1 pred ~Op0, if this allows us to drop an instruction. 7126 if (Op0->getType()->isIntOrIntVectorTy()) { 7127 bool ConsumesOp0, ConsumesOp1; 7128 if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) && 7129 isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) && 7130 (ConsumesOp0 || ConsumesOp1)) { 7131 Value *InvOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder); 7132 Value *InvOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder); 7133 assert(InvOp0 && InvOp1 && 7134 "Mismatch between isFreeToInvert and getFreelyInverted"); 7135 return new ICmpInst(I.getSwappedPredicate(), InvOp0, InvOp1); 7136 } 7137 } 7138 7139 Instruction *AddI = nullptr; 7140 if (match(&I, m_UAddWithOverflow(m_Value(X), m_Value(Y), 7141 m_Instruction(AddI))) && 7142 isa<IntegerType>(X->getType())) { 7143 Value *Result; 7144 Constant *Overflow; 7145 // m_UAddWithOverflow can match patterns that do not include an explicit 7146 // "add" instruction, so check the opcode of the matched op. 7147 if (AddI->getOpcode() == Instruction::Add && 7148 OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, X, Y, *AddI, 7149 Result, Overflow)) { 7150 replaceInstUsesWith(*AddI, Result); 7151 eraseInstFromFunction(*AddI); 7152 return replaceInstUsesWith(I, Overflow); 7153 } 7154 } 7155 7156 // (zext X) * (zext Y) --> llvm.umul.with.overflow. 7157 if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) && 7158 match(Op1, m_APInt(C))) { 7159 if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this)) 7160 return R; 7161 } 7162 7163 // Signbit test folds 7164 // Fold (X u>> BitWidth - 1 Pred ZExt(i1)) --> X s< 0 Pred i1 7165 // Fold (X s>> BitWidth - 1 Pred SExt(i1)) --> X s< 0 Pred i1 7166 Instruction *ExtI; 7167 if ((I.isUnsigned() || I.isEquality()) && 7168 match(Op1, 7169 m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(Y)))) && 7170 Y->getType()->getScalarSizeInBits() == 1 && 7171 (Op0->hasOneUse() || Op1->hasOneUse())) { 7172 unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); 7173 Instruction *ShiftI; 7174 if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), 7175 m_Shr(m_Value(X), m_SpecificIntAllowUndef( 7176 OpWidth - 1))))) { 7177 unsigned ExtOpc = ExtI->getOpcode(); 7178 unsigned ShiftOpc = ShiftI->getOpcode(); 7179 if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || 7180 (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { 7181 Value *SLTZero = 7182 Builder.CreateICmpSLT(X, Constant::getNullValue(X->getType())); 7183 Value *Cmp = Builder.CreateICmp(Pred, SLTZero, Y, I.getName()); 7184 return replaceInstUsesWith(I, Cmp); 7185 } 7186 } 7187 } 7188 } 7189 7190 if (Instruction *Res = foldICmpEquality(I)) 7191 return Res; 7192 7193 if (Instruction *Res = foldICmpPow2Test(I, Builder)) 7194 return Res; 7195 7196 if (Instruction *Res = foldICmpOfUAddOv(I)) 7197 return Res; 7198 7199 // The 'cmpxchg' instruction returns an aggregate containing the old value and 7200 // an i1 which indicates whether or not we successfully did the swap. 7201 // 7202 // Replace comparisons between the old value and the expected value with the 7203 // indicator that 'cmpxchg' returns. 7204 // 7205 // N.B. This transform is only valid when the 'cmpxchg' is not permitted to 7206 // spuriously fail. In those cases, the old value may equal the expected 7207 // value but it is possible for the swap to not occur. 7208 if (I.getPredicate() == ICmpInst::ICMP_EQ) 7209 if (auto *EVI = dyn_cast<ExtractValueInst>(Op0)) 7210 if (auto *ACXI = dyn_cast<AtomicCmpXchgInst>(EVI->getAggregateOperand())) 7211 if (EVI->getIndices()[0] == 0 && ACXI->getCompareOperand() == Op1 && 7212 !ACXI->isWeak()) 7213 return ExtractValueInst::Create(ACXI, 1); 7214 7215 if (Instruction *Res = foldICmpWithHighBitMask(I, Builder)) 7216 return Res; 7217 7218 if (I.getType()->isVectorTy()) 7219 if (Instruction *Res = foldVectorCmp(I, Builder)) 7220 return Res; 7221 7222 if (Instruction *Res = foldICmpInvariantGroup(I)) 7223 return Res; 7224 7225 if (Instruction *Res = foldReductionIdiom(I, Builder, DL)) 7226 return Res; 7227 7228 return Changed ? &I : nullptr; 7229 } 7230 7231 /// Fold fcmp ([us]itofp x, cst) if possible. 7232 Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, 7233 Instruction *LHSI, 7234 Constant *RHSC) { 7235 if (!isa<ConstantFP>(RHSC)) return nullptr; 7236 const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); 7237 7238 // Get the width of the mantissa. We don't want to hack on conversions that 7239 // might lose information from the integer, e.g. "i64 -> float" 7240 int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); 7241 if (MantissaWidth == -1) return nullptr; // Unknown. 7242 7243 IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType()); 7244 7245 bool LHSUnsigned = isa<UIToFPInst>(LHSI); 7246 7247 if (I.isEquality()) { 7248 FCmpInst::Predicate P = I.getPredicate(); 7249 bool IsExact = false; 7250 APSInt RHSCvt(IntTy->getBitWidth(), LHSUnsigned); 7251 RHS.convertToInteger(RHSCvt, APFloat::rmNearestTiesToEven, &IsExact); 7252 7253 // If the floating point constant isn't an integer value, we know if we will 7254 // ever compare equal / not equal to it. 7255 if (!IsExact) { 7256 // TODO: Can never be -0.0 and other non-representable values 7257 APFloat RHSRoundInt(RHS); 7258 RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); 7259 if (RHS != RHSRoundInt) { 7260 if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) 7261 return replaceInstUsesWith(I, Builder.getFalse()); 7262 7263 assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); 7264 return replaceInstUsesWith(I, Builder.getTrue()); 7265 } 7266 } 7267 7268 // TODO: If the constant is exactly representable, is it always OK to do 7269 // equality compares as integer? 7270 } 7271 7272 // Check to see that the input is converted from an integer type that is small 7273 // enough that preserves all bits. TODO: check here for "known" sign bits. 7274 // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. 7275 unsigned InputSize = IntTy->getScalarSizeInBits(); 7276 7277 // Following test does NOT adjust InputSize downwards for signed inputs, 7278 // because the most negative value still requires all the mantissa bits 7279 // to distinguish it from one less than that value. 7280 if ((int)InputSize > MantissaWidth) { 7281 // Conversion would lose accuracy. Check if loss can impact comparison. 7282 int Exp = ilogb(RHS); 7283 if (Exp == APFloat::IEK_Inf) { 7284 int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); 7285 if (MaxExponent < (int)InputSize - !LHSUnsigned) 7286 // Conversion could create infinity. 7287 return nullptr; 7288 } else { 7289 // Note that if RHS is zero or NaN, then Exp is negative 7290 // and first condition is trivially false. 7291 if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) 7292 // Conversion could affect comparison. 7293 return nullptr; 7294 } 7295 } 7296 7297 // Otherwise, we can potentially simplify the comparison. We know that it 7298 // will always come through as an integer value and we know the constant is 7299 // not a NAN (it would have been previously simplified). 7300 assert(!RHS.isNaN() && "NaN comparison not already folded!"); 7301 7302 ICmpInst::Predicate Pred; 7303 switch (I.getPredicate()) { 7304 default: llvm_unreachable("Unexpected predicate!"); 7305 case FCmpInst::FCMP_UEQ: 7306 case FCmpInst::FCMP_OEQ: 7307 Pred = ICmpInst::ICMP_EQ; 7308 break; 7309 case FCmpInst::FCMP_UGT: 7310 case FCmpInst::FCMP_OGT: 7311 Pred = LHSUnsigned ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_SGT; 7312 break; 7313 case FCmpInst::FCMP_UGE: 7314 case FCmpInst::FCMP_OGE: 7315 Pred = LHSUnsigned ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_SGE; 7316 break; 7317 case FCmpInst::FCMP_ULT: 7318 case FCmpInst::FCMP_OLT: 7319 Pred = LHSUnsigned ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_SLT; 7320 break; 7321 case FCmpInst::FCMP_ULE: 7322 case FCmpInst::FCMP_OLE: 7323 Pred = LHSUnsigned ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_SLE; 7324 break; 7325 case FCmpInst::FCMP_UNE: 7326 case FCmpInst::FCMP_ONE: 7327 Pred = ICmpInst::ICMP_NE; 7328 break; 7329 case FCmpInst::FCMP_ORD: 7330 return replaceInstUsesWith(I, Builder.getTrue()); 7331 case FCmpInst::FCMP_UNO: 7332 return replaceInstUsesWith(I, Builder.getFalse()); 7333 } 7334 7335 // Now we know that the APFloat is a normal number, zero or inf. 7336 7337 // See if the FP constant is too large for the integer. For example, 7338 // comparing an i8 to 300.0. 7339 unsigned IntWidth = IntTy->getScalarSizeInBits(); 7340 7341 if (!LHSUnsigned) { 7342 // If the RHS value is > SignedMax, fold the comparison. This handles +INF 7343 // and large values. 7344 APFloat SMax(RHS.getSemantics()); 7345 SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, 7346 APFloat::rmNearestTiesToEven); 7347 if (SMax < RHS) { // smax < 13123.0 7348 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || 7349 Pred == ICmpInst::ICMP_SLE) 7350 return replaceInstUsesWith(I, Builder.getTrue()); 7351 return replaceInstUsesWith(I, Builder.getFalse()); 7352 } 7353 } else { 7354 // If the RHS value is > UnsignedMax, fold the comparison. This handles 7355 // +INF and large values. 7356 APFloat UMax(RHS.getSemantics()); 7357 UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, 7358 APFloat::rmNearestTiesToEven); 7359 if (UMax < RHS) { // umax < 13123.0 7360 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || 7361 Pred == ICmpInst::ICMP_ULE) 7362 return replaceInstUsesWith(I, Builder.getTrue()); 7363 return replaceInstUsesWith(I, Builder.getFalse()); 7364 } 7365 } 7366 7367 if (!LHSUnsigned) { 7368 // See if the RHS value is < SignedMin. 7369 APFloat SMin(RHS.getSemantics()); 7370 SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, 7371 APFloat::rmNearestTiesToEven); 7372 if (SMin > RHS) { // smin > 12312.0 7373 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || 7374 Pred == ICmpInst::ICMP_SGE) 7375 return replaceInstUsesWith(I, Builder.getTrue()); 7376 return replaceInstUsesWith(I, Builder.getFalse()); 7377 } 7378 } else { 7379 // See if the RHS value is < UnsignedMin. 7380 APFloat UMin(RHS.getSemantics()); 7381 UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false, 7382 APFloat::rmNearestTiesToEven); 7383 if (UMin > RHS) { // umin > 12312.0 7384 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || 7385 Pred == ICmpInst::ICMP_UGE) 7386 return replaceInstUsesWith(I, Builder.getTrue()); 7387 return replaceInstUsesWith(I, Builder.getFalse()); 7388 } 7389 } 7390 7391 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or 7392 // [0, UMAX], but it may still be fractional. Check whether this is the case 7393 // using the IsExact flag. 7394 // Don't do this for zero, because -0.0 is not fractional. 7395 APSInt RHSInt(IntWidth, LHSUnsigned); 7396 bool IsExact; 7397 RHS.convertToInteger(RHSInt, APFloat::rmTowardZero, &IsExact); 7398 if (!RHS.isZero()) { 7399 if (!IsExact) { 7400 // If we had a comparison against a fractional value, we have to adjust 7401 // the compare predicate and sometimes the value. RHSC is rounded towards 7402 // zero at this point. 7403 switch (Pred) { 7404 default: llvm_unreachable("Unexpected integer comparison!"); 7405 case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true 7406 return replaceInstUsesWith(I, Builder.getTrue()); 7407 case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false 7408 return replaceInstUsesWith(I, Builder.getFalse()); 7409 case ICmpInst::ICMP_ULE: 7410 // (float)int <= 4.4 --> int <= 4 7411 // (float)int <= -4.4 --> false 7412 if (RHS.isNegative()) 7413 return replaceInstUsesWith(I, Builder.getFalse()); 7414 break; 7415 case ICmpInst::ICMP_SLE: 7416 // (float)int <= 4.4 --> int <= 4 7417 // (float)int <= -4.4 --> int < -4 7418 if (RHS.isNegative()) 7419 Pred = ICmpInst::ICMP_SLT; 7420 break; 7421 case ICmpInst::ICMP_ULT: 7422 // (float)int < -4.4 --> false 7423 // (float)int < 4.4 --> int <= 4 7424 if (RHS.isNegative()) 7425 return replaceInstUsesWith(I, Builder.getFalse()); 7426 Pred = ICmpInst::ICMP_ULE; 7427 break; 7428 case ICmpInst::ICMP_SLT: 7429 // (float)int < -4.4 --> int < -4 7430 // (float)int < 4.4 --> int <= 4 7431 if (!RHS.isNegative()) 7432 Pred = ICmpInst::ICMP_SLE; 7433 break; 7434 case ICmpInst::ICMP_UGT: 7435 // (float)int > 4.4 --> int > 4 7436 // (float)int > -4.4 --> true 7437 if (RHS.isNegative()) 7438 return replaceInstUsesWith(I, Builder.getTrue()); 7439 break; 7440 case ICmpInst::ICMP_SGT: 7441 // (float)int > 4.4 --> int > 4 7442 // (float)int > -4.4 --> int >= -4 7443 if (RHS.isNegative()) 7444 Pred = ICmpInst::ICMP_SGE; 7445 break; 7446 case ICmpInst::ICMP_UGE: 7447 // (float)int >= -4.4 --> true 7448 // (float)int >= 4.4 --> int > 4 7449 if (RHS.isNegative()) 7450 return replaceInstUsesWith(I, Builder.getTrue()); 7451 Pred = ICmpInst::ICMP_UGT; 7452 break; 7453 case ICmpInst::ICMP_SGE: 7454 // (float)int >= -4.4 --> int >= -4 7455 // (float)int >= 4.4 --> int > 4 7456 if (!RHS.isNegative()) 7457 Pred = ICmpInst::ICMP_SGT; 7458 break; 7459 } 7460 } 7461 } 7462 7463 // Lower this FP comparison into an appropriate integer version of the 7464 // comparison. 7465 return new ICmpInst(Pred, LHSI->getOperand(0), Builder.getInt(RHSInt)); 7466 } 7467 7468 /// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary. 7469 static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, 7470 Constant *RHSC) { 7471 // When C is not 0.0 and infinities are not allowed: 7472 // (C / X) < 0.0 is a sign-bit test of X 7473 // (C / X) < 0.0 --> X < 0.0 (if C is positive) 7474 // (C / X) < 0.0 --> X > 0.0 (if C is negative, swap the predicate) 7475 // 7476 // Proof: 7477 // Multiply (C / X) < 0.0 by X * X / C. 7478 // - X is non zero, if it is the flag 'ninf' is violated. 7479 // - C defines the sign of X * X * C. Thus it also defines whether to swap 7480 // the predicate. C is also non zero by definition. 7481 // 7482 // Thus X * X / C is non zero and the transformation is valid. [qed] 7483 7484 FCmpInst::Predicate Pred = I.getPredicate(); 7485 7486 // Check that predicates are valid. 7487 if ((Pred != FCmpInst::FCMP_OGT) && (Pred != FCmpInst::FCMP_OLT) && 7488 (Pred != FCmpInst::FCMP_OGE) && (Pred != FCmpInst::FCMP_OLE)) 7489 return nullptr; 7490 7491 // Check that RHS operand is zero. 7492 if (!match(RHSC, m_AnyZeroFP())) 7493 return nullptr; 7494 7495 // Check fastmath flags ('ninf'). 7496 if (!LHSI->hasNoInfs() || !I.hasNoInfs()) 7497 return nullptr; 7498 7499 // Check the properties of the dividend. It must not be zero to avoid a 7500 // division by zero (see Proof). 7501 const APFloat *C; 7502 if (!match(LHSI->getOperand(0), m_APFloat(C))) 7503 return nullptr; 7504 7505 if (C->isZero()) 7506 return nullptr; 7507 7508 // Get swapped predicate if necessary. 7509 if (C->isNegative()) 7510 Pred = I.getSwappedPredicate(); 7511 7512 return new FCmpInst(Pred, LHSI->getOperand(1), RHSC, "", &I); 7513 } 7514 7515 /// Optimize fabs(X) compared with zero. 7516 static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { 7517 Value *X; 7518 if (!match(I.getOperand(0), m_FAbs(m_Value(X)))) 7519 return nullptr; 7520 7521 const APFloat *C; 7522 if (!match(I.getOperand(1), m_APFloat(C))) 7523 return nullptr; 7524 7525 if (!C->isPosZero()) { 7526 if (!C->isSmallestNormalized()) 7527 return nullptr; 7528 7529 const Function *F = I.getFunction(); 7530 DenormalMode Mode = F->getDenormalMode(C->getSemantics()); 7531 if (Mode.Input == DenormalMode::PreserveSign || 7532 Mode.Input == DenormalMode::PositiveZero) { 7533 7534 auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { 7535 Constant *Zero = ConstantFP::getZero(X->getType()); 7536 return new FCmpInst(P, X, Zero, "", I); 7537 }; 7538 7539 switch (I.getPredicate()) { 7540 case FCmpInst::FCMP_OLT: 7541 // fcmp olt fabs(x), smallest_normalized_number -> fcmp oeq x, 0.0 7542 return replaceFCmp(&I, FCmpInst::FCMP_OEQ, X); 7543 case FCmpInst::FCMP_UGE: 7544 // fcmp uge fabs(x), smallest_normalized_number -> fcmp une x, 0.0 7545 return replaceFCmp(&I, FCmpInst::FCMP_UNE, X); 7546 case FCmpInst::FCMP_OGE: 7547 // fcmp oge fabs(x), smallest_normalized_number -> fcmp one x, 0.0 7548 return replaceFCmp(&I, FCmpInst::FCMP_ONE, X); 7549 case FCmpInst::FCMP_ULT: 7550 // fcmp ult fabs(x), smallest_normalized_number -> fcmp ueq x, 0.0 7551 return replaceFCmp(&I, FCmpInst::FCMP_UEQ, X); 7552 default: 7553 break; 7554 } 7555 } 7556 7557 return nullptr; 7558 } 7559 7560 auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) { 7561 I->setPredicate(P); 7562 return IC.replaceOperand(*I, 0, X); 7563 }; 7564 7565 switch (I.getPredicate()) { 7566 case FCmpInst::FCMP_UGE: 7567 case FCmpInst::FCMP_OLT: 7568 // fabs(X) >= 0.0 --> true 7569 // fabs(X) < 0.0 --> false 7570 llvm_unreachable("fcmp should have simplified"); 7571 7572 case FCmpInst::FCMP_OGT: 7573 // fabs(X) > 0.0 --> X != 0.0 7574 return replacePredAndOp0(&I, FCmpInst::FCMP_ONE, X); 7575 7576 case FCmpInst::FCMP_UGT: 7577 // fabs(X) u> 0.0 --> X u!= 0.0 7578 return replacePredAndOp0(&I, FCmpInst::FCMP_UNE, X); 7579 7580 case FCmpInst::FCMP_OLE: 7581 // fabs(X) <= 0.0 --> X == 0.0 7582 return replacePredAndOp0(&I, FCmpInst::FCMP_OEQ, X); 7583 7584 case FCmpInst::FCMP_ULE: 7585 // fabs(X) u<= 0.0 --> X u== 0.0 7586 return replacePredAndOp0(&I, FCmpInst::FCMP_UEQ, X); 7587 7588 case FCmpInst::FCMP_OGE: 7589 // fabs(X) >= 0.0 --> !isnan(X) 7590 assert(!I.hasNoNaNs() && "fcmp should have simplified"); 7591 return replacePredAndOp0(&I, FCmpInst::FCMP_ORD, X); 7592 7593 case FCmpInst::FCMP_ULT: 7594 // fabs(X) u< 0.0 --> isnan(X) 7595 assert(!I.hasNoNaNs() && "fcmp should have simplified"); 7596 return replacePredAndOp0(&I, FCmpInst::FCMP_UNO, X); 7597 7598 case FCmpInst::FCMP_OEQ: 7599 case FCmpInst::FCMP_UEQ: 7600 case FCmpInst::FCMP_ONE: 7601 case FCmpInst::FCMP_UNE: 7602 case FCmpInst::FCMP_ORD: 7603 case FCmpInst::FCMP_UNO: 7604 // Look through the fabs() because it doesn't change anything but the sign. 7605 // fabs(X) == 0.0 --> X == 0.0, 7606 // fabs(X) != 0.0 --> X != 0.0 7607 // isnan(fabs(X)) --> isnan(X) 7608 // !isnan(fabs(X) --> !isnan(X) 7609 return replacePredAndOp0(&I, I.getPredicate(), X); 7610 7611 default: 7612 return nullptr; 7613 } 7614 } 7615 7616 static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { 7617 CmpInst::Predicate Pred = I.getPredicate(); 7618 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 7619 7620 // Canonicalize fneg as Op1. 7621 if (match(Op0, m_FNeg(m_Value())) && !match(Op1, m_FNeg(m_Value()))) { 7622 std::swap(Op0, Op1); 7623 Pred = I.getSwappedPredicate(); 7624 } 7625 7626 if (!match(Op1, m_FNeg(m_Specific(Op0)))) 7627 return nullptr; 7628 7629 // Replace the negated operand with 0.0: 7630 // fcmp Pred Op0, -Op0 --> fcmp Pred Op0, 0.0 7631 Constant *Zero = ConstantFP::getZero(Op0->getType()); 7632 return new FCmpInst(Pred, Op0, Zero, "", &I); 7633 } 7634 7635 Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { 7636 bool Changed = false; 7637 7638 /// Orders the operands of the compare so that they are listed from most 7639 /// complex to least complex. This puts constants before unary operators, 7640 /// before binary operators. 7641 if (getComplexity(I.getOperand(0)) < getComplexity(I.getOperand(1))) { 7642 I.swapOperands(); 7643 Changed = true; 7644 } 7645 7646 const CmpInst::Predicate Pred = I.getPredicate(); 7647 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 7648 if (Value *V = simplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), 7649 SQ.getWithInstruction(&I))) 7650 return replaceInstUsesWith(I, V); 7651 7652 // Simplify 'fcmp pred X, X' 7653 Type *OpType = Op0->getType(); 7654 assert(OpType == Op1->getType() && "fcmp with different-typed operands?"); 7655 if (Op0 == Op1) { 7656 switch (Pred) { 7657 default: break; 7658 case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) 7659 case FCmpInst::FCMP_ULT: // True if unordered or less than 7660 case FCmpInst::FCMP_UGT: // True if unordered or greater than 7661 case FCmpInst::FCMP_UNE: // True if unordered or not equal 7662 // Canonicalize these to be 'fcmp uno %X, 0.0'. 7663 I.setPredicate(FCmpInst::FCMP_UNO); 7664 I.setOperand(1, Constant::getNullValue(OpType)); 7665 return &I; 7666 7667 case FCmpInst::FCMP_ORD: // True if ordered (no nans) 7668 case FCmpInst::FCMP_OEQ: // True if ordered and equal 7669 case FCmpInst::FCMP_OGE: // True if ordered and greater than or equal 7670 case FCmpInst::FCMP_OLE: // True if ordered and less than or equal 7671 // Canonicalize these to be 'fcmp ord %X, 0.0'. 7672 I.setPredicate(FCmpInst::FCMP_ORD); 7673 I.setOperand(1, Constant::getNullValue(OpType)); 7674 return &I; 7675 } 7676 } 7677 7678 // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, 7679 // then canonicalize the operand to 0.0. 7680 if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { 7681 if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, DL, &TLI, 0, 7682 &AC, &I, &DT)) 7683 return replaceOperand(I, 0, ConstantFP::getZero(OpType)); 7684 7685 if (!match(Op1, m_PosZeroFP()) && 7686 isKnownNeverNaN(Op1, DL, &TLI, 0, &AC, &I, &DT)) 7687 return replaceOperand(I, 1, ConstantFP::getZero(OpType)); 7688 } 7689 7690 // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y 7691 Value *X, *Y; 7692 if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) 7693 return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); 7694 7695 if (Instruction *R = foldFCmpFNegCommonOp(I)) 7696 return R; 7697 7698 // Test if the FCmpInst instruction is used exclusively by a select as 7699 // part of a minimum or maximum operation. If so, refrain from doing 7700 // any other folding. This helps out other analyses which understand 7701 // non-obfuscated minimum and maximum idioms, such as ScalarEvolution 7702 // and CodeGen. And in this case, at least one of the comparison 7703 // operands has at least one user besides the compare (the select), 7704 // which would often largely negate the benefit of folding anyway. 7705 if (I.hasOneUse()) 7706 if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { 7707 Value *A, *B; 7708 SelectPatternResult SPR = matchSelectPattern(SI, A, B); 7709 if (SPR.Flavor != SPF_UNKNOWN) 7710 return nullptr; 7711 } 7712 7713 // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: 7714 // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 7715 if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) 7716 return replaceOperand(I, 1, ConstantFP::getZero(OpType)); 7717 7718 // Ignore signbit of bitcasted int when comparing equality to FP 0.0: 7719 // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 7720 if (match(Op1, m_PosZeroFP()) && 7721 match(Op0, m_OneUse(m_BitCast(m_Value(X)))) && 7722 X->getType()->isVectorTy() == OpType->isVectorTy() && 7723 X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) { 7724 ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE; 7725 if (Pred == FCmpInst::FCMP_OEQ) 7726 IntPred = ICmpInst::ICMP_EQ; 7727 else if (Pred == FCmpInst::FCMP_UNE) 7728 IntPred = ICmpInst::ICMP_NE; 7729 7730 if (IntPred != ICmpInst::BAD_ICMP_PREDICATE) { 7731 Type *IntTy = X->getType(); 7732 const APInt &SignMask = ~APInt::getSignMask(IntTy->getScalarSizeInBits()); 7733 Value *MaskX = Builder.CreateAnd(X, ConstantInt::get(IntTy, SignMask)); 7734 return new ICmpInst(IntPred, MaskX, ConstantInt::getNullValue(IntTy)); 7735 } 7736 } 7737 7738 // Handle fcmp with instruction LHS and constant RHS. 7739 Instruction *LHSI; 7740 Constant *RHSC; 7741 if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) { 7742 switch (LHSI->getOpcode()) { 7743 case Instruction::PHI: 7744 if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) 7745 return NV; 7746 break; 7747 case Instruction::SIToFP: 7748 case Instruction::UIToFP: 7749 if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) 7750 return NV; 7751 break; 7752 case Instruction::FDiv: 7753 if (Instruction *NV = foldFCmpReciprocalAndZero(I, LHSI, RHSC)) 7754 return NV; 7755 break; 7756 case Instruction::Load: 7757 if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) 7758 if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) 7759 if (Instruction *Res = foldCmpLoadFromIndexedGlobal( 7760 cast<LoadInst>(LHSI), GEP, GV, I)) 7761 return Res; 7762 break; 7763 } 7764 } 7765 7766 if (Instruction *R = foldFabsWithFcmpZero(I, *this)) 7767 return R; 7768 7769 if (match(Op0, m_FNeg(m_Value(X)))) { 7770 // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C 7771 Constant *C; 7772 if (match(Op1, m_Constant(C))) 7773 if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) 7774 return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); 7775 } 7776 7777 if (match(Op0, m_FPExt(m_Value(X)))) { 7778 // fcmp (fpext X), (fpext Y) -> fcmp X, Y 7779 if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType()) 7780 return new FCmpInst(Pred, X, Y, "", &I); 7781 7782 const APFloat *C; 7783 if (match(Op1, m_APFloat(C))) { 7784 const fltSemantics &FPSem = 7785 X->getType()->getScalarType()->getFltSemantics(); 7786 bool Lossy; 7787 APFloat TruncC = *C; 7788 TruncC.convert(FPSem, APFloat::rmNearestTiesToEven, &Lossy); 7789 7790 if (Lossy) { 7791 // X can't possibly equal the higher-precision constant, so reduce any 7792 // equality comparison. 7793 // TODO: Other predicates can be handled via getFCmpCode(). 7794 switch (Pred) { 7795 case FCmpInst::FCMP_OEQ: 7796 // X is ordered and equal to an impossible constant --> false 7797 return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); 7798 case FCmpInst::FCMP_ONE: 7799 // X is ordered and not equal to an impossible constant --> ordered 7800 return new FCmpInst(FCmpInst::FCMP_ORD, X, 7801 ConstantFP::getZero(X->getType())); 7802 case FCmpInst::FCMP_UEQ: 7803 // X is unordered or equal to an impossible constant --> unordered 7804 return new FCmpInst(FCmpInst::FCMP_UNO, X, 7805 ConstantFP::getZero(X->getType())); 7806 case FCmpInst::FCMP_UNE: 7807 // X is unordered or not equal to an impossible constant --> true 7808 return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); 7809 default: 7810 break; 7811 } 7812 } 7813 7814 // fcmp (fpext X), C -> fcmp X, (fptrunc C) if fptrunc is lossless 7815 // Avoid lossy conversions and denormals. 7816 // Zero is a special case that's OK to convert. 7817 APFloat Fabs = TruncC; 7818 Fabs.clearSign(); 7819 if (!Lossy && 7820 (Fabs.isZero() || !(Fabs < APFloat::getSmallestNormalized(FPSem)))) { 7821 Constant *NewC = ConstantFP::get(X->getType(), TruncC); 7822 return new FCmpInst(Pred, X, NewC, "", &I); 7823 } 7824 } 7825 } 7826 7827 // Convert a sign-bit test of an FP value into a cast and integer compare. 7828 // TODO: Simplify if the copysign constant is 0.0 or NaN. 7829 // TODO: Handle non-zero compare constants. 7830 // TODO: Handle other predicates. 7831 const APFloat *C; 7832 if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C), 7833 m_Value(X)))) && 7834 match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) { 7835 Type *IntType = Builder.getIntNTy(X->getType()->getScalarSizeInBits()); 7836 if (auto *VecTy = dyn_cast<VectorType>(OpType)) 7837 IntType = VectorType::get(IntType, VecTy->getElementCount()); 7838 7839 // copysign(non-zero constant, X) < 0.0 --> (bitcast X) < 0 7840 if (Pred == FCmpInst::FCMP_OLT) { 7841 Value *IntX = Builder.CreateBitCast(X, IntType); 7842 return new ICmpInst(ICmpInst::ICMP_SLT, IntX, 7843 ConstantInt::getNullValue(IntType)); 7844 } 7845 } 7846 7847 { 7848 Value *CanonLHS = nullptr, *CanonRHS = nullptr; 7849 match(Op0, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonLHS))); 7850 match(Op1, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonRHS))); 7851 7852 // (canonicalize(x) == x) => (x == x) 7853 if (CanonLHS == Op1) 7854 return new FCmpInst(Pred, Op1, Op1, "", &I); 7855 7856 // (x == canonicalize(x)) => (x == x) 7857 if (CanonRHS == Op0) 7858 return new FCmpInst(Pred, Op0, Op0, "", &I); 7859 7860 // (canonicalize(x) == canonicalize(y)) => (x == y) 7861 if (CanonLHS && CanonRHS) 7862 return new FCmpInst(Pred, CanonLHS, CanonRHS, "", &I); 7863 } 7864 7865 if (I.getType()->isVectorTy()) 7866 if (Instruction *Res = foldVectorCmp(I, Builder)) 7867 return Res; 7868 7869 return Changed ? &I : nullptr; 7870 } 7871