1 //===- SMTAPI.h -------------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a SMT generic Solver API, which will be the base class 10 // for every SMT solver specific class. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_SUPPORT_SMTAPI_H 15 #define LLVM_SUPPORT_SMTAPI_H 16 17 #include "llvm/ADT/APFloat.h" 18 #include "llvm/ADT/APSInt.h" 19 #include "llvm/ADT/FoldingSet.h" 20 #include "llvm/Support/raw_ostream.h" 21 #include <memory> 22 23 namespace llvm { 24 25 /// Generic base class for SMT sorts 26 class SMTSort { 27 public: 28 SMTSort() = default; 29 virtual ~SMTSort() = default; 30 31 /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl(). 32 virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); } 33 34 /// Returns true if the sort is a floating-point, calls isFloatSortImpl(). 35 virtual bool isFloatSort() const { return isFloatSortImpl(); } 36 37 /// Returns true if the sort is a boolean, calls isBooleanSortImpl(). 38 virtual bool isBooleanSort() const { return isBooleanSortImpl(); } 39 40 /// Returns the bitvector size, fails if the sort is not a bitvector 41 /// Calls getBitvectorSortSizeImpl(). 42 virtual unsigned getBitvectorSortSize() const { 43 assert(isBitvectorSort() && "Not a bitvector sort!"); 44 unsigned Size = getBitvectorSortSizeImpl(); 45 assert(Size && "Size is zero!"); 46 return Size; 47 }; 48 49 /// Returns the floating-point size, fails if the sort is not a floating-point 50 /// Calls getFloatSortSizeImpl(). 51 virtual unsigned getFloatSortSize() const { 52 assert(isFloatSort() && "Not a floating-point sort!"); 53 unsigned Size = getFloatSortSizeImpl(); 54 assert(Size && "Size is zero!"); 55 return Size; 56 }; 57 58 virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0; 59 60 bool operator<(const SMTSort &Other) const { 61 llvm::FoldingSetNodeID ID1, ID2; 62 Profile(ID1); 63 Other.Profile(ID2); 64 return ID1 < ID2; 65 } 66 67 friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) { 68 return LHS.equal_to(RHS); 69 } 70 71 virtual void print(raw_ostream &OS) const = 0; 72 73 LLVM_DUMP_METHOD void dump() const; 74 75 protected: 76 /// Query the SMT solver and returns true if two sorts are equal (same kind 77 /// and bit width). This does not check if the two sorts are the same objects. 78 virtual bool equal_to(SMTSort const &other) const = 0; 79 80 /// Query the SMT solver and checks if a sort is bitvector. 81 virtual bool isBitvectorSortImpl() const = 0; 82 83 /// Query the SMT solver and checks if a sort is floating-point. 84 virtual bool isFloatSortImpl() const = 0; 85 86 /// Query the SMT solver and checks if a sort is boolean. 87 virtual bool isBooleanSortImpl() const = 0; 88 89 /// Query the SMT solver and returns the sort bit width. 90 virtual unsigned getBitvectorSortSizeImpl() const = 0; 91 92 /// Query the SMT solver and returns the sort bit width. 93 virtual unsigned getFloatSortSizeImpl() const = 0; 94 }; 95 96 /// Shared pointer for SMTSorts, used by SMTSolver API. 97 using SMTSortRef = const SMTSort *; 98 99 /// Generic base class for SMT exprs 100 class SMTExpr { 101 public: 102 SMTExpr() = default; 103 virtual ~SMTExpr() = default; 104 105 bool operator<(const SMTExpr &Other) const { 106 llvm::FoldingSetNodeID ID1, ID2; 107 Profile(ID1); 108 Other.Profile(ID2); 109 return ID1 < ID2; 110 } 111 112 virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0; 113 114 friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) { 115 return LHS.equal_to(RHS); 116 } 117 118 virtual void print(raw_ostream &OS) const = 0; 119 120 LLVM_DUMP_METHOD void dump() const; 121 122 protected: 123 /// Query the SMT solver and returns true if two sorts are equal (same kind 124 /// and bit width). This does not check if the two sorts are the same objects. 125 virtual bool equal_to(SMTExpr const &other) const = 0; 126 }; 127 128 /// Shared pointer for SMTExprs, used by SMTSolver API. 129 using SMTExprRef = const SMTExpr *; 130 131 /// Generic base class for SMT Solvers 132 /// 133 /// This class is responsible for wrapping all sorts and expression generation, 134 /// through the mk* methods. It also provides methods to create SMT expressions 135 /// straight from clang's AST, through the from* methods. 136 class SMTSolver { 137 public: 138 SMTSolver() = default; 139 virtual ~SMTSolver() = default; 140 141 LLVM_DUMP_METHOD void dump() const; 142 143 // Returns an appropriate floating-point sort for the given bitwidth. 144 SMTSortRef getFloatSort(unsigned BitWidth) { 145 switch (BitWidth) { 146 case 16: 147 return getFloat16Sort(); 148 case 32: 149 return getFloat32Sort(); 150 case 64: 151 return getFloat64Sort(); 152 case 128: 153 return getFloat128Sort(); 154 default:; 155 } 156 llvm_unreachable("Unsupported floating-point bitwidth!"); 157 } 158 159 // Returns a boolean sort. 160 virtual SMTSortRef getBoolSort() = 0; 161 162 // Returns an appropriate bitvector sort for the given bitwidth. 163 virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0; 164 165 // Returns a floating-point sort of width 16 166 virtual SMTSortRef getFloat16Sort() = 0; 167 168 // Returns a floating-point sort of width 32 169 virtual SMTSortRef getFloat32Sort() = 0; 170 171 // Returns a floating-point sort of width 64 172 virtual SMTSortRef getFloat64Sort() = 0; 173 174 // Returns a floating-point sort of width 128 175 virtual SMTSortRef getFloat128Sort() = 0; 176 177 // Returns an appropriate sort for the given AST. 178 virtual SMTSortRef getSort(const SMTExprRef &AST) = 0; 179 180 /// Given a constraint, adds it to the solver 181 virtual void addConstraint(const SMTExprRef &Exp) const = 0; 182 183 /// Creates a bitvector addition operation 184 virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 185 186 /// Creates a bitvector subtraction operation 187 virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 188 189 /// Creates a bitvector multiplication operation 190 virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 191 192 /// Creates a bitvector signed modulus operation 193 virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 194 195 /// Creates a bitvector unsigned modulus operation 196 virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 197 198 /// Creates a bitvector signed division operation 199 virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 200 201 /// Creates a bitvector unsigned division operation 202 virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 203 204 /// Creates a bitvector logical shift left operation 205 virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 206 207 /// Creates a bitvector arithmetic shift right operation 208 virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 209 210 /// Creates a bitvector logical shift right operation 211 virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 212 213 /// Creates a bitvector negation operation 214 virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0; 215 216 /// Creates a bitvector not operation 217 virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0; 218 219 /// Creates a bitvector xor operation 220 virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 221 222 /// Creates a bitvector or operation 223 virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 224 225 /// Creates a bitvector and operation 226 virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 227 228 /// Creates a bitvector unsigned less-than operation 229 virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 230 231 /// Creates a bitvector signed less-than operation 232 virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 233 234 /// Creates a bitvector unsigned greater-than operation 235 virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 236 237 /// Creates a bitvector signed greater-than operation 238 virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 239 240 /// Creates a bitvector unsigned less-equal-than operation 241 virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 242 243 /// Creates a bitvector signed less-equal-than operation 244 virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 245 246 /// Creates a bitvector unsigned greater-equal-than operation 247 virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 248 249 /// Creates a bitvector signed greater-equal-than operation 250 virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 251 252 /// Creates a boolean not operation 253 virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0; 254 255 /// Creates a boolean equality operation 256 virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 257 258 /// Creates a boolean and operation 259 virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 260 261 /// Creates a boolean or operation 262 virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 263 264 /// Creates a boolean ite operation 265 virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T, 266 const SMTExprRef &F) = 0; 267 268 /// Creates a bitvector sign extension operation 269 virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0; 270 271 /// Creates a bitvector zero extension operation 272 virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0; 273 274 /// Creates a bitvector extract operation 275 virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low, 276 const SMTExprRef &Exp) = 0; 277 278 /// Creates a bitvector concat operation 279 virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS, 280 const SMTExprRef &RHS) = 0; 281 282 /// Creates a predicate that checks for overflow in a bitvector addition 283 /// operation 284 virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS, 285 const SMTExprRef &RHS, 286 bool isSigned) = 0; 287 288 /// Creates a predicate that checks for underflow in a signed bitvector 289 /// addition operation 290 virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS, 291 const SMTExprRef &RHS) = 0; 292 293 /// Creates a predicate that checks for overflow in a signed bitvector 294 /// subtraction operation 295 virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS, 296 const SMTExprRef &RHS) = 0; 297 298 /// Creates a predicate that checks for underflow in a bitvector subtraction 299 /// operation 300 virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS, 301 const SMTExprRef &RHS, 302 bool isSigned) = 0; 303 304 /// Creates a predicate that checks for overflow in a signed bitvector 305 /// division/modulus operation 306 virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS, 307 const SMTExprRef &RHS) = 0; 308 309 /// Creates a predicate that checks for overflow in a bitvector negation 310 /// operation 311 virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0; 312 313 /// Creates a predicate that checks for overflow in a bitvector multiplication 314 /// operation 315 virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS, 316 const SMTExprRef &RHS, 317 bool isSigned) = 0; 318 319 /// Creates a predicate that checks for underflow in a signed bitvector 320 /// multiplication operation 321 virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS, 322 const SMTExprRef &RHS) = 0; 323 324 /// Creates a floating-point negation operation 325 virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0; 326 327 /// Creates a floating-point isInfinite operation 328 virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0; 329 330 /// Creates a floating-point isNaN operation 331 virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0; 332 333 /// Creates a floating-point isNormal operation 334 virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0; 335 336 /// Creates a floating-point isZero operation 337 virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0; 338 339 /// Creates a floating-point multiplication operation 340 virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 341 342 /// Creates a floating-point division operation 343 virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 344 345 /// Creates a floating-point remainder operation 346 virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 347 348 /// Creates a floating-point addition operation 349 virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 350 351 /// Creates a floating-point subtraction operation 352 virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 353 354 /// Creates a floating-point less-than operation 355 virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 356 357 /// Creates a floating-point greater-than operation 358 virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 359 360 /// Creates a floating-point less-than-or-equal operation 361 virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 362 363 /// Creates a floating-point greater-than-or-equal operation 364 virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; 365 366 /// Creates a floating-point equality operation 367 virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS, 368 const SMTExprRef &RHS) = 0; 369 370 /// Creates a floating-point conversion from floatint-point to floating-point 371 /// operation 372 virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0; 373 374 /// Creates a floating-point conversion from signed bitvector to 375 /// floatint-point operation 376 virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From, 377 const SMTSortRef &To) = 0; 378 379 /// Creates a floating-point conversion from unsigned bitvector to 380 /// floatint-point operation 381 virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From, 382 const SMTSortRef &To) = 0; 383 384 /// Creates a floating-point conversion from floatint-point to signed 385 /// bitvector operation 386 virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0; 387 388 /// Creates a floating-point conversion from floatint-point to unsigned 389 /// bitvector operation 390 virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0; 391 392 /// Creates a new symbol, given a name and a sort 393 virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0; 394 395 // Returns an appropriate floating-point rounding mode. 396 virtual SMTExprRef getFloatRoundingMode() = 0; 397 398 // If the a model is available, returns the value of a given bitvector symbol 399 virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth, 400 bool isUnsigned) = 0; 401 402 // If the a model is available, returns the value of a given boolean symbol 403 virtual bool getBoolean(const SMTExprRef &Exp) = 0; 404 405 /// Constructs an SMTExprRef from a boolean. 406 virtual SMTExprRef mkBoolean(const bool b) = 0; 407 408 /// Constructs an SMTExprRef from a finite APFloat. 409 virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0; 410 411 /// Constructs an SMTExprRef from an APSInt and its bit width 412 virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0; 413 414 /// Given an expression, extract the value of this operand in the model. 415 virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0; 416 417 /// Given an expression extract the value of this operand in the model. 418 virtual bool getInterpretation(const SMTExprRef &Exp, 419 llvm::APFloat &Float) = 0; 420 421 /// Check if the constraints are satisfiable 422 virtual std::optional<bool> check() const = 0; 423 424 /// Push the current solver state 425 virtual void push() = 0; 426 427 /// Pop the previous solver state 428 virtual void pop(unsigned NumStates = 1) = 0; 429 430 /// Reset the solver and remove all constraints. 431 virtual void reset() = 0; 432 433 /// Checks if the solver supports floating-points. 434 virtual bool isFPSupported() = 0; 435 436 virtual void print(raw_ostream &OS) const = 0; 437 }; 438 439 /// Shared pointer for SMTSolvers. 440 using SMTSolverRef = std::shared_ptr<SMTSolver>; 441 442 /// Convenience method to create and Z3Solver object 443 SMTSolverRef CreateZ3Solver(); 444 445 } // namespace llvm 446 447 #endif 448