Skip to content
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs):
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)

func = getattr(np, name)
if hasattr(values, "__array_namespace__"):
xp = values.__array_namespace__()
func = getattr(xp, name)
else:
func = getattr(np, name)

try:
with warnings.catch_warnings():
Expand Down
43 changes: 43 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ def as_indexable(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down Expand Up @@ -1288,6 +1290,47 @@ def __init__(self, array):
self.array = array


class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""

__slots__ = ("array",)

def __init__(self, array):
if not hasattr(array, "__array_namespace__"):
raise TypeError(
"ArrayApiIndexingAdapter must wrap an object that "
"implements the __array_namespace__ protocol"
)
self.array = array

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value
else:
if not isinstance(key, VectorizedIndexer):
raise TypeError(f"Unrecognized indexer: {key}")
raise TypeError("Vectorized indexing is not supported")

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
else:
if not isinstance(key, VectorizedIndexer):
raise TypeError(f"Unrecognized indexer: {key}")
raise TypeError("Vectorized indexing is not supported")

def transpose(self, order):
xp = self.array.__array_namespace__()
return xp.permute_dims(self.array, order)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down
7 changes: 5 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool:
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and hasattr(value, "__array_function__")
and hasattr(value, "__array_ufunc__")
and (
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__")
)
)


Expand Down Expand Up @@ -298,6 +300,7 @@ def _is_scalar(value, include_0d):
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
or hasattr(value, "__array_namespace__")
)
)

Expand Down
4 changes: 3 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def as_compatible_data(data, fastpath=False):
else:
data = np.asarray(data)

if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"):
if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return data

# validate whether the data is valid data types.
Expand Down
48 changes: 48 additions & 0 deletions xarray/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy.array_api as xp
import pytest
from numpy.array_api._array_object import Array

import xarray as xr
from xarray.testing import assert_equal

np = pytest.importorskip("numpy", minversion="1.22")


@pytest.fixture
def arrays():
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
assert isinstance(xp_arr.data, Array)
return np_arr, xp_arr


def test_arithmetic(arrays):
np_arr, xp_arr = arrays
expected = np_arr + 7
actual = xp_arr + 7
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_aggregation(arrays):
np_arr, xp_arr = arrays
expected = np_arr.sum(skipna=False)
actual = xp_arr.sum(skipna=False)
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_indexing(arrays):
np_arr, xp_arr = arrays
expected = np_arr[:, 0]
actual = xp_arr[:, 0]
assert isinstance(actual.data, Array)
assert_equal(actual, expected)


def test_reorganizing_operation(arrays):
np_arr, xp_arr = arrays
expected = np_arr.transpose()
actual = xp_arr.transpose()
assert isinstance(actual.data, Array)
assert_equal(actual, expected)