11
11
using .. Random
12
12
end
13
13
14
- function transform_samples_with_jacobian (unconst_samples, transform, n_samples)
15
- unconst_iter = AdvancedVI. eachsample (unconst_samples)
16
- unconst_init = first (unconst_iter)
17
-
18
- samples_init, logjac_init = with_logabsdet_jacobian (transform, unconst_init)
19
-
20
- samples_and_logjac = mapreduce (
21
- AdvancedVI. catsamples_and_acc,
22
- Iterators. drop (unconst_iter, 1 );
23
- init= (AdvancedVI. samples_expand_dim (samples_init), logjac_init)
24
- ) do sample
25
- with_logabsdet_jacobian (transform, sample)
26
- end
27
- samples = first (samples_and_logjac)
28
- logjac = last (samples_and_logjac)/ n_samples
29
- samples, logjac
30
- end
31
-
32
14
function AdvancedVI. reparam_with_entropy (
33
15
rng :: Random.AbstractRNG ,
34
16
q :: Bijectors.TransformedDistribution ,
@@ -41,14 +23,24 @@ function AdvancedVI.reparam_with_entropy(
41
23
q_unconst_stop = q_stop. dist
42
24
43
25
# Draw samples and compute entropy of the uncontrained distribution
44
- unconst_samples , unconst_entropy = AdvancedVI. reparam_with_entropy (
26
+ unconstr_samples , unconst_entropy = AdvancedVI. reparam_with_entropy (
45
27
rng, q_unconst, q_unconst_stop, n_samples, ent_est
46
28
)
47
29
48
30
# Apply bijector to samples while estimating its jacobian
49
- samples, logjac = transform_samples_with_jacobian (
50
- unconst_samples, transform, n_samples
51
- )
31
+ unconstr_iter = AdvancedVI. eachsample (unconstr_samples)
32
+ unconstr_init = first (unconstr_iter)
33
+ samples_init, logjac_init = with_logabsdet_jacobian (transform, unconstr_init)
34
+ samples_and_logjac = mapreduce (
35
+ AdvancedVI. catsamples_and_acc,
36
+ Iterators. drop (unconstr_iter, 1 );
37
+ init= (reshape (samples_init, (:,1 )), logjac_init)
38
+ ) do sample
39
+ with_logabsdet_jacobian (transform, sample)
40
+ end
41
+ samples = first (samples_and_logjac)
42
+ logjac = last (samples_and_logjac)/ n_samples
43
+
52
44
entropy = unconst_entropy + logjac
53
45
samples, entropy
54
46
end
0 commit comments