diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9731f20c2..26f6876f5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: - windows-latest arch: - x64 - - x86 + # - x86 # Uncomment after https://github.com/JuliaTesting/ReTest.jl/pull/52 is merged exclude: - os: macOS-latest arch: x86 @@ -61,3 +61,30 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info + docs: + name: Documentation + runs-on: ubuntu-latest + permissions: + contents: write + statuses: write + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: '1' + - name: Configure doc environment + run: | + julia --project=docs/ -e ' + using Pkg + Pkg.develop(PackageSpec(path=pwd())) + Pkg.instantiate()' + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - run: | + julia --project=docs -e ' + using Documenter: DocMeta, doctest + using AdvancedVI + DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) + doctest(AdvancedVI)' diff --git a/Project.toml b/Project.toml index 800fa34f3..075ae92f3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,37 +1,67 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.2.4" +version = "0.3.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +AdvancedVIEnzymeExt = "Enzyme" +AdvancedVIForwardDiffExt = "ForwardDiff" +AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVIZygoteExt = "Zygote" [compat] -Bijectors = "0.11, 0.12, 0.13" -Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" +ADTypes = "0.1, 0.2" +Accessors = "0.1" +Bijectors = "0.12, 0.13" +ChainRulesCore = "1.16" +DiffResults = "1" +Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" -ForwardDiff = "0.10.3" -ProgressMeter = "1.0.0" -Requires = "0.5, 1.0" +Enzyme = "0.11.7" +FillArrays = "1.3" +ForwardDiff = "0.10.36" +Functors = "0.4" +LogDensityProblems = "2" +Optimisers = "0.2.16" +ProgressMeter = "1.6" +Requires = "1.0" +ReverseDiff = "1.15.1" +SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" +Zygote = "0.6.63" julia = "1.6" [extras] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/README.md b/README.md index 18ba63e50..695e9ed98 100644 --- a/README.md +++ b/README.md @@ -1,250 +1,108 @@ -# AdvancedVI.jl -A library for variational Bayesian inference in Julia. - -At the time of writing (05/02/2020), implementations of the variational inference (VI) interface and some algorithms are implemented in [Turing.jl](https://github.com/TuringLang/Turing.jl). The idea is to soon separate the VI functionality in Turing.jl out and into this package. - -The purpose of this package will then be to provide a common interface together with implementations of standard algorithms and utilities with the goal of ease of use and the ability for other packages, e.g. Turing.jl, to write a light wrapper around AdvancedVI.jl for integration. -As an example, in Turing.jl we support automatic differentiation variational inference (ADVI) but really the only piece of code tied into the Turing.jl is the conversion of a `Turing.Model` to a `logjoint(z)` function which computes `z ↦ log p(x, z)`, with `x` denoting the observations embedded in the `Turing.Model`. As long as this `logjoint(z)` method is compatible with some AD framework, e.g. `ForwardDiff.jl` or `Zygote.jl`, this is all we need from Turing.jl to be able to perform ADVI! - -## [WIP] Interface -- `vi`: the main interface to the functionality in this package - - `vi(model, alg)`: only used when `alg` has a default variational posterior which it will provide. - - `vi(model, alg, q::VariationalPosterior, θ)`: `q` represents the family of variational distributions and `θ` is the initial parameters "indexing" the starting distribution. This assumes that there exists an implementation `Variational.update(q, θ)` which returns the variational posterior corresponding to parameters `θ`. - - `vi(model, alg, getq::Function, θ)`: here `getq(θ)` is a function returning a `VariationalPosterior` corresponding to `θ`. -- `optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())` -- `grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)` - - Different combinations of variational objectives (`vo`), VI methods (`alg`), and variational posteriors (`q`) might use different gradient estimators. `grad!` allows us to specify these different behaviors. +# AdvancedVI.jl +[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. +VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. +`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. +The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. +For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) by simply converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. ## Examples -### Variational Inference -A very simple generative model is the following - μ ~ 𝒩(0, 1) - xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n +`AdvancedVI` expects a `LogDensityProblem`. +For example, for the normal-log-normal model: -where μ and xᵢ are some ℝᵈ vectors and 𝒩 denotes a d-dimensional multivariate Normal distribution. +$$ +\begin{aligned} +x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), +\end{aligned} +$$ -Given a set of `n` observations `[x₁, …, xₙ]` we're interested in finding the distribution `p(μ∣x₁, …, xₙ)` over the mean `μ`. We can obtain (an approximation to) this distribution that using AdvancedVI.jl! - -First we generate some observations and set up the problem: +a `LogDensityProblem` can be implemented as ```julia -julia> using Distributions - -julia> d = 2; n = 100; - -julia> observations = randn((d, n)); # 100 observations from 2D 𝒩(0, 1) - -julia> # Define generative model - # μ ~ 𝒩(0, 1) - # xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n - prior(μ) = logpdf(MvNormal(ones(d)), μ) -prior (generic function with 1 method) +using LogDensityProblems -julia> likelihood(x, μ) = sum(logpdf(MvNormal(μ, ones(d)), x)) -likelihood (generic function with 1 method) - -julia> logπ(μ) = likelihood(observations, μ) + prior(μ) -logπ (generic function with 1 method) - -julia> logπ(randn(2)) # <= just checking that it works --311.74132761437653 -``` -Now there are mainly two different ways of specifying the approximate posterior (and its family). The first is by providing a mapping from distribution parameters to the distribution `θ ↦ q(⋅∣θ)`: -```julia -julia> using DistributionsAD, AdvancedVI +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end -julia> # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:d], exp.(θ[d + 1:4])) -getq (generic function with 1 method) -``` -Then we make the choice of algorithm, a subtype of `VariationalInference`, -```julia -julia> # Perform VI - advi = ADVI(10, 10_000) -ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 10000) -``` -And finally we can perform VI! The usual inferface is to call `vi` which behind the scenes takes care of the optimization and returns the resulting variational posterior: -```julia -julia> q = vi(logπ, advi, getq, randn(4)) -[ADVI] Optimizing...100% Time: 0:00:01 -TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[0.16282745378074515, 0.15789310089462574], σ=[0.09519377533754399, 0.09273176907111745]) -``` -Let's have a look at the resulting ELBO: -```julia -julia> AdvancedVI.elbo(advi, q, logπ, 1000) --287.7866366886285 -``` -Unfortunately, the *final* value of the ELBO is not always a very good diagnostic, though the ELBO is an important metric to keep an eye on during training since an *increase* in the ELBO means we're going in the right direction. Luckily, this is such a simple problem that we can indeed obtain a closed form solution! Because we're lazy (at least I am), we'll let [ConjugatePriors.jl](https://github.com/JuliaStats/ConjugatePriors.jl) do this for us: -```julia -julia> # True posterior - using ConjugatePriors +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end -julia> pri = MvNormal(zeros(2), ones(2)); +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end -julia> true_posterior = posterior((pri, pri.Σ), MvNormal, observations) -DiagNormal( -dim: 2 -μ: [0.1746546592601148, 0.16457110079543008] -Σ: [0.009900990099009901 0.0; 0.0 0.009900990099009901] -) +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end ``` -Comparing to our variational approximation, this looks pretty good! Worth noting that in this particular case the variational posterior seems to overestimate the variance. -To conclude, let's make a somewhat pretty picture: +Since the support of `x` is constrained to be $$\mathbb{R}_+$$, and inference is best done in the unconstrained space $$\mathbb{R}_+$$, we need to use a *bijector* to match support. +This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, 2015). ```julia -julia> using Plots - -julia> p_samples = rand(true_posterior, 10_000); q_samples = rand(q, 10_000); +using Bijectors -julia> p1 = histogram(p_samples[1, :], label="p"); histogram!(q_samples[1, :], alpha=0.7, label="q") - -julia> title!(raw"$\mu_1$") - -julia> p2 = histogram(p_samples[2, :], label="p"); histogram!(q_samples[2, :], alpha=0.7, label="q") - -julia> title!(raw"$\mu_2$") - -julia> plot(p1, p2) +function Bijectors.bijector(model::NormalLogNormal) + (; μ_x, σ_x, μ_y, Σ_y) = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end ``` -![Histogram](hist.png?raw=true) - -### Simple example: using Advanced.jl to directly minimize the KL-divergence between two distributions `p(z)` and `q(z)` -In VI we aim to approximate the true posterior `p(z ∣ x)` by some approximate variational posterior `q(z)` by maximizing the ELBO: - - ELBO(q) = 𝔼_q[log p(x, z) - log q(z)] - -Observe that we can express the ELBO as the negative KL-divergence between `p(x, ⋅)` and `q(⋅)`: - - ELBO(q) = - 𝔼_q[log (q(z) / p(x, z))] - = - KL(q(⋅) || p(x, ⋅)) -So if we apply VI to something that isn't an actual posterior, i.e. there's no data involved and we write `p(z ∣ x) = p(z)`, we're really just minimizing the KL-divergence between the distributions. +A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. -Therefore, we can try out `AdvancedVI.jl` real quick by applying using the interface to minimize the KL-divergence between two distributions: - -```julia -julia> using Distributions, DistributionsAD, AdvancedVI - -julia> # Target distribution - p = MvNormal(ones(2)) -ZeroMeanDiagNormal( -dim: 2 -μ: [0.0, 0.0] -Σ: [1.0 0.0; 0.0 1.0] -) - -julia> logπ(z) = logpdf(p, z) -logπ (generic function with 1 method) - -julia> # Make a choice of VI algorithm - advi = ADVI(10, 1000) -ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 1000) -``` -Now there are two different ways of specifying the approximate posterior (and its family); the first is by providing a mapping from parameters to distribution `θ ↦ q(⋅∣θ)`: -```julia -julia> # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -getq (generic function with 1 method) - -julia> # Perform VI - q = vi(logπ, advi, getq, randn(4)) -┌ Info: [ADVI] Should only be seen once: optimizer created for θ -└ objectid(θ) = 0x5ddb564423896704 -[ADVI] Optimizing...100% Time: 0:00:01 -TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[-0.012691337868985757, -0.0004442434543332919], σ=[1.0334797673569802, 0.9957355128767893]) -``` -Or we can check the ELBO (which in this case since, as mentioned, doesn't involve data, is the negative KL-divergence): -```julia -julia> AdvancedVI.elbo(advi, q, logπ, 1000) # empirical estimate -0.08031049170093245 -``` -It's worth noting that the actual value of the ELBO doesn't really tell us too much about the quality of fit. In this particular case, because we're *directly* minimizing the KL-divergence, we can only say something useful if we reach 0, in which case we have obtained the true distribution. - -Let's just quickly check the mean-squared error between the `log p(z)` and `log q(z)` for a random set of samples from the target `p`: +Let us instantiate a random normal-log-normal model. ```julia -julia> zs = rand(p, 100); - -julia> mean(abs2, logpdf(q, zs) - logpdf(p, zs)) -0.0014889109427524852 +using LinearAlgebra + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) ``` -That doesn't look too bad! - -## Implementing your own training loop -Sometimes it might be convenient to roll your own training loop rather than using `vi(...)`. Here's some psuedo-code for how one would do that when used together with Turing.jl: +ADVI can be used as follows: ```julia -using Turing, AdvancedVI, DiffResults -using Turing: Variational - -using ProgressMeter - -# Assuming you have an instance of a Turing model (`model`) - -# 1. Create log-joint needed for ELBO evaluation -logπ = Variational.make_logjoint(model) - -# 2. Define objective -variational_objective = Variational.ELBO() - -# 3. Optimizer -optimizer = Variational.DecayedADAGrad() - -# 4. VI-algorithm -alg = ADVI(10, 1000) - -# 5. Variational distribution -function getq(θ) - # ... -end - -# 6. [OPTIONAL] Implement convergence criterion -function hasconverged(args...) - # ... -end - -# 7. [OPTIONAL] Implement a callback for tracking stats -function callback(args...) - # ... -end - -# 8. Train -converged = false -step = 1 - -prog = ProgressMeter.Progress(num_steps, 1) - -diff_results = DiffResults.GradientResult(θ_init) - -while (step ≤ num_steps) && !converged - # 1. Compute gradient and objective value; results are stored in `diff_results` - AdvancedVI.grad!(variational_objective, alg, getq, model, diff_results) - - # 2. Extract gradient from `diff_result` - ∇ = DiffResults.gradient(diff_result) - - # 3. Apply optimizer, e.g. multiplying by step-size - Δ = apply!(optimizer, θ, ∇) - - # 4. Update parameters - @. θ = θ - Δ - - # 5. Do whatever analysis you want - callback(args...) - - # 6. Update - converged = hasconverged(...) # or something user-defined - step += 1 +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI + +b = Bijectors.bijector(model) +b⁻¹ = inverse(b) + +# ADVI objective +objective = AVI.ADVI(model, 10; invbij=b⁻¹) + +# Mean-field Gaussian variational family +d = LogDensityProblems.dimension(model) +μ = randn(d) +L = Diagonal(ones(d)) +q = AVI.VIMeanFieldGaussian(μ, L) + +# Run inference +n_max_iter = 10^4 +q, stats, _ = AVI.optimize( + objective, + q, + n_max_iter; + adbackend = ADTypes.AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +) - ProgressMeter.next!(prog) -end +# Evaluate final ELBO with 10^3 Monte Carlo samples +objective(q; n_samples=10^3) ``` ## References -- Jordan, Michael I., Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. "An introduction to variational methods for graphical models." Machine learning 37, no. 2 (1999): 183-233. -- Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. "Variational inference: A review for statisticians." Journal of the American statistical Association 112, no. 518 (2017): 859-877. - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015. -- Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882. -- Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003. diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 000000000..568be1b61 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,17 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" + +[compat] +ADTypes = "0.1.6" +Bijectors = "0.13.6" +Documenter = "0.26, 0.27" +LogDensityProblems = "2.1.1" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 000000000..5d3716089 --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,22 @@ + +using AdvancedVI +using Documenter + +DocMeta.setdocmeta!( + AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true +) + +makedocs(; + modules = [AdvancedVI], + sitename = "AdvancedVI.jl", + repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", + format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), + pages = ["AdvancedVI" => "index.md", + "Getting Started" => "started.md", + "ELBO Maximization" => [ + "Automatic Differentiation VI" => "advi.md", + "Location Scale Family" => "locscale.md", + ]], +) + +deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/advi.md b/docs/src/advi.md new file mode 100644 index 000000000..f4fe3715b --- /dev/null +++ b/docs/src/advi.md @@ -0,0 +1,227 @@ + +# [Automatic Differentiation Variational Inference](@id advi) + +## Introduction + +The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. +By maximizing ADVI objective, it is equivalent to solving the problem + +```math + \mathrm{minimize}_{\lambda \in \Lambda}\quad \mathrm{KL}\left(q_{\phi,\lambda}, \pi\right). +``` + +The key aspects of the ADVI objective are the followings: +1. The use of the reparameterization gradient estimator +2. Automatically match the support of the target posterior through "bijectors." + +Thanks to Item 2, the user is free to choose any unconstrained variational family, for which +bijectors will automatically match the potentially constrained support of the target. + +In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}`` +from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that +```math +z \sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} +``` +ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``. + +That is, + +```math +\begin{aligned} +\mathrm{ADVI}\left(\lambda\right) +&\triangleq +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + \log \pi\left( \phi^{-1}\left( \eta \right) \right) +\right] ++ \mathbb{H}\left(q_{\lambda}\right) ++ \log \lvert J_{\phi^{-1}}\left(\eta\right) \rvert \\ +&= +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + \log \pi\left( \phi^{-1}\left( \eta \right) \right) +\right] ++ +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + - \log q_{\lambda}\left( \eta \right) \lvert J_{\phi}\left(\eta\right) \rvert +\right] \\ +&= +\mathbb{E}_{z \sim q_{\phi,\lambda}}\left[ \log \pi\left(z\right) \right] ++ +\mathbb{H}\left(q_{\phi,\lambda}\right) +\end{aligned} +``` + +The idea of using the reparameterization gradient estimator for variational inference was first +coined by Titsias and Lázaro-Gredilla (2014). +Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by +Fjelde *et al.* (2017). + +## The `ADVI` Objective + +```@docs +ADVI +``` + +## The `StickingTheLanding` Control Variate + +The STL control variate was proposed by Roeder *et al.* (2017). +By slightly modifying the differentiation path, it implicitly forms a control variate of the form of +```math +\begin{aligned} + \mathrm{CV}_{\mathrm{STL}}\left(z\right) + &\triangleq + \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) \\ + &= + -\nabla_{\lambda} \mathbb{E}_{z \sim q_{\nu}} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) +\end{aligned} +``` +where ``\nu = \lambda`` is set to avoid differentiating through the density of ``q_{\lambda}``. +We can see that this vector-valued function has a mean of zero and is therefore a valid control variate. + +Adding this to the closed-form entropy ELBO estimator yields the STL estimator: +```math +\begin{aligned} + \widehat{\nabla \mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) + &\triangleq \mathbb{E}_{u \sim \varphi}\left[ + \nabla_{\lambda} \log \pi \left(z_{\lambda}\left(u\right)\right) + - + \nabla_{\lambda} \log q_{\nu} \left(z_{\lambda}\left(u\right)\right) + \right] + \\ + &= + \mathbb{E}\left[ \nabla_{\lambda} \log \pi\left(z_{\lambda}\left(u\right)\right) \right] + + + \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + - + \mathrm{CV}_{\mathrm{STL}}\left(z\right) + \\ + &= + \widehat{\nabla \mathrm{ELBO}}\left(\lambda\right) + - + \mathrm{CV}_{\mathrm{STL}}\left(z\right), +\end{aligned} +``` +which has the same expectation as the original ADVI estimator, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. +The conditions for which the STL estimator results in lower variance is still an active subject for research. + +The main downside of the STL estimator is that it needs to evaluate and differentiate the log density of ``q_{\lambda}`` in every iteration. +Depending on the variational family, this might be computationally inefficient or even numerically unstable. +For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability. + + +The STL control variate can be used by changing the entropy estimator using the following object: +```@docs +StickingTheLandingEntropy +``` + +```@setup stl +using LogDensityProblems +using SimpleUnPack +using Bijectors +using LinearAlgebra +using Plots + +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); + +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); +q0 = AVI.VIMeanFieldGaussian(μ, L) + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end +``` + +Let us come back to the example in [Getting Started](@ref getting_started), where a `LogDensityProblem` is given as `model`. +In this example, the true posterior is contained within the variational family. +This setting is known as "perfect variational family specification." +In this case, the STL estimator is able to converge exponentially fast to the true solution. + +Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows: +```@example stl +n_montecarlo = 1; +b = Bijectors.bijector(model); +b⁻¹ = inverse(b) + +cfe = AVI.ADVI(model, n_montecarlo; invbij = b⁻¹) +``` +The STL estimator can instead be created as follows: +```@example stl +stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), invbij = b⁻¹); +``` + +```@setup stl +n_max_iter = 10^4 + +_, stats_cfe, _ = AVI.optimize( + cfe, + q0, + n_max_iter; + show_progress = false, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +_, stats_stl, _ = AVI.optimize( + stl, + q0, + n_max_iter; + show_progress = false, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +t = [stat.iteration for stat ∈ stats_cfe] +y_cfe = [stat.elbo for stat ∈ stats_cfe] +y_stl = [stat.elbo for stat ∈ stats_stl] +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) +savefig("advi_stl_elbo.svg") +nothing +``` +![](advi_stl_elbo.svg) + +We can see that the noise of the STL estimator becomes smaller as VI converges. +However, the speed of convergence may not always be significantly different. + +## References +1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. +3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. +4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. +5. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. + + diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 000000000..dea6d405d --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,14 @@ +```@meta +CurrentModule = AdvancedVI +``` + +# AdvancedVI + +## Introduction +[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. +VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. +`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. + +## Provided Algorithms +`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization: +- [Automatic Differentiation Variational Inference](@ref advi) diff --git a/docs/src/locscale.md b/docs/src/locscale.md new file mode 100644 index 000000000..a5966f44b --- /dev/null +++ b/docs/src/locscale.md @@ -0,0 +1,85 @@ + +# [Location-Scale Variational Family](@id locscale) + +## Introduction +The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as +```math +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} C u + m;\quad u \sim \varphi +``` +where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. +``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. +The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. + +The probability density is given by +```math + q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) +``` +and the entropy is given as +```math + \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|, +``` +where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. +Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. +The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. + +## Constructors + +!!! note + For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. + Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. + +```@docs +VILocationScale +``` + +```@docs +VIFullRankGaussian +VIMeanFieldGaussian +``` + +## Gaussian Variational Families + +Gaussian variational family: +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + +L = diagm(ones(2)) |> LowerTriangular; +q = VIFullRankGaussian(μ, L) + +L = ones(2) |> Diagonal; +q = VIMeanFieldGaussian(μ, L) +``` + +## Non-Gaussian Variational Families +Sudent-T Variational Family: + +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); +ν = 3; + +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = VILocationScale(μ, L, TDist(ν)) + +# Mean-Field +L = ones(2) |> Diagonal; +q = VILocationScale(μ, L, TDist(ν)) +``` + +Multivariate Laplace family: +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = VILocationScale(μ, L, Laplace()) + +# Mean-Field +L = ones(2) |> Diagonal; +q = VILocationScale(μ, L, Laplace()) +``` + diff --git a/docs/src/started.md b/docs/src/started.md new file mode 100644 index 000000000..e3e78c359 --- /dev/null +++ b/docs/src/started.md @@ -0,0 +1,132 @@ + +# [Getting Started with `AdvancedVI`](@id getting_started) + +## General Usage +Each VI algorithm provides the followings: +1. Variational families supported by each VI algorithm. +2. A variational objective corresponding to the VI algorithm. +Note that each variational family is subject to its own constraints. +Thus, please refer to the documentation of the variational inference algorithm of interest. + +To use `AdvancedVI`, a user needs to select a `variational family`, `variational objective`, and feed them into `optimize`. + +```@docs +optimize +``` + +## `ADVI` Example +In this tutorial, we will work with a `normal-log-normal` model. +```math +\begin{aligned} +x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) +\end{aligned} +``` +ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. + +Using the `LogDensityProblems` interface, we the model can be defined as follows: +```@example advi +using LogDensityProblems +using SimpleUnPack + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end +``` +Let's now instantiate the model +```@example advi +using LinearAlgebra + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +``` + +Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. +Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. +```@example advi +using Bijectors + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +b = Bijectors.bijector(model); +b⁻¹ = inverse(b) +``` + +Let's now load `AdvancedVI`. +Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`. +Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. +Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. +```@example advi +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI +``` +We now need to select 1. a variational objective, and 2. a variational family. +Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. +```@example advi +n_montecaro = 10; +objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) +``` +For the variational family, we will use the classic mean-field Gaussian family. +```@example advi +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); +q = AVI.VIMeanFieldGaussian(μ, L) +``` +Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. +```@example advi +n_max_iter = 10^4 +q, stats, _ = AVI.optimize( + objective, + q, + n_max_iter; + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); +``` + +The selected inference procedure stores per-iteration statistics into `stats`. +For instance, the ELBO can be ploted as follows: +```@example advi +using Plots + +t = [stat.iteration for stat ∈ stats] +y = [stat.elbo for stat ∈ stats] +plot(t, y, label="ADVI", xlabel="Iteration", ylabel="ELBO") +savefig("advi_example_elbo.svg") +nothing +``` +![](advi_example_elbo.svg) + +Further information can be gathered by defining your own `callback!`. + +The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: +```@example advi +ELBO = objective(q; n_samples=10^4) +``` diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 000000000..8333299f0 --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,26 @@ + +module AdvancedVIEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..Enzyme + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y = f(θ) + DiffResults.value!(out, y) + ∇θ = DiffResults.gradient(out) + fill!(∇θ, zero(T)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + return out +end + +end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl new file mode 100644 index 000000000..5949bdf81 --- /dev/null +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -0,0 +1,29 @@ + +module AdvancedVIForwardDiffExt + +if isdefined(Base, :get_extension) + using ForwardDiff + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..ForwardDiff + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + chunk_size = getchunksize(ad) + config = if isnothing(chunk_size) + ForwardDiff.GradientConfig(f, θ) + else + ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + end + ForwardDiff.gradient!(out, f, θ, config) + return out +end + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 000000000..520cd9ff1 --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,23 @@ + +module AdvancedVIReverseDiffExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using ReverseDiff +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..ReverseDiff +end + +# ReverseDiff without compiled tape +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + tp = ReverseDiff.GradientTape(f, θ) + ReverseDiff.gradient!(out, tp, θ) + return out +end + +end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl new file mode 100644 index 000000000..7b8f8817a --- /dev/null +++ b/ext/AdvancedVIZygoteExt.jl @@ -0,0 +1,24 @@ + +module AdvancedVIZygoteExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Zygote +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Zygote +end + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + y, back = Zygote.pullback(f, θ) + ∇θ = back(one(y)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, only(∇θ)) + return out +end + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e203a13ca..7272303a8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,270 +1,113 @@ -module AdvancedVI - -using Random: Random - -using Distributions, DistributionsAD, Bijectors -using DocStringExtensions - -using ProgressMeter, LinearAlgebra - -using ForwardDiff -using Tracker - -const PROGRESS = Ref(true) -function turnprogress(switch::Bool) - @info("[AdvancedVI]: global PROGRESS is set as $switch") - PROGRESS[] = switch -end - -const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) - -include("ad.jl") -include("utils.jl") - -using Requires -function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin - apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) - Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) - Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) - end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ZygoteAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - y, back = Zygote.pullback(f, θ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD +module AdvancedVI - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - tp = AdvancedVI.tape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out - end - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - export EnzymeAD +using SimpleUnPack: @unpack, @pack! +using Accessors - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.EnzymeAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(θ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy)) - return out - end - end -end - -export - vi, - ADVI, - ELBO, - elbo, - TruncatedADAGrad, - DecayedADAGrad, - VariationalInference +using Random: AbstractRNG, default_rng +using Distributions +import Distributions: + logpdf, _logpdf, rand, rand!, _rand!, + ContinuousMultivariateDistribution -abstract type VariationalInference{AD} end +using Functors +using Optimisers -getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD) -getADtype(::VariationalInference{AD}) where AD = AD +using DocStringExtensions -abstract type VariationalObjective end +using ProgressMeter +using LinearAlgebra +using LinearAlgebra: AbstractTriangular -const VariationalPosterior = Distribution{Multivariate, Continuous} +using LogDensityProblems +using ADTypes, DiffResults +using ADTypes: AbstractADType +using ChainRulesCore: @ignore_derivatives -""" - grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) +using FillArrays +using Bijectors -Computes the gradients used in `optimize!`. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. +using StatsBase +using StatsBase: entropy -Variance reduction techniques, e.g. control variates, should be implemented in this function. +# derivatives """ -function grad! end - + value_and_gradient!( + ad::ADTypes.AbstractADType, + f, + θ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult + ) + +Compute the value and gradient of a function `f` at `θ` using the automatic +differentiation backend `ad`. The result is stored in `out`. +The function `f` must return a scalar value. The gradient is stored in `out` as a +vector of the same length as `θ`. """ - vi(model, alg::VariationalInference) - vi(model, alg::VariationalInference, q::VariationalPosterior) - vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) +function value_and_gradient! end -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. +# estimators +abstract type AbstractVariationalObjective end -# Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior -""" -function vi end +function init end +function estimate_gradient end -function update end +# ADVI-specific interfaces +abstract type AbstractEntropyEstimator end -# default implementations -function grad!( - vo, - alg::VariationalInference{<:ForwardDiffAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - f(θ_) = if (q isa Distribution) - - vo(alg, update(q, θ_), model, args...) - else - - vo(alg, q(θ_), model, args...) - end +# entropy.jl must preceed advi.jl +include("objectives/elbo/entropy.jl") +include("objectives/elbo/advi.jl") - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(typeof(alg)) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, θ) - else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) - end - ForwardDiff.gradient!(out, f, θ, config) -end +export + ELBO, + ADVI, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy -function grad!( - vo, - alg::VariationalInference{<:TrackerAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - θ_tracked = Tracker.param(θ) - y = if (q isa Distribution) - - vo(alg, update(q, θ_tracked), model, args...) - else - - vo(alg, q(θ_tracked), model, args...) - end - Tracker.back!(y, 1.0) +# Variational Families - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(θ_tracked)) -end +include("distributions/location_scale.jl") +export + VILocationScale, + VIFullRankGaussian, + VIMeanFieldGaussian -""" - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) +# Optimization Routine -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. -""" -function optimize!( - vo, - alg::VariationalInference, - q, - model, - θ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad() -) - # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) - samples_per_step = alg.samples_per_step - max_iters = alg.max_iters - - num_params = length(θ) +function optimize end - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) - end +include("optimize.jl") - diff_result = DiffResults.GradientResult(θ) +export optimize - i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 - end +include("utils.jl") - # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ - - AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) +# optional dependencies +if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base + using Requires +end - i += 1 +@static if !isdefined(Base, :get_extension) + function __init__() + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") + end + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/AdvancedVIForwardDiffExt.jl") + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/AdvancedVIReverseDiffExt.jl") + end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") + end end - - return θ end -# objectives -include("objectives.jl") - -# optimisers -include("optimisers.jl") - -# VI algorithms -include("advi.jl") +end -end # module diff --git a/src/ad.jl b/src/ad.jl deleted file mode 100644 index 62e785e1b..000000000 --- a/src/ad.jl +++ /dev/null @@ -1,46 +0,0 @@ -############################## -# Global variables/constants # -############################## -const ADBACKEND = Ref(:forwarddiff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end -function setadbackend(::Val{:forwarddiff}) - ADBACKEND[] = :forwarddiff -end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) - ADBACKEND[] = :tracker -end - -const ADSAFE = Ref(false) -function setadsafe(switch::Bool) - @info("[AdvancedVI]: global ADSAFE is set as $switch") - ADSAFE[] = switch -end - -const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically - -function setchunksize(chunk_size::Int) - @info("[AdvancedVI]: AD chunk size is set as $chunk_size") - CHUNKSIZE[] = chunk_size -end - -abstract type ADBackend end -struct ForwardDiffAD{chunk} <: ADBackend end -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk - -struct TrackerAD <: ADBackend end - -ADBackend() = ADBackend(ADBACKEND[]) -ADBackend(T::Symbol) = ADBackend(Val(T)) - -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") diff --git a/src/advi.jl b/src/advi.jl deleted file mode 100644 index 7f9e73460..000000000 --- a/src/advi.jl +++ /dev/null @@ -1,99 +0,0 @@ -using StatsFuns -using DistributionsAD -using Bijectors -using Bijectors: TransformedDistribution - - -""" -$(TYPEDEF) - -Automatic Differentiation Variational Inference (ADVI) with automatic differentiation -backend `AD`. - -# Fields - -$(TYPEDFIELDS) -""" -struct ADVI{AD} <: VariationalInference{AD} - "Number of samples used to estimate the ELBO in each optimization step." - samples_per_step::Int - "Maximum number of gradient steps." - max_iters::Int -end - -function ADVI(samples_per_step::Int=1, max_iters::Int=1000) - return ADVI{ADBackend()}(samples_per_step, max_iters) -end - -alg_str(::ADVI) = "ADVI" - -function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - # If `q` is a mean-field approx we use the specialized `update` function - if q isa Distribution - return update(q, θ) - else - # Otherwise we assume it's a mapping θ → q - return q(θ) - end -end - - -function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - - # `model` assumed to be callable z ↦ p(x, z) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - return θ -end - -# WITHOUT updating parameters inside ELBO -function (elbo::ELBO)( - rng::Random.AbstractRNG, - alg::ADVI, - q::VariationalPosterior, - logπ::Function, - num_samples -) - # 𝔼_q(z)[log p(xᵢ, z)] - # = ∫ log p(xᵢ, z) q(z) dz - # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) - # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] - - # 𝔼_q(z)[log q(z)] - # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ) [log q(f(ϕ))] - # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] - # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - - # Finally, the ELBO is given by - # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) - - # If f: supp(p(z | x)) → ℝ then - # ELBO = 𝔼[log p(x, z) - log q(z)] - # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) - # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - - # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - z, logjac = rand_and_logjac(rng, q) - res = (logπ(z) + logjac) / num_samples - - if q isa TransformedDistribution - res += entropy(q.dist) - else - res += entropy(q) - end - - for i = 2:num_samples - z, logjac = rand_and_logjac(rng, q) - res += (logπ(z) + logjac) / num_samples - end - - return res -end diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl deleted file mode 100644 index c6bb9ac39..000000000 --- a/src/compat/enzyme.jl +++ /dev/null @@ -1,5 +0,0 @@ -struct EnzymeAD <: ADBackend end -ADBackend(::Val{:enzyme}) = EnzymeAD -function setadbackend(::Val{:enzyme}) - ADBACKEND[] = :enzyme -end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl deleted file mode 100644 index 721d03618..000000000 --- a/src/compat/reversediff.jl +++ /dev/null @@ -1,16 +0,0 @@ -using .ReverseDiff: compile, GradientTape -using .ReverseDiff.DiffResults: GradientResult - -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setcache(b::Bool) = RDCache[] = b -getcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()} -function setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - -tape(f, x) = GradientTape(f, x) -function taperesult(f, x) - return tape(f, x), GradientResult(x) -end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl deleted file mode 100644 index 40022e215..000000000 --- a/src/compat/zygote.jl +++ /dev/null @@ -1,5 +0,0 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote -end diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl new file mode 100644 index 000000000..c290b81a8 --- /dev/null +++ b/src/distributions/location_scale.jl @@ -0,0 +1,163 @@ + +""" + VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution + +The location scale variational family broadly represents various variational +families using `location` and `scale` variational parameters. + +It generally represents any distribution for which the sampling path can be +represented as follows: +```julia + d = length(location) + u = rand(dist, d) + z = scale*u + location +``` +""" +struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution + location::L + scale ::S + dist ::D + + function VILocationScale(location::AbstractVector{<:Real}, + scale ::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, + dist ::ContinuousUnivariateDistribution) + # Restricting all the arguments to have the same types creates problems + # with dual-variable-based AD frameworks. + @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2)) + new{typeof(location), typeof(scale), typeof(dist)}(location, scale, dist) + end +end + +Functors.@functor VILocationScale (location, scale) + +# Specialization of `Optimisers.destructure` for mean-field location-scale families. +# These are necessary because we only want to extract the diagonal elements of +# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD +# is very inefficient. +# begin +struct RestructureMeanField{L, S<:Diagonal, D} + q::VILocationScale{L, S, D} +end + +function (re::RestructureMeanField)(flat::AbstractVector) + n_dims = div(length(flat), 2) + location = first(flat, n_dims) + scale = Diagonal(last(flat, n_dims)) + VILocationScale(location, scale, re.q.dist) +end + +function Optimisers.destructure( + q::VILocationScale{L, <:Diagonal, D} +) where {L, D} + @unpack location, scale, dist = q + flat = vcat(location, diag(scale)) + n_dims = length(location) + flat, RestructureMeanField(q) +end +# end + +Base.length(q::VILocationScale) = length(q.location) + +Base.size(q::VILocationScale) = size(q.location) + +Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D) + +function StatsBase.entropy(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) +end + +function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) +end + +function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) +end + +function rand(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims) + location +end + +function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(rng, dist, n_dims, num_samples) .+ location +end + +# This specialization improves AD performance of the sampling path +function rand( + rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int +) where {L, D} + @unpack location, scale, dist = q + n_dims = length(location) + scale_diag = diag(scale) + scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location +end + +function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x .= scale*x + return x += location +end + +function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x[:] = scale*x + return x .+= location +end + +Distributions.mean(q::VILocationScale) = q.location + +function Distributions.var(q::VILocationScale) + C = q.scale + Diagonal(C*C') +end + +function Distributions.cov(q::VILocationScale) + C = q.scale + Hermitian(C*C') +end + +""" + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true) + +This constructs a multivariate Gaussian distribution with a full rank covariance matrix. +""" +function VIFullRankGaussian( + μ::AbstractVector{T}, + L::AbstractTriangular{T}; + check_args::Bool = true +) where {T <: Real} + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base) +end + +""" + VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}; check_args = true) + +This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. +""" +function VIMeanFieldGaussian( + μ::AbstractVector{T}, + L::Diagonal{T}; + check_args::Bool = true +) where {T <: Real} + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base) +end diff --git a/src/objectives.jl b/src/objectives.jl deleted file mode 100644 index 5a6b61b0c..000000000 --- a/src/objectives.jl +++ /dev/null @@ -1,7 +0,0 @@ -struct ELBO <: VariationalObjective end - -function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) - return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...) -end - -const elbo = ELBO() diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl new file mode 100644 index 000000000..354c822c7 --- /dev/null +++ b/src/objectives/elbo/advi.jl @@ -0,0 +1,107 @@ + +""" + ADVI(prob, n_samples; kwargs...) + +Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. + +# Arguments +- `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.) + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) +- `cv`: A control variate. +- `invbij`: A bijective mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- `logdensity(prob)` must be differentiable by the selected AD backend. + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. +""" +struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + prob ::P + invbij ::B + entropy ::EntropyEst + n_samples::Int + + function ADVI(prob, + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy(), + invbij = Bijectors.identity) + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + new{typeof(prob), typeof(invbij), typeof(entropy)}( + prob, invbij, entropy, n_samples + ) + end +end + +Base.show(io::IO, advi::ADVI) = + print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") + +init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing + +function (advi::ADVI)( + rng::AbstractRNG, + q_η::ContinuousMultivariateDistribution, + ηs ::AbstractMatrix +) + 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) + LogDensityProblems.logdensity(advi.prob, zᵢ) + logdetjacᵢ + end + ℍ = advi.entropy(q_η, ηs) + 𝔼ℓ + ℍ +end + +""" + (advi::ADVI)( + q_η::ContinuousMultivariateDistribution; + rng::AbstractRNG = Random.default_rng(), + n_samples::Int = advi.n_samples + ) + +Evaluate the ELBO using the ADVI formulation. + +# Arguments +- `q_η`: Variational approximation before applying a bijector (unconstrained support). +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. + +""" +function (advi::ADVI)( + q_η ::ContinuousMultivariateDistribution; + rng ::AbstractRNG = default_rng(), + n_samples::Int = advi.n_samples +) + ηs = rand(rng, q_η, n_samples) + advi(rng, q_η, ηs) +end + +function estimate_gradient( + rng ::AbstractRNG, + adbackend ::AbstractADType, + advi ::ADVI, + obj_state, + λ ::AbstractVector{<:Real}, + restructure, + out ::DiffResults.MutableDiffResult +) + f(λ′) = begin + q_η = restructure(λ′) + ηs = rand(rng, q_η, advi.n_samples) + -advi(rng, q_η, ηs) + end + value_and_gradient!(adbackend, f, λ, out) + + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + out, nothing, stat +end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl new file mode 100644 index 000000000..97ccda299 --- /dev/null +++ b/src/objectives/elbo/entropy.jl @@ -0,0 +1,31 @@ + +struct ClosedFormEntropy <: AbstractEntropyEstimator end + +function (::ClosedFormEntropy)(q, ::AbstractMatrix) + entropy(q) +end + +struct MonteCarloEntropy <: AbstractEntropyEstimator end + +function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) + mean(eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) + end +end + +""" + StickingTheLandingEntropy() + +The "sticking the landing" entropy estimator. + +# Requirements +- `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +""" +struct StickingTheLandingEntropy <: AbstractEntropyEstimator end + +function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix) + @ignore_derivatives mean(eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) + end +end diff --git a/src/optimisers.jl b/src/optimisers.jl deleted file mode 100644 index 8077f98cb..000000000 --- a/src/optimisers.jl +++ /dev/null @@ -1,94 +0,0 @@ -const ϵ = 1e-8 - -""" - TruncatedADAGrad(η=0.1, τ=1.0, n=100) - -Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - τ: constant scale factor - - n: number of previous gradient norms to use in the scaling. -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. - -[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E). -""" -mutable struct TruncatedADAGrad - eta::Float64 - tau::Float64 - n::Int - - iters::IdDict - acc::IdDict -end - -function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) - TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) -end - -function apply!(o::TruncatedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - τ = o.tau - - g² = get!( - o.acc, - x, - [zeros(T, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(Δ)), 1} - i = get!(o.iters, x, 1)::Int - - # Example: suppose i = 12 and o.n = 10 - idx = mod(i - 1, o.n) + 1 # => idx = 2 - - # set the current - @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ - - # TODO: make more efficient and stable - s = sum(g²) - - # increment - o.iters[x] += 1 - - # TODO: increment (but "truncate") - # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1 - - @. Δ *= η / (τ + sqrt(s) + ϵ) -end - -""" - DecayedADAGrad(η=0.1, pre=1.0, post=0.9) - -Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - pre: weight of new gradient norm - - post: weight of histroy of gradient norms -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -mutable struct DecayedADAGrad - eta::Float64 - pre::Float64 - post::Float64 - - acc::IdDict -end - -DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) - -function apply!(o::DecayedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) - @. acc = o.post * acc + o.pre * Δ^2 - @. Δ *= η / (√acc + ϵ) -end diff --git a/src/optimize.jl b/src/optimize.jl new file mode 100644 index 000000000..ea2fd5a12 --- /dev/null +++ b/src/optimize.jl @@ -0,0 +1,113 @@ + +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end + +""" + optimize( + objective ::AbstractVariationalObjective, + restructure, + λ₀ ::AbstractVector{<:Real}, + n_max_iter ::Int, + objargs...; + kwargs... + ) + +Optimize the variational objective `objective` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`. + + optimize( + objective ::AbstractVariationalObjective, + q, + n_max_iter::Int, + objargs...; + kwargs... + ) + +Optimize the variational objective `objective` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface. + +# Arguments +- `objective`: Variational Objective. +- `λ₀`: Initial value of the variational parameters. +- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. +- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. +- `n_max_iter`: Maximum number of iterations. +- `objargs...`: Arguments to be passed to `objective`. +- `kwargs...`: Additional keywoard arguments. (See below.) + +# Keyword Arguments +- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.) +- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) +- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) +- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) +- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.) +- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) +- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.) + +# Returns +- `λ`: Variational parameters optimizing the variational objective. +- `logstats`: Statistics and logs gathered during optimization. +- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. +""" +function optimize( + objective ::AbstractVariationalObjective, + restructure, + λ₀ ::AbstractVector{<:Real}, + n_max_iter ::Int, + objargs...; + adbackend ::AbstractADType, + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + rng ::AbstractRNG = default_rng(), + show_progress::Bool = true, + state ::NamedTuple = NamedTuple(), + callback! = nothing, + prog = ProgressMeter.Progress( + n_max_iter; + desc = "Optimizing", + barlen = 31, + showspeed = true, + enabled = show_progress + ) +) + λ = copy(λ₀) + opt_st = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ) + obj_st = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure) + grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) + logstats = NamedTuple[] + + for t = 1:n_max_iter + stat = (iteration=t,) + + grad_buf, obj_st, stat′ = estimate_gradient( + rng, adbackend, objective, obj_st, + λ, restructure, grad_buf; objargs... + ) + stat = merge(stat, stat′) + + g = DiffResults.gradient(grad_buf) + opt_st, λ = Optimisers.update!(opt_st, λ, g) + + if !isnothing(callback!) + stat′ = callback!(; stat, restructure, λ, g) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat + end + + @debug "Iteration $t" stat... + + pm_next!(prog, stat) + push!(logstats, stat) + end + state = (opt=opt_st, obj=obj_st) + logstats = map(identity, logstats) + λ, logstats, state +end + +function optimize(objective ::AbstractVariationalObjective, + q₀, + n_max_iter::Int; + kwargs...) + λ, restructure = Optimisers.destructure(q₀) + λ, logstats, state = optimize( + objective, restructure, λ, n_max_iter; kwargs... + ) + restructure(λ), logstats, state +end diff --git a/src/utils.jl b/src/utils.jl index bb4c1f18f..e69de29bb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,15 +0,0 @@ -using Distributions - -using Bijectors: Bijectors - - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution) - x = rand(rng, dist) - return x, zero(eltype(x)) -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution) - x = rand(rng, dist.dist) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac -end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..663d671dc --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,21 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" +ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 000000000..f575b485b --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,22 @@ + +using ReTest + +@testset "ad" begin + @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + # :Enzyme => AutoEnzyme(), # Currently not tested against. + ) + D = 10 + A = randn(D, D) + λ = randn(D) + grad_buf = DiffResults.GradientResult(λ) + f(λ′) = λ′'*A*λ′ / 2 + AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A')*λ/2 + @test f ≈ λ'*A*λ / 2 + end +end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl new file mode 100644 index 000000000..8d8df1e93 --- /dev/null +++ b/test/advi_locscale.jl @@ -0,0 +1,97 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using ReTest + +@testset "advi" begin + @testset "locscale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64], # Currently only tested against Float64 + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + # :Zygote => AutoZygote(), + # :Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + modelstats = modelconstr(realtype; rng) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + FillArrays.Eye(n_dims) |> Diagonal + else + FillArrays.Eye(n_dims) |> Matrix |> LowerTriangular + end + + q₀ = if is_meanfield + VIMeanFieldGaussian(μ₀, L₀) + else + VIFullRankGaussian(μ₀, L₀) + end + + obj = objective(model, b⁻¹, 10) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( + obj, q₀, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + rng = rng, + adbackend = adbackend, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats, _ = optimize( + obj, q₀, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + rng = rng, + adbackend = adbackend, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats, _ = optimize( + obj, q₀, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + rng = rng_repl, + adbackend = adbackend, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + end + end +end diff --git a/test/distributions.jl b/test/distributions.jl new file mode 100644 index 000000000..9cb158c1e --- /dev/null +++ b/test/distributions.jl @@ -0,0 +1,111 @@ + +using ReTest +using Distributions: _logpdf + +@testset "distributions" begin + @testset "$(string(covtype)) $(basedist) $(realtype)" for + basedist = [:gaussian], + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64] + + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(rng, realtype, n_dims) + L = if covtype == :fullrank + tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular + else + Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) + end + Σ = L*L' + + q = if covtype == :fullrank && basedist == :gaussian + VIFullRankGaussian(μ, L) + elseif covtype == :meanfield && basedist == :gaussian + VIMeanFieldGaussian(μ, L) + end + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end + + @testset "logpdf" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z = rand(rng, q) + @test eltype(z) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + @test eltype(_logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end + + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) == μ + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ + end + end + + @testset "sampling" begin + @testset "rand" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = mapreduce(x -> rand(rng, q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand batch" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = rand(rng, q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand!" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(rng, q, z_samples) + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + end + end + + @testset "Diagonal destructure" for + n_dims = 10 + μ = zeros(n_dims) + L = ones(n_dims) + q = VIMeanFieldGaussian(μ, L |> Diagonal) + λ, re = Optimisers.destructure(q) + + @test length(λ) == 2*n_dims + @test q == re(λ) + end +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl new file mode 100644 index 000000000..b8d72cc0f --- /dev/null +++ b/test/models/normallognormal.jl @@ -0,0 +1,62 @@ + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +function normallognormal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ_y = L_y*L_y' |> Hermitian + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) + + L = Matrix{realtype}(undef, n_dims+1, n_dims+1) |> LowerTriangular + L[1,1] = σ_x + L[2:end,2:end] = L_y + + μ = vcat(μ_x, μ_y) + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + + μ = vcat(μ_x, μ_y) + L = vcat(σ_x, σ_y) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end diff --git a/test/optimisers.jl b/test/optimisers.jl deleted file mode 100644 index fae652ed0..000000000 --- a/test/optimisers.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Random, Test, LinearAlgebra, ForwardDiff -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! - -θ = randn(10, 10) -@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 - x = rand(10) - Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end - diff --git a/test/optimize.jl b/test/optimize.jl new file mode 100644 index 000000000..a8013be21 --- /dev/null +++ b/test/optimize.jl @@ -0,0 +1,92 @@ + +using ReTest + +@testset "optimize" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + T = 1000 + modelstats = normallognormal_meanfield(Float64; rng) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + # Global Test Configurations + b⁻¹ = Bijectors.bijector(model) |> inverse + μ₀ = zeros(Float64, n_dims) + L₀ = ones(Float64, n_dims) |> Diagonal + q₀ = VIMeanFieldGaussian(μ₀, L₀) + obj = ADVI(model, 10; invbij=b⁻¹) + + adbackend = AutoForwardDiff() + optimizer = Optimisers.Adam(1e-2) + + rng = Philox4x(UInt64, seed, 8) + q_ref, stats_ref, _ = optimize( + obj, q₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + ) + λ_ref, _ = Optimisers.destructure(q_ref) + + @testset "restructure" begin + λ₀, re = Optimisers.destructure(q₀) + + rng = Philox4x(UInt64, seed, 8) + λ, stats, _ = optimize( + obj, re, λ₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + ) + @test λ == λ_ref + @test stats == stats_ref + end + + @testset "callback" begin + rng = Philox4x(UInt64, seed, 8) + test_values = rand(rng, T) + + callback!(; stat, restructure, λ, g) = begin + (test_value = test_values[stat.iteration],) + end + + rng = Philox4x(UInt64, seed, 8) + _, stats, _ = optimize( + obj, q₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + callback! + ) + @test [stat.test_value for stat ∈ stats] == test_values + end + + @testset "warm start" begin + rng = Philox4x(UInt64, seed, 8) + + T_first = div(T,2) + T_last = T - T_first + + q_first, _, state = optimize( + obj, q₀, T_first; + optimizer, + show_progress = false, + rng, + adbackend + ) + + q, stats, _ = optimize( + obj, q_first, T_last; + optimizer, + show_progress = false, + state, + rng, + adbackend + ) + @test q == q_ref + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a305c25e5..127503be2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,28 +1,43 @@ -using Test -using Distributions, DistributionsAD -using AdvancedVI - -include("optimisers.jl") - -target = MvNormal(ones(2)) -logπ(z) = logpdf(target, z) -advi = ADVI(10, 1000) -# Using a function z ↦ q(⋅∣z) -getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -q = vi(logπ, advi, getq, randn(4)) +using ReTest +using ReTest: @testset, @test + +using Comonicon +using Random +using Random123 +using Statistics +using Distributions +using LinearAlgebra +using SimpleUnPack: @unpack +using FillArrays +using PDMats + +using Bijectors +using LogDensityProblems +using Optimisers +using ADTypes +using ForwardDiff, ReverseDiff, Zygote -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +using AdvancedVI -# OR: implement `update` and pass a `Distribution` -function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) +# Models for Inference Tests +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool end -q0 = TuringDiagMvNormal(zeros(2), ones(2)) -q = vi(logπ, advi, q0, randn(4)) +include("models/normallognormal.jl") + +# Tests +include("ad.jl") +include("distributions.jl") +include("advi_locscale.jl") +include("optimize.jl") -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +@main function runtests(patterns...; dry::Bool = false) + retest(patterns...; dry = dry, verbose = Inf) +end