Skip to content

Commit 80c67f4

Browse files
committed
Fix
1 parent 93800aa commit 80c67f4

File tree

5 files changed

+84
-71
lines changed

5 files changed

+84
-71
lines changed

flox/aggregate_flox.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from . import xrdtypes as dtypes
56
from .xrutils import is_scalar, isnull, notnull
67

78

@@ -60,6 +61,7 @@ def quantile_or_topk(
6061
fill_value=None,
6162
):
6263
assert q or k
64+
assert axis == -1
6365

6466
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))
6567

@@ -84,7 +86,7 @@ def quantile_or_topk(
8486
nanmask = full_sizes != actual_sizes
8587
# TODO: Don't know if this array has been copied in _prepare_for_flox.
8688
# This is potentially wasteful
87-
array = np.where(array_nanmask, -np.inf, array)
89+
array = np.where(array_nanmask, dtypes.get_neg_infinity(array.dtype, min_for_int=True), array)
8890
maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis)
8991
replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis)
9092
array[array_nanmask] = replacement[array_nanmask]
@@ -128,14 +130,17 @@ def quantile_or_topk(
128130
# partition the complex array in-place
129131
labels_broadcast = np.broadcast_to(group_idx, array.shape)
130132
with np.errstate(invalid="ignore"):
131-
cmplx = labels_broadcast + 1j * array
133+
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
132134
cmplx.partition(kth=kth, axis=-1)
133135

134136
if is_scalar_param:
135137
a_ = cmplx.imag
136138
else:
137139
a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape)
138140

141+
if array.dtype.kind in "Mm":
142+
a_ = a_.astype(array.dtype)
143+
139144
loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
140145
if q is not None:
141146
# get bounds, Broadcast to (num quantiles, ..., num labels)
@@ -204,6 +209,8 @@ def _np_grouped_op(
204209

205210

206211
def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
212+
if fillna in [dtypes.INF, dtypes.NINF]:
213+
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
207214
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
208215
# np.nanmax([np.nan, np.nan]) = np.nan
209216
# To recover this behaviour, we need to search for the fillna value
@@ -221,9 +228,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
221228
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
222229
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
223230
max = partial(_np_grouped_op, op=np.maximum.reduceat)
224-
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
231+
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
225232
min = partial(_np_grouped_op, op=np.minimum.reduceat)
226-
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
233+
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
227234
quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False))
228235
topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
229236
nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))

flox/aggregations.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -115,60 +115,6 @@ def generic_aggregate(
115115
return result
116116

117117

118-
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
119-
if dtype is None:
120-
dtype = array_dtype
121-
if dtype is np.floating:
122-
# mean, std, var always result in floating
123-
# but we preserve the array's dtype if it is floating
124-
if array_dtype.kind in "fcmM":
125-
dtype = array_dtype
126-
else:
127-
dtype = np.dtype("float64")
128-
elif not isinstance(dtype, np.dtype):
129-
dtype = np.dtype(dtype)
130-
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
131-
dtype = np.result_type(dtype, fill_value)
132-
return dtype
133-
134-
135-
def _maybe_promote_int(dtype) -> np.dtype:
136-
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
137-
# The dtype of a is used by default unless a has an integer dtype of less precision
138-
# than the default platform integer.
139-
if not isinstance(dtype, np.dtype):
140-
dtype = np.dtype(dtype)
141-
if dtype.kind == "i":
142-
dtype = np.result_type(dtype, np.intp)
143-
elif dtype.kind == "u":
144-
dtype = np.result_type(dtype, np.uintp)
145-
return dtype
146-
147-
148-
def _get_fill_value(dtype, fill_value):
149-
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
150-
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
151-
return ""
152-
if fill_value == dtypes.INF or fill_value is None:
153-
return dtypes.get_pos_infinity(dtype, max_for_int=True)
154-
if fill_value == dtypes.NINF:
155-
return dtypes.get_neg_infinity(dtype, min_for_int=True)
156-
if fill_value == dtypes.NA:
157-
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
158-
return np.nan
159-
# This is madness, but npg checks that fill_value is compatible
160-
# with array dtype even if the fill_value is never used.
161-
elif np.issubdtype(dtype, np.integer):
162-
return dtypes.get_neg_infinity(dtype, min_for_int=True)
163-
elif np.issubdtype(dtype, np.timedelta64):
164-
return np.timedelta64("NaT")
165-
elif np.issubdtype(dtype, np.datetime64):
166-
return np.datetime64("NaT")
167-
else:
168-
return None
169-
return fill_value
170-
171-
172118
def _atleast_1d(inp, min_length: int = 1):
173119
if xrutils.is_scalar(inp):
174120
inp = (inp,) * min_length
@@ -646,7 +592,7 @@ def last(self) -> AlignedArrays:
646592
# TODO: automate?
647593
engine="flox",
648594
dtype=self.array.dtype,
649-
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
595+
fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA),
650596
expected_groups=None,
651597
)
652598
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
@@ -829,7 +775,9 @@ def _initialize_aggregation(
829775
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
830776
)
831777

