-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Goal: make JAX support https://data-apis.org/array-api/latest/
Related to #19246
TODO
-
Initial Implementation
- Add initial implementation in
jax.experimental.array_api
Initial implementation of the Python Array API standard #16099 - Add CI test based on https://github.com/data-apis/array-api-tests Initial implementation of the Python Array API standard #16099
- Add smoketest for normal CI runs [array-api] add simple smoketest target for standard CI testing #18685
- Enable
fft_tests
(requires waiting on upstream test fixes)
- Add initial implementation in
-
JAX API fixes
- Add JAX support for scalar boolean indexing jax.numpy: implement scalar boolean indexing #19722 Support scalar boolean indices in arr.at[idx].set(vals) #21305
- Fix NaN identity issue within
unique
jnp.unique: add support for the equal_nan keyword #19090 - Add
descending
argument tosort
andargsort
[array api] add stable & descending params to jnp.sort & jnp.argsort #19201
-
Make
jax.Array
conform to the API spec- Deprecate
device()
method Deprecate the device() method of JAX arrays #18730 - Add
device
property (afterdevice()
method is removed; ~March 2024) [array API] add device property & to_device method #22597 - Add
to_device()
method [array API] add device property & to_device method #22597 - Add
device
keyword tozeros
,ones
,arange
, etc. (lax.full: add sharding argument #19445, lax.full_like: add sharding argument #19466, jax.numpy: support device argument for full, empty, zeros, ones #19470, jnp.full_like & co: support device parameter #19504) - Add
__array_namespace__
property [array API] move api metadata into jax.numpy namespace #22734
- Deprecate
-
Add Array API functions to the standard
jax.numpy
namespace-
jnp.bool
[array API] add jnp.bool #19403 -
jnp.isdtype
Add jnp.isdtype function, following np.isdtype in NumPy 2.0 #19400 -
jnp.astype
Add jax.numpy.astype function #18757 -
unique_all
,unique_counts
,unique_inverse
,unique_values
array api: add unique_* interfaces #19088 -
concat
[array api] add jax.numpy.concat #19323 -
permute_dims
[array api] add jax.numpy.permute_dims function #19244 -
acos
,acosh
,asin
,asinh
,atan
,atanh
,atan2
jax.numpy: add trig aliases acos(h), asin(h), atan(h), atan2 #19054 -
bitwise_left_shift
,bitwise_right_shift
,bitwise_invert
array api: add jnp.bitwise_* aliases #19278 -
copy
keyword argument forjnp.asarray
[array API] support copy argument to jnp.asarray #19186 -
jnp.linalg
:-
diagonal
[array API] add jnp.linalg.diagonal #19321 -
cross
array api: add jnp.linalg.cross & jnp.linalg.outer #18928 -
matmul
jnp.linalg: add matmul, tensordot, & svdvals #19042 -
matrix_norm
jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose #19005 -
matrix_transpose
jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose #19005 -
outer
array api: add jnp.linalg.cross & jnp.linalg.outer #18928 -
svdvals
jnp.linalg: add matmul, tensordot, & svdvals #19042 -
tensordot
jnp.linalg: add matmul, tensordot, & svdvals #19042 -
vecdot
jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose #19005 -
vector_norm
jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose #19005 -
eigh
returnsNamedTuple
[array api] return NamedTuple from np.linalg APIs #19347 -
qr
returnsNamedTuple
[array api] return NamedTuple from np.linalg APIs #19347 -
slogdet
returnsNamedTuple
[array api] return NamedTuple from np.linalg APIs #19347 -
svd
returnsNamedTuple
[array api] return NamedTuple from np.linalg APIs #19347 -
cholesky
upper
argument jnp.linalg.cholesky: add upper argument #19606 -
solve
vectorization update jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1 #19674
-
-
-
Update to v2023.12 APIs and behavior (see changelog)
-
Consider removing
jax.experimental.array_api
and makejax.numpy
itself fully-compliant with the array API.
adonath, jonas-eschle and Suchismit4leofang and rgommerslucascolley, mtsokol and NeilGirdhar
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request