Skip to content

Commit 35dd38d

Browse files
committed
More consistent fill_value handling.
fill_value now applies to groups with no observations and groups with all-NaN observations. This seems to be the only way to keep the dask and numpy pathways consistent.
1 parent 92d8442 commit 35dd38d

File tree

4 files changed

+146
-75
lines changed

4 files changed

+146
-75
lines changed

flox/aggregate_npg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import numpy_groupies as npg
35

@@ -64,3 +66,23 @@ def nansum_of_squares(group_idx, array, engine, *, axis=-1, size=None, fill_valu
6466
axis=axis,
6567
dtype=dtype,
6668
)
69+
70+
71+
def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None, dtype=None):
72+
result = _get_aggregate(engine).aggregate(
73+
group_idx,
74+
array,
75+
axis=axis,
76+
func=func,
77+
size=size,
78+
fill_value=0,
79+
dtype=np.int64,
80+
)
81+
if fill_value is not None:
82+
result = result.astype(np.array([fill_value]).dtype)
83+
result[result == 0] = fill_value
84+
return result
85+
86+
87+
len = partial(_len, func="len")
88+
nanlen = partial(_len, func="nanlen")

flox/aggregations.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self.preprocess = preprocess
108108
# Use "chunk_reduce" or "chunk_argreduce"
109109
self.reduction_type = reduction_type
110-
self.numpy = numpy if numpy else self.name
110+
self.numpy = (numpy,) if numpy else (self.name,)
111111
# initialize blockwise reduction
112112
self.chunk = _atleast_1d(chunk)
113113
# how to aggregate results after first round of reduction
@@ -163,6 +163,7 @@ def __repr__(self):
163163
f"combine: {self.combine}",
164164
f"aggregate: {self.aggregate}",
165165
f"finalize: {self.finalize}",
166+
f"min_count: {self.min_count}",
166167
)
167168
)
168169

@@ -265,9 +266,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
265266

266267

267268
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
268-
nanmin = Aggregation("nanmin", chunk="nanmin", combine="min", fill_value=dtypes.INF)
269+
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
269270
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
270-
nanmax = Aggregation("nanmax", chunk="nanmax", combine="max", fill_value=dtypes.NINF)
271+
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
271272

272273

273274
def argreduce_preprocess(array, axis):
@@ -409,7 +410,13 @@ def _zip_index(array_, idx_):
409410
}
410411

411412

412-
def _initialize_aggregation(func: str | Aggregation, array_dtype, fill_value) -> Aggregation:
413+
def _initialize_aggregation(
414+
func: str | Aggregation,
415+
array_dtype,
416+
fill_value,
417+
min_count: int,
418+
finalize_kwargs,
419+
) -> Aggregation:
413420
if not isinstance(func, Aggregation):
414421
try:
415422
# TODO: need better interface
@@ -425,6 +432,7 @@ def _initialize_aggregation(func: str | Aggregation, array_dtype, fill_value) ->
425432
raise ValueError("Bad type for func. Expected str or Aggregation")
426433

427434
agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value)
435+
agg.dtype["numpy"] = (agg.dtype[func],)
428436
agg.dtype["intermediate"] = [
429437
_normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"]
430438
]
@@ -435,4 +443,27 @@ def _initialize_aggregation(func: str | Aggregation, array_dtype, fill_value) ->
435443
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
436444
)
437445
agg.fill_value[func] = _get_fill_value(agg.dtype[func], agg.fill_value[func])
446+
447+
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
448+
agg.fill_value["numpy"] = (fv,)
449+
450+
if finalize_kwargs is not None:
451+
assert isinstance(finalize_kwargs, dict)
452+
agg.finalize_kwargs = finalize_kwargs
453+
454+
# This is needed for the dask pathway.
455+
# Because we use intermediate fill_value since a group could be
456+
# absent in one block, but present in another block
457+
# We set it for numpy to get nansum, nanprod tests to pass
458+
# where the identity element is 0, 1
459+
if min_count is not None:
460+
agg.min_count = min_count
461+
agg.chunk += ("nanlen",)
462+
agg.numpy += ("nanlen",)
463+
agg.combine += ("sum",)
464+
agg.fill_value["intermediate"] += (0,)
465+
agg.fill_value["numpy"] += (0,)
466+
agg.dtype["intermediate"] += (np.intp,)
467+
agg.dtype["numpy"] += (np.intp,)
468+
438469
return agg

flox/core.py

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
164164
labels = np.asarray(labels)
165165

166166
if method == "split-reduce":
167-
return pd.unique(labels.ravel()).reshape(-1, 1).tolist()
167+
return _get_expected_groups(labels, sort=False).values.reshape(-1, 1).tolist()
168168

