Skip to content

[Merged by Bors] - Support for submodels #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.19"
version = "0.10.20"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
6 changes: 5 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export AbstractVarInfo,
LikelihoodContext,
PriorContext,
MiniBatchContext,
PrefixContext,
assume,
dot_assume,
observer,
Expand All @@ -96,7 +97,9 @@ export AbstractVarInfo,
logjoint,
pointwise_loglikelihoods,
# Convenience macros
@addlogprob!
@addlogprob!,
@submodel


# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -124,5 +127,6 @@ include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")

end # module
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
#################

"""
@model(expr[, warn = true])
@model(expr[, warn = false])

Macro to specify a probabilistic model.

Expand All @@ -73,7 +73,7 @@ end

To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
"""
macro model(expr, warn=true)
macro model(expr, warn=false)
Copy link
Member

@devmotion devmotion Apr 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, if the default value is false we can also just remove it since I doubt that anyone will enable the warnings explicitly. I am not completely sure if the warnings are useful anymore, in particular with the new variable names __varinfo__ etc. it seems unlikely thats someone would use the same name in their model definition. On the other hand, if we could ensure that official macros such as @addlogprob! and @submodel do not cause these warnings, I don't think there is any harm in keeping them.

So if possible, I think it would be better to check in the macro expansion step of the compiler if it is one of the official macros and disable warnings for only the expression generated by them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde What's your opinion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I weakly lean towards keeping this feature for developers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry! But yes, I left it there because of the same reason as Hong said. I'm pro leaving it as is, and then if no one uses it for a long time, we might as well just drop it then. No need to rush completely removing it IMO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought one should not only keep it but also show warnings if not explicitly requested otherwise - i.e., I suggested reverting it back to

Suggested change
macro model(expr, warn=false)
macro model(expr, warn=true)

However, to avoid printing warnings if users use @submodel or @addlogprob! I think one should disable warnings for the expanded code of these macros. It seems a simple if statement in the macro expansion in

return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
should be sufficient to achieve this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have to admit I don't like it either 😄 So I think I changed my mind and I would be fine with changing it to warn=false. Even though this changes the behaviour of @model this won't break anyone's code. And in the next breaking release we might even consider removing the warn argument completely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, lovely 👍 True, plus I don't think I've ever come across anyone actually using these warnings...

I'll make default false and push 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm, it's already this way, haha. I think this is good to go then!:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that it's used in DiffEqBayes since I added it there to avoid the warnings: https://github.com/SciML/DiffEqBayes.jl/blob/1749bc7ade1511d62a858eec4359705901126c92/src/turing_inference.jl#L53 😄 So as long as we do not suddenly remove it completely in a supposedly non-breaking release, it's fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, nice:) Good! I just merged with master and checking that tests run locally. Once that's done I'll bump version and it should be ready for bors!

# include `LineNumberNode` with information about the call site in the
# generated function for easier debugging and interpretation of error messages
esc(model(__module__, __source__, expr, warn))
Expand Down
6 changes: 6 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ end
function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
return tilde(rng, ctx.ctx, sampler, right, left, inds, vi)
end
function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi)
return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi)
end

"""
tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
Expand Down Expand Up @@ -75,6 +78,9 @@ end
function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
end
function tilde(ctx::PrefixContext, sampler, right, left, vi)
return tilde(ctx.ctx, sampler, right, left, vi)
end

"""
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
Expand Down
26 changes: 26 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,29 @@ end
function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints)
return MiniBatchContext(ctx, npoints/batch_size)
end


struct PrefixContext{Prefix, C} <: AbstractContext
ctx::C
end
PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} = PrefixContext{Prefix, typeof(ctx)}(ctx)

const PREFIX_SEPARATOR = Symbol(".")

function PrefixContext{PrefixInner}(
ctx::PrefixContext{PrefixOuter}
) where {PrefixInner, PrefixOuter}
if @generated
:(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}(ctx.ctx))
else
PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx)
end
end

function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix, Sym}
if @generated
return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing))
else
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
end
end
23 changes: 23 additions & 0 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
macro submodel(expr)
return quote
_evaluate(
$(esc(:__rng__)),
$(esc(expr)),
$(esc(:__varinfo__)),
$(esc(:__sampler__)),
$(esc(:__context__))
)
end
end

macro submodel(prefix, expr)
return quote
_evaluate(
$(esc(:__rng__)),
$(esc(expr)),
$(esc(:__varinfo__)),
$(esc(:__sampler__)),
PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))
)
end
end
101 changes: 101 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,107 @@ end
@test demo2()() == 42
end

@testset "submodel" begin
# No prefix, 1 level.
@model function demo1(x)
x ~ Normal()
end;
@model function demo2(x, y)
@submodel demo1(x)
y ~ Uniform()
end;
# No observation.
m = demo2(missing, missing);
vi = VarInfo(m);
ks = keys(vi)
@test VarName(:x) ∈ ks
@test VarName(:y) ∈ ks

# Observation in top-level.
m = demo2(missing, 1.0);
vi = VarInfo(m);
ks = keys(vi)
@test VarName(:x) ∈ ks
@test VarName(:y) ∉ ks

# Observation in nested model.
m = demo2(1000.0, missing);
vi = VarInfo(m);
ks = keys(vi)
@test VarName(:x) ∉ ks
@test VarName(:y) ∈ ks

# Observe all.
m = demo2(1000.0, 0.5);
vi = VarInfo(m);
ks = keys(vi)
@test isempty(ks)

# Check values makes sense.
@model function demo2(x, y)
@submodel demo1(x)
y ~ Normal(x)
end;
m = demo2(1000.0, missing);
# Mean of `y` should be close to 1000.
@test abs(mean([VarInfo(m)[VarName(:y)] for i = 1:10]) - 1000) ≤ 10;

# Prefixed submodels and usage of submodel return values.
@model function demo_return(x)
x ~ Normal()
return x
end;

@model function demo_useval(x, y)
x1 = @submodel sub1 demo_return(x)
x2 = @submodel sub2 demo_return(y)

z ~ Normal(x1 + x2 + 100, 1.0)
end;
m = demo_useval(missing, missing)
vi = VarInfo(m);
ks = keys(vi)
@test VarName(Symbol("sub1.x")) ∈ ks
@test VarName(Symbol("sub2.x")) ∈ ks
@test VarName(:z) ∈ ks
@test abs(mean([VarInfo(m)[VarName(:z)] for i = 1:10]) - 100) ≤ 10

# AR1 model. Dynamic prefixing.
@model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV}
η ~ MvNormal(num_steps, 1.0)
δ = sqrt(1 - α^2)

x = TV(undef, num_steps)
x[1] = η[1]
@inbounds for t = 2:num_steps
x[t] = @. α * x[t - 1] + δ * η[t]
end

return @. μ + σ * x
end

@model function demo(y)
α ~ Uniform()
μ ~ Normal()
σ ~ truncated(Normal(), 0, Inf)

num_steps = length(y[1])
num_obs = length(y)
@inbounds for i = 1:num_obs
x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
y[i] ~ MvNormal(x, 0.1)
end
end;

ys = [randn(10), randn(10)];
m = demo(ys);
vi = VarInfo(m);

for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")]
@test VarName(k) ∈ keys(vi)
end
end

@testset "check_tilde_rhs" begin
@test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn())

Expand Down