Skip to content

Commit 9ebfc3f

Browse files
authored
Basic rewrite of the package 2023 edition Part II: Location-scale variational families (#51)
* add location scale family * refactor switch bijector tests to use locscale, enable ReverseDiff * fix test file name for location-scale plus bijector inference test * fix wrong testset names, add interface test for VILocationScale * fix test parameters for `LocationScale` * fix test for LocationScale with Bijectors * add tests to improve coverage, fix bug for `rand!` with vectors * rename location scale, fix type ambiguity for `rand` * remove duplicate type tests for `LocationScale`
1 parent 576259a commit 9ebfc3f

File tree

6 files changed

+420
-17
lines changed

6 files changed

+420
-17
lines changed

src/AdvancedVI.jl

+10
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ export
141141
include("objectives/elbo/entropy.jl")
142142
include("objectives/elbo/repgradelbo.jl")
143143

144+
145+
# Variational Families
146+
export
147+
VILocationScale,
148+
MeanFieldGaussian,
149+
FullRankGaussian
150+
151+
include("families/location_scale.jl")
152+
153+
144154
# Optimization Routine
145155

146156
function optimize end

src/families/location_scale.jl

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
2+
"""
3+
MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution
4+
5+
The location scale variational family broadly represents various variational
6+
families using `location` and `scale` variational parameters.
7+
8+
It generally represents any distribution for which the sampling path can be
9+
represented as follows:
10+
```julia
11+
d = length(location)
12+
u = rand(dist, d)
13+
z = scale*u + location
14+
```
15+
"""
16+
struct MvLocationScale{
17+
S, D <: ContinuousDistribution, L
18+
} <: ContinuousMultivariateDistribution
19+
location::L
20+
scale ::S
21+
dist ::D
22+
end
23+
24+
Functors.@functor MvLocationScale (location, scale)
25+
26+
# Specialization of `Optimisers.destructure` for mean-field location-scale families.
27+
# These are necessary because we only want to extract the diagonal elements of
28+
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
29+
# is very inefficient.
30+
# begin
31+
struct RestructureMeanField{S <: Diagonal, D, L}
32+
q::MvLocationScale{S, D, L}
33+
end
34+
35+
function (re::RestructureMeanField)(flat::AbstractVector)
36+
n_dims = div(length(flat), 2)
37+
location = first(flat, n_dims)
38+
scale = Diagonal(last(flat, n_dims))
39+
MvLocationScale(location, scale, re.q.dist)
40+
end
41+
42+
function Optimisers.destructure(
43+
q::MvLocationScale{<:Diagonal, D, L}
44+
) where {D, L}
45+
@unpack location, scale, dist = q
46+
flat = vcat(location, diag(scale))
47+
flat, RestructureMeanField(q)
48+
end
49+
# end
50+
51+
Base.length(q::MvLocationScale) = length(q.location)
52+
53+
Base.size(q::MvLocationScale) = size(q.location)
54+
55+
Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
56+
57+
function StatsBase.entropy(q::MvLocationScale)
58+
@unpack location, scale, dist = q
59+
n_dims = length(location)
60+
n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
61+
end
62+
63+
function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
64+
@unpack location, scale, dist = q
65+
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
66+
end
67+
68+
function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
69+
@unpack location, scale, dist = q
70+
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
71+
end
72+
73+
function Distributions.rand(q::MvLocationScale)
74+
@unpack location, scale, dist = q
75+
n_dims = length(location)
76+
scale*rand(dist, n_dims) + location
77+
end
78+
79+
function Distributions.rand(
80+
rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
81+
) where {S, D, L}
82+
@unpack location, scale, dist = q
83+
n_dims = length(location)
84+
scale*rand(rng, dist, n_dims, num_samples) .+ location
85+
end
86+
87+
# This specialization improves AD performance of the sampling path
88+
function Distributions.rand(
89+
rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
90+
) where {L, D}
91+
@unpack location, scale, dist = q
92+
n_dims = length(location)
93+
scale_diag = diag(scale)
94+
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
95+
end
96+
97+
function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real})
98+
@unpack location, scale, dist = q
99+
rand!(rng, dist, x)
100+
x[:] = scale*x
101+
return x .+= location
102+
end
103+
104+
Distributions.mean(q::MvLocationScale) = q.location
105+
106+
function Distributions.var(q::MvLocationScale)
107+
C = q.scale
108+
Diagonal(C*C')
109+
end
110+
111+
function Distributions.cov(q::MvLocationScale)
112+
C = q.scale
113+
Hermitian(C*C')
114+
end
115+
116+
"""
117+
FullRankGaussian(location, scale; check_args = true)
118+
119+
Construct a Gaussian variational approximation with a dense covariance matrix.
120+
121+
# Arguments
122+
- `location::AbstractVector{T}`: Mean of the Gaussian.
123+
- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.
124+
125+
# Keyword Arguments
126+
- `check_args`: Check the conditioning of the initial scale (default: `true`).
127+
"""
128+
function FullRankGaussian(
129+
μ::AbstractVector{T},
130+
L::LinearAlgebra.AbstractTriangular{T};
131+
check_args::Bool = true
132+
) where {T <: Real}
133+
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
134+
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
135+
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
136+
end
137+
q_base = Normal{T}(zero(T), one(T))
138+
MvLocationScale(μ, L, q_base)
139+
end
140+
141+
"""
142+
MeanFieldGaussian(location, scale; check_args = true)
143+
144+
Construct a Gaussian variational approximation with a diagonal covariance matrix.
145+
146+
# Arguments
147+
- `location::AbstractVector{T}`: Mean of the Gaussian.
148+
- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
149+
150+
# Keyword Arguments
151+
- `check_args`: Check the conditioning of the initial scale (default: `true`).
152+
"""
153+
function MeanFieldGaussian(
154+
μ::AbstractVector{T},
155+
L::Diagonal{T};
156+
check_args::Bool = true
157+
) where {T <: Real}
158+
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
159+
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
160+
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
161+
end
162+
q_base = Normal{T}(zero(T), one(T))
163+
MvLocationScale(μ, L, q_base)
164+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
3+
4+
using Test
5+
6+
@testset "inference RepGradELBO VILocationScale" begin
7+
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
8+
realtype [Float64, Float32],
9+
(modelname, modelconstr) Dict(
10+
:Normal=> normal_meanfield,
11+
:Normal=> normal_fullrank,
12+
),
13+
(objname, objective) Dict(
14+
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
15+
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
16+
),
17+
(adbackname, adbackend) Dict(
18+
:ForwarDiff => AutoForwardDiff(),
19+
:ReverseDiff => AutoReverseDiff(),
20+
:Zygote => AutoZygote(),
21+
#:Enzyme => AutoEnzyme(),
22+
)
23+
24+
seed = (0x38bef07cf9cc549d)
25+
rng = StableRNG(seed)
26+
27+
modelstats = modelconstr(rng, realtype)
28+
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
29+
30+
T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
31+
32+
q0 = if is_meanfield
33+
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
34+
else
35+
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
36+
FullRankGaussian(zeros(realtype, n_dims), L0)
37+
end
38+
39+
@testset "convergence" begin
40+
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
41+
q, stats, _ = optimize(
42+
rng, model, objective, q0, T;
43+
optimizer = Optimisers.Adam(realtype(η)),
44+
show_progress = PROGRESS,
45+
adbackend = adbackend,
46+
)
47+
48+
μ = q.location
49+
L = q.scale
50+
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
51+
52+
@test Δλ Δλ₀/T^(1/4)
53+
@test eltype(μ) == eltype(μ_true)
54+
@test eltype(L) == eltype(L_true)
55+
end
56+
57+
@testset "determinism" begin
58+
rng = StableRNG(seed)
59+
q, stats, _ = optimize(
60+
rng, model, objective, q0, T;
61+
optimizer = Optimisers.Adam(realtype(η)),
62+
show_progress = PROGRESS,
63+
adbackend = adbackend,
64+
)
65+
μ = q.location
66+
L = q.scale
67+
68+
rng_repl = StableRNG(seed)
69+
q, stats, _ = optimize(
70+
rng_repl, model, objective, q0, T;
71+
optimizer = Optimisers.Adam(realtype(η)),
72+
show_progress = PROGRESS,
73+
adbackend = adbackend,
74+
)
75+
μ_repl = q.location
76+
L_repl = q.scale
77+
@test μ == μ_repl
78+
@test L == L_repl
79+
end
80+
end
81+
end
82+

test/inference/repgradelbo_distributionsad_bijectors.jl renamed to test/inference/repgradelbo_locationscale_bijectors.jl

+21-16
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
33

44
using Test
55

6-
@testset "inference RepGradELBO DistributionsAD Bijectors" begin
6+
@testset "inference RepGradELBO VILocationScale Bijectors" begin
77
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
88
realtype [Float64, Float32],
99
(modelname, modelconstr) Dict(
@@ -15,7 +15,7 @@ using Test
1515
),
1616
(adbackname, adbackend) Dict(
1717
:ForwarDiff => AutoForwardDiff(),
18-
#:ReverseDiff => AutoReverseDiff(),
18+
:ReverseDiff => AutoReverseDiff(),
1919
#:Zygote => AutoZygote(),
2020
#:Enzyme => AutoEnzyme(),
2121
)
@@ -30,23 +30,28 @@ using Test
3030

3131
b = Bijectors.bijector(model)
3232
b⁻¹ = inverse(b)
33-
μ₀ = Zeros(realtype, n_dims)
34-
L₀ = Diagonal(Ones(realtype, n_dims))
33+
μ0 = Zeros(realtype, n_dims)
34+
L0 = Diagonal(Ones(realtype, n_dims))
3535

36-
q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
37-
q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
36+
q0_η = if is_meanfield
37+
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
38+
else
39+
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
40+
FullRankGaussian(zeros(realtype, n_dims), L0)
41+
end
42+
q0_z = Bijectors.transformed(q0_η, b⁻¹)
3843

3944
@testset "convergence" begin
40-
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
45+
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
4146
q, stats, _ = optimize(
42-
rng, model, objective, q₀_z, T;
47+
rng, model, objective, q0_z, T;
4348
optimizer = Optimisers.Adam(realtype(η)),
4449
show_progress = PROGRESS,
4550
adbackend = adbackend,
4651
)
4752

48-
μ = mean(q.dist)
49-
L = sqrt(cov(q.dist))
53+
μ = q.dist.location
54+
L = q.dist.scale
5055
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
5156

5257
@test Δλ Δλ₀/T^(1/4)
@@ -57,23 +62,23 @@ using Test
5762
@testset "determinism" begin
5863
rng = StableRNG(seed)
5964
q, stats, _ = optimize(
60-
rng, model, objective, q₀_z, T;
65+
rng, model, objective, q0_z, T;
6166
optimizer = Optimisers.Adam(realtype(η)),
6267
show_progress = PROGRESS,
6368
adbackend = adbackend,
6469
)
65-
μ = mean(q.dist)
66-
L = sqrt(cov(q.dist))
70+
μ = q.dist.location
71+
L = q.dist.scale
6772

6873
rng_repl = StableRNG(seed)
6974
q, stats, _ = optimize(
70-
rng_repl, model, objective, q₀_z, T;
75+
rng_repl, model, objective, q0_z, T;
7176
optimizer = Optimisers.Adam(realtype(η)),
7277
show_progress = PROGRESS,
7378
adbackend = adbackend,
7479
)
75-
μ_repl = mean(q.dist)
76-
L_repl = sqrt(cov(q.dist))
80+
μ_repl = q.dist.location
81+
L_repl = q.dist.scale
7782
@test μ == μ_repl
7883
@test L == L_repl
7984
end

0 commit comments

Comments
 (0)