Skip to content

Commit b9df825

Browse files
committed
add
1 parent ad5bf6e commit b9df825

20 files changed

+145
-48
lines changed

Project.toml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.5.0"
55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
910
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -20,31 +21,44 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2021

2122
[weakdeps]
2223
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
24+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
25+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
26+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2327

2428
[extensions]
25-
AdvancedVIBijectorsExt = "Bijectors"
29+
AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"]
30+
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
31+
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
32+
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]
2633

2734
[compat]
2835
ADTypes = "1"
2936
Accessors = "0.1"
3037
Bijectors = "0.13, 0.14, 0.15"
38+
ChainRulesCore = "1"
3139
DiffResults = "1"
3240
DifferentiationInterface = "0.6, 0.7"
3341
Distributions = "0.25.111"
3442
DocStringExtensions = "0.8, 0.9"
43+
Enzyme = "0.13"
3544
FillArrays = "1.3"
3645
Functors = "0.4, 0.5"
3746
LinearAlgebra = "1"
3847
LogDensityProblems = "2"
48+
Mooncake = "0.4"
3949
Optimisers = "0.2.16, 0.3, 0.4"
4050
ProgressMeter = "1.6"
4151
Random = "1"
52+
ReverseDiff = "1"
4253
StatsBase = "0.32, 0.33, 0.34"
4354
julia = "1.10, 1.11.2"
4455

4556
[extras]
4657
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
58+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
59+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
4760
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
61+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4862
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4963

5064
[targets]

bench/benchmarks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ begin
4747
],
4848
(adname, adtype) in [
4949
("Zygote", AutoZygote()),
50-
("ForwardDiff", AutoForwardDiff()),
5150
("ReverseDiff", AutoReverseDiff()),
5251
("Mooncake", AutoMooncake(; config=Mooncake.Config())),
5352
# ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)),

bench/normallognormal.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
1212
return log_density_x + log_density_y
1313
end
1414

15+
function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
16+
return (
17+
LogDensityProblems.logdensity(model, θ),
18+
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
19+
)
20+
end
21+
1522
function LogDensityProblems.dimension(model::NormalLogNormal)
1623
return length(model.μ_y) + 1
1724
end
1825

1926
function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
20-
return LogDensityProblems.LogDensityOrder{0}()
27+
return LogDensityProblems.LogDensityOrder{1}()
2128
end
2229

2330
function Bijectors.bijector(model::NormalLogNormal)

bench/unconstrdist.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ function LogDensityProblems.logdensity(model::UnconstrDist, x)
77
return logpdf(model.dist, x)
88
end
99

10+
function LogDensityProblems.logdensity_and_gradient(model::UnconstrDist, θ)
11+
return (
12+
LogDensityProblems.logdensity(model, θ),
13+
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
14+
)
15+
end
16+
1017
function LogDensityProblems.dimension(model::UnconstrDist)
1118
return length(model.dist)
1219
end

docs/src/examples.md

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows
1515

1616
```@example elboexample
1717
using LogDensityProblems
18+
using ForwardDiff
1819
1920
struct NormalLogNormal{MX,SX,MY,SY}
2021
μ_x::MX
@@ -28,15 +29,26 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
2829
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
2930
end
3031
32+
function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
33+
return (
34+
LogDensityProblems.logdensity(model, θ),
35+
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
36+
)
37+
end
38+
3139
function LogDensityProblems.dimension(model::NormalLogNormal)
3240
return length(model.μ_y) + 1
3341
end
3442
3543
function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
36-
return LogDensityProblems.LogDensityOrder{0}()
44+
return LogDensityProblems.LogDensityOrder{1}()
3745
end
3846
```
3947

48+
Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities).
49+
The required order of differentiation capability will vary depending on the VI algorithm.
50+
In this example, we will use `KLMinRepGradDescent`, which requires first-order capability.
51+
4052
Let's now instantiate the model
4153

