Skip to content

Commit d95c71b

Browse files
author
Sasha
committed
finish draft
1 parent 6330901 commit d95c71b

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

examples/probability.dx

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,15 @@ def support (AsDist x: Dist m) : List (m & Float) =
5151
concat $ for i. select (x.i > 0.0) (AsList 1 [(i, x.i)]) mempty
5252

5353
instance Arbitrary (Dist m)
54-
arb = \key. normalize $ arb key
54+
arb = \key.
55+
a = arb key
56+
normalize $ for i. abs a.i
5557

5658
' We can define some combinators for taking expectations.
5759

5860
def expect [VSpace out] (AsDist x: Dist m) (y : m => out) : out =
5961
sum for m'. x.m' .* y.m'
6062

61-
' And for independent distributions.
62-
63-
def (,,) (x: Dist m) (y: Dist n) : Dist (m & n) =
64-
AsDist for (m',n'). (m' ?? x) * (n' ?? y)
65-
6663
' To represent conditional probabilities such as $ Pr(B \ |\ A)$ we define a type alias.
6764

6865
def Pr (b:Type) (a:Type): Type = a => Dist b
@@ -166,7 +163,6 @@ indicator variables to represent data observations.
166163

167164
' ## Differential Posterior Inference
168165

169-
170166
' The network polynomial is a convenient method for computing probilities,
171167
but what makes it particularly useful is that it allows us to compute
172168
posterior probabilities simply using derivatives.
@@ -193,6 +189,9 @@ yields posterior terms.
193189

194190
def posterior (f : (Var a) -> Float) : Dist a =
195191
AsDist $ (grad (\ x. log $ f x)) one
192+
def posteriorTab (f : m => (Var a) -> Float) : m => Dist a =
193+
out = (grad (\ x. log $ f x)) one
194+
for i. AsDist $ out.i
196195

197196
' And this yields exactly the term above! This is really neat though because it
198197
doesn't require any application of model specific inference.
@@ -258,7 +257,7 @@ posterior (\m. two_dice latent m)
258257

259258
support $ posterior (\m. two_dice m (observed (roll_sum 4)))
260259

261-
' ## Conditional Independence
260+
' ## Discussion - Conditional Independence
262261

263262
' One tricky problem for discrete PPLs is modeling conditional independence.
264263
Models can be very slow to compute if we are not careful to exploint
@@ -323,7 +322,7 @@ def yesno (x:Bool) : Dist YesNo = delta $ select x yes no
323322
1. Finally we will see if we won.
324323

325324

326-
def monte_hall (change': Var YesNo) (win': Var YesNo) : Float =
325+
def monty_hall (change': Var YesNo) (win': Var YesNo) : Float =
327326
(one ~ uniform) (for (pick, correct): (Doors & Doors).
328327
(change' ~ uniform) (for change.
329328
(win' ~ (select (change == yes)
@@ -334,30 +333,53 @@ def monte_hall (change': Var YesNo) (win': Var YesNo) : Float =
334333
' To check the odds we will compute probabity of winning conditioned
335334
on changing.
336335

337-
yes ?? (posterior $ monte_hall (observed yes))
336+
yes ?? (posterior $ monty_hall (observed yes))
338337

339338

340339
' And compare to proability of winning with no change.
341340

342-
yes ?? (posterior $ monte_hall (observed no))
341+
yes ?? (posterior $ monty_hall (observed no))
343342

344343
' Finally a neat trick is that we can get both these terms by taking a second derivative. (TODO: show this in Dex)
345344

346345

347346
' ## Example 5: Hidden Markov Models
348347

348+
' Finally we conclude with a more complex example. A hidden Markov model is
349+
one of the most widely used discrete time series models. It models the relationship between discrete hidden states $Z$ and emissions $X$.
350+
351+
Z = Fin 5
352+
X = Fin 10
353+
354+
' It consists of three distributions: initial, transition, and emission.
355+
356+
initial : Pr Z nil = arb $ newKey 1
357+
emission : Pr X Z = arb $ newKey 2
358+
transition : Pr Z Z = arb $ newKey 3
359+
360+
' The model itself takes the following form for $m$ steps.
361+
'
362+
$$ z_0 \sim initial$$
363+
$$ z_1 \sim transition(z_0)$$
364+
$$ x_1 \sim emission(z_1)$$
365+
$$ ...$$
366+
367+
' This is implemented in reverse order for clarity (backward algorithm).
368+
369+
def hmm (init': Var Z) (x': m => Var X) (z' : m => Var Z) : Float =
370+
(init' ~ initial.nil) $ yieldState one ( \future .
371+
for i:m.
372+
j = ((size m) - (ordinal i) - 1)@_
373+
future := for z.
374+
(x'.j ~ emission.z) (for _.
375+
(z'.j ~ transition.z) (get future)))
376+
377+
378+
' We can marginalize out over latents.
349379

380+
hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) (for i. latent)
350381

351-
def hmm (hidden_vars : m => Var Z) (init_var: Var Z) (x_vars: m => Var X)
352-
(transition : CDist Z Z) (emission: CDist X Z)
353-
: Float =
354382

355-
-- Sample an initial state
356-
initial = sample init_var uniform []
357-
sum $ yieldState ( \zref .
358-
for i.
359-
-- Sample next state
360-
z' = markov $ sample hidden_vars.i transition (get zref)
383+
' Or we can compute the posterior probabilities of specific values.
361384

362-
-- Factor in evidence
363-
zref := sample x_vars.i emission z'')
385+
posteriorTab $ \z . hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) z

0 commit comments

Comments
 (0)