diff --git a/docs/source/implementation.md b/docs/source/implementation.md index e911386c5..ae0db3539 100644 --- a/docs/source/implementation.md +++ b/docs/source/implementation.md @@ -13,7 +13,7 @@ or `xarray_reduce`. First we describe xarray's current strategy -## `method="split-reduce"`: Xarray's current GroupBy strategy +## Background: Xarray's current GroupBy strategy Xarray's current strategy is to find all unique group labels, index out each group, and then apply the reduction operation. Note that this only works if we know the group diff --git a/flox/core.py b/flox/core.py index c854ce953..344e2b822 100644 --- a/flox/core.py +++ b/flox/core.py @@ -137,7 +137,7 @@ def _get_optimal_chunks_for_groups(chunks, labels): @memoize -def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "cohorts"): +def find_group_cohorts(labels, chunks, merge: bool = True): """ Finds groups labels that occur together aka "cohorts" @@ -167,9 +167,6 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co # To do this, we must have values in memory so casting to numpy should be safe labels = np.asarray(labels) - if method == "split-reduce": - return list(_get_expected_groups(labels, sort=False).to_numpy().reshape(-1, 1)) - # Build an array with the shape of labels, but where every element is the "chunk number" # 1. First subset the array appropriately axis = range(-labels.ndim, 0) @@ -195,7 +192,7 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co if merge: # First sort by number of chunks occupied by cohort sorted_chunks_cohorts = dict( - reversed(sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]))) + sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True) ) items = tuple(sorted_chunks_cohorts.items()) @@ -218,9 +215,15 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co merged_cohorts[k1].extend(v2) merged_keys.append(k2) - return merged_cohorts.values() + # make sure each cohort is sorted after merging + sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()} + # sort by first label in cohort + # This will help when sort=True (default) + # and we have to resort the dask array + return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0])) + else: - return chunks_cohorts.values() + return chunks_cohorts def rechunk_for_cohorts( @@ -1079,6 +1082,63 @@ def _reduce_blockwise( return result +def subset_to_blocks( + array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None +) -> DaskArray: + """ + Advanced indexing of .blocks such that we always get a regular array back. + + Parameters + ---------- + array : dask.array + flatblocks : flat indices of blocks to extract + blkshape : shape of blocks with which to unravel flatblocks + + Returns + ------- + dask.array + """ + if blkshape is None: + blkshape = array.blocks.shape + + unraveled = np.unravel_index(flatblocks, blkshape) + normalized: list[Union[int, np.ndarray, slice]] = [] + for ax, idx in enumerate(unraveled): + i = np.unique(idx).squeeze() + if i.ndim == 0: + normalized.append(i.item()) + else: + if np.array_equal(i, np.arange(blkshape[ax])): + normalized.append(slice(None)) + elif np.array_equal(i, np.arange(i[0], i[-1] + 1)): + normalized.append(slice(i[0], i[-1] + 1)) + else: + normalized.append(i) + full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized) + + # has no iterables + noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized) + # has all iterables + alliter = { + ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized) + } + + # apply everything but the iterables + if all(i == slice(None) for i in noiter): + return array + + subset = array.blocks[noiter] + + for ax, inds in alliter.items(): + if isinstance(inds, slice): + continue + idxr = [slice(None, None)] * array.ndim + idxr[ax] = inds + subset = subset.blocks[tuple(idxr)] + + return subset + + def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph @@ -1115,6 +1175,7 @@ def dask_groupby_agg( reindex: bool = False, engine: T_Engine = "numpy", sort: bool = True, + chunks_cohorts=None, ) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: import dask.array @@ -1194,7 +1255,7 @@ def dask_groupby_agg( partial( blockwise_method, axis=axis, - expected_groups=expected_groups, + expected_groups=None if method in ["split-reduce", "cohorts"] else expected_groups, engine=engine, sort=sort, ), @@ -1223,43 +1284,77 @@ def dask_groupby_agg( expected_groups = _get_expected_groups(by_input, sort=sort) group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),) - if method == "map-reduce": + if method in ["map-reduce", "cohorts", "split-reduce"]: combine: Callable[..., IntermediateDict] if do_simple_combine: combine = _simple_combine else: combine = partial(_grouped_combine, engine=engine, sort=sort) - # reduced is really a dict mapping reduction name to array - # and "groups" to an array of group labels + # Each chunk of `reduced`` is really a dict mapping + # 1. reduction name to array + # 2. "groups" to an array of group labels # Note: it does not make sense to interpret axis relative to # shape of intermediate results after the blockwise call - reduced = dask.array.reductions._tree_reduce( - intermediate, - aggregate=partial( - _aggregate, - combine=combine, - agg=agg, - expected_groups=None if split_out > 1 else expected_groups, - fill_value=fill_value, - reindex=reindex, - ), + tree_reduce = partial( + dask.array.reductions._tree_reduce, combine=partial(combine, agg=agg), - name=f"{name}-reduce", + name=f"{name}-reduce-{method}", dtype=array.dtype, axis=axis, keepdims=True, concatenate=False, ) - - if is_duck_dask_array(by_input) and expected_groups is None: - groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) - else: - if expected_groups is None: - expected_groups_ = _get_expected_groups(by_input, sort=sort) + aggregate = partial( + _aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex + ) + if method == "map-reduce": + reduced = tree_reduce( + intermediate, + aggregate=partial( + aggregate, expected_groups=None if split_out > 1 else expected_groups + ), + ) + if is_duck_dask_array(by_input) and expected_groups is None: + groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) else: - expected_groups_ = expected_groups - groups = (expected_groups_.to_numpy(),) + if expected_groups is None: + expected_groups_ = _get_expected_groups(by_input, sort=sort) + else: + expected_groups_ = expected_groups + groups = (expected_groups_.to_numpy(),) + + elif method in ["cohorts", "split-reduce"]: + chunks_cohorts = find_group_cohorts( + by_input, [array.chunks[ax] for ax in axis], merge=True + ) + reduced_ = [] + groups_ = [] + for blks, cohort in chunks_cohorts.items(): + subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) + if do_simple_combine: + # reindex so that reindex can be set to True later + reindexed = dask.array.map_blocks( + reindex_intermediates, + subset, + agg=agg, + unique_groups=cohort, + meta=subset._meta, + ) + else: + reindexed = subset + + reduced_.append( + tree_reduce( + reindexed, + aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex), + ) + ) + groups_.append(cohort) + + reduced = dask.array.concatenate(reduced_, axis=-1) + groups = (np.concatenate(groups_),) + group_chunks = (tuple(len(cohort) for cohort in groups_),) elif method == "blockwise": reduced = intermediate @@ -1297,7 +1392,11 @@ def dask_groupby_agg( nblocks = tuple(len(array.chunks[ax]) for ax in axis) inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks) else: - inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1) + inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + if split_out > 1: + inchunk = inchunk + (0,) + inchunk = inchunk + (ochunk[-1],) + layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) result = dask.array.Array( @@ -1326,6 +1425,9 @@ def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_gro if method in ["split-reduce", "cohorts"] and reindex is False: raise NotImplementedError + if method in ["split-reduce", "cohorts"] and reindex is None: + reindex = True + # TODO: Should reindex be a bool-only at this point? Would've been nice but # None's are relied on after this function as well. return reindex @@ -1480,9 +1582,7 @@ def groupby_reduce( method by first rechunking using ``rechunk_for_cohorts`` (for 1D ``by`` only). * ``"split-reduce"``: - Break out each group into its own array and then ``"map-reduce"``. - This is implemented by having each group be its own cohort, - and is identical to xarray's default strategy. + Same as "cohorts" and will be removed soon. engine : {"flox", "numpy", "numba"}, optional Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: * ``"numpy"``: @@ -1652,67 +1752,26 @@ def groupby_reduce( partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs) - if method in ["split-reduce", "cohorts"]: - cohorts = find_group_cohorts( - by_, [array.chunks[ax] for ax in axis_], merge=True, method=method - ) - - results_ = [] - groups_ = [] - for cohort in cohorts: - cohort = sorted(cohort) - # equivalent of xarray.DataArray.where(mask, drop=True) - mask = np.isin(by_, cohort) - indexer = [np.unique(v) for v in np.nonzero(mask)] - array_subset = array - for ax, idxr in zip(range(-by_.ndim, 0), indexer): - array_subset = np.take(array_subset, idxr, axis=ax) - numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_]) - - # get final result for these groups - r, *g = partial_agg( - array_subset, - by_[np.ix_(*indexer)], - expected_groups=pd.Index(cohort), - # First deep copy becasue we might be doping blockwise, - # which sets agg.finalize=None, then map-reduce (GH102) - agg=copy.deepcopy(agg), - # reindex to expected_groups at the blockwise step. - # this approach avoids replacing non-cohort members with - # np.nan or some other sentinel value, and preserves dtypes - reindex=True, - # sort controls the final output order so apply that at the end - sort=False, - # if only a single block along axis, we can just work blockwise - # inspired by https://github.com/dask/dask/issues/8361 - method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce", - ) - results_.append(r) - groups_.append(cohort) + if method == "blockwise" and by_.ndim == 1: + array = rechunk_for_blockwise(array, axis=-1, labels=by_) - # concatenate results together, - # sort to make sure we match expected output - groups = (np.hstack(groups_),) - result = np.concatenate(results_, axis=-1) - else: - if method == "blockwise" and by_.ndim == 1: - array = rechunk_for_blockwise(array, axis=-1, labels=by_) - - result, groups = partial_agg( - array, - by_, - expected_groups=None if method == "blockwise" else expected_groups, - agg=agg, - reindex=reindex, - method=method, - sort=sort, - ) + result, groups = partial_agg( + array, + by_, + expected_groups=None if method == "blockwise" else expected_groups, + agg=agg, + reindex=reindex, + method=method, + sort=sort, + ) if sort and method != "map-reduce": assert len(groups) == 1 sorted_idx = np.argsort(groups[0]) - result = result[..., sorted_idx] - groups = (groups[0][sorted_idx],) + # This optimization helps specifically with resampling + if not (sorted_idx[1:] <= sorted_idx[:-1]).all(): + result = result[..., sorted_idx] + groups = (groups[0][sorted_idx],) if factorize_early: # nan group labels are factorized to -1, and preserved diff --git a/flox/visualize.py b/flox/visualize.py index c3cd6c816..fd712fd4b 100644 --- a/flox/visualize.py +++ b/flox/visualize.py @@ -136,10 +136,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"): print("finding cohorts...") before_merged = find_group_cohorts( by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method - ) + ).values() merged = find_group_cohorts( by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method - ) + ).values() print("finished cohorts...") xticks = np.cumsum(array.chunks[-1]) diff --git a/flox/xarray.py b/flox/xarray.py index 3b8ec96e8..faebc468e 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -126,9 +126,7 @@ def xarray_reduce( method by first rechunking using ``rechunk_for_cohorts`` (for 1D ``by`` only). * ``"split-reduce"``: - Break out each group into its own array and then ``"map-reduce"``. - This is implemented by having each group be its own cohort, - and is identical to xarray's default strategy. + Same as "cohorts" and will be removed soon. engine : {"flox", "numpy", "numba"}, optional Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: * ``"numpy"``: diff --git a/tests/test_core.py b/tests/test_core.py index 05deae2fd..53a71f808 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -669,11 +669,11 @@ def test_rechunk_for_blockwise(inchunks, expected): [[[1, 2, 3, 4]], [1, 2, 3, 1, 2, 3, 4], (3, 4), True], [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 4), False], [[[1], [2], [3], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), False], - [[[3], [2], [1], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), True], + [[[1], [2], [3], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), True], [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 3, 1), True], [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 3, 1), False], [ - [[2, 3, 4, 1], [5], [0]], + [[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4), True, @@ -681,11 +681,7 @@ def test_rechunk_for_blockwise(inchunks, expected): ], ) def test_find_group_cohorts(expected, labels, chunks, merge): - actual = list(find_group_cohorts(labels, (chunks,), merge, method="cohorts")) - assert actual == expected, (actual, expected) - - actual = find_group_cohorts(labels, (chunks,), merge, method="split-reduce") - expected = [[label] for label in np.unique(labels)] + actual = list(find_group_cohorts(labels, (chunks,), merge).values()) assert actual == expected, (actual, expected) @@ -799,11 +795,9 @@ def test_cohorts_nd_by(func, method, axis, engine): assert_equal(actual, expected) actual, groups = groupby_reduce(array, by, sort=False, **kwargs) - if method == "cohorts": - assert_equal(groups, [4, 3, 40, 2, 31, 1, 30]) - elif method in ("split-reduce", "map-reduce"): + if method == "map-reduce": assert_equal(groups, [1, 30, 2, 31, 3, 4, 40]) - elif method == "blockwise": + else: assert_equal(groups, [1, 30, 2, 31, 3, 40, 4]) reindexed = reindex_(actual, groups, pd.Index(sorted_groups)) assert_equal(reindexed, expected)