From 96cca1bd96703106189d8759aeaeb9c9de365bda Mon Sep 17 00:00:00 2001 From: makaimann Date: Wed, 30 May 2018 14:54:41 -0700 Subject: [PATCH 1/5] Working on Floating Point. --- smt_switch/src/api.py | 31 ++++++++++++++++++++--- smt_switch/src/functions.py | 20 +++++++++++++-- smt_switch/src/solvers/CVC4Solver.py | 38 ++++++++++++++++++++++++---- smt_switch/src/sorts.py | 20 +++++++++++++++ 4 files changed, 99 insertions(+), 10 deletions(-) diff --git a/smt_switch/src/api.py b/smt_switch/src/api.py index c5730a2..81ceddd 100644 --- a/smt_switch/src/api.py +++ b/smt_switch/src/api.py @@ -1,7 +1,8 @@ # This file is part of the smt-switch project. # See the file LICENSE in the top-level source directory for licensing information. -from collections import Sequence +from collections import Sequence, namedtuple + from . import sorts from . import functions from . import terms @@ -126,8 +127,9 @@ def DeclareConst(self, name, sort): return self.__term_map[self.solver.__class__](self, sconst) - def TheoryConst(self, sort, value): - stconst = self.solver.TheoryConst(sort, value) + def TheoryConst(self, sort, *values): + values = [v.solver_term if hasattr(v, 'solver_term') else v for v in values] + stconst = self.solver.TheoryConst(sort, *values) return self.__term_map[self.solver.__class__](self, stconst) @@ -230,3 +232,26 @@ def Push(self): def Pop(self): self.solver.Pop() + + @property + def Round(self): + ''' + Returns a namedtuple containing integers encoding the type of Floating Point Rounding + + Intended for use with a solver supporting floating point queries. + ''' + return Round + + +# duplicate fenv.h values +# make available through Round +fenv = namedtuple("fenv", "FE_TONEAREST FE_DOWNWARD FE_UPWARD FE_TOWARDZERO RNE RTN RTP RTZ RNA") +FE_TONEAREST = 0 +FE_DOWNWARD = 0x400 +FE_UPWARD = 0x800 +FE_TOWARDZERO = 0xc00 + +Round = fenv(FE_TONEAREST, FE_DOWNWARD, FE_UPWARD, FE_TOWARDZERO, + FE_TONEAREST, FE_DOWNWARD, FE_UPWARD, FE_TOWARDZERO, + (((~FE_TONEAREST) & 0x1) | ((~FE_UPWARD) & 0x2) | + ((~FE_DOWNWARD) & 0x4) | ((~FE_TOWARDZERO) & 0x8))) diff --git a/smt_switch/src/functions.py b/smt_switch/src/functions.py index 98c9f93..04eae86 100644 --- a/smt_switch/src/functions.py +++ b/smt_switch/src/functions.py @@ -85,8 +85,24 @@ def _Or(*args): ('Store', fdata(0, 3, 3)), ('No_op', fdata(0, 0, 0)), ('_ApplyUF', fdata(0, 1, sys.maxsize)), - ('Distinct', fdata(0, 2, sys.maxsize))]) - + ('Distinct', fdata(0, 2, sys.maxsize)), + ('FPEq', fdata(0, 2, 2)), + ('FPAbs', fdata(0, 1, 1)), + ('FPAdd', fdata(0, 2, 2)), + ('FPSub', fdata(0, 2, 2)), + ('FPMul', fdata(0, 2, 2)), + ('FPDiv', fdata(0, 2, 2)), + ('FPRem', fdata(0, 2, 2)), + ('FPFma', fdata(0, 4, 4)), + ('FPSqrt', fdata(0, 2, 2)), + ('FPRti', fdata(0, 2, 2)), + ('FPMin', fdata(0, 2, 2)), + ('FPMax', fdata(0, 2, 2)), + ('FPLt', fdata(0, 2, 2)), + ('FPLe', fdata(0, 2, 2)), + ('FPGt', fdata(0, 2, 2)), + ('FPGe', fdata(0, 2, 2)), + ('FPNeg', fdata(0, 1, 1))]) # generate enums for each of these function symbols func_d = dict() diff --git a/smt_switch/src/solvers/CVC4Solver.py b/smt_switch/src/solvers/CVC4Solver.py index 13591ac..0d428f3 100644 --- a/smt_switch/src/solvers/CVC4Solver.py +++ b/smt_switch/src/solvers/CVC4Solver.py @@ -33,7 +33,8 @@ def __init__(self, strict): sorts.Int: self._em.integerType, sorts.Real: self._em.realType, sorts.Bool: self._em.booleanType, - sorts.Array: self._em.mkArrayType} + sorts.Array: self._em.mkArrayType, + sorts.FP: self._em.mkFloatingPointType} # def create_array_sort(idxsort, dsort): # # get parameterized sorts @@ -85,7 +86,24 @@ def __init__(self, strict): func_enum._ApplyUF: self.CVC4.APPLY_UF, func_enum.Select: self.CVC4.SELECT, func_enum.Store: self.CVC4.STORE, - func_enum.Distinct: self.CVC4.DISTINCT + func_enum.Distinct: self.CVC4.DISTINCT, + func_enum.FPEq: self.CVC4.FLOATINGPOINT_EQ, + func_enum.FPAbs: self.CVC4.FLOATINGPOINT_ABS, + func_enum.FPNeg: self.CVC4.FLOATINGPOINT_NEG, + func_enum.FPAdd: self.CVC4.FLOATINGPOINT_PLUS, + func_enum.FPSub: self.CVC4.FLOATINGPOINT_SUB, + func_enum.FPMul: self.CVC4.FLOATINGPOINT_MULT, + func_enum.FPDiv: self.CVC4.FLOATINGPOINT_DIV, + func_enum.FPRem: self.CVC4.FLOATINGPOINT_REM, + func_enum.FPFma: self.CVC4.FLOATINGPOINT_FMA, + func_enum.FPSqrt: self.CVC4.FLOATINGPOINT_SQRT, + func_enum.FPRti: self.CVC4.FLOATINGPOINT_RTI, + func_enum.FPMin: self.CVC4.FLOATINGPOINT_MIN, + func_enum.FPMax: self.CVC4.FLOATINGPOINT_MAX, + func_enum.FPLe: self.CVC4.FLOATINGPOINT_LEQ, + func_enum.FPLt: self.CVC4.FLOATINGPOINT_LT, + func_enum.FPGe: self.CVC4.FLOATINGPOINT_GEQ, + func_enum.FPGt: self.CVC4.FLOATINGPOINT_GT }) # all constants are No_op @@ -114,10 +132,20 @@ def create_real(value): def create_bool(value): return self._em.mkBoolConst(value) + def create_fp(expbits, sigbits, *args): + if len(args) == 3: + return self._em.mkExpr(self.CVC4.FLOATINGPOINT_FP, *args) + elif len(args) == 1: + assert isinstance(value, Fraction) + return self._em.mkExpr(self.CVC4.FLOATINGPOINT_FP, self.CVC4.Rational(value.numerator, value.denominator)) + else: + raise UnimplementedError("Don't have support for other FP instantiation techniques") + self._CVC4Consts = {sorts.BitVec: create_bv, sorts.Int: create_int, sorts.Real: create_real, - sorts.Bool: create_bool} + sorts.Bool: create_bool, + sorts.FP: create_fp} def Reset(self): self._smt.reset() @@ -155,8 +183,8 @@ def DeclareConst(self, name, sort): cvc4const = self._em.mkVar(name, cvc4sort) return cvc4const - def TheoryConst(self, sort, value): - cvc4tconst = self._CVC4Consts[sort.__class__](*(sort.params + (value,))) + def TheoryConst(self, sort, *values): + cvc4tconst = self._CVC4Consts[sort.__class__](*(sort.params + values)) return cvc4tconst def ApplyFun(self, f_enum, indices, *args): diff --git a/smt_switch/src/sorts.py b/smt_switch/src/sorts.py index 5c0418d..284e0a0 100644 --- a/smt_switch/src/sorts.py +++ b/smt_switch/src/sorts.py @@ -85,3 +85,23 @@ def dsort(self): @property def params(self): return (self._idxsort, self._dsort) + + +class FP(SortBase): + ''' Floating Point sort ''' + def __init__(self, expbits, sigbits): + super().__init__('(FP {} {})'.format(expbits, sigbits)) + self._expbits = expbits + self._sigbits = sigbits + + @property + def expbits(self): + return self._expbits + + @property + def sigbits(self): + return self._sigbits + + @property + def params(self): + return (self._expbits, self._sigbits) From cb5953ee5952664bcdfa20d8967f482023c5a10b Mon Sep 17 00:00:00 2001 From: makaimann Date: Thu, 31 May 2018 14:14:58 -0700 Subject: [PATCH 2/5] FP working with strict setting --- smt_switch/src/api.py | 36 ++++++++++---------------- smt_switch/src/solvers/CVC4Solver.py | 36 ++++++++++++++++++++++++-- smt_switch/src/sorts.py | 4 +-- smt_switch/src/terms.py | 38 +++++++++++++++++++++++++--- smt_switch/tests/__init__.py | 2 ++ smt_switch/tests/test_fp_strict.py | 32 +++++++++++++++++++++++ 6 files changed, 118 insertions(+), 30 deletions(-) create mode 100644 smt_switch/tests/test_fp_strict.py diff --git a/smt_switch/src/api.py b/smt_switch/src/api.py index 81ceddd..48edf49 100644 --- a/smt_switch/src/api.py +++ b/smt_switch/src/api.py @@ -68,6 +68,19 @@ def __init__(self, solver_name, strict=False): self._strict = strict + # create attributes and namedtuples for special solver constants + # these are used for particular tasks, such as rounding in Floating + # Point solving but do not fit into the general term structure + if hasattr(self._solver, '_special_consts'): + for k, v in self._solver._special_consts.items(): + assert isinstance(v, dict), 'Expecting special consts to be >' + + wrapped_consts = {s: terms.WrapperTerm(self, t) for s, t in v.items()} + NT = namedtuple(k, wrapped_consts.keys())(**wrapped_consts) + assert not hasattr(self, k), "Special Const name {} is already an api function".format(k) + + setattr(self, k, NT) + def ConstructFun(self, fun, *args): # partial function evaluation all handled internally return fun(*args) @@ -232,26 +245,3 @@ def Push(self): def Pop(self): self.solver.Pop() - - @property - def Round(self): - ''' - Returns a namedtuple containing integers encoding the type of Floating Point Rounding - - Intended for use with a solver supporting floating point queries. - ''' - return Round - - -# duplicate fenv.h values -# make available through Round -fenv = namedtuple("fenv", "FE_TONEAREST FE_DOWNWARD FE_UPWARD FE_TOWARDZERO RNE RTN RTP RTZ RNA") -FE_TONEAREST = 0 -FE_DOWNWARD = 0x400 -FE_UPWARD = 0x800 -FE_TOWARDZERO = 0xc00 - -Round = fenv(FE_TONEAREST, FE_DOWNWARD, FE_UPWARD, FE_TOWARDZERO, - FE_TONEAREST, FE_DOWNWARD, FE_UPWARD, FE_TOWARDZERO, - (((~FE_TONEAREST) & 0x1) | ((~FE_UPWARD) & 0x2) | - ((~FE_DOWNWARD) & 0x4) | ((~FE_TOWARDZERO) & 0x8))) diff --git a/smt_switch/src/solvers/CVC4Solver.py b/smt_switch/src/solvers/CVC4Solver.py index 0d428f3..5a2bae1 100644 --- a/smt_switch/src/solvers/CVC4Solver.py +++ b/smt_switch/src/solvers/CVC4Solver.py @@ -6,7 +6,7 @@ from .solverbase import SolverBase from fractions import Fraction from smt_switch.util import reversabledict -from collections import Sequence +from collections import Sequence, namedtuple import os @@ -115,7 +115,30 @@ def __init__(self, strict): # Note: losing info about op of applied function # TODO: see if can extract function definition self.CVC4.APPLY: func_enum.No_op, - self.CVC4.BITVECTOR_EXTRACT: func_enum.Extract} + self.CVC4.BITVECTOR_EXTRACT: func_enum.Extract, + self.CVC4.FLOATINGPOINT_FP: func_enum.No_op, + self.CVC4.NULL_EXPR: func_enum.No_op} + + # special constants for floating point solver + # duplicate fenv.h values + # make available through self._round + FE_TONEAREST = 0 + FE_DOWNWARD = 0x400 + FE_UPWARD = 0x800 + FE_TOWARDZERO = 0xc00 + + self._round = { + 'RNE': self._em.mkConst(FE_TONEAREST), + 'RTN': self._em.mkConst(FE_DOWNWARD), + 'RTP': self._em.mkConst(FE_UPWARD), + 'RTZ': self._em.mkConst(FE_TOWARDZERO), + 'RNA': self._em.mkConst(((~FE_TONEAREST) & 0x1) | ((~FE_UPWARD) & 0x2) | + ((~FE_DOWNWARD) & 0x4) | ((~FE_TOWARDZERO) & 0x8)) + } + + # The api creates an attribute for each entry in this dictionary, + # and creates a namedtuple out of each value + self._special_consts = {'Round': self._round} # Theory constant functions def create_bv(width, value): @@ -261,3 +284,12 @@ def Push(self): def Pop(self): self._smt.pop() + + @property + def Round(self): + ''' + Returns a namedtuple containing integers encoding the type of Floating Point Rounding + + Intended for use with a solver supporting floating point queries. + ''' + return self._round diff --git a/smt_switch/src/sorts.py b/smt_switch/src/sorts.py index 284e0a0..9d13445 100644 --- a/smt_switch/src/sorts.py +++ b/smt_switch/src/sorts.py @@ -5,7 +5,7 @@ import inspect -__all__ = ['BitVec', 'Int', 'Real', 'Bool', 'Array'] +__all__ = ['BitVec', 'Int', 'Real', 'Bool', 'Array', 'FP'] class SortBase(metaclass=ABCMeta): @abstractmethod @@ -90,7 +90,7 @@ def params(self): class FP(SortBase): ''' Floating Point sort ''' def __init__(self, expbits, sigbits): - super().__init__('(FP {} {})'.format(expbits, sigbits)) + super().__init__('(FP {} {})'.format(expbits, sigbits), []) self._expbits = expbits self._sigbits = sigbits diff --git a/smt_switch/src/terms.py b/smt_switch/src/terms.py index d622b79..a75bbb1 100644 --- a/smt_switch/src/terms.py +++ b/smt_switch/src/terms.py @@ -205,12 +205,21 @@ def __init__(self, smt, solver_term): 'bitvec': lambda p: sorts.BitVec(p), 'bool': lambda p: sorts.Bool(), 'boolean': lambda p: sorts.Bool(), - 'array': lambda ids, ds: sorts.Array(ids, ds) + 'array': lambda ids, ds: sorts.Array(ids, ds), + 'floatingpoint': lambda exp, sig: sorts.FP(exp, sig) } - p = re.compile('\(?(_ )?(?Pint|real|bitvector|bitvec|bool|array)\s?\(?(?P\d+)?\)?') + p = re.compile('\(?(_ )?(?Pfloatingpoint|int|real|bitvector|bitvec|bool|array)\s?\(?(?P\d+)?\)?') + + if solver_term.toString() != 'null': + cvc4sortstr = solver_term.getType().toString().lower() + else: + # HACK + # special-casing for FloatingPoint ops + children = solver_term.getChildren() + assert len(children) > 0, "Expecting FP Op Node with children" + cvc4sortstr = solver_term.getChildren()[0].getType().toString().lower() - cvc4sortstr = solver_term.getType().toString().lower() match = p.search(cvc4sortstr) if not match: @@ -229,6 +238,12 @@ def __init__(self, smt, solver_term): dsort = self._str2sort[dmatch.group('sort')](dmatch.group('param')) params = (idxsort, dsort) + elif 'floatingpoint' in cvc4sortstr: + # regex not quite right for floatingpoint + # TODO: Fix regex without breaking other sorts + sig, exp = cvc4sortstr[cvc4sortstr.find('floatingpoint')+len('floatingpoint '):].replace(")", "").split() + params = (int(sig), int(exp)) + elif 'bitvec' in match.group('sort'): assert match.group('param'), 'BitVecs must have a width' params = (int(match.group('param')),) @@ -434,5 +449,22 @@ def children(self): raise NotImplementedError('Boolector does not support querying children.') +class WrapperTerm: + ''' + Holds an arbitrary solver object. Used for special solver constants that can't be combined into arbitrary expressions or just don't fit into the general term structure for some reason + ''' + def __init__(self, smt, solver_term): + self._smt = smt + self._solver_term = solver_term + + @property + def smt(self): + return self._smt + + @property + def solver_term(self): + return self._solver_term + + def __bool_fun(*args): return sorts.Bool() diff --git a/smt_switch/tests/__init__.py b/smt_switch/tests/__init__.py index 1d4a569..15210a3 100644 --- a/smt_switch/tests/__init__.py +++ b/smt_switch/tests/__init__.py @@ -3,3 +3,5 @@ all_logic_solvers = {'Z3', 'CVC4'} bv_solvers = {'Z3', 'CVC4', 'Boolector'} + +fp_solvers = {'CVC4'} # haven't added support for fp in other solvers diff --git a/smt_switch/tests/test_fp_strict.py b/smt_switch/tests/test_fp_strict.py new file mode 100644 index 0000000..7e06c8b --- /dev/null +++ b/smt_switch/tests/test_fp_strict.py @@ -0,0 +1,32 @@ +import pytest +from smt_switch import smt +from smt_switch.tests import fp_solvers + +def test_basic(): + ''' + Very basic floating point test + ''' + + for name in fp_solvers: + s = smt(name, strict=True) + s.SetLogic("QF_FP") + + bvsort1 = s.ConstructSort(s.BitVec, 1) + bvsort8 = s.ConstructSort(s.BitVec, 8) + bvsort24 = s.ConstructSort(s.BitVec, 24) + + fpsort8_24 = s.ConstructSort(s.FP, 8, 24) + + b0 = s.TheoryConst(bvsort1, 0) + b200 = s.TheoryConst(bvsort8, 200) + b468 = s.TheoryConst(bvsort24, 468) + + f = s.TheoryConst(fpsort8_24, b0, b200, b468) + f2 = s.TheoryConst(fpsort8_24, b0, b200, b468) + + fc = s.DeclareConst("fc", fpsort8_24) + + fpf2 = s.ApplyFun(s.FPAdd, s.Round.RNE, f, f2) + +if __name__ == "__main__": + test_basic() From 010695c1a6c093b0d031b6082f7ac1fdec8a7d05 Mon Sep 17 00:00:00 2001 From: makaimann Date: Tue, 5 Jun 2018 12:00:35 -0700 Subject: [PATCH 3/5] Finish up floating point --- smt_switch/src/api.py | 7 ++-- smt_switch/src/solvers/BoolectorSolver.py | 6 ++- smt_switch/src/solvers/CVC4Solver.py | 15 ++++--- smt_switch/src/sorts.py | 21 +++++++++- smt_switch/src/terms.py | 48 +++++++++++++++++------ smt_switch/tests/test_fp_relaxed.py | 43 ++++++++++++++++++++ smt_switch/tests/test_fp_strict.py | 20 +++++++++- 7 files changed, 136 insertions(+), 24 deletions(-) create mode 100644 smt_switch/tests/test_fp_relaxed.py diff --git a/smt_switch/src/api.py b/smt_switch/src/api.py index 48edf49..81d6b96 100644 --- a/smt_switch/src/api.py +++ b/smt_switch/src/api.py @@ -1,7 +1,8 @@ # This file is part of the smt-switch project. # See the file LICENSE in the top-level source directory for licensing information. -from collections import Sequence, namedtuple +from collections import Sequence +from types import SimpleNamespace from . import sorts from . import functions @@ -68,7 +69,7 @@ def __init__(self, solver_name, strict=False): self._strict = strict - # create attributes and namedtuples for special solver constants + # create attributes and SimpleNamespaces for special solver constants # these are used for particular tasks, such as rounding in Floating # Point solving but do not fit into the general term structure if hasattr(self._solver, '_special_consts'): @@ -76,7 +77,7 @@ def __init__(self, solver_name, strict=False): assert isinstance(v, dict), 'Expecting special consts to be >' wrapped_consts = {s: terms.WrapperTerm(self, t) for s, t in v.items()} - NT = namedtuple(k, wrapped_consts.keys())(**wrapped_consts) + NT = SimpleNamespace(**wrapped_consts) assert not hasattr(self, k), "Special Const name {} is already an api function".format(k) setattr(self, k, NT) diff --git a/smt_switch/src/solvers/BoolectorSolver.py b/smt_switch/src/solvers/BoolectorSolver.py index cc61455..01fe959 100644 --- a/smt_switch/src/solvers/BoolectorSolver.py +++ b/smt_switch/src/solvers/BoolectorSolver.py @@ -64,7 +64,9 @@ def __init__(self, strict): # sorts.Bool: results.BoolectorBitVecResult} self._BoolectorOptions = {'produce-models': self.boolector.BTOR_OPT_MODEL_GEN, 'random-seed': self.boolector.BTOR_OPT_SEED, - 'incremental': self.boolector.BTOR_OPT_INCREMENTAL} + 'incremental': self.boolector.BTOR_OPT_INCREMENTAL, + 'fun:preprop': self.boolector.BTOR_OPT_FUN_PREPROP, + 'prop:nprops': self.boolector.BTOR_OPT_PROP_NPROPS} # am I missing any? self._BoolectorLogics = ['QF_BV', 'QF_ABV', 'QF_UFBV', 'QF_AUFBV'] @@ -88,6 +90,8 @@ def SetLogic(self, logicstr): def SetOption(self, optionstr, value): if optionstr in self._BoolectorOptions: self._btor.Set_opt(self._BoolectorOptions[optionstr], bool(value)) + else: + raise RuntimeError("Unrecognized option: {}".format(optionstr)) def DeclareFun(self, name, inputsorts, outputsort): assert isinstance(inputsorts, Sequence), \ diff --git a/smt_switch/src/solvers/CVC4Solver.py b/smt_switch/src/solvers/CVC4Solver.py index 5a2bae1..4f133a6 100644 --- a/smt_switch/src/solvers/CVC4Solver.py +++ b/smt_switch/src/solvers/CVC4Solver.py @@ -128,14 +128,17 @@ def __init__(self, strict): FE_TOWARDZERO = 0xc00 self._round = { - 'RNE': self._em.mkConst(FE_TONEAREST), - 'RTN': self._em.mkConst(FE_DOWNWARD), - 'RTP': self._em.mkConst(FE_UPWARD), - 'RTZ': self._em.mkConst(FE_TOWARDZERO), - 'RNA': self._em.mkConst(((~FE_TONEAREST) & 0x1) | ((~FE_UPWARD) & 0x2) | - ((~FE_DOWNWARD) & 0x4) | ((~FE_TOWARDZERO) & 0x8)) + 'RNE': self._em.mkVar(self._em.roundingModeType(), FE_TONEAREST), + 'RTN': self._em.mkVar(self._em.roundingModeType(), FE_DOWNWARD), + 'RTP': self._em.mkVar(self._em.roundingModeType(), FE_UPWARD), + 'RTZ': self._em.mkVar(self._em.roundingModeType(), FE_TOWARDZERO), + 'RNA': self._em.mkVar(self._em.roundingModeType(), ((~FE_TONEAREST) & 0x1) | ((~FE_UPWARD) & 0x2) | + ((~FE_DOWNWARD) & 0x4) | ((~FE_TOWARDZERO) & 0x8)) } + # set the default + self._round['default'] = self._round['RNE'] + # The api creates an attribute for each entry in this dictionary, # and creates a namedtuple out of each value self._special_consts = {'Round': self._round} diff --git a/smt_switch/src/sorts.py b/smt_switch/src/sorts.py index 9d13445..cb12aeb 100644 --- a/smt_switch/src/sorts.py +++ b/smt_switch/src/sorts.py @@ -5,7 +5,7 @@ import inspect -__all__ = ['BitVec', 'Int', 'Real', 'Bool', 'Array', 'FP'] +__all__ = ['BitVec', 'Int', 'Real', 'Bool', 'Array', 'FP', '_RoundingMode'] class SortBase(metaclass=ABCMeta): @abstractmethod @@ -105,3 +105,22 @@ def sigbits(self): @property def params(self): return (self._expbits, self._sigbits) + + def __eq__(self, other): + # Need None parameters to match with anything + return isinstance(other, type(self)) and \ + ((self.params[0] == other.params[0]) or (None in (self.params[0], other.params[0]))) and \ + ((self.params[1] == other.params[1]) or (None in (self.params[1], other.params[1]))) + + def __ne__(self, other): + # Need None parameters to match with anything + return not self.__eq__(other) + + +class _RoundingMode(SortBase): + ''' + RoundingMode for FloatingPoint + Should never need to instantiate this as a user + ''' + def __init__(self): + super().__init__('RoundingMode', []) diff --git a/smt_switch/src/terms.py b/smt_switch/src/terms.py index a75bbb1..2c326ca 100644 --- a/smt_switch/src/terms.py +++ b/smt_switch/src/terms.py @@ -43,6 +43,8 @@ def __ne__(self, other): def __add__(self, other): if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVAdd, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPAdd, self._smt.Round.default, self, other) else: return self._smt.ApplyFun(self._smt.Add, self, other) @@ -53,6 +55,8 @@ def __sub__(self, other): # override for bitvectors if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVSub, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPSub, self._smt.Round.default, self, other) else: return self._smt.ApplyFun(self._smt.Sub, self, other) @@ -66,6 +70,8 @@ def __rsub__(self, other): def __neg__(self): if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVNeg, self) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPNeg, self._smt.Round.default, self) else: zero = self._smt.TheoryConst(self.sort, 0) return self._smt.ApplyFun(self._smt.Sub, zero, self) @@ -73,6 +79,8 @@ def __neg__(self): def __mul__(self, other): if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVMul, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPMul, self._smt.Round.default, self, other) else: raise NotImplementedError("Haven't added nonlinear arithmetic operators yet.") @@ -88,6 +96,8 @@ def __mod__(self, other): def __truediv__(self, other): if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVUdiv, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPDiv, self._smt.Round.default, self, other) else: raise NotImplementedError("Haven't added nonlinear arithmetic operators yet.") @@ -102,6 +112,8 @@ def __lt__(self, other): "Operator expects 2 arguments of same sort" if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVSlt, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPLt, self, other) return self._smt.ApplyFun(self._smt.LT, self, other) @@ -110,6 +122,8 @@ def __le__(self, other): "Operator expects 2 arguments of same sort" if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVSle, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPLe, self, other) return self._smt.ApplyFun(self._smt.LEQ, self, other) @@ -118,6 +132,8 @@ def __gt__(self, other): "Operator expects 2 arguments of same sort" if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVSgt, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPGt, self, other) return self._smt.ApplyFun(self._smt.GT, self, other) @@ -126,6 +142,8 @@ def __ge__(self, other): "Operator expects 2 arguments of same sort" if self.sort.__class__ == sorts.BitVec: return self._smt.ApplyFun(self._smt.BVSge, self, other) + elif self.sort.__class__ == sorts.FP: + return self._smt.ApplyFun(self._smt.FPGe, self._smt.Round.default, self, other) return self._smt.ApplyFun(self._smt.GEQ, self, other) @@ -206,19 +224,13 @@ def __init__(self, smt, solver_term): 'bool': lambda p: sorts.Bool(), 'boolean': lambda p: sorts.Bool(), 'array': lambda ids, ds: sorts.Array(ids, ds), - 'floatingpoint': lambda exp, sig: sorts.FP(exp, sig) + 'floatingpoint': lambda exp, sig: sorts.FP(exp, sig), + 'roundingmode': lambda p: sorts._RoundingMode() } - p = re.compile('\(?(_ )?(?Pfloatingpoint|int|real|bitvector|bitvec|bool|array)\s?\(?(?P\d+)?\)?') + p = re.compile('\(?(_ )?(?Pfloatingpoint|int|real|bitvector|bitvec|bool|array|roundingmode)\s?\(?(?P\d+)?\)?') - if solver_term.toString() != 'null': - cvc4sortstr = solver_term.getType().toString().lower() - else: - # HACK - # special-casing for FloatingPoint ops - children = solver_term.getChildren() - assert len(children) > 0, "Expecting FP Op Node with children" - cvc4sortstr = solver_term.getChildren()[0].getType().toString().lower() + cvc4sortstr = solver_term.getType().toString().lower() match = p.search(cvc4sortstr) @@ -242,7 +254,15 @@ def __init__(self, smt, solver_term): # regex not quite right for floatingpoint # TODO: Fix regex without breaking other sorts sig, exp = cvc4sortstr[cvc4sortstr.find('floatingpoint')+len('floatingpoint '):].replace(")", "").split() - params = (int(sig), int(exp)) + sig, exp = int(sig), int(exp) + if sig == -1: + assert exp == -1 + + # this is the result of an operation -- unknown parameters + # don't give them values + sig, exp = None, None + + params = (sig, exp) elif 'bitvec' in match.group('sort'): assert match.group('param'), 'BitVecs must have a width' @@ -465,6 +485,12 @@ def smt(self): def solver_term(self): return self._solver_term + def __eq__(self, other): + return self.solver_term == other.solver_term + + def __ne__(self, other): + return self.solver_term != other.solver_term + def __bool_fun(*args): return sorts.Bool() diff --git a/smt_switch/tests/test_fp_relaxed.py b/smt_switch/tests/test_fp_relaxed.py new file mode 100644 index 0000000..28478eb --- /dev/null +++ b/smt_switch/tests/test_fp_relaxed.py @@ -0,0 +1,43 @@ +import pytest +from smt_switch import smt +from smt_switch.tests import fp_solvers + +def test_basic(): + ''' + Very basic floating point test + ''' + + for name in fp_solvers: + s = smt(name, strict=False) + s.SetLogic("QF_FP") + + b0 = s.TheoryConst(s.BitVec(1), 0) + b200 = s.TheoryConst(s.BitVec(8), 200) + b468 = s.TheoryConst(s.BitVec(24), 468) + + b152 = s.TheoryConst(s.BitVec(8), 152) + b42 = s.TheoryConst(s.BitVec(24), 42) + + # TODO: Handle python integers correctly + f = s.TheoryConst(s.FP(8, 24), b0, b200, b468) + f2 = s.TheoryConst(s.FP(8, 24), b0, b152, b42) + + fc = s.DeclareConst("fc", s.FP(8, 25)) + fc2 = s.DeclareConst("fc2", s.FP(8, 25)) + + # default rounding mode is RNE + # used for syntactic sugar like f + f2 + assert s.Round.default == s.Round.RNE + + # can change default with + # s.Round.default = s.Round. + + s.Assert(s.FPGt(fc, f + f2)) + s.Assert(s.FPLt(fc2, f - f2)) + + s.Assert(s.FPLe(fc, fc2)) + + assert not s.CheckSat() + +if __name__ == "__main__": + test_basic() diff --git a/smt_switch/tests/test_fp_strict.py b/smt_switch/tests/test_fp_strict.py index 7e06c8b..699057d 100644 --- a/smt_switch/tests/test_fp_strict.py +++ b/smt_switch/tests/test_fp_strict.py @@ -16,17 +16,33 @@ def test_basic(): bvsort24 = s.ConstructSort(s.BitVec, 24) fpsort8_24 = s.ConstructSort(s.FP, 8, 24) + fpsort8_25 = s.ConstructSort(s.FP, 8, 25) b0 = s.TheoryConst(bvsort1, 0) b200 = s.TheoryConst(bvsort8, 200) b468 = s.TheoryConst(bvsort24, 468) + b152 = s.TheoryConst(s.BitVec(8), 152) + b42 = s.TheoryConst(s.BitVec(24), 42) + f = s.TheoryConst(fpsort8_24, b0, b200, b468) - f2 = s.TheoryConst(fpsort8_24, b0, b200, b468) + f2 = s.TheoryConst(fpsort8_24, b0, b152, b42) - fc = s.DeclareConst("fc", fpsort8_24) + fc = s.DeclareConst("fc", fpsort8_25) + fc2 = s.DeclareConst("fc2", fpsort8_25) fpf2 = s.ApplyFun(s.FPAdd, s.Round.RNE, f, f2) + fmf2 = s.ApplyFun(s.FPSub, s.Round.RNE, f, f2) + + fcgtfpf2 = s.ApplyFun(s.FPGt, fc, fpf2) + fc2ltfmf2 = s.ApplyFun(s.FPLt, fc2, fmf2) + + s.Assert(fcgtfpf2) + s.Assert(fc2ltfmf2) + + lt = s.ApplyFun(s.FPLe, fc, fc2) + s.Assert(lt) + assert not s.CheckSat() if __name__ == "__main__": test_basic() From 65cbec3e8662a555f37b7bc0bb9b4a8711f08651 Mon Sep 17 00:00:00 2001 From: makaimann Date: Tue, 5 Jun 2018 12:09:51 -0700 Subject: [PATCH 4/5] Add clarification on FP sort --- smt_switch/tests/test_fp_relaxed.py | 5 +++-- smt_switch/tests/test_fp_strict.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/smt_switch/tests/test_fp_relaxed.py b/smt_switch/tests/test_fp_relaxed.py index 28478eb..487b0b4 100644 --- a/smt_switch/tests/test_fp_relaxed.py +++ b/smt_switch/tests/test_fp_relaxed.py @@ -19,8 +19,9 @@ def test_basic(): b42 = s.TheoryConst(s.BitVec(24), 42) # TODO: Handle python integers correctly - f = s.TheoryConst(s.FP(8, 24), b0, b200, b468) - f2 = s.TheoryConst(s.FP(8, 24), b0, b152, b42) + # Note: There's a "hidden" bit which makes it 25 instead of 24 + f = s.TheoryConst(s.FP(8, 25), b0, b200, b468) + f2 = s.TheoryConst(s.FP(8, 25), b0, b152, b42) fc = s.DeclareConst("fc", s.FP(8, 25)) fc2 = s.DeclareConst("fc2", s.FP(8, 25)) diff --git a/smt_switch/tests/test_fp_strict.py b/smt_switch/tests/test_fp_strict.py index 699057d..cc47b2c 100644 --- a/smt_switch/tests/test_fp_strict.py +++ b/smt_switch/tests/test_fp_strict.py @@ -15,7 +15,7 @@ def test_basic(): bvsort8 = s.ConstructSort(s.BitVec, 8) bvsort24 = s.ConstructSort(s.BitVec, 24) - fpsort8_24 = s.ConstructSort(s.FP, 8, 24) + # Note: There's a "hidden" bit which makes it 25 instead of 24 fpsort8_25 = s.ConstructSort(s.FP, 8, 25) b0 = s.TheoryConst(bvsort1, 0) @@ -25,8 +25,8 @@ def test_basic(): b152 = s.TheoryConst(s.BitVec(8), 152) b42 = s.TheoryConst(s.BitVec(24), 42) - f = s.TheoryConst(fpsort8_24, b0, b200, b468) - f2 = s.TheoryConst(fpsort8_24, b0, b152, b42) + f = s.TheoryConst(fpsort8_25, b0, b200, b468) + f2 = s.TheoryConst(fpsort8_25, b0, b152, b42) fc = s.DeclareConst("fc", fpsort8_25) fc2 = s.DeclareConst("fc2", fpsort8_25) From 6e562b17ea9f33bba1bec9982deb8ea2a840bd34 Mon Sep 17 00:00:00 2001 From: makaimann Date: Tue, 5 Jun 2018 12:20:52 -0700 Subject: [PATCH 5/5] Get solvers with FP support --- util/get_solver_binaries.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/util/get_solver_binaries.sh b/util/get_solver_binaries.sh index de131c3..394eec0 100755 --- a/util/get_solver_binaries.sh +++ b/util/get_solver_binaries.sh @@ -15,8 +15,8 @@ do echo "Missing at least $solver_dir and possibly more" echo "Retrieving solvers" rm -rf ./smt_solvers - wget http://web.stanford.edu/~makaim/files/smt_solvers.tar.gz - tar -xzvf ./smt_solvers.tar.gz + wget http://web.stanford.edu/~makaim/files/smt_solvers_fp.tar.gz + tar -xzvf ./smt_solvers_fp.tar.gz all_logic_solvers=false break fi