@@ -958,7 +958,7 @@ def _gen_named_parameters_by_table_ssd_pmt(
958
958
name as well as the parameter itself. The embedding table is in the form of
959
959
PartiallyMaterializedTensor to support windowed access.
960
960
"""
961
- pmts , _ , _ = emb_module .split_embedding_weights ()
961
+ pmts , _ , _ , _ = emb_module .split_embedding_weights ()
962
962
for table_config , pmt in zip (config .embedding_tables , pmts ):
963
963
table_name = table_config .name
964
964
emb_table = pmt
@@ -1272,7 +1272,7 @@ def state_dict(
1272
1272
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
1273
1273
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
1274
1274
1275
- emb_tables , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
1275
+ emb_tables , _ , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
1276
1276
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
1277
1277
for emb_table in emb_table_config_copy :
1278
1278
emb_table .local_metadata .placement ._device = torch .device ("cpu" )
@@ -1322,6 +1322,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1322
1322
Union [ShardedTensor , PartiallyMaterializedTensor ],
1323
1323
Optional [ShardedTensor ],
1324
1324
Optional [ShardedTensor ],
1325
+ Optional [ShardedTensor ],
1325
1326
]
1326
1327
]:
1327
1328
"""
@@ -1330,13 +1331,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1330
1331
RocksDB snapshot to support windowed access.
1331
1332
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
1332
1333
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
1334
+ optional ShardedTensor for metadata, this won't be used here as this is non-kvzch
1333
1335
"""
1334
1336
for config , tensor in zip (
1335
1337
self ._config .embedding_tables ,
1336
1338
self .split_embedding_weights (no_snapshot = False )[0 ],
1337
1339
):
1338
1340
key = append_prefix (prefix , f"{ config .name } " )
1339
- yield key , tensor , None , None
1341
+ yield key , tensor , None , None , None
1340
1342
1341
1343
def flush (self ) -> None :
1342
1344
"""
@@ -1364,6 +1366,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
1364
1366
List [PartiallyMaterializedTensor ],
1365
1367
Optional [List [torch .Tensor ]],
1366
1368
Optional [List [torch .Tensor ]],
1369
+ Optional [List [torch .Tensor ]],
1367
1370
]:
1368
1371
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
1369
1372
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
@@ -1415,6 +1418,7 @@ def __init__(
1415
1418
List [ShardedTensor ],
1416
1419
List [ShardedTensor ],
1417
1420
List [ShardedTensor ],
1421
+ List [ShardedTensor ],
1418
1422
]
1419
1423
] = None
1420
1424
@@ -1490,7 +1494,7 @@ def state_dict(
1490
1494
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
1491
1495
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
1492
1496
1493
- emb_tables , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
1497
+ emb_tables , _ , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
1494
1498
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
1495
1499
for emb_table in emb_table_config_copy :
1496
1500
emb_table .local_metadata .placement ._device = torch .device ("cpu" )
@@ -1546,8 +1550,10 @@ def _init_sharded_split_embedding_weights(
1546
1550
if not force_regenerate and self ._split_weights_res is not None :
1547
1551
return
1548
1552
1549
- pmt_list , weight_ids_list , bucket_cnt_list = self .split_embedding_weights (
1550
- no_snapshot = False ,
1553
+ pmt_list , weight_ids_list , bucket_cnt_list , metadata_list = (
1554
+ self .split_embedding_weights (
1555
+ no_snapshot = False ,
1556
+ )
1551
1557
)
1552
1558
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
1553
1559
for emb_table in emb_table_config_copy :
@@ -1581,17 +1587,31 @@ def _init_sharded_split_embedding_weights(
1581
1587
self ._table_name_to_weight_count_per_rank ,
1582
1588
use_param_size_as_rows = True ,
1583
1589
)
1584
- # pyre-ignore
1585
- assert len (pmt_list ) == len (weight_ids_list ) == len (bucket_cnt_list )
1590
+ metadata_sharded_t_list = create_virtual_sharded_tensors (
1591
+ emb_table_config_copy ,
1592
+ metadata_list , # pyre-ignore [6]
1593
+ self ._pg ,
1594
+ prefix ,
1595
+ self ._table_name_to_weight_count_per_rank ,
1596
+ )
1597
+
1598
+ assert (
1599
+ len (pmt_list )
1600
+ == len (weight_ids_list ) # pyre-ignore
1601
+ == len (bucket_cnt_list ) # pyre-ignore
1602
+ == len (metadata_list ) # pyre-ignore
1603
+ )
1586
1604
assert (
1587
1605
len (pmt_sharded_t_list )
1588
1606
== len (weight_id_sharded_t_list )
1589
1607
== len (bucket_cnt_sharded_t_list )
1608
+ == len (metadata_sharded_t_list )
1590
1609
)
1591
1610
self ._split_weights_res = (
1592
1611
pmt_sharded_t_list ,
1593
1612
weight_id_sharded_t_list ,
1594
1613
bucket_cnt_sharded_t_list ,
1614
+ metadata_sharded_t_list ,
1595
1615
)
1596
1616
1597
1617
def get_named_split_embedding_weights_snapshot (self , prefix : str = "" ) -> Iterator [
@@ -1600,6 +1620,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1600
1620
Union [ShardedTensor , PartiallyMaterializedTensor ],
1601
1621
Optional [ShardedTensor ],
1602
1622
Optional [ShardedTensor ],
1623
+ Optional [ShardedTensor ],
1603
1624
]
1604
1625
]:
1605
1626
"""
@@ -1608,6 +1629,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1608
1629
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1609
1630
optional ShardedTensor for weight_id
1610
1631
optional ShardedTensor for bucket_cnt
1632
+ optional ShardedTensor for metadata
1611
1633
"""
1612
1634
self ._init_sharded_split_embedding_weights ()
1613
1635
# pyre-ignore[16]
@@ -1616,13 +1638,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1616
1638
pmt_sharded_t_list = self ._split_weights_res [0 ]
1617
1639
weight_id_sharded_t_list = self ._split_weights_res [1 ]
1618
1640
bucket_cnt_sharded_t_list = self ._split_weights_res [2 ]
1641
+ metadata_sharded_t_list = self ._split_weights_res [3 ]
1619
1642
for table_idx , pmt_sharded_t in enumerate (pmt_sharded_t_list ):
1620
1643
table_config = self ._config .embedding_tables [table_idx ]
1621
1644
key = append_prefix (prefix , f"{ table_config .name } " )
1622
1645
1623
1646
yield key , pmt_sharded_t , weight_id_sharded_t_list [
1624
1647
table_idx
1625
- ], bucket_cnt_sharded_t_list [table_idx ]
1648
+ ], bucket_cnt_sharded_t_list [table_idx ], metadata_sharded_t_list [ table_idx ]
1626
1649
1627
1650
def flush (self ) -> None :
1628
1651
"""
@@ -1651,6 +1674,7 @@ def split_embedding_weights(
1651
1674
Union [List [PartiallyMaterializedTensor ], List [torch .Tensor ]],
1652
1675
Optional [List [torch .Tensor ]],
1653
1676
Optional [List [torch .Tensor ]],
1677
+ Optional [List [torch .Tensor ]],
1654
1678
]:
1655
1679
return self .emb_module .split_embedding_weights (no_snapshot , should_flush )
1656
1680
@@ -2079,7 +2103,7 @@ def state_dict(
2079
2103
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
2080
2104
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
2081
2105
2082
- emb_tables , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
2106
+ emb_tables , _ , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
2083
2107
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
2084
2108
for emb_table in emb_table_config_copy :
2085
2109
emb_table .local_metadata .placement ._device = torch .device ("cpu" )
@@ -2129,6 +2153,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
2129
2153
Union [ShardedTensor , PartiallyMaterializedTensor ],
2130
2154
Optional [ShardedTensor ],
2131
2155
Optional [ShardedTensor ],
2156
+ Optional [ShardedTensor ],
2132
2157
]
2133
2158
]:
2134
2159
"""
@@ -2137,13 +2162,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
2137
2162
RocksDB snapshot to support windowed access.
2138
2163
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
2139
2164
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
2165
+ optional ShardedTensor for metadata, this won't be used here as this is non-kvzch
2140
2166
"""
2141
2167
for config , tensor in zip (
2142
2168
self ._config .embedding_tables ,
2143
2169
self .split_embedding_weights (no_snapshot = False )[0 ],
2144
2170
):
2145
2171
key = append_prefix (prefix , f"{ config .name } " )
2146
- yield key , tensor , None , None
2172
+ yield key , tensor , None , None , None
2147
2173
2148
2174
def flush (self ) -> None :
2149
2175
"""
@@ -2170,6 +2196,7 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
2170
2196
List [PartiallyMaterializedTensor ],
2171
2197
Optional [List [torch .Tensor ]],
2172
2198
Optional [List [torch .Tensor ]],
2199
+ Optional [List [torch .Tensor ]],
2173
2200
]:
2174
2201
# pyre-fixme[7]: Expected `Tuple[List[PartiallyMaterializedTensor],
2175
2202
# Optional[List[Tensor]], Optional[List[Tensor]]]` but got
@@ -2223,6 +2250,7 @@ def __init__(
2223
2250
List [ShardedTensor ],
2224
2251
List [ShardedTensor ],
2225
2252
List [ShardedTensor ],
2253
+ List [ShardedTensor ],
2226
2254
]
2227
2255
] = None
2228
2256
@@ -2298,7 +2326,7 @@ def state_dict(
2298
2326
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
2299
2327
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
2300
2328
2301
- emb_tables , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
2329
+ emb_tables , _ , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
2302
2330
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
2303
2331
for emb_table in emb_table_config_copy :
2304
2332
emb_table .local_metadata .placement ._device = torch .device ("cpu" )
@@ -2354,8 +2382,10 @@ def _init_sharded_split_embedding_weights(
2354
2382
if not force_regenerate and self ._split_weights_res is not None :
2355
2383
return
2356
2384
2357
- pmt_list , weight_ids_list , bucket_cnt_list = self .split_embedding_weights (
2358
- no_snapshot = False ,
2385
+ pmt_list , weight_ids_list , bucket_cnt_list , metadata_list = (
2386
+ self .split_embedding_weights (
2387
+ no_snapshot = False ,
2388
+ )
2359
2389
)
2360
2390
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
2361
2391
for emb_table in emb_table_config_copy :
@@ -2389,17 +2419,31 @@ def _init_sharded_split_embedding_weights(
2389
2419
self ._table_name_to_weight_count_per_rank ,
2390
2420
use_param_size_as_rows = True ,
2391
2421
)
2392
- # pyre-ignore
2393
- assert len (pmt_list ) == len (weight_ids_list ) == len (bucket_cnt_list )
2422
+ metadata_sharded_t_list = create_virtual_sharded_tensors (
2423
+ emb_table_config_copy ,
2424
+ metadata_list , # pyre-ignore [6]
2425
+ self ._pg ,
2426
+ prefix ,
2427
+ self ._table_name_to_weight_count_per_rank ,
2428
+ )
2429
+
2430
+ assert (
2431
+ len (pmt_list )
2432
+ == len (weight_ids_list ) # pyre-ignore
2433
+ == len (bucket_cnt_list ) # pyre-ignore
2434
+ == len (metadata_list ) # pyre-ignore
2435
+ )
2394
2436
assert (
2395
2437
len (pmt_sharded_t_list )
2396
2438
== len (weight_id_sharded_t_list )
2397
2439
== len (bucket_cnt_sharded_t_list )
2440
+ == len (metadata_sharded_t_list )
2398
2441
)
2399
2442
self ._split_weights_res = (
2400
2443
pmt_sharded_t_list ,
2401
2444
weight_id_sharded_t_list ,
2402
2445
bucket_cnt_sharded_t_list ,
2446
+ metadata_sharded_t_list ,
2403
2447
)
2404
2448
2405
2449
def get_named_split_embedding_weights_snapshot (self , prefix : str = "" ) -> Iterator [
@@ -2408,6 +2452,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
2408
2452
Union [ShardedTensor , PartiallyMaterializedTensor ],
2409
2453
Optional [ShardedTensor ],
2410
2454
Optional [ShardedTensor ],
2455
+ Optional [ShardedTensor ],
2411
2456
]
2412
2457
]:
2413
2458
"""
@@ -2416,6 +2461,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
2416
2461
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
2417
2462
optional ShardedTensor for weight_id
2418
2463
optional ShardedTensor for bucket_cnt
2464
+ optional ShardedTensor for metadata
2419
2465
"""
2420
2466
self ._init_sharded_split_embedding_weights ()
2421
2467
# pyre-ignore[16]
@@ -2424,13 +2470,14 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
2424
2470
pmt_sharded_t_list = self ._split_weights_res [0 ]
2425
2471
weight_id_sharded_t_list = self ._split_weights_res [1 ]
2426
2472
bucket_cnt_sharded_t_list = self ._split_weights_res [2 ]
2473
+ metadata_sharded_t_list = self ._split_weights_res [3 ]
2427
2474
for table_idx , pmt_sharded_t in enumerate (pmt_sharded_t_list ):
2428
2475
table_config = self ._config .embedding_tables [table_idx ]
2429
2476
key = append_prefix (prefix , f"{ table_config .name } " )
2430
2477
2431
2478
yield key , pmt_sharded_t , weight_id_sharded_t_list [
2432
2479
table_idx
2433
- ], bucket_cnt_sharded_t_list [table_idx ]
2480
+ ], bucket_cnt_sharded_t_list [table_idx ], metadata_sharded_t_list [ table_idx ]
2434
2481
2435
2482
def flush (self ) -> None :
2436
2483
"""
@@ -2459,6 +2506,7 @@ def split_embedding_weights(
2459
2506
Union [List [PartiallyMaterializedTensor ], List [torch .Tensor ]],
2460
2507
Optional [List [torch .Tensor ]],
2461
2508
Optional [List [torch .Tensor ]],
2509
+ Optional [List [torch .Tensor ]],
2462
2510
]:
2463
2511
return self .emb_module .split_embedding_weights (no_snapshot , should_flush )
2464
2512
0 commit comments