From 32c61cc90c033b389fbed3ade57f55aefc832fd5 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Mon, 25 Aug 2025 16:49:28 +1000 Subject: [PATCH 1/3] Get MLD extension working --- Project.toml | 5 ++- ext/TuringMarginalLogDensitiesExt.jl | 46 ++++++++++++++++++++++++++++ src/Turing.jl | 6 +++- src/extensions.jl | 3 ++ test/ext/marginallogdensities.jl | 19 ++++++++++++ 5 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 ext/TuringMarginalLogDensitiesExt.jl create mode 100644 src/extensions.jl create mode 100644 test/ext/marginallogdensities.jl diff --git a/Project.toml b/Project.toml index 422284095..34d529c7f 100644 --- a/Project.toml +++ b/Project.toml @@ -41,10 +41,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Optim = "429524aa-4258-5aef-a3af-852621145aeb" [extensions] TuringDynamicHMCExt = "DynamicHMC" +TuringMarginalLogDensitiesExt = "MarginalLogDensities" TuringOptimExt = "Optim" [compat] @@ -71,6 +73,7 @@ Libtask = "0.9.3" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" +MarginalLogDensities = "0.4.1" NamedArrays = "0.9, 0.10" Optim = "1" Optimization = "3, 4" @@ -89,4 +92,4 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" \ No newline at end of file diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..f470a29a7 --- /dev/null +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -0,0 +1,46 @@ +module TuringMarginalLogDensitiesExt + +using Turing: Turing, DynamicPPL +using Turing.Inference: LogDensityProblems +using MarginalLogDensities: MarginalLogDensities + + +# Use a struct for this to avoid closure overhead. +struct Drop2ndArgAndFlipSign{F} + f::F +end + +(f::Drop2ndArgAndFlipSign)(x, _) = -f.f(x) + +_to_varname(n::Symbol) = DynamicPPL.@varname($n) +_to_varname(n::DynamicPPL.AbstractPPL.VarName) = n + +function Turing.marginalize( + model::DynamicPPL.Model, + varnames::Vector, + method::MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities.LaplaceApprox(), +) + # Determine the indices for the variables to marginalise out. + varinfo = DynamicPPL.typed_varinfo(model) + vns = _to_varname.(varnames) + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) + # Construct the marginal log-density model. + # Use linked `varinfo` to that we're working in unconstrained space + varinfo_linked = DynamicPPL.link(varinfo, model) + + f = Turing.Optimisation.OptimLogDensity( + model, + Turing.DynamicPPL.getlogjoint, + # Turing.DynamicPPL.typed_varinfo(model) + varinfo_linked + ) + + # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which + # represent the _negative_ log-density. + mdl = MarginalLogDensities.MarginalLogDensity( + Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method + ) + return mdl +end + +end diff --git a/src/Turing.jl b/src/Turing.jl index 0cdbe2458..c08155d3c 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -53,6 +53,8 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation +include("extensions.jl") + ########### # Exports # ########### @@ -153,6 +155,8 @@ export maximum_a_posteriori, maximum_likelihood, MAP, - MLE + MLE, + # MarginalLogDensities extension + marginalize end diff --git a/src/extensions.jl b/src/extensions.jl new file mode 100644 index 000000000..bd768c04f --- /dev/null +++ b/src/extensions.jl @@ -0,0 +1,3 @@ +function marginalize(model, varnames, method) + error("This function is available after importing MarginalLogDensities.") +end \ No newline at end of file diff --git a/test/ext/marginallogdensities.jl b/test/ext/marginallogdensities.jl new file mode 100644 index 000000000..c24b54e92 --- /dev/null +++ b/test/ext/marginallogdensities.jl @@ -0,0 +1,19 @@ +module TuringMarginalLogDensitiesTest + +using Turing, MarginalLogDensities, Test + +@testset "MarginalLogDensities" begin + # Simple test case. + @model function demo() + x ~ MvNormal(zeros(2), [1, 1]) + y ~ Normal(0, 1) + end + model = demo(); + # Marginalize out `x`. + marginalized = marginalize(model, [@varname(x)]); + marginalized = marginalize(model, [:x]); + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol=2e-1 +end + +end \ No newline at end of file From 88623589272459060b0adc15ea99289ea6a7eb95 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Mon, 25 Aug 2025 18:45:19 +1000 Subject: [PATCH 2/3] add marginalization tests to runtests.jl --- test/runtests.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 5fb6b2141..b59b91ee0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,11 @@ end @timeit_include("optimisation/Optimisation.jl") @timeit_include("ext/OptimInterface.jl") end + + @testset "marginalization" verbose = true begin + @timeit_include("ext/marginallogdensities.jl") + end + end @testset "stdlib" verbose = true begin From edb0418db0c583bd8b9e567c057bbdc934cca754 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Wed, 27 Aug 2025 14:29:30 +1000 Subject: [PATCH 3/3] Add arguments for log-probability and kwargs passed to MLD --- ext/TuringMarginalLogDensitiesExt.jl | 9 +++++---- src/extensions.jl | 2 +- test/ext/marginallogdensities.jl | 15 ++++++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl index f470a29a7..848d6fcaa 100644 --- a/ext/TuringMarginalLogDensitiesExt.jl +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -18,7 +18,9 @@ _to_varname(n::DynamicPPL.AbstractPPL.VarName) = n function Turing.marginalize( model::DynamicPPL.Model, varnames::Vector, - method::MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities.LaplaceApprox(), + getlogprob = DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities.LaplaceApprox(); + kwargs... ) # Determine the indices for the variables to marginalise out. varinfo = DynamicPPL.typed_varinfo(model) @@ -30,15 +32,14 @@ function Turing.marginalize( f = Turing.Optimisation.OptimLogDensity( model, - Turing.DynamicPPL.getlogjoint, - # Turing.DynamicPPL.typed_varinfo(model) + getlogprob, varinfo_linked ) # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which # represent the _negative_ log-density. mdl = MarginalLogDensities.MarginalLogDensity( - Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method + Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method; kwargs... ) return mdl end diff --git a/src/extensions.jl b/src/extensions.jl index bd768c04f..118db1b15 100644 --- a/src/extensions.jl +++ b/src/extensions.jl @@ -1,3 +1,3 @@ -function marginalize(model, varnames, method) +function marginalize(model, varnames, getlogprob, method; kwargs...) error("This function is available after importing MarginalLogDensities.") end \ No newline at end of file diff --git a/test/ext/marginallogdensities.jl b/test/ext/marginallogdensities.jl index c24b54e92..76cf47a50 100644 --- a/test/ext/marginallogdensities.jl +++ b/test/ext/marginallogdensities.jl @@ -1,6 +1,7 @@ module TuringMarginalLogDensitiesTest using Turing, MarginalLogDensities, Test +using Turing.DynamicPPL: getlogprior, getlogjoint @testset "MarginalLogDensities" begin # Simple test case. @@ -10,10 +11,14 @@ using Turing, MarginalLogDensities, Test end model = demo(); # Marginalize out `x`. - marginalized = marginalize(model, [@varname(x)]); - marginalized = marginalize(model, [:x]); - # Compute the marginal log-density of `y = 0.0`. - @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol=2e-1 + + for vn in [@varname(x), :x] + for getlogprob in [getlogprior, getlogjoint] + marginalized = marginalize(model, [vn], getlogprob, hess_adtype=AutoForwardDiff()); + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol=1e-5 + end + end end -end \ No newline at end of file +end