Skip to content

Commit 7f7bf90

Browse files
committed
Add and use applicative population transformer
1 parent b582f8d commit 7f7bf90

File tree

4 files changed

+42
-14
lines changed

4 files changed

+42
-14
lines changed

monad-bayes.cabal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,21 @@ common deps
6363
, scientific ^>=0.3
6464
, statistics >=0.14.0 && <0.17
6565
, text >=1.2 && <2.1
66+
, transformers ^>=0.5.6
6667
, vector >=0.12.0 && <0.14
6768
, vty ^>=5.38
6869

6970
common test-deps
7071
build-depends:
7172
, abstract-par ^>=0.3
72-
, criterion >=1.5 && <1.7
73+
, criterion >=1.5 && <1.7
7374
, directory ^>=1.3
7475
, hspec ^>=2.11
7576
, monad-bayes
76-
, optparse-applicative >=0.17 && <0.19
77+
, optparse-applicative >=0.17 && <0.19
7778
, process ^>=1.6
7879
, QuickCheck ^>=2.14
79-
, time >=1.9 && <1.13
80-
, transformers ^>=0.5.6
80+
, time >=1.9 && <1.13
8181
, typed-process ^>=0.2
8282

8383
autogen-modules: Paths_monad_bayes

src/Control/Applicative/List.hs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
1+
{-# LANGUAGE StandaloneDeriving #-}
2+
13
module Control.Applicative.List where
24

35
-- base
46

57
import Control.Applicative
6-
import Data.Functor.Compose
8+
import Control.Monad.Trans.Writer.Strict
9+
import Numeric.Log (Log)
10+
11+
-- * Applicative ListT
712

813
-- | _Applicative_ transformer adding a list/nondeterminism/choice effect.
914
-- It is not a valid monad transformer, but it is a valid 'Applicative'.
10-
newtype ListT m a = ListT {getListT :: Compose [] m a}
15+
newtype ListT m a = ListT {getListT :: Compose m [] a}
1116
deriving newtype (Functor, Applicative, Alternative)
1217

13-
lift :: m a -> ListT m a
14-
lift = ListT . Compose . pure
18+
lift :: (Functor m) => m a -> ListT m a
19+
lift = ListT . Compose . fmap pure
1520

16-
runListT :: ListT m a -> [m a]
21+
runListT :: ListT m a -> m [a]
1722
runListT = getCompose . getListT
23+
24+
-- * Applicative Population transformer
25+
26+
-- WriterT has to be used instead of WeightedT,
27+
-- since WeightedT uses StateT under the hood,
28+
-- which requires a Monad (ListT m) constraint.
29+
newtype PopulationT m a = PopulationT {getPopulationT :: WriterT (Log Double) (ListT m) a}
30+
deriving newtype (Functor, Applicative, Alternative)
31+
32+
runPopulationT :: PopulationT m a -> m [(a, Log Double)]
33+
runPopulationT = runListT . runWriterT . getPopulationT
34+
35+
fromWeightedList :: m [(a, Log Double)] -> PopulationT m a
36+
fromWeightedList = PopulationT . WriterT . ListT . Compose

src/Control/Monad/Bayes/Inference/RMSMC.hs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import Control.Monad.Bayes.Inference.SMC
2626
import Control.Monad.Bayes.Population
2727
( PopulationT,
2828
flatten,
29+
single,
2930
withParticles,
3031
)
3132
import Control.Monad.Bayes.Sequential.Coroutine as Seq
@@ -50,7 +51,7 @@ rmsmc ::
5051
PopulationT m a
5152
rmsmc (MCMCConfig {..}) (SMCConfig {..}) =
5253
marginal
53-
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist flatten . mhStep) . TrStat.hoist resampler) numSteps
54+
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
5455
. S.hoistFirst (TrStat.hoist (withParticles numParticles))
5556

5657
-- | Resample-move Sequential Monte Carlo with a more efficient
@@ -64,7 +65,7 @@ rmsmcBasic ::
6465
PopulationT m a
6566
rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) =
6667
TrBas.marginal
67-
. S.sequentially (TrBas.hoist flatten . composeCopies numMCMCSteps (TrBas.hoist flatten . TrBas.mhStep) . TrBas.hoist resampler) numSteps
68+
. S.sequentially (TrBas.hoist (single . flatten) . composeCopies numMCMCSteps (TrBas.hoist (single . flatten) . TrBas.mhStep) . TrBas.hoist resampler) numSteps
6869
. S.hoistFirst (TrBas.hoist (withParticles numParticles))
6970

7071
-- | A variant of resample-move Sequential Monte Carlo
@@ -79,7 +80,7 @@ rmsmcDynamic ::
7980
PopulationT m a
8081
rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) =
8182
TrDyn.marginal
82-
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist flatten . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps
83+
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist (single . flatten) . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps
8384
. S.hoistFirst (TrDyn.hoist (withParticles numParticles))
8485

8586
-- | Apply a function a given number of times.

src/Control/Monad/Bayes/Population.hs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ module Control.Monad.Bayes.Population
3535
popAvg,
3636
withParticles,
3737
flatten,
38+
single,
3839
)
3940
where
4041

4142
import Control.Applicative (Alternative)
43+
import Control.Applicative.List qualified as ApplicativeListT
4244
import Control.Arrow (second)
4345
import Control.Monad (MonadPlus, replicateM)
4446
import Control.Monad.Bayes.Class
@@ -277,5 +279,11 @@ hoist ::
277279
hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT
278280

279281
-- | Flatten all layers of the free structure
280-
flatten :: (Monad m) => PopulationT m a -> PopulationT m a
281-
flatten = fromWeightedList . runPopulationT
282+
flatten :: (Monad m) => PopulationT m a -> ApplicativeListT.PopulationT m a
283+
flatten = ApplicativeListT.fromWeightedList . runPopulationT
284+
285+
-- | Create a population from a single layer of branching computations.
286+
--
287+
-- Similar to 'fromWeightedListT'.
288+
single :: (Monad m) => ApplicativeListT.PopulationT m a -> PopulationT m a
289+
single = fromWeightedList . ApplicativeListT.runPopulationT

0 commit comments

Comments
 (0)