diff --git a/src/huggingface_hub/dataclasses.py b/src/huggingface_hub/dataclasses.py index ed08f246f5..677cb91484 100644 --- a/src/huggingface_hub/dataclasses.py +++ b/src/huggingface_hub/dataclasses.py @@ -166,13 +166,27 @@ def __init__(self, **kwargs: Any) -> None: # Call the original __init__ with standard fields original_init(self, **standard_kwargs) - # Add any additional kwargs as attributes + # Pass any additional kwargs to `__post_init__` and let the object + # decide whether to set the attr or use for different purposes (e.g. BC checks) + additional_kwargs = {} for name, value in kwargs.items(): if name not in dataclass_fields: - self.__setattr__(name, value) + additional_kwargs[name] = value + + self.__post_init__(**additional_kwargs) cls.__init__ = __init__ # type: ignore[method-assign] + # Define a default __post_init__ if not defined + if not hasattr(cls, "__post_init__"): + + def __post_init__(self, **kwargs: Any) -> None: + """Default __post_init__ to accept additional kwargs.""" + for name, value in kwargs.items(): + setattr(self, name, value) + + cls.__post_init__ = __post_init__ # type: ignore[method-assign] + # (optional) Override __repr__ to include additional kwargs original_repr = cls.__repr__ diff --git a/tests/test_utils_strict_dataclass.py b/tests/test_utils_strict_dataclass.py index 2b1b147b23..fb8cf1094b 100644 --- a/tests/test_utils_strict_dataclass.py +++ b/tests/test_utils_strict_dataclass.py @@ -84,6 +84,18 @@ class ConfigWithKwargs: vocab_size: int = validated_field(validator=positive_int, default=16) +@strict(accept_kwargs=True) +@dataclass +class ConfigWithKwargsAndPostInit: + model_type: str + vocab_size: int = validated_field(validator=positive_int, default=16) + + def __post_init__(self, **kwargs: Any) -> None: + """Custom __post_init__ that also accepts additional kwargs.""" + for name, value in kwargs.items(): + setattr(self, name.upper(), value) # store additional kwargs in uppercase (just for testing) + + class DummyClass: pass @@ -376,6 +388,13 @@ class Config: Config(model_type="bert", vocab_size=30000) +def test_post_init_with_kwargs(): + config = ConfigWithKwargsAndPostInit(model_type="bert", vocab_size=30000, extra_param="extra_value") + assert config.model_type == "bert" + assert config.vocab_size == 30000 + assert config.EXTRA_PARAM == "extra_value" # stored in uppercase by custom __post_init__ + + def test_is_recognized_as_dataclass(): # Check that dataclasses module recognizes it as a dataclass assert is_dataclass(Config)