Skip to content

Commit 7e5dbe9

Browse files
committed
Handle dtypes.NA properly for datetime/timedelta
1 parent f0ce343 commit 7e5dbe9

File tree

4 files changed

+89
-73
lines changed

4 files changed

+89
-73
lines changed

flox/aggregate_flox.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from . import xrdtypes as dtypes
56
from .xrutils import is_scalar, isnull, notnull
67

78

@@ -98,7 +99,7 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
9899
# partition the complex array in-place
99100
labels_broadcast = np.broadcast_to(group_idx, array.shape)
100101
with np.errstate(invalid="ignore"):
101-
cmplx = labels_broadcast + 1j * array
102+
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
102103
cmplx.partition(kth=kth, axis=-1)
103104
if is_scalar_q:
104105
a_ = cmplx.imag
@@ -158,6 +159,8 @@ def _np_grouped_op(
158159

159160

160161
def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
162+
if fillna in [dtypes.INF, dtypes.NINF]:
163+
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
161164
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
162165
# np.nanmax([np.nan, np.nan]) = np.nan
163166
# To recover this behaviour, we need to search for the fillna value
@@ -175,9 +178,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
175178
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
176179
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
177180
max = partial(_np_grouped_op, op=np.maximum.reduceat)
178-
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
181+
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
179182
min = partial(_np_grouped_op, op=np.minimum.reduceat)
180-
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
183+
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
181184
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
182185
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
183186
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))

flox/aggregations.py

Lines changed: 22 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -115,60 +115,6 @@ def generic_aggregate(
115115
return result
116116

117117

118-
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
119-
if dtype is None:
120-
dtype = array_dtype
121-
if dtype is np.floating:
122-
# mean, std, var always result in floating
123-
# but we preserve the array's dtype if it is floating
124-
if array_dtype.kind in "fcmM":
125-
dtype = array_dtype
126-
else:
127-
dtype = np.dtype("float64")
128-
elif not isinstance(dtype, np.dtype):
129-
dtype = np.dtype(dtype)
130-
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
131-
dtype = np.result_type(dtype, fill_value)
132-
return dtype
133-
134-
135-
def _maybe_promote_int(dtype) -> np.dtype:
136-
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137-
# The dtype of a is used by default unless a has an integer dtype of less precision
138-
# than the default platform integer.
139-
if not isinstance(dtype, np.dtype):
140-
dtype = np.dtype(dtype)
141-
if dtype.kind == "i":
142-
dtype = np.result_type(dtype, np.intp)
143-
elif dtype.kind == "u":
144-
dtype = np.result_type(dtype, np.uintp)
145-
return dtype
146-
147-
148-
def _get_fill_value(dtype, fill_value):
149-
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150-
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
151-
return ""
152-
if fill_value == dtypes.INF or fill_value is None:
153-
return dtypes.get_pos_infinity(dtype, max_for_int=True)
154-
if fill_value == dtypes.NINF:
155-
return dtypes.get_neg_infinity(dtype, min_for_int=True)
156-
if fill_value == dtypes.NA:
157-
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
158-
return np.nan
159-
# This is madness, but npg checks that fill_value is compatible
160-
# with array dtype even if the fill_value is never used.
161-
elif (
162-
np.issubdtype(dtype, np.integer)
163-
or np.issubdtype(dtype, np.timedelta64)
164-
or np.issubdtype(dtype, np.datetime64)
165-
):
166-
return dtypes.get_neg_infinity(dtype, min_for_int=True)
167-
else:
168-
return None
169-
return fill_value
170-
171-
172118
def _atleast_1d(inp, min_length: int = 1):
173119
if xrutils.is_scalar(inp):
174120
inp = (inp,) * min_length
@@ -435,9 +381,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
435381

436382

437383
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
438-
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
384+
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA)
439385
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
440-
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
386+
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA)
441387

442388

