@@ -795,7 +795,7 @@ def execute_ssd_backward_(
795
795
def split_optimizer_states_ (
796
796
self , emb : SSDTableBatchedEmbeddingBags
797
797
) -> List [List [torch .Tensor ]]:
798
- _ , bucket_asc_ids_list , _ = emb .split_embedding_weights (
798
+ _ , bucket_asc_ids_list , _ , _ = emb .split_embedding_weights (
799
799
no_snapshot = False , should_flush = True
800
800
)
801
801
@@ -1289,7 +1289,7 @@ def test_ssd_emb_state_dict(
1289
1289
split_optimizer_states = self .split_optimizer_states_ (emb )
1290
1290
1291
1291
# Compare emb state dict with expected values from nn.EmbeddingBag
1292
- emb_state_dict , _ , _ = emb .split_embedding_weights (no_snapshot = False )
1292
+ emb_state_dict , _ , _ , _ = emb .split_embedding_weights (no_snapshot = False )
1293
1293
for feature_index , table_index in self .get_physical_table_arg_indices_ (
1294
1294
emb .feature_table_map
1295
1295
):
@@ -1904,9 +1904,12 @@ def test_kv_emb_state_dict(
1904
1904
split_optimizer_states = []
1905
1905
1906
1906
# Compare emb state dict with expected values from nn.EmbeddingBag
1907
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
1908
- emb .split_embedding_weights (no_snapshot = False , should_flush = True )
1909
- )
1907
+ (
1908
+ emb_state_dict_list ,
1909
+ bucket_asc_ids_list ,
1910
+ num_active_id_per_bucket_list ,
1911
+ metadata_list ,
1912
+ ) = emb .split_embedding_weights (no_snapshot = False , should_flush = True )
1910
1913
1911
1914
for s in emb .split_optimizer_states (
1912
1915
bucket_asc_ids_list , no_snapshot = False , should_flush = True
@@ -1973,6 +1976,7 @@ def test_kv_emb_state_dict(
1973
1976
)
1974
1977
self .assertLess (table_index , len (emb_state_dict_list ))
1975
1978
assert len (split_optimizer_states [table_index ][0 ]) == num_ids
1979
+ assert len (metadata_list [table_index ]) == num_ids
1976
1980
# NOTE: The [0] index is a hack since the test is fixed to use
1977
1981
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1978
1982
# be upgraded in the future to support multiple optimizers
@@ -2119,7 +2123,7 @@ def test_kv_opt_state_w_offloading(
2119
2123
)
2120
2124
2121
2125
# Compare emb state dict with expected values from nn.EmbeddingBag
2122
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
2126
+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list , _ = (
2123
2127
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2124
2128
)
2125
2129
split_optimizer_states = emb .split_optimizer_states (
@@ -2348,7 +2352,7 @@ def test_kv_state_dict_w_backend_return_whole_row(
2348
2352
)
2349
2353
2350
2354
# Compare emb state dict with expected values from nn.EmbeddingBag
2351
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
2355
+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list , _ = (
2352
2356
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2353
2357
)
2354
2358
split_optimizer_states = emb .split_optimizer_states (
@@ -2616,7 +2620,7 @@ def test_apply_kv_state_dict(
2616
2620
)
2617
2621
2618
2622
# Compare emb state dict with expected values from nn.EmbeddingBag
2619
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
2623
+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list , _ = (
2620
2624
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2621
2625
)
2622
2626
split_optimizer_states = emb .split_optimizer_states (
@@ -2684,6 +2688,7 @@ def test_apply_kv_state_dict(
2684
2688
emb_state_dict_list2 ,
2685
2689
bucket_asc_ids_list2 ,
2686
2690
num_active_id_per_bucket_list2 ,
2691
+ _ ,
2687
2692
) = emb2 .split_embedding_weights (no_snapshot = False , should_flush = True )
2688
2693
split_optimizer_states2 = emb2 .split_optimizer_states (
2689
2694
bucket_asc_ids_list2 , no_snapshot = False , should_flush = True
@@ -3139,7 +3144,7 @@ def copy_opt_states_hook(
3139
3144
emb .flush ()
3140
3145
3141
3146
# Compare emb state dict with expected values from nn.EmbeddingBag
3142
- _emb_state_dict_list , bucket_asc_ids_list , _num_active_id_per_bucket_list = (
3147
+ _emb_state_dict_list , bucket_asc_ids_list , _num_active_id_per_bucket_list , _ = (
3143
3148
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
3144
3149
)
3145
3150
assert bucket_asc_ids_list is not None
0 commit comments