@@ -164,7 +164,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
164
164
labels = np .asarray (labels )
165
165
166
166
if method == "split-reduce" :
167
- return pd . unique (labels . ravel ()) .reshape (- 1 , 1 ).tolist ()
167
+ return _get_expected_groups (labels , sort = False ). values .reshape (- 1 , 1 ).tolist ()
168
168
169
169
# Build an array with the shape of labels, but where every element is the "chunk number"
170
170
# 1. First subset the array appropriately
@@ -630,6 +630,8 @@ def chunk_reduce(
630
630
# counts are needed for the final result as well as for masking
631
631
# optimize that out.
632
632
previous_reduction = None
633
+ for param in (fill_value , kwargs , dtype ):
634
+ assert len (param ) >= len (func )
633
635
for reduction , fv , kw , dt in zip (func , fill_value , kwargs , dtype ):
634
636
if empty :
635
637
result = np .full (shape = final_array_shape , fill_value = fv )
@@ -953,13 +955,10 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
953
955
Blockwise groupby reduction that produces the final result. This code path is
954
956
also used for non-dask array aggregations.
955
957
"""
956
-
957
958
# for pure numpy grouping, we just use npg directly and avoid "finalizing"
958
959
# (agg.finalize = None). We still need to do the reindexing step in finalize
959
960
# so that everything matches the dask version.
960
961
agg .finalize = None
961
- # xarray's count is npg's nanlen
962
- func : tuple [str ] = (agg .numpy , "nanlen" )
963
962
964
963
assert agg .finalize_kwargs is not None
965
964
finalize_kwargs = agg .finalize_kwargs
@@ -970,14 +969,14 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
970
969
results = chunk_reduce (
971
970
array ,
972
971
by ,
973
- func = func ,
972
+ func = agg . numpy ,
974
973
axis = axis ,
975
974
expected_groups = expected_groups ,
976
975
# This fill_value should only apply to groups that only contain NaN observations
977
976
# BUT there is funkiness when axis is a subset of all possible values
978
977
# (see below)
979
- fill_value = ( agg .fill_value [agg . name ], 0 ) ,
980
- dtype = ( agg .dtype [agg . name ], np . intp ) ,
978
+ fill_value = agg .fill_value ["numpy" ] ,
979
+ dtype = agg .dtype ["numpy" ] ,
981
980
kwargs = finalize_kwargs ,
982
981
engine = engine ,
983
982
sort = sort ,
@@ -989,36 +988,20 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
989
988
# so replace -1 with 0; unravel; then replace 0 with -1
990
989
# UGH!
991
990
idx = results ["intermediates" ][0 ]
992
- mask = idx == - 1
991
+ mask = idx == agg . fill_value [ "numpy" ][ 0 ]
993
992
idx [mask ] = 0
994
993
# Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
995
994
# will return wrong indices
996
995
idx = np .unravel_index (idx , array .shape )[- 1 ]
997
- idx [mask ] = - 1
996
+ idx [mask ] = agg . fill_value [ "numpy" ][ 0 ]
998
997
results ["intermediates" ][0 ] = idx
999
998
elif agg .name in ["nanvar" , "nanstd" ]:
1000
- # Fix npg bug where all-NaN rows are 0 instead of NaN
999
+ # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1001
1000
value , counts = results ["intermediates" ]
1002
1001
mask = counts <= 0
1003
1002
value [mask ] = np .nan
1004
1003
results ["intermediates" ][0 ] = value
1005
1004
1006
- # When axis is a subset of possible values; then npg will
1007
- # apply it to groups that don't exist along a particular axis (for e.g.)
1008
- # since these count as a group that is absent. thoo!
1009
- # TODO: the "count" bit is a hack to make tests pass.
1010
- if len (axis ) < by .ndim and agg .min_count is None and agg .name != "count" :
1011
- agg .min_count = 1
1012
-
1013
- # This fill_value applies to members of expected_groups not seen in groups
1014
- # or when the min_count threshold is not satisfied
1015
- # Use xarray's dtypes.NA to match type promotion rules
1016
- if fill_value is None :
1017
- if agg .name in ["any" , "all" ]:
1018
- fill_value = False
1019
- elif not _is_arg_reduction (agg ):
1020
- fill_value = xrdtypes .NA
1021
-
1022
1005
result = _finalize_results (results , agg , axis , expected_groups , fill_value = fill_value )
1023
1006
return result
1024
1007
@@ -1444,20 +1427,33 @@ def groupby_reduce(
1444
1427
array = _move_reduce_dims_to_end (array , axis )
1445
1428
axis = tuple (array .ndim + np .arange (- len (axis ), 0 ))
1446
1429
1430
+ has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by )
1431
+
1432
+ # When axis is a subset of possible values; then npg will
1433
+ # apply it to groups that don't exist along a particular axis (for e.g.)
1434
+ # since these count as a group that is absent. thoo!
1435
+ # fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
1436
+ # The only way to do this consistently is mask out using min_count
1437
+ # Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
1438
+ if min_count is None :
1439
+ if (
1440
+ len (axis ) < by .ndim
1441
+ or fill_value is not None
1442
+ # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1443
+ or (not has_dask and isinstance (func , str ) and func in ["nanvar" , "nanstd" ])
1444
+ ):
1445
+ min_count = 1
1446
+
1447
+ # TODO: set in xarray?
1447
1448
if min_count is not None and func in ["nansum" , "nanprod" ] and fill_value is None :
1448
1449
# nansum, nanprod have fill_value=0, 1
1449
1450
# overwrite than when min_count is set
1450
1451
fill_value = np .nan
1451
1452
1452
- agg = _initialize_aggregation (func , array .dtype , fill_value )
1453
- agg .min_count = min_count
1454
- if finalize_kwargs is not None :
1455
- assert isinstance (finalize_kwargs , dict )
1456
- agg .finalize_kwargs = finalize_kwargs
1457
-
1458
1453
kwargs = dict (axis = axis , fill_value = fill_value , engine = engine , sort = sort )
1454
+ agg = _initialize_aggregation (func , array .dtype , fill_value , min_count , finalize_kwargs )
1459
1455
1460
- if not is_duck_dask_array ( array ) and not is_duck_dask_array ( by ) :
1456
+ if not has_dask :
1461
1457
results = _reduce_blockwise (array , by , agg , expected_groups = expected_groups , ** kwargs )
1462
1458
groups = (results ["groups" ],)
1463
1459
result = results [agg .name ]
@@ -1466,21 +1462,10 @@ def groupby_reduce(
1466
1462
if agg .chunk is None :
1467
1463
raise NotImplementedError (f"{ func } not implemented for dask arrays" )
1468
1464
1469
- if agg .min_count is None :
1470
- # This is needed for the dask pathway.
1471
- # Because we use intermediate fill_value since a group could be
1472
- # absent in one block, but present in another block
1473
- agg .min_count = 1
1474
-
1475
1465
# we always need some fill_value (see above) so choose the default if needed
1476
1466
if kwargs ["fill_value" ] is None :
1477
1467
kwargs ["fill_value" ] = agg .fill_value [agg .name ]
1478
1468
1479
- agg .chunk += ("nanlen" ,)
1480
- agg .combine += ("sum" ,)
1481
- agg .fill_value ["intermediate" ] += (0 ,)
1482
- agg .dtype ["intermediate" ] += (np .intp ,)
1483
-
1484
1469
partial_agg = partial (dask_groupby_agg , agg = agg , split_out = split_out , ** kwargs )
1485
1470
1486
1471
if method in ["split-reduce" , "cohorts" ]:
0 commit comments