1 //===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file holds routines to help analyse compare instructions 10 // and fold them into constants or other compare instructions 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/CmpInstAnalysis.h" 15 #include "llvm/IR/Constants.h" 16 #include "llvm/IR/Instructions.h" 17 #include "llvm/IR/PatternMatch.h" 18 19 using namespace llvm; 20 21 unsigned llvm::getICmpCode(CmpInst::Predicate Pred) { 22 switch (Pred) { 23 // False -> 0 24 case ICmpInst::ICMP_UGT: return 1; // 001 25 case ICmpInst::ICMP_SGT: return 1; // 001 26 case ICmpInst::ICMP_EQ: return 2; // 010 27 case ICmpInst::ICMP_UGE: return 3; // 011 28 case ICmpInst::ICMP_SGE: return 3; // 011 29 case ICmpInst::ICMP_ULT: return 4; // 100 30 case ICmpInst::ICMP_SLT: return 4; // 100 31 case ICmpInst::ICMP_NE: return 5; // 101 32 case ICmpInst::ICMP_ULE: return 6; // 110 33 case ICmpInst::ICMP_SLE: return 6; // 110 34 // True -> 7 35 default: 36 llvm_unreachable("Invalid ICmp predicate!"); 37 } 38 } 39 40 Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy, 41 CmpInst::Predicate &Pred) { 42 switch (Code) { 43 default: llvm_unreachable("Illegal ICmp code!"); 44 case 0: // False. 45 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); 46 case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; 47 case 2: Pred = ICmpInst::ICMP_EQ; break; 48 case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; 49 case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; 50 case 5: Pred = ICmpInst::ICMP_NE; break; 51 case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; 52 case 7: // True. 53 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); 54 } 55 return nullptr; 56 } 57 58 bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) { 59 return (CmpInst::isSigned(P1) == CmpInst::isSigned(P2)) || 60 (CmpInst::isSigned(P1) && ICmpInst::isEquality(P2)) || 61 (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1)); 62 } 63 64 Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, 65 CmpInst::Predicate &Pred) { 66 Pred = static_cast<FCmpInst::Predicate>(Code); 67 assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && 68 "Unexpected FCmp predicate!"); 69 if (Pred == FCmpInst::FCMP_FALSE) 70 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); 71 if (Pred == FCmpInst::FCMP_TRUE) 72 return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); 73 return nullptr; 74 } 75 76 std::optional<DecomposedBitTest> 77 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, 78 bool LookThruTrunc, bool AllowNonZeroC, 79 bool DecomposeAnd) { 80 using namespace PatternMatch; 81 82 const APInt *OrigC; 83 if ((ICmpInst::isEquality(Pred) && !DecomposeAnd) || 84 !match(RHS, m_APIntAllowPoison(OrigC))) 85 return std::nullopt; 86 87 bool Inverted = false; 88 if (ICmpInst::isGT(Pred) || ICmpInst::isGE(Pred)) { 89 Inverted = true; 90 Pred = ICmpInst::getInversePredicate(Pred); 91 } 92 93 APInt C = *OrigC; 94 if (ICmpInst::isLE(Pred)) { 95 if (ICmpInst::isSigned(Pred) ? C.isMaxSignedValue() : C.isMaxValue()) 96 return std::nullopt; 97 ++C; 98 Pred = ICmpInst::getStrictPredicate(Pred); 99 } 100 101 DecomposedBitTest Result; 102 switch (Pred) { 103 default: 104 llvm_unreachable("Unexpected predicate"); 105 case ICmpInst::ICMP_SLT: { 106 // X < 0 is equivalent to (X & SignMask) != 0. 107 if (C.isZero()) { 108 Result.Mask = APInt::getSignMask(C.getBitWidth()); 109 Result.C = APInt::getZero(C.getBitWidth()); 110 Result.Pred = ICmpInst::ICMP_NE; 111 break; 112 } 113 114 APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth()); 115 if (FlippedSign.isPowerOf2()) { 116 // X s< 10000100 is equivalent to (X & 11111100 == 10000000) 117 Result.Mask = -FlippedSign; 118 Result.C = APInt::getSignMask(C.getBitWidth()); 119 Result.Pred = ICmpInst::ICMP_EQ; 120 break; 121 } 122 123 if (FlippedSign.isNegatedPowerOf2()) { 124 // X s< 01111100 is equivalent to (X & 11111100 != 01111100) 125 Result.Mask = FlippedSign; 126 Result.C = C; 127 Result.Pred = ICmpInst::ICMP_NE; 128 break; 129 } 130 131 return std::nullopt; 132 } 133 case ICmpInst::ICMP_ULT: { 134 // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. 135 if (C.isPowerOf2()) { 136 Result.Mask = -C; 137 Result.C = APInt::getZero(C.getBitWidth()); 138 Result.Pred = ICmpInst::ICMP_EQ; 139 break; 140 } 141 142 // X u< 11111100 is equivalent to (X & 11111100 != 11111100) 143 if (C.isNegatedPowerOf2()) { 144 Result.Mask = C; 145 Result.C = C; 146 Result.Pred = ICmpInst::ICMP_NE; 147 break; 148 } 149 150 return std::nullopt; 151 } 152 case ICmpInst::ICMP_EQ: 153 case ICmpInst::ICMP_NE: { 154 assert(DecomposeAnd); 155 const APInt *AndC; 156 Value *AndVal; 157 if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) { 158 LHS = AndVal; 159 Result.Mask = *AndC; 160 Result.C = C; 161 Result.Pred = Pred; 162 break; 163 } 164 165 return std::nullopt; 166 } 167 } 168 169 if (!AllowNonZeroC && !Result.C.isZero()) 170 return std::nullopt; 171 172 if (Inverted) 173 Result.Pred = ICmpInst::getInversePredicate(Result.Pred); 174 175 Value *X; 176 if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) { 177 Result.X = X; 178 Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits()); 179 Result.C = Result.C.zext(X->getType()->getScalarSizeInBits()); 180 } else { 181 Result.X = LHS; 182 } 183 184 return Result; 185 } 186 187 std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond, 188 bool LookThruTrunc, 189 bool AllowNonZeroC, 190 bool DecomposeAnd) { 191 using namespace PatternMatch; 192 if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) { 193 // Don't allow pointers. Splat vectors are fine. 194 if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy()) 195 return std::nullopt; 196 return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), 197 ICmp->getPredicate(), LookThruTrunc, 198 AllowNonZeroC, DecomposeAnd); 199 } 200 Value *X; 201 if (Cond->getType()->isIntOrIntVectorTy(1) && 202 (match(Cond, m_Trunc(m_Value(X))) || 203 match(Cond, m_Not(m_Trunc(m_Value(X)))))) { 204 DecomposedBitTest Result; 205 Result.X = X; 206 unsigned BitWidth = X->getType()->getScalarSizeInBits(); 207 Result.Mask = APInt(BitWidth, 1); 208 Result.C = APInt::getZero(BitWidth); 209 Result.Pred = isa<TruncInst>(Cond) ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; 210 211 return Result; 212 } 213 214 return std::nullopt; 215 } 216