diff --git a/flox/core.py b/flox/core.py index 10a9197a9..15577d35c 100644 --- a/flox/core.py +++ b/flox/core.py @@ -6,17 +6,7 @@ import operator from collections import namedtuple from functools import partial, reduce -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - Literal, - Mapping, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union import numpy as np import numpy_groupies as npg @@ -37,8 +27,11 @@ if TYPE_CHECKING: import dask.array.Array as DaskArray + T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index] + T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None] T_Func = Union[str, Callable] T_Funcs = Union[T_Func, Sequence[T_Func]] + T_Agg = Union[str, Aggregation] T_Axis = int T_Axes = tuple[T_Axis, ...] T_AxesOpt = Union[T_Axis, T_Axes, None] @@ -60,7 +53,7 @@ DUMMY_AXIS = -2 -def _is_arg_reduction(func: str | Aggregation) -> bool: +def _is_arg_reduction(func: T_Agg) -> bool: if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]: return True if isinstance(func, Aggregation) and func.reduction_type == "argreduce": @@ -68,6 +61,12 @@ def _is_arg_reduction(func: str | Aggregation) -> bool: return False +def _is_minmax_reduction(func: T_Agg) -> bool: + return not _is_arg_reduction(func) and ( + isinstance(func, str) and ("max" in func or "min" in func) + ) + + def _get_expected_groups(by, sort: bool) -> pd.Index: if is_duck_dask_array(by): raise ValueError("Please provide expected_groups if not grouping by a numpy array.") @@ -1027,7 +1026,16 @@ def split_blocks(applied, split_out, expected_groups, split_name): def _reduce_blockwise( - array, by, agg, *, axis: T_Axes, expected_groups, fill_value, engine: T_Engine, sort, reindex + array, + by, + agg: Aggregation, + *, + axis: T_Axes, + expected_groups, + fill_value, + engine: T_Engine, + sort, + reindex, ) -> FinalResultsDict: """ Blockwise groupby reduction that produces the final result. This code path is @@ -1335,7 +1343,7 @@ def _assert_by_is_aligned(shape, by): def _convert_expected_groups_to_index( - expected_groups: Iterable, isbin: Sequence[bool], sort: bool + expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool ) -> tuple[pd.Index | None, ...]: out: list[pd.Index | None] = [] for ex, isbin_ in zip(expected_groups, isbin): @@ -1397,8 +1405,8 @@ def _factorize_multiple(by, expected_groups, by_is_dask, reindex): def groupby_reduce( array: np.ndarray | DaskArray, *by: np.ndarray | DaskArray, - func: str | Aggregation, - expected_groups: Sequence | np.ndarray | None = None, + func: T_Agg, + expected_groups: T_ExpectedGroupsOpt = None, sort: bool = True, isbin: T_IsBins = False, axis: T_AxesOpt = None, @@ -1520,7 +1528,8 @@ def groupby_reduce( if not is_duck_array(array): array = np.asarray(array) - array = array.astype(int) if np.issubdtype(array.dtype, bool) else array + is_bool_array = np.issubdtype(array.dtype, bool) + array = array.astype(int) if is_bool_array else array if isinstance(isbin, Sequence): isbins = isbin @@ -1709,4 +1718,7 @@ def groupby_reduce( result, from_=groups[0], to=expected_groups, fill_value=fill_value ).reshape(result.shape[:-1] + grp_shape) groups = final_groups + + if _is_minmax_reduction(func) and is_bool_array: + result = result.astype(bool) return (result, *groups) diff --git a/tests/__init__.py b/tests/__init__.py index fef0a778e..9917b41fc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,7 +14,7 @@ dask_array_type = da.Array except ImportError: - dask_array_type = () + dask_array_type = () # type: ignore try: @@ -22,7 +22,7 @@ xr_types = (xr.DataArray, xr.Dataset) except ImportError: - xr_types = () + xr_types = () # type: ignore def _importorskip(modname, minversion=None): @@ -98,6 +98,9 @@ def assert_equal(a, b): # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: + if a.dtype != b.dtype: + raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})") + np.testing.assert_allclose(a, b, equal_nan=True) diff --git a/tests/test_core.py b/tests/test_core.py index c70ae83c6..25660e734 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from functools import reduce +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -63,6 +66,9 @@ def dask_array_ones(*args): pytest.param("nanmedian", marks=(pytest.mark.skip,)), ) +if TYPE_CHECKING: + from flox.core import T_Engine, T_ExpectedGroupsOpt, T_Func2 + def test_alignment_error(): da = np.ones((12,)) @@ -101,8 +107,16 @@ def test_alignment_error(): ], ) def test_groupby_reduce( - array, by, expected, func, expected_groups, chunk, split_out, dtype, engine -): + engine: T_Engine, + func: T_Func2, + array: np.ndarray, + by: np.ndarray, + expected: list[float], + expected_groups: T_ExpectedGroupsOpt, + chunk: bool, + split_out: int, + dtype: np.typing.DTypeLike, +) -> None: array = array.astype(dtype) if chunk: if not has_dask or expected_groups is None: @@ -110,12 +124,12 @@ def test_groupby_reduce( array = da.from_array(array, chunks=(3,) if array.ndim == 1 else (1, 3)) by = da.from_array(by, chunks=(3,) if by.ndim == 1 else (1, 3)) - if "mean" in func: - expected = np.array(expected, dtype=float) + if func == "mean" or func == "nanmean": + expected_result = np.array(expected, dtype=float) elif func == "sum": - expected = np.array(expected, dtype=dtype) + expected_result = np.array(expected, dtype=dtype) elif func == "count": - expected = np.array(expected, dtype=int) + expected_result = np.array(expected, dtype=int) result, groups, = groupby_reduce( array, @@ -126,8 +140,10 @@ def test_groupby_reduce( split_out=split_out, engine=engine, ) - assert_equal(groups, [0, 1, 2]) - assert_equal(expected, result) + g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype + + assert_equal(groups, np.array([0, 1, 2], g_dtype)) + assert_equal(expected_result, result) def gen_array_by(size, func): @@ -843,16 +859,16 @@ def test_bool_reductions(func, engine): @requires_dask -def test_map_reduce_blockwise_mixed(): +def test_map_reduce_blockwise_mixed() -> None: t = pd.date_range("2000-01-01", "2000-12-31", freq="D").to_series() data = t.dt.dayofyear - actual = groupby_reduce( + actual, _ = groupby_reduce( dask.array.from_array(data.values, chunks=365), t.dt.month, func="mean", method="split-reduce", ) - expected = groupby_reduce(data, t.dt.month, func="mean") + expected, _ = groupby_reduce(data, t.dt.month, func="mean") assert_equal(expected, actual) @@ -908,7 +924,7 @@ def test_factorize_values_outside_bins(): assert_equal(expected, actual) -def test_multiple_groupers(): +def test_multiple_groupers() -> None: actual, *_ = groupby_reduce( np.ones((5, 2)), np.arange(10).reshape(5, 2), @@ -921,7 +937,7 @@ def test_multiple_groupers(): reindex=True, func="count", ) - expected = np.eye(5, 5) + expected = np.eye(5, 5, dtype=int) assert_equal(expected, actual)