Skip to content

Commit c41dbe2

Browse files
committed
respect max_seq_len setting for pos embeddings
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 0c224e4 commit c41dbe2

File tree

1 file changed

+20
-1
lines changed
  • tensorrt_llm/_torch/auto_deploy/models

1 file changed

+20
-1
lines changed

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,24 @@ class AutoModelForCausalLMFactory(ModelFactory):
7676
"max_position_embeddings": 1024,
7777
}
7878

79+
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
80+
"""Get the max position embeddings config for the model."""
81+
return {
82+
"max_position_embeddings": self.max_seq_len,
83+
}
84+
7985
def __init__(self, *args, **kwargs):
8086
super().__init__(*args, **kwargs)
8187

8288
self._quant_config: Optional[Dict] = None
8389

8490
# Ingest defaults for tokenizer and model kwargs
8591
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
86-
self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs)
92+
self.model_kwargs = deep_merge_dicts(
93+
self._model_defaults,
94+
self.model_kwargs,
95+
self._get_max_position_embeddings_config(),
96+
)
8797

8898
# special handling for torch_dtype in model_kwargs since HF does not correctly update
8999
# torch_dtype string to an actual torch.dtype object (only with default)
@@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
344354
},
345355
}
346356

357+
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
358+
"""Get the max position embeddings config for the model."""
359+
return {
360+
"max_position_embeddings": self.max_seq_len,
361+
"text_config": {
362+
"max_position_embeddings": self.max_seq_len,
363+
},
364+
}
365+
347366
@property
348367
def automodel_from_config(self):
349368
return AutoModelForImageTextToText.from_config

0 commit comments

Comments
 (0)