diff --git a/env-dev.yml b/env-dev.yml index 014979c..375558b 100644 --- a/env-dev.yml +++ b/env-dev.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc==5.24.0 + - pymc>=5.24.0 - numba - matplotlib - numpy diff --git a/env.yml b/env.yml index f5ebf01..3afdd9f 100644 --- a/env.yml +++ b/env.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc==5.24.0 + - pymc>=5.24.0 - numba - matplotlib - numpy diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 92f0e21..87bd36a 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -37,6 +37,7 @@ get_idx_left_child, get_idx_right_child, ) +from pymc_bart.utils import _encode_vi class ParticleTree: @@ -118,7 +119,7 @@ class PGBART(ArrayStepShared): default_blocked = False generates_stats = True stats_dtypes_shapes: dict[str, tuple[type, list]] = { - "variable_inclusion": (object, []), + "variable_inclusion": (int, []), "tune": (bool, []), } @@ -335,6 +336,8 @@ def astep(self, _): if not self.tune: self.bart.all_trees.append(self.all_trees) + variable_inclusion = _encode_vi(variable_inclusion) + stats = {"variable_inclusion": variable_inclusion, "tune": self.tune} return self.sum_trees, [stats] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index ab10467..78ce920 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pymc as pm import pytensor.tensor as pt from arviz_base import rcParams from arviz_stats.base import array_stats @@ -674,22 +675,29 @@ def _smooth_mean( return x_data, y_data -def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): +def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None, to_kulprit=False): """ Get the normalized variable inclusion from BART model. Parameters ---------- idata : InferenceData - InferenceData containing a collection of BART_trees in sample_stats group + InferenceData with a variable "variable_inclusion" in ``sample_stats`` group X : npt.NDArray The covariate matrix. + model : Optional[pm.Model] + The PyMC model that contains the BART variable. Only needed if the model contains multiple + BART variables. + bart_var_name : Optional[str] + The name of the BART variable in the model. Only needed if the model contains multiple + BART variables. labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. to_kulprit : bool If True, the function will return a list of list with the variables names. This list can be passed as a path to Kulprit's project method. Defaults to False. + Returns ------- VI_norm : npt.NDArray @@ -697,7 +705,20 @@ def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): labels : list[str] List of the names of the covariates. """ - VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values + n_vars = X.shape[1] + vi_xarray = idata["sample_stats"]["variable_inclusion"] + if "variable_inclusion_dim_0" in vi_xarray.coords: + if model is None or bart_var_name is None: + raise ValueError( + "The InfereceData was generated from a model with multiple BART variables, \n" + "please provide the model and also the name of the BART variable \n" + "for which you want to compute the variable inclusion." + ) + index = [var.name for var in model.free_RVs].index(bart_var_name) + vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel() + else: + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + VIs = np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0) VI_norm = VIs / VIs.sum() idxs = np.argsort(VI_norm) @@ -705,17 +726,15 @@ def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): n_vars = len(indices) if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns + labels = list(X.columns) if labels is None: - labels = np.arange(n_vars).astype(str) - - label_list = labels.to_list() + labels = [str(i) for i in range(n_vars)] if to_kulprit: - return [label_list[:idx] for idx in range(n_vars)] + return [labels[:idx] for idx in range(n_vars)] else: - return VI_norm[indices], label_list + return VI_norm[indices], labels def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): @@ -781,10 +800,11 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 idata: Any, bartrv: Variable, X: npt.NDArray, + model: "pm.Model | None" = None, method: str = "VI", fixed: int = 0, samples: int = 50, - random_seed: Optional[int] = None, + random_seed: int | None = None, ) -> dict[str, object]: """ Estimates variable importance from the BART-posterior. @@ -792,11 +812,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 Parameters ---------- idata : InferenceData - InferenceData containing a collection of BART_trees in sample_stats group + InferenceData containing a "variable_inclusion" variable in the sample_stats group. bartrv : BART Random Variable BART variable once the model that include it has been fitted. X : npt.NDArray The covariate matrix. + model : Optional[pm.Model] + The PyMC model that contains the BART variable. Only needed if the model contains multiple + BART variables. method : str Method used to rank variables. Available options are "VI" (default), "backward" and "backward_VI". @@ -825,6 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 rng = np.random.default_rng(random_seed) all_trees = bartrv.owner.op.all_trees + bart_var_name = bartrv.name if bartrv.ndim == 1: # type: ignore shape = 1 @@ -858,9 +882,20 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 ) if method in ["VI", "backward_VI"]: - idxs = np.argsort( - idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values - ) + vi_xarray = idata["sample_stats"]["variable_inclusion"] + if "variable_inclusion_dim_0" in vi_xarray.coords: + if model is None: + raise ValueError( + "The InfereceData was generated from a model with multiple BART variables, \n" + "please provide the model and also the name of the BART variable \n" + "for which you want to compute the variable inclusion." + ) + + index = [var.name for var in model.free_RVs].index(bart_var_name) + vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel() + else: + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + idxs = np.argsort(np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0)) subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))] subsets.append(None) # type: ignore @@ -1226,3 +1261,39 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax): ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha) return ax + + +def _decode_vi(n: int, length: int) -> list[int]: + """ + Decode the variable inclusion from the BART model. + """ + bits = bin(n)[2:] + vi_list: list[int] = [] + i = 0 + while len(vi_list) < length: + # Count prefix ones + prefix_len = 0 + while bits[i] == "1": + prefix_len += 1 + i += 1 + i += 1 # skip the '0' + b = bits[i : i + prefix_len] + vi_list.append(int(b, 2)) + i += prefix_len + return vi_list + + +def _encode_vi(vec: npt.NDArray) -> int: + """ + Encode variable inclusion vector into a single integer. + + The encoding is done by converting each element of the vector into a binary string, + where each element contributes a prefix of '1's followed by a '0' and its binary representation. + The final result is the integer representation of the concatenated binary string. + """ + bits = "" + for x in vec: + b = bin(x)[2:] + prefix = "1" * len(b) + "0" + bits += prefix + b + return int(bits, 2) diff --git a/requirements.txt b/requirements.txt index 2a053a7..24d156b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc==5.24.0 +pymc>=5.24.0 arviz-stats[xarray]>=0.6.0 numba matplotlib diff --git a/tests/test_bart.py b/tests/test_bart.py index 8311c2a..f446cd4 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -1,11 +1,12 @@ import numpy as np import pymc as pm import pytest -from numpy.testing import assert_almost_equal, assert_array_equal +from numpy.testing import assert_almost_equal from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp import pymc_bart as pmb +from pymc_bart.utils import _decode_vi def assert_moment_is_expected(model, expected, check_finite_logp=True): @@ -52,14 +53,12 @@ def test_bart_vi(response): with pm.Model() as model: mu = pmb.BART("mu", X, Y, m=10, response=response) sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) + pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(tune=200, draws=200, random_seed=3415) - var_imp = ( - idata.sample_stats["variable_inclusion"] - .stack(samples=("chain", "draw")) - .mean("samples") - ) - var_imp /= var_imp.sum() + vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel() + var_imp = np.array([_decode_vi(val, 3) for val in vi_vals]).sum(axis=0) + + var_imp = var_imp / var_imp.sum() assert var_imp[0] > var_imp[1:].sum() assert_almost_equal(var_imp.sum(), 1) @@ -123,92 +122,6 @@ def test_shape(response): assert idata.posterior.coords["w_dim_1"].data.size == 250 -class TestUtils: - X_norm = np.random.normal(0, 1, size=(50, 2)) - X_binom = np.random.binomial(1, 0.5, size=(50, 1)) - X = np.hstack([X_norm, X_binom]) - Y = np.random.normal(0, 1, size=50) - - with pm.Model() as model: - mu = pmb.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=200, draws=200, random_seed=3415) - - def test_sample_posterior(self): - all_trees = self.mu.owner.op.all_trees - rng = np.random.default_rng(3) - pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) - rng = np.random.default_rng(3) - pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) - - assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) - assert pred_all.shape == (2, 50, 1) - assert pred_first.shape == (1, 10, 1) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "samples": 2, - "var_discrete": [3], - }, - {"instances": 2}, - {"var_idx": [0], "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - {"var_discrete": [0]}, - ], - ) - def test_ice(self, kwargs): - pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "samples": 2, - "xs_interval": "quantiles", - "xs_values": [0.25, 0.5, 0.75], - "var_discrete": [3], - }, - {"var_idx": [0], "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - {"var_discrete": [0]}, - ], - ) - def test_pdp(self, kwargs): - pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs) - - @pytest.mark.parametrize( - "kwargs", - [ - {"samples": 50}, - {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, - ], - ) - def test_vi(self, kwargs): - samples = kwargs.pop("samples") - vi_results = pmb.compute_variable_importance( - self.idata, bartrv=self.mu, X=self.X, samples=samples - ) - pmb.plot_variable_importance(vi_results, **kwargs) - pmb.plot_scatter_submodels(vi_results, **kwargs) - - def test_pdp_pandas_labels(self): - pd = pytest.importorskip("pandas") - - X_names = ["norm1", "norm2", "binom"] - X_pd = pd.DataFrame(self.X, columns=X_names) - Y_pd = pd.Series(self.Y, name="response") - axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd) - - figure = axes[0].figure - assert figure.texts[0].get_text() == "Partial response" - assert_array_equal([ax.get_xlabel() for ax in axes], X_names) - - @pytest.mark.parametrize( "size, expected", [ @@ -275,7 +188,7 @@ def test_multiple_bart_variables(): # Combined model sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + pm.Normal("y", mu1 + mu2, sigma, observed=Y) # Sample with automatic assignment of BART samplers idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415) @@ -291,6 +204,16 @@ def test_multiple_bart_variables(): assert idata.posterior["mu1"].shape == (1, 50, 50) assert idata.posterior["mu2"].shape == (1, 50, 50) + vi_results = pmb.compute_variable_importance(idata, mu1, X1, model=model) + assert vi_results["labels"].shape == (2,) + assert vi_results["preds"].shape == (2, 50, 50) + assert vi_results["preds_all"].shape == (50, 50) + + vi_tuple = pmb.get_variable_inclusion(idata, X1, model=model, bart_var_name="mu1") + assert vi_tuple[0].shape == (2,) + assert len(vi_tuple[1]) == 2 + assert isinstance(vi_tuple[1][0], str) + def test_multiple_bart_variables_manual_step(): """Test that multiple BART variables work with manually assigned PGBART samplers.""" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..dbf3aca --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,107 @@ +import numpy as np +import pymc as pm +import pytest +from numpy.testing import assert_almost_equal, assert_array_equal + +import pymc_bart as pmb + + +class TestUtils: + X_norm = np.random.normal(0, 1, size=(50, 2)) + X_binom = np.random.binomial(1, 0.5, size=(50, 1)) + X = np.hstack([X_norm, X_binom]) + Y = np.random.normal(0, 1, size=50) + + with pm.Model() as model: + mu = pmb.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(tune=200, draws=200, random_seed=3415) + + def test_sample_posterior(self): + all_trees = self.mu.owner.op.all_trees + rng = np.random.default_rng(3) + pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2) + rng = np.random.default_rng(3) + pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng) + + assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4) + assert pred_all.shape == (2, 50, 1) + assert pred_first.shape == (1, 10, 1) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "samples": 2, + "var_discrete": [3], + }, + {"instances": 2}, + {"var_idx": [0], "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + {"var_discrete": [0]}, + ], + ) + def test_ice(self, kwargs): + pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "samples": 2, + "xs_interval": "quantiles", + "xs_values": [0.25, 0.5, 0.75], + "var_discrete": [3], + }, + {"var_idx": [0], "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + {"var_discrete": [0]}, + ], + ) + def test_pdp(self, kwargs): + pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {"samples": 50}, + {"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)}, + ], + ) + def test_vi(self, kwargs): + samples = kwargs.pop("samples") + vi_results = pmb.compute_variable_importance( + self.idata, bartrv=self.mu, X=self.X, samples=samples + ) + pmb.plot_variable_importance(vi_results, **kwargs) + pmb.plot_scatter_submodels(vi_results, **kwargs) + + def test_pdp_pandas_labels(self): + pd = pytest.importorskip("pandas") + + X_names = ["norm1", "norm2", "binom"] + X_pd = pd.DataFrame(self.X, columns=X_names) + Y_pd = pd.Series(self.Y, name="response") + axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd) + + figure = axes[0].figure + assert figure.texts[0].get_text() == "Partial response" + assert_array_equal([ax.get_xlabel() for ax in axes], X_names) + + +def test_encoder_decoder(): + """Test that the encoder-decoder works correctly.""" + test_cases = [ + np.zeros(3, dtype=int), + np.ones(10, dtype=int), + np.array([4, 0, 1, 0, 2, 0, 3, 0, 0, 0]), + np.array([100, 50, 0, 1]), + np.array([1, 2, 4, 8, 16]), + ] + for case in test_cases: + encoded = pmb.utils._encode_vi(case) + decoded = pmb.utils._decode_vi(encoded, len(case)) + assert np.array_equal(decoded, case)