Skip to content

encode vi and update to work with multiple RVs #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
- defaults
dependencies:
- pymc==5.24.0
- pymc>=5.24.0
- numba
- matplotlib
- numpy
Expand Down
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
- defaults
dependencies:
- pymc==5.24.0
- pymc>=5.24.0
- numba
- matplotlib
- numpy
Expand Down
5 changes: 4 additions & 1 deletion pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_idx_left_child,
get_idx_right_child,
)
from pymc_bart.utils import _encode_vi


class ParticleTree:
Expand Down Expand Up @@ -118,7 +119,7 @@ class PGBART(ArrayStepShared):
default_blocked = False
generates_stats = True
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
"variable_inclusion": (object, []),
"variable_inclusion": (int, []),
"tune": (bool, []),
}

Expand Down Expand Up @@ -335,6 +336,8 @@ def astep(self, _):
if not self.tune:
self.bart.all_trees.append(self.all_trees)

variable_inclusion = _encode_vi(variable_inclusion)

stats = {"variable_inclusion": variable_inclusion, "tune": self.tune}
return self.sum_trees, [stats]

Expand Down
99 changes: 85 additions & 14 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
from arviz_base import rcParams
from arviz_stats.base import array_stats
Expand Down Expand Up @@ -674,48 +675,66 @@ def _smooth_mean(
return x_data, y_data


def get_variable_inclusion(idata, X, labels=None, to_kulprit=False):
def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None, to_kulprit=False):
"""
Get the normalized variable inclusion from BART model.

Parameters
----------
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
InferenceData with a variable "variable_inclusion" in ``sample_stats`` group
X : npt.NDArray
The covariate matrix.
model : Optional[pm.Model]
The PyMC model that contains the BART variable. Only needed if the model contains multiple
BART variables.
bart_var_name : Optional[str]
The name of the BART variable in the model. Only needed if the model contains multiple
BART variables.
labels : Optional[list[str]]
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
to_kulprit : bool
If True, the function will return a list of list with the variables names.
This list can be passed as a path to Kulprit's project method. Defaults to False.

Returns
-------
VI_norm : npt.NDArray
Normalized variable inclusion.
labels : list[str]
List of the names of the covariates.
"""
VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
n_vars = X.shape[1]
vi_xarray = idata["sample_stats"]["variable_inclusion"]
if "variable_inclusion_dim_0" in vi_xarray.coords:
if model is None or bart_var_name is None:
raise ValueError(
"The InfereceData was generated from a model with multiple BART variables, \n"
"please provide the model and also the name of the BART variable \n"
"for which you want to compute the variable inclusion."
)
index = [var.name for var in model.free_RVs].index(bart_var_name)
vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel()
else:
vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel()
VIs = np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0)
VI_norm = VIs / VIs.sum()
idxs = np.argsort(VI_norm)

indices = idxs[::-1]
n_vars = len(indices)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
labels = list(X.columns)

if labels is None:
labels = np.arange(n_vars).astype(str)

label_list = labels.to_list()
labels = [str(i) for i in range(n_vars)]

if to_kulprit:
return [label_list[:idx] for idx in range(n_vars)]
return [labels[:idx] for idx in range(n_vars)]
else:
return VI_norm[indices], label_list
return VI_norm[indices], labels


def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None):
Expand Down Expand Up @@ -781,22 +800,26 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
idata: Any,
bartrv: Variable,
X: npt.NDArray,
model: "pm.Model | None" = None,
method: str = "VI",
fixed: int = 0,
samples: int = 50,
random_seed: Optional[int] = None,
random_seed: int | None = None,
) -> dict[str, object]:
"""
Estimates variable importance from the BART-posterior.

Parameters
----------
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
InferenceData containing a "variable_inclusion" variable in the sample_stats group.
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
X : npt.NDArray
The covariate matrix.
model : Optional[pm.Model]
The PyMC model that contains the BART variable. Only needed if the model contains multiple
BART variables.
method : str
Method used to rank variables. Available options are "VI" (default), "backward"
and "backward_VI".
Expand Down Expand Up @@ -825,6 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
rng = np.random.default_rng(random_seed)

all_trees = bartrv.owner.op.all_trees
bart_var_name = bartrv.name

if bartrv.ndim == 1: # type: ignore
shape = 1
Expand Down Expand Up @@ -858,9 +882,20 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
)

if method in ["VI", "backward_VI"]:
idxs = np.argsort(
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
)
vi_xarray = idata["sample_stats"]["variable_inclusion"]
if "variable_inclusion_dim_0" in vi_xarray.coords:
if model is None:
raise ValueError(
"The InfereceData was generated from a model with multiple BART variables, \n"
"please provide the model and also the name of the BART variable \n"
"for which you want to compute the variable inclusion."
)

index = [var.name for var in model.free_RVs].index(bart_var_name)
vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel()
else:
vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel()
idxs = np.argsort(np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0))
subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))]
subsets.append(None) # type: ignore

Expand Down Expand Up @@ -1226,3 +1261,39 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):

ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha)
return ax


def _decode_vi(n: int, length: int) -> list[int]:
"""
Decode the variable inclusion from the BART model.
"""
bits = bin(n)[2:]
vi_list: list[int] = []
i = 0
while len(vi_list) < length:
# Count prefix ones
prefix_len = 0
while bits[i] == "1":
prefix_len += 1
i += 1
i += 1 # skip the '0'
b = bits[i : i + prefix_len]
vi_list.append(int(b, 2))
i += prefix_len
return vi_list


def _encode_vi(vec: npt.NDArray) -> int:
"""
Encode variable inclusion vector into a single integer.

