Skip to content

Commit f8b5212

Browse files
committed
ENH: enable H5NetCDFStore to work with already open h5netcdf.File and h5netcdf.Group objects, add test
1 parent 23d76b4 commit f8b5212

File tree

3 files changed

+64
-16
lines changed

3 files changed

+64
-16
lines changed

xarray/backends/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def maybe_decode_store(store, lock=False):
503503
elif engine == "pydap":
504504
store = backends.PydapDataStore.open(filename_or_obj, **backend_kwargs)
505505
elif engine == "h5netcdf":
506-
store = backends.H5NetCDFStore(
506+
store = backends.H5NetCDFStore.open(
507507
filename_or_obj, group=group, lock=lock, **backend_kwargs
508508
)
509509
elif engine == "pynio":
@@ -968,7 +968,7 @@ def open_mfdataset(
968968
WRITEABLE_STORES: Dict[str, Callable] = {
969969
"netcdf4": backends.NetCDF4DataStore.open,
970970
"scipy": backends.ScipyDataStore,
971-
"h5netcdf": backends.H5NetCDFStore,
971+
"h5netcdf": backends.H5NetCDFStore.open,
972972
}
973973

974974

xarray/backends/h5netcdf_.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from .. import Variable
66
from ..core import indexing
7-
from ..core.utils import FrozenDict
7+
from ..core.utils import FrozenDict, is_remote_uri
88
from .common import WritableCFDataStore
9-
from .file_manager import CachingFileManager
9+
from .file_manager import CachingFileManager, DummyFileManager
1010
from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock
1111
from .netCDF4_ import (
1212
BaseNetCDF4Array,
@@ -69,8 +69,42 @@ class H5NetCDFStore(WritableCFDataStore):
6969
"""Store for reading and writing data via h5netcdf
7070
"""
7171

72+
__slots__ = (
73+
"autoclose",
74+
"format",
75+
"is_remote",
76+
"lock",
77+
"_filename",
78+
"_group",
79+
"_manager",
80+
"_mode",
81+
)
82+
7283
def __init__(
73-
self,
84+
self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False
85+
):
86+
87+
import h5netcdf
88+
89+
if isinstance(manager, (h5netcdf.File, h5netcdf.Group)):
90+
if group is None:
91+
root, group = manager._root, manager.name
92+
else:
93+
root = manager
94+
manager = DummyFileManager(root)
95+
96+
self._manager = manager
97+
self._group = group
98+
self._mode = mode
99+
self.format = None
100+
self._filename = self.ds._root.filename
101+
self.is_remote = is_remote_uri(self._filename)
102+
self.lock = ensure_lock(lock)
103+
self.autoclose = autoclose
104+
105+
@classmethod
106+
def open(
107+
cls,
74108
filename,
75109
mode="r",
76110
format=None,
@@ -86,22 +120,16 @@ def __init__(
86120

87121
kwargs = {"invalid_netcdf": invalid_netcdf}
88122

89-
self._manager = CachingFileManager(
90-
h5netcdf.File, filename, mode=mode, kwargs=kwargs
91-
)
92-
93123
if lock is None:
94124
if mode == "r":
95125
lock = HDF5_LOCK
96126
else:
97127
lock = combine_locks([HDF5_LOCK, get_write_lock(filename)])
98128

99-
self._group = group
100-
self.format = format
101-
self._filename = filename
102-
self._mode = mode
103-
self.lock = ensure_lock(lock)
104-
self.autoclose = autoclose
129+
manager = CachingFileManager(
130+
h5netcdf.File, filename, mode=mode, kwargs=kwargs
131+
)
132+
return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose)
105133

106134
def _acquire(self, needs_lock=True):
107135
with self._manager.acquire_context(needs_lock) as root:

xarray/tests/test_backends.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2182,7 +2182,7 @@ class TestH5NetCDFData(NetCDF4Base):
21822182
@contextlib.contextmanager
21832183
def create_store(self):
21842184
with create_tmp_file() as tmp_file:
2185-
yield backends.H5NetCDFStore(tmp_file, "w")
2185+
yield backends.H5NetCDFStore.open(tmp_file, "w")
21862186

21872187
@pytest.mark.filterwarnings("ignore:complex dtypes are supported by h5py")
21882188
@pytest.mark.parametrize(
@@ -2345,6 +2345,26 @@ def test_dump_encodings_h5py(self):
23452345
assert actual.x.encoding["compression"] == "lzf"
23462346
assert actual.x.encoding["compression_opts"] is None
23472347

2348+
def test_already_open_dataset_group(self):
2349+
import h5netcdf
2350+
with create_tmp_file() as tmp_file:
2351+
with nc4.Dataset(tmp_file, mode="w") as nc:
2352+
group = nc.createGroup("g")
2353+
v = group.createVariable("x", "int")
2354+
v[...] = 42
2355+
2356+
h5 = h5netcdf.File(tmp_file, mode="r")
2357+
store = backends.H5NetCDFStore(h5["g"])
2358+
with open_dataset(store) as ds:
2359+
expected = Dataset({"x": ((), 42)})
2360+
assert_identical(expected, ds)
2361+
2362+
h5 = h5netcdf.File(tmp_file, mode="r")
2363+
store = backends.H5NetCDFStore(h5, group="g")
2364+
with open_dataset(store) as ds:
2365+
expected = Dataset({"x": ((), 42)})
2366+
assert_identical(expected, ds)
2367+
23482368

23492369
@requires_h5netcdf
23502370
class TestH5NetCDFFileObject(TestH5NetCDFData):

0 commit comments

Comments
 (0)