Skip to content

Commit d66f673

Browse files
authored
Disable automatic cache with dask (#1024)
* Disabled auto-caching dask arrays when pickling and when invoking the .values property. Added new method .compute(). * Minor tweaks * Simplified Dataset.copy() and Dataset.compute() * Minor cleanup * Cleaned up dask test * Integrate no_dask_resolve with dask_broadcast branches * Don't chunk coords * Added performance warning to release notes * Fix bug that caused dask array to be computed and then discarded when pickling * Eagerly cache IndexVariables only Eagerly cache only IndexVariables (e.g. coords that are not in dims. Coords that are not in dims are chunked and not cached. * Load IndexVariable.data into memory in init IndexVariables to eagerly load their data into memory (from disk or dask) as soon as they're created
2 parents 0ed1e2c + 376200a commit d66f673

File tree

8 files changed

+278
-95
lines changed

8 files changed

+278
-95
lines changed

doc/whats-new.rst

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ Breaking changes
2525
merges will now succeed in cases that previously raised
2626
``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the
2727
previous default.
28+
- Pickling an xarray object based on the dask backend, or reading its
29+
:py:meth:`values` property, won't automatically convert the array from dask
30+
to numpy in the original object anymore.
31+
If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or
32+
:py:class:`~xarray.Dataset`, its values are eagerly computed and cached,
33+
but only if it's used to index a dim (e.g. it's used for alignment).
34+
By `Guido Imperiale <https://github.com/crusaderky>`_.
2835

