-
Notifications
You must be signed in to change notification settings - Fork 36
MarginalLogDensities extension #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
module DynamicPPLMarginalLogDensitiesExt | ||
|
||
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName | ||
using MarginalLogDensities: MarginalLogDensities | ||
|
||
_to_varname(n::Symbol) = VarName{n}() | ||
_to_varname(n::VarName) = n | ||
|
||
""" | ||
marginalize( | ||
model::DynamicPPL.Model, | ||
varnames::AbstractVector{<:Union{Symbol,<:VarName}}, | ||
getlogprob=DynamicPPL.getlogjoint, | ||
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); | ||
kwargs..., | ||
) | ||
|
||
Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal | ||
log-density of the given `model`, after marginalizing out the variables specified in | ||
`varnames`. | ||
|
||
The resulting object can be called with a vector of parameter values to compute the marginal | ||
log-density. | ||
|
||
The `getlogprob` argument can be used to specify which kind of marginal log-density to | ||
compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint | ||
probability. | ||
|
||
By default the marginalization is performed with a Laplace approximation. Please see [the | ||
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) | ||
for other options. | ||
|
||
## Example | ||
|
||
```jldoctest | ||
julia> using DynamicPPL, Distributions, MarginalLogDensities | ||
|
||
julia> @model function demo() | ||
x ~ Normal(1.0) | ||
y ~ Normal(2.0) | ||
end | ||
demo (generic function with 2 methods) | ||
|
||
julia> marginalized = marginalize(demo(), [:x]); | ||
|
||
julia> # The resulting callable computes the marginal log-density of `y`. | ||
marginalized([1.0]) | ||
-1.4189385332046727 | ||
|
||
julia> logpdf(Normal(2.0), 1.0) | ||
-1.4189385332046727 | ||
``` | ||
""" | ||
function DynamicPPL.marginalize( | ||
model::DynamicPPL.Model, | ||
varnames::AbstractVector{<:Union{Symbol,<:VarName}}, | ||
getlogprob=DynamicPPL.getlogjoint, | ||
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); | ||
kwargs..., | ||
) | ||
# Determine the indices for the variables to marginalise out. | ||
varinfo = DynamicPPL.typed_varinfo(model) | ||
vns = map(_to_varname, varnames) | ||
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) | ||
# Construct the marginal log-density model. | ||
# Use linked `varinfo` to that we're working in unconstrained space | ||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||
Comment on lines
+66
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does MLD require a linked VarInfo, or would an unlinked VarInfo be fine? The reason I'm thinking about this is because if the VarInfo is linked, then all the parameters that are later supplied must be in linked space, which is potentially a bit confusing (though nothing that can't be fixed by documentation). Example: julia> using DynamicPPL, Distributions, Bijectors, MarginalLogDensities
julia> @model function f()
x ~ Normal()
y ~ Beta(2, 2)
end
f (generic function with 2 methods)
julia> m = marginalize(f(), [@varname(x)]);
julia> m([0.5]) # this 0.5 is in linked space
0.3436055008678415
julia> logpdf(Beta(2, 2), 0.5) # this 0.5 is unlinked, so logp is wrong
0.4054651081081644
julia> inverse(Bijectors.bijector(Beta(2, 2)))(0.5) # this is the unlinked value corresponding to 0.5
0.6224593312018546
julia> logpdf(Beta(2, 2), 0.6224593312018546) # now logp matches
0.3436055008678416 If an unlinked VarInfo is acceptable, then the choice of varinfo should probably also be added as an argument to |
||
|
||
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked) | ||
mdl = MarginalLogDensities.MarginalLogDensity( | ||
(x, _) -> LogDensityProblems.logdensity(f, x), | ||
varinfo_linked[:], | ||
varindices, | ||
(), | ||
method; | ||
kwargs..., | ||
) | ||
return mdl | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
module MarginalLogDensitiesExtTests | ||
|
||
using DynamicPPL, Distributions, Test | ||
using MarginalLogDensities | ||
using ADTypes: AutoForwardDiff | ||
|
||
@testset "MarginalLogDensities" begin | ||
# Simple test case. | ||
@model function demo() | ||
x ~ MvNormal(zeros(2), [1, 1]) | ||
return y ~ Normal(0, 1) | ||
end | ||
model = demo() | ||
# Marginalize out `x`. | ||
|
||
for vn in [@varname(x), :x] | ||
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] | ||
marginalized = marginalize( | ||
model, [vn], getlogprob; hess_adtype=AutoForwardDiff() | ||
) | ||
# Compute the marginal log-density of `y = 0.0`. | ||
@test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5 | ||
end | ||
end | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that in principle MLD should be able to handle this case fine:
Unfortunately, attempting to do this now falls over:
This is definitely a limitation of using the internal
vector_getranges
function. I wonder if maybe we should explicitly allow the user to specify the indices that they want to remove?