From a09db2ed8ae94dd8ddc67d7ee718c545ec99964b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 6 Jun 2025 13:25:06 -0400 Subject: [PATCH 001/108] refactor move files to algorithms/ --- src/AdvancedVI.jl | 6 +++--- .../elbo => algorithms/paramspacesgd}/entropy.jl | 0 .../elbo => algorithms/paramspacesgd}/repgradelbo.jl | 0 .../elbo => algorithms/paramspacesgd}/scoregradelbo.jl | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename src/{objectives/elbo => algorithms/paramspacesgd}/entropy.jl (100%) rename src/{objectives/elbo => algorithms/paramspacesgd}/repgradelbo.jl (100%) rename src/{objectives/elbo => algorithms/paramspacesgd}/scoregradelbo.jl (100%) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 85bbf2baf..2690ffee4 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -200,9 +200,9 @@ export RepGradELBO, ClosedFormEntropyZeroGradient, StickingTheLandingEntropyZeroGradient -include("objectives/elbo/entropy.jl") -include("objectives/elbo/repgradelbo.jl") -include("objectives/elbo/scoregradelbo.jl") +include("algorithms/paramspacesgd/entropy.jl") +include("algorithms/paramspacesgd/repgradelbo.jl") +include("algorithms/paramspacesgd/scoregradelbo.jl") # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian diff --git a/src/objectives/elbo/entropy.jl b/src/algorithms/paramspacesgd/entropy.jl similarity index 100% rename from src/objectives/elbo/entropy.jl rename to src/algorithms/paramspacesgd/entropy.jl diff --git a/src/objectives/elbo/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl similarity index 100% rename from src/objectives/elbo/repgradelbo.jl rename to src/algorithms/paramspacesgd/repgradelbo.jl diff --git a/src/objectives/elbo/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl similarity index 100% rename from src/objectives/elbo/scoregradelbo.jl rename to src/algorithms/paramspacesgd/scoregradelbo.jl From 08577cbdb094c5ae30210ae0423d977f0f7bf3fc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 6 Jun 2025 14:29:07 -0400 Subject: [PATCH 002/108] refactor move elbo-specific files into paramspacesgd/elbo/ --- src/AdvancedVI.jl | 79 +------------------ .../paramspacesgd/{ => elbo}/entropy.jl | 0 .../paramspacesgd/{ => elbo}/repgradelbo.jl | 0 .../paramspacesgd/{ => elbo}/scoregradelbo.jl | 0 src/algorithms/paramspacesgd/interface.jl | 79 +++++++++++++++++++ 5 files changed, 80 insertions(+), 78 deletions(-) rename src/algorithms/paramspacesgd/{ => elbo}/entropy.jl (100%) rename src/algorithms/paramspacesgd/{ => elbo}/repgradelbo.jl (100%) rename src/algorithms/paramspacesgd/{ => elbo}/scoregradelbo.jl (100%) create mode 100644 src/algorithms/paramspacesgd/interface.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 2690ffee4..c5dcf4994 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -94,84 +94,6 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) -# estimators -""" - AbstractVariationalObjective - -Abstract type for the VI algorithms supported by `AdvancedVI`. - -# Implementations -To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. -Also, it should provide gradients by implementing the function `estimate_gradient!`. -If the estimator is stateful, it can implement `init` to initialize the state. -""" -abstract type AbstractVariationalObjective end - -""" - init(rng, obj, adtype, prob, params, restructure) - -Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. -This function needs to be implemented only if `obj` is stateful. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `obj::AbstractVariationalObjective`: Variational objective. -` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -- `params`: Initial variational parameters. -- `restructure`: Function that reconstructs the variational approximation from `λ`. -""" -init( - ::Random.AbstractRNG, - ::AbstractVariationalObjective, - ::ADTypes.AbstractADType, - ::Any, - ::Any, - ::Any, -) = nothing - -""" - estimate_objective([rng,] obj, q, prob; kwargs...) - -Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `obj::AbstractVariationalObjective`: Variational objective. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `q`: Variational approximation. - -# Keyword Arguments -Depending on the objective, additional keyword arguments may apply. -Please refer to the respective documentation of each variational objective for more info. - -# Returns -- `obj_est`: Estimate of the objective value. -""" -function estimate_objective end - -export estimate_objective - -""" - estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) - -Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `obj::AbstractVariationalObjective`: Variational objective. -- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `params`: Variational parameters to evaluate the gradient on. -- `restructure`: Function that reconstructs the variational approximation from `params`. -- `obj_state`: Previous state of the objective. - -# Returns -- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. -- `obj_state`: The updated state of the objective. -- `stat::NamedTuple`: Statistics and logs generated during estimation. -""" -function estimate_gradient! end # ELBO-specific interfaces abstract type AbstractEntropyEstimator end @@ -200,6 +122,7 @@ export RepGradELBO, ClosedFormEntropyZeroGradient, StickingTheLandingEntropyZeroGradient +include("algorithms/paramspacesgd/interface.jl") include("algorithms/paramspacesgd/entropy.jl") include("algorithms/paramspacesgd/repgradelbo.jl") include("algorithms/paramspacesgd/scoregradelbo.jl") diff --git a/src/algorithms/paramspacesgd/entropy.jl b/src/algorithms/paramspacesgd/elbo/entropy.jl similarity index 100% rename from src/algorithms/paramspacesgd/entropy.jl rename to src/algorithms/paramspacesgd/elbo/entropy.jl diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/elbo/repgradelbo.jl similarity index 100% rename from src/algorithms/paramspacesgd/repgradelbo.jl rename to src/algorithms/paramspacesgd/elbo/repgradelbo.jl diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/elbo/scoregradelbo.jl similarity index 100% rename from src/algorithms/paramspacesgd/scoregradelbo.jl rename to src/algorithms/paramspacesgd/elbo/scoregradelbo.jl diff --git a/src/algorithms/paramspacesgd/interface.jl b/src/algorithms/paramspacesgd/interface.jl new file mode 100644 index 000000000..aee0e86b3 --- /dev/null +++ b/src/algorithms/paramspacesgd/interface.jl @@ -0,0 +1,79 @@ + +# estimators +""" + AbstractVariationalObjective + +Abstract type for the VI algorithms supported by `AdvancedVI`. + +# Implementations +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. +Also, it should provide gradients by implementing the function `estimate_gradient!`. +If the estimator is stateful, it can implement `init` to initialize the state. +""" +abstract type AbstractVariationalObjective end + +""" + init(rng, obj, adtype, prob, params, restructure) + +Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +This function needs to be implemented only if `obj` is stateful. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. +- `params`: Initial variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `λ`. +""" +init( + ::Random.AbstractRNG, + ::AbstractVariationalObjective, + ::ADTypes.AbstractADType, + ::Any, + ::Any, + ::Any, +) = nothing + +""" + estimate_objective([rng,] obj, q, prob; kwargs...) + +Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `q`: Variational approximation. + +# Keyword Arguments +Depending on the objective, additional keyword arguments may apply. +Please refer to the respective documentation of each variational objective for more info. + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective end + +export estimate_objective + +""" + estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) + +Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `params`: Variational parameters to evaluate the gradient on. +- `restructure`: Function that reconstructs the variational approximation from `params`. +- `obj_state`: Previous state of the objective. + +# Returns +- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective. +- `stat::NamedTuple`: Statistics and logs generated during estimation. +""" +function estimate_gradient! end From 7c53e63dfe9fe0729095d451c266ddd9be54ae61 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 6 Jun 2025 14:31:29 -0400 Subject: [PATCH 003/108] refactor move elbo-specific exports and interfaces to elbo.jl --- src/AdvancedVI.jl | 32 +---------------------- src/algorithms/paramspacesgd/elbo/elbo.jl | 31 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 31 deletions(-) create mode 100644 src/algorithms/paramspacesgd/elbo/elbo.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c5dcf4994..29d304d36 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -94,38 +94,8 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) - -# ELBO-specific interfaces -abstract type AbstractEntropyEstimator end - -""" - estimate_entropy(entropy_estimator, mc_samples, q, q_stop) - -Estimate the entropy of `q`. - -# Arguments -- `entropy_estimator`: Entropy estimation strategy. -- `q`: Variational approximation. -- `q_stop`: Variational approximation with detached from the automatic differentiation graph. -- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) - -# Returns -- `obj_est`: Estimate of the objective value. -""" -function estimate_entropy end - -export RepGradELBO, - ScoreGradELBO, - ClosedFormEntropy, - StickingTheLandingEntropy, - MonteCarloEntropy, - ClosedFormEntropyZeroGradient, - StickingTheLandingEntropyZeroGradient - include("algorithms/paramspacesgd/interface.jl") -include("algorithms/paramspacesgd/entropy.jl") -include("algorithms/paramspacesgd/repgradelbo.jl") -include("algorithms/paramspacesgd/scoregradelbo.jl") +include("algorithms/paramspacesgd/elbo/elbo.jl") # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian diff --git a/src/algorithms/paramspacesgd/elbo/elbo.jl b/src/algorithms/paramspacesgd/elbo/elbo.jl new file mode 100644 index 000000000..cca96e657 --- /dev/null +++ b/src/algorithms/paramspacesgd/elbo/elbo.jl @@ -0,0 +1,31 @@ + +# ELBO-specific interfaces +abstract type AbstractEntropyEstimator end + +""" + estimate_entropy(entropy_estimator, mc_samples, q, q_stop) + +Estimate the entropy of `q`. + +# Arguments +- `entropy_estimator`: Entropy estimation strategy. +- `q`: Variational approximation. +- `q_stop`: Variational approximation with detached from the automatic differentiation graph. +- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_entropy end + +include("repgradelbo.jl") +include("scoregradelbo.jl") +include("entropy.jl") + +export RepGradELBO, + ScoreGradELBO, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy, + ClosedFormEntropyZeroGradient, + StickingTheLandingEntropyZeroGradient From 5c71bf4d5d6cae09e34b4a6f9f68431dbbf6b7c8 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 10 Jun 2025 18:07:36 -0400 Subject: [PATCH 004/108] add `step` for `ParamSpaceSGD` --- src/AdvancedVI.jl | 27 +++- src/algorithms/paramspacesgd/paramspacesgd.jl | 115 ++++++++++++++++ src/optimize.jl | 129 ++++++------------ src/utils.jl | 34 ----- 4 files changed, 182 insertions(+), 123 deletions(-) create mode 100644 src/algorithms/paramspacesgd/paramspacesgd.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 29d304d36..a0b24ae77 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -94,9 +94,6 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) -include("algorithms/paramspacesgd/interface.jl") -include("algorithms/paramspacesgd/elbo/elbo.jl") - # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian @@ -188,6 +185,30 @@ include("optimization/proximal_location_scale_entropy.jl") export IdentityOperator, ClipScale, ProximalLocationScaleEntropy +# Algorithms + +abstract type AbstractAlgorithm end + +""" + init(rng, alg, q_init) + +Initialize `alg` given the initial variational approximation `q_init`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::AbstractAlgorithm`: Variational objective. +` `q_init`: Initial variational approximation. +""" +init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any) = nothing + +include("algorithms/paramspacesgd/interface.jl") +include("algorithms/paramspacesgd/paramspacesgd.jl") + +export ParamSpaceSGD + +# Parameter Space SGD Implementations +include("algorithms/paramspacesgd/elbo/elbo.jl") + # Main optimization routine function optimize end diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl new file mode 100644 index 000000000..b316e546c --- /dev/null +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -0,0 +1,115 @@ + +""" + ParamSpaceSGD( + problem, + objective::AbstractVariationalObjective, + adtype::ADTypes.AbstractADType, + optimizer::Optimisers.AbstractRule, + averager::AbstractAverager, + operator::AbstractOperator, + ) + +This algorithm applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters. + +The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. +This requires the variational approximation to be marked as a functor through `Functors.@functor`. + +# Arguments +- `adtype`: Automatic differentiation backend. +- `objective`: Variational Objective. +- `optimizer`: Optimizer used for inference. +- `averager` : Parameter averaging strategy. +- `operator` : Operator applied to the parameters after each optimization step. + +# Output +- `q_averaged`: The variational approximation formed from the averaged SGD iterates. + +# Callback +The callback function `callback` has a signature of + + callback(; rng, state, params, averaged_params, restructure, gradient) + +The arguments are as follows: +- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. +- `state`: Collection of the internal states used for optimization. +- `params`: Variational parameters. +- `averaged_params`: Variational parameters averaged according to the averaging strategy. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. +- `gradient`: The estimated (possibly stochastic) gradient. + +""" +struct ParamSpaceSGD{Prob, + Obj<:AbstractVariationalObjective, + AD<:ADTypes.AbstractADType, + Opt<:Optimisers.AbstractRule, + Avg<:AbstractAverager, + Op <:AbstractOperator, +} <: AbstractAlgorithm + problem::Prob + objective::Obj + adtype::AD + optimizer::Opt + averager::Avg + operator::Op +end + +struct ParamSpaceSGDState{Q, GradBuf, OptSt, ObjSt, AvgSt} + q::Q + iteration::Int + grad_buf::GradBuf + opt_st::OptSt + obj_st::ObjSt + avg_st::AvgSt +end + +function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init) + (; problem, adtype, optimizer, averager, objective) = alg + params, re = Optimisers.destructure(q_init) + opt_st = Optimisers.setup(optimizer, params) + obj_st = init(rng, objective, adtype, problem, params, re) + avg_st = init(averager, params) + grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + return ParamSpaceSGDState(q_init, 0, grad_buf, opt_st, obj_st, avg_st) +end + +function output(alg::ParamSpaceSGD, state) + params_avg = value(alg.averager, state.avg_st) + _, re = Optimisers.destructure(state.q) + return re(params_avg) +end + +function step( + rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, objargs...; kwargs... +) + (; adtype, problem, objective, operator, averager) = alg + (; q, iteration, grad_buf, opt_st, obj_st, avg_st) = state + + iteration += 1 + + params, re = Optimisers.destructure(q) + + grad_buf, obj_st, info = estimate_gradient!( + rng, objective, adtype, grad_buf, problem, params, re, obj_st, objargs... + ) + + grad = DiffResults.gradient(grad_buf) + opt_st, params = Optimisers.update!(opt_st, params, grad) + params = apply(operator, typeof(q), opt_st, params, re) + avg_st = apply(averager, avg_st, params) + + state = ParamSpaceSGDState(re(params), iteration, grad_buf, opt_st, obj_st, avg_st) + + if !isnothing(callback) + averaged_params = value(averager, avg_st) + info′ = callback(; + stat, + re, + params=params, + averaged_params=averaged_params, + gradient=grad, + state=state, + ) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end diff --git a/src/optimize.jl b/src/optimize.jl index 1be3181ae..fbf5aadc0 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -1,6 +1,14 @@ """ - optimize(problem, objective, q_init, max_iter, objargs...; kwargs...) + optimize( + [rng::Random.AbstractRNG = Random.default_rng(),] + problem + algorithm::AbstractAlgorithm, + q_init, + max_iter::Int, + objargs...; + kwargs... + ) Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. @@ -8,126 +16,75 @@ The trainable parameters in the variational approximation are expected to be ext This requires the variational approximation to be marked as a functor through `Functors.@functor`. # Arguments -- `objective::AbstractVariationalObjective`: Variational Objective. -- `q_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. +- `rng`: Random number generator. +- `algorithm`: Variational inference algorithm. +- `q_init`: Initial variational distribution. - `max_iter::Int`: Maximum number of iterations. - `objargs...`: Arguments to be passed to `objective`. # Keyword Arguments -- `adtype::ADtypes.AbstractADType`: Automatic differentiation backend. -- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) -- `averager::AbstractAverager` : Parameter averaging strategy. (Default: `NoAveraging()`) -- `operator::AbstractOperator` : Operator applied to the parameters after each optimization step. (Default: `IdentityOperator()`) -- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) - `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) -- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) -- `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) +- `progress::ProgressMeter.AbstractProgress`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) +- `state::Union{<:Any,Nothing}`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) # Returns -- `averaged_params`: Variational parameters generated by the algorithm averaged according to `averager`. -- `params`: Last variational parameters generated by the algorithm. -- `stats`: Statistics gathered during optimization. +- `output`: The output of the variational inference algorithm. +- `info`: Array of `NamedTuple`s, where each `NamedTuple` contains information generated at each iteration. - `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. # Callback -The callback function `callback` has a signature of - - callback(; stat, state, params, averaged_params, restructure, gradient) - -The arguments are as follows: -- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. -- `state`: Collection of the internal states used for optimization. -- `params`: Variational parameters. -- `averaged_params`: Variational parameters averaged according to the averaging strategy. -- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. -- `gradient`: The estimated (possibly stochastic) gradient. - -`callback` can return a `NamedTuple` containing some additional information computed within `cb`. -This will be appended to the statistic of the current corresponding iteration. -Otherwise, just return `nothing`. - +The signature of the callback function depends on the `algorithm` in use. +Thus, see the documentation for each `algorithm`. +However, a callback should return either a `nothing` or a `NamedTuple` containing information generated during the current iteration. +The content of the `NamedTuple` will be concatenated into the corresponding entry in the `info` array returns in the end of the call to `optimize` and will be displayed on the progress meter. """ function optimize( rng::Random.AbstractRNG, - problem, - objective::AbstractVariationalObjective, - q_init, + algorithm::AbstractAlgorithm, max_iter::Int, + q_init, objargs...; - adtype::ADTypes.AbstractADType, - optimizer::Optimisers.AbstractRule=Optimisers.Adam(), - averager::AbstractAverager=NoAveraging(), - operator=IdentityOperator(), show_progress::Bool=true, - state_init::NamedTuple=NamedTuple(), + state::Union{<:Any,Nothing}=nothing, callback=nothing, - prog=ProgressMeter.Progress( + progress::ProgressMeter.AbstractProgress=ProgressMeter.Progress( max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=show_progress ), + kwargs... ) - params, restructure = Optimisers.destructure(deepcopy(q_init)) - opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective( - state_init, rng, objective, adtype, problem, params, restructure - ) - avg_st = maybe_init_averager(state_init, averager, params) - grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - stats = NamedTuple[] + info_total = NamedTuple[] + state = if isnothing(state) + init(rng, algorithm, q_init) + else + state + end for t in 1:max_iter - stat = (iteration=t,) - grad_buf, obj_st, stat′ = estimate_gradient!( - rng, - objective, - adtype, - grad_buf, - problem, - params, - restructure, - obj_st, - objargs..., - ) - stat = merge(stat, stat′) + info = (iteration=t,) - grad = DiffResults.gradient(grad_buf) - opt_st, params = Optimisers.update!(opt_st, params, grad) - params = apply(operator, typeof(q_init), opt_st, params, restructure) - avg_st = apply(averager, avg_st, params) + state, terminate, info′ = step(rng, algorithm, state, callback, objargs...; kwargs...) + info = merge(info′, info) - if !isnothing(callback) - averaged_params = value(averager, avg_st) - stat′ = callback(; - stat, - restructure, - params=params, - averaged_params=averaged_params, - gradient=grad, - state=(optimizer=opt_st, averager=avg_st, objective=obj_st), - ) - stat = !isnothing(stat′) ? merge(stat′, stat) : stat + if terminate + break end - @debug "Iteration $t" stat... - - pm_next!(prog, stat) - push!(stats, stat) + pm_next!(progress, info) + push!(info_total, info) end - state = (optimizer=opt_st, averager=avg_st, objective=obj_st) - stats = map(identity, stats) - averaged_params = value(averager, avg_st) - return restructure(averaged_params), restructure(params), stats, state + out = output(algorithm, state) + return out, map(identity, info_total), state end function optimize( - problem, - objective::AbstractVariationalObjective, - q_init, + algorithm::AbstractAlgorithm, max_iter::Int, + q_init, objargs...; kwargs..., ) return optimize( - Random.default_rng(), problem, objective, q_init, max_iter, objargs...; kwargs... + Random.default_rng(), algorithm, max_iter, q_init, objargs...; kwargs... ) end diff --git a/src/utils.jl b/src/utils.jl index 7fbe4424c..6481d68d5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,40 +3,6 @@ function pm_next!(pm, stats::NamedTuple) return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) end -function maybe_init_optimizer( - state_init::NamedTuple, optimizer::Optimisers.AbstractRule, params -) - if haskey(state_init, :optimizer) - state_init.optimizer - else - Optimisers.setup(optimizer, params) - end -end - -function maybe_init_averager(state_init::NamedTuple, averager::AbstractAverager, params) - if haskey(state_init, :averager) - state_init.averager - else - init(averager, params) - end -end - -function maybe_init_objective( - state_init::NamedTuple, - rng::Random.AbstractRNG, - objective::AbstractVariationalObjective, - adtype::ADTypes.AbstractADType, - problem, - params, - restructure, -) - if haskey(state_init, :objective) - state_init.objective - else - init(rng, objective, adtype, problem, params, restructure) - end -end - eachsample(samples::AbstractMatrix) = eachcol(samples) function catsamples_and_acc( From f9560e74726f0dcb8521f88f00c2d703de0a3c31 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 10 Jun 2025 18:07:48 -0400 Subject: [PATCH 005/108] increment version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8190b03ec..6ef7a08c6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.4.1" +version = "0.5.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From b887cd2bc6d9ecf47511127fd766a26b3a239d95 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 18 Jun 2025 15:33:59 -0400 Subject: [PATCH 006/108] fix subtyping error on ClipScale --- src/optimization/clip_scale.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index 5a4c52f60..e8a34e79e 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -5,10 +5,12 @@ Projection operator ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`. `ClipScale` also supports by operating on `MvLocationScale` and `MvLocationScaleLowRank` wrapped by a `Bijectors.TransformedDistribution` object. """ -Optimisers.@def struct ClipScale <: AbstractOperator - epsilon = 1e-5 +struct ClipScale <: AbstractOperator + epsilon end +ClipScale() = ClipScale(1e-5) + function apply(::ClipScale, family::Type, state, params, restructure) return error("`ClipScale` is not defined for the variational family of type $(family).") end From 64848fabd3e61cc01251a042ff7db140921a7949 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 18 Jun 2025 15:34:16 -0400 Subject: [PATCH 007/108] fix signature of `callback` in `ParamSpaceSGD` --- src/algorithms/paramspacesgd/paramspacesgd.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index b316e546c..d0540459c 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -27,7 +27,7 @@ This requires the variational approximation to be marked as a functor through `F # Callback The callback function `callback` has a signature of - callback(; rng, state, params, averaged_params, restructure, gradient) + callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) The arguments are as follows: - `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. @@ -102,8 +102,9 @@ function step( if !isnothing(callback) averaged_params = value(averager, avg_st) info′ = callback(; - stat, - re, + rng, + iteration, + restructure=re, params=params, averaged_params=averaged_params, gradient=grad, From 116e22c33a3905c280957ca6fb707de5828defa7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 18 Jun 2025 15:51:53 -0400 Subject: [PATCH 008/108] fix tests to match new interface --- test/inference/repgradelbo_distributionsad.jl | 45 ++++----------- test/inference/repgradelbo_locationscale.jl | 49 ++++------------ .../repgradelbo_locationscale_bijectors.jl | 39 ++----------- .../repgradelbo_proximal_locationscale.jl | 56 +++++-------------- ...adelbo_proximal_locationscale_bijectors.jl | 42 ++------------ .../scoregradelbo_distributionsad.jl | 36 ++---------- test/inference/scoregradelbo_locationscale.jl | 49 ++++------------ .../scoregradelbo_locationscale_bijectors.jl | 39 ++----------- test/interface/clip_scale.jl | 8 +-- test/interface/optimize.jl | 42 ++++---------- .../proximal_location_scale_entropy.jl | 2 +- test/interface/repgradelbo.jl | 18 +++--- test/interface/scoregradelbo.jl | 18 +++--- 13 files changed, 99 insertions(+), 344 deletions(-) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index eb367b696..a5b8f884b 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -32,31 +32,26 @@ end modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + μ0 = zeros(realtype, n_dims) + L0 = Diagonal(ones(realtype, n_dims)) + q0 = TuringDiagMvNormal(μ0, diag(L0)) + T = 1000 η = 1e-3 opt = Optimisers.Descent(η) + avg = PolynomialAveraging() + op = IdentityOperator() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. contraction_rate = 1 - η * strong_convexity - μ0 = zeros(realtype, n_dims) - L0 = Diagonal(ones(realtype, n_dims)) - q0 = TuringDiagMvNormal(μ0, diag(L0)) - @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - show_progress=PROGRESS, - adtype=adtype, - ) + + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) @@ -69,30 +64,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0;show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0, - T; - optimizer=opt, - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) μ_repl = mean(q_avg) L_repl = sqrt(cov(q_avg)) @test μ == μ_repl diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index e0e7476b5..fb3b9059a 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -35,11 +35,9 @@ end T = 1000 η = 1e-3 opt = Optimisers.Descent(η) - - # For small enough η, the error of SGD, Δλ, is bounded as - # Δλ ≤ ρ^T Δλ0 + O(η), - # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η * strong_convexity + avg = PolynomialAveraging() + op = ClipScale() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) @@ -48,19 +46,14 @@ end FullRankGaussian(zeros(realtype, n_dims), L0) end + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -73,32 +66,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 24edf5678..a6a676e1f 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -35,6 +35,9 @@ end T = 1000 η = 1e-3 opt = Optimisers.Descent(η) + op = ClipScale() + avg = PolynomialAveraging() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -56,17 +59,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0_z, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -79,32 +72,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0_z, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0_z, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0_z; show_progress=PROGRESS) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ == μ_repl diff --git a/test/inference/repgradelbo_proximal_locationscale.jl b/test/inference/repgradelbo_proximal_locationscale.jl index 620cceca7..8a0693288 100644 --- a/test/inference/repgradelbo_proximal_locationscale.jl +++ b/test/inference/repgradelbo_proximal_locationscale.jl @@ -34,36 +34,28 @@ end modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + T = 1000 η = 1e-3 opt = DoWG(1.0) + avg = PolynomialAveraging() + op = ProximalLocationScaleEntropy() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), # where ρ = 1 - ημ, μ is the strong convexity constant. contraction_rate = 1 - η * strong_convexity - q0 = if is_meanfield - MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) - else - L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) - FullRankGaussian(zeros(realtype, n_dims), L0) - end - @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - averager=PolynomialAveraging(), - operator=ProximalLocationScaleEntropy(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -76,34 +68,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - averager=PolynomialAveraging(), - operator=ProximalLocationScaleEntropy(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0, - T; - optimizer=opt, - averager=PolynomialAveraging(), - operator=ProximalLocationScaleEntropy(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl index 169113571..ebea55c11 100644 --- a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl @@ -36,6 +36,9 @@ end T = 1000 η = 1e-3 opt = DoWG(1.0) + op = ProximalLocationScaleEntropy() + avg = PolynomialAveraging() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -57,18 +60,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0_z, - T; - optimizer=opt, - averager=PolynomialAveraging(), - operator=ProximalLocationScaleEntropy(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -81,34 +73,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0_z, - T; - optimizer=opt, - averager=PolynomialAveraging(), - operator=ProximalLocationScaleEntropy(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0_z, - T; - optimizer=opt, - averager=PolynomialAveraging(), - operator=ProximalLocationScaleEntropy(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0_z; show_progress=PROGRESS) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ == μ_repl diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index e58d479f6..81cdc13d8 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -30,6 +30,9 @@ end T = 1000 η = 1e-4 opt = Optimisers.Descent(η) + op = IdentityOperator() + avg = PolynomialAveraging() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -42,16 +45,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) @@ -64,30 +58,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0, - T; - optimizer=opt, - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) μ_repl = mean(q_avg) L_repl = sqrt(cov(q_avg)) @test μ ≈ μ_repl rtol = 1e-5 diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index 12f8756d7..6ec738905 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -32,11 +32,9 @@ end T = 1000 η = 1e-4 opt = Optimisers.Descent(η) - - # For small enough η, the error of SGD, Δλ, is bounded as - # Δλ ≤ ρ^T Δλ0 + O(η), - # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η * strong_convexity + op = ClipScale() + avg = PolynomialAveraging() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) @@ -45,19 +43,14 @@ end FullRankGaussian(zeros(realtype, n_dims), L0) end + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -70,32 +63,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ ≈ μ_repl rtol = 1e-3 diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/inference/scoregradelbo_locationscale_bijectors.jl index 734a7dbe7..5d31e6e92 100644 --- a/test/inference/scoregradelbo_locationscale_bijectors.jl +++ b/test/inference/scoregradelbo_locationscale_bijectors.jl @@ -31,6 +31,9 @@ end T = 1000 η = 1e-4 opt = Optimisers.Descent(η) + op = ClipScale() + avg = PolynomialAveraging() + alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -52,17 +55,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0_z, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -75,32 +68,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng, - model, - objective, - q0_z, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, _, stats, _ = optimize( - rng_repl, - model, - objective, - q0_z, - T; - optimizer=opt, - operator=ClipScale(), - show_progress=PROGRESS, - adtype=adtype, - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, q0_z; show_progress=PROGRESS) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ ≈ μ_repl rtol = 1e-3 diff --git a/test/interface/clip_scale.jl b/test/interface/clip_scale.jl index e2679f5fa..42582e428 100644 --- a/test/interface/clip_scale.jl +++ b/test/interface/clip_scale.jl @@ -23,9 +23,7 @@ end params, re = Optimisers.destructure(q) - opt_st = AdvancedVI.maybe_init_optimizer( - NamedTuple(), Optimisers.Descent(1e-2), params - ) + opt_st = Optimisers.setup(Descent(1e-2), params) params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), opt_st, params, re) q′ = re(params′) @@ -57,9 +55,7 @@ end params, re = Optimisers.destructure(q) - opt_st = AdvancedVI.maybe_init_optimizer( - NamedTuple(), Optimisers.Descent(1e-2), params - ) + opt_st = Optimisers.setup(Descent(1e-2), params) params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), opt_st, params, re) q′ = re(params′) diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 70944004c..a82c3cde4 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -1,6 +1,4 @@ -using Test - @testset "interface optimize" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -10,35 +8,30 @@ using Test (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - # Global Test Configurations - q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) adtype = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() + alg = ParamSpaceSGD(model, obj, adtype, optimizer, averager, IdentityOperator()) + @testset "default_rng" begin - optimize(model, obj, q0, T; optimizer, averager, show_progress=false, adtype) + optimize(alg, T, q0; show_progress=false) end @testset "callback" begin - rng = StableRNG(seed) test_values = rand(rng, T) - callback(; stat, args...) = (test_value=test_values[stat.iteration],) + callback(; iteration, args...) = (test_value=test_values[iteration],) - rng = StableRNG(seed) - _, _, stats, _ = optimize( - rng, model, obj, q0, T; show_progress=false, adtype, callback - ) - @test [stat.test_value for stat in stats] == test_values + _, info, _ = optimize(alg, T, q0; show_progress=false, callback) + @test [i.test_value for i in info] == test_values end rng = StableRNG(seed) - q_avg_ref, q_ref, _, _ = optimize( - rng, model, obj, q0, T; optimizer, averager, show_progress=false, adtype - ) + q_avg_ref, _, _ = optimize(rng, alg, T, q0; show_progress=false) @testset "warm start" begin rng = StableRNG(seed) @@ -46,24 +39,9 @@ using Test T_first = div(T, 2) T_last = T - T_first - _, q_first, _, state = optimize( - rng, model, obj, q0, T_first; optimizer, averager, show_progress=false, adtype - ) - - q_avg, q, _, _ = optimize( - rng, - model, - obj, - q_first, - T_last; - optimizer, - averager, - show_progress=false, - state_init=state, - adtype, - ) + _, _, state = optimize(rng, alg, T_first, q0; show_progress=false) + q_avg, _, _ = optimize(rng, alg, T_last, q0; show_progress=false, state) - @test q == q_ref @test q_avg == q_avg_ref end end diff --git a/test/interface/proximal_location_scale_entropy.jl b/test/interface/proximal_location_scale_entropy.jl index 2f53ec858..ddbbf7300 100644 --- a/test/interface/proximal_location_scale_entropy.jl +++ b/test/interface/proximal_location_scale_entropy.jl @@ -43,7 +43,7 @@ # This unit test will check that this equation is satisfied. params, re = Optimisers.destructure(q) - opt_st = AdvancedVI.maybe_init_optimizer(NamedTuple(), optimizer, params) + opt_st = Optimisers.setup(optimizer, params) params′ = AdvancedVI.apply( ProximalLocationScaleEntropy(), typeof(q), opt_st, params, re ) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 9a5cc4607..976b389de 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -27,18 +27,16 @@ end @testset "basic" begin @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] - obj = RepGradELBO(n_montecarlo) - _, _, stats, _ = optimize( - rng, + alg = ParamSpaceSGD( model, - obj, - q0, - 10; - optimizer=Descent(1e-5), - show_progress=false, - adtype=adtype, + RepGradELBO(n_montecarlo), + adtype, + Descent(1e-5), + PolynomialAveraging(), + IdentityOperator() ) - @assert isfinite(last(stats).elbo) + _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) + @assert isfinite(last(info).elbo) end end diff --git a/test/interface/scoregradelbo.jl b/test/interface/scoregradelbo.jl index f368626e7..6867bcfe8 100644 --- a/test/interface/scoregradelbo.jl +++ b/test/interface/scoregradelbo.jl @@ -22,18 +22,16 @@ end @testset "basic" begin @testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10] - obj = ScoreGradELBO(n_montecarlo) - _, _, stats, _ = optimize( - rng, + alg = ParamSpaceSGD( model, - obj, - q0, - 10; - optimizer=Descent(1e-5), - show_progress=false, - adtype=adtype, + ScoreGradELBO(n_montecarlo), + adtype, + Descent(1e-5), + PolynomialAveraging(), + IdentityOperator() ) - @assert isfinite(last(stats).elbo) + _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) + @assert isfinite(last(info).elbo) end end From 705748bb6c08568fd02f190c0f1e4668ea17398c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 18 Jun 2025 15:58:31 -0400 Subject: [PATCH 009/108] refactor restructure `test/` to match new structure in `src/` --- .../paramspacesgd}/repgradelbo.jl | 0 .../repgradelbo_distributionsad.jl | 0 .../repgradelbo_locationscale.jl | 0 .../repgradelbo_locationscale_bijectors.jl | 0 .../repgradelbo_proximal_locationscale.jl | 0 ...adelbo_proximal_locationscale_bijectors.jl | 0 .../paramspacesgd}/scoregradelbo.jl | 0 .../scoregradelbo_distributionsad.jl | 0 .../scoregradelbo_locationscale.jl | 0 .../scoregradelbo_locationscale_bijectors.jl | 0 test/{interface => general}/ad.jl | 0 test/{interface => general}/averaging.jl | 0 test/{interface => general}/clip_scale.jl | 0 test/{interface => general}/optimize.jl | 0 .../proximal_location_scale_entropy.jl | 0 test/{interface => general}/rules.jl | 0 test/runtests.jl | 39 +++++++++---------- 17 files changed, 19 insertions(+), 20 deletions(-) rename test/{interface => algorithms/paramspacesgd}/repgradelbo.jl (100%) rename test/{inference => algorithms/paramspacesgd}/repgradelbo_distributionsad.jl (100%) rename test/{inference => algorithms/paramspacesgd}/repgradelbo_locationscale.jl (100%) rename test/{inference => algorithms/paramspacesgd}/repgradelbo_locationscale_bijectors.jl (100%) rename test/{inference => algorithms/paramspacesgd}/repgradelbo_proximal_locationscale.jl (100%) rename test/{inference => algorithms/paramspacesgd}/repgradelbo_proximal_locationscale_bijectors.jl (100%) rename test/{interface => algorithms/paramspacesgd}/scoregradelbo.jl (100%) rename test/{inference => algorithms/paramspacesgd}/scoregradelbo_distributionsad.jl (100%) rename test/{inference => algorithms/paramspacesgd}/scoregradelbo_locationscale.jl (100%) rename test/{inference => algorithms/paramspacesgd}/scoregradelbo_locationscale_bijectors.jl (100%) rename test/{interface => general}/ad.jl (100%) rename test/{interface => general}/averaging.jl (100%) rename test/{interface => general}/clip_scale.jl (100%) rename test/{interface => general}/optimize.jl (100%) rename test/{interface => general}/proximal_location_scale_entropy.jl (100%) rename test/{interface => general}/rules.jl (100%) diff --git a/test/interface/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl similarity index 100% rename from test/interface/repgradelbo.jl rename to test/algorithms/paramspacesgd/repgradelbo.jl diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl similarity index 100% rename from test/inference/repgradelbo_distributionsad.jl rename to test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl diff --git a/test/inference/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl similarity index 100% rename from test/inference/repgradelbo_locationscale.jl rename to test/algorithms/paramspacesgd/repgradelbo_locationscale.jl diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl similarity index 100% rename from test/inference/repgradelbo_locationscale_bijectors.jl rename to test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl diff --git a/test/inference/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl similarity index 100% rename from test/inference/repgradelbo_proximal_locationscale.jl rename to test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl diff --git a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl similarity index 100% rename from test/inference/repgradelbo_proximal_locationscale_bijectors.jl rename to test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl diff --git a/test/interface/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl similarity index 100% rename from test/interface/scoregradelbo.jl rename to test/algorithms/paramspacesgd/scoregradelbo.jl diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl similarity index 100% rename from test/inference/scoregradelbo_distributionsad.jl rename to test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl similarity index 100% rename from test/inference/scoregradelbo_locationscale.jl rename to test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl diff --git a/test/inference/scoregradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl similarity index 100% rename from test/inference/scoregradelbo_locationscale_bijectors.jl rename to test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl diff --git a/test/interface/ad.jl b/test/general/ad.jl similarity index 100% rename from test/interface/ad.jl rename to test/general/ad.jl diff --git a/test/interface/averaging.jl b/test/general/averaging.jl similarity index 100% rename from test/interface/averaging.jl rename to test/general/averaging.jl diff --git a/test/interface/clip_scale.jl b/test/general/clip_scale.jl similarity index 100% rename from test/interface/clip_scale.jl rename to test/general/clip_scale.jl diff --git a/test/interface/optimize.jl b/test/general/optimize.jl similarity index 100% rename from test/interface/optimize.jl rename to test/general/optimize.jl diff --git a/test/interface/proximal_location_scale_entropy.jl b/test/general/proximal_location_scale_entropy.jl similarity index 100% rename from test/interface/proximal_location_scale_entropy.jl rename to test/general/proximal_location_scale_entropy.jl diff --git a/test/interface/rules.jl b/test/general/rules.jl similarity index 100% rename from test/interface/rules.jl rename to test/general/rules.jl diff --git a/test/runtests.jl b/test/runtests.jl index 495af19c1..5778ac7c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,7 @@ using ForwardDiff, ReverseDiff, Zygote, Mooncake using AdvancedVI +const PROGRESS = haskey(ENV, "PROGRESS") const TEST_GROUP = get(ENV, "TEST_GROUP", "All") if TEST_GROUP == "Enzyme" @@ -43,20 +44,18 @@ end include("models/normal.jl") include("models/normallognormal.jl") -if TEST_GROUP == "All" || TEST_GROUP == "Interface" +if TEST_GROUP == "All" || TEST_GROUP == "General" # Interface tests that do not involve testing on Enzyme - include("interface/optimize.jl") - include("interface/rules.jl") - include("interface/averaging.jl") - include("interface/scoregradelbo.jl") - include("interface/clip_scale.jl") - include("interface/proximal_location_scale_entropy.jl") + include("general/optimize.jl") + include("general/rules.jl") + include("general/averaging.jl") + include("general/clip_scale.jl") + include("general/proximal_location_scale_entropy.jl") end -if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme" +if TEST_GROUP == "All" || TEST_GROUP == "General" || TEST_GROUP == "Enzyme" # Interface tests that involve testing on Enzyme include("interface/ad.jl") - include("interface/repgradelbo.jl") end if TEST_GROUP == "All" || TEST_GROUP == "Families" @@ -64,15 +63,15 @@ if TEST_GROUP == "All" || TEST_GROUP == "Families" include("families/location_scale_low_rank.jl") end -const PROGRESS = haskey(ENV, "PROGRESS") - -if TEST_GROUP == "All" || TEST_GROUP == "Inference" || TEST_GROUP == "Enzyme" - include("inference/repgradelbo_distributionsad.jl") - include("inference/repgradelbo_locationscale.jl") - include("inference/repgradelbo_locationscale_bijectors.jl") - include("inference/repgradelbo_proximal_locationscale.jl") - include("inference/repgradelbo_proximal_locationscale_bijectors.jl") - include("inference/scoregradelbo_distributionsad.jl") - include("inference/scoregradelbo_locationscale.jl") - include("inference/scoregradelbo_locationscale_bijectors.jl") +if TEST_GROUP == "All" || TEST_GROUP == "ParamSpaceSGD" || TEST_GROUP == "Enzyme" + include("algorithms/paramspacesgd/repgradelbo.jl") + include("algorithms/paramspacesgd/scoregradelbo.jl") + include("algorithms/paramspacesgd/repgradelbo_distributionsad.jl") + include("algorithms/paramspacesgd/repgradelbo_locationscale.jl") + include("algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl") + include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl") + include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl") + include("algorithms/paramspacesgd/scoregradelbo_distributionsad.jl") + include("algorithms/paramspacesgd/scoregradelbo_locationscale.jl") + include("algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl") end From 29f9109127ac5a34cd838db6dfa9b3e200cd2cf0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 18 Jun 2025 15:59:28 -0400 Subject: [PATCH 010/108] run formatter --- .../paramspacesgd/elbo/repgradelbo.jl | 2 +- src/algorithms/paramspacesgd/interface.jl | 6 ++++-- src/algorithms/paramspacesgd/paramspacesgd.jl | 21 ++++++++++--------- src/optimize.jl | 12 +++++------ test/algorithms/paramspacesgd/repgradelbo.jl | 2 +- .../repgradelbo_distributionsad.jl | 4 ++-- .../repgradelbo_locationscale_bijectors.jl | 2 +- ...adelbo_proximal_locationscale_bijectors.jl | 2 +- .../algorithms/paramspacesgd/scoregradelbo.jl | 2 +- .../scoregradelbo_distributionsad.jl | 2 +- test/general/optimize.jl | 2 +- 11 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/algorithms/paramspacesgd/elbo/repgradelbo.jl b/src/algorithms/paramspacesgd/elbo/repgradelbo.jl index c6aa9135c..0ff3b24b7 100644 --- a/src/algorithms/paramspacesgd/elbo/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/elbo/repgradelbo.jl @@ -158,6 +158,6 @@ function estimate_gradient!( estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux ) nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) + stat = (elbo=(-nelbo),) return out, state, stat end diff --git a/src/algorithms/paramspacesgd/interface.jl b/src/algorithms/paramspacesgd/interface.jl index aee0e86b3..b67355c8a 100644 --- a/src/algorithms/paramspacesgd/interface.jl +++ b/src/algorithms/paramspacesgd/interface.jl @@ -25,14 +25,16 @@ This function needs to be implemented only if `obj` is stateful. - `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ -init( +function init( ::Random.AbstractRNG, ::AbstractVariationalObjective, ::ADTypes.AbstractADType, ::Any, ::Any, ::Any, -) = nothing +) + nothing +end """ estimate_objective([rng,] obj, q, prob; kwargs...) diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index d0540459c..25aeb48d4 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -38,12 +38,13 @@ The arguments are as follows: - `gradient`: The estimated (possibly stochastic) gradient. """ -struct ParamSpaceSGD{Prob, - Obj<:AbstractVariationalObjective, - AD<:ADTypes.AbstractADType, - Opt<:Optimisers.AbstractRule, - Avg<:AbstractAverager, - Op <:AbstractOperator, +struct ParamSpaceSGD{ + Prob, + Obj<:AbstractVariationalObjective, + AD<:ADTypes.AbstractADType, + Opt<:Optimisers.AbstractRule, + Avg<:AbstractAverager, + Op<:AbstractOperator, } <: AbstractAlgorithm problem::Prob objective::Obj @@ -53,7 +54,7 @@ struct ParamSpaceSGD{Prob, operator::Op end -struct ParamSpaceSGDState{Q, GradBuf, OptSt, ObjSt, AvgSt} +struct ParamSpaceSGDState{Q,GradBuf,OptSt,ObjSt,AvgSt} q::Q iteration::Int grad_buf::GradBuf @@ -65,9 +66,9 @@ end function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init) (; problem, adtype, optimizer, averager, objective) = alg params, re = Optimisers.destructure(q_init) - opt_st = Optimisers.setup(optimizer, params) - obj_st = init(rng, objective, adtype, problem, params, re) - avg_st = init(averager, params) + opt_st = Optimisers.setup(optimizer, params) + obj_st = init(rng, objective, adtype, problem, params, re) + avg_st = init(averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) return ParamSpaceSGDState(q_init, 0, grad_buf, opt_st, obj_st, avg_st) end diff --git a/src/optimize.jl b/src/optimize.jl index fbf5aadc0..e5eb30282 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -51,7 +51,7 @@ function optimize( progress::ProgressMeter.AbstractProgress=ProgressMeter.Progress( max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=show_progress ), - kwargs... + kwargs..., ) info_total = NamedTuple[] state = if isnothing(state) @@ -63,7 +63,9 @@ function optimize( for t in 1:max_iter info = (iteration=t,) - state, terminate, info′ = step(rng, algorithm, state, callback, objargs...; kwargs...) + state, terminate, info′ = step( + rng, algorithm, state, callback, objargs...; kwargs... + ) info = merge(info′, info) if terminate @@ -78,11 +80,7 @@ function optimize( end function optimize( - algorithm::AbstractAlgorithm, - max_iter::Int, - q_init, - objargs...; - kwargs..., + algorithm::AbstractAlgorithm, max_iter::Int, q_init, objargs...; kwargs... ) return optimize( Random.default_rng(), algorithm, max_iter, q_init, objargs...; kwargs... diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 976b389de..919e0b0f0 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -33,7 +33,7 @@ end adtype, Descent(1e-5), PolynomialAveraging(), - IdentityOperator() + IdentityOperator(), ) _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) @assert isfinite(last(info).elbo) diff --git a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl index a5b8f884b..050a46d41 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl @@ -40,7 +40,7 @@ end η = 1e-3 opt = Optimisers.Descent(η) avg = PolynomialAveraging() - op = IdentityOperator() + op = IdentityOperator() alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) # For small enough η, the error of SGD, Δλ, is bounded as @@ -64,7 +64,7 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0;show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index a6a676e1f..346c78286 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -35,7 +35,7 @@ end T = 1000 η = 1e-3 opt = Optimisers.Descent(η) - op = ClipScale() + op = ClipScale() avg = PolynomialAveraging() alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index ebea55c11..c3d886713 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -36,7 +36,7 @@ end T = 1000 η = 1e-3 opt = DoWG(1.0) - op = ProximalLocationScaleEntropy() + op = ProximalLocationScaleEntropy() avg = PolynomialAveraging() alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index 6867bcfe8..366455876 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -28,7 +28,7 @@ end adtype, Descent(1e-5), PolynomialAveraging(), - IdentityOperator() + IdentityOperator(), ) _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) @assert isfinite(last(info).elbo) diff --git a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl index 81cdc13d8..805c65a83 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl @@ -30,7 +30,7 @@ end T = 1000 η = 1e-4 opt = Optimisers.Descent(η) - op = IdentityOperator() + op = IdentityOperator() avg = PolynomialAveraging() alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) diff --git a/test/general/optimize.jl b/test/general/optimize.jl index a82c3cde4..1f98edf5d 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -8,7 +8,7 @@ (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) + q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) adtype = AutoForwardDiff() From db6b6949e233853de6c0488a6ac4cba2a0280062 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Jun 2025 16:55:39 -0400 Subject: [PATCH 011/108] fix wrong path for test file --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5778ac7c2..04bc841c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,7 +55,7 @@ end if TEST_GROUP == "All" || TEST_GROUP == "General" || TEST_GROUP == "Enzyme" # Interface tests that involve testing on Enzyme - include("interface/ad.jl") + include("general/ad.jl") end if TEST_GROUP == "All" || TEST_GROUP == "Families" From 3609bc677b5ecbab1bc98c9cdc1dcffc3b21ff66 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Jun 2025 17:24:28 -0400 Subject: [PATCH 012/108] re-organized project --- src/AdvancedVI.jl | 77 ++++++++++++++++-- src/algorithms/paramspacesgd/constructors.jl | 15 ++++ src/algorithms/paramspacesgd/elbo/elbo.jl | 31 ------- .../paramspacesgd/{elbo => }/entropy.jl | 0 src/algorithms/paramspacesgd/interface.jl | 81 ------------------- src/algorithms/paramspacesgd/paramspacesgd.jl | 2 +- .../paramspacesgd/{elbo => }/repgradelbo.jl | 0 .../paramspacesgd/{elbo => }/scoregradelbo.jl | 0 8 files changed, 87 insertions(+), 119 deletions(-) create mode 100644 src/algorithms/paramspacesgd/constructors.jl delete mode 100644 src/algorithms/paramspacesgd/elbo/elbo.jl rename src/algorithms/paramspacesgd/{elbo => }/entropy.jl (100%) delete mode 100644 src/algorithms/paramspacesgd/interface.jl rename src/algorithms/paramspacesgd/{elbo => }/repgradelbo.jl (100%) rename src/algorithms/paramspacesgd/{elbo => }/scoregradelbo.jl (100%) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a0b24ae77..b1292b948 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -196,18 +196,42 @@ Initialize `alg` given the initial variational approximation `q_init`. # Arguments - `rng::Random.AbstractRNG`: Random number generator. -- `alg::AbstractAlgorithm`: Variational objective. +- `alg::AbstractAlgorithm`: Variational inference algorithm. ` `q_init`: Initial variational approximation. """ init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any) = nothing -include("algorithms/paramspacesgd/interface.jl") -include("algorithms/paramspacesgd/paramspacesgd.jl") +""" + step(rng, alg, state, callback, objargs...; kwargs...) -export ParamSpaceSGD +Perform a single step of `alg` given the previous `stat`. -# Parameter Space SGD Implementations -include("algorithms/paramspacesgd/elbo/elbo.jl") +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::AbstractAlgorithm`: Variational inference algorithm. +- `state`: Previous state of the algorithm. +- `callback`: Callback function to be called during the step. + +# Returns +- state: New state generated by performing the step. +- terminate::Bool: Whether to terminate the algorithm after the step. +- info::NamedTuple: Information generated during the step. +""" +step(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, callback, objargs...; kwargs...) = nothing + +""" + output(alg, state) + +Generate an output variational approximation using the last `state` of `alg`. + +# Arguments +- `alg::AbstractAlgorithm`: Variational inference algorithm used to compute the state. +- `state`: The last state generated by the algorithm. + +# Returns +- `out`: The output of the algorithm. +""" +output(::AbstractAlgorithm, ::Any) = nothing # Main optimization routine function optimize end @@ -217,4 +241,45 @@ export optimize include("utils.jl") include("optimize.jl") +## Parameter Space SGD +include("algorithms/paramspacesgd/paramspacesgd.jl") +include("algorithms/paramspacesgd/abstractobjective.jl") + +export ParamSpaceSGD + +## Parameter Space SGD Implementations +### ELBO Maximization + +abstract type AbstractEntropyEstimator end + +""" + estimate_entropy(entropy_estimator, mc_samples, q, q_stop) + +Estimate the entropy of `q`. + +# Arguments +- `entropy_estimator`: Entropy estimation strategy. +- `q`: Variational approximation. +- `q_stop`: Variational approximation with detached from the automatic differentiation graph. +- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_entropy end + +include("algorithms/paramspacesgd/repgradelbo.jl") +include("algorithms/paramspacesgd/scoregradelbo.jl") +include("algorithms/paramspacesgd/entropy.jl") + +export RepGradELBO, + ScoreGradELBO, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy, + ClosedFormEntropyZeroGradient, + StickingTheLandingEntropyZeroGradient + +include("algorithms/paramspacesgd/constructors.jl") + end diff --git a/src/algorithms/paramspacesgd/constructors.jl b/src/algorithms/paramspacesgd/constructors.jl new file mode 100644 index 000000000..551a07457 --- /dev/null +++ b/src/algorithms/paramspacesgd/constructors.jl @@ -0,0 +1,15 @@ + +function ADVI(problem, adtype, optimizer, averager) + objective = RepGradELBO() + return ParamSpaceSGD(problem, objective, adtype, optimizer, averager) +end + +function BBVIRepGradProx(problem, adtype, optimizer, averager) + objective = RepGradELBO() + return ParamSpaceSGD(problem, objective, adtype, optimizer, averager) +end + +function BBVIRepGradProxSTL(problem, adtype, optimizer, averager) + objective = RepGradELBO() + return ParamSpaceSGD(problem, objective, adtype, optimizer, averager) +end diff --git a/src/algorithms/paramspacesgd/elbo/elbo.jl b/src/algorithms/paramspacesgd/elbo/elbo.jl deleted file mode 100644 index cca96e657..000000000 --- a/src/algorithms/paramspacesgd/elbo/elbo.jl +++ /dev/null @@ -1,31 +0,0 @@ - -# ELBO-specific interfaces -abstract type AbstractEntropyEstimator end - -""" - estimate_entropy(entropy_estimator, mc_samples, q, q_stop) - -Estimate the entropy of `q`. - -# Arguments -- `entropy_estimator`: Entropy estimation strategy. -- `q`: Variational approximation. -- `q_stop`: Variational approximation with detached from the automatic differentiation graph. -- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) - -# Returns -- `obj_est`: Estimate of the objective value. -""" -function estimate_entropy end - -include("repgradelbo.jl") -include("scoregradelbo.jl") -include("entropy.jl") - -export RepGradELBO, - ScoreGradELBO, - ClosedFormEntropy, - StickingTheLandingEntropy, - MonteCarloEntropy, - ClosedFormEntropyZeroGradient, - StickingTheLandingEntropyZeroGradient diff --git a/src/algorithms/paramspacesgd/elbo/entropy.jl b/src/algorithms/paramspacesgd/entropy.jl similarity index 100% rename from src/algorithms/paramspacesgd/elbo/entropy.jl rename to src/algorithms/paramspacesgd/entropy.jl diff --git a/src/algorithms/paramspacesgd/interface.jl b/src/algorithms/paramspacesgd/interface.jl deleted file mode 100644 index b67355c8a..000000000 --- a/src/algorithms/paramspacesgd/interface.jl +++ /dev/null @@ -1,81 +0,0 @@ - -# estimators -""" - AbstractVariationalObjective - -Abstract type for the VI algorithms supported by `AdvancedVI`. - -# Implementations -To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. -Also, it should provide gradients by implementing the function `estimate_gradient!`. -If the estimator is stateful, it can implement `init` to initialize the state. -""" -abstract type AbstractVariationalObjective end - -""" - init(rng, obj, adtype, prob, params, restructure) - -Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. -This function needs to be implemented only if `obj` is stateful. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `obj::AbstractVariationalObjective`: Variational objective. -` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -- `params`: Initial variational parameters. -- `restructure`: Function that reconstructs the variational approximation from `λ`. -""" -function init( - ::Random.AbstractRNG, - ::AbstractVariationalObjective, - ::ADTypes.AbstractADType, - ::Any, - ::Any, - ::Any, -) - nothing -end - -""" - estimate_objective([rng,] obj, q, prob; kwargs...) - -Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `obj::AbstractVariationalObjective`: Variational objective. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `q`: Variational approximation. - -# Keyword Arguments -Depending on the objective, additional keyword arguments may apply. -Please refer to the respective documentation of each variational objective for more info. - -# Returns -- `obj_est`: Estimate of the objective value. -""" -function estimate_objective end - -export estimate_objective - -""" - estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) - -Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `obj::AbstractVariationalObjective`: Variational objective. -- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `params`: Variational parameters to evaluate the gradient on. -- `restructure`: Function that reconstructs the variational approximation from `params`. -- `obj_state`: Previous state of the objective. - -# Returns -- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. -- `obj_state`: The updated state of the objective. -- `stat::NamedTuple`: Statistics and logs generated during estimation. -""" -function estimate_gradient! end diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 25aeb48d4..83679849d 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -15,8 +15,8 @@ The trainable parameters in the variational approximation are expected to be ext This requires the variational approximation to be marked as a functor through `Functors.@functor`. # Arguments -- `adtype`: Automatic differentiation backend. - `objective`: Variational Objective. +- `adtype`: Automatic differentiation backend. - `optimizer`: Optimizer used for inference. - `averager` : Parameter averaging strategy. - `operator` : Operator applied to the parameters after each optimization step. diff --git a/src/algorithms/paramspacesgd/elbo/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl similarity index 100% rename from src/algorithms/paramspacesgd/elbo/repgradelbo.jl rename to src/algorithms/paramspacesgd/repgradelbo.jl diff --git a/src/algorithms/paramspacesgd/elbo/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl similarity index 100% rename from src/algorithms/paramspacesgd/elbo/scoregradelbo.jl rename to src/algorithms/paramspacesgd/scoregradelbo.jl From 8751e31bfe2fc697467670c6ececcccfa42b105e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Jun 2025 17:42:36 -0400 Subject: [PATCH 013/108] run formatter --- src/AdvancedVI.jl | 10 ++- src/algorithms/paramspacesgd/constructors.jl | 70 ++++++++++++++++--- test/algorithms/paramspacesgd/repgradelbo.jl | 11 ++- .../repgradelbo_distributionsad.jl | 5 +- .../repgradelbo_locationscale.jl | 5 +- .../repgradelbo_locationscale_bijectors.jl | 5 +- .../repgradelbo_proximal_locationscale.jl | 5 +- ...adelbo_proximal_locationscale_bijectors.jl | 5 +- 8 files changed, 78 insertions(+), 38 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index b1292b948..42346e5f1 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -217,7 +217,11 @@ Perform a single step of `alg` given the previous `stat`. - terminate::Bool: Whether to terminate the algorithm after the step. - info::NamedTuple: Information generated during the step. """ -step(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, callback, objargs...; kwargs...) = nothing +function step( + ::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, callback, objargs...; kwargs... +) + nothing +end """ output(alg, state) @@ -242,8 +246,8 @@ include("utils.jl") include("optimize.jl") ## Parameter Space SGD -include("algorithms/paramspacesgd/paramspacesgd.jl") include("algorithms/paramspacesgd/abstractobjective.jl") +include("algorithms/paramspacesgd/paramspacesgd.jl") export ParamSpaceSGD @@ -282,4 +286,6 @@ export RepGradELBO, include("algorithms/paramspacesgd/constructors.jl") +export BBVIRepGrad, BBVIRepGradProxLocScale + end diff --git a/src/algorithms/paramspacesgd/constructors.jl b/src/algorithms/paramspacesgd/constructors.jl index 551a07457..3cc76042b 100644 --- a/src/algorithms/paramspacesgd/constructors.jl +++ b/src/algorithms/paramspacesgd/constructors.jl @@ -1,15 +1,65 @@ -function ADVI(problem, adtype, optimizer, averager) - objective = RepGradELBO() - return ParamSpaceSGD(problem, objective, adtype, optimizer, averager) -end +""" + BBVIRepGrad(problem, adtype; entropy, optimizer, n_samples, operator) + +Black-box variational inference with the reparameterization gradient with stochastic gradient descent. + +# Arguments +- `problem`: Target problem. +- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -function BBVIRepGradProx(problem, adtype, optimizer, averager) - objective = RepGradELBO() - return ParamSpaceSGD(problem, objective, adtype, optimizer, averager) +# Keyword Arguments +- `entropy`: Entropy gradient estimator to be used. Must be one of `ClosedFormEntropy`, `StickingTheLandingEntropy`, `MonteCarloEntropy`. (default: `ClosedFormEntropy()`) +- `optimizer::Optimisers.AbstractRule`: Optimization algorithm to be used. (default: `DoWG()`) +- `n_samples::Int`: Number of Monte Carlo samples to be used for estimating each gradient. (default: `1`) +- `operator::Union{<:IdentityOperator, <:ClipScale}`: Operator to be applied after each gradient descent step. (default: `ClipScale()`) +- `averager::AbstractAverager`: Parameter averaging strategy. + +""" +function BBVIRepGrad( + problem, + adtype::ADTypes.AbstractADType; + entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(), + optimizer::Optimisers.AbstractRule=DoWG(), + n_samples::Int=1, + averager::AbstractAverager=PolynomialAveraging(), + operator::Union{<:IdentityOperator,<:ClipScale}=ClipScale(), +) + objective = RepGradELBO(n_samples; entropy=entropy) + return ParamSpaceSGD(problem, objective, adtype, optimizer, averager, operator) end -function BBVIRepGradProxSTL(problem, adtype, optimizer, averager) - objective = RepGradELBO() - return ParamSpaceSGD(problem, objective, adtype, optimizer, averager) +""" + BBVIRepGradProxLocScale(problem, adtype; entropy, optimizer, n_samples, operator) + +Black-box variational inference with the reparameterization gradient and stochastic proximal gradient descent on location scale families. + +This algorithm only supports subtypes of `MvLocationScale`. +Also, since the stochastic proximal gradient descent does not use the entropy of the gradient, the entropy estimator to be used must have a zero-mean gradient. +Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed. + +# Arguments +- `problem`: Target problem. +- `adtype`: Automatic differentiation backend. + +# Keyword Arguments +- `entropy_zerograd`: Estimator of the entropy with a zero-mean gradient to be used. Must be one of `ClosedFormEntropyZeroGrad`, `StickingTheLandingEntropyZeroGrad`. (default: `ClosedFormEntropyZeroGrad()`) +- `optimizer::Optimisers.AbstractRule`: Optimization algorithm to be used. Only `DoG`, `DoWG` and `Optimisers.Descent` are supported. (default: `DoWG()`) +- `n_samples::Int`: Number of Monte Carlo samples to be used for estimating each gradient. +- `averager::AbstractAverager`: Parameter averaging strategy. (default: `PolynomialAveraging()`) + +""" +function BBVIRepGradProxLocScale( + problem, + adtype::ADTypes.AbstractADType; + entropy_zerograd::Union{ + <:ClosedFormEntropyZeroGradient,<:StickingTheLandingEntropyZeroGradient + }=ClosedFormEntropyZeroGradient(), + optimizer::Union{<:Descent,<:DoG,<:DoWG}=DoWG(), + n_samples::Int=1, + averager::AbstractAverager=PolynomialAveraging(), +) + objective = RepGradELBO(n_samples; entropy=entropy_zerograd) + operator = ProximalLocationScaleEntropy() + return ParamSpaceSGD(problem, objective, adtype, optimizer, averager, operator) end diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 919e0b0f0..9d863487d 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -27,13 +27,12 @@ end @testset "basic" begin @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] - alg = ParamSpaceSGD( + alg = BBVIRepGrad( model, - RepGradELBO(n_montecarlo), - adtype, - Descent(1e-5), - PolynomialAveraging(), - IdentityOperator(), + adtype; + n_samples=n_montecarlo, + operator=IdentityOperator(), + averager=PolynomialAveraging(), ) _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) @assert isfinite(last(info).elbo) diff --git a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl index 050a46d41..5577271f8 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl @@ -38,10 +38,7 @@ end T = 1000 η = 1e-3 - opt = Optimisers.Descent(η) - avg = PolynomialAveraging() - op = IdentityOperator() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = BBVIRepGrad(model, adtype, opt, avg, operator=IdentityOperator()) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index fb3b9059a..cfcd55113 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -34,10 +34,7 @@ end T = 1000 η = 1e-3 - opt = Optimisers.Descent(η) - avg = PolynomialAveraging() - op = ClipScale() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = BBVIRepGrad(model, adtype) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 346c78286..1f25f8f74 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -34,10 +34,7 @@ end T = 1000 η = 1e-3 - opt = Optimisers.Descent(η) - op = ClipScale() - avg = PolynomialAveraging() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = BBVIRepGrad(model, adtype) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 8a0693288..a75db1579 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -43,10 +43,7 @@ end T = 1000 η = 1e-3 - opt = DoWG(1.0) - avg = PolynomialAveraging() - op = ProximalLocationScaleEntropy() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = BBVIRepGradProxLocScale(model, adtype) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index c3d886713..a912cd9ab 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -35,10 +35,7 @@ end T = 1000 η = 1e-3 - opt = DoWG(1.0) - op = ProximalLocationScaleEntropy() - avg = PolynomialAveraging() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = BBVIRepGradProxLocScale(model, adtype) b = Bijectors.bijector(model) b⁻¹ = inverse(b) From c53048a0dc375e1e7c80ec8e3b25996b35775e6d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Jun 2025 18:06:33 -0400 Subject: [PATCH 014/108] fix tests to use update interface --- test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl | 2 +- test/algorithms/paramspacesgd/repgradelbo_locationscale.jl | 2 +- .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 2 +- .../paramspacesgd/repgradelbo_proximal_locationscale.jl | 2 +- .../repgradelbo_proximal_locationscale_bijectors.jl | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl index 5577271f8..e998705b5 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl @@ -38,7 +38,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGrad(model, adtype, opt, avg, operator=IdentityOperator()) + alg = BBVIRepGrad(model, adtype; operator=IdentityOperator(), optimizer=Descent(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index cfcd55113..2c0d2e5e4 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -34,7 +34,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGrad(model, adtype) + alg = BBVIRepGrad(model, adtype; optimizer=Descent(η)) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 1f25f8f74..686ef3fd3 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -34,7 +34,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGrad(model, adtype) + alg = BBVIRepGrad(model, adtype; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index a75db1579..d2faa4d7f 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -43,7 +43,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGradProxLocScale(model, adtype) + alg = BBVIRepGradProxLocScale(model, adtype; optimizer=Descent(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index a912cd9ab..e08d76291 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -35,7 +35,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGradProxLocScale(model, adtype) + alg = BBVIRepGradProxLocScale(model, adtype; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) From 95cbbda43381e1621685d43bb99d5355eefbee7c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 26 Jun 2025 18:07:40 -0400 Subject: [PATCH 015/108] bump AdvancedVI version in git submodules --- bench/Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/Project.toml b/bench/Project.toml index 703c546bd..c3ff783ab 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" -AdvancedVI = "0.3, 0.4" +AdvancedVI = "0.5" BenchmarkTools = "1" Bijectors = "0.13, 0.14, 0.15" Distributions = "0.25.111" diff --git a/docs/Project.toml b/docs/Project.toml index 4466b49b7..81d9f5fc5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ADTypes = "1" -AdvancedVI = "0.4" +AdvancedVI = "0.5" Bijectors = "0.13.6, 0.14, 0.15" Distributions = "0.25" Documenter = "1" From 8d728c4beb793474581e66b49751b66320f55f19 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 13:42:39 -0400 Subject: [PATCH 016/108] add missing file --- .../paramspacesgd/abstractobjective.jl | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/algorithms/paramspacesgd/abstractobjective.jl diff --git a/src/algorithms/paramspacesgd/abstractobjective.jl b/src/algorithms/paramspacesgd/abstractobjective.jl new file mode 100644 index 000000000..b67355c8a --- /dev/null +++ b/src/algorithms/paramspacesgd/abstractobjective.jl @@ -0,0 +1,81 @@ + +# estimators +""" + AbstractVariationalObjective + +Abstract type for the VI algorithms supported by `AdvancedVI`. + +# Implementations +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. +Also, it should provide gradients by implementing the function `estimate_gradient!`. +If the estimator is stateful, it can implement `init` to initialize the state. +""" +abstract type AbstractVariationalObjective end + +""" + init(rng, obj, adtype, prob, params, restructure) + +Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +This function needs to be implemented only if `obj` is stateful. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. +- `params`: Initial variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `λ`. +""" +function init( + ::Random.AbstractRNG, + ::AbstractVariationalObjective, + ::ADTypes.AbstractADType, + ::Any, + ::Any, + ::Any, +) + nothing +end + +""" + estimate_objective([rng,] obj, q, prob; kwargs...) + +Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `q`: Variational approximation. + +# Keyword Arguments +Depending on the objective, additional keyword arguments may apply. +Please refer to the respective documentation of each variational objective for more info. + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective end + +export estimate_objective + +""" + estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) + +Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `params`: Variational parameters to evaluate the gradient on. +- `restructure`: Function that reconstructs the variational approximation from `params`. +- `obj_state`: Previous state of the objective. + +# Returns +- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective. +- `stat::NamedTuple`: Statistics and logs generated during estimation. +""" +function estimate_gradient! end From 40648a4486d25253ff698c406237b71cf4660e42 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 14:00:35 -0400 Subject: [PATCH 017/108] update benchmarks to new interface --- bench/benchmarks.jl | 15 +++------------ bench/normallognormal.jl | 2 +- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index f86b35788..25c0adb6e 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -39,8 +39,7 @@ begin ] max_iter = 10^4 d = LogDensityProblems.dimension(prob) - optimizer = Optimisers.Adam(T(1e-3)) - operator = ClipScale() + opt = Optimisers.Adam(T(1e-3)) for (objname, obj) in [ ("RepGradELBO", RepGradELBO(10)), @@ -64,18 +63,10 @@ begin b = Bijectors.bijector(prob) binv = inverse(b) q = Bijectors.TransformedDistribution(family, binv) + alg = BBVIRepGrad($prob, $adtype; optimizer=opt, objective=obj) SUITES[probname][objname][familyname][adname] = begin - @benchmarkable AdvancedVI.optimize( - $prob, - $obj, - $q, - $max_iter; - adtype=$adtype, - optimizer=$optimizer, - operator=$operator, - show_progress=false, - ) + @benchmarkable AdvancedVI.optimize($alg, $max_iter, $q; show_progress=false) end end end diff --git a/bench/normallognormal.jl b/bench/normallognormal.jl index d4da8019c..cb6592b71 100644 --- a/bench/normallognormal.jl +++ b/bench/normallognormal.jl @@ -33,5 +33,5 @@ function normallognormal(; n_dims=10, realtype=Float64) σ_x = realtype(0.3) μ_y = Fill(realtype(5.0), n_dims) σ_y = Fill(realtype(0.3), n_dims) - return model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) + return NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) end From afd421e90cd27da77d58c2c7283e7fcc7d9e459d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 14:06:17 -0400 Subject: [PATCH 018/108] fix wrong dollar sign usage --- bench/benchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 25c0adb6e..fe958d0df 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -63,7 +63,7 @@ begin b = Bijectors.bijector(prob) binv = inverse(b) q = Bijectors.TransformedDistribution(family, binv) - alg = BBVIRepGrad($prob, $adtype; optimizer=opt, objective=obj) + alg = BBVIRepGrad(prob, adtype; optimizer=opt, objective=obj) SUITES[probname][objname][familyname][adname] = begin @benchmarkable AdvancedVI.optimize($alg, $max_iter, $q; show_progress=false) From 34a64baf5debd5ba3c8712a970b6303f0cf1566f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 14:19:19 -0400 Subject: [PATCH 019/108] fix wrong interface --- bench/benchmarks.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index fe958d0df..81b4a0594 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -41,9 +41,9 @@ begin d = LogDensityProblems.dimension(prob) opt = Optimisers.Adam(T(1e-3)) - for (objname, obj) in [ - ("RepGradELBO", RepGradELBO(10)), - ("RepGradELBO + STL", RepGradELBO(10; entropy=StickingTheLandingEntropy())), + for (objname, entropy) in [ + ("RepGradELBO", ClosedFormEntropy()), + ("RepGradELBO + STL", StickingTheLandingEntropy()), ], (adname, adtype) in [ ("Zygote", AutoZygote()), @@ -63,7 +63,7 @@ begin b = Bijectors.bijector(prob) binv = inverse(b) q = Bijectors.TransformedDistribution(family, binv) - alg = BBVIRepGrad(prob, adtype; optimizer=opt, objective=obj) + alg = BBVIRepGrad(prob, adtype; optimizer=opt, entropy) SUITES[probname][objname][familyname][adname] = begin @benchmarkable AdvancedVI.optimize($alg, $max_iter, $q; show_progress=false) From f03f6af4446e07c69c54ebc2958ec26fe2766e03 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 16:27:11 -0400 Subject: [PATCH 020/108] fix to new interface --- docs/src/elbo/repgradelbo.md | 58 +++++++++++++----------------------- docs/src/examples.md | 25 +++++++--------- docs/src/families.md | 42 ++++++++------------------ docs/src/general.md | 14 +++------ 4 files changed, 47 insertions(+), 92 deletions(-) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index ccc69a97f..72eb7e69f 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -188,16 +188,14 @@ binv = inverse(b) q0_trans = Bijectors.TransformedDistribution(q0, binv) -cfe = AdvancedVI.RepGradELBO(n_montecarlo) +cfe = BBVIRepGrad(model, AutoForwardDiff(); entropy=ClosedFormEntropy()) nothing ``` The repgradelbo estimator can instead be created as follows: ```@example repgradelbo -repgradelbo = AdvancedVI.RepGradELBO( - n_montecarlo; entropy=AdvancedVI.StickingTheLandingEntropy() -); +stl = BBVIRepGrad(model, AutoForwardDiff(); entropy=StickingTheLandingEntropy()) nothing ``` @@ -211,35 +209,27 @@ function callback(; params, restructure, kwargs...) (dist = sqrt(dist2),) end -_, _, stats_cfe, _ = AdvancedVI.optimize( - model, +_, info_cfe, _ = AdvancedVI.optimize( cfe, - q0_trans, - max_iter; + max_iter, + q0_trans; show_progress = false, - adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(3e-3), - operator = ClipScale(), callback = callback, ); -_, _, stats_stl, _ = AdvancedVI.optimize( - model, - repgradelbo, - q0_trans, - max_iter; +_, info_stl, _ = AdvancedVI.optimize( + stl, + max_iter, + q0_trans; show_progress = false, - adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(3e-3), - operator = ClipScale(), callback = callback, ); -t = [stat.iteration for stat in stats_cfe] -elbo_cfe = [stat.elbo for stat in stats_cfe] -elbo_stl = [stat.elbo for stat in stats_stl] -dist_cfe = [stat.dist for stat in stats_cfe] -dist_stl = [stat.dist for stat in stats_stl] +t = [i.iteration for i in info_cfe] +elbo_cfe = [i.elbo for i in info_cfe] +elbo_stl = [i.elbo for i in info_stl] +dist_cfe = [i.dist for i in info_cfe] +dist_stl = [i.dist for i in info_stl] plot( t, elbo_cfe, label="BBVI CFE", xlabel="Iteration", ylabel="ELBO") plot!(t, elbo_stl, label="BBVI STL", xlabel="Iteration", ylabel="ELBO") savefig("advi_stl_elbo.svg") @@ -309,23 +299,17 @@ nothing (Note that this is a quick-and-dirty example, and there are more sophisticated ways to implement this.) ```@setup repgradelbo -repgradelbo = AdvancedVI.RepGradELBO(n_montecarlo); - -_, _, stats_qmc, _ = AdvancedVI.optimize( - model, - repgradelbo, - q0_trans, - max_iter; +_, info_qmc, _ = AdvancedVI.optimize( + BBVIRepGrad(model, AutoForwardDiff(); n_samples=n_montecarlo), + max_iter, + q0_trans; show_progress = false, - adtype = AutoForwardDiff(), - optimizer = Optimisers.Adam(3e-3), - operator = ClipScale(), callback = callback, ); -t = [stat.iteration for stat in stats_qmc] -elbo_qmc = [stat.elbo for stat in stats_qmc] -dist_qmc = [stat.dist for stat in stats_qmc] +t = [i.iteration for i in info_qmc] +elbo_qmc = [i.elbo for i in info_qmc] +dist_qmc = [i.dist for i in info_qmc] plot( t, elbo_cfe, label="BBVI CFE", xlabel="Iteration", ylabel="ELBO") plot!(t, elbo_qmc, label="BBVI CFE QMC", xlabel="Iteration", ylabel="ELBO") savefig("advi_qmc_elbo.svg") diff --git a/docs/src/examples.md b/docs/src/examples.md index 9c0292b81..14e9f01d0 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -85,8 +85,7 @@ We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```@example elboexample -n_montecaro = 10; -objective = RepGradELBO(n_montecaro) +alg = BBVIRepGrad(model, AutoForwardDiff()) ``` For the variational family, we will use the classic mean-field Gaussian family. @@ -110,15 +109,11 @@ Passing `objective` and the initial variational approximation `q` to `optimize` ```@example elboexample n_max_iter = 10^4 -q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize( - model, - objective, - q0_trans, - n_max_iter; +q_out, info, _ = AdvancedVI.optimize( + alg, + n_max_iter, + q0_trans; show_progress=false, - adtype=AutoForwardDiff(), - optimizer=Optimisers.Adam(1e-3), - operator=ClipScale(), ); nothing ``` @@ -126,8 +121,8 @@ nothing `ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family. For more information see [this section](@ref clipscale). -`q_avg_trans` is the final output of the optimization procedure. -If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate. +`q_out` is the final output of the optimization procedure. +If a parameter averaging strategy is used through the keyword argument `averager`, `q_out` is be the output of the averaging strategy. The selected inference procedure stores per-iteration statistics into `stats`. For instance, the ELBO can be ploted as follows: @@ -135,8 +130,8 @@ For instance, the ELBO can be ploted as follows: ```@example elboexample using Plots -t = [stat.iteration for stat in stats] -y = [stat.elbo for stat in stats] +t = [i.iteration for i in info] +y = [i.elbo for i in info] plot(t, y; label="BBVI", xlabel="Iteration", ylabel="ELBO") savefig("bbvi_example_elbo.svg") nothing @@ -149,5 +144,5 @@ Further information can be gathered by defining your own `callback!`. The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: ```@example elboexample -estimate_objective(objective, q_avg_trans, model; n_samples=10^4) +estimate_objective(RepGradELBO(10^4), q_out, model) ``` diff --git a/docs/src/families.md b/docs/src/families.md index e270acad8..66c2c171d 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -177,56 +177,38 @@ D = ones(n_dims) U = zeros(n_dims, 3) q0_lr = LowRankGaussian(μ, D, U) -obj = RepGradELBO(1); +alg = BBVIRepGrad(model, AutoReverseDiff(); optimizer=Adam(0.01)) max_iter = 10^4 -function callback(; params, averaged_params, restructure, stat, kwargs...) +function callback(; params, averaged_params, restructure, kwargs...) q = restructure(averaged_params) μ, Σ = mean(q), cov(q) (dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),) end -_, _, stats_fr, _ = AdvancedVI.optimize( - model, - obj, - q0_fr, - max_iter; +_, info_fr, _ = AdvancedVI.optimize( + alg, max_iter, q0_fr; show_progress = false, - adtype = AutoReverseDiff(), - optimizer = Adam(0.01), - averager = PolynomialAveraging(), callback = callback, ); -_, _, stats_mf, _ = AdvancedVI.optimize( - model, - obj, - q0_mf, - max_iter; +_, info_mf, _ = AdvancedVI.optimize( + alg, max_iter, q0_mf; show_progress = false, - adtype = AutoReverseDiff(), - optimizer = Adam(0.01), - averager = PolynomialAveraging(), callback = callback, ); -_, _, stats_lr, _ = AdvancedVI.optimize( - model, - obj, - q0_lr, - max_iter; +_, info_lr, _ = AdvancedVI.optimize( + alg, max_iter, q0_lr; show_progress = false, - adtype = AutoReverseDiff(), - optimizer = Adam(0.01), - averager = PolynomialAveraging(), callback = callback, ); -t = [stat.iteration for stat in stats_fr] -dist_fr = [sqrt(stat.dist2) for stat in stats_fr] -dist_mf = [sqrt(stat.dist2) for stat in stats_mf] -dist_lr = [sqrt(stat.dist2) for stat in stats_lr] +t = [i.iteration for i in info_fr] +dist_fr = [sqrt(i.dist2) for i in info_fr] +dist_mf = [sqrt(i.dist2) for i in info_mf] +dist_lr = [sqrt(i.dist2) for i in info_lr] plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") diff --git a/docs/src/general.md b/docs/src/general.md index f4a8281d1..a59fc9e62 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -1,20 +1,14 @@ -# [General Usage](@id general) - -Each VI algorithm provides the followings: - - 1. Variational families supported by each VI algorithm. - 2. A variational objective corresponding to the VI algorithm. - Note that each variational family is subject to its own constraints. - Thus, please refer to the documentation of the variational inference algorithm of interest. -## Optimizing a Variational Objective +# [General Usage](@id general) -After constructing a *variational objective* `objective` and initializing a *variational approximation*, one can optimize `objective` by calling `optimize`: +AdvancedVI provides multiple variational inference (VI) algorithms. +Given a `<: AbstractAlgorithm` object associated with each algorithm, it suffices to call the function `optimize`: ```@docs optimize ``` + ## Estimating the Objective In some cases, it is useful to directly estimate the objective value. From 2dc2e69ed23d4c1e72c9d20da930cf52d251314c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 16:27:52 -0400 Subject: [PATCH 021/108] fix missing square in docstring of `ProximaLocationScaleEntropy` --- src/optimization/proximal_location_scale_entropy.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimization/proximal_location_scale_entropy.jl b/src/optimization/proximal_location_scale_entropy.jl index 33946527b..1cb36d9d8 100644 --- a/src/optimization/proximal_location_scale_entropy.jl +++ b/src/optimization/proximal_location_scale_entropy.jl @@ -4,7 +4,7 @@ Proximal operator for the entropy of a location-scale distribution, which is defined as ```math - \\mathrm{prox}(\\lambda) = \\argmin_{\\lambda^{\\prime}} - \\mathbb{H}(q_{\\lambda^{\\prime}}) + \\frac{1}{2 \\gamma_t} \\left\\lVert \\lambda - \\lambda^{\\prime} \\right\\rVert , + \\mathrm{prox}(\\lambda) = \\argmin_{\\lambda^{\\prime}} - \\mathbb{H}(q_{\\lambda^{\\prime}}) + \\frac{1}{2 \\gamma_t} {\\left\\lVert \\lambda - \\lambda^{\\prime} \\right\\rVert}_2^2 , ``` where \$\\gamma_t\$ is the stepsize the optimizer used with the proximal operator. This assumes the variational family is `<:VILocationScale` and the optimizer is one of the following: From 218819b4a257f7a8b13edcd15ca846c5cff76d3b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 16:32:21 -0400 Subject: [PATCH 022/108] fix typo --- src/optimization/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimization/rules.jl b/src/optimization/rules.jl index 6bdd82777..abe0d5bcb 100644 --- a/src/optimization/rules.jl +++ b/src/optimization/rules.jl @@ -57,7 +57,7 @@ end Continuous Coin Betting (COCOB[^OT2017]) optimizer. We use the "COCOB-Backprop" variant, which is closer to the Adam optimizer. -It's only parameter is the maximum change per parameter α, which shouldn't need much tuning. +Its only parameter is the maximum change per parameter α, which shouldn't need much tuning. # Parameters - `alpha`: Scaling parameter. (default value: `100`) From 504f16edfadfa04214c948824ab257967ac51229 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 16:58:57 -0400 Subject: [PATCH 023/108] add docstring for `AbstractAlgorithm` --- src/AdvancedVI.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 42346e5f1..debb9c460 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,3 +1,4 @@ + module AdvancedVI using Accessors @@ -187,6 +188,11 @@ export IdentityOperator, ClipScale, ProximalLocationScaleEntropy # Algorithms +""" + AbstractAlgorithm + +Abstract type for a variational inference algorithm. +""" abstract type AbstractAlgorithm end """ From 557fd3d64cf1f1c99b7389bfb93bdcde3ee66418 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 17:16:27 -0400 Subject: [PATCH 024/108] move files in docs --- docs/src/{elbo => paramspacesgd}/overview.md | 0 docs/src/{elbo => paramspacesgd}/repgradelbo.md | 17 ++++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) rename docs/src/{elbo => paramspacesgd}/overview.md (100%) rename docs/src/{elbo => paramspacesgd}/repgradelbo.md (98%) diff --git a/docs/src/elbo/overview.md b/docs/src/paramspacesgd/overview.md similarity index 100% rename from docs/src/elbo/overview.md rename to docs/src/paramspacesgd/overview.md diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md similarity index 98% rename from docs/src/elbo/repgradelbo.md rename to docs/src/paramspacesgd/repgradelbo.md index 72eb7e69f..c10696b84 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -1,3 +1,4 @@ + # [Reparameterization Gradient Estimator](@id repgradelbo) ## Overview @@ -188,14 +189,24 @@ binv = inverse(b) q0_trans = Bijectors.TransformedDistribution(q0, binv) -cfe = BBVIRepGrad(model, AutoForwardDiff(); entropy=ClosedFormEntropy()) +cfe = BBVIRepGrad( + model, + AutoForwardDiff(); + entropy=ClosedFormEntropy(), + optimizer=Adam(1e-2) +) nothing ``` The repgradelbo estimator can instead be created as follows: ```@example repgradelbo -stl = BBVIRepGrad(model, AutoForwardDiff(); entropy=StickingTheLandingEntropy()) +stl = BBVIRepGrad( + model, + AutoForwardDiff(); + entropy=StickingTheLandingEntropy(), + optimizer=Adam(1e-2) +) nothing ``` @@ -300,7 +311,7 @@ nothing ```@setup repgradelbo _, info_qmc, _ = AdvancedVI.optimize( - BBVIRepGrad(model, AutoForwardDiff(); n_samples=n_montecarlo), + BBVIRepGrad(model, AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), max_iter, q0_trans; show_progress = false, From 056818ebd3df6b1ac4847fc2143e626ce334aaf0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 27 Jun 2025 17:16:37 -0400 Subject: [PATCH 025/108] fix docstring --- src/optimize.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index e5eb30282..c37837799 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -2,7 +2,6 @@ """ optimize( [rng::Random.AbstractRNG = Random.default_rng(),] - problem algorithm::AbstractAlgorithm, q_init, max_iter::Int, From 0f3329f265cd13b4192948e8f7a1ace2282a9957 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 15:39:39 -0400 Subject: [PATCH 026/108] fix docstrings --- src/algorithms/paramspacesgd/abstractobjective.jl | 2 +- src/algorithms/paramspacesgd/paramspacesgd.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/algorithms/paramspacesgd/abstractobjective.jl b/src/algorithms/paramspacesgd/abstractobjective.jl index b67355c8a..29b146d38 100644 --- a/src/algorithms/paramspacesgd/abstractobjective.jl +++ b/src/algorithms/paramspacesgd/abstractobjective.jl @@ -7,7 +7,7 @@ Abstract type for the VI algorithms supported by `AdvancedVI`. # Implementations To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. -Also, it should provide gradients by implementing the function `estimate_gradient!`. +Also, it should provide gradients by implementing the function `estimate_gradient`. If the estimator is stateful, it can implement `init` to initialize the state. """ abstract type AbstractVariationalObjective end diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 83679849d..56db803d6 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -30,11 +30,11 @@ The callback function `callback` has a signature of callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) The arguments are as follows: -- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. -- `state`: Collection of the internal states used for optimization. -- `params`: Variational parameters. +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. +- `params`: Current variational parameters. - `averaged_params`: Variational parameters averaged according to the averaging strategy. -- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. - `gradient`: The estimated (possibly stochastic) gradient. """ From 8670e5dceadab522b4bb8cbc5279791e10d615b4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 16:42:03 -0400 Subject: [PATCH 027/108] update documentation --- docs/make.jl | 14 +++- docs/src/general.md | 60 +++++++++------ docs/src/index.md | 5 -- docs/src/paramspacesgd/general.md | 97 +++++++++++++++++++++++++ docs/src/paramspacesgd/objectives.md | 37 ++++++++++ docs/src/paramspacesgd/overview.md | 43 ----------- docs/src/paramspacesgd/repgradelbo.md | 23 ++---- docs/src/paramspacesgd/scoregradelbo.md | 11 +++ 8 files changed, 201 insertions(+), 89 deletions(-) create mode 100644 docs/src/paramspacesgd/general.md create mode 100644 docs/src/paramspacesgd/objectives.md delete mode 100644 docs/src/paramspacesgd/overview.md create mode 100644 docs/src/paramspacesgd/scoregradelbo.md diff --git a/docs/make.jl b/docs/make.jl index fe3453eaf..b103a71cc 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,6 +2,10 @@ using AdvancedVI using Documenter +# Necessary for invoking the docstring specializations +using Random +using ADTypes + DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) makedocs(; @@ -13,9 +17,13 @@ makedocs(; "AdvancedVI" => "index.md", "General Usage" => "general.md", "Examples" => "examples.md", - "ELBO Maximization" => [ - "Overview" => "elbo/overview.md", - "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", + "Parameter Space SGD" => [ + "General" => "paramspacesgd/general.md", + "Objectives" => [ + "Overview" => "paramspacesgd/objectives.md", + "RepGradELBO" => "paramspacesgd/repgradelbo.md", + "ScoreGradELBO" => "paramspacesgd/scoregradelbo.md", + ], ], "Variational Families" => "families.md", "Optimization" => "optimization.md", diff --git a/docs/src/general.md b/docs/src/general.md index a59fc9e62..8c194a5eb 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -2,43 +2,57 @@ # [General Usage](@id general) AdvancedVI provides multiple variational inference (VI) algorithms. -Given a `<: AbstractAlgorithm` object associated with each algorithm, it suffices to call the function `optimize`: - -```@docs -optimize -``` +Each algorithm defines its subtype of [`AdvancedVI.AbstractAlgorithm`](@ref) with some corresponding methods (see [this section](@ref algorithm)). +Then the algorithm can be executed by invoking `optimize`. (See [this section](@ref optimize)). +## [Optimize](@id optimize) -## Estimating the Objective - -In some cases, it is useful to directly estimate the objective value. -This can be done by the following funciton: +Given a subtype of `AbstractAlgorithm` associated with each algorithm, it suffices to call the function `optimize`: ```@docs -estimate_objective +optimize ``` -!!! info - - Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance. +Each algorithm may interact differently with the arguments of `optimize`. +Therefore, please refer to the documentation of each different algorithm for a detailed description on their behavior and their requirements. -## Advanced Usage +## [Algorithm Interface](@id algorithm) -Each variational objective is a subtype of the following abstract type: +A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractAlgorithm`: ```@docs -AdvancedVI.AbstractVariationalObjective +AdvancedVI.AbstractAlgorithm ``` -Furthermore, `AdvancedVI` only interacts with each variational objective by querying gradient estimates. -Therefore, to create a new custom objective to be optimized through `AdvancedVI`, it suffices to implement the following function: +The functionality of each algorithm is then implemented through the following methods: ```@docs -AdvancedVI.estimate_gradient! +AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractAlgorithm, ::Any) +AdvancedVI.step +AdvancedVI.output ``` -If an objective needs to be stateful, one can implement the following function to inialize the state. - -```@docs -AdvancedVI.init +The role of each method should be self-explanatory and should be clear once we take a look at how `optimize` interacts with each algorithm. +The operation of `optimize` can be simplified as follows: + +```julia +function optimize([rng,] algorithm, max_iter, q_init, objargs; kwargs...) + info_total = NamedTuple[] + state = init(rng, algorithm, q_init) + for t in 1:max_iter + info = (iteration=t,) + state, terminate, info′ = step( + rng, algorithm, state, callback, objargs...; kwargs... + ) + info = merge(info′, info) + + if terminate + break + end + + push!(info_total, info) + end + out = output(algorithm, state) + return out, info_total, state +end ``` diff --git a/docs/src/index.md b/docs/src/index.md index f177cd72b..090b5ae3a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -10,8 +10,3 @@ CurrentModule = AdvancedVI VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. `AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. -## Provided Algorithms - -`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization: - - - [Evidence Lower-Bound Maximization](@ref elbomax) diff --git a/docs/src/paramspacesgd/general.md b/docs/src/paramspacesgd/general.md new file mode 100644 index 000000000..614a73822 --- /dev/null +++ b/docs/src/paramspacesgd/general.md @@ -0,0 +1,97 @@ + +# [General](@id paramspacesgd) + +`ParamSpaceSGD` SGD is a general algorithm for leveraging automatic differentiation and SGD. +Furthermore, it operates in the space of *variational parameters*. +Consider the case where each member $q_{\lambda} \in \mathcal{Q}$ of the variational family $\mathcal{Q}$ is uniquely represented through a collection of parameters $\lambda \in \Lambda \subseteq \mathbb{R}^p$. +That is, + +```math +\mathcal{Q} = \{q_{\lambda} \mid \lambda \in \Lambda \}, +``` +Then, as implied by the name, `ParamSpaceSGD` runs SGD on $\Lambda$, the (Euclidean) space of parameters. + +Any algorithm that operates by iterating the following steps can easily be implemented via `ParamSpaceSGD`: + +1. Obtain an unbiased estimate of the target objective. +2. Obtain an estimate of the gradient of the objective by differentiating the objective estimate with respect to the parameters. +3. Perform gradient descent with the stochastic gradient estimate. + +After some simplifications, each `step` of `ParamSpaceSGD` can be described as follows: + +```julia +function step(rng, alg::ParamSpaceSGD, state, callback, objargs...; kwargs...) + (; adtype, problem, objective, operator, averager) = alg + (; q, iteration, grad_buf, opt_st, obj_st, avg_st) = state + iteration += 1 + + # Extract variational parameters of `q` + params, re = Optimisers.destructure(q) + + # Estimate gradient and update the `DiffResults` buffer `grad_buf`. + grad_buf, obj_st, info = estimate_gradient!(...) + + # Gradient descent step. + grad = DiffResults.gradient(grad_buf) + opt_st, params = Optimisers.update!(opt_st, params, grad) + + # Apply operator + params = apply(operator, typeof(q), opt_st, params, re) + + # Apply parameter averaging + avg_st = apply(averager, avg_st, params) + + # Updated state + state = ParamSpaceSGDState(re(params), iteration, grad_buf, opt_st, obj_st, avg_st) + state, false, info +end +``` +The output of `ParamSpaceSGD` is the final state of `averager`. +Furthermore, `operator` can be anything from an identity mapping, a projection operator, a proximal operator, and so on. + +## `ParamSpaceSGD` +The constructor for `ParamSpaceSGD` is as follows: + +```@docs +ParamSpaceSGD +``` + +## Objective Interface + +To define an instance of a `ParamSpaceSGD` algorithm, it suffices to implement the `AbstractVariationalObjective` interface. +First, we need to define a subtype of `AbstractVariationalObjective`: + +```@docs +AdvancedVI.AbstractVariationalObjective +``` + +In addition, we need to implement some methods associated with the objective. +First, each objective may maintain a state such as buffers, online estimates of control variates, batch iterators for subsampling, and so on. +Such things should be initialized by implementing the following: + +```@docs +AdvancedVI.init( + ::Random.AbstractRNG, + ::AdvancedVI.AbstractVariationalObjective, + ::ADTypes.AbstractADType, + ::Any, + ::Any, + ::Any, +) +``` +If this method is not implemented, the state will be automatically be `nothing`. + +Next, the key functionality of estimating stochastic gradients should be implemented through the following: + +```@docs +AdvancedVI.estimate_gradient! +``` + +`AdvancedVI` only interacts with each variational objective by querying gradient estimates. +In a lot of cases, however, it is convinient to be able to estimate the current value of the objective. +For example, for monitoring convergence. +This should be done through the following: + +```@docs +AdvancedVI.estimate_objective +``` diff --git a/docs/src/paramspacesgd/objectives.md b/docs/src/paramspacesgd/objectives.md new file mode 100644 index 000000000..192f06c75 --- /dev/null +++ b/docs/src/paramspacesgd/objectives.md @@ -0,0 +1,37 @@ + +# Overview of Algorithms + +This section will provide an overview of the algorithm form by each objectives provided by `AdvancedVI`. + +## Evidence Lower Bound Maximization + +Evidence lower bound (ELBO) maximization[^JGJS1999] is a general family of algorithms that minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence between the target distribution ``\pi`` and a variational approximation ``q_{\lambda}``. +More generally, it aims to solve the problem + +```math + \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right) \; , +``` + +where $\mathcal{Q}$ is some family of distributions, often called the variational family. +Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. +Instead, the ELBO maximization strategy maximizes a surrogate objective, the *ELBO*: + +```math + \mathrm{ELBO}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right) + \mathbb{H}\left(q\right), +``` + +which is equivalent to the KL up to an additive constant (the evidence). +The ELBO and its gradient can be readily estimated through various strategies. +Overall, ELBO maximization algorithms aim to solve the problem: + +```math + \mathrm{minimize}_{q \in \mathcal{Q}}\quad -\mathrm{ELBO}\left(q\right). +``` + +Multiple ways to solve this problem exist, each leading to a different variational inference algorithm. `AdvancedVI` provides the following objectives: + + - [RepGradELBO](@ref repgradelbo): Implements the reparameterization gradient estimator of the ELBO gradient. + - ScoreGradELBO: Implements the score gradient estimator of the ELBO gradient. + +[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. + diff --git a/docs/src/paramspacesgd/overview.md b/docs/src/paramspacesgd/overview.md deleted file mode 100644 index db9b598ef..000000000 --- a/docs/src/paramspacesgd/overview.md +++ /dev/null @@ -1,43 +0,0 @@ -# [Evidence Lower Bound Maximization](@id elbomax) - -## Introduction - -Evidence lower bound (ELBO) maximization[^JGJS1999] is a general family of algorithms that minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence between the target distribution ``\pi`` and a variational approximation ``q_{\lambda}``. -More generally, they aim to solve the following problem: - -```math - \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right), -``` - -where $$\mathcal{Q}$$ is some family of distributions, often called the variational family. -Since the target distribution ``\pi`` is intractable in general, the KL divergence is also intractable. -Instead, the ELBO maximization strategy maximizes a surrogate objective, the *ELBO*: - -```math - \mathrm{ELBO}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} \log \pi\left(\theta\right) + \mathbb{H}\left(q\right), -``` - -which serves as a lower bound to the KL. -The ELBO and its gradient can be readily estimated through various strategies. -Overall, ELBO maximization algorithms aim to solve the problem: - -```math - \mathrm{maximize}_{q \in \mathcal{Q}}\quad \mathrm{ELBO}\left(q\right). -``` - -Multiple ways to solve this problem exist, each leading to a different variational inference algorithm. - -## Algorithms - -Currently, `AdvancedVI` only provides the approach known as black-box variational inference (also known as Monte Carlo VI, Stochastic Gradient VI). -(Introduced independently by two groups [^RGB2014][^TL2014] in 2014.) -In particular, `AdvancedVI` focuses on the reparameterization gradient estimator[^TL2014][^RMW2014][^KW2014], which is generally superior compared to alternative strategies[^XQKS2019], discussed in the following section: - - - [RepGradELBO](@ref repgradelbo) - -[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. -[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. -[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. -[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. -[^XQKS2019]: Xu, M., Quiroz, M., Kohn, R., & Sisson, S. A. (2019). Variance reduction properties of the reparameterization trick. In *The International Conference on Artificial Intelligence and Statistics. -[^RGB2014]: Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In *Artificial Intelligence and Statistics*. diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index c10696b84..b0bd84541 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -3,29 +3,21 @@ ## Overview -The reparameterization gradient[^TL2014][^RMW2014][^KW2014] is an unbiased gradient estimator of the ELBO. -Consider some variational family - -```math -\mathcal{Q} = \{q_{\lambda} \mid \lambda \in \Lambda \}, -``` - -where $$\lambda$$ is the *variational parameters* of $$q_{\lambda}$$. -If its sampling process can be described by some differentiable reparameterization function $$\mathcal{T}_{\lambda}$$ and a *base distribution* $$\varphi$$ independent of $$\lambda$$ such that +The `RepGradELBO` objective implements the reparameterization gradient[^TL2014][^RMW2014][^KW2014] estimator of the ELBO gradient. +For the variational family $\mathcal{Q} = \{q_{\lambda} \mid \lambda \in \Lambda\}$, suppose the process of sampling from $q_{\lambda}$ can be described by some differentiable reparameterization function $$T_{\lambda}$$ and a *base distribution* $$\varphi$$ independent of $$\lambda$$ such that ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} \mathcal{T}_{\lambda}\left(\epsilon\right);\quad \epsilon \sim \varphi +z \stackrel{d}{=} T_{\lambda}\left(\epsilon\right);\quad \epsilon \sim \varphi \; . ``` - -we can effectively estimate the gradient of the ELBO by directly differentiating +In these cases, we can effectively estimate the gradient of the ELBO by directly differentiating the stochastic estimate of the ELBO objective ```math \widehat{\mathrm{ELBO}}\left(\lambda\right) = \frac{1}{M}\sum^M_{m=1} \log \pi\left(\mathcal{T}_{\lambda}\left(\epsilon_m\right)\right) + \mathbb{H}\left(q_{\lambda}\right), ``` -where $$\epsilon_m \sim \varphi$$ are Monte Carlo samples, with respect to $$\lambda$$. -This estimator is called the reparameterization gradient estimator. +where $$\epsilon_m \sim \varphi$$ are Monte Carlo samples. +The resulting gradient estimate is called the reparameterization gradient estimator. In addition to the reparameterization gradient, `AdvancedVI` provides the following features: @@ -35,7 +27,8 @@ In addition to the reparameterization gradient, `AdvancedVI` provides the follow [^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. [^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. [^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. -## The `RepGradELBO` Objective + +## `RepGradELBO` To use the reparameterization gradient, `AdvancedVI` provides the following variational objective: diff --git a/docs/src/paramspacesgd/scoregradelbo.md b/docs/src/paramspacesgd/scoregradelbo.md new file mode 100644 index 000000000..df559077e --- /dev/null +++ b/docs/src/paramspacesgd/scoregradelbo.md @@ -0,0 +1,11 @@ + +# [Score Gradient Estimator](@id scoregradelbo) + +## Overview + + +## `ScoreGradELBO` + +```@docs +ScoreGradELBO +``` From e7f1885fd7d9bc51f609f72e025009128b26f753 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 16:42:24 -0400 Subject: [PATCH 028/108] add note to docstring of `ParamSpaceSGD` --- src/algorithms/paramspacesgd/paramspacesgd.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 56db803d6..98a574723 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -14,6 +14,9 @@ This algorithm applies stochastic gradient descent (SGD) to the variational `obj The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`. +!!! note + Different objective may impose different requirements on `adtype`, variational family, `optimizer`, and `operator`. It is therefore important to check the documentation corresponding to each specific objective. Essentially, each objective should be thought as forming its own unique algorithm. + # Arguments - `objective`: Variational Objective. - `adtype`: Automatic differentiation backend. From 183fb12eaa64f9cbc53c70b08634b19a2b32a330 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 16:43:34 -0400 Subject: [PATCH 029/108] update docs for `RepGradELBO` --- src/algorithms/paramspacesgd/repgradelbo.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 0ff3b24b7..758ec95a2 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -3,17 +3,6 @@ RepGradELBO(n_samples; kwargs...) Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014]. -This computes the evidence lower-bound (ELBO) through the formulation: -```math -\\begin{aligned} -\\mathrm{ELBO}\\left(\\lambda\\right) -&\\triangleq -\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ - \\log \\pi\\left(z\\right) -\\right] -+ \\mathbb{H}\\left(q_{\\lambda}\\right), -\\end{aligned} -``` # Arguments - `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. From 25be8d0a6dded707a9410b9dd4775a1da0fe9cff Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 17:39:27 -0400 Subject: [PATCH 030/108] apply formatter --- docs/src/examples.md | 7 +---- docs/src/index.md | 1 - docs/src/paramspacesgd/objectives.md | 2 -- docs/src/paramspacesgd/repgradelbo.md | 32 +++++++++------------ docs/src/paramspacesgd/scoregradelbo.md | 38 ++++++++++++++++++++++++- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/docs/src/examples.md b/docs/src/examples.md index 14e9f01d0..1d934b15f 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -109,12 +109,7 @@ Passing `objective` and the initial variational approximation `q` to `optimize` ```@example elboexample n_max_iter = 10^4 -q_out, info, _ = AdvancedVI.optimize( - alg, - n_max_iter, - q0_trans; - show_progress=false, -); +q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, q0_trans; show_progress=false); nothing ``` diff --git a/docs/src/index.md b/docs/src/index.md index 090b5ae3a..3fca2793d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,4 +9,3 @@ CurrentModule = AdvancedVI [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. `AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. - diff --git a/docs/src/paramspacesgd/objectives.md b/docs/src/paramspacesgd/objectives.md index 192f06c75..82b7014af 100644 --- a/docs/src/paramspacesgd/objectives.md +++ b/docs/src/paramspacesgd/objectives.md @@ -1,4 +1,3 @@ - # Overview of Algorithms This section will provide an overview of the algorithm form by each objectives provided by `AdvancedVI`. @@ -34,4 +33,3 @@ Multiple ways to solve this problem exist, each leading to a different variation - ScoreGradELBO: Implements the score gradient estimator of the ELBO gradient. [^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. - diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index b0bd84541..b6074bfd6 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -1,19 +1,27 @@ - # [Reparameterization Gradient Estimator](@id repgradelbo) ## Overview -The `RepGradELBO` objective implements the reparameterization gradient[^TL2014][^RMW2014][^KW2014] estimator of the ELBO gradient. +The `RepGradELBO` objective implements the reparameterization gradient estimator[^HC1983][^G1991][^R1992][^P1996] of the ELBO gradient. +The reparameterization gradient, also known as the push-in gradient or the pathwise gradient, was introduced to VI in [^TL2014][^RMW2014][^KW2014]. For the variational family $\mathcal{Q} = \{q_{\lambda} \mid \lambda \in \Lambda\}$, suppose the process of sampling from $q_{\lambda}$ can be described by some differentiable reparameterization function $$T_{\lambda}$$ and a *base distribution* $$\varphi$$ independent of $$\lambda$$ such that +[^HC1983]: Ho, Y. C., & Cao, X. (1983). Perturbation analysis and optimization of queueing networks. Journal of optimization theory and Applications, 40(4), 559-582. +[^G1991]: Glasserman, P. (1991). Gradient estimation via perturbation analysis (Vol. 116). Springer Science & Business Media. +[^R1992]: Rubinstein, R. Y. (1992). Sensitivity analysis of discrete event systems by the “push out” method. Annals of Operations Research, 39(1), 229-250. +[^P1996]: Pflug, G. C. (1996). Optimization of stochastic models: the interface between simulation and optimization (Vol. 373). Springer Science & Business Media. +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. +[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. +[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad z \stackrel{d}{=} T_{\lambda}\left(\epsilon\right);\quad \epsilon \sim \varphi \; . ``` -In these cases, we can effectively estimate the gradient of the ELBO by directly differentiating the stochastic estimate of the ELBO objective + +In these cases, denoting the target log denstiy as $\log \pi$, we can effectively estimate the gradient of the ELBO by directly differentiating the stochastic estimate of the ELBO objective ```math - \widehat{\mathrm{ELBO}}\left(\lambda\right) = \frac{1}{M}\sum^M_{m=1} \log \pi\left(\mathcal{T}_{\lambda}\left(\epsilon_m\right)\right) + \mathbb{H}\left(q_{\lambda}\right), + \widehat{\mathrm{ELBO}}\left(\lambda\right) = \frac{1}{M}\sum^M_{m=1} \log \pi\left(T_{\lambda}\left(\epsilon_m\right)\right) + \mathbb{H}\left(q_{\lambda}\right), ``` where $$\epsilon_m \sim \varphi$$ are Monte Carlo samples. @@ -24,10 +32,6 @@ In addition to the reparameterization gradient, `AdvancedVI` provides the follow 1. **Posteriors with constrained supports** are handled through [`Bijectors`](https://github.com/TuringLang/Bijectors.jl), which is known as the automatic differentiation VI (ADVI; [^KTRGB2017]) formulation. (See [this section](@ref bijectors).) 2. **The gradient of the entropy** can be estimated through various strategies depending on the capabilities of the variational family. (See [this section](@ref entropygrad).) -[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. -[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. -[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*. - ## `RepGradELBO` To use the reparameterization gradient, `AdvancedVI` provides the following variational objective: @@ -114,8 +118,6 @@ The main downside of the STL estimator is that it needs to evaluate and differen Depending on the variational family, this might be computationally inefficient or even numerically unstable. For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability. -The STL control variate can be used by changing the entropy estimator using the following object: - ```@setup repgradelbo using Bijectors using FillArrays @@ -183,10 +185,7 @@ binv = inverse(b) q0_trans = Bijectors.TransformedDistribution(q0, binv) cfe = BBVIRepGrad( - model, - AutoForwardDiff(); - entropy=ClosedFormEntropy(), - optimizer=Adam(1e-2) + model, AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2) ) nothing ``` @@ -195,10 +194,7 @@ The repgradelbo estimator can instead be created as follows: ```@example repgradelbo stl = BBVIRepGrad( - model, - AutoForwardDiff(); - entropy=StickingTheLandingEntropy(), - optimizer=Adam(1e-2) + model, AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) ) nothing ``` diff --git a/docs/src/paramspacesgd/scoregradelbo.md b/docs/src/paramspacesgd/scoregradelbo.md index df559077e..f957b4487 100644 --- a/docs/src/paramspacesgd/scoregradelbo.md +++ b/docs/src/paramspacesgd/scoregradelbo.md @@ -1,8 +1,44 @@ - # [Score Gradient Estimator](@id scoregradelbo) ## Overview +The `ScoreGradELBO` implements the score gradient estimator[^G1990][^KR1996][^RSU1996][^W1992] of the ELBO gradient, also known as the score function method and the REINFORCE gradient. +For variational inference, the use of the score gradient was proposed in [^WW2013][^RGB2014]. +Unlike the [reparameterization gradient](@ref repgradelbo), the score gradient does not require the target log density to be differentiable, and does not differentiate through the sampling process of the variational approximation $q$. +Instead, it only requires gradients of the log density $\log q$. +For this reason, the score gradient is the standard method to deal with discrete variables and targets with discrete support. + +[^G1990]: Glynn, P. W. (1990). Likelihood ratio gradient estimation for stochastic systems. Communications of the ACM, 33(10), 75-84. +[^KR1996]: Kleijnen, J. P., & Rubinstein, R. Y. (1996). Optimization and sensitivity analysis of computer simulation models by the score function method. European Journal of Operational Research, 88(3), 413-427. +[^RSU1996]: Rubinstein, R. Y., Shapiro, A., & Uryasev, S. (1996). The score function method. Encyclopedia of Management Sciences, 1363-1366. +[^W1992]: Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8, 229-256. +[^WW2013]: Wingate, D., & Weber, T. (2013). Automated variational inference in probabilistic programming. arXiv preprint arXiv:1301.1299. +[^RGB2014]: Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Artificial intelligence and statistics (pp. 814-822). PMLR. +In more detail, the score gradient uses the Fisher log-derivative identity: For any regular $f$, + +```math +\nabla_{\lambda} \mathbb{E}_{z \sim q_{\lambda}} f += +\mathbb{E}_{z \sim q_{\lambda}}\left[ f(z) \log q_{\lambda}(z) \right] \; . +``` + +The ELBO corresponds to the case where $f = \log \pi / \log q$, where $\log \pi$ is the target log density. + +Instead of implementing the canonical score gradient, `ScoreGradELBO` uses the "VarGrad" objective[^RBNRA2020]: + +```math +\widehat{\mathrm{VarGrad}}(\lambda) += +\mathrm{Var}\left( \log q_{\lambda}(z_i) - \log \pi\left(z_i\right) \right) \; , +``` + +where the variance is computed over the samples $z_1, \ldots, z_m \sim q_{\lambda}$. +Differentiating the VarGrad objective corresponds to the canonical score gradient combined with the "leave-one-out" control variate[^SK2014][^KvHW2019]. + +[^RBNRA2020]: Richter, L., Boustati, A., Nüsken, N., Ruiz, F., & Akyildiz, O. D. (2020). Vargrad: a low-variance gradient estimator for variational inference. Advances in Neural Information Processing Systems, 33, 13481-13492. +[^SK2014]: Salimans, T., & Knowles, D. A. (2014). On using control variates with stochastic approximation for variational bayes and its connection to stochastic linear regression. arXiv preprint arXiv:1401.1022. +[^KvHW2019]: Kool, W., van Hoof, H., & Welling, M. (2019). Buy 4 reinforce samples, get a baseline for free!. +Since the expectation of the `VarGrad` objective (not its gradient) is not exactly the ELBO, we separately obtain an unbiased estimate of the ELBO to be returned by [`estimate_objective`](@ref). ## `ScoreGradELBO` From 27f3634e3300f55f598ef2a22f421dd62b16ca66 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 17:41:11 -0400 Subject: [PATCH 031/108] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/paramspacesgd/scoregradelbo.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/paramspacesgd/scoregradelbo.md b/docs/src/paramspacesgd/scoregradelbo.md index f957b4487..c100d73e7 100644 --- a/docs/src/paramspacesgd/scoregradelbo.md +++ b/docs/src/paramspacesgd/scoregradelbo.md @@ -38,8 +38,7 @@ Differentiating the VarGrad objective corresponds to the canonical score gradien [^RBNRA2020]: Richter, L., Boustati, A., Nüsken, N., Ruiz, F., & Akyildiz, O. D. (2020). Vargrad: a low-variance gradient estimator for variational inference. Advances in Neural Information Processing Systems, 33, 13481-13492. [^SK2014]: Salimans, T., & Knowles, D. A. (2014). On using control variates with stochastic approximation for variational bayes and its connection to stochastic linear regression. arXiv preprint arXiv:1401.1022. [^KvHW2019]: Kool, W., van Hoof, H., & Welling, M. (2019). Buy 4 reinforce samples, get a baseline for free!. -Since the expectation of the `VarGrad` objective (not its gradient) is not exactly the ELBO, we separately obtain an unbiased estimate of the ELBO to be returned by [`estimate_objective`](@ref). - + Since the expectation of the `VarGrad` objective (not its gradient) is not exactly the ELBO, we separately obtain an unbiased estimate of the ELBO to be returned by [`estimate_objective`](@ref). ## `ScoreGradELBO` ```@docs From 5be9cca1eaba7f7af0e45a810ad89acfe38365d5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 17:41:18 -0400 Subject: [PATCH 032/108] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/paramspacesgd/scoregradelbo.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/paramspacesgd/scoregradelbo.md b/docs/src/paramspacesgd/scoregradelbo.md index c100d73e7..c86ceca10 100644 --- a/docs/src/paramspacesgd/scoregradelbo.md +++ b/docs/src/paramspacesgd/scoregradelbo.md @@ -14,8 +14,7 @@ For this reason, the score gradient is the standard method to deal with discrete [^W1992]: Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8, 229-256. [^WW2013]: Wingate, D., & Weber, T. (2013). Automated variational inference in probabilistic programming. arXiv preprint arXiv:1301.1299. [^RGB2014]: Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. In Artificial intelligence and statistics (pp. 814-822). PMLR. -In more detail, the score gradient uses the Fisher log-derivative identity: For any regular $f$, - + In more detail, the score gradient uses the Fisher log-derivative identity: For any regular $f$, ```math \nabla_{\lambda} \mathbb{E}_{z \sim q_{\lambda}} f = From 5309ec7e4888aca2ebd5af1298eb67eb9a861e09 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 29 Jun 2025 17:42:09 -0400 Subject: [PATCH 033/108] add scoregradelbo to the list of objectives --- docs/src/paramspacesgd/objectives.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/paramspacesgd/objectives.md b/docs/src/paramspacesgd/objectives.md index 82b7014af..d3a822cf7 100644 --- a/docs/src/paramspacesgd/objectives.md +++ b/docs/src/paramspacesgd/objectives.md @@ -30,6 +30,6 @@ Overall, ELBO maximization algorithms aim to solve the problem: Multiple ways to solve this problem exist, each leading to a different variational inference algorithm. `AdvancedVI` provides the following objectives: - [RepGradELBO](@ref repgradelbo): Implements the reparameterization gradient estimator of the ELBO gradient. - - ScoreGradELBO: Implements the score gradient estimator of the ELBO gradient. + - [ScoreGradELBO](@ref scoregradelbo): Implements the score gradient estimator of the ELBO gradient. [^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. From f0eef5ffeac6adf788ddeab957b5fca85b2cbdba Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Jul 2025 16:13:53 -0400 Subject: [PATCH 034/108] move `prob` argument to `optimize` from the constructor of `alg` --- docs/src/examples.md | 4 +-- docs/src/paramspacesgd/repgradelbo.md | 13 +++++----- src/AdvancedVI.jl | 8 +++--- src/algorithms/paramspacesgd/constructors.jl | 12 +++------ src/algorithms/paramspacesgd/paramspacesgd.jl | 22 +++++++--------- src/optimize.jl | 26 ++++++++++--------- test/algorithms/paramspacesgd/repgradelbo.jl | 3 +-- .../repgradelbo_distributionsad.jl | 8 +++--- .../repgradelbo_locationscale.jl | 8 +++--- .../repgradelbo_locationscale_bijectors.jl | 8 +++--- .../repgradelbo_proximal_locationscale.jl | 8 +++--- ...adelbo_proximal_locationscale_bijectors.jl | 8 +++--- .../algorithms/paramspacesgd/scoregradelbo.jl | 3 +-- .../scoregradelbo_distributionsad.jl | 8 +++--- .../scoregradelbo_locationscale.jl | 8 +++--- .../scoregradelbo_locationscale_bijectors.jl | 8 +++--- test/general/optimize.jl | 12 ++++----- 17 files changed, 81 insertions(+), 86 deletions(-) diff --git a/docs/src/examples.md b/docs/src/examples.md index 1d934b15f..83f86e00c 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -85,7 +85,7 @@ We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```@example elboexample -alg = BBVIRepGrad(model, AutoForwardDiff()) +alg = BBVIRepGrad(AutoForwardDiff()) ``` For the variational family, we will use the classic mean-field Gaussian family. @@ -109,7 +109,7 @@ Passing `objective` and the initial variational approximation `q` to `optimize` ```@example elboexample n_max_iter = 10^4 -q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, q0_trans; show_progress=false); +q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model, q0_trans; show_progress=false); nothing ``` diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index b6074bfd6..39dd167d2 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -184,18 +184,14 @@ binv = inverse(b) q0_trans = Bijectors.TransformedDistribution(q0, binv) -cfe = BBVIRepGrad( - model, AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2) -) +cfe = BBVIRepGrad(AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)) nothing ``` The repgradelbo estimator can instead be created as follows: ```@example repgradelbo -stl = BBVIRepGrad( - model, AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) -) +stl = BBVIRepGrad(AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)) nothing ``` @@ -212,6 +208,7 @@ end _, info_cfe, _ = AdvancedVI.optimize( cfe, max_iter, + model, q0_trans; show_progress = false, callback = callback, @@ -220,6 +217,7 @@ _, info_cfe, _ = AdvancedVI.optimize( _, info_stl, _ = AdvancedVI.optimize( stl, max_iter, + model, q0_trans; show_progress = false, callback = callback, @@ -300,8 +298,9 @@ nothing ```@setup repgradelbo _, info_qmc, _ = AdvancedVI.optimize( - BBVIRepGrad(model, AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), + BBVIRepGrad(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), max_iter, + model, q0_trans; show_progress = false, callback = callback, diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index debb9c460..24e2a3f7f 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -18,6 +18,7 @@ using LogDensityProblems using ADTypes using DiffResults using DifferentiationInterface +using ChainRulesCore using FillArrays @@ -196,16 +197,17 @@ Abstract type for a variational inference algorithm. abstract type AbstractAlgorithm end """ - init(rng, alg, q_init) + init(rng, alg, prob, q_init) -Initialize `alg` given the initial variational approximation `q_init`. +Initialize `alg` given the initial variational approximation `q_init` and the target `prob`. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `alg::AbstractAlgorithm`: Variational inference algorithm. +- `prob`: Target problem. ` `q_init`: Initial variational approximation. """ -init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any) = nothing +init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, ::Any) = nothing """ step(rng, alg, state, callback, objargs...; kwargs...) diff --git a/src/algorithms/paramspacesgd/constructors.jl b/src/algorithms/paramspacesgd/constructors.jl index 3cc76042b..5072ef042 100644 --- a/src/algorithms/paramspacesgd/constructors.jl +++ b/src/algorithms/paramspacesgd/constructors.jl @@ -1,11 +1,10 @@ """ - BBVIRepGrad(problem, adtype; entropy, optimizer, n_samples, operator) + BBVIRepGrad(adtype; entropy, optimizer, n_samples, operator) Black-box variational inference with the reparameterization gradient with stochastic gradient descent. # Arguments -- `problem`: Target problem. - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. # Keyword Arguments @@ -17,7 +16,6 @@ Black-box variational inference with the reparameterization gradient with stocha """ function BBVIRepGrad( - problem, adtype::ADTypes.AbstractADType; entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(), optimizer::Optimisers.AbstractRule=DoWG(), @@ -26,11 +24,11 @@ function BBVIRepGrad( operator::Union{<:IdentityOperator,<:ClipScale}=ClipScale(), ) objective = RepGradELBO(n_samples; entropy=entropy) - return ParamSpaceSGD(problem, objective, adtype, optimizer, averager, operator) + return ParamSpaceSGD(objective, adtype, optimizer, averager, operator) end """ - BBVIRepGradProxLocScale(problem, adtype; entropy, optimizer, n_samples, operator) + BBVIRepGradProxLocScale(adtype; entropy, optimizer, n_samples, operator) Black-box variational inference with the reparameterization gradient and stochastic proximal gradient descent on location scale families. @@ -39,7 +37,6 @@ Also, since the stochastic proximal gradient descent does not use the entropy of Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed. # Arguments -- `problem`: Target problem. - `adtype`: Automatic differentiation backend. # Keyword Arguments @@ -50,7 +47,6 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed. """ function BBVIRepGradProxLocScale( - problem, adtype::ADTypes.AbstractADType; entropy_zerograd::Union{ <:ClosedFormEntropyZeroGradient,<:StickingTheLandingEntropyZeroGradient @@ -61,5 +57,5 @@ function BBVIRepGradProxLocScale( ) objective = RepGradELBO(n_samples; entropy=entropy_zerograd) operator = ProximalLocationScaleEntropy() - return ParamSpaceSGD(problem, objective, adtype, optimizer, averager, operator) + return ParamSpaceSGD(objective, adtype, optimizer, averager, operator) end diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 98a574723..b633dac7d 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -1,7 +1,6 @@ """ ParamSpaceSGD( - problem, objective::AbstractVariationalObjective, adtype::ADTypes.AbstractADType, optimizer::Optimisers.AbstractRule, @@ -42,14 +41,12 @@ The arguments are as follows: """ struct ParamSpaceSGD{ - Prob, Obj<:AbstractVariationalObjective, AD<:ADTypes.AbstractADType, Opt<:Optimisers.AbstractRule, Avg<:AbstractAverager, Op<:AbstractOperator, } <: AbstractAlgorithm - problem::Prob objective::Obj adtype::AD optimizer::Opt @@ -57,7 +54,8 @@ struct ParamSpaceSGD{ operator::Op end -struct ParamSpaceSGDState{Q,GradBuf,OptSt,ObjSt,AvgSt} +struct ParamSpaceSGDState{P,Q,GradBuf,OptSt,ObjSt,AvgSt} + prob::P q::Q iteration::Int grad_buf::GradBuf @@ -66,14 +64,14 @@ struct ParamSpaceSGDState{Q,GradBuf,OptSt,ObjSt,AvgSt} avg_st::AvgSt end -function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init) - (; problem, adtype, optimizer, averager, objective) = alg +function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, prob, q_init) + (; adtype, optimizer, averager, objective) = alg params, re = Optimisers.destructure(q_init) opt_st = Optimisers.setup(optimizer, params) - obj_st = init(rng, objective, adtype, problem, params, re) + obj_st = init(rng, objective, adtype, prob, params, re) avg_st = init(averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - return ParamSpaceSGDState(q_init, 0, grad_buf, opt_st, obj_st, avg_st) + return ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st) end function output(alg::ParamSpaceSGD, state) @@ -85,15 +83,15 @@ end function step( rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, objargs...; kwargs... ) - (; adtype, problem, objective, operator, averager) = alg - (; q, iteration, grad_buf, opt_st, obj_st, avg_st) = state + (; adtype, objective, operator, averager) = alg + (; prob, q, iteration, grad_buf, opt_st, obj_st, avg_st) = state iteration += 1 params, re = Optimisers.destructure(q) grad_buf, obj_st, info = estimate_gradient!( - rng, objective, adtype, grad_buf, problem, params, re, obj_st, objargs... + rng, objective, adtype, grad_buf, prob, params, re, obj_st, objargs... ) grad = DiffResults.gradient(grad_buf) @@ -101,7 +99,7 @@ function step( params = apply(operator, typeof(q), opt_st, params, re) avg_st = apply(averager, avg_st, params) - state = ParamSpaceSGDState(re(params), iteration, grad_buf, opt_st, obj_st, avg_st) + state = ParamSpaceSGDState(prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st) if !isnothing(callback) averaged_params = value(averager, avg_st) diff --git a/src/optimize.jl b/src/optimize.jl index c37837799..e538a0838 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -3,29 +3,30 @@ optimize( [rng::Random.AbstractRNG = Random.default_rng(),] algorithm::AbstractAlgorithm, - q_init, max_iter::Int, - objargs...; + prob, + q_init, + args...; kwargs... ) -Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. - -The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. -This requires the variational approximation to be marked as a functor through `Functors.@functor`. +Run variational inference `algorithm` on the `problem` implementing the `LogDensityProblems` interface. +For more details on the usage, refer to the documentation corresponding to `algorithm`. # Arguments - `rng`: Random number generator. - `algorithm`: Variational inference algorithm. -- `q_init`: Initial variational distribution. - `max_iter::Int`: Maximum number of iterations. -- `objargs...`: Arguments to be passed to `objective`. +- `prob`: Target `LogDensityProblem` +- `q_init`: Initial variational distribution. +- `args...`: Arguments to be passed to `algorithm`. # Keyword Arguments - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) +- `state::Union{<:Any,Nothing}`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) - `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) - `progress::ProgressMeter.AbstractProgress`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) -- `state::Union{<:Any,Nothing}`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) +- `kwargs...`: Keyword arguments to be passed to `algorithm`. # Returns - `output`: The output of the variational inference algorithm. @@ -42,6 +43,7 @@ function optimize( rng::Random.AbstractRNG, algorithm::AbstractAlgorithm, max_iter::Int, + prob, q_init, objargs...; show_progress::Bool=true, @@ -54,7 +56,7 @@ function optimize( ) info_total = NamedTuple[] state = if isnothing(state) - init(rng, algorithm, q_init) + init(rng, algorithm, prob, q_init) else state end @@ -79,9 +81,9 @@ function optimize( end function optimize( - algorithm::AbstractAlgorithm, max_iter::Int, q_init, objargs...; kwargs... + algorithm::AbstractAlgorithm, max_iter::Int, prob, q_init, objargs...; kwargs... ) return optimize( - Random.default_rng(), algorithm, max_iter, q_init, objargs...; kwargs... + Random.default_rng(), algorithm, max_iter, prob, q_init, objargs...; kwargs... ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 9d863487d..ea6cb4637 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -28,13 +28,12 @@ end @testset "basic" begin @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] alg = BBVIRepGrad( - model, adtype; n_samples=n_montecarlo, operator=IdentityOperator(), averager=PolynomialAveraging(), ) - _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) + _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) @assert isfinite(last(info).elbo) end end diff --git a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl index e998705b5..b2393e594 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl @@ -38,7 +38,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGrad(model, adtype; operator=IdentityOperator(), optimizer=Descent(η)) + alg = BBVIRepGrad(adtype; operator=IdentityOperator(), optimizer=Descent(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -48,7 +48,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) @@ -61,12 +61,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = mean(q_avg) L_repl = sqrt(cov(q_avg)) @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 2c0d2e5e4..3622a7d73 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -34,7 +34,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGrad(model, adtype; optimizer=Descent(η)) + alg = BBVIRepGrad(adtype; optimizer=Descent(η)) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) @@ -50,7 +50,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -63,12 +63,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 686ef3fd3..2df7c1e2e 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -34,7 +34,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGrad(model, adtype; optimizer=Descent(η)) + alg = BBVIRepGrad(adtype; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -56,7 +56,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -69,12 +69,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0_z; show_progress=PROGRESS) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index d2faa4d7f..eadbe6c94 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -43,7 +43,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGradProxLocScale(model, adtype; optimizer=Descent(η)) + alg = BBVIRepGradProxLocScale(adtype; optimizer=Descent(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -52,7 +52,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -65,12 +65,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index e08d76291..b4fa430ee 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -35,7 +35,7 @@ end T = 1000 η = 1e-3 - alg = BBVIRepGradProxLocScale(model, adtype; optimizer=Descent(η)) + alg = BBVIRepGradProxLocScale(adtype; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -57,7 +57,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -70,12 +70,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0_z; show_progress=PROGRESS) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index 366455876..6592149f2 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -23,14 +23,13 @@ end @testset "basic" begin @testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10] alg = ParamSpaceSGD( - model, ScoreGradELBO(n_montecarlo), adtype, Descent(1e-5), PolynomialAveraging(), IdentityOperator(), ) - _, info, _ = optimize(rng, alg, 10, q0; show_progress=false) + _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) @assert isfinite(last(info).elbo) end end diff --git a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl index 805c65a83..1bcd0a243 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl @@ -32,7 +32,7 @@ end opt = Optimisers.Descent(η) op = IdentityOperator() avg = PolynomialAveraging() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = ParamSpaceSGD(objective, adtype, opt, avg, op) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -45,7 +45,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) @@ -58,12 +58,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = mean(q_avg) L_repl = sqrt(cov(q_avg)) @test μ ≈ μ_repl rtol = 1e-5 diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl index 6ec738905..c3af62eb9 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl @@ -34,7 +34,7 @@ end opt = Optimisers.Descent(η) op = ClipScale() avg = PolynomialAveraging() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = ParamSpaceSGD(objective, adtype, opt, avg, op) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) @@ -50,7 +50,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -63,12 +63,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ ≈ μ_repl rtol = 1e-3 diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl index 5d31e6e92..7662cccd1 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl @@ -33,7 +33,7 @@ end opt = Optimisers.Descent(η) op = ClipScale() avg = PolynomialAveraging() - alg = ParamSpaceSGD(model, objective, adtype, opt, avg, op) + alg = ParamSpaceSGD(objective, adtype, opt, avg, op) b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -55,7 +55,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -68,12 +68,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0_z; show_progress=PROGRESS) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ ≈ μ_repl rtol = 1e-3 diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 1f98edf5d..6882b74da 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -15,10 +15,10 @@ optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() - alg = ParamSpaceSGD(model, obj, adtype, optimizer, averager, IdentityOperator()) + alg = ParamSpaceSGD(obj, adtype, optimizer, averager, IdentityOperator()) @testset "default_rng" begin - optimize(alg, T, q0; show_progress=false) + optimize(alg, T, model, q0; show_progress=false) end @testset "callback" begin @@ -26,12 +26,12 @@ callback(; iteration, args...) = (test_value=test_values[iteration],) - _, info, _ = optimize(alg, T, q0; show_progress=false, callback) + _, info, _ = optimize(alg, T, model, q0; show_progress=false, callback) @test [i.test_value for i in info] == test_values end rng = StableRNG(seed) - q_avg_ref, _, _ = optimize(rng, alg, T, q0; show_progress=false) + q_avg_ref, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) @testset "warm start" begin rng = StableRNG(seed) @@ -39,8 +39,8 @@ T_first = div(T, 2) T_last = T - T_first - _, _, state = optimize(rng, alg, T_first, q0; show_progress=false) - q_avg, _, _ = optimize(rng, alg, T_last, q0; show_progress=false, state) + _, _, state = optimize(rng, alg, T_first, model, q0; show_progress=false) + q_avg, _, _ = optimize(rng, alg, T_last, model, q0; show_progress=false, state) @test q_avg == q_avg_ref end From 35f2d3a364b3a6de63fc39fa85eb0639af69bbcd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Jul 2025 16:17:46 -0400 Subject: [PATCH 035/108] run formatter --- docs/src/paramspacesgd/repgradelbo.md | 4 +++- src/algorithms/paramspacesgd/paramspacesgd.jl | 4 +++- .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 4 +++- .../repgradelbo_proximal_locationscale_bijectors.jl | 4 +++- .../paramspacesgd/scoregradelbo_locationscale_bijectors.jl | 4 +++- 5 files changed, 15 insertions(+), 5 deletions(-) diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index 39dd167d2..e5b8f48ed 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -191,7 +191,9 @@ nothing The repgradelbo estimator can instead be created as follows: ```@example repgradelbo -stl = BBVIRepGrad(AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)) +stl = BBVIRepGrad( + AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) +) nothing ``` diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index b633dac7d..f70ac574b 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -99,7 +99,9 @@ function step( params = apply(operator, typeof(q), opt_st, params, re) avg_st = apply(averager, avg_st, params) - state = ParamSpaceSGDState(prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st) + state = ParamSpaceSGDState( + prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st + ) if !isnothing(callback) averaged_params = value(averager, avg_st) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 2df7c1e2e..f04ea8715 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -74,7 +74,9 @@ end L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize( + rng_repl, alg, T, model, q0_z; show_progress=PROGRESS + ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index b4fa430ee..0fa4862ca 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -75,7 +75,9 @@ end L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize( + rng_repl, alg, T, model, q0_z; show_progress=PROGRESS + ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl index 7662cccd1..be4bab86f 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl @@ -73,7 +73,9 @@ end L = q_avg.dist.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize( + rng_repl, alg, T, model, q0_z; show_progress=PROGRESS + ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale @test μ ≈ μ_repl rtol = 1e-3 From 9e36ee3a579456991a889a6194dbe6f295982def Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Jul 2025 16:21:37 -0400 Subject: [PATCH 036/108] fix remove unused import --- src/AdvancedVI.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 24e2a3f7f..c01a972b8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -18,7 +18,6 @@ using LogDensityProblems using ADTypes using DiffResults using DifferentiationInterface -using ChainRulesCore using FillArrays From a58c921a37d63a5477fdf96852b07e238314f4f9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Jul 2025 16:26:15 -0400 Subject: [PATCH 037/108] fix benchmark --- bench/benchmarks.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 81b4a0594..25f694cbf 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -63,10 +63,10 @@ begin b = Bijectors.bijector(prob) binv = inverse(b) q = Bijectors.TransformedDistribution(family, binv) - alg = BBVIRepGrad(prob, adtype; optimizer=opt, entropy) + alg = BBVIRepGrad(adtype; optimizer=opt, entropy) SUITES[probname][objname][familyname][adname] = begin - @benchmarkable AdvancedVI.optimize($alg, $max_iter, $q; show_progress=false) + @benchmarkable AdvancedVI.optimize($alg, $max_iter, $prob, $q; show_progress=false) end end end From 9b5a8937d90d309358c830923727f01e5aad7a1a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Jul 2025 16:34:12 -0400 Subject: [PATCH 038/108] fix docs --- bench/benchmarks.jl | 4 +++- docs/src/families.md | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 25f694cbf..6220bbee3 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -66,7 +66,9 @@ begin alg = BBVIRepGrad(adtype; optimizer=opt, entropy) SUITES[probname][objname][familyname][adname] = begin - @benchmarkable AdvancedVI.optimize($alg, $max_iter, $prob, $q; show_progress=false) + @benchmarkable AdvancedVI.optimize( + $alg, $max_iter, $prob, $q; show_progress=false + ) end end end diff --git a/docs/src/families.md b/docs/src/families.md index 66c2c171d..e6691570c 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -177,7 +177,7 @@ D = ones(n_dims) U = zeros(n_dims, 3) q0_lr = LowRankGaussian(μ, D, U) -alg = BBVIRepGrad(model, AutoReverseDiff(); optimizer=Adam(0.01)) +alg = BBVIRepGrad(AutoReverseDiff(); optimizer=Adam(0.01)) max_iter = 10^4 @@ -188,19 +188,19 @@ function callback(; params, averaged_params, restructure, kwargs...) end _, info_fr, _ = AdvancedVI.optimize( - alg, max_iter, q0_fr; + alg, max_iter, model, q0_fr; show_progress = false, callback = callback, ); _, info_mf, _ = AdvancedVI.optimize( - alg, max_iter, q0_mf; + alg, max_iter, model, q0_mf; show_progress = false, callback = callback, ); _, info_lr, _ = AdvancedVI.optimize( - alg, max_iter, q0_lr; + alg, max_iter, model, q0_lr; show_progress = false, callback = callback, ); From 23ae1f8207db9fe9d152891d0a98683353d6d89e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Jul 2025 16:44:39 -0400 Subject: [PATCH 039/108] fix docs --- docs/src/general.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/general.md b/docs/src/general.md index 8c194a5eb..c10c5b746 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -27,7 +27,7 @@ AdvancedVI.AbstractAlgorithm The functionality of each algorithm is then implemented through the following methods: ```@docs -AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractAlgorithm, ::Any) +AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractAlgorithm, ::Any, ::Any) AdvancedVI.step AdvancedVI.output ``` From f0fd86b19ea591996c8e0a00b1c3cc13d23a703a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 17:04:56 -0400 Subject: [PATCH 040/108] add dependencies --- Project.toml | 2 ++ docs/Project.toml | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/Project.toml b/Project.toml index 6ef7a08c6..fe6fdb2bb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.5.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -28,6 +29,7 @@ AdvancedVIBijectorsExt = "Bijectors" ADTypes = "1" Accessors = "0.1" Bijectors = "0.13, 0.14, 0.15" +ChainRulesCore = "1" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" diff --git a/docs/Project.toml b/docs/Project.toml index 81d9f5fc5..fdc7ed68a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,11 +2,16 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" @@ -17,11 +22,16 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ADTypes = "1" AdvancedVI = "0.5" Bijectors = "0.13.6, 0.14, 0.15" +DataFrames = "1" +DelimitedFiles = "1" Distributions = "0.25" Documenter = "1" FillArrays = "1" ForwardDiff = "0.10, 1" +Functors = "0.5" +GZip = "0.6" LogDensityProblems = "2.1.1" +NormalizingFlows = "0.2" Optimisers = "0.3, 0.4" Plots = "1" QuasiMonteCarlo = "0.3" From 744c9c535cd5dd992b835b29f9c3c6b43d7cf533 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:20:44 -0400 Subject: [PATCH 041/108] add mixed ad log-density problem wrapper --- Project.toml | 7 +++ ext/AdvancedVIMooncakeExt.jl | 9 ++++ ext/AdvancedVIReverseDiffExt.jl | 11 ++++ src/AdvancedVI.jl | 3 ++ .../paramspacesgd/abstractobjective.jl | 3 +- src/algorithms/paramspacesgd/paramspacesgd.jl | 4 +- src/algorithms/paramspacesgd/repgradelbo.jl | 29 +++++++---- src/algorithms/paramspacesgd/scoregradelbo.jl | 11 ++-- src/mixed_ad_logdensity.jl | 51 +++++++++++++++++++ src/optimize.jl | 4 +- 10 files changed, 112 insertions(+), 20 deletions(-) create mode 100644 ext/AdvancedVIMooncakeExt.jl create mode 100644 ext/AdvancedVIReverseDiffExt.jl create mode 100644 src/mixed_ad_logdensity.jl diff --git a/Project.toml b/Project.toml index fe6fdb2bb..9bcebbe45 100644 --- a/Project.toml +++ b/Project.toml @@ -21,9 +21,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] AdvancedVIBijectorsExt = "Bijectors" +AdvancedVIMooncakeExt = "Mooncake" +AdvancedVIReverseDiffExt = "ReverseDiff" + [compat] ADTypes = "1" @@ -46,6 +51,8 @@ julia = "1.10, 1.11.2" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl new file mode 100644 index 000000000..b6d4a4683 --- /dev/null +++ b/ext/AdvancedVIMooncakeExt.jl @@ -0,0 +1,9 @@ +module AdvancedVIMooncakeExt + +using AdvancedVI +using LogDensityProblems +using Mooncake + +Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(LogDensityProblems.logdensity), AdvancedVI.MixedADLogDensityProblem, AbstractVector} + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 000000000..185b722eb --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,11 @@ +module AdvancedVIReverseDiffExt + +using AdvancedVI +using LogDensityProblems +using ReverseDiff + +ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray) + +ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(prob::AdvancedVI.MixedADLogDensityProblem, x::AbstractVector) + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c01a972b8..1993080a3 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -18,6 +18,7 @@ using LogDensityProblems using ADTypes using DiffResults using DifferentiationInterface +using ChainRulesCore using FillArrays @@ -95,6 +96,8 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) +include("mixed_ad_logdensity.jl") + # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian diff --git a/src/algorithms/paramspacesgd/abstractobjective.jl b/src/algorithms/paramspacesgd/abstractobjective.jl index 29b146d38..18759bfca 100644 --- a/src/algorithms/paramspacesgd/abstractobjective.jl +++ b/src/algorithms/paramspacesgd/abstractobjective.jl @@ -59,7 +59,7 @@ function estimate_objective end export estimate_objective """ - estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) + estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state, args...) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` @@ -68,7 +68,6 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `obj::AbstractVariationalObjective`: Variational objective. - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. - `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Variational parameters to evaluate the gradient on. - `restructure`: Function that reconstructs the variational approximation from `params`. - `obj_state`: Previous state of the objective. diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index f70ac574b..10819cf1d 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -81,7 +81,7 @@ function output(alg::ParamSpaceSGD, state) end function step( - rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, objargs...; kwargs... + rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, args...; kwargs... ) (; adtype, objective, operator, averager) = alg (; prob, q, iteration, grad_buf, opt_st, obj_st, avg_st) = state @@ -91,7 +91,7 @@ function step( params, re = Optimisers.destructure(q) grad_buf, obj_st, info = estimate_gradient!( - rng, objective, adtype, grad_buf, prob, params, re, obj_st, objargs... + rng, objective, adtype, grad_buf, params, re, obj_st, args... ) grad = DiffResults.gradient(grad_buf) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 758ec95a2..d0039219b 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -13,7 +13,8 @@ Evidence lower-bound objective with the reparameterization gradient formulation[ # Requirements - The variational approximation ``q_{\\lambda}`` implements `rand`. - The target distribution and the variational approximation have the same support. -- The target `logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend. +- The target `LogDensityProblem` must have a capability at least `LogDensityProblems.LogDensityOrder{1}()`. +- The sampling process `rand(q)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ @@ -26,22 +27,32 @@ function init( rng::Random.AbstractRNG, obj::RepGradELBO, adtype::ADTypes.AbstractADType, - prob, + prob::Prob, params, restructure, -) +) where {Prob} q_stop = restructure(params) + capability = LogDensityProblems.capabilities(Prob) + @assert adtype in [AutoReverseDiff(), AutoZygote(), AutoMooncake(), AutoEnzyme()] + ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}() + @warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}())" * + "Will attempt to directly differentiate through `LogDensityProblems.logdensity`." * "If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())" + prob + else + MixedADLogDensityProblem(prob) + end aux = ( rng=rng, adtype=adtype, obj=obj, - problem=prob, + problem=ad_prob, restructure=restructure, q_stop=q_stop, ) - return AdvancedVI._prepare_gradient( + obj_ad_prep = AdvancedVI._prepare_gradient( estimate_repgradelbo_ad_forward, adtype, params, aux ) + return (obj_ad_prep=obj_ad_prep, problem=ad_prob) end function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy()) @@ -128,23 +139,23 @@ function estimate_gradient!( obj::RepGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, - prob, params, restructure, state, + args... ) - prep = state + (; obj_ad_prep, problem) = state q_stop = restructure(params) aux = ( rng=rng, adtype=adtype, obj=obj, - problem=prob, + problem=problem, restructure=restructure, q_stop=q_stop, ) AdvancedVI._value_and_gradient!( - estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux + estimate_repgradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux ) nelbo = DiffResults.value(out) stat = (elbo=(-nelbo),) diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl index 24caf934b..f61e654b9 100644 --- a/src/algorithms/paramspacesgd/scoregradelbo.jl +++ b/src/algorithms/paramspacesgd/scoregradelbo.jl @@ -28,9 +28,10 @@ function init( samples = rand(rng, q, obj.n_samples) ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) - return AdvancedVI._prepare_gradient( + obj_ad_prep = AdvancedVI._prepare_gradient( estimate_scoregradelbo_ad_forward, adtype, params, aux ) + return (obj_ad_prep=obj_ad_prep, problem=prob,) end function Base.show(io::IO, obj::ScoreGradELBO) @@ -82,18 +83,18 @@ function AdvancedVI.estimate_gradient!( obj::ScoreGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, - prob, params, restructure, state, + args... ) q = restructure(params) - prep = state + (; obj_ad_prep, problem) = state samples = rand(rng, q, obj.n_samples) - ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, problem), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) AdvancedVI._value_and_gradient!( - estimate_scoregradelbo_ad_forward, out, prep, adtype, params, aux + estimate_scoregradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux ) ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) elbo = mean(ℓπ - ℓq) diff --git a/src/mixed_ad_logdensity.jl b/src/mixed_ad_logdensity.jl new file mode 100644 index 000000000..b6b30d0ba --- /dev/null +++ b/src/mixed_ad_logdensity.jl @@ -0,0 +1,51 @@ + +""" + MixedADLogDensityProblem(problem) + +A `LogDensityProblem` wrapper for mixing AD frameworks. + Whenever the outer AD framework attempts to differentiate through `logdensity(problem)` +the pullback calls `logdensity_and_gradient`, which invokes the inner AD framework. +""" +struct MixedADLogDensityProblem{Problem} + problem::Problem +end + +function LogDensityProblems.dimension(mixedad_prob::MixedADLogDensityProblem) + return LogDensityProblems.dimension(mixedad_prob.problem) +end + +function LogDensityProblems.logdensity( + mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + return LogDensityProblems.logdensity(mixedad_prob.problem, x) +end + +function LogDensityProblems.logdensity_and_gradient( + mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) +end + +function LogDensityProblems.logdensity_gradient_and_hessian( + mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) +end + +function LogDensityProblems.capabilities(mixedad_prob::MixedADLogDensityProblem) + return LogDensityProblems.capabilities(mixedad_prob.problem) +end + +function ChainRulesCore.rrule( + ::typeof(LogDensityProblems.logdensity), mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + ℓπ, ∇ℓπ = LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) + function logdensity_pullback(∂y) + ∂x = @thunk(∂y' * ∇ℓπ) + return NoTangent(), NoTangent(), ∂x + end + return ℓπ, logdensity_pullback +end + + + diff --git a/src/optimize.jl b/src/optimize.jl index e538a0838..a32493275 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -45,7 +45,7 @@ function optimize( max_iter::Int, prob, q_init, - objargs...; + args...; show_progress::Bool=true, state::Union{<:Any,Nothing}=nothing, callback=nothing, @@ -65,7 +65,7 @@ function optimize( info = (iteration=t,) state, terminate, info′ = step( - rng, algorithm, state, callback, objargs...; kwargs... + rng, algorithm, state, callback, args...; kwargs... ) info = merge(info′, info) From 12246523447d7693efcc90a0b17337c1e2f765b4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:21:34 -0400 Subject: [PATCH 042/108] update benchmarks --- bench/benchmarks.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 6220bbee3..5bdfc0625 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -8,7 +8,7 @@ using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake using FillArrays using InteractiveUtils using LinearAlgebra -using LogDensityProblems +using LogDensityProblems, LogDensityProblemsAD using Optimisers using Random @@ -40,6 +40,7 @@ begin max_iter = 10^4 d = LogDensityProblems.dimension(prob) opt = Optimisers.Adam(T(1e-3)) + prob_ad = ADgradient(AutoForwardDiff(), prob) for (objname, entropy) in [ ("RepGradELBO", ClosedFormEntropy()), @@ -47,7 +48,6 @@ begin ], (adname, adtype) in [ ("Zygote", AutoZygote()), - ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), ("Mooncake", AutoMooncake(; config=Mooncake.Config())), # ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), @@ -67,7 +67,7 @@ begin SUITES[probname][objname][familyname][adname] = begin @benchmarkable AdvancedVI.optimize( - $alg, $max_iter, $prob, $q; show_progress=false + $alg, $max_iter, $prob_ad, $q; show_progress=false ) end end From c6380cb60ee5568a86162d79548b7af4f9309358 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:35:52 -0400 Subject: [PATCH 043/108] add Enzyme extension --- Project.toml | 3 +++ ext/AdvancedVIEnzymeExt.jl | 9 +++++++++ 2 files changed, 12 insertions(+) create mode 100644 ext/AdvancedVIEnzymeExt.jl diff --git a/Project.toml b/Project.toml index 9bcebbe45..a88980f3d 100644 --- a/Project.toml +++ b/Project.toml @@ -21,11 +21,13 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] AdvancedVIBijectorsExt = "Bijectors" +AdvancedVIEnzymeExt = "Enzyme" AdvancedVIMooncakeExt = "Mooncake" AdvancedVIReverseDiffExt = "ReverseDiff" @@ -51,6 +53,7 @@ julia = "1.10, 1.11.2" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 000000000..af684bc8a --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,9 @@ +module AdvancedVIEnzymeExt + +using AdvancedVI +using LogDensityProblems +using Enzyme + +Enzyme.@import_rrule(typeof(LogDensityProblems.logdensity), AdvancedVI.MixedADLogDensityProblem, AbstractVector) + +end From 54b1ffffe0a608c23df0a38554008a7bee1d0dbe Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:36:05 -0400 Subject: [PATCH 044/108] fix type constraints in `RepGradELBO` --- src/algorithms/paramspacesgd/repgradelbo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index d0039219b..a102ad1f7 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -14,6 +14,7 @@ Evidence lower-bound objective with the reparameterization gradient formulation[ - The variational approximation ``q_{\\lambda}`` implements `rand`. - The target distribution and the variational approximation have the same support. - The target `LogDensityProblem` must have a capability at least `LogDensityProblems.LogDensityOrder{1}()`. +- Only the AD backend `ReverseDiff`, `Zygote`, `Mooncake` are supported. - The sampling process `rand(q)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. @@ -33,7 +34,7 @@ function init( ) where {Prob} q_stop = restructure(params) capability = LogDensityProblems.capabilities(Prob) - @assert adtype in [AutoReverseDiff(), AutoZygote(), AutoMooncake(), AutoEnzyme()] + @assert adtype isa Union{<:AutoReverseDiff, <:AutoZygote, <:AutoMooncake, <:AutoEnzyme} ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}() @warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}())" * "Will attempt to directly differentiate through `LogDensityProblems.logdensity`." * "If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())" From e8a672e2afefe42defc608a78145489804beb9c3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:37:03 -0400 Subject: [PATCH 045/108] update tests, remove `DistributionsAD` --- test/Project.toml | 4 +- test/algorithms/paramspacesgd/repgradelbo.jl | 11 ++- .../repgradelbo_distributionsad.jl | 7 +- .../repgradelbo_locationscale.jl | 9 ++- .../repgradelbo_locationscale_bijectors.jl | 9 ++- .../repgradelbo_proximal_locationscale.jl | 9 ++- ...adelbo_proximal_locationscale_bijectors.jl | 9 ++- .../scoregradelbo_distributionsad.jl | 73 ------------------- test/general/optimize.jl | 13 ++-- test/runtests.jl | 7 +- 10 files changed, 42 insertions(+), 109 deletions(-) delete mode 100644 test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl diff --git a/test/Project.toml b/test/Project.toml index b626b1218..90a0c6815 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,13 +4,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -30,13 +30,13 @@ Bijectors = "0.13, 0.14, 0.15" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" -DistributionsAD = "0.6.45" Enzyme = "0.13, 0.14, 0.15" FillArrays = "1.6.1" ForwardDiff = "0.10.36, 1" Functors = "0.4.5, 0.5" LinearAlgebra = "1" LogDensityProblems = "2.1.1" +LogDensityProblemsAD = "1" Mooncake = "0.4" Optimisers = "0.2.16, 0.3, 0.4" PDMats = "0.11.7" diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index ea6cb4637..9f92b1a8d 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -8,7 +8,6 @@ AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" ] else [ - AutoForwardDiff(), AutoReverseDiff(), AutoZygote(), AutoMooncake(; config=Mooncake.Config()), @@ -27,13 +26,14 @@ end @testset "basic" begin @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] + model_ad = ADgradient(AutoForwardDiff(), model) alg = BBVIRepGrad( adtype; n_samples=n_montecarlo, operator=IdentityOperator(), averager=PolynomialAveraging(), ) - _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) + _, info, _ = optimize(rng, alg, 10, model_ad, q0; show_progress=false) @assert isfinite(last(info).elbo) end end @@ -61,7 +61,12 @@ end modelstats = normal_meanfield(rng, Float64) (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + model_ad = ADgradient(AutoForwardDiff(), model) + mixed_ad = MixedADLogDensityProblem(model_ad) + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] + model_ad = ADgradient(AutoForwardDiff(), model) + q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) @@ -70,7 +75,7 @@ end out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = ( - rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype + rng=rng, obj=obj, problem=mixed_ad, restructure=re, q_stop=q_true, adtype=adtype ) AdvancedVI._value_and_gradient!( AdvancedVI.estimate_repgradelbo_ad_forward, out, adtype, params, aux diff --git a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl index b2393e594..52e8796a6 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl @@ -8,7 +8,6 @@ AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), @@ -36,6 +35,8 @@ end L0 = Diagonal(ones(realtype, n_dims)) q0 = TuringDiagMvNormal(μ0, diag(L0)) + model_ad = ADgradient(AutoForwardDiff(), model) + T = 1000 η = 1e-3 alg = BBVIRepGrad(adtype; operator=IdentityOperator(), optimizer=Descent(η)) @@ -48,7 +49,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) @@ -61,7 +62,7 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) μ = mean(q_avg) L = sqrt(cov(q_avg)) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 3622a7d73..7bf10a30f 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), @@ -32,6 +31,8 @@ end modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + model_ad = ADgradient(AutoForwardDiff(), model) + T = 1000 η = 1e-3 alg = BBVIRepGrad(adtype; optimizer=Descent(η)) @@ -50,7 +51,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -63,12 +64,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index f04ea8715..361c854c5 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), @@ -36,6 +35,8 @@ end η = 1e-3 alg = BBVIRepGrad(adtype; optimizer=Descent(η)) + model_ad = ADgradient(AutoForwardDiff(), model) + b = Bijectors.bijector(model) b⁻¹ = inverse(b) μ0 = Zeros(realtype, n_dims) @@ -56,7 +57,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -69,13 +70,13 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) q_avg, stats, _ = optimize( - rng_repl, alg, T, model, q0_z; show_progress=PROGRESS + rng_repl, alg, T, model_ad, q0_z; show_progress=PROGRESS ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index eadbe6c94..4a5e9d45d 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -8,7 +8,6 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), @@ -34,6 +33,8 @@ end modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + model_ad = ADgradient(AutoForwardDiff(), model) + q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) else @@ -52,7 +53,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -65,12 +66,12 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index 0fa4862ca..6a76a5659 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), @@ -33,6 +32,8 @@ end modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + model_ad = ADgradient(AutoForwardDiff(), model) + T = 1000 η = 1e-3 alg = BBVIRepGradProxLocScale(adtype; optimizer=Descent(η)) @@ -57,7 +58,7 @@ end @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -70,13 +71,13 @@ end @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) q_avg, stats, _ = optimize( - rng_repl, alg, T, model, q0_z; show_progress=PROGRESS + rng_repl, alg, T, model_ad, q0_z; show_progress=PROGRESS ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale diff --git a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl deleted file mode 100644 index 1bcd0a243..000000000 --- a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl +++ /dev/null @@ -1,73 +0,0 @@ -AD_scoregradelbo_distributionsad = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - #:Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end - -@testset "inference ScoreGradELBO DistributionsAD" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], - (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - (objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(10)), - (adbackname, adtype) in AD_scoregradelbo_distributionsad - - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = modelconstr(rng, realtype) - (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats - - T = 1000 - η = 1e-4 - opt = Optimisers.Descent(η) - op = IdentityOperator() - avg = PolynomialAveraging() - alg = ParamSpaceSGD(objective, adtype, opt, avg, op) - - # For small enough η, the error of SGD, Δλ, is bounded as - # Δλ ≤ ρ^T Δλ0 + O(η), - # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η * strong_convexity - - μ0 = zeros(realtype, n_dims) - L0 = Diagonal(ones(realtype, n_dims)) - q0 = TuringDiagMvNormal(μ0, diag(L0)) - - @testset "convergence" begin - Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - - μ = mean(q_avg) - L = sqrt(cov(q_avg)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - μ = mean(q_avg) - L = sqrt(cov(q_avg)) - - rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) - μ_repl = mean(q_avg) - L_repl = sqrt(cov(q_avg)) - @test μ ≈ μ_repl rtol = 1e-5 - @test L ≈ L_repl rtol = 1e-5 - end - end -end diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 6882b74da..2de2432fa 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -10,15 +10,16 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) + model_ad = ADgradient(AutoForwardDiff(), model) - adtype = AutoForwardDiff() + adtype = AutoReverseDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() alg = ParamSpaceSGD(obj, adtype, optimizer, averager, IdentityOperator()) @testset "default_rng" begin - optimize(alg, T, model, q0; show_progress=false) + optimize(alg, T, model_ad, q0; show_progress=false) end @testset "callback" begin @@ -26,12 +27,12 @@ callback(; iteration, args...) = (test_value=test_values[iteration],) - _, info, _ = optimize(alg, T, model, q0; show_progress=false, callback) + _, info, _ = optimize(alg, T, model_ad, q0; show_progress=false, callback) @test [i.test_value for i in info] == test_values end rng = StableRNG(seed) - q_avg_ref, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) + q_avg_ref, _, _ = optimize(rng, alg, T, model_ad, q0; show_progress=false) @testset "warm start" begin rng = StableRNG(seed) @@ -39,8 +40,8 @@ T_first = div(T, 2) T_last = T - T_first - _, _, state = optimize(rng, alg, T_first, model, q0; show_progress=false) - q_avg, _, _ = optimize(rng, alg, T_last, model, q0; show_progress=false, state) + _, _, state = optimize(rng, alg, T_first, model_ad, q0; show_progress=false) + q_avg, _, _ = optimize(rng, alg, T_last, model_ad, q0; show_progress=false, state) @test q_avg == q_avg_ref end diff --git a/test/runtests.jl b/test/runtests.jl index 04bc841c1..f7180bfba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ using DiffResults using Distributions using FillArrays using LinearAlgebra -using LogDensityProblems +using LogDensityProblems, LogDensityProblemsAD using Optimisers using PDMats using Pkg @@ -16,10 +16,6 @@ using Random, StableRNGs using Statistics using StatsBase -using Functors -using DistributionsAD -@functor TuringDiagMvNormal - using ADTypes using ForwardDiff, ReverseDiff, Zygote, Mooncake @@ -71,7 +67,6 @@ if TEST_GROUP == "All" || TEST_GROUP == "ParamSpaceSGD" || TEST_GROUP == "Enzyme include("algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl") include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl") include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl") - include("algorithms/paramspacesgd/scoregradelbo_distributionsad.jl") include("algorithms/paramspacesgd/scoregradelbo_locationscale.jl") include("algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl") end From cd9d778d5061db6df18f255227b4d648694f9916 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:53:08 -0400 Subject: [PATCH 046/108] fix docs --- docs/src/examples.md | 14 +++++++++++--- docs/src/families.md | 11 ++++++----- docs/src/paramspacesgd/repgradelbo.md | 25 +++++++++++++------------ 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/docs/src/examples.md b/docs/src/examples.md index 83f86e00c..d534ed02b 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -48,6 +48,14 @@ n_dims = 10 μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); + +``` + +The algorithm we will use requires the target problem to have at least first-order differentiation capability. +```@example elboexample +using LogDensityProblemsAD + +model_ad = ADgradient(AutoForwardDiff(), model) nothing ``` @@ -77,7 +85,7 @@ Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes ```@example elboexample using Optimisers -using ADTypes, ForwardDiff +using ADTypes, ForwardDiff, ReverseDiff using AdvancedVI ``` @@ -85,7 +93,7 @@ We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```@example elboexample -alg = BBVIRepGrad(AutoForwardDiff()) +alg = BBVIRepGrad(AutoReverseDiff()) ``` For the variational family, we will use the classic mean-field Gaussian family. @@ -109,7 +117,7 @@ Passing `objective` and the initial variational approximation `q` to `optimize` ```@example elboexample n_max_iter = 10^4 -q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model, q0_trans; show_progress=false); +q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model_ad, q0_trans; show_progress=false); nothing ``` diff --git a/docs/src/families.md b/docs/src/families.md index e6691570c..99c906c1b 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -135,10 +135,10 @@ using ADTypes using AdvancedVI using Distributions using LinearAlgebra -using LogDensityProblems +using LogDensityProblems, LogDensityProblemsAD using Optimisers using Plots -using ReverseDiff +using ForwardDiff, ReverseDiff struct Target{D} dist::D @@ -163,6 +163,7 @@ D_true = Diagonal(log.(1 .+ exp.(randn(n_dims)))) Σsqrt_true = sqrt(Σ_true) μ_true = randn(n_dims) model = Target(MvNormal(μ_true, Σ_true)); +model_ad = ADgradient(AutoForwardDiff(), model) d = LogDensityProblems.dimension(model); μ = zeros(d); @@ -188,19 +189,19 @@ function callback(; params, averaged_params, restructure, kwargs...) end _, info_fr, _ = AdvancedVI.optimize( - alg, max_iter, model, q0_fr; + alg, max_iter, model_ad, q0_fr; show_progress = false, callback = callback, ); _, info_mf, _ = AdvancedVI.optimize( - alg, max_iter, model, q0_mf; + alg, max_iter, model_ad, q0_mf; show_progress = false, callback = callback, ); _, info_lr, _ = AdvancedVI.optimize( - alg, max_iter, model, q0_lr; + alg, max_iter, model_ad, q0_lr; show_progress = false, callback = callback, ); diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index e5b8f48ed..52e1a125d 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -122,12 +122,12 @@ For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a bac using Bijectors using FillArrays using LinearAlgebra -using LogDensityProblems +using LogDensityProblems, LogDensityProblemsAD using Plots using Random using Optimisers -using ADTypes, ForwardDiff +using ADTypes, ForwardDiff, ReverseDiff using AdvancedVI struct NormalLogNormal{MX,SX,MY,SY} @@ -150,12 +150,13 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) LogDensityProblems.LogDensityOrder{0}() end -n_dims = 10 -μ_x = 2.0 -σ_x = 0.3 -μ_y = Fill(2.0, n_dims) -σ_y = Fill(1.0, n_dims) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +n_dims = 10 +μ_x = 2.0 +σ_x = 0.3 +μ_y = Fill(2.0, n_dims) +σ_y = Fill(1.0, n_dims) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +model_ad = ADgradient(AutoForwardDiff(), model) d = LogDensityProblems.dimension(model); μ = zeros(d); @@ -192,7 +193,7 @@ The repgradelbo estimator can instead be created as follows: ```@example repgradelbo stl = BBVIRepGrad( - AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) + AutoReverseDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) ) nothing ``` @@ -210,7 +211,7 @@ end _, info_cfe, _ = AdvancedVI.optimize( cfe, max_iter, - model, + model_ad, q0_trans; show_progress = false, callback = callback, @@ -219,7 +220,7 @@ _, info_cfe, _ = AdvancedVI.optimize( _, info_stl, _ = AdvancedVI.optimize( stl, max_iter, - model, + model_ad, q0_trans; show_progress = false, callback = callback, @@ -302,7 +303,7 @@ nothing _, info_qmc, _ = AdvancedVI.optimize( BBVIRepGrad(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), max_iter, - model, + model_ad, q0_trans; show_progress = false, callback = callback, From 5c02cb2a498fd81648547fa59486a889a62659ac Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 7 Jul 2025 18:53:56 -0400 Subject: [PATCH 047/108] run formatter --- docs/src/examples.md | 5 ++++- ext/AdvancedVIEnzymeExt.jl | 8 ++++++-- ext/AdvancedVIMooncakeExt.jl | 6 ++++-- ext/AdvancedVIReverseDiffExt.jl | 10 +++++++--- src/algorithms/paramspacesgd/repgradelbo.jl | 7 ++++--- src/algorithms/paramspacesgd/scoregradelbo.jl | 4 ++-- src/mixed_ad_logdensity.jl | 7 +++---- src/optimize.jl | 4 +--- test/algorithms/paramspacesgd/repgradelbo.jl | 6 +----- .../paramspacesgd/repgradelbo_locationscale.jl | 4 +++- .../repgradelbo_proximal_locationscale.jl | 4 +++- 11 files changed, 38 insertions(+), 27 deletions(-) diff --git a/docs/src/examples.md b/docs/src/examples.md index d534ed02b..4f4baba12 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -52,6 +52,7 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); ``` The algorithm we will use requires the target problem to have at least first-order differentiation capability. + ```@example elboexample using LogDensityProblemsAD @@ -117,7 +118,9 @@ Passing `objective` and the initial variational approximation `q` to `optimize` ```@example elboexample n_max_iter = 10^4 -q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model_ad, q0_trans; show_progress=false); +q_out, info, _ = AdvancedVI.optimize( + alg, n_max_iter, model_ad, q0_trans; show_progress=false +); nothing ``` diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index af684bc8a..5b195ed9c 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -4,6 +4,10 @@ using AdvancedVI using LogDensityProblems using Enzyme -Enzyme.@import_rrule(typeof(LogDensityProblems.logdensity), AdvancedVI.MixedADLogDensityProblem, AbstractVector) +Enzyme.@import_rrule( + typeof(LogDensityProblems.logdensity), + AdvancedVI.MixedADLogDensityProblem, + AbstractVector +) -end +end diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl index b6d4a4683..9229ec64e 100644 --- a/ext/AdvancedVIMooncakeExt.jl +++ b/ext/AdvancedVIMooncakeExt.jl @@ -4,6 +4,8 @@ using AdvancedVI using LogDensityProblems using Mooncake -Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(LogDensityProblems.logdensity), AdvancedVI.MixedADLogDensityProblem, AbstractVector} +Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{ + typeof(LogDensityProblems.logdensity),AdvancedVI.MixedADLogDensityProblem,AbstractVector +} -end +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 185b722eb..f29a72c29 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -4,8 +4,12 @@ using AdvancedVI using LogDensityProblems using ReverseDiff -ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray) +ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity( + prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray +) -ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(prob::AdvancedVI.MixedADLogDensityProblem, x::AbstractVector) +ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity( + prob::AdvancedVI.MixedADLogDensityProblem, x::AbstractVector +) -end +end diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index a102ad1f7..96f08e360 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -34,10 +34,11 @@ function init( ) where {Prob} q_stop = restructure(params) capability = LogDensityProblems.capabilities(Prob) - @assert adtype isa Union{<:AutoReverseDiff, <:AutoZygote, <:AutoMooncake, <:AutoEnzyme} + @assert adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme} ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}() @warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}())" * - "Will attempt to directly differentiate through `LogDensityProblems.logdensity`." * "If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())" + "Will attempt to directly differentiate through `LogDensityProblems.logdensity`." * + "If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())" prob else MixedADLogDensityProblem(prob) @@ -143,7 +144,7 @@ function estimate_gradient!( params, restructure, state, - args... + args..., ) (; obj_ad_prep, problem) = state q_stop = restructure(params) diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl index f61e654b9..88764c100 100644 --- a/src/algorithms/paramspacesgd/scoregradelbo.jl +++ b/src/algorithms/paramspacesgd/scoregradelbo.jl @@ -31,7 +31,7 @@ function init( obj_ad_prep = AdvancedVI._prepare_gradient( estimate_scoregradelbo_ad_forward, adtype, params, aux ) - return (obj_ad_prep=obj_ad_prep, problem=prob,) + return (obj_ad_prep=obj_ad_prep, problem=prob) end function Base.show(io::IO, obj::ScoreGradELBO) @@ -86,7 +86,7 @@ function AdvancedVI.estimate_gradient!( params, restructure, state, - args... + args..., ) q = restructure(params) (; obj_ad_prep, problem) = state diff --git a/src/mixed_ad_logdensity.jl b/src/mixed_ad_logdensity.jl index b6b30d0ba..0c0fed37b 100644 --- a/src/mixed_ad_logdensity.jl +++ b/src/mixed_ad_logdensity.jl @@ -37,7 +37,9 @@ function LogDensityProblems.capabilities(mixedad_prob::MixedADLogDensityProblem) end function ChainRulesCore.rrule( - ::typeof(LogDensityProblems.logdensity), mixedad_prob::MixedADLogDensityProblem, x::AbstractArray + ::typeof(LogDensityProblems.logdensity), + mixedad_prob::MixedADLogDensityProblem, + x::AbstractArray, ) ℓπ, ∇ℓπ = LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) function logdensity_pullback(∂y) @@ -46,6 +48,3 @@ function ChainRulesCore.rrule( end return ℓπ, logdensity_pullback end - - - diff --git a/src/optimize.jl b/src/optimize.jl index a32493275..657157b75 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -64,9 +64,7 @@ function optimize( for t in 1:max_iter info = (iteration=t,) - state, terminate, info′ = step( - rng, algorithm, state, callback, args...; kwargs... - ) + state, terminate, info′ = step(rng, algorithm, state, callback, args...; kwargs...) info = merge(info′, info) if terminate diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 9f92b1a8d..5b6debfb7 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -7,11 +7,7 @@ AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" ), ] else - [ - AutoReverseDiff(), - AutoZygote(), - AutoMooncake(; config=Mooncake.Config()), - ] + [AutoReverseDiff(), AutoZygote(), AutoMooncake(; config=Mooncake.Config())] end @testset "interface RepGradELBO" begin diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 7bf10a30f..e32c462b1 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -69,7 +69,9 @@ end L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize( + rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS + ) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 4a5e9d45d..691636020 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -71,7 +71,9 @@ end L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize( + rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS + ) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl From cbd4ed86ab4db6111c7a012fe3285451978bd082 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 9 Jul 2025 13:47:15 -0400 Subject: [PATCH 048/108] revert docs dependencies, add missing dep --- docs/Project.toml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index fdc7ed68a..81d9f5fc5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,16 +2,11 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" @@ -22,16 +17,11 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ADTypes = "1" AdvancedVI = "0.5" Bijectors = "0.13.6, 0.14, 0.15" -DataFrames = "1" -DelimitedFiles = "1" Distributions = "0.25" Documenter = "1" FillArrays = "1" ForwardDiff = "0.10, 1" -Functors = "0.5" -GZip = "0.6" LogDensityProblems = "2.1.1" -NormalizingFlows = "0.2" Optimisers = "0.3, 0.4" Plots = "1" QuasiMonteCarlo = "0.3" From 00f5fc4e5f50e558593610a82f11412d68af5cdf Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 9 Jul 2025 19:19:55 -0400 Subject: [PATCH 049/108] add deps to extensions --- Project.toml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index a88980f3d..cafca1212 100644 --- a/Project.toml +++ b/Project.toml @@ -26,11 +26,10 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] -AdvancedVIBijectorsExt = "Bijectors" -AdvancedVIEnzymeExt = "Enzyme" -AdvancedVIMooncakeExt = "Mooncake" -AdvancedVIReverseDiffExt = "ReverseDiff" - +AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"] +AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"] +AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"] +AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"] [compat] ADTypes = "1" @@ -55,8 +54,8 @@ julia = "1.10, 1.11.2" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] From 8612b1e475be22bbad2e50a5329f06aba3745875 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 9 Jul 2025 19:22:20 -0400 Subject: [PATCH 050/108] add missing deps in docs --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index 81d9f5fc5..3322daae5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" From 070dc202ca55b1b182b659500887309671df0f9b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 01:57:48 -0400 Subject: [PATCH 051/108] fix MooncakeExt --- ext/AdvancedVIMooncakeExt.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl index 9229ec64e..bc817ff21 100644 --- a/ext/AdvancedVIMooncakeExt.jl +++ b/ext/AdvancedVIMooncakeExt.jl @@ -1,11 +1,17 @@ module AdvancedVIMooncakeExt using AdvancedVI +using Base: IEEEFloat using LogDensityProblems using Mooncake +# Array types explicitly tested by Mooncake +const SupportedArray{P,N} = Union{Array{P,N},AbstractGPUArray{P,N}} + Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{ - typeof(LogDensityProblems.logdensity),AdvancedVI.MixedADLogDensityProblem,AbstractVector + typeof(LogDensityProblems.logdensity), + AdvancedVI.MixedADLogDensityProblem, + SupportedArray{<:IEEEFloat,1}, } end From fd473ea8b63780d180c3583638b4fe8537df2fdd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 22:47:58 -0400 Subject: [PATCH 052/108] restructure CI for AD integration tests --- .github/workflows/CI.yml | 8 +++++ .github/workflows/Enzyme.yml | 3 +- ext/AdvancedVIMooncakeExt.jl | 5 +-- test/algorithms/paramspacesgd/repgradelbo.jl | 25 ++++---------- .../repgradelbo_locationscale.jl | 22 ++----------- .../repgradelbo_locationscale_bijectors.jl | 22 ++----------- .../repgradelbo_proximal_locationscale.jl | 23 ++----------- ...adelbo_proximal_locationscale_bijectors.jl | 22 ++----------- .../algorithms/paramspacesgd/scoregradelbo.jl | 15 ++------- .../scoregradelbo_locationscale.jl | 23 ++----------- .../scoregradelbo_locationscale_bijectors.jl | 22 ++----------- test/general/ad.jl | 28 +++------------- .../proximal_location_scale_entropy.jl | 8 ++--- test/runtests.jl | 33 +++++++++++-------- 14 files changed, 65 insertions(+), 194 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 06219ec9f..7e359ded3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,6 +27,11 @@ jobs: - windows-latest arch: - x64 + AD: + - Enzyme + - ReverseDiff + - Zygote + - Mooncake steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -36,6 +41,9 @@ jobs: - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: AD + AD: ${{ matrix.AD }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v3 with: diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index de16a19d0..ba6f8d876 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -15,7 +15,8 @@ jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} env: - TEST_GROUP: Enzyme + GROUP: AD + AD: Enzyme runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl index bc817ff21..9f9ad47ff 100644 --- a/ext/AdvancedVIMooncakeExt.jl +++ b/ext/AdvancedVIMooncakeExt.jl @@ -5,13 +5,10 @@ using Base: IEEEFloat using LogDensityProblems using Mooncake -# Array types explicitly tested by Mooncake -const SupportedArray{P,N} = Union{Array{P,N},AbstractGPUArray{P,N}} - Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{ typeof(LogDensityProblems.logdensity), AdvancedVI.MixedADLogDensityProblem, - SupportedArray{<:IEEEFloat,1}, + Array{<:IEEEFloat,1}, } end diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index bc52f05d1..1cc4d332c 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -1,15 +1,4 @@ -AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" - [ - AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ] -else - [AutoReverseDiff(), AutoZygote(), AutoMooncake(; config=Mooncake.Config())] -end - @testset "interface RepGradELBO" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -21,10 +10,10 @@ end q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) @testset "basic" begin - @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] + @testset for n_montecarlo in [1, 10] model_ad = ADgradient(AutoForwardDiff(), model) alg = KLMinRepGradDescent( - adtype; + AD; n_samples=n_montecarlo, operator=IdentityOperator(), averager=PolynomialAveraging(), @@ -58,11 +47,9 @@ end (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats model_ad = ADgradient(AutoForwardDiff(), model) - mixed_ad = MixedADLogDensityProblem(model_ad) - - @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] - model_ad = ADgradient(AutoForwardDiff(), model) + mixed_ad = AdvancedVI.MixedADLogDensityProblem(model_ad) + @testset for n_montecarlo in [1, 10] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) @@ -71,10 +58,10 @@ end out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = ( - rng=rng, obj=obj, problem=mixed_ad, restructure=re, q_stop=q_true, adtype=adtype + rng=rng, obj=obj, problem=mixed_ad, restructure=re, q_stop=q_true, adtype=AD ) AdvancedVI._value_and_gradient!( - AdvancedVI.estimate_repgradelbo_ad_forward, out, adtype, params, aux + AdvancedVI.estimate_repgradelbo_ad_forward, out, AD, params, aux ) grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol = 1e-5 diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 05295d386..c0d7796f2 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -1,29 +1,13 @@ -AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end @testset "inference RepGradELBO VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), (objname, objective) in Dict( :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropy()), - ), - (adbackname, adtype) in AD_repgradelbo_locationscale + ) seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -36,7 +20,7 @@ end T = 1000 η = 1e-3 - alg = KLMinRepGradDescent(adtype; optimizer=Descent(η)) + alg = KLMinRepGradDescent(AD; optimizer=Descent(η)) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 200bdd397..d185e644f 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -1,29 +1,13 @@ -AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end @testset "inference RepGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), (objname, objective) in Dict( :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropy()), - ), - (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors + ) seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -33,7 +17,7 @@ end T = 1000 η = 1e-3 - alg = KLMinRepGradDescent(adtype; optimizer=Descent(η)) + alg = KLMinRepGradDescent(AD; optimizer=Descent(η)) model_ad = ADgradient(AutoForwardDiff(), model) diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 4da3e7546..22fa11906 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -1,22 +1,6 @@ -AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end - @testset "inference RepGradELBO Proximal VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), (objname, objective) in Dict( @@ -24,8 +8,7 @@ end RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), - ), - (adbackname, adtype) in AD_repgradelbo_locationscale + ) seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -44,7 +27,7 @@ end T = 1000 η = 1e-3 - alg = KLMinRepGradProxDescent(adtype; optimizer=Descent(η)) + alg = KLMinRepGradProxDescent(AD; optimizer=Descent(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index 62e6fda09..b26e21065 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -1,21 +1,6 @@ -AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end @testset "inference RepGradELBO Proximal VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), (objname, objective) in Dict( @@ -23,8 +8,7 @@ end RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), - ), - (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors + ) seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -36,7 +20,7 @@ end T = 1000 η = 1e-3 - alg = KLMinRepGradProxDescent(adtype; optimizer=Descent(η)) + alg = KLMinRepGradProxDescent(AD; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index 011aea6a1..40084ec15 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -1,15 +1,4 @@ -AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" - [AutoEnzyme()] -else - [ - AutoForwardDiff(), - AutoReverseDiff(), - AutoZygote(), - AutoMooncake(; config=Mooncake.Config()), - ] -end - @testset "interface ScoreGradELBO" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -21,9 +10,9 @@ end q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) @testset "basic" begin - @testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10] + @testset for n_montecarlo in [1, 10] alg = KLMinScoreGradDescent( - adtype; n_samples=n_montecarlo, optimizer=Descent(1e-5) + AD; n_samples=n_montecarlo, optimizer=Descent(1e-5) ) _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) @assert isfinite(last(info).elbo) diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl index 2505a3669..cbde12318 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl @@ -1,25 +1,8 @@ -AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end - @testset "inference ScoreGradELBO VILocationScale" begin - @testset "$(modelname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], + @testset "$(modelname) $(realtype)" for realtype in [Float64, Float32], (modelname, modelconstr) in - Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), - (adbackname, adtype) in AD_scoregradelbo_locationscale + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank) seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -30,7 +13,7 @@ end T = 1000 η = 1e-4 opt = Optimisers.Descent(η) - alg = KLMinScoreGradDescent(adtype; n_samples=10, optimizer=opt) + alg = KLMinScoreGradDescent(AD; n_samples=10, optimizer=opt) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl index 63add9269..2e5d2691c 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl @@ -1,24 +1,8 @@ -AD_scoregradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), - #:Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end @testset "inference ScoreGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], + @testset "$(modelname) $(realtype)" for realtype in [Float64, Float32], (modelname, modelconstr) in - Dict(:NormalLogNormalMeanField => normallognormal_meanfield), - (adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors + Dict(:NormalLogNormalMeanField => normallognormal_meanfield) seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -29,7 +13,7 @@ end T = 1000 η = 1e-4 opt = Optimisers.Descent(η) - alg = KLMinScoreGradDescent(adtype; n_samples=10, optimizer=opt) + alg = KLMinScoreGradDescent(AD; n_samples=10, optimizer=opt) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/general/ad.jl b/test/general/ad.jl index 080104e4b..251eead5e 100644 --- a/test/general/ad.jl +++ b/test/general/ad.jl @@ -1,38 +1,20 @@ -using Test - -AD_interface = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end - @testset "ad" begin - @testset "$(adname)" for (adname, adtype) in AD_interface + @testset "value_and_gradient!" begin D = 10 A = randn(D, D) λ = randn(D) b = randn(D) grad_buf = DiffResults.GradientResult(λ) f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - AdvancedVI._value_and_gradient!(f, grad_buf, adtype, λ, (b=b,)) + AdvancedVI._value_and_gradient!(f, grad_buf, AD, λ, (b=b,)) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A') * λ / 2 + b @test f ≈ λ' * A * λ / 2 + dot(b, λ) end - @testset "$(adname) with prep" for (adname, adtype) in AD_interface + @testset "value_and_gradient! with prep" begin D = 10 λ = randn(D) A = randn(D, D) @@ -40,10 +22,10 @@ end b_prep = randn(D) f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - prep = AdvancedVI._prepare_gradient(f, adtype, λ, (b=b_prep,)) + prep = AdvancedVI._prepare_gradient(f, AD, λ, (b=b_prep,)) b = randn(D) - AdvancedVI._value_and_gradient!(f, grad_buf, prep, adtype, λ, (b=b,)) + AdvancedVI._value_and_gradient!(f, grad_buf, prep, AD, λ, (b=b,)) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) diff --git a/test/general/proximal_location_scale_entropy.jl b/test/general/proximal_location_scale_entropy.jl index ddbbf7300..84e475abd 100644 --- a/test/general/proximal_location_scale_entropy.jl +++ b/test/general/proximal_location_scale_entropy.jl @@ -1,4 +1,6 @@ +using Zygote + @testset "interface ProximalLocationScaleEntropy" begin @testset "MvLocationScale" begin @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in @@ -51,10 +53,8 @@ q′ = re(params′) scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale - grad_left = only(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) - grad_right = only( - Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) - ) + grad_left = Zygote.gradient(L_ -> first(logabsdet(L_)), scale′ ) |> first + grad_right = Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) |> first @test grad_left ≈ grad_right end diff --git a/test/runtests.jl b/test/runtests.jl index f7180bfba..ae58dca34 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,13 @@ using Test using Test: @testset, @test +using ADTypes using Base.Iterators using Bijectors using DiffResults using Distributions using FillArrays +using ForwardDiff, ReverseDiff using LinearAlgebra using LogDensityProblems, LogDensityProblemsAD using Optimisers @@ -16,18 +18,24 @@ using Random, StableRNGs using Statistics using StatsBase -using ADTypes -using ForwardDiff, ReverseDiff, Zygote, Mooncake - using AdvancedVI -const PROGRESS = haskey(ENV, "PROGRESS") -const TEST_GROUP = get(ENV, "TEST_GROUP", "All") - -if TEST_GROUP == "Enzyme" +const AD = if !haskey(ENV, "AD") || get(ENV, "AD") == "ReverseDiff" + AutoReverseDiff() +elseif get(ENV, "AD") == "Mooncake" + using Mooncake + AutoMooncake(; config=Mooncake.Config()) +elseif get(ENV, "AD") == "Zygote" + using Zygote + AutoZygote() +elseif get(ENV, "AD") == "Enzyme" using Enzyme + AutoEnzyme() end +const PROGRESS = haskey(ENV, "PROGRESS") +const GROUP = get(ENV, "GROUP", "All") + # Models for Inference Tests struct TestModel{M,L,S,SC} model::M @@ -40,8 +48,7 @@ end include("models/normal.jl") include("models/normallognormal.jl") -if TEST_GROUP == "All" || TEST_GROUP == "General" - # Interface tests that do not involve testing on Enzyme +if GROUP == "All" || GROUP == "General" include("general/optimize.jl") include("general/rules.jl") include("general/averaging.jl") @@ -49,20 +56,18 @@ if TEST_GROUP == "All" || TEST_GROUP == "General" include("general/proximal_location_scale_entropy.jl") end -if TEST_GROUP == "All" || TEST_GROUP == "General" || TEST_GROUP == "Enzyme" - # Interface tests that involve testing on Enzyme +if GROUP == "All" || GROUP == "General" || GROUP == "AD" include("general/ad.jl") end -if TEST_GROUP == "All" || TEST_GROUP == "Families" +if GROUP == "All" || GROUP == "Families" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") end -if TEST_GROUP == "All" || TEST_GROUP == "ParamSpaceSGD" || TEST_GROUP == "Enzyme" +if GROUP == "All" || GROUP == "ParamSpaceSGD" || GROUP == "AD" include("algorithms/paramspacesgd/repgradelbo.jl") include("algorithms/paramspacesgd/scoregradelbo.jl") - include("algorithms/paramspacesgd/repgradelbo_distributionsad.jl") include("algorithms/paramspacesgd/repgradelbo_locationscale.jl") include("algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl") include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl") From d81ba5e803138f901cc2c3b020ba12d6a6e8bb0f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 22:49:06 -0400 Subject: [PATCH 053/108] run formatter --- test/algorithms/paramspacesgd/repgradelbo.jl | 4 +--- test/algorithms/paramspacesgd/scoregradelbo.jl | 4 +--- test/general/proximal_location_scale_entropy.jl | 5 +++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 1cc4d332c..5704a3acc 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -57,9 +57,7 @@ end obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = ( - rng=rng, obj=obj, problem=mixed_ad, restructure=re, q_stop=q_true, adtype=AD - ) + aux = (rng=rng, obj=obj, problem=mixed_ad, restructure=re, q_stop=q_true, adtype=AD) AdvancedVI._value_and_gradient!( AdvancedVI.estimate_repgradelbo_ad_forward, out, AD, params, aux ) diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index 40084ec15..b9462b685 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -11,9 +11,7 @@ @testset "basic" begin @testset for n_montecarlo in [1, 10] - alg = KLMinScoreGradDescent( - AD; n_samples=n_montecarlo, optimizer=Descent(1e-5) - ) + alg = KLMinScoreGradDescent(AD; n_samples=n_montecarlo, optimizer=Descent(1e-5)) _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) @assert isfinite(last(info).elbo) end diff --git a/test/general/proximal_location_scale_entropy.jl b/test/general/proximal_location_scale_entropy.jl index 84e475abd..a4ed8c2e6 100644 --- a/test/general/proximal_location_scale_entropy.jl +++ b/test/general/proximal_location_scale_entropy.jl @@ -53,8 +53,9 @@ using Zygote q′ = re(params′) scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale - grad_left = Zygote.gradient(L_ -> first(logabsdet(L_)), scale′ ) |> first - grad_right = Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) |> first + grad_left = first(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) + grad_right = + first(Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′)) @test grad_left ≈ grad_right end From 29a24b2f1d60ea8b07fcc80a6d1af1abed27ff7e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 22:57:46 -0400 Subject: [PATCH 054/108] modify CI --- .github/workflows/CI.yml | 11 +++-------- .github/workflows/Enzyme.yml | 17 +++++++++++++---- test/runtests.jl | 8 ++++---- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7e359ded3..10a279dd1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,4 +1,4 @@ -name: CI +name: Tests on: push: branches: @@ -27,11 +27,6 @@ jobs: - windows-latest arch: - x64 - AD: - - Enzyme - - ReverseDiff - - Zygote - - Mooncake steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -42,8 +37,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: AD - AD: ${{ matrix.AD }} + GROUP: All + AD: ReverseDiff - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v3 with: diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index ba6f8d876..93f5fa0bc 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -1,4 +1,4 @@ -name: Enzyme +name: AD on: push: branches: @@ -14,9 +14,6 @@ concurrency: jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} - env: - GROUP: AD - AD: Enzyme runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -30,6 +27,11 @@ jobs: - windows-latest arch: - x64 + AD: + - Enzyme + - ReverseDiff + - Zygote + - Mooncake steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -39,3 +41,10 @@ jobs: - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: AD + AD: ${{ matrix.AD }} + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v3 + with: + files: lcov.info diff --git a/test/runtests.jl b/test/runtests.jl index ae58dca34..f52ac8ecc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,15 +56,15 @@ if GROUP == "All" || GROUP == "General" include("general/proximal_location_scale_entropy.jl") end -if GROUP == "All" || GROUP == "General" || GROUP == "AD" - include("general/ad.jl") -end - if GROUP == "All" || GROUP == "Families" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") end +if GROUP == "All" || GROUP == "General" || GROUP == "AD" + include("general/ad.jl") +end + if GROUP == "All" || GROUP == "ParamSpaceSGD" || GROUP == "AD" include("algorithms/paramspacesgd/repgradelbo.jl") include("algorithms/paramspacesgd/scoregradelbo.jl") From 8f74ebf579b59bddcac451d987d2f8582aa60306 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:04:22 -0400 Subject: [PATCH 055/108] fix tests --- .github/workflows/{Enzyme.yml => AD.yml} | 2 +- test/runtests.jl | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) rename .github/workflows/{Enzyme.yml => AD.yml} (90%) diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/AD.yml similarity index 90% rename from .github/workflows/Enzyme.yml rename to .github/workflows/AD.yml index 93f5fa0bc..5aade045d 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/AD.yml @@ -13,7 +13,7 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + name: Julia ${{ matrix.AD }} - ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/test/runtests.jl b/test/runtests.jl index f52ac8ecc..a7b694d31 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,15 +20,17 @@ using StatsBase using AdvancedVI -const AD = if !haskey(ENV, "AD") || get(ENV, "AD") == "ReverseDiff" +AD_str = get(ENV, "AD", "ReverseDiff") + +const AD = if AD_str == "ReverseDiff" AutoReverseDiff() -elseif get(ENV, "AD") == "Mooncake" +elseif AD_str using Mooncake AutoMooncake(; config=Mooncake.Config()) -elseif get(ENV, "AD") == "Zygote" +elseif AD_str using Zygote AutoZygote() -elseif get(ENV, "AD") == "Enzyme" +elseif AD_str using Enzyme AutoEnzyme() end From acf6c2229e796e3c23a9ad946b0ebe71169263ee Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:09:57 -0400 Subject: [PATCH 056/108] fix tests --- test/general/proximal_location_scale_entropy.jl | 5 +++-- test/runtests.jl | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/general/proximal_location_scale_entropy.jl b/test/general/proximal_location_scale_entropy.jl index a4ed8c2e6..ed74a6251 100644 --- a/test/general/proximal_location_scale_entropy.jl +++ b/test/general/proximal_location_scale_entropy.jl @@ -54,8 +54,9 @@ using Zygote scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale grad_left = first(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) - grad_right = - first(Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′)) + grad_right = first( + Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) + ) @test grad_left ≈ grad_right end diff --git a/test/runtests.jl b/test/runtests.jl index a7b694d31..9e1eda1cd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,17 +20,17 @@ using StatsBase using AdvancedVI -AD_str = get(ENV, "AD", "ReverseDiff") +AD_str = get(ENV, "AD", "ReverseDiff") const AD = if AD_str == "ReverseDiff" AutoReverseDiff() -elseif AD_str +elseif AD_str == "Mooncake" using Mooncake AutoMooncake(; config=Mooncake.Config()) -elseif AD_str +elseif AD_str == "Zygote" using Zygote AutoZygote() -elseif AD_str +elseif AD_str == "Enzyme" using Enzyme AutoEnzyme() end From 050f36383aed9e987bdc2f05cc6ac53cb48e5318 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:13:28 -0400 Subject: [PATCH 057/108] fix name --- .github/workflows/{CI.yml => Tests.yml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/workflows/{CI.yml => Tests.yml} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/Tests.yml similarity index 100% rename from .github/workflows/CI.yml rename to .github/workflows/Tests.yml From a9c7decfde42be471a8adf53f7dcc22376ff7f2c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:19:31 -0400 Subject: [PATCH 058/108] fix Enzyme --- test/runtests.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9e1eda1cd..41ee18fc2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,7 +32,10 @@ elseif AD_str == "Zygote" AutoZygote() elseif AD_str == "Enzyme" using Enzyme - AutoEnzyme() + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ) end const PROGRESS = haskey(ENV, "PROGRESS") From bab5e444b19d91cf9917664924b7da76acb1f615 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:19:38 -0400 Subject: [PATCH 059/108] add missing dep for benchmarks --- bench/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/bench/Project.toml b/bench/Project.toml index c3ff783ab..ab87b4881 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -10,6 +10,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From 88379950cd89f5d4470137e950c8bdaff842142c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:21:01 -0400 Subject: [PATCH 060/108] fix only optionally load ReverseDiff --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 41ee18fc2..ba04d2315 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ using Bijectors using DiffResults using Distributions using FillArrays -using ForwardDiff, ReverseDiff +using ForwardDiff using LinearAlgebra using LogDensityProblems, LogDensityProblemsAD using Optimisers @@ -23,6 +23,7 @@ using AdvancedVI AD_str = get(ENV, "AD", "ReverseDiff") const AD = if AD_str == "ReverseDiff" + using ReverseDiff AutoReverseDiff() elseif AD_str == "Mooncake" using Mooncake From 839762dcb237ed02992352858d1d32bbefbb6a7d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:29:58 -0400 Subject: [PATCH 061/108] try fixing Enzyme --- bench/benchmarks.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index b8cc288bf..6a67d520c 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -48,9 +48,9 @@ begin ], (adname, adtype) in [ ("Zygote", AutoZygote()), - ("ReverseDiff", AutoReverseDiff()), ("Mooncake", AutoMooncake(; config=Mooncake.Config())), - # ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), + ("ReverseDiff", AutoReverseDiff()), + ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), ], (familyname, family) in [ ("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))), From 86f0af787a949c12834143acb94220c877b5ddfe Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:31:48 -0400 Subject: [PATCH 062/108] run formatter --- bench/benchmarks.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 6a67d520c..260f6db8f 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -50,7 +50,13 @@ begin ("Zygote", AutoZygote()), ("Mooncake", AutoMooncake(; config=Mooncake.Config())), ("ReverseDiff", AutoReverseDiff()), - ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), + ( + "Enzyme", + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ), ], (familyname, family) in [ ("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))), From 82f930719e3e935fee5c472e126a15f0e4d9230c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:43:31 -0400 Subject: [PATCH 063/108] restructure move AD integration tests into separate workflow --- .../{AD.yml => AutoDiffIntegration.yml} | 23 ++++++++----------- .github/workflows/Enzyme.yml | 16 +++++++++++++ .github/workflows/Mooncake.yml | 16 +++++++++++++ .github/workflows/Zygote.yml | 16 +++++++++++++ 4 files changed, 58 insertions(+), 13 deletions(-) rename .github/workflows/{AD.yml => AutoDiffIntegration.yml} (83%) create mode 100644 .github/workflows/Enzyme.yml create mode 100644 .github/workflows/Mooncake.yml create mode 100644 .github/workflows/Zygote.yml diff --git a/.github/workflows/AD.yml b/.github/workflows/AutoDiffIntegration.yml similarity index 83% rename from .github/workflows/AD.yml rename to .github/workflows/AutoDiffIntegration.yml index 5aade045d..339ed2a0e 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AutoDiffIntegration.yml @@ -1,16 +1,17 @@ -name: AD +name: Autodiff Integration on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: + workflow_call: + inputs: + AD: + required: true + type: string + concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: name: Julia ${{ matrix.AD }} - ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} @@ -27,11 +28,7 @@ jobs: - windows-latest arch: - x64 - AD: - - Enzyme - - ReverseDiff - - Zygote - - Mooncake + steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -43,7 +40,7 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: AD - AD: ${{ matrix.AD }} + AD: ${{ inputs.AD }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v3 with: diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml new file mode 100644 index 000000000..d931b9be4 --- /dev/null +++ b/.github/workflows/Enzyme.yml @@ -0,0 +1,16 @@ +name: Enzyme Integration +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: + +jobs: + test: + uses: ./.github/workflows/AutoDiffIntegration.yml + with: + AD: Enzyme + secrets: + personal_access_token: ${{ secrets.token }} diff --git a/.github/workflows/Mooncake.yml b/.github/workflows/Mooncake.yml new file mode 100644 index 000000000..67348c9d0 --- /dev/null +++ b/.github/workflows/Mooncake.yml @@ -0,0 +1,16 @@ +name: Mooncake Integration +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: + +jobs: + test: + uses: ./.github/workflows/AutoDiffIntegration.yml + with: + AD: Mooncake + secrets: + personal_access_token: ${{ secrets.token }} diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml new file mode 100644 index 000000000..a18975797 --- /dev/null +++ b/.github/workflows/Zygote.yml @@ -0,0 +1,16 @@ +name: Zygote Integration +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: + +jobs: + test: + uses: ./.github/workflows/AutoDiffIntegration.yml + with: + AD: Zygote + secrets: + personal_access_token: ${{ secrets.token }} From 7757a9b694221648a914b87a2c9f7038538e2b50 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:47:26 -0400 Subject: [PATCH 064/108] fix try fixing AD integration with Mooncake --- .github/workflows/Mooncake.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/Mooncake.yml b/.github/workflows/Mooncake.yml index 67348c9d0..7f95f0350 100644 --- a/.github/workflows/Mooncake.yml +++ b/.github/workflows/Mooncake.yml @@ -7,6 +7,12 @@ on: pull_request: workflow_dispatch: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: uses: ./.github/workflows/AutoDiffIntegration.yml From 75633f6e7353cc1f681fd189df4fb09c5b304833 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 12 Jul 2025 23:49:33 -0400 Subject: [PATCH 065/108] fix remove unused code --- .github/workflows/Enzyme.yml | 2 -- .github/workflows/Mooncake.yml | 8 -------- .github/workflows/Zygote.yml | 2 -- 3 files changed, 12 deletions(-) diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index d931b9be4..cf6555f30 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -12,5 +12,3 @@ jobs: uses: ./.github/workflows/AutoDiffIntegration.yml with: AD: Enzyme - secrets: - personal_access_token: ${{ secrets.token }} diff --git a/.github/workflows/Mooncake.yml b/.github/workflows/Mooncake.yml index 7f95f0350..9de098606 100644 --- a/.github/workflows/Mooncake.yml +++ b/.github/workflows/Mooncake.yml @@ -7,16 +7,8 @@ on: pull_request: workflow_dispatch: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - jobs: test: uses: ./.github/workflows/AutoDiffIntegration.yml with: AD: Mooncake - secrets: - personal_access_token: ${{ secrets.token }} diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml index a18975797..59cfc9d3e 100644 --- a/.github/workflows/Zygote.yml +++ b/.github/workflows/Zygote.yml @@ -12,5 +12,3 @@ jobs: uses: ./.github/workflows/AutoDiffIntegration.yml with: AD: Zygote - secrets: - personal_access_token: ${{ secrets.token }} From c56e6ad0cb902e36d0b82c8c84b0321723834d55 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 00:57:45 -0400 Subject: [PATCH 066/108] change name for source of MixedADLogDensity --- src/AdvancedVI.jl | 2 +- src/{mixed_ad_logdensity.jl => mixedad_logdensity.jl} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{mixed_ad_logdensity.jl => mixedad_logdensity.jl} (100%) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 718153584..9ad4120d8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -96,7 +96,7 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) -include("mixed_ad_logdensity.jl") +include("mixedad_logdensity.jl") # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian diff --git a/src/mixed_ad_logdensity.jl b/src/mixedad_logdensity.jl similarity index 100% rename from src/mixed_ad_logdensity.jl rename to src/mixedad_logdensity.jl From 4fecea5985e7cbb9e24dc2f6c802e1799b79d160 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 01:21:35 -0400 Subject: [PATCH 067/108] add tests for MixedADLogDensityProblem --- src/mixedad_logdensity.jl | 2 +- test/general/mixedad_logdensity.jl | 44 ++++++++++++++++++++++++++++++ test/runtests.jl | 9 +++--- 3 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 test/general/mixedad_logdensity.jl diff --git a/src/mixedad_logdensity.jl b/src/mixedad_logdensity.jl index 0c0fed37b..38ed131bf 100644 --- a/src/mixedad_logdensity.jl +++ b/src/mixedad_logdensity.jl @@ -29,7 +29,7 @@ end function LogDensityProblems.logdensity_gradient_and_hessian( mixedad_prob::MixedADLogDensityProblem, x::AbstractArray ) - return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) + return LogDensityProblems.logdensity_gradient_and_hessian(mixedad_prob.problem, x) end function LogDensityProblems.capabilities(mixedad_prob::MixedADLogDensityProblem) diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl new file mode 100644 index 000000000..c26658075 --- /dev/null +++ b/test/general/mixedad_logdensity.jl @@ -0,0 +1,44 @@ + +struct MixedADTestModel end + +function LogDensityProblems.logdensity(::MixedADTestModel, θ) + return Float64(ℯ) +end + +function LogDensityProblems.dimension(::MixedADTestModel) + return 3 +end + +function LogDensityProblems.capabilities(::Type{<:MixedADTestModel}) + return LogDensityProblems.LogDensityOrder{2}() +end + +function LogDensityProblems.logdensity_and_gradient(::MixedADTestModel, θ) + return (Float64(ℯ), [1.0, 2.0, 3.0]) +end + +function LogDensityProblems.logdensity_gradient_and_hessian(::MixedADTestModel, θ) + return (Float64(ℯ), [1.0, 2.0, 3.0], [1.0 1.0 1.0; 2.0 2.0 2.0; 3.0 3.0 3.0]) +end + +@testset "interface MixedADLogDensityProblem" begin + model = MixedADTestModel() + model_ad = AdvancedVI.MixedADLogDensityProblem(model) + + d = 3 + x = randn(d) + + @test LogDensityProblems.dimension(model) == LogDensityProblems.dimension(model_ad) + @test last(LogDensityProblems.logdensity(model, x)) ≈ + last(LogDensityProblems.logdensity(model_ad, x)) + @test last(LogDensityProblems.logdensity_and_gradient(model, x)) ≈ + last(LogDensityProblems.logdensity_and_gradient(model_ad, x)) + @test last(LogDensityProblems.logdensity_gradient_and_hessian(model, x)) ≈ + last(LogDensityProblems.logdensity_gradient_and_hessian(model_ad, x)) + + f(x, aux) = LogDensityProblems.logdensity(model_ad, x) + + buf = DiffResults.DiffResult(zero(Float64), zeros(d)) + AdvancedVI._value_and_gradient!(f, buf, AD, x, nothing) + @test DiffResults.gradient(buf) ≈ [1, 2, 3] +end diff --git a/test/runtests.jl b/test/runtests.jl index ba04d2315..7d1f74636 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,15 +62,16 @@ if GROUP == "All" || GROUP == "General" include("general/proximal_location_scale_entropy.jl") end +if GROUP == "All" || GROUP == "General" || GROUP == "AD" + include("general/ad.jl") + include("general/mixedad_logdensity.jl") +end + if GROUP == "All" || GROUP == "Families" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") end -if GROUP == "All" || GROUP == "General" || GROUP == "AD" - include("general/ad.jl") -end - if GROUP == "All" || GROUP == "ParamSpaceSGD" || GROUP == "AD" include("algorithms/paramspacesgd/repgradelbo.jl") include("algorithms/paramspacesgd/scoregradelbo.jl") From 499203c4875612562d4ca2ec0cee80e86bc3abf3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 01:35:56 -0400 Subject: [PATCH 068/108] fix renamed jobs in integration tests --- .github/workflows/Enzyme.yml | 2 +- .github/workflows/Mooncake.yml | 2 +- .github/workflows/Zygote.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index cf6555f30..a0c5b0a84 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: jobs: - test: + Enzyme: uses: ./.github/workflows/AutoDiffIntegration.yml with: AD: Enzyme diff --git a/.github/workflows/Mooncake.yml b/.github/workflows/Mooncake.yml index 9de098606..134169cfe 100644 --- a/.github/workflows/Mooncake.yml +++ b/.github/workflows/Mooncake.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: jobs: - test: + Mooncake: uses: ./.github/workflows/AutoDiffIntegration.yml with: AD: Mooncake diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml index 59cfc9d3e..fb09c7c25 100644 --- a/.github/workflows/Zygote.yml +++ b/.github/workflows/Zygote.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: jobs: - test: + Zygote: uses: ./.github/workflows/AutoDiffIntegration.yml with: AD: Zygote From 9608db9b7b8d44db405bc3588498dccb3d66b4a9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 01:36:02 -0400 Subject: [PATCH 069/108] fix --- test/general/mixedad_logdensity.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl index c26658075..2316952a1 100644 --- a/test/general/mixedad_logdensity.jl +++ b/test/general/mixedad_logdensity.jl @@ -21,12 +21,16 @@ function LogDensityProblems.logdensity_gradient_and_hessian(::MixedADTestModel, return (Float64(ℯ), [1.0, 2.0, 3.0], [1.0 1.0 1.0; 2.0 2.0 2.0; 3.0 3.0 3.0]) end +function mixedad_prob_test_ad_fwd(x, aux) + return LogDensityProblems.logdensity(aux.model, x) +end + @testset "interface MixedADLogDensityProblem" begin model = MixedADTestModel() model_ad = AdvancedVI.MixedADLogDensityProblem(model) d = 3 - x = randn(d) + x = ones(Float64, d) @test LogDensityProblems.dimension(model) == LogDensityProblems.dimension(model_ad) @test last(LogDensityProblems.logdensity(model, x)) ≈ @@ -36,9 +40,7 @@ end @test last(LogDensityProblems.logdensity_gradient_and_hessian(model, x)) ≈ last(LogDensityProblems.logdensity_gradient_and_hessian(model_ad, x)) - f(x, aux) = LogDensityProblems.logdensity(model_ad, x) - buf = DiffResults.DiffResult(zero(Float64), zeros(d)) - AdvancedVI._value_and_gradient!(f, buf, AD, x, nothing) + AdvancedVI._value_and_gradient!(mixedad_prob_test_ad_fwd, buf, AD, x, (model=model_ad,)) @test DiffResults.gradient(buf) ≈ [1, 2, 3] end From 43a662509f4ac754f77b5606e7eea9343cbfb45e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 02:13:48 -0400 Subject: [PATCH 070/108] add test for without mixed ad --- test/algorithms/paramspacesgd/repgradelbo.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 5704a3acc..26ff3436e 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -23,6 +23,19 @@ end end + @testset "without mixed ad" begin + @testset for n_montecarlo in [1, 10] + alg = KLMinRepGradDescent( + AD; + n_samples=n_montecarlo, + operator=IdentityOperator(), + averager=PolynomialAveraging(), + ) + _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) + @assert isfinite(last(info).elbo) + end + end + obj = RepGradELBO(10) rng = StableRNG(seed) elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) From 8cc6058d35e421f3eee7f300dbae01a6ddaf4aa3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 02:18:46 -0400 Subject: [PATCH 071/108] fix test for MixedADLogDensityProblem --- test/general/mixedad_logdensity.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl index 2316952a1..1c69db6e9 100644 --- a/test/general/mixedad_logdensity.jl +++ b/test/general/mixedad_logdensity.jl @@ -21,8 +21,12 @@ function LogDensityProblems.logdensity_gradient_and_hessian(::MixedADTestModel, return (Float64(ℯ), [1.0, 2.0, 3.0], [1.0 1.0 1.0; 2.0 2.0 2.0; 3.0 3.0 3.0]) end -function mixedad_prob_test_ad_fwd(x, aux) - return LogDensityProblems.logdensity(aux.model, x) +@eval AdvancedVI begin + using LogDensityProblems + + function mixedad_prob_test_ad_fwd(x, aux) + return LogDensityProblems.logdensity(aux.model, x) + end end @testset "interface MixedADLogDensityProblem" begin @@ -41,6 +45,8 @@ end last(LogDensityProblems.logdensity_gradient_and_hessian(model_ad, x)) buf = DiffResults.DiffResult(zero(Float64), zeros(d)) - AdvancedVI._value_and_gradient!(mixedad_prob_test_ad_fwd, buf, AD, x, (model=model_ad,)) + AdvancedVI._value_and_gradient!( + AdvancedVI.mixedad_prob_test_ad_fwd, buf, AD, x, (model=model_ad,) + ) @test DiffResults.gradient(buf) ≈ [1, 2, 3] end From f841eed7ee4dd5aaf6ea7aece34086c9c1908fee Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 02:27:48 -0400 Subject: [PATCH 072/108] remove test --- test/general/mixedad_logdensity.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl index 1c69db6e9..308cbfa7e 100644 --- a/test/general/mixedad_logdensity.jl +++ b/test/general/mixedad_logdensity.jl @@ -21,14 +21,6 @@ function LogDensityProblems.logdensity_gradient_and_hessian(::MixedADTestModel, return (Float64(ℯ), [1.0, 2.0, 3.0], [1.0 1.0 1.0; 2.0 2.0 2.0; 3.0 3.0 3.0]) end -@eval AdvancedVI begin - using LogDensityProblems - - function mixedad_prob_test_ad_fwd(x, aux) - return LogDensityProblems.logdensity(aux.model, x) - end -end - @testset "interface MixedADLogDensityProblem" begin model = MixedADTestModel() model_ad = AdvancedVI.MixedADLogDensityProblem(model) @@ -43,10 +35,4 @@ end last(LogDensityProblems.logdensity_and_gradient(model_ad, x)) @test last(LogDensityProblems.logdensity_gradient_and_hessian(model, x)) ≈ last(LogDensityProblems.logdensity_gradient_and_hessian(model_ad, x)) - - buf = DiffResults.DiffResult(zero(Float64), zeros(d)) - AdvancedVI._value_and_gradient!( - AdvancedVI.mixedad_prob_test_ad_fwd, buf, AD, x, (model=model_ad,) - ) - @test DiffResults.gradient(buf) ≈ [1, 2, 3] end From fd6a0aca7e6be4dc269a8b4213ace9791d280d1a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 02:48:56 -0400 Subject: [PATCH 073/108] update docs --- docs/src/examples.md | 41 ++++++++++++++------------- docs/src/paramspacesgd/repgradelbo.md | 2 +- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/docs/src/examples.md b/docs/src/examples.md index 07c04a462..7c9350645 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -48,10 +48,12 @@ n_dims = 10 μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); - +nothing ``` -The algorithm we will use requires the target problem to have at least first-order differentiation capability. +Some of the VI algorithms require gradients of the target log-density. +In this example, we will use `KLMinRepGradDescent`, which requires first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities). +For this, we can rely on `LogDensityProblemsAD`: ```@example elboexample using LogDensityProblemsAD @@ -61,7 +63,23 @@ model_ad = ADgradient(AutoReverseDiff(), model) nothing ``` -Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. +Let's now load `AdvancedVI`. +In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation. +Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`. +(This does not need to be the same as the backend used by `model_ad`.) +The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. +Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. + +```@example elboexample +using Optimisers +using AdvancedVI + +alg = KLMinRepGradDescent(AutoReverseDiff()); +nothing +``` + +Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support. +Since `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. ```@example elboexample @@ -80,23 +98,6 @@ binv = inverse(b) nothing ``` -Let's now load `AdvancedVI`. -Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`. -Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. -Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. - -```@example elboexample -using Optimisers -using AdvancedVI -``` - -We now need to select 1. a variational objective, and 2. a variational family. -Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. - -```@example elboexample -alg = KLMinRepGradDescent(AutoReverseDiff()) -``` - For the variational family, we will use the classic mean-field Gaussian family. ```@example elboexample diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index e59535dfb..97ee29444 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -231,7 +231,7 @@ _, info_stl, _ = AdvancedVI.optimize( _, info_stl, _ = AdvancedVI.optimize( stl, max_iter, - model, + model_ad, q0_trans; show_progress = false, callback = callback, From 1accbf97166aaef3521b170efcfa998a66df84c2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 02:49:01 -0400 Subject: [PATCH 074/108] refactor test --- test/algorithms/paramspacesgd/repgradelbo.jl | 6 +++--- test/runtests.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 26ff3436e..e8f690c5e 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -8,10 +8,10 @@ (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) + model_ad = ADgradient(AutoForwardDiff(), model) @testset "basic" begin @testset for n_montecarlo in [1, 10] - model_ad = ADgradient(AutoForwardDiff(), model) alg = KLMinRepGradDescent( AD; n_samples=n_montecarlo, @@ -19,7 +19,7 @@ averager=PolynomialAveraging(), ) _, info, _ = optimize(rng, alg, 10, model_ad, q0; show_progress=false) - @assert isfinite(last(info).elbo) + @test isfinite(last(info).elbo) end end @@ -32,7 +32,7 @@ averager=PolynomialAveraging(), ) _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) - @assert isfinite(last(info).elbo) + @test isfinite(last(info).elbo) end end diff --git a/test/runtests.jl b/test/runtests.jl index 7d1f74636..1d9babef6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,11 +60,11 @@ if GROUP == "All" || GROUP == "General" include("general/averaging.jl") include("general/clip_scale.jl") include("general/proximal_location_scale_entropy.jl") + include("general/mixedad_logdensity.jl") end if GROUP == "All" || GROUP == "General" || GROUP == "AD" include("general/ad.jl") - include("general/mixedad_logdensity.jl") end if GROUP == "All" || GROUP == "Families" From 30c8e1574d07dca5dab964246892200996f7d84f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 13 Jul 2025 23:46:08 -0400 Subject: [PATCH 075/108] add missing test --- test/general/mixedad_logdensity.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl index 308cbfa7e..ace2c6b59 100644 --- a/test/general/mixedad_logdensity.jl +++ b/test/general/mixedad_logdensity.jl @@ -29,6 +29,8 @@ end x = ones(Float64, d) @test LogDensityProblems.dimension(model) == LogDensityProblems.dimension(model_ad) + @test LogDensityProblems.capabilities(typeof(model)) ≈ + LogDensityProblems.capabilities(typeof(model_ad)) @test last(LogDensityProblems.logdensity(model, x)) ≈ last(LogDensityProblems.logdensity(model_ad, x)) @test last(LogDensityProblems.logdensity_and_gradient(model, x)) ≈ From b17d66179087787be43f309a4abc033c394055ad Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:09:23 -0400 Subject: [PATCH 076/108] revert interface changes to paramspacesgd --- src/algorithms/paramspacesgd/paramspacesgd.jl | 4 ++-- src/algorithms/paramspacesgd/repgradelbo.jl | 7 ++++--- src/algorithms/paramspacesgd/scoregradelbo.jl | 10 +++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 10819cf1d..f70ac574b 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -81,7 +81,7 @@ function output(alg::ParamSpaceSGD, state) end function step( - rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, args...; kwargs... + rng::Random.AbstractRNG, alg::ParamSpaceSGD, state, callback, objargs...; kwargs... ) (; adtype, objective, operator, averager) = alg (; prob, q, iteration, grad_buf, opt_st, obj_st, avg_st) = state @@ -91,7 +91,7 @@ function step( params, re = Optimisers.destructure(q) grad_buf, obj_st, info = estimate_gradient!( - rng, objective, adtype, grad_buf, params, re, obj_st, args... + rng, objective, adtype, grad_buf, prob, params, re, obj_st, objargs... ) grad = DiffResults.gradient(grad_buf) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 96f08e360..9f6c084b0 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -141,23 +141,24 @@ function estimate_gradient!( obj::RepGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, + prob, params, restructure, state, args..., ) - (; obj_ad_prep, problem) = state + prep = state q_stop = restructure(params) aux = ( rng=rng, adtype=adtype, obj=obj, - problem=problem, + problem=prob, restructure=restructure, q_stop=q_stop, ) AdvancedVI._value_and_gradient!( - estimate_repgradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux + estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux ) nelbo = DiffResults.value(out) stat = (elbo=(-nelbo),) diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl index 88764c100..7a2f0b42f 100644 --- a/src/algorithms/paramspacesgd/scoregradelbo.jl +++ b/src/algorithms/paramspacesgd/scoregradelbo.jl @@ -28,10 +28,9 @@ function init( samples = rand(rng, q, obj.n_samples) ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) - obj_ad_prep = AdvancedVI._prepare_gradient( + return AdvancedVI._prepare_gradient( estimate_scoregradelbo_ad_forward, adtype, params, aux ) - return (obj_ad_prep=obj_ad_prep, problem=prob) end function Base.show(io::IO, obj::ScoreGradELBO) @@ -83,18 +82,19 @@ function AdvancedVI.estimate_gradient!( obj::ScoreGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, + prob, params, restructure, state, args..., ) q = restructure(params) - (; obj_ad_prep, problem) = state + prep = state samples = rand(rng, q, obj.n_samples) - ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, problem), eachsample(samples)) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) AdvancedVI._value_and_gradient!( - estimate_scoregradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux + estimate_scoregradelbo_ad_forward, out, prep, adtype, params, aux ) ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) elbo = mean(ℓπ - ℓq) From 408f41db23a6f4a3fdd62fb0fd9e93ef1dcc8569 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:16:53 -0400 Subject: [PATCH 077/108] remove dependency on LogDensityProblemsAD --- test/Project.toml | 2 -- test/models/normal.jl | 7 +++++++ test/models/normallognormal.jl | 9 ++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 90a0c6815..db7d9068a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -36,7 +35,6 @@ ForwardDiff = "0.10.36, 1" Functors = "0.4.5, 0.5" LinearAlgebra = "1" LogDensityProblems = "2.1.1" -LogDensityProblemsAD = "1" Mooncake = "0.4" Optimisers = "0.2.16, 0.3, 0.4" PDMats = "0.11.7" diff --git a/test/models/normal.jl b/test/models/normal.jl index 5826547df..a966d467a 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -9,6 +9,13 @@ function LogDensityProblems.logdensity(model::TestNormal, θ) return logpdf(MvNormal(μ, Σ), θ) end +function LogDensityProblems.logdensity_and_gradient(model::TestNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ) + ) +end + function LogDensityProblems.dimension(model::TestNormal) return length(model.μ) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 00949bc1b..7069e019e 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -11,12 +11,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ) + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end function Bijectors.bijector(model::NormalLogNormal) From 5f9cac9e1f81b31d31bc75f809c5b0210ed15373 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:17:05 -0400 Subject: [PATCH 078/108] revert calls to `ADgradient` in tests --- test/algorithms/paramspacesgd/repgradelbo.jl | 7 +++---- .../algorithms/paramspacesgd/repgradelbo_locationscale.jl | 8 +++----- .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 8 +++----- .../paramspacesgd/repgradelbo_proximal_locationscale.jl | 8 +++----- .../repgradelbo_proximal_locationscale_bijectors.jl | 8 +++----- 5 files changed, 15 insertions(+), 24 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index e8f690c5e..3a935e50c 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -8,7 +8,6 @@ (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) - model_ad = ADgradient(AutoForwardDiff(), model) @testset "basic" begin @testset for n_montecarlo in [1, 10] @@ -18,7 +17,7 @@ operator=IdentityOperator(), averager=PolynomialAveraging(), ) - _, info, _ = optimize(rng, alg, 10, model_ad, q0; show_progress=false) + _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) end end @@ -59,8 +58,8 @@ end modelstats = normal_meanfield(rng, Float64) (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - model_ad = ADgradient(AutoForwardDiff(), model) - mixed_ad = AdvancedVI.MixedADLogDensityProblem(model_ad) + model = ADgradient(AutoForwardDiff(), model) + mixed_ad = AdvancedVI.MixedADLogDensityProblem(model) @testset for n_montecarlo in [1, 10] q_true = MeanFieldGaussian( diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index c0d7796f2..bdf96a789 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -15,8 +15,6 @@ modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats - model_ad = ADgradient(AutoForwardDiff(), model) - T = 1000 η = 1e-3 @@ -36,7 +34,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -49,14 +47,14 @@ @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) q_avg, stats, _ = optimize( - rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS + rng_repl, alg, T, model, q0; show_progress=PROGRESS ) μ_repl = q_avg.location L_repl = q_avg.scale diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index d185e644f..99589ad93 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -19,8 +19,6 @@ η = 1e-3 alg = KLMinRepGradDescent(AD; optimizer=Descent(η)) - model_ad = ADgradient(AutoForwardDiff(), model) - b = Bijectors.bijector(model) b⁻¹ = inverse(b) μ0 = Zeros(realtype, n_dims) @@ -41,7 +39,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -54,14 +52,14 @@ @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) q_avg, stats, _ = optimize( - rng_repl, alg, T, model_ad, q0_z; show_progress=PROGRESS + rng_repl, alg, T, model, q0_z; show_progress=PROGRESS ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 22fa11906..2ab089f5e 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -16,8 +16,6 @@ modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats - model_ad = ADgradient(AutoForwardDiff(), model) - q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) else @@ -36,7 +34,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale @@ -49,13 +47,13 @@ @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) q_avg, stats, _ = optimize( - rng_repl, alg, T, model_ad, q0; show_progress=PROGRESS + rng_repl, alg, T, model, q0; show_progress=PROGRESS ) μ_repl = q_avg.location L_repl = q_avg.scale diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index b26e21065..76856d2b8 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -16,8 +16,6 @@ modelstats = modelconstr(rng, realtype) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats - model_ad = ADgradient(AutoForwardDiff(), model) - T = 1000 η = 1e-3 alg = KLMinRepGradProxDescent(AD; optimizer=Descent(η)) @@ -42,7 +40,7 @@ @testset "convergence" begin Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale @@ -55,13 +53,13 @@ @testset "determinism" begin rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model_ad, q0_z; show_progress=PROGRESS) + q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) μ = q_avg.dist.location L = q_avg.dist.scale rng_repl = StableRNG(seed) q_avg, stats, _ = optimize( - rng_repl, alg, T, model_ad, q0_z; show_progress=PROGRESS + rng_repl, alg, T, model, q0_z; show_progress=PROGRESS ) μ_repl = q_avg.dist.location L_repl = q_avg.dist.scale From 46b1f917df9e05bd88c9f7ec4a1d1bc202f2f5fa Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:17:23 -0400 Subject: [PATCH 079/108] move Zygote import --- test/general/optimize.jl | 11 +++++------ test/general/proximal_location_scale_entropy.jl | 6 ++---- test/runtests.jl | 3 +-- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 1ca924d93..ca5ef4506 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -10,7 +10,6 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) - model_ad = ADgradient(AutoForwardDiff(), model) adtype = AutoReverseDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() @@ -18,7 +17,7 @@ alg = ParamSpaceSGD(obj, adtype, optimizer, averager, IdentityOperator()) @testset "default_rng" begin - optimize(alg, T, model_ad, q0; show_progress=false) + optimize(alg, T, model, q0; show_progress=false) end @testset "callback" begin @@ -26,12 +25,12 @@ callback(; iteration, args...) = (test_value=test_values[iteration],) - _, info, _ = optimize(alg, T, model_ad, q0; show_progress=false, callback) + _, info, _ = optimize(alg, T, model, q0; show_progress=false, callback) @test [i.test_value for i in info] == test_values end rng = StableRNG(seed) - q_avg_ref, _, _ = optimize(rng, alg, T, model_ad, q0; show_progress=false) + q_avg_ref, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) @testset "warm start" begin rng = StableRNG(seed) @@ -39,8 +38,8 @@ T_first = div(T, 2) T_last = T - T_first - _, _, state = optimize(rng, alg, T_first, model_ad, q0; show_progress=false) - q_avg, _, _ = optimize(rng, alg, T_last, model_ad, q0; show_progress=false, state) + _, _, state = optimize(rng, alg, T_first, model, q0; show_progress=false) + q_avg, _, _ = optimize(rng, alg, T_last, model, q0; show_progress=false, state) @test q_avg == q_avg_ref end diff --git a/test/general/proximal_location_scale_entropy.jl b/test/general/proximal_location_scale_entropy.jl index ed74a6251..ddbbf7300 100644 --- a/test/general/proximal_location_scale_entropy.jl +++ b/test/general/proximal_location_scale_entropy.jl @@ -1,6 +1,4 @@ -using Zygote - @testset "interface ProximalLocationScaleEntropy" begin @testset "MvLocationScale" begin @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in @@ -53,8 +51,8 @@ using Zygote q′ = re(params′) scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale - grad_left = first(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) - grad_right = first( + grad_left = only(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) + grad_right = only( Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) ) diff --git a/test/runtests.jl b/test/runtests.jl index 1d9babef6..28854e5c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ using Bijectors using DiffResults using Distributions using FillArrays -using ForwardDiff +using ForwardDiff, Zygote using LinearAlgebra using LogDensityProblems, LogDensityProblemsAD using Optimisers @@ -29,7 +29,6 @@ elseif AD_str == "Mooncake" using Mooncake AutoMooncake(; config=Mooncake.Config()) elseif AD_str == "Zygote" - using Zygote AutoZygote() elseif AD_str == "Enzyme" using Enzyme From 907a8d49bd96ea2fee75185fca60d103beffbc52 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:24:05 -0400 Subject: [PATCH 080/108] revert remaining changes in tests --- test/algorithms/paramspacesgd/repgradelbo_locationscale.jl | 6 +----- .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index bdf96a789..b10430dac 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -17,7 +17,6 @@ T = 1000 η = 1e-3 - alg = KLMinRepGradDescent(AD; optimizer=Descent(η)) q0 = if is_meanfield @@ -48,14 +47,11 @@ @testset "determinism" begin rng = StableRNG(seed) q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - μ = q_avg.location L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize( - rng_repl, alg, T, model, q0; show_progress=PROGRESS - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 99589ad93..4657ea2f5 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -53,7 +53,6 @@ @testset "determinism" begin rng = StableRNG(seed) q_avg, stats, _ = optimize(rng, alg, T, model, q0_z; show_progress=PROGRESS) - μ = q_avg.dist.location L = q_avg.dist.scale From 68c35bbb997834719ff1c3f27d93632e244e84c9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:26:28 -0400 Subject: [PATCH 081/108] Revert "restructure move AD integration tests into separate workflow" This reverts commit 82f930719e3e935fee5c472e126a15f0e4d9230c. --- .../{AutoDiffIntegration.yml => AD.yml} | 23 +++++++++++-------- .github/workflows/Enzyme.yml | 14 ----------- .github/workflows/Mooncake.yml | 14 ----------- .github/workflows/Zygote.yml | 14 ----------- 4 files changed, 13 insertions(+), 52 deletions(-) rename .github/workflows/{AutoDiffIntegration.yml => AD.yml} (83%) delete mode 100644 .github/workflows/Enzyme.yml delete mode 100644 .github/workflows/Mooncake.yml delete mode 100644 .github/workflows/Zygote.yml diff --git a/.github/workflows/AutoDiffIntegration.yml b/.github/workflows/AD.yml similarity index 83% rename from .github/workflows/AutoDiffIntegration.yml rename to .github/workflows/AD.yml index 339ed2a0e..5aade045d 100644 --- a/.github/workflows/AutoDiffIntegration.yml +++ b/.github/workflows/AD.yml @@ -1,17 +1,16 @@ -name: Autodiff Integration +name: AD on: - workflow_call: - inputs: - AD: - required: true - type: string - + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} - jobs: test: name: Julia ${{ matrix.AD }} - ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} @@ -28,7 +27,11 @@ jobs: - windows-latest arch: - x64 - + AD: + - Enzyme + - ReverseDiff + - Zygote + - Mooncake steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -40,7 +43,7 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: AD - AD: ${{ inputs.AD }} + AD: ${{ matrix.AD }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v3 with: diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml deleted file mode 100644 index a0c5b0a84..000000000 --- a/.github/workflows/Enzyme.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Enzyme Integration -on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: - -jobs: - Enzyme: - uses: ./.github/workflows/AutoDiffIntegration.yml - with: - AD: Enzyme diff --git a/.github/workflows/Mooncake.yml b/.github/workflows/Mooncake.yml deleted file mode 100644 index 134169cfe..000000000 --- a/.github/workflows/Mooncake.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Mooncake Integration -on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: - -jobs: - Mooncake: - uses: ./.github/workflows/AutoDiffIntegration.yml - with: - AD: Mooncake diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml deleted file mode 100644 index fb09c7c25..000000000 --- a/.github/workflows/Zygote.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Zygote Integration -on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: - -jobs: - Zygote: - uses: ./.github/workflows/AutoDiffIntegration.yml - with: - AD: Zygote From d84b273ac3c3075cc0a683b38afe7721a2b50a1b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:28:30 -0400 Subject: [PATCH 082/108] fix revert renaming of CI.yml --- .github/workflows/{AD.yml => CI.yml} | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) rename .github/workflows/{AD.yml => CI.yml} (77%) diff --git a/.github/workflows/AD.yml b/.github/workflows/CI.yml similarity index 77% rename from .github/workflows/AD.yml rename to .github/workflows/CI.yml index 5aade045d..06219ec9f 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/CI.yml @@ -1,4 +1,4 @@ -name: AD +name: CI on: push: branches: @@ -13,7 +13,7 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - name: Julia ${{ matrix.AD }} - ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -27,11 +27,6 @@ jobs: - windows-latest arch: - x64 - AD: - - Enzyme - - ReverseDiff - - Zygote - - Mooncake steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -41,9 +36,6 @@ jobs: - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - GROUP: AD - AD: ${{ matrix.AD }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v3 with: From 52dd8058cfc740027f964f90759d720ae16cf795 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:29:47 -0400 Subject: [PATCH 083/108] Revert "fix name" This reverts commit 050f36383aed9e987bdc2f05cc6ac53cb48e5318. --- .github/workflows/Tests.yml | 45 ------------------------------------- 1 file changed, 45 deletions(-) delete mode 100644 .github/workflows/Tests.yml diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml deleted file mode 100644 index 10a279dd1..000000000 --- a/.github/workflows/Tests.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Tests -on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - version: - - 'lts' - - '1' - os: - - ubuntu-latest - - macOS-latest - - windows-latest - arch: - - x64 - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - env: - GROUP: All - AD: ReverseDiff - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 - with: - files: lcov.info From 8b911f4ebca1d404bc0e4397b924bd91dd9e7967 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:31:02 -0400 Subject: [PATCH 084/108] fix revert Enzyme.yml --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a04819f16..4a51ba183 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. For example, integrating `Turing` with `AdvancedVI.ADVI` only involves converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. -## Examples +## Example `AdvancedVI` works with differentiable models specified as a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl). For example, for the normal-log-normal model: From a2177ea6d9d5cdc6b9f12fc65af59131ca5a4cde Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:31:30 -0400 Subject: [PATCH 085/108] fix add back Enzyme.yml --- .github/workflows/Enzyme.yml | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/Enzyme.yml diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml new file mode 100644 index 000000000..de16a19d0 --- /dev/null +++ b/.github/workflows/Enzyme.yml @@ -0,0 +1,40 @@ +name: Enzyme +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + env: + TEST_GROUP: Enzyme + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - 'lts' + - '1' + os: + - ubuntu-latest + - macOS-latest + - windows-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 From 019aa47308e2ac9278e2ab08e6b77ba306fdbf31 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:33:50 -0400 Subject: [PATCH 086/108] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/models/normal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/models/normal.jl b/test/models/normal.jl index a966d467a..74abf29c8 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -12,7 +12,7 @@ end function LogDensityProblems.logdensity_and_gradient(model::TestNormal, θ) return ( LogDensityProblems.logdensity(model, θ), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ) + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ), ) end From 3a0b093d0d2ba22b709536f4f47c08689b2ce222 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:34:13 -0400 Subject: [PATCH 087/108] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/models/normallognormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 7069e019e..209febb3f 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -14,7 +14,7 @@ end function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) return ( LogDensityProblems.logdensity(model, θ), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ) + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ), ) end From 6776d55e761aae38d3fabfcda55bed1cd3040a2b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 14:42:56 -0400 Subject: [PATCH 088/108] revert non-essential changes to tests --- .../repgradelbo_locationscale.jl | 23 ++++++++++-- .../repgradelbo_locationscale_bijectors.jl | 23 ++++++++++-- .../repgradelbo_proximal_locationscale.jl | 28 +++++++++++--- ...adelbo_proximal_locationscale_bijectors.jl | 23 ++++++++++-- .../algorithms/paramspacesgd/scoregradelbo.jl | 17 ++++++++- .../scoregradelbo_locationscale.jl | 22 +++++++++-- .../scoregradelbo_locationscale_bijectors.jl | 22 +++++++++-- test/general/ad.jl | 26 ++++++++++--- test/general/optimize.jl | 3 +- test/runtests.jl | 37 +++++++------------ 10 files changed, 171 insertions(+), 53 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index b10430dac..64345cd8d 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -1,13 +1,30 @@ +AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end @testset "inference RepGradELBO VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), (objname, objective) in Dict( :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropy()), - ) + ), + (adbackname, adtype) in AD_repgradelbo_locationscale seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -17,7 +34,7 @@ T = 1000 η = 1e-3 - alg = KLMinRepGradDescent(AD; optimizer=Descent(η)) + alg = KLMinRepGradDescent(adtype; optimizer=Descent(η)) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 4657ea2f5..fe2c131b5 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -1,13 +1,30 @@ +AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end @testset "inference RepGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), (objname, objective) in Dict( :RepGradELBOClosedFormEntropy => RepGradELBO(10), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropy()), - ) + ), + (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -17,7 +34,7 @@ T = 1000 η = 1e-3 - alg = KLMinRepGradDescent(AD; optimizer=Descent(η)) + alg = KLMinRepGradDescent(adtype; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 2ab089f5e..c04445faa 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -1,6 +1,23 @@ +AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end + @testset "inference RepGradELBO Proximal VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), (objname, objective) in Dict( @@ -8,7 +25,8 @@ RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), - ) + ), + (adbackname, adtype) in AD_repgradelbo_locationscale seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -25,7 +43,7 @@ T = 1000 η = 1e-3 - alg = KLMinRepGradProxDescent(AD; optimizer=Descent(η)) + alg = KLMinRepGradProxDescent(adtype; optimizer=Descent(η)) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), @@ -52,9 +70,7 @@ L = q_avg.scale rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize( - rng_repl, alg, T, model, q0; show_progress=PROGRESS - ) + q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) μ_repl = q_avg.location L_repl = q_avg.scale @test μ == μ_repl diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index 76856d2b8..2d69d3af5 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -1,6 +1,22 @@ +AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end @testset "inference RepGradELBO Proximal VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), (objname, objective) in Dict( @@ -8,7 +24,8 @@ RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), :RepGradELBOStickingTheLanding => RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), - ) + ), + (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -18,7 +35,7 @@ T = 1000 η = 1e-3 - alg = KLMinRepGradProxDescent(AD; optimizer=Descent(η)) + alg = KLMinRepGradProxDescent(adtype; optimizer=Descent(η)) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index b9462b685..011aea6a1 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -1,4 +1,15 @@ +AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" + [AutoEnzyme()] +else + [ + AutoForwardDiff(), + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] +end + @testset "interface ScoreGradELBO" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -10,8 +21,10 @@ q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) @testset "basic" begin - @testset for n_montecarlo in [1, 10] - alg = KLMinScoreGradDescent(AD; n_samples=n_montecarlo, optimizer=Descent(1e-5)) + @testset for adtype in AD_scoregradelbo_interface, n_montecarlo in [1, 10] + alg = KLMinScoreGradDescent( + adtype; n_samples=n_montecarlo, optimizer=Descent(1e-5) + ) _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) @assert isfinite(last(info).elbo) end diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl index cbde12318..d4941587c 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl @@ -1,8 +1,24 @@ +AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end @testset "inference ScoreGradELBO VILocationScale" begin - @testset "$(modelname) $(realtype)" for realtype in [Float64, Float32], + @testset "$(modelname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in - Dict(:Normal => normal_meanfield, :Normal => normal_fullrank) + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), + (adbackname, adtype) in AD_scoregradelbo_locationscale seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -13,7 +29,7 @@ T = 1000 η = 1e-4 opt = Optimisers.Descent(η) - alg = KLMinScoreGradDescent(AD; n_samples=10, optimizer=opt) + alg = KLMinScoreGradDescent(adtype; n_samples=10, optimizer=opt) q0 = if is_meanfield MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl index 2e5d2691c..63add9269 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl @@ -1,8 +1,24 @@ +AD_scoregradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end @testset "inference ScoreGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(realtype)" for realtype in [Float64, Float32], + @testset "$(modelname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in - Dict(:NormalLogNormalMeanField => normallognormal_meanfield) + Dict(:NormalLogNormalMeanField => normallognormal_meanfield), + (adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -13,7 +29,7 @@ T = 1000 η = 1e-4 opt = Optimisers.Descent(η) - alg = KLMinScoreGradDescent(AD; n_samples=10, optimizer=opt) + alg = KLMinScoreGradDescent(adtype; n_samples=10, optimizer=opt) b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/general/ad.jl b/test/general/ad.jl index 251eead5e..ac9a82bfd 100644 --- a/test/general/ad.jl +++ b/test/general/ad.jl @@ -1,20 +1,36 @@ +AD_interface = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end + @testset "ad" begin - @testset "value_and_gradient!" begin + @testset "$(adname)" for (adname, adtype) in AD_interface D = 10 A = randn(D, D) λ = randn(D) b = randn(D) grad_buf = DiffResults.GradientResult(λ) f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - AdvancedVI._value_and_gradient!(f, grad_buf, AD, λ, (b=b,)) + AdvancedVI._value_and_gradient!(f, grad_buf, adtype, λ, (b=b,)) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A') * λ / 2 + b @test f ≈ λ' * A * λ / 2 + dot(b, λ) end - @testset "value_and_gradient! with prep" begin + @testset "$(adname) with prep" for (adname, adtype) in AD_interface D = 10 λ = randn(D) A = randn(D, D) @@ -22,10 +38,10 @@ b_prep = randn(D) f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - prep = AdvancedVI._prepare_gradient(f, AD, λ, (b=b_prep,)) + prep = AdvancedVI._prepare_gradient(f, adtype, λ, (b=b_prep,)) b = randn(D) - AdvancedVI._value_and_gradient!(f, grad_buf, prep, AD, λ, (b=b,)) + AdvancedVI._value_and_gradient!(f, grad_buf, prep, adtype, λ, (b=b,)) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) diff --git a/test/general/optimize.jl b/test/general/optimize.jl index ca5ef4506..6882b74da 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -10,7 +10,8 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) - adtype = AutoReverseDiff() + + adtype = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() diff --git a/test/runtests.jl b/test/runtests.jl index 28854e5c4..e72e00af8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,15 +2,13 @@ using Test using Test: @testset, @test -using ADTypes using Base.Iterators using Bijectors using DiffResults using Distributions using FillArrays -using ForwardDiff, Zygote using LinearAlgebra -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using Optimisers using PDMats using Pkg @@ -18,29 +16,18 @@ using Random, StableRNGs using Statistics using StatsBase +using ADTypes +using ForwardDiff, ReverseDiff, Zygote, Mooncake + using AdvancedVI -AD_str = get(ENV, "AD", "ReverseDiff") +const PROGRESS = haskey(ENV, "PROGRESS") +const TEST_GROUP = get(ENV, "TEST_GROUP", "All") -const AD = if AD_str == "ReverseDiff" - using ReverseDiff - AutoReverseDiff() -elseif AD_str == "Mooncake" - using Mooncake - AutoMooncake(; config=Mooncake.Config()) -elseif AD_str == "Zygote" - AutoZygote() -elseif AD_str == "Enzyme" +if TEST_GROUP == "Enzyme" using Enzyme - AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ) end -const PROGRESS = haskey(ENV, "PROGRESS") -const GROUP = get(ENV, "GROUP", "All") - # Models for Inference Tests struct TestModel{M,L,S,SC} model::M @@ -53,7 +40,8 @@ end include("models/normal.jl") include("models/normallognormal.jl") -if GROUP == "All" || GROUP == "General" +if TEST_GROUP == "All" || TEST_GROUP == "General" + # Interface tests that do not involve testing on Enzyme include("general/optimize.jl") include("general/rules.jl") include("general/averaging.jl") @@ -62,16 +50,17 @@ if GROUP == "All" || GROUP == "General" include("general/mixedad_logdensity.jl") end -if GROUP == "All" || GROUP == "General" || GROUP == "AD" +if TEST_GROUP == "All" || TEST_GROUP == "General" || TEST_GROUP == "Enzyme" + # Interface tests that involve testing on Enzyme include("general/ad.jl") end -if GROUP == "All" || GROUP == "Families" +if TEST_GROUP == "All" || TEST_GROUP == "Families" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") end -if GROUP == "All" || GROUP == "ParamSpaceSGD" || GROUP == "AD" +if TEST_GROUP == "All" || TEST_GROUP == "ParamSpaceSGD" || TEST_GROUP == "Enzyme" include("algorithms/paramspacesgd/repgradelbo.jl") include("algorithms/paramspacesgd/scoregradelbo.jl") include("algorithms/paramspacesgd/repgradelbo_locationscale.jl") From 2adf75c84d9c89f1ee09c2947d636a1c7203817f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:08:26 -0400 Subject: [PATCH 089/108] fix remaining changes in tests --- test/algorithms/paramspacesgd/repgradelbo.jl | 29 +++++++++++++++---- .../scoregradelbo_locationscale.jl | 1 + test/general/mixedad_logdensity.jl | 2 +- test/general/optimize.jl | 2 +- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 3a935e50c..61449741f 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -1,4 +1,19 @@ +AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" + [ + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ] +else + [ + AutoReverseDiff(), + AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), + ] +end + @testset "interface RepGradELBO" begin seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) @@ -10,9 +25,9 @@ q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) @testset "basic" begin - @testset for n_montecarlo in [1, 10] + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] alg = KLMinRepGradDescent( - AD; + adtype; n_samples=n_montecarlo, operator=IdentityOperator(), averager=PolynomialAveraging(), @@ -25,7 +40,7 @@ @testset "without mixed ad" begin @testset for n_montecarlo in [1, 10] alg = KLMinRepGradDescent( - AD; + adtype; n_samples=n_montecarlo, operator=IdentityOperator(), averager=PolynomialAveraging(), @@ -61,7 +76,7 @@ end model = ADgradient(AutoForwardDiff(), model) mixed_ad = AdvancedVI.MixedADLogDensityProblem(model) - @testset for n_montecarlo in [1, 10] + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) @@ -69,9 +84,11 @@ end obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()) out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - aux = (rng=rng, obj=obj, problem=mixed_ad, restructure=re, q_stop=q_true, adtype=AD) + aux = ( + rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype + ) AdvancedVI._value_and_gradient!( - AdvancedVI.estimate_repgradelbo_ad_forward, out, AD, params, aux + AdvancedVI.estimate_repgradelbo_ad_forward, out, adtype, params, aux ) grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol = 1e-5 diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl index d4941587c..2505a3669 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl @@ -1,3 +1,4 @@ + AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" Dict( :Enzyme => AutoEnzyme(; diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl index ace2c6b59..1fb65697d 100644 --- a/test/general/mixedad_logdensity.jl +++ b/test/general/mixedad_logdensity.jl @@ -29,7 +29,7 @@ end x = ones(Float64, d) @test LogDensityProblems.dimension(model) == LogDensityProblems.dimension(model_ad) - @test LogDensityProblems.capabilities(typeof(model)) ≈ + @test LogDensityProblems.capabilities(typeof(model)) == LogDensityProblems.capabilities(typeof(model_ad)) @test last(LogDensityProblems.logdensity(model, x)) ≈ last(LogDensityProblems.logdensity(model_ad, x)) diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 6882b74da..5849e2bc2 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -11,7 +11,7 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) - adtype = AutoForwardDiff() + adtype = AutoReverseDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() From 19547abb78c77019f3e7c4bdfdf33cebaf24eff2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:08:47 -0400 Subject: [PATCH 090/108] fix revert necessary change to paramspacesgd interface --- src/algorithms/paramspacesgd/abstractobjective.jl | 2 +- src/algorithms/paramspacesgd/paramspacesgd.jl | 2 +- src/algorithms/paramspacesgd/repgradelbo.jl | 7 +++---- src/algorithms/paramspacesgd/scoregradelbo.jl | 11 +++++------ src/mixedad_logdensity.jl | 4 ++-- src/optimize.jl | 6 ++++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/algorithms/paramspacesgd/abstractobjective.jl b/src/algorithms/paramspacesgd/abstractobjective.jl index 18759bfca..e395319bd 100644 --- a/src/algorithms/paramspacesgd/abstractobjective.jl +++ b/src/algorithms/paramspacesgd/abstractobjective.jl @@ -59,7 +59,7 @@ function estimate_objective end export estimate_objective """ - estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state, args...) + estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index f70ac574b..ef4116e78 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -91,7 +91,7 @@ function step( params, re = Optimisers.destructure(q) grad_buf, obj_st, info = estimate_gradient!( - rng, objective, adtype, grad_buf, prob, params, re, obj_st, objargs... + rng, objective, adtype, grad_buf, params, re, obj_st, objargs... ) grad = DiffResults.gradient(grad_buf) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 9f6c084b0..96f08e360 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -141,24 +141,23 @@ function estimate_gradient!( obj::RepGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, - prob, params, restructure, state, args..., ) - prep = state + (; obj_ad_prep, problem) = state q_stop = restructure(params) aux = ( rng=rng, adtype=adtype, obj=obj, - problem=prob, + problem=problem, restructure=restructure, q_stop=q_stop, ) AdvancedVI._value_and_gradient!( - estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux + estimate_repgradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux ) nelbo = DiffResults.value(out) stat = (elbo=(-nelbo),) diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl index 7a2f0b42f..3ce787ba3 100644 --- a/src/algorithms/paramspacesgd/scoregradelbo.jl +++ b/src/algorithms/paramspacesgd/scoregradelbo.jl @@ -28,9 +28,10 @@ function init( samples = rand(rng, q, obj.n_samples) ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) - return AdvancedVI._prepare_gradient( + obj_ad_prep = AdvancedVI._prepare_gradient( estimate_scoregradelbo_ad_forward, adtype, params, aux ) + return (obj_ad_prep=obj_ad_prep, problem=prob) end function Base.show(io::IO, obj::ScoreGradELBO) @@ -82,19 +83,17 @@ function AdvancedVI.estimate_gradient!( obj::ScoreGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, - prob, params, restructure, state, - args..., ) q = restructure(params) - prep = state + (; obj_ad_prep, problem) = state samples = rand(rng, q, obj.n_samples) - ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, problem), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) AdvancedVI._value_and_gradient!( - estimate_scoregradelbo_ad_forward, out, prep, adtype, params, aux + estimate_scoregradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux ) ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) elbo = mean(ℓπ - ℓq) diff --git a/src/mixedad_logdensity.jl b/src/mixedad_logdensity.jl index 38ed131bf..722229ebe 100644 --- a/src/mixedad_logdensity.jl +++ b/src/mixedad_logdensity.jl @@ -32,8 +32,8 @@ function LogDensityProblems.logdensity_gradient_and_hessian( return LogDensityProblems.logdensity_gradient_and_hessian(mixedad_prob.problem, x) end -function LogDensityProblems.capabilities(mixedad_prob::MixedADLogDensityProblem) - return LogDensityProblems.capabilities(mixedad_prob.problem) +function LogDensityProblems.capabilities(::Type{MixedADLogDensityProblem{Prob}}) where {Prob} + return LogDensityProblems.capabilities(Prob) end function ChainRulesCore.rrule( diff --git a/src/optimize.jl b/src/optimize.jl index 657157b75..e538a0838 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -45,7 +45,7 @@ function optimize( max_iter::Int, prob, q_init, - args...; + objargs...; show_progress::Bool=true, state::Union{<:Any,Nothing}=nothing, callback=nothing, @@ -64,7 +64,9 @@ function optimize( for t in 1:max_iter info = (iteration=t,) - state, terminate, info′ = step(rng, algorithm, state, callback, args...; kwargs...) + state, terminate, info′ = step( + rng, algorithm, state, callback, objargs...; kwargs... + ) info = merge(info′, info) if terminate From 2e4396d55e172adcd8615daf0d3ac8d914424ace Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:26:37 -0400 Subject: [PATCH 091/108] fix errors in tests --- test/algorithms/paramspacesgd/repgradelbo.jl | 2 +- test/algorithms/paramspacesgd/repgradelbo_locationscale.jl | 1 - .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 1 - .../paramspacesgd/repgradelbo_proximal_locationscale.jl | 1 - .../repgradelbo_proximal_locationscale_bijectors.jl | 1 - test/models/normal.jl | 4 ++-- test/models/normallognormal.jl | 2 +- 7 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 61449741f..dee53c2b6 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -38,7 +38,7 @@ end end @testset "without mixed ad" begin - @testset for n_montecarlo in [1, 10] + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] alg = KLMinRepGradDescent( adtype; n_samples=n_montecarlo, diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 64345cd8d..3f7d4f114 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index fe2c131b5..33995ee84 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index c04445faa..624a292f7 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -8,7 +8,6 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index 2d69d3af5..dcd9722a8 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/models/normal.jl b/test/models/normal.jl index 74abf29c8..3db1af08d 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -12,7 +12,7 @@ end function LogDensityProblems.logdensity_and_gradient(model::TestNormal, θ) return ( LogDensityProblems.logdensity(model, θ), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), ) end @@ -21,7 +21,7 @@ function LogDensityProblems.dimension(model::TestNormal) end function LogDensityProblems.capabilities(::Type{<:TestNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 209febb3f..7c9f29c10 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -14,7 +14,7 @@ end function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) return ( LogDensityProblems.logdensity(model, θ), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity), model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), ) end From b8298a3d771356d6080483eb9b52ffe7ef14d462 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:26:47 -0400 Subject: [PATCH 092/108] remove LogDensityProblemsAD dependency in benchmark --- bench/Project.toml | 1 - bench/benchmarks.jl | 3 +-- bench/normallognormal.jl | 9 ++++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/bench/Project.toml b/bench/Project.toml index ab87b4881..c3ff783ab 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -10,7 +10,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 260f6db8f..f5f42b9b0 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -40,7 +40,6 @@ begin max_iter = 10^4 d = LogDensityProblems.dimension(prob) opt = Optimisers.Adam(T(1e-3)) - prob_ad = ADgradient(AutoForwardDiff(), prob) for (objname, entropy) in [ ("RepGradELBO", ClosedFormEntropy()), @@ -73,7 +72,7 @@ begin SUITES[probname][objname][familyname][adname] = begin @benchmarkable AdvancedVI.optimize( - $alg, $max_iter, $prob_ad, $q; show_progress=false + $alg, $max_iter, $prob, $q; show_progress=false ) end end diff --git a/bench/normallognormal.jl b/bench/normallognormal.jl index cb6592b71..4cfc9af1a 100644 --- a/bench/normallognormal.jl +++ b/bench/normallognormal.jl @@ -12,12 +12,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) return log_density_x + log_density_y end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end function Bijectors.bijector(model::NormalLogNormal) From b7d98229404655074b754da57e7ab9381973ac1b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:28:00 -0400 Subject: [PATCH 093/108] fix remove unused import in benchmark --- bench/benchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index f5f42b9b0..8ac139518 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -8,7 +8,7 @@ using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake using FillArrays using InteractiveUtils using LinearAlgebra -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using Optimisers using Random From 27a2b835e3a89738fd53173b6f18aa33565bfd56 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:28:42 -0400 Subject: [PATCH 094/108] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/mixedad_logdensity.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mixedad_logdensity.jl b/src/mixedad_logdensity.jl index 722229ebe..451541ffc 100644 --- a/src/mixedad_logdensity.jl +++ b/src/mixedad_logdensity.jl @@ -32,7 +32,9 @@ function LogDensityProblems.logdensity_gradient_and_hessian( return LogDensityProblems.logdensity_gradient_and_hessian(mixedad_prob.problem, x) end -function LogDensityProblems.capabilities(::Type{MixedADLogDensityProblem{Prob}}) where {Prob} +function LogDensityProblems.capabilities( + ::Type{MixedADLogDensityProblem{Prob}} +) where {Prob} return LogDensityProblems.capabilities(Prob) end From e73b678a2f2bfdac058fb1c28e121280d4a7a962 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:31:15 -0400 Subject: [PATCH 095/108] add missing compats --- Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index cafca1212..5d31be04b 100644 --- a/Project.toml +++ b/Project.toml @@ -40,13 +40,16 @@ DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" +Enzyme = "0.13" FillArrays = "1.3" Functors = "0.4, 0.5" LinearAlgebra = "1" LogDensityProblems = "2" +Mooncake = "0.4" Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.6" Random = "1" +ReverseDiff = "1" StatsBase = "0.32, 0.33, 0.34" julia = "1.10, 1.11.2" From 45408b119b8d9273424acf7b83062b3d815e39d3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Jul 2025 15:32:12 -0400 Subject: [PATCH 096/108] fix revert changes to README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4a51ba183..a04819f16 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. For example, integrating `Turing` with `AdvancedVI.ADVI` only involves converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. -## Example +## Examples `AdvancedVI` works with differentiable models specified as a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl). For example, for the normal-log-normal model: From 1fc1c49db7d949444d02f80918d72ff1a192c238 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:08:34 -0400 Subject: [PATCH 097/108] revert change of the test order --- test/algorithms/paramspacesgd/repgradelbo.jl | 2 +- test/algorithms/paramspacesgd/repgradelbo_locationscale.jl | 2 +- .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 2 +- .../paramspacesgd/repgradelbo_proximal_locationscale.jl | 2 +- .../repgradelbo_proximal_locationscale_bijectors.jl | 2 +- test/algorithms/paramspacesgd/scoregradelbo.jl | 2 +- test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl | 2 +- test/general/optimize.jl | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index dee53c2b6..9c76f14e8 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -8,9 +8,9 @@ AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" ] else [ + AutoMooncake(; config=Mooncake.Config()), AutoReverseDiff(), AutoZygote(), - AutoMooncake(; config=Mooncake.Config()), ] end diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 3f7d4f114..e2e69b1cd 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -7,9 +7,9 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( + :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 33995ee84..4b8475691 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -7,9 +7,9 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( + :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 624a292f7..0cc5c0990 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -8,9 +8,9 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( + :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index dcd9722a8..3843ded42 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -7,9 +7,9 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( + :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index 011aea6a1..1f5563daf 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -3,10 +3,10 @@ AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" [AutoEnzyme()] else [ + AutoMooncake(; config=Mooncake.Config()), AutoForwardDiff(), AutoReverseDiff(), AutoZygote(), - AutoMooncake(; config=Mooncake.Config()), ] end diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl index 2505a3669..d2eb43bd6 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl @@ -8,10 +8,10 @@ AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( + :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 5849e2bc2..10d932fe0 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -11,7 +11,7 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) - adtype = AutoReverseDiff() + adtype = AutoZygote() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() From 71295320b164dc160bc5b05fed177aa64eb82c7d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:09:05 -0400 Subject: [PATCH 098/108] use ReverseDiff for general/optimize.jl tests --- test/general/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 10d932fe0..5849e2bc2 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -11,7 +11,7 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) - adtype = AutoZygote() + adtype = AutoReverseDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() From ddb703c8441f6f6f265acbb4a9489b610318709a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:09:18 -0400 Subject: [PATCH 099/108] fix Mooncake errors by removing wrong ReverseDiff specialization --- ext/AdvancedVIReverseDiffExt.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index f29a72c29..8c3eccba2 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -8,8 +8,4 @@ ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity( prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray ) -ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity( - prob::AdvancedVI.MixedADLogDensityProblem, x::AbstractVector -) - end From 2c7031b56a5f018411cb4f3c77b0466498a725c4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:10:36 -0400 Subject: [PATCH 100/108] fix remove call to `ADgradient` --- test/algorithms/paramspacesgd/repgradelbo.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index 9c76f14e8..d8e791b8c 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -73,7 +73,6 @@ end modelstats = normal_meanfield(rng, Float64) (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats - model = ADgradient(AutoForwardDiff(), model) mixed_ad = AdvancedVI.MixedADLogDensityProblem(model) @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] From 623657c560ebfd12953edffb4f2b8197931b7327 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:12:20 -0400 Subject: [PATCH 101/108] fix AD test order --- test/algorithms/paramspacesgd/repgradelbo_locationscale.jl | 2 +- .../paramspacesgd/repgradelbo_locationscale_bijectors.jl | 2 +- .../paramspacesgd/repgradelbo_proximal_locationscale.jl | 2 +- .../repgradelbo_proximal_locationscale_bijectors.jl | 2 +- test/algorithms/paramspacesgd/scoregradelbo.jl | 2 +- test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index e2e69b1cd..3f7d4f114 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -7,9 +7,9 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index 4b8475691..33995ee84 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -7,9 +7,9 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index 0cc5c0990..624a292f7 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -8,9 +8,9 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index 3843ded42..dcd9722a8 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -7,9 +7,9 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end diff --git a/test/algorithms/paramspacesgd/scoregradelbo.jl b/test/algorithms/paramspacesgd/scoregradelbo.jl index 1f5563daf..011aea6a1 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo.jl @@ -3,10 +3,10 @@ AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" [AutoEnzyme()] else [ - AutoMooncake(; config=Mooncake.Config()), AutoForwardDiff(), AutoReverseDiff(), AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), ] end diff --git a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl index d2eb43bd6..2505a3669 100644 --- a/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/scoregradelbo_locationscale.jl @@ -8,10 +8,10 @@ AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :Mooncake => AutoMooncake(; config=Mooncake.Config()), :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), ) end From 862f5930a554d2162869ba2d85964b5fb04a1fe3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:13:34 -0400 Subject: [PATCH 102/108] fix test order in `paramspacesgd/repgradelbo.jl` --- test/algorithms/paramspacesgd/repgradelbo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index d8e791b8c..636e0a764 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -8,9 +8,9 @@ AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" ] else [ - AutoMooncake(; config=Mooncake.Config()), AutoReverseDiff(), AutoZygote(), + AutoMooncake(; config=Mooncake.Config()), ] end From fe1fa8465bfe2fbc52a2d42003f5fbf14f797eb5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:26:14 -0400 Subject: [PATCH 103/108] fix typo in warning message of capability check in repgradelbo --- src/algorithms/paramspacesgd/repgradelbo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 96f08e360..cd81fd561 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -36,8 +36,8 @@ function init( capability = LogDensityProblems.capabilities(Prob) @assert adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme} ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}() - @warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}())" * - "Will attempt to directly differentiate through `LogDensityProblems.logdensity`." * + @warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}()) " * + "Will attempt to directly differentiate through `LogDensityProblems.logdensity`. " * "If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())" prob else From aaa39f9d8e574a99b8834c563e04e83210a82590 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:26:28 -0400 Subject: [PATCH 104/108] fix capability of unconstrdist in benchmark --- bench/unconstrdist.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bench/unconstrdist.jl b/bench/unconstrdist.jl index 04223757e..164199f0b 100644 --- a/bench/unconstrdist.jl +++ b/bench/unconstrdist.jl @@ -7,6 +7,13 @@ function LogDensityProblems.logdensity(model::UnconstrDist, x) return logpdf(model.dist, x) end +function LogDensityProblems.logdensity_and_gradient(model::UnconstrDist, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::UnconstrDist) return length(model.dist) end From 6cb9d7ea18557a1ef09ea10ae3528cd29633aa11 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:29:03 -0400 Subject: [PATCH 105/108] fix don't run benchmark on Enzyme --- bench/benchmarks.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 8ac139518..30ea0c093 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -49,13 +49,7 @@ begin ("Zygote", AutoZygote()), ("Mooncake", AutoMooncake(; config=Mooncake.Config())), ("ReverseDiff", AutoReverseDiff()), - ( - "Enzyme", - AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ), + # ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), ], (familyname, family) in [ ("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))), From 27b9f224304d3027a97569c64f0d512311849440 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:29:15 -0400 Subject: [PATCH 106/108] fix docs don't use LogDensityProblemsAD --- docs/Project.toml | 1 - docs/src/examples.md | 33 ++++++++++++--------------- docs/src/families.md | 18 ++++++++++----- docs/src/paramspacesgd/repgradelbo.md | 32 +++++++++++++++----------- 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 3322daae5..81d9f5fc5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,7 +7,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" diff --git a/docs/src/examples.md b/docs/src/examples.md index 7c9350645..07626dab7 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -15,6 +15,7 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows ```@example elboexample using LogDensityProblems +using ForwardDiff struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -28,14 +29,24 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end ``` +Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities). +The required order of differentiation capability will vary depending on the VI algorithm. +In this example, we will use `KLMinRepGradDescent`, which requires first-order capability. Let's now instantiate the model @@ -51,27 +62,15 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); nothing ``` -Some of the VI algorithms require gradients of the target log-density. -In this example, we will use `KLMinRepGradDescent`, which requires first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities). -For this, we can rely on `LogDensityProblemsAD`: - -```@example elboexample -using LogDensityProblemsAD -using ADTypes, ReverseDiff - -model_ad = ADgradient(AutoReverseDiff(), model) -nothing -``` - Let's now load `AdvancedVI`. In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation. Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`. -(This does not need to be the same as the backend used by `model_ad`.) +(This does not need to be the same as the AD backend used for the first-order capability of `model`.) The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. ```@example elboexample -using Optimisers +using ADTypes, ReverseDiff using AdvancedVI alg = KLMinRepGradDescent(AutoReverseDiff()); @@ -119,9 +118,7 @@ Passing `objective` and the initial variational approximation `q` to `optimize` ```@example elboexample n_max_iter = 10^4 -q_out, info, _ = AdvancedVI.optimize( - alg, n_max_iter, model_ad, q0_trans; show_progress=false -); +q_out, info, _ = AdvancedVI.optimize(alg, n_max_iter, model, q0_trans; show_progress=false); nothing ``` diff --git a/docs/src/families.md b/docs/src/families.md index df3342554..dc7fab768 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -135,7 +135,7 @@ using ADTypes using AdvancedVI using Distributions using LinearAlgebra -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using Optimisers using Plots using ForwardDiff, ReverseDiff @@ -148,12 +148,19 @@ function LogDensityProblems.logdensity(model::Target, θ) logpdf(model.dist, θ) end +function LogDensityProblems.logdensity_and_gradient(model::Target, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::Target) return length(model.dist) end function LogDensityProblems.capabilities(::Type{<:Target}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end n_dims = 30 @@ -163,7 +170,6 @@ D_true = Diagonal(log.(1 .+ exp.(randn(n_dims)))) Σsqrt_true = sqrt(Σ_true) μ_true = randn(n_dims) model = Target(MvNormal(μ_true, Σ_true)); -model_ad = ADgradient(AutoForwardDiff(), model) d = LogDensityProblems.dimension(model); μ = zeros(d); @@ -189,19 +195,19 @@ function callback(; params, averaged_params, restructure, kwargs...) end _, info_fr, _ = AdvancedVI.optimize( - alg, max_iter, model_ad, q0_fr; + alg, max_iter, model, q0_fr; show_progress = false, callback = callback, ); _, info_mf, _ = AdvancedVI.optimize( - alg, max_iter, model_ad, q0_mf; + alg, max_iter, model, q0_mf; show_progress = false, callback = callback, ); _, info_lr, _ = AdvancedVI.optimize( - alg, max_iter, model_ad, q0_lr; + alg, max_iter, model, q0_lr; show_progress = false, callback = callback, ); diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index 97ee29444..ea9bee849 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -122,7 +122,7 @@ For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a bac using Bijectors using FillArrays using LinearAlgebra -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using Plots using Random @@ -142,21 +142,27 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.LogDensityOrder{1}() end -n_dims = 10 -μ_x = 2.0 -σ_x = 0.3 -μ_y = Fill(2.0, n_dims) -σ_y = Fill(1.0, n_dims) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); -model_ad = ADgradient(AutoForwardDiff(), model) +n_dims = 10 +μ_x = 2.0 +σ_x = 0.3 +μ_y = Fill(2.0, n_dims) +σ_y = Fill(1.0, n_dims) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = zeros(d); @@ -213,7 +219,7 @@ end _, info_cfe, _ = AdvancedVI.optimize( cfe, max_iter, - model_ad, + model, q0_trans; show_progress = false, callback = callback, @@ -222,7 +228,7 @@ _, info_cfe, _ = AdvancedVI.optimize( _, info_stl, _ = AdvancedVI.optimize( stl, max_iter, - model_ad, + model, q0_trans; show_progress = false, callback = callback, @@ -231,7 +237,7 @@ _, info_stl, _ = AdvancedVI.optimize( _, info_stl, _ = AdvancedVI.optimize( stl, max_iter, - model_ad, + model, q0_trans; show_progress = false, callback = callback, @@ -314,7 +320,7 @@ nothing _, info_qmc, _ = AdvancedVI.optimize( KLMinRepGradDescent(AutoReverseDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), max_iter, - model_ad, + model, q0_trans; show_progress = false, callback = callback, From d16006085b9b59e0fe2a2a83d362270d8dbba13f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 15:29:50 -0400 Subject: [PATCH 107/108] fix order of AD benchmarks --- bench/benchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 30ea0c093..6808a4742 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -47,8 +47,8 @@ begin ], (adname, adtype) in [ ("Zygote", AutoZygote()), - ("Mooncake", AutoMooncake(; config=Mooncake.Config())), ("ReverseDiff", AutoReverseDiff()), + ("Mooncake", AutoMooncake(; config=Mooncake.Config())), # ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), ], (familyname, family) in [ From 3bb03f9f1a319c7345e51a2df74429b202b0f6f7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 30 Jul 2025 17:59:04 -0400 Subject: [PATCH 108/108] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/examples.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/examples.md b/docs/src/examples.md index 07626dab7..2fc0c8a73 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -44,6 +44,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) return LogDensityProblems.LogDensityOrder{1}() end ``` + Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities). The required order of differentiation capability will vary depending on the VI algorithm. In this example, we will use `KLMinRepGradDescent`, which requires first-order capability.