Skip to content

Commit 8568fc2

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Adding eviction metadata tensor fqn (#3247)
Summary: Pull Request resolved: #3247 X-link: pytorch/FBGEMM#4611 X-link: facebookresearch/FBGEMM#1646 Adding a new metadata fqn in kvzch ckpt, which is needed for eviction filter in publishing. Reviewed By: emlin Differential Revision: D78768842 fbshipit-source-id: eb5b648dc9395c2e02b0b851fb6ce2b2304d6113
1 parent 6ecb3ee commit 8568fc2

File tree

5 files changed

+105
-23
lines changed

5 files changed

+105
-23
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ def _gen_named_parameters_by_table_ssd_pmt(
958958
name as well as the parameter itself. The embedding table is in the form of
959959
PartiallyMaterializedTensor to support windowed access.
960960
"""
961-
pmts, _, _ = emb_module.split_embedding_weights()
961+
pmts, _, _, _ = emb_module.split_embedding_weights()
962962
for table_config, pmt in zip(config.embedding_tables, pmts):
963963
table_name = table_config.name
964964
emb_table = pmt
@@ -1272,7 +1272,7 @@ def state_dict(
12721272
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
12731273
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
12741274

1275-
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
1275+
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
12761276
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
12771277
for emb_table in emb_table_config_copy:
12781278
emb_table.local_metadata.placement._device = torch.device("cpu")
@@ -1322,6 +1322,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
13221322
Union[ShardedTensor, PartiallyMaterializedTensor],
13231323
Optional[ShardedTensor],
13241324
Optional[ShardedTensor],
1325+
Optional[ShardedTensor],
13251326
]
13261327
]:
13271328
"""
@@ -1330,13 +1331,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
13301331
RocksDB snapshot to support windowed access.
13311332
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
13321333
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
1334+
optional ShardedTensor for metadata, this won't be used here as this is non-kvzch
13331335
"""
13341336
for config, tensor in zip(
13351337
self._config.embedding_tables,
13361338
self.split_embedding_weights(no_snapshot=False)[0],
13371339
):
13381340
key = append_prefix(prefix, f"{config.name}")
1339-
yield key, tensor, None, None
1341+
yield key, tensor, None, None, None
13401342

13411343
def flush(self) -> None:
13421344
"""
@@ -1364,6 +1366,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
13641366
List[PartiallyMaterializedTensor],
13651367
Optional[List[torch.Tensor]],
13661368
Optional[List[torch.Tensor]],
1369+
Optional[List[torch.Tensor]],
13671370
]:
13681371
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
13691372
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
@@ -1415,6 +1418,7 @@ def __init__(
14151418
List[ShardedTensor],
14161419
List[ShardedTensor],
14171420
List[ShardedTensor],
1421+
List[ShardedTensor],
14181422
]
14191423
] = None
14201424

@@ -1490,7 +1494,7 @@ def state_dict(
14901494
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
14911495
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
14921496

1493-
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
1497+
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
14941498
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
14951499
for emb_table in emb_table_config_copy:
14961500
emb_table.local_metadata.placement._device = torch.device("cpu")
@@ -1546,8 +1550,10 @@ def _init_sharded_split_embedding_weights(
15461550
if not force_regenerate and self._split_weights_res is not None:
15471551
return
15481552

1549-
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
1550-
no_snapshot=False,
1553+
pmt_list, weight_ids_list, bucket_cnt_list, metadata_list = (
1554+
self.split_embedding_weights(
1555+
no_snapshot=False,
1556+
)
15511557
)
15521558
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
15531559
for emb_table in emb_table_config_copy:
@@ -1581,17 +1587,31 @@ def _init_sharded_split_embedding_weights(
15811587
self._table_name_to_weight_count_per_rank,
15821588
use_param_size_as_rows=True,
15831589
)
1584-
# pyre-ignore
1585-
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
1590+
metadata_sharded_t_list = create_virtual_sharded_tensors(
1591+
emb_table_config_copy,
1592+
metadata_list, # pyre-ignore [6]
1593+
self._pg,
1594+
prefix,
1595+
self._table_name_to_weight_count_per_rank,
1596+
)
1597+
1598+
assert (
1599+
len(pmt_list)
1600+
== len(weight_ids_list) # pyre-ignore
1601+
== len(bucket_cnt_list) # pyre-ignore
1602+
== len(metadata_list) # pyre-ignore
1603+
)
15861604
assert (
15871605
len(pmt_sharded_t_list)
15881606
== len(weight_id_sharded_t_list)
15891607
== len(bucket_cnt_sharded_t_list)
1608+
== len(metadata_sharded_t_list)
15901609
)
15911610
self._split_weights_res = (
15921611
pmt_sharded_t_list,
15931612
weight_id_sharded_t_list,
15941613
bucket_cnt_sharded_t_list,
1614+
metadata_sharded_t_list,
15951615
)
15961616

15971617
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
@@ -1600,6 +1620,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
16001620
Union[ShardedTensor, PartiallyMaterializedTensor],
16011621
Optional[ShardedTensor],
16021622
Optional[ShardedTensor],
1623+
Optional[ShardedTensor],
16031624
]
16041625
]:
16051626
"""
@@ -1608,6 +1629,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
16081629
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
16091630
optional ShardedTensor for weight_id
16101631
optional ShardedTensor for bucket_cnt
1632+
optional ShardedTensor for metadata
16111633
"""
16121634
self._init_sharded_split_embedding_weights()
16131635
# pyre-ignore[16]
@@ -1616,13 +1638,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
16161638
pmt_sharded_t_list = self._split_weights_res[0]
16171639
weight_id_sharded_t_list = self._split_weights_res[1]
16181640
bucket_cnt_sharded_t_list = self._split_weights_res[2]
1641+
metadata_sharded_t_list = self._split_weights_res[3]
16191642
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
16201643
table_config = self._config.embedding_tables[table_idx]
16211644
key = append_prefix(prefix, f"{table_config.name}")
16221645