443389
def argreduce_preprocess(array, axis):
@@ -634,7 +580,7 @@ def last(self) -> AlignedArrays:
634580
# TODO: automate?
635581
engine="flox",
636582
dtype=self.array.dtype,
637-
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
583+
fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA),
638584
expected_groups=None,
639585
)
640586
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
@@ -729,15 +675,15 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
729675
binary_op=None,
730676
reduction="nanlast",
731677
scan="ffill",
732-
identity=np.nan,
678+
identity=dtypes.NA,
733679
mode="concat_then_scan",
734680
)
735681
bfill = Scan(
736682
"bfill",
737683
binary_op=None,
738684
reduction="nanlast",
739685
scan="ffill",
740-
identity=np.nan,
686+
identity=dtypes.NA,
741687
mode="concat_then_scan",
742688
preprocess=reverse,
743689
finalize=reverse,
@@ -816,16 +762,27 @@ def _initialize_aggregation(
816762
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
817763
)
818764

819-
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
820-
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
821-
final_dtype = _maybe_promote_int(final_dtype)
765+
final_dtype = dtypes._normalize_dtype(
766+
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
767+
)
768+
if agg.name not in [
769+
"first",
770+
"last",
771+
"nanfirst",
772+
"nanlast",
773+
"min",
774+
"max",
775+
"nanmin",
776+
"nanmax",
777+
]:
778+
final_dtype = dtypes._maybe_promote_int(final_dtype)
822779
agg.dtype = {
823780
"user": dtype, # Save to automatically choose an engine
824781
"final": final_dtype,
825782
"numpy": (final_dtype,),
826783
"intermediate": tuple(
827784
(
828-
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
785+
dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
829786
if int_dtype is None
830787
else np.dtype(int_dtype)
831788
)
@@ -838,10 +795,10 @@ def _initialize_aggregation(
838795
# Replace sentinel fill values according to dtype
839796
agg.fill_value["user"] = fill_value
840797
agg.fill_value["intermediate"] = tuple(
841-
_get_fill_value(dt, fv)
798+
dtypes._get_fill_value(dt, fv)
842799
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
843800
)
844-
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
801+
agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func])
845802

846803
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
847804
if _is_arg_reduction(agg):

flox/xrdtypes.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22

33
import numpy as np
4+
from numpy.typing import DTypeLike
45

56
from . import xrutils as utils
67

@@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False):
147148
def is_datetime_like(dtype):
148149
"""Check if a dtype is a subclass of the numpy datetime types"""
149150
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
151+
152+
153+
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
154+
if dtype is None:
155+
dtype = array_dtype
156+
if dtype is np.floating:
157+
# mean, std, var always result in floating
158+
# but we preserve the array's dtype if it is floating
159+
if array_dtype.kind in "fcmM":
160+
dtype = array_dtype
161+
else:
162+
dtype = np.dtype("float64")
163+
elif not isinstance(dtype, np.dtype):
164+
dtype = np.dtype(dtype)
165+
if fill_value not in [None, INF, NINF, NA]:
166+
dtype = np.result_type(dtype, fill_value)
167+
return dtype
168+
169+
170+
def _maybe_promote_int(dtype) -> np.dtype:
171+
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
172+
# The dtype of a is used by default unless a has an integer dtype of less precision
173+
# than the default platform integer.
174+
if not isinstance(dtype, np.dtype):
175+
dtype = np.dtype(dtype)
176+
if dtype.kind == "i":
177+
dtype = np.result_type(dtype, np.intp)
178+
elif dtype.kind == "u":
179+
dtype = np.result_type(dtype, np.uintp)
180+
return dtype
181+
182+
183+
def _get_fill_value(dtype, fill_value):
184+
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
185+
if fill_value in [None, NA] and dtype.kind in "US":
186+
return ""
187+
if fill_value == INF or fill_value is None:
188+
return get_pos_infinity(dtype, max_for_int=True)
189+
if fill_value == NINF:
190+
return get_neg_infinity(dtype, min_for_int=True)
191+
if fill_value == NA:
192+
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
193+
return np.nan
194+
# This is madness, but npg checks that fill_value is compatible
195+
# with array dtype even if the fill_value is never used.
196+
elif np.issubdtype(dtype, np.integer):
197+
return get_neg_infinity(dtype, min_for_int=True)
198+
elif np.issubdtype(dtype, np.timedelta64):
199+
return np.timedelta64("NaT")
200+
elif np.issubdtype(dtype, np.datetime64):
201+
return np.datetime64("NaT")
202+
else:
203+
return None
204+
return fill_value

tests/test_core.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from numpy_groupies.aggregate_numpy import aggregate
1414

1515
import flox
16+
from flox import xrdtypes as dtypes
1617
from flox import xrutils
17-
from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int
18+
from flox.aggregations import Aggregation, _initialize_aggregation
1819
from flox.core import (
1920
HAS_NUMBAGG,
2021
_choose_engine,
@@ -161,7 +162,7 @@ def test_groupby_reduce(
161162
if func == "mean" or func == "nanmean":
162163
expected_result = np.array(expected, dtype=np.float64)
163164
elif func == "sum":
164-
expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype))
165+
expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype))
165166
elif func == "count":
166167
expected_result = np.array(expected, dtype=np.intp)
167168

@@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func):
389390
array = np.ones((2, 12), dtype=dtype)
390391
by = np.array([labels] * 2)
391392
result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func)
392-
expect_dtype = _maybe_promote_int(array.dtype)
393+
expect_dtype = dtypes._maybe_promote_int(array.dtype)
393394
assert result.dtype == expect_dtype
394395

395396

@@ -1054,7 +1055,7 @@ def test_dtype_preservation(dtype, func, engine):
10541055
# https://github.com/numbagg/numbagg/issues/121
10551056
pytest.skip()
10561057
if func == "sum":
1057-
expected = _maybe_promote_int(dtype)
1058+
expected = dtypes._maybe_promote_int(dtype)
10581059
elif func == "mean" and "int" in dtype:
10591060
expected = np.float64
10601061
else:
@@ -1085,7 +1086,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
10851086
actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method)
10861087
assert_equal(actual_groups, np.arange(6, dtype=labels.dtype))
10871088

1088-
expect_dtype = _maybe_promote_int(dtype)
1089+
expect_dtype = dtypes._maybe_promote_int(dtype)
10891090
assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))
10901091

10911092

0 commit comments

Comments
 (0)