@@ -68,7 +68,7 @@ def _get_expected_groups(by, sort, raise_if_dask=True) -> pd.Index | None:
68
68
expected = pd .unique (flatby [~ isnull (flatby )])
69
69
if sort :
70
70
expected = np .sort (expected )
71
- return _convert_expected_groups_to_index (expected , isbin = False )
71
+ return _convert_expected_groups_to_index (( expected ,), isbin = ( False ,))[ 0 ]
72
72
73
73
74
74
def _get_chunk_reduction (reduction_type : str ) -> Callable :
@@ -388,7 +388,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
388
388
389
389
390
390
def factorize_ (
391
- by : tuple , axis , expected_groups : tuple [pd .Index , ...] = None , reindex = False , sort = True
391
+ by : tuple ,
392
+ axis ,
393
+ expected_groups : tuple [pd .Index , ...] = None ,
394
+ reindex = False ,
395
+ sort = True ,
396
+ fastpath = False ,
392
397
):
393
398
"""
394
399
Returns an array of integer codes for groups (and associated data)
@@ -440,10 +445,13 @@ def factorize_(
440
445
grp_shape = tuple (len (grp ) for grp in found_groups )
441
446
ngroups = np .prod (grp_shape )
442
447
if len (by ) > 1 :
443
- group_idx = np .ravel_multi_index (factorized , grp_shape ). reshape ( by [ 0 ]. shape )
448
+ group_idx = np .ravel_multi_index (factorized , grp_shape )
444
449
else :
445
450
group_idx = factorized [0 ]
446
451
452
+ if fastpath :
453
+ return group_idx , found_groups , grp_shape
454
+
447
455
if np .isscalar (axis ) and groupvar .ndim > 1 :
448
456
# Not reducing along all dimensions of by
449
457
# this is OK because for 3D by and axis=(1,2),
@@ -1272,23 +1280,60 @@ def _assert_by_is_aligned(shape, by):
1272
1280
)
1273
1281
1274
1282
1275
- def _convert_expected_groups_to_index (expected_groups , isbin : bool ) -> pd .Index | None :
1276
- if isinstance (expected_groups , pd .IntervalIndex ) or (
1277
- isinstance (expected_groups , pd .Index ) and not isbin
1278
- ):
1279
- return expected_groups
1280
- if isbin :
1281
- return pd .IntervalIndex .from_arrays (expected_groups [:- 1 ], expected_groups [1 :])
1282
- elif expected_groups is not None :
1283
- return pd .Index (expected_groups )
1284
- return expected_groups
1283
+ def _convert_expected_groups_to_index (expected_groups : tuple , isbin : bool ) -> pd .Index | None :
1284
+ out = []
1285
+ for ex , isbin_ in zip (expected_groups , isbin ):
1286
+ if isinstance (ex , pd .IntervalIndex ) or (isinstance (ex , pd .Index ) and not isbin ):
1287
+ out .append (expected_groups )
1288
+ elif ex is not None :
1289
+ if isbin_ :
1290
+ out .append (pd .IntervalIndex .from_arrays (ex [:- 1 ], ex [1 :]))
1291
+ else :
1292
+ out .append (pd .Index (ex ))
1293
+ else :
1294
+ assert ex is None
1295
+ out .append (None )
1296
+ return tuple (out )
1297
+
1298
+
1299
+ def _lazy_factorize_wrapper (* by , ** kwargs ):
1300
+ group_idx , _ = factorize_ (by , ** kwargs )
1301
+ return group_idx
1302
+
1303
+
1304
+ def _factorize_multiple (by , expected_groups , by_is_dask ):
1305
+ kwargs = dict (
1306
+ expected_groups = expected_groups ,
1307
+ axis = None , # always None, we offset later if necessary.
1308
+ fastpath = True ,
1309
+ )
1310
+ if by_is_dask :
1311
+ import dask .array
1312
+
1313
+ group_idx = dask .array .map_blocks (
1314
+ _lazy_factorize_wrapper ,
1315
+ * np .broadcast_arrays (* by ),
1316
+ meta = np .array ((), dtype = np .int64 ),
1317
+ ** kwargs ,
1318
+ )
1319
+ found_groups = tuple (None if is_duck_dask_array (b ) else np .unique (b ) for b in by )
1320
+ else :
1321
+ group_idx , found_groups , grp_shape = factorize_ (by , ** kwargs )
1322
+
1323
+ final_groups = tuple (
1324
+ pd .Index (found ) if expect is None else expect
1325
+ for found , expect in zip (found_groups , expected_groups )
1326
+ )
1327
+
1328
+ if any (grp is None for grp in final_groups ):
1329
+ raise
1330
+ return (group_idx ,), final_groups , grp_shape
1285
1331
1286
1332
1287
1333
def groupby_reduce (
1288
1334
array : np .ndarray | DaskArray ,
1289
- by : np .ndarray | DaskArray ,
1335
+ * by : np .ndarray | DaskArray ,
1290
1336
func : str | Aggregation ,
1291
- * ,
1292
1337
expected_groups : Sequence | np .ndarray | None = None ,
1293
1338
sort : bool = True ,
1294
1339
isbin : bool = False ,
@@ -1402,6 +1447,7 @@ def groupby_reduce(
1402
1447
reindex = _validate_reindex (reindex , func , method , expected_groups )
1403
1448
1404
1449
by : tuple = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
1450
+ nby = len (by )
1405
1451
by_is_dask = any (is_duck_dask_array (b ) for b in by )
1406
1452
if not is_duck_array (array ):
1407
1453
array = np .asarray (array )
@@ -1423,6 +1469,20 @@ def groupby_reduce(
1423
1469
if expected_groups is not None and sort :
1424
1470
expected_groups = expected_groups .sort_values ()
1425
1471
1472
+ # when grouping by multiple variables, we factorize early.
1473
+ # TODO: could restrict this to dask-only
1474
+ if nby > 1 :
1475
+ by , final_groups , grp_shape = _factorize_multiple (
1476
+ by , expected_groups , by_is_dask = by_is_dask
1477
+ )
1478
+ expected_groups = (pd .RangeIndex (np .prod (grp_shape )),)
1479
+ else :
1480
+ final_groups = expected_groups
1481
+
1482
+ assert len (by ) == 1
1483
+ by = by [0 ]
1484
+ expected_groups = expected_groups [0 ]
1485
+
1426
1486
if axis is None :
1427
1487
axis = tuple (array .ndim + np .arange (- by .ndim , 0 ))
1428
1488
else :
@@ -1434,7 +1494,7 @@ def groupby_reduce(
1434
1494
)
1435
1495
1436
1496
# TODO: make sure expected_groups is unique
1437
- if len (axis ) == 1 and by_ndim > 1 and expected_groups [ 0 ] is None :
1497
+ if len (axis ) == 1 and by . ndim > 1 and expected_groups is None :
1438
1498
# TODO: hack
1439
1499
if not by_is_dask :
1440
1500
expected_groups = _get_expected_groups (by , sort )
@@ -1540,7 +1600,7 @@ def groupby_reduce(
1540
1600
result , * groups = partial_agg (
1541
1601
array ,
1542
1602
by ,
1543
- expected_groups = expected_groups ,
1603
+ expected_groups = None if method == "blockwise" else expected_groups ,
1544
1604
reindex = reindex ,
1545
1605
method = method ,
1546
1606
sort = sort ,
@@ -1552,4 +1612,7 @@ def groupby_reduce(
1552
1612
result = result [..., sorted_idx ]
1553
1613
groups = (groups [0 ][sorted_idx ],)
1554
1614
1615
+ if nby > 1 :
1616
+ groups = final_groups
1617
+ result = result .reshape (result .shape [:- 1 ] + grp_shape )
1555
1618
return (result , * groups )
0 commit comments