Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nominatim_api/search/db_search_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def get_country_tokens(self, trange: qmod.TokenRange) -> List[qmod.Token]:
"""
tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY)
if self.details.countries:
tokens = [t for t in tokens if t.lookup_word in self.details.countries]
tokens = [t for t in tokens if t.get_country() in self.details.countries]

return tokens

Expand Down
46 changes: 33 additions & 13 deletions test/python/api/search/test_db_search_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Tests for creating abstract searches from token assignments.
"""
from typing import Optional
import pytest
import dataclasses

from nominatim_api.search.query import Token, TokenRange, QueryStruct, Phrase
import nominatim_api.search.query as qmod
Expand All @@ -17,12 +19,15 @@
import nominatim_api.search.db_searches as dbs


@dataclasses.dataclass
class MyToken(Token):
cc: Optional[str] = None

def get_category(self):
return 'this', 'that'

def get_country(self):
return self.lookup_word
return self.cc


def make_query(*args):
Expand All @@ -33,18 +38,24 @@ def make_query(*args):
q.add_node(qmod.BREAK_END, qmod.PHRASE_ANY)

for start, tlist in enumerate(args):
for end, ttype, tinfo in tlist:
for tid, word in tinfo:
q.add_token(TokenRange(start, end), ttype,
MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0,
token=tid, count=1, addr_count=1,
lookup_word=word))
for end, ttype, tinfos in tlist:
for tinfo in tinfos:
if isinstance(tinfo, tuple):
q.add_token(TokenRange(start, end), ttype,
MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0,
token=tinfo[0], count=1, addr_count=1,
lookup_word=tinfo[1]))
else:
q.add_token(TokenRange(start, end), ttype, tinfo)

return q


def test_country_search():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])])
q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails())

searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
Expand All @@ -58,7 +69,10 @@ def test_country_search():


def test_country_search_with_country_restriction():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])])
q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'}))

searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
Expand All @@ -72,7 +86,10 @@ def test_country_search_with_country_restriction():


def test_country_search_with_conflicting_country_restriction():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])])
q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'}))

searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
Expand All @@ -97,8 +114,11 @@ def test_postcode_search_simple():


def test_postcode_with_country():
q = make_query([(1, qmod.TOKEN_POSTCODE, [(34, '2367')])],
[(2, qmod.TOKEN_COUNTRY, [(1, 'xx')])])
q = make_query(
[(1, qmod.TOKEN_POSTCODE, [(34, '2367')])],
[(2, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=1, count=1, addr_count=1, lookup_word='none', cc='xx'),
])])
builder = SearchBuilder(q, SearchDetails())

searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1),
Expand Down