Skip to content

Commit 8fc1977

Browse files
authored
Refactor before redoing cohorts (#164)
1 parent 91b6e19 commit 8fc1977

File tree

2 files changed

+46
-37
lines changed

2 files changed

+46
-37
lines changed

asv_bench/benchmarks/combine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ def construct_member(groups):
5858
]
5959

6060
self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
61-
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,), "neg_axis": (-1,)}
61+
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)}

flox/core.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,6 @@ def _grouped_combine(
880880
agg: Aggregation,
881881
axis: T_Axes,
882882
keepdims: bool,
883-
neg_axis: T_Axes,
884883
engine: T_Engine,
885884
is_aggregate: bool = False,
886885
sort: bool = True,
@@ -906,6 +905,9 @@ def _grouped_combine(
906905
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
907906
)
908907

908+
# these are negative axis indices useful for concatenating the intermediates
909+
neg_axis = tuple(range(-len(axis), 0))
910+
909911
groups = _conc2(x_chunk, "groups", axis=neg_axis)
910912

911913
if agg.reduction_type == "argreduce":
@@ -1068,6 +1070,30 @@ def _reduce_blockwise(
10681070
return result
10691071

10701072

1073+
def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
1074+
import dask.array
1075+
from dask.highlevelgraph import HighLevelGraph
1076+
1077+
layer: dict[tuple, tuple] = {}
1078+
groups_token = f"group-{reduced.name}"
1079+
first_block = reduced.ndim * (0,)
1080+
layer[(groups_token, *first_block)] = (
1081+
operator.getitem,
1082+
(reduced.name, *first_block),
1083+
"groups",
1084+
)
1085+
groups: tuple[DaskArray] = (
1086+
dask.array.Array(
1087+
HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]),
1088+
groups_token,
1089+
chunks=group_chunks,
1090+
meta=np.array([], dtype=dtype),
1091+
),
1092+
)
1093+
1094+
return groups
1095+
1096+
10711097
def dask_groupby_agg(
10721098
array: DaskArray,
10731099
by: DaskArray | np.ndarray,
@@ -1189,14 +1215,11 @@ def dask_groupby_agg(
11891215
group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
11901216

11911217
if method == "map-reduce":
1192-
# these are negative axis indices useful for concatenating the intermediates
1193-
neg_axis = tuple(range(-len(axis), 0))
1194-
11951218
combine: Callable[..., IntermediateDict]
11961219
if do_simple_combine:
11971220
combine = _simple_combine
11981221
else:
1199-
combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort)
1222+
combine = partial(_grouped_combine, engine=engine, sort=sort)
12001223

12011224
# reduced is really a dict mapping reduction name to array
12021225
# and "groups" to an array of group labels
@@ -1219,10 +1242,19 @@ def dask_groupby_agg(
12191242
keepdims=True,
12201243
concatenate=False,
12211244
)
1222-
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
1245+
1246+
if is_duck_dask_array(by_input) and expected_groups is None:
1247+
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
1248+
else:
1249+
if expected_groups is None:
1250+
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1251+
else:
1252+
expected_groups_ = expected_groups
1253+
groups = (expected_groups_.to_numpy(),)
1254+
12231255
elif method == "blockwise":
12241256
reduced = intermediate
1225-
# Here one input chunk → one output chunka
1257+
# Here one input chunk → one output chunks
12261258
# find number of groups in each chunk, this is needed for output chunks
12271259
# along the reduced axis
12281260
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
@@ -1235,41 +1267,17 @@ def dask_groupby_agg(
12351267
groups_in_block = tuple(
12361268
np.intersect1d(by_input[slc], expected_groups) for slc in slices
12371269
)
1270+
groups = (np.concatenate(groups_in_block),)
1271+
12381272
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
1239-
output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,)
1273+
group_chunks = (ngroups_per_block,)
1274+
12401275
else:
12411276
raise ValueError(f"Unknown method={method}.")
12421277

12431278
# extract results from the dict
1244-
layer: dict[tuple, tuple] = {}
1279+
output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
12451280
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
1246-
if is_duck_dask_array(by_input) and expected_groups is None:
1247-
groups_name = f"groups-{name}-{token}"
1248-
# we've used keepdims=True, so _tree_reduce preserves some dummy dimensions
1249-
first_block = len(ochunks) * (0,)
1250-
layer[(groups_name, *first_block)] = (
1251-
operator.getitem,
1252-
(reduced.name, *first_block),
1253-
"groups",
1254-
)
1255-
groups: tuple[np.ndarray | DaskArray] = (
1256-
dask.array.Array(
1257-
HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]),
1258-
groups_name,
1259-
chunks=group_chunks,
1260-
dtype=by.dtype,
1261-
),
1262-
)
1263-
else:
1264-
if method == "map-reduce":
1265-
if expected_groups is None:
1266-
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1267-
else:
1268-
expected_groups_ = expected_groups
1269-
groups = (expected_groups_.to_numpy(),)
1270-
else:
1271-
groups = (np.concatenate(groups_in_block),)
1272-
12731281
layer2: dict[tuple, tuple] = {}
12741282
agg_name = f"{name}-{token}"
12751283
for ochunk in itertools.product(*ochunks):
@@ -1624,6 +1632,7 @@ def groupby_reduce(
16241632
f"\n\n Received: {func}"
16251633
)
16261634

1635+
# TODO: just do this in dask_groupby_agg
16271636
# we always need some fill_value (see above) so choose the default if needed
16281637
if kwargs["fill_value"] is None:
16291638
kwargs["fill_value"] = agg.fill_value[agg.name]

0 commit comments

Comments
 (0)