Skip to content

Commit ce10c06

Browse files
committed
add ModelManager
- ModelManager features - import a spacy model created by DataManager - set metadata such as author, license, etc - publish the model to hugging face - add `test_model_manager` notebook with example of use - add `spacy-huggingface-hub` and `wheel` to dependencies also - refactor tests - simplify fixtures code in conftest.py - add `model_path` fixture to conftest.py - pin ipywidgets<8.0.5 for now to avoid test failures in CI - looks due to this change: jupyter-widgets/ipywidgets#3533
1 parent d092864 commit ce10c06

File tree

6 files changed

+382
-22
lines changed

6 files changed

+382
-22
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ on:
1010

1111
jobs:
1212
test:
13+
name: "${{ matrix.os }} :: ${{ matrix.python-version }}"
1314
runs-on: ${{ matrix.os }}
1415
strategy:
1516
matrix:
16-
os: [ubuntu-20.04]
17+
os: [ubuntu-latest]
1718
python-version: [3.9]
1819
steps:
1920
- name: Checkout repository
@@ -30,7 +31,7 @@ jobs:
3031
- name: Run pytest
3132
run: |
3233
cd moralization
33-
python -m pytest -s --cov=. --cov-report=xml
34+
python -m pytest -v -s --cov=. --cov-report=xml
3435
- name: Upload coverage
3536
uses: codecov/codecov-action@v3
3637
with:

moralization/model_manager.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import huggingface_hub
2+
import spacy_huggingface_hub
3+
import os
4+
import spacy
5+
from pathlib import Path
6+
from typing import Union, Optional, Dict, Any
7+
import tempfile
8+
import re
9+
import logging
10+
11+
12+
def _construct_wheel_path(model_path: Path, meta: Dict[str, Any]) -> Path:
13+
full_name = f"{meta['lang']}_{meta['name']}-{meta['version']}"
14+
return model_path / full_name / "dist" / f"{full_name}-py3-none-any.whl"
15+
16+
17+
def _make_valid_package_name(name: str) -> str:
18+
# attempt to make name valid, throw exception if we fail
19+
# https://packaging.python.org/en/latest/specifications/name-normalization
20+
valid_name = re.sub(r"[-_.,<>!@#$%^&*()+ /?]+", "_", name).lower().strip("_")
21+
if name != valid_name:
22+
logging.warning(
23+
f"'{name}' not a valid package name, using '{valid_name}' instead"
24+
)
25+
if (
26+
re.match("^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", valid_name, re.IGNORECASE)
27+
is None
28+
):
29+
raise ValueError(
30+
"Invalid package name: Can only contain ASCII letters, numbers and underscore."
31+
)
32+
return valid_name
33+
34+
35+
class ModelManager:
36+
"""
37+
Import, modify and publish models to hugging face.
38+
"""
39+
40+
_meta_keys_to_expose_to_user = [
41+
"name",
42+
"version",
43+
"description",
44+
"author",
45+
"email",
46+
"url",
47+
"license",
48+
]
49+
50+
def __init__(self, model_path: Union[str, Path] = None):
51+
self.load(model_path)
52+
53+
def load(self, model_path: Union[str, Path]):
54+
"""Load a spacy model from `model_path`."""
55+
self.model_path = Path(model_path)
56+
self.spacy_model = spacy.load(model_path)
57+
self.metadata = {
58+
k: self.spacy_model.meta.get(k, "")
59+
for k in self._meta_keys_to_expose_to_user
60+
}
61+
62+
def save(self):
63+
"""Save any changes made to the model metadata."""
64+
self._update_metadata()
65+
self.spacy_model.to_disk(self.model_path)
66+
67+
def publish(self, hugging_face_token: Optional[str] = None) -> Dict[str, str]:
68+
"""Publish the model to Hugging Face.
69+
70+
This requires a User Access Token from https://huggingface.co/
71+
72+
The token can either be passed via the `hugging_face_token` argument,
73+
or it can be set via the `HUGGING_FACE_TOKEN` environment variable.
74+
75+
Args:
76+
hugging_face_token (str, optional): Hugging Face User Access Token
77+
Returns:
78+
dict: URLs of the published model and the pip-installable wheel
79+
"""
80+
self.save()
81+
if hugging_face_token is None:
82+
hugging_face_token = os.environ.get("HUGGING_FACE_TOKEN")
83+
if hugging_face_token is None:
84+
raise ValueError(
85+
"API TOKEN required: pass as string or set the HUGGING_FACE_TOKEN environment variable."
86+
)
87+
huggingface_hub.login(token=hugging_face_token)
88+
with tempfile.TemporaryDirectory() as tmpdir:
89+
# convert model to a python package incl binary wheel
90+
output_path = Path(tmpdir)
91+
spacy.cli.package(self.model_path, output_path, create_wheel=True)
92+
# push the package to hugging face
93+
return spacy_huggingface_hub.push(
94+
_construct_wheel_path(output_path, self.spacy_model.meta)
95+
)
96+
97+
def _update_metadata(self):
98+
self.metadata["name"] = _make_valid_package_name(self.metadata.get("name"))
99+
for k, v in self.metadata.items():
100+
if k in self.spacy_model.meta:
101+
self.spacy_model.meta[k] = v

moralization/tests/conftest.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,44 @@
11
import pytest
22
from moralization import input_data
3+
from moralization.data_manager import DataManager
34
import pathlib
45

