Skip to content

Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 28, 2025

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Jul 24, 2025

Three improvements to accumulators, independent of each other, in three commits:

  • Introduce an abstract LogProbAccumulator to reduce code duplication
  • Replace Base.:+ with acclogp for LogProbAccumulators. The + felt a bit too general, and overloading like in DPPL 0.37 compat for particle MCMC Turing.jl#2625 (comment) felt wrong. We only ever use it for accumulating the logp anyway, there isn't really a need for legitimate algebra of LogProbAccs.
  • Introduce subset and Base.merge for accumulators. By default they just make a copy of the (second) argument, but accs like VariableOrderAccumulator can overload them to do something non-trivial.

Still missing merge and subset for VAIM acc and PointwiseLogDensitiesAcc, will add those.

Copy link
Contributor

github-actions bot commented Jul 24, 2025

Benchmark Report for Commit f0519df

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                 12.7 |                 1.4 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                730.7 |                37.1 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                520.9 |                44.1 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1179.2 |                29.5 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               7295.9 |                27.1 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1098.8 |                36.8 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1070.2 |                12.2 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               6799.9 |                51.9 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                990.4 |                 8.7 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              81128.3 |                47.0 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8414.0 |                 9.8 |
|               Dynamic |        10 |    mooncake |             typed |   true |                140.7 |                31.3 |
|              Submodel |         1 |    mooncake |             typed |   true |                 17.0 |                25.9 |
|                   LDA |        12 | reversediff |             typed |   true |               1035.5 |                 2.2 |

Copy link

codecov bot commented Jul 24, 2025

Codecov Report

❌ Patch coverage is 93.93939% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.91%. Comparing base (c6c0cbc) to head (f0519df).
⚠️ Report is 1 commits behind head on breaking.

Files with missing lines Patch % Lines
src/default_accumulators.jl 90.47% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff            @@
##           breaking     #999   +/-   ##
=========================================
  Coverage     81.91%   81.91%           
=========================================
  Files            38       38           
  Lines          4041     4025   -16     
=========================================
- Hits           3310     3297   -13     
+ Misses          731      728    -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mhauru
Copy link
Member Author

mhauru commented Jul 24, 2025

I decided against implementing merge and subset for PointwiseLogDensityAccumulator and VAIMAccumulator (and also PriorDistributionAccumulator) since they are strictly confined to the functions that use them, which never call subset/merge, and are not exported. Instead I added a bunch of tests which I forgot to add earlier.

Copy link
Contributor

DynamicPPL.jl documentation for PR #999 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR999/

Comment on lines +203 to +228
@generated function _joint_keys(
nt1::NamedTuple{names1}, nt2::NamedTuple{names2}
) where {names1,names2}
only_in_nt1 = tuple(setdiff(names1, names2)...)
only_in_nt2 = tuple(setdiff(names2, names1)...)
in_both = tuple(intersect(names1, names2)...)
return :($only_in_nt1, $only_in_nt2, $in_both)
end

"""
merge(at1::AccumulatorTuple, at2::AccumulatorTuple)

Merge two `AccumulatorTuple`s.

For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
accumulators themselves. Other accumulators are copied.
"""
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt)
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1)
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2)
accs_in_both = (
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both
)
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, and this implementation causes only one allocation of 32 bits, due to a new AccumulatorTuple being created.

@mhauru mhauru marked this pull request as ready for review July 24, 2025 12:36
@mhauru mhauru requested a review from penelopeysm July 24, 2025 12:36
@mhauru
Copy link
Member Author

mhauru commented Jul 24, 2025

@penelopeysm up to you when/where you want this merged. I made your py/logjac the base branch just to have the diff include only the actually novel bits.

Comment on lines +33 to +34
If two accumulators of the same type should be merged in some non-trivial way, other than
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
Copy link
Member

