Skip to content

Commit b21b040

Browse files
Support duck arrays by default (#132)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b6c1e1a commit b21b040

File tree

11 files changed

+40
-30
lines changed

11 files changed

+40
-30
lines changed

ci/docs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ dependencies:
55
- dask-core
66
- pip
77
- xarray
8+
- numpy>=1.20
89
- numpydoc
910
- numpy_groupies
1011
- toolz

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dependencies:
77
- dask-core
88
- netcdf4
99
- pandas
10+
- numpy>=1.20
1011
- pip
1112
- pytest
1213
- pytest-cov

ci/minimal-requirements.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ dependencies:
88
- pytest
99
- pytest-cov
1010
- pytest-xdist
11-
- numpy_groupies>=0.9.15
11+
- numpy==1.20
12+
- numpy_groupies==0.9.15
1213
- pandas
1314
- pooch
1415
- toolz

ci/no-dask.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ dependencies:
55
- codecov
66
- netcdf4
77
- pandas
8+
- numpy>=1.20
89
- pip
910
- pytest
1011
- pytest-cov

ci/no-xarray.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ dependencies:
55
- codecov
66
- netcdf4
77
- pandas
8+
- numpy>=1.20
89
- pip
910
- pytest
1011
- pytest-cov

flox/aggregate_flox.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@
55
from .xrutils import isnull
66

77

8+
def _prepare_for_flox(group_idx, array):
9+
"""
10+
Sort the input array once to save time.
11+
"""
12+
assert array.shape[-1] == group_idx.shape[0]
13+
issorted = (group_idx[:-1] <= group_idx[1:]).all()
14+
if issorted:
15+
ordered_array = array
16+
else:
17+
perm = group_idx.argsort(kind="stable")
18+
group_idx = group_idx[..., perm]
19+
ordered_array = array[..., perm]
20+
return group_idx, ordered_array
21+
22+
823
def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None):
924
"""
1025
most of this code is from shoyer's gist
@@ -13,7 +28,7 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
1328
# assumes input is sorted, which I do in core._prepare_for_flox
1429
aux = group_idx
1530

16-
flag = np.concatenate(([True], aux[1:] != aux[:-1]))
31+
flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1]))
1732
uniques = aux[flag]
1833
(inv_idx,) = flag.nonzero()
1934

@@ -25,11 +40,11 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
2540
if out is None:
2641
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
2742

28-
if (len(uniques) == size) and (uniques == np.arange(size)).all():
43+
if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
2944
# The previous version of this if condition
3045
# ((uniques[1:] - uniques[:-1]) == 1).all():
3146
# does not work when group_idx is [1, 2] for e.g.
32-
# This happens during binning
47+
# This happens during binning
3348
op.reduceat(array, inv_idx, axis=axis, dtype=dtype, out=out)
3449
else:
3550
out[..., uniques] = op.reduceat(array, inv_idx, axis=axis, dtype=dtype)
@@ -91,16 +106,14 @@ def nanlen(group_idx, array, *args, **kwargs):
91106
def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
92107
if fill_value is None:
93108
fill_value = 0
94-
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
95-
sum(group_idx, array, axis=axis, size=size, dtype=dtype, out=out)
109+
out = sum(group_idx, array, axis=axis, size=size, dtype=dtype, fill_value=fill_value)
96110
out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0)
97111
return out
98112

99113

100114
def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
101115
if fill_value is None:
102116
fill_value = 0
103-
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
104-
nansum(group_idx, array, size=size, axis=axis, dtype=dtype, out=out)
117+
out = nansum(group_idx, array, size=size, axis=axis, dtype=dtype, fill_value=fill_value)
105118
out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0)
106119
return out

flox/aggregations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def generic_aggregate(
4646
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
4747
)
4848

49+
group_idx = np.asarray(group_idx, like=array)
50+
4951
return method(
5052
group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
5153
)

flox/core.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import toolz as tlz
1414

1515
from . import xrdtypes
16+
from .aggregate_flox import _prepare_for_flox
1617
from .aggregations import (
1718
Aggregation,
1819
_atleast_1d,
@@ -44,21 +45,6 @@ def _is_arg_reduction(func: str | Aggregation) -> bool:
4445
return False
4546

4647

47-
def _prepare_for_flox(group_idx, array):
48-
"""
49-
Sort the input array once to save time.
50-
"""
51-
assert array.shape[-1] == group_idx.shape[0]
52-
issorted = (group_idx[:-1] <= group_idx[1:]).all()
53-
if issorted:
54-
ordered_array = array
55-
else:
56-
perm = group_idx.argsort(kind="stable")
57-
group_idx = group_idx[..., perm]
58-
ordered_array = array[..., perm]
59-
return group_idx, ordered_array
60-
61-
6248
def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None:
6349
if is_duck_dask_array(by):
6450
if raise_if_dask:
@@ -1367,7 +1353,7 @@ def groupby_reduce(
13671353
min_count: int | None = None,
13681354
split_out: int = 1,
13691355
method: str = "map-reduce",
1370-
engine: str = "flox",
1356+
engine: str = "numpy",
13711357
reindex: bool | None = None,
13721358
finalize_kwargs: Mapping | None = None,
13731359
) -> tuple[DaskArray, np.ndarray | DaskArray]:
@@ -1434,13 +1420,14 @@ def groupby_reduce(
14341420
and is identical to xarray's default strategy.
14351421
engine : {"flox", "numpy", "numba"}, optional
14361422
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
1423+
* ``"numpy"``:
1424+
Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
1425+
This is the default choice because it works for most array types.
14371426
* ``"flox"``:
14381427
Use an internal implementation where the data is sorted so that
14391428
all members of a group occur sequentially, and then numpy.ufunc.reduceat
14401429
is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy``
14411430
for a reduction that is not yet implemented.
1442-
* ``"numpy"``:
1443-
Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
14441431
* ``"numba"``:
14451432
Use the implementations in ``numpy_groupies.aggregate_numba``.
14461433
reindex : bool, optional

flox/xarray.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def xarray_reduce(
6161
split_out: int = 1,
6262
fill_value=None,
6363
method: str = "map-reduce",
64-
engine: str = "flox",
64+
engine: str = "numpy",
6565
keep_attrs: bool | None = True,
6666
skipna: bool | None = None,
6767
min_count: int | None = None,
@@ -125,13 +125,14 @@ def xarray_reduce(
125125
and is identical to xarray's default strategy.
126126
engine : {"flox", "numpy", "numba"}, optional
127127
Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
128+
* ``"numpy"``:
129+
Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
130+
This is the default choice because it works for other array types.
128131
* ``"flox"``:
129132
Use an internal implementation where the data is sorted so that
130133
all members of a group occur sequentially, and then numpy.ufunc.reduceat
131134
is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy``
132135
for a reduction that is not yet implemented.
133-
* ``"numpy"``:
134-
Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
135136
* ``"numba"``:
136137
Use the implementations in ``numpy_groupies.aggregate_numba``.
137138
keep_attrs : bool, optional

flox/xrutils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
9898

9999

100100
def isnull(data):
101-
data = np.asarray(data)
101+
if not is_duck_array(data):
102+
data = np.asarray(data)
102103
scalar_type = data.dtype.type
103104
if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
104105
# datetime types use NaT for null

0 commit comments

Comments
 (0)