Skip to content

Commit e914d7d

Browse files
committed
Merge remote-tracking branch 'origin/master' into dw/enzyme
2 parents 6b2b65f + 24d5556 commit e914d7d

File tree

11 files changed

+187
-429
lines changed

11 files changed

+187
-429
lines changed

HISTORY.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44

55
0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely.
66

7-
The new Gibbs sampler supports the same user-facing interface as the old one. However, given
8-
that the internals of it having been completely rewritten in a very different manner, there
9-
may be accidental breakage that we haven't anticipated. Please report any you find.
7+
The new Gibbs sampler currently supports the same user-facing interface as the old one, but the old constructors have been deprecated, and will be removed in the future. Also, given that the internals have been completely rewritten in a very different manner, there may be accidental breakage that we haven't anticipated. Please report any you find.
108

119
`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking.
1210

13-
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.
11+
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables thereof to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable.
1412

1513
Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`.
1614

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
5858
| `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler |
5959
| `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling |
6060
| `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling |
61-
| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | A "pseudo-sampler" to provide analytical conditionals to `Gibbs` |
6261
| `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo |
6362
| `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics |
6463
| `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo |

src/mcmc/gibbs.jl

Lines changed: 137 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,40 @@ function set_selector(x::RepeatSampler)
292292
end
293293
set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0))
294294

295+
to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)]
296+
# Any other value is assumed to be an iterable of VarNames and Symbols.
297+
to_varname_list(t) = collect(map(VarName, t))
298+
295299
"""
296300
Gibbs
297301
298302
A type representing a Gibbs sampler.
299303
304+
# Constructors
305+
306+
`Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single
307+
variable name per sampler, one can also give an iterable of variables, all of which are
308+
sampled by the same component sampler.
309+
310+
Each variable name can be given as either a `Symbol` or a `VarName`.
311+
312+
Some examples of valid constructors are:
313+
```julia
314+
Gibbs(:x => NUTS(), :y => MH())
315+
Gibbs(@varname(x) => NUTS(), @varname(y) => MH())
316+
Gibbs((@varname(x), :y) => NUTS(), :z => MH())
317+
```
318+
319+
Currently only variable names without indexing are supported, so for instance
320+
`Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future.
321+
300322
# Fields
301323
$(TYPEDFIELDS)
302324
"""
303-
struct Gibbs{V,A} <: InferenceAlgorithm
325+
struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <:
326+
InferenceAlgorithm
327+
# TODO(mhauru) Revisit whether A should have a fixed element type once
328+
# InferenceAlgorithm/Sampler types have been cleaned up.
304329
"varnames representing variables for each sampler"
305330
varnames::V
306331
"samplers for each entry in `varnames`"
@@ -310,40 +335,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm
310335
if length(varnames) != length(samplers)
311336
throw(ArgumentError("Number of varnames and samplers must match."))
312337
end
338+
313339
for spl in samplers
314340
if !isgibbscomponent(spl)
315341
msg = "All samplers must be valid Gibbs components, $(spl) is not."
316342
throw(ArgumentError(msg))
317343
end
318344
end
319-
return new{typeof(varnames),typeof(samplers)}(varnames, samplers)
320-
end
321-
end
322345

323-
to_varname(vn::VarName) = vn
324-
to_varname(s::Symbol) = VarName{s}()
325-
# Any other value is assumed to be an iterable.
326-
to_varname(t) = map(to_varname, collect(t))
327-
328-
# NamedTuple
329-
Gibbs(; algs...) = Gibbs(NamedTuple(algs))
330-
function Gibbs(algs::NamedTuple)
331-
return Gibbs(map(to_varname, keys(algs)), map(set_selector drop_space, values(algs)))
346+
# Ensure that samplers have the same selector, and that varnames are lists of
347+
# VarNames.
348+
samplers = tuple(map(set_selector drop_space, samplers)...)
349+
varnames = tuple(map(to_varname_list, varnames)...)
350+
return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers)
351+
end
332352
end
333353

334-
# AbstractDict
335-
function Gibbs(algs::AbstractDict)
336-
return Gibbs(
337-
map(to_varname, collect(keys(algs))), map(set_selector drop_space, values(algs))
338-
)
339-
end
340354
function Gibbs(algs::Pair...)
341-
return Gibbs(map(to_varname first, algs), map(set_selector drop_space last, algs))
355+
return Gibbs(map(first, algs), map(last, algs))
342356
end
343357