@penelopeysm penelopeysm Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What's the difference between combine and merge?

  2. I guess this follows on from our conversation on Slack. but I wonder if there a way to restrict the call on subset / merge to only the accumulators we care about? basically is it possible for us to circumvent the need to call subset on the entire AccumulatorTuple and thus avoid including slightly weird implementations of subset on the logp accumulators? Mainly inspired by how you avoided implementing these for e.g. PLDAcc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine and merge probably do the same thing for most accumulators, but I'm not sure they have to for all accumulators. combine has the restriction that combine(acc, split(acc)) == acc, whereas merge could do anything that is desirable on a call to merge.

I don't know how to specify a subset of accumulators to call merge/subset on in a way that is any simpler than this implementation. Even before this PR, we used to call copy on accs in merge and subset. Now that has just been shifted to be a method of AbstractAccumulator that is the fallback for what to do when a subtype doesn't specify anything else.

Co-authored-by: Penelope Yong <[email protected]>
Base automatically changed from py/logjac to breaking July 24, 2025 14:05
@mhauru mhauru requested a review from penelopeysm July 24, 2025 14:26
Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually super duper sorry to be annoying about this, but I still can't get over my hangup with merge. I took it upon myself to at least try to find a constructive way to solve this, and I think I have a good proposal.

(Even if we don't go with this, it was quite a fruitful exercise and along the way I found two new bugs, so that makes me feel quite happy!)

My starting point was: why can't we say that merge(vi1, vi2) just takes the accumulators from vi2? Indeed, that's what the current implementation (prior to this PR) does.

The reason why we're defining merge is solely because of VariableOrderAccumulator, and specifically, because of this line:

https://github.com/TuringLang/Turing.jl/blob/bdcd72f1e4f60c5542578f8bf4fc57e8307eef60/src/mcmc/particle_mcmc.jl#L342-L349

It is the call to set_retained_vns_del! (line 349) that fails.

But the fact that this is called immediately after reset_num_produce!! (line 342) makes me think that surely, one can just replace the call to set_retained_vns_del! with a manual

  # For all other particles, do not retain the variables but resample them.
- # DynamicPPL.set_retained_vns_del!(vi)
+ for vn in keys(vi)
+     DynamicPPL.set_flag!(vi, vn, "del")
+ end

Indeed, this also fixes the key x not found bug that is currently causing Turing CI to fail, arguably in a more direct way.

One might think (based on conversations we've had before, etc.) that we should only set the del flag for variables which have order strictly larger than 0. Actually it turns out this is not true: if you look at the implementation of set_retained_vns_del! on main:

DynamicPPL.jl/src/varinfo.jl

Lines 1905 to 1956 in ce7c8b1

"""
set_retained_vns_del!(vi::VarInfo)
Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`.
"""
function set_retained_vns_del!(vi::UntypedVarInfo)
idcs = _getidcs(vi)
if get_num_produce(vi) == 0
for i in length(idcs):-1:1
vi.metadata.flags["del"][idcs[i]] = true
end
else
for i in 1:length(vi.orders)
if i in idcs && vi.orders[i] > get_num_produce(vi)
vi.metadata.flags["del"][i] = true
end
end
end
return nothing
end
function set_retained_vns_del!(vi::NTVarInfo)
idcs = _getidcs(vi)
return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi))
end
@generated function _set_retained_vns_del!(
metadata, idcs::NamedTuple{names}, num_produce
) where {names}
expr = Expr(:block)
for f in names
f_idcs = :(idcs.$f)
f_orders = :(metadata.$f.orders)
f_flags = :(metadata.$f.flags)
push!(
expr.args,
quote
# Set the flag for variables with symbol `f`
if num_produce == 0
for i in length($f_idcs):-1:1
$f_flags["del"][$f_idcs[i]] = true
end
else
for i in 1:length($f_orders)
if i in $f_idcs && $f_orders[i] > num_produce
$f_flags["del"][i] = true
end
end
end
end,
)
end
return expr
end

if num_produce == 0 it actually sets the del flag for every variable (this was something I just found out too). Logically this also makes sense to me because otherwise you'd never resample the first set of variables.

