Skip to content

Commit e70138b

Browse files
authored
Recursive tokenization (#3515)
* recursive tokenize * black * What's New * Also test Dataset * Also test IndexVariable * Cleanup * tokenize sparse objects
1 parent b74f80c commit e70138b

File tree

6 files changed

+45
-5
lines changed

6 files changed

+45
-5
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ New Features
7373
for xarray objects. Note that xarray objects with a dask.array backend already used
7474
deterministic hashing in previous releases; this change implements it when whole
7575
xarray objects are embedded in a dask graph, e.g. when :meth:`DataArray.map` is
76-
invoked. (:issue:`3378`, :pull:`3446`)
76+
invoked. (:issue:`3378`, :pull:`3446`, :pull:`3515`)
7777
By `Deepak Cherian <https://github.com/dcherian>`_ and
7878
`Guido Imperiale <https://github.com/crusaderky>`_.
7979

xarray/core/dataarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,9 @@ def reset_coords(
755755
return dataset
756756

757757
def __dask_tokenize__(self):
758-
return (type(self), self._variable, self._coords, self._name)
758+
from dask.base import normalize_token
759+
760+
return normalize_token((type(self), self._variable, self._coords, self._name))
759761

760762
def __dask_graph__(self):
761763
return self._to_temp_dataset().__dask_graph__()

xarray/core/dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,11 @@ def load(self, **kwargs) -> "Dataset":
652652
return self
653653

654654
def __dask_tokenize__(self):
655-
return (type(self), self._variables, self._coord_names, self._attrs)
655+
from dask.base import normalize_token
656+
657+
return normalize_token(
658+
(type(self), self._variables, self._coord_names, self._attrs)
659+
)
656660

657661
def __dask_graph__(self):
658662
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}

xarray/core/variable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ def compute(self, **kwargs):
393393
def __dask_tokenize__(self):
394394
# Use v.data, instead of v._data, in order to cope with the wrappers
395395
# around NetCDF and the like
396-
return type(self), self._dims, self.data, self._attrs
396+
from dask.base import normalize_token
397+
398+
return normalize_token((type(self), self._dims, self.data, self._attrs))
397399

398400
def __dask_graph__(self):
399401
if isinstance(self._data, dask_array_type):
@@ -1973,8 +1975,10 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
19731975
self._data = PandasIndexAdapter(self._data)
19741976

19751977
def __dask_tokenize__(self):
1978+
from dask.base import normalize_token
1979+
19761980
# Don't waste time converting pd.Index to np.ndarray
1977-
return (type(self), self._dims, self._data.array, self._attrs)
1981+
return normalize_token((type(self), self._dims, self._data.array, self._attrs))
19781982

19791983
def load(self):
19801984
# data is already loaded into memory for IndexVariable

xarray/tests/test_dask.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,32 @@ def test_token_identical(obj, transform):
12831283
)
12841284

12851285

1286+
def test_recursive_token():
1287+
"""Test that tokenization is invoked recursively, and doesn't just rely on the
1288+
output of str()
1289+
"""
1290+
a = np.ones(10000)
1291+
b = np.ones(10000)
1292+
b[5000] = 2
1293+
assert str(a) == str(b)
1294+
assert dask.base.tokenize(a) != dask.base.tokenize(b)
1295+
1296+
# Test DataArray and Variable
1297+
da_a = DataArray(a)
1298+
da_b = DataArray(b)
1299+
assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b)
1300+
1301+
# Test Dataset
1302+
ds_a = da_a.to_dataset(name="x")
1303+
ds_b = da_b.to_dataset(name="x")
1304+
assert dask.base.tokenize(ds_a) != dask.base.tokenize(ds_b)
1305+
1306+
# Test IndexVariable
1307+
da_a = DataArray(a, dims=["x"], coords={"x": a})
1308+
da_b = DataArray(a, dims=["x"], coords={"x": b})
1309+
assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b)
1310+
1311+
12861312
@requires_scipy_or_netCDF4
12871313
def test_normalize_token_with_backend(map_ds):
12881314
with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp_file:

xarray/tests/test_sparse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,10 @@ def test_dask_token():
856856
import dask
857857

858858
s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))
859+
860+
# https://github.com/pydata/sparse/issues/300
861+
s.__dask_tokenize__ = lambda: dask.base.normalize_token(s.__dict__)
862+
859863
a = DataArray(s)
860864
t1 = dask.base.tokenize(a)
861865
t2 = dask.base.tokenize(a)

0 commit comments

Comments
 (0)