From 4b820f56a660b296078c5a8594e8454d05377d8e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 17:40:25 +0100 Subject: [PATCH 1/5] Transfer MarginalLogDensities extension from Turing See: https://github.com/TuringLang/Turing.jl/pull/2664 Co-authored-by: Sam Urmy --- Project.toml | 2 ++ ext/DynamicPPLMarginalLogDensitiesExt.jl | 36 +++++++++++++++++++ src/DynamicPPL.jl | 8 ++--- test/Project.toml | 1 + test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 27 ++++++++++++++ test/runtests.jl | 1 + 6 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 ext/DynamicPPLMarginalLogDensitiesExt.jl create mode 100644 test/ext/DynamicPPLMarginalLogDensitiesExt.jl diff --git a/Project.toml b/Project.toml index 5f11cba3f..56a681375 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] @@ -41,6 +42,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] +DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..8153d1522 --- /dev/null +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -0,0 +1,36 @@ +module DynamicPPLMarginalLogDensitiesExt + +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using MarginalLogDensities: MarginalLogDensities + +_to_varname(n::Symbol) = VarName{n}() +_to_varname(n::VarName) = n + +function DynamicPPL.marginalize( + model::DynamicPPL.Model, + varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + getlogprob=DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); + kwargs..., +) + # Determine the indices for the variables to marginalise out. + varinfo = DynamicPPL.typed_varinfo(model) + vns = map(_to_varname, varnames) + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) + # Construct the marginal log-density model. + # Use linked `varinfo` to that we're working in unconstrained space + varinfo_linked = DynamicPPL.link(varinfo, model) + + f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked) + mdl = MarginalLogDensities.MarginalLogDensity( + (x, _) -> LogDensityProblems.logdensity(f, x), + varinfo_linked[:], + varindices, + (), + method; + kwargs..., + ) + return mdl +end + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b400e83dd..f67cec6b7 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -122,6 +122,7 @@ export AbstractVarInfo, fix, unfix, predict, + marginalize, prefix, returned, to_submodel, @@ -199,10 +200,6 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -if !isdefined(Base, :get_extension) - using Requires -end - # Better error message if users forget to load JET if isdefined(Base.Experimental, :register_error_hint) function __init__() @@ -247,4 +244,7 @@ end # Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ struct DynamicPPLTag end +# Extended in MarginalLogDensitiesExt +function marginalize end + end # module diff --git a/test/Project.toml b/test/Project.toml index 91a885e96..168f360b8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl new file mode 100644 index 000000000..381f67e97 --- /dev/null +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -0,0 +1,27 @@ +module MarginalLogDensitiesExtTests + +using DynamicPPL, Distributions, Test +using MarginalLogDensities +using ADTypes: AutoForwardDiff + +@testset "MarginalLogDensities" begin + # Simple test case. + @model function demo() + x ~ MvNormal(zeros(2), [1, 1]) + return y ~ Normal(0, 1) + end + model = demo() + # Marginalize out `x`. + + for vn in [@varname(x), :x] + for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] + marginalized = marginalize( + model, [vn], getlogprob; hess_adtype=AutoForwardDiff() + ) + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5 + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..40960884e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,6 +80,7 @@ include("test_util.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") + include("ext/DynamicPPLMarginalLogDensitiesExt.jl") end @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") From 8fab001e737c32bb1600e34af8013a09aae915f7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 17:48:14 +0100 Subject: [PATCH 2/5] Add documentation --- docs/Project.toml | 1 + docs/make.jl | 1 + docs/src/api.md | 9 +++++ ext/DynamicPPLMarginalLogDensitiesExt.jl | 45 ++++++++++++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 1f01b11ef..cc0be339d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] diff --git a/docs/make.jl b/docs/make.jl index 9c59cb06b..4185672e1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,7 @@ using Distributions using DocumenterMermaid # load MCMCChains package extension to make `predict` available using MCMCChains +using MarginalLogDensities: MarginalLogDensities # Doctest setup DocMeta.setdocmeta!( diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..d93c58863 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -123,6 +123,15 @@ The `predict` function has two main methods: predict ``` +## Marginalization + +DynamicPPL provides the `marginalize` function to marginalize out variables from a model. +This requires `MarginalLogDensities.jl` to be loaded in your environment. + +```@docs +marginalize +``` + ### Basic Usage The typical workflow for posterior prediction involves: diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 8153d1522..f79e7a027 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,6 +6,51 @@ using MarginalLogDensities: MarginalLogDensities _to_varname(n::Symbol) = VarName{n}() _to_varname(n::VarName) = n +""" + marginalize( + model::DynamicPPL.Model, + varnames::AbstractVector{<:Union{Symbol,<:VarName}}, + getlogprob=DynamicPPL.getlogjoint, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); + kwargs..., + ) + +Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal +log-density of the given `model`, after marginalizing out the variables specified in +`varnames`. + +The resulting object can be called with a vector of parameter values to compute the marginal +log-density. + +The `getlogprob` argument can be used to specify which kind of marginal log-density to +compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint +probability. + +By default the marginalization is performed with a Laplace approximation. Please see [the +MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) +for other options. + +## Example + +```jldoctest +julia> using DynamicPPL, Distributions, MarginalLogDensities + +julia> @model function demo() + x ~ Normal(1.0) + y ~ Normal(2.0) + end +demo (generic function with 2 methods) + +julia> marginalized = marginalize(demo(), [:x]); + +julia> # The resulting callable computes the marginal log-density of `y`. + marginalized([1.0]) +-1.4189385332046727 + +julia> logpdf(Normal(2.0), 1.0) +-1.4189385332046727 +``` +""" function DynamicPPL.marginalize( model::DynamicPPL.Model, varnames::AbstractVector{<:Union{Symbol,<:VarName}}, From 105013a63b9d2ff95b636ab79749fd7a151d5cdf Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 17:54:28 +0100 Subject: [PATCH 3/5] Bump patch, add changelog --- HISTORY.md | 5 +++++ Project.toml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 0f22721d9..b2983a426 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.37.3 + +An extension for MarginalLogDensities.jl has been added. +Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model; please see the documentation for further information. + ## 0.37.2 Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well. diff --git a/Project.toml b/Project.toml index 56a681375..df764ffe2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.2" +version = "0.37.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 1b3e76b1c4e951223dcbc4f57372155352820fa2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 18:19:34 +0100 Subject: [PATCH 4/5] Add compat entry for MLD --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index df764ffe2..11a37316f 100644 --- a/Project.toml +++ b/Project.toml @@ -67,6 +67,7 @@ JET = "0.9, 0.10" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" +MarginalLogDensities = "0.4.1" MCMCChains = "6, 7" MacroTools = "0.5.6" Mooncake = "0.4.147" From aaca138c84dce3ac15acbd0653d5ebb5df5c153e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 5 Sep 2025 18:25:06 +0100 Subject: [PATCH 5/5] Fix docs --- docs/make.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 4185672e1..de64a21f5 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -25,7 +25,11 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() ), - modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], + modules=[ + DynamicPPL, + Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt), + Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt), + ], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] ],