-
Notifications
You must be signed in to change notification settings - Fork 36
Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Benchmark Report for Commit f0519dfComputer Information
Benchmark Results
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #999 +/- ##
=========================================
Coverage 81.91% 81.91%
=========================================
Files 38 38
Lines 4041 4025 -16
=========================================
- Hits 3310 3297 -13
+ Misses 731 728 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
I decided against implementing |
DynamicPPL.jl documentation for PR #999 is available at: |
@generated function _joint_keys( | ||
nt1::NamedTuple{names1}, nt2::NamedTuple{names2} | ||
) where {names1,names2} | ||
only_in_nt1 = tuple(setdiff(names1, names2)...) | ||
only_in_nt2 = tuple(setdiff(names2, names1)...) | ||
in_both = tuple(intersect(names1, names2)...) | ||
return :($only_in_nt1, $only_in_nt2, $in_both) | ||
end | ||
|
||
""" | ||
merge(at1::AccumulatorTuple, at2::AccumulatorTuple) | ||
|
||
Merge two `AccumulatorTuple`s. | ||
|
||
For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two | ||
accumulators themselves. Other accumulators are copied. | ||
""" | ||
function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) | ||
keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt) | ||
accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1) | ||
accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2) | ||
accs_in_both = ( | ||
merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both | ||
) | ||
return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked, and this implementation causes only one allocation of 32 bits, due to a new AccumulatorTuple being created.
@penelopeysm up to you when/where you want this merged. I made your |
If two accumulators of the same type should be merged in some non-trivial way, other than | ||
always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
What's the difference between
combine
andmerge
? -
I guess this follows on from our conversation on Slack. but I wonder if there a way to restrict the call on subset / merge to only the accumulators we care about? basically is it possible for us to circumvent the need to call
subset
on the entire AccumulatorTuple and thus avoid including slightly weird implementations ofsubset
on the logp accumulators? Mainly inspired by how you avoided implementing these for e.g. PLDAcc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine
and merge
probably do the same thing for most accumulators, but I'm not sure they have to for all accumulators. combine
has the restriction that combine(acc, split(acc)) == acc
, whereas merge
could do anything that is desirable on a call to merge
.
I don't know how to specify a subset of accumulators to call merge/subset on in a way that is any simpler than this implementation. Even before this PR, we used to call copy
on accs in merge and subset. Now that has just been shifted to be a method of AbstractAccumulator
that is the fallback for what to do when a subtype doesn't specify anything else.
Co-authored-by: Penelope Yong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually super duper sorry to be annoying about this, but I still can't get over my hangup with merge
. I took it upon myself to at least try to find a constructive way to solve this, and I think I have a good proposal.
(Even if we don't go with this, it was quite a fruitful exercise and along the way I found two new bugs, so that makes me feel quite happy!)
My starting point was: why can't we say that merge(vi1, vi2)
just takes the accumulators from vi2
? Indeed, that's what the current implementation (prior to this PR) does.
The reason why we're defining merge
is solely because of VariableOrderAccumulator, and specifically, because of this line:
It is the call to set_retained_vns_del!
(line 349) that fails.
But the fact that this is called immediately after reset_num_produce!!
(line 342) makes me think that surely, one can just replace the call to set_retained_vns_del!
with a manual
# For all other particles, do not retain the variables but resample them.
- # DynamicPPL.set_retained_vns_del!(vi)
+ for vn in keys(vi)
+ DynamicPPL.set_flag!(vi, vn, "del")
+ end
Indeed, this also fixes the key x not found
bug that is currently causing Turing CI to fail, arguably in a more direct way.
One might think (based on conversations we've had before, etc.) that we should only set the del
flag for variables which have order
strictly larger than 0. Actually it turns out this is not true: if you look at the implementation of set_retained_vns_del!
on main:
Lines 1905 to 1956 in ce7c8b1
""" | |
set_retained_vns_del!(vi::VarInfo) | |
Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. | |
""" | |
function set_retained_vns_del!(vi::UntypedVarInfo) | |
idcs = _getidcs(vi) | |
if get_num_produce(vi) == 0 | |
for i in length(idcs):-1:1 | |
vi.metadata.flags["del"][idcs[i]] = true | |
end | |
else | |
for i in 1:length(vi.orders) | |
if i in idcs && vi.orders[i] > get_num_produce(vi) | |
vi.metadata.flags["del"][i] = true | |
end | |
end | |
end | |
return nothing | |
end | |
function set_retained_vns_del!(vi::NTVarInfo) | |
idcs = _getidcs(vi) | |
return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) | |
end | |
@generated function _set_retained_vns_del!( | |
metadata, idcs::NamedTuple{names}, num_produce | |
) where {names} | |
expr = Expr(:block) | |
for f in names | |
f_idcs = :(idcs.$f) | |
f_orders = :(metadata.$f.orders) | |
f_flags = :(metadata.$f.flags) | |
push!( | |
expr.args, | |
quote | |
# Set the flag for variables with symbol `f` | |
if num_produce == 0 | |
for i in length($f_idcs):-1:1 | |
$f_flags["del"][$f_idcs[i]] = true | |
end | |
else | |
for i in 1:length($f_orders) | |
if i in $f_idcs && $f_orders[i] > num_produce | |
$f_flags["del"][i] = true | |
end | |
end | |
end | |
end, | |
) | |
end | |
return expr | |
end |
if num_produce == 0
it actually sets the del
flag for every variable (this was something I just found out too). Logically this also makes sense to me because otherwise you'd never resample the first set of variables.
So the docstring on main (and indeed the current implementation with accumulators) is wrong. I've put in #1000 to hotfix this (and hopefully we can eventually get rid of that function).
Results
So, there's good news and bad news.
The good news is that I tried it with my proposed fix instead (without this PR) and sampling this model gives almost the same results as this PR.
using Turing, Random
@model function f()
x ~ Normal()
y ~ Normal(x)
1.0 ~ Normal(y)
end
chn = sample(Xoshiro(468), f(), Gibbs(:x => PG(20), :y => ESS()), 1000)
mean(chn)
# My proposed fix:
# Turing @ py/no-set-retained-vns-del (I pushed this, but no PR yet; will wait for your thoughts)
# DynamicPPL @ breaking
julia> mean(chn)
[...]
x 0.0187
y 0.8660
# This PR:
# Turing @ mhauru/dppl-0.37-pmcmc
# DynamicPPL @ mhauru/logprobacc
julia> mean(chn)
[...]
x 0.0121
y 0.9058
The results should not be exactly the same because of the behaviour of set_retained_vns_del!
when num_produce == 0
that I mentioned above. However, generally they're the same. Thus, I'm reasonably convinced that my proposal is no worse than the current PR.
The bad news is that these numbers are clearly not right. x
should really be 1/3 and y
should be 2/3. If you run this on Turing/DPPL main, you do get the right result:
# Turing @ main
# DynamicPPL @ main
julia> mean(chn)
[...]
x 0.3128
y 0.6249
So it seems that somewhere along the way (but not as part of this PR) we have messed something up.
(many minutes later)
I found out why, and it's because ProduceLogLikelihoodAccumulator
isn't being added to the varinfo on subsequent steps. Adding that in (TuringLang/Turing.jl#2625 (comment)) makes this behave much more smoothly. In fact, (to my surprise!!) it gives us exactly the same results as on main!
julia> mean(chn)
[...]
x 0.3128
y 0.6249
What about this call? https://github.com/TuringLang/Turing.jl/blob/23b92ebaadb754f8ca6823bd5344d8d477813bd5/src/mcmc/particle_mcmc.jl#L47 |
I checked with a more complicated model and it seems to not error 🤷♀️ using Turing, Random
@model function f()
x ~ Normal()
y ~ Normal(x)
1.0 ~ Normal(y)
x2 ~ Normal(x)
2.0 ~ Normal(x2)
end
chn = sample(Xoshiro(468), f(), Gibbs((:x, :x2) => PG(20), (:y) => ESS()), 1000)
mean(chn) The results seem reasonable (and broadly similar to what I get from NUTS):
I guess maybe the real question is will Turing's CI pass? Let me open a PR and we can see. |
There are indeed still a couple of places where that line gets hit (and errors). Don't know why the new model wasn't enough to catch it. Will look into it :/ Also a huge bunch of failures from HMC using getlogjoint instead of getlogjoint_internal, good fun |
I figured it all out in TuringLang/Turing.jl#2629, but I think it's probably better for us to catch up about it some time next week. |
In fact, I am pretty sure that the proposed changes there allow us to get rid of VariableOrderAccumulator entirely 👀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving with the understanding that we'll eventually remove VariableOrderAccumulator.
* Bump minor version to 0.37.0 * Accumulators, stage 1 (#885) * Release 0.36 * AbstractPPL 0.11 + change prefixing behaviour (#830) * AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading * Remove VarInfo(VarInfo, params) (#870) * Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879) * Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really) * Draft of accumulators * Fix some variable names * Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!! * Map rather than broadcast Co-authored-by: Tor Erlend Fjelde <[email protected]> * Start documenting accumulators * Use Val{symbols} instead of AccTypes to index * More documentation for accumulators * Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890) * Link VarInfo by default * Tweak interface * Fix tests * Fix interface so that callers can inspect results * Document * Fix tests * Fix changelog * Test linked varinfos Closes #891 * Fix docstring + use AbstractFloat * Fix resetlogp!! and type stability for accumulators * Fix type rigidity of LogProbs and NumProduce * Fix uses of getlogp and other assorted issues * setaccs!! nicer interface and logdensity function fixes * Revert back to calling the macro @addlogprob! * Remove a dead test * Clarify a comment * Implement split/combine for PointwiseLogdensityAccumulator * Switch ThreadSafeVarInfo.accs_by_thread to be a tuple * Fix `condition` and `fix` in submodels (#892) * Fix conditioning in submodels * Simplify contextual_isassumption * Add documentation * Fix some tests * Add tests; fix a bunch of nested submodel issues * Fix fix as well * Fix doctests * Add unit tests for new functions * Add changelog entry * Update changelog Co-authored-by: Hong Ge <[email protected]> * Finish docs * Add a test for conditioning submodel via arguments * Clean new tests up a bit * Fix for VarNames with non-identity lenses * Apply suggestions from code review Co-authored-by: Markus Hauru <[email protected]> * Apply suggestions from code review * Make PrefixContext contain a varname rather than symbol (#896) --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Markus Hauru <[email protected]> * Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting in (Simple)VarInfo * Improve accumulator docs * Add test/accumulators.jl * Docs fixes * Various small fixes * Make DynamicTransformation not use accumulators other than LogPrior * Fix variable order and name of map_accumulator!! * Typo fixing * Small improvement to ThreadSafeVarInfo * Fix demo_dot_assume_observe_submodel prefixing * Typo fixing * Miscellaneous small fixes * HISTORY entry and more miscellanea * Add more tests for accumulators * Improve accumulators docstrings * Fix a typo * Expand HISTORY entry * Add accumulators to API docs * Remove unexported functions from API docs * Add NamedTuple methods for get/set/acclogp * Fix setlogp!! with single scalar to error * Export AbstractAccumulator, fix a docs typo * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce * Type bound log prob accumulators with T<:Real * Add @addlogprior! and @addloglikelihood! * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Move default accumulators to default_accumulators.jl * Fix some tests * Introduce default_accumulators() * Go back to only having @addlogprob! * Fix tilde_observe!! prefixing * Fix default_accumulators internal type * Make unflatten more type stable, and add a test for it * Always print all benchmark results * Move NumProduce VI functions to abstract_varinfo.jl --------- Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Hong Ge <[email protected]> * Replace PriorExtractorContext with PriorDistributionAccumulator (#907) * Implement values_as_in_model using an accumulator (#908) * Implement values_as_in_model using an accumulator * Make make_varname_expression a function * Refuse to combine ValuesAsInModelAccumulators with different include_colon_eqs * Fix nested context test * Bump DynamicPPL versions * Fix merge (1) * Add benchmark Pkg source * [no ci] Don't need to dev again * Disable use_closure for ReverseDiff * Revert "Disable use_closure for ReverseDiff" This reverts commit 3cb47cd. * Fix LogDensityAt struct * Try not duplicating * Update comment pointing to closure benchmarks * Remove `context` from model evaluation (use `model.context` instead) (#952) * Change `evaluate!!` API, add `sample!!` * Fix literally everything else that I broke * Fix some docstrings * fix ForwardDiffExt (look, multiple dispatch bad...) * Changelog * fix a test * Fix docstrings * use `sample!!` * Fix a couple more cases * Globally rename `sample!!` -> `evaluate_and_sample!!`, add changelog warning * Mark function as Const for Enzyme tests (#957) * Move submodel code to submodel.jl; remove `@submodel` (#959) * Move submodel code to submodel.jl * Remove `@submodel` * Fix missing field tests for 1.12 (#961) * Remove 3-argument `{_,}evaluate!!`; clean up submodel code (#960) * Clean up submodel code, remove 3-arg `_evaluate!!` * Remove 3-argument `evaluate!!` as well * Update changelog * Improve submodel error message * Fix doctest * Add error hint for three-argument evaluate!! * Improve API for AD testing (#964) * Rework API for AD testing * Fix test * Add `rng` keyword argument * Use atol and rtol * remove unbound type parameter (?) * Don't need to do elementwise check * Update changelog * Fix typo * DebugAccumulator (plus tiny bits and pieces) (#976) * DebugContext -> DebugAccumulator * Changelog * Force `conditioned` to return a dict * fix conditioned implementation * revert `conditioned` bugfix (will merge this to main instead) * fix show * Fix doctests * fix doctests 2 * Make VarInfo actually mandatory in check_model * Re-implement `missing` check * Revert `combine` signature in docstring * Revert changes to `Base.show` on AccumulatorTuple * Add TODO comment about VariableOrderAccumulator Co-authored-by: Markus Hauru <[email protected]> * Fix doctests --------- Co-authored-by: Markus Hauru <[email protected]> * VariableOrderAccumulator (#940) * Turn NumProduceAccumulator into VariableOrderAccumulator * Add comparison methods * Make VariableOrderAccumulator use regular Dict * Use copy rather than deepcopy for accumulators * Minor docstring touchup * Remove unnecessary use of NumProduceAccumulator * Fix split(VariableOrderAccumulator) * Remove NumProduceAcc from Debug * Fix set_retained_vns_del! --------- Co-authored-by: Penelope Yong <[email protected]> * Accumulators stage 2 (#925) * Give LogDensityFunction the getlogdensity field * Allow missing LogPriorAccumulator when linking * Trim whitespace * Run formatter * Fix a few typos * Fix comma -> semicolon * Fix `LogDensityAt` invocation * Fix one last test * Fix tests --------- Co-authored-by: Penelope Yong <[email protected]> * Implement more consistent tracking of logp components via `LogJacobianAccumulator` (#998) * logjac accumulator * Fix tests * Fix a whole bunch of stuff * Fix final tests * Fix docs * Fix docs/doctests * Fix maths in LogJacobianAccumulator docstring * Twiddle with a comment * Add changelog * Fix accumulator docstring * logJ -> logjac * Fix logjac accumulation for StaticTransformation * Fix behaviour of `set_retained_vns_del!` for `num_produce == 0` (#1000) * `InitContext`, part 2 - Move `hasvalue` and `getvalue` to AbstractPPL; enforce key type of `AbstractDict` (#980) * point to unmerged AbstractPPL branch * Remove code that was moved to AbstractPPL * Remove Dictionaries with Any key type * Fix bad merge conflict resolution * Fix doctests * Point to [email protected] This reverts commit 709dc9e. * Fix doctests * Fix docs AbstractPPL bound * Remove stray `Pkg.update()` * Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator (#999) * logjac accumulator * Fix tests * Fix a whole bunch of stuff * Fix final tests * Fix docs * Fix docs/doctests * Fix maths in LogJacobianAccumulator docstring * Twiddle with a comment * Add changelog * Simplify accs with LogProbAccumulator * Replace + with accumulate for LogProbAccs * Introduce merge and subset for accs * Improve acc tests * Fix docstring typo. Co-authored-by: Penelope Yong <[email protected]> * Fix merge --------- Co-authored-by: Penelope Yong <[email protected]> * Minor tweak to changelog wording --------- Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Hong Ge <[email protected]>
Three improvements to accumulators, independent of each other, in three commits:
Base.:+
withacclogp
for LogProbAccumulators. The+
felt a bit too general, and overloading like in DPPL 0.37 compat for particle MCMC Turing.jl#2625 (comment) felt wrong. We only ever use it for accumulating the logp anyway, there isn't really a need for legitimate algebra of LogProbAccs.subset
andBase.merge
for accumulators. By default they just make a copy of the (second) argument, but accs likeVariableOrderAccumulator
can overload them to do something non-trivial.Still missing
merge
andsubset
for VAIM acc and PointwiseLogDensitiesAcc, will add those.