4254
```@example elboexample
@@ -51,7 +63,23 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2));
5163
nothing
5264
```
5365

54-
Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
66+
Let's now load `AdvancedVI`.
67+
In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation.
68+
Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`.
69+
(This does not need to be the same as the AD backend used for the first-order capability of `model`.)
70+
The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
71+
Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.
72+
73+
```@example elboexample
74+
using ADTypes, ReverseDiff
75+
using AdvancedVI
76+
77+
alg = KLMinRepGradDescent(AutoReverseDiff());
78+
nothing
79+
```
80+
81+
Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support.
82+
Since `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
5583
Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.
5684

5785
```@example elboexample
@@ -70,24 +98,6 @@ binv = inverse(b)
7098
nothing
7199
```
72100

73-
Let's now load `AdvancedVI`.
74-
Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`.
75-
Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
76-
Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.
77-
78-
```@example elboexample
79-
using Optimisers
80-
using ADTypes, ForwardDiff
81-
using AdvancedVI
82-
```
83-
84-
We now need to select 1. a variational objective, and 2. a variational family.
85-
Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
86-
87-
```@example elboexample
88-
alg = KLMinRepGradDescent(AutoForwardDiff())
89-
```
90-
91101
For the variational family, we will use the classic mean-field Gaussian family.
92102

93103
```@example elboexample

docs/src/families.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ using LinearAlgebra
138138
using LogDensityProblems
139139
using Optimisers
140140
using Plots
141-
using ReverseDiff
141+
using ForwardDiff, ReverseDiff
142142
143143
struct Target{D}
144144
dist::D
@@ -148,12 +148,19 @@ function LogDensityProblems.logdensity(model::Target, θ)
148148
logpdf(model.dist, θ)
149149
end
150150
151+
function LogDensityProblems.logdensity_and_gradient(model::Target, θ)
152+
return (
153+
LogDensityProblems.logdensity(model, θ),
154+
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
155+
)
156+
end
157+
151158
function LogDensityProblems.dimension(model::Target)
152159
return length(model.dist)
153160
end
154161
155162
function LogDensityProblems.capabilities(::Type{<:Target})
156-
return LogDensityProblems.LogDensityOrder{0}()
163+
return LogDensityProblems.LogDensityOrder{1}()
157164
end
158165
159166
n_dims = 30

