Skip to content

Commit 718234b

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kvzch use new operator in model publish (#3108)
Summary: Publish change to enable KVEmbeddingInference when use_virtual_table is set to true Reviewed By: emlin Differential Revision: D75321284
1 parent a5d2d12 commit 718234b

File tree

3 files changed

+107
-40
lines changed

3 files changed

+107
-40
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def create_sharding_infos_by_sharding_device_group(
292292
getattr(config, "num_embeddings_post_pruning", None)
293293
# TODO: Need to check if attribute exists for BC
294294
),
295+
use_virtual_table=config.use_virtual_table,
295296
),
296297
param_sharding=parameter_sharding,
297298
param=param,

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PoolingMode,
2121
rounded_row_size_in_bytes,
2222
)
23+
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
2324
from torchrec.distributed.batched_embedding_kernel import (
2425
BaseBatchedEmbedding,
2526
BaseBatchedEmbeddingBag,
@@ -119,6 +120,32 @@ def _quantize_weight(
119120
return quant_weight_list
120121

121122

123+
def _get_shard_offsets_for_kv_zch(
124+
config: GroupedEmbeddingConfig,
125+
shard_index: int,
126+
) -> List[int]:
127+
"""
128+
Given kv zch tables are rw sharded, getting the row offsets for each shard
129+
at level to be used witin kv zch look up kernel
130+
"""
131+
shard_row_offsets = []
132+
for table in config.embedding_tables:
133+
assert (
134+
table.global_metadata is not None
135+
), f"Expected global_metadata to be populated for table {table.name} to get shard offsets for kv zch look up kernel"
136+
assert (
137+
len(table.global_metadata.shards_metadata) > shard_index
138+
), f"Expected table {table.name} to have more shards than shard index {shard_index}. Found {len(table.global_metadata.shards_metadata)} shards"
139+
shard_row_offsets.append(
140+
# pyre-ignore: Undefined attribute [16]
141+
table.global_metadata.shards_metadata[shard_index].shard_offsets[0]
142+
)
143+
logger.info(
144+
f"Shard row offsets for kv zch look up table {config.embedding_names=}: {shard_row_offsets=}"
145+
)
146+
return shard_row_offsets
147+
148+
122149
def _get_runtime_device(
123150
device: Optional[torch.device],
124151
config: GroupedEmbeddingConfig,
@@ -237,13 +264,16 @@ def __init__(
237264
super().__init__(config, pg, device)
238265

239266
managed: List[EmbeddingLocation] = []
267+
is_virtual_table: bool = False
240268
for table in config.embedding_tables:
241269
if device is not None and device.type == "cuda":
242270
managed.append(
243271
compute_kernel_to_embedding_location(table.compute_kernel)
244272
)
245273
else:
246274
managed.append(EmbeddingLocation.HOST)
275+
if table.use_virtual_table:
276+
is_virtual_table = True
247277
self._config: GroupedEmbeddingConfig = config
248278
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
249279
self._is_weighted: Optional[bool] = config.is_weighted
@@ -284,9 +314,21 @@ def __init__(
284314

285315
if self.lengths_to_tbe:
286316
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
317+
elif is_virtual_table:
318+
tbe_clazz = KVEmbeddingInference
287319
else:
288320
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen
289321

322+
if is_virtual_table:
323+
assert (
324+
shard_index is not None and shard_index >= 0
325+
), "valid shard_index must be provided for kv zch batch embedding to compute shard offsets"
326+
shard_offsets_for_kv_zch = _get_shard_offsets_for_kv_zch(
327+
config, shard_index
328+
)
329+
else:
330+
shard_offsets_for_kv_zch = None
331+
290332
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
291333
embedding_specs=embedding_specs,
292334
device=device,
@@ -448,13 +490,16 @@ def __init__(
448490
super().__init__(config, pg, device)
449491

450492
managed: List[EmbeddingLocation] = []
493+
is_virtual_table = False
451494
for table in config.embedding_tables:
452495
if device is not None and device.type == "cuda":
453496
managed.append(
454497
compute_kernel_to_embedding_location(table.compute_kernel)
455498
)
456499
else:
457500
managed.append(EmbeddingLocation.HOST)
501+
if table.use_virtual_table:
502+
is_virtual_table = True
458503
self._config: GroupedEmbeddingConfig = config
459504
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
460505
self._quant_state_dict_split_scale_bias: bool = (
@@ -465,37 +510,52 @@ def __init__(
465510
)
466511
# 16 for CUDA, 1 for others like CPU and MTIA.
467512
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
468-
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
469-
IntNBitTableBatchedEmbeddingBagsCodegen(
470-
embedding_specs=[
471-
(
472-
table.name,
473-
local_rows,
474-
(
475-
local_cols
476-
if self._quant_state_dict_split_scale_bias
477-
else table.embedding_dim
478-
),
479-
data_type_to_sparse_type(table.data_type),
480-
location,
481-
)
482-
for local_rows, local_cols, table, location in zip(
483-
self._local_rows,
484-
self._local_cols,
485-
config.embedding_tables,
486-
managed,
487-
)
488-
],
489-
device=device,
490-
pooling_mode=PoolingMode.NONE,
491-
feature_table_map=self._feature_table_map,
492-
row_alignment=self._tbe_row_alignment,
493-
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
494-
feature_names_per_table=[
495-
table.feature_names for table in config.embedding_tables
496-
],
497-
**(tbe_fused_params(fused_params) or {}),
513+
embedding_clazz = (
514+
KVEmbeddingInference
515+
if is_virtual_table
516+
else IntNBitTableBatchedEmbeddingBagsCodegen
517+
)
518+
if is_virtual_table:
519+
assert (
520+
shard_index is not None and shard_index >= 0
521+
), "valid shard_index must be provided for kv zch batch embedding to compute shard offsets"
522+
shard_offsets_for_kv_zch = _get_shard_offsets_for_kv_zch(
523+
config, shard_index
498524
)
525+
else:
526+
shard_offsets_for_kv_zch = None
527+
528+
self._emb_module: (
529+
IntNBitTableBatchedEmbeddingBagsCodegen | KVEmbeddingInference
530+
) = embedding_clazz(
531+
embedding_specs=[
532+
(
533+
table.name,
534+
local_rows,
535+
(
536+
local_cols
537+
if self._quant_state_dict_split_scale_bias
538+
else table.embedding_dim
539+
),
540+
data_type_to_sparse_type(table.data_type),
541+
location,
542+
)
543+
for local_rows, local_cols, table, location in zip(
544+
self._local_rows,
545+
self._local_cols,
546+
config.embedding_tables,
547+
managed,
548+
)
549+
],
550+
device=device,
551+
pooling_mode=PoolingMode.NONE,
552+
feature_table_map=self._feature_table_map,
553+
row_alignment=self._tbe_row_alignment,
554+
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
555+
feature_names_per_table=[
556+
table.feature_names for table in config.embedding_tables
557+
],
558+
**(tbe_fused_params(fused_params) or {}),
499559
)
500560
if device is not None:
501561
self._emb_module.initialize_weights()

torchrec/quant/embedding_modules.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
IntNBitTableBatchedEmbeddingBagsCodegen,
3131
PoolingMode,
3232
)
33+
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
3334
from torch import Tensor
3435
from torchrec.distributed.utils import none_throws
3536
from torchrec.modules.embedding_configs import (
@@ -357,7 +358,7 @@ def __init__(
357358
self._is_weighted = is_weighted
358359
self._embedding_bag_configs: List[EmbeddingBagConfig] = tables
359360
self._key_to_tables: Dict[
360-
Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig]
361+
Tuple[PoolingType, bool], List[EmbeddingBagConfig]
361362
] = defaultdict(list)
362363
self._feature_names: List[str] = []
363364
self._feature_splits: List[int] = []
@@ -383,15 +384,13 @@ def __init__(
383384
key = (table.pooling, table.use_virtual_table)
384385
else:
385386
key = (table.pooling, False)
386-
# pyre-ignore
387387
self._key_to_tables[key].append(table)
388388

389389
location = (
390390
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
391391
)
392392

393-
for key, emb_configs in self._key_to_tables.items():
394-
pooling = key[0]
393+
for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items():
395394
embedding_specs = []
396395
weight_lists: Optional[
397396
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -420,7 +419,12 @@ def __init__(
420419
)
421420
feature_table_map.extend([idx] * table.num_features())
422421

423-
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
422+
embedding_clazz = (
423+
KVEmbeddingInference
424+
if use_virtual_table
425+
else IntNBitTableBatchedEmbeddingBagsCodegen
426+
)
427+
emb_module = embedding_clazz(
424428
embedding_specs=embedding_specs,
425429
pooling_mode=pooling_type_to_pooling_mode(pooling),
426430
weight_lists=weight_lists,
@@ -790,8 +794,7 @@ def __init__( # noqa C901
790794
key = (table.data_type, False)
791795
self._key_to_tables[key].append(table)
792796
self._feature_splits: List[int] = []
793-
for key, emb_configs in self._key_to_tables.items():
794-
data_type = key[0]
797+
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
795798
embedding_specs = []
796799
weight_lists: Optional[
797800
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -816,10 +819,13 @@ def __init__( # noqa C901
816819
table_name_to_quantized_weights[table.name]
817820
)
818821
feature_table_map.extend([idx] * table.num_features())
819-
# move to here to make sure feature_names order is consistent with the embedding groups
820822
self._feature_names.extend(table.feature_names)
821-
822-
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
823+
embedding_clazz = (
824+
KVEmbeddingInference
825+
if use_virtual_table
826+
else IntNBitTableBatchedEmbeddingBagsCodegen
827+
)
828+
emb_module = embedding_clazz(
823829
embedding_specs=embedding_specs,
824830
pooling_mode=PoolingMode.NONE,
825831
weight_lists=weight_lists,

0 commit comments

Comments
 (0)