169169
# Build an array with the shape of labels, but where every element is the "chunk number"
170170
# 1. First subset the array appropriately
@@ -630,6 +630,8 @@ def chunk_reduce(
630630
# counts are needed for the final result as well as for masking
631631
# optimize that out.
632632
previous_reduction = None
633+
for param in (fill_value, kwargs, dtype):
634+
assert len(param) >= len(func)
633635
for reduction, fv, kw, dt in zip(func, fill_value, kwargs, dtype):
634636
if empty:
635637
result = np.full(shape=final_array_shape, fill_value=fv)
@@ -953,13 +955,10 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
953955
Blockwise groupby reduction that produces the final result. This code path is
954956
also used for non-dask array aggregations.
955957
"""
956-
957958
# for pure numpy grouping, we just use npg directly and avoid "finalizing"
958959
# (agg.finalize = None). We still need to do the reindexing step in finalize
959960
# so that everything matches the dask version.
960961
agg.finalize = None
961-
# xarray's count is npg's nanlen
962-
func: tuple[str] = (agg.numpy, "nanlen")
963962

964963
assert agg.finalize_kwargs is not None
965964
finalize_kwargs = agg.finalize_kwargs
@@ -970,14 +969,14 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
970969
results = chunk_reduce(
971970
array,
972971
by,
973-
func=func,
972+
func=agg.numpy,
974973
axis=axis,
975974
expected_groups=expected_groups,
976975
# This fill_value should only apply to groups that only contain NaN observations
977976
# BUT there is funkiness when axis is a subset of all possible values
978977
# (see below)
979-
fill_value=(agg.fill_value[agg.name], 0),
980-
dtype=(agg.dtype[agg.name], np.intp),
978+
fill_value=agg.fill_value["numpy"],
979+
dtype=agg.dtype["numpy"],
981980
kwargs=finalize_kwargs,
982981
engine=engine,
983982
sort=sort,
@@ -989,36 +988,20 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
989988
# so replace -1 with 0; unravel; then replace 0 with -1
990989
# UGH!
991990
idx = results["intermediates"][0]
992-
mask = idx == -1
991+
mask = idx == agg.fill_value["numpy"][0]
993992
idx[mask] = 0
994993
# Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
995994
# will return wrong indices
996995
idx = np.unravel_index(idx, array.shape)[-1]
997-
idx[mask] = -1
996+
idx[mask] = agg.fill_value["numpy"][0]
998997
results["intermediates"][0] = idx
999998
elif agg.name in ["nanvar", "nanstd"]:
1000-
# Fix npg bug where all-NaN rows are 0 instead of NaN
999+
# TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
10011000
value, counts = results["intermediates"]
10021001
mask = counts <= 0
10031002
value[mask] = np.nan
10041003
results["intermediates"][0] = value
10051004

1006-
# When axis is a subset of possible values; then npg will
1007-
# apply it to groups that don't exist along a particular axis (for e.g.)
1008-
# since these count as a group that is absent. thoo!
1009-
# TODO: the "count" bit is a hack to make tests pass.
1010-
if len(axis) < by.ndim and agg.min_count is None and agg.name != "count":
1011-
agg.min_count = 1
1012-
1013-
# This fill_value applies to members of expected_groups not seen in groups
1014-
# or when the min_count threshold is not satisfied
1015-
# Use xarray's dtypes.NA to match type promotion rules
1016-
if fill_value is None:
1017-
if agg.name in ["any", "all"]:
1018-
fill_value = False
1019-
elif not _is_arg_reduction(agg):
1020-
fill_value = xrdtypes.NA
1021-
10221005
result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value)
10231006
return result
10241007

@@ -1444,20 +1427,33 @@ def groupby_reduce(
14441427
array = _move_reduce_dims_to_end(array, axis)
14451428
axis = tuple(array.ndim + np.arange(-len(axis), 0))
14461429

1430+
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by)
1431+
1432+
# When axis is a subset of possible values; then npg will
1433+
# apply it to groups that don't exist along a particular axis (for e.g.)
1434+
# since these count as a group that is absent. thoo!
1435+
# fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
1436+
# The only way to do this consistently is mask out using min_count
1437+
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
1438+
if min_count is None:
1439+
if (
1440+
len(axis) < by.ndim
1441+
or fill_value is not None
1442+
# TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1443+
or (not has_dask and isinstance(func, str) and func in ["nanvar", "nanstd"])
1444+
):
1445+
min_count = 1
1446+
1447+
# TODO: set in xarray?
14471448
if min_count is not None and func in ["nansum", "nanprod"] and fill_value is None:
14481449
# nansum, nanprod have fill_value=0, 1
14491450
# overwrite than when min_count is set
14501451
fill_value = np.nan
14511452

1452-
agg = _initialize_aggregation(func, array.dtype, fill_value)
1453-
agg.min_count = min_count
1454-
if finalize_kwargs is not None:
1455-
assert isinstance(finalize_kwargs, dict)
1456-
agg.finalize_kwargs = finalize_kwargs
1457-
14581453
kwargs = dict(axis=axis, fill_value=fill_value, engine=engine, sort=sort)
1454+
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
14591455

1460-
if not is_duck_dask_array(array) and not is_duck_dask_array(by):
1456+
if not has_dask:
14611457
results = _reduce_blockwise(array, by, agg, expected_groups=expected_groups, **kwargs)
14621458
groups = (results["groups"],)
14631459
result = results[agg.name]
@@ -1466,21 +1462,10 @@ def groupby_reduce(
14661462
if agg.chunk is None:
14671463
raise NotImplementedError(f"{func} not implemented for dask arrays")
14681464

1469-
if agg.min_count is None:
1470-
# This is needed for the dask pathway.
1471-
# Because we use intermediate fill_value since a group could be
1472-
# absent in one block, but present in another block
1473-
agg.min_count = 1
1474-
14751465
# we always need some fill_value (see above) so choose the default if needed
14761466
if kwargs["fill_value"] is None:
14771467
kwargs["fill_value"] = agg.fill_value[agg.name]
14781468

1479-
agg.chunk += ("nanlen",)
1480-
agg.combine += ("sum",)
1481-
agg.fill_value["intermediate"] += (0,)
1482-
agg.dtype["intermediate"] += (np.intp,)
1483-
14841469
partial_agg = partial(dask_groupby_agg, agg=agg, split_out=split_out, **kwargs)
14851470

14861471
if method in ["split-reduce", "cohorts"]:

0 commit comments

Comments
 (0)