-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
jax.numpy.__array_namespace_info__().default_device()
returns None.
It should return a Device object (read data-apis/array-api#835 for controversy on which device should be returned).
jax.numpy.__array_namespace_info__().devices()
is an alias to jax.devices()
.
While this is not explicitly spelled out by the Array API (https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.devices.html#devices), I believe it is misleading as it lacks the backend
optional parameter that jax.devices
offers. I think it should return a list of all devices available at the moment of calling it; e.g. when CUDA is available, it should return [CpuDevice(id=0), CudaDevice(id=0)]
.
System info (python version, jaxlib version, accelerator, etc.)
jax-0.5.2
jaxlib-0.5.2-cuda126
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working