Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and improvements to experimental Gibbs #2231

Merged
merged 40 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0b2279f
moved new Gibbs tests all into a single block
torfjelde Apr 23, 2024
dcad548
initial work on making Gibbs work with `externalsampler`
torfjelde Apr 23, 2024
8e2d7be
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde May 18, 2024
4a609cb
removed references to Setfield.jl
torfjelde May 18, 2024
fc21894
fixed crucial bug in experimental Gibbs sampler
torfjelde May 18, 2024
9910962
added ground-truth comparison for Gibbs sampler on demo models
torfjelde May 18, 2024
b3a4692
added convenience method for performing two sample KS test
torfjelde May 18, 2024
3e17efc
use thinning to avoid OOM issues
torfjelde May 20, 2024
429fc8f
removed incredibly slow testset that didn't really add much
torfjelde May 20, 2024
f6af20e
removed now-redundant testset
torfjelde May 20, 2024
065eef6
use Anderson-Darling test instead of Kolomogorov-Smirnov to better
torfjelde May 20, 2024
c2d23e5
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 4, 2024
99f28f9
more work on testing
torfjelde Jun 6, 2024
6df1ccc
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 6, 2024
b6a907e
fixed tests
torfjelde Jun 6, 2024
a4e223e
Merge remote-tracking branch 'origin/torfjelde/gibbs-new-improv' into…
torfjelde Jun 6, 2024
e1e7386
make failures of `two_sample_ad_tests` a bit more informative
torfjelde Jun 6, 2024
be1ec7f
make failrues of `two_sample_ad_test` produce more informative logs
torfjelde Jun 6, 2024
5f36446
additional information upon `two_sample_ad_test` failure
torfjelde Jun 6, 2024
3be8f8b
rename `two_sample_ad_test` to `two_sample_test` and use KS test instead
torfjelde Jun 6, 2024
dbaf447
added minor test for externalsampler usage
torfjelde Jun 16, 2024
f44c407
also test AdvancedHMC samplers with Gibbs
torfjelde Jun 16, 2024
dd86cfa
forgot to add updates to src/mcmc/abstractmcmc.jl in previous commits
torfjelde Jun 17, 2024
4160577
Merge remote-tracking branch 'origin/torfjelde/gibbs-new-improv' into…
torfjelde Jun 17, 2024
2a7d85b
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 17, 2024
bdc61fe
removed usage of `timeit_testset` macro
torfjelde Jun 17, 2024
d76243e
added temporary fix for externalsampler that needs to be removed once
torfjelde Jun 17, 2024
14f5c89
minor reorg of two testsets
torfjelde Jun 17, 2024
5893d54
set random seeds more aggressively in an attempt to make tests more r…
torfjelde Jun 18, 2024
4a2cea2
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jun 18, 2024
4f30ea5
removed hack, awaiting PR to DynamicPPL
torfjelde Jun 18, 2024
414a077
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jun 26, 2024
89bc2e1
renamed `_getmodel` to `getmodel`, `_setmodel` to `setmodel`, and
torfjelde Jun 26, 2024
3d3c944
missed some instances during rnenaming
torfjelde Jun 26, 2024
e1f1a0e
fixed missing merge in initial step for experimental `Gibbs`
torfjelde Jul 10, 2024
7c4368e
Always reconstruct `ADGradientWrapper` using the `adype` available in…
torfjelde Jul 15, 2024
06357c6
Test Gibbs with different adtype in externalsampler to ensure that works
torfjelde Jul 15, 2024
02f9fad
Update Project.toml
yebai Jul 15, 2024
30ab9e0
Update Project.toml
yebai Jul 15, 2024
d40d82b
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions src/experimental/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, values), vi
return value, broadcast_logpdf(right, value), vi
end

# Otherwise, falls back to the default behavior.
Expand All @@ -90,8 +90,8 @@ function DynamicPPL.dot_tilde_assume(
)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return values, broadcast_logpdf(right, values), vi
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, value), vi
end

