|
11 | 11 | using ..Random
|
12 | 12 | end
|
13 | 13 |
|
14 |
| -function AdvancedVI.reparam_with_entropy( |
15 |
| - rng ::Random.AbstractRNG, |
16 |
| - q ::Bijectors.TransformedDistribution, |
17 |
| - q_stop ::Bijectors.TransformedDistribution, |
18 |
| - n_samples::Int, |
19 |
| - ent_est ::AdvancedVI.AbstractEntropyEstimator |
20 |
| -) |
21 |
| - transform = q.transform |
22 |
| - q_base = q.dist |
23 |
| - q_base_stop = q_stop.dist |
24 |
| - base_samples = rand(rng, q_base, n_samples) |
25 |
| - it = AdvancedVI.eachsample(base_samples) |
26 |
| - sample_init = first(it) |
| 14 | +function transform_samples_with_jacobian(unconst_samples, transform, n_samples) |
| 15 | + unconst_iter = AdvancedVI.eachsample(unconst_samples) |
| 16 | + unconst_init = first(unconst_iter) |
| 17 | + |
| 18 | + samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init) |
27 | 19 |
|
28 | 20 | samples_and_logjac = mapreduce(
|
29 | 21 | AdvancedVI.catsamples_and_acc,
|
30 |
| - Iterators.drop(it, 1); |
31 |
| - init=with_logabsdet_jacobian(transform, sample_init) |
| 22 | + Iterators.drop(unconst_iter, 1); |
| 23 | + init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init) |
32 | 24 | ) do sample
|
33 | 25 | with_logabsdet_jacobian(transform, sample)
|
34 | 26 | end
|
35 | 27 | samples = first(samples_and_logjac)
|
36 |
| - logjac = last(samples_and_logjac) |
| 28 | + logjac = last(samples_and_logjac)/n_samples |
| 29 | + samples, logjac |
| 30 | +end |
37 | 31 |
|
38 |
| - entropy_base = AdvancedVI.estimate_entropy_maybe_stl( |
39 |
| - ent_est, base_samples, q_base, q_base_stop |
| 32 | +function AdvancedVI.reparam_with_entropy( |
| 33 | + rng ::Random.AbstractRNG, |
| 34 | + q ::Bijectors.TransformedDistribution, |
| 35 | + q_stop ::Bijectors.TransformedDistribution, |
| 36 | + n_samples::Int, |
| 37 | + ent_est ::AdvancedVI.AbstractEntropyEstimator |
| 38 | +) |
| 39 | + transform = q.transform |
| 40 | + q_unconst = q.dist |
| 41 | + q_unconst_stop = q_stop.dist |
| 42 | + |
| 43 | + # Draw samples and compute entropy of the uncontrained distribution |
| 44 | + unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( |
| 45 | + rng, q_unconst, q_unconst_stop, n_samples, ent_est |
40 | 46 | )
|
41 | 47 |
|
42 |
| - entropy = entropy_base + logjac/n_samples |
| 48 | + # Apply bijector to samples while estimating its jacobian |
| 49 | + samples, logjac = transform_samples_with_jacobian( |
| 50 | + unconst_samples, transform, n_samples |
| 51 | + ) |
| 52 | + entropy = unconst_entropy + logjac |
43 | 53 | samples, entropy
|
44 | 54 | end
|
45 | 55 | end
|
0 commit comments