xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Support/SMTAPI.h (revision da5432eda807c4b7232d030d5157d5b417ea4f52)
1 //===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file defines a SMT generic Solver API, which will be the base class
10 //  for every SMT solver specific class.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_SUPPORT_SMTAPI_H
15 #define LLVM_SUPPORT_SMTAPI_H
16 
17 #include "llvm/ADT/APFloat.h"
18 #include "llvm/ADT/APSInt.h"
19 #include "llvm/ADT/FoldingSet.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include <memory>
22 
23 namespace llvm {
24 
25 /// Generic base class for SMT sorts
26 class SMTSort {
27 public:
28   SMTSort() = default;
29   virtual ~SMTSort() = default;
30 
31   /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
32   virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
33 
34   /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
35   virtual bool isFloatSort() const { return isFloatSortImpl(); }
36 
37   /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
38   virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
39 
40   /// Returns the bitvector size, fails if the sort is not a bitvector
41   /// Calls getBitvectorSortSizeImpl().
42   virtual unsigned getBitvectorSortSize() const {
43     assert(isBitvectorSort() && "Not a bitvector sort!");
44     unsigned Size = getBitvectorSortSizeImpl();
45     assert(Size && "Size is zero!");
46     return Size;
47   };
48 
49   /// Returns the floating-point size, fails if the sort is not a floating-point
50   /// Calls getFloatSortSizeImpl().
51   virtual unsigned getFloatSortSize() const {
52     assert(isFloatSort() && "Not a floating-point sort!");
53     unsigned Size = getFloatSortSizeImpl();
54     assert(Size && "Size is zero!");
55     return Size;
56   };
57 
58   virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
59 
60   bool operator<(const SMTSort &Other) const {
61     llvm::FoldingSetNodeID ID1, ID2;
62     Profile(ID1);
63     Other.Profile(ID2);
64     return ID1 < ID2;
65   }
66 
67   friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
68     return LHS.equal_to(RHS);
69   }
70 
71   virtual void print(raw_ostream &OS) const = 0;
72 
73   LLVM_DUMP_METHOD void dump() const;
74 
75 protected:
76   /// Query the SMT solver and returns true if two sorts are equal (same kind
77   /// and bit width). This does not check if the two sorts are the same objects.
78   virtual bool equal_to(SMTSort const &other) const = 0;
79 
80   /// Query the SMT solver and checks if a sort is bitvector.
81   virtual bool isBitvectorSortImpl() const = 0;
82 
83   /// Query the SMT solver and checks if a sort is floating-point.
84   virtual bool isFloatSortImpl() const = 0;
85 
86   /// Query the SMT solver and checks if a sort is boolean.
87   virtual bool isBooleanSortImpl() const = 0;
88 
89   /// Query the SMT solver and returns the sort bit width.
90   virtual unsigned getBitvectorSortSizeImpl() const = 0;
91 
92   /// Query the SMT solver and returns the sort bit width.
93   virtual unsigned getFloatSortSizeImpl() const = 0;
94 };
95 
96 /// Shared pointer for SMTSorts, used by SMTSolver API.
97 using SMTSortRef = const SMTSort *;
98 
99 /// Generic base class for SMT exprs
100 class SMTExpr {
101 public:
102   SMTExpr() = default;
103   virtual ~SMTExpr() = default;
104 
105   bool operator<(const SMTExpr &Other) const {
106     llvm::FoldingSetNodeID ID1, ID2;
107     Profile(ID1);
108     Other.Profile(ID2);
109     return ID1 < ID2;
110   }
111 
112   virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
113 
114   friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
115     return LHS.equal_to(RHS);
116   }
117 
118   virtual void print(raw_ostream &OS) const = 0;
119 
120   LLVM_DUMP_METHOD void dump() const;
121 
122 protected:
123   /// Query the SMT solver and returns true if two sorts are equal (same kind
124   /// and bit width). This does not check if the two sorts are the same objects.
125   virtual bool equal_to(SMTExpr const &other) const = 0;
126 };
127 
128 /// Shared pointer for SMTExprs, used by SMTSolver API.
129 using SMTExprRef = const SMTExpr *;
130 
131 /// Generic base class for SMT Solvers
132 ///
133 /// This class is responsible for wrapping all sorts and expression generation,
134 /// through the mk* methods. It also provides methods to create SMT expressions
135 /// straight from clang's AST, through the from* methods.
136 class SMTSolver {
137 public:
138   SMTSolver() = default;
139   virtual ~SMTSolver() = default;
140 
141   LLVM_DUMP_METHOD void dump() const;
142 
143   // Returns an appropriate floating-point sort for the given bitwidth.
144   SMTSortRef getFloatSort(unsigned BitWidth) {
145     switch (BitWidth) {
146     case 16:
147       return getFloat16Sort();
148     case 32:
149       return getFloat32Sort();
150     case 64:
151       return getFloat64Sort();
152     case 128:
153       return getFloat128Sort();
154     default:;
155     }
156     llvm_unreachable("Unsupported floating-point bitwidth!");
157   }
158 
159   // Returns a boolean sort.
160   virtual SMTSortRef getBoolSort() = 0;
161 
162   // Returns an appropriate bitvector sort for the given bitwidth.
163   virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
164 
165   // Returns a floating-point sort of width 16
166   virtual SMTSortRef getFloat16Sort() = 0;
167 
168   // Returns a floating-point sort of width 32
169   virtual SMTSortRef getFloat32Sort() = 0;
170 
171   // Returns a floating-point sort of width 64
172   virtual SMTSortRef getFloat64Sort() = 0;
173 
174   // Returns a floating-point sort of width 128
175   virtual SMTSortRef getFloat128Sort() = 0;
176 
177   // Returns an appropriate sort for the given AST.
178   virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
179 
180   /// Given a constraint, adds it to the solver
181   virtual void addConstraint(const SMTExprRef &Exp) const = 0;
182 
183   /// Creates a bitvector addition operation
184   virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
185 
186   /// Creates a bitvector subtraction operation
187   virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
188 
189   /// Creates a bitvector multiplication operation
190   virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
191 
192   /// Creates a bitvector signed modulus operation
193   virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
194 
195   /// Creates a bitvector unsigned modulus operation
196   virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
197 
198   /// Creates a bitvector signed division operation
199   virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
200 
201   /// Creates a bitvector unsigned division operation
202   virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
203 
204   /// Creates a bitvector logical shift left operation
205   virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
206 
207   /// Creates a bitvector arithmetic shift right operation
208   virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
209 
210   /// Creates a bitvector logical shift right operation
211   virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
212 
213   /// Creates a bitvector negation operation
214   virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
215 
216   /// Creates a bitvector not operation
217   virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
218 
219   /// Creates a bitvector xor operation
220   virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
221 
222   /// Creates a bitvector or operation
223   virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
224 
225   /// Creates a bitvector and operation
226   virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
227 
228   /// Creates a bitvector unsigned less-than operation
229   virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
230 
231   /// Creates a bitvector signed less-than operation
232   virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
233 
234   /// Creates a bitvector unsigned greater-than operation
235   virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
236 
237   /// Creates a bitvector signed greater-than operation
238   virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
239 
240   /// Creates a bitvector unsigned less-equal-than operation
241   virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
242 
243   /// Creates a bitvector signed less-equal-than operation
244   virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
245 
246   /// Creates a bitvector unsigned greater-equal-than operation
247   virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
248 
249   /// Creates a bitvector signed greater-equal-than operation
250   virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
251 
252   /// Creates a boolean not operation
253   virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
254 
255   /// Creates a boolean equality operation
256   virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
257 
258   /// Creates a boolean and operation
259   virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
260 
261   /// Creates a boolean or operation
262   virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
263 
264   /// Creates a boolean ite operation
265   virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
266                            const SMTExprRef &F) = 0;
267 
268   /// Creates a bitvector sign extension operation
269   virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
270 
271   /// Creates a bitvector zero extension operation
272   virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
273 
274   /// Creates a bitvector extract operation
275   virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
276                                  const SMTExprRef &Exp) = 0;
277 
278   /// Creates a bitvector concat operation
279   virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
280                                 const SMTExprRef &RHS) = 0;
281 
282   /// Creates a predicate that checks for overflow in a bitvector addition
283   /// operation
284   virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
285                                        const SMTExprRef &RHS,
286                                        bool isSigned) = 0;
287 
288   /// Creates a predicate that checks for underflow in a signed bitvector
289   /// addition operation
290   virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
291                                         const SMTExprRef &RHS) = 0;
292 
293   /// Creates a predicate that checks for overflow in a signed bitvector
294   /// subtraction operation
295   virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
296                                        const SMTExprRef &RHS) = 0;
297 
298   /// Creates a predicate that checks for underflow in a bitvector subtraction
299   /// operation
300   virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
301                                         const SMTExprRef &RHS,
302                                         bool isSigned) = 0;
303 
304   /// Creates a predicate that checks for overflow in a signed bitvector
305   /// division/modulus operation
306   virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
307                                         const SMTExprRef &RHS) = 0;
308 
309   /// Creates a predicate that checks for overflow in a bitvector negation
310   /// operation
311   virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
312 
313   /// Creates a predicate that checks for overflow in a bitvector multiplication
314   /// operation
315   virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
316                                        const SMTExprRef &RHS,
317                                        bool isSigned) = 0;
318 
319   /// Creates a predicate that checks for underflow in a signed bitvector
320   /// multiplication operation
321   virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
322                                         const SMTExprRef &RHS) = 0;
323 
324   /// Creates a floating-point negation operation
325   virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
326 
327   /// Creates a floating-point isInfinite operation
328   virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
329 
330   /// Creates a floating-point isNaN operation
331   virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
332 
333   /// Creates a floating-point isNormal operation
334   virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
335 
336   /// Creates a floating-point isZero operation
337   virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
338 
339   /// Creates a floating-point multiplication operation
340   virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
341 
342   /// Creates a floating-point division operation
343   virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
344 
345   /// Creates a floating-point remainder operation
346   virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
347 
348   /// Creates a floating-point addition operation
349   virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
350 
351   /// Creates a floating-point subtraction operation
352   virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
353 
354   /// Creates a floating-point less-than operation
355   virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
356 
357   /// Creates a floating-point greater-than operation
358   virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
359 
360   /// Creates a floating-point less-than-or-equal operation
361   virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
362 
363   /// Creates a floating-point greater-than-or-equal operation
364   virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
365 
366   /// Creates a floating-point equality operation
367   virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
368                                const SMTExprRef &RHS) = 0;
369 
370   /// Creates a floating-point conversion from floatint-point to floating-point
371   /// operation
372   virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
373 
374   /// Creates a floating-point conversion from signed bitvector to
375   /// floatint-point operation
376   virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
377                                const SMTSortRef &To) = 0;
378 
379   /// Creates a floating-point conversion from unsigned bitvector to
380   /// floatint-point operation
381   virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
382                                const SMTSortRef &To) = 0;
383 
384   /// Creates a floating-point conversion from floatint-point to signed
385   /// bitvector operation
386   virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
387 
388   /// Creates a floating-point conversion from floatint-point to unsigned
389   /// bitvector operation
390   virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
391 
392   /// Creates a new symbol, given a name and a sort
393   virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
394 
395   // Returns an appropriate floating-point rounding mode.
396   virtual SMTExprRef getFloatRoundingMode() = 0;
397 
398   // If the a model is available, returns the value of a given bitvector symbol
399   virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
400                                     bool isUnsigned) = 0;
401 
402   // If the a model is available, returns the value of a given boolean symbol
403   virtual bool getBoolean(const SMTExprRef &Exp) = 0;
404 
405   /// Constructs an SMTExprRef from a boolean.
406   virtual SMTExprRef mkBoolean(const bool b) = 0;
407 
408   /// Constructs an SMTExprRef from a finite APFloat.
409   virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
410 
411   /// Constructs an SMTExprRef from an APSInt and its bit width
412   virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
413 
414   /// Given an expression, extract the value of this operand in the model.
415   virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
416 
417   /// Given an expression extract the value of this operand in the model.
418   virtual bool getInterpretation(const SMTExprRef &Exp,
419                                  llvm::APFloat &Float) = 0;
420 
421   /// Check if the constraints are satisfiable
422   virtual std::optional<bool> check() const = 0;
423 
424   /// Push the current solver state
425   virtual void push() = 0;
426 
427   /// Pop the previous solver state
428   virtual void pop(unsigned NumStates = 1) = 0;
429 
430   /// Reset the solver and remove all constraints.
431   virtual void reset() = 0;
432 
433   /// Checks if the solver supports floating-points.
434   virtual bool isFPSupported() = 0;
435 
436   virtual void print(raw_ostream &OS) const = 0;
437 };
438 
439 /// Shared pointer for SMTSolvers.
440 using SMTSolverRef = std::shared_ptr<SMTSolver>;
441 
442 /// Convenience method to create and Z3Solver object
443 SMTSolverRef CreateZ3Solver();
444 
445 } // namespace llvm
446 
447 #endif
448