@@ -880,7 +880,6 @@ def _grouped_combine(
880
880
agg : Aggregation ,
881
881
axis : T_Axes ,
882
882
keepdims : bool ,
883
- neg_axis : T_Axes ,
884
883
engine : T_Engine ,
885
884
is_aggregate : bool = False ,
886
885
sort : bool = True ,
@@ -906,6 +905,9 @@ def _grouped_combine(
906
905
partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
907
906
)
908
907
908
+ # these are negative axis indices useful for concatenating the intermediates
909
+ neg_axis = tuple (range (- len (axis ), 0 ))
910
+
909
911
groups = _conc2 (x_chunk , "groups" , axis = neg_axis )
910
912
911
913
if agg .reduction_type == "argreduce" :
@@ -1068,6 +1070,30 @@ def _reduce_blockwise(
1068
1070
return result
1069
1071
1070
1072
1073
+ def _extract_unknown_groups (reduced , group_chunks , dtype ) -> tuple [DaskArray ]:
1074
+ import dask .array
1075
+ from dask .highlevelgraph import HighLevelGraph
1076
+
1077
+ layer : dict [tuple , tuple ] = {}
1078
+ groups_token = f"group-{ reduced .name } "
1079
+ first_block = reduced .ndim * (0 ,)
1080
+ layer [(groups_token , * first_block )] = (
1081
+ operator .getitem ,
1082
+ (reduced .name , * first_block ),
1083
+ "groups" ,
1084
+ )
1085
+ groups : tuple [DaskArray ] = (
1086
+ dask .array .Array (
1087
+ HighLevelGraph .from_collections (groups_token , layer , dependencies = [reduced ]),
1088
+ groups_token ,
1089
+ chunks = group_chunks ,
1090
+ meta = np .array ([], dtype = dtype ),
1091
+ ),
1092
+ )
1093
+
1094
+ return groups
1095
+
1096
+
1071
1097
def dask_groupby_agg (
1072
1098
array : DaskArray ,
1073
1099
by : DaskArray | np .ndarray ,
@@ -1189,14 +1215,11 @@ def dask_groupby_agg(
1189
1215
group_chunks = ((len (expected_groups ),) if expected_groups is not None else (np .nan ,),)
1190
1216
1191
1217
if method == "map-reduce" :
1192
- # these are negative axis indices useful for concatenating the intermediates
1193
- neg_axis = tuple (range (- len (axis ), 0 ))
1194
-
1195
1218
combine : Callable [..., IntermediateDict ]
1196
1219
if do_simple_combine :
1197
1220
combine = _simple_combine
1198
1221
else :
1199
- combine = partial (_grouped_combine , engine = engine , neg_axis = neg_axis , sort = sort )
1222
+ combine = partial (_grouped_combine , engine = engine , sort = sort )
1200
1223
1201
1224
# reduced is really a dict mapping reduction name to array
1202
1225
# and "groups" to an array of group labels
@@ -1219,10 +1242,19 @@ def dask_groupby_agg(
1219
1242
keepdims = True ,
1220
1243
concatenate = False ,
1221
1244
)
1222
- output_chunks = reduced .chunks [: - (len (axis ) + int (split_out > 1 ))] + group_chunks
1245
+
1246
+ if is_duck_dask_array (by_input ) and expected_groups is None :
1247
+ groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1248
+ else :
1249
+ if expected_groups is None :
1250
+ expected_groups_ = _get_expected_groups (by_input , sort = sort )
1251
+ else :
1252
+ expected_groups_ = expected_groups
1253
+ groups = (expected_groups_ .to_numpy (),)
1254
+
1223
1255
elif method == "blockwise" :
1224
1256
reduced = intermediate
1225
- # Here one input chunk → one output chunka
1257
+ # Here one input chunk → one output chunks
1226
1258
# find number of groups in each chunk, this is needed for output chunks
1227
1259
# along the reduced axis
1228
1260
slices = slices_from_chunks (tuple (array .chunks [ax ] for ax in axis ))
@@ -1235,41 +1267,17 @@ def dask_groupby_agg(
1235
1267
groups_in_block = tuple (
1236
1268
np .intersect1d (by_input [slc ], expected_groups ) for slc in slices
1237
1269
)
1270
+ groups = (np .concatenate (groups_in_block ),)
1271
+
1238
1272
ngroups_per_block = tuple (len (grp ) for grp in groups_in_block )
1239
- output_chunks = reduced .chunks [: - (len (axis ))] + (ngroups_per_block ,)
1273
+ group_chunks = (ngroups_per_block ,)
1274
+
1240
1275
else :
1241
1276
raise ValueError (f"Unknown method={ method } ." )
1242
1277
1243
1278
# extract results from the dict
1244
- layer : dict [ tuple , tuple ] = {}
1279
+ output_chunks = reduced . chunks [: - ( len ( axis ) + int ( split_out > 1 ))] + group_chunks
1245
1280
ochunks = tuple (range (len (chunks_v )) for chunks_v in output_chunks )
1246
- if is_duck_dask_array (by_input ) and expected_groups is None :
1247
- groups_name = f"groups-{ name } -{ token } "
1248
- # we've used keepdims=True, so _tree_reduce preserves some dummy dimensions
1249
- first_block = len (ochunks ) * (0 ,)
1250
- layer [(groups_name , * first_block )] = (
1251
- operator .getitem ,
1252
- (reduced .name , * first_block ),
1253
- "groups" ,
1254
- )
1255
- groups : tuple [np .ndarray | DaskArray ] = (
1256
- dask .array .Array (
1257
- HighLevelGraph .from_collections (groups_name , layer , dependencies = [reduced ]),
1258
- groups_name ,
1259
- chunks = group_chunks ,
1260
- dtype = by .dtype ,
1261
- ),
1262
- )
1263
- else :
1264
- if method == "map-reduce" :
1265
- if expected_groups is None :
1266
- expected_groups_ = _get_expected_groups (by_input , sort = sort )
1267
- else :
1268
- expected_groups_ = expected_groups
1269
- groups = (expected_groups_ .to_numpy (),)
1270
- else :
1271
- groups = (np .concatenate (groups_in_block ),)
1272
-
1273
1281
layer2 : dict [tuple , tuple ] = {}
1274
1282
agg_name = f"{ name } -{ token } "
1275
1283
for ochunk in itertools .product (* ochunks ):
@@ -1624,6 +1632,7 @@ def groupby_reduce(
1624
1632
f"\n \n Received: { func } "
1625
1633
)
1626
1634
1635
+ # TODO: just do this in dask_groupby_agg
1627
1636
# we always need some fill_value (see above) so choose the default if needed
1628
1637
if kwargs ["fill_value" ] is None :
1629
1638
kwargs ["fill_value" ] = agg .fill_value [agg .name ]
0 commit comments