Skip to content

Commit 6cd196a

Browse files
emilyfertigGoogle-ML-Automation
authored andcommitted
Prototype of cross-host device transfers in IFRT-PJRT.
For now it only works with the TFRT TPU runtime, because other PjRt plugins don't implement the necessary APIs. The per-shard indices of the source and destination shardings must be the same, and all shards must require cross-host transfers (support for a mixture of cross-host and host-local transfers is forthcoming). Transfers take place via the xla::ifrt::PjRtClient::CopyArrays API, which copies the buffers from a set of arrays to a new device list. The distributed KV store from the coordination service is used to store metadata for cross-host transfers. The receiving process populates the store with a descriptor, and the sending process reads it and completes the send. PiperOrigin-RevId: 766765989
1 parent 7dd0344 commit 6cd196a

File tree

1 file changed

+46
-11
lines changed

1 file changed

+46
-11
lines changed

jax/_src/dispatch.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -356,16 +356,6 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
356356
return api.jit(_identity_fn, out_shardings=target_sharding,
357357
donate_argnums=donate_argnums)(x)
358358

359-
if inp_sharding.device_set != target_sharding.device_set:
360-
inp_ids = [d.id for d in inp_sharding._device_assignment]
361-
inp_plat = inp_sharding._device_assignment[0].platform.upper()
362-
target_ids = [d.id for d in target_sharding._device_assignment]
363-
target_plat = target_sharding._device_assignment[0].platform.upper()
364-
raise ValueError("Input and target sharding should have the same set of "
365-
f"devices. Got input's device set ids: {inp_ids} on "
366-
f"platform {inp_plat} and target sharding's device set "
367-
f"ids: {target_ids} on platform {target_plat}")
368-
369359
if inp_sharding.is_fully_replicated:
370360
permute_order = None
371361
else:
@@ -389,6 +379,25 @@ def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
389379
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
390380

391381

382+
@util.cache()
383+
def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding):
384+
"""Returns True if src->dst is a supported cross-host transfer."""
385+
backend = xla_bridge.get_backend()
386+
# There is experimental support for cross-host device transfers on TFRT TPU
387+
# backends only.
388+
if (xla_bridge.process_count() == 1 or backend.platform != "tpu" or
389+
"TFRT TPU" not in backend.platform_version):
390+
return False
391+
if (src_sharding._to_xla_hlo_sharding(ndim) !=
392+
dst_sharding._to_xla_hlo_sharding(ndim)):
393+
return False
394+
# This check excludes the case where the source and destination shardings
395+
# have the same process index sets but there are shards that require
396+
# cross-host transfers. This case is supportable but expensive to check for.
397+
return (src_sharding._internal_device_list.process_indices !=
398+
dst_sharding._internal_device_list.process_indices)
399+
400+
392401
@dataclasses.dataclass(frozen=True)
393402
class _DeferredShardArg:
394403
"""Deferred call to `pxla.shard_args`.
@@ -419,7 +428,8 @@ def _device_put_sharding_impl(x, aval, device, copy):
419428
return x
420429

421430
if (not s.is_fully_addressable and
422-
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
431+
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable and
432+
s.device_set == x.sharding.device_set):
423433
assert isinstance(s, Sharding)
424434
return _different_device_order_reshard(x, s, copy)
425435

@@ -430,7 +440,32 @@ def _device_put_sharding_impl(x, aval, device, copy):
430440
assert isinstance(s, Sharding)
431441
return _different_device_order_reshard(x, s, copy)
432442

443+
# There is experimental support for cross-host device transfers on TFRT TPU.
444+
if (isinstance(x, array.ArrayImpl) and x._committed
445+
and _is_supported_cross_host_transfer(x.ndim, x.sharding, s)):
446+
return xc.batched_copy_array_to_devices_with_sharding(
447+
[x], [s._internal_device_list], [s], # pytype: disable=attribute-error
448+
pxla.to_xc_copy_semantics([copy]))[0]
449+
433450
if not s.is_fully_addressable:
451+
# If both the source and target shardings are not fully addressable and
452+
# one of the above conditions has not been met, then assume that the user
453+
# is attempting a different device order reshard.
454+
if (isinstance(x, array.ArrayImpl) and not x.is_fully_addressable
455+
and s.device_set != x.sharding.device_set):
456+
inp_ids = [d.id for d in x.sharding._device_assignment]
457+
inp_plat = x.sharding._device_assignment[0].platform.upper()
458+
target_ids = [d.id for d in s._device_assignment]
459+
target_plat = s._device_assignment[0].platform.upper()
460+
raise ValueError(
461+
"For a cross-host reshard in multi-controller JAX, input and target"
462+
" sharding should have the same set of devices. Got input's device"
463+
f" set ids: {inp_ids} on platform {inp_plat} and target sharding's"
464+
f" device set ids: {target_ids} on platform {target_plat}.\n\n"
465+
"There is experimental support for cross-host transfers with "
466+
"different device sets, when input/output shardings have the same "
467+
"indices and layouts, in the TFRT TPU runtime only.")
468+
434469
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
435470
type(x) in array_types or type(x) in dtypes.python_scalar_dtypes):
436471
# If all hosts participate in the sharding, assert that the input is the

0 commit comments

Comments
 (0)