Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,24 +532,27 @@
return filename

def to_datatree(self):
"""Convert InferenceData object to a :class:`~datatree.DataTree`."""
"""Convert InferenceData object to a :class:`~xarray.DataTree`."""
try:
from datatree import DataTree
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"datatree must be installed in order to use InferenceData.to_datatree"
from xarray import DataTree
except ImportError as err:
raise ImportError(

Check warning on line 539 in arviz/data/inference_data.py

View check run for this annotation

Codecov / codecov/patch

arviz/data/inference_data.py#L538-L539

Added lines #L538 - L539 were not covered by tests
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
"Update to xarray>=2024.11.0"
) from err
return DataTree.from_dict({group: ds for group, ds in self.items()})

@staticmethod
def from_datatree(datatree):
"""Create an InferenceData object from a :class:`~datatree.DataTree`.
"""Create an InferenceData object from a :class:`~xarray.DataTree`.

Parameters
----------
datatree : DataTree
"""
return InferenceData(**{group: sub_dt.to_dataset() for group, sub_dt in datatree.items()})
return InferenceData(
**{group: child.to_dataset() for group, child in datatree.children.items()}
)

def to_dict(self, groups=None, filter_groups=None):
"""Convert InferenceData to a dictionary following xarray naming conventions.
Expand Down Expand Up @@ -1531,9 +1534,8 @@
import xarray as xr
from xarray_einstats.stats import XrDiscreteRV
from scipy.stats import poisson
dist = XrDiscreteRV(poisson)
log_lik = xr.Dataset()
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
idata2.add_groups({"log_likelihood": log_lik})
idata2

Expand Down
4 changes: 2 additions & 2 deletions arviz/data/io_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def to_datatree(data):
"""Convert InferenceData object to a :class:`~datatree.DataTree`.
"""Convert InferenceData object to a :class:`~xarray.DataTree`.

Parameters
----------
Expand All @@ -14,7 +14,7 @@ def to_datatree(data):


def from_datatree(datatree):
"""Create an InferenceData object from a :class:`~datatree.DataTree`.
"""Create an InferenceData object from a :class:`~xarray.DataTree`.

Parameters
----------
Expand Down
4 changes: 0 additions & 4 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,10 +1501,6 @@ def test_json_converters(self, models):
assert not os.path.exists(filepath)


@pytest.mark.skipif(
not (importlib.util.find_spec("datatree") or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
reason="test requires xarray-datatree library",
)
class TestDataTree:
def test_datatree(self):
idata = load_arviz_data("centered_eight")
Expand Down
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
numpydoc_xref_aliases = {
"DataArray": ":class:`~xarray.DataArray`",
"Dataset": ":class:`~xarray.Dataset`",
"DataTree": ":class:`~xarray.DataTree`",
"Labeller": ":ref:`Labeller <labeller_api>`",
"ndarray": ":class:`~numpy.ndarray`",
"InferenceData": ":class:`~arviz.InferenceData`",
Expand Down
2 changes: 1 addition & 1 deletion requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ contourpy
ujson
dask[distributed]
zarr>=2.5.0,<3
xarray-datatree
xarray>=2024.11.0
dm-tree>=0.1.8