2936
Deprecations
3037
~~~~~~~~~~~~
@@ -52,32 +59,31 @@ Enhancements
5259
- Add checking of ``attr`` names and values when saving to netCDF, raising useful
5360
error messages if they are invalid. (:issue:`911`).
5461
By `Robin Wilson <https://github.com/robintw>`_.
55-
5662
- Added ability to save ``DataArray`` objects directly to netCDF files using
5763
:py:meth:`~xarray.DataArray.to_netcdf`, and to load directly from netCDF files
5864
using :py:func:`~xarray.open_dataarray` (:issue:`915`). These remove the need
5965
to convert a ``DataArray`` to a ``Dataset`` before saving as a netCDF file,
6066
and deals with names to ensure a perfect 'roundtrip' capability.
6167
By `Robin Wilson <https://github.com/robintw>`_.
62-
6368
- Multi-index levels are now accessible as "virtual" coordinate variables,
6469
e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index
6570
(see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels
6671
as keyword arguments, e.g., ``ds.sel(time='2000-01')``
6772
(see :ref:`multi-level indexing`).
6873
By `Benoit Bovy <https://github.com/benbovy>`_.
69-
7074
- Added the ``compat`` option ``'no_conflicts'`` to ``merge``, allowing the
7175
combination of xarray objects with disjoint (:issue:`742`) or
7276
overlapping (:issue:`835`) coordinates as long as all present data agrees.
7377
By `Johnnie Gray <https://github.com/jcmgray>`_. See
7478
:ref:`combining.no_conflicts` for more details.
75-
7679
- It is now possible to set ``concat_dim=None`` explicitly in
7780
:py:func:`~xarray.open_mfdataset` to disable inferring a dimension along
7881
which to concatenate.
7982
By `Stephan Hoyer <https://github.com/shoyer>`_.
80-
83+
- Added methods :py:meth:`DataArray.compute`, :py:meth:`Dataset.compute`, and
84+
:py:meth:`Variable.compute` as a non-mutating alternative to
85+
:py:meth:`~DataArray.load`.
86+
By `Guido Imperiale <https://github.com/crusaderky>`_.
8187
- Adds DataArray and Dataset methods :py:meth:`~xarray.DataArray.cumsum` and
8288
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
8389
<https://github.com/pwolfram>`_.

xarray/core/dataarray.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ def load(self):
570570
self._coords = new._coords
571571
return self
572572

573+
def compute(self):
574+
"""Manually trigger loading of this array's data from disk or a
575+
remote source into memory and return a new array. The original is
576+
left unaltered.
577+
578+
Normally, it should not be necessary to call this method in user code,
579+
because all xarray functions should either work on deferred data or
580+
load data automatically. However, this method can be necessary when
581+
working with many file objects on disk.
582+
"""
583+
new = self.copy(deep=False)
584+
return new.load()
585+
573586
def copy(self, deep=True):
574587
"""Returns a copy of this array.
575588

xarray/core/dataset.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,11 @@ def load_store(cls, store, decoder=None):
260260
return obj
261261

262262
def __getstate__(self):
263-
"""Always load data in-memory before pickling"""
264-
self.load()
263+
"""Load data in-memory before pickling (except for Dask data)"""
264+
for v in self.variables.values():
265+
if not isinstance(v.data, dask_array_type):
266+
v.load()
267+
265268
# self.__dict__ is the default pickle object, we don't need to
266269
# implement our own __setstate__ method to make pickle work
267270
state = self.__dict__.copy()
@@ -342,6 +345,19 @@ def load(self):
342345

343346
return self
344347

348+
def compute(self):
349+
"""Manually trigger loading of this dataset's data from disk or a
350+
remote source into memory and return a new dataset. The original is
351+
left unaltered.
352+
353+
Normally, it should not be necessary to call this method in user code,
354+
because all xarray functions should either work on deferred data or
355+
load data automatically. However, this method can be necessary when
356+
working with many file objects on disk.
357+
"""
358+
new = self.copy(deep=False)
359+
return new.load()
360+
345361
@classmethod
346362
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
347363
file_obj=None):
@@ -424,14 +440,12 @@ def copy(self, deep=False):
424440
"""Returns a copy of this dataset.
425441
426442
If `deep=True`, a deep copy is made of each of the component variables.
427-
Otherwise, a shallow copy is made, so each variable in the new dataset
428-
is also a variable in the original dataset.
443+
Otherwise, a shallow copy of each of the component variable is made, so
444+
that the underlying memory region of the new dataset is the same as in
445+
the original dataset.
429446
"""
430-
if deep:
431-
variables = OrderedDict((k, v.copy(deep=True))
432-
for k, v in iteritems(self._variables))
433-
else:
434-
variables = self._variables.copy()
447+
variables = OrderedDict((k, v.copy(deep=deep))
448+
for k, v in iteritems(self._variables))
435449
# skip __init__ to avoid costly validation
436450
return self._construct_direct(variables, self._coord_names.copy(),
437451
self._dims.copy(), self._attrs_copy())
@@ -817,11 +831,10 @@ def chunks(self):
817831
chunks = {}
818832
for v in self.variables.values():
819833
if v.chunks is not None:
820-
new_chunks = list(zip(v.dims, v.chunks))
821-
if any(chunk != chunks[d] for d, chunk in new_chunks
822-
if d in chunks):
823-
raise ValueError('inconsistent chunks')
824-
chunks.update(new_chunks)
834+
for dim, c in zip(v.dims, v.chunks):
835+
if dim in chunks and c != chunks[dim]:
836+
raise ValueError('inconsistent chunks')
837+
chunks[dim] = c
825838
return Frozen(SortedKeysDict(chunks))
826839

827840
def chunk(self, chunks=None, name_prefix='xarray-', token=None,

xarray/core/variable.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,21 @@ def data(self, data):
277277
"replacement data must match the Variable's shape")
278278
self._data = data
279279

280+
def _data_cast(self):
281+
if isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
282+
return self._data
283+
else:
284+
return np.asarray(self._data)
285+
280286
def _data_cached(self):
281-
if not isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
282-
self._data = np.asarray(self._data)
283-
return self._data
287+
"""Load data into memory and return it.
288+
Do not cache dask arrays automatically; that should
289+
require an explicit load() call.
290+
"""
291+
new_data = self._data_cast()
292+
if not isinstance(self._data, dask_array_type):
293+
self._data = new_data
294+
return new_data
284295

285296
@property
286297
def _indexable_data(self):
@@ -294,12 +305,26 @@ def load(self):
294305
because all xarray functions should either work on deferred data or
295306
load data automatically.
296307
"""
297-
self._data_cached()
308+
self._data = self._data_cast()
298309
return self
299310

311+
def compute(self):
312+
"""Manually trigger loading of this variable's data from disk or a
313+
remote source into memory and return a new variable. The original is
314+
left unaltered.
315+
316+
Normally, it should not be necessary to call this method in user code,
317+
because all xarray functions should either work on deferred data or
318+
load data automatically.
319+
"""
320+
new = self.copy(deep=False)
321+
return new.load()
322+
300323
def __getstate__(self):
301-
"""Always cache data as an in-memory array before pickling"""
302-
self._data_cached()
324+
"""Always cache data as an in-memory array before pickling
325+
(with the exception of dask backend)"""
326+
if not isinstance(self._data, dask_array_type):
327+
self._data_cached()
303328
# self.__dict__ is the default pickle object, we don't need to
304329
# implement our own __setstate__ method to make pickle work
305330
return self.__dict__
@@ -1075,10 +1100,19 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
10751100
raise ValueError('%s objects must be 1-dimensional' %
10761101
type(self).__name__)
10771102

1078-
def _data_cached(self):
1103+
# Unlike in Variable, always eagerly load values into memory
10791104
if not isinstance(self._data, PandasIndexAdapter):
10801105
self._data = PandasIndexAdapter(self._data)
1081-
return self._data
1106+
1107+
@Variable.data.setter
1108+
def data(self, data):
1109+
Variable.data.fset(self, data)
1110+
if not isinstance(self._data, PandasIndexAdapter):
1111+
self._data = PandasIndexAdapter(self._data)
1112+
1113+
def chunk(self, chunks=None, name=None, lock=False):
1114+
# Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
1115+
return self.copy(deep=False)
10821116

10831117
def __getitem__(self, key):
10841118
key = self._item_key_to_tuple(key)

xarray/test/test_backends.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ def assert_loads(vars=None):
128128
if vars is None:
129129
vars = expected
130130
with self.roundtrip(expected) as actual:
131-
for v in actual.variables.values():
132-
self.assertFalse(v._in_memory)
131+
for k, v in actual.variables.items():
132+
# IndexVariables are eagerly loaded into memory
133+
if k in actual.dims:
134+
self.assertTrue(v._in_memory)
135+
else:
136+
self.assertFalse(v._in_memory)
133137
yield actual
134138
for k, v in actual.variables.items():
135139
if k in vars:
@@ -152,6 +156,31 @@ def assert_loads(vars=None):
152156
actual = ds.load()
153157
self.assertDatasetAllClose(expected, actual)
154158

159+
def test_dataset_compute(self):
160+
expected = create_test_data()
161+
162+
with self.roundtrip(expected) as actual:
163+
# Test Dataset.compute()
164+
for k, v in actual.variables.items():
165+
# IndexVariables are eagerly cached
166+
if k in actual.dims:
167+
self.assertTrue(v._in_memory)
168+
else:
169+
self.assertFalse(v._in_memory)
170+
171+
computed = actual.compute()
172+
173+
for k, v in actual.variables.items():
174+
if k in actual.dims:
175+
self.assertTrue(v._in_memory)
176+
else:
177+
self.assertFalse(v._in_memory)
178+
for v in computed.variables.values():
179+
self.assertTrue(v._in_memory)
180+
181+
self.assertDatasetAllClose(expected, actual)
182+
self.assertDatasetAllClose(expected, computed)
183+
155184
def test_roundtrip_None_variable(self):
156185
expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])})
157186
with self.roundtrip(expected) as actual:
@@ -233,18 +262,6 @@ def test_roundtrip_coordinates(self):
233262
with self.roundtrip(expected) as actual:
234263
self.assertDatasetIdentical(expected, actual)
235264

