diff --git a/CHANGELOG.md b/CHANGELOG.md index c70376c3..32b692e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Improved the partial evaluation for bit vectors. ([#176](https://github.com/lsrcz/grisette/pull/176)) - Added `symRotateNegated` and `symShiftNegated`. ([#181](https://github.com/lsrcz/grisette/pull/181)) - Added `mrg` and `sym` variants for all reasonable operations from - `Control.Monad`, `Control.Applicative`, `Data.Foldable`, `Data.List`, and `Data.Traversable`. - ([#182](https://github.com/lsrcz/grisette/pull/182)) + `Control.Monad`, `Control.Applicative`, `Data.Foldable`, `Data.List`, `Data.Traversable`, + and `Data.Function`. + ([#182](https://github.com/lsrcz/grisette/pull/182), + [#186](https://github.com/lsrcz/grisette/pull/186)) - Added `mrgIfPropagatedStrategy`. ([#184](https://github.com/lsrcz/grisette/pull/184)) +- Added instances for `Const`. ([#186](https://github.com/lsrcz/grisette/pull/186)) +- Added support for [microlens](https://hackage.haskell.org/package/microlens). ([#186](https://github.com/lsrcz/grisette/pull/186)) - Added `freshString`. ([#188](https://github.com/lsrcz/grisette/pull/188)) ### Fixed diff --git a/flake.nix b/flake.nix index 37913076..a2905241 100644 --- a/flake.nix +++ b/flake.nix @@ -8,14 +8,14 @@ let pkgs = nixpkgs.legacyPackages.${system}; - hPkgs = pkgs.haskell.packages."ghc982"; + hPkgs = pkgs.haskell.packages."ghc964"; myDevTools = [ hPkgs.ghc # GHC compiler in the desired version (will be available on PATH) # hPkgs.ghcid # Continuous terminal Haskell compile checker # hPkgs.ormolu # Haskell formatter - # hPkgs.hlint # Haskell codestyle checker - # hPkgs.haskell-language-server # LSP server for editor + hPkgs.hlint # Haskell codestyle checker + hPkgs.haskell-language-server # LSP server for editor hPkgs.cabal-install stack-wrapped # (pkgs.ihaskell.override { diff --git a/grisette.cabal b/grisette.cabal index 877cb7f7..32cecfa7 100644 --- a/grisette.cabal +++ b/grisette.cabal @@ -135,12 +135,15 @@ library Grisette.Lib.Data.Bool Grisette.Lib.Data.Either Grisette.Lib.Data.Foldable + Grisette.Lib.Data.Function Grisette.Lib.Data.Functor + Grisette.Lib.Data.Functor.Const Grisette.Lib.Data.Functor.Sum Grisette.Lib.Data.List Grisette.Lib.Data.Maybe Grisette.Lib.Data.Traversable Grisette.Lib.Data.Tuple + Grisette.Lib.Lens.Micro Grisette.Qualified.ParallelUnionDo Grisette.Utils Grisette.Utils.Parameterized @@ -161,6 +164,8 @@ library , hashtables >=1.2.3.4 && <1.4 , intern >=0.9.2 && <0.10 , loch-th >=0.2.2 && <0.3 + , microlens + , microlens-th , mtl >=2.2.2 && <2.4 , parallel >=3.2.2.0 && <3.3 , prettyprinter >=1.5.0 && <1.8 @@ -200,6 +205,8 @@ test-suite doctest , hashtables >=1.2.3.4 && <1.4 , intern >=0.9.2 && <0.10 , loch-th >=0.2.2 && <0.3 + , microlens + , microlens-th , mtl >=2.2.2 && <2.4 , parallel >=3.2.2.0 && <3.3 , prettyprinter >=1.5.0 && <1.8 @@ -269,9 +276,11 @@ test-suite spec Grisette.Lib.Control.Monad.Trans.State.StrictTests Grisette.Lib.Control.MonadTests Grisette.Lib.Data.FoldableTests + Grisette.Lib.Data.FunctionTests Grisette.Lib.Data.FunctorTests Grisette.Lib.Data.ListTests Grisette.Lib.Data.TraversableTests + Grisette.Lib.Lens.MicroTests Grisette.TestUtil.NoMerge Grisette.TestUtil.PrettyPrint Grisette.TestUtil.SymbolicAssertion @@ -293,6 +302,8 @@ test-suite spec , hashtables >=1.2.3.4 && <1.4 , intern >=0.9.2 && <0.10 , loch-th >=0.2.2 && <0.3 + , microlens + , microlens-th , mtl >=2.2.2 && <2.4 , parallel >=3.2.2.0 && <3.3 , prettyprinter >=1.5.0 && <1.8 diff --git a/package.yaml b/package.yaml index 8b316476..698ec0a4 100644 --- a/package.yaml +++ b/package.yaml @@ -50,6 +50,8 @@ dependencies: - prettyprinter >= 1.5.0 && < 1.8 - async >= 2.2.2 && < 2.3 - stm >= 2.5 && < 2.6 + - microlens + - microlens-th flags: { diff --git a/src/Grisette/Core/Control/Monad/UnionM.hs b/src/Grisette/Core/Control/Monad/UnionM.hs index cf9076eb..c91c4afa 100644 --- a/src/Grisette/Core/Control/Monad/UnionM.hs +++ b/src/Grisette/Core/Control/Monad/UnionM.hs @@ -534,6 +534,12 @@ instance (LogicalOp a, Mergeable a) => LogicalOp (UnionM a) where symXor = unionMBinOp symXor symImplies = unionMBinOp symImplies +instance (Monoid a, Mergeable a) => Monoid (UnionM a) where + mempty = mrgSingle mempty + +instance (Monoid a, Mergeable a) => Semigroup (UnionM a) where + (<>) = unionMBinOp (<>) + instance (Solvable c t, Mergeable t) => Solvable c (UnionM t) where con = mrgSingle . con {-# INLINE con #-} diff --git a/src/Grisette/Core/Data/Class/BitVector.hs b/src/Grisette/Core/Data/Class/BitVector.hs index 0601c000..24636770 100644 --- a/src/Grisette/Core/Data/Class/BitVector.hs +++ b/src/Grisette/Core/Data/Class/BitVector.hs @@ -27,6 +27,7 @@ module Grisette.Core.Data.Class.BitVector ) where +import Data.Functor.Const (Const (Const)) import Data.Proxy (Proxy (Proxy)) import GHC.TypeNats (KnownNat, type (+), type (-), type (<=)) import Grisette.Utils.Parameterized @@ -129,6 +130,25 @@ class BV bv where a -> bv +instance (BV a) => BV (Const a b) where + bvConcat (Const a) (Const b) = Const (bvConcat a b) + {-# INLINE bvConcat #-} + + bvZext n (Const a) = Const (bvZext n a) + {-# INLINE bvZext #-} + + bvSext n (Const a) = Const (bvSext n a) + {-# INLINE bvSext #-} + + bvExt n (Const a) = Const (bvExt n a) + {-# INLINE bvExt #-} + + bvSelect i j (Const a) = Const (bvSelect i j a) + {-# INLINE bvSelect #-} + + bv i w = Const (bv i w) + {-# INLINE bv #-} + -- | Slicing out a smaller bit vector from a larger one, extract a slice from -- bit @i@ down to @j@. -- diff --git a/src/Grisette/Core/Data/Class/EvaluateSym.hs b/src/Grisette/Core/Data/Class/EvaluateSym.hs index 67d1cf0b..9e45457a 100644 --- a/src/Grisette/Core/Data/Class/EvaluateSym.hs +++ b/src/Grisette/Core/Data/Class/EvaluateSym.hs @@ -34,6 +34,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import Data.Maybe (fromJust) @@ -274,3 +275,8 @@ instance (EvaluateSym' a, EvaluateSym' b) => EvaluateSym' (a :+: b) where instance (EvaluateSym' a, EvaluateSym' b) => EvaluateSym' (a :*: b) where evaluateSym' fillDefault model (a :*: b) = evaluateSym' fillDefault model a :*: evaluateSym' fillDefault model b + +-- Const +instance (EvaluateSym a) => EvaluateSym (Const a b) where + evaluateSym fillDefault model (Const a) = + Const $ evaluateSym fillDefault model a diff --git a/src/Grisette/Core/Data/Class/ExtractSymbolics.hs b/src/Grisette/Core/Data/Class/ExtractSymbolics.hs index b19aa1cd..bfafb513 100644 --- a/src/Grisette/Core/Data/Class/ExtractSymbolics.hs +++ b/src/Grisette/Core/Data/Class/ExtractSymbolics.hs @@ -33,6 +33,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import qualified Data.Text as T @@ -311,3 +312,7 @@ instance ExtractSymbolics' (a :*: b) where extractSymbolics' (l :*: r) = extractSymbolics' l <> extractSymbolics' r + +-- Const +instance (ExtractSymbolics a) => ExtractSymbolics (Const a b) where + extractSymbolics (Const v) = extractSymbolics v diff --git a/src/Grisette/Core/Data/Class/GPretty.hs b/src/Grisette/Core/Data/Class/GPretty.hs index dab16eba..e9b9156f 100644 --- a/src/Grisette/Core/Data/Class/GPretty.hs +++ b/src/Grisette/Core/Data/Class/GPretty.hs @@ -28,6 +28,7 @@ import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as C +import Data.Functor.Const (Const) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import Data.String (IsString (fromString)) @@ -322,6 +323,12 @@ instance (GPretty (m a)) => GPretty (IdentityT m a) where gprettyPrec 11 a ] +-- Const +deriving via + (Default (Const a b)) + instance + (GPretty a) => GPretty (Const a b) + -- Prettyprint #define GPRETTY_SYM_SIMPLE(symtype) \ instance GPretty symtype where \ diff --git a/src/Grisette/Core/Data/Class/GenSym.hs b/src/Grisette/Core/Data/Class/GenSym.hs index c673ffa4..ae13c8b0 100644 --- a/src/Grisette/Core/Data/Class/GenSym.hs +++ b/src/Grisette/Core/Data/Class/GenSym.hs @@ -97,6 +97,7 @@ import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import Data.Bifunctor (Bifunctor (first)) import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Hashable (Hashable) import Data.Int (Int16, Int32, Int64, Int8) import Data.String (IsString (fromString)) @@ -1720,3 +1721,15 @@ instance where go (UnionSingle x) = fresh x go (UnionIf _ _ _ t f) = mrgIf <$> simpleFresh () <*> go t <*> go f + +instance (GenSym spec a) => GenSym spec (Const a b) where + fresh spec = do + u <- fresh spec + return $ do + a <- u + mrgSingle $ Const a + {-# INLINE fresh #-} + +instance (GenSymSimple spec a) => GenSymSimple spec (Const a b) where + simpleFresh spec = Const <$> simpleFresh spec + {-# INLINE simpleFresh #-} diff --git a/src/Grisette/Core/Data/Class/ITEOp.hs b/src/Grisette/Core/Data/Class/ITEOp.hs index 5add227d..757ab2b8 100644 --- a/src/Grisette/Core/Data/Class/ITEOp.hs +++ b/src/Grisette/Core/Data/Class/ITEOp.hs @@ -17,6 +17,7 @@ module Grisette.Core.Data.Class.ITEOp ) where +import Data.Functor.Const (Const (Const)) import GHC.TypeNats (KnownNat, type (<=)) import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term ( LinkedRep, @@ -75,3 +76,7 @@ ITEOP_BV(SymWordN) ITEOP_FUN(=~>, SymTabularFun) ITEOP_FUN(-~>, SymGeneralFun) #endif + +instance (ITEOp a) => ITEOp (Const a b) where + symIte c (Const t) (Const f) = Const $ symIte c t f + {-# INLINE symIte #-} diff --git a/src/Grisette/Core/Data/Class/LogicalOp.hs b/src/Grisette/Core/Data/Class/LogicalOp.hs index 22b298f8..b42c1afc 100644 --- a/src/Grisette/Core/Data/Class/LogicalOp.hs +++ b/src/Grisette/Core/Data/Class/LogicalOp.hs @@ -3,6 +3,7 @@ module Grisette.Core.Data.Class.LogicalOp ) where +import Data.Functor.Const (Const (Const)) import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool ( pevalAndTerm, pevalImplyTerm, @@ -104,3 +105,15 @@ instance LogicalOp SymBool where symNot (SymBool v) = SymBool $ pevalNotTerm v (SymBool l) `symXor` (SymBool r) = SymBool $ pevalXorTerm l r (SymBool l) `symImplies` (SymBool r) = SymBool $ pevalImplyTerm l r + +instance (LogicalOp a) => LogicalOp (Const a b) where + (.||) (Const a) (Const b) = Const $ a .|| b + {-# INLINE (.||) #-} + (.&&) (Const a) (Const b) = Const $ a .&& b + {-# INLINE (.&&) #-} + symNot (Const a) = Const $ symNot a + {-# INLINE symNot #-} + symXor (Const a) (Const b) = Const $ symXor a b + {-# INLINE symXor #-} + symImplies (Const a) (Const b) = Const $ symImplies a b + {-# INLINE symImplies #-} diff --git a/src/Grisette/Core/Data/Class/Mergeable.hs b/src/Grisette/Core/Data/Class/Mergeable.hs index 73a1bcfc..182ff1d1 100644 --- a/src/Grisette/Core/Data/Class/Mergeable.hs +++ b/src/Grisette/Core/Data/Class/Mergeable.hs @@ -83,6 +83,7 @@ import Data.Functor.Classes eq1, showsPrec1, ) +import Data.Functor.Const (Const (Const, getConst)) import Data.Functor.Sum (Sum (InL, InR)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Kind (Type) @@ -1038,3 +1039,14 @@ instance (Mergeable' a, Mergeable' b) => Mergeable' (a :+: b) where else wrapStrategy rootStrategy' R1 (\case (R1 v) -> v; _ -> undefined) ) {-# INLINE rootStrategy' #-} + +deriving via + (Default (Const a b)) + instance + (Mergeable a) => Mergeable (Const a b) + +deriving via (Default1 (Const a)) instance (Mergeable a) => Mergeable1 (Const a) + +instance Mergeable2 Const where + liftRootStrategy2 sa _ = wrapStrategy sa Const getConst + {-# INLINE liftRootStrategy2 #-} diff --git a/src/Grisette/Core/Data/Class/SEq.hs b/src/Grisette/Core/Data/Class/SEq.hs index 5aabb991..6adac873 100644 --- a/src/Grisette/Core/Data/Class/SEq.hs +++ b/src/Grisette/Core/Data/Class/SEq.hs @@ -35,6 +35,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import qualified Data.Text as T @@ -282,3 +283,8 @@ instance (SEq' a, SEq' b) => SEq' (a :*: b) where instance (Generic a, SEq' (Rep a)) => SEq (Default a) where Default l .== Default r = from l ..== from r {-# INLINE (.==) #-} + +deriving via + (Default (Const a b)) + instance + (SEq a) => SEq (Const a b) diff --git a/src/Grisette/Core/Data/Class/SOrd.hs b/src/Grisette/Core/Data/Class/SOrd.hs index c52010dc..76fe3726 100644 --- a/src/Grisette/Core/Data/Class/SOrd.hs +++ b/src/Grisette/Core/Data/Class/SOrd.hs @@ -39,6 +39,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import qualified Data.Text as T @@ -390,6 +391,19 @@ instance (SOrd a, Mergeable a) => SOrd (UnionM a) where y1 <- tryMerge y x1 `symCompare` y1 +-- | Const +instance (SOrd a) => SOrd (Const a b) where + (Const l) .<= (Const r) = l .<= r + {-# INLINE (.<=) #-} + (Const l) .< (Const r) = l .< r + {-# INLINE (.<) #-} + (Const l) .>= (Const r) = l .>= r + {-# INLINE (.>=) #-} + (Const l) .> (Const r) = l .> r + {-# INLINE (.>) #-} + (Const l) `symCompare` (Const r) = l `symCompare` r + {-# INLINE symCompare #-} + -- | Auxiliary class for 'SOrd' instance derivation class (SEq' f) => SOrd' f where -- | Auxiliary function for '(..<) derivation diff --git a/src/Grisette/Core/Data/Class/SafeDivision.hs b/src/Grisette/Core/Data/Class/SafeDivision.hs index 5173dc8e..0d0296b0 100644 --- a/src/Grisette/Core/Data/Class/SafeDivision.hs +++ b/src/Grisette/Core/Data/Class/SafeDivision.hs @@ -27,6 +27,7 @@ where import Control.Exception (ArithException (DivideByZero, Overflow, Underflow)) import Control.Monad.Except (MonadError (throwError)) +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) import GHC.TypeNats (KnownNat, type (<=)) @@ -290,3 +291,21 @@ instance SAFE_DIVISION_SYMBOLIC_FUNC2(safeDivMod, SymWordN, pevalDivIntegralTerm, pevalModIntegralTerm) SAFE_DIVISION_SYMBOLIC_FUNC2(safeQuotRem, SymWordN, pevalQuotIntegralTerm, pevalRemIntegralTerm) #endif + +instance (SafeDivision e a m) => SafeDivision e (Const a b) m where + safeDiv (Const c) (Const r) = mrgFmap Const $ safeDiv c r + {-# INLINE safeDiv #-} + safeMod (Const c) (Const r) = mrgFmap Const $ safeMod c r + {-# INLINE safeMod #-} + safeDivMod (Const c) (Const r) = do + (d, m) <- safeDivMod c r + mrgReturn (Const d, Const m) + {-# INLINE safeDivMod #-} + safeQuot (Const c) (Const r) = mrgFmap Const $ safeQuot c r + {-# INLINE safeQuot #-} + safeRem (Const c) (Const r) = mrgFmap Const $ safeRem c r + {-# INLINE safeRem #-} + safeQuotRem (Const c) (Const r) = do + (q, m) <- safeQuotRem c r + mrgReturn (Const q, Const m) + {-# INLINE safeQuotRem #-} diff --git a/src/Grisette/Core/Data/Class/SafeLinearArith.hs b/src/Grisette/Core/Data/Class/SafeLinearArith.hs index 06764d9f..192d14e3 100644 --- a/src/Grisette/Core/Data/Class/SafeLinearArith.hs +++ b/src/Grisette/Core/Data/Class/SafeLinearArith.hs @@ -26,6 +26,7 @@ where import Control.Exception (ArithException (DivideByZero, Overflow, Underflow)) import Control.Monad.Except (MonadError (throwError)) +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) import GHC.TypeNats (KnownNat, type (<=)) @@ -53,7 +54,7 @@ import Grisette.IR.SymPrim.Data.SymPrim SymInteger, SymWordN, ) -import Grisette.Lib.Control.Monad (mrgReturn) +import Grisette.Lib.Control.Monad (mrgFmap, mrgReturn) import Grisette.Lib.Control.Monad.Except (mrgThrowError) -- $setup @@ -224,3 +225,11 @@ instance (mrgSingle res) where res = ls - rs + +instance (SafeLinearArith e a m) => SafeLinearArith e (Const a b) m where + safeAdd (Const a) (Const b) = mrgFmap Const $ safeAdd a b + {-# INLINE safeAdd #-} + safeNeg (Const a) = mrgFmap Const $ safeNeg a + {-# INLINE safeNeg #-} + safeSub (Const a) (Const b) = mrgFmap Const $ safeSub a b + {-# INLINE safeSub #-} diff --git a/src/Grisette/Core/Data/Class/SafeSymRotate.hs b/src/Grisette/Core/Data/Class/SafeSymRotate.hs index 30df07be..f9dec465 100644 --- a/src/Grisette/Core/Data/Class/SafeSymRotate.hs +++ b/src/Grisette/Core/Data/Class/SafeSymRotate.hs @@ -21,6 +21,7 @@ module Grisette.Core.Data.Class.SafeSymRotate (SafeSymRotate (..)) where import Control.Exception (ArithException (Overflow)) import Control.Monad.Error.Class (MonadError) import Data.Bits (Bits (rotateL, rotateR), FiniteBits (finiteBitSize)) +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) import GHC.TypeLits (KnownNat, type (<=)) @@ -38,7 +39,7 @@ import Grisette.IR.SymPrim.Data.SymPrim ( SymIntN (SymIntN), SymWordN (SymWordN), ) -import Grisette.Lib.Control.Monad (mrgReturn) +import Grisette.Lib.Control.Monad (mrgFmap, mrgReturn) import Grisette.Lib.Control.Monad.Except (mrgThrowError) -- | Safe rotation operations. The operators will reject negative shift amounts. @@ -134,3 +135,9 @@ instance (r .< 0) (mrgThrowError Overflow) (mrgReturn $ SymIntN $ pevalRotateRightTerm ta tr) + +instance (SafeSymRotate e a m) => SafeSymRotate e (Const a b) m where + safeSymRotateL (Const a) (Const b) = mrgFmap Const $ safeSymRotateL a b + {-# INLINE safeSymRotateL #-} + safeSymRotateR (Const a) (Const b) = mrgFmap Const $ safeSymRotateR a b + {-# INLINE safeSymRotateR #-} diff --git a/src/Grisette/Core/Data/Class/SafeSymShift.hs b/src/Grisette/Core/Data/Class/SafeSymShift.hs index e597e7e2..2af24812 100644 --- a/src/Grisette/Core/Data/Class/SafeSymShift.hs +++ b/src/Grisette/Core/Data/Class/SafeSymShift.hs @@ -19,6 +19,7 @@ where import Control.Exception (ArithException (Overflow)) import Control.Monad.Error.Class (MonadError) import Data.Bits (Bits (shiftL, shiftR), FiniteBits (finiteBitSize)) +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) import GHC.TypeLits (KnownNat, type (<=)) @@ -40,7 +41,7 @@ import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits pevalShiftRightTerm, ) import Grisette.IR.SymPrim.Data.SymPrim (SymIntN (SymIntN), SymWordN (SymWordN)) -import Grisette.Lib.Control.Monad (mrgReturn) +import Grisette.Lib.Control.Monad (mrgFmap, mrgReturn) import Grisette.Lib.Control.Monad.Except (mrgThrowError) -- | Safe version for `shiftL` or `shiftR`. @@ -185,3 +186,15 @@ instance (return $ SymIntN $ pevalShiftRightTerm ta ts) where bs = fromIntegral (finiteBitSize a) + +instance (SafeSymShift e a m) => SafeSymShift e (Const a b) m where + safeSymShiftL (Const a) (Const b) = mrgFmap Const $ safeSymShiftL a b + {-# INLINE safeSymShiftL #-} + safeSymShiftR (Const a) (Const b) = mrgFmap Const $ safeSymShiftR a b + {-# INLINE safeSymShiftR #-} + safeSymStrictShiftL (Const a) (Const b) = + mrgFmap Const $ safeSymStrictShiftL a b + {-# INLINE safeSymStrictShiftL #-} + safeSymStrictShiftR (Const a) (Const b) = + mrgFmap Const $ safeSymStrictShiftR a b + {-# INLINE safeSymStrictShiftR #-} diff --git a/src/Grisette/Core/Data/Class/SignConversion.hs b/src/Grisette/Core/Data/Class/SignConversion.hs index 58693060..46f7bf2c 100644 --- a/src/Grisette/Core/Data/Class/SignConversion.hs +++ b/src/Grisette/Core/Data/Class/SignConversion.hs @@ -1,10 +1,12 @@ {-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE UndecidableInstances #-} module Grisette.Core.Data.Class.SignConversion ( SignConversion (..), ) where +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) @@ -35,3 +37,10 @@ instance SignConversion Word64 Int64 where instance SignConversion Word Int where toSigned = fromIntegral toUnsigned = fromIntegral + +instance + (SignConversion au ai) => + SignConversion (Const au b) (Const ai b) + where + toSigned (Const a) = Const $ toSigned a + toUnsigned (Const a) = Const $ toUnsigned a diff --git a/src/Grisette/Core/Data/Class/SimpleMergeable.hs b/src/Grisette/Core/Data/Class/SimpleMergeable.hs index 9e680eb5..c9d59858 100644 --- a/src/Grisette/Core/Data/Class/SimpleMergeable.hs +++ b/src/Grisette/Core/Data/Class/SimpleMergeable.hs @@ -49,6 +49,7 @@ import Control.Monad.Trans.Cont (ContT (ContT)) import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict +import Data.Functor.Const (Const (Const)) import Data.Kind (Type) import GHC.Generics ( Generic (Rep, from, to), @@ -627,3 +628,17 @@ SIMPLE_MERGEABLE_FUN(-~>) -- Exception deriving via (Default AssertionError) instance SimpleMergeable AssertionError + +-- Const +deriving via + (Default (Const a b)) + instance + (SimpleMergeable a) => SimpleMergeable (Const a b) + +instance (SimpleMergeable a) => SimpleMergeable1 (Const a) where + liftMrgIte _ cond (Const l) (Const r) = Const $ mrgIte cond l r + {-# INLINE liftMrgIte #-} + +instance SimpleMergeable2 Const where + liftMrgIte2 ma _ cond (Const a1) (Const a2) = Const $ ma cond a1 a2 + {-# INLINE liftMrgIte2 #-} diff --git a/src/Grisette/Core/Data/Class/Solvable.hs b/src/Grisette/Core/Data/Class/Solvable.hs index 2f664db8..d02a18da 100644 --- a/src/Grisette/Core/Data/Class/Solvable.hs +++ b/src/Grisette/Core/Data/Class/Solvable.hs @@ -21,6 +21,7 @@ module Grisette.Core.Data.Class.Solvable ) where +import Data.Functor.Const (Const (Const)) import Data.String (IsString) import qualified Data.Text as T @@ -80,3 +81,13 @@ pattern Con c <- (conView -> Just c) where Con c = con c + +instance (Solvable c t) => Solvable (Const c b) (Const t b) where + con (Const c) = Const $ con c + {-# INLINE con #-} + conView (Const t) = Const <$> conView t + {-# INLINE conView #-} + ssym = Const . ssym + {-# INLINE ssym #-} + isym symbol = Const . isym symbol + {-# INLINE isym #-} diff --git a/src/Grisette/Core/Data/Class/SubstituteSym.hs b/src/Grisette/Core/Data/Class/SubstituteSym.hs index 9bd57eb7..2eaa0edd 100644 --- a/src/Grisette/Core/Data/Class/SubstituteSym.hs +++ b/src/Grisette/Core/Data/Class/SubstituteSym.hs @@ -34,6 +34,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import qualified Data.Text as T @@ -257,6 +258,11 @@ instance (SubstituteSym a) => SubstituteSym (Identity a) where instance (SubstituteSym (m a)) => SubstituteSym (IdentityT m a) where substituteSym sym val (IdentityT a) = IdentityT $ substituteSym sym val a +-- Const +instance (SubstituteSym a) => SubstituteSym (Const a b) where + substituteSym sym val (Const a) = Const $ substituteSym sym val a + {-# INLINE substituteSym #-} + #define SUBSTITUTE_SYM_SIMPLE(symtype) \ instance SubstituteSym symtype where \ substituteSym sym v (symtype t) = symtype $ substTerm sym (underlyingTerm v) t diff --git a/src/Grisette/Core/Data/Class/SymRotate.hs b/src/Grisette/Core/Data/Class/SymRotate.hs index 90a26d67..4b587f8f 100644 --- a/src/Grisette/Core/Data/Class/SymRotate.hs +++ b/src/Grisette/Core/Data/Class/SymRotate.hs @@ -13,6 +13,7 @@ module Grisette.Core.Data.Class.SymRotate where import Data.Bits (Bits (isSigned, rotate), FiniteBits (finiteBitSize)) +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) @@ -90,3 +91,9 @@ deriving via (DefaultFiniteBitsSymRotate Word32) instance SymRotate Word32 deriving via (DefaultFiniteBitsSymRotate Word64) instance SymRotate Word64 deriving via (DefaultFiniteBitsSymRotate Word) instance SymRotate Word + +instance (SymRotate a) => SymRotate (Const a b) where + symRotate (Const a) (Const b) = Const $ symRotate a b + {-# INLINE symRotate #-} + symRotateNegated (Const a) (Const b) = Const $ symRotateNegated a b + {-# INLINE symRotateNegated #-} diff --git a/src/Grisette/Core/Data/Class/SymShift.hs b/src/Grisette/Core/Data/Class/SymShift.hs index 06275d4c..634e6766 100644 --- a/src/Grisette/Core/Data/Class/SymShift.hs +++ b/src/Grisette/Core/Data/Class/SymShift.hs @@ -13,6 +13,7 @@ module Grisette.Core.Data.Class.SymShift where import Data.Bits (Bits (isSigned, shift, shiftR), FiniteBits (finiteBitSize)) +import Data.Functor.Const (Const (Const)) import Data.Int (Int16, Int32, Int64, Int8) import Data.Word (Word16, Word32, Word64, Word8) @@ -105,3 +106,10 @@ deriving via (DefaultFiniteBitsSymShift Word32) instance SymShift Word32 deriving via (DefaultFiniteBitsSymShift Word64) instance SymShift Word64 deriving via (DefaultFiniteBitsSymShift Word) instance SymShift Word + +-- Const +instance (SymShift a) => SymShift (Const a b) where + symShift (Const a) (Const b) = Const $ symShift a b + {-# INLINE symShift #-} + symShiftNegated (Const a) (Const b) = Const $ symShiftNegated a b + {-# INLINE symShiftNegated #-} diff --git a/src/Grisette/Core/Data/Class/ToCon.hs b/src/Grisette/Core/Data/Class/ToCon.hs index da97bbc8..ba91fb63 100644 --- a/src/Grisette/Core/Data/Class/ToCon.hs +++ b/src/Grisette/Core/Data/Class/ToCon.hs @@ -36,6 +36,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import qualified Data.Text as T @@ -302,6 +303,10 @@ deriving via instance ToCon VerificationConditions VerificationConditions +instance (ToCon as a) => ToCon (Const as bs) (Const a b) where + toCon (Const a) = Const <$> toCon a + {-# INLINE toCon #-} + -- Derivation of ToCon for generic types instance (Generic a, Generic b, ToCon' (Rep a) (Rep b)) => ToCon a (Default b) where toCon v = fmap (Default . to) $ toCon' $ from v diff --git a/src/Grisette/Core/Data/Class/ToSym.hs b/src/Grisette/Core/Data/Class/ToSym.hs index 877a0163..3c62acfa 100644 --- a/src/Grisette/Core/Data/Class/ToSym.hs +++ b/src/Grisette/Core/Data/Class/ToSym.hs @@ -37,6 +37,7 @@ import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict import qualified Data.ByteString as B +import Data.Functor.Const (Const (Const)) import Data.Functor.Sum (Sum) import Data.Int (Int16, Int32, Int64, Int8) import qualified Data.Text as T @@ -218,6 +219,11 @@ instance (ToSym a b) => ToSym (Identity a) (Identity b) where instance (ToSym (m a) (m1 b)) => ToSym (IdentityT m a) (IdentityT m1 b) where toSym (IdentityT v) = IdentityT $ toSym v +-- Const +instance (ToSym a as) => ToSym (Const a b) (Const as bs) where + toSym (Const a) = Const $ toSym a + {-# INLINE toSym #-} + #define TO_SYM_SYMID_SIMPLE(symtype) \ instance ToSym symtype symtype where \ toSym = id diff --git a/src/Grisette/Core/Data/Class/TryMerge.hs b/src/Grisette/Core/Data/Class/TryMerge.hs index 0f21641a..d8bec10d 100644 --- a/src/Grisette/Core/Data/Class/TryMerge.hs +++ b/src/Grisette/Core/Data/Class/TryMerge.hs @@ -32,6 +32,7 @@ import qualified Control.Monad.State.Strict as StateStrict import Control.Monad.Trans.Maybe (MaybeT (MaybeT)) import qualified Control.Monad.Writer.Lazy as WriterLazy import qualified Control.Monad.Writer.Strict as WriterStrict +import Data.Functor.Const (Const) import Data.Functor.Sum (Sum (InL, InR)) import qualified Data.Monoid as Monoid import Grisette.Core.Data.Class.Mergeable @@ -208,3 +209,7 @@ instance (TryMerge f, TryMerge g) => TryMerge (Sum f g) where instance TryMerge Monoid.Sum where tryMergeWithStrategy _ = id {-# INLINE tryMergeWithStrategy #-} + +instance TryMerge (Const a) where + tryMergeWithStrategy _ = id + {-# INLINE tryMergeWithStrategy #-} diff --git a/src/Grisette/Lib/Base.hs b/src/Grisette/Lib/Base.hs index 86dc7a7c..d9dce0e5 100644 --- a/src/Grisette/Lib/Base.hs +++ b/src/Grisette/Lib/Base.hs @@ -15,7 +15,9 @@ module Grisette.Lib.Base module Grisette.Lib.Control.Monad, module Grisette.Lib.Data.Either, module Grisette.Lib.Data.Foldable, + module Grisette.Lib.Data.Function, module Grisette.Lib.Data.Functor, + module Grisette.Lib.Data.Functor.Const, module Grisette.Lib.Data.Functor.Sum, module Grisette.Lib.Data.List, module Grisette.Lib.Data.Maybe, @@ -28,7 +30,9 @@ import Grisette.Lib.Control.Applicative import Grisette.Lib.Control.Monad import Grisette.Lib.Data.Either import Grisette.Lib.Data.Foldable +import Grisette.Lib.Data.Function import Grisette.Lib.Data.Functor +import Grisette.Lib.Data.Functor.Const import Grisette.Lib.Data.Functor.Sum import Grisette.Lib.Data.List import Grisette.Lib.Data.Maybe diff --git a/src/Grisette/Lib/Data/Function.hs b/src/Grisette/Lib/Data/Function.hs new file mode 100644 index 00000000..fb119f05 --- /dev/null +++ b/src/Grisette/Lib/Data/Function.hs @@ -0,0 +1,29 @@ +module Grisette.Lib.Data.Function ((.$), (.&), mrgOn) where + +import Grisette.Core.Control.Monad.UnionM (UnionM) +import Grisette.Core.Data.Class.Mergeable (Mergeable) +import Grisette.Core.Data.Class.PlainUnion (simpleMerge) +import Grisette.Core.Data.Class.SimpleMergeable (SimpleMergeable) +import Grisette.Lib.Control.Applicative ((.<*>)) +import Grisette.Lib.Data.Functor (mrgFmap, (.<$>)) + +(.$) :: (Mergeable a, SimpleMergeable b) => (a -> b) -> UnionM a -> b +(.$) f u = simpleMerge $ mrgFmap f u +{-# INLINE (.$) #-} + +infixr 0 .$ + +(.&) :: (Mergeable a, SimpleMergeable b) => UnionM a -> (a -> b) -> b +(.&) = flip (.$) +{-# INLINE (.&) #-} + +infixl 1 .& + +mrgOn :: + (Mergeable a, Mergeable b, SimpleMergeable c) => + (b -> b -> c) -> + (a -> b) -> + UnionM a -> + UnionM a -> + c +mrgOn f u l r = simpleMerge $ f .<$> mrgFmap u l .<*> mrgFmap u r diff --git a/src/Grisette/Lib/Data/Functor/Const.hs b/src/Grisette/Lib/Data/Functor/Const.hs new file mode 100644 index 00000000..fe4f332e --- /dev/null +++ b/src/Grisette/Lib/Data/Functor/Const.hs @@ -0,0 +1,12 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TemplateHaskell #-} + +module Grisette.Lib.Data.Functor.Const (mrgConst) where + +import Data.Functor.Const (Const) +import Grisette.Core.Data.Class.TryMerge (mrgSingle) +import Grisette.Core.TH.MergeConstructor (mkMergeConstructor) + +mkMergeConstructor "mrg" ''Const diff --git a/src/Grisette/Lib/Data/Functor/Sum.hs b/src/Grisette/Lib/Data/Functor/Sum.hs index 2d7cfe11..58232604 100644 --- a/src/Grisette/Lib/Data/Functor/Sum.hs +++ b/src/Grisette/Lib/Data/Functor/Sum.hs @@ -7,8 +7,6 @@ module Grisette.Lib.Data.Functor.Sum (mrgInR, mrgInL) where import Data.Functor.Sum (Sum) import Grisette.Core.Data.Class.TryMerge (mrgSingle) -import Grisette.Core.TH.MergeConstructor - ( mkMergeConstructor, - ) +import Grisette.Core.TH.MergeConstructor (mkMergeConstructor) mkMergeConstructor "mrg" ''Sum diff --git a/src/Grisette/Lib/Lens/Micro.hs b/src/Grisette/Lib/Lens/Micro.hs new file mode 100644 index 00000000..782357fc --- /dev/null +++ b/src/Grisette/Lib/Lens/Micro.hs @@ -0,0 +1,210 @@ +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Grisette.Lib.Lens.Micro + ( (.%~), + mrgOver, + (.+~), + (.-~), + (.<>~), + (..~), + mrgSet, + (.?~), + (..?~), + (.<%~), + (.<<%~), + (.<<.~), + mrgRewriteOf, + mrgTransformOf, + (.^.), + (.^..), + (.^?), + mrgTraverseOf_, + mrgForOf_, + ) +where + +import Data.Monoid (Endo, First) +import Grisette.Core.Control.Monad.UnionM (UnionM) +import Grisette.Core.Data.Class.Mergeable (Mergeable) +import Grisette.Core.Data.Class.PlainUnion (simpleMerge) +import Grisette.Core.Data.Class.SimpleMergeable + ( SimpleMergeable, + ) +import Grisette.Core.Data.Class.TryMerge (TryMerge) +import Grisette.Lib.Control.Applicative (mrgPure, (.*>)) +import Grisette.Lib.Control.Monad (mrgFmap, mrgReturn, mrgSequence, mrgVoid) +import Lens.Micro + ( ASetter, + Getting, + LensLike, + over, + (%~), + (.~), + (<%~), + (<<%~), + (<<.~), + (?~), + (^.), + (^..), + (^?), + ) +import Lens.Micro.Internal (foldMapOf, (#.)) + +(.%~) :: + (Mergeable a, Mergeable b) => + ASetter s t (UnionM a) (UnionM b) -> + (a -> b) -> + s -> + t +setter .%~ f = setter %~ (mrgFmap f) +{-# INLINE (.%~) #-} + +mrgOver :: + (Mergeable a, Mergeable b) => + ASetter s t (UnionM a) (UnionM b) -> + (a -> b) -> + s -> + t +mrgOver = (.%~) +{-# INLINE mrgOver #-} + +(.+~) :: + (Num a, Mergeable a) => ASetter s t (UnionM a) (UnionM a) -> a -> s -> t +setter .+~ a = setter .%~ (+ a) +{-# INLINE (.+~) #-} + +(.-~) :: + (Num a, Mergeable a) => ASetter s t (UnionM a) (UnionM a) -> a -> s -> t +setter .-~ a = setter .%~ (\x -> x - a) +{-# INLINE (.-~) #-} + +(.<>~) :: + (Monoid a, Mergeable a) => + ASetter s t (UnionM a) (UnionM a) -> + a -> + s -> + t +setter .<>~ a = setter .%~ (<> a) +{-# INLINE (.<>~) #-} + +(..~) :: (Mergeable b) => ASetter s t a (UnionM b) -> b -> s -> t +setter ..~ x = setter .~ (mrgReturn x) +{-# INLINE (..~) #-} + +mrgSet :: (Mergeable b) => ASetter s t a (UnionM b) -> b -> s -> t +mrgSet = (..~) +{-# INLINE mrgSet #-} + +(.?~) :: (Mergeable b) => ASetter s t a (UnionM (Maybe b)) -> b -> s -> t +setter .?~ b = setter .~ mrgReturn (Just b) +{-# INLINE (.?~) #-} + +(..?~) :: (Mergeable b) => ASetter s t a (Maybe (UnionM b)) -> b -> s -> t +setter ..?~ b = setter ?~ mrgReturn b +{-# INLINE (..?~) #-} + +(.<%~) :: + (Mergeable a, Mergeable b) => + LensLike ((,) (UnionM b)) s t (UnionM a) (UnionM b) -> + (a -> b) -> + s -> + (UnionM b, t) +setter .<%~ f = setter <%~ (mrgFmap f) +{-# INLINE (.<%~) #-} + +(.<<%~) :: + (Mergeable a, Mergeable b) => + LensLike ((,) (UnionM a)) s t (UnionM a) (UnionM b) -> + (a -> b) -> + s -> + (UnionM a, t) +setter .<<%~ f = setter <<%~ (mrgFmap f) +{-# INLINE (.<<%~) #-} + +(.<<.~) :: + (Mergeable b) => + LensLike ((,) (UnionM a)) s t (UnionM a) (UnionM b) -> + b -> + s -> + (UnionM a, t) +setter .<<.~ b = setter <<.~ mrgReturn b +{-# INLINE (.<<.~) #-} + +mrgRewriteOf :: + (Mergeable b) => + ASetter a b (UnionM a) (UnionM b) -> + (b -> UnionM (Maybe a)) -> + a -> + UnionM b +mrgRewriteOf l f = go + where + go = mrgTransformOf l $ \x -> do + res <- f x + case res of + Nothing -> mrgReturn x + Just y -> go y +{-# INLINE mrgRewriteOf #-} + +mrgTransformOf :: + (Mergeable b) => + ASetter a b (UnionM a) (UnionM b) -> + (b -> UnionM b) -> + a -> + UnionM b +mrgTransformOf l f = go + where + go = f . over l (>>= go) +{-# INLINE mrgTransformOf #-} + +(.^.) :: (SimpleMergeable a) => UnionM s -> Getting a s a -> a +s .^. getter = simpleMerge $ (^. getter) <$> s +{-# INLINE (.^.) #-} + +(.^..) :: + (Mergeable s, Mergeable a) => + s -> + Getting (Endo [UnionM a]) s (UnionM a) -> + UnionM [a] +s .^.. getter = mrgSequence (s ^.. getter) +{-# INLINE (.^..) #-} + +(.^?) :: + (Mergeable s, Mergeable a) => + s -> + Getting (First (UnionM a)) s (UnionM a) -> + UnionM (Maybe a) +s .^? getter = mrgSequence (s ^? getter) +{-# INLINE (.^?) #-} + +newtype MrgTraversed_ f = MrgTraversed_ {getMrgTraversed_ :: f ()} + +instance (Applicative f, TryMerge f) => Monoid (MrgTraversed_ f) where + mempty = MrgTraversed_ (mrgPure ()) + {-# INLINE mempty #-} + +instance (Applicative f, TryMerge f) => Semigroup (MrgTraversed_ f) where + MrgTraversed_ ma <> MrgTraversed_ mb = MrgTraversed_ (ma .*> mb) + {-# INLINE (<>) #-} + +mrgTraverseOf_ :: + (Functor f, TryMerge f) => + Getting (MrgTraversed_ f) s a -> + (a -> f r) -> + s -> + f () +mrgTraverseOf_ l f = + mrgVoid . getMrgTraversed_ #. foldMapOf l (MrgTraversed_ #. mrgVoid . f) +{-# INLINE mrgTraverseOf_ #-} + +mrgForOf_ :: + (Functor f, TryMerge f) => + Getting (MrgTraversed_ f) s a -> + s -> + (a -> f r) -> + f () +mrgForOf_ l s f = mrgTraverseOf_ l f s +{-# INLINE mrgForOf_ #-} diff --git a/test/Grisette/Lib/Data/FunctionTests.hs b/test/Grisette/Lib/Data/FunctionTests.hs new file mode 100644 index 00000000..ee12e0cb --- /dev/null +++ b/test/Grisette/Lib/Data/FunctionTests.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Grisette.Lib.Data.FunctionTests (functionTests) where + +import Grisette (SymInteger) +import Grisette.Core.Data.Class.ITEOp (ITEOp (symIte)) +import Grisette.Core.Data.Class.SimpleMergeable (mrgIf) +import Grisette.Lib.Data.Either (mrgLeft, mrgRight) +import Grisette.Lib.Data.Function (mrgOn, (.$), (.&)) +import Test.Framework (Test, testGroup) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit.Base ((@?=)) + +functionTests :: Test +functionTests = + testGroup + "Function" + [ testCase ".$" $ do + let actual = + (either (+ 1) (\x -> x - 2)) + .$ (mrgIf "cond" (mrgLeft "a") (mrgRight "b")) + let expected = symIte "cond" ("a" + 1) ("b" - 2) :: SymInteger + actual @?= expected, + testCase ".&" $ do + let actual = + (mrgIf "cond" (mrgLeft "a") (mrgRight "b")) + .& (either (+ 1) (\x -> x - 2)) + let expected = symIte "cond" ("a" + 1) ("b" - 2) :: SymInteger + actual @?= expected, + testCase "mrgOn" $ do + let f = (+) + let u (Left x) = x + u (Right x) = x + let actual = + mrgOn + f + u + (mrgIf "cond1" (mrgLeft "a1") (mrgRight "b1")) + (mrgIf "cond2" (mrgLeft "a2") (mrgRight "b2")) + let expected = + symIte "cond1" "a1" "b1" + symIte "cond2" "a2" "b2" :: + SymInteger + actual @?= expected + ] diff --git a/test/Grisette/Lib/Lens/MicroTests.hs b/test/Grisette/Lib/Lens/MicroTests.hs new file mode 100644 index 00000000..68f7d78e --- /dev/null +++ b/test/Grisette/Lib/Lens/MicroTests.hs @@ -0,0 +1,361 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# OPTIONS_GHC -Wno-unused-top-binds #-} + +module Grisette.Lib.Lens.MicroTests + ( microTests, + optimise, + Expr (..), + mrgVar, + mrgLit, + mrgAdd, + ) +where + +import Control.Applicative (Alternative ((<|>))) +import GHC.Generics (Generic) +import Grisette + ( Default (Default), + EvaluateSym, + Mergeable, + SymInteger, + UnionM, + mkMergeConstructor, + mrgIf, + ) +import Grisette.Core.Data.Class.ITEOp (ITEOp (symIte)) +import Grisette.Core.Data.Class.LogicalOp (LogicalOp (symNot)) +import Grisette.Core.Data.Class.SEq (SEq) +import Grisette.Core.Data.Class.TryMerge (mrgSingle) +import Grisette.Lib.Control.Monad (mrgReturn) +import Grisette.Lib.Data.Function ((.&)) +import Grisette.Lib.Data.Functor ((.<&>)) +import Grisette.Lib.Data.Maybe (mrgJust, mrgNothing) +import Grisette.Lib.Lens.Micro + ( mrgForOf_, + mrgOver, + mrgRewriteOf, + mrgSet, + mrgTransformOf, + mrgTraverseOf_, + (.%~), + (.+~), + (.-~), + (..?~), + (..~), + (.<%~), + (.<<%~), + (.<<.~), + (.<>~), + (.?~), + (.^.), + (.^..), + (.^?), + ) +import Grisette.TestUtil.NoMerge (noMergeNotMerged) +import Grisette.TestUtil.SymbolicAssertion ((.@?=)) +import Lens.Micro (Traversal', traversed, (&)) +import Lens.Micro.TH (makeLenses) +import Test.Framework + ( Test, + TestOptions' (topt_timeout), + plusTestOptions, + testGroup, + ) +import Test.Framework.Providers.HUnit (testCase) +import Test.HUnit ((@?=)) + +data A + = A1 {_a1 :: UnionM [SymInteger], _a2 :: SymInteger} + | A2 {_a1 :: UnionM [SymInteger]} + | A3 {_a2 :: SymInteger} + | A4 {_a4 :: UnionM Int} + | A5 {_a5 :: UnionM (Maybe SymInteger), _a6 :: Maybe (UnionM SymInteger)} + deriving (Eq, Show, Generic) + deriving (Mergeable) via (Default A) + +makeLenses ''A + +value1 :: A +value1 = A1 (mrgIf "c1" (return ["v1"]) (return ["v2", "v3"])) "v4" + +value2 :: A +value2 = A2 (mrgIf "c2" (return ["u1"]) (return ["u2", "u3"])) + +value3 :: A +value3 = A3 "w" + +value4 :: A +value4 = A4 (mrgIf "c4" (return 1) (return 2)) + +value5 :: A +value5 = + A5 + (mrgIf "c5" (mrgJust "x1") mrgNothing) + (Just (mrgReturn "x2")) + +value12 :: UnionM A +value12 = mrgIf "c12" (return value1) (return value2) + +value13 :: UnionM A +value13 = mrgIf "c13" (return value1) (return value3) + +data Expr = Var String | Lit Int | Add (UnionM Expr) (UnionM Expr) + deriving (Show, Eq, Generic) + deriving (Mergeable, SEq, EvaluateSym) via (Default Expr) + +mkMergeConstructor "mrg" ''Expr + +subExprs :: Traversal' Expr (UnionM Expr) +subExprs f = \case + Var name -> pure (Var name) + Lit n -> pure (Lit n) + Add a b -> Add <$> f a <*> f b + +constantFold :: Expr -> UnionM (Maybe Expr) +constantFold (Add l r) = do + l1 <- l + r1 <- r + case (l1, r1) of + (Lit lv, Lit rv) -> mrgJust (Lit (lv + rv)) + _ -> mrgNothing +constantFold _ = mrgNothing + +zeroAdditionIdentity :: Expr -> UnionM (Maybe Expr) +zeroAdditionIdentity (Add l r) = do + l1 <- l + r1 <- r + case (l1, r1) of + (Lit 0, x) -> mrgJust x + (x, Lit 0) -> mrgJust x + _ -> mrgNothing +zeroAdditionIdentity _ = mrgNothing + +optimise :: Expr -> UnionM Expr +optimise = + mrgRewriteOf + subExprs + ( \expr -> do + constantFoldRes <- constantFold expr + zeroAdditionIdentityRes <- zeroAdditionIdentity expr + mrgReturn $ constantFoldRes <|> zeroAdditionIdentityRes + ) + +constantFoldSingle :: Expr -> UnionM Expr +constantFoldSingle (Add l r) = do + l1 <- l + r1 <- r + case (l1, r1) of + (Lit lv, Lit rv) -> mrgLit (lv + rv) + _ -> mrgAdd (mrgReturn l1) (mrgReturn r1) +constantFoldSingle v = mrgReturn v + +microTests :: Test +microTests = + testGroup "Micro" $ + concat + [ do + (name, op) <- [("(.%~)", (.%~)), ("mrgOver", mrgOver)] + return . testGroup name $ + [ testCase "value13" $ do + let actual = value13 .<&> a1 `op` (++ ["a"]) + let expected = + mrgIf + "c13" + ( return $ + A1 + ( mrgIf + "c1" + (return ["v1", "a"]) + (return ["v2", "v3", "a"]) + ) + "v4" + ) + (return value3) + actual @?= expected, + testCase "value12" $ do + let actual = value12 .<&> a1 `op` (++ ["a"]) + let expected = + mrgIf + "c12" + ( return $ + A1 + ( mrgIf + "c1" + (return ["v1", "a"]) + (return ["v2", "v3", "a"]) + ) + "v4" + ) + ( return $ + A2 + ( mrgIf + "c2" + (return ["u1", "a"]) + (return ["u2", "u3", "a"]) + ) + ) + actual @?= expected + ], + do + (value, name, op, v, vop) <- + [ (value4, ".+~", (.+~), 1, (+ 1)), + (value4, ".-~", (.-~), 1, (\x -> x - 1)) + ] + return . testCase name $ do + let actual = value & a4 `op` v + let expected = + A4 (mrgIf "c4" (return $ vop 1) (return $ vop 2)) + actual @?= expected, + [ testCase ".<>~" $ do + let actual = value1 & a1 .<>~ ["b"] + let expected = + A1 + ( mrgIf "c1" (return ["v1", "b"]) (return ["v2", "v3", "b"]) + ) + "v4" + actual @?= expected + ], + do + (name, op) <- [("(..~)", (..~)), ("mrgSet", mrgSet)] + return $ testCase name $ do + let actual = value1 & a1 `op` ["b"] + let expected = A1 (mrgReturn ["b"]) "v4" + actual @?= expected, + [ testCase ".?~" $ do + let actual = value5 & a5 .?~ "y" + let expected = + A5 + (mrgJust "y") + (Just (mrgReturn "x2")) + actual @?= expected, + testCase "..?~" $ do + let actual = value5 & a6 ..?~ "y" + let expected = + A5 + (mrgIf "c5" (mrgJust "x1") mrgNothing) + (Just (mrgReturn "y")) + actual @?= expected, + testCase ".<%~" $ do + let actual = value1 & a1 .<%~ (++ ["b"]) + let u = mrgIf "c1" (return ["v1", "b"]) (return ["v2", "v3", "b"]) + let expected = (u, A1 u "v4") + actual @?= expected, + testCase ".<<%~" $ do + let actual = value1 & a1 .<<%~ (++ ["b"]) + let uorig = mrgIf "c1" (return ["v1"]) (return ["v2", "v3"]) + let u = mrgIf "c1" (return ["v1", "b"]) (return ["v2", "v3", "b"]) + let expected = (uorig, A1 u "v4") + actual @?= expected, + testCase ".<<.~" $ do + let actual = value1 & a1 .<<.~ ["b"] + let uorig = mrgIf "c1" (return ["v1"]) (return ["v2", "v3"]) + let expected = (uorig, A1 (mrgReturn ["b"]) "v4") + actual @?= expected, + testGroup + "mrgRewriteOf" + [ testCase "simple" $ do + let expr = Add (mrgLit 2) (mrgAdd (mrgLit (-1)) (mrgLit (-1))) + let actual = optimise expr + let expected = mrgLit 0 + actual @?= expected, + testCase "complex" $ do + let expr = + Add + (mrgVar "a") + ( mrgAdd + (mrgIf "z" (mrgLit (-1)) (mrgVar "b")) + (mrgIf "x" (mrgLit 1) (mrgLit 10)) + ) + let actual = optimise expr + let expected = + mrgIf + "z" + ( mrgIf + "x" + (mrgVar "a") + (mrgAdd (mrgVar "a") (mrgLit 9)) + ) + ( mrgAdd + (mrgVar "a") + ( mrgAdd + (mrgVar "b") + (mrgIf "x" (mrgLit 1) (mrgLit 10)) + ) + ) + actual .@?= expected + ], + testGroup + "mrgTransformOf" + [ testCase "simple" $ do + let expr = Add (mrgLit 2) (mrgAdd (mrgLit (-1)) (mrgLit (-1))) + let actual = mrgTransformOf subExprs constantFoldSingle expr + let expected = mrgLit 0 + actual @?= expected, + testCase "complex" $ do + let expr = + Add + (mrgVar "a") + ( mrgAdd + (mrgIf "z" (mrgLit (-1)) (mrgVar "b")) + (mrgIf "x" (mrgLit 1) (mrgLit 10)) + ) + let actual = mrgTransformOf subExprs constantFoldSingle expr + let expected = + mrgAdd + (mrgVar "a") + ( mrgIf + "z" + (mrgIf "x" (mrgLit 0) (mrgLit 9)) + ( mrgAdd + (mrgVar "b") + (mrgIf "x" (mrgLit 1) (mrgLit 10)) + ) + ) + actual .@?= expected + ], + testCase ".^." $ do + let actual = value12 .^. a1 + let expected = + mrgIf + (symIte "c12" "c1" "c2") + (return [symIte "c12" "v1" "u1"]) + ( return + [symIte "c12" "v2" "u2", symIte "c12" "v3" "u3"] + ) + actual @?= expected, + testCase ".^.." $ do + let expr = + Add (mrgLit 2) (mrgAdd (mrgLit (-1)) (mrgLit (-1))) + let actual = expr .^.. subExprs + let expected = + mrgReturn [Lit 2, Add (mrgLit (-1)) (mrgLit (-1))] + actual @?= expected, + testCase ".^?" $ do + let actual = value13 .& (.^? a1) + let expected = + mrgIf + (symNot "c13") + mrgNothing + (mrgIf "c1" (mrgJust ["v1"]) (mrgJust ["v2", "v3"])) + actual @?= expected, + plusTestOptions (mempty {topt_timeout = Just (Just 1000000)}) $ + testCase "mrgTraverseOf_" $ do + let expr = [1 .. 1000] + let actual = mrgTraverseOf_ traversed (const noMergeNotMerged) expr + let expected = mrgReturn () + actual @?= expected, + plusTestOptions (mempty {topt_timeout = Just (Just 1000000)}) $ + testCase "mrgForOf_" $ do + let expr = [1 .. 1000] + let actual = mrgForOf_ traversed expr (const noMergeNotMerged) + let expected = mrgReturn () + actual @?= expected + ] + ] diff --git a/test/Main.hs b/test/Main.hs index c113f444..887d3596 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -62,9 +62,11 @@ import Grisette.Lib.Control.Monad.Trans.State.StrictTests ) import Grisette.Lib.Control.MonadTests (monadFunctionTests) import Grisette.Lib.Data.FoldableTests (foldableFunctionTests) +import Grisette.Lib.Data.FunctionTests (functionTests) import Grisette.Lib.Data.FunctorTests (functorFunctionTests) import Grisette.Lib.Data.ListTests (listTests) import Grisette.Lib.Data.TraversableTests (traversableFunctionTests) +import Grisette.Lib.Lens.MicroTests (microTests) import Test.Framework (Test, defaultMain, testGroup) main :: IO () @@ -151,8 +153,10 @@ libTests = [ foldableFunctionTests, traversableFunctionTests, functorFunctionTests, - listTests - ] + listTests, + functionTests + ], + testGroup "Micro" [microTests] ] irTests :: Test