# Otherwise, falls back to the default behavior.
Expand Down Expand Up @@ -144,14 +144,14 @@ end
Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned.
"""
function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo)
return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
end
function DynamicPPL.condition(
function condition_gibbs(
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo,
varinfos::DynamicPPL.AbstractVarInfo...
)
return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...)
return condition_gibbs(condition_gibbs(context, varinfo), varinfos...)
end
# Allow calling this on a `DynamicPPL.Model` directly.
function condition_gibbs(model::DynamicPPL.Model, values...)
Expand Down Expand Up @@ -238,6 +238,9 @@ function Gibbs(algs::Pair...)
return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs)))
end

# TODO: Remove when no longer needed.
DynamicPPL.getspace(::Gibbs) = ()

struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
vi::V
states::S
Expand All @@ -252,6 +255,7 @@ function DynamicPPL.initialstep(
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
vi_base::DynamicPPL.AbstractVarInfo;
initial_params=nothing,
kwargs...,
)
alg = spl.alg
Expand All @@ -260,15 +264,35 @@ function DynamicPPL.initialstep(

# 1. Run the model once to get the varnames present + initial values to condition on.
vi_base = DynamicPPL.VarInfo(model)

# Simple way of setting the initial parameters: set them in the `vi_base`
# if they are given so they propagate to the subset varinfos used by each sampler.
if initial_params !== nothing
vi_base = DynamicPPL.unflatten(vi_base, initial_params)
end

# Create the varinfos for each sampler.
varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames)
initial_params_all = if initial_params === nothing
fill(nothing, length(varnames))
else
# Extract from the `vi_base`, which should have the values set correctly from above.
map(vi -> vi[:], varinfos)
end

# 2. Construct a varinfo for every vn + sampler combo.
states_and_varinfos = map(samplers, varinfos) do sampler_local, varinfo_local
states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local
# Construct the conditional model.
model_local = make_conditional(model, varinfo_local, varinfos)

# Take initial step.
new_state_local = last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...))
new_state_local = last(AbstractMCMC.step(
rng, model_local, sampler_local;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...
))

# Return the new state and the invlinked `varinfo`.
vi_local_state = Turing.Inference.varinfo(new_state_local)
Expand Down Expand Up @@ -365,12 +389,7 @@ function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, s
end

# TODO: Remove `rng`?
"""
recompute_logprob!!(rng, model, sampler, state)

Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(
function Turing.Inference.recompute_logprob!!(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler,
Expand Down Expand Up @@ -436,7 +455,7 @@ function gibbs_step_inner(
state_local,
state_previous
)
current_state_local = recompute_logprob!!(
state_local = Turing.Inference.recompute_logprob!!(
rng,
model_local,
sampler_local,
Expand All @@ -450,7 +469,7 @@ function gibbs_step_inner(
rng,
model_local,
sampler_local,
current_state_local;
state_local;
kwargs...,
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
end
end

DynamicPPL.getspace(::ExternalSampler) = ()

"""
requires_unconstrained_space(sampler::ExternalSampler)

Expand Down
99 changes: 96 additions & 3 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,52 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
return transition_to_turing(parent(f), transition)
end

"""
getmodel(f)

Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

# FIXME: We'll have to overload this for every AD backend since some of the AD backends
# will cache certain parts of a given model, e.g. the tape, which results in a discrepancy
# between the primal (forward) and dual (backward).
"""
setmodel(f, model)

Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.

!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model)
return Accessors.@set f.ℓ = setmodel(f.ℓ, model)
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo

function varinfo(state::TuringState)
θ = getparams(getmodel(state.logdensity), state.state)
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
end

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)
end
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
Expand All @@ -33,13 +75,59 @@ function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo)
return Accessors.@set f.ℓ = setvarinfo(f.ℓ, varinfo)
end

"""
recompute_logprob!!(rng, model, sampler, state)

Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:ExternalSampler},
state,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should state have type TuringState, given function body assumes state has fields logdensity and state?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde any comment on this?

)
# Re-using the log-density function from the `state` and updating only the `model` field,
# since the `model` might now contain different conditioning values.
f = setmodel(state.logdensity, model)
# Recompute the log-probability with the new `model`.
state_inner = recompute_logprob!!(
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
)
return state_to_turing(f, state_inner)
end

function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedHMC.AbstractHMCSampler,
state::AdvancedHMC.HMCState,
)
# Construct hamiltionian.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
# Re-compute the log-probability and gradient.
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, state.transition.z.θ, state.transition.z.r
)
end

function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedMH.MetropolisHastings,
state::AdvancedMH.Transition,
)
logdensity = model.logdensity
return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params)
end

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler};
initial_state=nothing,
initial_params=nothing,
kwargs...
kwargs...,
)
alg = sampler_wrapper.alg
sampler = alg.sampler
Expand Down Expand Up @@ -69,7 +157,12 @@ function AbstractMCMC.step(
)
else
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, initial_state; initial_params, kwargs...
rng,
AbstractMCMC.LogDensityModel(f),
sampler,
initial_state;
initial_params,
kwargs...,
)
end
# Update the `state`
Expand All @@ -81,7 +174,7 @@ function AbstractMCMC.step(
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
Expand Down
Loading