Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/huggingface_hub/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,27 @@ def __init__(self, **kwargs: Any) -> None:
# Call the original __init__ with standard fields
original_init(self, **standard_kwargs)

# Add any additional kwargs as attributes
# Pass any additional kwargs to `__post_init__` and let the object
# decide whether to set the attr or use for different purposes (e.g. BC checks)
additional_kwargs = {}
for name, value in kwargs.items():
if name not in dataclass_fields:
self.__setattr__(name, value)
additional_kwargs[name] = value

self.__post_init__(**additional_kwargs)

cls.__init__ = __init__ # type: ignore[method-assign]

# Define a default __post_init__ if not defined
if not hasattr(cls, "__post_init__"):

def __post_init__(self, **kwargs: Any) -> None:
"""Default __post_init__ to accept additional kwargs."""
for name, value in kwargs.items():
setattr(self, name, value)

cls.__post_init__ = __post_init__ # type: ignore[method-assign]

# (optional) Override __repr__ to include additional kwargs
original_repr = cls.__repr__

Expand Down
19 changes: 19 additions & 0 deletions tests/test_utils_strict_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ class ConfigWithKwargs:
vocab_size: int = validated_field(validator=positive_int, default=16)


@strict(accept_kwargs=True)
@dataclass
class ConfigWithKwargsAndPostInit:
model_type: str
vocab_size: int = validated_field(validator=positive_int, default=16)

def __post_init__(self, **kwargs: Any) -> None:
"""Custom __post_init__ that also accepts additional kwargs."""
for name, value in kwargs.items():
setattr(self, name.upper(), value) # store additional kwargs in uppercase (just for testing)


class DummyClass:
pass

Expand Down Expand Up @@ -376,6 +388,13 @@ class Config:
Config(model_type="bert", vocab_size=30000)


def test_post_init_with_kwargs():
config = ConfigWithKwargsAndPostInit(model_type="bert", vocab_size=30000, extra_param="extra_value")
assert config.model_type == "bert"
assert config.vocab_size == 30000
assert config.EXTRA_PARAM == "extra_value" # stored in uppercase by custom __post_init__


def test_is_recognized_as_dataclass():
# Check that dataclasses module recognizes it as a dataclass
assert is_dataclass(Config)
Expand Down
Loading