diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9ba9a7e9c0f..b957ed4bab2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,11 +21,14 @@ v0.18.3 (unreleased) New Features ~~~~~~~~~~~~ - - Allow assigning values to a subset of a dataset using positional or label-based indexing (:issue:`3015`, :pull:`5362`). By `Matthias Göbel `_. - Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`). By `Mattia Almansi `_. +- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing + values if inputs are dask arrays (:issue:`4804`, :pull:`5284`). + By `Andrew Williams `_. + Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6010d502c23..07b7bca27c5 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1352,12 +1352,23 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): # 2. Ignore the nans valid_values = da_a.notnull() & da_b.notnull() + valid_count = valid_values.sum(dim) - ddof - if not valid_values.all(): - da_a = da_a.where(valid_values) - da_b = da_b.where(valid_values) + def _get_valid_values(da, other): + """ + Function to lazily mask da_a and da_b + following a similar approach to + https://github.com/pydata/xarray/pull/4559 + """ + missing_vals = np.logical_or(da.isnull(), other.isnull()) + if missing_vals.any(): + da = da.where(~missing_vals) + return da + else: + return da - valid_count = valid_values.sum(dim) - ddof + da_a = da_a.map_blocks(_get_valid_values, args=[da_b]) + da_b = da_b.map_blocks(_get_valid_values, args=[da_a]) # 3. Detrend along the given dim demeaned_da_a = da_a - da_a.mean(dim=dim) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index cbfa61a4482..b7ae1ca9828 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -22,7 +22,7 @@ unified_dim_sizes, ) -from . import has_dask, requires_dask +from . import has_dask, raise_if_dask_computes, requires_dask dask = pytest.importorskip("dask") @@ -1386,6 +1386,7 @@ def arrays_w_tuples(): da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(), xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]), + xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]), ] array_tuples = [ @@ -1394,12 +1395,40 @@ def arrays_w_tuples(): (arrays[1], arrays[1]), (arrays[2], arrays[2]), (arrays[2], arrays[3]), + (arrays[2], arrays[4]), + (arrays[4], arrays[2]), (arrays[3], arrays[3]), + (arrays[4], arrays[4]), ] return arrays, array_tuples +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize( + "da_a, da_b", + [ + arrays_w_tuples()[1][3], + arrays_w_tuples()[1][4], + arrays_w_tuples()[1][5], + arrays_w_tuples()[1][6], + arrays_w_tuples()[1][7], + arrays_w_tuples()[1][8], + ], +) +@pytest.mark.parametrize("dim", [None, "x", "time"]) +def test_lazy_corrcov(da_a, da_b, dim, ddof): + # GH 5284 + from dask import is_dask_collection + + with raise_if_dask_computes(): + cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof) + assert is_dask_collection(cov) + + corr = xr.corr(da_a.chunk(), da_b.chunk(), dim=dim) + assert is_dask_collection(corr) + + @pytest.mark.parametrize("ddof", [0, 1]) @pytest.mark.parametrize( "da_a, da_b",