From 6b93ad6cc4a6b6144c75867bc5c64405bd8575c4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 12:46:47 -0600 Subject: [PATCH 01/14] More efficient cohorts. Closes #140 We apply the cohort "split" step after the blockwise reduction, then use the tree reduction on each cohort. We also use the `.blocks` accessor to index out blocks. This is still a bit inefficient since we split by indexing out regular arrays, so we could index out blocks that don't contain any cohort members. However, because we are splitting _after_ the blockwise reduction, the amount of work duplication can be a lot less than splitting the bare array. One side-effect is that "split-reduce" is now a synonym for "cohorts". The reason is that find_group_cohorts returns a dict mapping blocks to cohorts. We could invert that behaviour but I don't see any benefit to trying to figure that out. --- flox/core.py | 202 ++++++++++++++++++++++++++------------------- tests/test_core.py | 4 +- 2 files changed, 121 insertions(+), 85 deletions(-) diff --git a/flox/core.py b/flox/core.py index 10a9197a9..fdc33a70a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -138,7 +138,7 @@ def _get_optimal_chunks_for_groups(chunks, labels): @memoize -def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "cohorts"): +def find_group_cohorts(labels, chunks, merge: bool = True, method: T_MethodCohorts = "cohorts"): """ Finds groups labels that occur together aka "cohorts" @@ -168,9 +168,6 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co # To do this, we must have values in memory so casting to numpy should be safe labels = np.asarray(labels) - if method == "split-reduce": - return list(_get_expected_groups(labels, sort=False).to_numpy().reshape(-1, 1)) - # Build an array with the shape of labels, but where every element is the "chunk number" # 1. First subset the array appropriately axis = range(-labels.ndim, 0) @@ -219,9 +216,9 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co merged_cohorts[k1].extend(v2) merged_keys.append(k2) - return merged_cohorts.values() + return merged_cohorts else: - return chunks_cohorts.values() + return chunks_cohorts def rechunk_for_cohorts( @@ -1070,6 +1067,43 @@ def _reduce_blockwise( return result +def subset_to_blocks( + array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None +) -> DaskArray: + """ + Advanced indexing of .blocks such that we always get a regular array back. + + Parameters + ---------- + array : dask.array + flatblocks : flat indices of blocks to extract + blkshape : shape of blocks with which to unravel flatblocks + + Returns + ------- + dask.array + """ + if blkshape is None: + blkshape = array.blocks.shape + + unraveled = np.unravel_index(flatblocks, blkshape) + if len(flatblocks) == 1: + idxr = [slice(None, None)] * (array.ndim - len(blkshape)) + idxr = idxr + [a.item() for a in np.unravel_index(flatblocks, blkshape)] + return array.blocks[tuple(idxr)] + + subset = array + for ax, inds in enumerate(unraveled): + inds = np.unique(inds) + if np.array_equal(inds, np.arange(blkshape[ax])): + continue + + idxr = [slice(None, None)] * array.ndim + idxr[array.ndim - len(blkshape) + ax] = inds + subset = subset.blocks[tuple(idxr)] + return subset + + def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph @@ -1106,6 +1140,7 @@ def dask_groupby_agg( reindex: bool = False, engine: T_Engine = "numpy", sort: bool = True, + chunks_cohorts=None, ) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: import dask.array @@ -1185,7 +1220,7 @@ def dask_groupby_agg( partial( blockwise_method, axis=axis, - expected_groups=expected_groups, + expected_groups=None if method in ["split-reduce", "cohorts"] else expected_groups, engine=engine, sort=sort, ), @@ -1214,44 +1249,80 @@ def dask_groupby_agg( expected_groups = _get_expected_groups(by_input, sort=sort) group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),) - if method == "map-reduce": + if method in ["map-reduce", "cohorts", "split-reduce"]: combine: Callable[..., IntermediateDict] if do_simple_combine: combine = _simple_combine else: combine = partial(_grouped_combine, engine=engine, sort=sort) - # reduced is really a dict mapping reduction name to array - # and "groups" to an array of group labels + # Each chunk of `reduced`` is really a dict mapping + # 1. reduction name to array + # 2. "groups" to an array of group labels # Note: it does not make sense to interpret axis relative to # shape of intermediate results after the blockwise call - reduced = dask.array.reductions._tree_reduce( - intermediate, - aggregate=partial( - _aggregate, - combine=combine, - agg=agg, - expected_groups=None if split_out > 1 else expected_groups, - fill_value=fill_value, - reindex=reindex, - ), + tree_reduce = partial( + dask.array.reductions._tree_reduce, combine=partial(combine, agg=agg), - name=f"{name}-reduce", + name=f"{name}-reduce-{method}", dtype=array.dtype, axis=axis, keepdims=True, concatenate=False, ) - - if is_duck_dask_array(by_input) and expected_groups is None: - groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) - else: - if expected_groups is None: - expected_groups_ = _get_expected_groups(by_input, sort=sort) + aggregate = partial( + _aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex + ) + if method == "map-reduce": + reduced = tree_reduce( + intermediate, + aggregate=partial( + aggregate, expected_groups=None if split_out > 1 else expected_groups + ), + ) + if is_duck_dask_array(by_input) and expected_groups is None: + groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) else: - expected_groups_ = expected_groups + if expected_groups is None: + expected_groups_ = _get_expected_groups(by_input, sort=sort) + else: + expected_groups_ = expected_groups groups = (expected_groups_.to_numpy(),) + elif method in ["cohorts", "split-reduce"]: + chunks_cohorts = find_group_cohorts( + by, [array.chunks[ax] for ax in axis], merge=True, method=method + ) + reduced_ = [] + groups_ = [] + for blks, cohort in chunks_cohorts.items(): + cohort = sorted(cohort) + subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) + if do_simple_combine: + # reindex so that reindex can be set to True later + reindexed = dask.array.map_blocks( + reindex_intermediates, + subset, + agg=agg, + unique_groups=cohort, + meta=subset._meta, + ) + else: + reindexed = subset + + reduced_.append( + tree_reduce( + reindexed, + aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex), + ) + ) + groups_.append(cohort) + + reduced = dask.array.concatenate(reduced_, axis=-1) + groups = (np.concatenate(groups_),) + group_chunks = (tuple(len(cohort) for cohort in groups_),) + # compute_blocks(reduced) + elif method == "blockwise": reduced = intermediate # Here one input chunk → one output chunks @@ -1288,7 +1359,12 @@ def dask_groupby_agg( nblocks = tuple(len(array.chunks[ax]) for ax in axis) inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks) else: - inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1) + inchunk = ( + ochunk[:-1] + + (0,) * (len(axis) - 1) + + (ochunk[-1],) # always 0 for map-reduce, something else for cohorts, split-reduce + + (ochunk[-1],) * int((split_out > 1)) + ) layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) result = dask.array.Array( @@ -1317,6 +1393,9 @@ def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_gro if method in ["split-reduce", "cohorts"] and reindex is False: raise NotImplementedError + if method in ["split-reduce", "cohorts"] and reindex is None: + reindex = True + # TODO: Should reindex be a bool-only at this point? Would've been nice but # None's are relied on after this function as well. return reindex @@ -1639,61 +1718,18 @@ def groupby_reduce( partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs) - if method in ["split-reduce", "cohorts"]: - cohorts = find_group_cohorts( - by_, [array.chunks[ax] for ax in axis_], merge=True, method=method - ) + if method == "blockwise" and by_.ndim == 1: + array = rechunk_for_blockwise(array, axis=-1, labels=by_) - results_ = [] - groups_ = [] - for cohort in cohorts: - cohort = sorted(cohort) - # equivalent of xarray.DataArray.where(mask, drop=True) - mask = np.isin(by_, cohort) - indexer = [np.unique(v) for v in np.nonzero(mask)] - array_subset = array - for ax, idxr in zip(range(-by_.ndim, 0), indexer): - array_subset = np.take(array_subset, idxr, axis=ax) - numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_]) - - # get final result for these groups - r, *g = partial_agg( - array_subset, - by_[np.ix_(*indexer)], - expected_groups=pd.Index(cohort), - # First deep copy becasue we might be doping blockwise, - # which sets agg.finalize=None, then map-reduce (GH102) - agg=copy.deepcopy(agg), - # reindex to expected_groups at the blockwise step. - # this approach avoids replacing non-cohort members with - # np.nan or some other sentinel value, and preserves dtypes - reindex=True, - # sort controls the final output order so apply that at the end - sort=False, - # if only a single block along axis, we can just work blockwise - # inspired by https://github.com/dask/dask/issues/8361 - method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce", - ) - results_.append(r) - groups_.append(cohort) - - # concatenate results together, - # sort to make sure we match expected output - groups = (np.hstack(groups_),) - result = np.concatenate(results_, axis=-1) - else: - if method == "blockwise" and by_.ndim == 1: - array = rechunk_for_blockwise(array, axis=-1, labels=by_) - - result, groups = partial_agg( - array, - by_, - expected_groups=None if method == "blockwise" else expected_groups, - agg=agg, - reindex=reindex, - method=method, - sort=sort, - ) + result, groups = partial_agg( + array, + by_, + expected_groups=None if method == "blockwise" else expected_groups, + agg=agg, + reindex=reindex, + method=method, + sort=sort, + ) if sort and method != "map-reduce": assert len(groups) == 1 diff --git a/tests/test_core.py b/tests/test_core.py index c70ae83c6..b799c572f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -776,9 +776,9 @@ def test_cohorts_nd_by(func, method, axis, engine): assert_equal(actual, expected) actual, groups = groupby_reduce(array, by, sort=False, **kwargs) - if method == "cohorts": + if method in ("split-reduce", "cohorts"): assert_equal(groups, [4, 3, 40, 2, 31, 1, 30]) - elif method in ("split-reduce", "map-reduce"): + elif method == "map-reduce": assert_equal(groups, [1, 30, 2, 31, 3, 4, 40]) elif method == "blockwise": assert_equal(groups, [1, 30, 2, 31, 3, 40, 4]) From 628c64a9af8b193fb864688a23428a17e81678ef Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 16:15:46 -0600 Subject: [PATCH 02/14] Fix for split_out > 1 --- flox/core.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/flox/core.py b/flox/core.py index fdc33a70a..7f2a13607 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1359,12 +1359,11 @@ def dask_groupby_agg( nblocks = tuple(len(array.chunks[ax]) for ax in axis) inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks) else: - inchunk = ( - ochunk[:-1] - + (0,) * (len(axis) - 1) - + (ochunk[-1],) # always 0 for map-reduce, something else for cohorts, split-reduce - + (ochunk[-1],) * int((split_out > 1)) - ) + inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + if split_out > 1: + inchunk = inchunk + (0,) + inchunk = inchunk + (ochunk[-1],) + layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) result = dask.array.Array( From 23e0ddfe9de71e9c20cfa5d45aab56044f474ac1 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 16:27:20 -0600 Subject: [PATCH 03/14] Fix test_find_group_cohorts --- tests/test_core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index b799c572f..a0b0537d8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -658,12 +658,12 @@ def test_rechunk_for_blockwise(inchunks, expected): ], ) def test_find_group_cohorts(expected, labels, chunks, merge): - actual = list(find_group_cohorts(labels, (chunks,), merge, method="cohorts")) + actual = list(find_group_cohorts(labels, (chunks,), merge, method="cohorts").values()) assert actual == expected, (actual, expected) - actual = find_group_cohorts(labels, (chunks,), merge, method="split-reduce") - expected = [[label] for label in np.unique(labels)] - assert actual == expected, (actual, expected) + # actual = find_group_cohorts(labels, (chunks,), merge, method="split-reduce") + # expected = [[label] for label in np.unique(labels)] + # assert actual == expected, (actual, expected) @requires_dask From bcd430f1acc1f443c29936f6de7337c38d20d1be Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 16:28:54 -0600 Subject: [PATCH 04/14] More test fixes --- flox/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 7f2a13607..3f7d40b05 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1287,11 +1287,11 @@ def dask_groupby_agg( expected_groups_ = _get_expected_groups(by_input, sort=sort) else: expected_groups_ = expected_groups - groups = (expected_groups_.to_numpy(),) + groups = (expected_groups_.to_numpy(),) elif method in ["cohorts", "split-reduce"]: chunks_cohorts = find_group_cohorts( - by, [array.chunks[ax] for ax in axis], merge=True, method=method + by_input, [array.chunks[ax] for ax in axis], merge=True, method=method ) reduced_ = [] groups_ = [] From fa13992cce5c970584a7e8750fb816cd98f5b8f9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 16:49:25 -0600 Subject: [PATCH 05/14] small fixes --- flox/core.py | 1 - flox/visualize.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index 3f7d40b05..9970c90be 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1321,7 +1321,6 @@ def dask_groupby_agg( reduced = dask.array.concatenate(reduced_, axis=-1) groups = (np.concatenate(groups_),) group_chunks = (tuple(len(cohort) for cohort in groups_),) - # compute_blocks(reduced) elif method == "blockwise": reduced = intermediate diff --git a/flox/visualize.py b/flox/visualize.py index c3cd6c816..fd712fd4b 100644 --- a/flox/visualize.py +++ b/flox/visualize.py @@ -136,10 +136,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"): print("finding cohorts...") before_merged = find_group_cohorts( by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method - ) + ).values() merged = find_group_cohorts( by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method - ) + ).values() print("finished cohorts...") xticks = np.cumsum(array.chunks[-1]) From c48041c966e699b6c7a994ee0d6bb239a21f9136 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 19:29:13 -0600 Subject: [PATCH 06/14] Optimize subset_to_blocks --- flox/core.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/flox/core.py b/flox/core.py index 9970c90be..e0104a859 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1087,20 +1087,39 @@ def subset_to_blocks( blkshape = array.blocks.shape unraveled = np.unravel_index(flatblocks, blkshape) - if len(flatblocks) == 1: - idxr = [slice(None, None)] * (array.ndim - len(blkshape)) - idxr = idxr + [a.item() for a in np.unravel_index(flatblocks, blkshape)] - return array.blocks[tuple(idxr)] - - subset = array - for ax, inds in enumerate(unraveled): - inds = np.unique(inds) - if np.array_equal(inds, np.arange(blkshape[ax])): - continue + normalized = [] + for ax, idx in enumerate(unraveled): + i = np.unique(idx).squeeze() + if i.ndim == 0: + i = i.item() + else: + if np.array_equal(i, np.arange(blkshape[ax])): + i = slice(None) + elif np.array_equal(i, np.arange(i[0], i[-1] + 1)): + i = slice(i[0], i[-1] + 1) + normalized.append(i) + full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized) + + # has no iterables + noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized) + # has all iterables + alliter = { + ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized) + } + + # apply everything but the iterables + if all(i == slice(None) for i in noiter): + return array + subset = array.blocks[noiter] + + for ax, inds in alliter.items(): + if isinstance(inds, slice): + continue idxr = [slice(None, None)] * array.ndim - idxr[array.ndim - len(blkshape) + ax] = inds + idxr[ax] = inds subset = subset.blocks[tuple(idxr)] + return subset From ed7f9350a1375cdc91972537ebc6625990cf0f54 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 19:34:48 -0600 Subject: [PATCH 07/14] Fix mypy --- flox/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/core.py b/flox/core.py index e0104a859..aada90def 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1091,13 +1091,13 @@ def subset_to_blocks( for ax, idx in enumerate(unraveled): i = np.unique(idx).squeeze() if i.ndim == 0: - i = i.item() + i_ = i.item() else: if np.array_equal(i, np.arange(blkshape[ax])): - i = slice(None) + i_ = slice(None) elif np.array_equal(i, np.arange(i[0], i[-1] + 1)): - i = slice(i[0], i[-1] + 1) - normalized.append(i) + i_ = slice(i[0], i[-1] + 1) + normalized.append(i_) full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized) # has no iterables From 04c7f413d4cd9c49bf5f69022a3fa6f9582c7337 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 20:32:58 -0600 Subject: [PATCH 08/14] bugfix --- flox/core.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/flox/core.py b/flox/core.py index aada90def..f6050ea8a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1087,17 +1087,18 @@ def subset_to_blocks( blkshape = array.blocks.shape unraveled = np.unravel_index(flatblocks, blkshape) - normalized = [] + normalized: list[Union[int, np.ndarray, slice]] = [] for ax, idx in enumerate(unraveled): i = np.unique(idx).squeeze() if i.ndim == 0: - i_ = i.item() + normalized.append(i.item()) else: if np.array_equal(i, np.arange(blkshape[ax])): - i_ = slice(None) + normalized.append(slice(None)) elif np.array_equal(i, np.arange(i[0], i[-1] + 1)): - i_ = slice(i[0], i[-1] + 1) - normalized.append(i_) + normalized.append(slice(i[0], i[-1] + 1)) + else: + normalized.append(i) full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized) # has no iterables From 3754cfff470e7f36ef95389124538a87416ea859 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 13:29:10 -0600 Subject: [PATCH 09/14] Bring back split-reduce --- flox/core.py | 60 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/flox/core.py b/flox/core.py index f6050ea8a..651052ede 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1124,6 +1124,37 @@ def subset_to_blocks( return subset +def reduce_cohorts(array, cohort, tree_reduce, do_simple_combine, reindex, agg, aggregate): + """Loop over multiple "cohorts" and apply the reduction. The loop is only + used for split-reduce.""" + import dask.array + + reduced = [] + groups = [] + for member in cohort: + member1d = np.atleast_1d(member) + if do_simple_combine: + # reindex so that reindex can be set to True later + reindexed = dask.array.map_blocks( + reindex_intermediates, + array, + agg=agg, + unique_groups=member1d, + meta=array._meta, + ) + else: + reindexed = array + + reduced.append( + tree_reduce( + reindexed, + aggregate=partial(aggregate, expected_groups=member1d, reindex=reindex), + ) + ) + groups.append(member1d) + return reduced, groups + + def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph @@ -1318,25 +1349,17 @@ def dask_groupby_agg( for blks, cohort in chunks_cohorts.items(): cohort = sorted(cohort) subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) - if do_simple_combine: - # reindex so that reindex can be set to True later - reindexed = dask.array.map_blocks( - reindex_intermediates, - subset, - agg=agg, - unique_groups=cohort, - meta=subset._meta, - ) - else: - reindexed = subset - - reduced_.append( - tree_reduce( - reindexed, - aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex), - ) + r, g = reduce_cohorts( + subset, + (cohort,) if method == "cohorts" else cohort, + tree_reduce, + do_simple_combine, + reindex, + agg, + aggregate, ) - groups_.append(cohort) + reduced_.extend(r) + groups_.extend(g) reduced = dask.array.concatenate(reduced_, axis=-1) groups = (np.concatenate(groups_),) @@ -1751,6 +1774,7 @@ def groupby_reduce( if sort and method != "map-reduce": assert len(groups) == 1 + assert groups[0].ndim == 1 sorted_idx = np.argsort(groups[0]) result = result[..., sorted_idx] groups = (groups[0][sorted_idx],) From f7351d944a94d33ff13e93230d882178a773caad Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 15:07:48 -0600 Subject: [PATCH 10/14] Sort cohorts at detection stage. We sort the members in each cohort (as earlier). And also by first label in each cohort. This means we preserve order as much as possible, which should help when sorting the final result, especially for resampling type operations. --- flox/core.py | 11 ++++++++--- tests/test_core.py | 10 ++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/flox/core.py b/flox/core.py index 651052ede..f49d59c97 100644 --- a/flox/core.py +++ b/flox/core.py @@ -193,7 +193,7 @@ def find_group_cohorts(labels, chunks, merge: bool = True, method: T_MethodCohor if merge: # First sort by number of chunks occupied by cohort sorted_chunks_cohorts = dict( - reversed(sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]))) + sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True) ) items = tuple(sorted_chunks_cohorts.items()) @@ -216,7 +216,13 @@ def find_group_cohorts(labels, chunks, merge: bool = True, method: T_MethodCohor merged_cohorts[k1].extend(v2) merged_keys.append(k2) - return merged_cohorts + # make sure each cohort is sorted after merging + sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()} + # sort by first label in cohort + # This will help when sort=True (default) + # and we have to resort the dask array + return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0])) + else: return chunks_cohorts @@ -1347,7 +1353,6 @@ def dask_groupby_agg( reduced_ = [] groups_ = [] for blks, cohort in chunks_cohorts.items(): - cohort = sorted(cohort) subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) r, g = reduce_cohorts( subset, diff --git a/tests/test_core.py b/tests/test_core.py index a0b0537d8..c48460dc9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -646,11 +646,11 @@ def test_rechunk_for_blockwise(inchunks, expected): [[[1, 2, 3, 4]], [1, 2, 3, 1, 2, 3, 4], (3, 4), True], [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 4), False], [[[1], [2], [3], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), False], - [[[3], [2], [1], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), True], + [[[1], [2], [3], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), True], [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 3, 1), True], [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 3, 1), False], [ - [[2, 3, 4, 1], [5], [0]], + [[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4), True, @@ -776,11 +776,9 @@ def test_cohorts_nd_by(func, method, axis, engine): assert_equal(actual, expected) actual, groups = groupby_reduce(array, by, sort=False, **kwargs) - if method in ("split-reduce", "cohorts"): - assert_equal(groups, [4, 3, 40, 2, 31, 1, 30]) - elif method == "map-reduce": + if method == "map-reduce": assert_equal(groups, [1, 30, 2, 31, 3, 4, 40]) - elif method == "blockwise": + else: assert_equal(groups, [1, 30, 2, 31, 3, 40, 4]) reindexed = reindex_(actual, groups, pd.Index(sorted_groups)) assert_equal(reindexed, expected) From 67af7a5e154015f4997b336a334832c1398c600b Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 15:09:59 -0600 Subject: [PATCH 11/14] Revert "Bring back split-reduce" This reverts commit 3754cfff470e7f36ef95389124538a87416ea859. Again, I don't see any benefits to this. --- flox/core.py | 60 ++++++++++++++++------------------------------------ 1 file changed, 18 insertions(+), 42 deletions(-) diff --git a/flox/core.py b/flox/core.py index f49d59c97..be759d358 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1130,37 +1130,6 @@ def subset_to_blocks( return subset -def reduce_cohorts(array, cohort, tree_reduce, do_simple_combine, reindex, agg, aggregate): - """Loop over multiple "cohorts" and apply the reduction. The loop is only - used for split-reduce.""" - import dask.array - - reduced = [] - groups = [] - for member in cohort: - member1d = np.atleast_1d(member) - if do_simple_combine: - # reindex so that reindex can be set to True later - reindexed = dask.array.map_blocks( - reindex_intermediates, - array, - agg=agg, - unique_groups=member1d, - meta=array._meta, - ) - else: - reindexed = array - - reduced.append( - tree_reduce( - reindexed, - aggregate=partial(aggregate, expected_groups=member1d, reindex=reindex), - ) - ) - groups.append(member1d) - return reduced, groups - - def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph @@ -1354,17 +1323,25 @@ def dask_groupby_agg( groups_ = [] for blks, cohort in chunks_cohorts.items(): subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) - r, g = reduce_cohorts( - subset, - (cohort,) if method == "cohorts" else cohort, - tree_reduce, - do_simple_combine, - reindex, - agg, - aggregate, + if do_simple_combine: + # reindex so that reindex can be set to True later + reindexed = dask.array.map_blocks( + reindex_intermediates, + subset, + agg=agg, + unique_groups=cohort, + meta=subset._meta, + ) + else: + reindexed = subset + + reduced_.append( + tree_reduce( + reindexed, + aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex), + ) ) - reduced_.extend(r) - groups_.extend(g) + groups_.append(cohort) reduced = dask.array.concatenate(reduced_, axis=-1) groups = (np.concatenate(groups_),) @@ -1779,7 +1756,6 @@ def groupby_reduce( if sort and method != "map-reduce": assert len(groups) == 1 - assert groups[0].ndim == 1 sorted_idx = np.argsort(groups[0]) result = result[..., sorted_idx] groups = (groups[0][sorted_idx],) From d2772671278de1b9131506d1fe07d0b991bcae1e Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 15:19:08 -0600 Subject: [PATCH 12/14] Don't sort output unless necessary --- flox/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index be759d358..156f8b0f9 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1757,8 +1757,10 @@ def groupby_reduce( if sort and method != "map-reduce": assert len(groups) == 1 sorted_idx = np.argsort(groups[0]) - result = result[..., sorted_idx] - groups = (groups[0][sorted_idx],) + # This optimization helps specifically with resampling + if not (sorted_idx[1:] <= sorted_idx[:-1]).all(): + result = result[..., sorted_idx] + groups = (groups[0][sorted_idx],) if factorize_early: # nan group labels are factorized to -1, and preserved From 378cbe4ec484cba4789799da95aba1ef550cd1c4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 15:27:19 -0600 Subject: [PATCH 13/14] Remove split-reduce. --- docs/source/implementation.md | 2 +- flox/core.py | 6 ++---- flox/xarray.py | 4 +--- tests/test_core.py | 6 +----- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/docs/source/implementation.md b/docs/source/implementation.md index e911386c5..ae0db3539 100644 --- a/docs/source/implementation.md +++ b/docs/source/implementation.md @@ -13,7 +13,7 @@ or `xarray_reduce`. First we describe xarray's current strategy -## `method="split-reduce"`: Xarray's current GroupBy strategy +## Background: Xarray's current GroupBy strategy Xarray's current strategy is to find all unique group labels, index out each group, and then apply the reduction operation. Note that this only works if we know the group diff --git a/flox/core.py b/flox/core.py index 156f8b0f9..2fb8a1d58 100644 --- a/flox/core.py +++ b/flox/core.py @@ -138,7 +138,7 @@ def _get_optimal_chunks_for_groups(chunks, labels): @memoize -def find_group_cohorts(labels, chunks, merge: bool = True, method: T_MethodCohorts = "cohorts"): +def find_group_cohorts(labels, chunks, merge: bool = True): """ Finds groups labels that occur together aka "cohorts" @@ -1570,9 +1570,7 @@ def groupby_reduce( method by first rechunking using ``rechunk_for_cohorts`` (for 1D ``by`` only). * ``"split-reduce"``: - Break out each group into its own array and then ``"map-reduce"``. - This is implemented by having each group be its own cohort, - and is identical to xarray's default strategy. + Same as "cohorts" and will be removed soon. engine : {"flox", "numpy", "numba"}, optional Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: * ``"numpy"``: diff --git a/flox/xarray.py b/flox/xarray.py index 5f87bafe6..58dbec429 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -123,9 +123,7 @@ def xarray_reduce( method by first rechunking using ``rechunk_for_cohorts`` (for 1D ``by`` only). * ``"split-reduce"``: - Break out each group into its own array and then ``"map-reduce"``. - This is implemented by having each group be its own cohort, - and is identical to xarray's default strategy. + Same as "cohorts" and will be removed soon. engine : {"flox", "numpy", "numba"}, optional Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: * ``"numpy"``: diff --git a/tests/test_core.py b/tests/test_core.py index c48460dc9..75fae09be 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -658,13 +658,9 @@ def test_rechunk_for_blockwise(inchunks, expected): ], ) def test_find_group_cohorts(expected, labels, chunks, merge): - actual = list(find_group_cohorts(labels, (chunks,), merge, method="cohorts").values()) + actual = list(find_group_cohorts(labels, (chunks,), merge).values()) assert actual == expected, (actual, expected) - # actual = find_group_cohorts(labels, (chunks,), merge, method="split-reduce") - # expected = [[label] for label in np.unique(labels)] - # assert actual == expected, (actual, expected) - @requires_dask @pytest.mark.parametrize( From 37240294e57b3944dbb327b98181822bf634f1f3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 11 Oct 2022 15:36:34 -0600 Subject: [PATCH 14/14] Fix typo --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 2fb8a1d58..7471fe968 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1317,7 +1317,7 @@ def dask_groupby_agg( elif method in ["cohorts", "split-reduce"]: chunks_cohorts = find_group_cohorts( - by_input, [array.chunks[ax] for ax in axis], merge=True, method=method + by_input, [array.chunks[ax] for ax in axis], merge=True ) reduced_ = [] groups_ = []