Skip to content

Fix numbagg aggregations #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
numbagg_skip = []
for name in expected_names:
numbagg_skip.extend(
list((func, expected_names[0], "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
)


Expand Down
112 changes: 63 additions & 49 deletions flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,75 @@
import numbagg.grouped
import numpy as np

DEFAULT_FILL_VALUE = {
"nansum": 0,
"nanmean": np.nan,
"nanvar": np.nan,
"nanstd": np.nan,
"nanmin": np.nan,
"nanmax": np.nan,
"nanany": False,
"nanall": False,
"nansum_of_squares": 0,
"nanprod": 1,
"nancount": 0,
"nanargmax": np.nan,
"nanargmin": np.nan,
"nanfirst": np.nan,
"nanlast": np.nan,
}

CAST_TO = {
# "nansum": {np.bool_: np.int64},
"nanmean": {np.int_: np.float64},
"nanvar": {np.int_: np.float64},
"nanstd": {np.int_: np.float64},
}


FILLNA = {"nansum": 0, "nanprod": 1}


def _numbagg_wrapper(
group_idx,
array,
*,
func,
axis=-1,
func="sum",
size=None,
fill_value=None,
dtype=None,
numbagg_func=None,
):
return numbagg_func(
array,
group_idx,
axis=axis,
num_labels=size,
# The following are unsupported
# fill_value=fill_value,
# dtype=dtype,
)
cast_to = CAST_TO.get(func, None)
if cast_to:
for from_, to_ in cast_to.items():
if np.issubdtype(array.dtype, from_):
array = array.astype(to_)

func_ = getattr(numbagg.grouped, f"group_{func}")

def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
if np.issubdtype(array.dtype, np.bool_):
array = array.astype(np.in64)
return numbagg.grouped.group_nansum(
result = func_(
array,
group_idx,
axis=axis,
num_labels=size,
# The following are unsupported
# fill_value=fill_value,
# dtype=dtype,
)

).astype(dtype, copy=False)

def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return numbagg.grouped.group_nanmean(
array,
group_idx,
axis=axis,
num_labels=size,
# fill_value=fill_value,
# dtype=dtype,
)
return result


def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
assert ddof != 0
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return numbagg.grouped.group_nanvar(
array,

return _numbagg_wrapper(
group_idx,
array,
axis=axis,
num_labels=size,
size=size,
func="nanvar",
# ddof=0,
# fill_value=fill_value,
# dtype=dtype,
Expand All @@ -70,30 +81,33 @@ def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None,

def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
assert ddof != 0
if np.issubdtype(array.dtype, np.int_):
array = array.astype(np.float64)
return numbagg.grouped.group_nanstd(
array,

return _numbagg_wrapper(
group_idx,
array,
axis=axis,
num_labels=size,
size=size,
func="nanstd"
# ddof=0,
# fill_value=fill_value,
# dtype=dtype,
)


nansum_of_squares = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nansum_of_squares)
nanlen = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nancount)
nanprod = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanprod)
nanfirst = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanfirst)
nanlast = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanlast)
# nanargmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmax)
# nanargmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanargmin)
nanmax = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmax)
nanmin = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanmin)
any = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanany)
all = partial(_numbagg_wrapper, numbagg_func=numbagg.grouped.group_nanall)
nansum = partial(_numbagg_wrapper, func="nansum")
nanmean = partial(_numbagg_wrapper, func="nanmean")
nanprod = partial(_numbagg_wrapper, func="nanprod")
nansum_of_squares = partial(_numbagg_wrapper, func="nansum_of_squares")
nanlen = partial(_numbagg_wrapper, func="nancount")
nanprod = partial(_numbagg_wrapper, func="nanprod")
nanfirst = partial(_numbagg_wrapper, func="nanfirst")
nanlast = partial(_numbagg_wrapper, func="nanlast")
# nanargmax = partial(_numbagg_wrapper, func="nanargmax)
# nanargmin = partial(_numbagg_wrapper, func="nanargmin)
nanmax = partial(_numbagg_wrapper, func="nanmax")
nanmin = partial(_numbagg_wrapper, func="nanmin")
any = partial(_numbagg_wrapper, func="nanany")
all = partial(_numbagg_wrapper, func="nanall")

# sum = nansum
# mean = nanmean
Expand Down
22 changes: 9 additions & 13 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, TypedDict
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np
from numpy.typing import DTypeLike
Expand All @@ -13,6 +13,7 @@

if TYPE_CHECKING:
FuncTuple = tuple[Callable | str, ...]
OptionalFuncTuple = tuple[Callable | str | None, ...]


def _is_arg_reduction(func: str | Aggregation) -> bool:
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
final_fill_value=dtypes.NA,
dtypes=None,
final_dtype: DTypeLike | None = None,
reduction_type="reduce",
reduction_type: Literal["reduce", "argreduce"] = "reduce",
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -203,11 +204,11 @@ def __init__(
self.reduction_type = reduction_type
self.numpy: FuncTuple = (numpy,) if numpy else (self.name,)
# initialize blockwise reduction
self.chunk: FuncTuple = _atleast_1d(chunk)
self.chunk: OptionalFuncTuple = _atleast_1d(chunk)
# how to aggregate results after first round of reduction
self.combine: FuncTuple = _atleast_1d(combine)
self.combine: OptionalFuncTuple = _atleast_1d(combine)
# simpler reductions used with the "simple combine" algorithm
self.simple_combine: tuple[Callable, ...] = ()
self.simple_combine: OptionalFuncTuple = ()
# finalize results (see mean)
self.finalize: Callable | None = finalize

Expand Down Expand Up @@ -279,13 +280,7 @@ def __repr__(self) -> str:
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
nanprod = Aggregation(
"nanprod",
chunk="nanprod",
combine="prod",
fill_value=1,
final_fill_value=dtypes.NA,
)
nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1)


