xref: /freebsd/contrib/llvm-project/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConv.h (revision 1ac55f4cb0001fed92329746c730aa9a947c09a5)
1 //== SMTConv.h --------------------------------------------------*- C++ -*--==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file defines a set of functions to create SMT expressions
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONV_H
14 #define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONV_H
15 
16 #include "clang/AST/Expr.h"
17 #include "clang/StaticAnalyzer/Core/PathSensitive/APSIntType.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h"
19 #include "llvm/Support/SMTAPI.h"
20 
21 namespace clang {
22 namespace ento {
23 
24 class SMTConv {
25 public:
26   // Returns an appropriate sort, given a QualType and it's bit width.
mkSort(llvm::SMTSolverRef & Solver,const QualType & Ty,unsigned BitWidth)27   static inline llvm::SMTSortRef mkSort(llvm::SMTSolverRef &Solver,
28                                         const QualType &Ty, unsigned BitWidth) {
29     if (Ty->isBooleanType())
30       return Solver->getBoolSort();
31 
32     if (Ty->isRealFloatingType())
33       return Solver->getFloatSort(BitWidth);
34 
35     return Solver->getBitvectorSort(BitWidth);
36   }
37 
38   /// Constructs an SMTSolverRef from an unary operator.
fromUnOp(llvm::SMTSolverRef & Solver,const UnaryOperator::Opcode Op,const llvm::SMTExprRef & Exp)39   static inline llvm::SMTExprRef fromUnOp(llvm::SMTSolverRef &Solver,
40                                           const UnaryOperator::Opcode Op,
41                                           const llvm::SMTExprRef &Exp) {
42     switch (Op) {
43     case UO_Minus:
44       return Solver->mkBVNeg(Exp);
45 
46     case UO_Not:
47       return Solver->mkBVNot(Exp);
48 
49     case UO_LNot:
50       return Solver->mkNot(Exp);
51 
52     default:;
53     }
54     llvm_unreachable("Unimplemented opcode");
55   }
56 
57   /// Constructs an SMTSolverRef from a floating-point unary operator.
fromFloatUnOp(llvm::SMTSolverRef & Solver,const UnaryOperator::Opcode Op,const llvm::SMTExprRef & Exp)58   static inline llvm::SMTExprRef fromFloatUnOp(llvm::SMTSolverRef &Solver,
59                                                const UnaryOperator::Opcode Op,
60                                                const llvm::SMTExprRef &Exp) {
61     switch (Op) {
62     case UO_Minus:
63       return Solver->mkFPNeg(Exp);
64 
65     case UO_LNot:
66       return fromUnOp(Solver, Op, Exp);
67 
68     default:;
69     }
70     llvm_unreachable("Unimplemented opcode");
71   }
72 
73   /// Construct an SMTSolverRef from a n-ary binary operator.
74   static inline llvm::SMTExprRef
fromNBinOp(llvm::SMTSolverRef & Solver,const BinaryOperator::Opcode Op,const std::vector<llvm::SMTExprRef> & ASTs)75   fromNBinOp(llvm::SMTSolverRef &Solver, const BinaryOperator::Opcode Op,
76              const std::vector<llvm::SMTExprRef> &ASTs) {
77     assert(!ASTs.empty());
78 
79     if (Op != BO_LAnd && Op != BO_LOr)
80       llvm_unreachable("Unimplemented opcode");
81 
82     llvm::SMTExprRef res = ASTs.front();
83     for (std::size_t i = 1; i < ASTs.size(); ++i)
84       res = (Op == BO_LAnd) ? Solver->mkAnd(res, ASTs[i])
85                             : Solver->mkOr(res, ASTs[i]);
86     return res;
87   }
88 
89   /// Construct an SMTSolverRef from a binary operator.
fromBinOp(llvm::SMTSolverRef & Solver,const llvm::SMTExprRef & LHS,const BinaryOperator::Opcode Op,const llvm::SMTExprRef & RHS,bool isSigned)90   static inline llvm::SMTExprRef fromBinOp(llvm::SMTSolverRef &Solver,
91                                            const llvm::SMTExprRef &LHS,
92                                            const BinaryOperator::Opcode Op,
93                                            const llvm::SMTExprRef &RHS,
94                                            bool isSigned) {
95     assert(*Solver->getSort(LHS) == *Solver->getSort(RHS) &&
96            "AST's must have the same sort!");
97 
98     switch (Op) {
99     // Multiplicative operators
100     case BO_Mul:
101       return Solver->mkBVMul(LHS, RHS);
102 
103     case BO_Div:
104       return isSigned ? Solver->mkBVSDiv(LHS, RHS) : Solver->mkBVUDiv(LHS, RHS);
105 
106     case BO_Rem:
107       return isSigned ? Solver->mkBVSRem(LHS, RHS) : Solver->mkBVURem(LHS, RHS);
108 
109       // Additive operators
110     case BO_Add:
111       return Solver->mkBVAdd(LHS, RHS);
112 
113     case BO_Sub:
114       return Solver->mkBVSub(LHS, RHS);
115 
116       // Bitwise shift operators
117     case BO_Shl:
118       return Solver->mkBVShl(LHS, RHS);
119 
120     case BO_Shr:
121       return isSigned ? Solver->mkBVAshr(LHS, RHS) : Solver->mkBVLshr(LHS, RHS);
122 
123       // Relational operators
124     case BO_LT:
125       return isSigned ? Solver->mkBVSlt(LHS, RHS) : Solver->mkBVUlt(LHS, RHS);
126 
127     case BO_GT:
128       return isSigned ? Solver->mkBVSgt(LHS, RHS) : Solver->mkBVUgt(LHS, RHS);
129 
130     case BO_LE:
131       return isSigned ? Solver->mkBVSle(LHS, RHS) : Solver->mkBVUle(LHS, RHS);
132 
133     case BO_GE:
134       return isSigned ? Solver->mkBVSge(LHS, RHS) : Solver->mkBVUge(LHS, RHS);
135 
136       // Equality operators
137     case BO_EQ:
138       return Solver->mkEqual(LHS, RHS);
139 
140     case BO_NE:
141       return fromUnOp(Solver, UO_LNot,
142                       fromBinOp(Solver, LHS, BO_EQ, RHS, isSigned));
143 
144       // Bitwise operators
145     case BO_And:
146       return Solver->mkBVAnd(LHS, RHS);
147 
148     case BO_Xor:
149       return Solver->mkBVXor(LHS, RHS);
150 
151     case BO_Or:
152       return Solver->mkBVOr(LHS, RHS);
153 
154       // Logical operators
155     case BO_LAnd:
156       return Solver->mkAnd(LHS, RHS);
157 
158     case BO_LOr:
159       return Solver->mkOr(LHS, RHS);
160 
161     default:;
162     }
163     llvm_unreachable("Unimplemented opcode");
164   }
165 
166   /// Construct an SMTSolverRef from a special floating-point binary
167   /// operator.
168   static inline llvm::SMTExprRef
fromFloatSpecialBinOp(llvm::SMTSolverRef & Solver,const llvm::SMTExprRef & LHS,const BinaryOperator::Opcode Op,const llvm::APFloat::fltCategory & RHS)169   fromFloatSpecialBinOp(llvm::SMTSolverRef &Solver, const llvm::SMTExprRef &LHS,
170                         const BinaryOperator::Opcode Op,
171                         const llvm::APFloat::fltCategory &RHS) {
172     switch (Op) {
173     // Equality operators
174     case BO_EQ:
175       switch (RHS) {
176       case llvm::APFloat::fcInfinity:
177         return Solver->mkFPIsInfinite(LHS);
178 
179       case llvm::APFloat::fcNaN:
180         return Solver->mkFPIsNaN(LHS);
181 
182       case llvm::APFloat::fcNormal:
183         return Solver->mkFPIsNormal(LHS);
184 
185       case llvm::APFloat::fcZero:
186         return Solver->mkFPIsZero(LHS);
187       }
188       break;
189 
190     case BO_NE:
191       return fromFloatUnOp(Solver, UO_LNot,
192                            fromFloatSpecialBinOp(Solver, LHS, BO_EQ, RHS));
193 
194     default:;
195     }
196 
197     llvm_unreachable("Unimplemented opcode");
198   }
199 
200   /// Construct an SMTSolverRef from a floating-point binary operator.
fromFloatBinOp(llvm::SMTSolverRef & Solver,const llvm::SMTExprRef & LHS,const BinaryOperator::Opcode Op,const llvm::SMTExprRef & RHS)201   static inline llvm::SMTExprRef fromFloatBinOp(llvm::SMTSolverRef &Solver,
202                                                 const llvm::SMTExprRef &LHS,
203                                                 const BinaryOperator::Opcode Op,
204                                                 const llvm::SMTExprRef &RHS) {
205     assert(*Solver->getSort(LHS) == *Solver->getSort(RHS) &&
206            "AST's must have the same sort!");
207 
208     switch (Op) {
209     // Multiplicative operators
210     case BO_Mul:
211       return Solver->mkFPMul(LHS, RHS);
212 
213     case BO_Div:
214       return Solver->mkFPDiv(LHS, RHS);
215 
216     case BO_Rem:
217       return Solver->mkFPRem(LHS, RHS);
218 
219       // Additive operators
220     case BO_Add:
221       return Solver->mkFPAdd(LHS, RHS);
222 
223     case BO_Sub:
224       return Solver->mkFPSub(LHS, RHS);
225 
226       // Relational operators
227     case BO_LT:
228       return Solver->mkFPLt(LHS, RHS);
229 
230     case BO_GT:
231       return Solver->mkFPGt(LHS, RHS);
232 
233     case BO_LE:
234       return Solver->mkFPLe(LHS, RHS);
235 
236     case BO_GE:
237       return Solver->mkFPGe(LHS, RHS);
238 
239       // Equality operators
240     case BO_EQ:
241       return Solver->mkFPEqual(LHS, RHS);
242 
243     case BO_NE:
244       return fromFloatUnOp(Solver, UO_LNot,
245                            fromFloatBinOp(Solver, LHS, BO_EQ, RHS));
246 
247       // Logical operators
248     case BO_LAnd:
249     case BO_LOr:
250       return fromBinOp(Solver, LHS, Op, RHS, /*isSigned=*/false);
251 
252     default:;
253     }
254 
255     llvm_unreachable("Unimplemented opcode");
256   }
257 
258   /// Construct an SMTSolverRef from a QualType FromTy to a QualType ToTy,
259   /// and their bit widths.
fromCast(llvm::SMTSolverRef & Solver,const llvm::SMTExprRef & Exp,QualType ToTy,uint64_t ToBitWidth,QualType FromTy,uint64_t FromBitWidth)260   static inline llvm::SMTExprRef fromCast(llvm::SMTSolverRef &Solver,
261                                           const llvm::SMTExprRef &Exp,
262                                           QualType ToTy, uint64_t ToBitWidth,
263                                           QualType FromTy,
264                                           uint64_t FromBitWidth) {
265     if ((FromTy->isIntegralOrEnumerationType() &&
266          ToTy->isIntegralOrEnumerationType()) ||
267         (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) ||
268         (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) ||
269         (FromTy->isReferenceType() ^ ToTy->isReferenceType())) {
270 
271       if (FromTy->isBooleanType()) {
272         assert(ToBitWidth > 0 && "BitWidth must be positive!");
273         return Solver->mkIte(
274             Exp, Solver->mkBitvector(llvm::APSInt("1"), ToBitWidth),
275             Solver->mkBitvector(llvm::APSInt("0"), ToBitWidth));
276       }
277 
278       if (ToBitWidth > FromBitWidth)
279         return FromTy->isSignedIntegerOrEnumerationType()
280                    ? Solver->mkBVSignExt(ToBitWidth - FromBitWidth, Exp)
281                    : Solver->mkBVZeroExt(ToBitWidth - FromBitWidth, Exp);
282 
283       if (ToBitWidth < FromBitWidth)
284         return Solver->mkBVExtract(ToBitWidth - 1, 0, Exp);
285 
286       // Both are bitvectors with the same width, ignore the type cast
287       return Exp;
288     }
289 
290     if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) {
291       if (ToBitWidth != FromBitWidth)
292         return Solver->mkFPtoFP(Exp, Solver->getFloatSort(ToBitWidth));
293 
294       return Exp;
295     }
296 
297     if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) {
298       llvm::SMTSortRef Sort = Solver->getFloatSort(ToBitWidth);
299       return FromTy->isSignedIntegerOrEnumerationType()
300                  ? Solver->mkSBVtoFP(Exp, Sort)
301                  : Solver->mkUBVtoFP(Exp, Sort);
302     }
303 
304     if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType())
305       return ToTy->isSignedIntegerOrEnumerationType()
306                  ? Solver->mkFPtoSBV(Exp, ToBitWidth)
307                  : Solver->mkFPtoUBV(Exp, ToBitWidth);
308 
309     llvm_unreachable("Unsupported explicit type cast!");
310   }
311 
312   // Callback function for doCast parameter on APSInt type.
castAPSInt(llvm::SMTSolverRef & Solver,const llvm::APSInt & V,QualType ToTy,uint64_t ToWidth,QualType FromTy,uint64_t FromWidth)313   static inline llvm::APSInt castAPSInt(llvm::SMTSolverRef &Solver,
314                                         const llvm::APSInt &V, QualType ToTy,
315                                         uint64_t ToWidth, QualType FromTy,
316                                         uint64_t FromWidth) {
317     APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType());
318     return TargetType.convert(V);
319   }
320 
321   /// Construct an SMTSolverRef from a SymbolData.
322   static inline llvm::SMTExprRef
fromData(llvm::SMTSolverRef & Solver,ASTContext & Ctx,const SymbolData * Sym)323   fromData(llvm::SMTSolverRef &Solver, ASTContext &Ctx, const SymbolData *Sym) {
324     const SymbolID ID = Sym->getSymbolID();
325     const QualType Ty = Sym->getType();
326     const uint64_t BitWidth = Ctx.getTypeSize(Ty);
327 
328     llvm::SmallString<16> Str;
329     llvm::raw_svector_ostream OS(Str);
330     OS << Sym->getKindStr() << ID;
331     return Solver->mkSymbol(Str.c_str(), mkSort(Solver, Ty, BitWidth));
332   }
333 
334   // Wrapper to generate SMTSolverRef from SymbolCast data.
getCastExpr(llvm::SMTSolverRef & Solver,ASTContext & Ctx,const llvm::SMTExprRef & Exp,QualType FromTy,QualType ToTy)335   static inline llvm::SMTExprRef getCastExpr(llvm::SMTSolverRef &Solver,
336                                              ASTContext &Ctx,
337                                              const llvm::SMTExprRef &Exp,
338                                              QualType FromTy, QualType ToTy) {
339     return fromCast(Solver, Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy,
340                     Ctx.getTypeSize(FromTy));
341   }
342 
343   // Wrapper to generate SMTSolverRef from unpacked binary symbolic
344   // expression. Sets the RetTy parameter. See getSMTSolverRef().
345   static inline llvm::SMTExprRef
getBinExpr(llvm::SMTSolverRef & Solver,ASTContext & Ctx,const llvm::SMTExprRef & LHS,QualType LTy,BinaryOperator::Opcode Op,const llvm::SMTExprRef & RHS,QualType RTy,QualType * RetTy)346   getBinExpr(llvm::SMTSolverRef &Solver, ASTContext &Ctx,
347              const llvm::SMTExprRef &LHS, QualType LTy,
348              BinaryOperator::Opcode Op, const llvm::SMTExprRef &RHS,
349              QualType RTy, QualType *RetTy) {
350     llvm::SMTExprRef NewLHS = LHS;
351     llvm::SMTExprRef NewRHS = RHS;
352     doTypeConversion(Solver, Ctx, NewLHS, NewRHS, LTy, RTy);
353 
354     // Update the return type parameter if the output type has changed.
355     if (RetTy) {
356       // A boolean result can be represented as an integer type in C/C++, but at
357       // this point we only care about the SMT sorts. Set it as a boolean type
358       // to avoid subsequent SMT errors.
359       if (BinaryOperator::isComparisonOp(Op) ||
360           BinaryOperator::isLogicalOp(Op)) {
361         *RetTy = Ctx.BoolTy;
362       } else {
363         *RetTy = LTy;
364       }
365 
366       // If the two operands are pointers and the operation is a subtraction,
367       // the result is of type ptrdiff_t, which is signed
368       if (LTy->isAnyPointerType() && RTy->isAnyPointerType() && Op == BO_Sub) {
369         *RetTy = Ctx.getPointerDiffType();
370       }
371     }
372 
373     return LTy->isRealFloatingType()
374                ? fromFloatBinOp(Solver, NewLHS, Op, NewRHS)
375                : fromBinOp(Solver, NewLHS, Op, NewRHS,
376                            LTy->isSignedIntegerOrEnumerationType());
377   }
378 
379   // Wrapper to generate SMTSolverRef from BinarySymExpr.
380   // Sets the hasComparison and RetTy parameters. See getSMTSolverRef().
getSymBinExpr(llvm::SMTSolverRef & Solver,ASTContext & Ctx,const BinarySymExpr * BSE,bool * hasComparison,QualType * RetTy)381   static inline llvm::SMTExprRef getSymBinExpr(llvm::SMTSolverRef &Solver,
382                                                ASTContext &Ctx,
383                                                const BinarySymExpr *BSE,
384                                                bool *hasComparison,
385                                                QualType *RetTy) {
386     QualType LTy, RTy;
387     BinaryOperator::Opcode Op = BSE->getOpcode();
388 
389     if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
390       llvm::SMTExprRef LHS =
391           getSymExpr(Solver, Ctx, SIE->getLHS(), &LTy, hasComparison);
392       llvm::APSInt NewRInt;
393       std::tie(NewRInt, RTy) = fixAPSInt(Ctx, SIE->getRHS());
394       llvm::SMTExprRef RHS =
395           Solver->mkBitvector(NewRInt, NewRInt.getBitWidth());
396       return getBinExpr(Solver, Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
397     }
398 
399     if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
400       llvm::APSInt NewLInt;
401       std::tie(NewLInt, LTy) = fixAPSInt(Ctx, ISE->getLHS());
402       llvm::SMTExprRef LHS =
403           Solver->mkBitvector(NewLInt, NewLInt.getBitWidth());
404       llvm::SMTExprRef RHS =
405           getSymExpr(Solver, Ctx, ISE->getRHS(), &RTy, hasComparison);
406       return getBinExpr(Solver, Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
407     }
408 
409     if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
410       llvm::SMTExprRef LHS =
411           getSymExpr(Solver, Ctx, SSM->getLHS(), &LTy, hasComparison);
412       llvm::SMTExprRef RHS =
413           getSymExpr(Solver, Ctx, SSM->getRHS(), &RTy, hasComparison);
414       return getBinExpr(Solver, Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
415     }
416 
417     llvm_unreachable("Unsupported BinarySymExpr type!");
418   }
419 
420   // Recursive implementation to unpack and generate symbolic expression.
421   // Sets the hasComparison and RetTy parameters. See getExpr().
getSymExpr(llvm::SMTSolverRef & Solver,ASTContext & Ctx,SymbolRef Sym,QualType * RetTy,bool * hasComparison)422   static inline llvm::SMTExprRef getSymExpr(llvm::SMTSolverRef &Solver,
423                                             ASTContext &Ctx, SymbolRef Sym,
424                                             QualType *RetTy,
425                                             bool *hasComparison) {
426     if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
427       if (RetTy)
428         *RetTy = Sym->getType();
429 
430       return fromData(Solver, Ctx, SD);
431     }
432 
433     if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
434       if (RetTy)
435         *RetTy = Sym->getType();
436 
437       QualType FromTy;
438       llvm::SMTExprRef Exp =
439           getSymExpr(Solver, Ctx, SC->getOperand(), &FromTy, hasComparison);
440 
441       // Casting an expression with a comparison invalidates it. Note that this
442       // must occur after the recursive call above.
443       // e.g. (signed char) (x > 0)
444       if (hasComparison)
445         *hasComparison = false;
446       return getCastExpr(Solver, Ctx, Exp, FromTy, Sym->getType());
447     }
448 
449     if (const UnarySymExpr *USE = dyn_cast<UnarySymExpr>(Sym)) {
450       if (RetTy)
451         *RetTy = Sym->getType();
452 
453       QualType OperandTy;
454       llvm::SMTExprRef OperandExp =
455           getSymExpr(Solver, Ctx, USE->getOperand(), &OperandTy, hasComparison);
456       llvm::SMTExprRef UnaryExp =
457           OperandTy->isRealFloatingType()
458               ? fromFloatUnOp(Solver, USE->getOpcode(), OperandExp)
459               : fromUnOp(Solver, USE->getOpcode(), OperandExp);
460 
461       // Currently, without the `support-symbolic-integer-casts=true` option,
462       // we do not emit `SymbolCast`s for implicit casts.
463       // One such implicit cast is missing if the operand of the unary operator
464       // has a different type than the unary itself.
465       if (Ctx.getTypeSize(OperandTy) != Ctx.getTypeSize(Sym->getType())) {
466         if (hasComparison)
467           *hasComparison = false;
468         return getCastExpr(Solver, Ctx, UnaryExp, OperandTy, Sym->getType());
469       }
470       return UnaryExp;
471     }
472 
473     if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
474       llvm::SMTExprRef Exp =
475           getSymBinExpr(Solver, Ctx, BSE, hasComparison, RetTy);
476       // Set the hasComparison parameter, in post-order traversal order.
477       if (hasComparison)
478         *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode());
479       return Exp;
480     }
481 
482     llvm_unreachable("Unsupported SymbolRef type!");
483   }
484 
485   // Generate an SMTSolverRef that represents the given symbolic expression.
486   // Sets the hasComparison parameter if the expression has a comparison
487   // operator. Sets the RetTy parameter to the final return type after
488   // promotions and casts.
489   static inline llvm::SMTExprRef getExpr(llvm::SMTSolverRef &Solver,
490                                          ASTContext &Ctx, SymbolRef Sym,
491                                          QualType *RetTy = nullptr,
492                                          bool *hasComparison = nullptr) {
493     if (hasComparison) {
494       *hasComparison = false;
495     }
496 
497     return getSymExpr(Solver, Ctx, Sym, RetTy, hasComparison);
498   }
499 
500   // Generate an SMTSolverRef that compares the expression to zero.
getZeroExpr(llvm::SMTSolverRef & Solver,ASTContext & Ctx,const llvm::SMTExprRef & Exp,QualType Ty,bool Assumption)501   static inline llvm::SMTExprRef getZeroExpr(llvm::SMTSolverRef &Solver,
502                                              ASTContext &Ctx,
503                                              const llvm::SMTExprRef &Exp,
504                                              QualType Ty, bool Assumption) {
505     if (Ty->isRealFloatingType()) {
506       llvm::APFloat Zero =
507           llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty));
508       return fromFloatBinOp(Solver, Exp, Assumption ? BO_EQ : BO_NE,
509                             Solver->mkFloat(Zero));
510     }
511 
512     if (Ty->isIntegralOrEnumerationType() || Ty->isAnyPointerType() ||
513         Ty->isBlockPointerType() || Ty->isReferenceType()) {
514 
515       // Skip explicit comparison for boolean types
516       bool isSigned = Ty->isSignedIntegerOrEnumerationType();
517       if (Ty->isBooleanType())
518         return Assumption ? fromUnOp(Solver, UO_LNot, Exp) : Exp;
519 
520       return fromBinOp(
521           Solver, Exp, Assumption ? BO_EQ : BO_NE,
522           Solver->mkBitvector(llvm::APSInt("0"), Ctx.getTypeSize(Ty)),
523           isSigned);
524     }
525 
526     llvm_unreachable("Unsupported type for zero value!");
527   }
528 
529   // Wrapper to generate SMTSolverRef from a range. If From == To, an
530   // equality will be created instead.
531   static inline llvm::SMTExprRef
getRangeExpr(llvm::SMTSolverRef & Solver,ASTContext & Ctx,SymbolRef Sym,const llvm::APSInt & From,const llvm::APSInt & To,bool InRange)532   getRangeExpr(llvm::SMTSolverRef &Solver, ASTContext &Ctx, SymbolRef Sym,
533                const llvm::APSInt &From, const llvm::APSInt &To, bool InRange) {
534     // Convert lower bound
535     QualType FromTy;
536     llvm::APSInt NewFromInt;
537     std::tie(NewFromInt, FromTy) = fixAPSInt(Ctx, From);
538     llvm::SMTExprRef FromExp =
539         Solver->mkBitvector(NewFromInt, NewFromInt.getBitWidth());
540 
541     // Convert symbol
542     QualType SymTy;
543     llvm::SMTExprRef Exp = getExpr(Solver, Ctx, Sym, &SymTy);
544 
545     // Construct single (in)equality
546     if (From == To)
547       return getBinExpr(Solver, Ctx, Exp, SymTy, InRange ? BO_EQ : BO_NE,
548                         FromExp, FromTy, /*RetTy=*/nullptr);
549 
550     QualType ToTy;
551     llvm::APSInt NewToInt;
552     std::tie(NewToInt, ToTy) = fixAPSInt(Ctx, To);
553     llvm::SMTExprRef ToExp =
554         Solver->mkBitvector(NewToInt, NewToInt.getBitWidth());
555     assert(FromTy == ToTy && "Range values have different types!");
556 
557     // Construct two (in)equalities, and a logical and/or
558     llvm::SMTExprRef LHS =
559         getBinExpr(Solver, Ctx, Exp, SymTy, InRange ? BO_GE : BO_LT, FromExp,
560                    FromTy, /*RetTy=*/nullptr);
561     llvm::SMTExprRef RHS = getBinExpr(Solver, Ctx, Exp, SymTy,
562                                       InRange ? BO_LE : BO_GT, ToExp, ToTy,
563                                       /*RetTy=*/nullptr);
564 
565     return fromBinOp(Solver, LHS, InRange ? BO_LAnd : BO_LOr, RHS,
566                      SymTy->isSignedIntegerOrEnumerationType());
567   }
568 
569   // Recover the QualType of an APSInt.
570   // TODO: Refactor to put elsewhere
getAPSIntType(ASTContext & Ctx,const llvm::APSInt & Int)571   static inline QualType getAPSIntType(ASTContext &Ctx,
572                                        const llvm::APSInt &Int) {
573     return Ctx.getIntTypeForBitwidth(Int.getBitWidth(), Int.isSigned());
574   }
575 
576   // Get the QualTy for the input APSInt, and fix it if it has a bitwidth of 1.
577   static inline std::pair<llvm::APSInt, QualType>
fixAPSInt(ASTContext & Ctx,const llvm::APSInt & Int)578   fixAPSInt(ASTContext &Ctx, const llvm::APSInt &Int) {
579     llvm::APSInt NewInt;
580 
581     // FIXME: This should be a cast from a 1-bit integer type to a boolean type,
582     // but the former is not available in Clang. Instead, extend the APSInt
583     // directly.
584     if (Int.getBitWidth() == 1 && getAPSIntType(Ctx, Int).isNull()) {
585       NewInt = Int.extend(Ctx.getTypeSize(Ctx.BoolTy));
586     } else
587       NewInt = Int;
588 
589     return std::make_pair(NewInt, getAPSIntType(Ctx, NewInt));
590   }
591 
592   // Perform implicit type conversion on binary symbolic expressions.
593   // May modify all input parameters.
594   // TODO: Refactor to use built-in conversion functions
doTypeConversion(llvm::SMTSolverRef & Solver,ASTContext & Ctx,llvm::SMTExprRef & LHS,llvm::SMTExprRef & RHS,QualType & LTy,QualType & RTy)595   static inline void doTypeConversion(llvm::SMTSolverRef &Solver,
596                                       ASTContext &Ctx, llvm::SMTExprRef &LHS,
597                                       llvm::SMTExprRef &RHS, QualType &LTy,
598                                       QualType &RTy) {
599     assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
600 
601     // Perform type conversion
602     if ((LTy->isIntegralOrEnumerationType() &&
603          RTy->isIntegralOrEnumerationType()) &&
604         (LTy->isArithmeticType() && RTy->isArithmeticType())) {
605       SMTConv::doIntTypeConversion<llvm::SMTExprRef, &fromCast>(
606           Solver, Ctx, LHS, LTy, RHS, RTy);
607       return;
608     }
609 
610     if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) {
611       SMTConv::doFloatTypeConversion<llvm::SMTExprRef, &fromCast>(
612           Solver, Ctx, LHS, LTy, RHS, RTy);
613       return;
614     }
615 
616     if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) ||
617         (LTy->isBlockPointerType() || RTy->isBlockPointerType()) ||
618         (LTy->isReferenceType() || RTy->isReferenceType())) {
619       // TODO: Refactor to Sema::FindCompositePointerType(), and
620       // Sema::CheckCompareOperands().
621 
622       uint64_t LBitWidth = Ctx.getTypeSize(LTy);
623       uint64_t RBitWidth = Ctx.getTypeSize(RTy);
624 
625       // Cast the non-pointer type to the pointer type.
626       // TODO: Be more strict about this.
627       if ((LTy->isAnyPointerType() ^ RTy->isAnyPointerType()) ||
628           (LTy->isBlockPointerType() ^ RTy->isBlockPointerType()) ||
629           (LTy->isReferenceType() ^ RTy->isReferenceType())) {
630         if (LTy->isNullPtrType() || LTy->isBlockPointerType() ||
631             LTy->isReferenceType()) {
632           LHS = fromCast(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
633           LTy = RTy;
634         } else {
635           RHS = fromCast(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
636           RTy = LTy;
637         }
638       }
639 
640       // Cast the void pointer type to the non-void pointer type.
641       // For void types, this assumes that the casted value is equal to the
642       // value of the original pointer, and does not account for alignment
643       // requirements.
644       if (LTy->isVoidPointerType() ^ RTy->isVoidPointerType()) {
645         assert((Ctx.getTypeSize(LTy) == Ctx.getTypeSize(RTy)) &&
646                "Pointer types have different bitwidths!");
647         if (RTy->isVoidPointerType())
648           RTy = LTy;
649         else
650           LTy = RTy;
651       }
652 
653       if (LTy == RTy)
654         return;
655     }
656 
657     // Fallback: for the solver, assume that these types don't really matter
658     if ((LTy.getCanonicalType() == RTy.getCanonicalType()) ||
659         (LTy->isObjCObjectPointerType() && RTy->isObjCObjectPointerType())) {
660       LTy = RTy;
661       return;
662     }
663 
664     // TODO: Refine behavior for invalid type casts
665   }
666 
667   // Perform implicit integer type conversion.
668   // May modify all input parameters.
669   // TODO: Refactor to use Sema::handleIntegerConversion()
670   template <typename T, T (*doCast)(llvm::SMTSolverRef &Solver, const T &,
671                                     QualType, uint64_t, QualType, uint64_t)>
doIntTypeConversion(llvm::SMTSolverRef & Solver,ASTContext & Ctx,T & LHS,QualType & LTy,T & RHS,QualType & RTy)672   static inline void doIntTypeConversion(llvm::SMTSolverRef &Solver,
673                                          ASTContext &Ctx, T &LHS, QualType &LTy,
674                                          T &RHS, QualType &RTy) {
675     uint64_t LBitWidth = Ctx.getTypeSize(LTy);
676     uint64_t RBitWidth = Ctx.getTypeSize(RTy);
677 
678     assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
679     // Always perform integer promotion before checking type equality.
680     // Otherwise, e.g. (bool) a + (bool) b could trigger a backend assertion
681     if (Ctx.isPromotableIntegerType(LTy)) {
682       QualType NewTy = Ctx.getPromotedIntegerType(LTy);
683       uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
684       LHS = (*doCast)(Solver, LHS, NewTy, NewBitWidth, LTy, LBitWidth);
685       LTy = NewTy;
686       LBitWidth = NewBitWidth;
687     }
688     if (Ctx.isPromotableIntegerType(RTy)) {
689       QualType NewTy = Ctx.getPromotedIntegerType(RTy);
690       uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
691       RHS = (*doCast)(Solver, RHS, NewTy, NewBitWidth, RTy, RBitWidth);
692       RTy = NewTy;
693       RBitWidth = NewBitWidth;
694     }
695 
696     if (LTy == RTy)
697       return;
698 
699     // Perform integer type conversion
700     // Note: Safe to skip updating bitwidth because this must terminate
701     bool isLSignedTy = LTy->isSignedIntegerOrEnumerationType();
702     bool isRSignedTy = RTy->isSignedIntegerOrEnumerationType();
703 
704     int order = Ctx.getIntegerTypeOrder(LTy, RTy);
705     if (isLSignedTy == isRSignedTy) {
706       // Same signedness; use the higher-ranked type
707       if (order == 1) {
708         RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
709         RTy = LTy;
710       } else {
711         LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
712         LTy = RTy;
713       }
714     } else if (order != (isLSignedTy ? 1 : -1)) {
715       // The unsigned type has greater than or equal rank to the
716       // signed type, so use the unsigned type
717       if (isRSignedTy) {
718         RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
719         RTy = LTy;
720       } else {
721         LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
722         LTy = RTy;
723       }
724     } else if (LBitWidth != RBitWidth) {
725       // The two types are different widths; if we are here, that
726       // means the signed type is larger than the unsigned type, so
727       // use the signed type.
728       if (isLSignedTy) {
729         RHS = (doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
730         RTy = LTy;
731       } else {
732         LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
733         LTy = RTy;
734       }
735     } else {
736       // The signed type is higher-ranked than the unsigned type,
737       // but isn't actually any bigger (like unsigned int and long
738       // on most 32-bit systems).  Use the unsigned type corresponding
739       // to the signed type.
740       QualType NewTy =
741           Ctx.getCorrespondingUnsignedType(isLSignedTy ? LTy : RTy);
742       RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
743       RTy = NewTy;
744       LHS = (doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
745       LTy = NewTy;
746     }
747   }
748 
749   // Perform implicit floating-point type conversion.
750   // May modify all input parameters.
751   // TODO: Refactor to use Sema::handleFloatConversion()
752   template <typename T, T (*doCast)(llvm::SMTSolverRef &Solver, const T &,
753                                     QualType, uint64_t, QualType, uint64_t)>
754   static inline void
doFloatTypeConversion(llvm::SMTSolverRef & Solver,ASTContext & Ctx,T & LHS,QualType & LTy,T & RHS,QualType & RTy)755   doFloatTypeConversion(llvm::SMTSolverRef &Solver, ASTContext &Ctx, T &LHS,
756                         QualType &LTy, T &RHS, QualType &RTy) {
757     uint64_t LBitWidth = Ctx.getTypeSize(LTy);
758     uint64_t RBitWidth = Ctx.getTypeSize(RTy);
759 
760     // Perform float-point type promotion
761     if (!LTy->isRealFloatingType()) {
762       LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
763       LTy = RTy;
764       LBitWidth = RBitWidth;
765     }
766     if (!RTy->isRealFloatingType()) {
767       RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
768       RTy = LTy;
769       RBitWidth = LBitWidth;
770     }
771 
772     if (LTy == RTy)
773       return;
774 
775     // If we have two real floating types, convert the smaller operand to the
776     // bigger result
777     // Note: Safe to skip updating bitwidth because this must terminate
778     int order = Ctx.getFloatingTypeOrder(LTy, RTy);
779     if (order > 0) {
780       RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
781       RTy = LTy;
782     } else if (order == 0) {
783       LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
784       LTy = RTy;
785     } else {
786       llvm_unreachable("Unsupported floating-point type cast!");
787     }
788   }
789 };
790 } // namespace ento
791 } // namespace clang
792 
793 #endif
794