Skip to content

Commit b2be395

Browse files
committed
WIP
1 parent 0db44f7 commit b2be395

File tree

1 file changed

+80
-17
lines changed

1 file changed

+80
-17
lines changed

flox/core.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _get_expected_groups(by, sort, raise_if_dask=True) -> pd.Index | None:
6868
expected = pd.unique(flatby[~isnull(flatby)])
6969
if sort:
7070
expected = np.sort(expected)
71-
return _convert_expected_groups_to_index(expected, isbin=False)
71+
return _convert_expected_groups_to_index((expected,), isbin=(False,))[0]
7272

7373

7474
def _get_chunk_reduction(reduction_type: str) -> Callable:
@@ -388,7 +388,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
388388

389389

390390
def factorize_(
391-
by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False, sort=True
391+
by: tuple,
392+
axis,
393+
expected_groups: tuple[pd.Index, ...] = None,
394+
reindex=False,
395+
sort=True,
396+
fastpath=False,
392397
):
393398
"""
394399
Returns an array of integer codes for groups (and associated data)
@@ -440,10 +445,13 @@ def factorize_(
440445
grp_shape = tuple(len(grp) for grp in found_groups)
441446
ngroups = np.prod(grp_shape)
442447
if len(by) > 1:
443-
group_idx = np.ravel_multi_index(factorized, grp_shape).reshape(by[0].shape)
448+
group_idx = np.ravel_multi_index(factorized, grp_shape)
444449
else:
445450
group_idx = factorized[0]
446451

452+
if fastpath:
453+
return group_idx, found_groups, grp_shape
454+
447455
if np.isscalar(axis) and groupvar.ndim > 1:
448456
# Not reducing along all dimensions of by
449457
# this is OK because for 3D by and axis=(1,2),
@@ -1272,23 +1280,60 @@ def _assert_by_is_aligned(shape, by):
12721280
)
12731281

12741282

1275-
def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
1276-
if isinstance(expected_groups, pd.IntervalIndex) or (
1277-
isinstance(expected_groups, pd.Index) and not isbin
1278-
):
1279-
return expected_groups
1280-
if isbin:
1281-
return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:])
1282-
elif expected_groups is not None:
1283-
return pd.Index(expected_groups)
1284-
return expected_groups
1283+
def _convert_expected_groups_to_index(expected_groups: tuple, isbin: bool) -> pd.Index | None:
1284+
out = []
1285+
for ex, isbin_ in zip(expected_groups, isbin):
1286+
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
1287+
out.append(expected_groups)
1288+
elif ex is not None:
1289+
if isbin_:
1290+
out.append(pd.IntervalIndex.from_arrays(ex[:-1], ex[1:]))
1291+
else:
1292+
out.append(pd.Index(ex))
1293+
else:
1294+
assert ex is None
1295+
out.append(None)
1296+
return tuple(out)
1297+
1298+
1299+
def _lazy_factorize_wrapper(*by, **kwargs):
1300+
group_idx, _ = factorize_(by, **kwargs)
1301+
return group_idx
1302+
1303+
1304+
def _factorize_multiple(by, expected_groups, by_is_dask):
1305+
kwargs = dict(
1306+
expected_groups=expected_groups,
1307+
axis=None, # always None, we offset later if necessary.
1308+
fastpath=True,
1309+
)
1310+
if by_is_dask:
1311+
import dask.array
1312+
1313+
group_idx = dask.array.map_blocks(
1314+
_lazy_factorize_wrapper,
1315+
*np.broadcast_arrays(*by),
1316+
meta=np.array((), dtype=np.int64),
1317+
**kwargs,
1318+
)
1319+
found_groups = tuple(None if is_duck_dask_array(b) else np.unique(b) for b in by)
1320+
else:
1321+
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)
1322+
1323+
final_groups = tuple(
1324+
pd.Index(found) if expect is None else expect
1325+
for found, expect in zip(found_groups, expected_groups)
1326+
)
1327+
1328+
if any(grp is None for grp in final_groups):
1329+
raise
1330+
return (group_idx,), final_groups, grp_shape
12851331

12861332

12871333
def groupby_reduce(
12881334
array: np.ndarray | DaskArray,
1289-
by: np.ndarray | DaskArray,
1335+
*by: np.ndarray | DaskArray,
12901336
func: str | Aggregation,
1291-
*,
12921337
expected_groups: Sequence | np.ndarray | None = None,
12931338
sort: bool = True,
12941339
isbin: bool = False,
@@ -1402,6 +1447,7 @@ def groupby_reduce(
14021447
reindex = _validate_reindex(reindex, func, method, expected_groups)
14031448

14041449
by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
1450+
nby = len(by)
14051451
by_is_dask = any(is_duck_dask_array(b) for b in by)
14061452
if not is_duck_array(array):
14071453
array = np.asarray(array)
@@ -1423,6 +1469,20 @@ def groupby_reduce(
14231469
if expected_groups is not None and sort:
14241470
expected_groups = expected_groups.sort_values()
14251471

1472+
# when grouping by multiple variables, we factorize early.
1473+
# TODO: could restrict this to dask-only
1474+
if nby > 1:
1475+
by, final_groups, grp_shape = _factorize_multiple(
1476+
by, expected_groups, by_is_dask=by_is_dask
1477+
)
1478+
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)
1479+
else:
1480+
final_groups = expected_groups
1481+
1482+
assert len(by) == 1
1483+
by = by[0]
1484+
expected_groups = expected_groups[0]
1485+
14261486
if axis is None:
14271487
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
14281488
else:
@@ -1434,7 +1494,7 @@ def groupby_reduce(
14341494
)
14351495

14361496
# TODO: make sure expected_groups is unique
1437-
if len(axis) == 1 and by_ndim > 1 and expected_groups[0] is None:
1497+
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
14381498
# TODO: hack
14391499
if not by_is_dask:
14401500
expected_groups = _get_expected_groups(by, sort)
@@ -1540,7 +1600,7 @@ def groupby_reduce(
15401600
result, *groups = partial_agg(
15411601
array,
15421602
by,
1543-
expected_groups=expected_groups,
1603+
expected_groups=None if method == "blockwise" else expected_groups,
15441604
reindex=reindex,
15451605
method=method,
15461606
sort=sort,
@@ -1552,4 +1612,7 @@ def groupby_reduce(
15521612
result = result[..., sorted_idx]
15531613
groups = (groups[0][sorted_idx],)
15541614

1615+
if nby > 1:
1616+
groups = final_groups
1617+
result = result.reshape(result.shape[:-1] + grp_shape)
15551618
return (result, *groups)

0 commit comments

Comments
 (0)