Skip to content

Commit 56f0e48

Browse files
authored
Add inherit=False option to DataTree.copy() (#9628)
* Add inherit=False option to DataTree.copy() This PR adds a inherit=False option to DataTree.copy, so users can decide if they want to inherit coordinates from parents or not when creating a subtree. The default behavior is `inherit=True`, which is a breaking change from the current behavior where parent coordinates are dropped (which I believe should be considered a bug). * fix typing * add migration guide note * ignore typing error
1 parent c3dabe1 commit 56f0e48

File tree

4 files changed

+40
-43
lines changed

4 files changed

+40
-43
lines changed

DATATREE_MIGRATION_GUIDE.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ A number of other API changes have been made, which should only require minor mo
3636
- The top-level import has changed, from `from datatree import DataTree, open_datatree` to `from xarray import DataTree, open_datatree`. Alternatively you can now just use the `import xarray as xr` namespace convention for everything datatree-related.
3737
- The `DataTree.ds` property has been changed to `DataTree.dataset`, though `DataTree.ds` remains as an alias for `DataTree.dataset`.
3838
- Similarly the `ds` kwarg in the `DataTree.__init__` constructor has been replaced by `dataset`, i.e. use `DataTree(dataset=)` instead of `DataTree(ds=...)`.
39-
- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherited=True/False`.
39+
- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherit=True/False`.
40+
- `DataTree.copy()` also has a new `inherit` keyword argument for controlling whether or not coordinates defined on parents are copied (only relevant when copying a non-root node).
4041
- The `DataTree.parent` property is now read-only. To assign a ancestral relationships directly you must instead use the `.children` property on the parent node, which remains settable.
4142
- Similarly the `parent` kwarg has been removed from the `DataTree.__init__` constuctor.
4243
- DataTree objects passed to the `children` kwarg in `DataTree.__init__` are now shallow-copied.

xarray/core/datatree.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -826,19 +826,13 @@ def _replace_node(
826826

827827
self.children = children
828828

829-
def _copy_node(
830-
self: DataTree,
831-
deep: bool = False,
832-
) -> DataTree:
833-
"""Copy just one node of a tree"""
834-
835-
new_node = super()._copy_node()
836-
837-
data = self._to_dataset_view(rebuild_dims=False, inherit=False)
829+
def _copy_node(self, inherit: bool, deep: bool = False) -> Self:
830+
"""Copy just one node of a tree."""
831+
new_node = super()._copy_node(inherit=inherit, deep=deep)
832+
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)
838833
if deep:
839834
data = data.copy(deep=True)
840835
new_node._set_node_data(data)
841-
842836
return new_node
843837

844838
def get( # type: ignore[override]
@@ -1159,7 +1153,9 @@ def depth(item) -> int:
11591153
new_nodes_along_path=True,
11601154
)
11611155

1162-
return obj
1156+
# TODO: figure out why mypy is raising an error here, likely something
1157+
# to do with the return type of Dataset.copy()
1158+
return obj # type: ignore[return-value]
11631159

11641160
def to_dict(self) -> dict[str, Dataset]:
11651161
"""

xarray/core/treenode.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TypeVar,
1111
)
1212

13+
from xarray.core.types import Self
1314
from xarray.core.utils import Frozen, is_dict_like
1415

1516
if TYPE_CHECKING:
@@ -238,10 +239,7 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None:
238239
"""Method call after attaching `children`."""
239240
pass
240241

241-
def copy(
242-
self: Tree,
243-
deep: bool = False,
244-
) -> Tree:
242+
def copy(self, *, inherit: bool = True, deep: bool = False) -> Self:
245243
"""
246244
Returns a copy of this subtree.
247245
@@ -254,7 +252,12 @@ def copy(
254252
255253
Parameters
256254
----------
257-
deep : bool, default: False
255+
inherit : bool
256+
Whether inherited coordinates defined on parents of this node should
257+
also be copied onto the new tree. Only relevant if the `parent` of
258+
this node is not yet, and "Inherited coordinates" appear in its
259+
repr.
260+
deep : bool
258261
Whether each component variable is loaded into memory and copied onto
259262
the new object. Default is False.
260263
@@ -269,35 +272,27 @@ def copy(
269272
xarray.Dataset.copy
270273
pandas.DataFrame.copy
271274
"""
272-
return self._copy_subtree(deep=deep)
275+
return self._copy_subtree(inherit=inherit, deep=deep)
273276

274-
def _copy_subtree(
275-
self: Tree,
276-
deep: bool = False,
277-
memo: dict[int, Any] | None = None,
278-
) -> Tree:
277+
def _copy_subtree(self, inherit: bool, deep: bool = False) -> Self:
279278
"""Copy entire subtree recursively."""
280-
281-
new_tree = self._copy_node(deep=deep)
279+
new_tree = self._copy_node(inherit=inherit, deep=deep)
282280
for name, child in self.children.items():
283281
# TODO use `.children[name] = ...` once #9477 is implemented
284-
new_tree._set(name, child._copy_subtree(deep=deep))
285-
282+
new_tree._set(name, child._copy_subtree(inherit=False, deep=deep))
286283
return new_tree
287284

288-
def _copy_node(
289-
self: Tree,
290-
deep: bool = False,
291-
) -> Tree:
285+
def _copy_node(self, inherit: bool, deep: bool = False) -> Self:
292286
"""Copy just one node of a tree"""
293287
new_empty_node = type(self)()
294288
return new_empty_node
295289

296-
def __copy__(self: Tree) -> Tree:
297-
return self._copy_subtree(deep=False)
290+
def __copy__(self) -> Self:
291+
return self._copy_subtree(inherit=True, deep=False)
298292

299-
def __deepcopy__(self: Tree, memo: dict[int, Any] | None = None) -> Tree:
300-
return self._copy_subtree(deep=True, memo=memo)
293+
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
294+
del memo # nodes cannot be reused in a DataTree
295+
return self._copy_subtree(inherit=True, deep=True)
301296

302297
def _iter_parents(self: Tree) -> Iterator[Tree]:
303298
"""Iterate up the tree, starting from the current node's parent."""
@@ -693,17 +688,14 @@ def __str__(self) -> str:
693688
name_repr = repr(self.name) if self.name is not None else ""
694689
return f"NamedNode({name_repr})"
695690

696-
def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
691+
def _post_attach(self, parent: Self, name: str) -> None:
697692
"""Ensures child has name attribute corresponding to key under which it has been stored."""
698693
_validate_name(name) # is this check redundant?
699694
self._name = name
700695

701-
def _copy_node(
702-
self: AnyNamedNode,
703-
deep: bool = False,
704-
) -> AnyNamedNode:
696+
def _copy_node(self, inherit: bool, deep: bool = False) -> Self:
705697
"""Copy just one node of a tree"""
706-
new_node = super()._copy_node()
698+
new_node = super()._copy_node(inherit=inherit, deep=deep)
707699
new_node._name = self.name
708700
return new_node
709701

xarray/tests/test_datatree.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,18 @@ def test_copy_coord_inheritance(self) -> None:
414414
tree = DataTree.from_dict(
415415
{"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()}
416416
)
417-
tree2 = tree.copy()
418-
node_ds = tree2.children["c"].to_dataset(inherit=False)
417+
actual = tree.copy()
418+
node_ds = actual.children["c"].to_dataset(inherit=False)
419419
assert_identical(node_ds, xr.Dataset())
420420

421+
actual = tree.children["c"].copy()
422+
expected = DataTree(Dataset(coords={"x": [0, 1]}), name="c")
423+
assert_identical(expected, actual)
424+
425+
actual = tree.children["c"].copy(inherit=False)
426+
expected = DataTree(name="c")
427+
assert_identical(expected, actual)
428+
421429
def test_deepcopy(self, create_test_datatree):
422430
dt = create_test_datatree()
423431

0 commit comments

Comments
 (0)