Skip to content

Commit 4ab6a66

Browse files
committed
Implement __dask_tokenize__
1 parent 652dd3c commit 4ab6a66

File tree

6 files changed

+81
-0
lines changed

6 files changed

+81
-0
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ Internal Changes
6060

6161
- Use Python 3.6 idioms throughout the codebase. (:pull:3419)
6262
By `Maximilian Roos <https://github.com/max-sixty>`_
63+
- Implement :py:func:`__dask_tokenize__` for xarray objects.
64+
By `Deepak Cherian <https://github.com/dcherian>`_
6365

6466
.. _whats-new.0.14.0:
6567

xarray/core/dataarray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,9 @@ def reset_coords(
754754
dataset[self.name] = self.variable
755755
return dataset
756756

757+
def __dask_tokenize__(self):
758+
return (DataArray, self._variable, self._coords, self._name)
759+
757760
def __dask_graph__(self):
758761
return self._to_temp_dataset().__dask_graph__()
759762

xarray/core/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,9 @@ def load(self, **kwargs) -> "Dataset":
648648

649649
return self
650650

651+
def __dask_tokenize__(self):
652+
return (Dataset, self._variables, self._coord_names, self._attrs)
653+
651654
def __dask_graph__(self):
652655
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
653656
graphs = {k: v for k, v in graphs.items() if v is not None}

xarray/core/variable.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ def compute(self, **kwargs):
389389
new = self.copy(deep=False)
390390
return new.load(**kwargs)
391391

392+
def __dask_tokenize__(self):
393+
return Variable, self._dims, self.data, self._attrs
394+
392395
def __dask_graph__(self):
393396
if isinstance(self._data, dask_array_type):
394397
return self._data.__dask_graph__()
@@ -1961,6 +1964,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
19611964
if not isinstance(self._data, PandasIndexAdapter):
19621965
self._data = PandasIndexAdapter(self._data)
19631966

1967+
def __dask_tokenize__(self):
1968+
return (IndexVariable, self._dims, self._data.array, self._attrs)
1969+
19641970
def load(self):
19651971
# data is already loaded into memory for IndexVariable
19661972
return self

xarray/tests/test_dask.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
assert_identical,
2323
raises_regex,
2424
)
25+
from .test_backends import create_tmp_file
2526

2627
dask = pytest.importorskip("dask")
2728
da = pytest.importorskip("dask.array")
@@ -1135,3 +1136,57 @@ def test_make_meta(map_ds):
11351136
for variable in map_ds.data_vars:
11361137
assert variable in meta.data_vars
11371138
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim
1139+
1140+
1141+
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1142+
@pytest.mark.parametrize(
1143+
"transform",
1144+
[
1145+
lambda x: x.reset_coords(),
1146+
lambda x: x.reset_coords(drop=True),
1147+
lambda x: x.isel(x=1),
1148+
lambda x: x.attrs.update(new_attrs=1),
1149+
lambda x: x.assign_coords(cxy=1),
1150+
lambda x: x.rename({"x": "xnew"}),
1151+
lambda x: x.rename({"cxy": "cxynew"}),
1152+
],
1153+
)
1154+
def test_normalize_token_not_identical(obj, transform):
1155+
with raise_if_dask_computes():
1156+
assert not dask.base.tokenize(obj) == dask.base.tokenize(transform(obj))
1157+
assert not dask.base.tokenize(obj.compute()) == dask.base.tokenize(
1158+
transform(obj.compute())
1159+
)
1160+
1161+
1162+
@pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.compute()])
1163+
def test_normalize_differently_when_data_changes(transform):
1164+
obj = transform(make_ds())
1165+
new = obj.copy(deep=True)
1166+
new["a"] *= 2
1167+
with raise_if_dask_computes():
1168+
assert not dask.base.tokenize(obj) == dask.base.tokenize(new)
1169+
1170+
obj = transform(make_da())
1171+
new = obj.copy(deep=True)
1172+
new *= 2
1173+
with raise_if_dask_computes():
1174+
assert not dask.base.tokenize(obj) == dask.base.tokenize(new)
1175+
1176+
1177+
@pytest.mark.parametrize(
1178+
"transform", [lambda x: x, lambda x: x.copy(), lambda x: x.copy(deep=True)]
1179+
)
1180+
@pytest.mark.parametrize(
1181+
"obj", [make_da(), make_ds(), make_da().indexes["x"], make_ds().variables["a"]]
1182+
)
1183+
def test_normalize_token_identical(obj, transform):
1184+
with raise_if_dask_computes():
1185+
assert dask.base.tokenize(obj) == dask.base.tokenize(transform(obj))
1186+
1187+
1188+
def test_normalize_token_netcdf_backend(map_ds):
1189+
with create_tmp_file() as tmp_file:
1190+
map_ds.to_netcdf(tmp_file)
1191+
read = xr.open_dataset(tmp_file)
1192+
assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)

xarray/tests/test_sparse.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323

2424
sparse = pytest.importorskip("sparse")
25+
dask = pytest.importorskip("dask")
2526

2627

2728
def assert_sparse_equal(a, b):
@@ -849,3 +850,14 @@ def test_chunk():
849850
dsc = ds.chunk(2)
850851
assert dsc.chunks == {"dim_0": (2, 2)}
851852
assert_identical(dsc, ds)
853+
854+
855+
def test_normalize_token():
856+
s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))
857+
a = DataArray(s)
858+
dask.base.tokenize(a)
859+
assert isinstance(a.data, sparse.COO)
860+
861+
ac = a.chunk(2)
862+
dask.base.tokenize(ac)
863+
assert isinstance(ac.data._meta, sparse.COO)

0 commit comments

Comments
 (0)