Skip to content

Commit b98c1e0

Browse files
ezhulenevjax authors
authored andcommitted
[xla:cpu] Support for up to 16 sorted inputs
+ enable more jax/lax tests for XLA CPU thunks FUTURE_COPYBARA_INTEGRATE_REVIEW=#22597 from jakevdp:arr-device 613a000 PiperOrigin-RevId: 654865806
1 parent 7792bde commit b98c1e0

File tree

10 files changed

+83
-10
lines changed

10 files changed

+83
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
2525
will be removed in a future release.
2626
* Updated the repr of gpu devices to be more consistent
2727
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
28+
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
29+
part of JAX's [Array API](https://data-apis.org/array-api) support.
2830
* Deprecations
2931
* Removed a number of previously-deprecated internal APIs related to
3032
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,

jax/_src/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ def size(self):
254254
def sharding(self):
255255
return self._sharding
256256

257+
@property
258+
def device(self):
259+
self._check_if_deleted()
260+
if isinstance(self.sharding, SingleDeviceSharding):
261+
return list(self.sharding.device_set)[0]
262+
return self.sharding
263+
257264
@property
258265
def weak_type(self):
259266
return self.aval.weak_type

jax/_src/basearray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from collections.abc import Sequence
2323

2424
# TODO(jakevdp): fix import cycles and define these.
25+
Device = Any
2526
Shard = Any
2627
Sharding = Any
2728

@@ -112,6 +113,15 @@ def is_fully_replicated(self) -> bool:
112113
def sharding(self) -> Sharding:
113114
"""The sharding for the array."""
114115

116+
@property
117+
@abc.abstractmethod
118+
def device(self) -> Device | Sharding:
119+
"""Array API-compatible device attribute.
120+
121+
For single-device arrays, this returns a Device. For sharded arrays, this
122+
returns a Sharding.
123+
"""
124+
115125

116126
Array.__module__ = "jax"
117127

jax/_src/basearray.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ class Array(abc.ABC):
204204
@property
205205
def sharding(self) -> Sharding: ...
206206
@property
207+
def device(self) -> Device | Sharding: ...
208+
@property
207209
def addressable_shards(self) -> Sequence[Shard]: ...
208210
@property
209211
def global_shards(self) -> Sequence[Shard]: ...
@@ -216,6 +218,7 @@ class Array(abc.ABC):
216218
@property
217219
def traceback(self) -> Traceback: ...
218220
def unsafe_buffer_pointer(self) -> int: ...
221+
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...
219222

220223

221224
StaticScalar = Union[

jax/_src/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,15 @@ def sharding(self):
738738
f"The 'sharding' attribute is not available on {self._error_repr()}."
739739
f"{self._origin_msg()}")
740740

