Skip to content

Commit 613a000

Browse files
committed
[array API] add device property & to_device method
1 parent 13e42ad commit 613a000

File tree

9 files changed

+69
-9
lines changed

9 files changed

+69
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
2525
will be removed in a future release.
2626
* Updated the repr of gpu devices to be more consistent
2727
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
28+
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
29+
part of JAX's [Array API](https://data-apis.org/array-api) support.
2830
* Deprecations
2931
* Removed a number of previously-deprecated internal APIs related to
3032
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,

jax/_src/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ def size(self):
254254
def sharding(self):
255255
return self._sharding
256256

257+
@property
258+
def device(self):
259+
self._check_if_deleted()
260+
if isinstance(self.sharding, SingleDeviceSharding):
261+
return list(self.sharding.device_set)[0]
262+
return self.sharding
263+
257264
@property
258265
def weak_type(self):
259266
return self.aval.weak_type

jax/_src/basearray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from collections.abc import Sequence
2323

2424
# TODO(jakevdp): fix import cycles and define these.
25+
Device = Any
2526
Shard = Any
2627
Sharding = Any
2728

@@ -112,6 +113,15 @@ def is_fully_replicated(self) -> bool:
112113
def sharding(self) -> Sharding:
113114
"""The sharding for the array."""
114115

116+
@property
117+
@abc.abstractmethod
118+
def device(self) -> Device | Sharding:
119+
"""Array API-compatible device attribute.
120+
121+
For single-device arrays, this returns a Device. For sharded arrays, this
122+
returns a Sharding.
123+
"""
124+
115125

116126
Array.__module__ = "jax"
117127

jax/_src/basearray.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ class Array(abc.ABC):
204204
@property
205205
def sharding(self) -> Sharding: ...
206206
@property
207+
def device(self) -> Device | Sharding: ...
208+
@property
207209
def addressable_shards(self) -> Sequence[Shard]: ...
208210
@property
209211
def global_shards(self) -> Sequence[Shard]: ...
@@ -216,6 +218,7 @@ class Array(abc.ABC):
216218
@property
217219
def traceback(self) -> Traceback: ...
218220
def unsafe_buffer_pointer(self) -> int: ...
221+
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...
219222

220223

221224
StaticScalar = Union[

jax/_src/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,15 @@ def sharding(self):
738738
f"The 'sharding' attribute is not available on {self._error_repr()}."
739739
f"{self._origin_msg()}")
740740

741+
@property
742+
def device(self):
743+
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
744+
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
745+
# we raise an AttributeError so that hasattr() and getattr() work as expected.
746+
raise AttributeError(self,
747+
f"The 'device' attribute is not available on {self._error_repr()}."
748+
f"{self._origin_msg()}")
749+
741750
@property
742751
def addressable_shards(self):
743752
raise ConcretizationTypeError(self,

jax/_src/earray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def sharding(self):
8383
phys_sharding = self._data.sharding
8484
return sharding_impls.logical_sharding(self.aval, phys_sharding)
8585

86+
@property
87+
def device(self):
88+
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
89+
return self._data.device
90+
return self.sharding
91+
8692
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
8793

8894
def addressable_data(self, index: int) -> EArray:

jax/_src/numpy/array_methods.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import jax
3333
from jax import lax
3434
from jax.sharding import Sharding
35+
from jax._src import api
3536
from jax._src import core
3637
from jax._src import dtypes
3738
from jax._src.api_util import _ensure_index_tuple
@@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
6768
"""
6869
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6970

71+
def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
72+
stream: int | Any | None = None):
73+
if stream is not None:
74+
raise NotImplementedError("stream argument of array.to_device()")
75+
return api.device_put(arr, device)
76+
7077

7178
def _nbytes(arr: ArrayLike) -> int:
7279
"""Total bytes consumed by the elements of the array."""
@@ -694,6 +701,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
694701
"sum": reductions.sum,
695702
"swapaxes": lax_numpy.swapaxes,
696703
"take": lax_numpy.take,
704+
"to_device": _to_device,
697705
"trace": lax_numpy.trace,
698706
"transpose": _transpose,
699707
"var": reductions.var,

jax/experimental/array_api/_array_methods.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
3131
return jax.experimental.array_api
3232

3333

34-
def _to_device(self, device: xe.Device | Sharding | None, *,
35-
stream: int | Any | None = None):
36-
if stream is not None:
37-
raise NotImplementedError("stream argument of array.to_device()")
38-
return jax.device_put(self, device)
39-
40-
4134
def add_array_object_methods():
4235
# TODO(jakevdp): set on tracers as well?
4336
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
44-
setattr(ArrayImpl, "to_device", _to_device)
45-
setattr(ArrayImpl, "device", property(lambda self: self.sharding))

tests/array_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,30 @@ def test_gspmd_sharding_hash_eq(self):
12761276
self.assertEqual(x1, x2)
12771277
self.assertEqual(hash(x1), hash(x2))
12781278

1279+
def test_device_attr(self):
1280+
# For single-device arrays, x.device returns the device
1281+
x = jnp.ones((2, 10))
1282+
self.assertEqual(x.device, list(x.devices())[0])
1283+
1284+
# For sharded arrays, x.device returns the sharding
1285+
mesh = jtu.create_global_mesh((2,), ('x',))
1286+
sharding = jax.sharding.NamedSharding(mesh, P('x'))
1287+
x = jax.device_put(x, sharding)
1288+
self.assertEqual(x.device, sharding)
1289+
1290+
def test_to_device(self):
1291+
device = jax.devices()[-1]
1292+
mesh = jtu.create_global_mesh((2,), ('x',))
1293+
sharding = jax.sharding.NamedSharding(mesh, P('x'))
1294+
1295+
x = jnp.ones((2, 10))
1296+
1297+
x_device = x.to_device(device)
1298+
x_sharding = x.to_device(sharding)
1299+
1300+
self.assertEqual(x_device.device, device)
1301+
self.assertEqual(x_sharding.device, sharding)
1302+
12791303

12801304
class RngShardingTest(jtu.JaxTestCase):
12811305
# tests that the PRNGs are automatically sharded as expected

0 commit comments

Comments
 (0)