docs/src/paramspacesgd/repgradelbo.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ using Plots
127127
using Random
128128
129129
using Optimisers
130-
using ADTypes, ForwardDiff
130+
using ADTypes, ForwardDiff, ReverseDiff
131131
using AdvancedVI
132132
133133
struct NormalLogNormal{MX,SX,MY,SY}
@@ -142,12 +142,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
142142
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
143143
end
144144
145+
function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
146+
return (
147+
LogDensityProblems.logdensity(model, θ),
148+
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
149+
)
150+
end
151+
145152
function LogDensityProblems.dimension(model::NormalLogNormal)
146153
length(model.μ_y) + 1
147154
end
148155
149156
function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
150-
LogDensityProblems.LogDensityOrder{0}()
157+
LogDensityProblems.LogDensityOrder{1}()
151158
end
152159
153160
n_dims = 10
@@ -185,7 +192,7 @@ binv = inverse(b)
185192
q0_trans = Bijectors.TransformedDistribution(q0, binv)
186193
187194
cfe = KLMinRepGradDescent(
188-
AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
195+
AutoReverseDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
189196
)
190197
nothing
191198
```
@@ -194,7 +201,7 @@ The repgradelbo estimator can instead be created as follows:
194201

195202
```@example repgradelbo
196203
stl = KLMinRepGradDescent(
197-
AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
204+
AutoReverseDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
198205
)
199206
nothing
200207
```
@@ -227,6 +234,15 @@ _, info_stl, _ = AdvancedVI.optimize(
227234
callback = callback,
228235
);
229236
237+
_, info_stl, _ = AdvancedVI.optimize(
238+
stl,
239+
max_iter,
240+
model,
241+
q0_trans;
242+
show_progress = false,
243+
callback = callback,
244+
);
245+
230246
t = [i.iteration for i in info_cfe]
231247
elbo_cfe = [i.elbo for i in info_cfe]
232248
elbo_stl = [i.elbo for i in info_stl]
@@ -302,7 +318,7 @@ nothing
302318

303319
```@setup repgradelbo
304320
_, info_qmc, _ = AdvancedVI.optimize(
305-
KLMinRepGradDescent(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
321+
KLMinRepGradDescent(AutoReverseDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
306322
max_iter,
307323
model,
308324
q0_trans;

src/AdvancedVI.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using LogDensityProblems
1818
using ADTypes
1919
using DiffResults
2020
using DifferentiationInterface
21+
using ChainRulesCore
2122

2223
using FillArrays
2324

@@ -95,6 +96,8 @@ This is an indirection for handling the type stability of `restructure`, as some
9596
"""
9697
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)
9798

99+
include("mixedad_logdensity.jl")
100+
98101
# Variational Families
99102
export MvLocationScale, MeanFieldGaussian, FullRankGaussian
100103

src/algorithms/paramspacesgd/repgradelbo.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ Evidence lower-bound objective with the reparameterization gradient formulation[
1313
# Requirements
1414
- The variational approximation ``q_{\\lambda}`` implements `rand`.
1515
- The target distribution and the variational approximation have the same support.
16-
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
16+
- The target `LogDensityProblem` must have a capability at least `LogDensityProblems.LogDensityOrder{1}()`.
17+
- Only the AD backend `ReverseDiff`, `Zygote`, `Mooncake` are supported.
18+
- The sampling process `rand(q)` must be differentiable by the selected AD backend.
1719
1820
Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
1921
"""
@@ -26,23 +28,33 @@ function init(
2628
rng::Random.AbstractRNG,
2729
obj::RepGradELBO,
2830
adtype::ADTypes.AbstractADType,
29-
prob,
31+
prob::Prob,
3032
params,
3133
restructure,
32-
)
34+
) where {Prob}
3335
q_stop = restructure(params)
36+
capability = LogDensityProblems.capabilities(Prob)
37+
@assert adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme}
38+
ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}()
39+
@warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}()) " *
40+
"Will attempt to directly differentiate through `LogDensityProblems.logdensity`. " *
41+
"If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())"
42+
prob
43+
else
44+
MixedADLogDensityProblem(prob)
45+
end
3446
aux = (
3547
rng=rng,
3648
adtype=adtype,
3749
obj=obj,
38-
problem=prob,
50+
problem=ad_prob,
3951
restructure=restructure,
4052
q_stop=q_stop,
4153
)
4254
obj_ad_prep = AdvancedVI._prepare_gradient(
4355
estimate_repgradelbo_ad_forward, adtype, params, aux
4456
)
45-
return (obj_ad_prep=obj_ad_prep, problem=prob)
57+
return (obj_ad_prep=obj_ad_prep, problem=ad_prob)
4658
end
4759

4860
function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy())
@@ -132,6 +144,7 @@ function estimate_gradient!(
132144
params,
133145
restructure,
134146
state,
147+
args...,
135148
)
136149
(; obj_ad_prep, problem) = state
137150
q_stop = restructure(params)

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
44
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
66
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
7-
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
87
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
98
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
109
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -30,7 +29,6 @@ Bijectors = "0.13, 0.14, 0.15"
3029
DiffResults = "1"
3130
DifferentiationInterface = "0.6, 0.7"
3231
Distributions = "0.25.111"
33-
DistributionsAD = "0.6.45"
3432
Enzyme = "0.13, 0.14, 0.15"
3533
FillArrays = "1.6.1"
3634
ForwardDiff = "0.10.36, 1"

0 commit comments

Comments
 (0)