Skip to content

Commit 4f45089

Browse files
authored
Merge pull request #34 from TuringLang/csp/parnames
Fix parameter names when Transitions hold NamedTuples
2 parents 947357e + dc62fc7 commit 4f45089

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.2"
3+
version = "0.5.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/AdvancedMH.jl

+19
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ function AbstractMCMC.bundle_samples(
7676
return nts
7777
end
7878

79+
function AbstractMCMC.bundle_samples(
80+
ts::Vector{<:Transition{<:NamedTuple}},
81+
model::DensityModel,
82+
sampler::MHSampler,
83+
state,
84+
chain_type::Type{Vector{NamedTuple}};
85+
param_names=missing,
86+
kwargs...
87+
)
88+
# If the element type of ts is NamedTuples, just use the names in the
89+
# struct.
90+
91+
# Extract NamedTuples
92+
nts = map(x -> merge(x.params, (lp=x.lp,)), ts)
93+
94+
# Return em'
95+
return nts
96+
end
97+
7998
function __init__()
8099
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("mcmcchains-connect.jl")
81100
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("structarray-connect.jl")

test/runtests.jl

+5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ using Test
8484
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})
8585
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
8686
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
87+
88+
@test keys(c1[1]) == (:param_1, :lp)
89+
@test keys(c2[1]) == (:param_1, :param_2, :lp)
90+
@test keys(c3[1]) == (:a, :b, :lp)
91+
@test keys(c4[1]) == (:param_1, :lp)
8792
end
8893

8994
@testset "Initial parameters" begin

0 commit comments

Comments
 (0)