Skip to content

Commit d61efb6

Browse files
bcbnzdcherian
andauthored
Fix behaviour of min_count in reducing functions (#4911)
* Add more tests for reducing functions with min_count * Make sure Dask-backed arrays are not computed. * Check some specific examples give the correct output. * Run membership tests on xarray.core.dtypes.NAT_TYPES * Fix behaviour of min_count in reducing functions. * Fix mask checks in xarray.core.nanops._maybe_null_out to run lazily for Dask-backed arrays. * Change xarray.core.dtypes.NAT_TYPES to a set (it is only used for membership checks). * Add dtypes to NAT_TYPES rather than instances. Previously np.float64 was returning true from `dtype in NAT_TYPES` which resulted in min_count being ignored when reducing over all axes. * Add whatsnew entry. * Improvements from review. * use duck_array_ops.where instead of np.where * add docstring and whatsnew messages about sum/prod on integer arrays with skipna=True and min_count != None now returning a float array. Co-authored-by: Deepak Cherian <[email protected]>
1 parent ae0a71b commit d61efb6

File tree

6 files changed

+94
-17
lines changed

6 files changed

+94
-17
lines changed

doc/whats-new.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ Breaking changes
4848
:ref:`weather-climate` (:pull:`2844`, :issue:`3689`)
4949
- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`).
5050
By `Aureliana Barghini <https://github.com/aurghs>`_.
51+
- As a result of :pull:`4911` the output from calling :py:meth:`DataArray.sum`
52+
or :py:meth:`DataArray.prod` on an integer array with ``skipna=True`` and a
53+
non-None value for ``min_count`` will now be a float array rather than an
54+
integer array.
5155

5256
Deprecations
5357
~~~~~~~~~~~~
@@ -129,6 +133,12 @@ Bug fixes
129133
By `Leif Denby <https://github.com/leifdenby>`_.
130134
- Fix time encoding bug associated with using cftime versions greater than
131135
1.4.0 with xarray (:issue:`4870`, :pull:`4871`). By `Spencer Clark <https://github.com/spencerkclark>`_.
136+
- Stop :py:meth:`DataArray.sum` and :py:meth:`DataArray.prod` computing lazy
137+
arrays when called with a ``min_count`` parameter (:issue:`4898`, :pull:`4911`).
138+
By `Blair Bonnett <https://github.com/bcbnz>`_.
139+
- Fix bug preventing the ``min_count`` parameter to :py:meth:`DataArray.sum` and
140+
:py:meth:`DataArray.prod` working correctly when calculating over all axes of
141+
a float64 array (:issue:`4898`, :pull:`4911`). By `Blair Bonnett <https://github.com/bcbnz>`_.
132142
- Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`).
133143
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
134144

xarray/core/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def maybe_promote(dtype):
7878
return np.dtype(dtype), fill_value
7979

8080

81-
NAT_TYPES = (np.datetime64("NaT"), np.timedelta64("NaT"))
81+
NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
8282

8383

8484
def get_fill_value(dtype):

xarray/core/nanops.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
import numpy as np
44

55
from . import dtypes, nputils, utils
6-
from .duck_array_ops import _dask_or_eager_func, count, fillna, isnull, where_method
6+
from .duck_array_ops import (
7+
_dask_or_eager_func,
8+
count,
9+
fillna,
10+
isnull,
11+
where,
12+
where_method,
13+
)
714
from .pycompat import dask_array_type
815

916
try:
@@ -28,18 +35,14 @@ def _maybe_null_out(result, axis, mask, min_count=1):
2835
"""
2936
xarray version of pandas.core.nanops._maybe_null_out
3037
"""
31-
3238
if axis is not None and getattr(result, "ndim", False):
3339
null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
34-
if null_mask.any():
35-
dtype, fill_value = dtypes.maybe_promote(result.dtype)
36-
result = result.astype(dtype)
37-
result[null_mask] = fill_value
40+
dtype, fill_value = dtypes.maybe_promote(result.dtype)
41+
result = where(null_mask, fill_value, result.astype(dtype))
3842

3943
elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
4044
null_mask = mask.size - mask.sum()
41-
if null_mask < min_count:
42-
result = np.nan
45+
result = where(null_mask < min_count, np.nan, result)
4346

4447
return result
4548

xarray/core/ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,12 @@
114114

115115
_MINCOUNT_DOCSTRING = """
116116
min_count : int, default: None
117-
The required number of valid values to perform the operation.
118-
If fewer than min_count non-NA values are present the result will
119-
be NA. New in version 0.10.8: Added with the default being None."""
117+
The required number of valid values to perform the operation. If
118+
fewer than min_count non-NA values are present the result will be
119+
NA. Only used if skipna is set to True or defaults to True for the
120+
array's dtype. New in version 0.10.8: Added with the default being
121+
None. Changed in version 0.17.0: if specified on an integer array
122+
and skipna=True, the result will be a float array."""
120123

121124
_COARSEN_REDUCE_DOCSTRING_TEMPLATE = """\
122125
Coarsen this object by applying `{name}` along its dimensions.

xarray/tests/test_dtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,9 @@ def test_maybe_promote(kind, expected):
9090
actual = dtypes.maybe_promote(np.dtype(kind))
9191
assert actual[0] == expected[0]
9292
assert str(actual[1]) == expected[1]
93+
94+
95+
def test_nat_types_membership():
96+
assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES
97+
assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES
98+
assert np.float64 not in dtypes.NAT_TYPES

xarray/tests/test_duck_array_ops.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
assert_array_equal,
3535
has_dask,
3636
has_scipy,
37+
raise_if_dask_computes,
3738
raises_regex,
3839
requires_cftime,
3940
requires_dask,
@@ -587,7 +588,10 @@ def test_min_count(dim_num, dtype, dask, func, aggdim, contains_nan, skipna):
587588
da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask)
588589
min_count = 3
589590

590-
actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count)
591+
# If using Dask, the function call should be lazy.
592+
with raise_if_dask_computes():
593+
actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count)
594+
591595
expected = series_reduce(da, func, skipna=skipna, dim=aggdim, min_count=min_count)
592596
assert_allclose(actual, expected)
593597
assert_dask_array(actual, dask)
@@ -603,14 +607,62 @@ def test_min_count_nd(dtype, dask, func):
603607
min_count = 3
604608
dim_num = 3
605609
da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask)
606-
actual = getattr(da, func)(dim=["x", "y", "z"], skipna=True, min_count=min_count)
610+
611+
# If using Dask, the function call should be lazy.
612+
with raise_if_dask_computes():
613+
actual = getattr(da, func)(
614+
dim=["x", "y", "z"], skipna=True, min_count=min_count
615+
)
616+
607617
# Supplying all dims is equivalent to supplying `...` or `None`
608618
expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count)
609619

610620
assert_allclose(actual, expected)
611621
assert_dask_array(actual, dask)
612622

613623

624+
@pytest.mark.parametrize("dask", [False, True])
625+
@pytest.mark.parametrize("func", ["sum", "prod"])
626+
@pytest.mark.parametrize("dim", [None, "a", "b"])
627+
def test_min_count_specific(dask, func, dim):
628+
if dask and not has_dask:
629+
pytest.skip("requires dask")
630+
631+
# Simple array with four non-NaN values.
632+
da = DataArray(np.ones((6, 6), dtype=np.float64) * np.nan, dims=("a", "b"))
633+
da[0][0] = 2
634+
da[0][3] = 2
635+
da[3][0] = 2
636+
da[3][3] = 2
637+
if dask:
638+
da = da.chunk({"a": 3, "b": 3})
639+
640+
# Expected result if we set min_count to the number of non-NaNs in a
641+
# row/column/the entire array.
642+
if dim:
643+
min_count = 2
644+
expected = DataArray(
645+
[4.0, np.nan, np.nan] * 2, dims=("a" if dim == "b" else "b",)
646+
)
647+
else:
648+
min_count = 4
649+
expected = DataArray(8.0 if func == "sum" else 16.0)
650+
651+
# Check for that min_count.
652+
with raise_if_dask_computes():
653+
actual = getattr(da, func)(dim, skipna=True, min_count=min_count)
654+
assert_dask_array(actual, dask)
655+
assert_allclose(actual, expected)
656+
657+
# With min_count being one higher, should get all NaN.
658+
min_count += 1
659+
expected *= np.nan
660+
with raise_if_dask_computes():
661+
actual = getattr(da, func)(dim, skipna=True, min_count=min_count)
662+
assert_dask_array(actual, dask)
663+
assert_allclose(actual, expected)
664+
665+
614666
@pytest.mark.parametrize("func", ["sum", "prod"])
615667
def test_min_count_dataset(func):
616668
da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False)
@@ -655,9 +707,12 @@ def test_docs():
655707
have a sentinel missing value (int) or skipna=True has not been
656708
implemented (object, datetime64 or timedelta64).
657709
min_count : int, default: None
658-
The required number of valid values to perform the operation.
659-
If fewer than min_count non-NA values are present the result will
660-
be NA. New in version 0.10.8: Added with the default being None.
710+
The required number of valid values to perform the operation. If
711+
fewer than min_count non-NA values are present the result will be
712+
NA. Only used if skipna is set to True or defaults to True for the
713+
array's dtype. New in version 0.10.8: Added with the default being
714+
None. Changed in version 0.17.0: if specified on an integer array
715+
and skipna=True, the result will be a float array.
661716
keep_attrs : bool, optional
662717
If True, the attributes (`attrs`) will be copied from the original
663718
object to the new one. If False (default), the new object will be

0 commit comments

Comments
 (0)