From 75a3cf9afa5db86c734396609c46e5a309f1a105 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 28 Jun 2021 11:03:57 +0200 Subject: [PATCH 1/3] Return InferenceData and run convergence checks in sample_smc by default --- pymc3/smc/sample_smc.py | 29 ++++++++++++++++++++++++++--- pymc3/tests/test_smc.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/pymc3/smc/sample_smc.py b/pymc3/smc/sample_smc.py index 90b953b066..ce5d586f08 100644 --- a/pymc3/smc/sample_smc.py +++ b/pymc3/smc/sample_smc.py @@ -21,6 +21,7 @@ import numpy as np +from pymc3.backends.arviz import to_inference_data from pymc3.backends.base import MultiTrace from pymc3.model import modelcontext from pymc3.parallel_sampling import _cpu_count @@ -42,6 +43,8 @@ def sample_smc( parallel=False, chains=None, cores=None, + compute_convergence_checks=True, + return_inferencedata=True, ): r""" Sequential Monte Carlo based sampling. @@ -91,7 +94,12 @@ def sample_smc( The number of chains to sample. Running independent chains is important for some convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever is larger. - + compute_convergence_checks : bool + Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``. + Defaults to ``True``. + return_inferencedata : bool, default=True + Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False) + Defaults to ``True``. Notes ----- SMC works by moving through successive stages. At each stage the inverse temperature @@ -223,10 +231,25 @@ def sample_smc( trace.report.nsteps = nsteps trace.report._t_sampling = time.time() - t1 + if compute_convergence_checks or return_inferencedata: + ikwargs = dict(model=model) + idata = to_inference_data(trace, **ikwargs) + + if compute_convergence_checks: + if draws < 100: + warnings.warn( + "The number of samples is too small to check convergence reliably.", + stacklevel=2, + ) + else: + trace.report._run_convergence_checks(idata, model) + trace.report._log_summary() + + posterior = idata if return_inferencedata else trace if save_sim_data: - return trace, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)} + return posterior, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)} else: - return trace + return posterior def sample_smc_int( diff --git a/pymc3/tests/test_smc.py b/pymc3/tests/test_smc.py index f6f01e2e48..d3d303c2a6 100644 --- a/pymc3/tests/test_smc.py +++ b/pymc3/tests/test_smc.py @@ -17,8 +17,11 @@ import numpy as np import pytest +from arviz.data.inference_data import InferenceData + import pymc3 as pm +from pymc3.backends.base import MultiTrace from pymc3.tests.helpers import SeededTest @@ -106,6 +109,36 @@ def test_slowdown_warning(self): y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4]) trace = pm.sample_smc() + def test_return_datatype(self): + chains = 2 + draws = 10 + + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + y = pm.Normal("y", x, 1, observed=5) + + idata = pm.sample_smc(chains=chains, draws=draws) + mt = pm.sample_smc(chains=chains, draws=draws, return_inferencedata=False) + + assert isinstance(idata, InferenceData) + assert len(idata.posterior.chain) == chains + assert len(idata.posterior.draw) == draws + + assert isinstance(mt, MultiTrace) + assert mt.nchains == chains + assert mt["x"].size == chains * draws + + def test_convergence_checks(self): + with pm.Model() as m: + x = pm.Normal("x", 0, 1) + y = pm.Normal("y", x, 1, observed=5) + + with pytest.warns( + UserWarning, + match="The number of samples is too small", + ): + pm.sample_smc(draws=99) + @pytest.mark.xfail(reason="SMC-ABC not refactored yet") class TestSMCABC(SeededTest): From fdfe3392a7b06279bf9f228e70658319bd5fd205 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 28 Jun 2021 14:49:38 +0200 Subject: [PATCH 2/3] Fix failing tests --- pymc3/tests/test_smc.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc3/tests/test_smc.py b/pymc3/tests/test_smc.py index d3d303c2a6..8fc98dec3f 100644 --- a/pymc3/tests/test_smc.py +++ b/pymc3/tests/test_smc.py @@ -62,7 +62,7 @@ def two_gaussians(x): def test_sample(self): with self.SMC_test: - mtrace = pm.sample_smc(draws=self.samples) + mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False) x = mtrace["X"] mu1d = np.abs(x).mean(axis=0) @@ -73,7 +73,7 @@ def test_discrete_continuous(self): a = pm.Poisson("a", 5) b = pm.HalfNormal("b", 10) y = pm.Normal("y", a, b, observed=[1, 2, 3, 4]) - trace = pm.sample_smc() + trace = pm.sample_smc(draws=10) def test_ml(self): data = np.repeat([1, 0], [50, 50]) @@ -85,7 +85,7 @@ def test_ml(self): with pm.Model() as model: a = pm.Beta("a", alpha, beta) y = pm.Bernoulli("y", a, observed=data) - trace = pm.sample_smc(2000) + trace = pm.sample_smc(2000, return_inferencedata=False) marginals.append(trace.report.log_marginal_likelihood) # compare to the analytical result assert abs(np.exp(np.mean(marginals[1]) - np.mean(marginals[0])) - 4.0) <= 1 @@ -99,7 +99,7 @@ def test_start(self): "a": np.random.poisson(5, size=500), "b_log__": np.abs(np.random.normal(0, 10, size=500)), } - trace = pm.sample_smc(500, start=start) + trace = pm.sample_smc(500, chains=1, start=start) def test_slowdown_warning(self): with aesara.config.change_flags(floatX="float32"): @@ -107,7 +107,7 @@ def test_slowdown_warning(self): with pm.Model() as model: a = pm.Poisson("a", 5) y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4]) - trace = pm.sample_smc() + trace = pm.sample_smc(draws=100, chains=2) def test_return_datatype(self): chains = 2 From a4537760d24a4b5413d7e106d4ce707b5697aa89 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 28 Jun 2021 17:48:40 +0200 Subject: [PATCH 3/3] Add SMC sample_stats to InferenceData --- pymc3/backends/arviz.py | 9 ++++-- pymc3/smc/sample_smc.py | 67 ++++++++++++++++++++++++++++++++++------- pymc3/tests/test_smc.py | 5 +-- 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/pymc3/backends/arviz.py b/pymc3/backends/arviz.py index 8a3f7b46cc..c144083ddf 100644 --- a/pymc3/backends/arviz.py +++ b/pymc3/backends/arviz.py @@ -119,7 +119,12 @@ def dict_to_dataset( """ if default_dims is None: return _dict_to_dataset( - data, library=library, coords=coords, dims=dims, skip_event_dims=skip_event_dims + data, + attrs=attrs, + library=library, + coords=coords, + dims=dims, + skip_event_dims=skip_event_dims, ) else: out_data = {} @@ -129,7 +134,7 @@ def dict_to_dataset( val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords) coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims} out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=out_data, attrs=make_attrs(library=library)) + return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library)) class InferenceDataConverter: # pylint: disable=too-many-instance-attributes diff --git a/pymc3/smc/sample_smc.py b/pymc3/smc/sample_smc.py index ce5d586f08..f5fed474a8 100644 --- a/pymc3/smc/sample_smc.py +++ b/pymc3/smc/sample_smc.py @@ -21,7 +21,11 @@ import numpy as np -from pymc3.backends.arviz import to_inference_data +from arviz import InferenceData + +import pymc3 + +from pymc3.backends.arviz import dict_to_dataset, to_inference_data from pymc3.backends.base import MultiTrace from pymc3.model import modelcontext from pymc3.parallel_sampling import _cpu_count @@ -32,6 +36,7 @@ def sample_smc( draws=2000, kernel="metropolis", n_steps=25, + *, start=None, tune_steps=True, p_acc_rate=0.85, @@ -45,6 +50,7 @@ def sample_smc( cores=None, compute_convergence_checks=True, return_inferencedata=True, + idata_kwargs=None, ): r""" Sequential Monte Carlo based sampling. @@ -100,6 +106,8 @@ def sample_smc( return_inferencedata : bool, default=True Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False) Defaults to ``True``. + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc3.to_inference_data` Notes ----- SMC works by moving through successive stages. At each stage the inverse temperature @@ -221,19 +229,54 @@ def sample_smc( accept_ratios, nsteps, ) = zip(*results) + trace = MultiTrace(traces) - trace.report._n_draws = draws - trace.report._n_tune = 0 - trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods) - trace.report.log_pseudolikelihood = log_pseudolikelihood - trace.report.betas = betas - trace.report.accept_ratios = accept_ratios - trace.report.nsteps = nsteps - trace.report._t_sampling = time.time() - t1 - - if compute_convergence_checks or return_inferencedata: + idata = None + + # Save sample_stats + _n_tune = 0 + _t_sampling = time.time() - t1 + if not return_inferencedata: + trace.report._n_draws = draws + trace.report._n_tune = _n_tune + trace.report.log_marginal_likelihood = log_marginal_likelihoods + trace.report.log_pseudolikelihood = log_pseudolikelihood + trace.report.betas = betas + trace.report.accept_ratios = accept_ratios + trace.report.nsteps = nsteps + trace.report._t_sampling = _t_sampling + else: + # There is only one log_marginal_likelihood per chain, here we broadcast + # it to the number of draws in each chain (to avoid InferenceData + # warning) and fill the non-final draws with nans + _log_marginal_likelihoods = [] + for chain in range(chains): + row = np.full(len(np.atleast_1d(betas)[chain]), np.nan) + row[-1] = np.atleast_1d(log_marginal_likelihoods)[chain] + _log_marginal_likelihoods.append(row) + + # Different chains might have more iteration steps, leading to a + # non-square `sample_stats` dataset, we cast as `object` to avoid + # numpy ragged array deprecation warning + sample_stats = dict_to_dataset( + dict( + accept_ratios=np.array(accept_ratios, dtype=object), + betas=np.array(betas, dtype=object), + log_marginal_likelihoods=np.array(_log_marginal_likelihoods, dtype=object), + nsteps=np.array(nsteps, dtype=object), + ), + attrs=dict( + _n_tune=_n_tune, + _t_sampling=_t_sampling, + ), + library=pymc3, + ) + ikwargs = dict(model=model) + if idata_kwargs is not None: + ikwargs.update(idata_kwargs) idata = to_inference_data(trace, **ikwargs) + idata = InferenceData(**idata, sample_stats=sample_stats) if compute_convergence_checks: if draws < 100: @@ -242,6 +285,8 @@ def sample_smc( stacklevel=2, ) else: + if idata is None: + idata = to_inference_data(trace, log_likelihood=False) trace.report._run_convergence_checks(idata, model) trace.report._log_summary() diff --git a/pymc3/tests/test_smc.py b/pymc3/tests/test_smc.py index 8fc98dec3f..baeb302b57 100644 --- a/pymc3/tests/test_smc.py +++ b/pymc3/tests/test_smc.py @@ -109,8 +109,8 @@ def test_slowdown_warning(self): y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4]) trace = pm.sample_smc(draws=100, chains=2) - def test_return_datatype(self): - chains = 2 + @pytest.mark.parametrize("chains", (1, 2)) + def test_return_datatype(self, chains): draws = 10 with pm.Model() as m: @@ -121,6 +121,7 @@ def test_return_datatype(self): mt = pm.sample_smc(chains=chains, draws=draws, return_inferencedata=False) assert isinstance(idata, InferenceData) + assert "sample_stats" in idata assert len(idata.posterior.chain) == chains assert len(idata.posterior.draw) == draws