@@ -299,48 +299,6 @@ def get_bindings_model_config(self,
299
299
num_heads = self .pretrained_config .num_attention_heads // (
300
300
self .mapping .tp_size * self .mapping .cp_size )
301
301
302
- # Handle both uniform and per-layer KV heads
303
- num_kv_heads_per_layer = getattr (self .pretrained_config ,
304
- 'num_kv_heads_per_layer' , None )
305
- if num_kv_heads_per_layer is not None :
306
- # For models with per-layer KV heads, like nemotron-nas
307
- kv_heads_per_layer_raw = num_kv_heads_per_layer
308
- use_per_layer_kv_heads = True
309
- else :
310
- # Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
311
- num_kv_heads_raw = getattr (self .pretrained_config ,
312
- 'num_key_value_heads' , None )
313
-
314
- if num_kv_heads_raw is not None and isinstance (
315
- num_kv_heads_raw , list ):
316
- # num_key_value_heads is a list - treat as per-layer KV heads
317
- kv_heads_per_layer_raw = num_kv_heads_raw
318
- use_per_layer_kv_heads = True
319
- else :
320
- # num_key_value_heads is scalar or None - treat as uniform KV heads
321
- if num_kv_heads_raw is None :
322
- # For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
323
- num_kv_heads_raw = getattr (
324
- self .pretrained_config , 'num_query_groups' ,
325
- self .pretrained_config .num_attention_heads )
326
-
327
- num_kv_heads = num_kv_heads_raw // (self .mapping .tp_size *
328
- self .mapping .cp_size )
329
- use_per_layer_kv_heads = False
330
-
331
- if use_per_layer_kv_heads :
332
- # TRT-LLM LoRA requires uniform KV heads across layers
333
- if self .lora_config is not None and len (
334
- set (kv_heads_per_layer_raw )) > 1 :
335
- raise ValueError (
336
- f"TRT-LLM LoRA requires uniform KV heads across layers, "
337
- f"got: { kv_heads_per_layer_raw } " )
338
- # Apply TP/CP scaling to each layer
339
- num_kv_heads_per_layer = [
340
- kv_heads // (self .mapping .tp_size * self .mapping .cp_size )
341
- for kv_heads in kv_heads_per_layer_raw
342
- ]
343
-
344
302
hidden_size = self .pretrained_config .hidden_size // self .mapping .tp_size
345
303
346
304
model_config_cpp = ModelConfigCpp (
@@ -361,9 +319,18 @@ def get_bindings_model_config(self,
361
319
else :
362
320
model_config_cpp .tokens_per_block = tokens_per_block
363
321
364
- if use_per_layer_kv_heads :
322
+ num_key_value_heads = getattr (self .pretrained_config ,
323
+ "num_key_value_heads" , num_heads )
324
+ if isinstance (num_key_value_heads , (list , tuple )):
325
+ # Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
326
+ num_kv_heads_per_layer = [
327
+ kv_heads // (self .mapping .tp_size * self .mapping .cp_size )
328
+ for kv_heads in num_key_value_heads
329
+ ]
365
330
model_config_cpp .num_kv_heads_per_layer = num_kv_heads_per_layer
366
331
else :
332
+ num_kv_heads = num_key_value_heads // (self .mapping .tp_size *
333
+ self .mapping .cp_size )
367
334
model_config_cpp .set_num_kv_heads (num_kv_heads )
368
335
369
336
mlp_hidden_size = None
0 commit comments