46
46
PartiallyMaterializedTensor ,
47
47
)
48
48
from torch import nn
49
+ from torch .distributed ._tensor import DTensor , Replicate , Shard as DTensorShard
49
50
from torchrec .distributed .comm import get_local_rank , get_node_group_size
50
51
from torchrec .distributed .composable .table_batched_embedding_slice import (
51
52
TableBatchedEmbeddingSlice ,
52
53
)
53
54
from torchrec .distributed .embedding_kernel import BaseEmbedding , get_state_dict
54
55
from torchrec .distributed .embedding_types import (
55
56
compute_kernel_to_embedding_location ,
57
+ DTensorMetadata ,
56
58
GroupedEmbeddingConfig ,
57
59
)
60
+ from torchrec .distributed .shards_wrapper import LocalShardsWrapper
58
61
from torchrec .distributed .types import (
59
62
Shard ,
60
63
ShardedTensor ,
@@ -213,6 +216,7 @@ class ShardParams:
213
216
optimizer_states : List [Optional [Tuple [torch .Tensor ]]]
214
217
local_metadata : List [ShardMetadata ]
215
218
embedding_weights : List [torch .Tensor ]
219
+ dtensor_metadata : List [DTensorMetadata ]
216
220
217
221
def get_optimizer_single_value_shard_metadata_and_global_metadata (
218
222
table_global_metadata : ShardedTensorMetadata ,
@@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
389
393
continue
390
394
if table_config .name not in table_to_shard_params :
391
395
table_to_shard_params [table_config .name ] = ShardParams (
392
- optimizer_states = [], local_metadata = [], embedding_weights = []
396
+ optimizer_states = [],
397
+ local_metadata = [],
398
+ embedding_weights = [],
399
+ dtensor_metadata = [],
393
400
)
394
401
optimizer_state_values = None
395
402
if optimizer_states :
@@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
410
417
table_to_shard_params [table_config .name ].local_metadata .append (
411
418
local_metadata
412
419
)
420
+ table_to_shard_params [table_config .name ].dtensor_metadata .append (
421
+ table_config .dtensor_metadata
422
+ )
413
423
table_to_shard_params [table_config .name ].embedding_weights .append (weight )
414
424
415
425
seen_tables = set ()
@@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
474
484
# pyre-ignore
475
485
def get_sharded_optim_state (
476
486
momentum_idx : int , state_key : str
477
- ) -> ShardedTensor :
487
+ ) -> Union [ ShardedTensor , DTensor ] :
478
488
assert momentum_idx > 0
479
489
momentum_local_shards : List [Shard ] = []
480
490
optimizer_sharded_tensor_metadata : ShardedTensorMetadata
@@ -528,12 +538,42 @@ def get_sharded_optim_state(
528
538
)
529
539
)
530
540
531
- # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
532
- return ShardedTensor ._init_from_local_shards_and_global_metadata (
533
- local_shards = momentum_local_shards ,
534
- sharded_tensor_metadata = optimizer_sharded_tensor_metadata ,
535
- process_group = self ._pg ,
536
- )
541
+ # Convert optimizer state to DTensor if enabled
542
+ if table_config .dtensor_metadata :
543
+ # Depending on the optim state we determine the Shard Dim?
544
+ # if rowwise state we do Shard(0), regardless of how the table is sharded
545
+ if optim_state .dim () == 1 :
546
+ stride = (1 ,)
547
+ placements = (
548
+ (Replicate (), DTensorShard (0 ))
549
+ if table_config .dtensor_metadata .mesh .ndim == 2
550
+ else (DTensorShard (0 ),)
551
+ )
552
+ else :
553
+ stride = table_config .dtensor_metadata .stride
554
+ placements = table_config .dtensor_metadata .placements
555
+
556
+ return DTensor .from_local (
557
+ local_tensor = LocalShardsWrapper (
558
+ local_shards = [x .tensor for x in momentum_local_shards ],
559
+ local_offsets = [
560
+ x .metadata .shard_offsets
561
+ for x in momentum_local_shards
562
+ ],
563
+ ),
564
+ device_mesh = table_config .dtensor_metadata .mesh ,
565
+ placements = placements ,
566
+ shape = optimizer_sharded_tensor_metadata .size ,
567
+ stride = stride ,
568
+ run_check = False ,
569
+ )
570
+ else :
571
+ # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
572
+ return ShardedTensor ._init_from_local_shards_and_global_metadata (
573
+ local_shards = momentum_local_shards ,
574
+ sharded_tensor_metadata = optimizer_sharded_tensor_metadata ,
575
+ process_group = self ._pg ,
576
+ )
537
577
538
578
num_states : int = min (
539
579
# pyre-ignore
0 commit comments