236-
expected = original.copy()
237-
expected.attrs['coordinates'] = 'something random'
238-
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
239-
with self.roundtrip(expected):
240-
pass
241-
242-
expected = original.copy(deep=True)
243-
expected['foo'].attrs['coordinates'] = 'something random'
244-
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
245-
with self.roundtrip(expected):
246-
pass
247-
248265
def test_roundtrip_boolean_dtype(self):
249266
original = create_boolean_data()
250267
self.assertEqual(original['x'].dtype, 'bool')
@@ -875,7 +892,26 @@ def test_read_byte_attrs_as_unicode(self):
875892
@requires_dask
876893
@requires_scipy
877894
@requires_netCDF4
878-
class DaskTest(TestCase):
895+
class DaskTest(TestCase, DatasetIOTestCases):
896+
@contextlib.contextmanager
897+
def create_store(self):
898+
yield Dataset()
899+
900+
@contextlib.contextmanager
901+
def roundtrip(self, data, save_kwargs={}, open_kwargs={}):
902+
yield data.chunk()
903+
904+
def test_roundtrip_datetime_data(self):
905+
# Override method in DatasetIOTestCases - remove not applicable save_kwds
906+
times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT'])
907+
expected = Dataset({'t': ('t', times), 't0': times[0]})
908+
with self.roundtrip(expected) as actual:
909+
self.assertDatasetIdentical(expected, actual)
910+
911+
def test_write_store(self):
912+
# Override method in DatasetIOTestCases - not applicable to dask
913+
pass
914+
879915
def test_open_mfdataset(self):
880916
original = Dataset({'foo': ('x', np.random.randn(10))})
881917
with create_tmp_file() as tmp1:
@@ -995,6 +1031,15 @@ def test_deterministic_names(self):
9951031
self.assertIn(tmp, dask_name)
9961032
self.assertEqual(original_names, repeat_names)
9971033

1034+
def test_dataarray_compute(self):
1035+
# Test DataArray.compute() on dask backend.
1036+
# The test for Dataset.compute() is already in DatasetIOTestCases;
1037+
# however dask is the only tested backend which supports DataArrays
1038+
actual = DataArray([1,2]).chunk()
1039+
computed = actual.compute()
1040+
self.assertFalse(actual._in_memory)
1041+
self.assertTrue(computed._in_memory)
1042+
self.assertDataArrayAllClose(actual, computed)
9981043

9991044
@requires_scipy_or_netCDF4
10001045
@requires_pydap

0 commit comments

Comments
 (0)