diff --git a/flox/cache.py b/flox/cache.py index eaac3f360..4f8de8b59 100644 --- a/flox/cache.py +++ b/flox/cache.py @@ -8,4 +8,4 @@ cache = cachey.Cache(1e6) memoize = partial(cache.memoize, key=dask.base.tokenize) except ImportError: - memoize = lambda x: x + memoize = lambda x: x # type: ignore diff --git a/flox/core.py b/flox/core.py index 943fd029e..58b89bf17 100644 --- a/flox/core.py +++ b/flox/core.py @@ -5,7 +5,16 @@ import operator from collections import namedtuple from functools import partial, reduce -from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Mapping, + Sequence, + Union, +) import numpy as np import numpy_groupies as npg @@ -1282,8 +1291,8 @@ def _assert_by_is_aligned(shape, by): def _convert_expected_groups_to_index( - expected_groups: tuple, isbin: bool, sort: bool -) -> pd.Index | None: + expected_groups: Iterable, isbin: Sequence[bool], sort: bool +) -> tuple[pd.Index | None]: out = [] for ex, isbin_ in zip(expected_groups, isbin): if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin): diff --git a/flox/xarray.py b/flox/xarray.py index 29b023a0e..c02959485 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Hashable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union import numpy as np import pandas as pd @@ -19,7 +19,10 @@ from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric if TYPE_CHECKING: - from xarray import DataArray, Dataset, Resample + from xarray.core.resample import Resample + from xarray.core.types import T_DataArray, T_Dataset + + Dims = Union[str, Iterable[Hashable], None] def _get_input_core_dims(group_names, dim, ds, grouper_dims): @@ -51,13 +54,13 @@ def lookup_order(dimension): def xarray_reduce( - obj: Dataset | DataArray, - *by: DataArray | Iterable[str] | Iterable[DataArray], + obj: T_Dataset | T_DataArray, + *by: T_DataArray | Hashable, func: str | Aggregation, expected_groups=None, isbin: bool | Sequence[bool] = False, sort: bool = True, - dim: Hashable = None, + dim: Dims | ellipsis = None, split_out: int = 1, fill_value=None, method: str = "map-reduce", @@ -203,8 +206,11 @@ def xarray_reduce( if keep_attrs is None: keep_attrs = True - if isinstance(isbin, bool): - isbin = (isbin,) * nby + if isinstance(isbin, Sequence): + isbins = isbin + else: + isbins = (isbin,) * nby + if expected_groups is None: expected_groups = (None,) * nby if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list @@ -217,78 +223,86 @@ def xarray_reduce( raise NotImplementedError # eventually drop the variables we are grouping by - maybe_drop = [b for b in by if isinstance(b, str)] + maybe_drop = [b for b in by if isinstance(b, Hashable)] unindexed_dims = tuple( b - for b, isbin_ in zip(by, isbin) - if isinstance(b, str) and not isbin_ and b in obj.dims and b not in obj.indexes + for b, isbin_ in zip(by, isbins) + if isinstance(b, Hashable) and not isbin_ and b in obj.dims and b not in obj.indexes ) - by: tuple[DataArray] = tuple(obj[g] if isinstance(g, str) else g for g in by) # type: ignore + by_da = tuple(obj[g] if isinstance(g, Hashable) else g for g in by) grouper_dims = [] - for g in by: + for g in by_da: for d in g.dims: if d not in grouper_dims: grouper_dims.append(d) - if isinstance(obj, xr.DataArray): - ds = obj._to_temp_dataset() - else: + if isinstance(obj, xr.Dataset): ds = obj + else: + ds = obj._to_temp_dataset() ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) if dim is Ellipsis: if nby > 1: raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.") - dim = tuple(obj.dims) - if by[0].name in ds.dims and not isbin[0]: - dim = tuple(d for d in dim if d != by[0].name) + name_ = by_da[0].name + if name_ in ds.dims and not isbins[0]: + dim_tuple = tuple(d for d in obj.dims if d != name_) + else: + dim_tuple = tuple(obj.dims) elif dim is not None: - dim = _atleast_1d(dim) + dim_tuple = _atleast_1d(dim) else: - dim = tuple() + dim_tuple = tuple() # broadcast all variables against each other along all dimensions in `by` variables # don't exclude `dim` because it need not be a dimension in any of the `by` variables! # in the case where dim is Ellipsis, and by.ndim < obj.ndim # then we also broadcast `by` to all `obj.dims` # TODO: avoid this broadcasting - exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim) - ds, *by = xr.broadcast(ds, *by, exclude=exclude_dims) + exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple) + ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims) - if not dim: - dim = tuple(by[0].dims) + # all members of by_broad have the same dimensions + # so we just pull by_broad[0].dims if dim is None + if not dim_tuple: + dim_tuple = tuple(by_broad[0].dims) - if any(d not in grouper_dims and d not in obj.dims for d in dim): + if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): raise ValueError(f"Cannot reduce over absent dimensions {dim}.") - dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims) - if dims_not_in_groupers == tuple(dim) and not any(isbin): + dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) + if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): # reducing along a dimension along which groups do not vary # This is really just a normal reduction. # This is not right when binning so we exclude. - if skipna and isinstance(func, str): - dsfunc = func[3:] + if isinstance(func, str): + dsfunc = func[3:] if skipna else func else: - dsfunc = func + raise NotImplementedError( + "func must be a string when reducing along a dimension not present in `by`" + ) # TODO: skipna needs test - result = getattr(ds, dsfunc)(dim=dim, skipna=skipna) + result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna) if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(result) else: return result - axis = tuple(range(-len(dim), 0)) - group_names = tuple(g.name if not binned else f"{g.name}_bins" for g, binned in zip(by, isbin)) - - group_shape = [None] * len(by) - expected_groups = list(expected_groups) + axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes # for output xarray objects - for idx, (b, expect, isbin_) in enumerate(zip(by, expected_groups, isbin)): + expected_groups = list(expected_groups) + group_names: tuple[Any, ...] = () + group_sizes: dict[Any, int] = {} + for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)): + group_name = b_.name if not isbin_ else f"{b_.name}_bins" + group_names += (group_name,) + if isbin_ and isinstance(expect, int): raise NotImplementedError( "flox does not support binning into an integer number of bins yet." @@ -297,13 +311,21 @@ def xarray_reduce( if isbin_: raise ValueError( f"Please provided bin edges for group variable {idx} " - f"named {group_names[idx]} in expected_groups." + f"named {group_name} in expected_groups." ) - expected_groups[idx] = _get_expected_groups(b.data, sort=sort, raise_if_dask=True) - - expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort=sort) - group_shape = tuple(len(e) for e in expected_groups) - group_sizes = dict(zip(group_names, group_shape)) + expect_ = _get_expected_groups(b_.data, sort=sort, raise_if_dask=True) + else: + expect_ = expect + expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0] + + # The if-check is for type hinting mainly, it narrows down the return + # type of _convert_expected_groups_to_index to pure pd.Index: + if expect_index is not None: + expected_groups[idx] = expect_index + group_sizes[group_name] = len(expect_index) + else: + # This will never be reached + raise ValueError("expect_index cannot be None") def wrapper(array, *by, func, skipna, **kwargs): # Handle skipna here because I need to know dtype to make a good default choice. @@ -349,20 +371,20 @@ def wrapper(array, *by, func, skipna, **kwargs): if isinstance(obj, xr.Dataset): # broadcasting means the group dim gets added to ds, so we check the original obj for k, v in obj.data_vars.items(): - is_missing_dim = not (any(d in v.dims for d in dim)) + is_missing_dim = not (any(d in v.dims for d in dim_tuple)) if is_missing_dim: missing_dim[k] = v - input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims) + input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims) input_core_dims += [input_core_dims[-1]] * (nby - 1) actual = xr.apply_ufunc( wrapper, - ds.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims), - *by, + ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims), + *by_broad, input_core_dims=input_core_dims, # for xarray's test_groupby_duplicate_coordinate_labels - exclude_dims=set(dim), + exclude_dims=set(dim_tuple), output_core_dims=[group_names], dask="allowed", dask_gufunc_kwargs=dict(output_sizes=group_sizes), @@ -379,18 +401,18 @@ def wrapper(array, *by, func, skipna, **kwargs): "engine": engine, "reindex": reindex, "expected_groups": tuple(expected_groups), - "isbin": isbin, + "isbin": isbins, "finalize_kwargs": finalize_kwargs, }, ) # restore non-dim coord variables without the core dimension # TODO: shouldn't apply_ufunc handle this? - for var in set(ds.variables) - set(ds.dims): - if all(d not in ds[var].dims for d in dim): - actual[var] = ds[var] + for var in set(ds_broad.variables) - set(ds_broad.dims): + if all(d not in ds_broad[var].dims for d in dim_tuple): + actual[var] = ds_broad[var] - for name, expect, by_ in zip(group_names, expected_groups, by): + for name, expect, by_ in zip(group_names, expected_groups, by_broad): # Can't remove this till xarray handles IntervalIndex if isinstance(expect, pd.IntervalIndex): expect = expect.to_numpy() @@ -398,8 +420,8 @@ def wrapper(array, *by, func, skipna, **kwargs): actual = actual.drop_vars(name) # When grouping by MultiIndex, expect is an pd.Index wrapping # an object array of tuples - if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex): - levelnames = ds.indexes[name].names + if name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex): + levelnames = ds_broad.indexes[name].names expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames) actual[name] = expect if Version(xr.__version__) > Version("2022.03.0"): @@ -414,18 +436,17 @@ def wrapper(array, *by, func, skipna, **kwargs): if nby == 1: for var in actual: - if isinstance(obj, xr.DataArray): - template = obj - else: + if isinstance(obj, xr.Dataset): template = obj[var] + else: + template = obj + if actual[var].ndim > 1: - actual[var] = _restore_dim_order(actual[var], template, by[0]) + actual[var] = _restore_dim_order(actual[var], template, by_broad[0]) if missing_dim: for k, v in missing_dim.items(): - missing_group_dims = { - dim: size for dim, size in group_sizes.items() if dim not in v.dims - } + missing_group_dims = {d: size for d, size in group_sizes.items() if d not in v.dims} # The expand_dims is for backward compat with xarray's questionable behaviour if missing_group_dims: actual[k] = v.expand_dims(missing_group_dims).variable @@ -439,9 +460,9 @@ def wrapper(array, *by, func, skipna, **kwargs): def rechunk_for_cohorts( - obj: DataArray | Dataset, + obj: T_DataArray | T_Dataset, dim: str, - labels: DataArray, + labels: T_DataArray, force_new_chunk_at, chunksize: int | None = None, ignore_old_chunks: bool = False, @@ -486,7 +507,7 @@ def rechunk_for_cohorts( ) -def rechunk_for_blockwise(obj: DataArray | Dataset, dim: str, labels: DataArray): +def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray): """ Rechunks array so that group boundaries line up with chunk boundaries, allowing embarassingly parallel group reductions. diff --git a/flox/xrutils.py b/flox/xrutils.py index 17ad2d71d..3e6edd89e 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -19,7 +19,7 @@ dask_array_type = dask.array.Array except ImportError: - dask_array_type = () + dask_array_type = () # type: ignore def asarray(data, xp=np): diff --git a/setup.cfg b/setup.cfg index f254a2f19..3645e5bc7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,3 +57,5 @@ per-file-ignores = exclude= .eggs doc +builtins = + ellipsis