diff --git a/Project.toml b/Project.toml index 422284095..34d529c7f 100644 --- a/Project.toml +++ b/Project.toml @@ -41,10 +41,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Optim = "429524aa-4258-5aef-a3af-852621145aeb" [extensions] TuringDynamicHMCExt = "DynamicHMC" +TuringMarginalLogDensitiesExt = "MarginalLogDensities" TuringOptimExt = "Optim" [compat] @@ -71,6 +73,7 @@ Libtask = "0.9.3" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" +MarginalLogDensities = "0.4.1" NamedArrays = "0.9, 0.10" Optim = "1" Optimization = "3, 4" @@ -89,4 +92,4 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" \ No newline at end of file diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..848d6fcaa --- /dev/null +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -0,0 +1,47 @@ +module TuringMarginalLogDensitiesExt + +using Turing: Turing, DynamicPPL +using Turing.Inference: LogDensityProblems +using MarginalLogDensities: MarginalLogDensities + + +# Use a struct for this to avoid closure overhead. +struct Drop2ndArgAndFlipSign{F} + f::F +end + +(f::Drop2ndArgAndFlipSign)(x, _) = -f.f(x) + +_to_varname(n::Symbol) = DynamicPPL.@varname($n) +_to_varname(n::DynamicPPL.AbstractPPL.VarName) = n + +function Turing.marginalize( + model::DynamicPPL.Model, + varnames::Vector, + getlogprob = DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities.LaplaceApprox(); + kwargs... +) + # Determine the indices for the variables to marginalise out. + varinfo = DynamicPPL.typed_varinfo(model) + vns = _to_varname.(varnames) + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) + # Construct the marginal log-density model. + # Use linked `varinfo` to that we're working in unconstrained space + varinfo_linked = DynamicPPL.link(varinfo, model) + + f = Turing.Optimisation.OptimLogDensity( + model, + getlogprob, + varinfo_linked + ) + + # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which + # represent the _negative_ log-density. + mdl = MarginalLogDensities.MarginalLogDensity( + Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method; kwargs... + ) + return mdl +end + +end diff --git a/src/Turing.jl b/src/Turing.jl index 0cdbe2458..c08155d3c 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -53,6 +53,8 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation +include("extensions.jl") + ########### # Exports # ########### @@ -153,6 +155,8 @@ export maximum_a_posteriori, maximum_likelihood, MAP, - MLE + MLE, + # MarginalLogDensities extension + marginalize end diff --git a/src/extensions.jl b/src/extensions.jl new file mode 100644 index 000000000..118db1b15 --- /dev/null +++ b/src/extensions.jl @@ -0,0 +1,3 @@ +function marginalize(model, varnames, getlogprob, method; kwargs...) + error("This function is available after importing MarginalLogDensities.") +end \ No newline at end of file diff --git a/test/ext/marginallogdensities.jl b/test/ext/marginallogdensities.jl new file mode 100644 index 000000000..76cf47a50 --- /dev/null +++ b/test/ext/marginallogdensities.jl @@ -0,0 +1,24 @@ +module TuringMarginalLogDensitiesTest + +using Turing, MarginalLogDensities, Test +using Turing.DynamicPPL: getlogprior, getlogjoint + +@testset "MarginalLogDensities" begin + # Simple test case. + @model function demo() + x ~ MvNormal(zeros(2), [1, 1]) + y ~ Normal(0, 1) + end + model = demo(); + # Marginalize out `x`. + + for vn in [@varname(x), :x] + for getlogprob in [getlogprior, getlogjoint] + marginalized = marginalize(model, [vn], getlogprob, hess_adtype=AutoForwardDiff()); + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol=1e-5 + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 5fb6b2141..b59b91ee0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,11 @@ end @timeit_include("optimisation/Optimisation.jl") @timeit_include("ext/OptimInterface.jl") end + + @testset "marginalization" verbose = true begin + @timeit_include("ext/marginallogdensities.jl") + end + end @testset "stdlib" verbose = true begin