From 7d6f8ed53b59047c86fb1c09c9d593ca30250a60 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 8 Jan 2025 16:43:59 +0000 Subject: [PATCH] Rework Gibbs constructors (#2456) * Rework Gibbs constructors, and remove the dead test/experimental/gibbs.jl * Update HISTORY.md * Clarify docstring * Remove unnecessary _maybecollect in gibbs.jl * Fix a bug * Fix more Gibbs constructors in tests * Improve HISTORY.md note Co-authored-by: Penelope Yong * Apply proposals from code review * Add type bounds to Gibbs type parameters * Style improvements to gibbs.jl * Fix method ambiguity * Modify type signature of Gibbs --------- Co-authored-by: Penelope Yong --- HISTORY.md | 6 +- src/mcmc/gibbs.jl | 71 +++++---- test/dynamicppl/compiler.jl | 6 +- test/experimental/gibbs.jl | 271 ----------------------------------- test/mcmc/Inference.jl | 12 +- test/mcmc/ess.jl | 4 +- test/mcmc/gibbs.jl | 59 ++++---- test/mcmc/hmc.jl | 10 +- test/mcmc/mh.jl | 6 +- test/skipped/explicit_ret.jl | 2 +- 10 files changed, 92 insertions(+), 355 deletions(-) delete mode 100644 test/experimental/gibbs.jl diff --git a/HISTORY.md b/HISTORY.md index ff50fb779..64be69106 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,13 +4,11 @@ 0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely. -The new Gibbs sampler supports the same user-facing interface as the old one. However, given -that the internals of it having been completely rewritten in a very different manner, there -may be accidental breakage that we haven't anticipated. Please report any you find. +The new Gibbs sampler currently supports the same user-facing interface as the old one, but the old constructors have been deprecated, and will be removed in the future. Also, given that the internals have been completely rewritten in a very different manner, there may be accidental breakage that we haven't anticipated. Please report any you find. `GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking. -The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable. +The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables thereof to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable. Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`. diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 0f2c78ebe..ada5f611b 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -292,15 +292,40 @@ function set_selector(x::RepeatSampler) end set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0)) +to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)] +# Any other value is assumed to be an iterable of VarNames and Symbols. +to_varname_list(t) = collect(map(VarName, t)) + """ Gibbs A type representing a Gibbs sampler. +# Constructors + +`Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single +variable name per sampler, one can also give an iterable of variables, all of which are +sampled by the same component sampler. + +Each variable name can be given as either a `Symbol` or a `VarName`. + +Some examples of valid constructors are: +```julia +Gibbs(:x => NUTS(), :y => MH()) +Gibbs(@varname(x) => NUTS(), @varname(y) => MH()) +Gibbs((@varname(x), :y) => NUTS(), :z => MH()) +``` + +Currently only variable names without indexing are supported, so for instance +`Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future. + # Fields $(TYPEDFIELDS) """ -struct Gibbs{V,A} <: InferenceAlgorithm +struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: + InferenceAlgorithm + # TODO(mhauru) Revisit whether A should have a fixed element type once + # InferenceAlgorithm/Sampler types have been cleaned up. "varnames representing variables for each sampler" varnames::V "samplers for each entry in `varnames`" @@ -310,40 +335,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm if length(varnames) != length(samplers) throw(ArgumentError("Number of varnames and samplers must match.")) end + for spl in samplers if !isgibbscomponent(spl) msg = "All samplers must be valid Gibbs components, $(spl) is not." throw(ArgumentError(msg)) end end - return new{typeof(varnames),typeof(samplers)}(varnames, samplers) - end -end - -to_varname(vn::VarName) = vn -to_varname(s::Symbol) = VarName{s}() -# Any other value is assumed to be an iterable. -to_varname(t) = map(to_varname, collect(t)) -# NamedTuple -Gibbs(; algs...) = Gibbs(NamedTuple(algs)) -function Gibbs(algs::NamedTuple) - return Gibbs(map(to_varname, keys(algs)), map(set_selector ∘ drop_space, values(algs))) + # Ensure that samplers have the same selector, and that varnames are lists of + # VarNames. + samplers = tuple(map(set_selector ∘ drop_space, samplers)...) + varnames = tuple(map(to_varname_list, varnames)...) + return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers) + end end -# AbstractDict -function Gibbs(algs::AbstractDict) - return Gibbs( - map(to_varname, collect(keys(algs))), map(set_selector ∘ drop_space, values(algs)) - ) -end function Gibbs(algs::Pair...) - return Gibbs(map(to_varname ∘ first, algs), map(set_selector ∘ drop_space ∘ last, algs)) + return Gibbs(map(first, algs), map(last, algs)) end # The below two constructors only provide backwards compatibility with the constructor of # the old Gibbs sampler. They are deprecated and will be removed in the future. -function Gibbs(algs::InferenceAlgorithm...) +function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...) + algs = [alg1, other_algs...] varnames = map(algs) do alg space = getspace(alg) if (space isa VarName) @@ -365,7 +380,11 @@ function Gibbs(algs::InferenceAlgorithm...) return Gibbs(varnames, map(set_selector ∘ drop_space, algs)) end -function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...) +function Gibbs( + alg_with_iters1::Tuple{<:InferenceAlgorithm,Int}, + other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}..., +) + algs_with_iters = [alg_with_iters1, other_algs_with_iters...] algs = Iterators.map(first, algs_with_iters) iters = Iterators.map(last, algs_with_iters) algs_duplicated = Iterators.flatten(( @@ -384,11 +403,6 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} states::S end -_maybevec(x) = vec(x) # assume it's iterable -_maybevec(x::Tuple) = [x...] -_maybevec(x::VarName) = [x] -_maybevec(x::Symbol) = [x] - varinfo(state::GibbsState) = state.vi function DynamicPPL.initialstep( @@ -412,7 +426,6 @@ function DynamicPPL.initialstep( # Initialise each component sampler in turn, collect all their states. states = [] for (varnames_local, sampler_local) in zip(varnames, samplers) - varnames_local = _maybevec(varnames_local) # Get the initial values for this component sampler. initial_params_local = if initial_params === nothing nothing @@ -463,7 +476,7 @@ function AbstractMCMC.step( # Take the inner step. sampler_local = samplers[index] state_local = states[index] - varnames_local = _maybevec(varnames[index]) + varnames_local = varnames[index] vi, new_state_local = gibbs_step_inner( rng, model, varnames_local, sampler_local, state_local, vi; kwargs... ) diff --git a/test/dynamicppl/compiler.jl b/test/dynamicppl/compiler.jl index 7939c7beb..7f5726614 100644 --- a/test/dynamicppl/compiler.jl +++ b/test/dynamicppl/compiler.jl @@ -54,7 +54,7 @@ const gdemo_default = gdemo_d() smc = SMC() pg = PG(10) - gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10)) + gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10)) chn_s = sample(testbb(obs), smc, 1000) chn_p = sample(testbb(obs), pg, 2000) @@ -81,7 +81,7 @@ const gdemo_default = gdemo_d() return s, m end - gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8)) + gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8)) chain = sample(fggibbstest(xs), gibbs, 2) end @testset "new grammar" begin @@ -177,7 +177,7 @@ const gdemo_default = gdemo_d() end @testset "sample" begin - alg = Gibbs(; m=HMC(0.2, 3), s=PG(10)) + alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10)) chn = sample(gdemo_default, alg, 1000) end diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl deleted file mode 100644 index 70546350d..000000000 --- a/test/experimental/gibbs.jl +++ /dev/null @@ -1,271 +0,0 @@ -module ExperimentalGibbsTests - -using ..Models: MoGtest_default, MoGtest_default_z_vector, gdemo -using ..NumericalTests: check_MoGtest_default, check_MoGtest_default_z_vector, check_gdemo, - check_numerical, two_sample_test -using DynamicPPL -using Random -using Test -using Turing -using Turing.Inference: AdvancedHMC, AdvancedMH -using ForwardDiff: ForwardDiff -using ReverseDiff: ReverseDiff - -function check_transition_varnames( - transition::Turing.Inference.Transition, - parent_varnames -) - transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val - [first(vn_and_val)] - end - # Varnames in `transition` should be subsumed by those in `vns`. - for vn in transition_varnames - @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) - end -end - -const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ - Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe_literal)}, - Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, -} -has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false -has_dot_assume(::Model) = true - -@testset "Gibbs using `condition`" begin - @testset "Demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s - end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m - end - - samplers = [ - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => NUTS(), - ), - Turing.Experimental.Gibbs( - vns_s => NUTS(), - vns_m => HMC(0.01, 4), - ) - ] - - if !has_dot_assume(model) - # Add in some MH samplers, which are not compatible with `.~`. - append!( - samplers, - [ - Turing.Experimental.Gibbs( - vns_s => HMC(0.01, 4), - vns_m => MH(), - ), - Turing.Experimental.Gibbs( - vns_s => MH(), - vns_m => HMC(0.01, 4), - ) - ] - ) - end - - @testset "$sampler" for sampler in samplers - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - end - end - - @testset "comparison with 'gold-standard' samples" begin - num_iterations = 1_000 - thinning = 10 - num_chains = 4 - - # Determine initial parameters to make comparison as fair as possible. - posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) - - # Sampler to use for Gibbs components. - sampler_inner = HMC(0.1, 32) - sampler = Turing.Experimental.Gibbs( - vns_s => sampler_inner, - vns_m => sampler_inner, - ) - Random.seed!(42) - chain = sample( - model, - sampler, - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - discard_initial=1_000, - thinning=thinning - ) - - # "Ground truth" samples. - # TODO: Replace with closed-form sampling once that is implemented in DynamicPPL. - Random.seed!(42) - chain_true = sample( - model, - NUTS(), - MCMCThreads(), - num_iterations, - num_chains; - progress=false, - initial_params=initial_params, - thinning=thinning, - ) - - # Perform KS test to ensure that the chains are similar. - xs = Array(chain) - xs_true = Array(chain_true) - for i = 1:size(xs, 2) - @test two_sample_test(xs[:, i], xs_true[:, i]; warn_on_fail=true) - # Let's make sure that the significance level is not too low by - # checking that the KS test fails for some simple transformations. - # TODO: Replace the heuristic below with closed-form implementations - # of the targets, once they are implemented in DynamicPPL. - @test !two_sample_test(0.9 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1.1 .* xs_true[:, i], xs_true[:, i]) - @test !two_sample_test(1e-1 .+ xs_true[:, i], xs_true[:, i]) - end - end - end - end - - @testset "multiple varnames" begin - rng = Random.default_rng() - - @testset "with both `s` and `m` as random" begin - model = gdemo(1.5, 2.0) - vns = (@varname(s), @varname(m)) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - - # `sample` - Random.seed!(42) - chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) - end - - @testset "without `m` as random" begin - model = gdemo(1.5, 2.0) | (m=7 / 6,) - vns = (@varname(s),) - alg = Turing.Experimental.Gibbs(vns => MH()) - - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) - check_transition_varnames(transition, vns) - end - end - end - - @testset "CSMC + ESS" begin - rng = Random.default_rng() - model = MoGtest_default - alg = Turing.Experimental.Gibbs( - (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) - check_MoGtest_default(chain, atol = 0.2) - end - - @testset "CSMC + ESS (usage of implicit varname)" begin - rng = Random.default_rng() - model = MoGtest_default_z_vector - alg = Turing.Experimental.Gibbs( - @varname(z) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2)) - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) - check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) - check_MoGtest_default_z_vector(chain, atol = 0.2) - end - - @testset "externsalsampler" begin - @model function demo_gibbs_external() - m1 ~ Normal() - m2 ~ Normal() - - -1 ~ Normal(m1, 1) - +1 ~ Normal(m1 + m2, 1) - - return (; m1, m2) - end - - model = demo_gibbs_external() - samplers_inner = [ - externalsampler(AdvancedMH.RWMH(1)), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoForwardDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff()), - externalsampler(AdvancedHMC.HMC(1e-1, 32), adtype=AutoReverseDiff(compile=true)), - ] - @testset "$(sampler_inner)" for sampler_inner in samplers_inner - sampler = Turing.Experimental.Gibbs( - @varname(m1) => sampler_inner, - @varname(m2) => sampler_inner, - ) - Random.seed!(42) - chain = sample(model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0) - check_numerical(chain, [:m1, :m2], [-0.2, 0.6], atol=0.1) - end - end -end - -end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 9356fbcc1..da29e7708 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -31,8 +31,8 @@ using Turing PG(10), IS(), MH(), - Gibbs(; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)), - Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()), + Gibbs(:s => PG(3), :m => HMC(0.4, 8; adtype=adbackend)), + Gibbs(:s => HMC(0.1, 5; adtype=adbackend), :m => ESS()), ) for sampler in samplers Random.seed!(5) @@ -81,7 +81,7 @@ using Turing @testset "chain save/resume" begin alg1 = HMCDA(1000, 0.65, 0.15; adtype=adbackend) alg2 = PG(20) - alg3 = Gibbs(; s=PG(30), m=HMC(0.2, 4; adtype=adbackend)) + alg3 = Gibbs(:s => PG(30), :m => HMC(0.2, 4; adtype=adbackend)) chn1 = sample(StableRNG(seed), gdemo_default, alg1, 2_000; save_state=true) check_gdemo(chn1) @@ -260,7 +260,7 @@ using Turing smc = SMC() pg = PG(10) - gibbs = Gibbs(; p=HMC(0.2, 3; adtype=adbackend), x=PG(10)) + gibbs = Gibbs(:p => HMC(0.2, 3; adtype=adbackend), :x => PG(10)) chn_s = sample(StableRNG(seed), testbb(obs), smc, 200) chn_p = sample(StableRNG(seed), testbb(obs), pg, 200) @@ -288,7 +288,7 @@ using Turing return s, m end - gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8; adtype=adbackend)) + gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8; adtype=adbackend)) chain = sample(StableRNG(seed), fggibbstest(xs), gibbs, 2) end @@ -415,7 +415,7 @@ using Turing end @testset "sample" begin - alg = Gibbs(; m=HMC(0.2, 3; adtype=adbackend), s=PG(10)) + alg = Gibbs(:m => HMC(0.2, 3; adtype=adbackend), :s => PG(10)) chn = sample(StableRNG(seed), gdemo_default, alg, 10) end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 6db469b76..5533d11d7 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -40,7 +40,7 @@ using Turing c3 = sample(demodot_default, s1, N) c4 = sample(demodot_default, s2, N) - s3 = Gibbs(; m=ESS(), s=MH()) + s3 = Gibbs(:m => ESS(), :s => MH()) c5 = sample(gdemo_default, s3, N) end @@ -59,7 +59,7 @@ using Turing end @testset "gdemo with CSMC + ESS" begin - alg = Gibbs(; s=CSMC(15), m=ESS()) + alg = Gibbs(:s => CSMC(15), :m => ESS()) chain = sample(StableRNG(seed), gdemo(1.5, 2.0), alg, 2000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 43bdcdbb8..1d7208b43 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -130,6 +130,9 @@ end @test_throws ArgumentError Gibbs( @varname(s) => SGLD(; stepsize=PolynomialStepsize(0.25)) ) + # Values that we don't know how to convert to VarNames. + @test_throws MethodError Gibbs(1 => NUTS()) + @test_throws MethodError Gibbs("x" => NUTS()) end # Test that the samplers are being called in the correct order, on the correct target @@ -227,14 +230,14 @@ end nuts = NUTS() # Sample with all sorts of combinations of samplers and targets. sampler = Gibbs( - (@varname(s),) => AlgWrapper(mh), + @varname(s) => AlgWrapper(mh), (@varname(s), @varname(m)) => AlgWrapper(mh), - (@varname(m),) => AlgWrapper(pg), - (@varname(xs),) => AlgWrapper(hmc), - (@varname(ys),) => AlgWrapper(nuts), - (@varname(ys),) => AlgWrapper(nuts), + @varname(m) => AlgWrapper(pg), + @varname(xs) => AlgWrapper(hmc), + @varname(ys) => AlgWrapper(nuts), + @varname(ys) => AlgWrapper(nuts), (@varname(xs), @varname(ys)) => AlgWrapper(hmc), - (@varname(s),) => AlgWrapper(mh), + @varname(s) => AlgWrapper(mh), ) chain = sample(test_model(-1), sampler, 2) @@ -300,17 +303,12 @@ end # Two variables being sampled by one sampler. s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend)) s2 = Gibbs((@varname(s), :m) => PG(10)) - # One variable per sampler, using the keyword arg interface. - s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend))) - # As above but using a Dict of VarNames. - s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend))) # As above but different samplers and using kwargs. - s5 = Gibbs(; s=CSMC(3), m=HMCDA(200, 0.65, 0.15; adtype=adbackend)) - s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS()) - s7 = Gibbs(Dict((:s, @varname(m)) => PG(10))) + s3 = Gibbs(:s => CSMC(3), :m => HMCDA(200, 0.65, 0.15; adtype=adbackend)) + s4 = Gibbs(@varname(s) => HMC(0.1, 5; adtype=adbackend), @varname(m) => ESS()) # Multiple instnaces of the same sampler. This implements running, in this case, # 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs. - s8 = begin + s5 = begin hmc = HMC(0.1, 5; adtype=adbackend) pg = PG(10) vns = @varname(s) @@ -318,23 +316,20 @@ end Gibbs(vns => hmc, vns => hmc, vns => hmc, vnm => pg, vnm => pg) end # Same thing but using RepeatSampler. - s9 = Gibbs( + s6 = Gibbs( @varname(s) => RepeatSampler(HMC(0.1, 5; adtype=adbackend), 3), @varname(m) => RepeatSampler(PG(10), 2), ) - for s in (s1, s2, s3, s4, s5, s6, s7, s8, s9) + for s in (s1, s2, s3, s4, s5, s6) @test DynamicPPL.alg_str(Turing.Sampler(s, gdemo_default)) == "Gibbs" end - sample(gdemo_default, s1, N) - sample(gdemo_default, s2, N) - sample(gdemo_default, s3, N) - sample(gdemo_default, s4, N) - sample(gdemo_default, s5, N) - sample(gdemo_default, s6, N) - sample(gdemo_default, s7, N) - sample(gdemo_default, s8, N) - sample(gdemo_default, s9, N) + @test sample(gdemo_default, s1, N) isa MCMCChains.Chains + @test sample(gdemo_default, s2, N) isa MCMCChains.Chains + @test sample(gdemo_default, s3, N) isa MCMCChains.Chains + @test sample(gdemo_default, s4, N) isa MCMCChains.Chains + @test sample(gdemo_default, s5, N) isa MCMCChains.Chains + @test sample(gdemo_default, s6, N) isa MCMCChains.Chains g = Turing.Sampler(s3, gdemo_default) @test sample(gdemo_default, g, N) isa MCMCChains.Chains @@ -344,7 +339,7 @@ end # posterior mean. @testset "Gibbs inference" begin @testset "CSMC and HMC on gdemo" begin - alg = Gibbs(; s=CSMC(15), m=HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => CSMC(15), :m => HMC(0.2, 4; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:m], [7 / 6]; atol=0.15) # Be more relaxed with the tolerance of the variance. @@ -352,13 +347,13 @@ end end @testset "MH and HMCDA on gdemo" begin - alg = Gibbs(; s=MH(), m=HMCDA(200, 0.65, 0.3; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMCDA(200, 0.65, 0.3; adtype=adbackend)) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @testset "CSMC and ESS on gdemo" begin - alg = Gibbs(; s=CSMC(15), m=ESS()) + alg = Gibbs(:s => CSMC(15), :m => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 3_000) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) end @@ -435,7 +430,7 @@ end return nothing end - alg = Gibbs(; s=MH(), m=HMC(0.2, 4; adtype=adbackend)) + alg = Gibbs(:s => MH(), :m => HMC(0.2, 4; adtype=adbackend)) sample(model, alg, 100; callback=callback) end @@ -463,11 +458,11 @@ end num_samples = 10_000 model = imm(Random.randn(num_zs), 1.0) # https://github.com/TuringLang/Turing.jl/issues/1725 - # sample(model, Gibbs(; z=MH(), m=HMC(0.01, 4)), 100); + # sample(model, Gibbs(:z => MH(), :m => HMC(0.01, 4)), 100); chn = sample( StableRNG(23), model, - Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), + Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), num_samples, ) # The number of m variables that have a non-zero value in a sample. @@ -527,7 +522,7 @@ end # TODO(mhauru) This is broken because of # https://github.com/TuringLang/DynamicPPL.jl/issues/700. @test_broken ( - sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100); + sample(model, Gibbs(:z => PG(10), :m => HMC(0.01, 4; adtype=adbackend)), 100); true ) end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 47ff73b1c..d45846f3d 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -146,7 +146,9 @@ using Turing # explicitly specifying the seeds here. @testset "hmcda+gibbs inference" begin Random.seed!(12345) - alg = Gibbs(; s=PG(20), m=HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend)) + alg = Gibbs( + :s => PG(20), :m => HMCDA(500, 0.8, 0.25; init_ϵ=0.05, adtype=adbackend) + ) res = sample(StableRNG(123), gdemo_default, alg, 3000; discard_initial=1000) check_gdemo(res) end @@ -199,9 +201,9 @@ using Turing end @testset "AHMC resize" begin - alg1 = Gibbs(; m=PG(10), s=NUTS(100, 0.65; adtype=adbackend)) - alg2 = Gibbs(; m=PG(10), s=HMC(0.1, 3; adtype=adbackend)) - alg3 = Gibbs(; m=PG(10), s=HMCDA(100, 0.65, 0.3; adtype=adbackend)) + alg1 = Gibbs(:m => PG(10), :s => NUTS(100, 0.65; adtype=adbackend)) + alg2 = Gibbs(:m => PG(10), :s => HMC(0.1, 3; adtype=adbackend)) + alg3 = Gibbs(:m => PG(10), :s => HMCDA(100, 0.65, 0.3; adtype=adbackend)) @test sample(StableRNG(seed), gdemo_default, alg1, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg2, 10) isa Chains @test sample(StableRNG(seed), gdemo_default, alg3, 10) isa Chains diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 3823c2986..7d5a841d4 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -34,7 +34,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) c2 = sample(gdemo_default, s2, N) c3 = sample(gdemo_default, s3, N) - s4 = Gibbs(; m=MH(), s=MH()) + s4 = Gibbs(:m => MH(), :s => MH()) c4 = sample(gdemo_default, s4, N) # s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) @@ -69,7 +69,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) end @testset "gdemo_default with MH-within-Gibbs" begin - alg = Gibbs(; m=MH(), s=MH()) + alg = Gibbs(:m => MH(), :s => MH()) chain = sample( StableRNG(seed), gdemo_default, alg, 10_000; discard_initial, initial_params ) @@ -177,7 +177,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # with small-valued VC matrix to check if we only see very small steps vc_μ = convert(Array, 1e-4 * I(2)) vc_σ = convert(Array, 1e-4 * I(2)) - alg_small = Gibbs(; μ=MH((:μ, vc_μ)), σ=MH((:σ, vc_σ))) + alg_small = Gibbs(:μ => MH((:μ, vc_μ)), :σ => MH((:σ, vc_σ))) alg_big = MH() chn_small = sample(StableRNG(seed), mod, alg_small, 1_000) chn_big = sample(StableRNG(seed), mod, alg_big, 1_000) diff --git a/test/skipped/explicit_ret.jl b/test/skipped/explicit_ret.jl index 2dabc09bd..7472e8c31 100644 --- a/test/skipped/explicit_ret.jl +++ b/test/skipped/explicit_ret.jl @@ -12,7 +12,7 @@ end mf = test_ex_rt() for alg in - [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(; x=PG(20, 1), y=HMC(0.2, 3))] + [HMC(0.2, 3), PG(20, 2000), SMC(), IS(10000), Gibbs(:x => PG(20, 1), :y => HMC(0.2, 3))] chn = sample(mf, alg) @test mean(chn[:x]) ≈ 10.0 atol = 0.2 @test mean(chn[:y]) ≈ 5.0 atol = 0.2