16231646
yield key, pmt_sharded_t, weight_id_sharded_t_list[
16241647
table_idx
1625-
], bucket_cnt_sharded_t_list[table_idx]
1648+
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx]
16261649

16271650
def flush(self) -> None:
16281651
"""
@@ -1651,6 +1674,7 @@ def split_embedding_weights(
16511674
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
16521675
Optional[List[torch.Tensor]],
16531676
Optional[List[torch.Tensor]],
1677+
Optional[List[torch.Tensor]],
16541678
]:
16551679
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
16561680

@@ -2079,7 +2103,7 @@ def state_dict(
20792103
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
20802104
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
20812105

2082-
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
2106+
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
20832107
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
20842108
for emb_table in emb_table_config_copy:
20852109
emb_table.local_metadata.placement._device = torch.device("cpu")
@@ -2129,6 +2153,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
21292153
Union[ShardedTensor, PartiallyMaterializedTensor],
21302154
Optional[ShardedTensor],
21312155
Optional[ShardedTensor],
2156+
Optional[ShardedTensor],
21322157
]
21332158
]:
21342159
"""
@@ -2137,13 +2162,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
21372162
RocksDB snapshot to support windowed access.
21382163
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
21392164
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
2165+
optional ShardedTensor for metadata, this won't be used here as this is non-kvzch
21402166
"""
21412167
for config, tensor in zip(
21422168
self._config.embedding_tables,
21432169
self.split_embedding_weights(no_snapshot=False)[0],
21442170
):
21452171
key = append_prefix(prefix, f"{config.name}")
2146-
yield key, tensor, None, None
2172+
yield key, tensor, None, None, None
21472173

21482174
def flush(self) -> None:
21492175
"""
@@ -2170,6 +2196,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
21702196
List[PartiallyMaterializedTensor],
21712197
Optional[List[torch.Tensor]],
21722198
Optional[List[torch.Tensor]],
2199+
Optional[List[torch.Tensor]],
21732200
]:
21742201
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
21752202
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
@@ -2223,6 +2250,7 @@ def __init__(
22232250
List[ShardedTensor],
22242251
List[ShardedTensor],
22252252
List[ShardedTensor],
2253+
List[ShardedTensor],
22262254
]
22272255
] = None
22282256

@@ -2298,7 +2326,7 @@ def state_dict(
22982326
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
22992327
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
23002328

2301-
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
2329+
emb_tables, _, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
23022330
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
23032331
for emb_table in emb_table_config_copy:
23042332
emb_table.local_metadata.placement._device = torch.device("cpu")
@@ -2354,8 +2382,10 @@ def _init_sharded_split_embedding_weights(
23542382
if not force_regenerate and self._split_weights_res is not None:
23552383
return
23562384

2357-
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
2358-
no_snapshot=False,
2385+
pmt_list, weight_ids_list, bucket_cnt_list, metadata_list = (
2386+
self.split_embedding_weights(
2387+
no_snapshot=False,
2388+
)
23592389
)
23602390
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
23612391
for emb_table in emb_table_config_copy:
@@ -2389,17 +2419,31 @@ def _init_sharded_split_embedding_weights(
23892419
self._table_name_to_weight_count_per_rank,
23902420
use_param_size_as_rows=True,
23912421
)
2392-
# pyre-ignore
2393-
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
2422+
metadata_sharded_t_list = create_virtual_sharded_tensors(
2423+
emb_table_config_copy,
2424+
metadata_list, # pyre-ignore [6]
2425+
self._pg,
2426+
prefix,
2427+
self._table_name_to_weight_count_per_rank,
2428+
)
2429+
2430+
assert (
2431+
len(pmt_list)
2432+
== len(weight_ids_list) # pyre-ignore
2433+
== len(bucket_cnt_list) # pyre-ignore
2434+
== len(metadata_list) # pyre-ignore
2435+
)
23942436
assert (
23952437
len(pmt_sharded_t_list)
23962438
== len(weight_id_sharded_t_list)
23972439
== len(bucket_cnt_sharded_t_list)
2440+
== len(metadata_sharded_t_list)
23982441
)
23992442
self._split_weights_res = (
24002443
pmt_sharded_t_list,
24012444
weight_id_sharded_t_list,
24022445
bucket_cnt_sharded_t_list,
2446+
metadata_sharded_t_list,
24032447
)
24042448

24052449
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
@@ -2408,6 +2452,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
24082452
Union[ShardedTensor, PartiallyMaterializedTensor],
24092453
Optional[ShardedTensor],
24102454
Optional[ShardedTensor],
2455+
Optional[ShardedTensor],
24112456
]
24122457
]:
24132458
"""
@@ -2416,6 +2461,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
24162461
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
24172462
optional ShardedTensor for weight_id
24182463
optional ShardedTensor for bucket_cnt
2464+
optional ShardedTensor for metadata
24192465
"""
24202466
self._init_sharded_split_embedding_weights()
24212467
# pyre-ignore[16]
@@ -2424,13 +2470,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
24242470
pmt_sharded_t_list = self._split_weights_res[0]
24252471
weight_id_sharded_t_list = self._split_weights_res[1]
24262472
bucket_cnt_sharded_t_list = self._split_weights_res[2]
2473+
metadata_sharded_t_list = self._split_weights_res[3]
24272474
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
24282475
table_config = self._config.embedding_tables[table_idx]
24292476
key = append_prefix(prefix, f"{table_config.name}")
24302477

