@@ -76,14 +76,24 @@ class AutoModelForCausalLMFactory(ModelFactory):
76
76
"max_position_embeddings" : 1024 ,
77
77
}
78
78
79
+ def _get_max_position_embeddings_config (self ) -> Dict [str , Any ]:
80
+ """Get the max position embeddings config for the model."""
81
+ return {
82
+ "max_position_embeddings" : self .max_seq_len ,
83
+ }
84
+
79
85
def __init__ (self , * args , ** kwargs ):
80
86
super ().__init__ (* args , ** kwargs )
81
87
82
88
self ._quant_config : Optional [Dict ] = None
83
89
84
90
# Ingest defaults for tokenizer and model kwargs
85
91
self .tokenizer_kwargs = deep_merge_dicts (self ._tokenizer_defaults , self .tokenizer_kwargs )
86
- self .model_kwargs = deep_merge_dicts (self ._model_defaults , self .model_kwargs )
92
+ self .model_kwargs = deep_merge_dicts (
93
+ self ._model_defaults ,
94
+ self .model_kwargs ,
95
+ self ._get_max_position_embeddings_config (),
96
+ )
87
97
88
98
# special handling for torch_dtype in model_kwargs since HF does not correctly update
89
99
# torch_dtype string to an actual torch.dtype object (only with default)
@@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
344
354
},
345
355
}
346
356
357
+ def _get_max_position_embeddings_config (self ) -> Dict [str , Any ]:
358
+ """Get the max position embeddings config for the model."""
359
+ return {
360
+ "max_position_embeddings" : self .max_seq_len ,
361
+ "text_config" : {
362
+ "max_position_embeddings" : self .max_seq_len ,
363
+ },
364
+ }
365
+
347
366
@property
348
367
def automodel_from_config (self ):
349
368
return AutoModelForImageTextToText .from_config
0 commit comments