@@ -356,16 +356,6 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
356
356
return api .jit (_identity_fn , out_shardings = target_sharding ,
357
357
donate_argnums = donate_argnums )(x )
358
358
359
- if inp_sharding .device_set != target_sharding .device_set :
360
- inp_ids = [d .id for d in inp_sharding ._device_assignment ]
361
- inp_plat = inp_sharding ._device_assignment [0 ].platform .upper ()
362
- target_ids = [d .id for d in target_sharding ._device_assignment ]
363
- target_plat = target_sharding ._device_assignment [0 ].platform .upper ()
364
- raise ValueError ("Input and target sharding should have the same set of "
365
- f"devices. Got input's device set ids: { inp_ids } on "
366
- f"platform { inp_plat } and target sharding's device set "
367
- f"ids: { target_ids } on platform { target_plat } " )
368
-
369
359
if inp_sharding .is_fully_replicated :
370
360
permute_order = None
371
361
else :
@@ -389,6 +379,25 @@ def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
389
379
return xc .reorder_shards (x , new_s , xc_copy_semantics ) # type: ignore
390
380
391
381
382
+ @util .cache ()
383
+ def _is_supported_cross_host_transfer (ndim , src_sharding , dst_sharding ):
384
+ """Returns True if src->dst is a supported cross-host transfer."""
385
+ backend = xla_bridge .get_backend ()
386
+ # There is experimental support for cross-host device transfers on TFRT TPU
387
+ # backends only.
388
+ if (xla_bridge .process_count () == 1 or backend .platform != "tpu" or
389
+ "TFRT TPU" not in backend .platform_version ):
390
+ return False
391
+ if (src_sharding ._to_xla_hlo_sharding (ndim ) !=
392
+ dst_sharding ._to_xla_hlo_sharding (ndim )):
393
+ return False
394
+ # This check excludes the case where the source and destination shardings
395
+ # have the same process index sets but there are shards that require
396
+ # cross-host transfers. This case is supportable but expensive to check for.
397
+ return (src_sharding ._internal_device_list .process_indices !=
398
+ dst_sharding ._internal_device_list .process_indices )
399
+
400
+
392
401
@dataclasses .dataclass (frozen = True )
393
402
class _DeferredShardArg :
394
403
"""Deferred call to `pxla.shard_args`.
@@ -419,7 +428,8 @@ def _device_put_sharding_impl(x, aval, device, copy):
419
428
return x
420
429
421
430
if (not s .is_fully_addressable and
422
- isinstance (x , array .ArrayImpl ) and not x .is_fully_addressable ):
431
+ isinstance (x , array .ArrayImpl ) and not x .is_fully_addressable and
432
+ s .device_set == x .sharding .device_set ):
423
433
assert isinstance (s , Sharding )
424
434
return _different_device_order_reshard (x , s , copy )
425
435
@@ -430,7 +440,32 @@ def _device_put_sharding_impl(x, aval, device, copy):
430
440
assert isinstance (s , Sharding )
431
441
return _different_device_order_reshard (x , s , copy )
432
442
443
+ # There is experimental support for cross-host device transfers on TFRT TPU.
444
+ if (isinstance (x , array .ArrayImpl ) and x ._committed
445
+ and _is_supported_cross_host_transfer (x .ndim , x .sharding , s )):
446
+ return xc .batched_copy_array_to_devices_with_sharding (
447
+ [x ], [s ._internal_device_list ], [s ], # pytype: disable=attribute-error
448
+ pxla .to_xc_copy_semantics ([copy ]))[0 ]
449
+
433
450
if not s .is_fully_addressable :
451
+ # If both the source and target shardings are not fully addressable and
452
+ # one of the above conditions has not been met, then assume that the user
453
+ # is attempting a different device order reshard.
454
+ if (isinstance (x , array .ArrayImpl ) and not x .is_fully_addressable
455
+ and s .device_set != x .sharding .device_set ):
456
+ inp_ids = [d .id for d in x .sharding ._device_assignment ]
457
+ inp_plat = x .sharding ._device_assignment [0 ].platform .upper ()
458
+ target_ids = [d .id for d in s ._device_assignment ]
459
+ target_plat = s ._device_assignment [0 ].platform .upper ()
460
+ raise ValueError (
461
+ "For a cross-host reshard in multi-controller JAX, input and target"
462
+ " sharding should have the same set of devices. Got input's device"
463
+ f" set ids: { inp_ids } on platform { inp_plat } and target sharding's"
464
+ f" device set ids: { target_ids } on platform { target_plat } .\n \n "
465
+ "There is experimental support for cross-host transfers with "
466
+ "different device sets, when input/output shardings have the same "
467
+ "indices and layouts, in the TFRT TPU runtime only." )
468
+
434
469
if ((isinstance (x , array .ArrayImpl ) and not x ._committed ) or
435
470
type (x ) in array_types or type (x ) in dtypes .python_scalar_dtypes ):
436
471
# If all hosts participate in the sharding, assert that the input is the
0 commit comments