Skip to content

Commit 80f496f

Browse files
committed
clean-up PandasIndexingAdapter dtype handling
Prevent numpy.dtype conversions or castings implemented in various places, gather the logic into one method.
1 parent 06a3b92 commit 80f496f

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

xarray/core/indexing.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,18 +1779,25 @@ def __init__(
17791779
def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override]
17801780
return self._dtype
17811781

1782+
def _get_numpy_dtype(self, dtype: np.typing.DTypeLike | None = None) -> np.dtype:
1783+
if dtype is None:
1784+
if is_valid_numpy_dtype(self.dtype):
1785+
return cast(np.dtype, self.dtype)
1786+
else:
1787+
return get_valid_numpy_dtype(self.array)
1788+
else:
1789+
return np.dtype(dtype)
1790+
17821791
def __array__(
17831792
self,
17841793
dtype: np.typing.DTypeLike | None = None,
17851794
/,
17861795
*,
17871796
copy: bool | None = None,
17881797
) -> np.ndarray:
1789-
if dtype is None and is_valid_numpy_dtype(self.dtype):
1790-
dtype = cast(np.dtype, self.dtype)
1791-
else:
1792-
dtype = get_valid_numpy_dtype(self.array)
1798+
dtype = self._get_numpy_dtype(dtype)
17931799
array = self.array
1800+
17941801
if isinstance(array, pd.PeriodIndex):
17951802
with suppress(AttributeError):
17961803
# this might not be public API
@@ -1830,10 +1837,8 @@ def _convert_scalar(self, item) -> np.ndarray:
18301837
# numpy fails to convert pd.Timestamp to np.datetime64[ns]
18311838
item = np.asarray(item.to_datetime64())
18321839
elif self.dtype != object:
1833-
dtype = self.dtype
1834-
if pd.api.types.is_extension_array_dtype(dtype):
1835-
dtype = get_valid_numpy_dtype(self.array)
1836-
item = np.asarray(item, dtype=cast(np.dtype, dtype))
1840+
dtype = self._get_numpy_dtype()
1841+
item = np.asarray(item, dtype=dtype)
18371842

18381843
# as for numpy.ndarray indexing, we always want the result to be
18391844
# a NumPy array.
@@ -1897,7 +1902,9 @@ def copy(self, deep: bool = True) -> Self:
18971902
def nbytes(self) -> int:
18981903
if pd.api.types.is_extension_array_dtype(self.dtype):
18991904
return self.array.nbytes
1900-
return cast(np.dtype, self.dtype).itemsize * len(self.array)
1905+
1906+
dtype = self._get_numpy_dtype()
1907+
return dtype.itemsize * len(self.array)
19011908

19021909

19031910
class PandasMultiIndexingAdapter(PandasIndexingAdapter):
@@ -2073,7 +2080,7 @@ def _index_get(
20732080
if isinstance(result, pd.IntervalIndex):
20742081
return type(self)(result, dtype=self.dtype)
20752082
elif isinstance(result, pd.Interval):
2076-
return np.array([result.left, result.right])
2083+
return np.array([result.left, result.right], dtype=self._get_numpy_dtype())
20772084
else:
20782085
return self._convert_scalar(result)
20792086

0 commit comments

Comments
 (0)