Skip to content

Commit b7567de

Browse files
authored
encode vi and update to work with multiple RVs (#235)
* encode vi and update to work with multiple RVs * add missing tests
1 parent d9095d9 commit b7567de

File tree

7 files changed

+217
-113
lines changed

7 files changed

+217
-113
lines changed

env-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- pymc==5.24.0
6+
- pymc>=5.24.0
77
- numba
88
- matplotlib
99
- numpy

env.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- pymc==5.24.0
6+
- pymc>=5.24.0
77
- numba
88
- matplotlib
99
- numpy

pymc_bart/pgbart.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
get_idx_left_child,
3838
get_idx_right_child,
3939
)
40+
from pymc_bart.utils import _encode_vi
4041

4142

4243
class ParticleTree:
@@ -118,7 +119,7 @@ class PGBART(ArrayStepShared):
118119
default_blocked = False
119120
generates_stats = True
120121
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
121-
"variable_inclusion": (object, []),
122+
"variable_inclusion": (int, []),
122123
"tune": (bool, []),
123124
}
124125

@@ -335,6 +336,8 @@ def astep(self, _):
335336
if not self.tune:
336337
self.bart.all_trees.append(self.all_trees)
337338

339+
variable_inclusion = _encode_vi(variable_inclusion)
340+
338341
stats = {"variable_inclusion": variable_inclusion, "tune": self.tune}
339342
return self.sum_trees, [stats]
340343

pymc_bart/utils.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import matplotlib.pyplot as plt
88
import numpy as np
99
import numpy.typing as npt
10+
import pymc as pm
1011
import pytensor.tensor as pt
1112
from arviz_base import rcParams
1213
from arviz_stats.base import array_stats
@@ -674,48 +675,66 @@ def _smooth_mean(
674675
return x_data, y_data
675676

676677

677-
def get_variable_inclusion(idata, X, labels=None, to_kulprit=False):
678+
def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None, to_kulprit=False):
678679
"""
679680
Get the normalized variable inclusion from BART model.
680681
681682
Parameters
682683
----------
683684
idata : InferenceData
684-
InferenceData containing a collection of BART_trees in sample_stats group
685+
InferenceData with a variable "variable_inclusion" in ``sample_stats`` group
685686
X : npt.NDArray
686687
The covariate matrix.
688+
model : Optional[pm.Model]
689+
The PyMC model that contains the BART variable. Only needed if the model contains multiple
690+
BART variables.
691+
bart_var_name : Optional[str]
692+
The name of the BART variable in the model. Only needed if the model contains multiple
693+
BART variables.
687694
labels : Optional[list[str]]
688695
List of the names of the covariates. If X is a DataFrame the names of the covariables will
689696
be taken from it and this argument will be ignored.
690697
to_kulprit : bool
691698
If True, the function will return a list of list with the variables names.
692699
This list can be passed as a path to Kulprit's project method. Defaults to False.
700+
693701
Returns
694702
-------
695703
VI_norm : npt.NDArray
696704
Normalized variable inclusion.
697705
labels : list[str]
698706
List of the names of the covariates.
699707
"""
700-
VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
708+
n_vars = X.shape[1]
709+
vi_xarray = idata["sample_stats"]["variable_inclusion"]
710+
if "variable_inclusion_dim_0" in vi_xarray.coords:
711+
if model is None or bart_var_name is None:
712+
raise ValueError(
713+
"The InfereceData was generated from a model with multiple BART variables, \n"
714+
"please provide the model and also the name of the BART variable \n"
715+
"for which you want to compute the variable inclusion."
716+
)
717+
index = [var.name for var in model.free_RVs].index(bart_var_name)
718+
vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel()
719+
else:
720+
vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel()
721+
VIs = np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0)
701722
VI_norm = VIs / VIs.sum()
702723
idxs = np.argsort(VI_norm)
703724

704725
indices = idxs[::-1]
705726
n_vars = len(indices)
706727

707728
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
708-
labels = X.columns
729+
labels = list(X.columns)
709730

710731
if labels is None:
711-
labels = np.arange(n_vars).astype(str)
712-
713-
label_list = labels.to_list()
732+
labels = [str(i) for i in range(n_vars)]
714733

715734
if to_kulprit:
716-
return [label_list[:idx] for idx in range(n_vars)]
735+
return [labels[:idx] for idx in range(n_vars)]
717736
else:
718-
return VI_norm[indices], label_list
737+
return VI_norm[indices], labels
719738

720739

