Skip to content

Basic rewrite of the package 2023 edition Part II: Location-scale variational families #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 20, 2023
10 changes: 10 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ export
include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")


# Variational Families
export
VILocationScale,
MeanFieldGaussian,
FullRankGaussian

include("families/location_scale.jl")


# Optimization Routine

function optimize end
Expand Down
164 changes: 164 additions & 0 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@

"""
MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution

The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.

It generally represents any distribution for which the sampling path can be
represented as follows:
```julia
d = length(location)
u = rand(dist, d)
z = scale*u + location
```
"""
struct MvLocationScale{
S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
location::L
scale ::S
dist ::D
end

Functors.@functor MvLocationScale (location, scale)

# Specialization of `Optimisers.destructure` for mean-field location-scale families.
# These are necessary because we only want to extract the diagonal elements of
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{S <: Diagonal, D, L}
q::MvLocationScale{S, D, L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
MvLocationScale(location, scale, re.q.dist)
end

function Optimisers.destructure(
q::MvLocationScale{<:Diagonal, D, L}
) where {D, L}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
flat, RestructureMeanField(q)
end
# end

Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
end

function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
end

function Distributions.rand(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(dist, n_dims) + location
end

function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
) where {S, D, L}
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
@unpack location, scale, dist = q
n_dims = length(location)
scale_diag = diag(scale)
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end

function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real})
@unpack location, scale, dist = q
rand!(rng, dist, x)
x[:] = scale*x
return x .+= location
end

Distributions.mean(q::MvLocationScale) = q.location

function Distributions.var(q::MvLocationScale)
C = q.scale
Diagonal(C*C')
end

function Distributions.cov(q::MvLocationScale)
C = q.scale
Hermitian(C*C')
end

"""
FullRankGaussian(location, scale; check_args = true)

Construct a Gaussian variational approximation with a dense covariance matrix.

# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
check_args::Bool = true
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
end

"""
MeanFieldGaussian(location, scale; check_args = true)

Construct a Gaussian variational approximation with a diagonal covariance matrix.

# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
check_args::Bool = true
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
end
82 changes: 82 additions & 0 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

q0 = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
else
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
FullRankGaussian(zeros(realtype, n_dims), L0)
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = q.location
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end

@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = q.location
L = q.scale

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = q.location
L_repl = q.scale
@test μ == μ_repl
@test L == L_repl
end
end
end

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference RepGradELBO DistributionsAD Bijectors" begin
@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
Expand All @@ -15,7 +15,7 @@ using Test
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)
Expand All @@ -30,23 +30,28 @@ using Test

b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))
μ0 = Zeros(realtype, n_dims)
L0 = Diagonal(Ones(realtype, n_dims))

q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
q0_η = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
else
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
FullRankGaussian(zeros(realtype, n_dims), L0)
end
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = mean(q.dist)
L = sqrt(cov(q.dist))
μ = q.dist.location
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
Expand All @@ -57,23 +62,23 @@ using Test
@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = mean(q.dist)
L = sqrt(cov(q.dist))
μ = q.dist.location
L = q.dist.scale

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
rng_repl, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q.dist)
L_repl = sqrt(cov(q.dist))
μ_repl = q.dist.location
L_repl = q.dist.scale
@test μ == μ_repl
@test L == L_repl
end
Expand Down
Loading