Skip to content

Commit 31db7bc

Browse files
committed
fix remove redundant helpers for reparam_with_entropy for bijector
1 parent 05dbb51 commit 31db7bc

File tree

2 files changed

+14
-28
lines changed

2 files changed

+14
-28
lines changed

ext/AdvancedVIBijectorsExt.jl

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,6 @@ else
1111
using ..Random
1212
end
1313

14-
function transform_samples_with_jacobian(unconst_samples, transform, n_samples)
15-
unconst_iter = AdvancedVI.eachsample(unconst_samples)
16-
unconst_init = first(unconst_iter)
17-
18-
samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init)
19-
20-
samples_and_logjac = mapreduce(
21-
AdvancedVI.catsamples_and_acc,
22-
Iterators.drop(unconst_iter, 1);
23-
init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init)
24-
) do sample
25-
with_logabsdet_jacobian(transform, sample)
26-
end
27-
samples = first(samples_and_logjac)
28-
logjac = last(samples_and_logjac)/n_samples
29-
samples, logjac
30-
end
31-
3214
function AdvancedVI.reparam_with_entropy(
3315
rng ::Random.AbstractRNG,
3416
q ::Bijectors.TransformedDistribution,
@@ -41,14 +23,24 @@ function AdvancedVI.reparam_with_entropy(
4123
q_unconst_stop = q_stop.dist
4224

4325
# Draw samples and compute entropy of the uncontrained distribution
44-
unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
26+
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
4527
rng, q_unconst, q_unconst_stop, n_samples, ent_est
4628
)
4729

4830
# Apply bijector to samples while estimating its jacobian
49-
samples, logjac = transform_samples_with_jacobian(
50-
unconst_samples, transform, n_samples
51-
)
31+
unconstr_iter = AdvancedVI.eachsample(unconstr_samples)
32+
unconstr_init = first(unconstr_iter)
33+
samples_init, logjac_init = with_logabsdet_jacobian(transform, unconstr_init)
34+
samples_and_logjac = mapreduce(
35+
AdvancedVI.catsamples_and_acc,
36+
Iterators.drop(unconstr_iter, 1);
37+
init=(reshape(samples_init, (:,1)), logjac_init)
38+
) do sample
39+
with_logabsdet_jacobian(transform, sample)
40+
end
41+
samples = first(samples_and_logjac)
42+
logjac = last(samples_and_logjac)/n_samples
43+
5244
entropy = unconst_entropy + logjac
5345
samples, entropy
5446
end

src/utils.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ end
2323

2424
eachsample(samples::AbstractMatrix) = eachcol(samples)
2525

26-
eachsample(samples::AbstractVector) = samples
27-
2826
function catsamples_and_acc(
2927
state_curr::Tuple{<:AbstractArray, <:Real},
3028
state_new ::Tuple{<:AbstractVector, <:Real}
@@ -34,7 +32,3 @@ function catsamples_and_acc(
3432
return (x, ∑y)
3533
end
3634

37-
function samples_expand_dim(x::AbstractVector)
38-
reshape(x, (:,1))
39-
end
40-

0 commit comments

Comments
 (0)