741+
@property
742+
def device(self):
743+
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
744+
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
745+
# we raise an AttributeError so that hasattr() and getattr() work as expected.
746+
raise AttributeError(self,
747+
f"The 'device' attribute is not available on {self._error_repr()}."
748+
f"{self._origin_msg()}")
749+
741750
@property
742751
def addressable_shards(self):
743752
raise ConcretizationTypeError(self,

jax/_src/earray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def sharding(self):
8383
phys_sharding = self._data.sharding
8484
return sharding_impls.logical_sharding(self.aval, phys_sharding)
8585

86+
@property
87+
def device(self):
88+
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
89+
return self._data.device
90+
return self.sharding
91+
8692
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
8793

8894
def addressable_data(self, index: int) -> EArray:

jax/_src/numpy/array_methods.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import jax
3333
from jax import lax
3434
from jax.sharding import Sharding
35+
from jax._src import api
3536
from jax._src import core
3637
from jax._src import dtypes
3738
from jax._src.api_util import _ensure_index_tuple
@@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
6768
"""
6869
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6970

71+
def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
72+
stream: int | Any | None = None):
73+
if stream is not None:
74+
raise NotImplementedError("stream argument of array.to_device()")
75+
return api.device_put(arr, device)
76+
7077

7178
def _nbytes(arr: ArrayLike) -> int:
7279
"""Total bytes consumed by the elements of the array."""
@@ -694,6 +701,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
694701
"sum": reductions.sum,
695702
"swapaxes": lax_numpy.swapaxes,
696703
"take": lax_numpy.take,
704+
"to_device": _to_device,
697705
"trace": lax_numpy.trace,
698706
"transpose": _transpose,
699707
"var": reductions.var,

jax/experimental/array_api/_array_methods.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
3131
return jax.experimental.array_api
3232

3333

34-
def _to_device(self, device: xe.Device | Sharding | None, *,
35-
stream: int | Any | None = None):
36-
if stream is not None:
37-
raise NotImplementedError("stream argument of array.to_device()")
38-
return jax.device_put(self, device)
39-
40-
4134
def add_array_object_methods():
4235
# TODO(jakevdp): set on tracers as well?
4336
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
44-
setattr(ArrayImpl, "to_device", _to_device)
45-
setattr(ArrayImpl, "device", property(lambda self: self.sharding))

tests/BUILD

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,10 @@ jax_test(
445445
"gpu": 40,
446446
"tpu": 50,
447447
},
448-
tags = ["noasan"], # Test times out on all backends
448+
tags = [
449+
"noasan", # Test times out on all backends
450+
"test_cpu_thunks",
451+
],
449452
)
450453

451454
jax_test(
@@ -456,6 +459,7 @@ jax_test(
456459
"gpu": 30,
457460
"tpu": 40,
458461
},
462+
tags = ["test_cpu_thunks"],
459463
)
460464

461465
jax_test(
@@ -466,6 +470,7 @@ jax_test(
466470
"gpu": 20,
467471
"tpu": 20,
468472
},
473+
tags = ["test_cpu_thunks"],
469474
)
470475

471476
jax_test(
@@ -486,16 +491,19 @@ jax_test(
486491
"gpu": 10,
487492
"tpu": 10,
488493
},
494+
tags = ["test_cpu_thunks"],
489495
)
490496

491497
jax_test(
492498
name = "lax_numpy_ufuncs_test",
493499
srcs = ["lax_numpy_ufuncs_test.py"],
500+
tags = ["test_cpu_thunks"],
494501
)
495502

496503
jax_test(
497504
name = "lax_numpy_vectorize_test",
498505
srcs = ["lax_numpy_vectorize_test.py"],
506+
tags = ["test_cpu_thunks"],
499507
)
500508

501509
jax_test(
@@ -560,6 +568,7 @@ jax_test(
560568
"gpu": 40,
561569
"tpu": 40,
562570
},
571+
tags = ["test_cpu_thunks"],
563572
deps = [
564573
"//jax:internal_test_util",
565574
"//jax:lax_reference",
@@ -589,6 +598,7 @@ jax_test(
589598
"gpu": 40,
590599
"tpu": 20,
591600
},
601+
tags = ["test_cpu_thunks"],
592602
)
593603

594604
jax_test(
@@ -599,6 +609,7 @@ jax_test(
599609
"gpu": 40,
600610
"tpu": 40,
601611
},
612+
tags = ["test_cpu_thunks"],
602613
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
603614
)
604615

@@ -610,6 +621,7 @@ jax_test(
610621
"gpu": 40,
611622
"tpu": 40,
612623
},
624+
tags = ["test_cpu_thunks"],
613625
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
614626
)
615627

@@ -652,6 +664,7 @@ jax_test(
652664
"gpu": 40,
653665
"tpu": 40,
654666
},
667+
tags = ["test_cpu_thunks"],
655668
)
656669

657670
jax_test(

tests/array_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,30 @@ def test_gspmd_sharding_hash_eq(self):
12761276
self.assertEqual(x1, x2)
12771277
self.assertEqual(hash(x1), hash(x2))
12781278

1279+
def test_device_attr(self):
1280+
# For single-device arrays, x.device returns the device
1281+
x = jnp.ones((2, 10))
1282+
self.assertEqual(x.device, list(x.devices())[0])
1283+
1284+
# For sharded arrays, x.device returns the sharding
1285+
mesh = jtu.create_global_mesh((2,), ('x',))
1286+
sharding = jax.sharding.NamedSharding(mesh, P('x'))
1287+
x = jax.device_put(x, sharding)
1288+
self.assertEqual(x.device, sharding)
1289+
1290+
def test_to_device(self):
1291+
device = jax.devices()[-1]
1292+
mesh = jtu.create_global_mesh((2,), ('x',))
1293+
sharding = jax.sharding.NamedSharding(mesh, P('x'))
1294+
1295+
x = jnp.ones((2, 10))
1296+
1297+
x_device = x.to_device(device)
1298+
x_sharding = x.to_device(sharding)
1299+
1300+
self.assertEqual(x_device.device, device)
1301+
self.assertEqual(x_sharding.device, sharding)
1302+
12791303

12801304
class RngShardingTest(jtu.JaxTestCase):
12811305
# tests that the PRNGs are automatically sharded as expected

0 commit comments

Comments
 (0)