def _mean_finalize(sum_, count):
Expand Down Expand Up @@ -579,6 +574,7 @@ def _initialize_aggregation(
}

# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
_get_fill_value(dt, fv)
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
Expand Down Expand Up @@ -613,7 +609,7 @@ def _initialize_aggregation(
else:
agg.min_count = 0

simple_combine: list[Callable] = []
simple_combine: list[Callable | None] = []
for combine in agg.combine:
if isinstance(combine, str):
if combine in ["nanfirst", "nanlast"]:
Expand Down
62 changes: 52 additions & 10 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@
DUMMY_AXIS = -2


def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
"""Account for numbagg not providing a fill_value kwarg."""
from .aggregate_numbagg import DEFAULT_FILL_VALUE

if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
return result
# The condition needs to be
# len(found_groups) < size; if so we mask with fill_value (?)
default_fv = DEFAULT_FILL_VALUE[func]
needs_masking = fill_value is not None and not np.array_equal(
fill_value, default_fv, equal_nan=True
)
groups = np.arange(size)
if needs_masking:
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
if mask.any():
result[..., groups[mask]] = fill_value
return result


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())

Expand Down Expand Up @@ -780,7 +800,11 @@ def chunk_reduce(
group_idx, grps, found_groups_shape, _, size, props = factorize_(
(by,), axes, expected_groups=(expected_groups,), reindex=reindex, sort=sort
)
groups = grps[0]
(groups,) = grps

# do this *before* possible broadcasting below.
# factorize_ has already taken care of offsetting
seen_groups = _unique(group_idx)

order = "C"
if nax > 1:
Expand Down Expand Up @@ -850,6 +874,16 @@ def chunk_reduce(
result = generic_aggregate(
group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func
).astype(dt, copy=False)
if engine == "numbagg":
result = _postprocess_numbagg(
result,
func=reduction,
size=size,
fill_value=fv,
# Unfortunately, we cannot reuse found_groups, it has not
# been "offset" and is really expected_groups in nearly all cases
seen_groups=seen_groups,
)
if np.any(props.nanmask):
# remove NaN group label which should be last
result = result[..., :-1]
Expand Down Expand Up @@ -1053,6 +1087,8 @@ def _grouped_combine(
"""Combine intermediates step of tree reduction."""
from dask.utils import deepmap

combine = agg.combine

if isinstance(x_chunk, dict):
# Only one block at final step; skip one extra groupby
return x_chunk
Expand Down Expand Up @@ -1093,7 +1129,8 @@ def _grouped_combine(
results = chunk_argreduce(
array_idx,
groups,
func=agg.combine[slicer], # count gets treated specially next
# count gets treated specially next
func=combine[slicer], # type: ignore[arg-type]
axis=axis,
expected_groups=None,
fill_value=agg.fill_value["intermediate"][slicer],
Expand Down Expand Up @@ -1127,9 +1164,10 @@ def _grouped_combine(
elif agg.reduction_type == "reduce":
# Here we reduce the intermediates individually
results = {"groups": None, "intermediates": []}
for idx, (combine, fv, dtype) in enumerate(
zip(agg.combine, agg.fill_value["intermediate"], agg.dtype["intermediate"])
for idx, (combine_, fv, dtype) in enumerate(
zip(combine, agg.fill_value["intermediate"], agg.dtype["intermediate"])
):
assert combine_ is not None
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis)
if array.shape[-1] == 0:
# all empty when combined
Expand All @@ -1143,7 +1181,7 @@ def _grouped_combine(
_results = chunk_reduce(
array,
groups,
func=combine,
func=combine_,
axis=axis,
expected_groups=None,
fill_value=(fv,),
Expand Down Expand Up @@ -1788,8 +1826,13 @@ def _choose_engine(by, agg: Aggregation):

# numbagg only supports nan-skipping reductions
# without dtype specified
if HAS_NUMBAGG and "nan" in agg.name:
if not_arg_reduce and dtype is None:
has_blockwise_nan_skipping = (agg.chunk[0] is None and "nan" in agg.name) or any(
(isinstance(func, str) and "nan" in func) for func in agg.chunk
)
if HAS_NUMBAGG:
if agg.name in ["all", "any"] or (
not_arg_reduce and has_blockwise_nan_skipping and dtype is None
):
return "numbagg"

if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
Expand Down Expand Up @@ -2050,7 +2093,7 @@ def groupby_reduce(
nax = len(axis_)

# When axis is a subset of possible values; then npg will
# apply it to groups that don't exist along a particular axis (for e.g.)
# apply the fill_value to groups that don't exist along a particular axis (for e.g.)
# since these count as a group that is absent. thoo!
# fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
# The only way to do this consistently is mask out using min_count
Expand Down Expand Up @@ -2090,8 +2133,7 @@ def groupby_reduce(
# TODO: How else to narrow that array.chunks is there?
assert isinstance(array, DaskArray)

# TODO: fix typing of FuncTuple in Aggregation
if agg.chunk[0] is None and method != "blockwise": # type: ignore[unreachable]
if agg.chunk[0] is None and method != "blockwise":
raise NotImplementedError(
f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
f"Received method={method!r}"
Expand Down
2 changes: 1 addition & 1 deletion flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,6 @@ def module_available(module: str, minversion: Optional[str] = None) -> bool:
has = importlib.util.find_spec(module) is not None
if has:
mod = importlib.import_module(module)
return Version(mod.__version__) < Version(minversion) if minversion is not None else True
return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
else:
return False
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ show_error_codes = true
warn_unused_ignores = true
warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude=["asv_bench/pkgs"]

[[tool.mypy.overrides]]
module=[
Expand Down
Loading