Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and improvements to experimental Gibbs #2231

Merged
merged 40 commits into from
Jul 16, 2024
Merged
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0b2279f
moved new Gibbs tests all into a single block
torfjelde Apr 23, 2024
dcad548
initial work on making Gibbs work with `externalsampler`
torfjelde Apr 23, 2024
8e2d7be
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde May 18, 2024
4a609cb
removed references to Setfield.jl
torfjelde May 18, 2024
fc21894
fixed crucial bug in experimental Gibbs sampler
torfjelde May 18, 2024
9910962
added ground-truth comparison for Gibbs sampler on demo models
torfjelde May 18, 2024
b3a4692
added convenience method for performing two sample KS test
torfjelde May 18, 2024
3e17efc
use thinning to avoid OOM issues
torfjelde May 20, 2024
429fc8f
removed incredibly slow testset that didn't really add much
torfjelde May 20, 2024
f6af20e
removed now-redundant testset
torfjelde May 20, 2024
065eef6
use Anderson-Darling test instead of Kolomogorov-Smirnov to better
torfjelde May 20, 2024
c2d23e5
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 4, 2024
99f28f9
more work on testing
torfjelde Jun 6, 2024
6df1ccc
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 6, 2024
b6a907e
fixed tests
torfjelde Jun 6, 2024
a4e223e
Merge remote-tracking branch 'origin/torfjelde/gibbs-new-improv' into…
torfjelde Jun 6, 2024
e1e7386
make failures of `two_sample_ad_tests` a bit more informative
torfjelde Jun 6, 2024
be1ec7f
make failrues of `two_sample_ad_test` produce more informative logs
torfjelde Jun 6, 2024
5f36446
additional information upon `two_sample_ad_test` failure
torfjelde Jun 6, 2024
3be8f8b
rename `two_sample_ad_test` to `two_sample_test` and use KS test instead
torfjelde Jun 6, 2024
dbaf447
added minor test for externalsampler usage
torfjelde Jun 16, 2024
f44c407
also test AdvancedHMC samplers with Gibbs
torfjelde Jun 16, 2024
dd86cfa
forgot to add updates to src/mcmc/abstractmcmc.jl in previous commits
torfjelde Jun 17, 2024
4160577
Merge remote-tracking branch 'origin/torfjelde/gibbs-new-improv' into…
torfjelde Jun 17, 2024
2a7d85b
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 17, 2024
bdc61fe
removed usage of `timeit_testset` macro
torfjelde Jun 17, 2024
d76243e
added temporary fix for externalsampler that needs to be removed once
torfjelde Jun 17, 2024
14f5c89
minor reorg of two testsets
torfjelde Jun 17, 2024
5893d54
set random seeds more aggressively in an attempt to make tests more r…
torfjelde Jun 18, 2024
4a2cea2
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jun 18, 2024
4f30ea5
removed hack, awaiting PR to DynamicPPL
torfjelde Jun 18, 2024
414a077
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jun 26, 2024
89bc2e1
renamed `_getmodel` to `getmodel`, `_setmodel` to `setmodel`, and
torfjelde Jun 26, 2024
3d3c944
missed some instances during rnenaming
torfjelde Jun 26, 2024
e1f1a0e
fixed missing merge in initial step for experimental `Gibbs`
torfjelde Jul 10, 2024
7c4368e
Always reconstruct `ADGradientWrapper` using the `adype` available in…
torfjelde Jul 15, 2024
06357c6
Test Gibbs with different adtype in externalsampler to ensure that works
torfjelde Jul 15, 2024
02f9fad
Update Project.toml
yebai Jul 15, 2024
30ab9e0
Update Project.toml
yebai Jul 15, 2024
d40d82b
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Test Gibbs with different adtype in externalsampler to ensure that works
torfjelde committed Jul 15, 2024
commit 06357c6e8cf7c3bf3d29f7a3fc58ff919878b6e4
6 changes: 5 additions & 1 deletion test/experimental/gibbs.jl
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@ using Random
using Test
using Turing
using Turing.Inference: AdvancedHMC, AdvancedMH
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff

function check_transition_varnames(
transition::Turing.Inference.Transition,
@@ -249,7 +251,9 @@ has_dot_assume(::Model) = true
model = demo_gibbs_external()
samplers_inner = [
externalsampler(AdvancedMH.RWMH(1)),
externalsampler(AdvancedHMC.HMC(1e-1, 32)),
externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()),
externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()),
externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)),
]
@testset "$(sampler_inner)" for sampler_inner in samplers_inner
sampler = Turing.Experimental.Gibbs(

Unchanged files with check annotations Beta

# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, value), vi

Check warning on line 81 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L81

Added line #L81 was not covered by tests
end
# Otherwise, falls back to the default behavior.
)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, value), vi

Check warning on line 94 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end
# Otherwise, falls back to the default behavior.
Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned.
"""
function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo)
return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))

Check warning on line 147 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L147

Added line #L147 was not covered by tests
end
function condition_gibbs(

Check warning on line 149 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L149

Added line #L149 was not covered by tests
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo,
varinfos::DynamicPPL.AbstractVarInfo...
)
return condition_gibbs(condition_gibbs(context, varinfo), varinfos...)

Check warning on line 154 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L154

Added line #L154 was not covered by tests
end
# Allow calling this on a `DynamicPPL.Model` directly.
function condition_gibbs(model::DynamicPPL.Model, values...)
end
# TODO: Remove when no longer needed.
DynamicPPL.getspace(::Gibbs) = ()

Check warning on line 242 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L242

Added line #L242 was not covered by tests
struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
vi::V
# Simple way of setting the initial parameters: set them in the `vi_base`
# if they are given so they propagate to the subset varinfos used by each sampler.
if initial_params !== nothing
vi_base = DynamicPPL.unflatten(vi_base, initial_params)

Check warning on line 271 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L270-L271

Added lines #L270 - L271 were not covered by tests
end
# Create the varinfos for each sampler.
varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames)
initial_params_all = if initial_params === nothing
fill(nothing, length(varnames))

Check warning on line 277 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L276-L277

Added lines #L276 - L277 were not covered by tests
else
# Extract from the `vi_base`, which should have the values set correctly from above.
map(vi -> vi[:], varinfos)

Check warning on line 280 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L280

Added line #L280 was not covered by tests
end
# 2. Construct a varinfo for every vn + sampler combo.
states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local

Check warning on line 284 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L284

Added line #L284 was not covered by tests
# Construct the conditional model.
model_local = make_conditional(model, varinfo_local, varinfos)
# Take initial step.
new_state_local = last(AbstractMCMC.step(

Check warning on line 289 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L289

Added line #L289 was not covered by tests
rng, model_local, sampler_local;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
varinfos = map(last, states_and_varinfos)
# Update the base varinfo from the first varinfo and replace it.
varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1)

Check warning on line 311 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L311

Added line #L311 was not covered by tests
# Merge the updated initial varinfo with the rest of the varinfos + update the logp.
vi = DynamicPPL.setlogp!!(
reduce(merge, varinfos_new),
end
# TODO: Remove `rng`?
function Turing.Inference.recompute_logprob!!(

Check warning on line 392 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L392

Added line #L392 was not covered by tests
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler,
state_local,
state_previous
)
state_local = Turing.Inference.recompute_logprob!!(

Check warning on line 458 in src/experimental/gibbs.jl

Codecov / codecov/patch

src/experimental/gibbs.jl#L458

Added line #L458 was not covered by tests
rng,
model_local,
sampler_local,
end
end
DynamicPPL.getspace(::ExternalSampler) = ()

Check warning on line 117 in src/mcmc/Inference.jl

Codecov / codecov/patch

src/mcmc/Inference.jl#L117

Added line #L117 was not covered by tests
"""
requires_unconstrained_space(sampler::ExternalSampler)
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

Check warning on line 26 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L25-L26

Added lines #L25 - L26 were not covered by tests
"""
setmodel(f, model[, adtype])
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(

Check warning on line 38 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L38

Added line #L38 was not covered by tests
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(adtype, setmodel(parent(f), model))

Check warning on line 51 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L51

Added line #L51 was not covered by tests
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model

Check warning on line 54 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L53-L54

Added lines #L53 - L54 were not covered by tests
end
function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))

Check warning on line 58 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L57-L58

Added lines #L57 - L58 were not covered by tests
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo

Check warning on line 60 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L60

Added line #L60 was not covered by tests
function varinfo(state::TuringState)
θ = getparams(getmodel(state.logdensity), state.state)

Check warning on line 63 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)

Check warning on line 65 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L65

Added line #L65 was not covered by tests
end
# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)

Check warning on line 73 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L72-L73

Added lines #L72 - L73 were not covered by tests
end
getstats(transition::AdvancedHMC.Transition) = transition.stat
Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(

Check warning on line 92 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L92

Added line #L92 was not covered by tests
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:ExternalSampler},
)
# Re-using the log-density function from the `state` and updating only the `model` field,
# since the `model` might now contain different conditioning values.
f = setmodel(state.logdensity, model, sampler.alg.adtype)

Check warning on line 100 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L100

Added line #L100 was not covered by tests
# Recompute the log-probability with the new `model`.
state_inner = recompute_logprob!!(

Check warning on line 102 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L102

Added line #L102 was not covered by tests
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
)
return state_to_turing(f, state_inner)

Check warning on line 105 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L105

Added line #L105 was not covered by tests
end
function recompute_logprob!!(

Check warning on line 108 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L108

Added line #L108 was not covered by tests
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedHMC.AbstractHMCSampler,
state::AdvancedHMC.HMCState,
)
# Construct hamiltionian.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)

Check warning on line 115 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L115

Added line #L115 was not covered by tests
# Re-compute the log-probability and gradient.
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(

Check warning on line 117 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L117

Added line #L117 was not covered by tests
hamiltonian, state.transition.z.θ, state.transition.z.r
)
end
function recompute_logprob!!(

Check warning on line 122 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L122

Added line #L122 was not covered by tests
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedMH.MetropolisHastings,
state::AdvancedMH.Transition,
)
logdensity = model.logdensity
return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params)

Check warning on line 129 in src/mcmc/abstractmcmc.jl

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end
# TODO: Do we also support `resume`, etc?