Skip to content

Remove context from model evaluation (use model.context instead) #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 19, 2025

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jun 13, 2025

Summary

This PR modifies the model evaluation function, model.f for some model::DynamicPPL.Model, to not take a context as its third argument. Thus, its signature looks like f(model, varinfo, args...; kwargs...) where args and kwargs are forwarded from the function that defines the model.

During model evaluation, the tilde-pipeline would always dispatch on the __context__ argument, and would completely ignore the __model__.context. It has now been changed so that it dispatches on __model__.context.

As a result of this, we can remove the context argument from most model evaluation code as well as LogDensityFunction. This simplifies a lot of code.

There are a handful of minor follow-up points:

Combination of __context__ and __model__.context

Inside the model evaluation function, the model's context was previously being ignored. However, if one were to call evaluate!!(model, varinfo, context), the model's context does not get ignored: it is combined together with context in order to form a larger context stack, which then becomes the actual evaluation context.

DynamicPPL.jl/src/model.jl

Lines 942 to 944 in 3e54c2d

context_new = setleafcontext(
context, setleafcontext(model.context, leafcontext(context))
)

After this PR, the intention is that if you want to keep that combining behaviour, you should do it manually. However, several 'outs' are provided:

  1. sample!!(rng, model, varinfo, sampler) is provided. This wraps model.context in a SamplingContext(rng, sampler) before calling evaluate!!. This takes care of most invocations of evaluate!! where this combining behaviour was being used.

    (As a bonus, this also helps us remove the rng and sampler arguments from evaluate!!, which greatly simplifies that function: evaluate!! shenanigans #720)

  2. evaluate!!(model, varinfo, context) still exists, but is dep-warned.

  3. _evaluate!!(model, varinfo, context) still exists with the same behaviour without a depwarn. This is only retained because submodels depend on this and I would like to leave this change for a subsequent PR, as that needs to be reasoned about more deeply.

Models as callables

Previously, (model::Model)(args...) would forward to evaluate!!(args...). I think this behaviour is quite dangerous (#629) and even in that PR it was mentioned that we should be more explicit about what args we take, which I fully agree with.

Along with cleaning up evaluate!!, this PR also cleans up models as callables such that the only allowed signatures are model(), model(rng), model(varinfo), and model(rng, varinfo). Thus you are no longer allowed to specify a context (you should contextualize the model instead), or a sampler (the default sampler was SampleFromPrior() and given that model(...) was always being used to sample from the prior, it seems hugely unlikely that anybody was really passing a different sampler, and if they were, they can jolly well call first(DynamicPPL.sample!!(rng, model, varinfo, sampler))).

Remaining test failure

The remaining test failure (Julia-pre) is unrelated to this PR.

Closes #951

This is a step towards fixing #720 but it's not complete; that will have to wait for #960

@penelopeysm penelopeysm changed the title remove context from model evaluation Remove context from model evaluation Jun 13, 2025
@penelopeysm penelopeysm changed the title Remove context from model evaluation Remove context from model evaluation (use model.context instead) Jun 13, 2025
Copy link
Contributor

github-actions bot commented Jun 13, 2025

Benchmark Report for Commit fa90d1c

Computer Information

Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.5 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                649.7 |                41.2 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                418.8 |                49.4 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |                986.3 |                35.9 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6572.7 |                26.5 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1467.9 |                27.8 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                989.9 |                 4.3 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5863.9 |                 3.8 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                971.3 |                 9.0 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              66043.0 |                 3.5 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8883.8 |                 9.6 |
|               Dynamic |        10 |    mooncake |             typed |   true |                128.4 |                12.3 |
|              Submodel |         1 |    mooncake |             typed |   true |                 12.5 |                 6.2 |
|                   LDA |        12 | reversediff |             typed |   true |               1203.8 |                 2.6 |

Copy link
Contributor

DynamicPPL.jl documentation for PR #952 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR952/

Copy link

codecov bot commented Jun 13, 2025

Codecov Report

Attention: Patch coverage is 89.31298% with 14 lines in your changes missing coverage. Please review.

Project coverage is 82.94%. Comparing base (bec523a) to head (fa90d1c).
Report is 1 commits behind head on breaking.

Files with missing lines Patch % Lines
src/simple_varinfo.jl 20.00% 8 Missing ⚠️
src/submodel_macro.jl 73.33% 4 Missing ⚠️
src/compiler.jl 66.66% 1 Missing ⚠️
src/test_utils/ad.jl 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking     #952      +/-   ##
============================================
+ Coverage     82.55%   82.94%   +0.38%     
============================================
  Files            38       38              
  Lines          4075     4068       -7     
============================================
+ Hits           3364     3374      +10     
+ Misses          711      694      -17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yebai
Copy link
Member

yebai commented Jun 18, 2025

That's an excellent direction. When duplicate contexts were first introduced, I was always bothered by them. We merged the relevant PRs to avoid blocking incremental improvements.

@yebai
Copy link
Member

yebai commented Jun 18, 2025

One minor comment on naming: sample!! will likely confuse people since it is ubiquitously used for sampling (posterior) distributions using MCMC. Perhaps consider an alternative, e.g., evaluate_and_sample!!?

@penelopeysm
Copy link
Member Author

I'd be happy to rename. Will wait for @mhauru to review before doing it -- though I also realise that about 99% of usecases for sample!! exist because we're trying to construct a new varinfo using SampleFromPrior, and if we go through with #955 I think those would effectively turn into initialise!! rather than sample!! -- although that's probably for another time. (And either way, IMO we should keep it unexported)

@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

Comparing to this, this runtime seems to have ~doubled:

|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6961.0 |                25.6 |

Not the end of the world, but any idea why? LDA has also gotten slower but I care less because it's so bad anyway.

@penelopeysm
Copy link
Member Author

this runtime seems to have ~doubled

The time on this branch is the same as on the breaking branch, though: so maybe it came from an earlier merge into there? I'm just in the process of updating all those branches so can take a look in a while

@penelopeysm
Copy link
Member Author

Also, while you're on leave, maybe this is a nice time for me to try to hack together something to monitor benchmark patterns over time haha

@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

My bad, I wasn't thinking when picking the thing to compare to. Nothing to do there then.

I'm still mid-review, but wondering about what to do about accumulators. I would like to say that I was hasty in merging the first accumulator stuff into breaking, because it's taking a while now to finish that stuff. It's releaseable in that tests pass, but I'm not sure it makes sense to release the accumulator stuff in parts. Thus we should move it to a different branch and clear the path for this to make it to a minor release. On the other hand though, this has been developed on top of accumulators, and I wouldn't want anyone to spend time disentangling the two with cherry-pick or rebase.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this PR. Just a few tiny typos and a couple of questions about code style.

Comment on lines +815 to +816
# ^ Weird Documenter.jl bug means that we have to write the two above separately
# as it can only detect the `function`-less syntax.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's an issue to track for this could link it. No worries if not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few issues about callable structs but I couldn't find one that was explicitly about this. I'll try to make an MWE

@penelopeysm
Copy link
Member Author

Okay, I fixed all the stuff. I renamed it to evaluate_and_sample!! too, but with a big warning in the changelog that this is liable to change (depending on what we do with SamplingContext). I think that should address the technical bits. The remaining question is whether we split this onto a separate branch. I'm inclined to not bother, partly because this isn't ready for release yet anyway (I need to fix it for submodels), and partly because of the hassle of cherry-picking (the simplification of leaf contexts helped a lot here, and I'm scared I might overlook correctness issues if I were to replay this on main).

I think the main argument for releasing it now is that I could use it upstream in some of the Turing sampler + LDF work, but I'm actually happy to hold off on that until we have a resolution to #955, and #955 in turn certainly needs to wait for accs to be done. So I think the happiest course of action is just for me to find other things to do until we are ready to release accs.

@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

Your argument for why it's okay to build this on top of accumulators is convincing and relieving.

@penelopeysm penelopeysm merged commit 3af63d5 into breaking Jun 19, 2025
19 of 21 checks passed
@penelopeysm penelopeysm deleted the py/no-context-eval branch June 19, 2025 16:45
This was referenced Jul 4, 2025
github-merge-queue bot pushed a commit that referenced this pull request Aug 7, 2025
* Bump minor version to 0.37.0

* Accumulators, stage 1 (#885)

* Release 0.36

* AbstractPPL 0.11 + change prefixing behaviour (#830)

* AbstractPPL 0.11; change prefixing behaviour

* Use DynamicPPL.prefix rather than overloading

* Remove VarInfo(VarInfo, params) (#870)

* Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879)

* Unify {Untyped,Typed}{Vector,}VarInfo constructors

* Update invocations

* NTVarInfo

* Fix tests

* More fixes

* Fixes

* Fixes

* Fixes

* Use lowercase functions, don't deprecate VarInfo

* Rewrite VarInfo docstring

* Fix methods

* Fix methods (really)

* Draft of accumulators

* Fix some variable names

* Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!!

* Map rather than broadcast

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Start documenting accumulators

* Use Val{symbols} instead of AccTypes to index

* More documentation for accumulators

* Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890)

* Link VarInfo by default

* Tweak interface

* Fix tests

* Fix interface so that callers can inspect results

* Document

* Fix tests

* Fix changelog

* Test linked varinfos

Closes #891

* Fix docstring + use AbstractFloat

* Fix resetlogp!! and type stability for accumulators

* Fix type rigidity of LogProbs and NumProduce

* Fix uses of getlogp and other assorted issues

* setaccs!! nicer interface and logdensity function fixes

* Revert back to calling the macro @addlogprob!

* Remove a dead test

* Clarify a comment

* Implement split/combine for PointwiseLogdensityAccumulator

* Switch ThreadSafeVarInfo.accs_by_thread to be a tuple

* Fix `condition` and `fix` in submodels (#892)

* Fix conditioning in submodels

* Simplify contextual_isassumption

* Add documentation

* Fix some tests

* Add tests; fix a bunch of nested submodel issues

* Fix fix as well

* Fix doctests

* Add unit tests for new functions

* Add changelog entry

* Update changelog

Co-authored-by: Hong Ge <[email protected]>

* Finish docs

* Add a test for conditioning submodel via arguments

* Clean new tests up a bit

* Fix for VarNames with non-identity lenses

* Apply suggestions from code review

Co-authored-by: Markus Hauru <[email protected]>

* Apply suggestions from code review

* Make PrefixContext contain a varname rather than symbol (#896)

---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>

* Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting in (Simple)VarInfo

* Improve accumulator docs

* Add test/accumulators.jl

* Docs fixes

* Various small fixes

* Make DynamicTransformation not use accumulators other than LogPrior

* Fix variable order and name of map_accumulator!!

* Typo fixing

* Small improvement to ThreadSafeVarInfo

* Fix demo_dot_assume_observe_submodel prefixing

* Typo fixing

* Miscellaneous small fixes

* HISTORY entry and more miscellanea

* Add more tests for accumulators

* Improve accumulators docstrings

* Fix a typo

* Expand HISTORY entry

* Add accumulators to API docs

* Remove unexported functions from API docs

* Add NamedTuple methods for get/set/acclogp

* Fix setlogp!! with single scalar to error

* Export AbstractAccumulator, fix a docs typo

* Apply suggestions from code review

Co-authored-by: Penelope Yong <[email protected]>

* Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce

* Type bound log prob accumulators with T<:Real

* Add @addlogprior! and @addloglikelihood!

* Apply suggestions from code review

Co-authored-by: Penelope Yong <[email protected]>

* Move default accumulators to default_accumulators.jl

* Fix some tests

* Introduce default_accumulators()

* Go back to only having @addlogprob!

* Fix tilde_observe!! prefixing

* Fix default_accumulators internal type

* Make unflatten more type stable, and add a test for it

* Always print all benchmark results

* Move NumProduce VI functions to abstract_varinfo.jl

---------

Co-authored-by: Penelope Yong <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Hong Ge <[email protected]>

* Replace PriorExtractorContext with PriorDistributionAccumulator (#907)

* Implement values_as_in_model using an accumulator (#908)

* Implement values_as_in_model using an accumulator

* Make make_varname_expression a function

* Refuse to combine ValuesAsInModelAccumulators with different include_colon_eqs

* Fix nested context test

* Bump DynamicPPL versions

* Fix merge (1)

* Add benchmark Pkg source

* [no ci] Don't need to dev again

* Disable use_closure for ReverseDiff

* Revert "Disable use_closure for ReverseDiff"

This reverts commit 3cb47cd.

* Fix LogDensityAt struct

* Try not duplicating

* Update comment pointing to closure benchmarks

* Remove `context` from model evaluation (use `model.context` instead) (#952)

* Change `evaluate!!` API, add `sample!!`

* Fix literally everything else that I broke

* Fix some docstrings

* fix ForwardDiffExt (look, multiple dispatch bad...)

* Changelog

* fix a test

* Fix docstrings

* use `sample!!`

* Fix a couple more cases

* Globally rename `sample!!` -> `evaluate_and_sample!!`, add changelog warning

* Mark function as Const for Enzyme tests (#957)

* Move submodel code to submodel.jl; remove `@submodel` (#959)

* Move submodel code to submodel.jl

* Remove `@submodel`

* Fix missing field tests for 1.12 (#961)

* Remove 3-argument `{_,}evaluate!!`; clean up submodel code (#960)

* Clean up submodel code, remove 3-arg `_evaluate!!`

* Remove 3-argument `evaluate!!` as well

* Update changelog

* Improve submodel error message

* Fix doctest

* Add error hint for three-argument evaluate!!

* Improve API for AD testing (#964)

* Rework API for AD testing

* Fix test

* Add `rng` keyword argument

* Use atol and rtol

* remove unbound type parameter (?)

* Don't need to do elementwise check

* Update changelog

* Fix typo

* DebugAccumulator (plus tiny bits and pieces) (#976)

* DebugContext -> DebugAccumulator

* Changelog

* Force `conditioned` to return a dict

* fix conditioned implementation

* revert `conditioned` bugfix (will merge this to main instead)

* fix show

* Fix doctests

* fix doctests 2

* Make VarInfo actually mandatory in check_model

* Re-implement `missing` check

* Revert `combine` signature in docstring

* Revert changes to `Base.show` on AccumulatorTuple

* Add TODO comment about VariableOrderAccumulator

Co-authored-by: Markus Hauru <[email protected]>

* Fix doctests

---------

Co-authored-by: Markus Hauru <[email protected]>

* VariableOrderAccumulator (#940)

* Turn NumProduceAccumulator into VariableOrderAccumulator

* Add comparison methods

* Make VariableOrderAccumulator use regular Dict

* Use copy rather than deepcopy for accumulators

* Minor docstring touchup

* Remove unnecessary use of NumProduceAccumulator

* Fix split(VariableOrderAccumulator)

* Remove NumProduceAcc from Debug

* Fix set_retained_vns_del!

---------

Co-authored-by: Penelope Yong <[email protected]>

* Accumulators stage 2 (#925)

* Give LogDensityFunction the getlogdensity field

* Allow missing LogPriorAccumulator when linking

* Trim whitespace

* Run formatter

* Fix a few typos

* Fix comma -> semicolon

* Fix `LogDensityAt` invocation

* Fix one last test

* Fix tests

---------

Co-authored-by: Penelope Yong <[email protected]>

* Implement more consistent tracking of logp components via `LogJacobianAccumulator` (#998)

* logjac accumulator

* Fix tests

* Fix a whole bunch of stuff

* Fix final tests

* Fix docs

* Fix docs/doctests

* Fix maths in LogJacobianAccumulator docstring

* Twiddle with a comment

* Add changelog

* Fix accumulator docstring

* logJ -> logjac

* Fix logjac accumulation for StaticTransformation

* Fix behaviour of `set_retained_vns_del!` for `num_produce == 0` (#1000)

* `InitContext`, part 2 - Move `hasvalue` and `getvalue` to AbstractPPL; enforce key type of `AbstractDict` (#980)

* point to unmerged AbstractPPL branch

* Remove code that was moved to AbstractPPL

* Remove Dictionaries with Any key type

* Fix bad merge conflict resolution

* Fix doctests

* Point to [email protected]

This reverts commit 709dc9e.

* Fix doctests

* Fix docs AbstractPPL bound

* Remove stray `Pkg.update()`

* Accumulator miscellanea: Subset, merge, acclogp, and LogProbAccumulator (#999)

* logjac accumulator

* Fix tests

* Fix a whole bunch of stuff

* Fix final tests

* Fix docs

* Fix docs/doctests

* Fix maths in LogJacobianAccumulator docstring

* Twiddle with a comment

* Add changelog

* Simplify accs with LogProbAccumulator

* Replace + with accumulate for LogProbAccs

* Introduce merge and subset for accs

* Improve acc tests

* Fix docstring typo.

Co-authored-by: Penelope Yong <[email protected]>

* Fix merge

---------

Co-authored-by: Penelope Yong <[email protected]>

* Minor tweak to changelog wording

---------

Co-authored-by: Penelope Yong <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants