diff --git a/.github/workflows/Examples.yml b/.github/workflows/Examples.yml new file mode 100644 index 00000000..f9da45f3 --- /dev/null +++ b/.github/workflows/Examples.yml @@ -0,0 +1,42 @@ +name: NF Examples + +on: + push: + branches: + - main + tags: ['*'] + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + run-examples: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: '1' + arch: x64 + - uses: julia-actions/cache@v2 + - name: Run NF examples + run: | + cd example + julia --project=. --color=yes -e ' + using Pkg; + Pkg.develop(PackageSpec(path=joinpath(pwd(), ".."))); + Pkg.instantiate(); + @info "Running planar flow demo"; + include("demo_planar_flow.jl"); + @info "Running radial flow demo"; + include("demo_radial_flow.jl"); + @info "Running Real NVP demo"; + include("demo_RealNVP.jl"); + @info "Running neural spline flow demo"; + include("demo_neural_spline_flow.jl"); + @info "Running Hamiltonian flow demo"; + include("demo_hamiltonian_flow.jl");' diff --git a/.gitignore b/.gitignore index f7a16479..a7dae251 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ /docs/build/ test/Manifest.toml example/Manifest.toml +example/LocalPreferences.toml # Files generated by invoking Julia with --code-coverage *.jl.cov diff --git a/Project.toml b/Project.toml index db11ea42..f03ebb69 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NormalizingFlows" uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256" -version = "0.2.0" +version = "0.2.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -24,7 +24,7 @@ NormalizingFlowsCUDAExt = "CUDA" ADTypes = "1" Bijectors = "0.12.6, 0.13, 0.14, 0.15" CUDA = "5" -DifferentiationInterface = "0.6.42" +DifferentiationInterface = "0.6, 0.7" Distributions = "0.25" DocStringExtensions = "0.9" Optimisers = "0.2.16, 0.3, 0.4" diff --git a/docs/src/api.md b/docs/src/api.md index eb128863..13fe0c26 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -61,6 +61,11 @@ and hope to generate approximate samples from it. ```@docs NormalizingFlows.elbo ``` + +```@docs +NormalizingFlows.elbo_batch +``` + #### Log-likelihood By maximizing the log-likelihood, it is equivalent to minimizing the forward KL divergence between $q_\theta$ and $p$, i.e., diff --git a/docs/src/index.md b/docs/src/index.md index dc840761..f685b8b9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -21,7 +21,7 @@ See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for mor To install the package, run the following command in the Julia REPL: ``` ] # enter Pkg mode -(@v1.9) pkg> add git@github.com:TuringLang/NormalizingFlows.jl.git +(@v1.11) pkg> add git@github.com:TuringLang/NormalizingFlows.jl.git ``` Then simply run the following command to use the package: ```julia diff --git a/example/Project.toml b/example/Project.toml index d462c5ee..0b9b0214 100644 --- a/example/Project.toml +++ b/example/Project.toml @@ -2,16 +2,21 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[extras] +CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/example/README.md b/example/README.md index 27a59bb0..738fff98 100644 --- a/example/README.md +++ b/example/README.md @@ -12,7 +12,7 @@ normalizing flow to approximate the target distribution using `NormalizingFlows. Currently, all examples share the same [Julia project](https://pkgdocs.julialang.org/v1/environments/#Using-someone-else's-project). To run the examples, first activate the project environment: ```julia -# pwd() = "NormalizingFlows.jl/" -using Pkg; Pkg.activate("example"); Pkg.instantiate() +# pwd() = "NormalizingFlows.jl/example" +using Pkg; Pkg.activate("."); Pkg.instantiate() ``` -This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with include(".jl"), or by running the example script line-by-line. +This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with `include(".jl")`, or by running the example script line-by-line. diff --git a/example/SyntheticTargets.jl b/example/SyntheticTargets.jl new file mode 100644 index 00000000..1b5d097a --- /dev/null +++ b/example/SyntheticTargets.jl @@ -0,0 +1,19 @@ +using DocStringExtensions +using Distributions, Random, LinearAlgebra +using IrrationalConstants +using Plots + + +include("targets/banana.jl") +include("targets/cross.jl") +include("targets/neal_funnel.jl") +include("targets/warped_gaussian.jl") + +function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000)) + xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) + yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) + z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] + fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) + scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) + return fig +end diff --git a/example/demo_RealNVP.jl b/example/demo_RealNVP.jl new file mode 100644 index 00000000..516e4c90 --- /dev/null +++ b/example/demo_RealNVP.jl @@ -0,0 +1,180 @@ +using Flux +using Bijectors +using Bijectors: partition, combine, PartitionMask + +using Random, Distributions, LinearAlgebra +using Functors +using Optimisers, ADTypes +using Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +################################## +# define affine coupling layer using Bijectors.jl interface +################################# +struct AffineCoupling <: Bijectors.Bijector + dim::Int + mask::Bijectors.PartitionMask + s::Flux.Chain + t::Flux.Chain +end + +# let params track field s and t +@functor AffineCoupling (s, t) + +function AffineCoupling( + dim::Int, # dimension of input + hdims::Int, # dimension of hidden units for s and t + mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on +) + cdims = length(mask_idx) # dimension of parts used to construct coupling law + s = mlp3(cdims, hdims, cdims) + t = mlp3(cdims, hdims, cdims) + mask = PartitionMask(dim, mask_idx) + return AffineCoupling(dim, mask, s, t) +end + +function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat) + # partition vector using 'af.mask::PartitionMask` + x₁, x₂, x₃ = partition(af.mask, x) + y₁ = x₁ .* af.s(x₂) .+ af.t(x₂) + return combine(af.mask, y₁, x₂, x₃) +end + +function (af::AffineCoupling)(x::AbstractArray) + return transform(af, x) +end + +function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector) + x_1, x_2, x_3 = Bijectors.partition(af.mask, x) + y_1 = af.s(x_2) .* x_1 .+ af.t(x_2) + logjac = sum(log ∘ abs, af.s(x_2)) # this is a scalar + return combine(af.mask, y_1, x_2, x_3), logjac +end + +function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix) + x_1, x_2, x_3 = Bijectors.partition(af.mask, x) + y_1 = af.s(x_2) .* x_1 .+ af.t(x_2) + logjac = sum(log ∘ abs, af.s(x_2); dims = 1) # 1 × size(x, 2) + return combine(af.mask, y_1, x_2, x_3), vec(logjac) +end + + +function Bijectors.with_logabsdet_jacobian( + iaf::Inverse{<:AffineCoupling}, y::AbstractVector +) + af = iaf.orig + # partition vector using `af.mask::PartitionMask` + y_1, y_2, y_3 = partition(af.mask, y) + # inverse transformation + x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2) + logjac = -sum(log ∘ abs, af.s(y_2)) + return combine(af.mask, x_1, y_2, y_3), logjac +end + +function Bijectors.with_logabsdet_jacobian( + iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix +) + af = iaf.orig + # partition vector using `af.mask::PartitionMask` + y_1, y_2, y_3 = partition(af.mask, y) + # inverse transformation + x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2) + logjac = -sum(log ∘ abs, af.s(y_2); dims = 1) + return combine(af.mask, x_1, y_2, y_3), vec(logjac) +end + +################### +# an equivalent definition of AffineCoupling using Bijectors.Coupling +# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1) +################### + +# struct AffineCoupling <: Bijectors.Bijector +# dim::Int +# mask::Bijectors.PartitionMask +# s::Flux.Chain +# t::Flux.Chain +# end + +# # let params track field s and t +# @functor AffineCoupling (s, t) + +# function AffineCoupling(dim, mask, s, t) +# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask) +# end + +# function AffineCoupling( +# dim::Int, # dimension of input +# hdims::Int, # dimension of hidden units for s and t +# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on +# ) +# cdims = length(mask_idx) # dimension of parts used to construct coupling law +# s = mlp3(cdims, hdims, cdims) +# t = mlp3(cdims, hdims, cdims) +# mask = PartitionMask(dim, mask_idx) +# return AffineCoupling(dim, mask, s, t) +# end + + + +################################## +# start demo +################################# +Random.seed!(123) +rng = Random.default_rng() +T = Float32 + +###################################### +# a difficult banana target +###################################### +target = Banana(2, 1.0f0, 100.0f0) +logp = Base.Fix1(logpdf, target) + +###################################### +# learn the target using Affine coupling flow +###################################### +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) + +d = 2 +hdims = 32 + +# alternating the coupling layers +Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3] + +flow = create_flow(Ls, q0) +flow_untrained = deepcopy(flow) + + +###################################### +# start training +###################################### +sample_per_iter = 64 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + rng, + elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup + flow, + logp, + sample_per_iter; + max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results + optimiser=Optimisers.Adam(5e-4), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_hamiltonian_flow.jl b/example/demo_hamiltonian_flow.jl new file mode 100644 index 00000000..054c581a --- /dev/null +++ b/example/demo_hamiltonian_flow.jl @@ -0,0 +1,179 @@ +using Random, Distributions, LinearAlgebra +using Functors +using Optimisers, ADTypes +using Mooncake +using Bijectors +using Bijectors: partition, combine, PartitionMask + +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + + +""" +Hamiltonian flow interleaves betweem Leapfrog layer and an affine transformation to the momentum variable +It targets the joint distribution π(x, ρ) = π(x) * N(ρ; 0, I) +where x is the target variable and ρ is the momentum variable + +- The optimizable parameters for the leapfrog layers are the step size logϵ +- and we will also optimize the affine transformation parameters (shift and scale) + +Wrap the leapfrog transformation into a Bijectors.jl interface + +# References +[1] Naitong Chen, Zuheng Xu, Trevor Campbell, Bayesian inference via sparse Hamiltonian flows, NeurIPS 2022. +""" +struct LeapFrog{T<:Real} <: Bijectors.Bijector + "dimention of the target space" + dim::Int + "leapfrog step size" + logϵ::AbstractVector{T} + "number of leapfrog steps" + L::Int + "score function of the target distribution" + ∇logp + "mask function to split the input into position and momentum" + mask::PartitionMask +end +@functor LeapFrog (logϵ,) + +function LeapFrog(dim::Int, logϵ::T, L::Int, ∇logp) where {T<:Real} + return LeapFrog(dim, logϵ .* ones(T, dim), L, ∇logp, PartitionMask(2dim, 1:dim)) +end + +_get_stepsize(lf::LeapFrog) = exp.(lf.logϵ) + +""" +run L leapfrog steps with std Gaussian momentum distribution with vector stepsizes +""" +function _leapfrog( + ∇ℓπ, ϵ::AbstractVector{T}, L::Int, x::AbstractVecOrMat{T}, v::AbstractVecOrMat{T} +) where {T<:Real} + v += ϵ/2 .* ∇ℓπ(x) + for _ in 1:L - 1 + x += ϵ .* v + v += ϵ .* ∇ℓπ(x) + end + x += ϵ .* v + v += ϵ/2 .* ∇ℓπ(x) + return x, v +end + +function Bijectors.transform(lf::LeapFrog{T}, z::AbstractVector{T}) where {T<:Real} + (; dim, logϵ, L, ∇logp) = lf + @assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]" + + ϵ = _get_stepsize(lf) + x, ρ, e = partition(lf.mask, z) # split into position and momentum + x_, ρ_ = _leapfrog(∇logp, ϵ, L, x, ρ) # run L learpfrog steps + return combine(lf.mask, x_, ρ_, e) +end + +function Bijectors.transform(ilf::Inverse{<:LeapFrog{T}}, z::AbstractVector{T}) where {T<:Real} + lf = ilf.orig + (; dim, logϵ, L, ∇logp) = lf + @assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]" + + ϵ = _get_stepsize(lf) + x, ρ, e = partition(lf.mask, z) # split into position and momentum + x_, ρ_ = _leapfrog(∇logp, -ϵ, L, x, ρ) # run L learpfrog steps + return combine(lf.mask, x_, ρ_, e) +end + +function Bijectors.with_logabsdet_jacobian(lf::LeapFrog{T}, z::AbstractVector{T}) where {T<:Real} + # leapfrog is symplectic, so the logabsdetjacobian is 0 + return Bijectors.transform(lf, z), zero(eltype(z)) +end +function Bijectors.with_logabsdet_jacobian(ilf::Inverse{<:LeapFrog{T}}, z::AbstractVector{T}) where {T<:Real} + # leapfrog is symplectic, so the logabsdetjacobian is 0 + return Bijectors.transform(ilf, z), zero(eltype(z)) +end + +# shift and scale transformation that only applies to the momentum variable ρ = z[(dim + 1):end] +function momentum_normalization_layer(dims::Int, T::Type{<:Real}) + bx = identity # leave position variable x = z[1:dim] unchanged + bρ = Bijectors.Shift(zeros(T, dims)) ∘ Bijectors.Scale(ones(T, dims)) + b = Bijectors.Stacked((bx, bρ), [1:dims, (dims + 1):(2*dims)]) + return b +end + + +################################## +# start demo +################################# +Random.seed!(123) +rng = Random.default_rng() +T = Float64 # for Hamiltonian VI, its recommended to use Float64 as the dynamic is chaotic + +###################################### +# a Funnel target +###################################### +dims = 2 +target = Funnel(dims, -8.0, 5.0) +# visualize(target) + +logp = Base.Fix1(logpdf, target) +function logp_joint(z::AbstractVector{T}) where {T<:Real} + dims = div(length(z), 2) + x = @view z[1:dims] + ρ = @view z[(dims + 1):end] + logp_x = logp(x) + logp_ρ = sum(logpdf(Normal(), ρ)) + return logp_x + logp_ρ +end + +# the score function is the gradient of the logpdf. +# In all the synthetic targets, the score function is only implemented for the Banana target +∇logp = Base.Fix1(score, target) + +###################################### +# build the flow in the joint space +###################################### +# mean field Gaussian reference +@leaf MvNormal +q0 = transformed( + MvNormal(zeros(T, 2dims), ones(T, 2dims)), Bijectors.Shift(zeros(T, 2dims)) ∘ Bijectors.Scale(ones(T, 2dims)) +) + +nlfg = 3 +logϵ0 = log(0.05) # initial step size + +# Hamiltonian flow interleaves betweem Leapfrog layer and an affine transformation to the momentum variable +Ls = [ + momentum_normalization_layer(dims, T) ∘ LeapFrog(dims, logϵ0, nlfg, ∇logp) for _ in 1:15 +] + +flow = create_flow(Ls, q0) +flow_untrained = deepcopy(flow) + + +###################################### +# start training +###################################### +sample_per_iter = 16 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp_joint, + sample_per_iter; + max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results + optimiser=Optimisers.Adam(3e-4), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_neural_spline_flow.jl b/example/demo_neural_spline_flow.jl new file mode 100644 index 00000000..ffeba09f --- /dev/null +++ b/example/demo_neural_spline_flow.jl @@ -0,0 +1,172 @@ +using Flux +using Bijectors +using Bijectors: partition, combine, PartitionMask + +using Random, Distributions, LinearAlgebra +using Functors +using Optimisers, ADTypes +using Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +################################## +# define neural spline layer using Bijectors.jl interface +################################# +""" +Neural Rational quadratic Spline layer + +# References +[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). +""" +struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector + dim::Int # dimension of input + K::Int # number of knots + n_dims_transferred::Int # number of dimensions that are transformed + nn::A # networks that parmaterize the knots and derivatives + B::T # bound of the knots + mask::Bijectors.PartitionMask +end + +function NeuralSplineLayer( + dim::T1, # dimension of input + hdims::T1, # dimension of hidden units for s and t + K::T1, # number of knots + B::T2, # bound of the knots + mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on +) where {T1<:Int,T2<:Real} + num_of_transformed_dims = length(mask_idx) + input_dims = dim - num_of_transformed_dims + + # output dim of the NN + output_dims = (3K - 1)*num_of_transformed_dims + # one big mlp that outputs all the knots and derivatives for all the transformed dimensions + nn = mlp3(input_dims, hdims, output_dims) + + mask = Bijectors.PartitionMask(dim, mask_idx) + return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask) +end + +@functor NeuralSplineLayer (nn,) + +# define forward and inverse transformation +""" +Build a rational quadratic spline from the nn output +Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline + +we just need to map the nn output to the knots and derivatives of the RQS +""" +function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector) + K, B = nsl.K, nsl.B + nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :) + ws = @view nnoutput[:, 1:K] + hs = @view nnoutput[:, (K + 1):(2K)] + ds = @view nnoutput[:, (2K + 1):(3K - 1)] + return Bijectors.RationalQuadraticSpline(ws, hs, ds, B) +end + +function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector) + x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) + # instantiate rqs knots and derivatives + rqs = instantiate_rqs(nsl, x_2) + y_1 = Bijectors.transform(rqs, x_1) + return Bijectors.combine(nsl.mask, y_1, x_2, x_3) +end + +function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector) + nsl = insl.orig + y1, y2, y3 = partition(nsl.mask, y) + rqs = instantiate_rqs(nsl, y2) + x1 = Bijectors.transform(Inverse(rqs), y1) + return Bijectors.combine(nsl.mask, x1, y2, y3) +end + +function (nsl::NeuralSplineLayer)(x::AbstractVector) + return Bijectors.transform(nsl, x) +end + +# define logabsdetjac +function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector) + x_1, x_2, _ = Bijectors.partition(nsl.mask, x) + rqs = instantiate_rqs(nsl, x_2) + logjac = logabsdetjac(rqs, x_1) + return logjac +end + +function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector) + nsl = insl.orig + y1, y2, _ = partition(nsl.mask, y) + rqs = instantiate_rqs(nsl, y2) + logjac = logabsdetjac(Inverse(rqs), y1) + return logjac +end + +function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector) + x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) + rqs = instantiate_rqs(nsl, x_2) + y_1, logjac = with_logabsdet_jacobian(rqs, x_1) + return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac +end + +################################## +# start demo +################################# +Random.seed!(123) +rng = Random.default_rng() +T = Float32 + +###################################### +# neals funnel target +###################################### +target = Funnel(2, 0.0f0, 9.0f0) +logp = Base.Fix1(logpdf, target) + +###################################### +# learn the target using Affine coupling flow +###################################### +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) + +d = 2 +hdims = 64 +K = 10 +B = 30 +Ls = [ + NeuralSplineLayer(d, hdims, K, B, [1]) ∘ NeuralSplineLayer(d, hdims, K, B, [2]) for + i in 1:3 +] + +flow = create_flow(Ls, q0) +flow_untrained = deepcopy(flow) + + +###################################### +# start training +###################################### +sample_per_iter = 64 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp, + sample_per_iter; + max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results + optimiser=Optimisers.Adam(5e-5), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_planar_flow.jl b/example/demo_planar_flow.jl new file mode 100644 index 00000000..25a694e5 --- /dev/null +++ b/example/demo_planar_flow.jl @@ -0,0 +1,63 @@ +using Random, Distributions, LinearAlgebra, Bijectors +using Functors +using Optimisers, ADTypes, Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +Random.seed!(123) +rng = Random.default_rng() +T = Float64 + +###################################### +# 2d Banana as the target distribution +###################################### +target = Banana(2, 1.0, 10.0) +logp = Base.Fix1(logpdf, target) + + +###################################### +# setup planar flow +###################################### +function create_planar_flow(n_layers::Int, q₀) + d = length(q₀) + Ls = [PlanarLayer(d) for _ in 1:n_layers] + ts = reduce(∘, Ls) + return transformed(q₀, ts) +end + +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) +flow = create_planar_flow(10, q0) +flow_untrained = deepcopy(flow) + +###################################### +# start training +###################################### +sample_per_iter = 32 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo_batch, + flow, + logp, + sample_per_iter; + max_iters=10_000, + optimiser=Optimisers.Adam(one(T)/100), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_radial_flow.jl b/example/demo_radial_flow.jl new file mode 100644 index 00000000..89b4c561 --- /dev/null +++ b/example/demo_radial_flow.jl @@ -0,0 +1,64 @@ +using Random, Distributions, LinearAlgebra, Bijectors +using Functors +using Optimisers, ADTypes, Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +Random.seed!(123) +rng = Random.default_rng() +T = Float64 + +###################################### +# get target logp +###################################### +target = WarpedGauss() +logp = Base.Fix1(logpdf, target) + +###################################### +# setup radial flow +###################################### +function create_radial_flow(n_layers::Int, q₀) + d = length(q₀) + Ls = [RadialLayer(d) for _ in 1:n_layers] + ts = reduce(∘, Ls) + return transformed(q₀, ts) +end + +# create a 10-layer radial flow +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) +flow = create_radial_flow(10, q0) + +flow_untrained = deepcopy(flow) + +###################################### +# start training +###################################### +sample_per_iter = 32 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo_batch, + flow, + logp, + sample_per_iter; + max_iters=10_000, + optimiser=Optimisers.Adam(one(T)/100), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/neural_spline_flow/nsf_layer.jl b/example/neural_spline_flow/nsf_layer.jl deleted file mode 100644 index 6f3b2a7f..00000000 --- a/example/neural_spline_flow/nsf_layer.jl +++ /dev/null @@ -1,107 +0,0 @@ -using Flux -using Functors -using Bijectors -using Bijectors: partition, PartitionMask - -""" -Neural Rational quadratic Spline layer - -# References -[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). -""" -struct NeuralSplineLayer{T1,T2,A<:AbstractVecOrMat{T1}} <: Bijectors.Bijector - dim::Int - mask::Bijectors.PartitionMask - w::A # width - h::A # height - d::A # derivative of the knots - B::T2 # bound of the knots -end - -function MLP_3layer(input_dim::Int, hdims::Int, output_dim::Int; activation=Flux.leakyrelu) - return Chain( - Flux.Dense(input_dim, hdims, activation), - Flux.Dense(hdims, hdims, activation), - Flux.Dense(hdims, output_dim), - ) -end - -function NeuralSplineLayer( - dim::Int, # dimension of input - hdims::Int, # dimension of hidden units for s and t - K::Int, # number of knots - mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on - B::T2, # bound of the knots -) where {T2<:Real} - num_of_transformed_dims = length(mask_idx) - input_dims = dim - num_of_transformed_dims - w = [MLP_3layer(input_dims, hdims, K) for i in 1:num_of_transformed_dims] - h = [MLP_3layer(input_dims, hdims, K) for i in 1:num_of_transformed_dims] - d = [MLP_3layer(input_dims, hdims, K - 1) for i in 1:num_of_transformed_dims] - mask = Bijectors.PartitionMask(D, mask_idx) - return NeuralSplineLayer(D, mask, w, h, d, B) -end - -@functor NeuralSplineLayer (w, h, d) - -# define forward and inverse transformation -function instantiate_rqs(nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector) - # instantiate rqs knots and derivatives - ws = permutedims(reduce(hcat, [w(x) for w in nsl.w])) - hs = permutedims(reduce(hcat, [h(x) for h in nsl.h])) - ds = permutedims(reduce(hcat, [d(x) for d in nsl.d])) - return Bijectors.RationalQuadraticSpline(ws, hs, ds, nsl.B) -end - -function Bijectors.transform( - nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector -) - x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) - # instantiate rqs knots and derivatives - rqs = instantiate_rqs(nsl, x_2) - y_1 = transform(rqs, x_1) - return Bijectors.combine(nsl.mask, y_1, x_2, x_3) -end - -function Bijectors.transform( - insl::Inverse{<:NeuralSplineLayer{<:Vector{<:Flux.Chain}}}, y::AbstractVector -) - nsl = insl.orig - y1, y2, y3 = partition(nsl.mask, y) - rqs = instantiate_rqs(nsl, y2) - x1 = transform(Inverse(rqs), y1) - return combine(nsl.mask, x1, y2, y3) -end - -function (nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}})(x::AbstractVector) - return Bijectors.transform(nsl, x) -end - -# define logabsdetjac -function Bijectors.logabsdetjac( - nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector -) - x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) - Rqs = instantiate_rqs(nsl, x_2) - logjac = logabsdetjac(Rqs, x_1) - return logjac -end - -function Bijectors.logabsdetjac( - insl::Inverse{<:NeuralSplineLayer{<:Vector{<:Flux.Chain}}}, y::AbstractVector -) - nsl = insl.orig - y1, y2, y3 = partition(nsl.mask, y) - rqs = instantiate_rqs(nsl, y2) - logjac = logabsdetjac(Inverse(rqs), y1) - return logjac -end - -function Bijectors.with_logabsdet_jacobian( - nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector -) - x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) - rqs = instantiate_rqs(nsl, x_2) - y_1, logjac = with_logabsdet_jacobian(rqs, x_1) - return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac -end diff --git a/example/planar_radial_flow/planar_flow.jl b/example/planar_radial_flow/planar_flow.jl deleted file mode 100644 index d1ccff81..00000000 --- a/example/planar_radial_flow/planar_flow.jl +++ /dev/null @@ -1,56 +0,0 @@ -using Random, Distributions, LinearAlgebra, Bijectors -using ADTypes -using Optimisers -using FunctionChains -using NormalizingFlows -using Zygote -using Flux: f32 -using Plots -include("../common.jl") - -Random.seed!(123) -rng = Random.default_rng() - -###################################### -# 2d Banana as the target distribution -###################################### -include("../targets/banana.jl") - -# create target p -p = Banana(2, 1.0f-1, 100.0f0) -logp = Base.Fix1(logpdf, p) - -###################################### -# learn the target using planar flow -###################################### -function create_planar_flow(n_layers::Int, q₀) - d = length(q₀) - Ls = [f32(PlanarLayer(d)) for _ in 1:n_layers] - ts = fchain(Ls) - return transformed(q₀, ts) -end - -# create a 10-layer planar flow -flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) -flow_untrained = deepcopy(flow) - -# train the flow -sample_per_iter = 10 -cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) -flow_trained, stats, _ = train_flow( - elbo, - flow, - logp, - sample_per_iter; - max_iters=200_00, - optimiser=Optimisers.ADAM(), - callback=cb, - ADbackend=AutoZygote(), -) -losses = map(x -> x.loss, stats) - -###################################### -# evaluate trained flow -###################################### -plot(losses; label="Loss", linewidth=2) # plot the loss -compare_trained_and_untrained_flow(flow_trained, flow_untrained, p, 1000) diff --git a/example/planar_radial_flow/radial_flow.jl b/example/planar_radial_flow/radial_flow.jl deleted file mode 100644 index 4d51cd24..00000000 --- a/example/planar_radial_flow/radial_flow.jl +++ /dev/null @@ -1,55 +0,0 @@ -using Random, Distributions, LinearAlgebra, Bijectors -using ADTypes -using Optimisers -using FunctionChains -using NormalizingFlows -using Zygote -using Flux: f32 -using Plots -include("../common.jl") - -Random.seed!(123) -rng = Random.default_rng() - -###################################### -# 2d Banana as the target distribution -###################################### -include("../targets/banana.jl") - -# create target p -p = Banana(2, 1.0f-1, 100.0f0) -logp = Base.Fix1(logpdf, p) - -###################################### -# learn the target using radial flow -###################################### -function create_radial_flow(n_layers::Int, q₀) - d = length(q₀) - Ls = [f32(RadialLayer(d)) for _ in 1:n_layers] - ts = fchain(Ls) - return transformed(q₀, ts) -end - -# create a 20-layer radial flow -flow = create_radial_flow(10, MvNormal(zeros(Float32, 2), I)) -flow_untrained = deepcopy(flow) - -# train the flow -sample_per_iter = 10 -cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) -flow_trained, stats, _ = train_flow( - elbo, - flow, - logp, - sample_per_iter; - max_iters=200_00, - optimiser=Optimisers.ADAM(), - callback=cb, -) -losses = map(x -> x.loss, stats) - -###################################### -# evaluate trained flow -###################################### -plot(losses; label="Loss", linewidth=2) # plot the loss -compare_trained_and_untrained_flow(flow_trained, flow_untrained, p, 1000; legend=:bottom) diff --git a/example/targets/banana.jl b/example/targets/banana.jl index c3d1d1ba..4cc9b25e 100644 --- a/example/targets/banana.jl +++ b/example/targets/banana.jl @@ -1,7 +1,3 @@ -using Distributions, Random -using Plots -using IrrationalConstants - """ Banana{T<:Real} @@ -85,12 +81,3 @@ function Distributions._logpdf(p::Banana, x::AbstractVector) logz = (log(s) / d + IrrationalConstants.log2π) * d / 2 return -logz - sum(ϕ⁻¹_x .^ 2 ./ vcat(s, ones(T, d - 1))) / 2 end - -function visualize(p::Banana, samples=rand(p, 1000)) - xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) - yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) - z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] - fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) - scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) - return fig -end diff --git a/example/targets/cross.jl b/example/targets/cross.jl index f905b86e..1c0fcdb2 100644 --- a/example/targets/cross.jl +++ b/example/targets/cross.jl @@ -1,4 +1,3 @@ -using Distributions, Random """ Cross(μ::Real=2.0, σ::Real=0.15) diff --git a/example/targets/neal_funnel.jl b/example/targets/neal_funnel.jl index 6cdf84a2..877c5aff 100644 --- a/example/targets/neal_funnel.jl +++ b/example/targets/neal_funnel.jl @@ -1,5 +1,3 @@ -using Distributions, Random - """ Funnel{T<:Real} @@ -45,8 +43,7 @@ Funnel(dim::Int) = Funnel(dim, 0.0, 9.0) Base.length(p::Funnel) = p.dim Base.eltype(p::Funnel{T}) where {T<:Real} = T -function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat) - T = eltype(x) +function Distributions._rand!(rng::AbstractRNG, p::Funnel{T}, x::AbstractVecOrMat{T}) where {T<:Real} d, μ, σ = p.dim, p.μ, p.σ d == size(x, 1) || error("Dimension mismatch") x[1, :] .= randn(rng, T, size(x, 2)) .* σ .+ μ @@ -54,9 +51,22 @@ function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat) return x end -function Distributions._logpdf(p::Funnel, x::AbstractVector) +function Distributions._logpdf(p::Funnel{T}, x::AbstractVector{T}) where {T<:Real} d, μ, σ = p.dim, p.μ, p.σ - lpdf1 = logpdf(Normal(μ, σ), x[1]) - lpdfs = logpdf.(Normal.(zeros(T, d - 1), exp(x[1] / 2)), @view(x[2:end])) - return lpdf1 + sum(lpdfs) + x1 = x[1] + x2 = x[2:end] + lpdf_x1 = logpdf(Normal(μ, σ), x1) + lpdf_x2_given_1 = logpdf(MvNormal(zeros(T, d-1), exp(x1)I), x2) + return lpdf_x1 + lpdf_x2_given_1 +end + +function score(p::Funnel{T}, x::AbstractVector{T}) where {T<:Real} + d, μ, σ = p.dim, p.μ, p.σ + x1 = x[1] + x_2_d = x[2:end] + a = expm1(-x1) + 1 + + ∇lpdf1 = (μ - x1)/σ^2 - (d-1)/2 + a*sum(abs2, x_2_d)/2 + ∇lpdfs = -a*x_2_d + return vcat(∇lpdf1, ∇lpdfs) end diff --git a/example/targets/warped_gaussian.jl b/example/targets/warped_gaussian.jl index a63012ed..90ed5c12 100644 --- a/example/targets/warped_gaussian.jl +++ b/example/targets/warped_gaussian.jl @@ -1,5 +1,3 @@ -using Distributions, Random, LinearAlgebra, IrrationalConstants - """ WarpedGauss{T<:Real} @@ -39,11 +37,11 @@ WarpedGauss(σ1::T, σ2::T) where {T<:Real} = WarpedGauss{T}(σ1, σ2) WarpedGauss() = WarpedGauss(1.0, 0.12) Base.length(p::WarpedGauss) = 2 -Base.eltype(p::WarpedGauss{T}) where {T<:Real} = T +Base.eltype(::WarpedGauss{T}) where {T<:Real} = T Distributions.sampler(p::WarpedGauss) = p # Define the transformation function φ and the inverse ϕ⁻¹ for the warped Gaussian distribution -function ϕ!(p::WarpedGauss, z::AbstractVector) +function ϕ!(::WarpedGauss, z::AbstractVector) length(z) == 2 || error("Dimension mismatch") x, y = z r = norm(z) @@ -53,7 +51,7 @@ function ϕ!(p::WarpedGauss, z::AbstractVector) return z end -function ϕ⁻¹(p::WarpedGauss, z::AbstractVector) +function ϕ⁻¹(::WarpedGauss, z::AbstractVector) length(z) == 2 || error("Dimension mismatch") x, y = z r = norm(z) @@ -71,7 +69,7 @@ end function Distributions._rand!(rng::AbstractRNG, p::WarpedGauss, x::AbstractVecOrMat) size(x, 1) == 2 || error("Dimension mismatch") - σ₁, σ₂ = p.σ₁, p.σ₂ + σ₁, σ₂ = p.σ1, p.σ2 randn!(rng, x) x .*= [σ₁, σ₂] for y in eachcol(x) @@ -82,7 +80,7 @@ end function Distributions._logpdf(p::WarpedGauss, x::AbstractVector) size(x, 1) == 2 || error("Dimension mismatch") - σ₁, σ₂ = p.σ₁, p.σ₂ + σ₁, σ₂ = p.σ1, p.σ2 S = [σ₁, σ₂] .^ 2 z, logJ = ϕ⁻¹(p, x) return -sum(z .^ 2 ./ S) / 2 - IrrationalConstants.log2π - log(σ₁) - log(σ₂) + logJ diff --git a/example/common.jl b/example/utils.jl similarity index 54% rename from example/common.jl rename to example/utils.jl index 114094fc..b5d19caa 100644 --- a/example/common.jl +++ b/example/utils.jl @@ -1,6 +1,23 @@ -using Random, Distributions, LinearAlgebra, Bijectors +using Random, Distributions, LinearAlgebra +using Bijectors: transformed +using Flux + +""" +A simple wrapper for a 3 layer dense MLP +""" +function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux.leakyrelu) + return Chain( + Flux.Dense(input_dim, hidden_dims, activation), + Flux.Dense(hidden_dims, hidden_dims, activation), + Flux.Dense(hidden_dims, output_dim), + ) +end + +function create_flow(Ls, q₀) + ts = reduce(∘, Ls) + return transformed(q₀, ts) +end -# accessing the trained flow by looking at the first 2 dimensions function compare_trained_and_untrained_flow( flow_trained::Bijectors.MultivariateTransformed, flow_untrained::Bijectors.MultivariateTransformed, @@ -47,7 +64,11 @@ function compare_trained_and_untrained_flow( return p end -function create_flow(Ls, q₀) - ts = fchain(Ls) - return transformed(q₀, ts) -end \ No newline at end of file +function visualize(p::Bijectors.MultivariateTransformed, samples=rand(p, 1000)) + xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) + yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) + z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] + fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) + scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) + return fig +end diff --git a/src/NormalizingFlows.jl b/src/NormalizingFlows.jl index 7c4deccb..72318df2 100644 --- a/src/NormalizingFlows.jl +++ b/src/NormalizingFlows.jl @@ -12,7 +12,7 @@ import DifferentiationInterface as DI using DocStringExtensions -export train_flow, elbo, loglikelihood +export train_flow, elbo, elbo_batch, loglikelihood """ train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...) diff --git a/src/objectives/elbo.jl b/src/objectives/elbo.jl index 14425f4c..d282e1e1 100644 --- a/src/objectives/elbo.jl +++ b/src/objectives/elbo.jl @@ -43,3 +43,39 @@ end function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples) return elbo(Random.default_rng(), flow, logp, n_samples) end + + +""" + elbo_batch(flow, logp, xs) + elbo_batch([rng, ]flow, logp, n_samples) + +Instead of broadcasting over elbo_single_sample, this function directly +computes the ELBO in a batched manner, which requires the flow.transform to be able to +handle batched transformation directly. + +This will be more efficient than `elbo` for invertible neural networks such as RealNVP, +Neural Spline Flow, etc. + +# Arguments +- `rng`: random number generator +- `flow`: variational distribution to be trained. In particular + `flow = transformed(q₀, T::Bijectors.Bijector)`, + q₀ is a reference distribution that one can easily sample and compute logpdf +- `logp`: log-pdf of the target distribution (not necessarily normalized) +- `xs`: samples from reference dist q₀ +- `n_samples`: number of samples from reference dist q₀ + +""" +function elbo_batch(flow::Bijectors.MultivariateTransformed, logp, xs::AbstractMatrix) + # requires the flow transformation to be able to handle batched inputs + ys, logabsdetjac = with_logabsdet_jacobian(flow.transform, xs) + elbos = logp(ys) .- logpdf(flow.dist, xs) .+ logabsdetjac + return elbos +end +function elbo_batch(rng::AbstractRNG, flow::Bijectors.MultivariateTransformed, logp, n_samples) + xs = _device_specific_rand(rng, flow.dist, n_samples) + elbos = elbo_batch(flow, logp, xs) + return mean(elbos) +end +elbo_batch(flow::Bijectors.UnivariateTransformed, logp, n_samples) = + elbo_batch(Random.default_rng(), flow, logp, n_samples) diff --git a/test/ext/CUDA/cuda.jl b/test/ext/CUDA/cuda.jl index e8fc7596..1a21302e 100644 --- a/test/ext/CUDA/cuda.jl +++ b/test/ext/CUDA/cuda.jl @@ -29,13 +29,13 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test return (transformed=transformed, wT_û=wT_û, wT_z=wT_z) end + CUDA.allowscalar(true) dists = [ MvNormal(CUDA.zeros(2), cu(Matrix{Float64}(I, 2, 2))), MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])), ] @testset "$dist" for dist in dists - CUDA.allowscalar(true) x = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist) xs = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist, 100) @test_nowarn logpdf(dist, x) @@ -44,7 +44,6 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test end @testset "$dist" for dist in dists - CUDA.allowscalar(true) pl1 = PlanarLayer( identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1)) ) diff --git a/test/objectives.jl b/test/objectives.jl index 4641b3cd..fa977e83 100644 --- a/test/objectives.jl +++ b/test/objectives.jl @@ -18,6 +18,13 @@ @test logpdf(flow, x) + el ≈ logp(x) end + @testset "elbo_batch" begin + el = elbo_batch(rng, flow, logp, 10) + + @test abs(el) ≤ 1e-5 + @test logpdf(flow, x) + el ≈ logp(x) + end + @testset "likelihood" begin sample_trained = rand(flow, 1000) sample_untrained = rand(q₀, 1000)