344358
# The below two constructors only provide backwards compatibility with the constructor of
345359
# the old Gibbs sampler. They are deprecated and will be removed in the future.
346-
function Gibbs(algs::InferenceAlgorithm...)
360+
function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...)
361+
algs = [alg1, other_algs...]
347362
varnames = map(algs) do alg
348363
space = getspace(alg)
349364
if (space isa VarName)
@@ -365,7 +380,11 @@ function Gibbs(algs::InferenceAlgorithm...)
365380
return Gibbs(varnames, map(set_selector drop_space, algs))
366381
end
367382

368-
function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...)
383+
function Gibbs(
384+
alg_with_iters1::Tuple{<:InferenceAlgorithm,Int},
385+
other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...,
386+
)
387+
algs_with_iters = [alg_with_iters1, other_algs_with_iters...]
369388
algs = Iterators.map(first, algs_with_iters)
370389
iters = Iterators.map(last, algs_with_iters)
371390
algs_duplicated = Iterators.flatten((
@@ -384,64 +403,80 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
384403
states::S
385404
end
386405

387-
_maybevec(x) = vec(x) # assume it's iterable
388-
_maybevec(x::Tuple) = [x...]
389-
_maybevec(x::VarName) = [x]
390-
_maybevec(x::Symbol) = [x]
391-
392406
varinfo(state::GibbsState) = state.vi
393407

394408
function DynamicPPL.initialstep(
395409
rng::Random.AbstractRNG,
396410
model::DynamicPPL.Model,
397411
spl::DynamicPPL.Sampler{<:Gibbs},
398-
vi_base::DynamicPPL.AbstractVarInfo;
412+
vi::DynamicPPL.AbstractVarInfo;
399413
initial_params=nothing,
400414
kwargs...,
401415
)
402416
alg = spl.alg
403417
varnames = alg.varnames
404418
samplers = alg.samplers
405419

406-
# Run the model once to get the varnames present + initial values to condition on.
407-
vi = DynamicPPL.VarInfo(rng, model)
408-
if initial_params !== nothing
409-
vi = DynamicPPL.unflatten(vi, initial_params)
410-
end
420+
vi, states = gibbs_initialstep_recursive(
421+
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
422+
)
423+
return Transition(model, vi), GibbsState(vi, states)
424+
end
411425

412-
# Initialise each component sampler in turn, collect all their states.
413-
states = []
414-
for (varnames_local, sampler_local) in zip(varnames, samplers)
415-
varnames_local = _maybevec(varnames_local)
416-
# Get the initial values for this component sampler.
417-
initial_params_local = if initial_params === nothing
418-
nothing
419-
else
420-
DynamicPPL.subset(vi, varnames_local)[:]
421-
end
426+
"""
427+
Take the first step of MCMC for the first component sampler, and call the same function
428+
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429+
and a tuple of initial states for all component samplers.
430+
"""
431+
function gibbs_initialstep_recursive(
432+
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
433+
)
434+
# End recursion
435+
if isempty(varname_vecs) && isempty(samplers)
436+
return vi, states
437+
end
422438

423-
# Construct the conditioned model.
424-
model_local, context_local = make_conditional(model, varnames_local, vi)
439+
varnames, varname_vecs_tail... = varname_vecs
440+
sampler, samplers_tail... = samplers
425441

426-
# Take initial step.
427-
_, new_state_local = AbstractMCMC.step(
428-
rng,
429-
model_local,
430-
sampler_local;
431-
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
432-
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
433-
initial_params=initial_params_local,
434-
kwargs...,
435-
)
436-
new_vi_local = varinfo(new_state_local)
437-
# Merge in any new variables that were introduced during the step, but that
438-
# were not in the domain of the current sampler.
439-
vi = merge(vi, get_global_varinfo(context_local))
440-
# Merge the new values for all the variables sampled by the current sampler.
441-
vi = merge(vi, new_vi_local)
442-
push!(states, new_state_local)
442+
# Get the initial values for this component sampler.
443+
initial_params_local = if initial_params === nothing
444+
nothing
445+
else
446+
DynamicPPL.subset(vi, varnames)[:]
443447
end
444-
return Transition(model, vi), GibbsState(vi, states)
448+
449+
# Construct the conditioned model.
450+
conditioned_model, context = make_conditional(model, varnames, vi)
451+
452+
# Take initial step with the current sampler.
453+
_, new_state = AbstractMCMC.step(
454+
rng,
455+
conditioned_model,
456+
sampler;
457+
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
458+
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
459+
initial_params=initial_params_local,
460+
kwargs...,
461+
)
462+
new_vi_local = varinfo(new_state)
463+
# Merge in any new variables that were introduced during the step, but that
464+
# were not in the domain of the current sampler.
465+
vi = merge(vi, get_global_varinfo(context))
466+
# Merge the new values for all the variables sampled by the current sampler.
467+
vi = merge(vi, new_vi_local)
468+
469+
states = (states..., new_state)
470+
return gibbs_initialstep_recursive(
471+
rng,
472+
model,
473+
varname_vecs_tail,
474+
samplers_tail,
475+
vi,
476+
states;
477+
initial_params=initial_params,
478+
kwargs...,
479+
)
445480
end
446481

447482
function AbstractMCMC.step(
@@ -458,17 +493,7 @@ function AbstractMCMC.step(
458493
states = state.states
459494
@assert length(samplers) == length(state.states)
460495

461-
# TODO: move this into a recursive function so we can unroll when reasonable?
462-
for index in 1:length(samplers)
463-
# Take the inner step.
464-
sampler_local = samplers[index]
465-
state_local = states[index]
466-
varnames_local = _maybevec(varnames[index])
467-
vi, new_state_local = gibbs_step_inner(
468-
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
469-
)
470-
states = Accessors.setindex(states, new_state_local, index)
471-
end
496+
vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
472497
return Transition(model, vi), GibbsState(vi, states)
473498
end
474499

@@ -592,19 +617,33 @@ function match_linking!!(varinfo_local, prev_state_local, model)
592617
return varinfo_local
593618
end
594619

595-
function gibbs_step_inner(
620+
"""
621+
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622+
function on the tail, until there are no more samplers left.
623+
"""
624+
function gibbs_step_recursive(
596625
rng::Random.AbstractRNG,
597626
model::DynamicPPL.Model,
598-
varnames_local,
599-
sampler_local,
600-
state_local,
601-
global_vi;
627+
varname_vecs,
628+
samplers,
629+
states,
630+
global_vi,
631+
new_states=();
602632
kwargs...,
603633
)
634+
# End recursion.
635+
if isempty(varname_vecs) && isempty(samplers) && isempty(states)
636+
return global_vi, new_states
637+
end
638+
639+
varnames, varname_vecs_tail... = varname_vecs
640+
sampler, samplers_tail... = samplers
641+
state, states_tail... = states
642+
604643
# Construct the conditional model and the varinfo that this sampler should use.
605-
model_local, context_local = make_conditional(model, varnames_local, global_vi)
606-
varinfo_local = subset(global_vi, varnames_local)
607-
varinfo_local = match_linking!!(varinfo_local, state_local, model)
644+
conditioned_model, context = make_conditional(model, varnames, global_vi)
645+
vi = subset(global_vi, varnames)
646+
vi = match_linking!!(vi, state, model)
608647

609648
# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
610649
# sampled by other samplers, we don't need to `setparams`, but could rather simply
@@ -615,18 +654,25 @@ function gibbs_step_inner(
615654
# going to be a significant expense anyway.
616655
# Set the state of the current sampler, accounting for any changes made by other
617656
# samplers.
618-
state_local = setparams_varinfo!!(
619-
model_local, sampler_local, state_local, varinfo_local
620-
)
657+
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)
621658

622659
# Take a step with the local sampler.
623-
new_state_local = last(
624-
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
625-
)
660+
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))
626661

627-
new_vi_local = varinfo(new_state_local)
662+
new_vi_local = varinfo(new_state)
628663
# Merge the latest values for all the variables in the current sampler.
629-
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
664+
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
630665
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))
631-
return new_global_vi, new_state_local
666+
667+
new_states = (new_states..., new_state)
668+
return gibbs_step_recursive(
669+
rng,
670+
model,
671+
varname_vecs_tail,
672+
samplers_tail,
673+
states_tail,
674+
new_global_vi,
675+
new_states;
676+
kwargs...,
677+
)
632678
end

test/dynamicppl/compiler.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ const gdemo_default = gdemo_d()
5454

5555
smc = SMC()
5656
pg = PG(10)
57-
gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10))
57+
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))
5858

5959
chn_s = sample(testbb(obs), smc, 1000)
6060
chn_p = sample(testbb(obs), pg, 2000)
@@ -81,7 +81,7 @@ const gdemo_default = gdemo_d()
8181
return s, m
8282
end
8383

84-
gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8))
84+
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8))
8585
chain = sample(fggibbstest(xs), gibbs, 2)
8686
end
8787
@testset "new grammar" begin
@@ -177,7 +177,7 @@ const gdemo_default = gdemo_d()
177177
end
178178

179179
@testset "sample" begin
180-
alg = Gibbs(; m=HMC(0.2, 3), s=PG(10))
180+
alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10))
181181
chn = sample(gdemo_default, alg, 1000)
182182
end
183183

0 commit comments

Comments
 (0)