xref: /freebsd/contrib/llvm-project/llvm/lib/Support/Z3Solver.cpp (revision 8ddb146abcdf061be9f2c0db7e391697dafad85c)
1 //== Z3Solver.cpp -----------------------------------------------*- C++ -*--==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/SmallString.h"
10 #include "llvm/ADT/Twine.h"
11 #include "llvm/Config/config.h"
12 #include "llvm/Support/SMTAPI.h"
13 #include <set>
14 
15 using namespace llvm;
16 
17 #if LLVM_WITH_Z3
18 
19 #include <z3.h>
20 
21 namespace {
22 
23 /// Configuration class for Z3
24 class Z3Config {
25   friend class Z3Context;
26 
27   Z3_config Config;
28 
29 public:
30   Z3Config() : Config(Z3_mk_config()) {
31     // Enable model finding
32     Z3_set_param_value(Config, "model", "true");
33     // Disable proof generation
34     Z3_set_param_value(Config, "proof", "false");
35     // Set timeout to 15000ms = 15s
36     Z3_set_param_value(Config, "timeout", "15000");
37   }
38 
39   ~Z3Config() { Z3_del_config(Config); }
40 }; // end class Z3Config
41 
42 // Function used to report errors
43 void Z3ErrorHandler(Z3_context Context, Z3_error_code Error) {
44   llvm::report_fatal_error("Z3 error: " +
45                            llvm::Twine(Z3_get_error_msg(Context, Error)));
46 }
47 
48 /// Wrapper for Z3 context
49 class Z3Context {
50 public:
51   Z3_context Context;
52 
53   Z3Context() {
54     Context = Z3_mk_context_rc(Z3Config().Config);
55     // The error function is set here because the context is the first object
56     // created by the backend
57     Z3_set_error_handler(Context, Z3ErrorHandler);
58   }
59 
60   virtual ~Z3Context() {
61     Z3_del_context(Context);
62     Context = nullptr;
63   }
64 }; // end class Z3Context
65 
66 /// Wrapper for Z3 Sort
67 class Z3Sort : public SMTSort {
68   friend class Z3Solver;
69 
70   Z3Context &Context;
71 
72   Z3_sort Sort;
73 
74 public:
75   /// Default constructor, mainly used by make_shared
76   Z3Sort(Z3Context &C, Z3_sort ZS) : Context(C), Sort(ZS) {
77     Z3_inc_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
78   }
79 
80   /// Override implicit copy constructor for correct reference counting.
81   Z3Sort(const Z3Sort &Other) : Context(Other.Context), Sort(Other.Sort) {
82     Z3_inc_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
83   }
84 
85   /// Override implicit copy assignment constructor for correct reference
86   /// counting.
87   Z3Sort &operator=(const Z3Sort &Other) {
88     Z3_inc_ref(Context.Context, reinterpret_cast<Z3_ast>(Other.Sort));
89     Z3_dec_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
90     Sort = Other.Sort;
91     return *this;
92   }
93 
94   Z3Sort(Z3Sort &&Other) = delete;
95   Z3Sort &operator=(Z3Sort &&Other) = delete;
96 
97   ~Z3Sort() {
98     if (Sort)
99       Z3_dec_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
100   }
101 
102   void Profile(llvm::FoldingSetNodeID &ID) const override {
103     ID.AddInteger(
104         Z3_get_ast_id(Context.Context, reinterpret_cast<Z3_ast>(Sort)));
105   }
106 
107   bool isBitvectorSortImpl() const override {
108     return (Z3_get_sort_kind(Context.Context, Sort) == Z3_BV_SORT);
109   }
110 
111   bool isFloatSortImpl() const override {
112     return (Z3_get_sort_kind(Context.Context, Sort) == Z3_FLOATING_POINT_SORT);
113   }
114 
115   bool isBooleanSortImpl() const override {
116     return (Z3_get_sort_kind(Context.Context, Sort) == Z3_BOOL_SORT);
117   }
118 
119   unsigned getBitvectorSortSizeImpl() const override {
120     return Z3_get_bv_sort_size(Context.Context, Sort);
121   }
122 
123   unsigned getFloatSortSizeImpl() const override {
124     return Z3_fpa_get_ebits(Context.Context, Sort) +
125            Z3_fpa_get_sbits(Context.Context, Sort);
126   }
127 
128   bool equal_to(SMTSort const &Other) const override {
129     return Z3_is_eq_sort(Context.Context, Sort,
130                          static_cast<const Z3Sort &>(Other).Sort);
131   }
132 
133   void print(raw_ostream &OS) const override {
134     OS << Z3_sort_to_string(Context.Context, Sort);
135   }
136 }; // end class Z3Sort
137 
138 static const Z3Sort &toZ3Sort(const SMTSort &S) {
139   return static_cast<const Z3Sort &>(S);
140 }
141 
142 class Z3Expr : public SMTExpr {
143   friend class Z3Solver;
144 
145   Z3Context &Context;
146 
147   Z3_ast AST;
148 
149 public:
150   Z3Expr(Z3Context &C, Z3_ast ZA) : SMTExpr(), Context(C), AST(ZA) {
151     Z3_inc_ref(Context.Context, AST);
152   }
153 
154   /// Override implicit copy constructor for correct reference counting.
155   Z3Expr(const Z3Expr &Copy) : SMTExpr(), Context(Copy.Context), AST(Copy.AST) {
156     Z3_inc_ref(Context.Context, AST);
157   }
158 
159   /// Override implicit copy assignment constructor for correct reference
160   /// counting.
161   Z3Expr &operator=(const Z3Expr &Other) {
162     Z3_inc_ref(Context.Context, Other.AST);
163     Z3_dec_ref(Context.Context, AST);
164     AST = Other.AST;
165     return *this;
166   }
167 
168   Z3Expr(Z3Expr &&Other) = delete;
169   Z3Expr &operator=(Z3Expr &&Other) = delete;
170 
171   ~Z3Expr() {
172     if (AST)
173       Z3_dec_ref(Context.Context, AST);
174   }
175 
176   void Profile(llvm::FoldingSetNodeID &ID) const override {
177     ID.AddInteger(Z3_get_ast_id(Context.Context, AST));
178   }
179 
180   /// Comparison of AST equality, not model equivalence.
181   bool equal_to(SMTExpr const &Other) const override {
182     assert(Z3_is_eq_sort(Context.Context, Z3_get_sort(Context.Context, AST),
183                          Z3_get_sort(Context.Context,
184                                      static_cast<const Z3Expr &>(Other).AST)) &&
185            "AST's must have the same sort");
186     return Z3_is_eq_ast(Context.Context, AST,
187                         static_cast<const Z3Expr &>(Other).AST);
188   }
189 
190   void print(raw_ostream &OS) const override {
191     OS << Z3_ast_to_string(Context.Context, AST);
192   }
193 }; // end class Z3Expr
194 
195 static const Z3Expr &toZ3Expr(const SMTExpr &E) {
196   return static_cast<const Z3Expr &>(E);
197 }
198 
199 class Z3Model {
200   friend class Z3Solver;
201 
202   Z3Context &Context;
203 
204   Z3_model Model;
205 
206 public:
207   Z3Model(Z3Context &C, Z3_model ZM) : Context(C), Model(ZM) {
208     Z3_model_inc_ref(Context.Context, Model);
209   }
210 
211   Z3Model(const Z3Model &Other) = delete;
212   Z3Model(Z3Model &&Other) = delete;
213   Z3Model &operator=(Z3Model &Other) = delete;
214   Z3Model &operator=(Z3Model &&Other) = delete;
215 
216   ~Z3Model() {
217     if (Model)
218       Z3_model_dec_ref(Context.Context, Model);
219   }
220 
221   void print(raw_ostream &OS) const {
222     OS << Z3_model_to_string(Context.Context, Model);
223   }
224 
225   LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); }
226 }; // end class Z3Model
227 
228 /// Get the corresponding IEEE floating-point type for a given bitwidth.
229 static const llvm::fltSemantics &getFloatSemantics(unsigned BitWidth) {
230   switch (BitWidth) {
231   default:
232     llvm_unreachable("Unsupported floating-point semantics!");
233     break;
234   case 16:
235     return llvm::APFloat::IEEEhalf();
236   case 32:
237     return llvm::APFloat::IEEEsingle();
238   case 64:
239     return llvm::APFloat::IEEEdouble();
240   case 128:
241     return llvm::APFloat::IEEEquad();
242   }
243 }
244 
245 // Determine whether two float semantics are equivalent
246 static bool areEquivalent(const llvm::fltSemantics &LHS,
247                           const llvm::fltSemantics &RHS) {
248   return (llvm::APFloat::semanticsPrecision(LHS) ==
249           llvm::APFloat::semanticsPrecision(RHS)) &&
250          (llvm::APFloat::semanticsMinExponent(LHS) ==
251           llvm::APFloat::semanticsMinExponent(RHS)) &&
252          (llvm::APFloat::semanticsMaxExponent(LHS) ==
253           llvm::APFloat::semanticsMaxExponent(RHS)) &&
254          (llvm::APFloat::semanticsSizeInBits(LHS) ==
255           llvm::APFloat::semanticsSizeInBits(RHS));
256 }
257 
258 class Z3Solver : public SMTSolver {
259   friend class Z3ConstraintManager;
260 
261   Z3Context Context;
262 
263   Z3_solver Solver;
264 
265   // Cache Sorts
266   std::set<Z3Sort> CachedSorts;
267 
268   // Cache Exprs
269   std::set<Z3Expr> CachedExprs;
270 
271 public:
272   Z3Solver() : Solver(Z3_mk_simple_solver(Context.Context)) {
273     Z3_solver_inc_ref(Context.Context, Solver);
274   }
275 
276   Z3Solver(const Z3Solver &Other) = delete;
277   Z3Solver(Z3Solver &&Other) = delete;
278   Z3Solver &operator=(Z3Solver &Other) = delete;
279   Z3Solver &operator=(Z3Solver &&Other) = delete;
280 
281   ~Z3Solver() {
282     if (Solver)
283       Z3_solver_dec_ref(Context.Context, Solver);
284   }
285 
286   void addConstraint(const SMTExprRef &Exp) const override {
287     Z3_solver_assert(Context.Context, Solver, toZ3Expr(*Exp).AST);
288   }
289 
290   // Given an SMTSort, adds/retrives it from the cache and returns
291   // an SMTSortRef to the SMTSort in the cache
292   SMTSortRef newSortRef(const SMTSort &Sort) {
293     auto It = CachedSorts.insert(toZ3Sort(Sort));
294     return &(*It.first);
295   }
296 
297   // Given an SMTExpr, adds/retrives it from the cache and returns
298   // an SMTExprRef to the SMTExpr in the cache
299   SMTExprRef newExprRef(const SMTExpr &Exp) {
300     auto It = CachedExprs.insert(toZ3Expr(Exp));
301     return &(*It.first);
302   }
303 
304   SMTSortRef getBoolSort() override {
305     return newSortRef(Z3Sort(Context, Z3_mk_bool_sort(Context.Context)));
306   }
307 
308   SMTSortRef getBitvectorSort(unsigned BitWidth) override {
309     return newSortRef(
310         Z3Sort(Context, Z3_mk_bv_sort(Context.Context, BitWidth)));
311   }
312 
313   SMTSortRef getSort(const SMTExprRef &Exp) override {
314     return newSortRef(
315         Z3Sort(Context, Z3_get_sort(Context.Context, toZ3Expr(*Exp).AST)));
316   }
317 
318   SMTSortRef getFloat16Sort() override {
319     return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_16(Context.Context)));
320   }
321 
322   SMTSortRef getFloat32Sort() override {
323     return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_32(Context.Context)));
324   }
325 
326   SMTSortRef getFloat64Sort() override {
327     return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_64(Context.Context)));
328   }
329 
330   SMTSortRef getFloat128Sort() override {
331     return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_128(Context.Context)));
332   }
333 
334   SMTExprRef mkBVNeg(const SMTExprRef &Exp) override {
335     return newExprRef(
336         Z3Expr(Context, Z3_mk_bvneg(Context.Context, toZ3Expr(*Exp).AST)));
337   }
338 
339   SMTExprRef mkBVNot(const SMTExprRef &Exp) override {
340     return newExprRef(
341         Z3Expr(Context, Z3_mk_bvnot(Context.Context, toZ3Expr(*Exp).AST)));
342   }
343 
344   SMTExprRef mkNot(const SMTExprRef &Exp) override {
345     return newExprRef(
346         Z3Expr(Context, Z3_mk_not(Context.Context, toZ3Expr(*Exp).AST)));
347   }
348 
349   SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
350     return newExprRef(
351         Z3Expr(Context, Z3_mk_bvadd(Context.Context, toZ3Expr(*LHS).AST,
352                                     toZ3Expr(*RHS).AST)));
353   }
354 
355   SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
356     return newExprRef(
357         Z3Expr(Context, Z3_mk_bvsub(Context.Context, toZ3Expr(*LHS).AST,
358                                     toZ3Expr(*RHS).AST)));
359   }
360 
361   SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
362     return newExprRef(
363         Z3Expr(Context, Z3_mk_bvmul(Context.Context, toZ3Expr(*LHS).AST,
364                                     toZ3Expr(*RHS).AST)));
365   }
366 
367   SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
368     return newExprRef(
369         Z3Expr(Context, Z3_mk_bvsrem(Context.Context, toZ3Expr(*LHS).AST,
370                                      toZ3Expr(*RHS).AST)));
371   }
372 
373   SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
374     return newExprRef(
375         Z3Expr(Context, Z3_mk_bvurem(Context.Context, toZ3Expr(*LHS).AST,
376                                      toZ3Expr(*RHS).AST)));
377   }
378 
379   SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
380     return newExprRef(
381         Z3Expr(Context, Z3_mk_bvsdiv(Context.Context, toZ3Expr(*LHS).AST,
382                                      toZ3Expr(*RHS).AST)));
383   }
384 
385   SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
386     return newExprRef(
387         Z3Expr(Context, Z3_mk_bvudiv(Context.Context, toZ3Expr(*LHS).AST,
388                                      toZ3Expr(*RHS).AST)));
389   }
390 
391   SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
392     return newExprRef(
393         Z3Expr(Context, Z3_mk_bvshl(Context.Context, toZ3Expr(*LHS).AST,
394                                     toZ3Expr(*RHS).AST)));
395   }
396 
397   SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
398     return newExprRef(
399         Z3Expr(Context, Z3_mk_bvashr(Context.Context, toZ3Expr(*LHS).AST,
400                                      toZ3Expr(*RHS).AST)));
401   }
402 
403   SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
404     return newExprRef(
405         Z3Expr(Context, Z3_mk_bvlshr(Context.Context, toZ3Expr(*LHS).AST,
406                                      toZ3Expr(*RHS).AST)));
407   }
408 
409   SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
410     return newExprRef(
411         Z3Expr(Context, Z3_mk_bvxor(Context.Context, toZ3Expr(*LHS).AST,
412                                     toZ3Expr(*RHS).AST)));
413   }
414 
415   SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
416     return newExprRef(
417         Z3Expr(Context, Z3_mk_bvor(Context.Context, toZ3Expr(*LHS).AST,
418                                    toZ3Expr(*RHS).AST)));
419   }
420 
421   SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
422     return newExprRef(
423         Z3Expr(Context, Z3_mk_bvand(Context.Context, toZ3Expr(*LHS).AST,
424                                     toZ3Expr(*RHS).AST)));
425   }
426 
427   SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
428     return newExprRef(
429         Z3Expr(Context, Z3_mk_bvult(Context.Context, toZ3Expr(*LHS).AST,
430                                     toZ3Expr(*RHS).AST)));
431   }
432 
433   SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
434     return newExprRef(
435         Z3Expr(Context, Z3_mk_bvslt(Context.Context, toZ3Expr(*LHS).AST,
436                                     toZ3Expr(*RHS).AST)));
437   }
438 
439   SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
440     return newExprRef(
441         Z3Expr(Context, Z3_mk_bvugt(Context.Context, toZ3Expr(*LHS).AST,
442                                     toZ3Expr(*RHS).AST)));
443   }
444 
445   SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
446     return newExprRef(
447         Z3Expr(Context, Z3_mk_bvsgt(Context.Context, toZ3Expr(*LHS).AST,
448                                     toZ3Expr(*RHS).AST)));
449   }
450 
451   SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
452     return newExprRef(
453         Z3Expr(Context, Z3_mk_bvule(Context.Context, toZ3Expr(*LHS).AST,
454                                     toZ3Expr(*RHS).AST)));
455   }
456 
457   SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
458     return newExprRef(
459         Z3Expr(Context, Z3_mk_bvsle(Context.Context, toZ3Expr(*LHS).AST,
460                                     toZ3Expr(*RHS).AST)));
461   }
462 
463   SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
464     return newExprRef(
465         Z3Expr(Context, Z3_mk_bvuge(Context.Context, toZ3Expr(*LHS).AST,
466                                     toZ3Expr(*RHS).AST)));
467   }
468 
469   SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
470     return newExprRef(
471         Z3Expr(Context, Z3_mk_bvsge(Context.Context, toZ3Expr(*LHS).AST,
472                                     toZ3Expr(*RHS).AST)));
473   }
474 
475   SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
476     Z3_ast Args[2] = {toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST};
477     return newExprRef(Z3Expr(Context, Z3_mk_and(Context.Context, 2, Args)));
478   }
479 
480   SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
481     Z3_ast Args[2] = {toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST};
482     return newExprRef(Z3Expr(Context, Z3_mk_or(Context.Context, 2, Args)));
483   }
484 
485   SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
486     return newExprRef(
487         Z3Expr(Context, Z3_mk_eq(Context.Context, toZ3Expr(*LHS).AST,
488                                  toZ3Expr(*RHS).AST)));
489   }
490 
491   SMTExprRef mkFPNeg(const SMTExprRef &Exp) override {
492     return newExprRef(
493         Z3Expr(Context, Z3_mk_fpa_neg(Context.Context, toZ3Expr(*Exp).AST)));
494   }
495 
496   SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) override {
497     return newExprRef(Z3Expr(
498         Context, Z3_mk_fpa_is_infinite(Context.Context, toZ3Expr(*Exp).AST)));
499   }
500 
501   SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) override {
502     return newExprRef(
503         Z3Expr(Context, Z3_mk_fpa_is_nan(Context.Context, toZ3Expr(*Exp).AST)));
504   }
505 
506   SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) override {
507     return newExprRef(Z3Expr(
508         Context, Z3_mk_fpa_is_normal(Context.Context, toZ3Expr(*Exp).AST)));
509   }
510 
511   SMTExprRef mkFPIsZero(const SMTExprRef &Exp) override {
512     return newExprRef(Z3Expr(
513         Context, Z3_mk_fpa_is_zero(Context.Context, toZ3Expr(*Exp).AST)));
514   }
515 
516   SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
517     SMTExprRef RoundingMode = getFloatRoundingMode();
518     return newExprRef(
519         Z3Expr(Context,
520                Z3_mk_fpa_mul(Context.Context, toZ3Expr(*RoundingMode).AST,
521                              toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST)));
522   }
523 
524   SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
525     SMTExprRef RoundingMode = getFloatRoundingMode();
526     return newExprRef(
527         Z3Expr(Context,
528                Z3_mk_fpa_div(Context.Context, toZ3Expr(*RoundingMode).AST,
529                              toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST)));
530   }
531 
532   SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
533     return newExprRef(
534         Z3Expr(Context, Z3_mk_fpa_rem(Context.Context, toZ3Expr(*LHS).AST,
535                                       toZ3Expr(*RHS).AST)));
536   }
537 
538   SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
539     SMTExprRef RoundingMode = getFloatRoundingMode();
540     return newExprRef(
541         Z3Expr(Context,
542                Z3_mk_fpa_add(Context.Context, toZ3Expr(*RoundingMode).AST,
543                              toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST)));
544   }
545 
546   SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
547     SMTExprRef RoundingMode = getFloatRoundingMode();
548     return newExprRef(
549         Z3Expr(Context,
550                Z3_mk_fpa_sub(Context.Context, toZ3Expr(*RoundingMode).AST,
551                              toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST)));
552   }
553 
554   SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
555     return newExprRef(
556         Z3Expr(Context, Z3_mk_fpa_lt(Context.Context, toZ3Expr(*LHS).AST,
557                                      toZ3Expr(*RHS).AST)));
558   }
559 
560   SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
561     return newExprRef(
562         Z3Expr(Context, Z3_mk_fpa_gt(Context.Context, toZ3Expr(*LHS).AST,
563                                      toZ3Expr(*RHS).AST)));
564   }
565 
566   SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
567     return newExprRef(
568         Z3Expr(Context, Z3_mk_fpa_leq(Context.Context, toZ3Expr(*LHS).AST,
569                                       toZ3Expr(*RHS).AST)));
570   }
571 
572   SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
573     return newExprRef(
574         Z3Expr(Context, Z3_mk_fpa_geq(Context.Context, toZ3Expr(*LHS).AST,
575                                       toZ3Expr(*RHS).AST)));
576   }
577 
578   SMTExprRef mkFPEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
579     return newExprRef(
580         Z3Expr(Context, Z3_mk_fpa_eq(Context.Context, toZ3Expr(*LHS).AST,
581                                      toZ3Expr(*RHS).AST)));
582   }
583 
584   SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
585                    const SMTExprRef &F) override {
586     return newExprRef(
587         Z3Expr(Context, Z3_mk_ite(Context.Context, toZ3Expr(*Cond).AST,
588                                   toZ3Expr(*T).AST, toZ3Expr(*F).AST)));
589   }
590 
591   SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) override {
592     return newExprRef(Z3Expr(
593         Context, Z3_mk_sign_ext(Context.Context, i, toZ3Expr(*Exp).AST)));
594   }
595 
596   SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) override {
597     return newExprRef(Z3Expr(
598         Context, Z3_mk_zero_ext(Context.Context, i, toZ3Expr(*Exp).AST)));
599   }
600 
601   SMTExprRef mkBVExtract(unsigned High, unsigned Low,
602                          const SMTExprRef &Exp) override {
603     return newExprRef(Z3Expr(Context, Z3_mk_extract(Context.Context, High, Low,
604                                                     toZ3Expr(*Exp).AST)));
605   }
606 
607   /// Creates a predicate that checks for overflow in a bitvector addition
608   /// operation
609   SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS,
610                                bool isSigned) override {
611     return newExprRef(Z3Expr(
612         Context, Z3_mk_bvadd_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
613                                          toZ3Expr(*RHS).AST, isSigned)));
614   }
615 
616   /// Creates a predicate that checks for underflow in a signed bitvector
617   /// addition operation
618   SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
619                                 const SMTExprRef &RHS) override {
620     return newExprRef(Z3Expr(
621         Context, Z3_mk_bvadd_no_underflow(Context.Context, toZ3Expr(*LHS).AST,
622                                           toZ3Expr(*RHS).AST)));
623   }
624 
625   /// Creates a predicate that checks for overflow in a signed bitvector
626   /// subtraction operation
627   SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
628                                const SMTExprRef &RHS) override {
629     return newExprRef(Z3Expr(
630         Context, Z3_mk_bvsub_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
631                                          toZ3Expr(*RHS).AST)));
632   }
633 
634   /// Creates a predicate that checks for underflow in a bitvector subtraction
635   /// operation
636   SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS, const SMTExprRef &RHS,
637                                 bool isSigned) override {
638     return newExprRef(Z3Expr(
639         Context, Z3_mk_bvsub_no_underflow(Context.Context, toZ3Expr(*LHS).AST,
640                                           toZ3Expr(*RHS).AST, isSigned)));
641   }
642 
643   /// Creates a predicate that checks for overflow in a signed bitvector
644   /// division/modulus operation
645   SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
646                                 const SMTExprRef &RHS) override {
647     return newExprRef(Z3Expr(
648         Context, Z3_mk_bvsdiv_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
649                                           toZ3Expr(*RHS).AST)));
650   }
651 
652   /// Creates a predicate that checks for overflow in a bitvector negation
653   /// operation
654   SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) override {
655     return newExprRef(Z3Expr(
656         Context, Z3_mk_bvneg_no_overflow(Context.Context, toZ3Expr(*Exp).AST)));
657   }
658 
659   /// Creates a predicate that checks for overflow in a bitvector multiplication
660   /// operation
661   SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS,
662                                bool isSigned) override {
663     return newExprRef(Z3Expr(
664         Context, Z3_mk_bvmul_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
665                                          toZ3Expr(*RHS).AST, isSigned)));
666   }
667 
668   /// Creates a predicate that checks for underflow in a signed bitvector
669   /// multiplication operation
670   SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
671                                 const SMTExprRef &RHS) override {
672     return newExprRef(Z3Expr(
673         Context, Z3_mk_bvmul_no_underflow(Context.Context, toZ3Expr(*LHS).AST,
674                                           toZ3Expr(*RHS).AST)));
675   }
676 
677   SMTExprRef mkBVConcat(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
678     return newExprRef(
679         Z3Expr(Context, Z3_mk_concat(Context.Context, toZ3Expr(*LHS).AST,
680                                      toZ3Expr(*RHS).AST)));
681   }
682 
683   SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) override {
684     SMTExprRef RoundingMode = getFloatRoundingMode();
685     return newExprRef(Z3Expr(
686         Context,
687         Z3_mk_fpa_to_fp_float(Context.Context, toZ3Expr(*RoundingMode).AST,
688                               toZ3Expr(*From).AST, toZ3Sort(*To).Sort)));
689   }
690 
691   SMTExprRef mkSBVtoFP(const SMTExprRef &From, const SMTSortRef &To) override {
692     SMTExprRef RoundingMode = getFloatRoundingMode();
693     return newExprRef(Z3Expr(
694         Context,
695         Z3_mk_fpa_to_fp_signed(Context.Context, toZ3Expr(*RoundingMode).AST,
696                                toZ3Expr(*From).AST, toZ3Sort(*To).Sort)));
697   }
698 
699   SMTExprRef mkUBVtoFP(const SMTExprRef &From, const SMTSortRef &To) override {
700     SMTExprRef RoundingMode = getFloatRoundingMode();
701     return newExprRef(Z3Expr(
702         Context,
703         Z3_mk_fpa_to_fp_unsigned(Context.Context, toZ3Expr(*RoundingMode).AST,
704                                  toZ3Expr(*From).AST, toZ3Sort(*To).Sort)));
705   }
706 
707   SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) override {
708     SMTExprRef RoundingMode = getFloatRoundingMode();
709     return newExprRef(Z3Expr(
710         Context, Z3_mk_fpa_to_sbv(Context.Context, toZ3Expr(*RoundingMode).AST,
711                                   toZ3Expr(*From).AST, ToWidth)));
712   }
713 
714   SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) override {
715     SMTExprRef RoundingMode = getFloatRoundingMode();
716     return newExprRef(Z3Expr(
717         Context, Z3_mk_fpa_to_ubv(Context.Context, toZ3Expr(*RoundingMode).AST,
718                                   toZ3Expr(*From).AST, ToWidth)));
719   }
720 
721   SMTExprRef mkBoolean(const bool b) override {
722     return newExprRef(Z3Expr(Context, b ? Z3_mk_true(Context.Context)
723                                         : Z3_mk_false(Context.Context)));
724   }
725 
726   SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) override {
727     const Z3_sort Z3Sort = toZ3Sort(*getBitvectorSort(BitWidth)).Sort;
728 
729     // Slow path, when 64 bits are not enough.
730     if (LLVM_UNLIKELY(Int.getBitWidth() > 64u)) {
731       SmallString<40> Buffer;
732       Int.toString(Buffer, 10);
733       return newExprRef(Z3Expr(
734           Context, Z3_mk_numeral(Context.Context, Buffer.c_str(), Z3Sort)));
735     }
736 
737     const int64_t BitReprAsSigned = Int.getExtValue();
738     const uint64_t BitReprAsUnsigned =
739         reinterpret_cast<const uint64_t &>(BitReprAsSigned);
740 
741     Z3_ast Literal =
742         Int.isSigned()
743             ? Z3_mk_int64(Context.Context, BitReprAsSigned, Z3Sort)
744             : Z3_mk_unsigned_int64(Context.Context, BitReprAsUnsigned, Z3Sort);
745     return newExprRef(Z3Expr(Context, Literal));
746   }
747 
748   SMTExprRef mkFloat(const llvm::APFloat Float) override {
749     SMTSortRef Sort =
750         getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics()));
751 
752     llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false);
753     SMTExprRef Z3Int = mkBitvector(Int, Int.getBitWidth());
754     return newExprRef(Z3Expr(
755         Context, Z3_mk_fpa_to_fp_bv(Context.Context, toZ3Expr(*Z3Int).AST,
756                                     toZ3Sort(*Sort).Sort)));
757   }
758 
759   SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) override {
760     return newExprRef(
761         Z3Expr(Context, Z3_mk_const(Context.Context,
762                                     Z3_mk_string_symbol(Context.Context, Name),
763                                     toZ3Sort(*Sort).Sort)));
764   }
765 
766   llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
767                             bool isUnsigned) override {
768     return llvm::APSInt(
769         llvm::APInt(BitWidth,
770                     Z3_get_numeral_string(Context.Context, toZ3Expr(*Exp).AST),
771                     10),
772         isUnsigned);
773   }
774 
775   bool getBoolean(const SMTExprRef &Exp) override {
776     return Z3_get_bool_value(Context.Context, toZ3Expr(*Exp).AST) == Z3_L_TRUE;
777   }
778 
779   SMTExprRef getFloatRoundingMode() override {
780     // TODO: Don't assume nearest ties to even rounding mode
781     return newExprRef(Z3Expr(Context, Z3_mk_fpa_rne(Context.Context)));
782   }
783 
784   bool toAPFloat(const SMTSortRef &Sort, const SMTExprRef &AST,
785                  llvm::APFloat &Float, bool useSemantics) {
786     assert(Sort->isFloatSort() && "Unsupported sort to floating-point!");
787 
788     llvm::APSInt Int(Sort->getFloatSortSize(), true);
789     const llvm::fltSemantics &Semantics =
790         getFloatSemantics(Sort->getFloatSortSize());
791     SMTSortRef BVSort = getBitvectorSort(Sort->getFloatSortSize());
792     if (!toAPSInt(BVSort, AST, Int, true)) {
793       return false;
794     }
795 
796     if (useSemantics && !areEquivalent(Float.getSemantics(), Semantics)) {
797       assert(false && "Floating-point types don't match!");
798       return false;
799     }
800 
801     Float = llvm::APFloat(Semantics, Int);
802     return true;
803   }
804 
805   bool toAPSInt(const SMTSortRef &Sort, const SMTExprRef &AST,
806                 llvm::APSInt &Int, bool useSemantics) {
807     if (Sort->isBitvectorSort()) {
808       if (useSemantics && Int.getBitWidth() != Sort->getBitvectorSortSize()) {
809         assert(false && "Bitvector types don't match!");
810         return false;
811       }
812 
813       // FIXME: This function is also used to retrieve floating-point values,
814       // which can be 16, 32, 64 or 128 bits long. Bitvectors can be anything
815       // between 1 and 64 bits long, which is the reason we have this weird
816       // guard. In the future, we need proper calls in the backend to retrieve
817       // floating-points and its special values (NaN, +/-infinity, +/-zero),
818       // then we can drop this weird condition.
819       if (Sort->getBitvectorSortSize() <= 64 ||
820           Sort->getBitvectorSortSize() == 128) {
821         Int = getBitvector(AST, Int.getBitWidth(), Int.isUnsigned());
822         return true;
823       }
824 
825       assert(false && "Bitwidth not supported!");
826       return false;
827     }
828 
829     if (Sort->isBooleanSort()) {
830       if (useSemantics && Int.getBitWidth() < 1) {
831         assert(false && "Boolean type doesn't match!");
832         return false;
833       }
834 
835       Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), getBoolean(AST)),
836                          Int.isUnsigned());
837       return true;
838     }
839 
840     llvm_unreachable("Unsupported sort to integer!");
841   }
842 
843   bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) override {
844     Z3Model Model(Context, Z3_solver_get_model(Context.Context, Solver));
845     Z3_func_decl Func = Z3_get_app_decl(
846         Context.Context, Z3_to_app(Context.Context, toZ3Expr(*Exp).AST));
847     if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE)
848       return false;
849 
850     SMTExprRef Assign = newExprRef(
851         Z3Expr(Context,
852                Z3_model_get_const_interp(Context.Context, Model.Model, Func)));
853     SMTSortRef Sort = getSort(Assign);
854     return toAPSInt(Sort, Assign, Int, true);
855   }
856 
857   bool getInterpretation(const SMTExprRef &Exp, llvm::APFloat &Float) override {
858     Z3Model Model(Context, Z3_solver_get_model(Context.Context, Solver));
859     Z3_func_decl Func = Z3_get_app_decl(
860         Context.Context, Z3_to_app(Context.Context, toZ3Expr(*Exp).AST));
861     if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE)
862       return false;
863 
864     SMTExprRef Assign = newExprRef(
865         Z3Expr(Context,
866                Z3_model_get_const_interp(Context.Context, Model.Model, Func)));
867     SMTSortRef Sort = getSort(Assign);
868     return toAPFloat(Sort, Assign, Float, true);
869   }
870 
871   Optional<bool> check() const override {
872     Z3_lbool res = Z3_solver_check(Context.Context, Solver);
873     if (res == Z3_L_TRUE)
874       return true;
875 
876     if (res == Z3_L_FALSE)
877       return false;
878 
879     return Optional<bool>();
880   }
881 
882   void push() override { return Z3_solver_push(Context.Context, Solver); }
883 
884   void pop(unsigned NumStates = 1) override {
885     assert(Z3_solver_get_num_scopes(Context.Context, Solver) >= NumStates);
886     return Z3_solver_pop(Context.Context, Solver, NumStates);
887   }
888 
889   bool isFPSupported() override { return true; }
890 
891   /// Reset the solver and remove all constraints.
892   void reset() override { Z3_solver_reset(Context.Context, Solver); }
893 
894   void print(raw_ostream &OS) const override {
895     OS << Z3_solver_to_string(Context.Context, Solver);
896   }
897 }; // end class Z3Solver
898 
899 } // end anonymous namespace
900 
901 #endif
902 
903 llvm::SMTSolverRef llvm::CreateZ3Solver() {
904 #if LLVM_WITH_Z3
905   return std::make_unique<Z3Solver>();
906 #else
907   llvm::report_fatal_error("LLVM was not compiled with Z3 support, rebuild "
908                            "with -DLLVM_ENABLE_Z3_SOLVER=ON",
909                            false);
910   return nullptr;
911 #endif
912 }
913 
914 LLVM_DUMP_METHOD void SMTSort::dump() const { print(llvm::errs()); }
915 LLVM_DUMP_METHOD void SMTExpr::dump() const { print(llvm::errs()); }
916 LLVM_DUMP_METHOD void SMTSolver::dump() const { print(llvm::errs()); }
917