Skip to content

Support cohorts for nD by arrays #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 28, 2021
Merged
176 changes: 105 additions & 71 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -131,9 +130,9 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
Parameters
----------
labels: np.ndarray
1D Array of group labels
chunks: tuple
chunks along grouping dimension for array that is being reduced
mD Array of group labels
array: tuple
nD array that is being reduced
merge: bool, optional
Attempt to merge cohorts when one cohort's chunks are a subset
of another cohort's chunks.
Expand All @@ -147,19 +146,35 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
"""
import copy

import dask
import toolz as tlz

if method == "split-reduce":
return np.unique(labels).reshape(-1, 1).tolist()

which_chunk = np.repeat(np.arange(len(chunks)), chunks)
# To do this, we must have values in memory so casting to numpy should be safe
labels = np.asarray(labels)

# Build an array with the shape of labels, but where every element is the "chunk number"
# 1. First subset the array appropriately
axis = range(-labels.ndim, 0)
# Easier to create a dask array and use the .blocks property
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)

# Iterate over each block and create a new block of same shape with "chunk number"
shape = tuple(array.blocks.shape[ax] for ax in axis)
blocks = np.empty(np.prod(shape), dtype=object)
for idx, block in enumerate(array.blocks.ravel()):
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
which_chunk = np.block(blocks.reshape(shape).tolist()).ravel()

# these are chunks where a label is present
label_chunks = {lab: tuple(np.unique(which_chunk[labels == lab])) for lab in np.unique(labels)}
label_chunks = {
lab: tuple(np.unique(which_chunk[labels.ravel() == lab])) for lab in np.unique(labels)
}
# These invert the label_chunks mapping so we know which labels occur together.
chunks_cohorts = tlz.groupby(label_chunks.get, label_chunks.keys())

# TODO: sort by length of values (i.e. cohort);
# then loop in reverse and merge when keys are subsets of initial keys?
if merge:
# First sort by number of chunks occupied by cohort
sorted_chunks_cohorts = dict(
Expand Down Expand Up @@ -299,7 +314,6 @@ def reindex_(array: np.ndarray, from_, to, fill_value=None, axis: int = -1) -> n
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
return reindexed

from_ = np.atleast_1d(from_)
if from_.dtype.kind == "O" and isinstance(from_[0], tuple):
raise NotImplementedError(
"Currently does not support reindexing with object arrays of tuples. "
Expand Down Expand Up @@ -706,14 +720,14 @@ def _npg_aggregate(
expected_groups: Union[Sequence, np.ndarray, None],
axis: Sequence,
keepdims,
group_ndim: int,
neg_axis: Sequence,
fill_value: Any = None,
min_count: Optional[int] = None,
engine: str = "numpy",
finalize_kwargs: Optional[Mapping] = None,
) -> FinalResultsDict:
"""Final aggregation step of tree reduction"""
results = _npg_combine(x_chunk, agg, axis, keepdims, group_ndim, engine)
results = _npg_combine(x_chunk, agg, axis, keepdims, neg_axis, engine)
return _finalize_results(
results, agg, axis, expected_groups, fill_value, min_count, finalize_kwargs
)
Expand Down Expand Up @@ -742,7 +756,7 @@ def _npg_combine(
agg: Aggregation,
axis: Sequence,
keepdims: bool,
group_ndim: int,
neg_axis: Sequence,
engine: str,
) -> IntermediateDict:
"""Combine intermediates step of tree reduction."""
Expand Down Expand Up @@ -771,12 +785,7 @@ def reindex_intermediates(x):

x_chunk = deepmap(reindex_intermediates, x_chunk)

group_conc_axis: Iterable[int]
if group_ndim == 1:
group_conc_axis = (0,)
else:
group_conc_axis = sorted(group_ndim - ax - 1 for ax in axis)
groups = _conc2(x_chunk, "groups", axis=group_conc_axis)
groups = _conc2(x_chunk, "groups", axis=neg_axis)

if agg.reduction_type == "argreduce":
# If "nanlen" was added for masking later, we need to account for that
Expand Down Expand Up @@ -830,7 +839,7 @@ def reindex_intermediates(x):
np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=agg.dtype)
)
results["groups"] = np.empty(
shape=(1,) * (len(group_conc_axis) - 1) + (0,), dtype=groups.dtype
shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype
)
else:
_results = chunk_reduce(
Expand Down Expand Up @@ -891,6 +900,7 @@ def groupby_agg(
method: str = "map-reduce",
min_count: Optional[int] = None,
isbin: bool = False,
reindex: bool = False,
engine: str = "numpy",
finalize_kwargs: Optional[Mapping] = None,
) -> Tuple["DaskArray", Union[np.ndarray, "DaskArray"]]:
Expand All @@ -902,6 +912,9 @@ def groupby_agg(
assert isinstance(axis, Sequence)
assert all(ax >= 0 for ax in axis)

# these are negative axis indices useful for concatenating the intermediates
neg_axis = range(-len(axis), 0)

inds = tuple(range(array.ndim))
name = f"groupby_{agg.name}"
token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
Expand All @@ -926,11 +939,11 @@ def groupby_agg(
axis=axis,
# with the current implementation we want reindexing at the blockwise step
# only reindex to groups present at combine stage
expected_groups=expected_groups if split_out > 1 or isbin else None,
expected_groups=expected_groups if reindex or split_out > 1 or isbin else None,
fill_value=agg.fill_value["intermediate"],
dtype=agg.dtype["intermediate"],
isbin=isbin,
reindex=split_out > 1,
reindex=reindex or (split_out > 1),
engine=engine,
),
inds,
Expand Down Expand Up @@ -964,7 +977,7 @@ def groupby_agg(
expected_agg = expected_groups

agg_kwargs = dict(
group_ndim=by.ndim,
neg_axis=neg_axis,
fill_value=fill_value,
min_count=min_count,
engine=engine,
Expand All @@ -984,7 +997,7 @@ def groupby_agg(
expected_groups=expected_agg,
**agg_kwargs,
),
combine=partial(_npg_combine, agg=agg, group_ndim=by.ndim, engine=engine),
combine=partial(_npg_combine, agg=agg, neg_axis=neg_axis, engine=engine),
name=f"{name}-reduce",
dtype=array.dtype,
axis=axis,
Expand All @@ -996,12 +1009,7 @@ def groupby_agg(
# Blockwise apply the aggregation step so that one input chunk → one output chunk
# TODO: We could combine this with the chunk reduction and do everything in one task.
# This would also optimize the single block along reduced-axis case.
if (
expected_groups is None
or split_out > 1
or len(axis) > 1
or not isinstance(by_maybe_numpy, np.ndarray)
):
if expected_groups is None or split_out > 1 or not isinstance(by_maybe_numpy, np.ndarray):
raise NotImplementedError

reduced = dask.array.blockwise(
Expand All @@ -1020,17 +1028,25 @@ def groupby_agg(
dtype=array.dtype,
meta=array._meta,
align_arrays=False,
name=f"{name}-blockwise-agg-{token}",
name=f"{name}-blockwise-{token}",
)
chunks = array.chunks[axis[0]]

# find number of groups in each chunk, this is needed for output chunks
# along the reduced axis
bnds = np.insert(np.cumsum(chunks), 0, 0)
groups_per_chunk = tuple(
len(np.unique(by_maybe_numpy[i0:i1])) for i0, i1 in zip(bnds[:-1], bnds[1:])
)
output_chunks = reduced.chunks[: -(len(axis))] + (groups_per_chunk,)
from dask.array.core import slices_from_chunks

slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
if expected_groups is None:
groups_in_block = tuple(np.unique(by_maybe_numpy[slc]) for slc in slices)
else:
# For cohorts, we could be indexing a block with groups that
# are not in the cohort (usually for nD `by`)
# Only keep the expected groups.
groups_in_block = tuple(
np.intersect1d(by_maybe_numpy[slc], expected_groups) for slc in slices
)
ngroups_per_block = tuple(len(groups) for groups in groups_in_block)
output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,)
else:
raise ValueError(f"Unknown method={method}.")

Expand Down Expand Up @@ -1059,15 +1075,22 @@ def _getitem(d, key1, key2):
),
)
else:
groups = (expected_groups,)
if method == "map-reduce":
groups = (expected_groups,)
else:
groups = (np.concatenate(groups_in_block),)

layer: Dict[Tuple, Tuple] = {} # type: ignore
agg_name = f"{name}-{token}"
for ochunk in itertools.product(*ochunks):
if method == "blockwise":
inchunk = ochunk
if len(axis) == 1:
inchunk = ochunk
else:
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis)) + (ochunk[-1],) * int(split_out > 1)
inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
layer[(agg_name, *ochunk)] = (
operator.getitem,
(reduced.name, *inchunk),
Expand All @@ -1089,6 +1112,7 @@ def groupby_reduce(
func: Union[str, Aggregation],
*,
expected_groups: Union[Sequence, np.ndarray] = None,
sort: bool = True,
isbin: bool = False,
axis=None,
fill_value=None,
Expand All @@ -1114,6 +1138,10 @@ def groupby_reduce(
Expected unique labels.
isbin : bool, optional
Are ``expected_groups`` bin edges?
sort : (optional), bool
Whether groups should be returned in sorted order. Only applies for dask
reductions when ``method`` is not `"map-reduce"`. For ``"map-reduce", the groups
are always sorted.
axis : (optional) None or int or Sequence[int]
If None, reduce across all dimensions of by
Else, reduce across corresponding axes of array
Expand All @@ -1138,17 +1166,19 @@ def groupby_reduce(
* ``"blockwise"``:
Only reduce using blockwise and avoid aggregating blocks
together. Useful for resampling-style reductions where group
members are always together. The array is rechunked so that
chunk boundaries line up with group boundaries
members are always together. If `by` is 1D, `array` is automatically
rechunked so that chunk boundaries line up with group boundaries
i.e. each block contains all members of any group present
in that block.
in that block. For nD `by`, you must make sure that all members of a group
are present in a single block.
* ``"cohorts"``:
Finds group labels that tend to occur together ("cohorts"),
indexes out cohorts and reduces that subset using "map-reduce",
repeat for all cohorts. This works well for many time groupings
where the group labels repeat at regular intervals like 'hour',
'month', dayofyear' etc. Optimize chunking ``array`` for this
method by first rechunking using ``rechunk_for_cohorts``.
method by first rechunking using ``rechunk_for_cohorts``
(for 1D ``by`` only).
* ``"split-reduce"``:
Break out each group into its own array and then ``"map-reduce"``.
This is implemented by having each group be its own cohort,
Expand Down Expand Up @@ -1208,6 +1238,11 @@ def groupby_reduce(
else:
axis = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore

if method in ["blockwise", "cohorts", "split-reduce"] and len(axis) != by.ndim:
raise NotImplementedError(
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
)

if expected_groups is None and isinstance(by, np.ndarray):
flatby = by.ravel()
expected_groups = np.unique(flatby[~isnull(flatby)])
Expand Down Expand Up @@ -1366,59 +1401,58 @@ def groupby_reduce(
)

if method in ["split-reduce", "cohorts"]:
if by.ndim > 1:
raise ValueError(
"`by` must be 1D when method='split-reduce' and method='cohorts'. "
f"Received {by.ndim}D array. Please use method='map-reduce' instead."
)
assert axis == (array.ndim - 1,)

cohorts = find_group_cohorts(by, array.chunks[axis[0]], merge=True, method=method)
idx = np.arange(len(by))
cohorts = find_group_cohorts(
by, [array.chunks[ax] for ax in axis], merge=True, method=method
)

results = []
groups_ = []
for cohort in cohorts:
cohort = sorted(cohort)
# indexes for a subset of groups
subset_idx = idx[np.isin(by, cohort)]
array_subset = array[..., subset_idx]
numblocks = len(array_subset.chunks[-1])
# equivalent of xarray.DataArray.where(mask, drop=True)
mask = np.isin(by, cohort)
indexer = [np.unique(v) for v in np.nonzero(mask)]
array_subset = array
for ax, idxr in zip(range(-by.ndim, 0), indexer):
array_subset = np.take(array_subset, idxr, axis=ax)
numblocks = np.prod([len(array_subset.chunks[ax]) for ax in axis])

# get final result for these groups
r, *g = partial_agg(
array_subset,
by[subset_idx],
by[np.ix_(*indexer)],
expected_groups=cohort,
# reindex to expected_groups at the blockwise step.
# this approach avoids replacing non-cohort members with
# np.nan or some other sentinel value, and preserves dtypes
reindex=True,
# if only a single block along axis, we can just work blockwise
# inspired by https://github.com/dask/dask/issues/8361
method="blockwise" if numblocks == 1 else "map-reduce",
method="blockwise" if numblocks == 1 and len(axis) == by.ndim else "map-reduce",
)
results.append(r)
groups_.append(cohort)

# concatenate results together,
# sort to make sure we match expected output
allgroups = np.hstack(groups_)
sorted_idx = np.argsort(allgroups)
result = np.concatenate(results, axis=-1)[..., sorted_idx]
groups = (allgroups[sorted_idx],)

groups = (np.hstack(groups_),)
result = np.concatenate(results, axis=-1)
else:
if method == "blockwise":
if by.ndim > 1:
raise ValueError(
"For method='blockwise', ``by`` must be 1D. "
f"Received {by.ndim} dimensions instead."
)
array = rechunk_for_blockwise(array, axis=-1, labels=by)

# TODO: test with mixed array kinds (numpy + dask; dask + numpy)
if by.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by)

# TODO: test with mixed array kinds (numpy array + dask by)
result, *groups = partial_agg(
array,
by,
expected_groups=expected_groups,
method=method,
)
if sort and method != "map-reduce":
assert len(groups) == 1
sorted_idx = np.argsort(groups[0])
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)

return (result, *groups)
Loading