Skip to content

Commit 2a34bfb

Browse files
authored
Compatibility with dask 2021.02.0 (#4884)
* Compatibility with dask 2021.02.0 * Rework postpersist and postcompute
1 parent 10f0227 commit 2a34bfb

File tree

3 files changed

+26
-36
lines changed

3 files changed

+26
-36
lines changed

ci/requirements/environment-windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
# - cdms2 # Not available on Windows
99
# - cfgrib # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340
1010
- cftime
11-
- dask<2021.02.0
11+
- dask
1212
- distributed
1313
- h5netcdf
1414
- h5py=2

ci/requirements/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies:
99
- cdms2
1010
- cfgrib
1111
- cftime
12-
- dask<2021.02.0
12+
- dask
1313
- distributed
1414
- h5netcdf
1515
- h5py=2

xarray/core/dataset.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -866,79 +866,69 @@ def __dask_postcompute__(self):
866866
import dask
867867

868868
info = [
869-
(True, k, v.__dask_postcompute__())
869+
(k, None) + v.__dask_postcompute__()
870870
if dask.is_dask_collection(v)
871-
else (False, k, v)
871+
else (k, v, None, None)
872872
for k, v in self._variables.items()
873873
]
874-
args = (
875-
info,
874+
construct_direct_args = (
876875
self._coord_names,
877876
self._dims,
878877
self._attrs,
879878
self._indexes,
880879
self._encoding,
881880
self._close,
882881
)
883-
return self._dask_postcompute, args
882+
return self._dask_postcompute, (info, construct_direct_args)
884883

885884
def __dask_postpersist__(self):
886885
import dask
887886

888887
info = [
889-
(True, k, v.__dask_postpersist__())
888+
(k, None, v.__dask_keys__()) + v.__dask_postpersist__()
890889
if dask.is_dask_collection(v)
891-
else (False, k, v)
890+
else (k, v, None, None, None)
892891
for k, v in self._variables.items()
893892
]
894-
args = (
895-
info,
893+
construct_direct_args = (
896894
self._coord_names,
897895
self._dims,
898896
self._attrs,
899897
self._indexes,
900898
self._encoding,
901899
self._close,
902900
)
903-
return self._dask_postpersist, args
901+
return self._dask_postpersist, (info, construct_direct_args)
904902

905903
@staticmethod
906-
def _dask_postcompute(results, info, *args):
904+
def _dask_postcompute(results, info, construct_direct_args):
907905
variables = {}
908-
results2 = list(results[::-1])
909-
for is_dask, k, v in info:
910-
if is_dask:
911-
func, args2 = v
912-
r = results2.pop()
913-
result = func(r, *args2)
906+
results_iter = iter(results)
907+
for k, v, rebuild, rebuild_args in info:
908+
if v is None:
909+
variables[k] = rebuild(next(results_iter), *rebuild_args)
914910
else:
915-
result = v
916-
variables[k] = result
911+
variables[k] = v
917912

918-
final = Dataset._construct_direct(variables, *args)
913+
final = Dataset._construct_direct(variables, *construct_direct_args)
919914
return final
920915

921916
@staticmethod
922-
def _dask_postpersist(dsk, info, *args):
917+
def _dask_postpersist(dsk, info, construct_direct_args):
918+
from dask.optimization import cull
919+
923920
variables = {}
924921
# postpersist is called in both dask.optimize and dask.persist
925922
# When persisting, we want to filter out unrelated keys for
926923
# each Variable's task graph.
927-
is_persist = len(dsk) == len(info)
928-
for is_dask, k, v in info:
929-
if is_dask:
930-
func, args2 = v
931-
if is_persist:
932-
name = args2[1][0]
933-
dsk2 = {k: v for k, v in dsk.items() if k[0] == name}
934-
else:
935-
dsk2 = dsk
936-
result = func(dsk2, *args2)
924+
for k, v, dask_keys, rebuild, rebuild_args in info:
925+
if v is None:
926+
dsk2, _ = cull(dsk, dask_keys)
927+
variables[k] = rebuild(dsk2, *rebuild_args)
937928
else:
938-
result = v
939-
variables[k] = result
929+
variables[k] = v
940930

941-
return Dataset._construct_direct(variables, *args)
931+
return Dataset._construct_direct(variables, *construct_direct_args)
942932

943933
def compute(self, **kwargs) -> "Dataset":
944934
"""Manually trigger loading and/or computation of this dataset's data

0 commit comments

Comments
 (0)