Skip to content

Array API default_device() and devices() are incorrect #27606

@crusaderky

Description

@crusaderky

Description

jax.numpy.__array_namespace_info__().default_device() returns None.
It should return a Device object (read data-apis/array-api#835 for controversy on which device should be returned).

jax.numpy.__array_namespace_info__().devices() is an alias to jax.devices().
While this is not explicitly spelled out by the Array API (https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.devices.html#devices), I believe it is misleading as it lacks the backend optional parameter that jax.devices offers. I think it should return a list of all devices available at the moment of calling it; e.g. when CUDA is available, it should return [CpuDevice(id=0), CudaDevice(id=0)].

System info (python version, jaxlib version, accelerator, etc.)

jax-0.5.2
jaxlib-0.5.2-cuda126

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions