diff --git a/examples/grisette-examples.cabal b/examples/grisette-examples.cabal index bad6f13f1..886f6a2df 100644 --- a/examples/grisette-examples.cabal +++ b/examples/grisette-examples.cabal @@ -1,11 +1,11 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.38.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack name: grisette-examples -version: 0.13.0.0 +version: 0.13.0.1 synopsis: Examples for Grisette description: More examples are available in the [tutorials](https://github.com/lsrcz/grisette/tree/main/tutorials) of diff --git a/grisette.cabal b/grisette.cabal index 5ce70c779..fe74e6404 100644 --- a/grisette.cabal +++ b/grisette.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -116,6 +116,7 @@ library Grisette.Internal.Core.Data.UnionBase Grisette.Internal.SymPrim.AlgReal Grisette.Internal.SymPrim.AllSyms + Grisette.Internal.SymPrim.Array Grisette.Internal.SymPrim.BV Grisette.Internal.SymPrim.FP Grisette.Internal.SymPrim.FunInstanceGen @@ -147,6 +148,7 @@ library Grisette.Internal.SymPrim.Quantifier Grisette.Internal.SymPrim.SomeBV Grisette.Internal.SymPrim.SymAlgReal + Grisette.Internal.SymPrim.SymArray Grisette.Internal.SymPrim.SymBool Grisette.Internal.SymPrim.SymBV Grisette.Internal.SymPrim.SymFP @@ -329,7 +331,7 @@ library , mtl >=2.2.2 && <2.4 , parallel >=3.2.2 && <3.3 , prettyprinter >=1.5.0 && <1.8 - , sbv >=8.17 && <13 + , sbv >=13.4 , stm ==2.5.* , template-haskell >=2.16 && <2.24 , text >=1.2.4.1 && <2.2 @@ -378,7 +380,7 @@ test-suite doctest , mtl >=2.2.2 && <2.4 , parallel >=3.2.2 && <3.3 , prettyprinter >=1.5.0 && <1.8 - , sbv >=8.17 && <13 + , sbv >=13.4 , stm ==2.5.* , template-haskell >=2.16 && <2.24 , text >=1.2.4.1 && <2.2 @@ -499,7 +501,7 @@ test-suite spec , mtl >=2.2.2 && <2.4 , parallel >=3.2.2 && <3.3 , prettyprinter >=1.5.0 && <1.8 - , sbv >=8.17 && <13 + , sbv >=13.4 , stm ==2.5.* , template-haskell >=2.16 && <2.24 , test-framework >=0.8.2 && <0.9 diff --git a/package.yaml b/package.yaml index f3f65e630..ea9d6e98d 100644 --- a/package.yaml +++ b/package.yaml @@ -36,7 +36,7 @@ dependencies: - th-compat >= 0.1.2 && < 0.2 - th-abstraction >= 0.4 && < 0.8 - array >= 0.5.4 && < 0.6 - - sbv >= 8.17 && < 13 + - sbv >= 13.4 - parallel >= 3.2.2 && < 3.3 - text >= 1.2.4.1 && < 2.2 - QuickCheck >= 2.14 && < 2.17 diff --git a/src/Grisette/Internal/Backend/Solving.hs b/src/Grisette/Internal/Backend/Solving.hs index 8916a55d9..2d1db9b48 100644 --- a/src/Grisette/Internal/Backend/Solving.hs +++ b/src/Grisette/Internal/Backend/Solving.hs @@ -264,6 +264,9 @@ import Grisette.Internal.SymPrim.Prim.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool)) @@ -796,6 +799,18 @@ lowerSinglePrimCached t' m' = do mode <- goCached qs mode arg <- goCached qs arg return $ \qst -> sbvToFPTerm @b (mode qst) (arg qst) + goCachedIntermediate qs (SelectTerm (arr :: Term arr) key) = withPrim @arr $ do + arr' <- goCached qs arr + key' <- goCached qs key + pure $ \qst -> SBV.readArray (arr' qst) (key' qst) + goCachedIntermediate qs (StoreTerm arr key val) = withPrim @a $ do + arr' <- goCached qs arr + key' <- goCached qs key + val' <- goCached qs val + pure $ \qst -> SBV.writeArray (arr' qst) (key' qst) (val' qst) + goCachedIntermediate qs (ConstArrayTerm _ val) = withPrim @a $ do + val' <- goCached qs val + pure $ \qst -> SBV.constArray $ val' qst goCachedIntermediate _ ConTerm {} = error "Should not happen" goCachedIntermediate _ SymTerm {} = error "Should not happen" goCachedIntermediate _ ForallTerm {} = error "Should not happen" diff --git a/src/Grisette/Internal/Core/Data/Class/ITEOp.hs b/src/Grisette/Internal/Core/Data/Class/ITEOp.hs index 51842ebc7..48cb37019 100644 --- a/src/Grisette/Internal/Core/Data/Class/ITEOp.hs +++ b/src/Grisette/Internal/Core/Data/Class/ITEOp.hs @@ -33,10 +33,13 @@ import Grisette.Internal.SymPrim.GeneralFun import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm)) import Grisette.Internal.SymPrim.Prim.Term ( SupportedPrim (pevalITETerm), + SupportedNonFuncPrim, + LinkedRep, TypedConstantSymbol, symTerm, ) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) +import Grisette.Internal.SymPrim.SymArray (SymArray (SymArray)) import Grisette.Internal.SymPrim.SymBV ( SymIntN (SymIntN), SymWordN (SymWordN), @@ -93,6 +96,15 @@ ITEOP_FUN((=->), (=~>), SymTabularFun) ITEOP_FUN((-->), (-~>), SymGeneralFun) #endif +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + ITEOp (SymArray sk sv) where + symIte (SymBool c) (SymArray t) (SymArray f) = SymArray $ pevalITETerm c t f + instance ITEOp (a --> b) where symIte (SymBool c) diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs index 8d92bb4ee..0c9eab183 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs @@ -75,6 +75,7 @@ import Grisette.Internal.Internal.Decl.Core.Data.Class.EvalSym evalSym1, ) import Grisette.Internal.SymPrim.AlgReal (AlgReal) +import Grisette.Internal.SymPrim.Array (Array) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP, @@ -86,9 +87,12 @@ import Grisette.Internal.SymPrim.GeneralFun (type (-->) (GeneralFun)) import Grisette.Internal.SymPrim.Prim.Model (evalTerm) import Grisette.Internal.SymPrim.Prim.Term ( SymRep (SymType), + SupportedNonFuncPrim, + LinkedRep, someTypedSymbol, ) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) +import Grisette.Internal.SymPrim.SymArray (SymArray (SymArray)) import Grisette.Internal.SymPrim.SymBV ( SymIntN (SymIntN), SymWordN (SymWordN), @@ -137,6 +141,7 @@ CONCRETE_EVALUATESYM(Ordering) CONCRETE_EVALUATESYM_BV(IntN) CONCRETE_EVALUATESYM_BV(WordN) CONCRETE_EVALUATESYM(AlgReal) +CONCRETE_EVALUATESYM((Array k v)) #endif instance EvalSym (Proxy a) where @@ -186,6 +191,15 @@ instance (ValidFP eb sb) => EvalSym (SymFP eb sb) where evalSym fillDefault model (SymFP t) = SymFP $ evalTerm fillDefault model HS.empty t +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + EvalSym (SymArray sk sv) where + evalSym fill model (SymArray t) = SymArray $ evalTerm fill model HS.empty t + derive [ ''(), ''AssertionError, diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs index 63c26b5c0..6937aef4c 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs @@ -90,6 +90,7 @@ import Grisette.Internal.Internal.Decl.Core.Data.Class.Mergeable wrapStrategy, ) import Grisette.Internal.SymPrim.AlgReal (AlgReal, AlgRealPoly, RealPoint) +import Grisette.Internal.SymPrim.Array (Array) import Grisette.Internal.SymPrim.BV ( IntN, WordN, @@ -103,6 +104,7 @@ import Grisette.Internal.SymPrim.FP ) import Grisette.Internal.SymPrim.GeneralFun (type (-->)) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal) +import Grisette.Internal.SymPrim.SymArray (SymArray) import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN) import Grisette.Internal.SymPrim.SymFP (SymFP, SymFPRoundingMode) import Grisette.Internal.SymPrim.SymGeneralFun (type (-~>)) @@ -111,6 +113,7 @@ import Grisette.Internal.SymPrim.SymTabularFun (type (=~>)) import Grisette.Internal.SymPrim.TabularFun (type (=->)) import Grisette.Internal.TH.Derivation.Derive (derive) import Unsafe.Coerce (unsafeCoerce) +import Grisette.Internal.SymPrim.Prim.Internal.Term (SupportedNonFuncPrim, LinkedRep) #define CONCRETE_ORD_MERGEABLE(type) \ instance Mergeable type where \ @@ -175,6 +178,9 @@ instance Mergeable (a =-> b) where instance Mergeable (a --> b) where rootStrategy = SimpleStrategy symIte +instance Mergeable (Array k v) where + rootStrategy = NoStrategy + #define MERGEABLE_SIMPLE(symtype) \ instance Mergeable symtype where \ rootStrategy = SimpleStrategy symIte @@ -197,6 +203,15 @@ MERGEABLE_FUN((=->), (=~>), SymTabularFun) MERGEABLE_FUN((-->), (-~>), SymGeneralFun) #endif +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Mergeable (SymArray sk sv) where + rootStrategy = SimpleStrategy $ symIte + instance (ValidFP eb sb) => Mergeable (SymFP eb sb) where rootStrategy = SimpleStrategy symIte diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs index 2cc3b47dc..d3d8247b1 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs @@ -78,10 +78,11 @@ import Grisette.Internal.Internal.Decl.Core.Data.Class.SimpleMergeable import Grisette.Internal.Internal.Impl.Core.Data.Class.TryMerge () import Grisette.Internal.SymPrim.FP (ValidFP) import Grisette.Internal.SymPrim.GeneralFun (freshArgSymbol, substTerm, type (-->) (GeneralFun)) -import Grisette.Internal.SymPrim.Prim.Internal.Term (SupportedPrim (pevalITETerm), symTerm) +import Grisette.Internal.SymPrim.Prim.Internal.Term (SupportedPrim (pevalITETerm), LinkedRep, symTerm) import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm)) -import Grisette.Internal.SymPrim.Prim.Term (TypedConstantSymbol) +import Grisette.Internal.SymPrim.Prim.Term (TypedConstantSymbol, SupportedNonFuncPrim) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) +import Grisette.Internal.SymPrim.SymArray (SymArray (SymArray)) import Grisette.Internal.SymPrim.SymBV ( SymIntN (SymIntN), SymWordN (SymWordN), @@ -559,6 +560,15 @@ SIMPLE_MERGEABLE_FUN((=->), (=~>), SymTabularFun) SIMPLE_MERGEABLE_FUN((-->), (-~>), SymGeneralFun) #endif +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + SimpleMergeable (SymArray sk sv) where + mrgIte (SymBool c) (SymArray t) (SymArray f) = SymArray $ pevalITETerm c t f + instance SimpleMergeable (a --> b) where mrgIte (SymBool c) diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs index b8fe9c838..44f2be6f5 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs @@ -80,8 +80,9 @@ import Grisette.Internal.SymPrim.FP ) import Grisette.Internal.SymPrim.Prim.Term ( SupportedPrim (pevalDistinctTerm), + LinkedRep (underlyingTerm, wrapTerm), + SupportedNonFuncPrim, pevalEqTerm, - underlyingTerm, ) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) import Grisette.Internal.SymPrim.SymBV @@ -95,6 +96,7 @@ import Grisette.Internal.SymPrim.SymFP ) import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger)) import Grisette.Internal.TH.Derivation.Derive (derive) +import Grisette.Internal.SymPrim.SymArray (SymArray) #define CONCRETE_SEQ(type) \ instance SymEq type where \ @@ -185,6 +187,15 @@ instance (ValidFP eb sb) => SymEq (SymFP eb sb) where (SymFP l) .== (SymFP r) = SymBool $ pevalEqTerm l r {-# INLINE (.==) #-} +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + SymEq (SymArray sk sv) where + lhs .== rhs = wrapTerm $ pevalEqTerm (underlyingTerm lhs) (underlyingTerm rhs) + derive [ ''(), ''AssertionError, diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs index 9a6cb8500..2c8f95973 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs @@ -100,10 +100,7 @@ import Grisette.Internal.SymPrim.SymBV SymWordN (SymWordN), ) import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool)) -import Grisette.Internal.SymPrim.SymFP - ( SymFP (SymFP), - SymFPRoundingMode (SymFPRoundingMode), - ) +import Grisette.Internal.SymPrim.SymFP (SymFP (SymFP)) import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger)) import Grisette.Internal.TH.Derivation.Derive (derive) @@ -254,7 +251,6 @@ instance SymOrd SymBool where #if 1 SORD_SIMPLE(SymInteger) SORD_SIMPLE(SymAlgReal) -SORD_SIMPLE(SymFPRoundingMode) SORD_BV(SymIntN) SORD_BV(SymWordN) #endif diff --git a/src/Grisette/Internal/SymPrim/Array.hs b/src/Grisette/Internal/SymPrim/Array.hs new file mode 100644 index 000000000..ae6797f6f --- /dev/null +++ b/src/Grisette/Internal/SymPrim/Array.hs @@ -0,0 +1,73 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveLift #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE ImportQualifiedPost #-} + +-- | +-- Module : Grisette.Internal.SymPrim.Array +-- Copyright : (c) Sirui Lu 2021-2023 +-- License : BSD-3-Clause (see the LICENSE file) +-- +-- Maintainer : siruilu@cs.washington.edu +-- Stability : Experimental +-- Portability : GHC only +module Grisette.Internal.SymPrim.Array + ( Array (..) + , const + , select + , store + ) where + +import Control.DeepSeq (NFData) +import Data.Binary qualified as Binary +import Data.Bytes.Serial (Serial (serialize, deserialize)) +import Data.Hashable (Hashable) +import Data.HashMap.Strict qualified as HM +import Data.Serialize qualified as Cereal +import GHC.Generics (Generic) +import Language.Haskell.TH.Syntax (Lift) +import Prelude (Show, Eq, Ord) + +-- TODO: The equality of this array model is incorrect. The easy solution is +-- to disallow it entirely. Alternatively, I already have a version with a +-- working equality check. It works by canonicalising the array. +-- +-- Canonicalisation will not happen for keys with an infinite domain and +-- realistically also not for keys with a sufficiently large domain. In fact, +-- we avoid tracking information for canonicalisation in these cases altogether! +-- The main gripe with this is that at that point is that insertions do require +-- keys for which we know both their cardinality and can enumerate their domain. +-- The latter we could restrict to only enumerable domains given a finite +-- cardinality with type-level shenanigans, but still. +-- +-- Yet another alternative would be to simply accept that we cannot conclude +-- inequality if one of the arrays would require canonicalisation? Then we don't +-- need the additional typeclass constraints. This way, we could still perform +-- normalisation of most terms. +data Array k v = Array (HM.HashMap k v) v + deriving (Show, Eq, Ord, Generic, Lift, Hashable, NFData) + +instance (Hashable k, Serial k, Serial v) => Serial (Array k v) + +instance (Hashable k, Serial k, Serial v) => Cereal.Serialize (Array k v) where + put = serialize + get = deserialize + +instance (Hashable k, Serial k, Serial v) => Binary.Binary (Array k v) where + put = serialize + get = deserialize + +-- TODO: Perhaps it is nice to make this a typeclass and give it names that do +-- not require qualified imports? I don't necessarily mind the qualified import, +-- but we'll see what the library author thinks. +const :: forall k v. v -> Array k v +const = Array HM.empty + +select :: forall k v. Hashable k => Array k v -> k -> v +select (Array entries root) key = HM.lookupDefault root key entries + +store :: forall k v. Hashable k => Array k v -> k -> v -> Array k v +store (Array entries root) key value = do + let entries' = HM.insert key value entries + Array entries' root diff --git a/src/Grisette/Internal/SymPrim/GeneralFun.hs b/src/Grisette/Internal/SymPrim/GeneralFun.hs index d42782657..011b19fa8 100644 --- a/src/Grisette/Internal/SymPrim/GeneralFun.hs +++ b/src/Grisette/Internal/SymPrim/GeneralFun.hs @@ -100,6 +100,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term PEvalOrdTerm (pevalLeOrdTerm, pevalLtOrdTerm), PEvalRotateTerm (pevalRotateRightTerm), PEvalShiftTerm (pevalShiftLeftTerm, pevalShiftRightTerm), + pevalSelectTerm, + pevalStoreTerm, + pevalConstArrayTerm, SBVRep (SBVType), SomeTypedAnySymbol, SomeTypedConstantSymbol, @@ -189,6 +192,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.Prim.Pattern (pattern SubTerms) import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm), someTerm) @@ -542,6 +548,12 @@ generalSubstSomeTerm subst initialBoundedSymbols = go initialMemo _ (SomeTerm (ToFPTerm mode (arg :: Term a) (_ :: p eb) (_ :: q sb))) = goBinary memo (pevalToFPTerm @a @eb @sb) mode arg + goSome memo _ (SomeTerm (SelectTerm arr key)) = + goBinary memo pevalSelectTerm arr key + goSome memo _ (SomeTerm (StoreTerm arr key val)) = + goTernary memo pevalStoreTerm arr key val + goSome memo _ (SomeTerm (ConstArrayTerm pkey val)) = + goUnary memo (pevalConstArrayTerm pkey) val goUnary memo f a = SomeTerm $ f (go memo a) goBinary memo f a b = SomeTerm $ f (go memo a) (go memo b) goTernary memo f a b c = diff --git a/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs b/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs index 9712c5f62..7414ffe63 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs @@ -30,9 +30,7 @@ import Grisette.Internal.SymPrim.AlgReal (AlgReal) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP, - FPRoundingMode, ValidFP, - allFPRoundingMode, ) import Grisette.Internal.SymPrim.Prim.Internal.Instances.PEvalNumTerm () import Grisette.Internal.SymPrim.Prim.Internal.Term @@ -41,10 +39,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term ( pevalLeOrdTerm, pevalLtOrdTerm, sbvLeOrdTerm, - sbvLtOrdTerm, withSbvOrdTermConstraint ), - SupportedPrim (conSBVTerm, withPrim), + SupportedPrim (withPrim), Term, conTerm, leOrdTerm, @@ -133,45 +130,6 @@ instance (ValidFP eb sb) => PEvalOrdTerm (FP eb sb) where (SBV.sNot (SBV.fpIsNaN x) SBV..&& SBV.sNot (SBV.fpIsNaN y)) SBV..&& (x SBV..<= y) --- Use this table to avoid accidental breakage introduced by sbv. -fpRoundingModeLtTable :: [(SBV.SRoundingMode, SBV.SRoundingMode)] -fpRoundingModeLtTable = - [ ( conSBVTerm @FPRoundingMode a, - conSBVTerm @FPRoundingMode b - ) - | a <- allFPRoundingMode, - b <- allFPRoundingMode, - a < b - ] - -fpRoundingModeLeTable :: [(SBV.SRoundingMode, SBV.SRoundingMode)] -fpRoundingModeLeTable = - [ ( conSBVTerm @FPRoundingMode a, - conSBVTerm @FPRoundingMode b - ) - | a <- allFPRoundingMode, - b <- allFPRoundingMode, - a <= b - ] - -sbvTableLookup :: - [(SBV.SRoundingMode, SBV.SRoundingMode)] -> - SBV.SRoundingMode -> - SBV.SRoundingMode -> - SBV.SBV Bool -sbvTableLookup tbl lhs rhs = - foldl - (\acc (a, b) -> acc SBV..|| ((lhs SBV..== a) SBV..&& (rhs SBV..== b))) - SBV.sFalse - tbl - -instance PEvalOrdTerm FPRoundingMode where - pevalLtOrdTerm = pevalGeneralLtOrdTerm - pevalLeOrdTerm = pevalGeneralLeOrdTerm - withSbvOrdTermConstraint r = withPrim @FPRoundingMode r - sbvLtOrdTerm = sbvTableLookup fpRoundingModeLtTable - sbvLeOrdTerm = sbvTableLookup fpRoundingModeLeTable - instance PEvalOrdTerm AlgReal where pevalLtOrdTerm = pevalGeneralLtOrdTerm pevalLeOrdTerm = pevalGeneralLeOrdTerm diff --git a/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs b/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs index cfa76b17b..ae81aee42 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs @@ -26,6 +26,7 @@ module Grisette.Internal.SymPrim.Prim.Internal.Serialize () where import Control.Monad (replicateM, unless, when) +import Control.Monad.Identity (Identity (runIdentity)) import Control.Monad.State (StateT, evalStateT) import qualified Control.Monad.State as State import qualified Data.Binary as Binary @@ -38,14 +39,17 @@ import qualified Data.HashSet as HS import Data.Hashable (Hashable (hashWithSalt)) import Data.List (intercalate) import Data.List.NonEmpty (NonEmpty ((:|))) +import Data.Maybe (isJust, fromMaybe) import Data.Proxy (Proxy (Proxy)) import qualified Data.Serialize as Cereal +import Data.Typeable (heqT) import Data.Word (Word8) import GHC.Generics (Generic) import GHC.Natural (Natural) import GHC.Stack (HasCallStack) import GHC.TypeNats (KnownNat, natVal, type (+), type (<=)) import Grisette.Internal.SymPrim.AlgReal (AlgReal) +import Grisette.Internal.SymPrim.Array (Array) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP, @@ -144,6 +148,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term termId, toFPTerm, xorBitsTerm, + selectTerm, + storeTerm, + constArrayTerm, pattern AbsNumTerm, pattern AddNumTerm, pattern AndBitsTerm, @@ -193,6 +200,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.Prim.SomeTerm ( SomeTerm (SomeTerm), @@ -213,11 +223,9 @@ import Grisette.Internal.Utils.Parameterized unsafeLeqProof, ) import Type.Reflection - ( SomeTypeRep (SomeTypeRep), - TypeRep, + ( TypeRep, Typeable, eqTypeRep, - someTypeRep, typeRep, pattern App, pattern Con, @@ -233,6 +241,7 @@ data KnownNonFuncType where FPType :: (ValidFP eb sb) => Proxy eb -> Proxy sb -> KnownNonFuncType FPRoundingModeType :: KnownNonFuncType AlgRealType :: KnownNonFuncType + ArrayType :: KnownNonFuncType -> KnownNonFuncType -> KnownNonFuncType instance Eq KnownNonFuncType where BoolType == BoolType = True @@ -255,6 +264,8 @@ instance Hashable KnownNonFuncType where s `hashWithSalt` (4 :: Int) `hashWithSalt` natVal p `hashWithSalt` natVal q hashWithSalt s FPRoundingModeType = s `hashWithSalt` (5 :: Int) hashWithSalt s AlgRealType = s `hashWithSalt` (6 :: Int) + hashWithSalt s (ArrayType k v) = + s `hashWithSalt` k `hashWithSalt` v `hashWithSalt` (7 :: Int) data KnownNonFuncTypeWitness where KnownNonFuncTypeWitness :: @@ -280,6 +291,10 @@ witnessKnownNonFuncType (FPType (Proxy :: Proxy eb) (Proxy :: Proxy sb)) = witnessKnownNonFuncType FPRoundingModeType = KnownNonFuncTypeWitness (Proxy @FPRoundingMode) witnessKnownNonFuncType AlgRealType = KnownNonFuncTypeWitness (Proxy @AlgReal) +witnessKnownNonFuncType (ArrayType k v) = runIdentity $ do + KnownNonFuncTypeWitness (_ :: Proxy k) <- pure $ witnessKnownNonFuncType k + KnownNonFuncTypeWitness (_ :: Proxy v) <- pure $ witnessKnownNonFuncType v + pure $ KnownNonFuncTypeWitness @(Array k v) Proxy data KnownType where NonFuncType :: KnownNonFuncType -> KnownType @@ -496,91 +511,90 @@ instance Show KnownNonFuncType where <> show (natVal (Proxy @sb)) show FPRoundingModeType = "FPRoundingMode" show AlgRealType = "AlgReal" + show (ArrayType key val) = "Array (" ++ show key ++ ") (" ++ show val ++ ")" instance Show KnownType where show (NonFuncType t) = show t show (TabularFunType ts) = intercalate " =-> " $ show <$> ts show (GeneralFunType ts) = intercalate " --> " $ show <$> ts -knownNonFuncType :: - forall a p. (SupportedNonFuncPrim a) => p a -> KnownNonFuncType -knownNonFuncType _ = - case tr of - _ | SomeTypeRep tr == someTypeRep (Proxy @Bool) -> BoolType - _ | SomeTypeRep tr == someTypeRep (Proxy @Integer) -> IntegerType - _ - | SomeTypeRep tr == someTypeRep (Proxy @FPRoundingMode) -> - FPRoundingModeType - _ | SomeTypeRep tr == someTypeRep (Proxy @AlgReal) -> AlgRealType - App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) -> - case ( eqTypeRep ta (typeRep @WordN), - eqTypeRep ta (typeRep @IntN) - ) of - (Just HRefl, _) -> withPrim @a $ WordNType (Proxy @n) - (_, Just HRefl) -> withPrim @a $ IntNType (Proxy @n) - _ -> err - App (App (tf :: TypeRep f) (_ :: TypeRep a0)) (_ :: TypeRep a1) -> - case eqTypeRep tf (typeRep @FP) of - Just HRefl -> withPrim @a $ FPType (Proxy @a0) (Proxy @a1) - _ -> err - _ -> err +knownNonFuncTypeMaybe :: + forall a p. SupportedPrim a => p a -> Maybe KnownNonFuncType +knownNonFuncTypeMaybe _ = withPrim @a $ case tr of + _ | isTy @Bool Proxy -> pure BoolType + | isTy @Integer Proxy -> pure IntegerType + | isTy @FPRoundingMode Proxy -> pure FPRoundingModeType + | isTy @AlgReal Proxy -> pure AlgRealType + App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) + | Just HRefl <- eqTypeRep ta $ typeRep @WordN -> pure $ WordNType @n Proxy + | Just HRefl <- eqTypeRep ta $ typeRep @IntN -> pure $ IntNType @n Proxy + App (App (tf :: TypeRep f) (_ :: TypeRep eb)) (_ :: TypeRep es) + | Just HRefl <- eqTypeRep tf $ typeRep @FP -> do + pure $ FPType (Proxy @eb) (Proxy @es) + App (App arrR (_ :: TypeRep k)) (_ :: TypeRep v) + | Just HRefl <- eqTypeRep arrR $ typeRep @Array -> do + keyTy <- knownNonFuncTypeMaybe @k Proxy + valTy <- knownNonFuncTypeMaybe @v Proxy + pure $ ArrayType keyTy valTy + _ -> Nothing where - tr = primTypeRep @a - err = error $ "knownNonFuncType: unsupported type: " <> show tr - -knownType :: - forall a p. (SupportedPrim a) => p a -> KnownType -knownType _ = - case tr of - _ | SomeTypeRep tr == someTypeRep (Proxy @Bool) -> NonFuncType BoolType - _ - | SomeTypeRep tr == someTypeRep (Proxy @Integer) -> - NonFuncType IntegerType - _ - | SomeTypeRep tr == someTypeRep (Proxy @FPRoundingMode) -> - NonFuncType FPRoundingModeType - _ - | SomeTypeRep tr == someTypeRep (Proxy @AlgReal) -> - NonFuncType AlgRealType - App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) -> - case ( eqTypeRep ta (typeRep @WordN), - eqTypeRep ta (typeRep @IntN) - ) of - (Just HRefl, _) -> withPrim @a $ NonFuncType $ WordNType (Proxy @n) - (_, Just HRefl) -> withPrim @a $ NonFuncType $ IntNType (Proxy @n) - _ -> err - App (App (tf :: TypeRep f) (_ :: TypeRep a0)) (_ :: TypeRep a1) -> - case ( eqTypeRep tf (typeRep @FP), - eqTypeRep tf (typeRep @(=->)), - eqTypeRep tf (typeRep @(-->)) - ) of - (Just HRefl, _, _) -> - withPrim @a $ NonFuncType $ FPType (Proxy @a0) (Proxy @a1) - (_, Just HRefl, _) -> - withPrim @a $ - let arg = knownType (Proxy @a0) - ret = knownType (Proxy @a1) - in case arg of - NonFuncType n -> case ret of - NonFuncType m -> TabularFunType [n, m] - TabularFunType ns -> TabularFunType (n : ns) - _ -> err - _ -> err - (_, _, Just HRefl) -> - withPrim @a $ - let arg = knownType (Proxy @a0) - ret = knownType (Proxy @a1) - in case arg of - NonFuncType n -> case ret of - NonFuncType m -> GeneralFunType [n, m] - GeneralFunType ns -> GeneralFunType (n : ns) - _ -> err - _ -> err - _ -> err - _ -> err + tr = typeRep @a + + isTy :: forall b. Typeable b => Proxy b -> Bool + isTy _ = isJust . eqTypeRep tr $ typeRep @b + +knownNonFuncType :: + forall a p. SupportedPrim a => p a -> KnownNonFuncType +knownNonFuncType proxy = do + let err = error $ "knownNonFuncType: unsupported type: " <> show (typeRep @a) + fromMaybe err $ knownNonFuncTypeMaybe proxy + +knownTypeMaybe :: forall a p. SupportedPrim a => p a -> Maybe KnownType +knownTypeMaybe proxy = withPrim @a $ case tr of + _ | Just result <- knownNonFuncTypeMaybe proxy -> pure $ NonFuncType result + App (App (funR :: TypeRep f) (_ :: TypeRep arg)) (_ :: TypeRep res) + | Just HRefl <- eqTypeRep funR $ typeRep @(=->) -> do + -- Gather the argument type. + arg <- knownTypeMaybe @arg Proxy + n <- case arg of + NonFuncType n -> pure n + _ -> Nothing + + -- Gather the result type. + ret <- knownTypeMaybe @res Proxy + ns <- case ret of + NonFuncType m -> pure [m] + TabularFunType ns -> pure ns + _ -> Nothing + + -- Create the tabular function type. + pure $ TabularFunType (n : ns) + + | Just HRefl <- eqTypeRep funR $ typeRep @(-->) -> do + -- Gather the argument type. + arg <- knownTypeMaybe @arg Proxy + n <- case arg of + NonFuncType n -> pure n + _ -> Nothing + + -- Gather the result type. + ret <- knownTypeMaybe @res Proxy + ns <- case ret of + NonFuncType m -> pure [m] + GeneralFunType ns -> pure ns + _ -> Nothing + + -- Create the general function type. + pure $ GeneralFunType (n : ns) + + _ -> Nothing where - tr = primTypeRep @a - err = error $ "knownType: unsupported type: " <> show tr + tr = typeRep @a + +knownType :: forall a p. SupportedPrim a => p a -> KnownType +knownType proxy = do + let err = error $ "knownType: unsupported type: " <> show (typeRep @a) + fromMaybe err $ knownTypeMaybe proxy -- Bool: 0 -- Integer: 1 @@ -589,6 +603,7 @@ knownType _ = -- FP: 4 -- FPRoundingMode: 5 -- AlgReal: 6 +-- Array: 7 serializeKnownNonFuncType :: (MonadPut m) => KnownNonFuncType -> m () serializeKnownNonFuncType BoolType = putWord8 0 serializeKnownNonFuncType IntegerType = putWord8 1 @@ -600,6 +615,10 @@ serializeKnownNonFuncType (FPType (Proxy :: Proxy eb) (Proxy :: Proxy sb)) = putWord8 4 >> serialize (natVal (Proxy @eb)) >> serialize (natVal (Proxy @sb)) serializeKnownNonFuncType FPRoundingModeType = putWord8 5 serializeKnownNonFuncType AlgRealType = putWord8 6 +serializeKnownNonFuncType (ArrayType key val) = do + putWord8 7 + serializeKnownNonFuncType key + serializeKnownNonFuncType val serializeKnownType :: (MonadPut m) => KnownType -> m () serializeKnownType (NonFuncType t) = putWord8 0 >> serializeKnownNonFuncType t @@ -639,6 +658,10 @@ deserializeKnownNonFuncType = do withUnsafeValidFP @eb @sb $ return $ FPType (Proxy @eb) (Proxy @sb) 5 -> return FPRoundingModeType 6 -> return AlgRealType + 7 -> do + keyT <- deserializeKnownNonFuncType + valT <- deserializeKnownNonFuncType + pure $ ArrayType keyT valT _ -> fail "deserializeKnownNonFuncType: Unknown type tag" deserializeKnownType :: (MonadGet m) => m KnownType @@ -877,6 +900,15 @@ fromFPOrTermTag = 46 toFPTermTag :: Word8 toFPTermTag = 47 +selectTermTag :: Word8 +selectTermTag = 48 + +storeTermTag :: Word8 +storeTermTag = 49 + +constArrayTermTag :: Word8 +constArrayTermTag = 50 + terminalTag :: Word8 terminalTag = 255 @@ -931,33 +963,18 @@ asNumTypeTerm (SomeTerm (t1 :: Term a)) f = err = error $ "asNumTypeTerm: unsupported type: " <> show ta asOrdTypeTerm :: - (HasCallStack) => SomeTerm -> (forall n. (PEvalOrdTerm n) => Term n -> r) -> r -asOrdTypeTerm (SomeTerm (t1 :: Term a)) f = - case ( eqTypeRep ta (typeRep @Integer), - eqTypeRep ta (typeRep @AlgReal), - eqTypeRep ta (typeRep @FPRoundingMode) - ) of - (Just HRefl, _, _) -> f t1 - (_, Just HRefl, _) -> f t1 - (_, _, Just HRefl) -> f t1 - _ -> - case ta of - App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) -> - case ( eqTypeRep ta (typeRep @WordN), - eqTypeRep ta (typeRep @IntN) - ) of - (Just HRefl, _) -> withPrim @a $ f t1 - (_, Just HRefl) -> withPrim @a $ f t1 - _ -> err - App (App (tf :: TypeRep f) (_ :: TypeRep a0)) (_ :: TypeRep a1) -> - case eqTypeRep tf (typeRep @FP) of - Just HRefl -> - withPrim @a $ withPrim @a $ f t1 - _ -> err - _ -> err + HasCallStack => SomeTerm -> (forall n. PEvalOrdTerm n => Term n -> r) -> r +asOrdTypeTerm (SomeTerm (t1 :: Term a)) f = case ta of + _ | Just HRefl <- eqTypeRep ta $ typeRep @Integer -> f t1 + _ | Just HRefl <- eqTypeRep ta $ typeRep @AlgReal -> f t1 + App bvR _nR + | Just HRefl <- eqTypeRep bvR $ typeRep @WordN -> withPrim @a $ f t1 + | Just HRefl <- eqTypeRep bvR $ typeRep @IntN -> withPrim @a $ f t1 + App (App fpR _ebR) _esR + | Just HRefl <- eqTypeRep fpR $ typeRep @FP -> withPrim @a $ f t1 + _ -> error $ "asNumTypeTerm: unsupported type: " <> show ta where ta = primTypeRep @a - err = error $ "asOrdTypeTerm: unsupported type: " <> show ta asBitsTypeTerm :: (HasCallStack) => @@ -1527,6 +1544,67 @@ statefulDeserializeSomeTerm = do ktTmId ) else error "statefulDeserializeSomeTerm: invalid FP type" + | tag == selectTermTag -> do + -- Deserialize the array and key. + SomeTerm (arr :: Term a) <- deserializeTerm + SomeTerm (key :: Term k) <- deserializeTerm + + -- Deserialize the resulting value type and get the required + -- dictionaries for this type. + -- TODO: How do I know I get the correct known type here? + valType <- deserializeKnownType + KnownTypeWitness (_ :: Proxy v) <- pure $ witnessKnownType valType + + -- Ensure that the key is indeed a valid key for the array and that + -- the resulting value matches the value type such that we can provide + -- the dictionary. Using this, we construct the final term. + let term = case typeRep @a of + App (App aRep kRep) vRep + | Just HRefl <- eqTypeRep aRep $ typeRep @Array + , Just HRefl <- eqTypeRep kRep $ typeRep @k + , Just HRefl <- eqTypeRep vRep $ typeRep @v -> do + someTerm $ selectTerm @k @v arr key + _ -> error "statefulDeserializeSomeTerm: invalid Array type" + + pure $ Just (term, ktTmId) + | tag == storeTermTag -> do + -- Deserialize the array, key and value. + SomeTerm (arr :: Term a) <- deserializeTerm + SomeTerm (key :: Term k) <- deserializeTerm + SomeTerm (val :: Term v) <- deserializeTerm + + -- Ensure the types match up such that we can construct the term. + let term = case typeRep @a of + App (App aRep kRep) vRep + | Just HRefl <- eqTypeRep aRep $ typeRep @Array + , Just HRefl <- eqTypeRep kRep $ typeRep @k + , Just HRefl <- eqTypeRep vRep $ typeRep @v -> do + someTerm $ storeTerm @k @v arr key val + _ -> error "statefulDeserializeSomeTerm: invalid Array type" + + pure $ Just (term, ktTmId) + | tag == constArrayTermTag -> do + -- Get the value term and non-function primitive dictionary. + SomeTerm (val :: Term v) <- deserializeTerm + let valType = knownNonFuncType @v Proxy + KnownNonFuncTypeWitness (_ :: p v') <- do + pure $ witnessKnownNonFuncType valType + + -- Gather the key type and its non-function primitive dictionary + -- TODO: How do I know I get the correct known type for the key? + keyType <- deserializeKnownNonFuncType + KnownNonFuncTypeWitness (_ :: p k) <- do + pure $ witnessKnownNonFuncType keyType + + -- Really, this should never fail but I guess we can check instead of + -- coercing unsafely. + HRefl <- case heqT @v @v' of + Just refl -> pure refl + Nothing -> error "statefulDeserializeSomeTerm: non-injective type translation" + + let term = someTerm $ constArrayTerm @k Proxy val + + pure $ Just (term, ktTmId) | otherwise -> error $ "statefulDeserializeSomeTerm: unknown tag: " <> show tag case r of @@ -1811,6 +1889,14 @@ serializeSingleSomeTerm (SomeTerm (tm :: Term t)) = do serialize $ natVal sb serialize $ knownTypeTermId rd serialize $ knownTypeTermId t + SelectTerm arr key -> do + serializeBinary ktTmId selectTermTag arr key + serializeKnownType $ knownType @t Proxy + StoreTerm arr key val -> do + serializeTernary ktTmId storeTermTag arr key val + ConstArrayTerm pkey val -> withPrim @t $ do + serializeUnary ktTmId constArrayTermTag val + serializeKnownType $ knownType pkey State.put $ HS.insert ktTmId st where serializeQuantified :: diff --git a/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs b/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs index 112a48f4c..8f4c1fbf0 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs @@ -69,6 +69,9 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term PEvalFloatingTerm (..), PEvalFromIntegralTerm (..), PEvalIEEEFPConvertibleTerm (..), + pevalSelectTerm, + pevalStoreTerm, + pevalConstArrayTerm, -- * Typed symbols SymbolKind (..), @@ -165,6 +168,9 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term fromIntegralTerm, fromFPOrTerm, toFPTerm, + selectTerm, + storeTerm, + constArrayTerm, -- * Patterns pattern SupportedTerm, @@ -220,6 +226,9 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term pattern FromIntegralTerm, pattern FromFPOrTerm, pattern ToFPTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, -- * Support for boolean type trueTerm, @@ -329,6 +338,7 @@ import qualified Control.Monad.Writer.Lazy as Lazy import qualified Control.Monad.Writer.Strict as Strict import Data.Atomics (atomicModifyIORefCAS_) import qualified Data.Binary as Binary +import Data.Bifunctor (Bifunctor(bimap)) import Data.Bits ( Bits (complement, isSigned, xor, zeroBits, (.&.), (.|.)), FiniteBits (countLeadingZeros), @@ -377,6 +387,7 @@ import Grisette.Internal.Core.Data.Symbol Symbol (IndexedSymbol, SimpleSymbol), ) import Grisette.Internal.SymPrim.AlgReal (AlgReal, fromSBVAlgReal, toSBVAlgReal) +import Grisette.Internal.SymPrim.Array (Array (Array)) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP (FP), @@ -495,7 +506,8 @@ class Eq a, Show a, Hashable a, - Typeable a + Typeable a, + SBVType a ~ SBV.SBV (NonFuncSBVBaseType a) ) => NonFuncSBVRep a where @@ -509,7 +521,7 @@ type NonFuncPrimConstraint a = SBV.Mergeable (SBVType a), SBV.SMTDefinable (SBVType a), SBV.Mergeable (SBVType a), - SBVType a ~ SBV.SBV (NonFuncSBVBaseType a), + SBVT.SatModel (NonFuncSBVBaseType a), PrimConstraint a ) @@ -520,6 +532,7 @@ class (NonFuncSBVRep a) => SupportedNonFuncPrim a where symNonFuncSBVTerm :: (SBVFreshMonad m) => String -> m (SBV.SBV (NonFuncSBVBaseType a)) withNonFuncPrim :: ((NonFuncPrimConstraint a) => r) -> r + sbvToCon :: NonFuncSBVBaseType a -> a -- | Partition the list of CVs for models for functions. partitionCVArg :: @@ -644,6 +657,13 @@ class (SBVT.EqSymbolic (SBVType t)) => NonEmpty (SBVType t) -> SBV.SBV Bool sbvDistinct = SBV.distinct . toList parseSMTModelResult :: Int -> ([([SBVD.CV], SBVD.CV)], SBVD.CV) -> t + default parseSMTModelResult :: + SupportedNonFuncPrim t => + Int -> + ([([SBVD.CV], SBVD.CV)], SBVD.CV) -> + t + parseSMTModelResult _ = withNonFuncPrim @t $ do + parseScalarSMTModelResult sbvToCon castTypedSymbol :: (IsSymbolKind knd') => TypedSymbol knd t -> Maybe (TypedSymbol knd' t) funcDummyConstraint :: SBVType t -> SBV.SBV Bool @@ -1734,6 +1754,25 @@ data Term t where Proxy eb -> Proxy sb -> Term (FP eb sb) + SelectTerm' :: + SupportedPrim (Array k v) => + {-# UNPACK #-} !CachedInfo -> + !(Term (Array k v)) -> + !(Term k) -> + Term v + StoreTerm' :: + SupportedPrim (Array k v) => + {-# UNPACK #-} !CachedInfo -> + !(Term (Array k v)) -> + !(Term k) -> + !(Term v) -> + Term (Array k v) + ConstArrayTerm' :: + SupportedPrim (Array k v) => + {-# UNPACK #-} !CachedInfo -> + Proxy k -> + !(Term v) -> + Term (Array k v) data SupportedPrimEvidence t where SupportedPrimEvidence :: (SupportedPrim t) => SupportedPrimEvidence t @@ -2699,6 +2738,67 @@ pattern ToFPTerm rm t eb sb <- (ToFPTerm' _ rm t@SupportedTerm eb sb) {-# INLINE ToFPTerm #-} #endif +-- | Pattern synonym for 'SelectTerm''. Note that using this pattern to +-- construct a 'Term' will do term simplification. +pattern SelectTerm :: + forall ret. + () => + forall k v. + ( SupportedPrim (Array k v), + ret ~ v + ) => + Term (Array k v) -> + Term k -> + Term ret +pattern SelectTerm arr key <- SelectTerm' _ arr key + where + SelectTerm arr key = pevalSelectTerm arr key + +#if MIN_VERSION_base(4, 16, 4) +{-# INLINE SelectTerm #-} +#endif + +-- | Pattern synonym for 'StoreTerm''. Note that using this pattern to +-- construct a 'Term' will do term simplification. +pattern StoreTerm :: + forall ret. + () => + forall k v. + ( SupportedPrim (Array k v), + ret ~ Array k v + ) => + Term (Array k v) -> + Term k -> + Term v -> + Term ret +pattern StoreTerm arr key val <- StoreTerm' _ arr key val + where + StoreTerm arr key = pevalStoreTerm arr key + +#if MIN_VERSION_base(4, 16, 4) +{-# INLINE StoreTerm #-} +#endif + +-- | Pattern synonym for 'StoreTerm''. Note that using this pattern to +-- construct a 'Term' will do term simplification. +pattern ConstArrayTerm :: + forall ret. + () => + forall k v. + ( SupportedPrim (Array k v), + ret ~ Array k v + ) => + Proxy k -> + Term v -> + Term ret +pattern ConstArrayTerm pkey val <- ConstArrayTerm' _ pkey val + where + ConstArrayTerm pkey val = pevalConstArrayTerm pkey val + +#if MIN_VERSION_base(4, 16, 4) +{-# INLINE ConstArrayTerm #-} +#endif + #if MIN_VERSION_base(4, 16, 4) {-# COMPLETE ConTerm, @@ -2748,7 +2848,10 @@ pattern ToFPTerm rm t eb sb <- (ToFPTerm' _ rm t@SupportedTerm eb sb) FPFMATerm, FromIntegralTerm, FromFPOrTerm, - ToFPTerm + ToFPTerm, + SelectTerm, + StoreTerm, + ConstArrayTerm #-} #endif @@ -2802,6 +2905,9 @@ termInfo (FPFMATerm' i _ _ _ _) = i termInfo (FromIntegralTerm' i _) = i termInfo (FromFPOrTerm' i _ _ _) = i termInfo (ToFPTerm' i _ _ _ _) = i +termInfo (SelectTerm' i _ _) = i +termInfo (StoreTerm' i _ _ _) = i +termInfo (ConstArrayTerm' i _ _) = i -- | Get the thread ID for a term. {-# INLINE termThreadId #-} @@ -2923,6 +3029,10 @@ introSupportedPrimConstraint0 FPFMATerm' {} x = x introSupportedPrimConstraint0 FromIntegralTerm' {} x = x introSupportedPrimConstraint0 FromFPOrTerm' {} x = x introSupportedPrimConstraint0 ToFPTerm' {} x = x +introSupportedPrimConstraint0 (SelectTerm' _ (_ :: Term arr) _) x = do + withPrim @arr x +introSupportedPrimConstraint0 StoreTerm' {} x = x +introSupportedPrimConstraint0 ConstArrayTerm' {} x = x -- | Introduce the 'SupportedPrim' constraint from a term. introSupportedPrimConstraint :: @@ -2984,6 +3094,9 @@ pformatTerm (FPFMATerm mode arg1 arg2 arg3) = pformatTerm (FromIntegralTerm arg) = "(from_integral " ++ pformatTerm arg ++ ")" pformatTerm (FromFPOrTerm d r arg) = "(from_fp_or " ++ pformatTerm d ++ " " ++ pformatTerm r ++ " " ++ pformatTerm arg ++ ")" pformatTerm (ToFPTerm r arg _ _) = "(to_fp " ++ pformatTerm r ++ " " ++ pformatTerm arg ++ ")" +pformatTerm (SelectTerm arr key) = "(select " ++ pformatTerm arr ++ " " ++ pformatTerm key ++ ")" +pformatTerm (StoreTerm arr key val) = "(store " ++ pformatTerm arr ++ " " ++ pformatTerm key ++ " " ++ pformatTerm val ++ ")" +pformatTerm (ConstArrayTerm _ val) = "(const_array " ++ pformatTerm val ++ ")" -- {-# INLINE pformatTerm #-} @@ -3054,6 +3167,11 @@ instance Lift (Term t) where liftTyped (FromFPOrTerm t1 t2 t3) = [||fromFPOrTerm t1 t2 t3||] liftTyped (ToFPTerm t1 t2 _ _) = [||toFPTerm t1 t2||] + liftTyped (SelectTerm t1 t2) = [||selectTerm t1 t2||] + liftTyped (StoreTerm t1 t2 t3) = [||storeTerm t1 t2 t3||] + liftTyped (ConstArrayTerm (_ :: p k) t2) = do + let pkey = [||Proxy||] :: CODE (Proxy k) + [||constArrayTerm $$pkey t2||] instance Show (Term ty) where show t@(ConTerm v) = @@ -3530,6 +3648,38 @@ instance Show (Term ty) where ++ ", arg=" ++ show arg ++ "}" + show t@(SelectTerm arr key) = + "SelectTerm{tid=" + ++ show (termThreadId t) + ++ ", id=" + ++ show (termId t) + ++ ", array=" + ++ show arr + ++ ", key=" + ++ show key + ++ "}" + show t@(StoreTerm arr key val) = + "StoreTerm{tid=" + ++ show (termThreadId t) + ++ ", id=" + ++ show (termId t) + ++ ", array=" + ++ show arr + ++ ", key=" + ++ show key + ++ ", val=" + ++ show val + ++ "}" + show t@(ConstArrayTerm (_ :: p k) val) = + "ConstArrayTerm{tid=" + ++ show (termThreadId t) + ++ ", id=" + ++ show (termId t) + ++ ", key=" + ++ withPrim @ty (show $ typeRep @k) + ++ ", val=" + ++ show val + ++ "}" -- {-# INLINE show #-} @@ -3751,6 +3901,22 @@ data UTerm t where Proxy eb -> Proxy sb -> UTerm (FP eb sb) + USelectTerm :: + SupportedPrim (Array k v) => + !(Term (Array k v)) -> + !(Term k) -> + UTerm v + UStoreTerm :: + SupportedPrim (Array k v) => + !(Term (Array k v)) -> + !(Term k) -> + !(Term v) -> + UTerm (Array k v) + UConstArrayTerm :: + SupportedPrim (Array k v) => + Proxy k -> + !(Term v) -> + UTerm (Array k v) -- | Compare two t'TypedSymbol's for equality. eqHeteroSymbol :: forall ta a tb b. TypedSymbol ta a -> TypedSymbol tb b -> Bool @@ -4017,6 +4183,20 @@ preHashToFPTermDescription h1 h2 = fromIntegral (50 `hashWithSalt` h1 `hashWithSalt` h2) {-# INLINE preHashToFPTermDescription #-} +preHashSelectDescription :: HashId -> HashId -> Digest +preHashSelectDescription h1 h2 = + fromIntegral (51 `hashWithSalt` h1 `hashWithSalt` h2) +{-# INLINE preHashSelectDescription #-} + +preHashStoreDescription :: HashId -> HashId -> HashId -> Digest +preHashStoreDescription h1 h2 h3 = + fromIntegral (52 `hashWithSalt` h1 `hashWithSalt` h2 `hashWithSalt` h3) +{-# INLINE preHashStoreDescription #-} + +preHashConstArrayDescription :: HashId -> Digest +preHashConstArrayDescription h1 = fromIntegral (53 `hashWithSalt` h1) +{-# INLINE preHashConstArrayDescription #-} + instance Interned (Term t) where type Uninterned (Term t) = UTerm t data Description (Term t) where @@ -4265,6 +4445,22 @@ instance Interned (Term t) where {-# UNPACK #-} !HashId -> {-# UNPACK #-} !TypeHashId -> Description (Term (FP eb sb)) + DSelectTerm :: + {-# UNPACK #-} !Digest -> + {-# UNPACK #-} !HashId -> + {-# UNPACK #-} !HashId -> + Description (Term v) + DStoreTerm :: + {-# UNPACK #-} !Digest -> + {-# UNPACK #-} !HashId -> + {-# UNPACK #-} !HashId -> + {-# UNPACK #-} !HashId -> + Description (Term v) + DConstArrayTerm :: + {-# UNPACK #-} !Digest -> + {-# UNPACK #-} !Fingerprint -> + {-# UNPACK #-} !HashId -> + Description (Term v) describe (UConTerm v) = DConTerm sameCon (preHashConDescription v) v describe ((USymTerm name) :: UTerm t) = @@ -4576,6 +4772,22 @@ instance Interned (Term t) where (preHashToFPTermDescription modeHashId argHashId) modeHashId argHashId + describe (USelectTerm arr key) = do + let arrHashId = termHashId arr + let keyHashId = termHashId key + let digest = preHashSelectDescription arrHashId keyHashId + DSelectTerm digest arrHashId keyHashId + describe (UStoreTerm arr key val) = do + let arrHashId = termHashId arr + let keyHashId = termHashId key + let valHashId = termHashId val + let digest = preHashStoreDescription arrHashId keyHashId valHashId + DStoreTerm digest arrHashId keyHashId valHashId + describe (UConstArrayTerm pkey val) = withPrim @t $ do + let keyFingerprint = typeRepFingerprint $ someTypeRep pkey + let valHashId = termHashId val + let digest = preHashConstArrayDescription valHashId + DConstArrayTerm digest keyFingerprint valHashId -- {-# INLINE describe #-} @@ -4640,6 +4852,9 @@ instance Interned (Term t) where go (UFromFPOrTerm d mode arg) = FromFPOrTerm' info d mode arg go (UToFPTerm mode (arg :: Term a) _ _) = goPhantomToFP info getPhantomDict mode arg + go (USelectTerm arr key) = SelectTerm' info arr key + go (UStoreTerm arr key val) = StoreTerm' info arr key val + go (UConstArrayTerm proxy val) = ConstArrayTerm' info proxy val {-# INLINE go #-} -- {-# INLINE identify #-} @@ -4694,6 +4909,9 @@ instance Interned (Term t) where descriptionDigest (DFromIntegralTerm h _) = h descriptionDigest (DFromFPOrTerm h _ _ _) = h descriptionDigest (DToFPTerm h _ _) = h + descriptionDigest (DSelectTerm h _ _) = h + descriptionDigest (DStoreTerm h _ _ _) = h + descriptionDigest (DConstArrayTerm h _ _) = h -- {-# INLINE descriptionDigest #-} {-# NOINLINE goPhantomCon #-} @@ -5087,6 +5305,18 @@ fullReconstructTerm (FromFPOrTerm d r arg) = fullReconstructTerm3 curThreadFromFPOrTerm d r arg fullReconstructTerm (ToFPTerm r arg _ _) = fullReconstructTerm2 curThreadToFPTerm r arg +fullReconstructTerm (SelectTerm (arr :: Term arr) key) = withPrim @arr $ do + arr' <- fullReconstructTerm arr + key' <- fullReconstructTerm key + intern $ USelectTerm arr' key' +fullReconstructTerm (StoreTerm arr key val) = do + arr' <- fullReconstructTerm arr + key' <- fullReconstructTerm key + val' <- fullReconstructTerm val + intern $ UStoreTerm arr' key' val' +fullReconstructTerm (ConstArrayTerm pkey val) = do + val' <- fullReconstructTerm val + intern $ UConstArrayTerm pkey val' toCurThreadImpl :: forall t. WeakThreadId -> Term t -> IO (Term t) toCurThreadImpl tid t | termThreadId t == tid = return t @@ -5505,6 +5735,36 @@ curThreadToFPTerm :: curThreadToFPTerm r f = intern $ UToFPTerm r f (Proxy @eb) (Proxy @sb) {-# INLINE curThreadToFPTerm #-} +-- | Construct and internalizing a 'SelectTerm'. +curThreadSelectTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + IO (Term v) +curThreadSelectTerm arr key = withPrim @(Array k v) $ do + intern $ USelectTerm arr key +{-# INLINE curThreadSelectTerm #-} + +curThreadStoreTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v -> + IO (Term (Array k v)) +curThreadStoreTerm arr key val = intern $ UStoreTerm arr key val +{-# INLINE curThreadStoreTerm #-} + +curThreadConstArrayTerm :: + forall k v. + SupportedPrim (Array k v) => + Proxy k -> + Term v -> + IO (Term (Array k v)) +curThreadConstArrayTerm pkey val = intern $ UConstArrayTerm pkey val +{-# INLINE curThreadConstArrayTerm #-} + inCurThread1 :: forall a b. (Term a -> IO (Term b)) -> @@ -6045,6 +6305,37 @@ toFPTerm :: toFPTerm = unsafeInCurThread2 curThreadToFPTerm {-# NOINLINE toFPTerm #-} +-- | Construct and internalizing a 'SelectTerm'. +selectTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v +selectTerm = unsafeInCurThread2 curThreadSelectTerm +{-# NOINLINE selectTerm #-} + +-- | Construct and internalizing a 'StoreTerm'. +storeTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v -> + Term (Array k v) +storeTerm = unsafeInCurThread3 curThreadStoreTerm +{-# NOINLINE storeTerm #-} + +-- | Construct and internalizing a 'ConstArrayTerm'. +constArrayTerm :: + forall k v. + SupportedPrim (Array k v) => + Proxy k -> + Term v -> + Term (Array k v) +constArrayTerm pkey = unsafeInCurThread1 $ curThreadConstArrayTerm pkey +{-# NOINLINE constArrayTerm #-} + -- Support for boolean type defaultValueForBool :: Bool defaultValueForBool = False @@ -6692,7 +6983,6 @@ instance SupportedPrim Bool where symSBVName symbol _ = show symbol symSBVTerm = sbvFresh withPrim r = r - parseSMTModelResult _ = parseScalarSMTModelResult id castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -6711,6 +7001,7 @@ instance SupportedNonFuncPrim Bool where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @Bool withNonFuncPrim r = r + sbvToCon = id data PhantomDict a where PhantomDict :: (SupportedPrim a) => PhantomDict a @@ -6807,7 +7098,6 @@ instance SupportedPrim Integer where conSBVTerm n = fromInteger n symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ = parseScalarSMTModelResult id castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -6826,6 +7116,7 @@ instance SupportedNonFuncPrim Integer where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @Integer withNonFuncPrim r = r + sbvToCon = id pevalITEBVTerm :: forall bv n. @@ -6955,9 +7246,6 @@ instance (KnownNat w, 1 <= w) => SupportedPrim (IntN w) where symSBVTerm name = bvIsNonZeroFromGEq1 (Proxy @w) $ sbvFresh name withPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r {-# INLINE withPrim #-} - parseSMTModelResult _ cv = - withPrim @(IntN w) $ - parseScalarSMTModelResult (\(x :: SBV.IntN w) -> fromIntegral x) cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -6988,6 +7276,7 @@ instance (KnownNat w, 1 <= w) => SupportedNonFuncPrim (IntN w) where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @(IntN w) withNonFuncPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r + sbvToCon = withPrim @(IntN w) fromIntegral -- Unsigned BV instance (KnownNat w, 1 <= w) => SupportedPrimConstraint (WordN w) where @@ -7014,9 +7303,6 @@ instance (KnownNat w, 1 <= w) => SupportedPrim (WordN w) where symSBVTerm name = bvIsNonZeroFromGEq1 (Proxy @w) $ sbvFresh name withPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r {-# INLINE withPrim #-} - parseSMTModelResult _ cv = - withPrim @(WordN w) $ - parseScalarSMTModelResult (\(x :: SBV.WordN w) -> fromIntegral x) cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -7035,6 +7321,7 @@ instance (KnownNat w, 1 <= w) => SupportedNonFuncPrim (WordN w) where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @(WordN w) withNonFuncPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r + sbvToCon = withPrim @(WordN w) fromIntegral -- FP instance (ValidFP eb sb) => SupportedPrimConstraint (FP eb sb) where @@ -7067,9 +7354,6 @@ instance (ValidFP eb sb) => SupportedPrim (FP eb sb) where conSBVTerm (FP fp) = SBV.literal fp symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ cv = - withPrim @(FP eb sb) $ - parseScalarSMTModelResult (\(x :: SBV.FloatingPoint eb sb) -> coerce x) cv funcDummyConstraint _ = SBV.sTrue -- Workaround for sbv#702. @@ -7101,6 +7385,7 @@ instance (ValidFP eb sb) => SupportedNonFuncPrim (FP eb sb) where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @(FP eb sb) withNonFuncPrim r = r + sbvToCon = coerce -- FPRoundingMode instance SupportedPrimConstraint FPRoundingMode @@ -7122,17 +7407,6 @@ instance SupportedPrim FPRoundingMode where conSBVTerm RTZ = SBV.sRTZ symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ cv = - withPrim @(FPRoundingMode) $ - parseScalarSMTModelResult - ( \(x :: SBV.RoundingMode) -> case x of - SBV.RoundNearestTiesToEven -> RNE - SBV.RoundNearestTiesToAway -> RNA - SBV.RoundTowardPositive -> RTP - SBV.RoundTowardNegative -> RTN - SBV.RoundTowardZero -> RTZ - ) - cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -7151,6 +7425,12 @@ instance SupportedNonFuncPrim FPRoundingMode where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @FPRoundingMode withNonFuncPrim r = r + sbvToCon mode = case mode of + SBV.RoundNearestTiesToEven -> RNE + SBV.RoundNearestTiesToAway -> RNA + SBV.RoundTowardPositive -> RTP + SBV.RoundTowardNegative -> RTN + SBV.RoundTowardZero -> RTZ -- AlgReal @@ -7169,9 +7449,6 @@ instance SupportedPrim AlgReal where conSBVTerm = SBV.literal . toSBVAlgReal symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ cv = - withPrim @AlgReal $ - parseScalarSMTModelResult fromSBVAlgReal cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -7190,6 +7467,92 @@ instance SupportedNonFuncPrim AlgReal where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @AlgReal withNonFuncPrim r = r + sbvToCon = fromSBVAlgReal + +-- Array + +pevalSelectTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v +pevalSelectTerm = selectTerm -- TODO: perform optimisation + +pevalStoreTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v -> + Term (Array k v) +pevalStoreTerm = storeTerm -- TODO: perform optimisation + +pevalConstArrayTerm :: + forall k v. + SupportedPrim (Array k v) => + Proxy k -> + Term v -> + Term (Array k v) +pevalConstArrayTerm = constArrayTerm -- TODO: perform optimisation + +instance SupportedPrimConstraint (Array k v) where + type PrimConstraint (Array k v) = + ( SupportedNonFuncPrim k + , SupportedNonFuncPrim v + , SBVT.SymVal (NonFuncSBVBaseType k) + , SBVT.SymVal (NonFuncSBVBaseType v) + ) + +instance SBVRep (Array k v) where + type SBVType (Array k v) = SBV.SArray (NonFuncSBVBaseType k) (NonFuncSBVBaseType v) + +instance + ( SupportedNonFuncPrim k + , SupportedNonFuncPrim v + ) => SupportedPrim (Array k v) where + defaultValue = Array mempty defaultValue + pevalITETerm = pevalITEBasicTerm + pevalEqTerm = pevalDefaultEqTerm + pevalDistinctTerm = pevalGeneralDistinct + conSBVTerm (Array entries def) = withNonFuncPrim @(Array k v) $ do + let root = SBV.constArray $ conSBVTerm def + let foldlWithKeyBy acc xs f = HM.foldlWithKey' f acc xs + foldlWithKeyBy root entries $ \acc key val -> do + SBV.writeArray acc (conSBVTerm key) (conSBVTerm val) + symSBVName x _ = show x + symSBVTerm = withNonFuncPrim @(Array k v) $ sbvFresh + withPrim = withNonFuncPrim @(Array k v) + sbvEq = withPrim @(Array k v) (SBV..==) + sbvDistinct = withPrim @(Array k v) $ SBV.distinct . toList + castTypedSymbol :: + forall knd' knd. + IsSymbolKind knd' => + TypedSymbol knd (Array k v) -> + Maybe (TypedSymbol knd' (Array k v)) + castTypedSymbol = pure . case decideSymbolKind @knd' of + Left HRefl -> TypedSymbol . unTypedSymbol + Right HRefl -> TypedSymbol . unTypedSymbol + funcDummyConstraint _ = SBV.sTrue + +instance + ( SupportedNonFuncPrim k, Ord k, Typeable k, Hashable k, Show k + , SupportedNonFuncPrim v, Ord v, Typeable v, Hashable v, Show v + ) => NonFuncSBVRep (Array k v) where + type NonFuncSBVBaseType (Array k v) = SBV.ArrayModel (NonFuncSBVBaseType k) (NonFuncSBVBaseType v) + +instance + ( SupportedNonFuncPrim k + , SupportedNonFuncPrim v + ) => SupportedNonFuncPrim (Array k v) where + conNonFuncSBVTerm = conSBVTerm + symNonFuncSBVTerm = withNonFuncPrim @(Array k v) sbvFresh + withNonFuncPrim = withNonFuncPrim @k $ withNonFuncPrim @v $ id + sbvToCon (SBV.ArrayModel tbl def) = do + -- NOTE: We reverse the list as later elements should take precedence. + let tbl' = HM.fromList . reverse . fmap (bimap sbvToCon sbvToCon) $ tbl + let def' = sbvToCon def + Array tbl' def' -- Bitwise diff --git a/src/Grisette/Internal/SymPrim/Prim/Pattern.hs b/src/Grisette/Internal/SymPrim/Prim/Pattern.hs index cc0c2f6e1..e5cbd6e2c 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Pattern.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Pattern.hs @@ -1,6 +1,8 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -- | @@ -19,6 +21,7 @@ where import Data.Foldable (Foldable (toList)) import Grisette.Internal.SymPrim.Prim.Internal.Term ( Term, + SupportedPrim (withPrim), pattern AbsNumTerm, pattern AddNumTerm, pattern AndBitsTerm, @@ -67,10 +70,13 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm)) -subTermsViewPattern :: Term a -> Maybe [SomeTerm] +subTermsViewPattern :: forall a. Term a -> Maybe [SomeTerm] subTermsViewPattern (ConTerm _) = return [] subTermsViewPattern (SymTerm _) = return [] subTermsViewPattern (ForallTerm _ t) = return [SomeTerm t] @@ -121,6 +127,9 @@ subTermsViewPattern (FPFMATerm rd t1 t2 t3) = subTermsViewPattern (FromIntegralTerm t) = return [SomeTerm t] subTermsViewPattern (FromFPOrTerm t1 rd t2) = return [SomeTerm t1, SomeTerm rd, SomeTerm t2] subTermsViewPattern (ToFPTerm rd t1 _ _) = return [SomeTerm rd, SomeTerm t1] +subTermsViewPattern (SelectTerm (t1 :: Term arr) t2) = withPrim @arr $ return [SomeTerm t1, SomeTerm t2] +subTermsViewPattern (StoreTerm t1 t2 t3) = withPrim @a $ return [SomeTerm t1, SomeTerm t2, SomeTerm t3] +subTermsViewPattern (ConstArrayTerm _ t1) = withPrim @a $ return [SomeTerm t1] -- | Extract all the subterms of a term. pattern SubTerms :: [SomeTerm] -> Term a diff --git a/src/Grisette/Internal/SymPrim/SymArray.hs b/src/Grisette/Internal/SymPrim/SymArray.hs new file mode 100644 index 000000000..c5391142b --- /dev/null +++ b/src/Grisette/Internal/SymPrim/SymArray.hs @@ -0,0 +1,163 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveLift #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | +-- Module : Grisette.Internal.SymPrim.Array +-- Copyright : (c) Sirui Lu 2021-2023 +-- License : BSD-3-Clause (see the LICENSE file) +-- +-- Maintainer : siruilu@cs.washington.edu +-- Stability : Experimental +-- Portability : GHC only +module Grisette.Internal.SymPrim.SymArray + ( SymArray (..) + , const + , select + , store + ) where + +import Control.DeepSeq (NFData) +import Data.Binary qualified as Binary +import Data.Bytes.Serial (Serial (deserialize, serialize)) +import Data.Data (Proxy(Proxy)) +import Data.Serialize qualified as Cereal +import Data.String (IsString (fromString)) +import Grisette.Internal.SymPrim.Array (Array) +import Grisette.Internal.SymPrim.Prim.Internal.Term + ( Term + , SupportedNonFuncPrim + , ConRep (ConType) + , SymRep (SymType) + , LinkedRep (underlyingTerm, wrapTerm) + , conTerm + , typedConstantSymbol + , symTerm + , pformatTerm + , pattern ConTerm + , pattern SelectTerm + , pattern StoreTerm + , pattern ConstArrayTerm + ) +import Grisette.Internal.SymPrim.Prim.Internal.Serialize () +import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con, sym, conView), ssym) +import GHC.Generics (Generic) +import Language.Haskell.TH.Syntax (Lift) +import Prelude (Show (show), Maybe (Just, Nothing), (<$>), ($), (.)) + +newtype SymArray k v = SymArray { underlyingArrayTerm :: Term (Array (ConType k) (ConType v)) } + deriving (Lift, NFData, Generic) + +instance ConRep (SymArray k v) where + type ConType (SymArray k v) = Array (ConType k) (ConType v) + +instance (SupportedNonFuncPrim k, SupportedNonFuncPrim v) => SymRep (Array k v) where + type SymType (Array k v) = SymArray (SymType k) (SymType v) + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + LinkedRep (Array ck cv) (SymArray sk sv) where + underlyingTerm = underlyingArrayTerm + wrapTerm = SymArray + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Solvable (Array ck cv) (SymArray sk sv) where + con = wrapTerm . conTerm + sym = wrapTerm . symTerm . typedConstantSymbol + conView v = case underlyingTerm v of + ConTerm t -> Just t + _ -> Nothing + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + IsString (SymArray sk sv) where + fromString = ssym . fromString + +instance Show (SymArray sk sv) where + show = pformatTerm . underlyingArrayTerm + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Serial (SymArray sk sv) where + serialize = serialize . underlyingTerm + deserialize = wrapTerm <$> deserialize + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Cereal.Serialize (SymArray sk sv) where + put = serialize + get = deserialize + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Binary.Binary (SymArray sk sv) where + put = serialize + get = deserialize + +const + :: forall k v + . SupportedNonFuncPrim (ConType k) + => SupportedNonFuncPrim (ConType v) + => LinkedRep (ConType k) k + => LinkedRep (ConType v) v + => v + -> SymArray k v +const val = wrapTerm $ ConstArrayTerm Proxy (underlyingTerm val) + +select + :: forall k v + . SupportedNonFuncPrim (ConType k) + => SupportedNonFuncPrim (ConType v) + => LinkedRep (ConType k) k + => LinkedRep (ConType v) v + => SymArray k v + -> k + -> v +select arr key = wrapTerm $ SelectTerm (underlyingTerm arr) (underlyingTerm key) + +store + :: forall k v + . SupportedNonFuncPrim (ConType k) + => SupportedNonFuncPrim (ConType v) + => LinkedRep (ConType k) k + => LinkedRep (ConType v) v + => SymArray k v + -> k + -> v + -> SymArray k v +store arr key val = do + wrapTerm $ StoreTerm (underlyingTerm arr) (underlyingTerm key) (underlyingTerm val) diff --git a/src/Grisette/SymPrim.hs b/src/Grisette/SymPrim.hs index 6ca8637d5..97684f6a1 100644 --- a/src/Grisette/SymPrim.hs +++ b/src/Grisette/SymPrim.hs @@ -290,6 +290,9 @@ module Grisette.SymPrim pattern FromIntegralTerm, pattern FromFPOrTerm, pattern ToFPTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, pattern SubTerms, ) where @@ -416,6 +419,9 @@ import Grisette.Internal.SymPrim.Prim.Term pattern SupportedTypedSymbol, pattern SymTerm, pattern ToFPTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, pattern XorBitsTerm, ) import Grisette.Internal.SymPrim.Prim.TermUtils diff --git a/stack-9.10.yaml b/stack-9.10.yaml index bb503c62e..e64a01b13 100644 --- a/stack-9.10.yaml +++ b/stack-9.10.yaml @@ -33,13 +33,11 @@ packages: # These entries can reference officially published versions as well as # forks / in-progress versions pinned to a git hash. For example: # -# extra-deps: -# - acme-missiles-0.3 -# - git: https://github.com/commercialhaskell/stack.git -# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a # # Override default flag values for local packages and extra-deps # flags: {} +extra-deps: +- sbv-13.5 # Extra package databases containing global packages # extra-package-dbs: [] diff --git a/stack.yaml.lock b/stack.yaml.lock index a1b57cf9c..9d1442db9 100644 --- a/stack.yaml.lock +++ b/stack.yaml.lock @@ -3,7 +3,14 @@ # For more information, please see the documentation at: # https://docs.haskellstack.org/en/stable/topics/lock_files -packages: [] +packages: +- completed: + hackage: sbv-13.5@sha256:00aea86ad09dcefb5d5286dd68bfb894709d2874052bcdadf80af3add9596af7,26690 + pantry-tree: + sha256: 65d23f68088ce549e8aa8ae77a35f2ec7c6491facbc4aeae8c2317101d9ddcda + size: 93933 + original: + hackage: sbv-13.5 snapshots: - completed: sha256: 7a26eba54b469fc72b1e37b881dfec480a2c1cb0636136f96aec7d81be6c762f