24312478
yield key, pmt_sharded_t, weight_id_sharded_t_list[
24322479
table_idx
2433-
], bucket_cnt_sharded_t_list[table_idx]
2480+
], bucket_cnt_sharded_t_list[table_idx], metadata_sharded_t_list[table_idx]
24342481

24352482
def flush(self) -> None:
24362483
"""
@@ -2459,6 +2506,7 @@ def split_embedding_weights(
24592506
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
24602507
Optional[List[torch.Tensor]],
24612508
Optional[List[torch.Tensor]],
2509+
Optional[List[torch.Tensor]],
24622510
]:
24632511
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
24642512

torchrec/distributed/embedding.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,10 +698,13 @@ def _pre_load_state_dict_hook(
698698
weight_key = f"{prefix}embeddings.{table_name}.weight"
699699
weight_id_key = f"{prefix}embeddings.{table_name}.weight_id"
700700
bucket_key = f"{prefix}embeddings.{table_name}.bucket"
701+
metadata_key = f"{prefix}embeddings.{table_name}.metadata"
701702
if weight_id_key in state_dict:
702703
del state_dict[weight_id_key]
703704
if bucket_key in state_dict:
704705
del state_dict[bucket_key]
706+
if metadata_key in state_dict:
707+
del state_dict[metadata_key]
705708
assert weight_key in state_dict
706709
assert (
707710
len(self._model_parallel_name_to_local_shards[table_name]) == 1
@@ -1037,6 +1040,7 @@ def post_state_dict_hook(
10371040
weights_t,
10381041
weight_ids_sharded_t,
10391042
id_cnt_per_bucket_sharded_t,
1043+
metadata_sharded_t,
10401044
) in (
10411045
lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore
10421046
):
@@ -1048,19 +1052,22 @@ def post_state_dict_hook(
10481052
assert (
10491053
weight_ids_sharded_t is not None
10501054
and id_cnt_per_bucket_sharded_t is not None
1055+
and metadata_sharded_t is not None
10511056
)
10521057
# The logic here assumes there is only one shard per table on any particular rank
10531058
# if there are cases each rank has >1 shards, we need to update here accordingly
10541059
sharded_kvtensors_copy[table_name] = weights_t
10551060
virtual_table_sharded_t_map[table_name] = (
10561061
weight_ids_sharded_t,
10571062
id_cnt_per_bucket_sharded_t,
1063+
metadata_sharded_t,
10581064
)
10591065
else:
10601066
assert isinstance(weights_t, PartiallyMaterializedTensor)
10611067
assert (
10621068
weight_ids_sharded_t is None
10631069
and id_cnt_per_bucket_sharded_t is None
1070+
and metadata_sharded_t is None
10641071
)
10651072
# The logic here assumes there is only one shard per table on any particular rank
10661073
# if there are cases each rank has >1 shards, we need to update here accordingly
@@ -1099,6 +1106,12 @@ def update_destination(
10991106
destination,
11001107
virtual_table_sharded_t_map[table_name][1],
11011108
)
1109+
update_destination(
1110+
table_name,
1111+
"metadata",
1112+
destination,
1113+
virtual_table_sharded_t_map[table_name][2],
1114+
)
11021115

11031116
def _post_load_state_dict_hook(
11041117
module: "ShardedEmbeddingCollection",

torchrec/distributed/embedding_lookup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def get_named_split_embedding_weights_snapshot(
381381
Union[ShardedTensor, PartiallyMaterializedTensor],
382382
Optional[ShardedTensor],
383383
Optional[ShardedTensor],
384+
Optional[ShardedTensor],
384385
]
385386
]:
386387
"""
@@ -732,6 +733,7 @@ def get_named_split_embedding_weights_snapshot(
732733
Union[ShardedTensor, PartiallyMaterializedTensor],
733734
Optional[ShardedTensor],
734735
Optional[ShardedTensor],
736+
Optional[ShardedTensor],
735737
]
736738
]:
737739
"""

0 commit comments

Comments
 (0)