Skip to content

.reduce() on a DataArray with Dask distributed immediately executes the preceding portions of the computational graph #3161

Closed
@zbarry

Description

@zbarry

MCVE Code Sample

.mean() on a DataArray pointing to a Dask array returns a Dask array-containing DataArray as expected:

da = xr.DataArray(np.zeros((500, 500, 500))).chunk((100, 100, 100)).mean('dim_0')
da
<xarray.DataArray (dim_1: 500, dim_2: 500)>
dask.array<shape=(500, 500), dtype=float64, chunksize=(100, 100)>
Dimensions without coordinates: dim_1, dim_2

Calling .compute() on this result produces the expected result:

da = xr.DataArray(np.zeros((500, 500, 500))).chunk((100, 100, 100)).mean('dim_0').compute()
<xarray.DataArray (dim_1: 500, dim_2: 500)>
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
Dimensions without coordinates: dim_1, dim_2

The .reduce() method immediately executes all of the previously queued computations leading up to the new reduce method before even calling the supplied function.

def func(x, axis=None):
    return x

da = xr.DataArray(np.zeros((500, 500, 500))).chunk((100, 100, 100)).mean('dim_0').reduce(func)
<xarray.DataArray (dim_1: 500, dim_2: 500)>
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
Dimensions without coordinates: dim_1, dim_2

Expected Output

A Dask array when .reduce(func) isn't followed up by .compute().

Problem Description

When using Dask distributed, the computational graph you are constructing is immediately executed if you call .reduce() instead of adding that function as another node in the DAG. This graph execution happens before the function you pass to reduce is called.

Output of xr.show_versions()

INSTALLED VERSIONS ------------------ commit: None python: 3.7.3 | packaged by conda-forge | (default, Jul 1 2019, 21:52:21) [GCC 7.3.0] python-bits: 64 OS: Linux OS-release: 3.10.0-693.21.1.el7.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8 libhdf5: 1.10.4 libnetcdf: 4.6.2

xarray: 0.12.3
pandas: 0.25.0
numpy: 1.16.4
scipy: 1.3.0
netCDF4: 1.5.1.2
pydap: None
h5netcdf: None
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.3.4
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.1.0
distributed: 2.1.0
matplotlib: 3.1.1
cartopy: None
seaborn: None
numbagg: None
setuptools: 41.0.1
pip: 19.2.1
conda: None
pytest: 5.0.1
IPython: 7.6.1
sphinx: 2.1.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions