Skip to content

Commit a04d857

Browse files
authored
Adds copy parameter to __array__ for numpy 2.0 (#9393)
1 parent 93e410b commit a04d857

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ Bug fixes
5252
- Fix issue with passing parameters to ZarrStore.open_store when opening
5353
datatree in zarr format (:issue:`9376`, :pull:`9377`).
5454
By `Alfonso Ladino <https://github.com/aladinor>`_
55+
- Fix deprecation warning that was raised when calling ``np.array`` on an ``xr.DataArray``
56+
in NumPy 2.0 (:issue:`9312`, :pull:`9393`)
57+
By `Andrew Scherer <https://github.com/andrew-s28>`_.
5558

5659
Documentation
5760
~~~~~~~~~~~~~

xarray/core/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,22 @@ def __int__(self: Any) -> int:
162162
def __complex__(self: Any) -> complex:
163163
return complex(self.values)
164164

165-
def __array__(self: Any, dtype: DTypeLike | None = None) -> np.ndarray:
166-
return np.asarray(self.values, dtype=dtype)
165+
def __array__(
166+
self: Any, dtype: DTypeLike | None = None, copy: bool | None = None
167+
) -> np.ndarray:
168+
if not copy:
169+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
170+
copy = None
171+
elif np.lib.NumpyVersion(np.__version__) <= "1.28.0":
172+
copy = False
173+
else:
174+
# 2.0.0 dev versions, handle cases where copy may or may not exist
175+
try:
176+
np.array([1]).__array__(copy=None)
177+
copy = None
178+
except TypeError:
179+
copy = False
180+
return np.array(self.values, dtype=dtype, copy=copy)
167181

168182
def __repr__(self) -> str:
169183
return formatting.array_repr(self)

xarray/tests/test_dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7065,6 +7065,14 @@ def test_from_numpy(self) -> None:
70657065
np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
70667066
np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))
70677067

7068+
def test_to_numpy(self) -> None:
7069+
arr = np.array([1, 2, 3])
7070+
da = xr.DataArray(arr, dims="x", coords={"lat": ("x", [4, 5, 6])})
7071+
7072+
with assert_no_warnings():
7073+
np.testing.assert_equal(np.asarray(da), arr)
7074+
np.testing.assert_equal(np.array(da), arr)
7075+
70687076
@requires_dask
70697077
def test_from_dask(self) -> None:
70707078
da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})

0 commit comments

Comments
 (0)