@@ -292,15 +292,40 @@ function set_selector(x::RepeatSampler)
292
292
end
293
293
set_selector (x:: InferenceAlgorithm ) = DynamicPPL. Sampler (x, DynamicPPL. Selector (0 ))
294
294
295
+ to_varname_list (x:: Union{VarName,Symbol} ) = [VarName (x)]
296
+ # Any other value is assumed to be an iterable of VarNames and Symbols.
297
+ to_varname_list (t) = collect (map (VarName, t))
298
+
295
299
"""
296
300
Gibbs
297
301
298
302
A type representing a Gibbs sampler.
299
303
304
+ # Constructors
305
+
306
+ `Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single
307
+ variable name per sampler, one can also give an iterable of variables, all of which are
308
+ sampled by the same component sampler.
309
+
310
+ Each variable name can be given as either a `Symbol` or a `VarName`.
311
+
312
+ Some examples of valid constructors are:
313
+ ```julia
314
+ Gibbs(:x => NUTS(), :y => MH())
315
+ Gibbs(@varname(x) => NUTS(), @varname(y) => MH())
316
+ Gibbs((@varname(x), :y) => NUTS(), :z => MH())
317
+ ```
318
+
319
+ Currently only variable names without indexing are supported, so for instance
320
+ `Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future.
321
+
300
322
# Fields
301
323
$(TYPEDFIELDS)
302
324
"""
303
- struct Gibbs{V,A} <: InferenceAlgorithm
325
+ struct Gibbs{N,V<: NTuple{N,AbstractVector{<:VarName}} ,A<: NTuple{N,Any} } < :
326
+ InferenceAlgorithm
327
+ # TODO (mhauru) Revisit whether A should have a fixed element type once
328
+ # InferenceAlgorithm/Sampler types have been cleaned up.
304
329
" varnames representing variables for each sampler"
305
330
varnames:: V
306
331
" samplers for each entry in `varnames`"
@@ -310,40 +335,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm
310
335
if length (varnames) != length (samplers)
311
336
throw (ArgumentError (" Number of varnames and samplers must match." ))
312
337
end
338
+
313
339
for spl in samplers
314
340
if ! isgibbscomponent (spl)
315
341
msg = " All samplers must be valid Gibbs components, $(spl) is not."
316
342
throw (ArgumentError (msg))
317
343
end
318
344
end
319
- return new {typeof(varnames),typeof(samplers)} (varnames, samplers)
320
- end
321
- end
322
345
323
- to_varname (vn:: VarName ) = vn
324
- to_varname (s:: Symbol ) = VarName {s} ()
325
- # Any other value is assumed to be an iterable.
326
- to_varname (t) = map (to_varname, collect (t))
327
-
328
- # NamedTuple
329
- Gibbs (; algs... ) = Gibbs (NamedTuple (algs))
330
- function Gibbs (algs:: NamedTuple )
331
- return Gibbs (map (to_varname, keys (algs)), map (set_selector ∘ drop_space, values (algs)))
346
+ # Ensure that samplers have the same selector, and that varnames are lists of
347
+ # VarNames.
348
+ samplers = tuple (map (set_selector ∘ drop_space, samplers)... )
349
+ varnames = tuple (map (to_varname_list, varnames)... )
350
+ return new {length(samplers),typeof(varnames),typeof(samplers)} (varnames, samplers)
351
+ end
332
352
end
333
353
334
- # AbstractDict
335
- function Gibbs (algs:: AbstractDict )
336
- return Gibbs (
337
- map (to_varname, collect (keys (algs))), map (set_selector ∘ drop_space, values (algs))
338
- )
339
- end
340
354
function Gibbs (algs:: Pair... )
341
- return Gibbs (map (to_varname ∘ first, algs), map (set_selector ∘ drop_space ∘ last, algs))
355
+ return Gibbs (map (first, algs), map (last, algs))
342
356
end
343
357
344
358
# The below two constructors only provide backwards compatibility with the constructor of
345
359
# the old Gibbs sampler. They are deprecated and will be removed in the future.
346
- function Gibbs (algs:: InferenceAlgorithm... )
360
+ function Gibbs (alg1:: InferenceAlgorithm , other_algs:: InferenceAlgorithm... )
361
+ algs = [alg1, other_algs... ]
347
362
varnames = map (algs) do alg
348
363
space = getspace (alg)
349
364
if (space isa VarName)
@@ -365,7 +380,11 @@ function Gibbs(algs::InferenceAlgorithm...)
365
380
return Gibbs (varnames, map (set_selector ∘ drop_space, algs))
366
381
end
367
382
368
- function Gibbs (algs_with_iters:: Tuple{<:InferenceAlgorithm,Int} ...)
383
+ function Gibbs (
384
+ alg_with_iters1:: Tuple{<:InferenceAlgorithm,Int} ,
385
+ other_algs_with_iters:: Tuple{<:InferenceAlgorithm,Int} ...,
386
+ )
387
+ algs_with_iters = [alg_with_iters1, other_algs_with_iters... ]
369
388
algs = Iterators. map (first, algs_with_iters)
370
389
iters = Iterators. map (last, algs_with_iters)
371
390
algs_duplicated = Iterators. flatten ((
@@ -384,64 +403,80 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
384
403
states:: S
385
404
end
386
405
387
- _maybevec (x) = vec (x) # assume it's iterable
388
- _maybevec (x:: Tuple ) = [x... ]
389
- _maybevec (x:: VarName ) = [x]
390
- _maybevec (x:: Symbol ) = [x]
391
-
392
406
varinfo (state:: GibbsState ) = state. vi
393
407
394
408
function DynamicPPL. initialstep (
395
409
rng:: Random.AbstractRNG ,
396
410
model:: DynamicPPL.Model ,
397
411
spl:: DynamicPPL.Sampler{<:Gibbs} ,
398
- vi_base :: DynamicPPL.AbstractVarInfo ;
412
+ vi :: DynamicPPL.AbstractVarInfo ;
399
413
initial_params= nothing ,
400
414
kwargs... ,
401
415
)
402
416
alg = spl. alg
403
417
varnames = alg. varnames
404
418
samplers = alg. samplers
405
419
406
- # Run the model once to get the varnames present + initial values to condition on.
407
- vi = DynamicPPL . VarInfo ( rng, model)
408
- if initial_params != = nothing
409
- vi = DynamicPPL . unflatten (vi, initial_params )
410
- end
420
+ vi, states = gibbs_initialstep_recursive (
421
+ rng, model, varnames, samplers, vi; initial_params = initial_params, kwargs ...
422
+ )
423
+ return Transition (model, vi), GibbsState (vi, states )
424
+ end
411
425
412
- # Initialise each component sampler in turn, collect all their states.
413
- states = []
414
- for (varnames_local, sampler_local) in zip (varnames, samplers)
415
- varnames_local = _maybevec (varnames_local)
416
- # Get the initial values for this component sampler.
417
- initial_params_local = if initial_params === nothing
418
- nothing
419
- else
420
- DynamicPPL. subset (vi, varnames_local)[:]
421
- end
426
+ """
427
+ Take the first step of MCMC for the first component sampler, and call the same function
428
+ recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429
+ and a tuple of initial states for all component samplers.
430
+ """
431
+ function gibbs_initialstep_recursive (
432
+ rng, model, varname_vecs, samplers, vi, states= (); initial_params= nothing , kwargs...
433
+ )
434
+ # End recursion
435
+ if isempty (varname_vecs) && isempty (samplers)
436
+ return vi, states
437
+ end
422
438
423
- # Construct the conditioned model.
424
- model_local, context_local = make_conditional (model, varnames_local, vi)
439
+ varnames, varname_vecs_tail ... = varname_vecs
440
+ sampler, samplers_tail ... = samplers
425
441
426
- # Take initial step.
427
- _, new_state_local = AbstractMCMC. step (
428
- rng,
429
- model_local,
430
- sampler_local;
431
- # FIXME : This will cause issues if the sampler expects initial params in unconstrained space.
432
- # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
433
- initial_params= initial_params_local,
434
- kwargs... ,
435
- )
436
- new_vi_local = varinfo (new_state_local)
437
- # Merge in any new variables that were introduced during the step, but that
438
- # were not in the domain of the current sampler.
439
- vi = merge (vi, get_global_varinfo (context_local))
440
- # Merge the new values for all the variables sampled by the current sampler.
441
- vi = merge (vi, new_vi_local)
442
- push! (states, new_state_local)
442
+ # Get the initial values for this component sampler.
443
+ initial_params_local = if initial_params === nothing
444
+ nothing
445
+ else
446
+ DynamicPPL. subset (vi, varnames)[:]
443
447
end
444
- return Transition (model, vi), GibbsState (vi, states)
448
+
449
+ # Construct the conditioned model.
450
+ conditioned_model, context = make_conditional (model, varnames, vi)
451
+
452
+ # Take initial step with the current sampler.
453
+ _, new_state = AbstractMCMC. step (
454
+ rng,
455
+ conditioned_model,
456
+ sampler;
457
+ # FIXME : This will cause issues if the sampler expects initial params in unconstrained space.
458
+ # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
459
+ initial_params= initial_params_local,
460
+ kwargs... ,
461
+ )
462
+ new_vi_local = varinfo (new_state)
463
+ # Merge in any new variables that were introduced during the step, but that
464
+ # were not in the domain of the current sampler.
465
+ vi = merge (vi, get_global_varinfo (context))
466
+ # Merge the new values for all the variables sampled by the current sampler.
467
+ vi = merge (vi, new_vi_local)
468
+
469
+ states = (states... , new_state)
470
+ return gibbs_initialstep_recursive (
471
+ rng,
472
+ model,
473
+ varname_vecs_tail,
474
+ samplers_tail,
475
+ vi,
476
+ states;
477
+ initial_params= initial_params,
478
+ kwargs... ,
479
+ )
445
480
end
446
481
447
482
function AbstractMCMC. step (
@@ -458,17 +493,7 @@ function AbstractMCMC.step(
458
493
states = state. states
459
494
@assert length (samplers) == length (state. states)
460
495
461
- # TODO : move this into a recursive function so we can unroll when reasonable?
462
- for index in 1 : length (samplers)
463
- # Take the inner step.
464
- sampler_local = samplers[index]
465
- state_local = states[index]
466
- varnames_local = _maybevec (varnames[index])
467
- vi, new_state_local = gibbs_step_inner (
468
- rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
469
- )
470
- states = Accessors. setindex (states, new_state_local, index)
471
- end
496
+ vi, states = gibbs_step_recursive (rng, model, varnames, samplers, states, vi; kwargs... )
472
497
return Transition (model, vi), GibbsState (vi, states)
473
498
end
474
499
@@ -592,19 +617,33 @@ function match_linking!!(varinfo_local, prev_state_local, model)
592
617
return varinfo_local
593
618
end
594
619
595
- function gibbs_step_inner (
620
+ """
621
+ Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622
+ function on the tail, until there are no more samplers left.
623
+ """
624
+ function gibbs_step_recursive (
596
625
rng:: Random.AbstractRNG ,
597
626
model:: DynamicPPL.Model ,
598
- varnames_local,
599
- sampler_local,
600
- state_local,
601
- global_vi;
627
+ varname_vecs,
628
+ samplers,
629
+ states,
630
+ global_vi,
631
+ new_states= ();
602
632
kwargs... ,
603
633
)
634
+ # End recursion.
635
+ if isempty (varname_vecs) && isempty (samplers) && isempty (states)
636
+ return global_vi, new_states
637
+ end
638
+
639
+ varnames, varname_vecs_tail... = varname_vecs
640
+ sampler, samplers_tail... = samplers
641
+ state, states_tail... = states
642
+
604
643
# Construct the conditional model and the varinfo that this sampler should use.
605
- model_local, context_local = make_conditional (model, varnames_local , global_vi)
606
- varinfo_local = subset (global_vi, varnames_local )
607
- varinfo_local = match_linking!! (varinfo_local, state_local , model)
644
+ conditioned_model, context = make_conditional (model, varnames , global_vi)
645
+ vi = subset (global_vi, varnames )
646
+ vi = match_linking!! (vi, state , model)
608
647
609
648
# TODO (mhauru) The below may be overkill. If the varnames for this sampler are not
610
649
# sampled by other samplers, we don't need to `setparams`, but could rather simply
@@ -615,18 +654,25 @@ function gibbs_step_inner(
615
654
# going to be a significant expense anyway.
616
655
# Set the state of the current sampler, accounting for any changes made by other
617
656
# samplers.
618
- state_local = setparams_varinfo!! (
619
- model_local, sampler_local, state_local, varinfo_local
620
- )
657
+ state = setparams_varinfo!! (conditioned_model, sampler, state, vi)
621
658
622
659
# Take a step with the local sampler.
623
- new_state_local = last (
624
- AbstractMCMC. step (rng, model_local, sampler_local, state_local; kwargs... )
625
- )
660
+ new_state = last (AbstractMCMC. step (rng, conditioned_model, sampler, state; kwargs... ))
626
661
627
- new_vi_local = varinfo (new_state_local )
662
+ new_vi_local = varinfo (new_state )
628
663
# Merge the latest values for all the variables in the current sampler.
629
- new_global_vi = merge (get_global_varinfo (context_local ), new_vi_local)
664
+ new_global_vi = merge (get_global_varinfo (context ), new_vi_local)
630
665
new_global_vi = setlogp!! (new_global_vi, getlogp (new_vi_local))
631
- return new_global_vi, new_state_local
666
+
667
+ new_states = (new_states... , new_state)
668
+ return gibbs_step_recursive (
669
+ rng,
670
+ model,
671
+ varname_vecs_tail,
672
+ samplers_tail,
673
+ states_tail,
674
+ new_global_vi,
675
+ new_states;
676
+ kwargs... ,
677
+ )
632
678
end
0 commit comments