Skip to content

Commit 433dcd5

Browse files
authored
ENH/TST: Add BaseUnaryOpsTests tests for ArrowExtensionArray (#47711)
1 parent 96b036c commit 433dcd5

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,20 @@ def __arrow_array__(self, type=None):
307307
"""Convert myself to a pyarrow ChunkedArray."""
308308
return self._data
309309

310+
def __invert__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
311+
if pa_version_under2p0:
312+
raise NotImplementedError("__invert__ not implement for pyarrow < 2.0")
313+
return type(self)(pc.invert(self._data))
314+
315+
def __neg__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
316+
return type(self)(pc.negate_checked(self._data))
317+
318+
def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
319+
return type(self)(self._data)
320+
321+
def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
322+
return type(self)(pc.abs_checked(self._data))
323+
310324
def _cmp_method(self, other, op):
311325
from pandas.arrays import BooleanArray
312326

pandas/tests/extension/test_arrow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,24 @@ def test_EA_types(self, engine, data, request):
12211221
super().test_EA_types(engine, data)
12221222

12231223

1224+
class TestBaseUnaryOps(base.BaseUnaryOpsTests):
1225+
@pytest.mark.xfail(
1226+
pa_version_under2p0,
1227+
raises=NotImplementedError,
1228+
reason="pyarrow.compute.invert not supported in pyarrow<2.0",
1229+
)
1230+
def test_invert(self, data, request):
1231+
pa_dtype = data.dtype.pyarrow_dtype
1232+
if not pa.types.is_boolean(pa_dtype):
1233+
request.node.add_marker(
1234+
pytest.mark.xfail(
1235+
raises=pa.ArrowNotImplementedError,
1236+
reason=f"pyarrow.compute.invert does support {pa_dtype}",
1237+
)
1238+
)
1239+
super().test_invert(data)
1240+
1241+
12241242
class TestBaseMethods(base.BaseMethodsTests):
12251243
@pytest.mark.parametrize("periods", [1, -2])
12261244
def test_diff(self, data, periods, request):

0 commit comments

Comments
 (0)