Skip to content

Commit 24be2d1

Browse files
committed
Preserve dtype better when specified.
1 parent 4dbadae commit 24be2d1

File tree

5 files changed

+50
-12
lines changed

5 files changed

+50
-12
lines changed

flox/aggregations.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ def __repr__(self) -> str:
292292
combine="sum",
293293
fill_value=0,
294294
final_fill_value=0,
295-
dtypes=np.intp,
296-
final_dtype=np.intp,
295+
dtypes=np.integer,
296+
final_dtype=np.integer,
297297
)
298298

299299
# note that the fill values are the result of np.func([np.nan, np.nan])
@@ -521,20 +521,23 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
521521
return (Dim(name="quantile", values=q),)
522522

523523

524+
# if the input contains integers or floats smaller than float64,
525+
# the output data-type is float64. Otherwise, the output data-type is the same as that
526+
# of the input.
524527
quantile = Aggregation(
525528
name="quantile",
526529
fill_value=dtypes.NA,
527530
chunk=None,
528531
combine=None,
529-
final_dtype=np.floating,
532+
final_dtype=np.float64,
530533
new_dims_func=quantile_new_dims_func,
531534
)
532535
nanquantile = Aggregation(
533536
name="nanquantile",
534537
fill_value=dtypes.NA,
535538
chunk=None,
536539
combine=None,
537-
final_dtype=np.floating,
540+
final_dtype=np.float64,
538541
new_dims_func=quantile_new_dims_func,
539542
)
540543
mode = Aggregation(
@@ -780,10 +783,8 @@ def _initialize_aggregation(
780783
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
781784
)
782785
final_dtype = dtypes._normalize_dtype(
783-
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
786+
dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
784787
)
785-
if not agg.preserves_dtype:
786-
final_dtype = dtypes._maybe_promote_int(final_dtype)
787788
agg.dtype = {
788789
"user": dtype, # Save to automatically choose an engine
789790
"final": final_dtype,

flox/xrdtypes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,14 @@ def is_datetime_like(dtype):
150150
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
151151

152152

153-
def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
153+
def _normalize_dtype(
154+
dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
155+
) -> np.dtype:
154156
if dtype is None:
155-
dtype = array_dtype
157+
if not preserves_dtype:
158+
dtype = _maybe_promote_int(array_dtype)
159+
else:
160+
dtype = array_dtype
156161
if dtype is np.floating:
157162
# mean, std, var always result in floating
158163
# but we preserve the array's dtype if it is floating

tests/strategies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
2626

2727

2828
# TODO: stop excluding everything but U
29-
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
29+
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
3030
by_dtype_st = supported_dtypes()
3131

3232
NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
@@ -38,7 +38,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
3838
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
3939
)
4040
numeric_arrays = npst.arrays(
41-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
41+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
4242
)
4343
all_arrays = npst.arrays(
4444
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()

tests/test_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,3 +1929,13 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
19291929
expected = flox.groupby_scan(array.compute(), by, func=func)
19301930
actual = flox.groupby_scan(array, by, func=func)
19311931
assert_equal(expected, actual)
1932+
1933+
1934+
def test_agg_dtypes():
1935+
# regression test for GH388
1936+
counts = np.array([0, 2, 1, 0, 1])
1937+
group = np.array([1, 1, 1, 2, 2])
1938+
actual, _ = groupby_reduce(
1939+
counts, group, expected_groups=(np.array([1, 2]),), func="sum", dtype="uint8"
1940+
)
1941+
assert actual.dtype == np.uint8

tests/test_properties.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from flox.xrutils import notnull
2020

2121
from . import assert_equal
22-
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
22+
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
2323
from .strategies import chunks as chunks_strategy
2424

2525
dask.config.set(scheduler="sync")
@@ -223,3 +223,25 @@ def test_first_last_useless(data, func):
223223
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
224224
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
225225
assert_equal(actual, expected)
226+
227+
228+
@given(
229+
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
230+
engine=st.sampled_from(["numpy", "flox"]),
231+
array_dtype=st.none() | array_dtypes,
232+
dtype=st.none() | array_dtypes,
233+
)
234+
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
235+
# regression test for GH388
236+
counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
237+
group = np.array([1, 1, 1, 2, 2])
238+
actual, _ = groupby_reduce(
239+
counts,
240+
group,
241+
expected_groups=(np.array([1, 2]),),
242+
func=func,
243+
dtype=dtype,
244+
engine=engine,
245+
)
246+
expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
247+
assert actual.dtype == expected.dtype

0 commit comments

Comments
 (0)