Skip to content

Commit 1bc5769

Browse files
authored
Declare Array API 2023.12 support (#651)
* Declare Array API 2023.12 support * Add (unused) device parameter to astype * Fix unstack edge cases * Bumpy array API tests commit to test against * Change dtype=None behavior in sum/prod following data-apis/array-api#744 * Fix unstack edge cases * Update array api test skips file * Only support integral values for `repeats` in `repeat`
1 parent fc2201e commit 1bc5769

File tree

10 files changed

+34
-32
lines changed

10 files changed

+34
-32
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
uses: actions/checkout@v3
3333
with:
3434
repository: data-apis/array-api-tests
35-
ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08
35+
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
3636
path: array-api-tests
3737
submodules: "true"
3838
- name: Set up Python ${{ matrix.python-version }}
@@ -90,8 +90,7 @@ jobs:
9090
array_api_tests/test_has_names.py
9191
9292
# signatures of items not implemented
93-
array_api_tests/test_signatures.py::test_func_signature[std]
94-
array_api_tests/test_signatures.py::test_func_signature[var]
93+
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
9594
array_api_tests/test_signatures.py::test_func_signature[unique_all]
9695
array_api_tests/test_signatures.py::test_func_signature[unique_counts]
9796
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
@@ -110,13 +109,15 @@ jobs:
110109
array_api_tests/test_linalg.py::test_vecdot
111110
# (getitem with negative step size is not implemented)
112111
array_api_tests/test_array_object.py::test_getitem
112+
# test_searchsorted depends on sort which is not implemented
113+
array_api_tests/test_searching_functions.py::test_searchsorted
113114
114115
# not implemented
115116
array_api_tests/test_array_object.py::test_setitem
116117
array_api_tests/test_array_object.py::test_setitem_masking
118+
array_api_tests/test_manipulation_functions.py::test_repeat
117119
array_api_tests/test_sorting_functions.py
118-
array_api_tests/test_statistical_functions.py::test_std
119-
array_api_tests/test_statistical_functions.py::test_var
120+
array_api_tests/test_statistical_functions.py::test_cumulative_sum
120121
121122
# finfo(float32).eps returns float32 but should return float
122123
array_api_tests/test_data_type_functions.py::test_finfo[float32]
@@ -126,6 +127,9 @@ jobs:
126127
# https://github.com/numpy/numpy/issues/18881
127128
array_api_tests/test_creation_functions.py::test_linspace
128129
130+
# https://github.com/numpy/numpy/issues/20870
131+
#array_api_tests/test_data_type_functions.py::test_can_cast
132+
129133
EOF
130134
131135
pytest -v -rxXfEA --hypothesis-max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --hypothesis-disable-deadline

api_status.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
## Array API Coverage Implementation Status
22

3-
Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.
4-
5-
Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
3+
Cubed supports version [2023.12](https://data-apis.org/array-api/2023.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html) are *not* supported.
64

75
This table shows which parts of the the [Array API](https://data-apis.org/array-api/latest/API_specification/index.html) have been implemented in Cubed, and which ones are missing. The version column shows the version when the feature was added to the standard, for version 2022.12 or later.
86

@@ -61,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
6159
| | `expand_dims` | :white_check_mark: | | |
6260
| | `flip` | :white_check_mark: | | |
6361
| | `permute_dims` | :white_check_mark: | | |
64-
| | `repeat` | :white_check_mark: | | |
62+
| | `repeat` | :white_check_mark: | 2023.12 | |
6563
| | `reshape` | :white_check_mark: | | Partial implementation |
6664
| | `roll` | :white_check_mark: | | |
6765
| | `squeeze` | :white_check_mark: | | |

cubed/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
# Array API
4848

49-
__array_api_version__ = "2022.12"
49+
__array_api_version__ = "2023.12"
5050

5151
from .array_api.inspection import __array_namespace_info__
5252

cubed/array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = []
22

3-
__array_api_version__ = "2022.12"
3+
__array_api_version__ = "2023.12"
44

55
from .inspection import __array_namespace_info__
66

cubed/array_api/array_object.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,11 @@ def __abs__(self, /):
361361
return elemwise(nxp.abs, self, dtype=dtype)
362362

363363
def __array_namespace__(self, /, *, api_version=None):
364-
if api_version is not None and api_version not in ("2021.12", "2022.12"):
364+
if api_version is not None and api_version not in (
365+
"2021.12",
366+
"2022.12",
367+
"2023.12",
368+
):
365369
raise ValueError(f"Unrecognized array API version: {api_version!r}")
366370
import cubed.array_api as array_api
367371

cubed/array_api/data_type_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from cubed.core import CoreArray, map_blocks
33

44

5-
def astype(x, dtype, /, *, copy=True):
5+
def astype(x, dtype, /, *, copy=True, device=None):
66
if not copy and dtype == x.dtype:
77
return x
88
return map_blocks(_astype, x, dtype=dtype, astype_dtype=dtype)

cubed/array_api/manipulation_functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ def permute_dims(x, /, axes):
385385

386386

387387
def repeat(x, repeats, /, *, axis=0):
388+
if not isinstance(repeats, int):
389+
raise ValueError("repeat only supports integral values for `repeats`")
390+
388391
if axis is None:
389392
x = flatten(x)
390393
axis = 0
@@ -599,8 +602,10 @@ def unstack(x, /, *, axis=0):
599602

600603
n_arrays = x.shape[axis]
601604

602-
if n_arrays == 1:
603-
return (x,)
605+
if n_arrays == 0:
606+
return ()
607+
elif n_arrays == 1:
608+
return (squeeze(x, axis=axis),)
604609

605610
shape = x.shape[:axis] + x.shape[axis + 1 :]
606611
dtype = x.dtype

cubed/array_api/statistical_functions.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
_real_numeric_dtypes,
88
_signed_integer_dtypes,
99
_unsigned_integer_dtypes,
10-
complex64,
11-
complex128,
12-
float32,
13-
float64,
1410
int64,
1511
uint64,
1612
)
@@ -128,10 +124,6 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
128124
dtype = int64
129125
elif x.dtype in _unsigned_integer_dtypes:
130126
dtype = uint64
131-
elif x.dtype == float32:
132-
dtype = float64
133-
elif x.dtype == complex64:
134-
dtype = complex128
135127
else:
136128
dtype = x.dtype
137129
extra_func_kwargs = dict(dtype=dtype)
@@ -169,10 +161,6 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
169161
dtype = int64
170162
elif x.dtype in _unsigned_integer_dtypes:
171163
dtype = uint64
172-
elif x.dtype == float32:
173-
dtype = float64
174-
elif x.dtype == complex64:
175-
dtype = complex128
176164
else:
177165
dtype = x.dtype
178166
extra_func_kwargs = dict(dtype=dtype)

cubed/tests/test_array_api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,15 @@ def test_unstack(spec, executor, chunks):
722722
assert_array_equal(cu, np.full((4, 6), 3))
723723

724724

725-
def test_unstack_noop(spec):
725+
def test_unstack_zero_arrays(spec):
726+
a = xp.full((0, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
727+
assert xp.unstack(a) == ()
728+
729+
730+
def test_unstack_single_array(spec):
726731
a = xp.full((1, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
727732
(b,) = xp.unstack(a)
728-
assert a is b
733+
assert_array_equal(b.compute(), np.full((4, 6), 1))
729734

730735

731736
# Searching functions

docs/array-api.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Python Array API
22

3-
Cubed implements version 2022.12 of the [Python Array API standard](https://data-apis.org/array-api/2022.12/index.html) in `cubed.array_api`, with a few exceptions listed on the [coverage status](https://github.com/cubed-dev/cubed/blob/main/api_status.md) page. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.
4-
5-
Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
3+
Cubed implements version 2023.12 of the [Python Array API standard](https://data-apis.org/array-api/2023.12/index.html) in `cubed.array_api`, with a few exceptions listed on the [coverage status](https://github.com/cubed-dev/cubed/blob/main/api_status.md) page. The [Fourier transform functions](https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html) are *not* supported.
64

75
## Differences between Cubed and the standard
86

0 commit comments

Comments
 (0)