Skip to content

Commit b87e654

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Adding eviction metadata tensor fqn (#4611)
Summary: X-link: pytorch/torchrec#3247 Pull Request resolved: #4611 X-link: facebookresearch/FBGEMM#1646 Adding a new metadata fqn in kvzch ckpt, which is needed for eviction filter in publishing. Reviewed By: emlin Differential Revision: D78768842 fbshipit-source-id: eb5b648dc9395c2e02b0b851fb6ce2b2304d6113
1 parent c531bc5 commit b87e654

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2890,6 +2890,7 @@ def split_embedding_weights(
28902890
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
28912891
Optional[List[torch.Tensor]],
28922892
Optional[List[torch.Tensor]],
2893+
Optional[List[torch.Tensor]],
28932894
]:
28942895
"""
28952896
This method is intended to be used by the checkpointing engine
@@ -2909,6 +2910,7 @@ def split_embedding_weights(
29092910
2nd arg: input id sorted in bucket id ascending order
29102911
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
29112912
where for the i th element, we have i + bucket_id_start = global bucket id
2913+
4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
29122914
"""
29132915
snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
29142916
no_snapshot=no_snapshot,
@@ -2925,16 +2927,19 @@ def split_embedding_weights(
29252927
self._cached_kvzch_data.cached_weight_tensor_per_table,
29262928
self._cached_kvzch_data.cached_id_tensor_per_table,
29272929
self._cached_kvzch_data.cached_bucket_splits,
2930+
[], # metadata tensor is not needed for checkpointing loading
29282931
)
29292932
start_time = time.time()
29302933
pmt_splits = []
29312934
bucket_sorted_id_splits = [] if self.kv_zch_params else None
29322935
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
2936+
metadata_splits = [] if self.kv_zch_params else None
29332937

29342938
table_offset = 0
29352939
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
29362940
bucket_ascending_id_tensor = None
29372941
bucket_t = None
2942+
metadata_tensor = None
29382943
row_offset = table_offset
29392944
metaheader_dim = 0
29402945
if self.kv_zch_params:
@@ -2966,6 +2971,12 @@ def split_embedding_weights(
29662971
bucket_size,
29672972
)
29682973
)
2974+
metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
2975+
bucket_ascending_id_tensor,
2976+
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
2977+
snapshot_handle,
2978+
)
2979+
29692980
# 3. convert local id back to global id
29702981
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
29712982

@@ -2981,11 +2992,17 @@ def split_embedding_weights(
29812992
device=torch.device("cpu"),
29822993
dtype=torch.int64,
29832994
)
2995+
metadata_tensor = torch.zeros(
2996+
(self.local_weight_counts[i], 1),
2997+
device=torch.device("cpu"),
2998+
dtype=torch.int64,
2999+
)
29843000
# self.local_weight_counts[i] = 0 # Reset the count
29853001

29863002
# pyre-ignore [16] bucket_sorted_id_splits is not None
29873003
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
29883004
active_id_cnt_per_bucket_split.append(bucket_t)
3005+
metadata_splits.append(metadata_tensor)
29893006

29903007
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
29913008
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
@@ -3041,7 +3058,12 @@ def split_embedding_weights(
30413058
f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
30423059
)
30433060

3044-
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
3061+
return (
3062+
pmt_splits,
3063+
bucket_sorted_id_splits,
3064+
active_id_cnt_per_bucket_split,
3065+
metadata_splits,
3066+
)
30453067

30463068
@torch.jit.ignore
30473069
def _apply_state_dict_w_offloading(self) -> None:

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def execute_ssd_backward_(
795795
def split_optimizer_states_(
796796
self, emb: SSDTableBatchedEmbeddingBags
797797
) -> List[List[torch.Tensor]]:
798-
_, bucket_asc_ids_list, _ = emb.split_embedding_weights(
798+
_, bucket_asc_ids_list, _, _ = emb.split_embedding_weights(
799799
no_snapshot=False, should_flush=True
800800
)
801801

@@ -1289,7 +1289,7 @@ def test_ssd_emb_state_dict(
12891289
split_optimizer_states = self.split_optimizer_states_(emb)
12901290

12911291
# Compare emb state dict with expected values from nn.EmbeddingBag
1292-
emb_state_dict, _, _ = emb.split_embedding_weights(no_snapshot=False)
1292+
emb_state_dict, _, _, _ = emb.split_embedding_weights(no_snapshot=False)
12931293
for feature_index, table_index in self.get_physical_table_arg_indices_(
12941294
emb.feature_table_map
12951295
):
@@ -1904,9 +1904,12 @@ def test_kv_emb_state_dict(
19041904
split_optimizer_states = []
19051905

19061906
# Compare emb state dict with expected values from nn.EmbeddingBag
1907-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
1908-
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
1909-
)
1907+
(
1908+
emb_state_dict_list,
1909+
bucket_asc_ids_list,
1910+
num_active_id_per_bucket_list,
1911+
metadata_list,
1912+
) = emb.split_embedding_weights(no_snapshot=False, should_flush=True)
19101913

19111914
for s in emb.split_optimizer_states(
19121915
bucket_asc_ids_list, no_snapshot=False, should_flush=True
@@ -1973,6 +1976,7 @@ def test_kv_emb_state_dict(
19731976
)
19741977
self.assertLess(table_index, len(emb_state_dict_list))
19751978
assert len(split_optimizer_states[table_index][0]) == num_ids
1979+
assert len(metadata_list[table_index]) == num_ids
19761980
# NOTE: The [0] index is a hack since the test is fixed to use
19771981
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
19781982
# be upgraded in the future to support multiple optimizers
@@ -2119,7 +2123,7 @@ def test_kv_opt_state_w_offloading(
21192123
)
21202124

21212125
# Compare emb state dict with expected values from nn.EmbeddingBag
2122-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
2126+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = (
21232127
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
21242128
)
21252129
split_optimizer_states = emb.split_optimizer_states(
@@ -2348,7 +2352,7 @@ def test_kv_state_dict_w_backend_return_whole_row(
23482352
)
23492353

23502354
# Compare emb state dict with expected values from nn.EmbeddingBag
2351-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
2355+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = (
23522356
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
23532357
)
23542358
split_optimizer_states = emb.split_optimizer_states(
@@ -2616,7 +2620,7 @@ def test_apply_kv_state_dict(
26162620
)
26172621

26182622
# Compare emb state dict with expected values from nn.EmbeddingBag
2619-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
2623+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = (
26202624
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
26212625
)
26222626
split_optimizer_states = emb.split_optimizer_states(
@@ -2684,6 +2688,7 @@ def test_apply_kv_state_dict(
26842688
emb_state_dict_list2,
26852689
bucket_asc_ids_list2,
26862690
num_active_id_per_bucket_list2,
2691+
_,
26872692
) = emb2.split_embedding_weights(no_snapshot=False, should_flush=True)
26882693
split_optimizer_states2 = emb2.split_optimizer_states(
26892694
bucket_asc_ids_list2, no_snapshot=False, should_flush=True
@@ -3139,7 +3144,7 @@ def copy_opt_states_hook(
31393144
emb.flush()
31403145

31413146
# Compare emb state dict with expected values from nn.EmbeddingBag
3142-
_emb_state_dict_list, bucket_asc_ids_list, _num_active_id_per_bucket_list = (
3147+
_emb_state_dict_list, bucket_asc_ids_list, _num_active_id_per_bucket_list, _ = (
31433148
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
31443149
)
31453150
assert bucket_asc_ids_list is not None

0 commit comments

Comments
 (0)