diff --git a/.gitignore b/.gitignore index f7a16479..a7dae251 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ /docs/build/ test/Manifest.toml example/Manifest.toml +example/LocalPreferences.toml # Files generated by invoking Julia with --code-coverage *.jl.cov diff --git a/Project.toml b/Project.toml index db11ea42..ad666126 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,9 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/example/Project.toml b/example/Project.toml index d462c5ee..12198160 100644 --- a/example/Project.toml +++ b/example/Project.toml @@ -2,16 +2,22 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[extras] +CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/example/README.md b/example/README.md index 27a59bb0..738fff98 100644 --- a/example/README.md +++ b/example/README.md @@ -12,7 +12,7 @@ normalizing flow to approximate the target distribution using `NormalizingFlows. Currently, all examples share the same [Julia project](https://pkgdocs.julialang.org/v1/environments/#Using-someone-else's-project). To run the examples, first activate the project environment: ```julia -# pwd() = "NormalizingFlows.jl/" -using Pkg; Pkg.activate("example"); Pkg.instantiate() +# pwd() = "NormalizingFlows.jl/example" +using Pkg; Pkg.activate("."); Pkg.instantiate() ``` -This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with include(".jl"), or by running the example script line-by-line. +This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with `include(".jl")`, or by running the example script line-by-line. diff --git a/example/SyntheticTargets.jl b/example/SyntheticTargets.jl new file mode 100644 index 00000000..d35408cb --- /dev/null +++ b/example/SyntheticTargets.jl @@ -0,0 +1,34 @@ +using DocStringExtensions +using Distributions, Random, LinearAlgebra +using IrrationalConstants +using Plots + + +include("targets/banana.jl") +include("targets/cross.jl") +include("targets/neal_funnel.jl") +include("targets/warped_gaussian.jl") + + +function load_model(name::String) + if name == "Banana" + return Banana(2, 1.0, 10.0) + elseif name == "Cross" + return Cross() + elseif name == "Funnel" + return Funnel(2) + elseif name == "WarpedGaussian" + return WarpedGauss() + else + error("Model not defined") + end +end + +function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000)) + xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) + yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) + z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] + fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) + scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) + return fig +end diff --git a/example/common.jl b/example/common.jl deleted file mode 100644 index 114094fc..00000000 --- a/example/common.jl +++ /dev/null @@ -1,53 +0,0 @@ -using Random, Distributions, LinearAlgebra, Bijectors - -# accessing the trained flow by looking at the first 2 dimensions -function compare_trained_and_untrained_flow( - flow_trained::Bijectors.MultivariateTransformed, - flow_untrained::Bijectors.MultivariateTransformed, - true_dist::ContinuousMultivariateDistribution, - n_samples::Int; - kwargs..., -) - samples_trained = rand(flow_trained, n_samples) - samples_untrained = rand(flow_untrained, n_samples) - samples_true = rand(true_dist, n_samples) - - p = scatter( - samples_true[1, :], - samples_true[2, :]; - label="True Distribution", - color=:blue, - markersize=2, - alpha=0.5, - ) - scatter!( - p, - samples_untrained[1, :], - samples_untrained[2, :]; - label="Untrained Flow", - color=:red, - markersize=2, - alpha=0.5, - ) - scatter!( - p, - samples_trained[1, :], - samples_trained[2, :]; - label="Trained Flow", - color=:green, - markersize=2, - alpha=0.5, - ) - plot!(; kwargs...) - - xlabel!(p, "X") - ylabel!(p, "Y") - title!(p, "Comparison of Trained and Untrained Flow") - - return p -end - -function create_flow(Ls, q₀) - ts = fchain(Ls) - return transformed(q₀, ts) -end \ No newline at end of file diff --git a/example/demo_RealNVP.jl b/example/demo_RealNVP.jl new file mode 100644 index 00000000..e81ea83b --- /dev/null +++ b/example/demo_RealNVP.jl @@ -0,0 +1,163 @@ +using Flux +using Bijectors +using Bijectors: partition, combine, PartitionMask + +using Random, Distributions, LinearAlgebra +using Functors +using Optimisers, ADTypes +using Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +################################## +# define affine coupling layer using Bijectors.jl interface +################################# +struct AffineCoupling <: Bijectors.Bijector + dim::Int + mask::Bijectors.PartitionMask + s::Flux.Chain + t::Flux.Chain +end + +# let params track field s and t +@functor AffineCoupling (s, t) + +function AffineCoupling( + dim::Int, # dimension of input + hdims::Int, # dimension of hidden units for s and t + mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on +) + cdims = length(mask_idx) # dimension of parts used to construct coupling law + s = mlp3(cdims, hdims, cdims) + t = mlp3(cdims, hdims, cdims) + mask = PartitionMask(dim, mask_idx) + return AffineCoupling(dim, mask, s, t) +end + +function Bijectors.transform(af::AffineCoupling, x::AbstractVector) + # partition vector using 'af.mask::PartitionMask` + x₁, x₂, x₃ = partition(af.mask, x) + y₁ = x₁ .* af.s(x₂) .+ af.t(x₂) + return combine(af.mask, y₁, x₂, x₃) +end + +function (af::AffineCoupling)(x::AbstractArray) + return transform(af, x) +end + +function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector) + x_1, x_2, x_3 = Bijectors.partition(af.mask, x) + y_1 = af.s(x_2) .* x_1 .+ af.t(x_2) + logjac = sum(log ∘ abs, af.s(x_2)) + return combine(af.mask, y_1, x_2, x_3), logjac +end + +function Bijectors.with_logabsdet_jacobian( + iaf::Inverse{<:AffineCoupling}, y::AbstractVector +) + af = iaf.orig + # partition vector using `af.mask::PartitionMask` + y_1, y_2, y_3 = partition(af.mask, y) + # inverse transformation + x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2) + logjac = -sum(log ∘ abs, af.s(y_2)) + return combine(af.mask, x_1, y_2, y_3), logjac +end + +function Bijectors.logabsdetjac(af::AffineCoupling, x::AbstractVector) + _, x_2, _ = partition(af.mask, x) + logjac = sum(log ∘ abs, af.s(x_2)) + return logjac +end + +################### +# an equivalent definition of AffineCoupling using Bijectors.Coupling +# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1) +################### + +# struct AffineCoupling <: Bijectors.Bijector +# dim::Int +# mask::Bijectors.PartitionMask +# s::Flux.Chain +# t::Flux.Chain +# end + +# # let params track field s and t +# @functor AffineCoupling (s, t) + +# function AffineCoupling(dim, mask, s, t) +# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask) +# end + +# function AffineCoupling( +# dim::Int, # dimension of input +# hdims::Int, # dimension of hidden units for s and t +# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on +# ) +# cdims = length(mask_idx) # dimension of parts used to construct coupling law +# s = mlp3(cdims, hdims, cdims) +# t = mlp3(cdims, hdims, cdims) +# mask = PartitionMask(dim, mask_idx) +# return AffineCoupling(dim, mask, s, t) +# end + + + +################################## +# start demo +################################# +Random.seed!(123) +rng = Random.default_rng() +T = Float32 + +###################################### +# a difficult banana target +###################################### +target = Banana(2, 1.0f0, 100.0f0) +logp = Base.Fix1(logpdf, target) + +###################################### +# learn the target using Affine coupling flow +###################################### +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) + +d = 2 +hdims = 32 +Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3] + +flow = create_flow(Ls, q0) +flow_untrained = deepcopy(flow) + + +###################################### +# start training +###################################### +sample_per_iter = 64 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp, + sample_per_iter; + max_iters=50_000, + optimiser=Optimisers.Adam(5e-4), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_hamiltonian_flow.jl b/example/demo_hamiltonian_flow.jl new file mode 100644 index 00000000..1fea3cbb --- /dev/null +++ b/example/demo_hamiltonian_flow.jl @@ -0,0 +1,177 @@ +using Random, Distributions, LinearAlgebra +using Functors +using Optimisers, ADTypes +using Mooncake +using Bijectors +using Bijectors: partition, combine, PartitionMask +using SimpleUnPack: @unpack + +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + + +""" +Hamiltonian flow interleaves betweem Leapfrog layer and an affine transformation to the momentum variable +It targets the joint distribution π(x, ρ) = π(x) * N(ρ; 0, I) +where x is the target variable and ρ is the momentum variable + +- The optimizable parameters for the leapfrog layers are the step size logϵ +- and we will also optimize the affine transformation parameters (shift and scale) + +Wrap the leapfrog transformation into a Bijectors.jl interface + +# References +[1] Naitong Chen, Zuheng Xu, Trevor Campbell, Bayesian inference via sparse Hamiltonian flows, NeurIPS 2022. +""" +struct LeapFrog{T<:Real} <: Bijectors.Bijector + "dimention of the target space" + dim::Int + "leapfrog step size" + logϵ::AbstractVector{T} + "number of leapfrog steps" + L::Int + "score function of the target distribution" + ∇logp + "mask function to split the input into position and momentum" + mask::PartitionMask +end +@functor LeapFrog (logϵ,) + +function LeapFrog(dim::Int, logϵ::T, L::Int, ∇logp) where {T<:Real} + return LeapFrog(dim, logϵ .* ones(T, dim), L, ∇logp, PartitionMask(2dim, 1:dim)) +end + +_get_stepsize(lf::LeapFrog) = exp.(lf.logϵ) + +""" +run L leapfrog steps with std Gaussian momentum distribution with vector stepsizes +""" +function _leapfrog( + ∇ℓπ, ϵ::AbstractVector{T}, L::Int, x::AbstractVecOrMat{T}, v::AbstractVecOrMat{T} +) where {T<:Real} + v += ϵ/2 .* ∇ℓπ(x) + for _ in 1:L - 1 + x += ϵ .* v + v += ϵ .* ∇ℓπ(x) + end + x += ϵ .* v + v += ϵ/2 .* ∇ℓπ(x) + return x, v +end + +function Bijectors.transform(lf::LeapFrog{T}, z::AbstractVector{T}) where {T<:Real} + @unpack dim, logϵ, L, ∇logp = lf + @assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]" + + ϵ = _get_stepsize(lf) + x, ρ, e = partition(lf.mask, z) # split into position and momentum + x_, ρ_ = _leapfrog(∇logp, ϵ, L, x, ρ) # run L learpfrog steps + return combine(lf.mask, x_, ρ_, e) +end + +function Bijectors.transform(ilf::Inverse{<:LeapFrog{T}}, z::AbstractVector{T}) where {T<:Real} + lf = ilf.orig + @unpack dim, logϵ, L, ∇logp = lf + @assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]" + + ϵ = _get_stepsize(lf) + x, ρ, e = partition(lf.mask, z) # split into position and momentum + x_, ρ_ = _leapfrog(∇logp, -ϵ, L, x, ρ) # run L learpfrog steps + return combine(lf.mask, x_, ρ_, e) +end + +function Bijectors.with_logabsdet_jacobian(lf::LeapFrog{T}, z::AbstractVector{T}) where {T<:Real} + # leapfrog is symplectic, so the logabsdetjacobian is 0 + return Bijectors.transform(lf, z), zero(eltype(z)) +end +function Bijectors.with_logabsdet_jacobian(ilf::Inverse{<:LeapFrog{T}}, z::AbstractVector{T}) where {T<:Real} + # leapfrog is symplectic, so the logabsdetjacobian is 0 + return Bijectors.transform(ilf, z), zero(eltype(z)) +end + +# shift and scale transformation that only applies to the momentum variable ρ = z[(dim + 1):end] +function momentum_normalization_layer(dims::Int, T::Type{<:Real}) + bx = identity # leave position variable x = z[1:dim] unchanged + bρ = Bijectors.Shift(zeros(T, dims)) ∘ Bijectors.Scale(ones(T, dims)) + b = Bijectors.Stacked((bx, bρ), [1:dims, (dims + 1):(2*dims)]) + return b +end + + +################################## +# start demo +################################# +Random.seed!(123) +rng = Random.default_rng() +T = Float64 # for Hamiltonian VI, its recommended to use Float64 as the dynamic is chaotic + +###################################### +# a Funnel target +###################################### +dims = 2 +target = Funnel(dims, -8.0, 5.0) +# visualize(target) + +logp = Base.Fix1(logpdf, target) +function logp_joint(z::AbstractVector{T}) where {T<:Real} + dims = div(length(z), 2) + x = @view z[1:dims] + ρ = @view z[(dims + 1):end] + logp_x = logp(x) + logp_ρ = sum(logpdf(Normal(), ρ)) + return logp_x + logp_ρ +end +∇logp = Base.Fix1(score, target) + +###################################### +# build the flow in the joint space +###################################### +# mean field Gaussian reference +@leaf MvNormal +q0 = transformed( + MvNormal(zeros(T, 2dims), ones(T, 2dims)), Bijectors.Shift(zeros(T, 2dims)) ∘ Bijectors.Scale(ones(T, 2dims)) +) + +nlfg = 3 +logϵ0 = log(0.05) # initial step size + +# Hamiltonian flow interleaves betweem Leapfrog layer and an affine transformation to the momentum variable +Ls = [ + momentum_normalization_layer(dims, T) ∘ LeapFrog(dims, logϵ0, nlfg, ∇logp) for _ in 1:15 +] + +flow = create_flow(Ls, q0) +flow_untrained = deepcopy(flow) + + +###################################### +# start training +###################################### +sample_per_iter = 16 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp_joint, + sample_per_iter; + max_iters=50_000, + optimiser=Optimisers.Adam(3e-4), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_neural_spline_flow.jl b/example/demo_neural_spline_flow.jl new file mode 100644 index 00000000..02a7b321 --- /dev/null +++ b/example/demo_neural_spline_flow.jl @@ -0,0 +1,172 @@ +using Flux +using Bijectors +using Bijectors: partition, combine, PartitionMask + +using Random, Distributions, LinearAlgebra +using Functors +using Optimisers, ADTypes +using Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +################################## +# define neural spline layer using Bijectors.jl interface +################################# +""" +Neural Rational quadratic Spline layer + +# References +[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). +""" +struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector + dim::Int # dimension of input + K::Int # number of knots + n_dims_transferred::Int # number of dimensions that are transformed + nn::A # networks that parmaterize the knots and derivatives + B::T # bound of the knots + mask::Bijectors.PartitionMask +end + +function NeuralSplineLayer( + dim::T1, # dimension of input + hdims::T1, # dimension of hidden units for s and t + K::T1, # number of knots + B::T2, # bound of the knots + mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on +) where {T1<:Int,T2<:Real} + num_of_transformed_dims = length(mask_idx) + input_dims = dim - num_of_transformed_dims + + # output dim of the NN + output_dims = (3K - 1)*num_of_transformed_dims + # one big mlp that outputs all the knots and derivatives for all the transformed dimensions + nn = mlp3(input_dims, hdims, output_dims) + + mask = Bijectors.PartitionMask(dim, mask_idx) + return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask) +end + +@functor NeuralSplineLayer (nn,) + +# define forward and inverse transformation +""" +Build a rational quadratic spline from the nn output +Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline + +we just need to map the nn output to the knots and derivatives of the RQS +""" +function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector) + K, B = nsl.K, nsl.B + nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :) + ws = @view nnoutput[:, 1:K] + hs = @view nnoutput[:, (K + 1):(2K)] + ds = @view nnoutput[:, (2K + 1):(3K - 1)] + return Bijectors.RationalQuadraticSpline(ws, hs, ds, B) +end + +function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector) + x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) + # instantiate rqs knots and derivatives + rqs = instantiate_rqs(nsl, x_2) + y_1 = Bijectors.transform(rqs, x_1) + return Bijectors.combine(nsl.mask, y_1, x_2, x_3) +end + +function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector) + nsl = insl.orig + y1, y2, y3 = partition(nsl.mask, y) + rqs = instantiate_rqs(nsl, y2) + x1 = Bijectors.transform(Inverse(rqs), y1) + return Bijectors.combine(nsl.mask, x1, y2, y3) +end + +function (nsl::NeuralSplineLayer)(x::AbstractVector) + return Bijectors.transform(nsl, x) +end + +# define logabsdetjac +function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector) + x_1, x_2, _ = Bijectors.partition(nsl.mask, x) + rqs = instantiate_rqs(nsl, x_2) + logjac = logabsdetjac(rqs, x_1) + return logjac +end + +function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector) + nsl = insl.orig + y1, y2, _ = partition(nsl.mask, y) + rqs = instantiate_rqs(nsl, y2) + logjac = logabsdetjac(Inverse(rqs), y1) + return logjac +end + +function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector) + x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) + rqs = instantiate_rqs(nsl, x_2) + y_1, logjac = with_logabsdet_jacobian(rqs, x_1) + return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac +end + +################################## +# start demo +################################# +Random.seed!(123) +rng = Random.default_rng() +T = Float32 + +###################################### +# neals funnel target +###################################### +target = Funnel(2, 0.0f0, 9.0f0) +logp = Base.Fix1(logpdf, target) + +###################################### +# learn the target using Affine coupling flow +###################################### +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) + +d = 2 +hdims = 64 +K = 10 +B = 30 +Ls = [ + NeuralSplineLayer(d, hdims, K, B, [1]) ∘ NeuralSplineLayer(d, hdims, K, B, [2]) for + i in 1:3 +] + +flow = create_flow(Ls, q0) +flow_untrained = deepcopy(flow) + + +###################################### +# start training +###################################### +sample_per_iter = 64 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp, + sample_per_iter; + max_iters=50_000, + optimiser=Optimisers.Adam(5e-5), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_planar_flow.jl b/example/demo_planar_flow.jl new file mode 100644 index 00000000..a37e2679 --- /dev/null +++ b/example/demo_planar_flow.jl @@ -0,0 +1,63 @@ +using Random, Distributions, LinearAlgebra, Bijectors +using Functors +using Optimisers, ADTypes, Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +Random.seed!(123) +rng = Random.default_rng() +T = Float64 + +###################################### +# 2d Banana as the target distribution +###################################### +target = load_model("Banana") +logp = Base.Fix1(logpdf, target) + + +###################################### +# setup planar flow +###################################### +function create_planar_flow(n_layers::Int, q₀) + d = length(q₀) + Ls = [PlanarLayer(d) for _ in 1:n_layers] + ts = reduce(∘, Ls) + return transformed(q₀, ts) +end + +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) +flow = create_planar_flow(10, q0) +flow_untrained = deepcopy(flow) + +###################################### +# start training +###################################### +sample_per_iter = 30 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp, + sample_per_iter; + max_iters=10_000, + optimiser=Optimisers.Adam(one(T)/100), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/demo_radial_flow.jl b/example/demo_radial_flow.jl new file mode 100644 index 00000000..282af5c3 --- /dev/null +++ b/example/demo_radial_flow.jl @@ -0,0 +1,64 @@ +using Random, Distributions, LinearAlgebra, Bijectors +using Functors +using Optimisers, ADTypes, Mooncake +using NormalizingFlows + +include("SyntheticTargets.jl") +include("utils.jl") + +Random.seed!(123) +rng = Random.default_rng() +T = Float64 + +###################################### +# get target logp +###################################### +target = load_model("WarpedGaussian") +logp = Base.Fix1(logpdf, target) + +###################################### +# setup radial flow +###################################### +function create_radial_flow(n_layers::Int, q₀) + d = length(q₀) + Ls = [RadialLayer(d) for _ in 1:n_layers] + ts = reduce(∘, Ls) + return transformed(q₀, ts) +end + +# create a 10-layer radial flow +@leaf MvNormal +q0 = MvNormal(zeros(T, 2), ones(T, 2)) +flow = create_radial_flow(10, q0) + +flow_untrained = deepcopy(flow) + +###################################### +# start training +###################################### +sample_per_iter = 30 + +# callback function to log training progress +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) +adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 +flow_trained, stats, _ = train_flow( + elbo, + flow, + logp, + sample_per_iter; + max_iters=10_000, + optimiser=Optimisers.Adam(one(T)/100), + ADbackend=adtype, + show_progress=true, + callback=cb, + hasconverged=checkconv, +) +θ, re = Optimisers.destructure(flow_trained) +losses = map(x -> x.loss, stats) + +###################################### +# evaluate trained flow +###################################### +plot(losses; label="Loss", linewidth=2) # plot the loss +compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) diff --git a/example/neural_spline_flow/nsf_layer.jl b/example/neural_spline_flow/nsf_layer.jl deleted file mode 100644 index 6f3b2a7f..00000000 --- a/example/neural_spline_flow/nsf_layer.jl +++ /dev/null @@ -1,107 +0,0 @@ -using Flux -using Functors -using Bijectors -using Bijectors: partition, PartitionMask - -""" -Neural Rational quadratic Spline layer - -# References -[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019). -""" -struct NeuralSplineLayer{T1,T2,A<:AbstractVecOrMat{T1}} <: Bijectors.Bijector - dim::Int - mask::Bijectors.PartitionMask - w::A # width - h::A # height - d::A # derivative of the knots - B::T2 # bound of the knots -end - -function MLP_3layer(input_dim::Int, hdims::Int, output_dim::Int; activation=Flux.leakyrelu) - return Chain( - Flux.Dense(input_dim, hdims, activation), - Flux.Dense(hdims, hdims, activation), - Flux.Dense(hdims, output_dim), - ) -end - -function NeuralSplineLayer( - dim::Int, # dimension of input - hdims::Int, # dimension of hidden units for s and t - K::Int, # number of knots - mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on - B::T2, # bound of the knots -) where {T2<:Real} - num_of_transformed_dims = length(mask_idx) - input_dims = dim - num_of_transformed_dims - w = [MLP_3layer(input_dims, hdims, K) for i in 1:num_of_transformed_dims] - h = [MLP_3layer(input_dims, hdims, K) for i in 1:num_of_transformed_dims] - d = [MLP_3layer(input_dims, hdims, K - 1) for i in 1:num_of_transformed_dims] - mask = Bijectors.PartitionMask(D, mask_idx) - return NeuralSplineLayer(D, mask, w, h, d, B) -end - -@functor NeuralSplineLayer (w, h, d) - -# define forward and inverse transformation -function instantiate_rqs(nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector) - # instantiate rqs knots and derivatives - ws = permutedims(reduce(hcat, [w(x) for w in nsl.w])) - hs = permutedims(reduce(hcat, [h(x) for h in nsl.h])) - ds = permutedims(reduce(hcat, [d(x) for d in nsl.d])) - return Bijectors.RationalQuadraticSpline(ws, hs, ds, nsl.B) -end - -function Bijectors.transform( - nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector -) - x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) - # instantiate rqs knots and derivatives - rqs = instantiate_rqs(nsl, x_2) - y_1 = transform(rqs, x_1) - return Bijectors.combine(nsl.mask, y_1, x_2, x_3) -end - -function Bijectors.transform( - insl::Inverse{<:NeuralSplineLayer{<:Vector{<:Flux.Chain}}}, y::AbstractVector -) - nsl = insl.orig - y1, y2, y3 = partition(nsl.mask, y) - rqs = instantiate_rqs(nsl, y2) - x1 = transform(Inverse(rqs), y1) - return combine(nsl.mask, x1, y2, y3) -end - -function (nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}})(x::AbstractVector) - return Bijectors.transform(nsl, x) -end - -# define logabsdetjac -function Bijectors.logabsdetjac( - nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector -) - x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) - Rqs = instantiate_rqs(nsl, x_2) - logjac = logabsdetjac(Rqs, x_1) - return logjac -end - -function Bijectors.logabsdetjac( - insl::Inverse{<:NeuralSplineLayer{<:Vector{<:Flux.Chain}}}, y::AbstractVector -) - nsl = insl.orig - y1, y2, y3 = partition(nsl.mask, y) - rqs = instantiate_rqs(nsl, y2) - logjac = logabsdetjac(Inverse(rqs), y1) - return logjac -end - -function Bijectors.with_logabsdet_jacobian( - nsl::NeuralSplineLayer{<:Vector{<:Flux.Chain}}, x::AbstractVector -) - x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x) - rqs = instantiate_rqs(nsl, x_2) - y_1, logjac = with_logabsdet_jacobian(rqs, x_1) - return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac -end diff --git a/example/planar_radial_flow/planar_flow.jl b/example/planar_radial_flow/planar_flow.jl deleted file mode 100644 index d1ccff81..00000000 --- a/example/planar_radial_flow/planar_flow.jl +++ /dev/null @@ -1,56 +0,0 @@ -using Random, Distributions, LinearAlgebra, Bijectors -using ADTypes -using Optimisers -using FunctionChains -using NormalizingFlows -using Zygote -using Flux: f32 -using Plots -include("../common.jl") - -Random.seed!(123) -rng = Random.default_rng() - -###################################### -# 2d Banana as the target distribution -###################################### -include("../targets/banana.jl") - -# create target p -p = Banana(2, 1.0f-1, 100.0f0) -logp = Base.Fix1(logpdf, p) - -###################################### -# learn the target using planar flow -###################################### -function create_planar_flow(n_layers::Int, q₀) - d = length(q₀) - Ls = [f32(PlanarLayer(d)) for _ in 1:n_layers] - ts = fchain(Ls) - return transformed(q₀, ts) -end - -# create a 10-layer planar flow -flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I)) -flow_untrained = deepcopy(flow) - -# train the flow -sample_per_iter = 10 -cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) -flow_trained, stats, _ = train_flow( - elbo, - flow, - logp, - sample_per_iter; - max_iters=200_00, - optimiser=Optimisers.ADAM(), - callback=cb, - ADbackend=AutoZygote(), -) -losses = map(x -> x.loss, stats) - -###################################### -# evaluate trained flow -###################################### -plot(losses; label="Loss", linewidth=2) # plot the loss -compare_trained_and_untrained_flow(flow_trained, flow_untrained, p, 1000) diff --git a/example/planar_radial_flow/radial_flow.jl b/example/planar_radial_flow/radial_flow.jl deleted file mode 100644 index 4d51cd24..00000000 --- a/example/planar_radial_flow/radial_flow.jl +++ /dev/null @@ -1,55 +0,0 @@ -using Random, Distributions, LinearAlgebra, Bijectors -using ADTypes -using Optimisers -using FunctionChains -using NormalizingFlows -using Zygote -using Flux: f32 -using Plots -include("../common.jl") - -Random.seed!(123) -rng = Random.default_rng() - -###################################### -# 2d Banana as the target distribution -###################################### -include("../targets/banana.jl") - -# create target p -p = Banana(2, 1.0f-1, 100.0f0) -logp = Base.Fix1(logpdf, p) - -###################################### -# learn the target using radial flow -###################################### -function create_radial_flow(n_layers::Int, q₀) - d = length(q₀) - Ls = [f32(RadialLayer(d)) for _ in 1:n_layers] - ts = fchain(Ls) - return transformed(q₀, ts) -end - -# create a 20-layer radial flow -flow = create_radial_flow(10, MvNormal(zeros(Float32, 2), I)) -flow_untrained = deepcopy(flow) - -# train the flow -sample_per_iter = 10 -cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,) -flow_trained, stats, _ = train_flow( - elbo, - flow, - logp, - sample_per_iter; - max_iters=200_00, - optimiser=Optimisers.ADAM(), - callback=cb, -) -losses = map(x -> x.loss, stats) - -###################################### -# evaluate trained flow -###################################### -plot(losses; label="Loss", linewidth=2) # plot the loss -compare_trained_and_untrained_flow(flow_trained, flow_untrained, p, 1000; legend=:bottom) diff --git a/example/targets/banana.jl b/example/targets/banana.jl index c3d1d1ba..4cc9b25e 100644 --- a/example/targets/banana.jl +++ b/example/targets/banana.jl @@ -1,7 +1,3 @@ -using Distributions, Random -using Plots -using IrrationalConstants - """ Banana{T<:Real} @@ -85,12 +81,3 @@ function Distributions._logpdf(p::Banana, x::AbstractVector) logz = (log(s) / d + IrrationalConstants.log2π) * d / 2 return -logz - sum(ϕ⁻¹_x .^ 2 ./ vcat(s, ones(T, d - 1))) / 2 end - -function visualize(p::Banana, samples=rand(p, 1000)) - xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) - yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) - z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] - fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) - scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) - return fig -end diff --git a/example/targets/cross.jl b/example/targets/cross.jl index f905b86e..1c0fcdb2 100644 --- a/example/targets/cross.jl +++ b/example/targets/cross.jl @@ -1,4 +1,3 @@ -using Distributions, Random """ Cross(μ::Real=2.0, σ::Real=0.15) diff --git a/example/targets/neal_funnel.jl b/example/targets/neal_funnel.jl index 6cdf84a2..877c5aff 100644 --- a/example/targets/neal_funnel.jl +++ b/example/targets/neal_funnel.jl @@ -1,5 +1,3 @@ -using Distributions, Random - """ Funnel{T<:Real} @@ -45,8 +43,7 @@ Funnel(dim::Int) = Funnel(dim, 0.0, 9.0) Base.length(p::Funnel) = p.dim Base.eltype(p::Funnel{T}) where {T<:Real} = T -function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat) - T = eltype(x) +function Distributions._rand!(rng::AbstractRNG, p::Funnel{T}, x::AbstractVecOrMat{T}) where {T<:Real} d, μ, σ = p.dim, p.μ, p.σ d == size(x, 1) || error("Dimension mismatch") x[1, :] .= randn(rng, T, size(x, 2)) .* σ .+ μ @@ -54,9 +51,22 @@ function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat) return x end -function Distributions._logpdf(p::Funnel, x::AbstractVector) +function Distributions._logpdf(p::Funnel{T}, x::AbstractVector{T}) where {T<:Real} d, μ, σ = p.dim, p.μ, p.σ - lpdf1 = logpdf(Normal(μ, σ), x[1]) - lpdfs = logpdf.(Normal.(zeros(T, d - 1), exp(x[1] / 2)), @view(x[2:end])) - return lpdf1 + sum(lpdfs) + x1 = x[1] + x2 = x[2:end] + lpdf_x1 = logpdf(Normal(μ, σ), x1) + lpdf_x2_given_1 = logpdf(MvNormal(zeros(T, d-1), exp(x1)I), x2) + return lpdf_x1 + lpdf_x2_given_1 +end + +function score(p::Funnel{T}, x::AbstractVector{T}) where {T<:Real} + d, μ, σ = p.dim, p.μ, p.σ + x1 = x[1] + x_2_d = x[2:end] + a = expm1(-x1) + 1 + + ∇lpdf1 = (μ - x1)/σ^2 - (d-1)/2 + a*sum(abs2, x_2_d)/2 + ∇lpdfs = -a*x_2_d + return vcat(∇lpdf1, ∇lpdfs) end diff --git a/example/targets/warped_gaussian.jl b/example/targets/warped_gaussian.jl index a63012ed..90ed5c12 100644 --- a/example/targets/warped_gaussian.jl +++ b/example/targets/warped_gaussian.jl @@ -1,5 +1,3 @@ -using Distributions, Random, LinearAlgebra, IrrationalConstants - """ WarpedGauss{T<:Real} @@ -39,11 +37,11 @@ WarpedGauss(σ1::T, σ2::T) where {T<:Real} = WarpedGauss{T}(σ1, σ2) WarpedGauss() = WarpedGauss(1.0, 0.12) Base.length(p::WarpedGauss) = 2 -Base.eltype(p::WarpedGauss{T}) where {T<:Real} = T +Base.eltype(::WarpedGauss{T}) where {T<:Real} = T Distributions.sampler(p::WarpedGauss) = p # Define the transformation function φ and the inverse ϕ⁻¹ for the warped Gaussian distribution -function ϕ!(p::WarpedGauss, z::AbstractVector) +function ϕ!(::WarpedGauss, z::AbstractVector) length(z) == 2 || error("Dimension mismatch") x, y = z r = norm(z) @@ -53,7 +51,7 @@ function ϕ!(p::WarpedGauss, z::AbstractVector) return z end -function ϕ⁻¹(p::WarpedGauss, z::AbstractVector) +function ϕ⁻¹(::WarpedGauss, z::AbstractVector) length(z) == 2 || error("Dimension mismatch") x, y = z r = norm(z) @@ -71,7 +69,7 @@ end function Distributions._rand!(rng::AbstractRNG, p::WarpedGauss, x::AbstractVecOrMat) size(x, 1) == 2 || error("Dimension mismatch") - σ₁, σ₂ = p.σ₁, p.σ₂ + σ₁, σ₂ = p.σ1, p.σ2 randn!(rng, x) x .*= [σ₁, σ₂] for y in eachcol(x) @@ -82,7 +80,7 @@ end function Distributions._logpdf(p::WarpedGauss, x::AbstractVector) size(x, 1) == 2 || error("Dimension mismatch") - σ₁, σ₂ = p.σ₁, p.σ₂ + σ₁, σ₂ = p.σ1, p.σ2 S = [σ₁, σ₂] .^ 2 z, logJ = ϕ⁻¹(p, x) return -sum(z .^ 2 ./ S) / 2 - IrrationalConstants.log2π - log(σ₁) - log(σ₂) + logJ diff --git a/example/utils.jl b/example/utils.jl new file mode 100644 index 00000000..490f1e55 --- /dev/null +++ b/example/utils.jl @@ -0,0 +1,108 @@ +using Random, Distributions, LinearAlgebra +using Bijectors: transformed +using Flux + +""" +A simple wrapper for a 3 layer dense MLP +""" +function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux.leakyrelu) + return Chain( + Flux.Dense(input_dim, hidden_dims, activation), + Flux.Dense(hidden_dims, hidden_dims, activation), + Flux.Dense(hidden_dims, output_dim), + ) +end + +function create_flow(Ls, q₀) + ts = reduce(∘, Ls) + return transformed(q₀, ts) +end + +function compare_trained_and_untrained_flow( + flow_trained::Bijectors.MultivariateTransformed, + flow_untrained::Bijectors.MultivariateTransformed, + true_dist::ContinuousMultivariateDistribution, + n_samples::Int; + kwargs..., +) + samples_trained = rand(flow_trained, n_samples) + samples_untrained = rand(flow_untrained, n_samples) + samples_true = rand(true_dist, n_samples) + + p = scatter( + samples_true[1, :], + samples_true[2, :]; + label="True Distribution", + color=:blue, + markersize=2, + alpha=0.5, + ) + scatter!( + p, + samples_untrained[1, :], + samples_untrained[2, :]; + label="Untrained Flow", + color=:red, + markersize=2, + alpha=0.5, + ) + scatter!( + p, + samples_trained[1, :], + samples_trained[2, :]; + label="Trained Flow", + color=:green, + markersize=2, + alpha=0.5, + ) + plot!(; kwargs...) + + xlabel!(p, "X") + ylabel!(p, "Y") + title!(p, "Comparison of Trained and Untrained Flow") + + return p +end + +# function check_trained_flow( +# flow_trained::Bijectors.MultivariateTransformed, +# true_dist::ContinuousMultivariateDistribution, +# n_samples::Int; +# kwargs..., +# ) +# samples_trained = rand_batch(flow_trained, n_samples) +# samples_true = rand(true_dist, n_samples) + +# p = Plots.scatter( +# samples_true[1, :], +# samples_true[2, :]; +# label="True Distribution", +# color=:green, +# markersize=2, +# alpha=0.5, +# ) +# Plots.scatter!( +# p, +# samples_trained[1, :], +# samples_trained[2, :]; +# label="Trained Flow", +# color=:red, +# markersize=2, +# alpha=0.5, +# ) +# Plots.plot!(; kwargs...) + +# Plots.title!(p, "Trained HamFlow") + +# return p +# end + + +function visualize(p::Bijectors.MultivariateTransformed, samples=rand(p, 1000)) + xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) + yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) + z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] + fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) + scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) + return fig +end