xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp (revision 0b37c1590418417c894529d371800dfac71ef887)
1 //===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file holds routines to help analyse compare instructions
10 // and fold them into constants or other compare instructions
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/CmpInstAnalysis.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/PatternMatch.h"
18 
19 using namespace llvm;
20 
21 unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) {
22   ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate()
23                                         : ICI->getPredicate();
24   switch (Pred) {
25       // False -> 0
26     case ICmpInst::ICMP_UGT: return 1;  // 001
27     case ICmpInst::ICMP_SGT: return 1;  // 001
28     case ICmpInst::ICMP_EQ:  return 2;  // 010
29     case ICmpInst::ICMP_UGE: return 3;  // 011
30     case ICmpInst::ICMP_SGE: return 3;  // 011
31     case ICmpInst::ICMP_ULT: return 4;  // 100
32     case ICmpInst::ICMP_SLT: return 4;  // 100
33     case ICmpInst::ICMP_NE:  return 5;  // 101
34     case ICmpInst::ICMP_ULE: return 6;  // 110
35     case ICmpInst::ICMP_SLE: return 6;  // 110
36       // True -> 7
37     default:
38       llvm_unreachable("Invalid ICmp predicate!");
39   }
40 }
41 
42 Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy,
43                                    CmpInst::Predicate &Pred) {
44   switch (Code) {
45     default: llvm_unreachable("Illegal ICmp code!");
46     case 0: // False.
47       return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0);
48     case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break;
49     case 2: Pred = ICmpInst::ICMP_EQ; break;
50     case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break;
51     case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break;
52     case 5: Pred = ICmpInst::ICMP_NE; break;
53     case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break;
54     case 7: // True.
55       return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1);
56   }
57   return nullptr;
58 }
59 
60 bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) {
61   return (CmpInst::isSigned(P1) == CmpInst::isSigned(P2)) ||
62          (CmpInst::isSigned(P1) && ICmpInst::isEquality(P2)) ||
63          (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1));
64 }
65 
66 bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
67                                 CmpInst::Predicate &Pred,
68                                 Value *&X, APInt &Mask, bool LookThruTrunc) {
69   using namespace PatternMatch;
70 
71   const APInt *C;
72   if (!match(RHS, m_APInt(C)))
73     return false;
74 
75   switch (Pred) {
76   default:
77     return false;
78   case ICmpInst::ICMP_SLT:
79     // X < 0 is equivalent to (X & SignMask) != 0.
80     if (!C->isNullValue())
81       return false;
82     Mask = APInt::getSignMask(C->getBitWidth());
83     Pred = ICmpInst::ICMP_NE;
84     break;
85   case ICmpInst::ICMP_SLE:
86     // X <= -1 is equivalent to (X & SignMask) != 0.
87     if (!C->isAllOnesValue())
88       return false;
89     Mask = APInt::getSignMask(C->getBitWidth());
90     Pred = ICmpInst::ICMP_NE;
91     break;
92   case ICmpInst::ICMP_SGT:
93     // X > -1 is equivalent to (X & SignMask) == 0.
94     if (!C->isAllOnesValue())
95       return false;
96     Mask = APInt::getSignMask(C->getBitWidth());
97     Pred = ICmpInst::ICMP_EQ;
98     break;
99   case ICmpInst::ICMP_SGE:
100     // X >= 0 is equivalent to (X & SignMask) == 0.
101     if (!C->isNullValue())
102       return false;
103     Mask = APInt::getSignMask(C->getBitWidth());
104     Pred = ICmpInst::ICMP_EQ;
105     break;
106   case ICmpInst::ICMP_ULT:
107     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
108     if (!C->isPowerOf2())
109       return false;
110     Mask = -*C;
111     Pred = ICmpInst::ICMP_EQ;
112     break;
113   case ICmpInst::ICMP_ULE:
114     // X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0.
115     if (!(*C + 1).isPowerOf2())
116       return false;
117     Mask = ~*C;
118     Pred = ICmpInst::ICMP_EQ;
119     break;
120   case ICmpInst::ICMP_UGT:
121     // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0.
122     if (!(*C + 1).isPowerOf2())
123       return false;
124     Mask = ~*C;
125     Pred = ICmpInst::ICMP_NE;
126     break;
127   case ICmpInst::ICMP_UGE:
128     // X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0.
129     if (!C->isPowerOf2())
130       return false;
131     Mask = -*C;
132     Pred = ICmpInst::ICMP_NE;
133     break;
134   }
135 
136   if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
137     Mask = Mask.zext(X->getType()->getScalarSizeInBits());
138   } else {
139     X = LHS;
140   }
141 
142   return true;
143 }
144