721740
def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None):
@@ -781,22 +800,26 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
781800
idata: Any,
782801
bartrv: Variable,
783802
X: npt.NDArray,
803+
model: "pm.Model | None" = None,
784804
method: str = "VI",
785805
fixed: int = 0,
786806
samples: int = 50,
787-
random_seed: Optional[int] = None,
807+
random_seed: int | None = None,
788808
) -> dict[str, object]:
789809
"""
790810
Estimates variable importance from the BART-posterior.
791811
792812
Parameters
793813
----------
794814
idata : InferenceData
795-
InferenceData containing a collection of BART_trees in sample_stats group
815+
InferenceData containing a "variable_inclusion" variable in the sample_stats group.
796816
bartrv : BART Random Variable
797817
BART variable once the model that include it has been fitted.
798818
X : npt.NDArray
799819
The covariate matrix.
820+
model : Optional[pm.Model]
821+
The PyMC model that contains the BART variable. Only needed if the model contains multiple
822+
BART variables.
800823
method : str
801824
Method used to rank variables. Available options are "VI" (default), "backward"
802825
and "backward_VI".
@@ -825,6 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
825848
rng = np.random.default_rng(random_seed)
826849

827850
all_trees = bartrv.owner.op.all_trees
851+
bart_var_name = bartrv.name
828852

829853
if bartrv.ndim == 1: # type: ignore
830854
shape = 1
@@ -858,9 +882,20 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
858882
)
859883

860884
if method in ["VI", "backward_VI"]:
861-
idxs = np.argsort(
862-
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
863-
)
885+
vi_xarray = idata["sample_stats"]["variable_inclusion"]
886+
if "variable_inclusion_dim_0" in vi_xarray.coords:
887+
if model is None:
888+
raise ValueError(
889+
"The InfereceData was generated from a model with multiple BART variables, \n"
890+
"please provide the model and also the name of the BART variable \n"
891+
"for which you want to compute the variable inclusion."
892+
)
893+
894+
index = [var.name for var in model.free_RVs].index(bart_var_name)
895+
vi_vals = vi_xarray.sel({"variable_inclusion_dim_0": index}).values.ravel()
896+
else:
897+
vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel()
898+
idxs = np.argsort(np.array([_decode_vi(val, n_vars) for val in vi_vals]).sum(axis=0))
864899
subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))]
865900
subsets.append(None) # type: ignore
866901

@@ -1226,3 +1261,39 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):
12261261

12271262
ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha)
12281263
return ax
1264+
1265+
1266+
def _decode_vi(n: int, length: int) -> list[int]:
1267+
"""
1268+
Decode the variable inclusion from the BART model.
1269+
"""
1270+
bits = bin(n)[2:]
1271+
vi_list: list[int] = []
1272+
i = 0
1273+
while len(vi_list) < length:
1274+
# Count prefix ones
1275+
prefix_len = 0
1276+
while bits[i] == "1":
1277+
prefix_len += 1
1278+
i += 1
1279+
i += 1 # skip the '0'
1280+
b = bits[i : i + prefix_len]
1281+
vi_list.append(int(b, 2))
1282+
i += prefix_len
1283+
return vi_list
1284+
1285+
1286+
def _encode_vi(vec: npt.NDArray) -> int:
1287+
"""
1288+
Encode variable inclusion vector into a single integer.
1289+
1290+
The encoding is done by converting each element of the vector into a binary string,
1291+
where each element contributes a prefix of '1's followed by a '0' and its binary representation.
1292+
The final result is the integer representation of the concatenated binary string.
1293+
"""
1294+
bits = ""
1295+
for x in vec:
1296+
b = bin(x)[2:]
1297+
prefix = "1" * len(b) + "0"
1298+
bits += prefix + b
1299+
return int(bits, 2)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pymc==5.24.0
1+
pymc>=5.24.0
22
arviz-stats[xarray]>=0.6.0
33
numba
44
matplotlib

tests/test_bart.py

Lines changed: 18 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
22
import pymc as pm
33
import pytest
4-
from numpy.testing import assert_almost_equal, assert_array_equal
4+
from numpy.testing import assert_almost_equal
55
from pymc.initial_point import make_initial_point_fn
66
from pymc.logprob.basic import transformed_conditional_logp
77

88
import pymc_bart as pmb
9+
from pymc_bart.utils import _decode_vi
910

1011

1112
def assert_moment_is_expected(model, expected, check_finite_logp=True):
@@ -52,14 +53,12 @@ def test_bart_vi(response):
5253
with pm.Model() as model:
5354
mu = pmb.BART("mu", X, Y, m=10, response=response)
5455
sigma = pm.HalfNormal("sigma", 1)
55-
y = pm.Normal("y", mu, sigma, observed=Y)
56+
pm.Normal("y", mu, sigma, observed=Y)
5657
idata = pm.sample(tune=200, draws=200, random_seed=3415)
57-
var_imp = (
58-
idata.sample_stats["variable_inclusion"]
59-
.stack(samples=("chain", "draw"))
60-
.mean("samples")
61-
)
62-
var_imp /= var_imp.sum()
58+
vi_vals = idata["sample_stats"]["variable_inclusion"].values.ravel()
59+
var_imp = np.array([_decode_vi(val, 3) for val in vi_vals]).sum(axis=0)
60+
61+
var_imp = var_imp / var_imp.sum()
6362
assert var_imp[0] > var_imp[1:].sum()
6463
assert_almost_equal(var_imp.sum(), 1)
6564

@@ -123,92 +122,6 @@ def test_shape(response):
123122
assert idata.posterior.coords["w_dim_1"].data.size == 250
124123

125124

126-
class TestUtils:
127-
X_norm = np.random.normal(0, 1, size=(50, 2))
128-
X_binom = np.random.binomial(1, 0.5, size=(50, 1))
129-
X = np.hstack([X_norm, X_binom])
130-
Y = np.random.normal(0, 1, size=50)
131-
132-
with pm.Model() as model:
133-
mu = pmb.BART("mu", X, Y, m=10)
134-
sigma = pm.HalfNormal("sigma", 1)
135-
y = pm.Normal("y", mu, sigma, observed=Y)
136-
idata = pm.sample(tune=200, draws=200, random_seed=3415)
137-
138-
def test_sample_posterior(self):
139-
all_trees = self.mu.owner.op.all_trees
140-
rng = np.random.default_rng(3)
141-
pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2)
142-
rng = np.random.default_rng(3)
143-
pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng)
144-
145-
assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4)
146-
assert pred_all.shape == (2, 50, 1)
147-
assert pred_first.shape == (1, 10, 1)
148-
149-
@pytest.mark.parametrize(
150-
"kwargs",
151-
[
152-
{},
153-
{
154-
"samples": 2,
155-
"var_discrete": [3],
156-
},
157-
{"instances": 2},
158-
{"var_idx": [0], "smooth": False, "color": "k"},
159-
{"grid": (1, 2), "sharey": "none", "alpha": 1},
160-
{"var_discrete": [0]},
161-
],
162-
)
163-
def test_ice(self, kwargs):
164-
pmb.plot_ice(self.mu, X=self.X, Y=self.Y, **kwargs)
165-
166-
@pytest.mark.parametrize(
167-
"kwargs",
168-
[
169-
{},
170-
{
171-
"samples": 2,
172-
"xs_interval": "quantiles",
173-
"xs_values": [0.25, 0.5, 0.75],
174-
"var_discrete": [3],
175-
},
176-
{"var_idx": [0], "smooth": False, "color": "k"},
177-
{"grid": (1, 2), "sharey": "none", "alpha": 1},
178-
{"var_discrete": [0]},
179-
],
180-
)
181-
def test_pdp(self, kwargs):
182-
pmb.plot_pdp(self.mu, X=self.X, Y=self.Y, **kwargs)
183-
184-
@pytest.mark.parametrize(
185-
"kwargs",
186-
[
187-
{"samples": 50},
188-
{"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)},
189-
],
190-
)
191-
def test_vi(self, kwargs):
192-
samples = kwargs.pop("samples")
193-
vi_results = pmb.compute_variable_importance(
194-
self.idata, bartrv=self.mu, X=self.X, samples=samples
195-
)
196-
pmb.plot_variable_importance(vi_results, **kwargs)
197-
pmb.plot_scatter_submodels(vi_results, **kwargs)
198-
199-
def test_pdp_pandas_labels(self):
200-
pd = pytest.importorskip("pandas")
201-
202-
X_names = ["norm1", "norm2", "binom"]
203-
X_pd = pd.DataFrame(self.X, columns=X_names)
204-
Y_pd = pd.Series(self.Y, name="response")
205-
axes = pmb.plot_pdp(self.mu, X=X_pd, Y=Y_pd)
206-
207-
figure = axes[0].figure
208-
assert figure.texts[0].get_text() == "Partial response"
209-
assert_array_equal([ax.get_xlabel() for ax in axes], X_names)
210-
211-
212125
@pytest.mark.parametrize(
213126
"size, expected",
214127
[
@@ -275,7 +188,7 @@ def test_multiple_bart_variables():
275188

276189
# Combined model
277190
sigma = pm.HalfNormal("sigma", 1)
278-
y = pm.Normal("y", mu1 + mu2, sigma, observed=Y)
191+
pm.Normal("y", mu1 + mu2, sigma, observed=Y)
279192

280193
# Sample with automatic assignment of BART samplers
281194
idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415)
@@ -291,6 +204,16 @@ def test_multiple_bart_variables():
291204
assert idata.posterior["mu1"].shape == (1, 50, 50)
292205
assert idata.posterior["mu2"].shape == (1, 50, 50)
293206

207+
vi_results = pmb.compute_variable_importance(idata, mu1, X1, model=model)
208+
assert vi_results["labels"].shape == (2,)
209+
assert vi_results["preds"].shape == (2, 50, 50)
210+
assert vi_results["preds_all"].shape == (50, 50)
211+
212+
vi_tuple = pmb.get_variable_inclusion(idata, X1, model=model, bart_var_name="mu1")
213+
assert vi_tuple[0].shape == (2,)
214+
assert len(vi_tuple[1]) == 2
215+
assert isinstance(vi_tuple[1][0], str)
216+
294217

295218
def test_multiple_bart_variables_manual_step():
296219
"""Test that multiple BART variables work with manually assigned PGBART samplers."""

0 commit comments

Comments
 (0)