Skip to content

Commit 3205b82

Browse files
authored
Use datatree from xarray and update example to latest einstats (#2458)
* Use datatree from xarray and update example to latest einstats * update test * black
1 parent 896da8d commit 3205b82

File tree

5 files changed

+16
-17
lines changed

5 files changed

+16
-17
lines changed

arviz/data/inference_data.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -532,24 +532,27 @@ def to_netcdf(
532532
return filename
533533

534534
def to_datatree(self):
535-
"""Convert InferenceData object to a :class:`~datatree.DataTree`."""
535+
"""Convert InferenceData object to a :class:`~xarray.DataTree`."""
536536
try:
537-
from datatree import DataTree
538-
except ModuleNotFoundError as err:
539-
raise ModuleNotFoundError(
540-
"datatree must be installed in order to use InferenceData.to_datatree"
537+
from xarray import DataTree
538+
except ImportError as err:
539+
raise ImportError(
540+
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
541+
"Update to xarray>=2024.11.0"
541542
) from err
542543
return DataTree.from_dict({group: ds for group, ds in self.items()})
543544

544545
@staticmethod
545546
def from_datatree(datatree):
546-
"""Create an InferenceData object from a :class:`~datatree.DataTree`.
547+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
547548
548549
Parameters
549550
----------
550551
datatree : DataTree
551552
"""
552-
return InferenceData(**{group: sub_dt.to_dataset() for group, sub_dt in datatree.items()})
553+
return InferenceData(
554+
**{group: child.to_dataset() for group, child in datatree.children.items()}
555+
)
553556

554557
def to_dict(self, groups=None, filter_groups=None):
555558
"""Convert InferenceData to a dictionary following xarray naming conventions.
@@ -1531,9 +1534,8 @@ def add_groups(
15311534
import xarray as xr
15321535
from xarray_einstats.stats import XrDiscreteRV
15331536
from scipy.stats import poisson
1534-
dist = XrDiscreteRV(poisson)
1535-
log_lik = xr.Dataset()
1536-
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
1537+
dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
1538+
log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
15371539
idata2.add_groups({"log_likelihood": log_lik})
15381540
idata2
15391541

arviz/data/io_datatree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def to_datatree(data):
7-
"""Convert InferenceData object to a :class:`~datatree.DataTree`.
7+
"""Convert InferenceData object to a :class:`~xarray.DataTree`.
88
99
Parameters
1010
----------
@@ -14,7 +14,7 @@ def to_datatree(data):
1414

1515

1616
def from_datatree(datatree):
17-
"""Create an InferenceData object from a :class:`~datatree.DataTree`.
17+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
1818
1919
Parameters
2020
----------

arviz/tests/base_tests/test_data.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,10 +1501,6 @@ def test_json_converters(self, models):
15011501
assert not os.path.exists(filepath)
15021502

15031503

1504-
@pytest.mark.skipif(
1505-
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
1506-
reason="test requires xarray-datatree library",
1507-
)
15081504
class TestDataTree:
15091505
def test_datatree(self):
15101506
idata = load_arviz_data("centered_eight")

doc/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
numpydoc_xref_aliases = {
116116
"DataArray": ":class:`~xarray.DataArray`",
117117
"Dataset": ":class:`~xarray.Dataset`",
118+
"DataTree": ":class:`~xarray.DataTree`",
118119
"Labeller": ":ref:`Labeller <labeller_api>`",
119120
"ndarray": ":class:`~numpy.ndarray`",
120121
"InferenceData": ":class:`~arviz.InferenceData`",

requirements-optional.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ contourpy
55
ujson
66
dask[distributed]
77
zarr>=2.5.0,<3
8-
xarray-datatree
8+
xarray>=2024.11.0
99
dm-tree>=0.1.8

0 commit comments

Comments
 (0)