diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index cd8f5f89d..c94380869 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -27,7 +27,7 @@ import torch import torch.nn.functional as F import transformers -from pydantic import PositiveInt +from pydantic import Field, PositiveInt from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm @@ -137,8 +137,8 @@ class TransformersModelConfig(ModelConfig): subfolder: str | None = None revision: str = "main" batch_size: PositiveInt | None = None - generation_size: PositiveInt = 256 max_length: PositiveInt | None = None + model_loading_kwargs: dict = Field(default_factory=dict) add_special_tokens: bool = True model_parallel: bool | None = None dtype: str | None = None @@ -384,7 +384,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel: pretrained_config = self.transformers_config - kwargs = {} + kwargs = self.config.model_loading_kwargs.copy() if "quantization_config" not in pretrained_config.to_dict(): kwargs["quantization_config"] = quantization_config