Skip to content

Commit 3c6d2d5

Browse files
committed
Grouper, Resampler as public api
1 parent c9d3084 commit 3c6d2d5

File tree

5 files changed

+77
-29
lines changed

5 files changed

+77
-29
lines changed

xarray/core/common.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ def _resample(
10491049
# TODO support non-string indexer after removing the old API.
10501050

10511051
from xarray.core.dataarray import DataArray
1052-
from xarray.core.groupby import ResolvedGrouper, TimeResampler
1052+
from xarray.core.groupby import Resampler, ResolvedGrouper, TimeResampler
10531053
from xarray.core.resample import RESAMPLE_DIM
10541054

10551055
# note: the second argument (now 'skipna') use to be 'dim'
@@ -1079,15 +1079,19 @@ def _resample(
10791079
name=RESAMPLE_DIM,
10801080
)
10811081

1082-
grouper = TimeResampler(
1083-
freq=freq,
1084-
closed=closed,
1085-
label=label,
1086-
origin=origin,
1087-
offset=offset,
1088-
loffset=loffset,
1089-
base=base,
1090-
)
1082+
if isinstance(freq, str):
1083+
grouper = TimeResampler(
1084+
freq=freq,
1085+
closed=closed,
1086+
label=label,
1087+
origin=origin,
1088+
offset=offset,
1089+
loffset=loffset,
1090+
base=base,
1091+
)
1092+
else:
1093+
assert isinstance(freq, Resampler)
1094+
grouper = freq
10911095

10921096
rgrouper = ResolvedGrouper(grouper, group, self)
10931097

xarray/core/dataarray.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6636,9 +6636,10 @@ def interp_calendar(
66366636

66376637
def groupby(
66386638
self,
6639-
group: Hashable | DataArray | IndexVariable,
6639+
group: Hashable | DataArray | IndexVariable = None,
66406640
squeeze: bool | None = None,
66416641
restore_coord_dims: bool = False,
6642+
**groupers,
66426643
) -> DataArrayGroupBy:
66436644
"""Returns a DataArrayGroupBy object for performing grouped operations.
66446645
@@ -6710,7 +6711,19 @@ def groupby(
67106711
)
67116712

67126713
_validate_groupby_squeeze(squeeze)
6713-
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
6714+
6715+
if group is not None:
6716+
assert not groupers
6717+
grouper = UniqueGrouper()
6718+
else:
6719+
if len(groupers) > 1:
6720+
raise ValueError("grouping by multiple variables is not supported yet.")
6721+
if not groupers:
6722+
raise ValueError
6723+
group, grouper = next(iter(groupers.items()))
6724+
6725+
rgrouper = ResolvedGrouper(grouper, group, self)
6726+
67146727
return DataArrayGroupBy(
67156728
self,
67166729
(rgrouper,),

xarray/core/dataset.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10134,9 +10134,10 @@ def interp_calendar(
1013410134

1013510135
def groupby(
1013610136
self,
10137-
group: Hashable | DataArray | IndexVariable,
10137+
group: Hashable | DataArray | IndexVariable | None = None,
1013810138
squeeze: bool | None = None,
1013910139
restore_coord_dims: bool = False,
10140+
**groupers,
1014010141
) -> DatasetGroupBy:
1014110142
"""Returns a DatasetGroupBy object for performing grouped operations.
1014210143
@@ -10186,7 +10187,16 @@ def groupby(
1018610187
)
1018710188

1018810189
_validate_groupby_squeeze(squeeze)
10189-
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
10190+
if group is not None:
10191+
assert not groupers
10192+
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
10193+
else:
10194+
if len(groupers) > 1:
10195+
raise ValueError("grouping by multiple variables is not supported yet.")
10196+
if not groupers:
10197+
raise ValueError
10198+
for group, grouper in groupers.items():
10199+
rgrouper = ResolvedGrouper(grouper, group, self)
1019010200

1019110201
return DatasetGroupBy(
1019210202
self,

xarray/core/groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def attrs(self) -> dict:
254254

255255
def __getitem__(self, key):
256256
if isinstance(key, tuple):
257-
key = key[0]
257+
(key,) = key
258258
return self.values[key]
259259

260260
def to_index(self) -> pd.Index:

xarray/tests/test_groupby.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import xarray as xr
1414
from xarray import DataArray, Dataset, Variable
15-
from xarray.core.groupby import _consolidate_slices
15+
from xarray.core.groupby import (
16+
BinGrouper,
17+
UniqueGrouper,
18+
_consolidate_slices,
19+
)
1620
from xarray.tests import (
1721
InaccessibleArray,
1822
assert_allclose,
@@ -112,8 +116,9 @@ def test_multi_index_groupby_map(dataset) -> None:
112116
assert_equal(expected, actual)
113117

114118

115-
def test_reduce_numeric_only(dataset) -> None:
116-
gb = dataset.groupby("x", squeeze=False)
119+
@pytest.mark.parametrize("grouper", [dict(group="x"), dict(x=UniqueGrouper())])
120+
def test_reduce_numeric_only(dataset, grouper) -> None:
121+
gb = dataset.groupby(**grouper, squeeze=False)
117122
with xr.set_options(use_flox=False):
118123
expected = gb.sum()
119124
with xr.set_options(use_flox=True):
@@ -830,11 +835,12 @@ def test_groupby_dataset_reduce() -> None:
830835

831836
expected = data.mean("y")
832837
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
833-
actual = data.groupby("x").mean(...)
834-
assert_allclose(expected, actual)
838+
for gb in [data.groupby("x"), data.groupby(x=UniqueGrouper())]:
839+
actual = gb.mean(...)
840+
assert_allclose(expected, actual)
835841

836-
actual = data.groupby("x").mean("y")
837-
assert_allclose(expected, actual)
842+
actual = gb.mean("y")
843+
assert_allclose(expected, actual)
838844

839845
letters = data["letters"]
840846
expected = Dataset(
@@ -844,8 +850,9 @@ def test_groupby_dataset_reduce() -> None:
844850
"yonly": data["yonly"].groupby(letters).mean(),
845851
}
846852
)
847-
actual = data.groupby("letters").mean(...)
848-
assert_allclose(expected, actual)
853+
for gb in [data.groupby("letters"), data.groupby(letters=UniqueGrouper())]:
854+
actual = gb.mean(...)
855+
assert_allclose(expected, actual)
849856

850857

851858
@pytest.mark.parametrize("squeeze", [True, False])
@@ -975,6 +982,14 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
975982
)
976983
assert_identical(expected, actual)
977984

985+
with xr.set_options(use_flox=use_flox):
986+
actual = da.groupby(
987+
x=BinGrouper(
988+
bins=x_bins, cut_kwargs=dict(include_lowest=True, right=False)
989+
),
990+
).mean()
991+
assert_identical(expected, actual)
992+
978993

979994
@pytest.mark.parametrize("indexed_coord", [True, False])
980995
def test_groupby_bins_math(indexed_coord) -> None:
@@ -983,11 +998,17 @@ def test_groupby_bins_math(indexed_coord) -> None:
983998
if indexed_coord:
984999
da["x"] = np.arange(N)
9851000
da["y"] = np.arange(N)
986-
g = da.groupby_bins("x", np.arange(0, N + 1, 3))
987-
mean = g.mean()
988-
expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1]))
989-
actual = g - mean
990-
assert_identical(expected, actual)
1001+
1002+
for g in [
1003+
da.groupby_bins("x", np.arange(0, N + 1, 3)),
1004+
da.groupby(x=BinGrouper(bins=np.arange(0, N + 1, 3))),
1005+
]:
1006+
mean = g.mean()
1007+
expected = da.isel(x=slice(1, None)) - mean.isel(
1008+
x_bins=("x", [0, 0, 0, 1, 1, 1])
1009+
)
1010+
actual = g - mean
1011+
assert_identical(expected, actual)
9911012

9921013

9931014
def test_groupby_math_nD_group() -> None:

0 commit comments

Comments
 (0)