From 6a2642ff51e7d0a587c547f1909407ce685e98e1 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 10 Jan 2024 02:27:05 +0100 Subject: [PATCH] ENH: Implement interpolation for arrow and masked dtypes (#56757) * ENH: Implement interpolation for arrow and masked dtypes * Fixup * Fix typing * Update (cherry picked from commit 5fc2ed2703a1370207f4ebad834e665b6c2ad42f) --- doc/source/whatsnew/v2.2.0.rst | 1 + pandas/core/arrays/arrow/array.py | 40 ++++++++++++++ pandas/core/arrays/masked.py | 54 +++++++++++++++++++ pandas/core/missing.py | 14 +++-- .../tests/frame/methods/test_interpolate.py | 41 ++++++++++++-- 5 files changed, 143 insertions(+), 7 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index e244794664b34..51f3dc025864a 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -343,6 +343,7 @@ Other enhancements - :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`) - :meth:`Series.ffill`, :meth:`Series.bfill`, :meth:`DataFrame.ffill`, and :meth:`DataFrame.bfill` have gained the argument ``limit_area``; 3rd party :class:`.ExtensionArray` authors need to add this argument to the method ``_pad_or_backfill`` (:issue:`56492`) - Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`) +- Implement :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` for :class:`ArrowDtype` and masked dtypes (:issue:`56267`) - Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`) - Implemented :meth:`Series.dt` methods and attributes for :class:`ArrowDtype` with ``pyarrow.duration`` type (:issue:`52284`) - Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 3858ce4cf0ea1..a5ce46ed612f3 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -182,6 +182,7 @@ def floordiv_compat( AxisInt, Dtype, FillnaOptions, + InterpolateOptions, Iterator, NpDtype, NumpySorter, @@ -2048,6 +2049,45 @@ def _maybe_convert_setitem_value(self, value): raise TypeError(msg) from err return value + def interpolate( + self, + *, + method: InterpolateOptions, + axis: int, + index, + limit, + limit_direction, + limit_area, + copy: bool, + **kwargs, + ) -> Self: + """ + See NDFrame.interpolate.__doc__. + """ + # NB: we return type(self) even if copy=False + mask = self.isna() + if self.dtype.kind == "f": + data = self._pa_array.to_numpy() + elif self.dtype.kind in "iu": + data = self.to_numpy(dtype="f8", na_value=0.0) + else: + raise NotImplementedError( + f"interpolate is not implemented for dtype={self.dtype}" + ) + + missing.interpolate_2d_inplace( + data, + method=method, + axis=0, + index=index, + limit=limit, + limit_direction=limit_direction, + limit_area=limit_area, + mask=mask, + **kwargs, + ) + return type(self)(self._box_pa_array(pa.array(data, mask=mask))) + @classmethod def _if_else( cls, diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 545d45e450f3f..234d96e53a67c 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -22,6 +22,7 @@ AxisInt, DtypeObj, FillnaOptions, + InterpolateOptions, NpDtype, PositionalIndexer, Scalar, @@ -98,6 +99,7 @@ NumpySorter, NumpyValueArrayLike, ) + from pandas.core.arrays import FloatingArray from pandas.compat.numpy import function as nv @@ -1491,6 +1493,58 @@ def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): else: return self.dtype.na_value + def interpolate( + self, + *, + method: InterpolateOptions, + axis: int, + index, + limit, + limit_direction, + limit_area, + copy: bool, + **kwargs, + ) -> FloatingArray: + """ + See NDFrame.interpolate.__doc__. + """ + # NB: we return type(self) even if copy=False + if self.dtype.kind == "f": + if copy: + data = self._data.copy() + mask = self._mask.copy() + else: + data = self._data + mask = self._mask + elif self.dtype.kind in "iu": + copy = True + data = self._data.astype("f8") + mask = self._mask.copy() + else: + raise NotImplementedError( + f"interpolate is not implemented for dtype={self.dtype}" + ) + + missing.interpolate_2d_inplace( + data, + method=method, + axis=0, + index=index, + limit=limit, + limit_direction=limit_direction, + limit_area=limit_area, + mask=mask, + **kwargs, + ) + if not copy: + return self # type: ignore[return-value] + if self.dtype.kind == "f": + return type(self)._simple_new(data, mask) # type: ignore[return-value] + else: + from pandas.core.arrays import FloatingArray + + return FloatingArray._simple_new(data, mask) + def _accumulate( self, name: str, *, skipna: bool = True, **kwargs ) -> BaseMaskedArray: diff --git a/pandas/core/missing.py b/pandas/core/missing.py index ff45662d0bdc8..c016aab8ad074 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -349,6 +349,7 @@ def interpolate_2d_inplace( limit_direction: str = "forward", limit_area: str | None = None, fill_value: Any | None = None, + mask=None, **kwargs, ) -> None: """ @@ -396,6 +397,7 @@ def func(yvalues: np.ndarray) -> None: limit_area=limit_area_validated, fill_value=fill_value, bounds_error=False, + mask=mask, **kwargs, ) @@ -440,6 +442,7 @@ def _interpolate_1d( fill_value: Any | None = None, bounds_error: bool = False, order: int | None = None, + mask=None, **kwargs, ) -> None: """ @@ -453,8 +456,10 @@ def _interpolate_1d( ----- Fills 'yvalues' in-place. """ - - invalid = isna(yvalues) + if mask is not None: + invalid = mask + else: + invalid = isna(yvalues) valid = ~invalid if not valid.any(): @@ -531,7 +536,10 @@ def _interpolate_1d( **kwargs, ) - if is_datetimelike: + if mask is not None: + mask[:] = False + mask[preserve_nans] = True + elif is_datetimelike: yvalues[preserve_nans] = NaT.value else: yvalues[preserve_nans] = np.nan diff --git a/pandas/tests/frame/methods/test_interpolate.py b/pandas/tests/frame/methods/test_interpolate.py index e0641fcb65bd3..252b950004bea 100644 --- a/pandas/tests/frame/methods/test_interpolate.py +++ b/pandas/tests/frame/methods/test_interpolate.py @@ -508,8 +508,41 @@ def test_interpolate_empty_df(self): assert result is None tm.assert_frame_equal(df, expected) - def test_interpolate_ea_raise(self): + def test_interpolate_ea(self, any_int_ea_dtype): # GH#55347 - df = DataFrame({"a": [1, None, 2]}, dtype="Int64") - with pytest.raises(NotImplementedError, match="does not implement"): - df.interpolate() + df = DataFrame({"a": [1, None, None, None, 3]}, dtype=any_int_ea_dtype) + orig = df.copy() + result = df.interpolate(limit=2) + expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="Float64") + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(df, orig) + + @pytest.mark.parametrize( + "dtype", + [ + "Float64", + "Float32", + pytest.param("float32[pyarrow]", marks=td.skip_if_no("pyarrow")), + pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")), + ], + ) + def test_interpolate_ea_float(self, dtype): + # GH#55347 + df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype) + orig = df.copy() + result = df.interpolate(limit=2) + expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype=dtype) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(df, orig) + + @pytest.mark.parametrize( + "dtype", + ["int64", "uint64", "int32", "int16", "int8", "uint32", "uint16", "uint8"], + ) + def test_interpolate_arrow(self, dtype): + # GH#55347 + pytest.importorskip("pyarrow") + df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype + "[pyarrow]") + result = df.interpolate(limit=2) + expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="float64[pyarrow]") + tm.assert_frame_equal(result, expected)