1 //===- SMTAPI.h -------------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a SMT generic Solver API, which will be the base class 10 // for every SMT solver specific class. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_SUPPORT_SMTAPI_H 15 #define LLVM_SUPPORT_SMTAPI_H 16 17 #include "llvm/ADT/APFloat.h" 18 #include "llvm/ADT/APSInt.h" 19 #include "llvm/ADT/FoldingSet.h" 20 #include "llvm/Support/Compiler.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include <memory> 23 24 namespace llvm { 25 26 /// Generic base class for SMT sorts 27 class SMTSort { 28 public: 29 SMTSort() = default; 30 virtual ~SMTSort() = default; 31 32 /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl(). isBitvectorSort()33 virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); } 34 35 /// Returns true if the sort is a floating-point, calls isFloatSortImpl(). isFloatSort()36 virtual bool isFloatSort() const { return isFloatSortImpl(); } 37 38 /// Returns true if the sort is a boolean, calls isBooleanSortImpl(). isBooleanSort()39 virtual bool isBooleanSort() const { return isBooleanSortImpl(); } 40 41 /// Returns the bitvector size, fails if the sort is not a bitvector 42 /// Calls getBitvectorSortSizeImpl(). getBitvectorSortSize()43 virtual unsigned getBitvectorSortSize() const { 44 assert(isBitvectorSort() && "Not a bitvector sort!"); 45 unsigned Size = getBitvectorSortSizeImpl(); 46 assert(Size && "Size is zero!"); 47 return Size; 48 }; 49 50 /// Returns the floating-point size, fails if the sort is not a floating-point 51 /// Calls getFloatSortSizeImpl(). getFloatSortSize()52 virtual unsigned getFloatSortSize() const { 53 assert(isFloatSort() && "Not a floating-point sort!"); 54 unsigned Size = getFloatSortSizeImpl(); 55 assert(Size && "Size is zero!"); 56 return Size; 57 }; 58 59 virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0; 60 61 bool operator<(const SMTSort &Other) const { 62 llvm::FoldingSetNodeID ID1, ID2; 63 Profile(ID1); 64 Other.Profile(ID2); 65 return ID1 < ID2; 66 } 67 68 friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) { 69 return LHS.equal_to(RHS); 70 } 71 72 virtual void print(raw_ostream &OS) const = 0; 73 74 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 75 LLVM_DUMP_METHOD void dump() const; 76 #endif 77 78 protected: 79 /// Query the SMT solver and returns true if two sorts are equal (same kind 80 /// and bit width). This does not check if the two sorts are the same objects. 81 virtual bool equal_to(SMTSort const &other) const = 0; 82 83 /// Query the SMT solver and checks if a sort is bitvector. 84 virtual bool isBitvectorSortImpl() const = 0; 85 86 /// Query the SMT solver and checks if a sort is floating-point. 87 virtual bool isFloatSortImpl() const = 0; 88 89 /// Query the SMT solver and checks if a sort is boolean. 90 virtual bool isBooleanSortImpl() const = 0; 91 92 /// Query the SMT solver and returns the sort bit width. 93 virtual unsigned getBitvectorSortSizeImpl() const = 0; 94 95 /// Query the SMT solver and returns the sort bit width. 96 virtual unsigned getFloatSortSizeImpl() const = 0; 97 }; 98 99 /// Shared pointer for SMTSorts, used by SMTSolver API. 100 using SMTSortRef = const SMTSort *; 101 102 /// Generic base class for SMT exprs 103 class SMTExpr { 104 public: 105 SMTExpr() = default; 106 virtual ~SMTExpr() = default; 107 108 bool operator<(const SMTExpr &Other) const { 109 llvm::FoldingSetNodeID ID1, ID2; 110 Profile(ID1); 111 Other.Profile(ID2); 112 return ID1 < ID2; 113 } 114 115 virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0; 116 117 friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) { 118 return LHS.equal_to(RHS); 119 } 120 121 virtual void print(raw_ostream &OS) const = 0; 122 123 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 124 LLVM_DUMP_METHOD void dump() const; 125 #endif 126 127 protected: 128 /// Query the SMT solver and returns true if two sorts are equal (same kind 129 /// and bit width). This does not check if the two sorts are the same objects. 130 virtual bool equal_to(SMTExpr const &other) const = 0; 131 }; 132 133 class SMTSolverStatistics { 134 public: 135 SMTSolverStatistics() = default; 136 virtual ~SMTSolverStatistics() = default; 137 138 virtual double getDouble(llvm::StringRef) const = 0; 139 virtual unsigned getUnsigned(llvm::StringRef) const = 0; 140 141 virtual void print(raw_ostream &OS) const = 0; 142 143 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 144 LLVM_DUMP_METHOD void dump() const; 145 #endif 146 }; 147 148 /// Shared pointer for SMTExprs, used by SMTSolver API. 149 using SMTExprRef = const SMTExpr *; 150 151 /// Generic base class for SMT Solvers 152 /// 153 /// This class is responsible for wrapping all sorts and expression generation, 154 /// through the mk* methods. It also provides methods to create SMT expressions 155 /// straight from clang's AST, through the from* methods. 156 class SMTSolver { 157 public: 158 SMTSolver() = default; 159 virtual ~SMTSolver() = default; 160 161 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 162 LLVM_DUMP_METHOD void dump() const; 163 #endif 164 165 // Returns an appropriate floating-point sort for the given bitwidth. getFloatSort(unsigned BitWidth)166 SMTSortRef getFloatSort(unsigned BitWidth) { 167 switch (BitWidth) { 168 case 16: 169 return getFloat16Sort(); 170 case 32: 171 return getFloat32Sort(); 172 case 64: 173 return getFloat64Sort(); 174 case 128: 175 return getFloat128Sort(); 176 default:; 177 } 178 llvm_unreachable("Unsupported floating-point bitwidth!"); 179 } 180 181 // Returns a boolean sort. 182 virtual SMTSortRef getBoolSort() = 0; 183 184 // Returns an appropriate bitvector sort for the given bitwidth. 185 virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0; 186 187 // Returns a floating-point sort of width 16 188 virtual SMTSortRef getFloat16Sort() = 0; 189 190 // Returns a floating-point sort of width 32 191 virtual SMTSortRef getFloat32Sort() = 0; 192 193 // Returns a floating-point sort of width 64 194 virtual SMTSortRef getFloat64Sort() = 0; 195 196 // Returns a floating-point sort of width 128 197 virtual SMTSortRef getFloat128Sort() = 0; 198 199 // Returns an appropriate sort for the given AST. 200 virtual SMTSortRef getSort(const SMTExprRef &AST) = 0; 201 202 /// Given a constraint, adds it to the solver 203 virtual void addConstraint(const SMTExprRef &Exp) const = 0; 204 205 /// Creates a bitvector addition operation 206 virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 207 208 /// Creates a bitvector subtraction operation 209 virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 210 211 /// Creates a bitvector multiplication operation 212 virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 213 214 /// Creates a bitvector signed modulus operation 215 virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 216 217 /// Creates a bitvector unsigned modulus operation 218 virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 219 220 /// Creates a bitvector signed division operation 221 virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 222 223 /// Creates a bitvector unsigned division operation 224 virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 225 226 /// Creates a bitvector logical shift left operation 227 virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 228 229 /// Creates a bitvector arithmetic shift right operation 230 virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 231 232 /// Creates a bitvector logical shift right operation 233 virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 234 235 /// Creates a bitvector negation operation 236 virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0; 237 238 /// Creates a bitvector not operation 239 virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0; 240 241 /// Creates a bitvector xor operation 242 virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 243 244 /// Creates a bitvector or operation 245 virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 246 247 /// Creates a bitvector and operation 248 virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 249 250 /// Creates a bitvector unsigned less-than operation 251 virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 252 253 /// Creates a bitvector signed less-than operation 254 virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 255 256 /// Creates a bitvector unsigned greater-than operation 257 virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 258 259 /// Creates a bitvector signed greater-than operation 260 virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 261 262 /// Creates a bitvector unsigned less-equal-than operation 263 virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 264 265 /// Creates a bitvector signed less-equal-than operation 266 virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 267 268 /// Creates a bitvector unsigned greater-equal-than operation 269 virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 270 271 /// Creates a bitvector signed greater-equal-than operation 272 virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 273 274 /// Creates a boolean not operation 275 virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0; 276 277 /// Creates a boolean equality operation 278 virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 279 280 /// Creates a boolean and operation 281 virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 282 283 /// Creates a boolean or operation 284 virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 285 286 /// Creates a boolean ite operation 287 virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T, 288 const SMTExprRef &F) = 0; 289 290 /// Creates a bitvector sign extension operation 291 virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0; 292 293 /// Creates a bitvector zero extension operation 294 virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0; 295 296 /// Creates a bitvector extract operation 297 virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low, 298 const SMTExprRef &Exp) = 0; 299 300 /// Creates a bitvector concat operation 301 virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS, 302 const SMTExprRef &RHS) = 0; 303 304 /// Creates a predicate that checks for overflow in a bitvector addition 305 /// operation 306 virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS, 307 const SMTExprRef &RHS, 308 bool isSigned) = 0; 309 310 /// Creates a predicate that checks for underflow in a signed bitvector 311 /// addition operation 312 virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS, 313 const SMTExprRef &RHS) = 0; 314 315 /// Creates a predicate that checks for overflow in a signed bitvector 316 /// subtraction operation 317 virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS, 318 const SMTExprRef &RHS) = 0; 319 320 /// Creates a predicate that checks for underflow in a bitvector subtraction 321 /// operation 322 virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS, 323 const SMTExprRef &RHS, 324 bool isSigned) = 0; 325 326 /// Creates a predicate that checks for overflow in a signed bitvector 327 /// division/modulus operation 328 virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS, 329 const SMTExprRef &RHS) = 0; 330 331 /// Creates a predicate that checks for overflow in a bitvector negation 332 /// operation 333 virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0; 334 335 /// Creates a predicate that checks for overflow in a bitvector multiplication 336 /// operation 337 virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS, 338 const SMTExprRef &RHS, 339 bool isSigned) = 0; 340 341 /// Creates a predicate that checks for underflow in a signed bitvector 342 /// multiplication operation 343 virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS, 344 const SMTExprRef &RHS) = 0; 345 346 /// Creates a floating-point negation operation 347 virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0; 348 349 /// Creates a floating-point isInfinite operation 350 virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0; 351 352 /// Creates a floating-point isNaN operation 353 virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0; 354 355 /// Creates a floating-point isNormal operation 356 virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0; 357 358 /// Creates a floating-point isZero operation 359 virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0; 360 361 /// Creates a floating-point multiplication operation 362 virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 363 364 /// Creates a floating-point division operation 365 virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 366 367 /// Creates a floating-point remainder operation 368 virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 369 370 /// Creates a floating-point addition operation 371 virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 372 373 /// Creates a floating-point subtraction operation 374 virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 375 376 /// Creates a floating-point less-than operation 377 virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 378 379 /// Creates a floating-point greater-than operation 380 virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 381 382 /// Creates a floating-point less-than-or-equal operation 383 virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 384 385 /// Creates a floating-point greater-than-or-equal operation 386 virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 387 388 /// Creates a floating-point equality operation 389 virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS, 390 const SMTExprRef &RHS) = 0; 391 392 /// Creates a floating-point conversion from floatint-point to floating-point 393 /// operation 394 virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0; 395 396 /// Creates a floating-point conversion from signed bitvector to 397 /// floatint-point operation 398 virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From, 399 const SMTSortRef &To) = 0; 400 401 /// Creates a floating-point conversion from unsigned bitvector to 402 /// floatint-point operation 403 virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From, 404 const SMTSortRef &To) = 0; 405 406 /// Creates a floating-point conversion from floatint-point to signed 407 /// bitvector operation 408 virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0; 409 410 /// Creates a floating-point conversion from floatint-point to unsigned 411 /// bitvector operation 412 virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0; 413 414 /// Creates a new symbol, given a name and a sort 415 virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0; 416 417 // Returns an appropriate floating-point rounding mode. 418 virtual SMTExprRef getFloatRoundingMode() = 0; 419 420 // If the a model is available, returns the value of a given bitvector symbol 421 virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth, 422 bool isUnsigned) = 0; 423 424 // If the a model is available, returns the value of a given boolean symbol 425 virtual bool getBoolean(const SMTExprRef &Exp) = 0; 426 427 /// Constructs an SMTExprRef from a boolean. 428 virtual SMTExprRef mkBoolean(const bool b) = 0; 429 430 /// Constructs an SMTExprRef from a finite APFloat. 431 virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0; 432 433 /// Constructs an SMTExprRef from an APSInt and its bit width 434 virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0; 435 436 /// Given an expression, extract the value of this operand in the model. 437 virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0; 438 439 /// Given an expression extract the value of this operand in the model. 440 virtual bool getInterpretation(const SMTExprRef &Exp, 441 llvm::APFloat &Float) = 0; 442 443 /// Check if the constraints are satisfiable 444 virtual std::optional<bool> check() const = 0; 445 446 /// Push the current solver state 447 virtual void push() = 0; 448 449 /// Pop the previous solver state 450 virtual void pop(unsigned NumStates = 1) = 0; 451 452 /// Reset the solver and remove all constraints. 453 virtual void reset() = 0; 454 455 /// Checks if the solver supports floating-points. 456 virtual bool isFPSupported() = 0; 457 458 virtual void print(raw_ostream &OS) const = 0; 459 460 /// Sets the requested option. 461 virtual void setBoolParam(StringRef Key, bool Value) = 0; 462 virtual void setUnsignedParam(StringRef Key, unsigned Value) = 0; 463 464 virtual std::unique_ptr<SMTSolverStatistics> getStatistics() const = 0; 465 }; 466 467 /// Shared pointer for SMTSolvers. 468 using SMTSolverRef = std::shared_ptr<SMTSolver>; 469 470 /// Convenience method to create and Z3Solver object 471 LLVM_ABI SMTSolverRef CreateZ3Solver(); 472 473 } // namespace llvm 474 475 #endif 476