Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import numpy as np

from pandas._libs.missing import NAType # noqa: F401

# To prevent import cycles place any internal imports in the branch below
# and use a string literal forward reference to it in subsequent types
# https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
Expand Down
17 changes: 15 additions & 2 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Optional, Sequence, Type, TypeVar, Union, overload

import numpy as np

Expand All @@ -25,6 +25,7 @@
NDArrayBackedExtensionArrayT = TypeVar(
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
)
EAScalarOrMissing = object # both scalar value and na_value can be any type


class NDArrayBackedExtensionArray(ExtensionArray):
Expand Down Expand Up @@ -214,9 +215,21 @@ def __setitem__(self, key, value):
def _validate_setitem_value(self, value):
return value

@overload
# error: Overloaded function signatures 1 and 2 overlap with incompatible
# return types [misc]
def __getitem__(self, key: int) -> EAScalarOrMissing: # type: ignore[misc]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can also return NDArrayBackedExtensionArrayT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate. is this for 2d EA?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NDArrayBackedExtensionArray supports 2D, exactly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it also supports a tuple indexer? and I guess from that NDArrayBackedExtensionArray can't support nested data as it uses is_scalar checks.

so also need to change EAScalarOrMissing = object # both scalar value and na_value can be any type -> ScalarOrScalarMissing = Scalar # both values and na_value must be scalars ?

Copy link
Member

@jbrockmendel jbrockmendel Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it also supports a tuple indexer?

Yes. For that matter i think even 1D will support 1-tuples

and I guess from that NDArrayBackedExtensionArray can't support nested data as it uses is_scalar checks.

PandasArray can have object dtype, and Categorical can hold tuples. We never see 2D versions of those in practice though (yet).

...

@overload
def __getitem__(
self: NDArrayBackedExtensionArrayT, key: Union[slice, np.ndarray]
) -> NDArrayBackedExtensionArrayT:
...

def __getitem__(
self: NDArrayBackedExtensionArrayT, key: Union[int, slice, np.ndarray]
) -> Union[NDArrayBackedExtensionArrayT, Any]:
) -> Union[NDArrayBackedExtensionArrayT, EAScalarOrMissing]:
if lib.is_integer(key):
# fast-path
result = self._ndarray[key]
Expand Down
14 changes: 13 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -51,6 +52,7 @@
_extension_array_shared_docs: Dict[str, str] = dict()

ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray")
EAScalarOrMissing = object # both scalar value and na_value can be any type


class ExtensionArray:
Expand Down Expand Up @@ -256,9 +258,19 @@ def _from_factorized(cls, values, original):
# Must be a Sequence
# ------------------------------------------------------------------------

@overload
# error: Overloaded function signatures 1 and 2 overlap with incompatible
# return types [misc]
def __getitem__(self, item: int) -> EAScalarOrMissing: # type: ignore[misc]
...

@overload
def __getitem__(self, item: Union[slice, np.ndarray]) -> ExtensionArray:
...

def __getitem__(
self, item: Union[int, slice, np.ndarray]
) -> Union[ExtensionArray, Any]:
) -> Union[ExtensionArray, EAScalarOrMissing]:
"""
Select a subset of self.

Expand Down
15 changes: 13 additions & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TypeVar,
Union,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -266,9 +267,19 @@ def __array__(self, dtype=None) -> np.ndarray:
return np.array(list(self), dtype=object)
return self._ndarray

@overload
def __getitem__(self, key: int) -> DTScalarOrNaT:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, can still return DatetimeLikeArrayT

...

@overload
def __getitem__(
self: DatetimeLikeArrayT, key: Union[slice, np.ndarray]
) -> DatetimeLikeArrayT:
...

def __getitem__(
self, key: Union[int, slice, np.ndarray]
) -> Union[DatetimeLikeArrayMixin, DTScalarOrNaT]:
self: DatetimeLikeArrayT, key: Union[int, slice, np.ndarray]
) -> Union[DatetimeLikeArrayT, DTScalarOrNaT]:
"""
This getitem defers to the underlying array, which by-definition can
only handle list-likes, slices, and integer scalars
Expand Down
8 changes: 3 additions & 5 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, time, timedelta, tzinfo
from typing import Optional, Union, cast
from typing import Optional, Union
import warnings

import numpy as np
Expand Down Expand Up @@ -444,11 +444,9 @@ def _generate_range(
)

