diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..1e72b507e --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style="blue" diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 000000000..e24839b43 --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,33 @@ +name: Format + +on: + push: + branches: + # This is where pull requests from "bors r+" are built. + - staging + # This is where pull requests from "bors try" are built. + - trying + # Build the master branch. + - master + pull_request: + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - name: Format code + run: | + using Pkg + Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899") + using JuliaFormatter + format("."; verbose=true) + shell: julia --color=yes {0} + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: JuliaFormatter + fail_on_error: true diff --git a/README.md b/README.md index 399342024..0398ed6e0 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![IntegrationTest](https://github.com/TuringLang/DynamicPPL.jl/workflows/IntegrationTest/badge.svg?branch=master)](https://github.com/TuringLang/DynamicPPL.jl/actions?query=workflow%3AIntegrationTest+branch%3Amaster) [![Coverage Status](https://coveralls.io/repos/github/TuringLang/DynamicPPL.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/DynamicPPL.jl?branch=master) [![Codecov](https://codecov.io/gh/TuringLang/DynamicPPL.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/DynamicPPL.jl) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://colprac.sciml.ai/) [![Bors enabled](https://bors.tech/images/badge_small.svg)](https://app.bors.tech/repositories/24589) diff --git a/bors.toml b/bors.toml index fac48228c..d60e29865 100644 --- a/bors.toml +++ b/bors.toml @@ -15,7 +15,8 @@ status = [ "test (1.%, windows-latest, x64, 2)", "test (1, windows-latest, x64, 1)", "test (1, windows-latest, x64, 2)", - "Turing.jl" + "Turing.jl", + "format" ] delete_merged_branches = true # Require at least on approval of a project member. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 2d6a61f33..acdb98183 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -5,101 +5,101 @@ using AbstractPPL using Distributions using Bijectors -import AbstractMCMC -import ChainRulesCore -import NaturalSort -import MacroTools -import ZygoteRules +using AbstractMCMC: AbstractMCMC +using ChainRulesCore: ChainRulesCore +using NaturalSort: NaturalSort +using MacroTools: MacroTools +using ZygoteRules: ZygoteRules -import Random +using Random: Random -import Base: Symbol, - ==, - hash, - getindex, - setindex!, - push!, - show, - isempty, - empty!, - getproperty, - setproperty!, - keys, - haskey +import Base: + Symbol, + ==, + hash, + getindex, + setindex!, + push!, + show, + isempty, + empty!, + getproperty, + setproperty!, + keys, + haskey # VarInfo -export AbstractVarInfo, - VarInfo, - UntypedVarInfo, - TypedVarInfo, - getlogp, - setlogp!, - acclogp!, - resetlogp!, - get_num_produce, - set_num_produce!, - reset_num_produce!, - increment_num_produce!, - set_retained_vns_del_by_spl!, - is_flagged, - set_flag!, - unset_flag!, - setgid!, - updategid!, - setorder!, - istrans, - link!, - invlink!, - tonamedtuple, -# VarName (reexport from AbstractPPL) - VarName, - inspace, - subsumes, - @varname, -# Compiler - @model, -# Utilities - vectorize, - reconstruct, - reconstruct!, - Sample, - init, - vectorize, -# Model - Model, - getmissings, - getargnames, - generated_quantities, -# Samplers - Sampler, - SampleFromPrior, - SampleFromUniform, -# Contexts - DefaultContext, - LikelihoodContext, - PriorContext, - MiniBatchContext, - PrefixContext, - assume, - dot_assume, - observer, - dot_observe, - tilde, - dot_tilde, -# Pseudo distributions - NamedDist, - NoDist, -# Prob macros - @prob_str, - @logprob_str, -# Convenience functions - logprior, - logjoint, - pointwise_loglikelihoods, -# Convenience macros - @addlogprob!, - @submodel - +export AbstractVarInfo, + VarInfo, + UntypedVarInfo, + TypedVarInfo, + getlogp, + setlogp!, + acclogp!, + resetlogp!, + get_num_produce, + set_num_produce!, + reset_num_produce!, + increment_num_produce!, + set_retained_vns_del_by_spl!, + is_flagged, + set_flag!, + unset_flag!, + setgid!, + updategid!, + setorder!, + istrans, + link!, + invlink!, + tonamedtuple, + # VarName (reexport from AbstractPPL) + VarName, + inspace, + subsumes, + @varname, + # Compiler + @model, + # Utilities + vectorize, + reconstruct, + reconstruct!, + Sample, + init, + vectorize, + # Model + Model, + getmissings, + getargnames, + generated_quantities, + # Samplers + Sampler, + SampleFromPrior, + SampleFromUniform, + # Contexts + DefaultContext, + LikelihoodContext, + PriorContext, + MiniBatchContext, + PrefixContext, + assume, + dot_assume, + observer, + dot_observe, + tilde, + dot_tilde, + # Pseudo distributions + NamedDist, + NoDist, + # Prob macros + @prob_str, + @logprob_str, + # Convenience functions + logprior, + logjoint, + pointwise_loglikelihoods, + # Convenience macros + @addlogprob!, + @submodel # Reexport using Distributions: loglikelihood @@ -112,7 +112,6 @@ function getspace end abstract type AbstractVarInfo <: AbstractModelTrace end abstract type AbstractContext end - include("utils.jl") include("selector.jl") include("model.jl") diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 874d0462e..47a627506 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -1,21 +1,15 @@ # See https://github.com/TuringLang/Turing.jl/issues/1199 ChainRulesCore.@non_differentiable push!( - vi::VarInfo, - vn::VarName, - r, - dist::Distribution, - gidset::Set{Selector} + vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) ChainRulesCore.@non_differentiable updategid!( - vi::AbstractVarInfo, - vn::VarName, - spl::Sampler, + vi::AbstractVarInfo, vn::VarName, spl::Sampler ) # https://github.com/TuringLang/Turing.jl/issues/1595 ZygoteRules.@adjoint function dot_observe( - spl::Union{SampleFromPrior, SampleFromUniform}, + spl::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi, diff --git a/src/compiler.jl b/src/compiler.jl index eb6804476..bef7d11c2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -15,14 +15,15 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(expr::Union{Symbol, Expr}) +function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote let $vn = $(varname(expr)) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__) + if !$(DynamicPPL.inargnames)($vn, __model__) || + $(DynamicPPL.inmissings)($vn, __model__) true else # Evaluate the LHS @@ -42,9 +43,11 @@ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of `Distributions`, then return `x`. """ function check_tilde_rhs(@nospecialize(x)) - return throw(ArgumentError( - "the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s" - )) + return throw( + ArgumentError( + "the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s", + ), + ) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x @@ -76,16 +79,14 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`. macro model(expr, warn=false) # include `LineNumberNode` with information about the call site in the # generated function for easier debugging and interpretation of error messages - esc(model(__module__, __source__, expr, warn)) + return esc(model(__module__, __source__, expr, warn)) end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) # Generate main body - modelinfo[:body] = generate_mainbody( - mod, modelinfo[:modeldef][:body], warn - ) + modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) return build_output(modelinfo, linenumbernode) end @@ -208,27 +209,29 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ) |> Base.remove_linenums! + return Base.remove_linenums!( + generate_dot_tilde( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, R, warn), + ), + ) end # Modify tilde operators. args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde - return generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ) |> Base.remove_linenums! + return Base.remove_linenums!( + generate_tilde( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, R, warn), + ), + ) end return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) end - - """ generate_tilde(left, right) @@ -331,8 +334,8 @@ function generate_dot_tilde(left, right) end end -const FloatOrArrayType = Type{<:Union{AbstractFloat, AbstractArray}} -hasmissing(T::Type{<:AbstractArray{TA}}) where {TA <: AbstractArray} = hasmissing(TA) +const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} +hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true hasmissing(T::Type) = false @@ -393,12 +396,11 @@ function build_output(modelinfo, linenumbernode) return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef))) end - function warn_empty(body) if all(l -> isa(l, LineNumberNode), body.args) @warn("Model definition seems empty, still continue.") end - return + return nothing end """ @@ -430,26 +432,18 @@ For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is `eltype(vi[spl])`. """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T -function get_matching_type( - spl::AbstractSampler, - vi, - ::Type{<:Union{Missing, AbstractFloat}}, -) - return Union{Missing, floatof(eltype(vi, spl))} +function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,floatof(eltype(vi, spl))} end -function get_matching_type( - spl::AbstractSampler, - vi, - ::Type{<:AbstractFloat}, -) +function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat}) return floatof(eltype(vi, spl)) end function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(spl, vi, T), N} + return Array{get_matching_type(spl, vi, T),N} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T +function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T} return Array{get_matching_type(spl, vi, T)} end -floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) +floatof(::Type{T}) where {T<:Real} = typeof(one(T) / one(T)) floatof(::Type) = Real # fallback if type inference failed diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5000b8bfe..afc5e4da3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,9 +1,8 @@ -using Distributions: UnivariateDistribution, - MultivariateDistribution, - MatrixDistribution, - Distribution +using Distributions: + UnivariateDistribution, MultivariateDistribution, MatrixDistribution, Distribution -const AMBIGUITY_MSG = "Ambiguous `LHS .~ RHS` or `@. LHS ~ RHS` syntax. The broadcasting " * +const AMBIGUITY_MSG = + "Ambiguous `LHS .~ RHS` or `@. LHS ~ RHS` syntax. The broadcasting " * "can either be column-wise following the convention of Distributions.jl or " * "element-wise following Julia's general broadcasting semantics. Please make sure " * "that the element type of `LHS` is not a supertype of the support type of " * @@ -57,7 +56,6 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end - function _tilde(rng, sampler, right, vn::VarName, vi) return assume(rng, sampler, right, vn, vi) end @@ -111,23 +109,18 @@ function tilde_observe(ctx, sampler, right, left, vi) return left end - _tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) function assume(rng, spl::Sampler, dist) - error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") + return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end function observe(spl::Sampler, weight) - error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") + return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end function assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi, + rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. @@ -149,10 +142,7 @@ function assume( end function observe( - spl::Union{SampleFromPrior, SampleFromUniform}, - dist::Distribution, - value, - vi, + spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi ) increment_num_produce!(vi) return Distributions.loglikelihood(dist, value) @@ -165,16 +155,7 @@ function dot_tilde(rng, ctx::DefaultContext, sampler, right, left, vn::VarName, vns, dist = get_vns_and_dist(right, left, vn) return _dot_tilde(rng, sampler, dist, left, vns, vi) end -function dot_tilde( - rng, - ctx::LikelihoodContext, - sampler, - right, - left, - vn::VarName, - inds, - vi, -) +function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) @@ -188,16 +169,7 @@ end function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) end -function dot_tilde( - rng, - ctx::PriorContext, - sampler, - right, - left, - vn::VarName, - inds, - vi, -) +function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) if ctx.vars !== nothing var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) @@ -223,19 +195,15 @@ function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) return value end - function get_vns_and_dist(dist::NamedDist, var, vn::VarName) return get_vns_and_dist(dist.dist, var, dist.name) end function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName) getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i))) return getvn.(1:size(var, 2)), dist - end function get_vns_and_dist( - dist::Union{Distribution, AbstractArray{<:Distribution}}, - var::AbstractArray, - vn::VarName + dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vn::VarName ) getvn = ind -> VarName(vn, (vn.indexing..., Tuple(ind))) return getvn.(CartesianIndices(var)), dist @@ -249,17 +217,17 @@ end function _dot_tilde( rng, sampler::AbstractSampler, - right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}}, + right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, vn::AbstractVector{<:VarName}, vi, ) - throw(DimensionMismatch(AMBIGUITY_MSG)) + return throw(DimensionMismatch(AMBIGUITY_MSG)) end function dot_assume( rng, - spl::Union{SampleFromPrior, SampleFromUniform}, + spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, vns::AbstractVector{<:VarName}, var::AbstractMatrix, @@ -273,8 +241,8 @@ function dot_assume( end function dot_assume( rng, - spl::Union{SampleFromPrior, SampleFromUniform}, - dists::Union{Distribution, AbstractArray{<:Distribution}}, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, var::AbstractArray, vi, @@ -285,15 +253,10 @@ function dot_assume( var .= r return var, lp end -function dot_assume( - rng, - spl::Sampler, - ::Any, - ::AbstractArray{<:VarName}, - ::Any, - ::Any, -) - error("[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement") +function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) + return error( + "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" + ) end function get_and_set_val!( @@ -322,7 +285,7 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!(vi, vn, r[:,i], dist, spl) + push!(vi, vn, r[:, i], dist, spl) settrans!(vi, false, vn) end end @@ -333,7 +296,7 @@ function get_and_set_val!( rng, vi, vns::AbstractArray{<:VarName}, - dists::Union{Distribution, AbstractArray{<:Distribution}}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, spl::Union{SampleFromPrior,SampleFromUniform}, ) if haskey(vi, vns[1]) @@ -362,21 +325,18 @@ function get_and_set_val!( end function set_val!( - vi, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - val::AbstractMatrix, + vi, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) - vi[vn] = val[:,i] + vi[vn] = val[:, i] end return val end function set_val!( vi, vns::AbstractArray{<:VarName}, - dists::Union{Distribution, AbstractArray{<:Distribution}}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, val::AbstractArray, ) @assert size(val) == size(vns) @@ -436,15 +396,15 @@ end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics function _dot_tilde( sampler::AbstractSampler, - right::Union{MultivariateDistribution, AbstractVector{<:MultivariateDistribution}}, + right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, left::AbstractMatrix{>:AbstractVector}, vi, ) - throw(DimensionMismatch(AMBIGUITY_MSG)) + return throw(DimensionMismatch(AMBIGUITY_MSG)) end function dot_observe( - spl::Union{SampleFromPrior, SampleFromUniform}, + spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi, @@ -455,7 +415,7 @@ function dot_observe( return Distributions.loglikelihood(dist, value) end function dot_observe( - spl::Union{SampleFromPrior, SampleFromUniform}, + spl::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi, @@ -466,7 +426,7 @@ function dot_observe( return Distributions.loglikelihood(dists, value) end function dot_observe( - spl::Union{SampleFromPrior, SampleFromUniform}, + spl::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi, @@ -476,11 +436,8 @@ function dot_observe( @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe( - spl::Sampler, - ::Any, - ::Any, - ::Any, -) - error("[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement") +function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) + return error( + "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" + ) end diff --git a/src/contexts.jl b/src/contexts.jl index 0d7006bf0..4d4f30bdc 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -45,33 +45,36 @@ The `MiniBatchContext` enables the computation of This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ -struct MiniBatchContext{Tctx, T} <: AbstractContext +struct MiniBatchContext{Tctx,T} <: AbstractContext ctx::Tctx loglike_scalar::T end -function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints/batch_size) +function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) + return MiniBatchContext(ctx, npoints / batch_size) end - -struct PrefixContext{Prefix, C} <: AbstractContext +struct PrefixContext{Prefix,C} <: AbstractContext ctx::C end -PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} = PrefixContext{Prefix, typeof(ctx)}(ctx) +function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(ctx)}(ctx) +end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( ctx::PrefixContext{PrefixOuter} -) where {PrefixInner, PrefixOuter} +) where {PrefixInner,PrefixOuter} if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}(ctx.ctx)) + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( + ctx.ctx + )) else PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) end end -function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix, Sym} +function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing)) else diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 3bcd6bd76..0cea2c225 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -1,31 +1,20 @@ -import Distributions -import Bijectors -using Distributions: Univariate, - Multivariate, - Matrixvariate - +using Distributions: Distributions +using Bijectors: Bijectors +using Distributions: Univariate, Multivariate, Matrixvariate """ A named distribution that carries the name of the random variable with it. """ -struct NamedDist{ - variate, - support, - Td <: Distribution{variate, support}, - Tv <: VarName -} <: Distribution{variate, support} +struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <: + Distribution{variate,support} dist::Td name::Tv end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName(name)) - -struct NoDist{ - variate, - support, - Td <: Distribution{variate, support} -} <: Distribution{variate, support} +struct NoDist{variate,support,Td<:Distribution{variate,support}} <: + Distribution{variate,support} dist::Td end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index e74aa31ca..89672127a 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,83 +1,72 @@ # Context version -struct PointwiseLikelihoodContext{A, Ctx} <: AbstractContext +struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A ctx::Ctx end function PointwiseLikelihoodContext( - likelihoods = Dict{VarName, Vector{Float64}}(), - ctx::AbstractContext = LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() ) return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName, Vector{Float64}}}, - vn::VarName, - logp::Real + ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real ) lookup = ctx.loglikelihoods ℓ = get!(lookup, vn, Float64[]) - push!(ℓ, logp) + return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName, Float64}}, - vn::VarName, - logp::Real + ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - ctx.loglikelihoods[vn] = logp + return ctx.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String, Vector{Float64}}}, - vn::VarName, - logp::Real + ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real ) lookup = ctx.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) - push!(ℓ, logp) + return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String, Float64}}, - vn::VarName, - logp::Real + ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - ctx.loglikelihoods[string(vn)] = logp + return ctx.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String, Vector{Float64}}}, - vn::String, - logp::Real + ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real ) lookup = ctx.loglikelihoods ℓ = get!(lookup, vn, Float64[]) - push!(ℓ, logp) + return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String, Float64}}, - vn::String, - logp::Real + ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - ctx.loglikelihoods[vn] = logp + return ctx.loglikelihoods[vn] = logp end - function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) end -function dot_tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume( + rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi +) value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end - -function tilde_observe(ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi) +function tilde_observe( + ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi +) # This is slightly unfortunate since it is not completely generic... # Ideally we would call `tilde_observe` recursively but then we don't get the # loglikelihood value. @@ -90,7 +79,6 @@ function tilde_observe(ctx::PointwiseLikelihoodContext, sampler, right, left, vn return left end - """ pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String) @@ -160,15 +148,11 @@ Dict{VarName,Array{Float64,2}} with 4 entries: xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] ``` """ -function pointwise_loglikelihoods( - model::Model, - chain, - keytype::Type{T} = String -) where {T} +function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T, Vector{Float64}}()) + ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -182,14 +166,14 @@ function pointwise_loglikelihoods( niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( - varname => reshape(logliks, niters, nchains) - for (varname, logliks) in ctx.loglikelihoods + varname => reshape(logliks, niters, nchains) for + (varname, logliks) in ctx.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName, Float64}()) + ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) model(varinfo, SampleFromPrior(), ctx) return ctx.loglikelihoods end diff --git a/src/model.jl b/src/model.jl index b0b78f71f..7189b590e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,7 +32,8 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbabilisticProgram +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: + AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -50,7 +51,9 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbab args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}(name, f, args, defaults) + return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( + name, f, args, defaults + ) end end @@ -64,10 +67,7 @@ Default arguments `defaults` are used internally when constructing instances of model with different arguments. """ @generated function Model( - name::Symbol, - f::F, - args::NamedTuple{argnames,Targs}, - defaults::NamedTuple = NamedTuple(), + name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) return :(Model{$missings}(name, f, args, defaults)) @@ -84,9 +84,9 @@ number of `sampler`. """ function (model::Model)( rng::Random.AbstractRNG, - varinfo::AbstractVarInfo = VarInfo(), - sampler::AbstractSampler = SampleFromPrior(), - context::AbstractContext = DefaultContext(), + varinfo::AbstractVarInfo=VarInfo(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) if Threads.nthreads() == 1 return evaluate_threadunsafe(rng, model, varinfo, sampler, context) @@ -99,11 +99,7 @@ function (model::Model)(args...) end # without VarInfo -function (model::Model)( - rng::Random.AbstractRNG, - sampler::AbstractSampler, - args..., -) +function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end @@ -151,7 +147,9 @@ end Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ -@generated function _evaluate(rng, model::Model{_F,argnames}, varinfo, sampler, context) where {_F,argnames} +@generated function _evaluate( + rng, model::Model{_F,argnames}, varinfo, sampler, context +) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) end @@ -163,7 +161,6 @@ Get a tuple of the argument names of the `model`. """ getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames - """ getmissings(model::Model) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 52d1c7466..d761e9fdc 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -8,7 +8,7 @@ macro prob_str(str) end function get_exprs(str::String) - substrings = split(str, '|'; limit = 2) + substrings = split(str, '|'; limit=2) length(substrings) == 2 || error("Invalid expression.") str1, str2 = substrings @@ -30,14 +30,16 @@ function logprob(ex1, ex2) end end -function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {namesl, namesr} +function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {namesl,namesr} if :chain in namesr if isdefined(ntr.chain.info, :model) model = ntr.chain.info.model elseif isdefined(ntr, :model) model = ntr.model else - throw("The model is not defined. Please make sure the model is either saved in the chain or passed on the RHS of |.") + throw( + "The model is not defined. Please make sure the model is either saved in the chain or passed on the RHS of |.", + ) end @assert model isa Model if isdefined(ntr.chain.info, :vi) @@ -53,7 +55,8 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names end defaults = model.defaults @assert all(getargnames(model)) do arg - isdefined(ntl, arg) || isdefined(ntr, arg) || + isdefined(ntl, arg) || + isdefined(ntr, arg) || isdefined(defaults, arg) && getfield(defaults, arg) !== missing end return Val(:likelihood), model, vi @@ -75,11 +78,13 @@ end function probtype( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames} + model::Model{_F,argnames,defaultnames}, ) where {leftnames,rightnames,argnames,defaultnames,_F} defaults = model.defaults - prior_rhs = all(n -> n in (:model, :varinfo) || - n in argnames && getfield(right, n) !== missing, rightnames) + prior_rhs = all( + n -> n in (:model, :varinfo) || n in argnames && getfield(right, n) !== missing, + rightnames, + ) function get_arg(arg) if arg in leftnames return getfield(left, arg) @@ -101,8 +106,8 @@ function probtype( # If no default value exists, use `nothing`. if prior_rhs return Val(:prior) - # Uses the default values for model arguments not provided. - # If no default value exists or the default value is missing, then error. + # Uses the default values for model arguments not provided. + # If no default value exists or the default value is missing, then error. elseif valid_args return Val(:likelihood) else @@ -114,14 +119,15 @@ function probtype( end end -missing_arg_error_msg(arg, ::Missing) = """Variable $arg has a value of `missing`, or is not defined and its default value is `missing`. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" -missing_arg_error_msg(arg, ::Nothing) = """Variable $arg is not defined and has no default value. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" +function missing_arg_error_msg(arg, ::Missing) + return """Variable $arg has a value of `missing`, or is not defined and its default value is `missing`. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" +end +function missing_arg_error_msg(arg, ::Nothing) + return """Variable $arg is not defined and has no default value. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" +end function logprior( - left::NamedTuple, - right::NamedTuple, - _model::Model, - _vi::Union{Nothing, VarInfo} + left::NamedTuple, right::NamedTuple, _model::Model, _vi::Union{Nothing,VarInfo} ) # For model args on the LHS of |, use their passed value but add the symbol to # model.missings. This will lead to an `assume`/`dot_assume` call for those variables. @@ -147,7 +153,7 @@ end @generated function make_prior_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames} + model::Model{_F,argnames,defaultnames}, ) where {leftnames,rightnames,argnames,defaultnames,_F} argvals = [] missings = [] @@ -172,7 +178,7 @@ end return quote $(warnings...) Model{$(Tuple(missings))}( - model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults, + model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults ) end end @@ -180,10 +186,7 @@ end warn_msg(arg) = "Argument $arg is not defined. A value of `nothing` is used." function Distributions.loglikelihood( - left::NamedTuple, - right::NamedTuple, - _model::Model, - _vi::Union{Nothing, VarInfo}, + left::NamedTuple, right::NamedTuple, _model::Model, _vi::Union{Nothing,VarInfo} ) model = make_likelihood_model(left, right, _model) vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi @@ -224,13 +227,15 @@ end elseif argname in defaultnames push!(argvals, :(model.defaults.$argname)) else - throw("This point should not be reached. Please open an issue in the DynamicPPL.jl repository.") + throw( + "This point should not be reached. Please open an issue in the DynamicPPL.jl repository.", + ) end end # `args` is inserted as properly typed NamedTuple expression; # `missings` is splatted into a tuple at compile time and inserted as literal return :(Model{$(Tuple(missings))}( - model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults, + model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults )) end diff --git a/src/sampler.jl b/src/sampler.jl index 8b32bef9d..5e97a64e3 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -6,7 +6,7 @@ Robust initialization method for model parameters in Hamiltonian samplers. struct SampleFromUniform <: AbstractSampler end struct SampleFromPrior <: AbstractSampler end -getspace(::Union{SampleFromPrior, SampleFromUniform}) = () +getspace(::Union{SampleFromPrior,SampleFromUniform}) = () # Initializations. init(rng, dist, ::SampleFromPrior) = rand(rng, dist) @@ -47,8 +47,8 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, sampler::Union{SampleFromUniform,SampleFromPrior}, - state = nothing; - kwargs... + state=nothing; + kwargs..., ) vi = VarInfo() model(rng, vi, sampler) @@ -57,11 +57,7 @@ end # initial step: general interface for resuming and function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler; - resume_from = nothing, - kwargs... + rng::Random.AbstractRNG, model::Model, spl::Sampler; resume_from=nothing, kwargs... ) if resume_from !== nothing state = loadstate(resume_from) @@ -135,7 +131,7 @@ function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) vi[spl] = theta linked && link!(vi, spl) - return + return nothing end """ diff --git a/src/selector.jl b/src/selector.jl index d974ca189..fd4aa6d1c 100644 --- a/src/selector.jl +++ b/src/selector.jl @@ -1,12 +1,12 @@ struct Selector - gid :: UInt64 - tag :: Symbol # :default, :invalid, :Gibbs, :HMC, etc. - rerun :: Bool + gid::UInt64 + tag::Symbol # :default, :invalid, :Gibbs, :HMC, etc. + rerun::Bool end -function Selector(tag::Symbol = :default, rerun = tag != :default) +function Selector(tag::Symbol=:default, rerun=tag != :default) return Selector(time_ns(), tag, rerun) end -function Selector(gid::Integer, tag::Symbol = :default) +function Selector(gid::Integer, tag::Symbol=:default) return Selector(gid, tag, tag != :default) end hash(s::Selector) = hash(s.gid) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 267ec8933..92584ae8b 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -5,7 +5,7 @@ macro submodel(expr) $(esc(expr)), $(esc(:__varinfo__)), $(esc(:__sampler__)), - $(esc(:__context__)) + $(esc(:__context__)), ) end end @@ -17,7 +17,7 @@ macro submodel(prefix, expr) $(esc(expr)), $(esc(:__varinfo__)), $(esc(:__sampler__)), - PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))) + PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 996934c53..c940f9e3f 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -47,7 +47,7 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - setgid!(vi.varinfo, gid, vn) + return setgid!(vi.varinfo, gid, vn) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -66,13 +66,13 @@ getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) function setindex!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - setindex!(vi.varinfo, val, spl) + return setindex!(vi.varinfo, val, spl) end function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - setindex!(vi.varinfo, val, spl) + return setindex!(vi.varinfo, val, spl) end function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - setindex!(vi.varinfo, val, spl) + return setindex!(vi.varinfo, val, spl) end function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) @@ -87,13 +87,9 @@ function empty!(vi::ThreadSafeVarInfo) end function push!( - vi::ThreadSafeVarInfo, - vn::VarName, - r, - dist::Distribution, - gidset::Set{Selector} + vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - push!(vi.varinfo, vn, r, dist, gidset) + return push!(vi.varinfo, vn, r, dist, gidset) end function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) diff --git a/src/utils.jl b/src/utils.jl index 9b083fe35..e77a4ecdd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -51,28 +51,31 @@ function to_namedtuple_expr(syms, vals=syms) if length(syms) == 0 nt = :(NamedTuple()) else - nt_type = Expr(:curly, :NamedTuple, - Expr(:tuple, QuoteNode.(syms)...), - Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...) + nt_type = Expr( + :curly, + :NamedTuple, + Expr(:tuple, QuoteNode.(syms)...), + Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...), ) nt = Expr(:call, :($(DynamicPPL.namedtuple)), nt_type, Expr(:tuple, vals...)) end return nt end - if VERSION == v"1.2" - @eval function namedtuple(::Type{NamedTuple{names, T}}, args::Tuple) where {names, T <: Tuple} + @eval function namedtuple( + ::Type{NamedTuple{names,T}}, args::Tuple + ) where {names,T<:Tuple} if length(args) != length(names) throw(ArgumentError("Wrong number of arguments to named tuple constructor.")) end # Note T(args) might not return something of type T; e.g. # Tuple{Type{Float64}}((Float64,)) returns a Tuple{DataType} - $(Expr(:splatnew, :(NamedTuple{names,T}), :(T(args)))) + return $(Expr(:splatnew, :(NamedTuple{names,T}), :(T(args)))) end else - function namedtuple(::Type{NamedTuple{names, T}}, args::Tuple) where {names, T <: Tuple} - return NamedTuple{names, T}(args) + function namedtuple(::Type{NamedTuple{names,T}}, args::Tuple) where {names,T<:Tuple} + return NamedTuple{names,T}(args) end end @@ -129,8 +132,13 @@ end randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 -const Transformable = Union{PositiveDistribution,UnitDistribution,TransformDistribution, - SimplexDistribution,PDMatDistribution} +const Transformable = Union{ + PositiveDistribution, + UnitDistribution, + TransformDistribution, + SimplexDistribution, + PDMatDistribution, +} istransformable(dist) = false istransformable(::Transformable) = true @@ -139,7 +147,9 @@ istransformable(::Transformable) = true ################################# inittrans(rng, dist::UnivariateDistribution) = invlink(dist, randrealuni(rng)) -inittrans(rng, dist::MultivariateDistribution) = invlink(dist, randrealuni(rng, size(dist)[1])) +function inittrans(rng, dist::MultivariateDistribution) + return invlink(dist, randrealuni(rng, size(dist)[1])) +end inittrans(rng, dist::MatrixDistribution) = invlink(dist, randrealuni(rng, size(dist)...)) ################################ @@ -154,7 +164,6 @@ function inittrans(rng, dist::MatrixDistribution, n::Int) return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) end - ####################### # Convenience methods # ####################### diff --git a/src/varinfo.jl b/src/varinfo.jl index de369cceb..e5e71eed1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,14 +1,12 @@ # Constants for caching -const CACHERESET = 0b00 -const CACHEIDCS = 0b10 +const CACHERESET = 0b00 +const CACHEIDCS = 0b10 const CACHERANGES = 0b01 #### #### Types for typed and untyped VarInfo #### - - #################### # VarInfo metadata # #################### @@ -43,33 +41,39 @@ When sampling, the first iteration uses a type unstable `Metadata` for all the variables then a specialized `Metadata` is used for each symbol along with a function barrier to make the rest of the sampling type stable. """ -struct Metadata{TIdcs <: Dict{<:VarName,Int}, TDists <: AbstractVector{<:Distribution}, TVN <: AbstractVector{<:VarName}, TVal <: AbstractVector{<:Real}, TGIds <: AbstractVector{Set{Selector}}} +struct Metadata{ + TIdcs<:Dict{<:VarName,Int}, + TDists<:AbstractVector{<:Distribution}, + TVN<:AbstractVector{<:VarName}, + TVal<:AbstractVector{<:Real}, + TGIds<:AbstractVector{Set{Selector}}, +} # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` - idcs :: TIdcs # Dict{<:VarName,Int} + idcs::TIdcs # Dict{<:VarName,Int} # Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` - vns :: TVN # AbstractVector{<:VarName} + vns::TVN # AbstractVector{<:VarName} # Vector of index ranges in `vals` corresponding to `vns` # Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` - ranges :: Vector{UnitRange{Int}} + ranges::Vector{UnitRange{Int}} # Vector of values of all the univariate, multivariate and matrix variables # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals :: TVal # AbstractVector{<:Real} + vals::TVal # AbstractVector{<:Real} # Vector of distributions correpsonding to `vns` - dists :: TDists # AbstractVector{<:Distribution} + dists::TDists # AbstractVector{<:Distribution} # Vector of sampler ids corresponding to `vns` # Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set` - gids :: TGIds # AbstractVector{Set{Selector}} + gids::TGIds # AbstractVector{Set{Selector}} # Number of `observe` statements before each random variable is sampled - orders :: Vector{Int} + orders::Vector{Int} # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` - flags :: Dict{String, BitVector} + flags::Dict{String,BitVector} end ########### @@ -97,7 +101,7 @@ Note: It is the user's responsibility to ensure that each "symbol" is visited at once whenever the model is called, regardless of any stochastic branching. Each symbol refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. """ -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} @@ -113,14 +117,16 @@ end function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))) + return VarInfo( + md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) + ) end function VarInfo( rng::Random.AbstractRNG, model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::AbstractContext = DefaultContext(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) varinfo = VarInfo() model(rng, varinfo, sampler, context) @@ -129,31 +135,33 @@ end VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo( - rng::Random.AbstractRNG, - model::Model, - context::AbstractContext, -) +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end -@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space} +@generated function newmetadata( + metadata::NamedTuple{names}, ::Val{space}, x +) where {names,space} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) if inspace(f, space) || length(space) == 0 len = :(length($mdf.vals)) - push!(exprs, :($f = Metadata($mdf.idcs, - $mdf.vns, - $mdf.ranges, - x[($offset + 1):($offset + $len)], - $mdf.dists, - $mdf.gids, - $mdf.orders, - $mdf.flags - ) - ) + push!( + exprs, + :( + $f = Metadata( + $mdf.idcs, + $mdf.vns, + $mdf.ranges, + x[($offset + 1):($offset + $len)], + $mdf.dists, + $mdf.gids, + $mdf.orders, + $mdf.flags, + ) + ), ) offset = :($offset + $len) else @@ -174,20 +182,20 @@ end Construct an empty type unstable instance of `Metadata`. """ function Metadata() - vals = Vector{Real}() - flags = Dict{String, BitVector}() + vals = Vector{Real}() + flags = Dict{String,BitVector}() flags["del"] = BitVector() flags["trans"] = BitVector() return Metadata( - Dict{VarName, Int}(), + Dict{VarName,Int}(), Vector{VarName}(), Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), Vector{Set{Selector}}(), Vector{Int}(), - flags + flags, ) end @@ -215,12 +223,12 @@ end # Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. if VERSION < v"1.1" - _tail(nt::NamedTuple{names}) where names = NamedTuple{Base.tail(names)}(nt) + _tail(nt::NamedTuple{names}) where {names} = NamedTuple{Base.tail(names)}(nt) else _tail(nt::NamedTuple) = Base.tail(nt) end -const VarView = Union{Int, UnitRange, Vector{Int}} +const VarView = Union{Int,UnitRange,Vector{Int}} """ getval(vi::UntypedVarInfo, vview::Union{Int, UnitRange, Vector{Int}}) @@ -270,7 +278,7 @@ getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ function getranges(vi::AbstractVarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getrange(vi, vn), vcat, vns, init=Int[]) + return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[]) end """ @@ -335,13 +343,13 @@ The values may or may not be transformed to Euclidean space. """ setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) -@generated function _setall!(metadata::NamedTuple{names}, val, start = 0) where {names} +@generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names} expr = Expr(:block) start = :(1) for f in names length = :(length(metadata.$f.vals)) finish = :($start + $length - 1) - push!(expr.args, :(metadata.$f.vals .= val[$start:$finish])) + push!(expr.args, :(metadata.$f.vals .= val[($start):($finish)])) start = :($start + $length) end return expr @@ -360,7 +368,7 @@ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] Set the `trans` flag value of `vn` in `vi`. """ function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") + return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") end """ @@ -375,7 +383,7 @@ syms(vi::TypedVarInfo) = keys(vi.metadata) # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler @inline function _getidcs(vi::UntypedVarInfo, ::SampleFromPrior) - return filter(i -> isempty(vi.metadata.gids[i]) , 1:length(vi.metadata.gids)) + return filter(i -> isempty(vi.metadata.gids[i]), 1:length(vi.metadata.gids)) end # Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol @inline function _getidcs(vi::TypedVarInfo, ::SampleFromPrior) @@ -401,16 +409,18 @@ end #if haskey(spl.info, :idcs) && (spl.info[:cache_updated] & CACHEIDCS) > 0 # spl.info[:idcs] #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) - #spl.info[:idcs] = idcs + #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS + idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) + #spl.info[:idcs] = idcs #end return idcs end @inline _getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) @inline _getidcs(vi::TypedVarInfo, s::Selector, space) = _getidcs(vi.metadata, s, space) # Get a NamedTuple for all the indices belonging to a given selector for each symbol -@generated function _getidcs(metadata::NamedTuple{names}, s::Selector, ::Val{space}) where {names, space} +@generated function _getidcs( + metadata::NamedTuple{names}, s::Selector, ::Val{space} +) where {names,space} exprs = [] # Iterate through each varname in metadata. for f in names @@ -426,9 +436,12 @@ end end @inline function findinds(f_meta, s, ::Val{space}) where {space} # Get all the idcs of the vns in `space` and that belong to the selector `s` - return filter((i) -> - (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && - (isempty(space) || inspace(f_meta.vns[i], space)), 1:length(f_meta.gids)) + return filter( + (i) -> + (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && + (isempty(space) || inspace(f_meta.vns[i], space)), + 1:length(f_meta.gids), + ) end @inline function findinds(f_meta) # Get all the idcs of the vns @@ -437,8 +450,12 @@ end # Get all vns of variables belonging to spl _getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -_getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior, SampleFromUniform}) = _getvns(vi, Selector(), Val(())) -_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space)) +function _getvns(vi::AbstractVarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) + return _getvns(vi, Selector(), Val(())) +end +function _getvns(vi::UntypedVarInfo, s::Selector, space) + return view(vi.metadata.vns, _getidcs(vi, s, space)) +end function _getvns(vi::TypedVarInfo, s::Selector, space) return _getvns(vi.metadata, _getidcs(vi, s, space)) end @@ -459,10 +476,10 @@ end #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 # spl.info[:ranges] #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, Val(getspace(spl))) - #spl.info[:ranges] = ranges - return ranges + #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES + ranges = _getranges(vi, spl.selector, Val(getspace(spl))) + #spl.info[:ranges] = ranges + return ranges #end end # Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` @@ -470,7 +487,7 @@ end return _getranges(vi, _getidcs(vi, s, space)) end @inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) - mapreduce(i -> vi.metadata.ranges[i], vcat, idcs, init=Int[]) + return mapreduce(i -> vi.metadata.ranges[i], vcat, idcs; init=Int[]) end @inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) @@ -483,7 +500,7 @@ end return :($(exprs...),) end @inline function findranges(f_ranges, f_idcs) - return mapreduce(i -> f_ranges[i], vcat, f_idcs, init=Int[]) + return mapreduce(i -> f_ranges[i], vcat, f_idcs; init=Int[]) end """ @@ -499,7 +516,6 @@ end #### APIs for typed and untyped VarInfo #### - # VarInfo VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) @@ -530,8 +546,7 @@ function TypedVarInfo(vi::UntypedVarInfo) sym_dists = getindex.((meta.dists,), inds) # New gids, can make a resizeable FillArray sym_gids = getindex.((meta.gids,), inds) - @assert length(sym_gids) <= 1 || - all(x -> x == sym_gids[1], @view sym_gids[2:end]) + @assert length(sym_gids) <= 1 || all(x -> x == sym_gids[1], @view sym_gids[2:end]) # New orders sym_orders = getindex.((meta.orders,), inds) # New flags @@ -544,13 +559,24 @@ function TypedVarInfo(vi::UntypedVarInfo) sym_ranges = Vector{eltype(_ranges)}(undef, n) start = 0 for i in 1:n - sym_ranges[i] = start + 1 : start + length(_vals[i]) + sym_ranges[i] = (start + 1):(start + length(_vals[i])) start += length(_vals[i]) end sym_vals = foldl(vcat, _vals) - push!(new_metas, Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, - sym_dists, sym_gids, sym_orders, sym_flags)) + push!( + new_metas, + Metadata( + sym_idcs, + sym_vns, + sym_ranges, + sym_vals, + sym_dists, + sym_gids, + sym_orders, + sym_flags, + ), + ) end logp = getlogp(vi) num_produce = get_num_produce(vi) @@ -606,7 +632,9 @@ end Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -setgid!(vi::VarInfo, gid::Selector, vn::VarName) = push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +function setgid!(vi::VarInfo, gid::Selector, vn::VarName) + return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +end """ istrans(vi::VarInfo, vn::VarName) @@ -714,7 +742,11 @@ function link!(vi::UntypedVarInfo, spl::Sampler) @debug "X -> ℝ for $(vn)..." dist = getdist(vi, vn) # TODO: Use inplace versions to avoid allocations - setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn) + setval!( + vi, + vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), + vn, + ) settrans!(vi, true, vn) end else @@ -728,24 +760,36 @@ function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end -@generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} +@generated function _link!( + metadata::NamedTuple{names}, vi, vns, ::Val{space} +) where {names,space} expr = Expr(:block) for f in names if inspace(f, space) || length(space) == 0 - push!(expr.args, quote - f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - @debug "X -> R for $(vn)..." - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, true, vn) + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + if ~istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + @debug "X -> R for $(vn)..." + dist = getdist(vi, vn) + setval!( + vi, + vectorize( + dist, + Bijectors.link(dist, reconstruct(dist, getval(vi, vn))), + ), + vn, + ) + settrans!(vi, true, vn) + end + else + @warn("[DynamicPPL] attempt to link a linked vi") end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end) + end, + ) end end return expr @@ -765,7 +809,11 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) for vn in vns @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn) + setval!( + vi, + vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), + vn, + ) settrans!(vi, false, vn) end else @@ -779,30 +827,43 @@ function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end -@generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} +@generated function _invlink!( + metadata::NamedTuple{names}, vi, vns, ::Val{space} +) where {names,space} expr = Expr(:block) for f in names if inspace(f, space) || length(space) == 0 - push!(expr.args, quote - f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - @debug "ℝ -> X for $(vn)..." - dist = getdist(vi, vn) - setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn) - settrans!(vi, false, vn) + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + @debug "ℝ -> X for $(vn)..." + dist = getdist(vi, vn) + setval!( + vi, + vectorize( + dist, + Bijectors.invlink( + dist, reconstruct(dist, getval(vi, vn)) + ), + ), + vn, + ) + settrans!(vi, false, vn) + end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end) + end, + ) end end return expr end - """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -813,11 +874,11 @@ Turing's Hamiltonian samplers use the `link` and `invlink` functions from (for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. """ -function islinked(vi::UntypedVarInfo, spl::Union{Sampler, SampleFromPrior}) +function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) vns = _getvns(vi, spl) return istrans(vi, vns[1]) end -function islinked(vi::TypedVarInfo, spl::Union{Sampler, SampleFromPrior}) +function islinked(vi::TypedVarInfo, spl::Union{Sampler,SampleFromPrior}) vns = _getvns(vi, spl) return _islinked(vi, vns) end @@ -844,16 +905,20 @@ If the value(s) is (are) transformed to the Euclidean space, it is function getindex(vi::AbstractVarInfo, vn::VarName) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" dist = getdist(vi, vn) - return istrans(vi, vn) ? - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) : + return if istrans(vi, vn) + Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) + else reconstruct(dist, getval(vi, vn)) + end end function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" dist = getdist(vi, vns[1]) - return istrans(vi, vns[1]) ? - Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) : + return if istrans(vi, vns[1]) + Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) + else reconstruct(dist, getval(vi, vns), length(vns)) + end end """ @@ -915,7 +980,7 @@ end start = :($offset + 1) len = :(length($f_range)) finish = :($offset + $len) - push!(expr.args, :(@views $f_vals[$f_range] .= val[$start:$finish])) + push!(expr.args, :(@views $f_vals[$f_range] .= val[($start):($finish)])) offset = :($offset + $len) end return expr @@ -941,7 +1006,10 @@ end length(names) === 0 && return :(NamedTuple()) expr = Expr(:tuple) map(names) do f - push!(expr.args, Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns)))) + push!( + expr.args, + Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns))), + ) end return expr end @@ -953,8 +1021,8 @@ end return map(vn -> vi[vn], f_vns) end -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler, SampleFromPrior}) - return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi), typeof(spl)})) +function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) + return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) end """ @@ -984,17 +1052,16 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) | flags : $(vi.metadata.flags) \\======================================================================= """ - print(io, vi_str) + return print(io, vi_str) end - const _MAX_VARS_SHOWN = 4 function _show_varnames(io::IO, vi) md = vi.metadata vns = md.vns - vns_by_name = Dict{Symbol, Vector{VarName}}() + vns_by_name = Dict{Symbol,Vector{VarName}}() for vn in vns group = get!(() -> Vector{VarName}(), vns_by_name, getsym(vn)) push!(group, vn) @@ -1014,11 +1081,10 @@ end function Base.show(io::IO, vi::UntypedVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi), digits=3)) - print(io, ")") + print(io, "; logp: ", round(getlogp(vi); digits=3)) + return print(io, ")") end - """ push!(vi::VarInfo, vn::VarName, r, dist::Distribution) @@ -1040,7 +1106,9 @@ The sampler is passed here to invalidate its cache where defined. function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) return push!(vi, vn, r, dist, spl.selector) end -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) +function push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler +) return push!(vi, vn, r, dist) end @@ -1053,14 +1121,7 @@ selector `gid` from a distribution `dist` to `VarInfo` `vi`. function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) return push!(vi, vn, r, dist, Set([gid])) end -function push!( - vi::VarInfo, - vn::VarName, - r, - dist::Distribution, - gidset::Set{Selector} - ) - +function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo @@ -1072,8 +1133,9 @@ function push!( meta = getmetadata(vi, vn) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) - l = length(meta.vals); n = length(val) - push!(meta.ranges, l+1:l+n) + l = length(meta.vals) + n = length(val) + push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) push!(meta.gids, gidset) @@ -1129,8 +1191,8 @@ function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) # Get the indices of `vns` that belong to `spl` as a vector gidcs = _getidcs(vi, spl) if get_num_produce(vi) == 0 - for i = length(gidcs):-1:1 - vi.metadata.flags["del"][gidcs[i]] = true + for i in length(gidcs):-1:1 + vi.metadata.flags["del"][gidcs[i]] = true end else for i in 1:length(vi.orders) @@ -1146,26 +1208,31 @@ function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) gidcs = _getidcs(vi, spl) return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) end -@generated function _set_retained_vns_del_by_spl!(metadata, gidcs::NamedTuple{names}, num_produce) where {names} +@generated function _set_retained_vns_del_by_spl!( + metadata, gidcs::NamedTuple{names}, num_produce +) where {names} expr = Expr(:block) for f in names f_gidcs = :(gidcs.$f) f_orders = :(metadata.$f.orders) f_flags = :(metadata.$f.flags) - push!(expr.args, quote - # Set the flag for variables with symbol `f` - if num_produce == 0 - for i = length($f_gidcs):-1:1 - $f_flags["del"][$f_gidcs[i]] = true - end - else - for i in 1:length($f_orders) - if i in $f_gidcs && $f_orders[i] > num_produce - $f_flags["del"][i] = true + push!( + expr.args, + quote + # Set the flag for variables with symbol `f` + if num_produce == 0 + for i in length($f_gidcs):-1:1 + $f_flags["del"][$f_gidcs[i]] = true + end + else + for i in 1:length($f_orders) + if i in $f_gidcs && $f_orders[i] > num_produce + $f_flags["del"][i] = true + end end end - end - end) + end, + ) end return expr end @@ -1210,15 +1277,12 @@ function _apply!(kernel!, vi::AbstractVarInfo, values, keys) return vi end -_apply!(kernel!, vi::TypedVarInfo, values, keys) = _typed_apply!( - kernel!, vi, vi.metadata, values, collectmaybe(keys)) +function _apply!(kernel!, vi::TypedVarInfo, values, keys) + return _typed_apply!(kernel!, vi, vi.metadata, values, collectmaybe(keys)) +end @generated function _typed_apply!( - kernel!, - vi::TypedVarInfo, - metadata::NamedTuple{names}, - values, - keys + kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys ) where {names} updates = map(names) do n quote @@ -1316,8 +1380,12 @@ julia> var_info[@varname(x[1])] # [✓] unchanged ``` """ setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x)) -function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - return _apply!(_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains)) +function setval!( + vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int +) + return _apply!( + _setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains) + ) end function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) @@ -1389,9 +1457,18 @@ julia> var_info[@varname(x[1])] # [✓] changed ## See also - [`setval!`](@ref) """ -setval_and_resample!(vi::AbstractVarInfo, x) = _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x)) -function setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - return _apply!(_setval_and_resample_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains)) +function setval_and_resample!(vi::AbstractVarInfo, x) + return _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x)) +end +function setval_and_resample!( + vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int +) + return _apply!( + _setval_and_resample_kernel!, + vi, + chains.value[sample_idx, :, chain_idx], + keys(chains), + ) end function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) diff --git a/src/varname.jl b/src/varname.jl index ca5823a8a..bb936a4ce 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -22,11 +22,10 @@ Statically check whether the variable of name `varname` is an argument of the `m Possibly existing indices of `varname` are neglected. """ -@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F} +@generated function inargnames(::VarName{s}, ::Model{_F,argnames}) where {s,argnames,_F} return s in argnames end - """ inmissings(varname::VarName, model::Model) @@ -35,6 +34,8 @@ of the `model`. Possibly existing indices of `varname` are neglected. """ -@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T} +@generated function inmissings( + ::VarName{s}, ::Model{_F,_a,_T,missings} +) where {s,missings,_F,_a,_T} return s in missings end diff --git a/test/compat/ad.jl b/test/compat/ad.jl index 784be8983..1d8b02c55 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -6,14 +6,15 @@ m = x[2] dist = Normal(m, sqrt(s)) - return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + - logpdf(dist, 1.5) + logpdf(dist, 2.0) + return logpdf(InverseGamma(2, 3), s) + + logpdf(Normal(0, sqrt(s)), m) + + logpdf(dist, 1.5) + logpdf(dist, 2.0) end test_model_ad(gdemo_default, logp_gdemo_default) @model function wishart_ad() - v ~ Wishart(7, [1 0.5; 0.5 1]) + return v ~ Wishart(7, [1 0.5; 0.5 1]) end # Hand-written log probabilities for `x = [v]`. @@ -28,7 +29,9 @@ # https://github.com/TuringLang/Turing.jl/issues/1595 @testset "dot_observe" begin function f_dot_observe(x) - return DynamicPPL.dot_observe(SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo()) + return DynamicPPL.dot_observe( + SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() + ) end function f_dot_observe_manual(x) return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) diff --git a/test/compiler.jl b/test/compiler.jl index 244738c3d..78b472563 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,6 +1,5 @@ macro custom(expr) - (Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || - error("incorrect macro usage") + (Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || error("incorrect macro usage") quote $(esc(expr.args[2])) = 0.0 end @@ -32,92 +31,92 @@ end @testset "compiler.jl" begin @testset "model macro" begin @model function testmodel_comp(x, y) - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) + x ~ Normal(m, sqrt(s)) + y ~ Normal(m, sqrt(s)) return x, y end testmodel_comp(1.0, 1.2) # check if drawing from the prior works - @model function testmodel01(x = missing) - x ~ Normal() + @model function testmodel01(x=missing) + x ~ Normal() return x end f0_mm = testmodel01() - @test mean(f0_mm() for _ in 1:1000) ≈ 0. atol=0.1 + @test mean(f0_mm() for _ in 1:1000) ≈ 0.0 atol = 0.1 # Test #544 - @model function testmodel02(x = missing) + @model function testmodel02(x=missing) if x === missing x = Vector{Float64}(undef, 2) end - x[1] ~ Normal() - x[2] ~ Normal() + x[1] ~ Normal() + x[2] ~ Normal() return x end f0_mm = testmodel02() - @test all(x -> isapprox(x, 0; atol = 0.1), mean(f0_mm() for _ in 1:1000)) + @test all(x -> isapprox(x, 0; atol=0.1), mean(f0_mm() for _ in 1:1000)) - @model function testmodel03(x = missing) - x ~ Bernoulli(0.5) + @model function testmodel03(x=missing) + x ~ Bernoulli(0.5) return x end f01_mm = testmodel03() - @test mean(f01_mm() for _ in 1:1000) ≈ 0.5 atol=0.1 + @test mean(f01_mm() for _ in 1:1000) ≈ 0.5 atol = 0.1 # test if we get the correct return values @model function testmodel1(x1, x2) - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) - x1 ~ Normal(m, sqrt(s)) - x2 ~ Normal(m, sqrt(s)) + x1 ~ Normal(m, sqrt(s)) + x2 ~ Normal(m, sqrt(s)) return x1, x2 end - f1_mm = testmodel1(1., 10.) + f1_mm = testmodel1(1.0, 10.0) @test f1_mm() == (1, 10) # alternatives with keyword arguments testmodel1kw(; x1, x2) = testmodel1(x1, x2) - f1_mm = testmodel1kw(x1 = 1., x2 = 10.) + f1_mm = testmodel1kw(; x1=1.0, x2=10.0) @test f1_mm() == (1, 10) @model function testmodel2(; x1, x2) - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) - x1 ~ Normal(m, sqrt(s)) - x2 ~ Normal(m, sqrt(s)) + x1 ~ Normal(m, sqrt(s)) + x2 ~ Normal(m, sqrt(s)) return x1, x2 end - f1_mm = testmodel2(x1=1., x2=10.) + f1_mm = testmodel2(; x1=1.0, x2=10.0) @test f1_mm() == (1, 10) @info "Testing the compiler's ability to catch bad models..." # Test for assertions in observe statements. @model function brokentestmodel_observe1(x1, x2) - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) - x1 ~ Normal(m, sqrt(s)) - x2 ~ x1 + 2 + x1 ~ Normal(m, sqrt(s)) + x2 ~ x1 + 2 return x1, x2 end - btest = brokentestmodel_observe1(1., 2.) + btest = brokentestmodel_observe1(1.0, 2.0) @test_throws ArgumentError btest() @model function brokentestmodel_observe2(x) - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) x = Vector{Float64}(undef, 2) x ~ [Normal(m, sqrt(s)), 2.0] @@ -125,16 +124,16 @@ end return x end - btest = brokentestmodel_observe2([1., 2.]) + btest = brokentestmodel_observe2([1.0, 2.0]) @test_throws ArgumentError btest() # Test for assertions in assume statements. @model function brokentestmodel_assume1() - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) - x1 ~ Normal(m, sqrt(s)) - x2 ~ x1 + 2 + x1 ~ Normal(m, sqrt(s)) + x2 ~ x1 + 2 return x1, x2 end @@ -143,8 +142,8 @@ end @test_throws ArgumentError btest() @model function brokentestmodel_assume2() - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) x = Vector{Float64}(undef, 2) x ~ [Normal(m, sqrt(s)), 2.0] @@ -212,7 +211,7 @@ end return m end model = testmodel_missing5(rand(10)) - @test all(z -> isapprox(z, 0; atol = 0.2), mean(model() for _ in 1:1000)) + @test all(z -> isapprox(z, 0; atol=0.2), mean(model() for _ in 1:1000)) # test Turing#1464 @model function gdemo(x) @@ -235,7 +234,7 @@ end @testset "nested model" begin function makemodel(p) @model function testmodel(x) - x[1] ~ Bernoulli(p) + x[1] ~ Bernoulli(p) global lp = getlogp(__varinfo__) return x end @@ -247,7 +246,7 @@ end end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) - @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:,1])) + @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) vi1 = VarInfo(f1()) vi2 = VarInfo(f2()) @@ -271,24 +270,24 @@ end "This is a test" @model function demo(x) m ~ Normal() - x ~ Normal(m, 1) + return x ~ Normal(m, 1) end s = @doc(demo) @test string(s) == "This is a test\n" # Verify that adding docstring didn't completely break execution of model - m = demo(0.) + m = demo(0.0) @test m() isa Float64 end @testset "type annotations" begin @model function demo_without(x) - x ~ Normal() + return x ~ Normal() end @test isempty(VarInfo(demo_without(0.0))) @model function demo_with(x::Real) - x ~ Normal() + return x ~ Normal() end @test isempty(VarInfo(demo_with(0.0))) end @@ -309,7 +308,7 @@ end # expanded => errors since `MyModelStruct` is not a distribution, # and hence `tilde_observe` errors. @model function demo2() - $(@mymodel2(y ~ Uniform())) + return $(@mymodel2(y ~ Uniform())) end @test demo2()() == 42 end @@ -317,76 +316,76 @@ end @testset "submodel" begin # No prefix, 1 level. @model function demo1(x) - x ~ Normal() - end; + return x ~ Normal() + end @model function demo2(x, y) @submodel demo1(x) - y ~ Uniform() - end; + return y ~ Uniform() + end # No observation. - m = demo2(missing, missing); - vi = VarInfo(m); + m = demo2(missing, missing) + vi = VarInfo(m) ks = keys(vi) @test VarName(:x) ∈ ks @test VarName(:y) ∈ ks # Observation in top-level. - m = demo2(missing, 1.0); - vi = VarInfo(m); + m = demo2(missing, 1.0) + vi = VarInfo(m) ks = keys(vi) @test VarName(:x) ∈ ks @test VarName(:y) ∉ ks # Observation in nested model. - m = demo2(1000.0, missing); - vi = VarInfo(m); + m = demo2(1000.0, missing) + vi = VarInfo(m) ks = keys(vi) @test VarName(:x) ∉ ks @test VarName(:y) ∈ ks # Observe all. - m = demo2(1000.0, 0.5); - vi = VarInfo(m); + m = demo2(1000.0, 0.5) + vi = VarInfo(m) ks = keys(vi) @test isempty(ks) # Check values makes sense. @model function demo2(x, y) @submodel demo1(x) - y ~ Normal(x) - end; - m = demo2(1000.0, missing); + return y ~ Normal(x) + end + m = demo2(1000.0, missing) # Mean of `y` should be close to 1000. - @test abs(mean([VarInfo(m)[VarName(:y)] for i = 1:10]) - 1000) ≤ 10; + @test abs(mean([VarInfo(m)[VarName(:y)] for i in 1:10]) - 1000) ≤ 10 # Prefixed submodels and usage of submodel return values. @model function demo_return(x) x ~ Normal() return x - end; + end @model function demo_useval(x, y) x1 = @submodel sub1 demo_return(x) x2 = @submodel sub2 demo_return(y) - z ~ Normal(x1 + x2 + 100, 1.0) - end; + return z ~ Normal(x1 + x2 + 100, 1.0) + end m = demo_useval(missing, missing) - vi = VarInfo(m); + vi = VarInfo(m) ks = keys(vi) @test VarName(Symbol("sub1.x")) ∈ ks @test VarName(Symbol("sub2.x")) ∈ ks @test VarName(:z) ∈ ks - @test abs(mean([VarInfo(m)[VarName(:z)] for i = 1:10]) - 100) ≤ 10 + @test abs(mean([VarInfo(m)[VarName(:z)] for i in 1:10]) - 100) ≤ 10 # AR1 model. Dynamic prefixing. - @model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV} + @model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV} η ~ MvNormal(num_steps, 1.0) δ = sqrt(1 - α^2) x = TV(undef, num_steps) x[1] = η[1] - @inbounds for t = 2:num_steps + @inbounds for t in 2:num_steps x[t] = @. α * x[t - 1] + δ * η[t] end @@ -400,15 +399,15 @@ end num_steps = length(y[1]) num_obs = length(y) - @inbounds for i = 1:num_obs + @inbounds for i in 1:num_obs x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end - end; + end - ys = [randn(10), randn(10)]; - m = demo(ys); - vi = VarInfo(m); + ys = [randn(10), randn(10)] + m = demo(ys) + vi = VarInfo(m) for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")] @test VarName(k) ∈ keys(vi) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 54843b243..dfdc5dd9e 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -31,7 +31,7 @@ ntuple(i -> randn(ysize[1:i]), length(ysize))..., # singleton dimensions ntuple( - i -> randn(ysize[1:(i-1)]..., 1, ysize[(i+1):end]...), + i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), length(ysize), )..., ) @@ -48,7 +48,7 @@ @testset "observe" begin @model function test(x, y) - y .~ Normal.(x) + return y .~ Normal.(x) end for ysize in ((2,), (2, 3), (2, 3, 4)) @@ -59,7 +59,7 @@ ntuple(i -> randn(ysize[1:i]), length(ysize))..., # singleton dimensions ntuple( - i -> randn(ysize[1:(i-1)]..., 1, ysize[(i+1):end]...), + i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), length(ysize), )..., ) diff --git a/test/independence.jl b/test/independence.jl index 91d8de2c5..a4a834a61 100644 --- a/test/independence.jl +++ b/test/independence.jl @@ -2,10 +2,10 @@ @model coinflip(y) = begin p ~ Beta(1, 1) N = length(y) - for i = 1:N + for i in 1:N y[i] ~ Bernoulli(p) end end - model = coinflip([1,1,0]) + model = coinflip([1, 1, 0]) model(SampleFromPrior(), LikelihoodContext()) end diff --git a/test/model.jl b/test/model.jl index 8309e4059..2cdeae5fa 100644 --- a/test/model.jl +++ b/test/model.jl @@ -55,7 +55,7 @@ @testset "nameof" begin @model function test1(x) m ~ Normal(0, 1) - x ~ Normal(m, 1) + return x ~ Normal(m, 1) end @model test2(x) = begin m ~ Normal(0, 1) diff --git a/test/prob_macro.jl b/test/prob_macro.jl index 4c990e625..774b6e4c4 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -2,7 +2,7 @@ @testset "scalar" begin @model function demo(x) m ~ Normal() - x ~ Normal(m, 1) + return x ~ Normal(m, 1) end mval = 3 @@ -29,9 +29,9 @@ end @testset "vector" begin n = 5 - @model function demo(x, n = n) + @model function demo(x, n=n) m ~ MvNormal(n, 1.0) - x ~ MvNormal(m, 1.0) + return x ~ MvNormal(m, 1.0) end mval = rand(n) xval = rand(n) diff --git a/test/runtests.jl b/test/runtests.jl index 2dbbdb47f..809fe8100 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,10 +52,7 @@ include("test_util.jl") @testset "doctests" begin DocMeta.setdocmeta!( - DynamicPPL, - :DocTestSetup, - :(using DynamicPPL); - recursive=true, + DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true ) doctest(DynamicPPL; manual=false) end @@ -65,7 +62,7 @@ include("test_util.jl") @testset "turing" begin # activate separate test environment Pkg.activate(DIRECTORY_Turing_tests) - Pkg.develop(PackageSpec(path=DIRECTORY_DynamicPPL)) + Pkg.develop(PackageSpec(; path=DIRECTORY_DynamicPPL)) Pkg.instantiate() # make sure that the new environment is considered `using` and `import` statements diff --git a/test/sampler.jl b/test/sampler.jl index 4959bf845..b32ba29d6 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -4,13 +4,13 @@ s ~ InverseGamma(2, 3) m ~ Normal(2.0, sqrt(s)) x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) + return y ~ Normal(m, sqrt(s)) end model = gdemo(1.0, 2.0) N = 1_000 - chains = sample(model, SampleFromPrior(), N; progress = false) + chains = sample(model, SampleFromPrior(), N; progress=false) @test chains isa Vector{<:VarInfo} @test length(chains) == N @@ -20,7 +20,7 @@ # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.1 - chains = sample(model, SampleFromUniform(), N; progress = false) + chains = sample(model, SampleFromUniform(), N; progress=false) @test chains isa Vector{<:VarInfo} @test length(chains) == N @@ -41,7 +41,7 @@ ::Sampler{<:OnlyInitAlg}, vi::AbstractVarInfo; kwargs..., - ) + ) return vi, nothing end DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () @@ -54,19 +54,18 @@ # model with one variable: initialization p = 0.2 @model function coinflip() p ~ Beta(1, 1) - 10 ~ Binomial(25, p) + return 10 ~ Binomial(25, p) end model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - chain = sample(model, sampler, 1; init_params = 0.2, progress = false) + chain = sample(model, sampler, 1; init_params=0.2, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogp(chain[1]) == lptrue # parallel sampling chains = sample( - model, sampler, MCMCThreads(), 1, 10; - init_params = 0.2, progress = false, + model, sampler, MCMCThreads(), 1, 10; init_params=0.2, progress=false ) for c in chains @test c[1].metadata.p.vals == [0.2] @@ -76,19 +75,18 @@ # model with two variables: initialization s = 4, m = -1 @model function twovars() s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) + return m ~ Normal(0, sqrt(s)) end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) + chain = sample(model, sampler, 1; init_params=[4, -1], progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogp(chain[1]) == lptrue # parallel sampling chains = sample( - model, sampler, MCMCThreads(), 1, 10; - init_params = [4, -1], progress = false, + model, sampler, MCMCThreads(), 1, 10; init_params=[4, -1], progress=false ) for c in chains @test c[1].metadata.s.vals == [4] @@ -97,14 +95,19 @@ end # set only m = -1 - chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) + chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] # parallel sampling chains = sample( - model, sampler, MCMCThreads(), 1, 10; - init_params = [missing, -1], progress = false, + model, + sampler, + MCMCThreads(), + 1, + 10; + init_params=[missing, -1], + progress=false, ) for c in chains @test !ismissing(c[1].metadata.s.vals[1]) diff --git a/test/serialization.jl b/test/serialization.jl index bf5adf01f..2f2bf2a2b 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -10,8 +10,8 @@ samples_s = first.(samples) samples_m = last.(samples) - @test mean(samples_s) ≈ 3 atol=0.1 - @test mean(samples_m) ≈ 0 atol=0.1 + @test mean(samples_s) ≈ 3 atol = 0.1 + @test mean(samples_m) ≈ 0 atol = 0.1 end @testset "pmap" begin # Add worker processes. @@ -26,7 +26,7 @@ # Define model on all proceses. @everywhere @model function model() - m ~ Normal(0, 1) + return m ~ Normal(0, 1) end # Generate `Model` objects on all processes. @@ -43,8 +43,8 @@ for samples in (samples1, samples2) @test samples isa Vector{Float64} @test length(samples) == n - @test mean(samples) ≈ 0 atol=0.15 - @test std(samples) ≈ 1 atol=0.1 + @test mean(samples) ≈ 0 atol = 0.15 + @test std(samples) ≈ 1 atol = 0.1 end # Remove processes diff --git a/test/test_util.jl b/test/test_util.jl index 651f63ff8..f5e4dc40f 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -46,7 +46,6 @@ function test_model_ad(model, logp_manual) @test back(1)[1] ≈ grad end - """ test_setval!(model, chain; sample_idx = 1, chain_idx = 1) @@ -55,7 +54,7 @@ Test `setval!` on `model` and `chain`. Worth noting that this only supports models containing symbols of the forms `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. """ -function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) +function test_setval!(model, chain; sample_idx=1, chain_idx=1) var_info = VarInfo(model) spl = SampleFromPrior() θ_old = var_info[spl] diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 629046a05..746d6a5f8 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -42,7 +42,7 @@ global vi_ = __varinfo__ x[1] ~ Normal(0, 1) Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i-1], 1) + x[i] ~ Normal(x[i - 1], 1) end end @@ -60,20 +60,22 @@ @time wthreads(x)(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi, - SampleFromPrior(), DefaultContext()) + DynamicPPL.evaluate_threadsafe( + Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") - @time DynamicPPL.evaluate_threadsafe(Random.GLOBAL_RNG, wthreads(x), vi, - SampleFromPrior(), DefaultContext()) + @time DynamicPPL.evaluate_threadsafe( + Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + ) @model function wothreads(x) global vi_ = __varinfo__ x[1] ~ Normal(0, 1) for i in 2:length(x) - x[i] ~ Normal(x[i-1], 1) + x[i] ~ Normal(x[i - 1], 1) end end @@ -93,13 +95,15 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi, - SampleFromPrior(), DefaultContext()) + DynamicPPL.evaluate_threadunsafe( + Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") - @time DynamicPPL.evaluate_threadunsafe(Random.GLOBAL_RNG, wothreads(x), vi, - SampleFromPrior(), DefaultContext()) + @time DynamicPPL.evaluate_threadunsafe( + Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + ) end end diff --git a/test/turing/compiler.jl b/test/turing/compiler.jl index 753627790..f12e50633 100644 --- a/test/turing/compiler.jl +++ b/test/turing/compiler.jl @@ -3,7 +3,7 @@ @model function test_assume() x ~ Bernoulli(1) y ~ Bernoulli(x / 2) - x, y + return x, y end smc = SMC() @@ -12,26 +12,26 @@ res1 = sample(test_assume(), smc, 1000) res2 = sample(test_assume(), pg, 1000) - check_numerical(res1, [:y], [0.5], atol=0.1) - check_numerical(res2, [:y], [0.5], atol=0.1) + check_numerical(res1, [:y], [0.5]; atol=0.1) + check_numerical(res2, [:y], [0.5]; atol=0.1) # Check that all xs are 1. @test all(isone, res1[:x]) @test all(isone, res2[:x]) end @testset "beta binomial" begin - prior = Beta(2,2) - obs = [0,1,0,1,1,1,1,1,1,1] + prior = Beta(2, 2) + obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] exact = Beta(prior.α + sum(obs), prior.β + length(obs) - sum(obs)) meanp = exact.α / (exact.α + exact.β) @model function testbb(obs) - p ~ Beta(2,2) + p ~ Beta(2, 2) x ~ Bernoulli(p) for i in 1:length(obs) obs[i] ~ Bernoulli(p) end - p, x + return p, x end smc = SMC() @@ -42,25 +42,25 @@ chn_p = sample(testbb(obs), pg, 2000) chn_g = sample(testbb(obs), gibbs, 1500) - check_numerical(chn_s, [:p], [meanp], atol=0.05) - check_numerical(chn_p, [:x], [meanp], atol=0.1) - check_numerical(chn_g, [:x], [meanp], atol=0.1) + check_numerical(chn_s, [:p], [meanp]; atol=0.05) + check_numerical(chn_p, [:x], [meanp]; atol=0.1) + check_numerical(chn_g, [:x], [meanp]; atol=0.1) end @testset "forbid global" begin xs = [1.5 2.0] # xx = 1 @model function fggibbstest(xs) - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt(s)) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) # xx ~ Normal(m, sqrt(s)) # this is illegal - for i = 1:length(xs) + for i in 1:length(xs) xs[i] ~ Normal(m, sqrt(s)) # for xx in xs # xx ~ Normal(m, sqrt(s)) end - s, m + return s, m end gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m)) @@ -71,53 +71,51 @@ @model function gauss(x) priors = TArray{Float64}(2) - priors[1] ~ InverseGamma(2,3) # s + priors[1] ~ InverseGamma(2, 3) # s priors[2] ~ Normal(0, sqrt(priors[1])) # m for i in 1:length(x) x[i] ~ Normal(priors[2], sqrt(priors[1])) end - priors + return priors end chain = sample(gauss(x), PG(10), 10) chain = sample(gauss(x), SMC(), 10) - @model function gauss2(::Type{TV} = Vector{Float64}; x) where {TV} + @model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV} priors = TV(undef, 2) - priors[1] ~ InverseGamma(2,3) # s + priors[1] ~ InverseGamma(2, 3) # s priors[2] ~ Normal(0, sqrt(priors[1])) # m for i in 1:length(x) x[i] ~ Normal(priors[2], sqrt(priors[1])) end - priors + return priors end - chain = sample(gauss2(x = x), PG(10), 10) - chain = sample(gauss2(x = x), SMC(), 10) + chain = sample(gauss2(; x=x), PG(10), 10) + chain = sample(gauss2(; x=x), SMC(), 10) - chain = sample(gauss2(Vector{Float64}; x = x), PG(10), 10) - chain = sample(gauss2(Vector{Float64}; x = x), SMC(), 10) + chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) + chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10) end @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @model function newinterface(obs) - p ~ Beta(2,2) - for i = 1:length(obs) + p ~ Beta(2, 2) + for i in 1:length(obs) obs[i] ~ Bernoulli(p) end - p + return p end chain = sample( - newinterface(obs), - HMC{Turing.ForwardDiffAD{2}}(0.75, 3, :p, :x), - 100, + newinterface(obs), HMC{Turing.ForwardDiffAD{2}}(0.75, 3, :p, :x), 100 ) end @testset "no return" begin @model function noreturn(x) - s ~ InverseGamma(2,3) + s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(s)) @@ -125,20 +123,20 @@ end chain = sample(noreturn([1.5 2.0]), HMC(0.15, 6), 1000) - check_numerical(chain, [:s, :m], [49/24, 7/6]) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]) end @testset "observe" begin @model function test() - z ~ Normal(0,1) - x ~ Bernoulli(1) - 1 ~ Bernoulli(x / 2) - 0 ~ Bernoulli(x / 2) - x + z ~ Normal(0, 1) + x ~ Bernoulli(1) + 1 ~ Bernoulli(x / 2) + 0 ~ Bernoulli(x / 2) + return x end - is = IS() + is = IS() smc = SMC() - pg = PG(10) + pg = PG(10) res_is = sample(test(), is, 10000) res_smc = sample(test(), smc, 1000) @@ -154,11 +152,11 @@ end @testset "sample" begin alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s)) - chn = sample(gdemo_default, alg, 1000); + chn = sample(gdemo_default, alg, 1000) end @testset "vectorization @." begin @model function vdemo1(x) - s ~ InverseGamma(2,3) + s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @. x ~ Normal(m, sqrt(s)) return s, m @@ -169,7 +167,7 @@ res = sample(vdemo1(x), alg, 250) @model function vdemo1b(x) - s ~ InverseGamma(2,3) + s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @. x ~ Normal(m, $(sqrt(s))) return s, m @@ -193,7 +191,7 @@ @model function vdemo3() x = Vector{Real}(undef, N) - for i = 1:N + for i in 1:N x[i] ~ Normal(0, sqrt(4)) end end @@ -202,8 +200,8 @@ # Test for vectorize UnivariateDistribution @model function vdemo4() - x = Vector{Real}(undef, N) - @. x ~ Normal(0, 2) + x = Vector{Real}(undef, N) + @. x ~ Normal(0, 2) end t_vec = @elapsed res = sample(vdemo4(), alg, 1000) @@ -235,7 +233,7 @@ end @testset "vectorization .~" begin @model function vdemo1(x) - s ~ InverseGamma(2,3) + s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) x .~ Normal(m, sqrt(s)) return s, m @@ -248,11 +246,11 @@ D = 2 @model function vdemo2(x) μ ~ MvNormal(zeros(D), ones(D)) - x .~ MvNormal(μ, ones(D)) + return x .~ MvNormal(μ, ones(D)) end alg = HMC(0.01, 5) - res = sample(vdemo2(randn(D,100)), alg, 250) + res = sample(vdemo2(randn(D, 100)), alg, 250) # Vector assumptions N = 10 @@ -261,7 +259,7 @@ @model function vdemo3() x = Vector{Real}(undef, N) - for i = 1:N + for i in 1:N x[i] ~ Normal(0, sqrt(4)) end end @@ -271,7 +269,7 @@ # Test for vectorize UnivariateDistribution @model function vdemo4() x = Vector{Real}(undef, N) - x .~ Normal(0, 2) + return x .~ Normal(0, 2) end t_vec = @elapsed res = sample(vdemo4(), alg, 1000) @@ -288,14 +286,14 @@ # Transformed test @model function vdemo6() x = Vector{Real}(undef, N) - x .~ InverseGamma(2, 3) + return x .~ InverseGamma(2, 3) end sample(vdemo6(), alg, 1000) @model function vdemo7() x = Array{Real}(undef, N, N) - x .~ [InverseGamma(2, 3) for i in 1:N] + return x .~ [InverseGamma(2, 3) for i in 1:N] end sample(vdemo7(), alg, 1000) @@ -307,7 +305,7 @@ x = randn(1000) @model function vdemo1(::Type{T}=Float64) where {T} x = Vector{T}(undef, N) - for i = 1:N + for i in 1:N x[i] ~ Normal(0, sqrt(4)) end end @@ -316,9 +314,9 @@ t_loop = @elapsed res = sample(vdemo1(Float64), alg, 250) vdemo1kw(; T) = vdemo1(T) - t_loop = @elapsed res = sample(vdemo1kw(T = Float64), alg, 250) + t_loop = @elapsed res = sample(vdemo1kw(; T=Float64), alg, 250) - @model function vdemo2(::Type{T}=Float64) where {T <: Real} + @model function vdemo2(::Type{T}=Float64) where {T<:Real} x = Vector{T}(undef, N) @. x ~ Normal(0, 2) end @@ -327,9 +325,9 @@ t_vec = @elapsed res = sample(vdemo2(Float64), alg, 250) vdemo2kw(; T) = vdemo2(T) - t_vec = @elapsed res = sample(vdemo2kw(T = Float64), alg, 250) + t_vec = @elapsed res = sample(vdemo2kw(; T=Float64), alg, 250) - @model function vdemo3(::Type{TV}=Vector{Float64}) where {TV <: AbstractVector} + @model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector} x = TV(undef, N) @. x ~ InverseGamma(2, 3) end @@ -338,6 +336,6 @@ sample(vdemo3(Vector{Float64}), alg, 250) vdemo3kw(; T) = vdemo3(T) - sample(vdemo3kw(T = Vector{Float64}), alg, 250) + sample(vdemo3kw(; T=Vector{Float64}), alg, 250) end -end \ No newline at end of file +end diff --git a/test/turing/loglikelihoods.jl b/test/turing/loglikelihoods.jl index dcb30f6f7..2fb991089 100644 --- a/test/turing/loglikelihoods.jl +++ b/test/turing/loglikelihoods.jl @@ -6,7 +6,7 @@ xs[i] ~ Normal(m, √s) end - y ~ Normal(m, √s) + return y ~ Normal(m, √s) end xs = randn(3) diff --git a/test/turing/model.jl b/test/turing/model.jl index ee488b64c..c41b2a5be 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -1,28 +1,28 @@ @testset "model.jl" begin @testset "setval! & generated_quantities" begin - @model function demo1(xs, ::Type{TV} = Vector{Float64}) where {TV} + @model function demo1(xs, ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) for i in 1:2 m[i] ~ Normal(0, 1) end for i in eachindex(xs) - xs[i] ~ Normal(m[1], 1.) + xs[i] ~ Normal(m[1], 1.0) end - return (m, ) + return (m,) end @model function demo2(xs) - m ~ MvNormal(2, 1.) - + m ~ MvNormal(2, 1.0) + for i in eachindex(xs) - xs[i] ~ Normal(m[1], 1.) + xs[i] ~ Normal(m[1], 1.0) end - return (m, ) + return (m,) end - + xs = randn(3) model1 = demo1(xs) model2 = demo2(xs) @@ -41,47 +41,47 @@ @test all(res11 .== res21) @test all(res12 .== res22) # Ensure that they're not all the same (some can be, because rejected samples) - @test any(res12[1:end - 1] .!= res12[2:end]) + @test any(res12[1:(end - 1)] .!= res12[2:end]) test_setval!(model1, chain1) test_setval!(model2, chain2) # Next level - @model function demo3(xs, ::Type{TV} = Vector{Float64}) where {TV} + @model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV} m = Vector{TV}(undef, 2) - for i = 1:length(m) - m[i] ~ MvNormal(2, 1.) + for i in 1:length(m) + m[i] ~ MvNormal(2, 1.0) end - + for i in eachindex(xs) - xs[i] ~ Normal(m[1][1], 1.) + xs[i] ~ Normal(m[1][1], 1.0) end - return (m, ) + return (m,) end - @model function demo4(xs, ::Type{TV} = Vector{Vector{Float64}}) where {TV} + @model function demo4(xs, ::Type{TV}=Vector{Vector{Float64}}) where {TV} m = TV(undef, 2) - for i = 1:length(m) - m[i] ~ MvNormal(2, 1.) + for i in 1:length(m) + m[i] ~ MvNormal(2, 1.0) end - + for i in eachindex(xs) - xs[i] ~ Normal(m[1][1], 1.) + xs[i] ~ Normal(m[1][1], 1.0) end - return (m, ) + return (m,) end model3 = demo3(xs) model4 = demo4(xs) - + chain3 = sample(model3, MH(), 100) chain4 = sample(model4, MH(), 100) - + res33 = generated_quantities(model3, chain3) res43 = generated_quantities(model4, chain3) - + res34 = generated_quantities(model3, chain4) res44 = generated_quantities(model4, chain4) @@ -90,6 +90,6 @@ @test all(res33 .== res43) @test all(res34 .== res44) # Ensure that they're not all the same (some can be, because rejected samples) - @test any(res34[1:end - 1] .!= res34[2:end]) + @test any(res34[1:(end - 1)] .!= res34[2:end]) end -end \ No newline at end of file +end diff --git a/test/turing/prob_macro.jl b/test/turing/prob_macro.jl index c1aa5ba8d..0eb2a1290 100644 --- a/test/turing/prob_macro.jl +++ b/test/turing/prob_macro.jl @@ -2,7 +2,7 @@ @testset "scalar" begin @model function demo(x) m ~ Normal() - x ~ Normal(m, 1) + return x ~ Normal(m, 1) end mval = 3 @@ -11,7 +11,7 @@ model = demo(xval) varinfo = VarInfo(model) - chain = sample(model, IS(), iters; save_state = true) + chain = sample(model, IS(), iters; save_state=true) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) lps = logpdf.(Normal.(chain["m"], 1), xval) @test logprob"x = xval | chain = chain" == lps @@ -30,9 +30,9 @@ end @testset "vector" begin n = 5 - @model function demo(x, n = n) + @model function demo(x, n=n) m ~ MvNormal(n, 1.0) - x ~ MvNormal(m, 1.0) + return x ~ MvNormal(m, 1.0) end mval = rand(n) xval = rand(n) @@ -40,13 +40,13 @@ model = demo(xval) varinfo = VarInfo(model) - chain = sample(model, HMC(0.5, 1), iters; save_state = true) + chain = sample(model, HMC(0.5, 1), iters; save_state=true) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) names = namesingroup(chain, "m") lps = [ - logpdf(MvNormal(chain.value[i, names, j], 1.0), xval) - for i in 1:size(chain, 1), j in 1:size(chain, 3) + logpdf(MvNormal(chain.value[i, names, j], 1.0), xval) for i in 1:size(chain, 1), + j in 1:size(chain, 3) ] @test logprob"x = xval | chain = chain" == lps @test logprob"x = xval | chain = chain2, model = model" == lps @@ -67,7 +67,7 @@ σ ~ truncated(Cauchy(0, 1), 0, Inf) α ~ filldist(Normal(0, 10), n_groups) μ = α[group] - y ~ MvNormal(μ, σ) + return y ~ MvNormal(μ, σ) end y = randn(100) diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 3f932a5b9..b72832b78 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -1,13 +1,11 @@ @testset "varinfo.jl" begin # Declare empty model to make the Sampler constructor work. - @model empty_model() = begin x = 1; end + @model empty_model() = begin + x = 1 + end function randr( - vi::VarInfo, - vn::VarName, - dist::Distribution, - spl::Sampler, - count::Bool = false, + vi::VarInfo, vn::VarName, dist::Distribution, spl::Sampler, count::Bool=false ) if !haskey(vi, vn) r = rand(dist) @@ -30,7 +28,7 @@ # Test linking spl and vi: # link!, invlink!, istrans @model gdemo(x, y) = begin - s ~ InverseGamma(2,3) + s ~ InverseGamma(2, 3) m ~ Uniform(0, 2) x ~ Normal(m, sqrt(s)) y ~ Normal(m, sqrt(s)) @@ -69,10 +67,10 @@ @test meta.m.vals == v_m # Transforming only a subset of the variables - link!(vi, spl, Val((:m, ))) + link!(vi, spl, Val((:m,))) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl, Val((:m, ))) + invlink!(vi, spl, Val((:m,))) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals == v_s @@ -170,14 +168,16 @@ xs = rand(Normal(0.5, 1), 100) # Define model - @model priorsinarray(xs, ::Type{T}=Float64) where {T} = begin - priors = Vector{T}(undef, 2) - priors[1] ~ InverseGamma(2, 3) - priors[2] ~ Normal(0, sqrt(priors[1])) - for i = 1:length(xs) - xs[i] ~ Normal(priors[2], sqrt(priors[1])) + @model function priorsinarray(xs, ::Type{T}=Float64) where {T} + begin + priors = Vector{T}(undef, 2) + priors[1] ~ InverseGamma(2, 3) + priors[2] ~ Normal(0, sqrt(priors[1])) + for i in 1:length(xs) + xs[i] ~ Normal(priors[2], sqrt(priors[1])) + end + priors end - priors end # Sampling @@ -198,24 +198,24 @@ @test v_arr.indexing == ((1,),) # Matrix - v_mat = @varname x[i,j] + v_mat = @varname x[i, j] @test v_mat.indexing == ((1, 2),) - v_mat = @varname x[i,j,k] - @test v_mat.indexing == ((1,2,3),) + v_mat = @varname x[i, j, k] + @test v_mat.indexing == ((1, 2, 3),) - v_mat = @varname x[1,2][1+5][45][3][i] - @test v_mat.indexing == ((1,2), (6,), (45,), (3,), (1,)) + v_mat = @varname x[1, 2][1 + 5][45][3][i] + @test v_mat.indexing == ((1, 2), (6,), (45,), (3,), (1,)) @model function mat_name_test() p = Array{Any}(undef, 2, 2) for i in 1:2, j in 1:2 - p[i,j] ~ Normal(0, 1) + p[i, j] ~ Normal(0, 1) end - p + return p end chain = sample(mat_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1,1]"], [0], atol = 0.25) + check_numerical(chain, ["p[1,1]"], [0]; atol=0.25) # Multi array v_arrarr = @varname x[i][j] @@ -228,11 +228,11 @@ for i in 1:2, j in 1:2 p[i][j] ~ Normal(0, 1) end - p + return p end chain = sample(marr_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1][1]"], [0], atol = 0.25) + check_numerical(chain, ["p[1][1]"], [0]; atol=0.25) end @testset "varinfo" begin dists = [Normal(0, 1), MvNormal([0; 0], [1.0 0; 0 1.0]), Wishart(7, [1 0.5; 0.5 1])] @@ -255,7 +255,7 @@ vns = [vn_x, vn_y, vn_z] spl1 = Sampler(PG(5, :x, :y, :z), empty_model()) - for i = 1:3 + for i in 1:3 r = randr(vi, vns[i], dists[i], spl1, false) val = vi[vns[i]] @test sum(val - r) <= 1e-9 @@ -293,17 +293,16 @@ test_varinfo!(empty!(TypedVarInfo(vi))) @model igtest() = begin - x ~ InverseGamma(2,3) - y ~ InverseGamma(2,3) - z ~ InverseGamma(2,3) - w ~ InverseGamma(2,3) - u ~ InverseGamma(2,3) + x ~ InverseGamma(2, 3) + y ~ InverseGamma(2, 3) + z ~ InverseGamma(2, 3) + w ~ InverseGamma(2, 3) + u ~ InverseGamma(2, 3) end # Test the update of group IDs g_demo_f = igtest() - # This test section no longer seems as applicable, considering the # user will never end up using an UntypedVarInfo. The `VarInfo` # Varible is also not passed around in the same way as it used to be. diff --git a/test/utils.jl b/test/utils.jl index 5243e5a79..37f1aaa86 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,7 +3,7 @@ @model function testmodel() global lp_before = getlogp(__varinfo__) @addlogprob!(42) - global lp_after = getlogp(__varinfo__) + return global lp_after = getlogp(__varinfo__) end model = testmodel() diff --git a/test/varinfo.jl b/test/varinfo.jl index bee6e781e..c936ad67c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,7 +1,7 @@ @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin - s ~ InverseGamma(2,3) + s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) x ~ Normal(m, sqrt(s)) y ~ Normal(m, sqrt(s)) @@ -68,12 +68,12 @@ @test vi[vn] == r @test vi[SampleFromPrior()][1] == r - vi[vn] = [2*r] - @test vi[vn] == 2*r - @test vi[SampleFromPrior()][1] == 2*r - vi[SampleFromPrior()] = [3*r] - @test vi[vn] == 3*r - @test vi[SampleFromPrior()][1] == 3*r + vi[vn] = [2 * r] + @test vi[vn] == 2 * r + @test vi[SampleFromPrior()][1] == 2 * r + vi[SampleFromPrior()] = [3 * r] + @test vi[vn] == 3 * r + @test vi[SampleFromPrior()][1] == 3 * r empty!(vi) @test isempty(vi) @@ -88,13 +88,13 @@ @test inspace(@varname(z[1][1]), space) @test inspace(@varname(z[1][:]), space) @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2,3], 1]), space) + @test inspace(@varname(M[[2, 3], 1]), space) @test inspace(@varname(M[:, 1:4]), space) @test inspace(@varname(M[1, [2, 4, 6]]), space) @test !inspace(@varname(z[2]), space) @test !inspace(@varname(z), space) end - test_inspace() + return test_inspace() end vi = VarInfo() test_base!(vi) @@ -150,10 +150,10 @@ n = length(x) s ~ truncated(Normal(), 0, Inf) m ~ MvNormal(n, 1.0) - x ~ MvNormal(m, s) + return x ~ MvNormal(m, s) end - @model function testmodel_univariate(x, ::Type{TV} = Vector{Float64}) where {TV} + @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV} n = length(x) s ~ truncated(Normal(), 0, Inf) @@ -172,7 +172,7 @@ model_uv = testmodel_univariate(x) for model in [model_uv, model_mv] - m_vns = model == model_uv ? [@varname(m[i]) for i = 1:5] : @varname(m) + m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) vi_typed = VarInfo(model) @@ -183,7 +183,7 @@ vicopy = deepcopy(vi) ### `setval` ### - DynamicPPL.setval!(vicopy, (m = zeros(5),)) + DynamicPPL.setval!(vicopy, (m=zeros(5),)) # Setting `m` fails for univariate due to limitations of `setval!` # and `setval_and_resample!`. See docstring of `setval!` for more info. if model == model_uv @@ -193,11 +193,13 @@ end @test vicopy[s_vns] == vi[s_vns] - DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)) + DynamicPPL.setval!( + vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) + ) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == vi[s_vns] - DynamicPPL.setval!(vicopy, (s = 42,)) + DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 @@ -210,7 +212,7 @@ end vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m = zeros(5),)) + DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) model(vicopy) # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` if model == model_uv @@ -221,14 +223,13 @@ @test vicopy[s_vns] != vi[s_vns] DynamicPPL.setval_and_resample!( - vicopy, - (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) + vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) ) model(vicopy) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] != vi[s_vns] - DynamicPPL.setval_and_resample!(vicopy, (s = 42,)) + DynamicPPL.setval_and_resample!(vicopy, (s=42,)) model(vicopy) @test vicopy[m_vns] != 1:5 @test vicopy[s_vns] == 42