20
20
PoolingMode ,
21
21
rounded_row_size_in_bytes ,
22
22
)
23
+ from fbgemm_gpu .tbe .cache .kv_embedding_ops_inference import KVEmbeddingInference
23
24
from torchrec .distributed .batched_embedding_kernel import (
24
25
BaseBatchedEmbedding ,
25
26
BaseBatchedEmbeddingBag ,
@@ -119,6 +120,32 @@ def _quantize_weight(
119
120
return quant_weight_list
120
121
121
122
123
+ def _get_shard_offsets_for_kv_zch (
124
+ config : GroupedEmbeddingConfig ,
125
+ shard_index : int ,
126
+ ) -> List [int ]:
127
+ """
128
+ Given kv zch tables are rw sharded, getting the row offsets for each shard
129
+ at level to be used witin kv zch look up kernel
130
+ """
131
+ shard_row_offsets = []
132
+ for table in config .embedding_tables :
133
+ assert (
134
+ table .global_metadata is not None
135
+ ), f"Expected global_metadata to be populated for table { table .name } to get shard offsets for kv zch look up kernel"
136
+ assert (
137
+ len (table .global_metadata .shards_metadata ) > shard_index
138
+ ), f"Expected table { table .name } to have more shards than shard index { shard_index } . Found { len (table .global_metadata .shards_metadata )} shards"
139
+ shard_row_offsets .append (
140
+ # pyre-ignore: Undefined attribute [16]
141
+ table .global_metadata .shards_metadata [shard_index ].shard_offsets [0 ]
142
+ )
143
+ logger .info (
144
+ f"Shard row offsets for kv zch look up table { config .embedding_names = } : { shard_row_offsets = } "
145
+ )
146
+ return shard_row_offsets
147
+
148
+
122
149
def _get_runtime_device (
123
150
device : Optional [torch .device ],
124
151
config : GroupedEmbeddingConfig ,
@@ -237,13 +264,16 @@ def __init__(
237
264
super ().__init__ (config , pg , device )
238
265
239
266
managed : List [EmbeddingLocation ] = []
267
+ is_virtual_table : bool = False
240
268
for table in config .embedding_tables :
241
269
if device is not None and device .type == "cuda" :
242
270
managed .append (
243
271
compute_kernel_to_embedding_location (table .compute_kernel )
244
272
)
245
273
else :
246
274
managed .append (EmbeddingLocation .HOST )
275
+ if table .use_virtual_table :
276
+ is_virtual_table = True
247
277
self ._config : GroupedEmbeddingConfig = config
248
278
self ._emb_module_registered : bool = is_fused_param_register_tbe (fused_params )
249
279
self ._is_weighted : Optional [bool ] = config .is_weighted
@@ -284,9 +314,21 @@ def __init__(
284
314
285
315
if self .lengths_to_tbe :
286
316
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
317
+ elif is_virtual_table :
318
+ tbe_clazz = KVEmbeddingInference
287
319
else :
288
320
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen
289
321
322
+ if is_virtual_table :
323
+ assert (
324
+ shard_index is not None and shard_index >= 0
325
+ ), "valid shard_index must be provided for kv zch batch embedding to compute shard offsets"
326
+ shard_offsets_for_kv_zch = _get_shard_offsets_for_kv_zch (
327
+ config , shard_index
328
+ )
329
+ else :
330
+ shard_offsets_for_kv_zch = None
331
+
290
332
self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz (
291
333
embedding_specs = embedding_specs ,
292
334
device = device ,
@@ -448,13 +490,16 @@ def __init__(
448
490
super ().__init__ (config , pg , device )
449
491
450
492
managed : List [EmbeddingLocation ] = []
493
+ is_virtual_table = False
451
494
for table in config .embedding_tables :
452
495
if device is not None and device .type == "cuda" :
453
496
managed .append (
454
497
compute_kernel_to_embedding_location (table .compute_kernel )
455
498
)
456
499
else :
457
500
managed .append (EmbeddingLocation .HOST )
501
+ if table .use_virtual_table :
502
+ is_virtual_table = True
458
503
self ._config : GroupedEmbeddingConfig = config
459
504
self ._emb_module_registered : bool = is_fused_param_register_tbe (fused_params )
460
505
self ._quant_state_dict_split_scale_bias : bool = (
@@ -465,37 +510,52 @@ def __init__(
465
510
)
466
511
# 16 for CUDA, 1 for others like CPU and MTIA.
467
512
self ._tbe_row_alignment : int = 16 if self ._runtime_device .type == "cuda" else 1
468
- self ._emb_module : IntNBitTableBatchedEmbeddingBagsCodegen = (
469
- IntNBitTableBatchedEmbeddingBagsCodegen (
470
- embedding_specs = [
471
- (
472
- table .name ,
473
- local_rows ,
474
- (
475
- local_cols
476
- if self ._quant_state_dict_split_scale_bias
477
- else table .embedding_dim
478
- ),
479
- data_type_to_sparse_type (table .data_type ),
480
- location ,
481
- )
482
- for local_rows , local_cols , table , location in zip (
483
- self ._local_rows ,
484
- self ._local_cols ,
485
- config .embedding_tables ,
486
- managed ,
487
- )
488
- ],
489
- device = device ,
490
- pooling_mode = PoolingMode .NONE ,
491
- feature_table_map = self ._feature_table_map ,
492
- row_alignment = self ._tbe_row_alignment ,
493
- uvm_host_mapped = True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
494
- feature_names_per_table = [
495
- table .feature_names for table in config .embedding_tables
496
- ],
497
- ** (tbe_fused_params (fused_params ) or {}),
513
+ embedding_clazz = (
514
+ KVEmbeddingInference
515
+ if is_virtual_table
516
+ else IntNBitTableBatchedEmbeddingBagsCodegen
517
+ )
518
+ if is_virtual_table :
519
+ assert (
520
+ shard_index is not None and shard_index >= 0
521
+ ), "valid shard_index must be provided for kv zch batch embedding to compute shard offsets"
522
+ shard_offsets_for_kv_zch = _get_shard_offsets_for_kv_zch (
523
+ config , shard_index
498
524
)
525
+ else :
526
+ shard_offsets_for_kv_zch = None
527
+
528
+ self ._emb_module : (
529
+ IntNBitTableBatchedEmbeddingBagsCodegen | KVEmbeddingInference
530
+ ) = embedding_clazz (
531
+ embedding_specs = [
532
+ (
533
+ table .name ,
534
+ local_rows ,
535
+ (
536
+ local_cols
537
+ if self ._quant_state_dict_split_scale_bias
538
+ else table .embedding_dim
539
+ ),
540
+ data_type_to_sparse_type (table .data_type ),
541
+ location ,
542
+ )
543
+ for local_rows , local_cols , table , location in zip (
544
+ self ._local_rows ,
545
+ self ._local_cols ,
546
+ config .embedding_tables ,
547
+ managed ,
548
+ )
549
+ ],
550
+ device = device ,
551
+ pooling_mode = PoolingMode .NONE ,
552
+ feature_table_map = self ._feature_table_map ,
553
+ row_alignment = self ._tbe_row_alignment ,
554
+ uvm_host_mapped = True , # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
555
+ feature_names_per_table = [
556
+ table .feature_names for table in config .embedding_tables
557
+ ],
558
+ ** (tbe_fused_params (fused_params ) or {}),
499
559
)
500
560
if device is not None :
501
561
self ._emb_module .initialize_weights ()
0 commit comments