Skip to content

kvzch use new operator in model publish #3108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 42 additions & 30 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PoolingMode,
rounded_row_size_in_bytes,
)
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
from torchrec.distributed.batched_embedding_kernel import (
BaseBatchedEmbedding,
BaseBatchedEmbeddingBag,
Expand Down Expand Up @@ -237,13 +238,16 @@ def __init__(
super().__init__(config, pg, device)

managed: List[EmbeddingLocation] = []
is_virtual_table: bool = False
for table in config.embedding_tables:
if device is not None and device.type == "cuda":
managed.append(
compute_kernel_to_embedding_location(table.compute_kernel)
)
else:
managed.append(EmbeddingLocation.HOST)
if table.use_virtual_table:
is_virtual_table = True
self._config: GroupedEmbeddingConfig = config
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
self._is_weighted: Optional[bool] = config.is_weighted
Expand Down Expand Up @@ -284,6 +288,8 @@ def __init__(

if self.lengths_to_tbe:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
elif is_virtual_table:
tbe_clazz = KVEmbeddingInference
else:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen

Expand Down Expand Up @@ -448,13 +454,16 @@ def __init__(
super().__init__(config, pg, device)

managed: List[EmbeddingLocation] = []
is_virtual_table = False
for table in config.embedding_tables:
if device is not None and device.type == "cuda":
managed.append(
compute_kernel_to_embedding_location(table.compute_kernel)
)
else:
managed.append(EmbeddingLocation.HOST)
if table.use_virtual_table:
is_virtual_table = True
self._config: GroupedEmbeddingConfig = config
self._emb_module_registered: bool = is_fused_param_register_tbe(fused_params)
self._quant_state_dict_split_scale_bias: bool = (
Expand All @@ -465,37 +474,40 @@ def __init__(
)
# 16 for CUDA, 1 for others like CPU and MTIA.
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
embedding_clazz = (
KVEmbeddingInference
if is_virtual_table
else IntNBitTableBatchedEmbeddingBagsCodegen
)
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
embedding_specs=[
(
table.name,
local_rows,
(
table.name,
local_rows,
(
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(table.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
device=device,
pooling_mode=PoolingMode.NONE,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(table.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
device=device,
pooling_mode=PoolingMode.NONE,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
if device is not None:
self._emb_module.initialize_weights()
Expand Down
26 changes: 16 additions & 10 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
IntNBitTableBatchedEmbeddingBagsCodegen,
PoolingMode,
)
from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference
from torch import Tensor
from torchrec.distributed.utils import none_throws
from torchrec.modules.embedding_configs import (
Expand Down Expand Up @@ -357,7 +358,7 @@ def __init__(
self._is_weighted = is_weighted
self._embedding_bag_configs: List[EmbeddingBagConfig] = tables
self._key_to_tables: Dict[
Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig]
Tuple[PoolingType, bool], List[EmbeddingBagConfig]
] = defaultdict(list)
self._feature_names: List[str] = []
self._feature_splits: List[int] = []
Expand All @@ -383,15 +384,13 @@ def __init__(
key = (table.pooling, table.use_virtual_table)
else:
key = (table.pooling, False)
# pyre-ignore
self._key_to_tables[key].append(table)

location = (
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
)

for key, emb_configs in self._key_to_tables.items():
pooling = key[0]
for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items():
embedding_specs = []
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand Down Expand Up @@ -420,7 +419,12 @@ def __init__(
)
feature_table_map.extend([idx] * table.num_features())

emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_clazz = (
KVEmbeddingInference
if use_virtual_table
else IntNBitTableBatchedEmbeddingBagsCodegen
)
emb_module = embedding_clazz(
embedding_specs=embedding_specs,
pooling_mode=pooling_type_to_pooling_mode(pooling),
weight_lists=weight_lists,
Expand Down Expand Up @@ -790,8 +794,7 @@ def __init__( # noqa C901
key = (table.data_type, False)
self._key_to_tables[key].append(table)
self._feature_splits: List[int] = []
for key, emb_configs in self._key_to_tables.items():
data_type = key[0]
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
embedding_specs = []
weight_lists: Optional[
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
Expand All @@ -816,10 +819,13 @@ def __init__( # noqa C901
table_name_to_quantized_weights[table.name]
)
feature_table_map.extend([idx] * table.num_features())
# move to here to make sure feature_names order is consistent with the embedding groups
self._feature_names.extend(table.feature_names)

emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_clazz = (
KVEmbeddingInference
if use_virtual_table
else IntNBitTableBatchedEmbeddingBagsCodegen
)
emb_module = embedding_clazz(
embedding_specs=embedding_specs,
pooling_mode=PoolingMode.NONE,
weight_lists=weight_lists,
Expand Down
24 changes: 24 additions & 0 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import logging
import unittest
from dataclasses import replace
from typing import Dict, List, Optional, Type
Expand Down Expand Up @@ -44,6 +45,19 @@
KeyedTensor,
)

logger: logging.Logger = logging.getLogger(__name__)


def load_required_dram_kv_embedding_libraries() -> bool:
try:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference"
)
return True
except Exception as e:
logger.error(f"Failed to load dram_kv_embedding libraries, skipping test: {e}")
return False


class EmbeddingBagCollectionTest(unittest.TestCase):
def _asserting_same_embeddings(
Expand Down Expand Up @@ -260,6 +274,11 @@ def test_multiple_features(self) -> None:
)
self._test_ebc([eb1_config, eb2_config], features)

# pyre-ignore: Invalid decoration [56]
@unittest.skipIf(
not load_required_dram_kv_embedding_libraries(),
"Skip when required libraries are not available",
)
def test_multiple_kernels_per_ebc_table(self) -> None:
class TestModule(torch.nn.Module):
def __init__(self, m: torch.nn.Module) -> None:
Expand Down Expand Up @@ -780,6 +799,11 @@ def __init__(self, m: torch.nn.Module) -> None:
self.assertEqual(config.name, "t2")
self.assertEqual(config.data_type, DataType.INT8)

# pyre-ignore: Invalid decoration [56]
@unittest.skipIf(
not load_required_dram_kv_embedding_libraries(),
"Skip when required libraries are not available",
)
def test_multiple_kernels_per_ec_table(self) -> None:
class TestModule(torch.nn.Module):
def __init__(self, m: torch.nn.Module) -> None:
Expand Down
Loading