diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a0b9e300a94..20bbdc7ec69 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ New Features - Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This includes ``datatree`` support, and removing slashes from dimension names. By `Miguel Jimenez-Urias `_. +- Improved support pandas Extension Arrays. (:issue:`9661`, :pull:`9671`) + By `Ilan Gold `_. + Breaking changes ~~~~~~~~~~~~~~~~ @@ -41,6 +44,12 @@ Breaking changes pydap 3.4 3.5.0 ===================== ========= ======= + +- Reductions with ``groupby_bins`` or those that involve :py:class:`xarray.groupers.BinGrouper` + now return objects indexed by :py:meth:`pandas.IntervalArray` objects, + instead of numpy object arrays containing tuples. This change enables interval-aware indexing of + such Xarray objects. (:pull:`9671`). By `Ilan Gold `_. + Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9b80e154b95..f523f971725 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6917,7 +6917,7 @@ def groupby( [[nan, nan, nan], [ 3., 4., 5.]]]) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 32B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9bfad11994e..9d52f2e0776 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7059,6 +7059,8 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): + from xarray.core.extension_array import PandasExtensionArray + columns_in_order = [k for k in self.variables if k not in self.dims] non_extension_array_columns = [ k @@ -7070,20 +7072,41 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): for k in columns_in_order if is_extension_array_dtype(self.variables[k].data) ] + extension_array_columns_different_index = [ + k + for k in extension_array_columns + if set(self.variables[k].dims) != set(ordered_dims.keys()) + ] + extension_array_columns_same_index = [ + k + for k in extension_array_columns + if k not in extension_array_columns_different_index + ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) for k in non_extension_array_columns ] index = self.coords.to_index([*ordered_dims]) broadcasted_df = pd.DataFrame( - dict(zip(non_extension_array_columns, data, strict=True)), index=index + { + **dict(zip(non_extension_array_columns, data, strict=True)), + **{ + c: self.variables[c].data.array + for c in extension_array_columns_same_index + }, + }, + index=index, ) - for extension_array_column in extension_array_columns: + for extension_array_column in extension_array_columns_different_index: extension_array = self.variables[extension_array_column].data.array - index = self[self.variables[extension_array_column].dims[0]].data + index = self[ + self.variables[extension_array_column].dims[0] + ].coords.to_index() extension_array_df = pd.DataFrame( {extension_array_column: extension_array}, - index=self[self.variables[extension_array_column].dims[0]].data, + index=pd.Index(index.array) + if isinstance(index, PandasExtensionArray) + else index, ) extension_array_df.index.name = self.variables[extension_array_column].dims[ 0 @@ -9892,10 +9915,10 @@ def groupby( >>> from xarray.groupers import BinGrouper, UniqueGrouper >>> >>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() - Size: 128B + Size: 144B Dimensions: (y: 3, x_bins: 2, letters: 2) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 32B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y Data variables: diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 43829b4029f..e8006a4c8c3 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -102,7 +102,7 @@ def replace_duck_with_extension_array(args) -> list: return type(self)[type(res)](res) return res - def __array_ufunc__(ufunc, method, *inputs, **kwargs): + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __repr__(self): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9dc1a26b1f0..bc934132f1c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -11,6 +11,7 @@ from xarray.core import formatting, nputils, utils from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( CoordinateTransformIndexingAdapter, IndexSelResult, @@ -444,6 +445,8 @@ def safe_cast_to_index(array: Any) -> pd.Index: from xarray.core.variable import Variable from xarray.namedarray.pycompat import to_numpy + if isinstance(array, PandasExtensionArray): + array = pd.Index(array.array) if isinstance(array, pd.Index): index = array elif isinstance(array, DataArray | Variable): @@ -602,7 +605,11 @@ def __init__( self.dim = dim if coord_dtype is None: - coord_dtype = get_valid_numpy_dtype(index) + if pd.api.types.is_extension_array_dtype(index.dtype): + cast(pd.api.extensions.ExtensionDtype, index.dtype) + coord_dtype = index.dtype + else: + coord_dtype = get_valid_numpy_dtype(index) self.coord_dtype = coord_dtype def _replace(self, index, dim=None, coord_dtype=None): @@ -698,6 +705,8 @@ def concat( if not indexes: coord_dtype = None + elif len(set(idx.coord_dtype for idx in indexes)) == 1: + coord_dtype = indexes[0].coord_dtype else: coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 2999506d1de..aa56006eff3 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -10,14 +10,16 @@ from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload import numpy as np import pandas as pd +from numpy.typing import DTypeLike from packaging.version import Version from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.extension_array import PandasExtensionArray from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -28,14 +30,13 @@ is_duck_array, is_duck_dask_array, is_scalar, + is_valid_numpy_dtype, to_0d_array, ) from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array if TYPE_CHECKING: - from numpy.typing import DTypeLike - from xarray.core.indexes import Index from xarray.core.types import Self from xarray.core.variable import Variable @@ -1744,27 +1745,43 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("_dtype", "array") array: pd.Index - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype - def __init__(self, array: pd.Index, dtype: DTypeLike = None): + def __init__( + self, + array: pd.Index, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + ): from xarray.core.indexes import safe_cast_to_index self.array = safe_cast_to_index(array) if dtype is None: - self._dtype = get_valid_numpy_dtype(array) + if pd.api.types.is_extension_array_dtype(array.dtype): + cast(pd.api.extensions.ExtensionDtype, array.dtype) + self._dtype = array.dtype + else: + self._dtype = get_valid_numpy_dtype(array) + elif pd.api.types.is_extension_array_dtype(dtype): + self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype) else: - self._dtype = np.dtype(dtype) + self._dtype = np.dtype(cast(DTypeLike, dtype)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override] return self._dtype def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: - if dtype is None: - dtype = self.dtype + if dtype is None and is_valid_numpy_dtype(self.dtype): + dtype = cast(np.dtype, self.dtype) + else: + dtype = get_valid_numpy_dtype(self.array) array = self.array if isinstance(array, pd.PeriodIndex): with suppress(AttributeError): @@ -1776,14 +1793,18 @@ def __array__( else: return np.asarray(array.values, dtype=dtype) - def get_duck_array(self) -> np.ndarray: + def get_duck_array(self) -> np.ndarray | PandasExtensionArray: + # We return an PandasExtensionArray wrapper type that satisfies + # duck array protocols. This is what's needed for tests to pass. + if pd.api.types.is_extension_array_dtype(self.array): + return PandasExtensionArray(self.array.array) return np.asarray(self) @property def shape(self) -> _Shape: return (len(self.array),) - def _convert_scalar(self, item): + def _convert_scalar(self, item) -> np.ndarray: if item is pd.NaT: # work around the impossibility of casting NaT with asarray # note: it probably would be better in general to return @@ -1799,7 +1820,10 @@ def _convert_scalar(self, item): # numpy fails to convert pd.Timestamp to np.datetime64[ns] item = np.asarray(item.to_datetime64()) elif self.dtype != object: - item = np.asarray(item, dtype=self.dtype) + dtype = self.dtype + if pd.api.types.is_extension_array_dtype(dtype): + dtype = get_valid_numpy_dtype(self.array) + item = np.asarray(item, dtype=cast(np.dtype, dtype)) # as for numpy.ndarray indexing, we always want the result to be # a NumPy array. @@ -1902,6 +1926,12 @@ def copy(self, deep: bool = True) -> Self: array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype) + @property + def nbytes(self) -> int: + if pd.api.types.is_extension_array_dtype(self.dtype): + return self.array.nbytes + return cast(np.dtype, self.dtype).itemsize * len(self.array) + class PandasMultiIndexingAdapter(PandasIndexingAdapter): """Handles explicit indexing for a pandas.MultiIndex. @@ -1914,23 +1944,27 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): __slots__ = ("_dtype", "adapter", "array", "level") array: pd.MultiIndex - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype level: str | None def __init__( self, array: pd.MultiIndex, - dtype: DTypeLike = None, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, level: str | None = None, ): super().__init__(array, dtype) self.level = level def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: - dtype = self.dtype + dtype = cast(np.dtype, self.dtype) if self.level is not None: return np.asarray( self.array.get_level_values(self.level).values, dtype=dtype diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f59680dd7df..6d769842a69 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike -from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.compat.array_api_compat import to_like_array @@ -60,6 +59,7 @@ indexing.ExplicitlyIndexed, pd.Index, pd.api.extensions.ExtensionArray, + PandasExtensionArray, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -192,7 +192,7 @@ def _maybe_wrap_data(data): if isinstance(data, pd.Index): return PandasIndexingAdapter(data) if isinstance(data, pd.api.extensions.ExtensionArray): - return PandasExtensionArray[type(data)](data) + return PandasExtensionArray(data) return data @@ -2593,11 +2593,6 @@ def chunk( # type: ignore[override] dask.array.from_array """ - if is_extension_array_dtype(self): - raise ValueError( - f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." - ) - if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index cdf9eab5c8d..f9c1919201f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -834,6 +834,7 @@ def chunk( if chunkmanager.is_chunked_array(data_old): data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] else: + ndata: duckarray[Any, Any] if not isinstance(data_old, ExplicitlyIndexed): ndata = data_old else: diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index a4ed7eba1d0..ee49928aa01 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -524,7 +524,7 @@ def line( assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - if np.issubdtype(xplt.dtype, np.datetime64): + if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr] _set_concise_date(ax, axis="x") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c1310bc7e1d..ed8c4178ed0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4381,7 +4381,6 @@ def test_setitem_pandas(self) -> None: ds["x"] = np.arange(3) ds_copy = ds.copy() ds_copy["bar"] = ds["bar"].to_pandas() - assert_equal(ds, ds_copy) def test_setitem_auto_align(self) -> None: @@ -4972,6 +4971,16 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_categorical_dtype_index(self) -> None: + cat = pd.CategoricalIndex(list("abcd")) + df = pd.DataFrame({"f": [0, 1, 2, 3]}, index=cat) + ds = df.to_xarray() + restored = ds.to_dataframe() + df.index.name = ( + "index" # restored gets the name because it has the coord with the name + ) + pd.testing.assert_frame_equal(df, restored) + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] @@ -4996,7 +5005,7 @@ def test_from_dataframe_categorical_index_string_categories(self) -> None: ) ser = pd.Series(1, index=cat) ds = ser.to_xarray() - assert ds.coords.dtypes["index"] == np.dtype("O") + assert ds.coords.dtypes["index"] == ser.index.dtype @requires_sparse def test_from_dataframe_sparse(self) -> None: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1c351f0ee62..52ab8c4d232 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -4,7 +4,7 @@ import operator import warnings from itertools import pairwise -from typing import Literal +from typing import Literal, cast from unittest import mock import numpy as np @@ -1118,7 +1118,8 @@ def test_groupby_math_nD_group() -> None: expected = da.isel(x=slice(30)) - expanded_mean expected["labels"] = expected.labels.broadcast_like(expected.labels2d) expected["num"] = expected.num.broadcast_like(expected.num2d) - expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data[idxr]) + # mean.num2d_bins.data is a pandas IntervalArray so needs to be put in `numpy` to allow indexing + expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data.to_numpy()[idxr]) actual = g - mean assert_identical(expected, actual) @@ -1680,13 +1681,9 @@ def test_groupby_bins( df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) # type: ignore[call-overload] expected_df = df.groupby("dim_0_bins", observed=True).sum() - # TODO: can't convert df with IntervalIndex to Xarray - expected = ( - expected_df.reset_index(drop=True) - .to_xarray() - .assign_coords(index=np.array(expected_df.index)) - .rename({"index": "dim_0_bins"})["a"] - ) + expected = expected_df.to_xarray().assign_coords( + dim_0_bins=cast(pd.CategoricalIndex, expected_df.index).categories + )["a"] with xr.set_options(use_flox=use_flox): gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) diff --git a/xarray/tests/test_pandas_to_xarray.py b/xarray/tests/test_pandas_to_xarray.py index 590749bf548..111866541eb 100644 --- a/xarray/tests/test_pandas_to_xarray.py +++ b/xarray/tests/test_pandas_to_xarray.py @@ -107,14 +107,6 @@ def index_flat(request): return indices_dict[key].copy() -@pytest.fixture -def using_infer_string() -> bool: - """ - Fixture to check if infer string option is enabled. - """ - return pd.options.future.infer_string is True # type: ignore[union-attr] - - class TestDataFrameToXArray: @pytest.fixture def df(self): @@ -131,8 +123,7 @@ def df(self): } ) - @pytest.mark.xfail(reason="needs some work") - def test_to_xarray_index_types(self, index_flat, df, using_infer_string): + def test_to_xarray_index_types(self, index_flat, df): index = index_flat # MultiIndex is tested in test_to_xarray_with_multiindex if len(index) == 0: @@ -154,9 +145,6 @@ def test_to_xarray_index_types(self, index_flat, df, using_infer_string): # datetimes w/tz are preserved # column names are lost expected = df.copy() - expected["f"] = expected["f"].astype( - object if not using_infer_string else "str" - ) expected.columns.name = None tm.assert_frame_equal(result.to_dataframe(), expected) @@ -168,7 +156,7 @@ def test_to_xarray_empty(self, df): assert result.sizes["foo"] == 0 assert isinstance(result, Dataset) - def test_to_xarray_with_multiindex(self, df, using_infer_string): + def test_to_xarray_with_multiindex(self, df): from xarray import Dataset # MultiIndex @@ -183,9 +171,7 @@ def test_to_xarray_with_multiindex(self, df, using_infer_string): result = result.to_dataframe() expected = df.copy() - expected["f"] = expected["f"].astype( - object if not using_infer_string else "str" - ) + expected["f"] = expected["f"].astype(object) expected.columns.name = None tm.assert_frame_equal(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 8569cb093e7..619dc1561ef 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -333,7 +333,7 @@ def test_pandas_period_index(self): v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="D")) v = v.load() # for dask-based Variable assert v[0] == pd.Period("2000", freq="D") - assert "Period('2000-01-01', 'D')" in repr(v) + assert "PeriodArray" in repr(v) @pytest.mark.parametrize("dtype", [float, int]) def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: @@ -656,7 +656,7 @@ def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error - assert v.dtype == "int64" + assert v.dtype == data.dtype def test_pandas_datetime64_with_tz(self): data = pd.date_range( @@ -667,9 +667,12 @@ def test_pandas_datetime64_with_tz(self): ) v = self.cls("x", data) print(v) # should not error - if "America/New_York" in str(data.dtype): - # pandas is new enough that it has datetime64 with timezone dtype - assert v.dtype == "object" + if v.dtype == np.dtype("O"): + import dask.array as da + + assert isinstance(v.data, da.Array) + else: + assert v.dtype == data.dtype def test_multiindex(self): idx = pd.MultiIndex.from_product([list("abc"), [0, 1]]) @@ -1592,14 +1595,6 @@ def test_pandas_categorical_dtype(self): print(v) # should not error assert pd.api.types.is_extension_array_dtype(v.dtype) - def test_pandas_categorical_no_chunk(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - v = self.cls("x", data) - with pytest.raises( - ValueError, match=r".*was found to be a Pandas ExtensionArray.*" - ): - v.chunk((5,)) - def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) @@ -2412,10 +2407,17 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) + @pytest.mark.skip(reason="dask doesn't support extension arrays") + def test_pandas_period_index(self): + super().test_pandas_period_index() + + @pytest.mark.skip(reason="dask doesn't support extension arrays") + def test_pandas_datetime64_with_tz(self): + super().test_pandas_datetime64_with_tz() + + @pytest.mark.skip(reason="dask doesn't support extension arrays") def test_pandas_categorical_dtype(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): - self.cls("x", data) + super().test_pandas_categorical_dtype() @requires_sparse @@ -3021,7 +3023,7 @@ def test_datetime_conversion(values, unit) -> None: # todo: check for redundancy (suggested per review) dims = ["time"] if isinstance(values, np.ndarray | pd.Index | pd.Series) else [] var = Variable(dims, values) - if var.dtype.kind == "M": + if var.dtype.kind == "M" and isinstance(var.dtype, np.dtype): assert var.dtype == np.dtype(f"datetime64[{unit}]") else: # The only case where a non-datetime64 dtype can occur currently is in @@ -3063,8 +3065,12 @@ def test_pandas_two_only_datetime_conversion_warnings( # todo: check for redundancy (suggested per review) var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] - if var.dtype.kind == "M": + # we internally convert series to numpy representations to avoid too much nastiness with extension arrays + # when calling data.array e.g., with NumpyExtensionArrays + if isinstance(data, pd.Series): assert var.dtype == np.dtype("datetime64[s]") + elif var.dtype.kind == "M": + assert var.dtype == dtype else: # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware