Skip to content

Commit bd99d4f

Browse files
authored
Implement more consistent tracking of logp components via LogJacobianAccumulator (#998)
* logjac accumulator * Fix tests * Fix a whole bunch of stuff * Fix final tests * Fix docs * Fix docs/doctests * Fix maths in LogJacobianAccumulator docstring * Twiddle with a comment * Add changelog * Fix accumulator docstring * logJ -> logjac * Fix logjac accumulation for StaticTransformation
1 parent e60eab0 commit bd99d4f

24 files changed

+458
-191
lines changed

HISTORY.md

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,40 @@ Their semantics are the same as in Julia's `isapprox`; two values are equal if t
3232
You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.
3333
Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`).
3434

35-
### Removal of `PriorContext` and `LikelihoodContext`
36-
37-
A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`.
38-
Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below.
35+
### Evaluating model log-probabilities in more detail
3936

4037
Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field.
4138
`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively.
4239

43-
Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below).
40+
In this version, we have overhauled this quite substantially.
41+
The technical details of exactly _how_ this is done is covered in the 'Accumulators' section below, but the upshot is that the log prior, log likelihood, and log Jacobian terms (for any linked variables) are separately tracked.
42+
43+
Specifically, you will want to use the following functions to access these log probabilities:
44+
45+
- `getlogprior(varinfo)` to get the log prior. **Note:** This version introduces new, more consistent behaviour for this function, in that it always returns the log-prior of the values in the original, untransformed space, even if the `varinfo` has been linked.
46+
- `getloglikelihood(varinfo)` to get the log likelihood.
47+
- `getlogjoint(varinfo)` to get the log joint probability. **Note:** Similar to `getlogprior`, this function now always returns the log joint of the values in the original, untransformed space, even if the `varinfo` has been linked.
48+
49+
If you are using linked VarInfos (e.g. if you are writing a sampler), you may find that you need to obtain the log probability of the variables in the transformed space.
50+
To this end, you can use:
51+
52+
- `getlogjac(varinfo)` to get the log Jacobian of the link transforms for any linked variables.
53+
- `getlogprior_internal(varinfo)` to get the log prior of the variables in the transformed space.
54+
- `getlogjoint_internal(varinfo)` to get the log joint probability of the variables in the transformed space.
55+
56+
Since transformations only apply to random variables, the likelihood is unaffected by linking.
57+
58+
### Removal of `PriorContext` and `LikelihoodContext`
59+
60+
Following on from the above, a number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`.
61+
Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below.
62+
4463
If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood).
4564

4665
If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`.
4766
`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want.
48-
Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`.
67+
Thus, if you pass `getlogprior_internal` as the value of this parameter, you will get the same behaviour as with `PriorContext`.
68+
(You should consider whether your use case needs the log prior in the transformed space, or the original space, and use (respectively) `getlogprior_internal` or `getlogprior` as needed.)
4969

5070
The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior.
5171
Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`.

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8686
vi = DynamicPPL.link(vi, model)
8787
end
8888

