Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# DynamicPPL Changelog

## 0.37.3

An extension for MarginalLogDensities.jl has been added.
Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model; please see the documentation for further information.

## 0.37.2

Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well.
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.2"
version = "0.37.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -34,13 +34,15 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]

Expand All @@ -65,6 +67,7 @@ JET = "0.9, 0.10"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
MarginalLogDensities = "0.4.1"
MCMCChains = "6, 7"
MacroTools = "0.5.6"
Mooncake = "0.4.147"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Expand Down
7 changes: 6 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Distributions
using DocumenterMermaid
# load MCMCChains package extension to make `predict` available
using MCMCChains
using MarginalLogDensities: MarginalLogDensities

# Doctest setup
DocMeta.setdocmeta!(
Expand All @@ -24,7 +25,11 @@ makedocs(;
format=Documenter.HTML(;
size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3()
),
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
modules=[
DynamicPPL,
Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt),
Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt),
],
pages=[
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]
],
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ The `predict` function has two main methods:
predict
```

## Marginalization

DynamicPPL provides the `marginalize` function to marginalize out variables from a model.
This requires `MarginalLogDensities.jl` to be loaded in your environment.

```@docs
marginalize
```

### Basic Usage

The typical workflow for posterior prediction involves:
Expand Down
81 changes: 81 additions & 0 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
module DynamicPPLMarginalLogDensitiesExt

using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
using MarginalLogDensities: MarginalLogDensities

_to_varname(n::Symbol) = VarName{n}()
_to_varname(n::VarName) = n

"""
marginalize(
model::DynamicPPL.Model,
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
getlogprob=DynamicPPL.getlogjoint,
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
kwargs...,
)

Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal
log-density of the given `model`, after marginalizing out the variables specified in
`varnames`.

The resulting object can be called with a vector of parameter values to compute the marginal
log-density.

The `getlogprob` argument can be used to specify which kind of marginal log-density to
compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
probability.

By default the marginalization is performed with a Laplace approximation. Please see [the
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
for other options.

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

julia> @model function demo()
x ~ Normal(1.0)
y ~ Normal(2.0)
end
demo (generic function with 2 methods)

julia> marginalized = marginalize(demo(), [:x]);

julia> # The resulting callable computes the marginal log-density of `y`.
marginalized([1.0])
-1.4189385332046727

julia> logpdf(Normal(2.0), 1.0)
-1.4189385332046727
```
"""
function DynamicPPL.marginalize(
model::DynamicPPL.Model,
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
getlogprob=DynamicPPL.getlogjoint,
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
kwargs...,
)
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
vns = map(_to_varname, varnames)
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that in principle MLD should be able to handle this case fine:

using DynamicPPL, Distributions, MarginalLogDensities, LinearAlgebra

@model function f()
    x ~ MvNormal([0.0, 1.0, 2.0], I)
end

marginalize(f(), [@varname(x[1])])

Unfortunately, attempting to do this now falls over:

ERROR: KeyError: key VarName{:x, Accessors.IndexLens{Tuple{Int64}}}[x[1]] not found
Stacktrace:
 [1] vector_getranges(varinfo::VarInfo{@NamedTuple{…}, DynamicPPL.AccumulatorTuple{…}}, vns::Vector{VarName{…}})
   @ DynamicPPL ~/ppl/dppl/src/varinfo.jl:732

This is definitely a limitation of using the internal vector_getranges function. I wonder if maybe we should explicitly allow the user to specify the indices that they want to remove?

# Construct the marginal log-density model.
# Use linked `varinfo` to that we're working in unconstrained space
varinfo_linked = DynamicPPL.link(varinfo, model)
Comment on lines +66 to +67
Copy link
Member Author

@penelopeysm penelopeysm Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does MLD require a linked VarInfo, or would an unlinked VarInfo be fine?

The reason I'm thinking about this is because if the VarInfo is linked, then all the parameters that are later supplied must be in linked space, which is potentially a bit confusing (though nothing that can't be fixed by documentation). Example:

julia> using DynamicPPL, Distributions, Bijectors, MarginalLogDensities

julia> @model function f()
           x ~ Normal()
           y ~ Beta(2, 2)
       end
f (generic function with 2 methods)

julia> m = marginalize(f(), [@varname(x)]);

julia> m([0.5]) # this 0.5 is in linked space
0.3436055008678415

julia> logpdf(Beta(2, 2), 0.5) # this 0.5 is unlinked, so logp is wrong
0.4054651081081644

julia> inverse(Bijectors.bijector(Beta(2, 2)))(0.5) # this is the unlinked value corresponding to 0.5
0.6224593312018546

julia> logpdf(Beta(2, 2), 0.6224593312018546) # now logp matches
0.3436055008678416

If an unlinked VarInfo is acceptable, then the choice of varinfo should probably also be added as an argument to marginalize.


f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked)
mdl = MarginalLogDensities.MarginalLogDensity(
(x, _) -> LogDensityProblems.logdensity(f, x),
varinfo_linked[:],
varindices,
(),
method;
kwargs...,
)
return mdl
end

end
8 changes: 4 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ export AbstractVarInfo,
fix,
unfix,
predict,
marginalize,
prefix,
returned,
to_submodel,
Expand Down Expand Up @@ -199,10 +200,6 @@ include("test_utils.jl")
include("experimental.jl")
include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end

# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
Expand Down Expand Up @@ -247,4 +244,7 @@ end
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct DynamicPPLTag end

# Extended in MarginalLogDensitiesExt
function marginalize end

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
27 changes: 27 additions & 0 deletions test/ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module MarginalLogDensitiesExtTests

using DynamicPPL, Distributions, Test
using MarginalLogDensities
using ADTypes: AutoForwardDiff

@testset "MarginalLogDensities" begin
# Simple test case.
@model function demo()
x ~ MvNormal(zeros(2), [1, 1])
return y ~ Normal(0, 1)
end
model = demo()
# Marginalize out `x`.

for vn in [@varname(x), :x]
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint]
marginalized = marginalize(
model, [vn], getlogprob; hess_adtype=AutoForwardDiff()
)
# Compute the marginal log-density of `y = 0.0`.
@test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ include("test_util.jl")
@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
include("ext/DynamicPPLMarginalLogDensitiesExt.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
Expand Down
Loading