Skip to content

Commit 9f8d47c

Browse files
tomwhiteIllviljandcherianTomNicholas
authored
Support NumPy array API (experimental) (#6804)
* Support NumPy array API (experimental) * Address feedback * Update xarray/core/indexing.py Co-authored-by: Illviljan <[email protected]> * Update xarray/core/indexing.py Co-authored-by: Illviljan <[email protected]> * Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <[email protected]> * Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <[email protected]> * Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <[email protected]> * Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <[email protected]> * Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <[email protected]> * Fix import order * Fix import order * update whatsnew Co-authored-by: Illviljan <[email protected]> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Thomas Nicholas <[email protected]>
1 parent dabd977 commit 9f8d47c

File tree

6 files changed

+112
-4
lines changed

6 files changed

+112
-4
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ New Features
3030
:py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`,
3131
(:pull:`6702`)
3232
By `Michael Niklas <https://github.com/headtr1ck>`_.
33+
- Experimental support for wrapping any array type that conforms to the python array api standard.
34+
(:pull:`6804`)
35+
By `Tom White <https://github.com/tomwhite>`_.
3336

3437
Deprecations
3538
~~~~~~~~~~~~

xarray/core/duck_array_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs):
329329
if name in ["sum", "prod"]:
330330
kwargs.pop("min_count", None)
331331

332-
func = getattr(np, name)
332+
if hasattr(values, "__array_namespace__"):
333+
xp = values.__array_namespace__()
334+
func = getattr(xp, name)
335+
else:
336+
func = getattr(np, name)
333337

334338
try:
335339
with warnings.catch_warnings():

xarray/core/indexing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,8 @@ def as_indexable(array):
679679
return DaskIndexingAdapter(array)
680680
if hasattr(array, "__array_function__"):
681681
return NdArrayLikeIndexingAdapter(array)
682+
if hasattr(array, "__array_namespace__"):
683+
return ArrayApiIndexingAdapter(array)
682684

683685
raise TypeError(f"Invalid array type: {type(array)}")
684686

@@ -1288,6 +1290,49 @@ def __init__(self, array):
12881290
self.array = array
12891291

12901292

1293+
class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
1294+
"""Wrap an array API array to use explicit indexing."""
1295+
1296+
__slots__ = ("array",)
1297+
1298+
def __init__(self, array):
1299+
if not hasattr(array, "__array_namespace__"):
1300+
raise TypeError(
1301+
"ArrayApiIndexingAdapter must wrap an object that "
1302+
"implements the __array_namespace__ protocol"
1303+
)
1304+
self.array = array
1305+
1306+
def __getitem__(self, key):
1307+
if isinstance(key, BasicIndexer):
1308+
return self.array[key.tuple]
1309+
elif isinstance(key, OuterIndexer):
1310+
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
1311+
key = key.tuple
1312+
value = self.array
1313+
for axis, subkey in reversed(list(enumerate(key))):
1314+
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
1315+
return value
1316+
else:
1317+
if isinstance(key, VectorizedIndexer):
1318+
raise TypeError("Vectorized indexing is not supported")
1319+
else:
1320+
raise TypeError(f"Unrecognized indexer: {key}")
1321+
1322+
def __setitem__(self, key, value):
1323+
if isinstance(key, (BasicIndexer, OuterIndexer)):
1324+
self.array[key.tuple] = value
1325+
else:
1326+
if isinstance(key, VectorizedIndexer):
1327+
raise TypeError("Vectorized indexing is not supported")
1328+
else:
1329+
raise TypeError(f"Unrecognized indexer: {key}")
1330+
1331+
def transpose(self, order):
1332+
xp = self.array.__array_namespace__()
1333+
return xp.permute_dims(self.array, order)
1334+
1335+
12911336
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
12921337
"""Wrap a dask array to support explicit indexing."""
12931338

xarray/core/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool:
263263
hasattr(value, "ndim")
264264
and hasattr(value, "shape")
265265
and hasattr(value, "dtype")
266-
and hasattr(value, "__array_function__")
267-
and hasattr(value, "__array_ufunc__")
266+
and (
267+
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
268+
or hasattr(value, "__array_namespace__")
269+
)
268270
)
269271

270272

@@ -298,6 +300,7 @@ def _is_scalar(value, include_0d):
298300
or not (
299301
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
300302
or hasattr(value, "__array_function__")
303+
or hasattr(value, "__array_namespace__")
301304
)
302305
)
303306

xarray/core/variable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ def as_compatible_data(data, fastpath=False):
237237
else:
238238
data = np.asarray(data)
239239

240-
if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"):
240+
if not isinstance(data, np.ndarray) and (
241+
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
242+
):
241243
return data
242244

243245
# validate whether the data is valid data types.

xarray/tests/test_array_api.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Tuple
2+
3+
import pytest
4+
5+
import xarray as xr
6+
from xarray.testing import assert_equal
7+
8+
np = pytest.importorskip("numpy", minversion="1.22")
9+
10+
import numpy.array_api as xp # isort:skip
11+
from numpy.array_api._array_object import Array # isort:skip
12+
13+
14+
@pytest.fixture
15+
def arrays() -> Tuple[xr.DataArray, xr.DataArray]:
16+
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
17+
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
18+
assert isinstance(xp_arr.data, Array)
19+
return np_arr, xp_arr
20+
21+
22+
def test_arithmetic(arrays) -> None:
23+
np_arr, xp_arr = arrays
24+
expected = np_arr + 7
25+
actual = xp_arr + 7
26+
assert isinstance(actual.data, Array)
27+
assert_equal(actual, expected)
28+
29+
30+
def test_aggregation(arrays) -> None:
31+
np_arr, xp_arr = arrays
32+
expected = np_arr.sum(skipna=False)
33+
actual = xp_arr.sum(skipna=False)
34+
assert isinstance(actual.data, Array)
35+
assert_equal(actual, expected)
36+
37+
38+
def test_indexing(arrays) -> None:
39+
np_arr, xp_arr = arrays
40+
expected = np_arr[:, 0]
41+
actual = xp_arr[:, 0]
42+
assert isinstance(actual.data, Array)
43+
assert_equal(actual, expected)
44+
45+
46+
def test_reorganizing_operation(arrays) -> None:
47+
np_arr, xp_arr = arrays
48+
expected = np_arr.transpose()
49+
actual = xp_arr.transpose()
50+
assert isinstance(actual.data, Array)
51+
assert_equal(actual, expected)

0 commit comments

Comments
 (0)