56

6-
def _data_path_fixture(dir_path):
7-
@pytest.fixture
8-
def _fixture():
9-
return dir_path
7+
@pytest.fixture(scope="session")
8+
def data_dir():
9+
return pathlib.Path(__file__).parents[1].resolve() / "data"
1010

11-
return _fixture
1211

12+
@pytest.fixture(scope="session")
13+
def ts_file(data_dir):
14+
return data_dir / "TypeSystem.xml"
1315

14-
def _doc_dict_fixture(dir_path):
15-
@pytest.fixture
16-
def _fixture():
17-
return input_data.InputOutput.read_data(dir_path)
1816

19-
return _fixture
17+
@pytest.fixture(scope="session")
18+
def data_file(data_dir):
19+
return (
20+
data_dir / "test_data-trimmed_version_of-Interviews-pos-SH-neu-optimiert-AW.xmi"
21+
)
2022

2123

22-
dir_path = pathlib.Path(__file__).parents[1].resolve() / "data"
23-
data_dir = _data_path_fixture(dir_path)
24-
doc_dicts = _doc_dict_fixture(dir_path)
24+
@pytest.fixture(scope="session")
25+
def config_file(data_dir):
26+
return data_dir / "config.cfg"
2527

2628

27-
ts_file = _data_path_fixture(dir_path / "TypeSystem.xml")
28-
data_file = _data_path_fixture(
29-
dir_path / "test_data-trimmed_version_of-Interviews-pos-SH-neu-optimiert-AW.xmi"
30-
)
31-
config_file = _data_path_fixture(dir_path / "config.cfg")
29+
@pytest.fixture(scope="session")
30+
def model_path(data_dir, config_file, tmp_path_factory):
31+
"""
32+
Returns a temporary path containing a trained model.
33+
This is only created once and re-used for the entire pytest session.
34+
"""
35+
dm = DataManager(data_dir)
36+
dm.export_data_DocBin()
37+
tmp_path = tmp_path_factory.mktemp("model")
38+
dm.spacy_train(working_dir=tmp_path, config=config_file, n_epochs=1)
39+
yield tmp_path / "output" / "model-best"
40+
41+
42+
@pytest.fixture
43+
def doc_dicts(data_dir):
44+
return input_data.InputOutput.read_data(str(data_dir))
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from moralization.model_manager import ModelManager
2+
import spacy
3+
import pytest
4+
5+
6+
def test_model_manager_valid_path(model_path):
7+
model = ModelManager(model_path)
8+
assert model.spacy_model is not None
9+
assert model.spacy_model.lang == "de"
10+
assert model.spacy_model.path == model_path
11+
12+
13+
def test_model_manager_modify_metadata(model_path):
14+
model = ModelManager(model_path)
15+
# update metadata values and save model
16+
keys = ["name", "version", "description", "author", "email", "url", "license"]
17+
for key in keys:
18+
model.metadata[key] = f"{key}"
19+
model.save()
20+
for key in keys:
21+
assert model.metadata[key] == f"{key}"
22+
# re-load model
23+
model.load(model_path)
24+
for key in keys:
25+
assert model.metadata[key] == f"{key}"
26+
# load model directly in spacy and check its meta has also been updated
27+
nlp = spacy.load(model_path)
28+
for key in keys:
29+
assert nlp.meta[key] == f"{key}"
30+
31+
32+
def test_model_manager_modify_metadata_fixable_invalid_names(model_path):
33+
model = ModelManager(model_path)
34+
for invalid_name, valid_name in [("!hm & __OK?,...", "hm_ok"), ("Im - S", "im_s")]:
35+
model.metadata["name"] = invalid_name
36+
assert model.metadata["name"] == invalid_name
37+
# name is made valid on call to save()
38+
model.save()
39+
assert model.metadata["name"] == valid_name
40+
nlp = spacy.load(model_path)
41+
assert nlp.meta["name"] == valid_name
42+
43+
44+
def test_model_manager_modify_metadata_unfixable_invalid_names(model_path):
45+
model = ModelManager(model_path)
46+
for unfixable_invalid_name in ["", "_", "ü"]:
47+
model.metadata["name"] = unfixable_invalid_name
48+
with pytest.raises(ValueError) as e:
49+
model.save()
50+
assert "invalid" in str(e.value).lower()
51+
52+
53+
def test_model_manager_publish_no_token(model_path, monkeypatch):
54+
monkeypatch.delenv("HUGGING_FACE_TOKEN", raising=False)
55+
model = ModelManager(model_path)
56+
with pytest.raises(ValueError) as e:
57+
model.publish()
58+
assert "token" in str(e.value).lower()
59+
60+
61+
def test_model_manager_publish_invalid_token_env(model_path, monkeypatch):
62+
monkeypatch.setenv("HUGGING_FACE_TOKEN", "invalid")
63+
model = ModelManager(model_path)
64+
with pytest.raises(ValueError) as e:
65+
model.publish()
66+
assert "token" in str(e.value).lower()
67+
68+
69+
def test_model_manager_publish_invalid_token_arg(model_path):
70+
model = ModelManager(model_path)
71+
with pytest.raises(ValueError) as e:
72+
model.publish(hugging_face_token="invalid")
73+
assert "token" in str(e.value).lower()

0 commit comments

Comments
 (0)