xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Support/SMTAPI.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file defines a SMT generic Solver API, which will be the base class
10 //  for every SMT solver specific class.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_SUPPORT_SMTAPI_H
15 #define LLVM_SUPPORT_SMTAPI_H
16 
17 #include "llvm/ADT/APFloat.h"
18 #include "llvm/ADT/APSInt.h"
19 #include "llvm/ADT/FoldingSet.h"
20 #include "llvm/Support/Compiler.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 /// Generic base class for SMT sorts
27 class SMTSort {
28 public:
29   SMTSort() = default;
30   virtual ~SMTSort() = default;
31 
32   /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
isBitvectorSort()33   virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
34 
35   /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
isFloatSort()36   virtual bool isFloatSort() const { return isFloatSortImpl(); }
37 
38   /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
isBooleanSort()39   virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
40 
41   /// Returns the bitvector size, fails if the sort is not a bitvector
42   /// Calls getBitvectorSortSizeImpl().
getBitvectorSortSize()43   virtual unsigned getBitvectorSortSize() const {
44     assert(isBitvectorSort() && "Not a bitvector sort!");
45     unsigned Size = getBitvectorSortSizeImpl();
46     assert(Size && "Size is zero!");
47     return Size;
48   };
49 
50   /// Returns the floating-point size, fails if the sort is not a floating-point
51   /// Calls getFloatSortSizeImpl().
getFloatSortSize()52   virtual unsigned getFloatSortSize() const {
53     assert(isFloatSort() && "Not a floating-point sort!");
54     unsigned Size = getFloatSortSizeImpl();
55     assert(Size && "Size is zero!");
56     return Size;
57   };
58 
59   virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
60 
61   bool operator<(const SMTSort &Other) const {
62     llvm::FoldingSetNodeID ID1, ID2;
63     Profile(ID1);
64     Other.Profile(ID2);
65     return ID1 < ID2;
66   }
67 
68   friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
69     return LHS.equal_to(RHS);
70   }
71 
72   virtual void print(raw_ostream &OS) const = 0;
73 
74 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
75   LLVM_DUMP_METHOD void dump() const;
76 #endif
77 
78 protected:
79   /// Query the SMT solver and returns true if two sorts are equal (same kind
80   /// and bit width). This does not check if the two sorts are the same objects.
81   virtual bool equal_to(SMTSort const &other) const = 0;
82 
83   /// Query the SMT solver and checks if a sort is bitvector.
84   virtual bool isBitvectorSortImpl() const = 0;
85 
86   /// Query the SMT solver and checks if a sort is floating-point.
87   virtual bool isFloatSortImpl() const = 0;
88 
89   /// Query the SMT solver and checks if a sort is boolean.
90   virtual bool isBooleanSortImpl() const = 0;
91 
92   /// Query the SMT solver and returns the sort bit width.
93   virtual unsigned getBitvectorSortSizeImpl() const = 0;
94 
95   /// Query the SMT solver and returns the sort bit width.
96   virtual unsigned getFloatSortSizeImpl() const = 0;
97 };
98 
99 /// Shared pointer for SMTSorts, used by SMTSolver API.
100 using SMTSortRef = const SMTSort *;
101 
102 /// Generic base class for SMT exprs
103 class SMTExpr {
104 public:
105   SMTExpr() = default;
106   virtual ~SMTExpr() = default;
107 
108   bool operator<(const SMTExpr &Other) const {
109     llvm::FoldingSetNodeID ID1, ID2;
110     Profile(ID1);
111     Other.Profile(ID2);
112     return ID1 < ID2;
113   }
114 
115   virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
116 
117   friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
118     return LHS.equal_to(RHS);
119   }
120 
121   virtual void print(raw_ostream &OS) const = 0;
122 
123 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
124   LLVM_DUMP_METHOD void dump() const;
125 #endif
126 
127 protected:
128   /// Query the SMT solver and returns true if two sorts are equal (same kind
129   /// and bit width). This does not check if the two sorts are the same objects.
130   virtual bool equal_to(SMTExpr const &other) const = 0;
131 };
132 
133 class SMTSolverStatistics {
134 public:
135   SMTSolverStatistics() = default;
136   virtual ~SMTSolverStatistics() = default;
137 
138   virtual double getDouble(llvm::StringRef) const = 0;
139   virtual unsigned getUnsigned(llvm::StringRef) const = 0;
140 
141   virtual void print(raw_ostream &OS) const = 0;
142 
143 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
144   LLVM_DUMP_METHOD void dump() const;
145 #endif
146 };
147 
148 /// Shared pointer for SMTExprs, used by SMTSolver API.
149 using SMTExprRef = const SMTExpr *;
150 
151 /// Generic base class for SMT Solvers
152 ///
153 /// This class is responsible for wrapping all sorts and expression generation,
154 /// through the mk* methods. It also provides methods to create SMT expressions
155 /// straight from clang's AST, through the from* methods.
156 class SMTSolver {
157 public:
158   SMTSolver() = default;
159   virtual ~SMTSolver() = default;
160 
161 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
162   LLVM_DUMP_METHOD void dump() const;
163 #endif
164 
165   // Returns an appropriate floating-point sort for the given bitwidth.
getFloatSort(unsigned BitWidth)166   SMTSortRef getFloatSort(unsigned BitWidth) {
167     switch (BitWidth) {
168     case 16:
169       return getFloat16Sort();
170     case 32:
171       return getFloat32Sort();
172     case 64:
173       return getFloat64Sort();
174     case 128:
175       return getFloat128Sort();
176     default:;
177     }
178     llvm_unreachable("Unsupported floating-point bitwidth!");
179   }
180 
181   // Returns a boolean sort.
182   virtual SMTSortRef getBoolSort() = 0;
183 
184   // Returns an appropriate bitvector sort for the given bitwidth.
185   virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
186 
187   // Returns a floating-point sort of width 16
188   virtual SMTSortRef getFloat16Sort() = 0;
189 
190   // Returns a floating-point sort of width 32
191   virtual SMTSortRef getFloat32Sort() = 0;
192 
193   // Returns a floating-point sort of width 64
194   virtual SMTSortRef getFloat64Sort() = 0;
195 
196   // Returns a floating-point sort of width 128
197   virtual SMTSortRef getFloat128Sort() = 0;
198 
199   // Returns an appropriate sort for the given AST.
200   virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
201 
202   /// Given a constraint, adds it to the solver
203   virtual void addConstraint(const SMTExprRef &Exp) const = 0;
204 
205   /// Creates a bitvector addition operation
206   virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
207 
208   /// Creates a bitvector subtraction operation
209   virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
210 
211   /// Creates a bitvector multiplication operation
212   virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
213 
214   /// Creates a bitvector signed modulus operation
215   virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
216 
217   /// Creates a bitvector unsigned modulus operation
218   virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
219 
220   /// Creates a bitvector signed division operation
221   virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
222 
223   /// Creates a bitvector unsigned division operation
224   virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
225 
226   /// Creates a bitvector logical shift left operation
227   virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
228 
229   /// Creates a bitvector arithmetic shift right operation
230   virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
231 
232   /// Creates a bitvector logical shift right operation
233   virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
234 
235   /// Creates a bitvector negation operation
236   virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
237 
238   /// Creates a bitvector not operation
239   virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
240 
241   /// Creates a bitvector xor operation
242   virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
243 
244   /// Creates a bitvector or operation
245   virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
246 
247   /// Creates a bitvector and operation
248   virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
249 
250   /// Creates a bitvector unsigned less-than operation
251   virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
252 
253   /// Creates a bitvector signed less-than operation
254   virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
255 
256   /// Creates a bitvector unsigned greater-than operation
257   virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
258 
259   /// Creates a bitvector signed greater-than operation
260   virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
261 
262   /// Creates a bitvector unsigned less-equal-than operation
263   virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
264 
265   /// Creates a bitvector signed less-equal-than operation
266   virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
267 
268   /// Creates a bitvector unsigned greater-equal-than operation
269   virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
270 
271   /// Creates a bitvector signed greater-equal-than operation
272   virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
273 
274   /// Creates a boolean not operation
275   virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
276 
277   /// Creates a boolean equality operation
278   virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
279 
280   /// Creates a boolean and operation
281   virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
282 
283   /// Creates a boolean or operation
284   virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
285 
286   /// Creates a boolean ite operation
287   virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
288                            const SMTExprRef &F) = 0;
289 
290   /// Creates a bitvector sign extension operation
291   virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
292 
293   /// Creates a bitvector zero extension operation
294   virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
295 
296   /// Creates a bitvector extract operation
297   virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
298                                  const SMTExprRef &Exp) = 0;
299 
300   /// Creates a bitvector concat operation
301   virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
302                                 const SMTExprRef &RHS) = 0;
303 
304   /// Creates a predicate that checks for overflow in a bitvector addition
305   /// operation
306   virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
307                                        const SMTExprRef &RHS,
308                                        bool isSigned) = 0;
309 
310   /// Creates a predicate that checks for underflow in a signed bitvector
311   /// addition operation
312   virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
313                                         const SMTExprRef &RHS) = 0;
314 
315   /// Creates a predicate that checks for overflow in a signed bitvector
316   /// subtraction operation
317   virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
318                                        const SMTExprRef &RHS) = 0;
319 
320   /// Creates a predicate that checks for underflow in a bitvector subtraction
321   /// operation
322   virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
323                                         const SMTExprRef &RHS,
324                                         bool isSigned) = 0;
325 
326   /// Creates a predicate that checks for overflow in a signed bitvector
327   /// division/modulus operation
328   virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
329                                         const SMTExprRef &RHS) = 0;
330 
331   /// Creates a predicate that checks for overflow in a bitvector negation
332   /// operation
333   virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
334 
335   /// Creates a predicate that checks for overflow in a bitvector multiplication
336   /// operation
337   virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
338                                        const SMTExprRef &RHS,
339                                        bool isSigned) = 0;
340 
341   /// Creates a predicate that checks for underflow in a signed bitvector
342   /// multiplication operation
343   virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
344                                         const SMTExprRef &RHS) = 0;
345 
346   /// Creates a floating-point negation operation
347   virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
348 
349   /// Creates a floating-point isInfinite operation
350   virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
351 
352   /// Creates a floating-point isNaN operation
353   virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
354 
355   /// Creates a floating-point isNormal operation
356   virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
357 
358   /// Creates a floating-point isZero operation
359   virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
360 
361   /// Creates a floating-point multiplication operation
362   virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
363 
364   /// Creates a floating-point division operation
365   virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
366 
367   /// Creates a floating-point remainder operation
368   virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
369 
370   /// Creates a floating-point addition operation
371   virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
372 
373   /// Creates a floating-point subtraction operation
374   virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
375 
376   /// Creates a floating-point less-than operation
377   virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
378 
379   /// Creates a floating-point greater-than operation
380   virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
381 
382   /// Creates a floating-point less-than-or-equal operation
383   virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
384 
385   /// Creates a floating-point greater-than-or-equal operation
386   virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
387 
388   /// Creates a floating-point equality operation
389   virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
390                                const SMTExprRef &RHS) = 0;
391 
392   /// Creates a floating-point conversion from floatint-point to floating-point
393   /// operation
394   virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
395 
396   /// Creates a floating-point conversion from signed bitvector to
397   /// floatint-point operation
398   virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
399                                const SMTSortRef &To) = 0;
400 
401   /// Creates a floating-point conversion from unsigned bitvector to
402   /// floatint-point operation
403   virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
404                                const SMTSortRef &To) = 0;
405 
406   /// Creates a floating-point conversion from floatint-point to signed
407   /// bitvector operation
408   virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
409 
410   /// Creates a floating-point conversion from floatint-point to unsigned
411   /// bitvector operation
412   virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
413 
414   /// Creates a new symbol, given a name and a sort
415   virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
416 
417   // Returns an appropriate floating-point rounding mode.
418   virtual SMTExprRef getFloatRoundingMode() = 0;
419 
420   // If the a model is available, returns the value of a given bitvector symbol
421   virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
422                                     bool isUnsigned) = 0;
423 
424   // If the a model is available, returns the value of a given boolean symbol
425   virtual bool getBoolean(const SMTExprRef &Exp) = 0;
426 
427   /// Constructs an SMTExprRef from a boolean.
428   virtual SMTExprRef mkBoolean(const bool b) = 0;
429 
430   /// Constructs an SMTExprRef from a finite APFloat.
431   virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
432 
433   /// Constructs an SMTExprRef from an APSInt and its bit width
434   virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
435 
436   /// Given an expression, extract the value of this operand in the model.
437   virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
438 
439   /// Given an expression extract the value of this operand in the model.
440   virtual bool getInterpretation(const SMTExprRef &Exp,
441                                  llvm::APFloat &Float) = 0;
442 
443   /// Check if the constraints are satisfiable
444   virtual std::optional<bool> check() const = 0;
445 
446   /// Push the current solver state
447   virtual void push() = 0;
448 
449   /// Pop the previous solver state
450   virtual void pop(unsigned NumStates = 1) = 0;
451 
452   /// Reset the solver and remove all constraints.
453   virtual void reset() = 0;
454 
455   /// Checks if the solver supports floating-points.
456   virtual bool isFPSupported() = 0;
457 
458   virtual void print(raw_ostream &OS) const = 0;
459 
460   /// Sets the requested option.
461   virtual void setBoolParam(StringRef Key, bool Value) = 0;
462   virtual void setUnsignedParam(StringRef Key, unsigned Value) = 0;
463 
464   virtual std::unique_ptr<SMTSolverStatistics> getStatistics() const = 0;
465 };
466 
467 /// Shared pointer for SMTSolvers.
468 using SMTSolverRef = std::shared_ptr<SMTSolver>;
469 
470 /// Convenience method to create and Z3Solver object
471 LLVM_ABI SMTSolverRef CreateZ3Solver();
472 
473 } // namespace llvm
474 
475 #endif
476