832-
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
778+
final_dtype = dtypes._normalize_dtype(
779+
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
780+
)
833781
if agg.name not in [
834782
"first",
835783
"last",
@@ -841,14 +789,14 @@ def _initialize_aggregation(
841789
"nanmax",
842790
"topk",
843791
]:
844-
final_dtype = _maybe_promote_int(final_dtype)
792+
final_dtype = dtypes._maybe_promote_int(final_dtype)
845793
agg.dtype = {
846794
"user": dtype, # Save to automatically choose an engine
847795
"final": final_dtype,
848796
"numpy": (final_dtype,),
849797
"intermediate": tuple(
850798
(
851-
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
799+
dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
852800
if int_dtype is None
853801
else np.dtype(int_dtype)
854802
)
@@ -863,10 +811,10 @@ def _initialize_aggregation(
863811
# Replace sentinel fill values according to dtype
864812
agg.fill_value["user"] = fill_value
865813
agg.fill_value["intermediate"] = tuple(
866-
_get_fill_value(dt, fv)
814+
dtypes._get_fill_value(dt, fv)
867815
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
868816
)
869-
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
817+
agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func])
870818

871819
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
872820
if _is_arg_reduction(agg):

flox/xrdtypes.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22

33
import numpy as np
4+
from numpy.typing import DTypeLike
45

56
from . import xrutils as utils
67

@@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False):
147148
def is_datetime_like(dtype):
148149
"""Check if a dtype is a subclass of the numpy datetime types"""
149150
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
151+
152+
153+
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
154+
if dtype is None:
155+
dtype = array_dtype
156+
if dtype is np.floating:
157+
# mean, std, var always result in floating
158+
# but we preserve the array's dtype if it is floating
159+
if array_dtype.kind in "fcmM":
160+
dtype = array_dtype
161+
else:
162+
dtype = np.dtype("float64")
163+
elif not isinstance(dtype, np.dtype):
164+
dtype = np.dtype(dtype)
165+
if fill_value not in [None, INF, NINF, NA]:
166+
dtype = np.result_type(dtype, fill_value)
167+
return dtype
168+
169+
170+
def _maybe_promote_int(dtype) -> np.dtype:
171+
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
172+
# The dtype of a is used by default unless a has an integer dtype of less precision
173+
# than the default platform integer.
174+
if not isinstance(dtype, np.dtype):
175+
dtype = np.dtype(dtype)
176+
if dtype.kind == "i":
177+
dtype = np.result_type(dtype, np.intp)
178+
elif dtype.kind == "u":
179+
dtype = np.result_type(dtype, np.uintp)
180+
return dtype
181+
182+
183+
def _get_fill_value(dtype, fill_value):
184+
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
185+
if fill_value in [None, NA] and dtype.kind in "US":
186+
return ""
187+
if fill_value == INF or fill_value is None:
188+
return get_pos_infinity(dtype, max_for_int=True)
189+
if fill_value == NINF:
190+
return get_neg_infinity(dtype, min_for_int=True)
191+
if fill_value == NA:
192+
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
193+
return np.nan
194+
# This is madness, but npg checks that fill_value is compatible
195+
# with array dtype even if the fill_value is never used.
196+
elif np.issubdtype(dtype, np.integer):
197+
return get_neg_infinity(dtype, min_for_int=True)
198+
elif np.issubdtype(dtype, np.timedelta64):
199+
return np.timedelta64("NaT")
200+
elif np.issubdtype(dtype, np.datetime64):
201+
return np.datetime64("NaT")
202+
else:
203+
return None
204+
return fill_value

tests/test_core.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from numpy_groupies.aggregate_numpy import aggregate
1414

1515
import flox
16+
from flox import xrdtypes as dtypes
1617
from flox import xrutils
17-
from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int
18+
from flox.aggregations import Aggregation, _initialize_aggregation
1819
from flox.core import (
1920
HAS_NUMBAGG,
2021
_choose_engine,
@@ -161,7 +162,7 @@ def test_groupby_reduce(
161162
if func == "mean" or func == "nanmean":
162163
expected_result = np.array(expected, dtype=np.float64)
163164
elif func == "sum":
164-
expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype))
165+
expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype))
165166
elif func == "count":
166167
expected_result = np.array(expected, dtype=np.intp)
167168

@@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func):
389390
array = np.ones((2, 12), dtype=dtype)
390391
by = np.array([labels] * 2)
391392
result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func)
392-
expect_dtype = _maybe_promote_int(array.dtype)
393+
expect_dtype = dtypes._maybe_promote_int(array.dtype)
393394
assert result.dtype == expect_dtype
394395

395396

@@ -1027,7 +1028,7 @@ def test_dtype_preservation(dtype, func, engine):
10271028
# https://github.com/numbagg/numbagg/issues/121
10281029
pytest.skip()
10291030
if func == "sum":
1030-
expected = _maybe_promote_int(dtype)
1031+
expected = dtypes._maybe_promote_int(dtype)
10311032
elif func == "mean" and "int" in dtype:
10321033
expected = np.float64
10331034
else:
@@ -1058,7 +1059,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
10581059
actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method)
10591060
assert_equal(actual_groups, np.arange(6, dtype=labels.dtype))
10601061

1061-
expect_dtype = _maybe_promote_int(dtype)
1062+
expect_dtype = dtypes._maybe_promote_int(dtype)
10621063
assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))
10631064

10641065

tests/test_properties.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,9 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None:
215215
from hypothesis import settings
216216

217217

218+
# TODO: do all_arrays instead of numeric_arrays
218219
@settings(report_multiple_bugs=False)
219-
@given(data=st.data(), array=chunked_arrays())
220+
@given(data=st.data(), array=chunked_arrays(arrays=numeric_arrays))
220221
def test_topk_max_min(data, array):
221222
"top 1 == nanmax; top -1 == nanmin"
222223
size = array.shape[-1]
@@ -226,5 +227,6 @@ def test_topk_max_min(data, array):
226227

227228
for a in (array, array.compute()):
228229
actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k})
229-
expected, _ = groupby_reduce(a, by, func=npfunc)
230+
# TODO: do numbagg, flox
231+
expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy")
230232
assert_equal(actual, expected[np.newaxis, :])

0 commit comments

Comments
 (0)