From 6c65fa0f4c49d6b057626092572f311ed1e616e0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 16:39:26 -0700 Subject: [PATCH 01/23] Fix numbagg version check Closes #281 --- flox/xrutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/xrutils.py b/flox/xrutils.py index 497cd7b24..0ced6fbb6 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -339,6 +339,6 @@ def module_available(module: str, minversion: Optional[str] = None) -> bool: has = importlib.util.find_spec(module) is not None if has: mod = importlib.import_module(module) - return Version(mod.__version__) < Version(minversion) if minversion is not None else True + return Version(mod.__version__) >= Version(minversion) if minversion is not None else True else: return False From 75a7a3d622ade5dce31bcfe73c9f3f8b31eae2fb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 16:46:36 -0700 Subject: [PATCH 02/23] Enable numbagg for count --- flox/core.py | 2 +- tests/test_core.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 66e39b6f1..fbf420fa4 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1779,7 +1779,7 @@ def _choose_engine(by, agg: Aggregation): # numbagg only supports nan-skipping reductions # without dtype specified - if HAS_NUMBAGG and "nan" in agg.name: + if HAS_NUMBAGG and ("nan" in agg.name or agg.name == "count"): if not_arg_reduce and dtype is None: return "numbagg" diff --git a/tests/test_core.py b/tests/test_core.py index 99f181255..984baee2c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1565,6 +1565,18 @@ def test_choose_engine(dtype): min_count=0, finalize_kwargs=None, ) + count = _initialize_aggregation( + "count", + dtype=None, + array_dtype=dtype, + fill_value=0, + min_count=0, + finalize_kwargs=None, + ) + + # count_engine + count_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=count) + assert count_engine == ("numbagg" if numbagg_possible else "flox") # sorted by -> flox sorted_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=mean) From a5dc5743cb309e8264da680beffbc8a93c0fe9eb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 16:58:17 -0700 Subject: [PATCH 03/23] Better numbagg special-casing --- flox/core.py | 9 +++++++-- tests/test_core.py | 21 +++++++++++---------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/flox/core.py b/flox/core.py index fbf420fa4..6fd17e99c 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1779,8 +1779,13 @@ def _choose_engine(by, agg: Aggregation): # numbagg only supports nan-skipping reductions # without dtype specified - if HAS_NUMBAGG and ("nan" in agg.name or agg.name == "count"): - if not_arg_reduce and dtype is None: + has_blockwise_nan_skipping = (agg.chunk is None and "nan" in agg.name) or any( + "nan" in func for func in agg.chunk + ) + if HAS_NUMBAGG: + if agg.name in ["all", "any"] or ( + not_arg_reduce and has_blockwise_nan_skipping and dtype is None + ): return "numbagg" if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)): diff --git a/tests/test_core.py b/tests/test_core.py index 984baee2c..50fa55462 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1565,18 +1565,19 @@ def test_choose_engine(dtype): min_count=0, finalize_kwargs=None, ) - count = _initialize_aggregation( - "count", - dtype=None, - array_dtype=dtype, - fill_value=0, - min_count=0, - finalize_kwargs=None, - ) # count_engine - count_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=count) - assert count_engine == ("numbagg" if numbagg_possible else "flox") + for method in ["all", "any", "count"]: + agg = _initialize_aggregation( + method, + dtype=None, + array_dtype=dtype, + fill_value=0, + min_count=0, + finalize_kwargs=None, + ) + engine = _choose_engine(np.array([1, 1, 2, 2]), agg=agg) + assert engine == ("numbagg" if HAS_NUMBAGG else "flox") # sorted by -> flox sorted_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=mean) From ee62e376063b20e2a08bdb5f545d2ff5a8ad592a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 17:13:53 -0700 Subject: [PATCH 04/23] Fixes. --- flox/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 6fd17e99c..c2e428281 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1779,8 +1779,8 @@ def _choose_engine(by, agg: Aggregation): # numbagg only supports nan-skipping reductions # without dtype specified - has_blockwise_nan_skipping = (agg.chunk is None and "nan" in agg.name) or any( - "nan" in func for func in agg.chunk + has_blockwise_nan_skipping = (agg.chunk[0] is None and "nan" in agg.name) or any( + (isinstance(func, str) and "nan" in func) for func in agg.chunk ) if HAS_NUMBAGG: if agg.name in ["all", "any"] or ( From 55c06d6e291ce15efce56086c0499f00a47735f6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 20:15:56 -0700 Subject: [PATCH 05/23] A bunch of typing --- flox/aggregations.py | 15 ++++++++------- flox/core.py | 15 +++++++++------ pyproject.toml | 1 + 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 3a65f9396..52ffb3396 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -3,7 +3,7 @@ import copy import warnings from functools import partial -from typing import TYPE_CHECKING, Any, Callable, TypedDict +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np from numpy.typing import DTypeLike @@ -13,6 +13,7 @@ if TYPE_CHECKING: FuncTuple = tuple[Callable | str, ...] + OptionalFuncTuple = tuple[Callable | str | None, ...] def _is_arg_reduction(func: str | Aggregation) -> bool: @@ -153,7 +154,7 @@ def __init__( final_fill_value=dtypes.NA, dtypes=None, final_dtype: DTypeLike | None = None, - reduction_type="reduce", + reduction_type: Literal["reduce", "argreduce"] = "reduce", ): """ Blueprint for computing grouped aggregations. @@ -204,13 +205,13 @@ def __init__( self.reduction_type = reduction_type self.numpy: FuncTuple = (numpy,) if numpy else (self.name,) # initialize blockwise reduction - self.chunk: FuncTuple = _atleast_1d(chunk) + self.chunk: OptionalFuncTuple = _atleast_1d(chunk) # how to aggregate results after first round of reduction - self.combine: FuncTuple = _atleast_1d(combine) + self.combine: OptionalFuncTuple = _atleast_1d(combine) # simpler reductions used with the "simple combine" algorithm - self.simple_combine: tuple[Callable, ...] = () + self.simple_combine: tuple[Callable | None, ...] = () # final aggregation - self.aggregate: Callable | str = aggregate if aggregate else self.combine[0] + self.aggregate: Callable | str | None = aggregate if aggregate else self.combine[0] # finalize results (see mean) self.finalize: Callable | None = finalize @@ -618,7 +619,7 @@ def _initialize_aggregation( else: agg.min_count = 0 - simple_combine: list[Callable] = [] + simple_combine: list[Callable | None] = [] for combine in agg.combine: if isinstance(combine, str): if combine in ["nanfirst", "nanlast"]: diff --git a/flox/core.py b/flox/core.py index c2e428281..5f5829420 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1044,6 +1044,8 @@ def _grouped_combine( """Combine intermediates step of tree reduction.""" from dask.utils import deepmap + combine = agg.combine + if isinstance(x_chunk, dict): # Only one block at final step; skip one extra groupby return x_chunk @@ -1084,7 +1086,8 @@ def _grouped_combine( results = chunk_argreduce( array_idx, groups, - func=agg.combine[slicer], # count gets treated specially next + # count gets treated specially next + func=combine[slicer], # type: ignore[arg-type] axis=axis, expected_groups=None, fill_value=agg.fill_value["intermediate"][slicer], @@ -1118,9 +1121,10 @@ def _grouped_combine( elif agg.reduction_type == "reduce": # Here we reduce the intermediates individually results = {"groups": None, "intermediates": []} - for idx, (combine, fv, dtype) in enumerate( - zip(agg.combine, agg.fill_value["intermediate"], agg.dtype["intermediate"]) + for idx, (combine_, fv, dtype) in enumerate( + zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"]) ): + assert combine_ is not None array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) if array.shape[-1] == 0: # all empty when combined @@ -1134,7 +1138,7 @@ def _grouped_combine( _results = chunk_reduce( array, groups, - func=combine, + func=combine_, axis=axis, expected_groups=None, fill_value=(fv,), @@ -2085,8 +2089,7 @@ def groupby_reduce( # TODO: How else to narrow that array.chunks is there? assert isinstance(array, DaskArray) - # TODO: fix typing of FuncTuple in Aggregation - if agg.chunk[0] is None and method != "blockwise": # type: ignore[unreachable] + if agg.chunk[0] is None and method != "blockwise": raise NotImplementedError( f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'." f"Received method={method!r}" diff --git a/pyproject.toml b/pyproject.toml index 9cf0b4ca4..c507a5222 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ show_error_codes = true warn_unused_ignores = true warn_unreachable = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +exclude=["asv_bench/pkgs"] [[tool.mypy.overrides]] module=[ From 46242babf3b7da8ff66d64bcda22242756449762 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 21:06:36 -0700 Subject: [PATCH 06/23] Handle fill_value in core numbagg reduction. --- flox/aggregate_numbagg.py | 98 ++++++++++++++++++++++----------------- flox/core.py | 2 +- 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index b0b06d86e..55a6bb85c 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -4,6 +4,31 @@ import numbagg.grouped import numpy as np +DEFAULT_FILL_VALUE = { + "nansum": 0, + "nanmean": np.nan, + "nanvar": np.nan, + "nanstd": np.nan, + "nanmin": np.nan, + "nanmax": np.nan, + "nanany": False, + "nanall": False, + "nansum_of_squares": 0, + "nanprod": 1, + "nancount": 0, + "nanargmax": np.nan, + "nanargmin": np.nan, + "nanfirst": np.nan, + "nanlast": np.nan, +} + +CAST_TO = { + "nansum": {np.bool_: np.int64}, + "nanmean": {np.int_: np.float64}, + "nanvar": {np.int_: np.float64}, + "nanstd": {np.int_: np.float64}, +} + def _numbagg_wrapper( group_idx, @@ -16,52 +41,39 @@ def _numbagg_wrapper( dtype=None, numbagg_func=None, ): - return numbagg_func( - array, - group_idx, - axis=axis, - num_labels=size, - # The following are unsupported - # fill_value=fill_value, - # dtype=dtype, - ) - + cast_to = CAST_TO.get(numbagg_func, None) + if cast_to: + for from_, to_ in cast_to.items(): + if isinstance(array, from_): + array = array.astype(to_) -def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): - if np.issubdtype(array.dtype, np.bool_): - array = array.astype(np.in64) - return numbagg.grouped.group_nansum( + func_ = getattr(numbagg.grouped, f"group_{numbagg_func}") + result = func_( array, group_idx, axis=axis, num_labels=size, + # The following are unsupported # fill_value=fill_value, # dtype=dtype, ) - -def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): - if np.issubdtype(array.dtype, np.int_): - array = array.astype(np.float64) - return numbagg.grouped.group_nanmean( - array, - group_idx, - axis=axis, - num_labels=size, - # fill_value=fill_value, - # dtype=dtype, - ) + default_fv = DEFAULT_FILL_VALUE[numbagg_func] + if fill_value is not None and fill_value != default_fv: + count = numbagg.grouped.group_nancount(array, group_idx, axis=axis, num_labels=size) + result[count == 0] = fill_value + return result def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0): assert ddof != 0 - if np.issubdtype(array.dtype, np.int_): - array = array.astype(np.float64) - return numbagg.grouped.group_nanvar( + + return _numbagg_wrapper( array, group_idx, axis=axis, num_labels=size, + numbagg_func="nanvar" # ddof=0, # fill_value=fill_value, # dtype=dtype, @@ -70,8 +82,7 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0): assert ddof != 0 - if np.issubdtype(array.dtype, np.int_): - array = array.astype(np.float64) + return numbagg.grouped.group_nanstd( array, group_idx, @@ -83,17 +94,20 @@ def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ) -nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares) -nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nancount) -nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod) -nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst) -nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast) -# nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax) -# nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin) -nanmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmax) -nanmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmin) -any = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanany) -all = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanall) +nansum = partial(_numbagg_wrapper, numbagg_func="nansum") +nanmean = partial(_numbagg_wrapper, numbagg_func="nanmean") +nanprod = partial(_numbagg_wrapper, numbagg_func="nanprod") +nansum_of_squares = partial(_numbagg_wrapper, numbagg_func="nansum_of_squares") +nanlen = partial(_numbagg_wrapper, numbagg_func="nancount") +nanprod = partial(_numbagg_wrapper, numbagg_func="nanprod") +nanfirst = partial(_numbagg_wrapper, numbagg_func="nanfirst") +nanlast = partial(_numbagg_wrapper, numbagg_func="nanlast") +# nanargmax = partial(_numbagg_wrapper, numbagg_func="nanargmax) +# nanargmin = partial(_numbagg_wrapper, numbagg_func="nanargmin) +nanmax = partial(_numbagg_wrapper, numbagg_func="nanmax") +nanmin = partial(_numbagg_wrapper, numbagg_func="nanmin") +any = partial(_numbagg_wrapper, numbagg_func="nanany") +all = partial(_numbagg_wrapper, numbagg_func="nanall") # sum = nansum # mean = nanmean diff --git a/flox/core.py b/flox/core.py index 5f5829420..1e1940559 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2049,7 +2049,7 @@ def groupby_reduce( nax = len(axis_) # When axis is a subset of possible values; then npg will - # apply it to groups that don't exist along a particular axis (for e.g.) + # apply the fill_value to groups that don't exist along a particular axis (for e.g.) # since these count as a group that is absent. thoo! # fill_value applies to all-NaN groups as well as labels in expected_groups that are not found. # The only way to do this consistently is mask out using min_count From 83853c0ec2aa70243e48bf3b2ac7a8a8797503aa Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 21:10:30 -0700 Subject: [PATCH 07/23] Update flox/aggregate_numbagg.py --- flox/aggregate_numbagg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 55a6bb85c..9244b44b7 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -96,7 +96,6 @@ def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, nansum = partial(_numbagg_wrapper, numbagg_func="nansum") nanmean = partial(_numbagg_wrapper, numbagg_func="nanmean") -nanprod = partial(_numbagg_wrapper, numbagg_func="nanprod") nansum_of_squares = partial(_numbagg_wrapper, numbagg_func="nansum_of_squares") nanlen = partial(_numbagg_wrapper, numbagg_func="nancount") nanprod = partial(_numbagg_wrapper, numbagg_func="nanprod") From 36e9359c89735746fb982551ec24585e4c5ef278 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 21:33:23 -0700 Subject: [PATCH 08/23] cleanup --- flox/aggregations.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 52ffb3396..8cd3d00b0 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -285,13 +285,7 @@ def __repr__(self) -> str: sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0) nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0) prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1) -nanprod = Aggregation( - "nanprod", - chunk="nanprod", - combine="prod", - fill_value=1, - final_fill_value=dtypes.NA, -) +nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1) def _mean_finalize(sum_, count): From 94cb700c5682e7749a4507c204427c4c789870ac Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 21:33:29 -0700 Subject: [PATCH 09/23] [WIP] test hacky fix --- flox/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flox/core.py b/flox/core.py index 1e1940559..d7d1a3472 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2071,6 +2071,11 @@ def groupby_reduce( kwargs = dict(axis=axis_, fill_value=fill_value) agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs) + # HACK? + if min_count_ == 0: + agg.fill_value["numpy"] = None + agg.fill_value[agg.name] = None + # Need to set this early using `agg` # It cannot be done in the core loop of chunk_reduce # since we "prepare" the data for flox. From 5fd2acdbb615896fb69cace67d718471acad96d6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 6 Nov 2023 22:16:12 -0700 Subject: [PATCH 10/23] [wip] --- flox/aggregate_numbagg.py | 13 +++++++++++-- flox/aggregations.py | 4 ++++ flox/core.py | 8 ++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 9244b44b7..54bebfc9b 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -30,6 +30,8 @@ } +FILLNA = {"nansum": 0, "nanprod": 1} + def _numbagg_wrapper( group_idx, array, @@ -48,6 +50,11 @@ def _numbagg_wrapper( array = array.astype(to_) func_ = getattr(numbagg.grouped, f"group_{numbagg_func}") + default_fv = DEFAULT_FILL_VALUE[numbagg_func] + + fillna = FILLNA.get(numbagg_func, None) + if fillna: + array = np.where(np.isnan(array), fillna, array) result = func_( array, group_idx, @@ -58,8 +65,10 @@ def _numbagg_wrapper( # dtype=dtype, ) - default_fv = DEFAULT_FILL_VALUE[numbagg_func] - if fill_value is not None and fill_value != default_fv: + # The condition needs to be + # is len(found_groups) < size; if so we mask with fill_value (?) + needs_masking = fill_value is not None and fill_value != default_fv + if needs_masking: count = numbagg.grouped.group_nancount(array, group_idx, axis=axis, num_labels=size) result[count == 0] = fill_value return result diff --git a/flox/aggregations.py b/flox/aggregations.py index 8cd3d00b0..29ded1cd3 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -579,6 +579,7 @@ def _initialize_aggregation( } # Replace sentinel fill values according to dtype + agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( _get_fill_value(dt, fv) for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) @@ -589,6 +590,9 @@ def _initialize_aggregation( if _is_arg_reduction(agg): # this allows us to unravel_index easily. we have to do that nearly every time. agg.fill_value["numpy"] = (0,) + # elif min_count == 0 and agg.fill_value["user"] is None: + # # disable filling completely + # agg.fill_value["numpy"] = (None,) else: agg.fill_value["numpy"] = (fv,) diff --git a/flox/core.py b/flox/core.py index d7d1a3472..f8114161d 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2071,10 +2071,10 @@ def groupby_reduce( kwargs = dict(axis=axis_, fill_value=fill_value) agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs) - # HACK? - if min_count_ == 0: - agg.fill_value["numpy"] = None - agg.fill_value[agg.name] = None + # # HACK? + # if min_count_ == 0: + # agg.fill_value["numpy"] = None + # agg.fill_value[agg.name] = None # Need to set this early using `agg` # It cannot be done in the core loop of chunk_reduce From e27b04bf7e6404e789426846a2162401c1753a63 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 20:53:49 -0700 Subject: [PATCH 11/23] Cleanup functions --- flox/aggregate_numbagg.py | 52 ++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 54bebfc9b..76fb57294 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -32,27 +32,27 @@ FILLNA = {"nansum": 0, "nanprod": 1} + def _numbagg_wrapper( group_idx, array, *, axis=-1, - func="sum", + func=None, size=None, fill_value=None, dtype=None, - numbagg_func=None, ): - cast_to = CAST_TO.get(numbagg_func, None) + cast_to = CAST_TO.get(func, None) if cast_to: for from_, to_ in cast_to.items(): if isinstance(array, from_): array = array.astype(to_) - func_ = getattr(numbagg.grouped, f"group_{numbagg_func}") - default_fv = DEFAULT_FILL_VALUE[numbagg_func] + func_ = getattr(numbagg.grouped, f"group_{func}") + default_fv = DEFAULT_FILL_VALUE[func] - fillna = FILLNA.get(numbagg_func, None) + fillna = FILLNA.get(func, None) if fillna: array = np.where(np.isnan(array), fillna, array) result = func_( @@ -78,11 +78,11 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, assert ddof != 0 return _numbagg_wrapper( - array, group_idx, + array, axis=axis, - num_labels=size, - numbagg_func="nanvar" + size=size, + func="nanvar", # ddof=0, # fill_value=fill_value, # dtype=dtype, @@ -92,30 +92,32 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0): assert ddof != 0 - return numbagg.grouped.group_nanstd( - array, + return _numbagg_wrapper( group_idx, + array, axis=axis, - num_labels=size, + size=size, + func="nanstd" # ddof=0, # fill_value=fill_value, # dtype=dtype, ) -nansum = partial(_numbagg_wrapper, numbagg_func="nansum") -nanmean = partial(_numbagg_wrapper, numbagg_func="nanmean") -nansum_of_squares = partial(_numbagg_wrapper, numbagg_func="nansum_of_squares") -nanlen = partial(_numbagg_wrapper, numbagg_func="nancount") -nanprod = partial(_numbagg_wrapper, numbagg_func="nanprod") -nanfirst = partial(_numbagg_wrapper, numbagg_func="nanfirst") -nanlast = partial(_numbagg_wrapper, numbagg_func="nanlast") -# nanargmax = partial(_numbagg_wrapper, numbagg_func="nanargmax) -# nanargmin = partial(_numbagg_wrapper, numbagg_func="nanargmin) -nanmax = partial(_numbagg_wrapper, numbagg_func="nanmax") -nanmin = partial(_numbagg_wrapper, numbagg_func="nanmin") -any = partial(_numbagg_wrapper, numbagg_func="nanany") -all = partial(_numbagg_wrapper, numbagg_func="nanall") +nansum = partial(_numbagg_wrapper, func="nansum") +nanmean = partial(_numbagg_wrapper, func="nanmean") +nanprod = partial(_numbagg_wrapper, func="nanprod") +nansum_of_squares = partial(_numbagg_wrapper, func="nansum_of_squares") +nanlen = partial(_numbagg_wrapper, func="nancount") +nanprod = partial(_numbagg_wrapper, func="nanprod") +nanfirst = partial(_numbagg_wrapper, func="nanfirst") +nanlast = partial(_numbagg_wrapper, func="nanlast") +# nanargmax = partial(_numbagg_wrapper, func="nanargmax) +# nanargmin = partial(_numbagg_wrapper, func="nanargmin) +nanmax = partial(_numbagg_wrapper, func="nanmax") +nanmin = partial(_numbagg_wrapper, func="nanmin") +any = partial(_numbagg_wrapper, func="nanany") +all = partial(_numbagg_wrapper, func="nanall") # sum = nansum # mean = nanmean From 8a2ac0e5433cb80092878e0ac6151904c04a5327 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 20:54:01 -0700 Subject: [PATCH 12/23] Fix casting --- flox/aggregate_numbagg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 76fb57294..b5ba4b4b8 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -23,7 +23,7 @@ } CAST_TO = { - "nansum": {np.bool_: np.int64}, + # "nansum": {np.bool_: np.int64}, "nanmean": {np.int_: np.float64}, "nanvar": {np.int_: np.float64}, "nanstd": {np.int_: np.float64}, @@ -46,7 +46,7 @@ def _numbagg_wrapper( cast_to = CAST_TO.get(func, None) if cast_to: for from_, to_ in cast_to.items(): - if isinstance(array, from_): + if np.issubdtype(array.dtype, from_): array = array.astype(to_) func_ = getattr(numbagg.grouped, f"group_{func}") From c0d7347821eda7365e9c6a7ce75318a35d44c49c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 20:54:15 -0700 Subject: [PATCH 13/23] Fix fill_value masking --- flox/aggregate_numbagg.py | 15 ++++++++++----- flox/core.py | 5 ----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index b5ba4b4b8..382537878 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -4,6 +4,8 @@ import numbagg.grouped import numpy as np +from .core import _unique + DEFAULT_FILL_VALUE = { "nansum": 0, "nanmean": np.nan, @@ -37,8 +39,8 @@ def _numbagg_wrapper( group_idx, array, *, + func, axis=-1, - func=None, size=None, fill_value=None, dtype=None, @@ -63,14 +65,17 @@ def _numbagg_wrapper( # The following are unsupported # fill_value=fill_value, # dtype=dtype, - ) + ).astype(dtype, copy=False) # The condition needs to be # is len(found_groups) < size; if so we mask with fill_value (?) - needs_masking = fill_value is not None and fill_value != default_fv + needs_masking = fill_value is not None and not np.array_equal( + fill_value, default_fv, equal_nan=True + ) if needs_masking: - count = numbagg.grouped.group_nancount(array, group_idx, axis=axis, num_labels=size) - result[count == 0] = fill_value + uniques = _unique(group_idx) + mask = np.isin(uniques, np.arange(size), assume_unique=True, invert=True) + result[..., uniques[mask]] = fill_value return result diff --git a/flox/core.py b/flox/core.py index f8114161d..1e1940559 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2071,11 +2071,6 @@ def groupby_reduce( kwargs = dict(axis=axis_, fill_value=fill_value) agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs) - # # HACK? - # if min_count_ == 0: - # agg.fill_value["numpy"] = None - # agg.fill_value[agg.name] = None - # Need to set this early using `agg` # It cannot be done in the core loop of chunk_reduce # since we "prepare" the data for flox. From b32794bdd663b5b98736ad4d7b39459806653705 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 21:14:58 -0700 Subject: [PATCH 14/23] optimize --- flox/aggregate_numbagg.py | 15 --------------- flox/core.py | 24 +++++++++++++++++++++++- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/flox/aggregate_numbagg.py b/flox/aggregate_numbagg.py index 382537878..a5f12d7e0 100644 --- a/flox/aggregate_numbagg.py +++ b/flox/aggregate_numbagg.py @@ -4,8 +4,6 @@ import numbagg.grouped import numpy as np -from .core import _unique - DEFAULT_FILL_VALUE = { "nansum": 0, "nanmean": np.nan, @@ -52,11 +50,7 @@ def _numbagg_wrapper( array = array.astype(to_) func_ = getattr(numbagg.grouped, f"group_{func}") - default_fv = DEFAULT_FILL_VALUE[func] - fillna = FILLNA.get(func, None) - if fillna: - array = np.where(np.isnan(array), fillna, array) result = func_( array, group_idx, @@ -67,15 +61,6 @@ def _numbagg_wrapper( # dtype=dtype, ).astype(dtype, copy=False) - # The condition needs to be - # is len(found_groups) < size; if so we mask with fill_value (?) - needs_masking = fill_value is not None and not np.array_equal( - fill_value, default_fv, equal_nan=True - ) - if needs_masking: - uniques = _unique(group_idx) - mask = np.isin(uniques, np.arange(size), assume_unique=True, invert=True) - result[..., uniques[mask]] = fill_value return result diff --git a/flox/core.py b/flox/core.py index 1e1940559..d0cbab17b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -86,6 +86,24 @@ DUMMY_AXIS = -2 +def _postprocess_numbagg(result, *, func, fill_value, size, found_groups): + """Account for numbagg not providing a fill_value kwarg.""" + from .aggregate_numbagg import DEFAULT_FILL_VALUE + + if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE: + return result + # The condition needs to be + # len(found_groups) < size; if so we mask with fill_value (?) + default_fv = DEFAULT_FILL_VALUE[func] + needs_masking = fill_value is not None and not np.array_equal( + fill_value, default_fv, equal_nan=True + ) + if needs_masking: + mask = np.isin(found_groups, np.arange(size), assume_unique=True, invert=True) + result[..., found_groups[mask]] = fill_value + return result + + def _issorted(arr: np.ndarray) -> bool: return bool((arr[:-1] <= arr[1:]).all()) @@ -780,7 +798,7 @@ def chunk_reduce( group_idx, grps, found_groups_shape, _, size, props = factorize_( (by,), axes, expected_groups=(expected_groups,), reindex=reindex, sort=sort ) - groups = grps[0] + (groups,) = grps if nax > 1: needs_broadcast = any( @@ -846,6 +864,10 @@ def chunk_reduce( # remove NaN group label which should be last result = result[..., :-1] result = result.reshape(final_array_shape[:-1] + found_groups_shape) + if engine == "numbagg": + result = _postprocess_numbagg( + result, func=func, size=size, fill_value=fill_value, found_groups=groups + ) results["intermediates"].append(result) results["groups"] = np.broadcast_to(results["groups"], final_groups_shape) From 48e03d93a050e2573111b69056f6e3ff9d0c87cb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 21:33:25 -0700 Subject: [PATCH 15/23] Update flox/aggregations.py --- flox/aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index d7722fa5a..178c761b0 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -612,7 +612,7 @@ def _initialize_aggregation( else: agg.min_count = 0 - simple_combine: list[Callable | None] = [] + simple_combine: list[Callable] = [] for combine in agg.combine: if isinstance(combine, str): if combine in ["nanfirst", "nanlast"]: From c164f380839aea1ad86955876111585b836e14fc Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 21:34:22 -0700 Subject: [PATCH 16/23] Small cleanup --- flox/aggregations.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 178c761b0..2be1c7b1c 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -585,9 +585,6 @@ def _initialize_aggregation( if _is_arg_reduction(agg): # this allows us to unravel_index easily. we have to do that nearly every time. agg.fill_value["numpy"] = (0,) - # elif min_count == 0 and agg.fill_value["user"] is None: - # # disable filling completely - # agg.fill_value["numpy"] = (None,) else: agg.fill_value["numpy"] = (fv,) From 6c7489d8ee2e35cd77b2ec24dfbac734c1307590 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 21:49:23 -0700 Subject: [PATCH 17/23] Fix. --- flox/core.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index 2302f02c9..e09eac999 100644 --- a/flox/core.py +++ b/flox/core.py @@ -98,9 +98,11 @@ def _postprocess_numbagg(result, *, func, fill_value, size, found_groups): needs_masking = fill_value is not None and not np.array_equal( fill_value, default_fv, equal_nan=True ) + groups = np.arange(size) if needs_masking: - mask = np.isin(found_groups, np.arange(size), assume_unique=True, invert=True) - result[..., found_groups[mask]] = fill_value + mask = np.isin(groups, found_groups, assume_unique=True, invert=True) + if mask.any(): + result[..., groups[mask]] = fill_value return result @@ -865,7 +867,11 @@ def chunk_reduce( result = result.reshape(final_array_shape[:-1] + found_groups_shape) if engine == "numbagg": result = _postprocess_numbagg( - result, func=func, size=size, fill_value=fill_value, found_groups=groups + result, + func=reduction, + size=size, + fill_value=fv, + found_groups=_unique(group_idx), ) results["intermediates"].append(result) previous_reduction = reduction From 3454db982280ab8dc691c58f9339e0d358a86722 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 21:52:57 -0700 Subject: [PATCH 18/23] Fix typing --- flox/aggregations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 2be1c7b1c..b91d191b2 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -208,7 +208,7 @@ def __init__( # how to aggregate results after first round of reduction self.combine: OptionalFuncTuple = _atleast_1d(combine) # simpler reductions used with the "simple combine" algorithm - self.simple_combine: tuple[Callable, ...] = () + self.simple_combine: OptionalFuncTuple = () # finalize results (see mean) self.finalize: Callable | None = finalize @@ -609,7 +609,7 @@ def _initialize_aggregation( else: agg.min_count = 0 - simple_combine: list[Callable] = [] + simple_combine: list[Callable | None] = [] for combine in agg.combine: if isinstance(combine, str): if combine in ["nanfirst", "nanlast"]: From 0395899ef925e0743d267892dbf5566ea440bdbe Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Nov 2023 22:12:23 -0700 Subject: [PATCH 19/23] Another bugfix --- flox/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/core.py b/flox/core.py index e09eac999..fe89c457b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -861,10 +861,6 @@ def chunk_reduce( result = generic_aggregate( group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func ).astype(dt, copy=False) - if np.any(props.nanmask): - # remove NaN group label which should be last - result = result[..., :-1] - result = result.reshape(final_array_shape[:-1] + found_groups_shape) if engine == "numbagg": result = _postprocess_numbagg( result, @@ -873,6 +869,10 @@ def chunk_reduce( fill_value=fv, found_groups=_unique(group_idx), ) + if np.any(props.nanmask): + # remove NaN group label which should be last + result = result[..., :-1] + result = result.reshape(final_array_shape[:-1] + found_groups_shape) results["intermediates"].append(result) previous_reduction = reduction From 783264dd8cef0277abf03e5e5a89f2873afac85a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 8 Nov 2023 08:01:55 -0700 Subject: [PATCH 20/23] Optimize seen_groups --- flox/core.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index fe89c457b..42c69ade0 100644 --- a/flox/core.py +++ b/flox/core.py @@ -86,7 +86,7 @@ DUMMY_AXIS = -2 -def _postprocess_numbagg(result, *, func, fill_value, size, found_groups): +def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): """Account for numbagg not providing a fill_value kwarg.""" from .aggregate_numbagg import DEFAULT_FILL_VALUE @@ -100,7 +100,7 @@ def _postprocess_numbagg(result, *, func, fill_value, size, found_groups): ) groups = np.arange(size) if needs_masking: - mask = np.isin(groups, found_groups, assume_unique=True, invert=True) + mask = np.isin(groups, seen_groups, assume_unique=True, invert=True) if mask.any(): result[..., groups[mask]] = fill_value return result @@ -802,6 +802,9 @@ def chunk_reduce( ) (groups,) = grps + # do this *before* possible broadcasting below. + # factorize_ has already taken care of offsetting + seen_groups = _unique(group_idx) if nax > 1: needs_broadcast = any( group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 @@ -867,7 +870,9 @@ def chunk_reduce( func=reduction, size=size, fill_value=fv, - found_groups=_unique(group_idx), + # Unfortunately, we cannot reuse found_groups, it has not + # been "offset" and is really expected_groups in nearly all cases + seen_groups=seen_groups, ) if np.any(props.nanmask): # remove NaN group label which should be last From 9366566b49bc8c84316a892319d8c3ef15f14f15 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 8 Nov 2023 08:02:56 -0700 Subject: [PATCH 21/23] Be careful about raveling --- flox/core.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 42c69ade0..463adec89 100644 --- a/flox/core.py +++ b/flox/core.py @@ -805,16 +805,26 @@ def chunk_reduce( # do this *before* possible broadcasting below. # factorize_ has already taken care of offsetting seen_groups = _unique(group_idx) + + order = "C" if nax > 1: needs_broadcast = any( group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 for ax in range(-nax, 0) ) if needs_broadcast: + # This is the dim=... case, it's a lot faster to ravel group_idx + # in fortran order since group_idx is then sorted + # I'm seeing 400ms -> 23ms for engine="flox" + # Of course we are slower to ravel `array` but we avoid argsorting + # both `array` *and* `group_idx` in _prepare_for_flox group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :]) + # if engine == "flox": + group_idx = group_idx.reshape(-1, order="F") + order = "F" # always reshape to 1D along group dimensions newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),) - array = array.reshape(newshape) + array = array.reshape(newshape, order=order) group_idx = group_idx.reshape(-1) assert group_idx.ndim == 1 From a0d932528e15c6e8a3de602e954dd6813b10037a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 8 Nov 2023 08:03:22 -0700 Subject: [PATCH 22/23] Fix benchmark skipping for numbagg --- asv_bench/benchmarks/reduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/reduce.py b/asv_bench/benchmarks/reduce.py index c475a8bc7..bbd70a487 100644 --- a/asv_bench/benchmarks/reduce.py +++ b/asv_bench/benchmarks/reduce.py @@ -18,7 +18,7 @@ numbagg_skip = [] for name in expected_names: numbagg_skip.extend( - list((func, expected_names[0], "numbagg") for func in funcs if func not in NUMBAGG_FUNCS) + list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS) ) From fd91510dc0bddadc5e4e6741ac7bea5d4292f97a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 8 Nov 2023 20:03:57 -0700 Subject: [PATCH 23/23] add test --- tests/test_core.py | 14 ++++++++++++++ tests/test_xarray.py | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 8778815c7..0d953b3c4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1586,3 +1586,17 @@ def test_choose_engine(dtype): assert _choose_engine(np.array([3, 1, 1]), agg=mean) == default # argmax does not give engine="flox" assert _choose_engine(np.array([1, 1, 2, 2]), agg=argmax) == "numpy" + + +def test_xarray_fill_value_behaviour(): + bar = np.array([1, 2, 3, np.nan, np.nan, np.nan, 4, 5, np.nan, np.nan]) + times = np.arange(0, 20, 2) + actual, _ = groupby_reduce(bar, times, func="nansum", expected_groups=(np.arange(19),)) + nan = np.nan + # fmt: off + expected = np.array( + [ 1., nan, 2., nan, 3., nan, 0., nan, 0., + nan, 0., nan, 4., nan, 5., nan, 0., nan, 0.] + ) + # fmt: on + assert_equal(expected, actual) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 116937052..6028a1139 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -561,3 +561,24 @@ def test_preserve_multiindex(): ) assert "region" in hist.coords + + +def test_fill_value_xarray_behaviour(): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = xr.Dataset( + { + "bar": ( + "time", + [1, 2, 3, np.nan, np.nan, np.nan, 4, 5, np.nan, np.nan], + {"meta": "data"}, + ), + "time": times, + } + ) + + expected_time = pd.date_range("2000-01-01", freq="3H", periods=19) + expected = ds.reindex(time=expected_time) + expected = ds.resample(time="3H").sum() + with xr.set_options(use_flox=True): + actual = ds.resample(time="3H").sum() + xr.testing.assert_identical(expected, actual)