Skip to content

Commit

Permalink
refactor create unified interface for estimating objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Nov 10, 2023
1 parent a03e955 commit ff83c03
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 25 deletions.
24 changes: 23 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function value_and_gradient! end
Abstract type for the VI algorithms supported by `AdvancedVI`.
# Implementations
To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`.
To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`.
Also, it should provide gradients by implementing the function `estimate_gradient!`.
If the estimator is stateful, it can implement `init` to initialize the state.
"""
Expand All @@ -71,6 +71,28 @@ init(
::Any
) = nothing

"""
estimate_objective([rng,] obj, q, prob, kwargs...)
Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`.
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `q`: Variational approximation.
# Keyword Arguments
For the keywword arguments, refer to the respective documentation for each variational objective.
# Returns
- `obj_est`: Estimate of the objective value.
"""
function estimate_objective end

export estimate_objective


"""
estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state)
Expand Down
40 changes: 22 additions & 18 deletions src/objectives/elbo/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,21 @@ Base.show(io::IO, advi::ADVI) =
Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample).
"""
function (advi::ADVI)(
function estimate_objective_with_samples(
advi::ADVI,
q ::Distributions.ContinuousMultivariateDistribution,
prob,
q ::Distributions.ContinuousMultivariateDistribution,
zs::AbstractMatrix
zs ::AbstractMatrix
)
𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs))
= advi.entropy(q, zs)
𝔼ℓ +
end

function (advi::ADVI)(
prob,
function estimate_objective_with_samples(
advi ::ADVI,
q_trans::Bijectors.TransformedDistribution,
prob,
ηs ::AbstractMatrix
)
@unpack dist, transform = q_trans
Expand All @@ -78,35 +80,37 @@ function (advi::ADVI)(
end

"""
(advi::ADVI)(
[rng], prob, q; n_samples::Int = advi.n_samples
estimate_objective(
advi::ADVI, [rng], prob, q; n_samples::Int = advi.n_samples
)
Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples.
"""
function (advi::ADVI)(
function estimate_objective(
rng ::Random.AbstractRNG,
prob,
q ::ContinuousDistribution;
advi ::ADVI,
q ::ContinuousDistribution,
prob;
n_samples::Int = advi.n_samples
)
zs = rand(rng, q, n_samples)
advi(prob, q, zs)
estimate_objective_with_samples(advi, q, prob, zs)
end

function (advi::ADVI)(
function estimate_objective(
rng ::Random.AbstractRNG,
prob,
q_trans ::Bijectors.TransformedDistribution;
advi ::ADVI,
q_trans ::Bijectors.TransformedDistribution,
prob;
n_samples::Int = advi.n_samples
)
q = q_trans.dist
ηs = rand(rng, q, n_samples)
advi(prob, q_trans, ηs)
estimate_objective_with_samples(advi, q_trans, prob, ηs)
end

(advi::ADVI)(prob, q::Distribution; n_samples::Int = advi.n_samples) =
advi(Random.default_rng(), prob, q; n_samples)
estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) =
estimate_objective(Random.default_rng(), advi, q, prob; n_samples)

function estimate_gradient!(
rng ::Random.AbstractRNG,
Expand All @@ -122,7 +126,7 @@ function estimate_gradient!(
q_trans = restructure(λ′)
q = q_trans.dist
ηs = rand(rng, q, advi.n_samples)
-advi(prob, q_trans, ηs)
-estimate_objective_with_samples(advi, q_trans, prob, ηs)
end
value_and_gradient!(adbackend, f, λ, out)

Expand Down
12 changes: 6 additions & 6 deletions test/interface/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ using Test
obj = ADVI(10)

rng = StableRNG(seed)
elbo_ref = obj(rng, model, q₀_z; n_samples=10^4)
elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)

@testset "determinism" begin
rng = StableRNG(seed)
elbo = obj(rng, model, q₀_z; n_samples=10^4)
elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
@test elbo == elbo_ref
end

@testset "default_rng" begin
elbo = obj(model, q₀_z; n_samples=10^4)
elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4)
@test elbo elbo_ref rtol=0.1
end
end
Expand All @@ -39,16 +39,16 @@ using Test

obj = ADVI(10)
rng = StableRNG(seed)
elbo_ref = obj(rng, model, q₀_z; n_samples=10^4)
elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)

@testset "determinism" begin
rng = StableRNG(seed)
elbo = obj(rng, model, q₀_z; n_samples=10^4)
elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
@test elbo == elbo_ref
end

@testset "default_rng" begin
elbo = obj(model, q₀_z; n_samples=10^4)
elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4)
@test elbo elbo_ref rtol=0.1
end
end
Expand Down

0 comments on commit ff83c03

Please sign in to comment.