Skip to content

Commit 7c718a9

Browse files
committed
add tests for country code and unknown language
1 parent 4207c15 commit 7c718a9

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

machine-learning/app/test_main.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,48 @@ def test_openclip_tokenizer_adds_flores_token_for_nllb(
447447

448448
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
449449

450+
def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found(
451+
self,
452+
mocker: MockerFixture,
453+
clip_model_cfg: dict[str, Any],
454+
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
455+
) -> None:
456+
mocker.patch.object(OpenClipTextualEncoder, "download")
457+
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
458+
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
459+
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
460+
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
461+
mock_ids = [randint(0, 50000) for _ in range(77)]
462+
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
463+
464+
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
465+
clip_encoder._load()
466+
clip_encoder.tokenize("test search query", language="de-CH")
467+
468+
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
469+
470+
def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found(
471+
self,
472+
mocker: MockerFixture,
473+
clip_model_cfg: dict[str, Any],
474+
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
475+
warning: mock.Mock,
476+
) -> None:
477+
mocker.patch.object(OpenClipTextualEncoder, "download")
478+
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
479+
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
480+
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
481+
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
482+
mock_ids = [randint(0, 50000) for _ in range(77)]
483+
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
484+
485+
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
486+
clip_encoder._load()
487+
clip_encoder.tokenize("test search query", language="unknown")
488+
489+
mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query")
490+
warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'")
491+
450492
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
451493
self,
452494
mocker: MockerFixture,

0 commit comments

Comments
 (0)