Skip to content

Support Mixing AD Frameworks for LogDensityProblems and the objective (cleaned-up) #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.5.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -20,31 +21,44 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"]
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]

[compat]
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13, 0.14, 0.15"
ChainRulesCore = "1"
DiffResults = "1"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.6"
Random = "1"
ReverseDiff = "1"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.10, 1.11.2"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
1 change: 0 additions & 1 deletion bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ begin
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
("Mooncake", AutoMooncake(; config=Mooncake.Config())),
# ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)),
Expand Down
9 changes: 8 additions & 1 deletion bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
return log_density_x + log_density_y
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end

function Bijectors.bijector(model::NormalLogNormal)
Expand Down
7 changes: 7 additions & 0 deletions bench/unconstrdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ function LogDensityProblems.logdensity(model::UnconstrDist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.logdensity_and_gradient(model::UnconstrDist, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::UnconstrDist)
return length(model.dist)
end
Expand Down
50 changes: 30 additions & 20 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows

```@example elboexample
using LogDensityProblems
using ForwardDiff

struct NormalLogNormal{MX,SX,MY,SY}
μ_x::MX
Expand All @@ -28,15 +29,26 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end
```

Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities).
The required order of differentiation capability will vary depending on the VI algorithm.
In this example, we will use `KLMinRepGradDescent`, which requires first-order capability.

Let's now instantiate the model

```@example elboexample
Expand All @@ -51,7 +63,23 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2));
nothing
```

Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
Let's now load `AdvancedVI`.
In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation.
Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`.
(This does not need to be the same as the AD backend used for the first-order capability of `model`.)
The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
Here, we will use `ReverseDiff`, which can be selected by later passing `ADTypes.AutoReverseDiff()`.

```@example elboexample
using ADTypes, ReverseDiff
using AdvancedVI

alg = KLMinRepGradDescent(AutoReverseDiff());
nothing
```

Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support.
Since `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.

```@example elboexample
Expand All @@ -70,24 +98,6 @@ binv = inverse(b)
nothing
```

Let's now load `AdvancedVI`.
Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`.
Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.

```@example elboexample
using Optimisers
using ADTypes, ForwardDiff
using AdvancedVI
```

We now need to select 1. a variational objective, and 2. a variational family.
Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.

```@example elboexample
alg = KLMinRepGradDescent(AutoForwardDiff())
```

For the variational family, we will use the classic mean-field Gaussian family.

```@example elboexample
Expand Down
11 changes: 9 additions & 2 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ using LinearAlgebra
using LogDensityProblems
using Optimisers
using Plots
using ReverseDiff
using ForwardDiff, ReverseDiff

struct Target{D}
dist::D
Expand All @@ -148,12 +148,19 @@ function LogDensityProblems.logdensity(model::Target, θ)
logpdf(model.dist, θ)
end

function LogDensityProblems.logdensity_and_gradient(model::Target, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::Target)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:Target})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end

n_dims = 30
Expand Down
26 changes: 21 additions & 5 deletions docs/src/paramspacesgd/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ using Plots
using Random

using Optimisers
using ADTypes, ForwardDiff
using ADTypes, ForwardDiff, ReverseDiff
using AdvancedVI

struct NormalLogNormal{MX,SX,MY,SY}
Expand All @@ -142,12 +142,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
LogDensityProblems.LogDensityOrder{0}()
LogDensityProblems.LogDensityOrder{1}()
end

n_dims = 10
Expand Down Expand Up @@ -185,7 +192,7 @@ binv = inverse(b)
q0_trans = Bijectors.TransformedDistribution(q0, binv)

cfe = KLMinRepGradDescent(
AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
AutoReverseDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
)
nothing
```
Expand All @@ -194,7 +201,7 @@ The repgradelbo estimator can instead be created as follows:

```@example repgradelbo
stl = KLMinRepGradDescent(
AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
AutoReverseDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
)
nothing
```
Expand Down Expand Up @@ -227,6 +234,15 @@ _, info_stl, _ = AdvancedVI.optimize(
callback = callback,
);

_, info_stl, _ = AdvancedVI.optimize(
stl,
max_iter,
model,
q0_trans;
show_progress = false,
callback = callback,
);

t = [i.iteration for i in info_cfe]
elbo_cfe = [i.elbo for i in info_cfe]
elbo_stl = [i.elbo for i in info_stl]
Expand Down Expand Up @@ -302,7 +318,7 @@ nothing

```@setup repgradelbo
_, info_qmc, _ = AdvancedVI.optimize(
KLMinRepGradDescent(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
KLMinRepGradDescent(AutoReverseDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
max_iter,
model,
q0_trans;
Expand Down
13 changes: 13 additions & 0 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module AdvancedVIEnzymeExt

using AdvancedVI
using LogDensityProblems
using Enzyme

Enzyme.@import_rrule(
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
AbstractVector
)

end
20 changes: 20 additions & 0 deletions ext/AdvancedVIMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module AdvancedVIMooncakeExt

using AdvancedVI
using Base: IEEEFloat
using LogDensityProblems
using Mooncake

Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
Array{<:IEEEFloat,1},
}

Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
SubArray{<:IEEEFloat,1},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Red-Portal, this looks suspicious to me. Two possible issues:

  • @from_rule might not work well with SubArray
  • SubArray might require extra new rules.

cc @willtebbutt who knows this better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's mostly that @from_rrule doesn't work well with SubArrays at this point. Would it be possible to produce an MWE, and open an issue on Mooncake? Since ChainRules doesn't provide strong promises around what type is used to represent the result of an rrule, the translation between ChainRules and Mooncake is always going to be a bit flakey (there's simply not way around that), but we ought to be able to

  1. provide a better error message, and
  2. handle this particular case, which really doesn't seem like it ought to be too challenging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you all for chiming in. Filed an issue here

Copy link
Member

@yebai yebai Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Red-Portal I suggest you directly implement this rule for Mooncake here. Mooncake’s support of SubArrays might be several weeks away, because we are a tiny team and forward mode has higher priority.

}

end
15 changes: 15 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module AdvancedVIReverseDiffExt

using AdvancedVI
using LogDensityProblems
using ReverseDiff

ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(
prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray
)

ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(
prob::AdvancedVI.MixedADLogDensityProblem, x::AbstractArray{<:ReverseDiff.TrackedReal}
)

end
3 changes: 3 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using LogDensityProblems
using ADTypes
using DiffResults
using DifferentiationInterface
using ChainRulesCore

using FillArrays

Expand Down Expand Up @@ -95,6 +96,8 @@ This is an indirection for handling the type stability of `restructure`, as some
"""
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)

include("mixedad_logdensity.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian

Expand Down
Loading
Loading