From 9b1a90b661dced1846d31848c357dc641e108cd2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 2 Nov 2024 23:51:08 -0600 Subject: [PATCH 1/5] Optimize grouped first, last. 1. Use flox where possible. 2. Use simple indexing where possible. Closes #9647 --- xarray/core/groupby.py | 91 ++++++++++++++++++++++++++++++++++-- xarray/core/resample.py | 16 +++++++ xarray/tests/test_groupby.py | 2 + 3 files changed, 104 insertions(+), 5 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ceae79031f8..e3345e74309 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -20,6 +20,7 @@ from xarray.core.alignment import align, broadcast from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce +from xarray.core.computation import apply_ufunc from xarray.core.concat import concat from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.duck_array_ops import where @@ -1357,7 +1358,9 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray: """ return ops.where_method(self, cond, other) - def _first_or_last(self, op, skipna, keep_attrs): + def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None): + from xarray.core.dataarray import DataArray + if all( isinstance(maybe_slice, slice) and (maybe_slice.stop == maybe_slice.start + 1) @@ -1368,17 +1371,95 @@ def _first_or_last(self, op, skipna, keep_attrs): return self._obj if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - return self.reduce( - op, dim=[self._group_dim], skipna=skipna, keep_attrs=keep_attrs + + def _groupby_first_last_wrapper( + values, + by, + *, + op: Literal["first", "last"], + skipna: bool | None, + group_indices, + ): + no_nans = dtypes.isdtype( + values.dtype, "signed integer" + ) or dtypes.is_string(values.dtype) + if (skipna or skipna is None) and not no_nans: + skipna = True + else: + skipna = False + + if TYPE_CHECKING: + assert isinstance(skipna, bool) + + if skipna is False or (skipna and no_nans): + # this is an optimization: when skipna=False, we can simply index + # the whole object after picking the first/last member of each group + # in self.encoded.group_indices + if op == "first": + indices = [ + (idx.start if isinstance(idx, slice) else idx[0]) + for idx in group_indices + if idx + ] + else: + indices = [ + (idx.stop - 1 if isinstance(idx, slice) else idx[-1]) + for idx in self.encoded.group_indices + if idx + ] + return self._obj.isel({self._group_dim: indices}) + + elif ( + skipna + and module_available("flox", minversion="0.9.14") + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) + ): + import flox + + result, *_ = flox.groupby_reduce( + values, self.group1d.data, axis=-1, func=f"nan{op}" + ) + return result + + else: + return self.reduce( + getattr(duck_array_ops, op), + dim=[self._group_dim], + skipna=skipna, + keep_attrs=keep_attrs, + ) + + result = apply_ufunc( + _groupby_first_last_wrapper, + self._obj, + self.group1d, + input_core_dims=[[self._group_dim], [self._group_dim]], + output_core_dims=[[self.group1d.name]], + dask="allowed", + output_sizes={self.group1d.name: len(self)}, + exclude_dims={self._group_dim}, + keep_attrs=keep_attrs, + kwargs={ + "op": op, + "skipna": skipna, + "group_indices": self.encoded.group_indices, + }, ) + result = result.assign_coords(self.encoded.coords) + result = self._maybe_unstack(result) + result = self._maybe_restore_empty_groups(result) + if isinstance(result, DataArray): + result = self._restore_dim_order(result) + return result def first(self, skipna: bool | None = None, keep_attrs: bool | None = None): """Return the first element of each group along the group dimension""" - return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) + return self._first_or_last("first", skipna, keep_attrs) def last(self, skipna: bool | None = None, keep_attrs: bool | None = None): """Return the last element of each group along the group dimension""" - return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) + return self._first_or_last("last", skipna, keep_attrs) def assign_coords(self, coords=None, **coords_kwargs): """Assign coordinates by group. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 5cc98c9651c..40e3d26bd76 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -103,6 +103,22 @@ def shuffle_to_chunks(self, chunks: T_Chunks = None): (grouper,) = self.groupers return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM) + def _first_or_last( + self, op: str, skipna: bool | None, keep_attrs: bool | None + ) -> T_Xarray: + from xarray.core.dataset import Dataset + + result = super()._first_or_last(op=op, skipna=skipna, keep_attrs=keep_attrs) + result = result.rename({RESAMPLE_DIM: self._group_dim}) + if isinstance(result, Dataset): + # Can't do this in the base class because group_dim is RESAMPLE_DIM + # which is not present in the original object + for var in result.data_vars: + result._variables[var] = result._variables[var].transpose( + *self._obj._variables[var].dims + ) + return result + def _drop_coords(self) -> T_Xarray: """Drop non-dimension coordinates along the resampled dimension.""" obj = self._obj diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3c7f46f5a02..87cfa8e16f7 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1618,6 +1618,8 @@ def test_groupby_first_and_last(self) -> None: expected = array # should be a no-op assert_identical(expected, actual) + # TODO: groupby_bins too + def make_groupby_multidim_example_array(self) -> DataArray: return DataArray( [[[0, 1], [2, 3]], [[5, 10], [15, 20]]], From d86bec1379c231808089844d4e8046db9149286f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 12 Jan 2025 16:01:43 -0600 Subject: [PATCH 2/5] simplify --- xarray/core/groupby.py | 97 +++++++---------------------------------- xarray/core/resample.py | 1 - 2 files changed, 15 insertions(+), 83 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e3345e74309..dc3bce2c71b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -20,7 +20,6 @@ from xarray.core.alignment import align, broadcast from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce -from xarray.core.computation import apply_ufunc from xarray.core.concat import concat from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.duck_array_ops import where @@ -1359,8 +1358,6 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray: return ops.where_method(self, cond, other) def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None): - from xarray.core.dataarray import DataArray - if all( isinstance(maybe_slice, slice) and (maybe_slice.stop == maybe_slice.start + 1) @@ -1371,86 +1368,22 @@ def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None): return self._obj if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - - def _groupby_first_last_wrapper( - values, - by, - *, - op: Literal["first", "last"], - skipna: bool | None, - group_indices, + if ( + skipna + and module_available("flox", minversion="0.9.16") + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) ): - no_nans = dtypes.isdtype( - values.dtype, "signed integer" - ) or dtypes.is_string(values.dtype) - if (skipna or skipna is None) and not no_nans: - skipna = True - else: - skipna = False - - if TYPE_CHECKING: - assert isinstance(skipna, bool) - - if skipna is False or (skipna and no_nans): - # this is an optimization: when skipna=False, we can simply index - # the whole object after picking the first/last member of each group - # in self.encoded.group_indices - if op == "first": - indices = [ - (idx.start if isinstance(idx, slice) else idx[0]) - for idx in group_indices - if idx - ] - else: - indices = [ - (idx.stop - 1 if isinstance(idx, slice) else idx[-1]) - for idx in self.encoded.group_indices - if idx - ] - return self._obj.isel({self._group_dim: indices}) - - elif ( - skipna - and module_available("flox", minversion="0.9.14") - and OPTIONS["use_flox"] - and contains_only_chunked_or_numpy(self._obj) - ): - import flox - - result, *_ = flox.groupby_reduce( - values, self.group1d.data, axis=-1, func=f"nan{op}" - ) - return result - - else: - return self.reduce( - getattr(duck_array_ops, op), - dim=[self._group_dim], - skipna=skipna, - keep_attrs=keep_attrs, - ) - - result = apply_ufunc( - _groupby_first_last_wrapper, - self._obj, - self.group1d, - input_core_dims=[[self._group_dim], [self._group_dim]], - output_core_dims=[[self.group1d.name]], - dask="allowed", - output_sizes={self.group1d.name: len(self)}, - exclude_dims={self._group_dim}, - keep_attrs=keep_attrs, - kwargs={ - "op": op, - "skipna": skipna, - "group_indices": self.encoded.group_indices, - }, - ) - result = result.assign_coords(self.encoded.coords) - result = self._maybe_unstack(result) - result = self._maybe_restore_empty_groups(result) - if isinstance(result, DataArray): - result = self._restore_dim_order(result) + result, *_ = self._flox_reduce( + dim=None, func=f"nan{op}" if skipna else op, keep_attrs=keep_attrs + ) + else: + result = self.reduce( + getattr(duck_array_ops, op), + dim=[self._group_dim], + skipna=skipna, + keep_attrs=keep_attrs, + ) return result def first(self, skipna: bool | None = None, keep_attrs: bool | None = None): diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 40e3d26bd76..da8334a5b52 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -109,7 +109,6 @@ def _first_or_last( from xarray.core.dataset import Dataset result = super()._first_or_last(op=op, skipna=skipna, keep_attrs=keep_attrs) - result = result.rename({RESAMPLE_DIM: self._group_dim}) if isinstance(result, Dataset): # Can't do this in the base class because group_dim is RESAMPLE_DIM # which is not present in the original object From 8d2d7ad47d96dd00ba8bd8d2a0d40db1ca9fb957 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 25 Jan 2025 14:05:26 -0700 Subject: [PATCH 3/5] add whats-new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9b40a323f39..5c1126fbc4d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -54,6 +54,8 @@ New Features By `Kai Mühlbauer `_ and `Spencer Clark `_. - Improve the error message raised when no key is matching the available variables in a dataset. (:pull:`9943`) By `Jimmy Westling `_. +- :py:meth:`DatasetGroupBy.first` and :py:meth:`DatasetGroupBy.last` can now use ``flox`` if available. (:issue:`9647`) + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ From 1ad2e5154617d2f351a323284c64369d7cc59c06 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sun, 26 Jan 2025 19:35:45 -0700 Subject: [PATCH 4/5] typing --- xarray/core/groupby.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index dc3bce2c71b..291632f91c9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1357,7 +1357,12 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray: """ return ops.where_method(self, cond, other) - def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None): + def _first_or_last( + self, + op: Literal["first" | "last"], + skipna: bool | None, + keep_attrs: bool | None, + ): if all( isinstance(maybe_slice, slice) and (maybe_slice.stop == maybe_slice.start + 1) @@ -1369,13 +1374,12 @@ def _first_or_last(self, op: str, skipna: bool | None, keep_attrs: bool | None): if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) if ( - skipna - and module_available("flox", minversion="0.9.16") + module_available("flox", minversion="0.9.16") and OPTIONS["use_flox"] and contains_only_chunked_or_numpy(self._obj) ): result, *_ = self._flox_reduce( - dim=None, func=f"nan{op}" if skipna else op, keep_attrs=keep_attrs + dim=None, func=op, skipna=skipna, keep_attrs=keep_attrs ) else: result = self.reduce( From 37eff5ab8fd84d27e893026cfce163835b35938b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 27 Jan 2025 11:27:04 -0700 Subject: [PATCH 5/5] more typing --- xarray/core/resample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index da8334a5b52..ebd3d46eb61 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable, Hashable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from xarray.core._aggregations import ( DataArrayResampleAggregations, @@ -104,7 +104,7 @@ def shuffle_to_chunks(self, chunks: T_Chunks = None): return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM) def _first_or_last( - self, op: str, skipna: bool | None, keep_attrs: bool | None + self, op: Literal["first", "last"], skipna: bool | None, keep_attrs: bool | None ) -> T_Xarray: from xarray.core.dataset import Dataset