Skip to content

Commit d8fbe78

Browse files
authored
pass initial_state through for NUTS sampling (#2680)
* pass initial_state through for NUTS sampling * Add a test * add test * bump patch
1 parent 48808be commit d8fbe78

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.40.4
2+
3+
Fixes a bug where `initial_state` was not respected for NUTS if `resume_from` was not also specified.
4+
15
# 0.40.3
26

37
This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.40.3"
3+
version = "0.40.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/hmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ function AbstractMCMC.sample(
120120
sampler,
121121
N;
122122
chain_type=chain_type,
123+
initial_state=initial_state,
123124
progress=progress,
124125
nadapts=_nadapts,
125126
discard_initial=_discard_initial,

test/mcmc/hmc.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,27 @@ using Turing
197197
@test_throws ErrorException sample(demo_impossible(), NUTS(), 5)
198198
end
199199

200+
@testset "NUTS initial parameters" begin
201+
@model function f()
202+
x ~ Normal()
203+
return 10 ~ Normal(x)
204+
end
205+
chn1 = sample(StableRNG(468), f(), NUTS(), 100; save_state=true)
206+
# chn1 should end up around x = 5.
207+
chn2 = sample(
208+
StableRNG(468),
209+
f(),
210+
NUTS(),
211+
10;
212+
nadapts=0,
213+
discard_adapt=false,
214+
initial_state=chn1.info.samplerstate,
215+
)
216+
# if chn2 uses initial_state, its first sample should be somewhere around 5. if
217+
# initial_state isn't used, it will be sampled from [-2, 2] so this test should fail
218+
@test isapprox(chn2[:x][1], 5.0; atol=2.0)
219+
end
220+
200221
@testset "(partially) issue: #2095" begin
201222
@model function vector_of_dirichlet((::Type{TV})=Vector{Float64}) where {TV}
202223
xs = Vector{TV}(undef, 2)

0 commit comments

Comments
 (0)