if not left_closed and len(index) and index[0] == start:
# TODO: overload DatetimeLikeArrayMixin.__getitem__
index = cast(DatetimeArray, index[1:])
index = index[1:]
if not right_closed and len(index) and index[-1] == end:
# TODO: overload DatetimeLikeArrayMixin.__getitem__
index = cast(DatetimeArray, index[:-1])
index = index[:-1]

dtype = tz_to_dtype(tz)
return cls._simple_new(index.asi8, freq=freq, dtype=dtype)
Expand Down
27 changes: 24 additions & 3 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)

import numpy as np

from pandas._libs import lib, missing as libmissing
from pandas._typing import Scalar
from pandas._typing import NAType, Scalar
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly, doc

Expand All @@ -30,6 +39,8 @@


BaseMaskedArrayT = TypeVar("BaseMaskedArrayT", bound="BaseMaskedArray")
# scalar value is a Python scalar, missing value is pd.NA
ScalarOrNAType = Union[Scalar, NAType]


class BaseMaskedDtype(ExtensionDtype):
Expand Down Expand Up @@ -102,9 +113,19 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
def dtype(self) -> BaseMaskedDtype:
raise AbstractMethodError(self)

@overload
# error: Overloaded function signatures 1 and 2 overlap with incompatible return
# types [misc]
def __getitem__(self, item: int) -> ScalarOrNAType: # type: ignore[misc]
...

@overload
def __getitem__(self, item: Union[slice, np.ndarray]) -> BaseMaskedArray:
...

def __getitem__(
self, item: Union[int, slice, np.ndarray]
) -> Union[BaseMaskedArray, Any]:
) -> Union[BaseMaskedArray, ScalarOrNAType]:
if is_integer(item):
if self._mask[item]:
return self.dtype.na_value
Expand Down
18 changes: 5 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3222,14 +3222,8 @@ def _get_nearest_indexer(self, target: "Index", limit, tolerance) -> np.ndarray:
right_indexer = self.get_indexer(target, "backfill", limit=limit)

target_values = target._values
# error: Unsupported left operand type for - ("ExtensionArray")
left_distances = np.abs(
self._values[left_indexer] - target_values # type: ignore[operator]
)
# error: Unsupported left operand type for - ("ExtensionArray")
right_distances = np.abs(
self._values[right_indexer] - target_values # type: ignore[operator]
)
left_distances = np.abs(self._values[left_indexer] - target_values)
right_distances = np.abs(self._values[right_indexer] - target_values)

op = operator.lt if self.is_monotonic_increasing else operator.le
indexer = np.where(
Expand All @@ -3248,8 +3242,7 @@ def _filter_indexer_tolerance(
indexer: np.ndarray,
tolerance,
) -> np.ndarray:
# error: Unsupported left operand type for - ("ExtensionArray")
distance = abs(self._values[indexer] - target) # type: ignore[operator]
distance = abs(self._values[indexer] - target)
indexer = np.where(distance <= tolerance, indexer, -1)
return indexer

Expand Down Expand Up @@ -4546,9 +4539,8 @@ def asof_locs(self, where: "Index", mask) -> np.ndarray:

result = np.arange(len(self))[mask].take(locs)

# TODO: overload return type of ExtensionArray.__getitem__
first_value = cast(Any, self._values[mask.argmax()])
result[(locs == 0) & (where._values < first_value)] = -1
first = mask.argmax()
result[(locs == 0) & (where._values < self._values[first])] = -1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverting change from previous PR


return result

Expand Down
7 changes: 2 additions & 5 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Any, cast
from typing import Any

import numpy as np

Expand Down Expand Up @@ -673,10 +673,7 @@ def difference(self, other, sort=None):

if self.equals(other):
# pass an empty PeriodArray with the appropriate dtype

# TODO: overload DatetimeLikeArrayMixin.__getitem__
values = cast(PeriodArray, self._data[:0])
return type(self)._simple_new(values, name=self.name)
return type(self)._simple_new(self._data[:0], name=self.name)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverting change from previous PR


if is_object_dtype(other):
return self.astype(object).difference(other).astype(self.dtype)
Expand Down
1 change: 1 addition & 0 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def split_and_operate(
-------
list of blocks
"""
assert isinstance(self.values, np.ndarray)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to ensure split_and_operate is not called from EA block (or base method not overridden)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move assert after ndim check

if mask is None:
mask = np.broadcast_to(True, shape=self.shape)

Expand Down