diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index e2ae440517..cf18c58589 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -89,16 +89,18 @@ def __getattr__(name): if name in submodules: return importlib.import_module(f"{package_name}.{name}") elif name in attr_to_modules: - submod = importlib.import_module(f"{package_name}.{attr_to_modules[name]}") + submod_path = f"{package_name}.{attr_to_modules[name]}" + submod = importlib.import_module(submod_path) + attr = getattr(submod, name) + + # If the attribute lives in a file (module) with the same + # name as the attribute, ensure that the attribute and *not* + # the module is accessible on the package. if name == attr_to_modules[name]: - warnings.warn( - _LazyImportWarning( - "Module attribute and module have same " - f"name: `{name}`; will likely cause conflicts " - "when accessing attribute." - ) - ) - return getattr(submod, name) + pkg = sys.modules[package_name] + pkg.__dict__[name] = attr + + return attr else: raise AttributeError(f"No {package_name} attribute {name}") diff --git a/src/huggingface_hub/snapshot_download.py b/src/huggingface_hub/snapshot_download.py new file mode 100644 index 0000000000..888fcc1d8c --- /dev/null +++ b/src/huggingface_hub/snapshot_download.py @@ -0,0 +1,14 @@ +# TODO: remove in 0.11 + +import warnings + + +warnings.warn( + "snapshot_download.py has been made private and will no longer be available from" + " version 0.11. Please use `from huggingface_hub import snapshot_download` to" + " import the only public function in this module. Other members of the file may be" + " changed without a deprecation notice.", + FutureWarning, +) + +from ._snapshot_download import * # noqa diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index b41b438a3c..d8220b3983 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -4,6 +4,8 @@ import time import unittest +import pytest + import requests from huggingface_hub import HfApi, Repository, snapshot_download from huggingface_hub.hf_api import HfFolder @@ -357,3 +359,10 @@ def test_download_model_with_ignore_regex(self): def test_download_model_with_ignore_regex_list(self): self.check_download_model_with_regex(["*.git*", "*.pt"], allow=False) + + +def test_snapshot_download_import(): + with pytest.warns(FutureWarning, match="has been made private"): + from huggingface_hub.snapshot_download import snapshot_download as x # noqa + + assert x is snapshot_download