Skip to content

Commit 6789c2a

Browse files
authored
feat(ml): better multilingual search with nllb models (#13567)
1 parent 838a8dd commit 6789c2a

File tree

16 files changed

+301
-18
lines changed

16 files changed

+301
-18
lines changed

docs/docs/features/searching.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Some search examples:
4545
</TabItem>
4646
<TabItem value="Mobile" label="Mobile">
4747

48-
<img src={require('./img/moblie-smart-serach.webp').default} width="30%" title='Smart search on mobile' />
48+
<img src={require('./img/mobile-smart-search.webp').default} width="30%" title='Smart search on mobile' />
4949

5050
</TabItem>
5151
</Tabs>
@@ -56,7 +56,20 @@ Navigating to `Administration > Settings > Machine Learning Settings > Smart Sea
5656

5757
### CLIP models
5858

59-
More powerful models can be used for more accurate search results, but are slower and can require more server resources. Check the dropdowns below to see how they compare in memory usage, speed and quality by language.
59+
The default search model is fast, but there are many other options that can provide better search results. The tradeoff of using these models is that they're slower and/or use more memory (both when indexing images with background Smart Search jobs and when searching).
60+
61+
The first step of choosing the right model for you is to know which languages your users will search in.
62+
63+
If your users will only search in English, then the [CLIP][huggingface-clip] section is the first place to look. This is a curated list of the models that generally perform the best for their size class. The models here are ordered from higher to lower quality. This means that the top models will generally rank the most relevant results higher and have a higher capacity to understand descriptive, detailed, and/or niche queries. The models are also generally ordered from larger to smaller, so consider the impact on memory usage, job processing and search speed when deciding on one. The smaller models in this list are not too different in quality and many times faster.
64+
65+
[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. Use these models if you expect non-English searches to be common. They can be separated into three search patterns:
66+
67+
- `nllb` models expect the search query to be in the language specified in the user settings
68+
- `xlm` and `siglip2` models understand search text regardless of the current language setting
69+
70+
`nllb` models tend to perform the best and are recommended when users primarily searches in their native, non-English language. `xlm` and `siglip2` models are more flexible and are recommended for mixed language search, where the same user might search in different languages at different times.
71+
72+
For more details, check the tables below to see how they compare in memory usage, speed and quality by language.
6073

6174
Once you've chosen a model, follow these steps:
6275

machine-learning/immich_ml/models/clip/textual.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from immich_ml.config import log
1212
from immich_ml.models.base import InferenceModel
13+
from immich_ml.models.constants import WEBLATE_TO_FLORES200
1314
from immich_ml.models.transforms import clean_text, serialize_np_array
1415
from immich_ml.schemas import ModelSession, ModelTask, ModelType
1516

@@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
1819
depends = []
1920
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
2021

21-
def _predict(self, inputs: str, **kwargs: Any) -> str:
22-
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
22+
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str:
23+
tokens = self.tokenize(inputs, language=language)
24+
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
2325
return serialize_np_array(res)
2426

2527
def _load(self) -> ModelSession:
@@ -28,6 +30,7 @@ def _load(self) -> ModelSession:
2830
self.tokenizer = self._load_tokenizer()
2931
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
3032
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
33+
self.is_nllb = self.model_name.startswith("nllb")
3134
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
3235

3336
return session
@@ -37,7 +40,7 @@ def _load_tokenizer(self) -> Tokenizer:
3740
pass
3841

3942
@abstractmethod
40-
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
43+
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
4144
pass
4245

4346
@property
@@ -92,14 +95,23 @@ def _load_tokenizer(self) -> Tokenizer:
9295

9396
return tokenizer
9497

95-
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
98+
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
9699
text = clean_text(text, canonicalize=self.canonicalize)
100+
if self.is_nllb and language is not None:
101+
flores_code = WEBLATE_TO_FLORES200.get(language)
102+
if flores_code is None:
103+
no_country = language.split("-")[0]
104+
flores_code = WEBLATE_TO_FLORES200.get(no_country)
105+
if flores_code is None:
106+
log.warning(f"Language '{language}' not found, defaulting to 'en'")
107+
flores_code = "eng_Latn"
108+
text = f"{flores_code}{text}"
97109
tokens: Encoding = self.tokenizer.encode(text)
98110
return {"text": np.array([tokens.ids], dtype=np.int32)}
99111

100112

101113
class MClipTextualEncoder(OpenClipTextualEncoder):
102-
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
114+
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
103115
text = clean_text(text, canonicalize=self.canonicalize)
104116
tokens: Encoding = self.tokenizer.encode(text)
105117
return {

machine-learning/immich_ml/models/constants.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,66 @@
8686
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
8787

8888

89+
WEBLATE_TO_FLORES200 = {
90+
"af": "afr_Latn",
91+
"ar": "arb_Arab",
92+
"az": "azj_Latn",
93+
"be": "bel_Cyrl",
94+
"bg": "bul_Cyrl",
95+
"ca": "cat_Latn",
96+
"cs": "ces_Latn",
97+
"da": "dan_Latn",
98+
"de": "deu_Latn",
99+
"el": "ell_Grek",
100+
"en": "eng_Latn",
101+
"es": "spa_Latn",
102+
"et": "est_Latn",
103+
"fa": "pes_Arab",
104+
"fi": "fin_Latn",
105+
"fr": "fra_Latn",
106+
"he": "heb_Hebr",
107+
"hi": "hin_Deva",
108+
"hr": "hrv_Latn",
109+
"hu": "hun_Latn",
110+
"hy": "hye_Armn",
111+
"id": "ind_Latn",
112+
"it": "ita_Latn",
113+
"ja": "jpn_Hira",
114+
"kmr": "kmr_Latn",
115+
"ko": "kor_Hang",
116+
"lb": "ltz_Latn",
117+
"lt": "lit_Latn",
118+
"lv": "lav_Latn",
119+
"mfa": "zsm_Latn",
120+
"mk": "mkd_Cyrl",
121+
"mn": "khk_Cyrl",
122+
"mr": "mar_Deva",
123+
"ms": "zsm_Latn",
124+
"nb-NO": "nob_Latn",
125+
"nn": "nno_Latn",
126+
"nl": "nld_Latn",
127+
"pl": "pol_Latn",
128+
"pt-BR": "por_Latn",
129+
"pt": "por_Latn",
130+
"ro": "ron_Latn",
131+
"ru": "rus_Cyrl",
132+
"sk": "slk_Latn",
133+
"sl": "slv_Latn",
134+
"sr-Cyrl": "srp_Cyrl",
135+
"sv": "swe_Latn",
136+
"ta": "tam_Taml",
137+
"te": "tel_Telu",
138+
"th": "tha_Thai",
139+
"tr": "tur_Latn",
140+
"uk": "ukr_Cyrl",
141+
"ur": "urd_Arab",
142+
"vi": "vie_Latn",
143+
"zh-CN": "zho_Hans",
144+
"zh-Hans": "zho_Hans",
145+
"zh-TW": "zho_Hant",
146+
}
147+
148+
89149
def get_model_source(model_name: str) -> ModelSource | None:
90150
cleaned_name = clean_name(model_name)
91151

machine-learning/test_main.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,88 @@ def test_openclip_tokenizer_canonicalizes_text(
494494
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
495495
mock_tokenizer.encode.assert_called_once_with("test search query")
496496

497+
def test_openclip_tokenizer_adds_flores_token_for_nllb(
498+
self,
499+
mocker: MockerFixture,
500+
clip_model_cfg: dict[str, Any],
501+
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
502+
) -> None:
503+
mocker.patch.object(OpenClipTextualEncoder, "download")
504+
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
505+
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
506+
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
507+
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
508+
mock_ids = [randint(0, 50000) for _ in range(77)]
509+
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
510+
511+
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
512+
clip_encoder._load()
513+
clip_encoder.tokenize("test search query", language="de")
514+
515+
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
516+
517+
def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found(
518+
self,
519+
mocker: MockerFixture,
520+
clip_model_cfg: dict[str, Any],
521+
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
522+
) -> None:
523+
mocker.patch.object(OpenClipTextualEncoder, "download")
524+
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
525+
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
526+
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
527+
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
528+
mock_ids = [randint(0, 50000) for _ in range(77)]
529+
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
530+
531+
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
532+
clip_encoder._load()
533+
clip_encoder.tokenize("test search query", language="de-CH")
534+
535+
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
536+
537+
def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found(
538+
self,
539+
mocker: MockerFixture,
540+
clip_model_cfg: dict[str, Any],
541+
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
542+
warning: mock.Mock,
543+
) -> None:
544+
mocker.patch.object(OpenClipTextualEncoder, "download")
545+
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
546+
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
547+
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
548+
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
549+
mock_ids = [randint(0, 50000) for _ in range(77)]
550+
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
551+
552+
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
553+
clip_encoder._load()
554+
clip_encoder.tokenize("test search query", language="unknown")
555+
556+
mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query")
557+
warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'")
558+
559+
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
560+
self,
561+
mocker: MockerFixture,
562+
clip_model_cfg: dict[str, Any],
563+
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
564+
) -> None:
565+
mocker.patch.object(OpenClipTextualEncoder, "download")
566+
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
567+
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
568+
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
569+
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
570+
mock_ids = [randint(0, 50000) for _ in range(77)]
571+
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
572+
573+
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
574+
clip_encoder._load()
575+
clip_encoder.tokenize("test search query", language="de")
576+
577+
mock_tokenizer.encode.assert_called_once_with("test search query")
578+
497579
def test_mclip_tokenizer(
498580
self,
499581
mocker: MockerFixture,

mobile/lib/models/search/search_filter.model.dart

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ class SearchFilter {
236236
String? context;
237237
String? filename;
238238
String? description;
239+
String? language;
239240
Set<Person> people;
240241
SearchLocationFilter location;
241242
SearchCameraFilter camera;
@@ -249,6 +250,7 @@ class SearchFilter {
249250
this.context,
250251
this.filename,
251252
this.description,
253+
this.language,
252254
required this.people,
253255
required this.location,
254256
required this.camera,
@@ -279,6 +281,7 @@ class SearchFilter {
279281
String? context,
280282
String? filename,
281283
String? description,
284+
String? language,
282285
Set<Person>? people,
283286
SearchLocationFilter? location,
284287
SearchCameraFilter? camera,
@@ -290,6 +293,7 @@ class SearchFilter {
290293
context: context ?? this.context,
291294
filename: filename ?? this.filename,
292295
description: description ?? this.description,
296+
language: language ?? this.language,
293297
people: people ?? this.people,
294298
location: location ?? this.location,
295299
camera: camera ?? this.camera,
@@ -301,7 +305,7 @@ class SearchFilter {
301305

302306
@override
303307
String toString() {
304-
return 'SearchFilter(context: $context, filename: $filename, description: $description, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
308+
return 'SearchFilter(context: $context, filename: $filename, description: $description, language: $language, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
305309
}
306310

307311
@override
@@ -311,6 +315,7 @@ class SearchFilter {
311315
return other.context == context &&
312316
other.filename == filename &&
313317
other.description == description &&
318+
other.language == language &&
314319
other.people == people &&
315320
other.location == location &&
316321
other.camera == camera &&
@@ -324,6 +329,7 @@ class SearchFilter {
324329
return context.hashCode ^
325330
filename.hashCode ^
326331
description.hashCode ^
332+
language.hashCode ^
327333
people.hashCode ^
328334
location.hashCode ^
329335
camera.hashCode ^

mobile/lib/pages/search/search.page.dart

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class SearchPage extends HookConsumerWidget {
4848
isFavorite: false,
4949
),
5050
mediaType: prefilter?.mediaType ?? AssetType.other,
51+
language:
52+
"${context.locale.languageCode}-${context.locale.countryCode}",
5153
),
5254
);
5355

mobile/lib/services/search.service.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SearchService {
6060
response = await _apiService.searchApi.searchSmart(
6161
SmartSearchDto(
6262
query: filter.context!,
63+
language: filter.language,
6364
country: filter.location.country,
6465
state: filter.location.state,
6566
city: filter.location.city,

mobile/openapi/lib/model/smart_search_dto.dart

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)