Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pymc3/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def dict_to_dataset(
"""
if default_dims is None:
return _dict_to_dataset(
data, library=library, coords=coords, dims=dims, skip_event_dims=skip_event_dims
data,
attrs=attrs,
library=library,
coords=coords,
dims=dims,
skip_event_dims=skip_event_dims,
)
else:
out_data = {}
Expand All @@ -129,7 +134,7 @@ def dict_to_dataset(
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=out_data, attrs=make_attrs(library=library))
return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library))


class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
Expand Down
90 changes: 79 additions & 11 deletions pymc3/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

import numpy as np

from arviz import InferenceData

import pymc3

from pymc3.backends.arviz import dict_to_dataset, to_inference_data
from pymc3.backends.base import MultiTrace
from pymc3.model import modelcontext
from pymc3.parallel_sampling import _cpu_count
Expand All @@ -31,6 +36,7 @@ def sample_smc(
draws=2000,
kernel="metropolis",
n_steps=25,
*,
start=None,
tune_steps=True,
p_acc_rate=0.85,
Expand All @@ -42,6 +48,9 @@ def sample_smc(
parallel=False,
chains=None,
cores=None,
compute_convergence_checks=True,
return_inferencedata=True,
idata_kwargs=None,
):
r"""
Sequential Monte Carlo based sampling.
Expand Down Expand Up @@ -91,7 +100,14 @@ def sample_smc(
The number of chains to sample. Running independent chains is important for some
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
is larger.

compute_convergence_checks : bool
Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``.
Defaults to ``True``.
return_inferencedata : bool, default=True
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to ``True``.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc3.to_inference_data`
Notes
-----
SMC works by moving through successive stages. At each stage the inverse temperature
Expand Down Expand Up @@ -213,20 +229,72 @@ def sample_smc(
accept_ratios,
nsteps,
) = zip(*results)

trace = MultiTrace(traces)
trace.report._n_draws = draws
trace.report._n_tune = 0
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
trace.report.log_pseudolikelihood = log_pseudolikelihood
trace.report.betas = betas
trace.report.accept_ratios = accept_ratios
trace.report.nsteps = nsteps
trace.report._t_sampling = time.time() - t1
idata = None

# Save sample_stats
_n_tune = 0
_t_sampling = time.time() - t1
if not return_inferencedata:
trace.report._n_draws = draws
trace.report._n_tune = _n_tune
trace.report.log_marginal_likelihood = log_marginal_likelihoods
trace.report.log_pseudolikelihood = log_pseudolikelihood
trace.report.betas = betas
trace.report.accept_ratios = accept_ratios
trace.report.nsteps = nsteps
trace.report._t_sampling = _t_sampling
else:
# There is only one log_marginal_likelihood per chain, here we broadcast
# it to the number of draws in each chain (to avoid InferenceData
# warning) and fill the non-final draws with nans
_log_marginal_likelihoods = []
for chain in range(chains):
row = np.full(len(np.atleast_1d(betas)[chain]), np.nan)
row[-1] = np.atleast_1d(log_marginal_likelihoods)[chain]
_log_marginal_likelihoods.append(row)

# Different chains might have more iteration steps, leading to a
# non-square `sample_stats` dataset, we cast as `object` to avoid
# numpy ragged array deprecation warning
sample_stats = dict_to_dataset(
dict(
accept_ratios=np.array(accept_ratios, dtype=object),
betas=np.array(betas, dtype=object),
log_marginal_likelihoods=np.array(_log_marginal_likelihoods, dtype=object),
nsteps=np.array(nsteps, dtype=object),
),
attrs=dict(
_n_tune=_n_tune,
_t_sampling=_t_sampling,
),
library=pymc3,
)

ikwargs = dict(model=model)
if idata_kwargs is not None:
ikwargs.update(idata_kwargs)
idata = to_inference_data(trace, **ikwargs)
idata = InferenceData(**idata, sample_stats=sample_stats)

if compute_convergence_checks:
if draws < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
trace.report._run_convergence_checks(idata, model)
trace.report._log_summary()

posterior = idata if return_inferencedata else trace
if save_sim_data:
return trace, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)}
return posterior, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)}
else:
return trace
return posterior


def sample_smc_int(
Expand Down
44 changes: 39 additions & 5 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import numpy as np
import pytest

from arviz.data.inference_data import InferenceData

import pymc3 as pm

from pymc3.backends.base import MultiTrace
from pymc3.tests.helpers import SeededTest


Expand Down Expand Up @@ -59,7 +62,7 @@ def two_gaussians(x):

def test_sample(self):
with self.SMC_test:
mtrace = pm.sample_smc(draws=self.samples)
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)

x = mtrace["X"]
mu1d = np.abs(x).mean(axis=0)
Expand All @@ -70,7 +73,7 @@ def test_discrete_continuous(self):
a = pm.Poisson("a", 5)
b = pm.HalfNormal("b", 10)
y = pm.Normal("y", a, b, observed=[1, 2, 3, 4])
trace = pm.sample_smc()
trace = pm.sample_smc(draws=10)

def test_ml(self):
data = np.repeat([1, 0], [50, 50])
Expand All @@ -82,7 +85,7 @@ def test_ml(self):
with pm.Model() as model:
a = pm.Beta("a", alpha, beta)
y = pm.Bernoulli("y", a, observed=data)
trace = pm.sample_smc(2000)
trace = pm.sample_smc(2000, return_inferencedata=False)
marginals.append(trace.report.log_marginal_likelihood)
# compare to the analytical result
assert abs(np.exp(np.mean(marginals[1]) - np.mean(marginals[0])) - 4.0) <= 1
Expand All @@ -96,15 +99,46 @@ def test_start(self):
"a": np.random.poisson(5, size=500),
"b_log__": np.abs(np.random.normal(0, 10, size=500)),
}
trace = pm.sample_smc(500, start=start)
trace = pm.sample_smc(500, chains=1, start=start)

def test_slowdown_warning(self):
with aesara.config.change_flags(floatX="float32"):
with pytest.warns(UserWarning, match="SMC sampling may run slower due to"):
with pm.Model() as model:
a = pm.Poisson("a", 5)
y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4])
trace = pm.sample_smc()
trace = pm.sample_smc(draws=100, chains=2)

@pytest.mark.parametrize("chains", (1, 2))
def test_return_datatype(self, chains):
draws = 10

with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=5)

idata = pm.sample_smc(chains=chains, draws=draws)
mt = pm.sample_smc(chains=chains, draws=draws, return_inferencedata=False)

assert isinstance(idata, InferenceData)
assert "sample_stats" in idata
assert len(idata.posterior.chain) == chains
assert len(idata.posterior.draw) == draws

assert isinstance(mt, MultiTrace)
assert mt.nchains == chains
assert mt["x"].size == chains * draws

def test_convergence_checks(self):
with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=5)

with pytest.warns(
UserWarning,
match="The number of samples is too small",
):
pm.sample_smc(draws=99)


@pytest.mark.xfail(reason="SMC-ABC not refactored yet")
Expand Down