Skip to content

Commit b080349

Browse files
malmans2dcherian
andauthored
Use zarr to validate attrs when writing to zarr (#6636)
Co-authored-by: Deepak Cherian <[email protected]>
1 parent 7bdb0e4 commit b080349

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ Deprecations
9696
Bug fixes
9797
~~~~~~~~~
9898

99+
- :py:meth:`Dataset.to_zarr` now allows to write all attribute types supported by `zarr-python`.
100+
By `Mattia Almansi <https://github.com/malmans2>`_.
99101
- Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and
100102
ensure it skips missing values for float dtypes (consistent with other methods). This should
101103
not change the behavior (:pull:`6303`).

xarray/backends/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,9 +1561,8 @@ def to_zarr(
15611561
f"'w-', 'a' and 'r+', but mode={mode!r}"
15621562
)
15631563

1564-
# validate Dataset keys, DataArray names, and attr keys/values
1564+
# validate Dataset keys, DataArray names
15651565
_validate_dataset_names(dataset)
1566-
_validate_attrs(dataset)
15671566

15681567
if region is not None:
15691568
_validate_region(dataset, region)

xarray/backends/zarr.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,15 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim)
320320
)
321321

322322

323+
def _put_attrs(zarr_obj, attrs):
324+
"""Raise a more informative error message for invalid attrs."""
325+
try:
326+
zarr_obj.attrs.put(attrs)
327+
except TypeError as e:
328+
raise TypeError("Invalid attribute in Dataset.attrs.") from e
329+
return zarr_obj
330+
331+
323332
class ZarrStore(AbstractWritableDataStore):
324333
"""Store for reading and writing data via zarr"""
325334

@@ -479,7 +488,7 @@ def set_dimensions(self, variables, unlimited_dims=None):
479488
)
480489

481490
def set_attributes(self, attributes):
482-
self.zarr_group.attrs.put(attributes)
491+
_put_attrs(self.zarr_group, attributes)
483492

484493
def encode_variable(self, variable):
485494
variable = encode_zarr_variable(variable)
@@ -618,7 +627,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
618627
zarr_array = self.zarr_group.create(
619628
name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding
620629
)
621-
zarr_array.attrs.put(encoded_attrs)
630+
zarr_array = _put_attrs(zarr_array, encoded_attrs)
622631

623632
write_region = self._write_region if self._write_region is not None else {}
624633
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

xarray/tests/test_backends.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,6 +2443,22 @@ def test_write_read_select_write(self):
24432443
with self.create_zarr_target() as final_store:
24442444
ds_sel.to_zarr(final_store, mode="w")
24452445

2446+
@pytest.mark.parametrize("obj", [Dataset(), DataArray(name="foo")])
2447+
def test_attributes(self, obj):
2448+
obj = obj.copy()
2449+
2450+
obj.attrs["good"] = {"key": "value"}
2451+
ds = obj if isinstance(obj, Dataset) else obj.to_dataset()
2452+
with self.create_zarr_target() as store_target:
2453+
ds.to_zarr(store_target)
2454+
assert_identical(ds, xr.open_zarr(store_target))
2455+
2456+
obj.attrs["bad"] = DataArray()
2457+
ds = obj if isinstance(obj, Dataset) else obj.to_dataset()
2458+
with self.create_zarr_target() as store_target:
2459+
with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."):
2460+
ds.to_zarr(store_target)
2461+
24462462

24472463
@requires_zarr
24482464
class TestZarrDictStore(ZarrBase):

0 commit comments

Comments
 (0)