Skip to content

Commit 05dbb51

Browse files
committed
fix bug for bijector with 1 MC sample with tests
1 parent 9ebfc3f commit 05dbb51

File tree

5 files changed

+57
-40
lines changed

5 files changed

+57
-40
lines changed

ext/AdvancedVIBijectorsExt.jl

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,45 @@ else
1111
using ..Random
1212
end
1313

14-
function AdvancedVI.reparam_with_entropy(
15-
rng ::Random.AbstractRNG,
16-
q ::Bijectors.TransformedDistribution,
17-
q_stop ::Bijectors.TransformedDistribution,
18-
n_samples::Int,
19-
ent_est ::AdvancedVI.AbstractEntropyEstimator
20-
)
21-
transform = q.transform
22-
q_base = q.dist
23-
q_base_stop = q_stop.dist
24-
base_samples = rand(rng, q_base, n_samples)
25-
it = AdvancedVI.eachsample(base_samples)
26-
sample_init = first(it)
14+
function transform_samples_with_jacobian(unconst_samples, transform, n_samples)
15+
unconst_iter = AdvancedVI.eachsample(unconst_samples)
16+
unconst_init = first(unconst_iter)
17+
18+
samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init)
2719

2820
samples_and_logjac = mapreduce(
2921
AdvancedVI.catsamples_and_acc,
30-
Iterators.drop(it, 1);
31-
init=with_logabsdet_jacobian(transform, sample_init)
22+
Iterators.drop(unconst_iter, 1);
23+
init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init)
3224
) do sample
3325
with_logabsdet_jacobian(transform, sample)
3426
end
3527
samples = first(samples_and_logjac)
36-
logjac = last(samples_and_logjac)
28+
logjac = last(samples_and_logjac)/n_samples
29+
samples, logjac
30+
end
3731

38-
entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
39-
ent_est, base_samples, q_base, q_base_stop
32+
function AdvancedVI.reparam_with_entropy(
33+
rng ::Random.AbstractRNG,
34+
q ::Bijectors.TransformedDistribution,
35+
q_stop ::Bijectors.TransformedDistribution,
36+
n_samples::Int,
37+
ent_est ::AdvancedVI.AbstractEntropyEstimator
38+
)
39+
transform = q.transform
40+
q_unconst = q.dist
41+
q_unconst_stop = q_stop.dist
42+
43+
# Draw samples and compute entropy of the uncontrained distribution
44+
unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
45+
rng, q_unconst, q_unconst_stop, n_samples, ent_est
4046
)
4147

42-
entropy = entropy_base + logjac/n_samples
48+
# Apply bijector to samples while estimating its jacobian
49+
samples, logjac = transform_samples_with_jacobian(
50+
unconst_samples, transform, n_samples
51+
)
52+
entropy = unconst_entropy + logjac
4353
samples, entropy
4454
end
4555
end

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,7 @@ function catsamples_and_acc(
3434
return (x, ∑y)
3535
end
3636

37+
function samples_expand_dim(x::AbstractVector)
38+
reshape(x, (:,1))
39+
end
40+

test/inference/repgradelbo_distributionsad.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ using Test
99
(modelname, modelconstr) Dict(
1010
:Normal=> normal_meanfield,
1111
),
12-
(objname, objective) Dict(
13-
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
14-
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
12+
n_montecarlo in [1, 10],
13+
(objname, objective) in Dict(
14+
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
15+
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
1516
),
1617
(adbackname, adbackend) Dict(
1718
:ForwarDiff => AutoForwardDiff(),
@@ -33,7 +34,7 @@ using Test
3334
q0 = TuringDiagMvNormal(μ0, diag(L0))
3435

3536
@testset "convergence" begin
36-
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
37+
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
3738
q, stats, _ = optimize(
3839
rng, model, objective, q0, T;
3940
optimizer = Optimisers.Adam(realtype(η)),
@@ -45,7 +46,7 @@ using Test
4546
L = sqrt(cov(q))
4647
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
4748

48-
@test Δλ Δλ₀/T^(1/4)
49+
@test Δλ Δλ0/T^(1/4)
4950
@test eltype(μ) == eltype(μ_true)
5051
@test eltype(L) == eltype(L_true)
5152
end

test/inference/repgradelbo_locationscale.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@ using Test
55

66
@testset "inference RepGradELBO VILocationScale" begin
77
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
8-
realtype [Float64, Float32],
9-
(modelname, modelconstr) Dict(
8+
realtype in [Float64, Float32],
9+
(modelname, modelconstr) in Dict(
1010
:Normal=> normal_meanfield,
1111
:Normal=> normal_fullrank,
1212
),
13-
(objname, objective) Dict(
14-
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
15-
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
13+
n_montecarlo in [1, 10],
14+
(objname, objective) in Dict(
15+
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
16+
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
1617
),
17-
(adbackname, adbackend) Dict(
18+
(adbackname, adbackend) in Dict(
1819
:ForwarDiff => AutoForwardDiff(),
1920
:ReverseDiff => AutoReverseDiff(),
2021
:Zygote => AutoZygote(),
@@ -37,7 +38,7 @@ using Test
3738
end
3839

3940
@testset "convergence" begin
40-
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
41+
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
4142
q, stats, _ = optimize(
4243
rng, model, objective, q0, T;
4344
optimizer = Optimisers.Adam(realtype(η)),
@@ -49,7 +50,7 @@ using Test
4950
L = q.scale
5051
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
5152

52-
@test Δλ Δλ₀/T^(1/4)
53+
@test Δλ Δλ0/T^(1/4)
5354
@test eltype(μ) == eltype(μ_true)
5455
@test eltype(L) == eltype(L_true)
5556
end

test/inference/repgradelbo_locationscale_bijectors.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@ using Test
55

66
@testset "inference RepGradELBO VILocationScale Bijectors" begin
77
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
8-
realtype [Float64, Float32],
9-
(modelname, modelconstr) Dict(
8+
realtype in [Float64, Float32],
9+
(modelname, modelconstr) in Dict(
1010
:NormalLogNormalMeanField => normallognormal_meanfield,
1111
),
12-
(objname, objective) Dict(
13-
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
14-
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
12+
n_montecarlo in [1, 10],
13+
(objname, objective) in Dict(
14+
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
15+
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
1516
),
16-
(adbackname, adbackend) Dict(
17+
(adbackname, adbackend) in Dict(
1718
:ForwarDiff => AutoForwardDiff(),
1819
:ReverseDiff => AutoReverseDiff(),
1920
#:Zygote => AutoZygote(),
@@ -42,7 +43,7 @@ using Test
4243
q0_z = Bijectors.transformed(q0_η, b⁻¹)
4344

4445
@testset "convergence" begin
45-
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
46+
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
4647
q, stats, _ = optimize(
4748
rng, model, objective, q0_z, T;
4849
optimizer = Optimisers.Adam(realtype(η)),
@@ -54,7 +55,7 @@ using Test
5455
L = q.dist.scale
5556
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
5657

57-
@test Δλ Δλ₀/T^(1/4)
58+
@test Δλ Δλ0/T^(1/4)
5859
@test eltype(μ) == eltype(μ_true)
5960
@test eltype(L) == eltype(L_true)
6061
end

0 commit comments

Comments
 (0)