The encoding is done by converting each element of the vector into a binary string,
where each element contributes a prefix of '1's followed by a '0' and its binary representation.
The final result is the integer representation of the concatenated binary string.
"""
bits = ""
for x in vec:
b = bin(x)[2:]
prefix = "1" * len(b) + "0"
bits += prefix + b
return int(bits, 2)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pymc==5.24.0
pymc>=5.24.0
arviz-stats[xarray]>=0.6.0
numba
matplotlib
Expand Down
113 changes: 18 additions & 95 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import pymc as pm
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from numpy.testing import assert_almost_equal
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp

import pymc_bart as pmb
from pymc_bart.utils import _decode_vi


def assert_moment_is_expected(model, expected, check_finite_logp=True):
Expand Down Expand Up @@ -52,14 +53,12 @@ def test_bart_vi(response):
with pm.Model() as model:
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(tune=200, draws=200, random_seed=3415)
var_imp = (
idata.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
.mean("samples")
)
var_imp /= var_imp.sum()
vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel()
var_imp = np.array([_decode_vi(val, 3) for val in vi_vals]).sum(axis=0)

var_imp = var_imp / var_imp.sum()
assert var_imp[0] > var_imp[1:].sum()
assert_almost_equal(var_imp.sum(), 1)

Expand Down Expand Up @@ -123,92 +122,6 @@ def test_shape(response):
assert idata.posterior.coords["w_dim_1"].data.size == 250


class TestUtils:
X_norm = np.random.normal(0, 1, size=(50, 2))
X_binom = np.random.binomial(1, 0.5, size=(50, 1))
X = np.hstack([X_norm, X_binom])
Y = np.random.normal(0, 1, size=50)

with pm.Model() as model:
mu = pmb.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(tune=200, draws=200, random_seed=3415)

def test_sample_posterior(self):
all_trees = self.mu.owner.op.all_trees
rng = np.random.default_rng(3)
pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2)
rng = np.random.default_rng(3)
pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng)

assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4)
assert pred_all.shape == (2, 50, 1)
assert pred_first.shape == (1, 10, 1)

@pytest.mark.parametrize(
"kwargs",
[
{},
{
"samples": 2,
"var_discrete": [3],
},
{"instances": 2},
{"var_idx": [0], "smooth": False, "color": "k"},
{"grid": (1, 2), "sharey": "none", "alpha": 1},
{"var_discrete": [0]},
],
)
def test_ice(self, kwargs):
pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs)

@pytest.mark.parametrize(
"kwargs",
[
{},
{
"samples": 2,
"xs_interval": "quantiles",
"xs_values": [0.25, 0.5, 0.75],
"var_discrete": [3],
},
{"var_idx": [0], "smooth": False, "color": "k"},
{"grid": (1, 2), "sharey": "none", "alpha": 1},
{"var_discrete": [0]},
],
)
def test_pdp(self, kwargs):
pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs)

@pytest.mark.parametrize(
"kwargs",
[
{"samples": 50},
{"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)},
],
)
def test_vi(self, kwargs):
samples = kwargs.pop("samples")
vi_results = pmb.compute_variable_importance(
self.idata, bartrv=self.mu, X=self.X, samples=samples
)
pmb.plot_variable_importance(vi_results, **kwargs)
pmb.plot_scatter_submodels(vi_results, **kwargs)

def test_pdp_pandas_labels(self):
pd = pytest.importorskip("pandas")

X_names = ["norm1", "norm2", "binom"]
X_pd = pd.DataFrame(self.X, columns=X_names)
Y_pd = pd.Series(self.Y, name="response")
axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd)

figure = axes[0].figure
assert figure.texts[0].get_text() == "Partial response"
assert_array_equal([ax.get_xlabel() for ax in axes], X_names)


@pytest.mark.parametrize(
"size, expected",
[
Expand Down Expand Up @@ -275,7 +188,7 @@ def test_multiple_bart_variables():

# Combined model
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu1 + mu2, sigma, observed=Y)
pm.Normal("y", mu1 + mu2, sigma, observed=Y)

# Sample with automatic assignment of BART samplers
idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415)
Expand All @@ -291,6 +204,16 @@ def test_multiple_bart_variables():
assert idata.posterior["mu1"].shape == (1, 50, 50)
assert idata.posterior["mu2"].shape == (1, 50, 50)

vi_results = pmb.compute_variable_importance(idata, mu1, X1, model=model)
assert vi_results["labels"].shape == (2,)
assert vi_results["preds"].shape == (2, 50, 50)
assert vi_results["preds_all"].shape == (50, 50)

vi_tuple = pmb.get_variable_inclusion(idata, X1, model=model, bart_var_name="mu1")
assert vi_tuple[0].shape == (2,)
assert len(vi_tuple[1]) == 2
assert isinstance(vi_tuple[1][0], str)


def test_multiple_bart_variables_manual_step():
"""Test that multiple BART variables work with manually assigned PGBART samplers."""
Expand Down
Loading
Loading