Skip to content

Commit 227b622

Browse files
Using map_blocks to lazily mask input arrays, following pydata#4559
1 parent b8d6daf commit 227b622

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

xarray/core/computation.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,28 +1326,23 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None):
13261326

13271327
# 2. Ignore the nans
13281328
valid_values = da_a.notnull() & da_b.notnull()
1329+
valid_count = valid_values.sum(dim) - ddof
13291330

1330-
def _nan_check(d):
1331-
if d.all():
1332-
return True
1331+
def _get_valid_values(da, other):
1332+
"""
1333+
Function to lazily mask da_a and da_b
1334+
following a similar approach to
1335+
https://github.com/pydata/xarray/pull/4559
1336+
"""
1337+
missing_vals = np.logical_or(da.isnull(), other.isnull())
1338+
if missing_vals.any():
1339+
da = da.where(~missing_vals)
1340+
return da
13331341
else:
1334-
return False
1335-
1336-
if is_duck_dask_array(valid_values.data):
1337-
# assign to copy - else the check is not triggered
1338-
_are_there_nans = valid_values.copy(
1339-
data=valid_values.data.map_blocks(_nan_check, dtype=valid_values.dtype),
1340-
deep=False,
1341-
)
1342-
1343-
else:
1344-
_are_there_nans = _nan_check(valid_values.data)
1342+
return da
13451343

1346-
if not _are_there_nans:
1347-
da_a = da_a.where(valid_values)
1348-
da_b = da_b.where(valid_values)
1349-
1350-
valid_count = valid_values.sum(dim) - ddof
1344+
da_a = da_a.map_blocks(_get_valid_values, args=[da_b])
1345+
da_b = da_b.map_blocks(_get_valid_values, args=[da_a])
13511346

13521347
# 3. Detrend along the given dim
13531348
demeaned_da_a = da_a - da_a.mean(dim=dim)

0 commit comments

Comments
 (0)