Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit 84f472a

Browse files
Joe HammanTomNicholas
andauthored
[WIP] add DataTree.to_netcdf (#26)
* first attempt at to_netcdf * lint * add test for roundtrip and support empty nodes * Apply suggestions from code review Co-authored-by: Tom Nicholas <[email protected]> * update roundtrip test, improves empty node handling in IO Co-authored-by: Tom Nicholas <[email protected]>
1 parent 35fa1e3 commit 84f472a

File tree

3 files changed

+135
-7
lines changed

3 files changed

+135
-7
lines changed

datatree/datatree.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,10 +873,44 @@ def groups(self):
873873
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
874874
return tuple(node.pathstr for node in self.subtree_nodes)
875875

876-
def to_netcdf(self, filename: str):
876+
def to_netcdf(
877+
self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs
878+
):
879+
"""
880+
Write datatree contents to a netCDF file.
881+
882+
Paramters
883+
---------
884+
filepath : str or Path
885+
Path to which to save this datatree.
886+
mode : {"w", "a"}, default: "w"
887+
Write ('w') or append ('a') mode. If mode='w', any existing file at
888+
this location will be overwritten. If mode='a', existing variables
889+
will be overwritten. Only appies to the root group.
890+
encoding : dict, optional
891+
Nested dictionary with variable names as keys and dictionaries of
892+
variable specific encodings as values, e.g.,
893+
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1,
894+
"zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available
895+
options.
896+
unlimited_dims : dict, optional
897+
Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions.
898+
By default, no dimensions are treated as unlimited dimensions.
899+
Note that unlimited_dims may also be set via
900+
``dataset.encoding["unlimited_dims"]``.
901+
kwargs :
902+
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
903+
"""
877904
from .io import _datatree_to_netcdf
878905

879-
_datatree_to_netcdf(self, filename)
906+
_datatree_to_netcdf(
907+
self,
908+
filepath,
909+
mode=mode,
910+
encoding=encoding,
911+
unlimited_dims=unlimited_dims,
912+
**kwargs,
913+
)
880914

881915
def plot(self):
882916
raise NotImplementedError

datatree/io.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@
77
from .datatree import DataNode, DataTree, PathType
88

99

10+
def _ds_or_none(ds):
11+
"""return none if ds is empty"""
12+
if any(ds.coords) or any(ds.variables) or any(ds.attrs):
13+
return ds
14+
return None
15+
16+
1017
def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
1118
for g in ncgroup.groups.values():
1219

1320
# Open and add this node's dataset to the tree
1421
name = os.path.basename(g.path)
1522
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
23+
ds = _ds_or_none(ds)
1624
child_node = DataNode(name, ds)
1725
node.add_child(child_node)
1826

@@ -65,5 +73,73 @@ def open_mfdatatree(
6573
return full_tree
6674

6775

68-
def _datatree_to_netcdf(dt: DataTree, filepath: str):
69-
raise NotImplementedError
76+
def _maybe_extract_group_kwargs(enc, group):
77+
try:
78+
return enc[group]
79+
except KeyError:
80+
return None
81+
82+
83+
def _create_empty_group(filename, group, mode):
84+
with netCDF4.Dataset(filename, mode=mode) as rootgrp:
85+
rootgrp.createGroup(group)
86+
87+
88+
def _datatree_to_netcdf(
89+
dt: DataTree,
90+
filepath,
91+
mode: str = "w",
92+
encoding=None,
93+
unlimited_dims=None,
94+
**kwargs
95+
):
96+
97+
if kwargs.get("format", None) not in [None, "NETCDF4"]:
98+
raise ValueError("to_netcdf only supports the NETCDF4 format")
99+
100+
if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
101+
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")
102+
103+
if kwargs.get("group", None) is not None:
104+
raise NotImplementedError(
105+
"specifying a root group for the tree has not been implemented"
106+
)
107+
108+
if not kwargs.get("compute", True):
109+
raise NotImplementedError("compute=False has not been implemented yet")
110+
111+
if encoding is None:
112+
encoding = {}
113+
114+
if unlimited_dims is None:
115+
unlimited_dims = {}
116+
117+
ds = dt.ds
118+
group_path = dt.pathstr.replace(dt.root.pathstr, "")
119+
if ds is None:
120+
_create_empty_group(filepath, group_path, mode)
121+
else:
122+
ds.to_netcdf(
123+
filepath,
124+
group=group_path,
125+
mode=mode,
126+
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
127+
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
128+
**kwargs
129+
)
130+
mode = "a"
131+
132+
for node in dt.descendants:
133+
ds = node.ds
134+
group_path = node.pathstr.replace(dt.root.pathstr, "")
135+
if ds is None:
136+
_create_empty_group(filepath, group_path, mode)
137+
else:
138+
ds.to_netcdf(
139+
filepath,
140+
group=group_path,
141+
mode=mode,
142+
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
143+
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
144+
**kwargs
145+
)

datatree/tests/test_datatree.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from xarray.testing import assert_identical
55

66
from datatree import DataNode, DataTree
7+
from datatree.io import open_datatree
78

89

910
def create_test_datatree():
@@ -31,14 +32,14 @@ def create_test_datatree():
3132
| Dimensions: (x: 2, y: 3)
3233
| Data variables:
3334
| a (y) int64 6, 7, 8
34-
| set1 (x) int64 9, 10
35+
| set0 (x) int64 9, 10
3536
3637
The structure has deliberately repeated names of tags, variables, and
3738
dimensions in order to better check for bugs caused by name conflicts.
3839
"""
3940
set1_data = xr.Dataset({"a": 0, "b": 1})
4041
set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})
41-
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set1": ("x", [9, 10])})
42+
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
4243

4344
# Avoid using __init__ so we can independently test it
4445
# TODO change so it has a DataTree at the bottom
@@ -297,4 +298,21 @@ def test_repr_of_node_with_data(self):
297298

298299

299300
class TestIO:
300-
...
301+
def test_to_netcdf(self, tmpdir):
302+
filepath = str(
303+
tmpdir / "test.nc"
304+
) # casting to str avoids a pathlib bug in xarray
305+
original_dt = create_test_datatree()
306+
original_dt.to_netcdf(filepath, engine="netcdf4")
307+
308+
roundtrip_dt = open_datatree(filepath)
309+
310+
original_dt.name == roundtrip_dt.name
311+
assert original_dt.ds.identical(roundtrip_dt.ds)
312+
for a, b in zip(original_dt.descendants, roundtrip_dt.descendants):
313+
assert a.name == b.name
314+
assert a.pathstr == b.pathstr
315+
if a.has_data:
316+
assert a.ds.identical(b.ds)
317+
else:
318+
assert a.ds is b.ds

0 commit comments

Comments
 (0)