So the docstring on main (and indeed the current implementation with accumulators) is wrong. I've put in #1000 to hotfix this (and hopefully we can eventually get rid of that function).

Results

So, there's good news and bad news.

The good news is that I tried it with my proposed fix instead (without this PR) and sampling this model gives almost the same results as this PR.

using Turing, Random
@model function f()
    x ~ Normal()
    y ~ Normal(x)
    1.0 ~ Normal(y)
end
chn = sample(Xoshiro(468), f(), Gibbs(:x => PG(20), :y => ESS()), 1000)
mean(chn)
# My proposed fix:
# Turing     @ py/no-set-retained-vns-del (I pushed this, but no PR yet; will wait for your thoughts)
# DynamicPPL @ breaking

julia> mean(chn)
[...]
           x    0.0187
           y    0.8660
# This PR:
# Turing     @ mhauru/dppl-0.37-pmcmc
# DynamicPPL @ mhauru/logprobacc

julia> mean(chn)
[...]
           x    0.0121
           y    0.9058

The results should not be exactly the same because of the behaviour of set_retained_vns_del! when num_produce == 0 that I mentioned above. However, generally they're the same. Thus, I'm reasonably convinced that my proposal is no worse than the current PR.

The bad news is that these numbers are clearly not right. x should really be 1/3 and y should be 2/3. If you run this on Turing/DPPL main, you do get the right result:

# Turing     @ main
# DynamicPPL @ main

julia> mean(chn)
[...]
           x    0.3128
           y    0.6249

So it seems that somewhere along the way (but not as part of this PR) we have messed something up.

(many minutes later)

I found out why, and it's because ProduceLogLikelihoodAccumulator isn't being added to the varinfo on subsequent steps. Adding that in (TuringLang/Turing.jl#2625 (comment)) makes this behave much more smoothly. In fact, (to my surprise!!) it gives us exactly the same results as on main!

julia> mean(chn)
[...]
           x    0.3128
           y    0.6249

@mhauru
Copy link
Member Author

mhauru commented Jul 25, 2025

@penelopeysm
Copy link
Member

penelopeysm commented Jul 25, 2025

I checked with a more complicated model and it seems to not error 🤷‍♀️

using Turing, Random
@model function f()
    x ~ Normal()
    y ~ Normal(x)
    1.0 ~ Normal(y)
    x2 ~ Normal(x)
    2.0 ~ Normal(x2)
end

chn = sample(Xoshiro(468), f(), Gibbs((:x, :x2) => PG(20), (:y) => ESS()), 1000)
mean(chn)

The results seem reasonable (and broadly similar to what I get from NUTS):

julia> mean(chn)
[...]
           x    0.7037
           y    0.8105
          x2    1.3517

I guess maybe the real question is will Turing's CI pass? Let me open a PR and we can see.

@penelopeysm
Copy link
Member

I guess maybe the real question is will Turing's CI pass? Let me open a PR and we can see.

There are indeed still a couple of places where that line gets hit (and errors). Don't know why the new model wasn't enough to catch it. Will look into it :/

Also a huge bunch of failures from HMC using getlogjoint instead of getlogjoint_internal, good fun

@penelopeysm
Copy link
Member

There are indeed still a couple of places where that line gets hit (and errors). Don't know why the new model wasn't enough to catch it.

I figured it all out in TuringLang/Turing.jl#2629, but I think it's probably better for us to catch up about it some time next week.

@penelopeysm
Copy link
Member

In fact, I am pretty sure that the proposed changes there allow us to get rid of VariableOrderAccumulator entirely 👀

@penelopeysm penelopeysm mentioned this pull request Jul 26, 2025
22 tasks
Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with the understanding that we'll eventually remove VariableOrderAccumulator.

@mhauru mhauru merged commit 5d9e934 into breaking Jul 28, 2025
20 of 21 checks passed
@mhauru mhauru deleted the mhauru/logprobacc branch July 28, 2025 14:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants