Skip to content

Commit

Permalink
Replace Gibbs inner loop with recursion (#2464)
Browse files Browse the repository at this point in the history
* Use recursion in gibbs_inner_step

* Remove unnecessary initialisation in Gibbs

We were doing work that was already done by the caller of initialstep.

* variable naming / destructuring (#2465)

* Variable naming, destructuring

* Tuple -> Vec

* Reviewdog code style suggestions

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: Penelope Yong <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 15, 2025
1 parent 7d6f8ed commit 11944c9
Showing 1 changed file with 96 additions and 63 deletions.
159 changes: 96 additions & 63 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,52 +409,74 @@ function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
vi_base::DynamicPPL.AbstractVarInfo;
vi::DynamicPPL.AbstractVarInfo;
initial_params=nothing,
kwargs...,
)
alg = spl.alg
varnames = alg.varnames
samplers = alg.samplers

# Run the model once to get the varnames present + initial values to condition on.
vi = DynamicPPL.VarInfo(rng, model)
if initial_params !== nothing
vi = DynamicPPL.unflatten(vi, initial_params)
end
vi, states = gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
)
return Transition(model, vi), GibbsState(vi, states)
end

# Initialise each component sampler in turn, collect all their states.
states = []
for (varnames_local, sampler_local) in zip(varnames, samplers)
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames_local)[:]
end
"""
Take the first step of MCMC for the first component sampler, and call the same function
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
and a tuple of initial states for all component samplers.
"""
function gibbs_initialstep_recursive(
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
)
# End recursion
if isempty(varname_vecs) && isempty(samplers)
return vi, states
end

# Construct the conditioned model.
model_local, context_local = make_conditional(model, varnames_local, vi)
varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers

# Take initial step.
_, new_state_local = AbstractMCMC.step(
rng,
model_local,
sampler_local;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state_local)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context_local))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)
push!(states, new_state_local)
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames)[:]
end
return Transition(model, vi), GibbsState(vi, states)

# Construct the conditioned model.
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(
rng,
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)

states = (states..., new_state)
return gibbs_initialstep_recursive(
rng,
model,
varname_vecs_tail,
samplers_tail,
vi,
states;
initial_params=initial_params,
kwargs...,
)
end

function AbstractMCMC.step(
Expand All @@ -471,17 +493,7 @@ function AbstractMCMC.step(
states = state.states
@assert length(samplers) == length(state.states)

# TODO: move this into a recursive function so we can unroll when reasonable?
for index in 1:length(samplers)
# Take the inner step.
sampler_local = samplers[index]
state_local = states[index]
varnames_local = varnames[index]
vi, new_state_local = gibbs_step_inner(
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
)
states = Accessors.setindex(states, new_state_local, index)
end
vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
return Transition(model, vi), GibbsState(vi, states)
end

Expand Down Expand Up @@ -605,19 +617,33 @@ function match_linking!!(varinfo_local, prev_state_local, model)
return varinfo_local
end

function gibbs_step_inner(
"""
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
function on the tail, until there are no more samplers left.
"""
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varnames_local,
sampler_local,
state_local,
global_vi;
varname_vecs,
samplers,
states,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varname_vecs) && isempty(samplers) && isempty(states)
return global_vi, new_states
end

varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers
state, states_tail... = states

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)
varinfo_local = match_linking!!(varinfo_local, state_local, model)
conditioned_model, context = make_conditional(model, varnames, global_vi)
vi = subset(global_vi, varnames)
vi = match_linking!!(vi, state, model)

# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
# sampled by other samplers, we don't need to `setparams`, but could rather simply
Expand All @@ -628,18 +654,25 @@ function gibbs_step_inner(
# going to be a significant expense anyway.
# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
)
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)

# Take a step with the local sampler.
new_state_local = last(
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
)
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))

new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))
return new_global_vi, new_state_local

new_states = (new_states..., new_state)
return gibbs_step_recursive(
rng,
model,
varname_vecs_tail,
samplers_tail,
states_tail,
new_global_vi,
new_states;
kwargs...,
)
end

0 comments on commit 11944c9

Please sign in to comment.