diff --git a/ci/doc.yml b/ci/doc.yml index 0a20f516..ff303a98 100644 --- a/ci/doc.yml +++ b/ci/doc.yml @@ -4,7 +4,6 @@ channels: dependencies: - pip - python>=3.8 - - xarray>=0.20.2 - netcdf4 - scipy - sphinx @@ -16,3 +15,4 @@ dependencies: - zarr - pip: - git+https://github.com/xarray-contrib/datatree + - xarray>=2022.05.0.dev0 diff --git a/ci/environment.yml b/ci/environment.yml index c5d58977..1aa9af93 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -4,7 +4,6 @@ channels: - nodefaults dependencies: - python>=3.8 - - xarray>=0.20.2 - netcdf4 - pytest - flake8 @@ -13,3 +12,5 @@ dependencies: - pytest-cov - h5netcdf - zarr + - pip: + - xarray>=2022.05.0.dev0 diff --git a/datatree/__init__.py b/datatree/__init__.py index d799dc02..58b65aec 100644 --- a/datatree/__init__.py +++ b/datatree/__init__.py @@ -6,7 +6,7 @@ # import public API from .datatree import DataTree from .io import open_datatree -from .mapping import map_over_subtree +from .mapping import TreeIsomorphismError, map_over_subtree try: __version__ = get_distribution(__name__).version diff --git a/datatree/datatree.py b/datatree/datatree.py index 708a4599..05dfb850 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1,5 +1,7 @@ from __future__ import annotations +import copy +import itertools from collections import OrderedDict from html import escape from typing import ( @@ -8,18 +10,27 @@ Callable, Dict, Generic, + Hashable, Iterable, + Iterator, Mapping, MutableMapping, Optional, + Set, Tuple, Union, ) -from xarray import DataArray, Dataset +import pandas as pd from xarray.core import utils +from xarray.core.coordinates import DatasetCoordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset, DataVariables +from xarray.core.indexes import Index, Indexes +from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS -from xarray.core.variable import Variable +from xarray.core.utils import Default, Frozen, _default +from xarray.core.variable import Variable, calculate_dimensions from . import formatting, formatting_html from .mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree @@ -41,7 +52,7 @@ # the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every # node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin # classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API. - +# # Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered # (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new # tree) and some will get overridden by the class definition of DataTree. @@ -51,12 +62,37 @@ T_Path = Union[str, NodePath] +def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: + if isinstance(data, DataArray): + ds = data.to_dataset() + elif isinstance(data, Dataset): + ds = data + elif data is None: + ds = Dataset() + else: + raise TypeError( + f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" + ) + return ds + + +def _check_for_name_collisions( + children: Iterable[str], variables: Iterable[Hashable] +) -> None: + colliding_names = set(children).intersection(set(variables)) + if colliding_names: + raise KeyError( + f"Some names would collide between variables and children: {list(colliding_names)}" + ) + + class DataTree( TreeNode, MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, Generic[Tree], + Mapping, ): """ A tree-like hierarchical collection of xarray objects. @@ -80,10 +116,23 @@ class DataTree( # TODO .loc, __contains__, __iter__, __array__, __len__ + # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from + + # TODO __slots__ + + # TODO all groupby classes + _name: Optional[str] - _parent: Optional[Tree] - _children: OrderedDict[str, Tree] - _ds: Dataset + _parent: Optional[DataTree] + _children: OrderedDict[str, DataTree] + _attrs: Optional[Dict[Hashable, Any]] + _cache: Dict[str, Any] + _coord_names: Set[Hashable] + _dims: Dict[Hashable, int] + _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] + _indexes: Dict[Hashable, Index] + _variables: Dict[Hashable, Variable] def __init__( self, @@ -93,33 +142,54 @@ def __init__( name: str = None, ): """ - Create a single node of a DataTree, which optionally contains data in the form of an xarray.Dataset. + Create a single node of a DataTree. + + The node may optionally contain data in the form of data and coordinate variables, stored in the same way as + data is stored in an xarray.Dataset. Parameters ---------- - data : Dataset, DataArray, Variable or None, optional - Data to store under the .ds attribute of this node. DataArrays and Variables will be promoted to Datasets. + data : Dataset, DataArray, or None, optional + Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. Default is None. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional Any child nodes of this node. Default is None. name : str, optional - Name for the root node of the tree. + Name for this node of the tree. Default is None. Returns ------- - node : DataTree + DataTree See Also -------- DataTree.from_dict """ + # validate input + if children is None: + children = {} + ds = _coerce_to_dataset(data) + _check_for_name_collisions(children, ds.variables) + + # set tree attributes super().__init__(children=children) self.name = name self.parent = parent - self.ds = data + + # set data attributes + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close @property def name(self) -> str | None: @@ -149,53 +219,136 @@ def parent(self: DataTree, new_parent: DataTree) -> None: @property def ds(self) -> Dataset: """The data in this node, returned as a Dataset.""" - return self._ds + # TODO change this to return only an immutable view onto this node's data (see GH #80) + return self.to_dataset() @ds.setter def ds(self, data: Union[Dataset, DataArray] = None) -> None: - if not isinstance(data, (Dataset, DataArray)) and data is not None: - raise TypeError( - f"{type(data)} object is not an xarray Dataset, DataArray, or None" - ) - if isinstance(data, DataArray): - data = data.to_dataset() - elif data is None: - data = Dataset() + ds = _coerce_to_dataset(data) - for var in list(data.variables): - if var in self.children: - raise KeyError( - f"Cannot add variable named {var}: node already has a child named {var}" - ) + _check_for_name_collisions(self.children, ds.variables) + + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close + + def _pre_attach(self: DataTree, parent: DataTree) -> None: + """ + Method which superclass calls before setting parent, here used to prevent having two + children with duplicate names (or a data variable with the same name as a child). + """ + super()._pre_attach(parent) + if self.name in list(parent.ds.variables): + raise KeyError( + f"parent {parent.name} already contains a data variable named {self.name}" + ) - self._ds = data + def to_dataset(self) -> Dataset: + """Return the data in this node as a new xarray.Dataset object.""" + return Dataset._construct_direct( + self._variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) @property - def has_data(self) -> bool: + def has_data(self): """Whether or not there are any data variables in this node.""" - return len(self.ds.variables) > 0 + return len(self._variables) > 0 @property def has_attrs(self) -> bool: """Whether or not there are any metadata attributes in this node.""" - return len(self.ds.attrs.keys()) > 0 + return len(self.attrs.keys()) > 0 @property def is_empty(self) -> bool: """False if node contains any data or attrs. Does not look at children.""" return not (self.has_data or self.has_attrs) - def _pre_attach(self: DataTree, parent: DataTree) -> None: + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to node contents as dict of Variable objects. + + This ordered dictionary is frozen to prevent mutation that could + violate Dataset invariants. It contains all variable objects + constituting this DataTree node, including both data variables and + coordinates. """ - Method which superclass calls before setting parent, here used to prevent having two - children with duplicate names (or a data variable with the same name as a child). + return Frozen(self._variables) + + @property + def attrs(self) -> Dict[Hashable, Any]: + """Dictionary of global attributes on this node""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + + @property + def encoding(self) -> Dict: + """Dictionary of global encoding attributes on this node""" + if self._encoding is None: + self._encoding = {} + return self._encoding + + @encoding.setter + def encoding(self, value: Mapping) -> None: + self._encoding = dict(value) + + @property + def dims(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + Note that type of this object differs from `DataArray.dims`. + See `DataTree.sizes`, `Dataset.sizes`, and `DataArray.sizes` for consistently named + properties. """ - super()._pre_attach(parent) - if parent.has_data and self.name in list(parent.ds.variables): - raise KeyError( - f"parent {parent.name} already contains a data variable named {self.name}" - ) + return Frozen(self._dims) + + @property + def sizes(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + This is an alias for `DataTree.dims` provided for the benefit of + consistency with `DataArray.sizes`. + + See Also + -------- + DataArray.sizes + """ + return self.dims + + def __contains__(self, key: object) -> bool: + """The 'in' operator will return true or false depending on whether + 'key' is either an array stored in the datatree or a child node, or neither. + """ + return key in self.variables or key in self.children + + def __bool__(self) -> bool: + return bool(self.ds.data_vars) or bool(self.children) + + def __iter__(self) -> Iterator[Hashable]: + return itertools.chain(self.ds.data_vars, self.children) def __repr__(self) -> str: return formatting.datatree_repr(self) @@ -209,20 +362,135 @@ def _repr_html_(self): return f"
{escape(repr(self))}
" return formatting_html.datatree_repr(self) + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] = None, + attrs: dict = None, + indexes: dict[Any, Index] = None, + encoding: dict = None, + name: str | None = None, + parent: DataTree | None = None, + children: OrderedDict[str, DataTree] = None, + close: Callable[[], None] = None, + ) -> DataTree: + """Shortcut around __init__ for internal use when we want to skip costly validation.""" + + # data attributes + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + if children is None: + children = OrderedDict() + + obj: DataTree = object.__new__(cls) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + + # tree attributes + obj._name = name + obj._children = children + obj._parent = parent + + return obj + + def _replace( + self: DataTree, + variables: dict[Hashable, Variable] = None, + coord_names: set[Hashable] = None, + dims: dict[Any, int] = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] = None, + encoding: dict | None | Default = _default, + name: str | None | Default = _default, + parent: DataTree | None = _default, + children: OrderedDict[str, DataTree] = None, + inplace: bool = False, + ) -> DataTree: + """ + Fastpath constructor for internal use. + + Returns an object with optionally replaced attributes. + + Explicitly passed arguments are *not* copied when placed on the new + datatree. It is up to the caller to ensure that they have the right type + and are not used elsewhere. + """ + if inplace: + if variables is not None: + self._variables = variables + if coord_names is not None: + self._coord_names = coord_names + if dims is not None: + self._dims = dims + if attrs is not _default: + self._attrs = attrs + if indexes is not None: + self._indexes = indexes + if encoding is not _default: + self._encoding = encoding + if name is not _default: + self._name = name + if parent is not _default: + self._parent = parent + if children is not None: + self._children = children + obj = self + else: + if variables is None: + variables = self._variables.copy() + if coord_names is None: + coord_names = self._coord_names.copy() + if dims is None: + dims = self._dims.copy() + if attrs is _default: + attrs = copy.copy(self._attrs) + if indexes is None: + indexes = self._indexes.copy() + if encoding is _default: + encoding = copy.copy(self._encoding) + if name is _default: + name = self._name # no need to copy str objects or None + if parent is _default: + parent = copy.copy(self._parent) + if children is _default: + children = copy.copy(self._children) + obj = self._construct_direct( + variables, + coord_names, + dims, + attrs, + indexes, + encoding, + name, + parent, + children, + ) + return obj + def get( self: DataTree, key: str, default: Optional[DataTree | DataArray] = None ) -> Optional[DataTree | DataArray]: """ - Access child nodes stored in this node as a DataTree or variables or coordinates stored in this node as a - DataArray. + Access child nodes, variables, or coordinates stored in this node. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. Parameters ---------- key : str - Name of variable / node item, which must lie in this immediate node (not elsewhere in the tree). + Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). default : DataTree | DataArray, optional - A value to return if the specified key does not exist. - Default value is None. + A value to return if the specified key does not exist. Default return value is None. """ if key in self.children: return self.children[key] @@ -233,13 +501,19 @@ def get( def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: """ - Access child nodes stored in this tree as a DataTree or variables or coordinates stored in this tree as a - DataArray. + Access child nodes, variables, or coordinates stored anywhere in this tree. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. Parameters ---------- key : str - Name of variable / node, or unix-like path to variable / node. + Name of variable / child within this node, or unix-like path to variable / child within another node. + + Returns + ------- + Union[DataTree, DataArray] """ # Either: @@ -272,7 +546,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: val.parent = self elif isinstance(val, (DataArray, Variable)): # TODO this should also accomodate other types that can be coerced into Variables - self.ds[key] = val + self.update({key: val}) else: raise TypeError(f"Type {type(val)} cannot be assigned to a DataTree") @@ -316,8 +590,12 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: else: raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") - super().update(new_children) - self.ds.update(new_variables) + vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + # TODO are there any subtleties with preserving order of children like this? + merged_children = OrderedDict(**self.children, **new_children) + self._replace( + inplace=True, children=merged_children, **vars_merge_result._asdict() + ) @classmethod def from_dict( @@ -326,7 +604,7 @@ def from_dict( name: str = None, ) -> DataTree: """ - Create a datatree from a dictionary of data objects, labelled by paths into the tree. + Create a datatree from a dictionary of data objects, organised by paths into the tree. Parameters ---------- @@ -365,28 +643,54 @@ def from_dict( allow_overwrite=False, new_nodes_along_path=True, ) + return obj - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> Dict[str, Dataset]: """ Create a dictionary mapping of absolute node paths to the data contained in those nodes. Returns ------- - Dict + Dict[str, Dataset] """ - return {node.path: node.ds for node in self.subtree} + return {node.path: node.to_dataset() for node in self.subtree} @property def nbytes(self) -> int: - return sum(node.ds.nbytes if node.has_data else 0 for node in self.subtree) + return sum(node.to_dataset().nbytes for node in self.subtree) def __len__(self) -> int: - if self.children: - n_children = len(self.children) - else: - n_children = 0 - return n_children + len(self.ds) + return len(self.children) + len(self.data_vars) + + @property + def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + Raises an error if this DataTree node has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + DataTree.xindexes + """ + return self.xindexes.to_pandas_indexes() + + @property + def xindexes(self) -> Indexes[Index]: + """Mapping of xarray Index objects used for label based indexing.""" + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + + @property + def coords(self) -> DatasetCoordinates: + """Dictionary of xarray.DataArray objects corresponding to coordinate + variables + """ + return DatasetCoordinates(self.to_dataset()) + + @property + def data_vars(self) -> DataVariables: + """Dictionary of DataArray objects corresponding to data variables""" + return DataVariables(self.to_dataset()) def isomorphic( self, @@ -400,7 +704,7 @@ def isomorphic( Nothing about the data in each node is checked. Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, - such as tree1 + tree2. + such as ``tree1 + tree2``. By default this method does not check any part of the tree above the given node. Therefore this method can be used as default to check that two subtrees are isomorphic. @@ -408,12 +712,13 @@ def isomorphic( Parameters ---------- other : DataTree - The tree object to compare to. + The other tree object to compare to. from_root : bool, optional, default is False - Whether or not to first traverse to the root of the trees before checking for isomorphism. - If a & b have no parents then this has no effect. + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. strict_names : bool, optional, default is False - Whether or not to also check that each node has the same name as its counterpart. + Whether or not to also check that every node in the tree has the same name as its counterpart in the other + tree. See Also -------- @@ -441,10 +746,10 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool: Parameters ---------- other : DataTree - The tree object to compare to. + The other tree object to compare to. from_root : bool, optional, default is True - Whether or not to first traverse to the root of the trees before checking. - If a & b have no parents then this has no effect. + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. See Also -------- @@ -472,10 +777,10 @@ def identical(self, other: DataTree, from_root=True) -> bool: Parameters ---------- other : DataTree - The tree object to compare to. + The other tree object to compare to. from_root : bool, optional, default is True - Whether or not to first traverse to the root of the trees before checking. - If a & b have no parents then this has no effect. + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. See Also -------- diff --git a/datatree/ops.py b/datatree/ops.py index ee55ccfe..bdc931c9 100644 --- a/datatree/ops.py +++ b/datatree/ops.py @@ -30,8 +30,8 @@ "map_blocks", ] _DATASET_METHODS_TO_MAP = [ - "copy", "as_numpy", + "copy", "__copy__", "__deepcopy__", "set_coords", @@ -57,7 +57,6 @@ "reorder_levels", "stack", "unstack", - "update", "merge", "drop_vars", "drop_sel", @@ -245,7 +244,6 @@ class MappedDataWithCoords: """ # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample - # TODO re-implement AttrsAccessMixin stuff so that it includes access to child nodes _wrap_then_attach_to_cls( target_cls_dict=vars(), source_cls=Dataset, diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 87611db3..b69a54b8 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -1,7 +1,12 @@ +from copy import copy, deepcopy + +import numpy as np import pytest import xarray as xr import xarray.testing as xrt +from xarray.tests import source_ndarray +import datatree.testing as dtt from datatree import DataTree @@ -31,12 +36,37 @@ def test_setparent_unnamed_child_node_fails(self): with pytest.raises(ValueError, match="unnamed"): DataTree(parent=john) + def test_create_two_children(self): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=root) + DataTree(name="set2", parent=set1) + + def test_create_full_tree(self, simple_datatree): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) + + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2 = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + expected = simple_datatree + assert root.identical(expected) + class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) john = DataTree(name="john", data=dat) - assert john.ds is dat + xrt.assert_identical(john.ds, dat) with pytest.raises(TypeError): DataTree(name="mary", parent=john, data="junk") # noqa @@ -45,7 +75,7 @@ def test_set_data(self): john = DataTree(name="john") dat = xr.Dataset({"a": 0}) john.ds = dat - assert john.ds is dat + xrt.assert_identical(john.ds, dat) with pytest.raises(TypeError): john.ds = "junk" @@ -66,11 +96,11 @@ def test_parent_already_has_variable_with_childs_name(self): def test_assign_when_already_child_with_variables_name(self): dt = DataTree(data=None) DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="already has a child named a"): + with pytest.raises(KeyError, match="names would collide"): dt.ds = xr.Dataset({"a": 0}) dt.ds = xr.Dataset() - with pytest.raises(KeyError, match="already has a child named a"): + with pytest.raises(KeyError, match="names would collide"): dt.ds = dt.ds.assign(a=xr.DataArray(0)) @pytest.mark.xfail @@ -78,7 +108,7 @@ def test_update_when_already_child_with_variables_name(self): # See issue #38 dt = DataTree(name="root", data=None) DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="already has a child named a"): + with pytest.raises(KeyError, match="names would collide"): dt.ds["a"] = xr.DataArray(0) @@ -136,7 +166,82 @@ def test_getitem_dict_like_selection_access_to_dataset(self): class TestUpdate: - ... + def test_update_new_named_dataarray(self): + da = xr.DataArray(name="temp", data=[0, 50]) + folder1 = DataTree(name="folder1") + folder1.update({"results": da}) + expected = da.rename("results") + xrt.assert_equal(folder1["results"], expected) + + +class TestCopy: + def test_copy(self, create_test_datatree): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=False), copy(dt)]: + dtt.assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + for k in node.data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is copied_node.attrs["Test"] + + def test_deepcopy(self, create_test_datatree): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=True), deepcopy(dt)]: + dtt.assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + for k in node.data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is not source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is not copied_node.attrs["Test"] + + @pytest.mark.xfail(reason="data argument not yet implemented") + def test_copy_with_data(self, create_test_datatree): + orig = create_test_datatree() + # TODO use .data_vars once that property is available + data_vars = { + k: v for k, v in orig.variables.items() if k not in orig._coord_names + } + new_data = {k: np.random.randn(*v.shape) for k, v in data_vars.items()} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + dtt.assert_identical(expected, actual) + + # TODO test parents and children? class TestSetItem: @@ -187,27 +292,27 @@ def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results") results["."] = data - assert results.ds is data + xrt.assert_identical(results.ds, data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results"] = data - assert folder1["results"].ds is data + xrt.assert_identical(folder1["results"].ds, data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results/highres"] = data - assert folder1["results/highres"].ds is data + xrt.assert_identical(folder1["results/highres"].ds, data) def test_setitem_named_dataarray(self): - data = xr.DataArray(name="temp", data=[0, 50]) + da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1") - folder1["results"] = data - expected = data.rename("results") + folder1["results"] = da + expected = da.rename("results") xrt.assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): @@ -250,16 +355,16 @@ def test_data_in_root(self): assert dt.name is None assert dt.parent is None assert dt.children == {} - assert dt.ds is dat + xrt.assert_identical(dt.ds, dat) def test_one_layer(self): dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) xrt.assert_identical(dt.ds, xr.Dataset()) assert dt.name is None - assert dt["run1"].ds is dat1 + xrt.assert_identical(dt["run1"].ds, dat1) assert dt["run1"].children == {} - assert dt["run2"].ds is dat2 + xrt.assert_identical(dt["run2"].ds, dat2) assert dt["run2"].children == {} def test_two_layers(self): @@ -268,13 +373,13 @@ def test_two_layers(self): assert "highres" in dt.children assert "lowres" in dt.children highres_run = dt["highres/run"] - assert highres_run.ds is dat1 + xrt.assert_identical(highres_run.ds, dat1) def test_nones(self): dt = DataTree.from_dict({"d": None, "d/e": None}) assert [node.name for node in dt.subtree] == [None, "d", "e"] assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] - xrt.assert_equal(dt["d/e"].ds, xr.Dataset()) + xrt.assert_identical(dt["d/e"].ds, xr.Dataset()) def test_full(self, simple_datatree): dt = simple_datatree diff --git a/datatree/treenode.py b/datatree/treenode.py index e29bfd66..a2e87675 100644 --- a/datatree/treenode.py +++ b/datatree/treenode.py @@ -137,7 +137,9 @@ def _detach(self, parent: Tree | None) -> None: def _attach(self, parent: Tree | None, child_name: str = None) -> None: if parent is not None: if child_name is None: - raise ValueError("Cannot directly assign a parent to an unnamed node") + raise ValueError( + "To directly set parent, child needs a name, but child is unnamed" + ) self._pre_attach(parent) parentchildren = parent._children diff --git a/docs/source/api.rst b/docs/source/api.rst index 5cd16466..9ad74190 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -18,6 +18,8 @@ Creating a DataTree Tree Attributes --------------- +Attributes relating to the recursive tree-like structure of a ``DataTree``. + .. autosummary:: :toctree: generated/ @@ -34,34 +36,40 @@ Tree Attributes DataTree.ancestors DataTree.groups -Data Attributes ---------------- +Data Contents +------------- + +Interface to the data objects (optionally) stored inside a single ``DataTree`` node. +This interface echoes that of ``xarray.Dataset``. .. autosummary:: :toctree: generated/ DataTree.dims - DataTree.variables - DataTree.encoding DataTree.sizes + DataTree.data_vars + DataTree.coords DataTree.attrs + DataTree.encoding DataTree.indexes - DataTree.xindexes - DataTree.coords DataTree.chunks + DataTree.nbytes DataTree.ds + DataTree.to_dataset DataTree.has_data DataTree.has_attrs DataTree.is_empty .. - Missing - DataTree.chunksizes + Missing: + ``DataTree.chunksizes`` Dictionary interface -------------------- +``DataTree`` objects also have a dict-like interface mapping keys to either ``xarray.DataArray``s or to child ``DataTree`` nodes. + .. autosummary:: :toctree: generated/ @@ -70,16 +78,14 @@ Dictionary interface DataTree.__delitem__ DataTree.update DataTree.get - -.. - - Missing DataTree.items DataTree.keys DataTree.values -Tree Manipulation Methods -------------------------- +Tree Manipulation +----------------- + +For manipulating, traversing, navigating, or mapping over the tree structure. .. autosummary:: :toctree: generated/ @@ -89,127 +95,181 @@ Tree Manipulation Methods DataTree.relative_to DataTree.iter_lineage DataTree.find_common_ancestor + map_over_subtree + +DataTree Contents +----------------- -Tree Manipulation Utilities ---------------------------- +Manipulate the contents of all nodes in a tree simultaneously. .. autosummary:: :toctree: generated/ - map_over_subtree + DataTree.copy + DataTree.assign + DataTree.assign_coords + DataTree.merge + DataTree.rename + DataTree.rename_vars + DataTree.rename_dims + DataTree.swap_dims + DataTree.expand_dims + DataTree.drop_vars + DataTree.drop_dims + DataTree.set_coords + DataTree.reset_coords -Methods -------- -.. +DataTree Node Contents +---------------------- - TODO divide these up into "Dataset contents", "Indexing", "Computation" etc. +Manipulate the contents of a single DataTree node. + +Comparisons +=========== + +Compare one ``DataTree`` object to another. + +.. autosummary:: + :toctree: generated/ + + DataTree.isomorphic + DataTree.equals + DataTree.identical + +Indexing +======== + +Index into all nodes in the subtree simultaneously. .. autosummary:: :toctree: generated/ - DataTree.load - DataTree.compute - DataTree.persist - DataTree.unify_chunks - DataTree.chunk - DataTree.map_blocks - DataTree.copy - DataTree.as_numpy - DataTree.__copy__ - DataTree.__deepcopy__ - DataTree.set_coords - DataTree.reset_coords - DataTree.info DataTree.isel DataTree.sel + DataTree.drop_sel + DataTree.drop_isel DataTree.head DataTree.tail DataTree.thin - DataTree.broadcast_like - DataTree.reindex_like - DataTree.reindex + DataTree.squeeze DataTree.interp DataTree.interp_like - DataTree.rename - DataTree.rename_dims - DataTree.rename_vars - DataTree.swap_dims - DataTree.expand_dims + DataTree.reindex + DataTree.reindex_like DataTree.set_index DataTree.reset_index DataTree.reorder_levels - DataTree.stack - DataTree.unstack - DataTree.update - DataTree.merge - DataTree.drop_vars - DataTree.drop_sel - DataTree.drop_isel - DataTree.drop_dims - DataTree.isomorphic - DataTree.equals - DataTree.identical - DataTree.transpose + DataTree.query + +.. + + Missing: + ``DataTree.loc`` + + +Missing Value Handling +====================== + +.. autosummary:: + :toctree: generated/ + + DataTree.isnull + DataTree.notnull + DataTree.combine_first DataTree.dropna DataTree.fillna - DataTree.interpolate_na DataTree.ffill DataTree.bfill - DataTree.combine_first - DataTree.reduce + DataTree.interpolate_na + DataTree.where + DataTree.isin + +Computation +=========== + +Apply a computation to the data in all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + DataTree.map - DataTree.assign + DataTree.reduce DataTree.diff - DataTree.shift - DataTree.roll - DataTree.sortby DataTree.quantile - DataTree.rank DataTree.differentiate DataTree.integrate - DataTree.cumulative_integrate - DataTree.filter_by_attrs + DataTree.map_blocks DataTree.polyfit - DataTree.pad - DataTree.idxmin - DataTree.idxmax - DataTree.argmin - DataTree.argmax - DataTree.query DataTree.curvefit - DataTree.squeeze - DataTree.clip - DataTree.assign_coords - DataTree.where - DataTree.close - DataTree.isnull - DataTree.notnull - DataTree.isin - DataTree.astype -Comparisons +Aggregation =========== +Aggregate data in all nodes in the subtree simultaneously. + .. autosummary:: :toctree: generated/ - testing.assert_isomorphic - testing.assert_equal - testing.assert_identical + DataTree.all + DataTree.any + DataTree.argmax + DataTree.argmin + DataTree.idxmax + DataTree.idxmin + DataTree.max + DataTree.min + DataTree.mean + DataTree.median + DataTree.prod + DataTree.sum + DataTree.std + DataTree.var + DataTree.cumsum + DataTree.cumprod ndarray methods ---------------- +=============== + +Methods copied from `np.ndarray` objects, here applying to the data in all nodes in the subtree. .. autosummary:: :toctree: generated/ - DataTree.nbytes - DataTree.real + DataTree.argsort + DataTree.astype + DataTree.clip + DataTree.conj + DataTree.conjugate DataTree.imag + DataTree.round + DataTree.real + DataTree.rank + +Reshaping and reorganising +========================== + +Reshape or reorganise the data in all nodes in the subtree. + +.. autosummary:: + :toctree: generated/ + + DataTree.transpose + DataTree.stack + DataTree.unstack + DataTree.shift + DataTree.roll + DataTree.pad + DataTree.sortby + DataTree.broadcast_like + +Plotting +======== I/O === +Create or + .. autosummary:: :toctree: generated/ @@ -221,14 +281,46 @@ I/O .. - Missing - open_mfdatatree + Missing: + ``open_mfdatatree`` + +Tutorial +======== + +Testing +======= + +Test that two DataTree objects are similar. + +.. autosummary:: + :toctree: generated/ + + testing.assert_isomorphic + testing.assert_equal + testing.assert_identical Exceptions ========== +Exceptions raised when manipulating trees. + .. autosummary:: :toctree: generated/ - TreeError TreeIsomorphismError + +Advanced API +============ + +Relatively advanced API for users or developers looking to understand the internals, or extend functionality. + +.. autosummary:: + :toctree: generated/ + + DataTree.variables + +.. + + Missing: + ``DataTree.set_close`` + ``register_datatree_accessor`` diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index d46d5b87..e64ff549 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -38,9 +38,16 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- API page updated with all the methods that are copied from ``xarray.Dataset``. (:pull:`41`) + By `Tom Nicholas `_. + Internal Changes ~~~~~~~~~~~~~~~~ +- Refactored ``DataTree`` class to store a set of ``xarray.Variable`` objects instead of a single ``xarray.Dataset``. + This approach means that the ``DataTree`` class now effectively copies and extends the internal structure of + ``xarray.Dataset``. (:pull:`41`) + By `Tom Nicholas `_. - Made ``testing.test_datatree.create_test_datatree`` into a pytest fixture (:pull:`107`). By `Benjamin Woods `_. diff --git a/requirements.txt b/requirements.txt index cf84c87e..4eb031ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -xarray>=0.20.2 +xarray>=2022.05.0.dev0