-
Notifications
You must be signed in to change notification settings - Fork 36
Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Benchmark Report for Commit f0519dfComputer Information
Benchmark Results
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #999 +/- ##
=========================================
Coverage 81.91% 81.91%
=========================================
Files 38 38
Lines 4041 4025 -16
=========================================
- Hits 3310 3297 -13
+ Misses 731 728 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
I decided against implementing |
DynamicPPL.jl documentation for PR #999 is available at: |
@generated function _joint_keys( | ||
nt1::NamedTuple{names1}, nt2::NamedTuple{names2} | ||
) where {names1,names2} | ||
only_in_nt1 = tuple(setdiff(names1, names2)...) | ||
only_in_nt2 = tuple(setdiff(names2, names1)...) | ||
in_both = tuple(intersect(names1, names2)...) | ||
return :($only_in_nt1, $only_in_nt2, $in_both) | ||
end | ||
|
||
""" | ||
merge(at1::AccumulatorTuple, at2::AccumulatorTuple) | ||
|
||
Merge two `AccumulatorTuple`s. | ||
|
||
For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two | ||
accumulators themselves. Other accumulators are copied. | ||
""" | ||
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) | ||
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt) | ||
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1) | ||
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2) | ||
accs_in_both = ( | ||
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both | ||
) | ||
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked, and this implementation causes only one allocation of 32 bits, due to a new AccumulatorTuple being created.
@penelopeysm up to you when/where you want this merged. I made your |
If two accumulators of the same type should be merged in some non-trivial way, other than | ||
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
What's the difference between
combine
andmerge
? -
I guess this follows on from our conversation on Slack. but I wonder if there a way to restrict the call on subset / merge to only the accumulators we care about? basically is it possible for us to circumvent the need to call
subset
on the entire AccumulatorTuple and thus avoid including slightly weird implementations ofsubset
on the logp accumulators? Mainly inspired by how you avoided implementing these for e.g. PLDAcc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine
and merge
probably do the same thing for most accumulators, but I'm not sure they have to for all accumulators. combine
has the restriction that combine(acc, split(acc)) == acc
, whereas merge
could do anything that is desirable on a call to merge
.
I don't know how to specify a subset of accumulators to call merge/subset on in a way that is any simpler than this implementation. Even before this PR, we used to call copy
on accs in merge and subset. Now that has just been shifted to be a method of AbstractAccumulator
that is the fallback for what to do when a subtype doesn't specify anything else.
Co-authored-by: Penelope Yong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually super duper sorry to be annoying about this, but I still can't get over my hangup with merge
. I took it upon myself to at least try to find a constructive way to solve this, and I think I have a good proposal.
(Even if we don't go with this, it was quite a fruitful exercise and along the way I found two new bugs, so that makes me feel quite happy!)
My starting point was: why can't we say that merge(vi1, vi2)
just takes the accumulators from vi2
? Indeed, that's what the current implementation (prior to this PR) does.
The reason why we're defining merge
is solely because of VariableOrderAccumulator, and specifically, because of this line:
It is the call to set_retained_vns_del!
(line 349) that fails.
But the fact that this is called immediately after reset_num_produce!!
(line 342) makes me think that surely, one can just replace the call to set_retained_vns_del!
with a manual
# For all other particles, do not retain the variables but resample them.
- # DynamicPPL.set_retained_vns_del!(vi)
+ for vn in keys(vi)
+ DynamicPPL.set_flag!(vi, vn, "del")
+ end
Indeed, this also fixes the key x not found
bug that is currently causing Turing CI to fail, arguably in a more direct way.
One might think (based on conversations we've had before, etc.) that we should only set the del
flag for variables which have order
strictly larger than 0. Actually it turns out this is not true: if you look at the implementation of set_retained_vns_del!
on main:
Lines 1905 to 1956 in ce7c8b1
""" | |
set_retained_vns_del!(vi::VarInfo) | |
Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. | |
""" | |
function set_retained_vns_del!(vi::UntypedVarInfo) | |
idcs = _getidcs(vi) | |
if get_num_produce(vi) == 0 | |
for i in length(idcs):-1:1 | |
vi.metadata.flags["del"][idcs[i]] = true | |
end | |
else | |
for i in 1:length(vi.orders) | |
if i in idcs && vi.orders[i] > get_num_produce(vi) | |
vi.metadata.flags["del"][i] = true | |
end | |
end | |
end | |
return nothing | |
end | |
function set_retained_vns_del!(vi::NTVarInfo) | |
idcs = _getidcs(vi) | |
return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) | |
end | |
@generated function _set_retained_vns_del!( | |
metadata, idcs::NamedTuple{names}, num_produce | |
) where {names} | |
expr = Expr(:block) | |
for f in names | |
f_idcs = :(idcs.$f) | |
f_orders = :(metadata.$f.orders) | |
f_flags = :(metadata.$f.flags) | |
push!( | |
expr.args, | |
quote | |
# Set the flag for variables with symbol `f` | |
if num_produce == 0 | |
for i in length($f_idcs):-1:1 | |
$f_flags["del"][$f_idcs[i]] = true | |
end | |
else | |
for i in 1:length($f_orders) | |
if i in $f_idcs && $f_orders[i] > num_produce | |
$f_flags["del"][i] = true | |
end | |
end | |
end | |
end, | |
) | |
end | |
return expr | |
end |
if num_produce == 0
it actually sets the del
flag for every variable (this was something I just found out too). Logically this also makes sense to me because otherwise you'd never resample the first set of variables.
So the docstring on main (and indeed the current implementation with accumulators) is wrong. I've put in #1000 to hotfix this (and hopefully we can eventually get rid of that function).
Results
So, there's good news and bad news.
The good news is that I tried it with my proposed fix instead (without this PR) and sampling this model gives almost the same results as this PR.
using Turing, Random
@model function f()
x ~ Normal()
y ~ Normal(x)
1.0 ~ Normal(y)
end
chn = sample(Xoshiro(468), f(), Gibbs(:x => PG(20), :y => ESS()), 1000)
mean(chn)
# My proposed fix:
# Turing @ py/no-set-retained-vns-del (I pushed this, but no PR yet; will wait for your thoughts)
# DynamicPPL @ breaking
julia> mean(chn)
[...]
x 0.0187
y 0.8660
# This PR:
# Turing @ mhauru/dppl-0.37-pmcmc
# DynamicPPL @ mhauru/logprobacc
julia> mean(chn)
[...]
x 0.0121
y 0.9058
The results should not be exactly the same because of the behaviour of set_retained_vns_del!
when num_produce == 0
that I mentioned above. However, generally they're the same. Thus, I'm reasonably convinced that my proposal is no worse than the current PR.
The bad news is that these numbers are clearly not right. x
should really be 1/3 and y
should be 2/3. If you run this on Turing/DPPL main, you do get the right result:
# Turing @ main
# DynamicPPL @ main
julia> mean(chn)
[...]
x 0.3128
y 0.6249
So it seems that somewhere along the way (but not as part of this PR) we have messed something up.
(many minutes later)
I found out why, and it's because ProduceLogLikelihoodAccumulator
isn't being added to the varinfo on subsequent steps. Adding that in (TuringLang/Turing.jl#2625 (comment)) makes this behave much more smoothly. In fact, (to my surprise!!) it gives us exactly the same results as on main!
julia> mean(chn)
[...]
x 0.3128
y 0.6249
What about this call? https://github.com/TuringLang/Turing.jl/blob/23b92ebaadb754f8ca6823bd5344d8d477813bd5/src/mcmc/particle_mcmc.jl#L47 |
I checked with a more complicated model and it seems to not error 🤷♀️ using Turing, Random
@model function f()
x ~ Normal()
y ~ Normal(x)
1.0 ~ Normal(y)
x2 ~ Normal(x)
2.0 ~ Normal(x2)
end
chn = sample(Xoshiro(468), f(), Gibbs((:x, :x2) => PG(20), (:y) => ESS()), 1000)
mean(chn) The results seem reasonable (and broadly similar to what I get from NUTS):
I guess maybe the real question is will Turing's CI pass? Let me open a PR and we can see. |
There are indeed still a couple of places where that line gets hit (and errors). Don't know why the new model wasn't enough to catch it. Will look into it :/ Also a huge bunch of failures from HMC using getlogjoint instead of getlogjoint_internal, good fun |
I figured it all out in TuringLang/Turing.jl#2629, but I think it's probably better for us to catch up about it some time next week. |
In fact, I am pretty sure that the proposed changes there allow us to get rid of VariableOrderAccumulator entirely 👀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving with the understanding that we'll eventually remove VariableOrderAccumulator.
Three improvements to accumulators, independent of each other, in three commits:
Base.:+
withacclogp
for LogProbAccumulators. The+
felt a bit too general, and overloading like in DPPL 0.37 compat for particle MCMC Turing.jl#2625 (comment) felt wrong. We only ever use it for accumulating the logp anyway, there isn't really a need for legitimate algebra of LogProbAccs.subset
andBase.merge
for accumulators. By default they just make a copy of the (second) argument, but accs likeVariableOrderAccumulator
can overload them to do something non-trivial.Still missing
merge
andsubset
for VAIM acc and PointwiseLogDensitiesAcc, will add those.