89-
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
89+
f = DynamicPPL.LogDensityFunction(
90+
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
91+
)
9092
# The parameters at which we evaluate f.
9193
θ = vi[:]
9294

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ makedocs(;
2121
sitename="DynamicPPL",
2222
# The API index.html page is fairly large, and violates the default HTML page size
2323
# threshold of 200KiB, so we double that.
24-
format=Documenter.HTML(; size_threshold=2^10 * 400),
24+
format=Documenter.HTML(;
25+
size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3()
26+
),
2527
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
2628
pages=[
2729
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators.
367367

368368
```@docs
369369
LogPriorAccumulator
370+
LogJacobianAccumulator
370371
LogLikelihoodAccumulator
371372
VariableOrderAccumulator
372373
```
@@ -380,7 +381,12 @@ getlogp
380381
setlogp!!
381382
acclogp!!
382383
getlogjoint
384+
getlogjoint_internal
385+
getlogjac
386+
setlogjac!!
387+
acclogjac!!
383388
getlogprior
389+
getlogprior_internal
384390
setlogprior!!
385391
acclogprior!!
386392
getloglikelihood

src/DynamicPPL.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export AbstractVarInfo,
5050
AbstractAccumulator,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
53+
LogJacobianAccumulator,
5354
VariableOrderAccumulator,
5455
push!!,
5556
empty!!,
@@ -58,10 +59,15 @@ export AbstractVarInfo,
5859
getlogjoint,
5960
getlogprior,
6061
getloglikelihood,
62+
getlogjac,
63+
getlogjoint_internal,
64+
getlogprior_internal,
6165
setlogp!!,
6266
setlogprior!!,
67+
setlogjac!!,
6368
setloglikelihood!!,
6469
acclogp!!,
70+
acclogjac!!,
6571
acclogprior!!,
6672
accloglikelihood!!,
6773
resetlogp!!,

src/abstract_varinfo.jl

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
9999
"""
100100
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi)
101101

102+
"""
103+
getlogjoint_internal(vi::AbstractVarInfo)
104+
105+
Return the log of the joint probability of the observed data and parameters as
106+
they are stored internally in `vi`, including the log-Jacobian for any linked
107+
parameters.
108+
109+
In general, we have that:
110+
111+
```julia
112+
getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi)
113+
```
114+
"""
115+
getlogjoint_internal(vi::AbstractVarInfo) =
116+
getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi)
117+
102118
"""
103119
getlogp(vi::AbstractVarInfo)
104120
105-
Return a NamedTuple of the log prior and log likelihood probabilities.
121+
Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities.
106122
107-
The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
108-
error will be thrown.
123+
The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them
124+
are not present in `vi` an error will be thrown.
109125
"""
110126
function getlogp(vi::AbstractVarInfo)
111-
return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi))
127+
return (;
128+
logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi)
129+
)
112130
end
113131

114132
"""
@@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@
164182
"""
165183
getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp
166184

185+
"""
186+
getlogprior_internal(vi::AbstractVarInfo)
187+
188+
Return the log of the prior probability of the parameters as stored internally
189+
in `vi`. This includes the log-Jacobian for any linked parameters.
190+
191+
In general, we have that:
192+
193+
```julia
194+
getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi)
195+
```
196+
"""
197+
getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi)
198+
199+
"""
200+
getlogjac(vi::AbstractVarInfo)
201+
202+
Return the accumulated log-Jacobian term for any linked parameters in `vi`. The
203+
Jacobian here is taken with respect to the forward (link) transform.
204+
205+
See also: [`setlogjac!!`](@ref).
206+
"""
207+
getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logjac
208+
167209
"""
168210
getloglikelihood(vi::AbstractVarInfo)
169211
@@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re
196238
"""
197239
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))
198240

241+
"""
242+
setlogjac!!(vi::AbstractVarInfo, logjac)
243+
244+
Set the accumulated log-Jacobian term for any linked parameters in `vi`. The
245+
Jacobian here is taken with respect to the forward (link) transform.
246+
247+
See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref).
248+
"""
249+
setlogjac!!(vi::AbstractVarInfo, logjac) = setacc!!(vi, LogJacobianAccumulator(logjac))
250+
199251
"""
200252
setloglikelihood!!(vi::AbstractVarInfo, logp)
201253
@@ -215,18 +267,21 @@ Set both the log prior and the log likelihood probabilities in `vi`.
215267
See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
216268
"""
217269
function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
218-
if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior))
219-
error("logp must have the fields logprior and loglikelihood and no other fields.")
270+
if Set(names) != Set([:logprior, :logjac, :loglikelihood])
271+
error(
272+
"The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.",
273+
)
220274
end
221275
vi = setlogprior!!(vi, logp.logprior)
276+
vi = setlogjac!!(vi, logp.logjac)
222277
vi = setloglikelihood!!(vi, logp.loglikelihood)
223278
return vi
224279
end
225280

226281
function setlogp!!(vi::AbstractVarInfo, logp::Number)
227282
return error("""
228283
`setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use
229-
`setloglikelihood!!` and/or `setlogprior!!` instead.
284+
`setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead.
230285
""")
231286
end
232287

@@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp)
306361
return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
307362
end
308363

364+
"""
365+
acclogjac!!(vi::AbstractVarInfo, logjac)
366+
367+
Add `logjac` to the value of the log Jacobian in `vi`.
368+
369+
See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref).
370+
"""
371+
function acclogjac!!(vi::AbstractVarInfo, logjac)
372+
return map_accumulator!!(
373+
acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian)
374+
)
375+
end
376+
309377
"""
310378
accloglikelihood!!(vi::AbstractVarInfo, logp)
311379
@@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo)
368436
if hasacc(vi, Val(:LogPrior))
369437
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
370438
end
439+
if hasacc(vi, Val(:LogJacobian))
440+
vi = map_accumulator!!(zero, vi, Val(:LogJacobian))
441+
end
371442
if hasacc(vi, Val(:LogLikelihood))
372443
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
373444
end
@@ -836,21 +907,29 @@ function link!!(
836907
x = vi[:]
837908
y, logjac = with_logabsdet_jacobian(b, x)
838909

839-
lp_new = getlogprior(vi) - logjac
840-
vi_new = setlogprior!!(unflatten(vi, y), lp_new)
841-
return settrans!!(vi_new, t)
910+
# Set parameters and add the logjac term.
911+
vi = unflatten(vi, y)
912+
if hasacc(vi, Val(:LogJacobian))
913+
vi = acclogjac!!(vi, logjac)
914+
end
915+
return settrans!!(vi, t)
842916
end
843917

844918
function invlink!!(
845919
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
846920
)
847921
b = t.bijector
848922
y = vi[:]
849-
x, logjac = with_logabsdet_jacobian(b, y)
850-
851-
lp_new = getlogprior(vi) + logjac
852-
vi_new = setlogprior!!(unflatten(vi, x), lp_new)
853-
return settrans!!(vi_new, NoTransformation())
923+
x, inv_logjac = with_logabsdet_jacobian(b, y)
924+
925+
# Mildly confusing: we need to _add_ the logjac of the inverse transform,
926+
# because we are trying to remove the logjac of the forward transform
927+
# that was previously accumulated when linking.
928+
vi = unflatten(vi, x)
929+
if hasacc(vi, Val(:LogJacobian))
930+
vi = acclogjac!!(vi, inv_logjac)
931+
end
932+
return settrans!!(vi, NoTransformation())
854933
end
855934

856935
"""

src/accumulators.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,21 @@ seen so far.
1111
1212
An accumulator type `T <: AbstractAccumulator` must implement the following methods:
1313
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
14-
- `accumulate_observe!!(acc::T, right, left, vn)`
15-
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
14+
- `accumulate_observe!!(acc::T, dist, val, vn)`
15+
- `accumulate_assume!!(acc::T, val, logjac, vn, dist)`
1616
- `Base.copy(acc::T)`
1717
18+
In these functions:
19+
- `val` is the new value of the random variable sampled from a distribution (always in
20+
the original unlinked space), or the value on the left-hand side of an observe
21+
statement.
22+
- `dist` is the distribution on the RHS of the tilde statement.
23+
- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the
24+
tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`.
25+
- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the
26+
variable is stored as a linked value in the VarInfo. If the variable is stored in its
27+
original, unlinked form, then `logjac` is zero.
28+
1829
To be able to work with multi-threading, it should also implement:
1930
- `split(acc::T)`
2031
- `combine(acc::T, acc2::T)`

src/context_implementations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ end
123123
function assume(dist::Distribution, vn::VarName, vi)
124124
y = getindex_internal(vi, vn)
125125
f = from_maybe_linked_internal_transform(vi, vn, dist)
126-
x, logjac = with_logabsdet_jacobian(f, y)
127-
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
126+
x, inv_logjac = with_logabsdet_jacobian(f, y)
127+
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist)
128128
return x, vi
129129
end
130130

@@ -166,6 +166,6 @@ function assume(
166166

167167
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
168168
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
169-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
169+
vi = accumulate_assume!!(vi, r, logjac, vn, dist)
170170
return r, vi
171171
end

0 commit comments

Comments
 (0)